fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/data/shm.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import atexit
|
|
4
|
+
import logging
|
|
5
|
+
import multiprocessing as mp
|
|
6
|
+
from multiprocessing.synchronize import Event
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, Future
|
|
8
|
+
|
|
9
|
+
# mp.pool is not eagerly imported, needs an explicit import
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
import shutil
|
|
13
|
+
import signal
|
|
14
|
+
from collections import deque
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Generic, TypeVar
|
|
17
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from lightning.pytorch.profilers import Profiler
|
|
21
|
+
from lightning.pytorch.utilities import move_data_to_device
|
|
22
|
+
import torch
|
|
23
|
+
from torch.utils.data import DataLoader, Dataset, Sampler
|
|
24
|
+
|
|
25
|
+
from fkat.utils import shm
|
|
26
|
+
from fkat.utils.pool import ThreadPool
|
|
27
|
+
from fkat.utils.profiler import profile_until_exit
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
_shutdown: Event | None = None
|
|
32
|
+
|
|
33
|
+
DEFAULT_SHUTDOWN_TIMEOUT = 60 # time for workers to gracefully shutdown
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def initialize(
|
|
37
|
+
seed: int,
|
|
38
|
+
dp_rank: int,
|
|
39
|
+
shutdown: Event,
|
|
40
|
+
profiler: Profiler | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
# signal handlers are inherited, we used shudown flag to gracefully terminate child processes
|
|
43
|
+
try:
|
|
44
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
45
|
+
except ValueError:
|
|
46
|
+
pass # won't work from non-MainThread
|
|
47
|
+
# this allows the worker function to access `shutdown` even though it is
|
|
48
|
+
# not passed as an argument to the function.
|
|
49
|
+
global _shutdown
|
|
50
|
+
_shutdown = shutdown
|
|
51
|
+
pid = os.getpid()
|
|
52
|
+
logger.debug(f"worker init {pid} ...")
|
|
53
|
+
if profiler:
|
|
54
|
+
action = f"ShmDataLoader[worker_pid={pid}]"
|
|
55
|
+
profile_until_exit(profiler, action=action, filename_suffix=f"_{pid}")
|
|
56
|
+
|
|
57
|
+
# Set RNG seed ensure TP rank within same DP group load and iterate
|
|
58
|
+
# the same data in the same order with consistent RNG states
|
|
59
|
+
rng_seed = seed + dp_rank
|
|
60
|
+
np.random.seed(rng_seed)
|
|
61
|
+
random.seed(rng_seed)
|
|
62
|
+
torch.manual_seed(rng_seed)
|
|
63
|
+
logger.info(f"RNG seed is set with {rng_seed}")
|
|
64
|
+
logger.debug(f"worker init {pid} complete")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DataLoaderFactory(Generic[T_co]):
|
|
71
|
+
"""Factory class for creating DataLoaders.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
dataset_generator (Callable): A function that generates a dataset.
|
|
75
|
+
sampler_generator (Optional[Callable]): An optional function that generates a sampler for the dataset.
|
|
76
|
+
batch_sampler_generator (Optional[Callable]): An optional function that generates a batch sampler.
|
|
77
|
+
dataloader_generator (Optional[Callable]): An optional function that generates a DataLoader instance.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
dataset_generator: Callable[[], Dataset[T_co]],
|
|
83
|
+
sampler_generator: Callable[[Dataset[T_co]], Sampler[T_co]] | None = None,
|
|
84
|
+
batch_sampler_generator: Callable[[Sampler[Any] | Dataset[T_co]], Iterable[list[Any]]] | None = None,
|
|
85
|
+
dataloader_generator: Callable[[Any], Iterable[list[T_co]]] | None = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
# Assert that either sampler_generator or batch_sampler_generator is provided
|
|
88
|
+
assert sampler_generator or batch_sampler_generator, (
|
|
89
|
+
"either sampler_generator or batch_sampler_generation must be provided"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Initialize instance variables
|
|
93
|
+
self.dataset_generator = dataset_generator
|
|
94
|
+
self.sampler_generator = sampler_generator
|
|
95
|
+
self.batch_sampler_generator = batch_sampler_generator
|
|
96
|
+
self.dataloader_generator = dataloader_generator or DataLoader
|
|
97
|
+
|
|
98
|
+
def __call__(self) -> Iterable[list[T_co]]:
|
|
99
|
+
"""Generates a DataLoader.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Iterable[List[T_co]]: An iterable of batches of data.
|
|
103
|
+
"""
|
|
104
|
+
# Generate dataset using dataset_generator
|
|
105
|
+
dataset = self.dataset_generator()
|
|
106
|
+
|
|
107
|
+
# Generate sampler if sampler_generator is provided
|
|
108
|
+
sampler = self.sampler_generator(dataset) if self.sampler_generator else None
|
|
109
|
+
|
|
110
|
+
# Generate batch sampler if batch_sampler_generator is provided
|
|
111
|
+
if self.batch_sampler_generator:
|
|
112
|
+
batch_sampler = self.batch_sampler_generator(sampler if sampler else dataset)
|
|
113
|
+
sampler = None # mutually exclusive
|
|
114
|
+
|
|
115
|
+
# Generate DataLoader instance using dataloader_generator
|
|
116
|
+
dataloader = self.dataloader_generator( # type: ignore[call-arg]
|
|
117
|
+
dataset, batch_size=1, shuffle=None, sampler=sampler, batch_sampler=batch_sampler
|
|
118
|
+
)
|
|
119
|
+
return dataloader
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class DataLoaderIterGenerator(Generic[T_co]):
|
|
123
|
+
"""Generates and saves an iterator over DataLoaders.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible for generating DataLoaders.
|
|
127
|
+
num_microbatches_prefetch (int, optional): The number of microbatches to prefetch. Defaults to -1.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
dataloader_factory: DataLoaderFactory[T_co],
|
|
133
|
+
num_microbatch_prefetches: int = -1,
|
|
134
|
+
) -> None:
|
|
135
|
+
"""Initializes the DataLoaderIterGenerator.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
dataloader_factory (DataLoaderFactory): DataLoaders provider.
|
|
139
|
+
num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
|
|
140
|
+
Defaults to -1.
|
|
141
|
+
"""
|
|
142
|
+
self.dataloader_factory = dataloader_factory
|
|
143
|
+
self.num_microbatch_prefetches = num_microbatch_prefetches
|
|
144
|
+
|
|
145
|
+
def __call__(self, path: Path) -> None:
|
|
146
|
+
"""Generates and saves an iterator over DataLoaders.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
path (Path): The path where the iterator will be saved.
|
|
150
|
+
"""
|
|
151
|
+
# Log debug message indicating the start of the process
|
|
152
|
+
logger.debug("generate ...")
|
|
153
|
+
|
|
154
|
+
# Generate a DataLoader
|
|
155
|
+
dataloader = self.dataloader_factory()
|
|
156
|
+
|
|
157
|
+
# Create an iterator over the DataLoader
|
|
158
|
+
dataloader_iter = iter(dataloader)
|
|
159
|
+
|
|
160
|
+
# Access global variable _shutdown
|
|
161
|
+
global _shutdown
|
|
162
|
+
|
|
163
|
+
# Save the iterator using shm.save_iter
|
|
164
|
+
shm.save_iter(
|
|
165
|
+
dataloader_iter,
|
|
166
|
+
path=path,
|
|
167
|
+
max_items=self.num_microbatch_prefetches,
|
|
168
|
+
should_stop=lambda: _shutdown is not None and _shutdown.is_set(),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Log debug message indicating the completion of the process
|
|
172
|
+
logger.debug("generate complete")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# Sub-class multiprocessing.Process to make sure it's not started in daemon mode by the Pool
|
|
176
|
+
class NoDaemonProcess(mp.Process):
|
|
177
|
+
@property
|
|
178
|
+
def daemon(self) -> bool:
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
@daemon.setter
|
|
182
|
+
def daemon(self, value: bool) -> None:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class NoDaemonContext(type(mp.get_context())): # type: ignore[misc]
|
|
187
|
+
Process = NoDaemonProcess
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool
|
|
191
|
+
# because the latter is only a wrapper function, not a proper class.
|
|
192
|
+
class NoDaemonPool(mp.pool.Pool): # type: ignore[unresolved-attribute]
|
|
193
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
194
|
+
kwargs["context"] = NoDaemonContext()
|
|
195
|
+
super().__init__(*args, **kwargs)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class ShmDataLoader(Iterable[list[T_co]]):
|
|
199
|
+
"""A :class:`DataLoader` that uses shared memory to efficiently manage and prefetch data batches.
|
|
200
|
+
|
|
201
|
+
Enables double-buffered micro-batch processing and fetching that overlaps with model
|
|
202
|
+
forward/backward passes, minimizing dataloading overhead.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
seed (int): Random seed for reproducibility. Use ${seed} at top level in config.yaml.
|
|
206
|
+
dataloader_factory (DataLoaderFactory[T_co]): Factory for creating DataLoaders.
|
|
207
|
+
num_microbatch_prefetches (int, optional): Number of microbatches to prefetch.
|
|
208
|
+
Defaults to -1.
|
|
209
|
+
dp_rank (int, optional): Rank of the current process. Defaults to 0.
|
|
210
|
+
profiler (Optional[Profiler], optional): Profiler for profiling.
|
|
211
|
+
Defaults to None.
|
|
212
|
+
device (Optional[torch.device]): device to move the microbatches to in the background
|
|
213
|
+
multiprocessing (Optional[True]): whether to instantiate DataLoader in a separate process.
|
|
214
|
+
Defaults to True to relieve pressure from the training process, use False to debug and profile
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
seed: int,
|
|
220
|
+
dataloader_factory: DataLoaderFactory[T_co],
|
|
221
|
+
num_microbatch_prefetches: int = -1,
|
|
222
|
+
dp_rank: int = 0,
|
|
223
|
+
profiler: Profiler | None = None,
|
|
224
|
+
device: torch.device | None = None,
|
|
225
|
+
multiprocessing: bool = True,
|
|
226
|
+
) -> None:
|
|
227
|
+
self.microbatches: Iterator[list[T_co]] | None = None
|
|
228
|
+
self.path: Path | None = None
|
|
229
|
+
self.device: torch.device | None = device
|
|
230
|
+
self.cleanup: set[Path] = set()
|
|
231
|
+
self.shutdown = mp.Event()
|
|
232
|
+
self.dataloader_factory = dataloader_factory
|
|
233
|
+
self.dataloader_iter_generator = DataLoaderIterGenerator(
|
|
234
|
+
dataloader_factory,
|
|
235
|
+
num_microbatch_prefetches,
|
|
236
|
+
)
|
|
237
|
+
self.data_jobs: deque[tuple[Path, mp.pool.AsyncResult[Any]]] = deque() # type: ignore[unresolved-attribute]
|
|
238
|
+
|
|
239
|
+
# Initialize a new ProcessPoolExecutor instance for prefetching if necessary
|
|
240
|
+
signal.signal(signal.SIGTERM, self.teardown) # terminate signal
|
|
241
|
+
signal.signal(signal.SIGINT, self.teardown) # keyboard interrupt
|
|
242
|
+
atexit.register(self.teardown)
|
|
243
|
+
self.writing_pool: NoDaemonPool | ThreadPool = (
|
|
244
|
+
NoDaemonPool(1, initializer=initialize, initargs=(seed, dp_rank, self.shutdown, profiler))
|
|
245
|
+
if multiprocessing
|
|
246
|
+
else ThreadPool(
|
|
247
|
+
max_workers=1,
|
|
248
|
+
thread_name_prefix="ShmDataWriter",
|
|
249
|
+
initializer=initialize,
|
|
250
|
+
initargs=(seed, dp_rank, self.shutdown, profiler),
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
self.reading_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ShmDataReader")
|
|
254
|
+
self.next_batch: Future[Any] | None = None
|
|
255
|
+
|
|
256
|
+
def __iter__(self) -> Iterator[list[T_co]]:
|
|
257
|
+
# all reading needs to go through the same TPE to avoid contention
|
|
258
|
+
if not self.reading_pool._shutdown:
|
|
259
|
+
self.reading_pool.submit(self._cleanup).result()
|
|
260
|
+
return self
|
|
261
|
+
|
|
262
|
+
def __next__(self) -> list[T_co]:
|
|
263
|
+
if not self.next_batch:
|
|
264
|
+
self.load_batch()
|
|
265
|
+
assert self.next_batch
|
|
266
|
+
batch = self.next_batch.result()
|
|
267
|
+
self.load_batch() # double-buffering
|
|
268
|
+
return batch
|
|
269
|
+
|
|
270
|
+
def load_batch(self) -> None:
|
|
271
|
+
if not self.reading_pool._shutdown:
|
|
272
|
+
self.next_batch = self.reading_pool.submit(self.load_batch_sync)
|
|
273
|
+
|
|
274
|
+
def load_batch_sync(self) -> list[T_co]:
|
|
275
|
+
while True:
|
|
276
|
+
if self.microbatches:
|
|
277
|
+
# Fetch the next microbatch if available
|
|
278
|
+
try:
|
|
279
|
+
microbatch = next(self.microbatches)
|
|
280
|
+
if self.device:
|
|
281
|
+
microbatch = move_data_to_device(microbatch, self.device)
|
|
282
|
+
return microbatch
|
|
283
|
+
# If no microbatches are available, which means all microbatches from current until exhausted
|
|
284
|
+
except StopIteration:
|
|
285
|
+
if self.path:
|
|
286
|
+
self.cleanup.remove(self.path)
|
|
287
|
+
self.path = None
|
|
288
|
+
self.microbatches = None
|
|
289
|
+
raise
|
|
290
|
+
|
|
291
|
+
if len(self.data_jobs) == 0:
|
|
292
|
+
logger.debug("load iter scheduling ...")
|
|
293
|
+
self.prefetch()
|
|
294
|
+
path, data_job = self.data_jobs.popleft()
|
|
295
|
+
logger.debug(f"load iter to {path} ...")
|
|
296
|
+
|
|
297
|
+
def wait_callback() -> None:
|
|
298
|
+
if not data_job.ready(): # noqa: B023
|
|
299
|
+
# Job is still running
|
|
300
|
+
return None
|
|
301
|
+
else:
|
|
302
|
+
# Job is finished, raise exception if job failed.
|
|
303
|
+
data_job.get() # noqa: B023
|
|
304
|
+
# Return whether the call completed without raising an exception.
|
|
305
|
+
assert data_job.successful() # noqa: B023
|
|
306
|
+
|
|
307
|
+
self.microbatches = shm.load_iter(path, wait_callback=wait_callback)
|
|
308
|
+
self.cleanup.add(path)
|
|
309
|
+
self.path = path
|
|
310
|
+
|
|
311
|
+
def set_device(self, device: torch.device | None) -> None:
|
|
312
|
+
self.device = device
|
|
313
|
+
|
|
314
|
+
def prefetch(self) -> None:
|
|
315
|
+
path = shm.generate_path()
|
|
316
|
+
data_job = self.writing_pool.apply_async(self.dataloader_iter_generator, (path,))
|
|
317
|
+
self.data_jobs.append((path, data_job))
|
|
318
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
319
|
+
logger.debug(f"queued {path}, {len(self.data_jobs)}")
|
|
320
|
+
|
|
321
|
+
def _cleanup(self, stop_pool: bool = False) -> None:
|
|
322
|
+
self.microbatches = None
|
|
323
|
+
if self.next_batch:
|
|
324
|
+
try:
|
|
325
|
+
# if called on teardown/on_exception/__del__ will wait for the pending work to finish
|
|
326
|
+
# if called on __init__ it's already done
|
|
327
|
+
self.next_batch.result()
|
|
328
|
+
except Exception:
|
|
329
|
+
pass
|
|
330
|
+
self.next_batch = None
|
|
331
|
+
self.shutdown.set() # signal running tasks to stop
|
|
332
|
+
if stop_pool:
|
|
333
|
+
self.writing_pool.close() # no new tasks can run
|
|
334
|
+
self.reading_pool.shutdown()
|
|
335
|
+
for path, result in self.data_jobs:
|
|
336
|
+
logger.debug(f"waiting for {path} to stop ...")
|
|
337
|
+
self.cleanup.add(path)
|
|
338
|
+
try:
|
|
339
|
+
result.wait(timeout=DEFAULT_SHUTDOWN_TIMEOUT)
|
|
340
|
+
except Exception:
|
|
341
|
+
pass
|
|
342
|
+
logger.debug(f"{path} stopped ...")
|
|
343
|
+
self.data_jobs.clear()
|
|
344
|
+
self.shutdown.clear()
|
|
345
|
+
for path in self.cleanup:
|
|
346
|
+
logger.debug(f"removing {path} ...")
|
|
347
|
+
shutil.rmtree(path, ignore_errors=True)
|
|
348
|
+
|
|
349
|
+
# called when fit/validate/predict/test is complete
|
|
350
|
+
def teardown(self, *args: Any) -> None:
|
|
351
|
+
logger.debug("teardown ...")
|
|
352
|
+
# all reading needs to go through the same TPE to avoid contention
|
|
353
|
+
self._cleanup(stop_pool=True)
|
|
354
|
+
logger.debug("teardown complete")
|
|
355
|
+
|
|
356
|
+
# will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
|
|
357
|
+
# once the below callback is operational we no longer need __del__ override
|
|
358
|
+
def on_exception(self, exception: BaseException) -> None:
|
|
359
|
+
self.teardown()
|
|
360
|
+
|
|
361
|
+
# called when the iterable link pointing to this object goes out of scope
|
|
362
|
+
# e.g. when exception happens
|
|
363
|
+
def __del__(self) -> None:
|
|
364
|
+
self.teardown()
|
fkat/predict.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#!/usr/bin/env python
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
The ``fkat.predict`` entrypoint processes the provided config,
|
|
7
|
+
instantiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.predict()``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import hydra
|
|
11
|
+
import lightning as L
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
|
|
14
|
+
from fkat import initialize, run_main
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@hydra.main(version_base="1.3")
|
|
18
|
+
def main(cfg: DictConfig) -> None:
|
|
19
|
+
s = initialize(cfg)
|
|
20
|
+
kwargs = {
|
|
21
|
+
"ckpt_path": s.ckpt_path,
|
|
22
|
+
"return_predictions": s.return_predictions,
|
|
23
|
+
}
|
|
24
|
+
if isinstance(s.data, L.LightningDataModule):
|
|
25
|
+
kwargs["datamodule"] = s.data
|
|
26
|
+
else:
|
|
27
|
+
kwargs["predict_dataloader"] = s.data.predict_dataloader() if s.data else None
|
|
28
|
+
s.trainer.predict(s.model, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
run_main(main)
|
fkat/py.typed
ADDED
|
File without changes
|
fkat/pytorch/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Protocol, Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LightningAction(Protocol):
|
|
7
|
+
"""A generic action to be executed given the context provided via key-value arguments."""
|
|
8
|
+
|
|
9
|
+
def perform(self, **kwargs: Any) -> Any:
|
|
10
|
+
"""Performs the action with the context provided via key-value arguments."""
|
|
11
|
+
...
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from fkat.utils import boto3
|
|
10
|
+
from fkat.pytorch.actions import LightningAction
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from types_boto3_batch import BatchClient
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TerminateJob(LightningAction):
|
|
17
|
+
"""This action calls Batch.TerminateJob."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, job_id: str | None = None) -> None:
|
|
20
|
+
self.job_id = job_id
|
|
21
|
+
|
|
22
|
+
@override
|
|
23
|
+
def perform(self, **kwargs: Any) -> Any:
|
|
24
|
+
"""Calls Batch.TerminateJob."""
|
|
25
|
+
job_id = self.job_id or os.getenv("AWS_BATCH_JOB_ID")
|
|
26
|
+
if job_id:
|
|
27
|
+
reason = ",".join(f"{k}={v}" for k, v in kwargs.items() if isinstance(v, str))
|
|
28
|
+
batch: BatchClient = boto3.session().client("batch") # type: ignore[assignment]
|
|
29
|
+
batch.terminate_job(jobId=job_id, reason=reason)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from fkat.utils import boto3, assert_not_none
|
|
9
|
+
from fkat.utils.aws import imds
|
|
10
|
+
from fkat.pytorch.loggers import LightningLogger, CompositeLogger
|
|
11
|
+
from fkat.pytorch.actions import LightningAction
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from types_boto3_ec2 import EC2Client
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TerminateInstances(LightningAction):
|
|
18
|
+
"""This action calls EC2.TerminateInstances."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, instance_ids: list[str] | None = None) -> None:
|
|
21
|
+
self.instance_ids = instance_ids or []
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def perform(self, **kwargs: Any) -> Any:
|
|
25
|
+
"""Calls EC2.TerminateInstances with the provided ``instance_id`` or the current node's instance_id"""
|
|
26
|
+
instance_ids = self.instance_ids or kwargs.get("instance_ids") or [imds.instance_metadata().instance_id]
|
|
27
|
+
ec2: EC2Client = boto3.session().client("ec2") # type: ignore[assignment]
|
|
28
|
+
ec2.terminate_instances(InstanceIds=instance_ids)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RebootInstances(LightningAction):
|
|
32
|
+
"""This action calls EC2.RebootInstances."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, instance_ids: list[str] | None = None) -> None:
|
|
35
|
+
self.instance_ids = instance_ids or []
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def perform(self, **kwargs: Any) -> Any:
|
|
39
|
+
"""Calls EC2.RebootInstances with the provided ``instance_id`` or the current node's instance_id"""
|
|
40
|
+
instance_ids = self.instance_ids or kwargs.get("instance_ids") or [imds.instance_metadata().instance_id]
|
|
41
|
+
ec2: EC2Client = boto3.session().client("ec2") # type: ignore[assignment]
|
|
42
|
+
ec2.reboot_instances(InstanceIds=instance_ids)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LogInstanceTags(LightningAction):
|
|
46
|
+
"""This action logs tags suffixed with EC2 instance-id."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, instance_id: str | None = None, tags: list[str] | None = None) -> None:
|
|
49
|
+
self.instance_id = instance_id
|
|
50
|
+
self.tags = tags or []
|
|
51
|
+
self.logger: LightningLogger | None = None
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def perform(self, **kwargs: Any) -> Any:
|
|
55
|
+
"""Calls LightningLogger.log_tag with all string-valued keys, requires ``trainer`` to be provided"""
|
|
56
|
+
instance_id = self.instance_id or kwargs.get("instance_id")
|
|
57
|
+
for t in self.tags:
|
|
58
|
+
if v := kwargs.get(t):
|
|
59
|
+
instance_id = instance_id or imds.instance_metadata().instance_id
|
|
60
|
+
self.logger = self.logger or CompositeLogger(assert_not_none(kwargs.get("trainer"), "trainer"))
|
|
61
|
+
self.logger.log_tag(f"{instance_id}/{t}/{v}", "True")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .xid import Xid
|
|
4
|
+
from .nsys import Nsys
|
|
5
|
+
from .nvtx import Nvtx
|
|
6
|
+
from .memory import MemoryObserver
|
|
7
|
+
from .cache import EmptyCache
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Xid",
|
|
12
|
+
"Nsys",
|
|
13
|
+
"Nvtx",
|
|
14
|
+
"MemoryObserver",
|
|
15
|
+
"EmptyCache",
|
|
16
|
+
]
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import lightning as L
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
11
|
+
|
|
12
|
+
from fkat.pytorch.schedule import (
|
|
13
|
+
Schedule,
|
|
14
|
+
Never,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmptyCache(L.Callback):
|
|
19
|
+
def __init__(self, schedule: Schedule | None = None) -> None:
|
|
20
|
+
"""
|
|
21
|
+
PyTorch Lightning callback to trigger ``torch.cuda.empty_cache()``.
|
|
22
|
+
|
|
23
|
+
This callback allows fine-grained control over CUDA's Caching Allocator during training,
|
|
24
|
+
validation, testing, and prediction.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
schedule (Optional[Schedule]): When to invoke ``torch, defaults to class:`Never`
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> trainer = Trainer(callbacks=[EmptyCache()])
|
|
31
|
+
"""
|
|
32
|
+
self.schedule = schedule or Never()
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
36
|
+
"""Perform empty cache after training epoch."""
|
|
37
|
+
self.maybe_empty_cache(trainer, "train")
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def on_train_batch_end(
|
|
41
|
+
self,
|
|
42
|
+
trainer: "L.Trainer",
|
|
43
|
+
pl_module: "L.LightningModule",
|
|
44
|
+
outputs: "STEP_OUTPUT",
|
|
45
|
+
batch: Any,
|
|
46
|
+
batch_idx: int,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Perform empty cache after training batch if needed."""
|
|
49
|
+
self.maybe_empty_cache(trainer, "train", batch_idx)
|
|
50
|
+
|
|
51
|
+
@override
|
|
52
|
+
def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
53
|
+
"""Perform empty cache after validation epoch."""
|
|
54
|
+
self.maybe_empty_cache(trainer, "validation")
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def on_validation_batch_end(
|
|
58
|
+
self,
|
|
59
|
+
trainer: "L.Trainer",
|
|
60
|
+
pl_module: "L.LightningModule",
|
|
61
|
+
outputs: "STEP_OUTPUT",
|
|
62
|
+
batch: Any,
|
|
63
|
+
batch_idx: int,
|
|
64
|
+
dataloader_idx: int = 0,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Perform empty cache after validation batch if needed."""
|
|
67
|
+
self.maybe_empty_cache(trainer, "validation", batch_idx)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
71
|
+
"""Perform empty cache after test epoch."""
|
|
72
|
+
self.maybe_empty_cache(trainer, "test")
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def on_predict_batch_end(
|
|
76
|
+
self,
|
|
77
|
+
trainer: "L.Trainer",
|
|
78
|
+
pl_module: "L.LightningModule",
|
|
79
|
+
outputs: Any,
|
|
80
|
+
batch: Any,
|
|
81
|
+
batch_idx: int,
|
|
82
|
+
dataloader_idx: int = 0,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Perform empty cache after prediction batch if needed."""
|
|
85
|
+
self.maybe_empty_cache(trainer, "predict", batch_idx)
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
89
|
+
"""Perform empty cache after predict epoch."""
|
|
90
|
+
self.maybe_empty_cache(trainer, "predict")
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def on_test_batch_end(
|
|
94
|
+
self,
|
|
95
|
+
trainer: "L.Trainer",
|
|
96
|
+
pl_module: "L.LightningModule",
|
|
97
|
+
outputs: "STEP_OUTPUT",
|
|
98
|
+
batch: Any,
|
|
99
|
+
batch_idx: int,
|
|
100
|
+
dataloader_idx: int = 0,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Perform empty cache after test batch if needed."""
|
|
103
|
+
self.maybe_empty_cache(trainer, "test", batch_idx)
|
|
104
|
+
|
|
105
|
+
def maybe_empty_cache(self, trainer: "L.Trainer", stage: str, batch_idx: int | None = None) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Perform empty cache if conditions are met.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
trainer (L.Trainer): Lightning Trainer
|
|
111
|
+
stage (str): training stage
|
|
112
|
+
batch_idx (int | None): Current batch index if available
|
|
113
|
+
"""
|
|
114
|
+
if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
115
|
+
torch.cuda.empty_cache()
|