fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/data/sharded.py
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import atexit
|
|
4
|
+
import fnmatch
|
|
5
|
+
import logging
|
|
6
|
+
import multiprocessing as mp
|
|
7
|
+
from multiprocessing.synchronize import Event
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor, Future
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
import shutil
|
|
13
|
+
import signal
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Generic, TypeVar
|
|
16
|
+
from collections import deque
|
|
17
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from pyarrow.fs import FileSelector, FileSystem
|
|
21
|
+
from lightning.pytorch.core.hooks import CheckpointHooks
|
|
22
|
+
from lightning.pytorch.profilers import Profiler
|
|
23
|
+
from lightning.pytorch.utilities import move_data_to_device
|
|
24
|
+
import torch
|
|
25
|
+
import torch.distributed as dist
|
|
26
|
+
from torch.utils.data import DataLoader, Dataset, Sampler
|
|
27
|
+
from typing_extensions import override
|
|
28
|
+
|
|
29
|
+
from fkat.data import PersistStates, RestoreStates
|
|
30
|
+
from fkat.utils import shm
|
|
31
|
+
from fkat.utils.pool import ThreadPool, NoDaemonPool
|
|
32
|
+
from fkat.utils.profiler import profile_until_exit
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
_shutdown: Event | None = None
|
|
37
|
+
|
|
38
|
+
DEFAULT_SHUTDOWN_TIMEOUT = 60 # time for shard workers to gracefully shutdown
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def initialize(seed: int, dp_rank: int, shutdown: Event, profiler: Profiler | None = None) -> None:
|
|
42
|
+
# signal handlers are inherited, we used shudown flag to gracefully terminate child processes
|
|
43
|
+
try:
|
|
44
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
45
|
+
except ValueError:
|
|
46
|
+
pass # won't work from non-MainThread
|
|
47
|
+
# this allows the worker function to access `shutdown` even though it is
|
|
48
|
+
# not passed as an argument to the function.
|
|
49
|
+
global _shutdown
|
|
50
|
+
_shutdown = shutdown
|
|
51
|
+
pid = os.getpid()
|
|
52
|
+
logger.debug(f"shard worker init {pid} ...")
|
|
53
|
+
if profiler:
|
|
54
|
+
action = f"ShardedDataLoader[worker_pid={pid}]"
|
|
55
|
+
profile_until_exit(profiler, action=action, filename_suffix=f"_{pid}")
|
|
56
|
+
|
|
57
|
+
# Set RNG seed ensure TP rank within same DP group load and iterate same data
|
|
58
|
+
# in same order with consistent RNG states
|
|
59
|
+
rng_seed = seed + dp_rank
|
|
60
|
+
np.random.seed(rng_seed)
|
|
61
|
+
random.seed(rng_seed)
|
|
62
|
+
torch.manual_seed(rng_seed)
|
|
63
|
+
logger.info(f"RNG seed is set with {rng_seed}")
|
|
64
|
+
logger.debug(f"shard worker init {pid} complete")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
68
|
+
Shard = str | list[str]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ShardSampler(Iterable[Shard], PersistStates, RestoreStates):
|
|
72
|
+
def __init__(self) -> None:
|
|
73
|
+
self.shards: list[Shard] = []
|
|
74
|
+
self.index: int = -1
|
|
75
|
+
self.all_rank_indices: list[int] = []
|
|
76
|
+
|
|
77
|
+
def __iter__(self) -> Iterator[Shard]:
|
|
78
|
+
# Iterate over shards and yield them one by one
|
|
79
|
+
yield from self.shards
|
|
80
|
+
|
|
81
|
+
def state_dict(self) -> dict[str, Any]:
|
|
82
|
+
"""Converts the current state to a dictionary saving the sampler states.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Dict[str, Any]: dict object representing the state.
|
|
86
|
+
"""
|
|
87
|
+
# Raise error if torch.distributed is not initialized
|
|
88
|
+
if not dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
|
|
89
|
+
raise RuntimeError("torch.distributed is not initialized.")
|
|
90
|
+
|
|
91
|
+
# Get local rank and world size
|
|
92
|
+
world_size = dist.get_world_size() # type: ignore[possibly-unbound-attribute]
|
|
93
|
+
|
|
94
|
+
# Get device
|
|
95
|
+
device = "cpu" if dist.get_backend() == "gloo" else "cuda" # type: ignore[possibly-unbound-attribute]
|
|
96
|
+
|
|
97
|
+
# Create a torch tensor with index defined
|
|
98
|
+
local_index = torch.tensor(self.index, dtype=torch.int, device=device)
|
|
99
|
+
|
|
100
|
+
# Prepare a list of tensors to hold the indices from all ranks
|
|
101
|
+
# i.e. rank 0 to rank 3 would have access to
|
|
102
|
+
# [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]
|
|
103
|
+
all_rank_indices = [torch.zeros_like(local_index) for _ in range(world_size)]
|
|
104
|
+
|
|
105
|
+
# Gather the indices from all ranks so all ranks have access to the same list of indices
|
|
106
|
+
dist.all_gather(all_rank_indices, local_index) # type: ignore[possibly-unbound-attribute]
|
|
107
|
+
|
|
108
|
+
# Return all rank indices
|
|
109
|
+
sampler_states = {"all_rank_indices": all_rank_indices}
|
|
110
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
111
|
+
logger.debug(f"DataModule sampler states are {sampler_states}")
|
|
112
|
+
|
|
113
|
+
return sampler_states
|
|
114
|
+
|
|
115
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
116
|
+
"""Load the sampler state dict with serialized state_dict
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
state_dict (Dict[str, Any]): serialized sampler states
|
|
120
|
+
"""
|
|
121
|
+
# Raise error if torch.distributed is not initialized
|
|
122
|
+
if not dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
|
|
123
|
+
raise RuntimeError("torch.distributed is not initialized.")
|
|
124
|
+
|
|
125
|
+
# Get the all_rank_indices
|
|
126
|
+
all_rank_indices = state_dict["all_rank_indices"]
|
|
127
|
+
|
|
128
|
+
# Convert list of tensor indices to list of integer indices
|
|
129
|
+
# if all_rank_indices = [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]
|
|
130
|
+
# then self.all_rank_indices = [1, 1, 5, 6]
|
|
131
|
+
self.all_rank_indices = [tensor_index.item() for tensor_index in all_rank_indices]
|
|
132
|
+
logger.info(f"All rank indices are {self.all_rank_indices}")
|
|
133
|
+
|
|
134
|
+
# Check if the number of ranks in the state_dict matches the current distributed setup.
|
|
135
|
+
# This ensures consistency when resuming training, preventing issues from mismatched
|
|
136
|
+
# configurations (e.g., different number of nodes or devices).
|
|
137
|
+
world_size = dist.get_world_size() # type: ignore[possibly-unbound-attribute]
|
|
138
|
+
num_saved_ranks = len(self.all_rank_indices)
|
|
139
|
+
if num_saved_ranks != world_size:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Inconsistent distributed training configuration: the loaded state_dict contains "
|
|
142
|
+
f"checkpoint data for {num_saved_ranks} ranks, but the current world size is {world_size}. "
|
|
143
|
+
"Ensure that you are resuming from a checkpoint with the same distributed setup "
|
|
144
|
+
"(number of nodes and devices)."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Get the local rank and set the corresponding index from saved state_dict
|
|
148
|
+
# Each rank loads its corresponding saved index from self.all_rank_indices
|
|
149
|
+
local_rank = dist.get_rank() # type: ignore[possibly-unbound-attribute]
|
|
150
|
+
self.index = self.all_rank_indices[local_rank]
|
|
151
|
+
logger.info(f"Rank {local_rank} has {self.index}")
|
|
152
|
+
|
|
153
|
+
# Set self.reset as False so it doesn't reset the index
|
|
154
|
+
self.reset = False
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class DataLoaderFactory(Generic[T_co]):
|
|
158
|
+
"""Factory class for creating DataLoaders.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
dataset_generator (Callable): A function that generates a dataset given a shard.
|
|
162
|
+
batch_size (Optional[int]): The batch size.
|
|
163
|
+
sampler_generator (Optional[Callable]): An optional function that generates a sampler for the dataset.
|
|
164
|
+
batch_sampler_generator (Optional[Callable]): An optional function that generates a batch sampler.
|
|
165
|
+
dataloader_generator (Optional[Callable]): An optional function that generates a DataLoader instance.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
dataset_generator: Callable[[Shard], Dataset[T_co]],
|
|
171
|
+
sampler_generator: Callable[[Dataset[T_co]], Sampler[T_co]] | None = None,
|
|
172
|
+
batch_sampler_generator: Callable[[Sampler[Any] | Dataset[T_co]], Iterable[list[Any]]] | None = None,
|
|
173
|
+
dataloader_generator: Callable[[Any], Iterable[list[T_co]]] | None = None,
|
|
174
|
+
batch_size: int = 1,
|
|
175
|
+
) -> None:
|
|
176
|
+
assert batch_size or sampler_generator or batch_sampler_generator, (
|
|
177
|
+
"either batch_size, sampler_generator or batch_sampler_generation must be provided"
|
|
178
|
+
)
|
|
179
|
+
self.dataset_generator = dataset_generator
|
|
180
|
+
self.sampler_generator = sampler_generator
|
|
181
|
+
self.batch_sampler_generator = batch_sampler_generator
|
|
182
|
+
self.dataloader_generator = dataloader_generator or DataLoader
|
|
183
|
+
self.batch_size = batch_size
|
|
184
|
+
|
|
185
|
+
def __call__(self, shard: Shard) -> Iterable[list[T_co]]:
|
|
186
|
+
"""Generates a DataLoader for the given shard.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
shard (Shard): Represents a subset of the dataset.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Iterable[List[T_co]]: An iterable of batches of data.
|
|
193
|
+
"""
|
|
194
|
+
# Generate dataset using dataset_generator
|
|
195
|
+
dataset = self.dataset_generator(shard)
|
|
196
|
+
|
|
197
|
+
# Generate sampler if sampler_generator is provided
|
|
198
|
+
sampler = self.sampler_generator(dataset) if self.sampler_generator else None
|
|
199
|
+
|
|
200
|
+
# Generate batch sampler if batch_sampler_generator is provided
|
|
201
|
+
if self.batch_sampler_generator:
|
|
202
|
+
batch_sampler = self.batch_sampler_generator(sampler if sampler else dataset)
|
|
203
|
+
sampler = None # mutually exclusive
|
|
204
|
+
else:
|
|
205
|
+
batch_sampler = None
|
|
206
|
+
|
|
207
|
+
# Generate DataLoader instance using dataloader_generator
|
|
208
|
+
dataloader = self.dataloader_generator( # type: ignore[call-arg]
|
|
209
|
+
dataset, batch_size=self.batch_size, shuffle=None, sampler=sampler, batch_sampler=batch_sampler
|
|
210
|
+
)
|
|
211
|
+
return dataloader
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class DataLoaderIterGenerator(Generic[T_co]):
|
|
215
|
+
"""Generates and saves an iterator over DataLoaders.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible for generating DataLoaders.
|
|
219
|
+
num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
|
|
220
|
+
Defaults to -1.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
dataloader_factory: DataLoaderFactory[T_co],
|
|
226
|
+
num_microbatch_prefetches: int = -1,
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Initializes the DataLoaderIterGenerator.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible
|
|
232
|
+
for generating DataLoaders.
|
|
233
|
+
num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
|
|
234
|
+
Defaults to -1.
|
|
235
|
+
"""
|
|
236
|
+
self.dataloader_factory = dataloader_factory
|
|
237
|
+
self.num_microbatch_prefetches = num_microbatch_prefetches
|
|
238
|
+
|
|
239
|
+
def __call__(self, shard: Shard, path: Path) -> None:
|
|
240
|
+
"""Generates and saves an iterator over DataLoaders.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
shard (Shard): A subset of the dataset.
|
|
244
|
+
path (Path): The path where the iterator will be saved.
|
|
245
|
+
"""
|
|
246
|
+
# Log debug message indicating the start of the process
|
|
247
|
+
logger.debug("shard generate ...")
|
|
248
|
+
|
|
249
|
+
# Generate a DataLoader using the provided shard
|
|
250
|
+
dataloader = self.dataloader_factory(shard)
|
|
251
|
+
|
|
252
|
+
# Create an iterator over the DataLoader
|
|
253
|
+
dataloader_iter = iter(dataloader)
|
|
254
|
+
|
|
255
|
+
# Access global variable _shutdown
|
|
256
|
+
global _shutdown
|
|
257
|
+
|
|
258
|
+
# Save the iterator using shm.save_iter
|
|
259
|
+
shm.save_iter(
|
|
260
|
+
dataloader_iter,
|
|
261
|
+
path=path,
|
|
262
|
+
max_items=self.num_microbatch_prefetches,
|
|
263
|
+
should_stop=lambda: _shutdown is not None and _shutdown.is_set(),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Log debug message indicating the completion of the process
|
|
267
|
+
logger.debug("shard generate complete")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class DistributedDataParallelShardSampler(ShardSampler, CheckpointHooks):
|
|
271
|
+
"""Distributed Data Parallel Shard Sampler.
|
|
272
|
+
|
|
273
|
+
This sampler distributes shards evenly among processes in a distributed training setup.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
sampler (ShardSampler): An instance of ShardSampler containing shards.
|
|
277
|
+
dp_size (int): Total number of processes in the distributed training setup.
|
|
278
|
+
dp_rank (int): Rank of the current process among the distributed processes.
|
|
279
|
+
state_dict (dict, optional): A dictionary object serialized from a Sampler object. If provided,
|
|
280
|
+
the sampler will be reconstructed using the dictionary state object and recovered to the previous state.
|
|
281
|
+
drop_last (bool, optional): Whether to drop last shards if number of shards can't be divided by dp_size.
|
|
282
|
+
Recommend to set this as False for evaluation and prediction tasks.
|
|
283
|
+
Default to ``True``.
|
|
284
|
+
num_uri_merge (int, optional): merge how many uri into a shard,
|
|
285
|
+
default to ``0``; if setting ``-1``, then all uri will be merged into a shard.
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
sampler: ShardSampler,
|
|
291
|
+
dp_size: int,
|
|
292
|
+
dp_rank: int,
|
|
293
|
+
drop_last: bool = True,
|
|
294
|
+
num_uri_merge: int = 0,
|
|
295
|
+
) -> None:
|
|
296
|
+
super().__init__()
|
|
297
|
+
# Convert sampler to list and determine the total number of shards
|
|
298
|
+
shards: list[str | list[str]] = list(sampler)
|
|
299
|
+
num_shards = len(shards)
|
|
300
|
+
|
|
301
|
+
# Ensure that the number of shards is compatible with dp_size
|
|
302
|
+
if num_shards < dp_size:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Only datasets with num_shards >= dp_size are supported, "
|
|
305
|
+
f"got num_shards={num_shards}, dp_size={dp_size}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Distribute shards evenly among each DP group
|
|
309
|
+
dp_shards, rem = divmod(num_shards, dp_size)
|
|
310
|
+
if rem > 0 and drop_last:
|
|
311
|
+
logger.warning(f"Truncating not even distribution of {num_shards} shards across dp_size={dp_size}")
|
|
312
|
+
shards = shards[:-rem]
|
|
313
|
+
else:
|
|
314
|
+
dp_shards, _ = divmod(num_shards + dp_size - 1, dp_size)
|
|
315
|
+
shards = shards[dp_rank * dp_shards : (dp_rank + 1) * dp_shards] # offset
|
|
316
|
+
|
|
317
|
+
if num_uri_merge != 0:
|
|
318
|
+
merged_shards: list[list[str]] = []
|
|
319
|
+
sub_shards: list[str] = []
|
|
320
|
+
for i in range(len(shards)):
|
|
321
|
+
if isinstance(shards[i], str):
|
|
322
|
+
sub_shards.append(str(shards[i]))
|
|
323
|
+
else:
|
|
324
|
+
sub_shards.extend(shards[i])
|
|
325
|
+
if len(sub_shards) == num_uri_merge:
|
|
326
|
+
merged_shards.append(sub_shards.copy())
|
|
327
|
+
sub_shards.clear()
|
|
328
|
+
if sub_shards:
|
|
329
|
+
merged_shards.append(sub_shards)
|
|
330
|
+
self.shards = merged_shards # type: ignore[assignment]
|
|
331
|
+
else:
|
|
332
|
+
self.shards = shards
|
|
333
|
+
|
|
334
|
+
# Ensure that assigned shards are not empty
|
|
335
|
+
assert self.shards
|
|
336
|
+
|
|
337
|
+
# Initialize vars
|
|
338
|
+
self.reset = True
|
|
339
|
+
self.index = -1
|
|
340
|
+
|
|
341
|
+
@override
|
|
342
|
+
def __iter__(self) -> Iterator[Shard]:
|
|
343
|
+
"""Returns an iterator over the shards.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Iterator[Shard]: Iterator over the shards.
|
|
347
|
+
"""
|
|
348
|
+
# Reset iterator if reset flag is set
|
|
349
|
+
if self.reset:
|
|
350
|
+
self.index = -1
|
|
351
|
+
self.reset = True
|
|
352
|
+
return self
|
|
353
|
+
|
|
354
|
+
def __next__(self) -> Shard:
|
|
355
|
+
"""Returns the next shard.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
Shard: Next shard.
|
|
359
|
+
"""
|
|
360
|
+
# Increment index and return the corresponding shard
|
|
361
|
+
self.index += 1
|
|
362
|
+
if self.index >= len(self.shards):
|
|
363
|
+
raise StopIteration
|
|
364
|
+
return self.shards[self.index]
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class ShuffledShardSampler(ShardSampler, CheckpointHooks):
|
|
368
|
+
"""Sampler for shuffling shards.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
sampler (ShardSampler): An instance of ShardSampler containing shards.
|
|
372
|
+
state_dict (Dict[str, Any], optional): A dictionary object containing the state of the sampler.
|
|
373
|
+
Defaults to None.
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
def __init__(self, sampler: ShardSampler) -> None:
|
|
377
|
+
# Convert sampler to list and assert non-emptiness
|
|
378
|
+
self.shards = list(sampler)
|
|
379
|
+
assert self.shards
|
|
380
|
+
|
|
381
|
+
# Initialize variables
|
|
382
|
+
self.reset = True
|
|
383
|
+
self.index = -1
|
|
384
|
+
|
|
385
|
+
# Create indices for shuffling
|
|
386
|
+
self.indices = list(range(len(self.shards)))
|
|
387
|
+
|
|
388
|
+
@override
|
|
389
|
+
def __iter__(self) -> Iterator[Shard]:
|
|
390
|
+
"""Returns an iterator over the shards.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Iterator[Shard]: Iterator over the shards.
|
|
394
|
+
"""
|
|
395
|
+
# If reset flag is set, shuffle indices
|
|
396
|
+
if self.reset:
|
|
397
|
+
self.index = -1
|
|
398
|
+
random.shuffle(self.indices)
|
|
399
|
+
self.reset = True
|
|
400
|
+
return self
|
|
401
|
+
|
|
402
|
+
def __next__(self) -> Shard:
|
|
403
|
+
"""Returns the next shard.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
Shard: Next shard.
|
|
407
|
+
"""
|
|
408
|
+
# Increment index and return the corresponding shard
|
|
409
|
+
self.index += 1
|
|
410
|
+
if self.index >= len(self.indices):
|
|
411
|
+
raise StopIteration
|
|
412
|
+
return self.shards[self.indices[self.index]]
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class FsShardSampler(ShardSampler, CheckpointHooks):
|
|
416
|
+
"""Sampler for shuffling shards based on file system paths.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
uri (str): The URI specifying the file system path.
|
|
420
|
+
glob (Optional[str], optional): A glob pattern to filter files.
|
|
421
|
+
Defaults to None.
|
|
422
|
+
recursive (Optional[bool], optional): Whether to recursively search for files.
|
|
423
|
+
Defaults to True.
|
|
424
|
+
state_dict (Optional[Dict[str, Any]], optional): A dictionary object containing the state of the sampler.
|
|
425
|
+
Defaults to None.
|
|
426
|
+
num_uri_merge (int, optional): merge how many uri into a shard,
|
|
427
|
+
default to ``0``; if setting ``-1``, then all uri will be merged into a shard.
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
def __init__(self, uri: str, glob: str | None = None, recursive: bool = True, num_uri_merge: int = 0) -> None:
|
|
431
|
+
# from_uri is a static method, but pyarrow-stubs says it's an instance one
|
|
432
|
+
# Extract file system and path from URI
|
|
433
|
+
fs: FileSystem
|
|
434
|
+
path: str
|
|
435
|
+
fs, path = FileSystem.from_uri(uri)
|
|
436
|
+
|
|
437
|
+
# Define a selector for files in the specified path
|
|
438
|
+
selector = FileSelector(path, recursive=recursive)
|
|
439
|
+
|
|
440
|
+
# Initialize shards list
|
|
441
|
+
self.shards = []
|
|
442
|
+
# Populate shards list with files matching the glob pattern (if provided)
|
|
443
|
+
for file in fs.get_file_info(selector):
|
|
444
|
+
path = f"{fs.type_name}://{file.path}"
|
|
445
|
+
if not glob or fnmatch.fnmatch(path, glob):
|
|
446
|
+
self.shards.append(path)
|
|
447
|
+
if num_uri_merge != 0:
|
|
448
|
+
merged_shards: list[list[str]] = []
|
|
449
|
+
shards: list[str] = []
|
|
450
|
+
for i in range(len(self.shards)):
|
|
451
|
+
shards.append(str(self.shards[i]))
|
|
452
|
+
if len(shards) == num_uri_merge:
|
|
453
|
+
merged_shards.append(shards.copy())
|
|
454
|
+
shards.clear()
|
|
455
|
+
if shards:
|
|
456
|
+
merged_shards.append(shards)
|
|
457
|
+
self.shards = merged_shards # type: ignore[assignment]
|
|
458
|
+
# Ensure that shards list is not empty
|
|
459
|
+
assert self.shards
|
|
460
|
+
|
|
461
|
+
# Initialize reset and index vars
|
|
462
|
+
self.reset = True
|
|
463
|
+
self.index = -1
|
|
464
|
+
|
|
465
|
+
@override
|
|
466
|
+
def __iter__(self) -> Iterator[Shard]:
|
|
467
|
+
"""Returns an iterator over the shards.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Iterator[Shard]: Iterator over the shards.
|
|
471
|
+
"""
|
|
472
|
+
# Reset iterator if reset flag is set
|
|
473
|
+
if self.reset:
|
|
474
|
+
self.index = -1
|
|
475
|
+
self.reset = True
|
|
476
|
+
return self
|
|
477
|
+
|
|
478
|
+
def __next__(self) -> Shard:
|
|
479
|
+
"""Returns the next shard.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
Shard: Next shard.
|
|
483
|
+
"""
|
|
484
|
+
# Increment index and return the corresponding shard
|
|
485
|
+
self.index += 1
|
|
486
|
+
if self.index >= len(self.shards):
|
|
487
|
+
raise StopIteration
|
|
488
|
+
return self.shards[self.index]
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class ShardedDataLoader(Iterable[list[T_co]]):
|
|
492
|
+
"""A :class:`DataLoader` that processes data in shards, designed for distributed training scenarios.
|
|
493
|
+
|
|
494
|
+
Enables double-buffered micro-batch processing and fetching that overlaps with model
|
|
495
|
+
forward/backward passes, minimizing dataloading overhead.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
seed (int): Random seed for reproducibility. Use ${seed} at top level in config.yaml.
|
|
499
|
+
shard_sampler (ShardSampler): Sampler for generating shards.
|
|
500
|
+
dataloader_factory (DataLoaderFactory[T_co]): Factory for creating DataLoaders.
|
|
501
|
+
num_shard_prefetches (int, optional): Number of shards to prefetch.
|
|
502
|
+
Defaults to 0.
|
|
503
|
+
num_microbatch_prefetches (int, optional): Number of microbatches to prefetch.
|
|
504
|
+
Defaults to -1.
|
|
505
|
+
dp_rank (int, optional): Rank of the current process.
|
|
506
|
+
Defaults to 0.
|
|
507
|
+
profiler (Profiler, optional): Profiler for profiling.
|
|
508
|
+
Defaults to None.
|
|
509
|
+
device (Optional[torch.device]): device to move the microbatches to in the background
|
|
510
|
+
multiprocessing (Optional[True]): whether to instantiate DataLoader in a separate process.
|
|
511
|
+
Defaults to True to relieve pressure from the training process, use False to debug and profile
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
def __init__(
|
|
515
|
+
self,
|
|
516
|
+
seed: int,
|
|
517
|
+
shard_sampler: ShardSampler,
|
|
518
|
+
dataloader_factory: DataLoaderFactory[T_co],
|
|
519
|
+
num_shard_prefetches: int = 0,
|
|
520
|
+
num_microbatch_prefetches: int = -1,
|
|
521
|
+
dp_rank: int = 0,
|
|
522
|
+
profiler: Profiler | None = None,
|
|
523
|
+
device: torch.device | None = None,
|
|
524
|
+
multiprocessing: bool = True,
|
|
525
|
+
) -> None:
|
|
526
|
+
# Initialize
|
|
527
|
+
self.microbatches: Iterator[list[T_co]] | None = None
|
|
528
|
+
self.path: Path | None = None
|
|
529
|
+
self.device: torch.device | None = device
|
|
530
|
+
self.cleanup: set[Path] = set()
|
|
531
|
+
self.shutdown = mp.Event()
|
|
532
|
+
self.shard_sampler = shard_sampler
|
|
533
|
+
self.shard_sampler_iter: Iterator[Shard]
|
|
534
|
+
self.dataloader_factory = dataloader_factory
|
|
535
|
+
self.dataloader_iter_generator = DataLoaderIterGenerator(
|
|
536
|
+
dataloader_factory,
|
|
537
|
+
num_microbatch_prefetches,
|
|
538
|
+
)
|
|
539
|
+
self.data_jobs: deque[tuple[Shard, Path, mp.pool.AsyncResult[Any]]] = deque() # type: ignore[unresolved-attribute]
|
|
540
|
+
|
|
541
|
+
# Initialize a new ProcessPoolExecutor instance for prefetching shards if necessary
|
|
542
|
+
signal.signal(signal.SIGTERM, self.teardown) # terminate signal
|
|
543
|
+
signal.signal(signal.SIGINT, self.teardown) # keyboard interrupt
|
|
544
|
+
atexit.register(self.teardown)
|
|
545
|
+
self.writing_pool: NoDaemonPool | ThreadPool = (
|
|
546
|
+
NoDaemonPool(
|
|
547
|
+
max(1, num_shard_prefetches),
|
|
548
|
+
initializer=initialize,
|
|
549
|
+
initargs=(seed, dp_rank, self.shutdown, profiler),
|
|
550
|
+
)
|
|
551
|
+
if multiprocessing
|
|
552
|
+
else ThreadPool(
|
|
553
|
+
max_workers=1,
|
|
554
|
+
thread_name_prefix="ShardedDataWriter",
|
|
555
|
+
initializer=initialize,
|
|
556
|
+
initargs=(seed, dp_rank, self.shutdown, profiler),
|
|
557
|
+
)
|
|
558
|
+
)
|
|
559
|
+
self.num_shard_prefetches = num_shard_prefetches
|
|
560
|
+
self.reading_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ShardedDataReader")
|
|
561
|
+
self.next_batch: Future[Any] | None = None
|
|
562
|
+
|
|
563
|
+
def __iter__(self) -> Iterator[list[T_co]]:
|
|
564
|
+
# all reading needs to go through the same TPE to avoid contention
|
|
565
|
+
if not self.reading_pool._shutdown:
|
|
566
|
+
self.reading_pool.submit(self._cleanup).result()
|
|
567
|
+
self.shard_sampler_iter = iter(self.shard_sampler)
|
|
568
|
+
return self
|
|
569
|
+
|
|
570
|
+
def __next__(self) -> list[T_co]:
|
|
571
|
+
if not self.next_batch:
|
|
572
|
+
self.load_batch()
|
|
573
|
+
assert self.next_batch
|
|
574
|
+
batch = self.next_batch.result()
|
|
575
|
+
self.load_batch() # double-buffering
|
|
576
|
+
return batch
|
|
577
|
+
|
|
578
|
+
def load_batch(self) -> None:
|
|
579
|
+
if not self.reading_pool._shutdown:
|
|
580
|
+
self.next_batch = self.reading_pool.submit(self.load_batch_sync)
|
|
581
|
+
|
|
582
|
+
def load_batch_sync(self) -> list[T_co]:
|
|
583
|
+
while True:
|
|
584
|
+
if self.microbatches:
|
|
585
|
+
# Fetch the next microbatch if available
|
|
586
|
+
try:
|
|
587
|
+
microbatch = next(self.microbatches)
|
|
588
|
+
# Move to target device in advance
|
|
589
|
+
if self.device:
|
|
590
|
+
microbatch = move_data_to_device(microbatch, self.device)
|
|
591
|
+
return microbatch
|
|
592
|
+
# If no microbatches are available, which means all microbatches from current shard are exhausted
|
|
593
|
+
except StopIteration:
|
|
594
|
+
if self.path:
|
|
595
|
+
self.cleanup.remove(self.path)
|
|
596
|
+
self.microbatches = None
|
|
597
|
+
self.path = None
|
|
598
|
+
|
|
599
|
+
if len(self.data_jobs) == 0:
|
|
600
|
+
logger.debug("load iter scheduling ...")
|
|
601
|
+
self.prefetch_shards(max(1, self.num_shard_prefetches))
|
|
602
|
+
if len(self.data_jobs) == 0:
|
|
603
|
+
raise StopIteration
|
|
604
|
+
shard, path, data_job = self.data_jobs.popleft()
|
|
605
|
+
logger.debug(f"load iter {shard} to {path} ...")
|
|
606
|
+
self.prefetch_shards(min(1, self.num_shard_prefetches)) # prefetch next shard in parallel
|
|
607
|
+
|
|
608
|
+
def wait_callback() -> None:
|
|
609
|
+
if not data_job.ready(): # noqa: B023
|
|
610
|
+
# Job is still running
|
|
611
|
+
return None
|
|
612
|
+
else:
|
|
613
|
+
# Job is finished, raise exception if job failed.
|
|
614
|
+
data_job.get() # noqa: B023
|
|
615
|
+
# Return whether the call completed without raising an exception.
|
|
616
|
+
assert data_job.successful() # noqa: B023
|
|
617
|
+
|
|
618
|
+
self.microbatches = shm.load_iter(path, wait_callback=wait_callback)
|
|
619
|
+
self.cleanup.add(path)
|
|
620
|
+
self.path = path
|
|
621
|
+
|
|
622
|
+
def state_dict(self) -> dict[str, Any]:
|
|
623
|
+
"""
|
|
624
|
+
Returns the shard sampler state dict with adjusted shard indices,
|
|
625
|
+
accounting for shard prefetches and prefetch backfill in parallel.
|
|
626
|
+
|
|
627
|
+
Example:
|
|
628
|
+
If num_shard_prefetches is 3 and the original state dict is
|
|
629
|
+
{"all_rank_indices": [torch.tensor(4), torch.tensor(5)]},
|
|
630
|
+
it will be updated to {"all_rank_indices": [torch.tensor(0), torch.tensor(1)]}.
|
|
631
|
+
This ensures that each rank resumes training from the correct shard index,
|
|
632
|
+
preventing reprocessing of shards that have already been trained on.
|
|
633
|
+
"""
|
|
634
|
+
# Invoke the shard_sampler's state_dict method when saving the data shard indices for each rank
|
|
635
|
+
shard_sampler_state_dict = self.shard_sampler.state_dict()
|
|
636
|
+
|
|
637
|
+
# Define the minimum prefetched shard backfill
|
|
638
|
+
min_prefetch_shard_backfill = 1
|
|
639
|
+
|
|
640
|
+
# Adjust each rank's shard index in place
|
|
641
|
+
for _, rank_indices in shard_sampler_state_dict.items():
|
|
642
|
+
for i, idx_tensor in enumerate(rank_indices):
|
|
643
|
+
rank_indices[i] = idx_tensor - self.num_shard_prefetches - min_prefetch_shard_backfill
|
|
644
|
+
|
|
645
|
+
return shard_sampler_state_dict
|
|
646
|
+
|
|
647
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
648
|
+
"""
|
|
649
|
+
Loads the state dict into the shard sampler, restoring the data shard indices for each rank.
|
|
650
|
+
|
|
651
|
+
The state dict should look like {"all_rank_indices":
|
|
652
|
+
[torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]},
|
|
653
|
+
where each tensor corresponds to the indices of data shards for specific ranks.
|
|
654
|
+
"""
|
|
655
|
+
# Restore the shard sampler's state from the given state_dict
|
|
656
|
+
self.shard_sampler.load_state_dict(state_dict)
|
|
657
|
+
|
|
658
|
+
def prefetch_shards(self, count: int) -> None:
|
|
659
|
+
try:
|
|
660
|
+
for _ in range(count):
|
|
661
|
+
shard = next(self.shard_sampler_iter)
|
|
662
|
+
path = shm.generate_path()
|
|
663
|
+
# append data_job to job pool
|
|
664
|
+
data_job = self.writing_pool.apply_async(self.dataloader_iter_generator, (shard, path))
|
|
665
|
+
self.data_jobs.append((shard, path, data_job))
|
|
666
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
667
|
+
logger.debug(f"queued {path}, {len(self.data_jobs)}")
|
|
668
|
+
except StopIteration:
|
|
669
|
+
pass
|
|
670
|
+
|
|
671
|
+
def _cleanup(self, stop_pool: bool = False) -> None:
|
|
672
|
+
self.microbatches = None
|
|
673
|
+
self.shutdown.set() # signal running tasks to stop
|
|
674
|
+
if self.next_batch:
|
|
675
|
+
try:
|
|
676
|
+
# if called on teardown/on_exception/__del__ will wait for the pending work to finish
|
|
677
|
+
# if called on __init__ it's already done
|
|
678
|
+
self.next_batch.result()
|
|
679
|
+
except Exception:
|
|
680
|
+
pass
|
|
681
|
+
self.next_batch = None
|
|
682
|
+
if stop_pool:
|
|
683
|
+
self.writing_pool.close() # no new tasks can run
|
|
684
|
+
self.reading_pool.shutdown()
|
|
685
|
+
for _, path, result in self.data_jobs:
|
|
686
|
+
logger.debug(f"waiting for {path} to stop ...")
|
|
687
|
+
self.cleanup.add(path)
|
|
688
|
+
try:
|
|
689
|
+
result.wait(timeout=DEFAULT_SHUTDOWN_TIMEOUT)
|
|
690
|
+
except Exception:
|
|
691
|
+
pass
|
|
692
|
+
logger.debug(f"{path} stopped ...")
|
|
693
|
+
self.data_jobs.clear()
|
|
694
|
+
self.shutdown.clear()
|
|
695
|
+
if stop_pool:
|
|
696
|
+
self.writing_pool.join() # make sure atexit is triggered in each subprocess
|
|
697
|
+
for path in self.cleanup:
|
|
698
|
+
logger.debug(f"removing {path} ...")
|
|
699
|
+
shutil.rmtree(path, ignore_errors=True)
|
|
700
|
+
|
|
701
|
+
def set_device(self, device: torch.device | None) -> None:
|
|
702
|
+
self.device = device
|
|
703
|
+
|
|
704
|
+
# called when fit/validate/predict/test is complete
|
|
705
|
+
def teardown(self, *args: Any) -> None:
|
|
706
|
+
logger.debug("teardown ...")
|
|
707
|
+
self._cleanup(stop_pool=True)
|
|
708
|
+
logger.debug("teardown complete")
|
|
709
|
+
|
|
710
|
+
# will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
|
|
711
|
+
# once the below callback is operational we no longer need __del__ override
|
|
712
|
+
def on_exception(self, exception: BaseException) -> None:
|
|
713
|
+
self.teardown()
|
|
714
|
+
|
|
715
|
+
# called when the iterable link pointing to this object goes out of scope
|
|
716
|
+
# e.g. when exception happens
|
|
717
|
+
def __del__(self) -> None:
|
|
718
|
+
self.teardown()
|