fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/data/sharded.py ADDED
@@ -0,0 +1,718 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import atexit
4
+ import fnmatch
5
+ import logging
6
+ import multiprocessing as mp
7
+ from multiprocessing.synchronize import Event
8
+ from concurrent.futures import ThreadPoolExecutor, Future
9
+
10
+ import os
11
+ import random
12
+ import shutil
13
+ import signal
14
+ from pathlib import Path
15
+ from typing import Any, Generic, TypeVar
16
+ from collections import deque
17
+ from collections.abc import Callable, Iterable, Iterator
18
+
19
+ import numpy as np
20
+ from pyarrow.fs import FileSelector, FileSystem
21
+ from lightning.pytorch.core.hooks import CheckpointHooks
22
+ from lightning.pytorch.profilers import Profiler
23
+ from lightning.pytorch.utilities import move_data_to_device
24
+ import torch
25
+ import torch.distributed as dist
26
+ from torch.utils.data import DataLoader, Dataset, Sampler
27
+ from typing_extensions import override
28
+
29
+ from fkat.data import PersistStates, RestoreStates
30
+ from fkat.utils import shm
31
+ from fkat.utils.pool import ThreadPool, NoDaemonPool
32
+ from fkat.utils.profiler import profile_until_exit
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _shutdown: Event | None = None
37
+
38
+ DEFAULT_SHUTDOWN_TIMEOUT = 60 # time for shard workers to gracefully shutdown
39
+
40
+
41
+ def initialize(seed: int, dp_rank: int, shutdown: Event, profiler: Profiler | None = None) -> None:
42
+ # signal handlers are inherited, we used shudown flag to gracefully terminate child processes
43
+ try:
44
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
45
+ except ValueError:
46
+ pass # won't work from non-MainThread
47
+ # this allows the worker function to access `shutdown` even though it is
48
+ # not passed as an argument to the function.
49
+ global _shutdown
50
+ _shutdown = shutdown
51
+ pid = os.getpid()
52
+ logger.debug(f"shard worker init {pid} ...")
53
+ if profiler:
54
+ action = f"ShardedDataLoader[worker_pid={pid}]"
55
+ profile_until_exit(profiler, action=action, filename_suffix=f"_{pid}")
56
+
57
+ # Set RNG seed ensure TP rank within same DP group load and iterate same data
58
+ # in same order with consistent RNG states
59
+ rng_seed = seed + dp_rank
60
+ np.random.seed(rng_seed)
61
+ random.seed(rng_seed)
62
+ torch.manual_seed(rng_seed)
63
+ logger.info(f"RNG seed is set with {rng_seed}")
64
+ logger.debug(f"shard worker init {pid} complete")
65
+
66
+
67
+ T_co = TypeVar("T_co", covariant=True)
68
+ Shard = str | list[str]
69
+
70
+
71
+ class ShardSampler(Iterable[Shard], PersistStates, RestoreStates):
72
+ def __init__(self) -> None:
73
+ self.shards: list[Shard] = []
74
+ self.index: int = -1
75
+ self.all_rank_indices: list[int] = []
76
+
77
+ def __iter__(self) -> Iterator[Shard]:
78
+ # Iterate over shards and yield them one by one
79
+ yield from self.shards
80
+
81
+ def state_dict(self) -> dict[str, Any]:
82
+ """Converts the current state to a dictionary saving the sampler states.
83
+
84
+ Returns:
85
+ Dict[str, Any]: dict object representing the state.
86
+ """
87
+ # Raise error if torch.distributed is not initialized
88
+ if not dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
89
+ raise RuntimeError("torch.distributed is not initialized.")
90
+
91
+ # Get local rank and world size
92
+ world_size = dist.get_world_size() # type: ignore[possibly-unbound-attribute]
93
+
94
+ # Get device
95
+ device = "cpu" if dist.get_backend() == "gloo" else "cuda" # type: ignore[possibly-unbound-attribute]
96
+
97
+ # Create a torch tensor with index defined
98
+ local_index = torch.tensor(self.index, dtype=torch.int, device=device)
99
+
100
+ # Prepare a list of tensors to hold the indices from all ranks
101
+ # i.e. rank 0 to rank 3 would have access to
102
+ # [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]
103
+ all_rank_indices = [torch.zeros_like(local_index) for _ in range(world_size)]
104
+
105
+ # Gather the indices from all ranks so all ranks have access to the same list of indices
106
+ dist.all_gather(all_rank_indices, local_index) # type: ignore[possibly-unbound-attribute]
107
+
108
+ # Return all rank indices
109
+ sampler_states = {"all_rank_indices": all_rank_indices}
110
+ if logger.isEnabledFor(logging.DEBUG):
111
+ logger.debug(f"DataModule sampler states are {sampler_states}")
112
+
113
+ return sampler_states
114
+
115
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
116
+ """Load the sampler state dict with serialized state_dict
117
+
118
+ Args:
119
+ state_dict (Dict[str, Any]): serialized sampler states
120
+ """
121
+ # Raise error if torch.distributed is not initialized
122
+ if not dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
123
+ raise RuntimeError("torch.distributed is not initialized.")
124
+
125
+ # Get the all_rank_indices
126
+ all_rank_indices = state_dict["all_rank_indices"]
127
+
128
+ # Convert list of tensor indices to list of integer indices
129
+ # if all_rank_indices = [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]
130
+ # then self.all_rank_indices = [1, 1, 5, 6]
131
+ self.all_rank_indices = [tensor_index.item() for tensor_index in all_rank_indices]
132
+ logger.info(f"All rank indices are {self.all_rank_indices}")
133
+
134
+ # Check if the number of ranks in the state_dict matches the current distributed setup.
135
+ # This ensures consistency when resuming training, preventing issues from mismatched
136
+ # configurations (e.g., different number of nodes or devices).
137
+ world_size = dist.get_world_size() # type: ignore[possibly-unbound-attribute]
138
+ num_saved_ranks = len(self.all_rank_indices)
139
+ if num_saved_ranks != world_size:
140
+ raise ValueError(
141
+ f"Inconsistent distributed training configuration: the loaded state_dict contains "
142
+ f"checkpoint data for {num_saved_ranks} ranks, but the current world size is {world_size}. "
143
+ "Ensure that you are resuming from a checkpoint with the same distributed setup "
144
+ "(number of nodes and devices)."
145
+ )
146
+
147
+ # Get the local rank and set the corresponding index from saved state_dict
148
+ # Each rank loads its corresponding saved index from self.all_rank_indices
149
+ local_rank = dist.get_rank() # type: ignore[possibly-unbound-attribute]
150
+ self.index = self.all_rank_indices[local_rank]
151
+ logger.info(f"Rank {local_rank} has {self.index}")
152
+
153
+ # Set self.reset as False so it doesn't reset the index
154
+ self.reset = False
155
+
156
+
157
+ class DataLoaderFactory(Generic[T_co]):
158
+ """Factory class for creating DataLoaders.
159
+
160
+ Args:
161
+ dataset_generator (Callable): A function that generates a dataset given a shard.
162
+ batch_size (Optional[int]): The batch size.
163
+ sampler_generator (Optional[Callable]): An optional function that generates a sampler for the dataset.
164
+ batch_sampler_generator (Optional[Callable]): An optional function that generates a batch sampler.
165
+ dataloader_generator (Optional[Callable]): An optional function that generates a DataLoader instance.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ dataset_generator: Callable[[Shard], Dataset[T_co]],
171
+ sampler_generator: Callable[[Dataset[T_co]], Sampler[T_co]] | None = None,
172
+ batch_sampler_generator: Callable[[Sampler[Any] | Dataset[T_co]], Iterable[list[Any]]] | None = None,
173
+ dataloader_generator: Callable[[Any], Iterable[list[T_co]]] | None = None,
174
+ batch_size: int = 1,
175
+ ) -> None:
176
+ assert batch_size or sampler_generator or batch_sampler_generator, (
177
+ "either batch_size, sampler_generator or batch_sampler_generation must be provided"
178
+ )
179
+ self.dataset_generator = dataset_generator
180
+ self.sampler_generator = sampler_generator
181
+ self.batch_sampler_generator = batch_sampler_generator
182
+ self.dataloader_generator = dataloader_generator or DataLoader
183
+ self.batch_size = batch_size
184
+
185
+ def __call__(self, shard: Shard) -> Iterable[list[T_co]]:
186
+ """Generates a DataLoader for the given shard.
187
+
188
+ Args:
189
+ shard (Shard): Represents a subset of the dataset.
190
+
191
+ Returns:
192
+ Iterable[List[T_co]]: An iterable of batches of data.
193
+ """
194
+ # Generate dataset using dataset_generator
195
+ dataset = self.dataset_generator(shard)
196
+
197
+ # Generate sampler if sampler_generator is provided
198
+ sampler = self.sampler_generator(dataset) if self.sampler_generator else None
199
+
200
+ # Generate batch sampler if batch_sampler_generator is provided
201
+ if self.batch_sampler_generator:
202
+ batch_sampler = self.batch_sampler_generator(sampler if sampler else dataset)
203
+ sampler = None # mutually exclusive
204
+ else:
205
+ batch_sampler = None
206
+
207
+ # Generate DataLoader instance using dataloader_generator
208
+ dataloader = self.dataloader_generator( # type: ignore[call-arg]
209
+ dataset, batch_size=self.batch_size, shuffle=None, sampler=sampler, batch_sampler=batch_sampler
210
+ )
211
+ return dataloader
212
+
213
+
214
+ class DataLoaderIterGenerator(Generic[T_co]):
215
+ """Generates and saves an iterator over DataLoaders.
216
+
217
+ Args:
218
+ dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible for generating DataLoaders.
219
+ num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
220
+ Defaults to -1.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dataloader_factory: DataLoaderFactory[T_co],
226
+ num_microbatch_prefetches: int = -1,
227
+ ) -> None:
228
+ """Initializes the DataLoaderIterGenerator.
229
+
230
+ Args:
231
+ dataloader_factory (DataLoaderFactory): An instance of DataLoaderFactory responsible
232
+ for generating DataLoaders.
233
+ num_microbatches_prefetch (int, optional): The number of microbatches to prefetch.
234
+ Defaults to -1.
235
+ """
236
+ self.dataloader_factory = dataloader_factory
237
+ self.num_microbatch_prefetches = num_microbatch_prefetches
238
+
239
+ def __call__(self, shard: Shard, path: Path) -> None:
240
+ """Generates and saves an iterator over DataLoaders.
241
+
242
+ Args:
243
+ shard (Shard): A subset of the dataset.
244
+ path (Path): The path where the iterator will be saved.
245
+ """
246
+ # Log debug message indicating the start of the process
247
+ logger.debug("shard generate ...")
248
+
249
+ # Generate a DataLoader using the provided shard
250
+ dataloader = self.dataloader_factory(shard)
251
+
252
+ # Create an iterator over the DataLoader
253
+ dataloader_iter = iter(dataloader)
254
+
255
+ # Access global variable _shutdown
256
+ global _shutdown
257
+
258
+ # Save the iterator using shm.save_iter
259
+ shm.save_iter(
260
+ dataloader_iter,
261
+ path=path,
262
+ max_items=self.num_microbatch_prefetches,
263
+ should_stop=lambda: _shutdown is not None and _shutdown.is_set(),
264
+ )
265
+
266
+ # Log debug message indicating the completion of the process
267
+ logger.debug("shard generate complete")
268
+
269
+
270
+ class DistributedDataParallelShardSampler(ShardSampler, CheckpointHooks):
271
+ """Distributed Data Parallel Shard Sampler.
272
+
273
+ This sampler distributes shards evenly among processes in a distributed training setup.
274
+
275
+ Args:
276
+ sampler (ShardSampler): An instance of ShardSampler containing shards.
277
+ dp_size (int): Total number of processes in the distributed training setup.
278
+ dp_rank (int): Rank of the current process among the distributed processes.
279
+ state_dict (dict, optional): A dictionary object serialized from a Sampler object. If provided,
280
+ the sampler will be reconstructed using the dictionary state object and recovered to the previous state.
281
+ drop_last (bool, optional): Whether to drop last shards if number of shards can't be divided by dp_size.
282
+ Recommend to set this as False for evaluation and prediction tasks.
283
+ Default to ``True``.
284
+ num_uri_merge (int, optional): merge how many uri into a shard,
285
+ default to ``0``; if setting ``-1``, then all uri will be merged into a shard.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ sampler: ShardSampler,
291
+ dp_size: int,
292
+ dp_rank: int,
293
+ drop_last: bool = True,
294
+ num_uri_merge: int = 0,
295
+ ) -> None:
296
+ super().__init__()
297
+ # Convert sampler to list and determine the total number of shards
298
+ shards: list[str | list[str]] = list(sampler)
299
+ num_shards = len(shards)
300
+
301
+ # Ensure that the number of shards is compatible with dp_size
302
+ if num_shards < dp_size:
303
+ raise ValueError(
304
+ f"Only datasets with num_shards >= dp_size are supported, "
305
+ f"got num_shards={num_shards}, dp_size={dp_size}"
306
+ )
307
+
308
+ # Distribute shards evenly among each DP group
309
+ dp_shards, rem = divmod(num_shards, dp_size)
310
+ if rem > 0 and drop_last:
311
+ logger.warning(f"Truncating not even distribution of {num_shards} shards across dp_size={dp_size}")
312
+ shards = shards[:-rem]
313
+ else:
314
+ dp_shards, _ = divmod(num_shards + dp_size - 1, dp_size)
315
+ shards = shards[dp_rank * dp_shards : (dp_rank + 1) * dp_shards] # offset
316
+
317
+ if num_uri_merge != 0:
318
+ merged_shards: list[list[str]] = []
319
+ sub_shards: list[str] = []
320
+ for i in range(len(shards)):
321
+ if isinstance(shards[i], str):
322
+ sub_shards.append(str(shards[i]))
323
+ else:
324
+ sub_shards.extend(shards[i])
325
+ if len(sub_shards) == num_uri_merge:
326
+ merged_shards.append(sub_shards.copy())
327
+ sub_shards.clear()
328
+ if sub_shards:
329
+ merged_shards.append(sub_shards)
330
+ self.shards = merged_shards # type: ignore[assignment]
331
+ else:
332
+ self.shards = shards
333
+
334
+ # Ensure that assigned shards are not empty
335
+ assert self.shards
336
+
337
+ # Initialize vars
338
+ self.reset = True
339
+ self.index = -1
340
+
341
+ @override
342
+ def __iter__(self) -> Iterator[Shard]:
343
+ """Returns an iterator over the shards.
344
+
345
+ Returns:
346
+ Iterator[Shard]: Iterator over the shards.
347
+ """
348
+ # Reset iterator if reset flag is set
349
+ if self.reset:
350
+ self.index = -1
351
+ self.reset = True
352
+ return self
353
+
354
+ def __next__(self) -> Shard:
355
+ """Returns the next shard.
356
+
357
+ Returns:
358
+ Shard: Next shard.
359
+ """
360
+ # Increment index and return the corresponding shard
361
+ self.index += 1
362
+ if self.index >= len(self.shards):
363
+ raise StopIteration
364
+ return self.shards[self.index]
365
+
366
+
367
+ class ShuffledShardSampler(ShardSampler, CheckpointHooks):
368
+ """Sampler for shuffling shards.
369
+
370
+ Args:
371
+ sampler (ShardSampler): An instance of ShardSampler containing shards.
372
+ state_dict (Dict[str, Any], optional): A dictionary object containing the state of the sampler.
373
+ Defaults to None.
374
+ """
375
+
376
+ def __init__(self, sampler: ShardSampler) -> None:
377
+ # Convert sampler to list and assert non-emptiness
378
+ self.shards = list(sampler)
379
+ assert self.shards
380
+
381
+ # Initialize variables
382
+ self.reset = True
383
+ self.index = -1
384
+
385
+ # Create indices for shuffling
386
+ self.indices = list(range(len(self.shards)))
387
+
388
+ @override
389
+ def __iter__(self) -> Iterator[Shard]:
390
+ """Returns an iterator over the shards.
391
+
392
+ Returns:
393
+ Iterator[Shard]: Iterator over the shards.
394
+ """
395
+ # If reset flag is set, shuffle indices
396
+ if self.reset:
397
+ self.index = -1
398
+ random.shuffle(self.indices)
399
+ self.reset = True
400
+ return self
401
+
402
+ def __next__(self) -> Shard:
403
+ """Returns the next shard.
404
+
405
+ Returns:
406
+ Shard: Next shard.
407
+ """
408
+ # Increment index and return the corresponding shard
409
+ self.index += 1
410
+ if self.index >= len(self.indices):
411
+ raise StopIteration
412
+ return self.shards[self.indices[self.index]]
413
+
414
+
415
+ class FsShardSampler(ShardSampler, CheckpointHooks):
416
+ """Sampler for shuffling shards based on file system paths.
417
+
418
+ Args:
419
+ uri (str): The URI specifying the file system path.
420
+ glob (Optional[str], optional): A glob pattern to filter files.
421
+ Defaults to None.
422
+ recursive (Optional[bool], optional): Whether to recursively search for files.
423
+ Defaults to True.
424
+ state_dict (Optional[Dict[str, Any]], optional): A dictionary object containing the state of the sampler.
425
+ Defaults to None.
426
+ num_uri_merge (int, optional): merge how many uri into a shard,
427
+ default to ``0``; if setting ``-1``, then all uri will be merged into a shard.
428
+ """
429
+
430
+ def __init__(self, uri: str, glob: str | None = None, recursive: bool = True, num_uri_merge: int = 0) -> None:
431
+ # from_uri is a static method, but pyarrow-stubs says it's an instance one
432
+ # Extract file system and path from URI
433
+ fs: FileSystem
434
+ path: str
435
+ fs, path = FileSystem.from_uri(uri)
436
+
437
+ # Define a selector for files in the specified path
438
+ selector = FileSelector(path, recursive=recursive)
439
+
440
+ # Initialize shards list
441
+ self.shards = []
442
+ # Populate shards list with files matching the glob pattern (if provided)
443
+ for file in fs.get_file_info(selector):
444
+ path = f"{fs.type_name}://{file.path}"
445
+ if not glob or fnmatch.fnmatch(path, glob):
446
+ self.shards.append(path)
447
+ if num_uri_merge != 0:
448
+ merged_shards: list[list[str]] = []
449
+ shards: list[str] = []
450
+ for i in range(len(self.shards)):
451
+ shards.append(str(self.shards[i]))
452
+ if len(shards) == num_uri_merge:
453
+ merged_shards.append(shards.copy())
454
+ shards.clear()
455
+ if shards:
456
+ merged_shards.append(shards)
457
+ self.shards = merged_shards # type: ignore[assignment]
458
+ # Ensure that shards list is not empty
459
+ assert self.shards
460
+
461
+ # Initialize reset and index vars
462
+ self.reset = True
463
+ self.index = -1
464
+
465
+ @override
466
+ def __iter__(self) -> Iterator[Shard]:
467
+ """Returns an iterator over the shards.
468
+
469
+ Returns:
470
+ Iterator[Shard]: Iterator over the shards.
471
+ """
472
+ # Reset iterator if reset flag is set
473
+ if self.reset:
474
+ self.index = -1
475
+ self.reset = True
476
+ return self
477
+
478
+ def __next__(self) -> Shard:
479
+ """Returns the next shard.
480
+
481
+ Returns:
482
+ Shard: Next shard.
483
+ """
484
+ # Increment index and return the corresponding shard
485
+ self.index += 1
486
+ if self.index >= len(self.shards):
487
+ raise StopIteration
488
+ return self.shards[self.index]
489
+
490
+
491
+ class ShardedDataLoader(Iterable[list[T_co]]):
492
+ """A :class:`DataLoader` that processes data in shards, designed for distributed training scenarios.
493
+
494
+ Enables double-buffered micro-batch processing and fetching that overlaps with model
495
+ forward/backward passes, minimizing dataloading overhead.
496
+
497
+ Args:
498
+ seed (int): Random seed for reproducibility. Use ${seed} at top level in config.yaml.
499
+ shard_sampler (ShardSampler): Sampler for generating shards.
500
+ dataloader_factory (DataLoaderFactory[T_co]): Factory for creating DataLoaders.
501
+ num_shard_prefetches (int, optional): Number of shards to prefetch.
502
+ Defaults to 0.
503
+ num_microbatch_prefetches (int, optional): Number of microbatches to prefetch.
504
+ Defaults to -1.
505
+ dp_rank (int, optional): Rank of the current process.
506
+ Defaults to 0.
507
+ profiler (Profiler, optional): Profiler for profiling.
508
+ Defaults to None.
509
+ device (Optional[torch.device]): device to move the microbatches to in the background
510
+ multiprocessing (Optional[True]): whether to instantiate DataLoader in a separate process.
511
+ Defaults to True to relieve pressure from the training process, use False to debug and profile
512
+ """
513
+
514
+ def __init__(
515
+ self,
516
+ seed: int,
517
+ shard_sampler: ShardSampler,
518
+ dataloader_factory: DataLoaderFactory[T_co],
519
+ num_shard_prefetches: int = 0,
520
+ num_microbatch_prefetches: int = -1,
521
+ dp_rank: int = 0,
522
+ profiler: Profiler | None = None,
523
+ device: torch.device | None = None,
524
+ multiprocessing: bool = True,
525
+ ) -> None:
526
+ # Initialize
527
+ self.microbatches: Iterator[list[T_co]] | None = None
528
+ self.path: Path | None = None
529
+ self.device: torch.device | None = device
530
+ self.cleanup: set[Path] = set()
531
+ self.shutdown = mp.Event()
532
+ self.shard_sampler = shard_sampler
533
+ self.shard_sampler_iter: Iterator[Shard]
534
+ self.dataloader_factory = dataloader_factory
535
+ self.dataloader_iter_generator = DataLoaderIterGenerator(
536
+ dataloader_factory,
537
+ num_microbatch_prefetches,
538
+ )
539
+ self.data_jobs: deque[tuple[Shard, Path, mp.pool.AsyncResult[Any]]] = deque() # type: ignore[unresolved-attribute]
540
+
541
+ # Initialize a new ProcessPoolExecutor instance for prefetching shards if necessary
542
+ signal.signal(signal.SIGTERM, self.teardown) # terminate signal
543
+ signal.signal(signal.SIGINT, self.teardown) # keyboard interrupt
544
+ atexit.register(self.teardown)
545
+ self.writing_pool: NoDaemonPool | ThreadPool = (
546
+ NoDaemonPool(
547
+ max(1, num_shard_prefetches),
548
+ initializer=initialize,
549
+ initargs=(seed, dp_rank, self.shutdown, profiler),
550
+ )
551
+ if multiprocessing
552
+ else ThreadPool(
553
+ max_workers=1,
554
+ thread_name_prefix="ShardedDataWriter",
555
+ initializer=initialize,
556
+ initargs=(seed, dp_rank, self.shutdown, profiler),
557
+ )
558
+ )
559
+ self.num_shard_prefetches = num_shard_prefetches
560
+ self.reading_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ShardedDataReader")
561
+ self.next_batch: Future[Any] | None = None
562
+
563
+ def __iter__(self) -> Iterator[list[T_co]]:
564
+ # all reading needs to go through the same TPE to avoid contention
565
+ if not self.reading_pool._shutdown:
566
+ self.reading_pool.submit(self._cleanup).result()
567
+ self.shard_sampler_iter = iter(self.shard_sampler)
568
+ return self
569
+
570
+ def __next__(self) -> list[T_co]:
571
+ if not self.next_batch:
572
+ self.load_batch()
573
+ assert self.next_batch
574
+ batch = self.next_batch.result()
575
+ self.load_batch() # double-buffering
576
+ return batch
577
+
578
+ def load_batch(self) -> None:
579
+ if not self.reading_pool._shutdown:
580
+ self.next_batch = self.reading_pool.submit(self.load_batch_sync)
581
+
582
+ def load_batch_sync(self) -> list[T_co]:
583
+ while True:
584
+ if self.microbatches:
585
+ # Fetch the next microbatch if available
586
+ try:
587
+ microbatch = next(self.microbatches)
588
+ # Move to target device in advance
589
+ if self.device:
590
+ microbatch = move_data_to_device(microbatch, self.device)
591
+ return microbatch
592
+ # If no microbatches are available, which means all microbatches from current shard are exhausted
593
+ except StopIteration:
594
+ if self.path:
595
+ self.cleanup.remove(self.path)
596
+ self.microbatches = None
597
+ self.path = None
598
+
599
+ if len(self.data_jobs) == 0:
600
+ logger.debug("load iter scheduling ...")
601
+ self.prefetch_shards(max(1, self.num_shard_prefetches))
602
+ if len(self.data_jobs) == 0:
603
+ raise StopIteration
604
+ shard, path, data_job = self.data_jobs.popleft()
605
+ logger.debug(f"load iter {shard} to {path} ...")
606
+ self.prefetch_shards(min(1, self.num_shard_prefetches)) # prefetch next shard in parallel
607
+
608
+ def wait_callback() -> None:
609
+ if not data_job.ready(): # noqa: B023
610
+ # Job is still running
611
+ return None
612
+ else:
613
+ # Job is finished, raise exception if job failed.
614
+ data_job.get() # noqa: B023
615
+ # Return whether the call completed without raising an exception.
616
+ assert data_job.successful() # noqa: B023
617
+
618
+ self.microbatches = shm.load_iter(path, wait_callback=wait_callback)
619
+ self.cleanup.add(path)
620
+ self.path = path
621
+
622
+ def state_dict(self) -> dict[str, Any]:
623
+ """
624
+ Returns the shard sampler state dict with adjusted shard indices,
625
+ accounting for shard prefetches and prefetch backfill in parallel.
626
+
627
+ Example:
628
+ If num_shard_prefetches is 3 and the original state dict is
629
+ {"all_rank_indices": [torch.tensor(4), torch.tensor(5)]},
630
+ it will be updated to {"all_rank_indices": [torch.tensor(0), torch.tensor(1)]}.
631
+ This ensures that each rank resumes training from the correct shard index,
632
+ preventing reprocessing of shards that have already been trained on.
633
+ """
634
+ # Invoke the shard_sampler's state_dict method when saving the data shard indices for each rank
635
+ shard_sampler_state_dict = self.shard_sampler.state_dict()
636
+
637
+ # Define the minimum prefetched shard backfill
638
+ min_prefetch_shard_backfill = 1
639
+
640
+ # Adjust each rank's shard index in place
641
+ for _, rank_indices in shard_sampler_state_dict.items():
642
+ for i, idx_tensor in enumerate(rank_indices):
643
+ rank_indices[i] = idx_tensor - self.num_shard_prefetches - min_prefetch_shard_backfill
644
+
645
+ return shard_sampler_state_dict
646
+
647
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
648
+ """
649
+ Loads the state dict into the shard sampler, restoring the data shard indices for each rank.
650
+
651
+ The state dict should look like {"all_rank_indices":
652
+ [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]},
653
+ where each tensor corresponds to the indices of data shards for specific ranks.
654
+ """
655
+ # Restore the shard sampler's state from the given state_dict
656
+ self.shard_sampler.load_state_dict(state_dict)
657
+
658
+ def prefetch_shards(self, count: int) -> None:
659
+ try:
660
+ for _ in range(count):
661
+ shard = next(self.shard_sampler_iter)
662
+ path = shm.generate_path()
663
+ # append data_job to job pool
664
+ data_job = self.writing_pool.apply_async(self.dataloader_iter_generator, (shard, path))
665
+ self.data_jobs.append((shard, path, data_job))
666
+ if logger.isEnabledFor(logging.DEBUG):
667
+ logger.debug(f"queued {path}, {len(self.data_jobs)}")
668
+ except StopIteration:
669
+ pass
670
+
671
+ def _cleanup(self, stop_pool: bool = False) -> None:
672
+ self.microbatches = None
673
+ self.shutdown.set() # signal running tasks to stop
674
+ if self.next_batch:
675
+ try:
676
+ # if called on teardown/on_exception/__del__ will wait for the pending work to finish
677
+ # if called on __init__ it's already done
678
+ self.next_batch.result()
679
+ except Exception:
680
+ pass
681
+ self.next_batch = None
682
+ if stop_pool:
683
+ self.writing_pool.close() # no new tasks can run
684
+ self.reading_pool.shutdown()
685
+ for _, path, result in self.data_jobs:
686
+ logger.debug(f"waiting for {path} to stop ...")
687
+ self.cleanup.add(path)
688
+ try:
689
+ result.wait(timeout=DEFAULT_SHUTDOWN_TIMEOUT)
690
+ except Exception:
691
+ pass
692
+ logger.debug(f"{path} stopped ...")
693
+ self.data_jobs.clear()
694
+ self.shutdown.clear()
695
+ if stop_pool:
696
+ self.writing_pool.join() # make sure atexit is triggered in each subprocess
697
+ for path in self.cleanup:
698
+ logger.debug(f"removing {path} ...")
699
+ shutil.rmtree(path, ignore_errors=True)
700
+
701
+ def set_device(self, device: torch.device | None) -> None:
702
+ self.device = device
703
+
704
+ # called when fit/validate/predict/test is complete
705
+ def teardown(self, *args: Any) -> None:
706
+ logger.debug("teardown ...")
707
+ self._cleanup(stop_pool=True)
708
+ logger.debug("teardown complete")
709
+
710
+ # will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
711
+ # once the below callback is operational we no longer need __del__ override
712
+ def on_exception(self, exception: BaseException) -> None:
713
+ self.teardown()
714
+
715
+ # called when the iterable link pointing to this object goes out of scope
716
+ # e.g. when exception happens
717
+ def __del__(self) -> None:
718
+ self.teardown()