fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/data/shm.py ADDED
@@ -0,0 +1,364 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import atexit
4
+ import logging
5
+ import multiprocessing as mp
6
+ from multiprocessing.synchronize import Event
7
+ from concurrent.futures import ThreadPoolExecutor, Future
8
+
9
+ # mp.pool is not eagerly imported, needs an explicit import
10
+ import os
11
+ import random
12
+ import shutil
13
+ import signal
14
+ from collections import deque
15
+ from pathlib import Path
16
+ from typing import Any, Generic, TypeVar
17
+ from collections.abc import Callable, Iterable, Iterator
18
+
19
+ import numpy as np
20
+ from lightning.pytorch.profilers import Profiler
21
+ from lightning.pytorch.utilities import move_data_to_device
22
+ import torch
23
+ from torch.utils.data import DataLoader, Dataset, Sampler
24
+
25
+ from fkat.utils import shm
26
+ from fkat.utils.pool import ThreadPool
27
+ from fkat.utils.profiler import profile_until_exit
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ _shutdown: Event | None = None
32
+
33
+ DEFAULT_SHUTDOWN_TIMEOUT = 60 # time for workers to gracefully shutdown
34
+
35
+
36
+ def initialize(
37
+ seed: int,
38
+ dp_rank: int,
39
+ shutdown: Event,
40
+ profiler: Profiler | None = None,
41
+ ) -> None:
42
+ # signal handlers are inherited, we used shudown flag to gracefully terminate child processes
43
+ try:
44
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
45
+ except ValueError:
46
+ pass # won't work from non-MainThread
47
+ # this allows the worker function to access `shutdown` even though it is
48
+ # not passed as an argument to the function.
49
+ global _shutdown
50
+ _shutdown = shutdown
51
+ pid = os.getpid()
52
+ logger.debug(f"worker init {pid} ...")
53
+ if profiler:
54
+ action = f"ShmDataLoader[worker_pid={pid}]"
55
+ profile_until_exit(profiler, action=action, filename_suffix=f"_{pid}")
56
+
57
+ # Set RNG seed ensure TP rank within same DP group load and iterate
58
+ # the same data in the same order with consistent RNG states
59
+ rng_seed = seed + dp_rank
60
+ np.random.seed(rng_seed)
61
+ random.seed(rng_seed)
62
+ torch.manual_seed(rng_seed)
63
+ logger.info(f"RNG seed is set with {rng_seed}")
64
+ logger.debug(f"worker init {pid} complete")
65
+
66
+
67
+ T_co = TypeVar("T_co", covariant=True)
68
+
69
+
70
+ class DataLoaderFactory(Generic[T_co]):
71
+ """Factory class for creating DataLoaders.
72
+
73
+ Args:
74
+ dataset_generator (Callable): A function that generates a dataset.
75
+ sampler_generator (Optional[Callable]): An optional function that generates a sampler for the dataset.
76
+ batch_sampler_generator (Optional[Callable]): An optional function that generates a batch sampler.
77
+ dataloader_generator (Optional[Callable]): An optional function that generates a DataLoader instance.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ dataset_generator: Callable[[], Dataset[T_co]],
83
+ sampler_generator: Callable[[Dataset[T_co]], Sampler[T_co]] | None = None,
84
+ batch_sampler_generator: Callable[[Sampler[Any] | Dataset[T_co]], Iterable[list[Any]]] | None = None,
85
+ dataloader_generator: Callable[[Any], Iterable[list[T_co]]] | None = None,
86
+ ) -> None:
87
+ # Assert that either sampler_generator or batch_sampler_generator is provided
88
+ assert sampler_generator or batch_sampler_generator, (
89
+ "either sampler_generator or batch_sampler_generation must be provided"
90
+ )
91
+
92
+ # Initialize instance variables
93
+ self.dataset_generator = dataset_generator
94
+ self.sampler_generator = sampler_generator
95
+ self.batch_sampler_generator = batch_sampler_generator
96
+ self.dataloader_generator = dataloader_generator or DataLoader
97
+
98
+ def __call__(self) -> Iterable[list[T_co]]:
99
+ """Generates a DataLoader.
100
+
101
+ Returns:
102
+ Iterable[List[T_co]]: An iterable of batches of data.
103
+ """
104
+ # Generate dataset using dataset_generator
105
+ dataset = self.dataset_generator()
106
+
107
+ # Generate sampler if sampler_generator is provided
108
+ sampler = self.sampler_generator(dataset) if self.sampler_generator else None
109
+
110
+ # Generate batch sampler if batch_sampler_generator is provided
111
+ if self.batch_sampler_generator:
112
+ batch_sampler = self.batch_sampler_generator(sampler if sampler else dataset)
113
+ sampler = None # mutually exclusive
114
+
115
+ # Generate DataLoader instance using dataloader_generator
116
+ dataloader = self.dataloader_generator( # type: ignore[call-arg]
117
+ dataset, batch_size=1, shuffle=None, sampler=sampler, batch_sampler=batch_sampler
118
+ )
119
+ return dataloader
120
+
121
+
122
+ class DataLoaderIterGenerator(Generic[T_co]):
123
+ """Generates and saves an iterator over DataLoaders.
124
+
125
+ Args:
126
+ dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible for generating DataLoaders.
127
+ num_microbatches_prefetch (int, optional): The number of microbatches to prefetch. Defaults to -1.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ dataloader_factory: DataLoaderFactory[T_co],
133
+ num_microbatch_prefetches: int = -1,
134
+ ) -> None:
135
+ """Initializes the DataLoaderIterGenerator.
136
+
137
+ Args:
138
+ dataloader_factory (DataLoaderFactory): DataLoaders provider.
139
+ num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
140
+ Defaults to -1.
141
+ """
142
+ self.dataloader_factory = dataloader_factory
143
+ self.num_microbatch_prefetches = num_microbatch_prefetches
144
+
145
+ def __call__(self, path: Path) -> None:
146
+ """Generates and saves an iterator over DataLoaders.
147
+
148
+ Args:
149
+ path (Path): The path where the iterator will be saved.
150
+ """
151
+ # Log debug message indicating the start of the process
152
+ logger.debug("generate ...")
153
+
154
+ # Generate a DataLoader
155
+ dataloader = self.dataloader_factory()
156
+
157
+ # Create an iterator over the DataLoader
158
+ dataloader_iter = iter(dataloader)
159
+
160
+ # Access global variable _shutdown
161
+ global _shutdown
162
+
163
+ # Save the iterator using shm.save_iter
164
+ shm.save_iter(
165
+ dataloader_iter,
166
+ path=path,
167
+ max_items=self.num_microbatch_prefetches,
168
+ should_stop=lambda: _shutdown is not None and _shutdown.is_set(),
169
+ )
170
+
171
+ # Log debug message indicating the completion of the process
172
+ logger.debug("generate complete")
173
+
174
+
175
+ # Sub-class multiprocessing.Process to make sure it's not started in daemon mode by the Pool
176
+ class NoDaemonProcess(mp.Process):
177
+ @property
178
+ def daemon(self) -> bool:
179
+ return False
180
+
181
+ @daemon.setter
182
+ def daemon(self, value: bool) -> None:
183
+ pass
184
+
185
+
186
+ class NoDaemonContext(type(mp.get_context())): # type: ignore[misc]
187
+ Process = NoDaemonProcess
188
+
189
+
190
+ # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool
191
+ # because the latter is only a wrapper function, not a proper class.
192
+ class NoDaemonPool(mp.pool.Pool): # type: ignore[unresolved-attribute]
193
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
194
+ kwargs["context"] = NoDaemonContext()
195
+ super().__init__(*args, **kwargs)
196
+
197
+
198
+ class ShmDataLoader(Iterable[list[T_co]]):
199
+ """A :class:`DataLoader` that uses shared memory to efficiently manage and prefetch data batches.
200
+
201
+ Enables double-buffered micro-batch processing and fetching that overlaps with model
202
+ forward/backward passes, minimizing dataloading overhead.
203
+
204
+ Args:
205
+ seed (int): Random seed for reproducibility. Use ${seed} at top level in config.yaml.
206
+ dataloader_factory (DataLoaderFactory[T_co]): Factory for creating DataLoaders.
207
+ num_microbatch_prefetches (int, optional): Number of microbatches to prefetch.
208
+ Defaults to -1.
209
+ dp_rank (int, optional): Rank of the current process. Defaults to 0.
210
+ profiler (Optional[Profiler], optional): Profiler for profiling.
211
+ Defaults to None.
212
+ device (Optional[torch.device]): device to move the microbatches to in the background
213
+ multiprocessing (Optional[True]): whether to instantiate DataLoader in a separate process.
214
+ Defaults to True to relieve pressure from the training process, use False to debug and profile
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ seed: int,
220
+ dataloader_factory: DataLoaderFactory[T_co],
221
+ num_microbatch_prefetches: int = -1,
222
+ dp_rank: int = 0,
223
+ profiler: Profiler | None = None,
224
+ device: torch.device | None = None,
225
+ multiprocessing: bool = True,
226
+ ) -> None:
227
+ self.microbatches: Iterator[list[T_co]] | None = None
228
+ self.path: Path | None = None
229
+ self.device: torch.device | None = device
230
+ self.cleanup: set[Path] = set()
231
+ self.shutdown = mp.Event()
232
+ self.dataloader_factory = dataloader_factory
233
+ self.dataloader_iter_generator = DataLoaderIterGenerator(
234
+ dataloader_factory,
235
+ num_microbatch_prefetches,
236
+ )
237
+ self.data_jobs: deque[tuple[Path, mp.pool.AsyncResult[Any]]] = deque() # type: ignore[unresolved-attribute]
238
+
239
+ # Initialize a new ProcessPoolExecutor instance for prefetching if necessary
240
+ signal.signal(signal.SIGTERM, self.teardown) # terminate signal
241
+ signal.signal(signal.SIGINT, self.teardown) # keyboard interrupt
242
+ atexit.register(self.teardown)
243
+ self.writing_pool: NoDaemonPool | ThreadPool = (
244
+ NoDaemonPool(1, initializer=initialize, initargs=(seed, dp_rank, self.shutdown, profiler))
245
+ if multiprocessing
246
+ else ThreadPool(
247
+ max_workers=1,
248
+ thread_name_prefix="ShmDataWriter",
249
+ initializer=initialize,
250
+ initargs=(seed, dp_rank, self.shutdown, profiler),
251
+ )
252
+ )
253
+ self.reading_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ShmDataReader")
254
+ self.next_batch: Future[Any] | None = None
255
+
256
+ def __iter__(self) -> Iterator[list[T_co]]:
257
+ # all reading needs to go through the same TPE to avoid contention
258
+ if not self.reading_pool._shutdown:
259
+ self.reading_pool.submit(self._cleanup).result()
260
+ return self
261
+
262
+ def __next__(self) -> list[T_co]:
263
+ if not self.next_batch:
264
+ self.load_batch()
265
+ assert self.next_batch
266
+ batch = self.next_batch.result()
267
+ self.load_batch() # double-buffering
268
+ return batch
269
+
270
+ def load_batch(self) -> None:
271
+ if not self.reading_pool._shutdown:
272
+ self.next_batch = self.reading_pool.submit(self.load_batch_sync)
273
+
274
+ def load_batch_sync(self) -> list[T_co]:
275
+ while True:
276
+ if self.microbatches:
277
+ # Fetch the next microbatch if available
278
+ try:
279
+ microbatch = next(self.microbatches)
280
+ if self.device:
281
+ microbatch = move_data_to_device(microbatch, self.device)
282
+ return microbatch
283
+ # If no microbatches are available, which means all microbatches from current until exhausted
284
+ except StopIteration:
285
+ if self.path:
286
+ self.cleanup.remove(self.path)
287
+ self.path = None
288
+ self.microbatches = None
289
+ raise
290
+
291
+ if len(self.data_jobs) == 0:
292
+ logger.debug("load iter scheduling ...")
293
+ self.prefetch()
294
+ path, data_job = self.data_jobs.popleft()
295
+ logger.debug(f"load iter to {path} ...")
296
+
297
+ def wait_callback() -> None:
298
+ if not data_job.ready(): # noqa: B023
299
+ # Job is still running
300
+ return None
301
+ else:
302
+ # Job is finished, raise exception if job failed.
303
+ data_job.get() # noqa: B023
304
+ # Return whether the call completed without raising an exception.
305
+ assert data_job.successful() # noqa: B023
306
+
307
+ self.microbatches = shm.load_iter(path, wait_callback=wait_callback)
308
+ self.cleanup.add(path)
309
+ self.path = path
310
+
311
+ def set_device(self, device: torch.device | None) -> None:
312
+ self.device = device
313
+
314
+ def prefetch(self) -> None:
315
+ path = shm.generate_path()
316
+ data_job = self.writing_pool.apply_async(self.dataloader_iter_generator, (path,))
317
+ self.data_jobs.append((path, data_job))
318
+ if logger.isEnabledFor(logging.DEBUG):
319
+ logger.debug(f"queued {path}, {len(self.data_jobs)}")
320
+
321
+ def _cleanup(self, stop_pool: bool = False) -> None:
322
+ self.microbatches = None
323
+ if self.next_batch:
324
+ try:
325
+ # if called on teardown/on_exception/__del__ will wait for the pending work to finish
326
+ # if called on __init__ it's already done
327
+ self.next_batch.result()
328
+ except Exception:
329
+ pass
330
+ self.next_batch = None
331
+ self.shutdown.set() # signal running tasks to stop
332
+ if stop_pool:
333
+ self.writing_pool.close() # no new tasks can run
334
+ self.reading_pool.shutdown()
335
+ for path, result in self.data_jobs:
336
+ logger.debug(f"waiting for {path} to stop ...")
337
+ self.cleanup.add(path)
338
+ try:
339
+ result.wait(timeout=DEFAULT_SHUTDOWN_TIMEOUT)
340
+ except Exception:
341
+ pass
342
+ logger.debug(f"{path} stopped ...")
343
+ self.data_jobs.clear()
344
+ self.shutdown.clear()
345
+ for path in self.cleanup:
346
+ logger.debug(f"removing {path} ...")
347
+ shutil.rmtree(path, ignore_errors=True)
348
+
349
+ # called when fit/validate/predict/test is complete
350
+ def teardown(self, *args: Any) -> None:
351
+ logger.debug("teardown ...")
352
+ # all reading needs to go through the same TPE to avoid contention
353
+ self._cleanup(stop_pool=True)
354
+ logger.debug("teardown complete")
355
+
356
+ # will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
357
+ # once the below callback is operational we no longer need __del__ override
358
+ def on_exception(self, exception: BaseException) -> None:
359
+ self.teardown()
360
+
361
+ # called when the iterable link pointing to this object goes out of scope
362
+ # e.g. when exception happens
363
+ def __del__(self) -> None:
364
+ self.teardown()
fkat/predict.py ADDED
@@ -0,0 +1,32 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #!/usr/bin/env python
4
+
5
+ """
6
+ The ``fkat.predict`` entrypoint processes the provided config,
7
+ instantiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.predict()``.
8
+ """
9
+
10
+ import hydra
11
+ import lightning as L
12
+ from omegaconf import DictConfig
13
+
14
+ from fkat import initialize, run_main
15
+
16
+
17
+ @hydra.main(version_base="1.3")
18
+ def main(cfg: DictConfig) -> None:
19
+ s = initialize(cfg)
20
+ kwargs = {
21
+ "ckpt_path": s.ckpt_path,
22
+ "return_predictions": s.return_predictions,
23
+ }
24
+ if isinstance(s.data, L.LightningDataModule):
25
+ kwargs["datamodule"] = s.data
26
+ else:
27
+ kwargs["predict_dataloader"] = s.data.predict_dataloader() if s.data else None
28
+ s.trainer.predict(s.model, **kwargs)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ run_main(main)
fkat/py.typed ADDED
File without changes
@@ -0,0 +1,3 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
@@ -0,0 +1,11 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Protocol, Any
4
+
5
+
6
+ class LightningAction(Protocol):
7
+ """A generic action to be executed given the context provided via key-value arguments."""
8
+
9
+ def perform(self, **kwargs: Any) -> Any:
10
+ """Performs the action with the context provided via key-value arguments."""
11
+ ...
@@ -0,0 +1,3 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
@@ -0,0 +1,29 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import TYPE_CHECKING, Any
7
+ from typing_extensions import override
8
+
9
+ from fkat.utils import boto3
10
+ from fkat.pytorch.actions import LightningAction
11
+
12
+ if TYPE_CHECKING:
13
+ from types_boto3_batch import BatchClient
14
+
15
+
16
+ class TerminateJob(LightningAction):
17
+ """This action calls Batch.TerminateJob."""
18
+
19
+ def __init__(self, job_id: str | None = None) -> None:
20
+ self.job_id = job_id
21
+
22
+ @override
23
+ def perform(self, **kwargs: Any) -> Any:
24
+ """Calls Batch.TerminateJob."""
25
+ job_id = self.job_id or os.getenv("AWS_BATCH_JOB_ID")
26
+ if job_id:
27
+ reason = ",".join(f"{k}={v}" for k, v in kwargs.items() if isinstance(v, str))
28
+ batch: BatchClient = boto3.session().client("batch") # type: ignore[assignment]
29
+ batch.terminate_job(jobId=job_id, reason=reason)
@@ -0,0 +1,61 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+ from typing_extensions import override
7
+
8
+ from fkat.utils import boto3, assert_not_none
9
+ from fkat.utils.aws import imds
10
+ from fkat.pytorch.loggers import LightningLogger, CompositeLogger
11
+ from fkat.pytorch.actions import LightningAction
12
+
13
+ if TYPE_CHECKING:
14
+ from types_boto3_ec2 import EC2Client
15
+
16
+
17
+ class TerminateInstances(LightningAction):
18
+ """This action calls EC2.TerminateInstances."""
19
+
20
+ def __init__(self, instance_ids: list[str] | None = None) -> None:
21
+ self.instance_ids = instance_ids or []
22
+
23
+ @override
24
+ def perform(self, **kwargs: Any) -> Any:
25
+ """Calls EC2.TerminateInstances with the provided ``instance_id`` or the current node's instance_id"""
26
+ instance_ids = self.instance_ids or kwargs.get("instance_ids") or [imds.instance_metadata().instance_id]
27
+ ec2: EC2Client = boto3.session().client("ec2") # type: ignore[assignment]
28
+ ec2.terminate_instances(InstanceIds=instance_ids)
29
+
30
+
31
+ class RebootInstances(LightningAction):
32
+ """This action calls EC2.RebootInstances."""
33
+
34
+ def __init__(self, instance_ids: list[str] | None = None) -> None:
35
+ self.instance_ids = instance_ids or []
36
+
37
+ @override
38
+ def perform(self, **kwargs: Any) -> Any:
39
+ """Calls EC2.RebootInstances with the provided ``instance_id`` or the current node's instance_id"""
40
+ instance_ids = self.instance_ids or kwargs.get("instance_ids") or [imds.instance_metadata().instance_id]
41
+ ec2: EC2Client = boto3.session().client("ec2") # type: ignore[assignment]
42
+ ec2.reboot_instances(InstanceIds=instance_ids)
43
+
44
+
45
+ class LogInstanceTags(LightningAction):
46
+ """This action logs tags suffixed with EC2 instance-id."""
47
+
48
+ def __init__(self, instance_id: str | None = None, tags: list[str] | None = None) -> None:
49
+ self.instance_id = instance_id
50
+ self.tags = tags or []
51
+ self.logger: LightningLogger | None = None
52
+
53
+ @override
54
+ def perform(self, **kwargs: Any) -> Any:
55
+ """Calls LightningLogger.log_tag with all string-valued keys, requires ``trainer`` to be provided"""
56
+ instance_id = self.instance_id or kwargs.get("instance_id")
57
+ for t in self.tags:
58
+ if v := kwargs.get(t):
59
+ instance_id = instance_id or imds.instance_metadata().instance_id
60
+ self.logger = self.logger or CompositeLogger(assert_not_none(kwargs.get("trainer"), "trainer"))
61
+ self.logger.log_tag(f"{instance_id}/{t}/{v}", "True")
@@ -0,0 +1,2 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,16 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .xid import Xid
4
+ from .nsys import Nsys
5
+ from .nvtx import Nvtx
6
+ from .memory import MemoryObserver
7
+ from .cache import EmptyCache
8
+
9
+
10
+ __all__ = [
11
+ "Xid",
12
+ "Nsys",
13
+ "Nvtx",
14
+ "MemoryObserver",
15
+ "EmptyCache",
16
+ ]
@@ -0,0 +1,115 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any, TYPE_CHECKING
4
+ from typing_extensions import override
5
+
6
+ import torch
7
+ import lightning as L
8
+
9
+ if TYPE_CHECKING:
10
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
11
+
12
+ from fkat.pytorch.schedule import (
13
+ Schedule,
14
+ Never,
15
+ )
16
+
17
+
18
+ class EmptyCache(L.Callback):
19
+ def __init__(self, schedule: Schedule | None = None) -> None:
20
+ """
21
+ PyTorch Lightning callback to trigger ``torch.cuda.empty_cache()``.
22
+
23
+ This callback allows fine-grained control over CUDA's Caching Allocator during training,
24
+ validation, testing, and prediction.
25
+
26
+ Args:
27
+ schedule (Optional[Schedule]): When to invoke ``torch, defaults to class:`Never`
28
+
29
+ Example:
30
+ >>> trainer = Trainer(callbacks=[EmptyCache()])
31
+ """
32
+ self.schedule = schedule or Never()
33
+
34
+ @override
35
+ def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
36
+ """Perform empty cache after training epoch."""
37
+ self.maybe_empty_cache(trainer, "train")
38
+
39
+ @override
40
+ def on_train_batch_end(
41
+ self,
42
+ trainer: "L.Trainer",
43
+ pl_module: "L.LightningModule",
44
+ outputs: "STEP_OUTPUT",
45
+ batch: Any,
46
+ batch_idx: int,
47
+ ) -> None:
48
+ """Perform empty cache after training batch if needed."""
49
+ self.maybe_empty_cache(trainer, "train", batch_idx)
50
+
51
+ @override
52
+ def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
53
+ """Perform empty cache after validation epoch."""
54
+ self.maybe_empty_cache(trainer, "validation")
55
+
56
+ @override
57
+ def on_validation_batch_end(
58
+ self,
59
+ trainer: "L.Trainer",
60
+ pl_module: "L.LightningModule",
61
+ outputs: "STEP_OUTPUT",
62
+ batch: Any,
63
+ batch_idx: int,
64
+ dataloader_idx: int = 0,
65
+ ) -> None:
66
+ """Perform empty cache after validation batch if needed."""
67
+ self.maybe_empty_cache(trainer, "validation", batch_idx)
68
+
69
+ @override
70
+ def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
71
+ """Perform empty cache after test epoch."""
72
+ self.maybe_empty_cache(trainer, "test")
73
+
74
+ @override
75
+ def on_predict_batch_end(
76
+ self,
77
+ trainer: "L.Trainer",
78
+ pl_module: "L.LightningModule",
79
+ outputs: Any,
80
+ batch: Any,
81
+ batch_idx: int,
82
+ dataloader_idx: int = 0,
83
+ ) -> None:
84
+ """Perform empty cache after prediction batch if needed."""
85
+ self.maybe_empty_cache(trainer, "predict", batch_idx)
86
+
87
+ @override
88
+ def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
89
+ """Perform empty cache after predict epoch."""
90
+ self.maybe_empty_cache(trainer, "predict")
91
+
92
+ @override
93
+ def on_test_batch_end(
94
+ self,
95
+ trainer: "L.Trainer",
96
+ pl_module: "L.LightningModule",
97
+ outputs: "STEP_OUTPUT",
98
+ batch: Any,
99
+ batch_idx: int,
100
+ dataloader_idx: int = 0,
101
+ ) -> None:
102
+ """Perform empty cache after test batch if needed."""
103
+ self.maybe_empty_cache(trainer, "test", batch_idx)
104
+
105
+ def maybe_empty_cache(self, trainer: "L.Trainer", stage: str, batch_idx: int | None = None) -> None:
106
+ """
107
+ Perform empty cache if conditions are met.
108
+
109
+ Args:
110
+ trainer (L.Trainer): Lightning Trainer
111
+ stage (str): training stage
112
+ batch_idx (int | None): Current batch index if available
113
+ """
114
+ if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
115
+ torch.cuda.empty_cache()