wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/utils/data.py ADDED
@@ -0,0 +1,1220 @@
1
+ """
2
+ Data Loading and Preprocessing Utilities
3
+ =========================================
4
+
5
+ Provides memory-efficient data loading for large-scale datasets with:
6
+ - Memory-mapped file support for datasets exceeding RAM
7
+ - DDP-safe data preparation with proper synchronization
8
+ - Thread-safe DataLoader worker initialization
9
+ - Multi-format support (NPZ, HDF5, MAT)
10
+
11
+ Author: Ductho Le (ductho.le@outlook.com)
12
+ Version: 1.0.0
13
+ """
14
+
15
+ import gc
16
+ import logging
17
+ import os
18
+ import pickle
19
+ from abc import ABC, abstractmethod
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import h5py
24
+ import numpy as np
25
+ import torch
26
+ from accelerate import Accelerator
27
+ from scipy.sparse import issparse
28
+ from sklearn.model_selection import train_test_split
29
+ from sklearn.preprocessing import StandardScaler
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from tqdm.auto import tqdm
32
+
33
+
34
+ # Optional scipy.io for MATLAB files
35
+ try:
36
+ import scipy.io
37
+
38
+ HAS_SCIPY_IO = True
39
+ except ImportError:
40
+ HAS_SCIPY_IO = False
41
+
42
+
43
+ # ==============================================================================
44
+ # DATA SOURCE ABSTRACTION
45
+ # ==============================================================================
46
+
47
+ # Supported key names for input/output arrays (priority order, pairwise aligned)
48
+ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x"]
49
+ OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
50
+
51
+
52
+ class LazyDataHandle:
53
+ """
54
+ Context manager wrapper for memory-mapped data handles.
55
+
56
+ Provides proper cleanup of file handles returned by load_mmap() methods.
57
+ Can be used either as a context manager (recommended) or with explicit close().
58
+
59
+ Usage:
60
+ # Context manager (recommended)
61
+ with source.load_mmap(path) as (inputs, outputs):
62
+ # Use inputs and outputs
63
+ pass # File automatically closed
64
+
65
+ # Manual cleanup
66
+ handle = source.load_mmap(path)
67
+ inputs, outputs = handle.inputs, handle.outputs
68
+ # ... use data ...
69
+ handle.close()
70
+
71
+ Attributes:
72
+ inputs: Input data array/dataset
73
+ outputs: Output data array/dataset
74
+ """
75
+
76
+ def __init__(self, inputs, outputs, file_handle=None):
77
+ """
78
+ Initialize the handle.
79
+
80
+ Args:
81
+ inputs: Input array or lazy dataset
82
+ outputs: Output array or lazy dataset
83
+ file_handle: Optional file handle to close on cleanup
84
+ """
85
+ self.inputs = inputs
86
+ self.outputs = outputs
87
+ self._file = file_handle
88
+ self._closed = False
89
+
90
+ def __enter__(self):
91
+ """Return (inputs, outputs) tuple for unpacking."""
92
+ return self.inputs, self.outputs
93
+
94
+ def __exit__(self, exc_type, exc_val, exc_tb):
95
+ """Close file handle on context exit."""
96
+ self.close()
97
+ return False # Don't suppress exceptions
98
+
99
+ def close(self):
100
+ """
101
+ Close the underlying file handle.
102
+
103
+ Safe to call multiple times.
104
+ """
105
+ if self._closed:
106
+ return
107
+ self._closed = True
108
+
109
+ # Close the file handle if we have one
110
+ if self._file is not None:
111
+ try:
112
+ self._file.close()
113
+ except Exception:
114
+ pass
115
+ self._file = None
116
+
117
+ # Also close any _TransposedH5Dataset wrappers
118
+ for data in (self.inputs, self.outputs):
119
+ if hasattr(data, "close"):
120
+ try:
121
+ data.close()
122
+ except Exception:
123
+ pass
124
+
125
+ def __repr__(self) -> str:
126
+ status = "closed" if self._closed else "open"
127
+ return f"LazyDataHandle(status={status})"
128
+
129
+
130
+ class DataSource(ABC):
131
+ """
132
+ Abstract base class for data loaders supporting multiple file formats.
133
+
134
+ Subclasses must implement the `load()` method to return input/output arrays,
135
+ and optionally `load_outputs_only()` for memory-efficient target loading.
136
+ """
137
+
138
+ @abstractmethod
139
+ def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
140
+ """
141
+ Load input and output arrays from a file.
142
+
143
+ Args:
144
+ path: Path to the data file
145
+
146
+ Returns:
147
+ Tuple of (inputs, outputs) as numpy arrays
148
+ """
149
+ pass
150
+
151
+ @abstractmethod
152
+ def load_outputs_only(self, path: str) -> np.ndarray:
153
+ """
154
+ Load only output/target arrays from a file (memory-efficient).
155
+
156
+ This avoids loading large input arrays when only targets are needed,
157
+ which is critical for HPC environments with memory constraints.
158
+
159
+ Args:
160
+ path: Path to the data file
161
+
162
+ Returns:
163
+ Output/target array
164
+ """
165
+ pass
166
+
167
+ @staticmethod
168
+ def detect_format(path: str) -> str:
169
+ """
170
+ Auto-detect file format from extension.
171
+
172
+ Args:
173
+ path: Path to data file
174
+
175
+ Returns:
176
+ Format string: 'npz', 'hdf5', or 'mat'
177
+ """
178
+ ext = Path(path).suffix.lower()
179
+ format_map = {
180
+ ".npz": "npz",
181
+ ".h5": "hdf5",
182
+ ".hdf5": "hdf5",
183
+ ".mat": "mat",
184
+ }
185
+ if ext not in format_map:
186
+ raise ValueError(
187
+ f"Unsupported file extension: '{ext}'. "
188
+ f"Supported formats: .npz, .h5, .hdf5, .mat"
189
+ )
190
+ return format_map[ext]
191
+
192
+ @staticmethod
193
+ def _find_key(available_keys: list[str], candidates: list[str]) -> str | None:
194
+ """Find first matching key from candidates in available keys."""
195
+ for key in candidates:
196
+ if key in available_keys:
197
+ return key
198
+ return None
199
+
200
+
201
+ class NPZSource(DataSource):
202
+ """Load data from NumPy .npz archives."""
203
+
204
+ def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
205
+ data = np.load(path, allow_pickle=True)
206
+ keys = list(data.keys())
207
+
208
+ input_key = self._find_key(keys, INPUT_KEYS)
209
+ output_key = self._find_key(keys, OUTPUT_KEYS)
210
+
211
+ if input_key is None or output_key is None:
212
+ raise KeyError(
213
+ f"NPZ must contain input and output arrays. "
214
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
215
+ f"Found: {keys}"
216
+ )
217
+
218
+ inp = data[input_key]
219
+ outp = data[output_key]
220
+
221
+ # Handle object arrays (e.g., sparse matrices stored as objects)
222
+ if inp.dtype == object:
223
+ inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
224
+
225
+ return inp, outp
226
+
227
+ def load_mmap(self, path: str) -> tuple[np.ndarray, np.ndarray]:
228
+ """
229
+ Load data using memory-mapped mode for zero-copy access.
230
+
231
+ This allows processing large datasets without loading them entirely
232
+ into RAM. Critical for HPC environments with memory constraints.
233
+
234
+ Note: Returns memory-mapped arrays - do NOT modify them.
235
+ """
236
+ data = np.load(path, allow_pickle=True, mmap_mode="r")
237
+ keys = list(data.keys())
238
+
239
+ input_key = self._find_key(keys, INPUT_KEYS)
240
+ output_key = self._find_key(keys, OUTPUT_KEYS)
241
+
242
+ if input_key is None or output_key is None:
243
+ raise KeyError(
244
+ f"NPZ must contain input and output arrays. "
245
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
246
+ f"Found: {keys}"
247
+ )
248
+
249
+ inp = data[input_key]
250
+ outp = data[output_key]
251
+
252
+ return inp, outp
253
+
254
+ def load_outputs_only(self, path: str) -> np.ndarray:
255
+ """Load only targets from NPZ (avoids loading large input arrays)."""
256
+ data = np.load(path, allow_pickle=True)
257
+ keys = list(data.keys())
258
+
259
+ output_key = self._find_key(keys, OUTPUT_KEYS)
260
+ if output_key is None:
261
+ raise KeyError(
262
+ f"NPZ must contain output array. "
263
+ f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
264
+ )
265
+
266
+ return data[output_key]
267
+
268
+
269
+ class HDF5Source(DataSource):
270
+ """Load data from HDF5 (.h5, .hdf5) files."""
271
+
272
+ def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
273
+ with h5py.File(path, "r") as f:
274
+ keys = list(f.keys())
275
+
276
+ input_key = self._find_key(keys, INPUT_KEYS)
277
+ output_key = self._find_key(keys, OUTPUT_KEYS)
278
+
279
+ if input_key is None or output_key is None:
280
+ raise KeyError(
281
+ f"HDF5 must contain input and output datasets. "
282
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
283
+ f"Found: {keys}"
284
+ )
285
+
286
+ # Load into memory (HDF5 datasets are lazy by default)
287
+ inp = f[input_key][:]
288
+ outp = f[output_key][:]
289
+
290
+ return inp, outp
291
+
292
+ def load_mmap(self, path: str) -> LazyDataHandle:
293
+ """
294
+ Load HDF5 file with lazy/memory-mapped access.
295
+
296
+ Returns a LazyDataHandle that reads from disk on-demand,
297
+ avoiding loading the entire file into RAM.
298
+
299
+ Usage:
300
+ with source.load_mmap(path) as (inputs, outputs):
301
+ # Use inputs and outputs
302
+ pass # File automatically closed
303
+ """
304
+ f = h5py.File(path, "r") # Keep file open for lazy access
305
+ keys = list(f.keys())
306
+
307
+ input_key = self._find_key(keys, INPUT_KEYS)
308
+ output_key = self._find_key(keys, OUTPUT_KEYS)
309
+
310
+ if input_key is None or output_key is None:
311
+ f.close()
312
+ raise KeyError(
313
+ f"HDF5 must contain input and output datasets. "
314
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
315
+ f"Found: {keys}"
316
+ )
317
+
318
+ # Return wrapped handle for proper cleanup
319
+ return LazyDataHandle(f[input_key], f[output_key], file_handle=f)
320
+
321
+ def load_outputs_only(self, path: str) -> np.ndarray:
322
+ """Load only targets from HDF5 (avoids loading large input arrays)."""
323
+ with h5py.File(path, "r") as f:
324
+ keys = list(f.keys())
325
+
326
+ output_key = self._find_key(keys, OUTPUT_KEYS)
327
+ if output_key is None:
328
+ raise KeyError(
329
+ f"HDF5 must contain output dataset. "
330
+ f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
331
+ )
332
+
333
+ outp = f[output_key][:]
334
+
335
+ return outp
336
+
337
+
338
+ class _TransposedH5Dataset:
339
+ """
340
+ Lazy transpose wrapper for h5py datasets.
341
+
342
+ MATLAB stores arrays in column-major (Fortran) order, while Python/NumPy
343
+ expects row-major (C) order. This wrapper provides a transposed view
344
+ without loading the entire dataset into memory.
345
+
346
+ Supports:
347
+ - len(): Returns the transposed first dimension
348
+ - []: Returns slices with automatic transpose
349
+ - shape: Returns the transposed shape
350
+ - dtype: Returns the underlying dtype
351
+
352
+ This is critical for MATSource.load_mmap() to return consistent axis
353
+ ordering with the eager loader (MATSource.load()).
354
+
355
+ IMPORTANT: Holds a strong reference to the h5py.File to prevent
356
+ premature garbage collection while datasets are in use.
357
+ """
358
+
359
+ def __init__(self, h5_dataset, file_handle=None):
360
+ """
361
+ Args:
362
+ h5_dataset: The h5py dataset to wrap
363
+ file_handle: Optional h5py.File reference to keep alive
364
+ """
365
+ self._dataset = h5_dataset
366
+ self._file = file_handle # Keep file alive to prevent GC
367
+ # Transpose shape: MATLAB (cols, rows, ...) -> Python (rows, cols, ...)
368
+ self.shape = tuple(reversed(h5_dataset.shape))
369
+ self.dtype = h5_dataset.dtype
370
+
371
+ # Precompute transpose axis order for efficiency
372
+ # For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0)
373
+ self._transpose_axes = tuple(range(len(h5_dataset.shape) - 1, -1, -1))
374
+
375
+ def __len__(self) -> int:
376
+ return self.shape[0]
377
+
378
+ def __getitem__(self, idx):
379
+ """
380
+ Fetch data with automatic full transpose.
381
+
382
+ Handles integer indexing, slices, and fancy indexing.
383
+ All operations return data with fully reversed axes to match .T behavior.
384
+ """
385
+ if isinstance(idx, int | np.integer):
386
+ # Single sample: index into last axis of h5py dataset (column-major)
387
+ # Result needs full transpose of remaining dimensions
388
+ data = self._dataset[..., idx]
389
+ if data.ndim == 0:
390
+ return data
391
+ elif data.ndim == 1:
392
+ return data # 1D doesn't need transpose
393
+ else:
394
+ # Full transpose: reverse all axes
395
+ return np.transpose(data)
396
+
397
+ elif isinstance(idx, slice):
398
+ # Slice indexing: fetch from last axis, then fully transpose
399
+ start, stop, step = idx.indices(self.shape[0])
400
+ data = self._dataset[..., start:stop:step]
401
+
402
+ # Handle special case: 1D result (e.g., row vector)
403
+ if data.ndim == 1:
404
+ return data
405
+
406
+ # Full transpose: reverse ALL axes (not just moveaxis)
407
+ # This matches the behavior of .T on a numpy array
408
+ return np.transpose(data, axes=self._transpose_axes)
409
+
410
+ elif isinstance(idx, list | np.ndarray):
411
+ # Fancy indexing: load samples one at a time (h5py limitation)
412
+ # This is slower but necessary for compatibility
413
+ samples = [self[i] for i in idx]
414
+ return np.stack(samples, axis=0)
415
+
416
+ else:
417
+ raise TypeError(f"Unsupported index type: {type(idx)}")
418
+
419
+ def close(self):
420
+ """Close the underlying file handle if we own it."""
421
+ if self._file is not None:
422
+ try:
423
+ self._file.close()
424
+ except Exception:
425
+ pass
426
+
427
+
428
+ class MATSource(DataSource):
429
+ """
430
+ Load data from MATLAB .mat files (v7.3+ only, which uses HDF5 format).
431
+
432
+ Note: MAT v7.3 files are HDF5 files under the hood, so we use h5py for
433
+ memory-efficient lazy loading. Save with: save('file.mat', '-v7.3')
434
+
435
+ Supports MATLAB sparse matrices (automatically converted to dense).
436
+
437
+ For older MAT files (v5/v7), convert to NPZ or save with -v7.3 flag.
438
+ """
439
+
440
+ @staticmethod
441
+ def _is_sparse_dataset(dataset) -> bool:
442
+ """Check if an HDF5 dataset/group represents a MATLAB sparse matrix."""
443
+ # MATLAB v7.3 stores sparse matrices as groups with 'data', 'ir', 'jc' keys
444
+ if hasattr(dataset, "keys"):
445
+ keys = set(dataset.keys())
446
+ return {"data", "ir", "jc"}.issubset(keys)
447
+ return False
448
+
449
+ @staticmethod
450
+ def _load_sparse_to_dense(group) -> np.ndarray:
451
+ """Convert MATLAB sparse matrix (CSC format in HDF5) to dense numpy array."""
452
+ from scipy.sparse import csc_matrix
453
+
454
+ data = np.array(group["data"])
455
+ ir = np.array(group["ir"]) # row indices
456
+ jc = np.array(group["jc"]) # column pointers
457
+
458
+ # Get shape from MATLAB attributes or infer
459
+ if "MATLAB_sparse" in group.attrs:
460
+ nrows = group.attrs["MATLAB_sparse"]
461
+ else:
462
+ nrows = ir.max() + 1 if len(ir) > 0 else 0
463
+ ncols = len(jc) - 1
464
+
465
+ sparse_mat = csc_matrix((data, ir, jc), shape=(nrows, ncols))
466
+ return sparse_mat.toarray()
467
+
468
+ def _load_dataset(self, f, key: str) -> np.ndarray:
469
+ """Load a dataset, handling sparse matrices automatically."""
470
+ dataset = f[key]
471
+
472
+ if self._is_sparse_dataset(dataset):
473
+ # Sparse matrix: convert to dense
474
+ arr = self._load_sparse_to_dense(dataset)
475
+ else:
476
+ # Regular dense array
477
+ arr = np.array(dataset)
478
+
479
+ # Transpose for MATLAB column-major -> Python row-major
480
+ return arr.T
481
+
482
+ def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
483
+ """Load MAT v7.3 file using h5py."""
484
+ try:
485
+ with h5py.File(path, "r") as f:
486
+ keys = list(f.keys())
487
+
488
+ input_key = self._find_key(keys, INPUT_KEYS)
489
+ output_key = self._find_key(keys, OUTPUT_KEYS)
490
+
491
+ if input_key is None or output_key is None:
492
+ raise KeyError(
493
+ f"MAT file must contain input and output arrays. "
494
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
495
+ f"Found: {keys}"
496
+ )
497
+
498
+ # Load with sparse matrix support
499
+ inp = self._load_dataset(f, input_key)
500
+ outp = self._load_dataset(f, output_key)
501
+
502
+ # Handle 1D outputs that become (1, N) after transpose
503
+ if outp.ndim == 2 and outp.shape[0] == 1:
504
+ outp = outp.T
505
+
506
+ except OSError as e:
507
+ raise ValueError(
508
+ f"Failed to load MAT file: {path}. "
509
+ f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
510
+ f"Original error: {e}"
511
+ )
512
+
513
+ return inp, outp
514
+
515
+ def load_mmap(self, path: str) -> LazyDataHandle:
516
+ """
517
+ Load MAT v7.3 file with lazy/memory-mapped access.
518
+
519
+ Returns a LazyDataHandle that reads from disk on-demand,
520
+ avoiding loading the entire file into RAM.
521
+
522
+ Note: For sparse matrices, this will load and convert them.
523
+ For dense arrays, returns a transposed view wrapper for consistent axis ordering.
524
+
525
+ Usage:
526
+ with source.load_mmap(path) as (inputs, outputs):
527
+ # Use inputs and outputs
528
+ pass # File automatically closed
529
+ """
530
+ try:
531
+ f = h5py.File(path, "r") # Keep file open for lazy access
532
+ keys = list(f.keys())
533
+
534
+ input_key = self._find_key(keys, INPUT_KEYS)
535
+ output_key = self._find_key(keys, OUTPUT_KEYS)
536
+
537
+ if input_key is None or output_key is None:
538
+ f.close()
539
+ raise KeyError(
540
+ f"MAT file must contain input and output arrays. "
541
+ f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
542
+ f"Found: {keys}"
543
+ )
544
+
545
+ # Check for sparse matrices - must load them eagerly
546
+ inp_dataset = f[input_key]
547
+ outp_dataset = f[output_key]
548
+
549
+ if self._is_sparse_dataset(inp_dataset):
550
+ inp = self._load_sparse_to_dense(inp_dataset).T
551
+ else:
552
+ # Wrap h5py dataset with transpose view for consistent axis order
553
+ # MATLAB stores column-major, Python expects row-major
554
+ # Pass file handle to keep it alive
555
+ inp = _TransposedH5Dataset(inp_dataset, file_handle=f)
556
+
557
+ if self._is_sparse_dataset(outp_dataset):
558
+ outp = self._load_sparse_to_dense(outp_dataset).T
559
+ else:
560
+ # Wrap h5py dataset with transpose view (shares same file handle)
561
+ outp = _TransposedH5Dataset(outp_dataset, file_handle=f)
562
+
563
+ # Return wrapped handle for proper cleanup
564
+ return LazyDataHandle(inp, outp, file_handle=f)
565
+
566
+ except OSError as e:
567
+ raise ValueError(
568
+ f"Failed to load MAT file: {path}. "
569
+ f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
570
+ f"Original error: {e}"
571
+ )
572
+
573
+ def load_outputs_only(self, path: str) -> np.ndarray:
574
+ """Load only targets from MAT v7.3 file (avoids loading large input arrays)."""
575
+ try:
576
+ with h5py.File(path, "r") as f:
577
+ keys = list(f.keys())
578
+
579
+ output_key = self._find_key(keys, OUTPUT_KEYS)
580
+ if output_key is None:
581
+ raise KeyError(
582
+ f"MAT file must contain output array. "
583
+ f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
584
+ )
585
+
586
+ # Load with sparse matrix support
587
+ outp = self._load_dataset(f, output_key)
588
+
589
+ # Handle 1D outputs
590
+ if outp.ndim == 2 and outp.shape[0] == 1:
591
+ outp = outp.T
592
+
593
+ except OSError as e:
594
+ raise ValueError(
595
+ f"Failed to load MAT file: {path}. "
596
+ f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
597
+ f"Original error: {e}"
598
+ )
599
+
600
+ return outp
601
+
602
+
603
+ def get_data_source(format: str) -> DataSource:
604
+ """
605
+ Factory function to get the appropriate DataSource for a format.
606
+
607
+ Args:
608
+ format: One of 'npz', 'hdf5', 'mat'
609
+
610
+ Returns:
611
+ DataSource instance
612
+ """
613
+ sources = {
614
+ "npz": NPZSource,
615
+ "hdf5": HDF5Source,
616
+ "mat": MATSource,
617
+ }
618
+
619
+ if format not in sources:
620
+ raise ValueError(
621
+ f"Unsupported format: {format}. Supported: {list(sources.keys())}"
622
+ )
623
+
624
+ return sources[format]()
625
+
626
+
627
+ def load_training_data(
628
+ path: str, format: str = "auto"
629
+ ) -> tuple[np.ndarray, np.ndarray]:
630
+ """
631
+ Load training data from file with automatic format detection.
632
+
633
+ Supports:
634
+ - NPZ: NumPy compressed archives (.npz)
635
+ - HDF5: Hierarchical Data Format (.h5, .hdf5)
636
+ - MAT: MATLAB files (.mat)
637
+
638
+ Flexible key detection supports: input_train/X/data and output_train/y/labels.
639
+
640
+ Args:
641
+ path: Path to data file
642
+ format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
643
+
644
+ Returns:
645
+ Tuple of (inputs, outputs) arrays
646
+ """
647
+ if format == "auto":
648
+ format = DataSource.detect_format(path)
649
+
650
+ source = get_data_source(format)
651
+ return source.load(path)
652
+
653
+
654
+ def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
655
+ """
656
+ Load only output/target arrays from file (memory-efficient).
657
+
658
+ This function avoids loading large input arrays when only targets are needed,
659
+ which is critical for HPC environments with memory constraints during DDP.
660
+
661
+ Args:
662
+ path: Path to data file
663
+ format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
664
+
665
+ Returns:
666
+ Output/target array
667
+ """
668
+ if format == "auto":
669
+ format = DataSource.detect_format(path)
670
+
671
+ source = get_data_source(format)
672
+ return source.load_outputs_only(path)
673
+
674
+
675
+ def load_test_data(
676
+ path: str,
677
+ format: str = "auto",
678
+ input_key: str | None = None,
679
+ output_key: str | None = None,
680
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
681
+ """
682
+ Load test/inference data and return PyTorch tensors ready for model input.
683
+
684
+ This is the unified data loading function for inference. It:
685
+ - Auto-detects file format from extension
686
+ - Handles custom key names for non-standard datasets
687
+ - Adds channel dimension if missing (dimension-agnostic)
688
+ - Returns None for targets if not present in file
689
+
690
+ Supports any input dimensionality:
691
+ - 1D: (N, L) → (N, 1, L)
692
+ - 2D: (N, H, W) → (N, 1, H, W)
693
+ - 3D: (N, D, H, W) → (N, 1, D, H, W)
694
+ - Already has channel: (N, C, ...) → unchanged
695
+
696
+ Args:
697
+ path: Path to data file (NPZ, HDF5, or MAT v7.3)
698
+ format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
699
+ input_key: Custom key for input data (overrides auto-detection)
700
+ output_key: Custom key for output data (overrides auto-detection)
701
+
702
+ Returns:
703
+ Tuple of:
704
+ - X: Input tensor with channel dimension (N, 1, *spatial_dims)
705
+ - y: Target tensor (N, T) or None if targets not present
706
+
707
+ Example:
708
+ >>> X, y = load_test_data("test_data.npz")
709
+ >>> X, y = load_test_data(
710
+ ... "data.mat", input_key="waveforms", output_key="params"
711
+ ... )
712
+ """
713
+ if format == "auto":
714
+ format = DataSource.detect_format(path)
715
+
716
+ source = get_data_source(format)
717
+
718
+ # Build custom key lists if provided
719
+ if input_key:
720
+ custom_input_keys = [input_key] + INPUT_KEYS
721
+ else:
722
+ # Prioritize test keys for inference
723
+ custom_input_keys = ["input_test"] + [
724
+ k for k in INPUT_KEYS if k != "input_test"
725
+ ]
726
+
727
+ if output_key:
728
+ custom_output_keys = [output_key] + OUTPUT_KEYS
729
+ else:
730
+ custom_output_keys = ["output_test"] + [
731
+ k for k in OUTPUT_KEYS if k != "output_test"
732
+ ]
733
+
734
+ # Load data using appropriate source
735
+ try:
736
+ inp, outp = source.load(path)
737
+ except KeyError:
738
+ # Try with just inputs if outputs not found
739
+ if format == "npz":
740
+ data = np.load(path, allow_pickle=True)
741
+ keys = list(data.keys())
742
+ inp_key = DataSource._find_key(keys, custom_input_keys)
743
+ if inp_key is None:
744
+ raise KeyError(
745
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
746
+ )
747
+ inp = data[inp_key]
748
+ if inp.dtype == object:
749
+ inp = np.array(
750
+ [x.toarray() if hasattr(x, "toarray") else x for x in inp]
751
+ )
752
+ out_key = DataSource._find_key(keys, custom_output_keys)
753
+ outp = data[out_key] if out_key else None
754
+ else:
755
+ raise
756
+
757
+ # Handle sparse matrices
758
+ if issparse(inp):
759
+ inp = inp.toarray()
760
+ if outp is not None and issparse(outp):
761
+ outp = outp.toarray()
762
+
763
+ # Convert to tensors
764
+ X = torch.tensor(np.asarray(inp), dtype=torch.float32)
765
+
766
+ if outp is not None:
767
+ y = torch.tensor(np.asarray(outp), dtype=torch.float32)
768
+ # Normalize target shape: (N,) → (N, 1)
769
+ if y.ndim == 1:
770
+ y = y.unsqueeze(1)
771
+ else:
772
+ y = None
773
+
774
+ # Add channel dimension if needed (dimension-agnostic)
775
+ # X.ndim == 2: 1D data (N, L) → (N, 1, L)
776
+ # X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
777
+ # X.ndim == 4: Check if already has channel dim (C <= 16 heuristic)
778
+ if X.ndim == 2:
779
+ X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
780
+ elif X.ndim == 3:
781
+ X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
782
+ elif X.ndim == 4:
783
+ # Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
784
+ # Heuristic: if dim 1 is small (<=16), assume it's already a channel dim
785
+ if X.shape[1] > 16:
786
+ X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
787
+ # X.ndim >= 5: assume channel dimension already exists
788
+
789
+ return X, y
790
+
791
+
792
+ # ==============================================================================
793
+ # WORKER INITIALIZATION
794
+ # ==============================================================================
795
+ def memmap_worker_init_fn(worker_id: int):
796
+ """
797
+ Worker initialization function for proper memmap handling in multi-worker DataLoader.
798
+
799
+ Each DataLoader worker process runs this function after forking. It:
800
+ 1. Resets the memmap file handle to None, forcing each worker to open its own
801
+ read-only handle (prevents file descriptor sharing issues and race conditions)
802
+ 2. Seeds numpy's random state per worker to ensure statistical diversity in
803
+ random augmentations (prevents all workers from applying identical "random"
804
+ transformations to their batches)
805
+
806
+ Args:
807
+ worker_id: Worker index (0 to num_workers-1), provided by DataLoader
808
+
809
+ Usage:
810
+ DataLoader(dataset, num_workers=8, worker_init_fn=memmap_worker_init_fn)
811
+ """
812
+ worker_info = torch.utils.data.get_worker_info()
813
+ if worker_info is not None:
814
+ dataset = worker_info.dataset
815
+ # Force re-initialization of memmap in each worker
816
+ dataset.data = None
817
+
818
+ # Seed numpy RNG per worker using PyTorch's worker seed for reproducibility
819
+ # This ensures random augmentations (noise, shifts, etc.) are unique per worker
820
+ np.random.seed(worker_info.seed % (2**32 - 1))
821
+
822
+
823
+ # ==============================================================================
824
+ # MEMORY-MAPPED DATASET
825
+ # ==============================================================================
826
+ class MemmapDataset(Dataset):
827
+ """
828
+ Zero-copy memory-mapped dataset for large-scale training.
829
+
830
+ Uses numpy memory mapping to load data directly from disk, allowing training
831
+ on datasets that exceed available RAM. The memmap is only opened when first
832
+ accessed (lazy initialization), and each DataLoader worker maintains its own
833
+ file handle for thread safety.
834
+
835
+ Args:
836
+ memmap_path: Path to the memory-mapped data file
837
+ targets: Pre-loaded target tensor (small enough to fit in memory)
838
+ shape: Full shape of the memmap array (N, C, H, W)
839
+ indices: Indices into the memmap for this split (train/val)
840
+
841
+ Thread Safety:
842
+ When using with DataLoader num_workers > 0, must use memmap_worker_init_fn
843
+ as the worker_init_fn to ensure each worker gets its own file handle.
844
+
845
+ Example:
846
+ dataset = MemmapDataset("cache.dat", y_tensor, (10000, 1, 500, 500), train_indices)
847
+ loader = DataLoader(dataset, num_workers=8, worker_init_fn=memmap_worker_init_fn)
848
+ """
849
+
850
+ def __init__(
851
+ self,
852
+ memmap_path: str,
853
+ targets: torch.Tensor,
854
+ shape: tuple[int, ...],
855
+ indices: np.ndarray,
856
+ ):
857
+ self.memmap_path = memmap_path
858
+ self.targets = targets
859
+ self.shape = shape
860
+ self.indices = indices
861
+ self.data: np.memmap | None = None # Lazy initialization
862
+
863
+ def __len__(self) -> int:
864
+ return len(self.indices)
865
+
866
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
867
+ if self.data is None:
868
+ # Mode 'r' = read-only, prevents accidental data modification
869
+ self.data = np.memmap(
870
+ self.memmap_path, dtype="float32", mode="r", shape=self.shape
871
+ )
872
+
873
+ real_idx = self.indices[idx]
874
+
875
+ # .copy() detaches from mmap buffer - essential for PyTorch pinned memory
876
+ x = torch.from_numpy(self.data[real_idx].copy()).contiguous()
877
+ y = self.targets[real_idx]
878
+
879
+ return x, y
880
+
881
+ def __repr__(self) -> str:
882
+ return (
883
+ f"MemmapDataset(path='{self.memmap_path}', "
884
+ f"samples={len(self)}, shape={self.shape})"
885
+ )
886
+
887
+
888
+ # ==============================================================================
889
+ # DATA PREPARATION
890
+ # ==============================================================================
891
+ def prepare_data(
892
+ args: Any,
893
+ logger: logging.Logger,
894
+ accelerator: Accelerator,
895
+ cache_dir: str = ".",
896
+ val_split: float = 0.2,
897
+ ) -> tuple[DataLoader, DataLoader, StandardScaler, tuple[int, ...], int]:
898
+ """
899
+ Prepare DataLoaders with DDP synchronization guarantees.
900
+
901
+ This function handles:
902
+ 1. Loading raw data and creating memmap cache (Rank 0 only)
903
+ 2. Fitting StandardScaler on training set only (no data leakage)
904
+ 3. Synchronizing all ranks before proceeding
905
+ 4. Creating thread-safe DataLoaders for DDP training
906
+
907
+ Supports any input dimensionality:
908
+ - 1D: (N, L) → returns in_shape = (L,)
909
+ - 2D: (N, H, W) → returns in_shape = (H, W)
910
+ - 3D: (N, D, H, W) → returns in_shape = (D, H, W)
911
+
912
+ Args:
913
+ args: Argument namespace with data_path, seed, batch_size, workers
914
+ logger: Logger instance for status messages
915
+ accelerator: Accelerator instance for DDP coordination
916
+ cache_dir: Directory for cache files (default: current directory)
917
+ val_split: Validation set fraction (default: 0.2)
918
+
919
+ Returns:
920
+ Tuple of:
921
+ - train_dl: Training DataLoader
922
+ - val_dl: Validation DataLoader
923
+ - scaler: Fitted StandardScaler (for inverse transforms)
924
+ - in_shape: Input spatial dimensions - (L,), (H, W), or (D, H, W)
925
+ - out_dim: Number of output targets
926
+
927
+ Cache Files Created:
928
+ - train_data_cache.dat: Memory-mapped input data
929
+ - scaler.pkl: Fitted StandardScaler
930
+ - data_metadata.pkl: Shape and dimension metadata
931
+ """
932
+ CACHE_FILE = os.path.join(cache_dir, "train_data_cache.dat")
933
+ SCALER_FILE = os.path.join(cache_dir, "scaler.pkl")
934
+ META_FILE = os.path.join(cache_dir, "data_metadata.pkl")
935
+
936
+ # ==========================================================================
937
+ # PHASE 1: DATA GENERATION (Rank 0 Only)
938
+ # ==========================================================================
939
+ # Check cache existence and validity (data path must match)
940
+ cache_exists = (
941
+ os.path.exists(CACHE_FILE)
942
+ and os.path.exists(SCALER_FILE)
943
+ and os.path.exists(META_FILE)
944
+ )
945
+
946
+ # Validate cache matches current data_path (prevents stale cache corruption)
947
+ if cache_exists:
948
+ try:
949
+ with open(META_FILE, "rb") as f:
950
+ meta = pickle.load(f)
951
+ cached_data_path = meta.get("data_path", None)
952
+ if cached_data_path != os.path.abspath(args.data_path):
953
+ if accelerator.is_main_process:
954
+ logger.warning(
955
+ f"⚠️ Cache was created from different data file!\n"
956
+ f" Cached: {cached_data_path}\n"
957
+ f" Current: {os.path.abspath(args.data_path)}\n"
958
+ f" Invalidating cache and regenerating..."
959
+ )
960
+ cache_exists = False
961
+ except Exception:
962
+ cache_exists = False
963
+
964
+ if not cache_exists:
965
+ if accelerator.is_main_process:
966
+ # RANK 0: Create cache (can take a long time for large datasets)
967
+ # Other ranks will wait at the barrier below
968
+
969
+ # Detect format from extension
970
+ data_format = DataSource.detect_format(args.data_path)
971
+ logger.info(
972
+ f"⚡ [Rank 0] Initializing Data Processing from: {args.data_path} (format: {data_format})"
973
+ )
974
+
975
+ # Validate data file exists
976
+ if not os.path.exists(args.data_path):
977
+ raise FileNotFoundError(
978
+ f"CRITICAL: Data file not found: {args.data_path}"
979
+ )
980
+
981
+ # Load raw data using memory-mapped mode for all formats
982
+ # This avoids loading the entire dataset into RAM at once
983
+ try:
984
+ if data_format == "npz":
985
+ source = NPZSource()
986
+ inp, outp = source.load_mmap(args.data_path)
987
+ elif data_format == "hdf5":
988
+ source = HDF5Source()
989
+ _lazy_handle = source.load_mmap(args.data_path)
990
+ inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
991
+ elif data_format == "mat":
992
+ source = MATSource()
993
+ _lazy_handle = source.load_mmap(args.data_path)
994
+ inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
995
+ else:
996
+ inp, outp = load_training_data(args.data_path, format=data_format)
997
+ logger.info(" Using memory-mapped loading (low memory mode)")
998
+ except Exception as e:
999
+ logger.error(f"Failed to load data file: {e}")
1000
+ raise
1001
+
1002
+ # Detect shape (handle sparse matrices) - DIMENSION AGNOSTIC
1003
+ num_samples = len(inp)
1004
+
1005
+ # Handle 1D targets: (N,) -> treat as single output
1006
+ if outp.ndim == 1:
1007
+ out_dim = 1
1008
+ else:
1009
+ out_dim = outp.shape[1]
1010
+
1011
+ sample_0 = inp[0]
1012
+ if issparse(sample_0) or hasattr(sample_0, "toarray"):
1013
+ sample_0 = sample_0.toarray()
1014
+
1015
+ # Detect dimensionality and validate single-channel assumption
1016
+ raw_shape = sample_0.shape
1017
+
1018
+ # Heuristic for ambiguous multi-channel detection:
1019
+ # Trigger when shape is EXACTLY 3D (could be C,H,W or D,H,W) with:
1020
+ # - First dim small (<=16) - looks like channels
1021
+ # - BOTH remaining dims large (>16) - confirming it's an image, not a tiny patch
1022
+ # This distinguishes (3, 256, 256) from (8, 1024) or (128, 128)
1023
+ # Use --single_channel to confirm shallow 3D volumes like (8, 128, 128)
1024
+ is_ambiguous_shape = (
1025
+ len(raw_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
1026
+ and raw_shape[0] <= 16 # First dim looks like channels
1027
+ and raw_shape[1] > 16
1028
+ and raw_shape[2] > 16 # Both spatial dims are large
1029
+ )
1030
+
1031
+ # Check for user confirmation via --single_channel flag
1032
+ user_confirmed_single_channel = getattr(args, "single_channel", False)
1033
+
1034
+ if is_ambiguous_shape and not user_confirmed_single_channel:
1035
+ raise ValueError(
1036
+ f"Ambiguous input shape detected: sample shape {raw_shape}. "
1037
+ f"This could be either:\n"
1038
+ f" - Multi-channel 2D data (C={raw_shape[0]}, H={raw_shape[1]}, W={raw_shape[2]})\n"
1039
+ f" - Single-channel 3D volume (D={raw_shape[0]}, H={raw_shape[1]}, W={raw_shape[2]})\n\n"
1040
+ f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
1041
+ f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
1042
+ )
1043
+
1044
+ spatial_shape = raw_shape
1045
+ full_shape = (
1046
+ num_samples,
1047
+ 1,
1048
+ ) + spatial_shape # Add channel dim: (N, 1, ...)
1049
+
1050
+ dim_names = {1: "1D (L)", 2: "2D (H, W)", 3: "3D (D, H, W)"}
1051
+ dim_type = dim_names.get(len(spatial_shape), f"{len(spatial_shape)}D")
1052
+ logger.info(
1053
+ f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
1054
+ )
1055
+
1056
+ # Save metadata (including data path for cache validation)
1057
+ with open(META_FILE, "wb") as f:
1058
+ pickle.dump(
1059
+ {
1060
+ "shape": full_shape,
1061
+ "out_dim": out_dim,
1062
+ "data_path": os.path.abspath(args.data_path),
1063
+ },
1064
+ f,
1065
+ )
1066
+
1067
+ # Create memmap cache
1068
+ if not os.path.exists(CACHE_FILE):
1069
+ logger.info(" Writing Memmap Cache (one-time operation)...")
1070
+ fp = np.memmap(CACHE_FILE, dtype="float32", mode="w+", shape=full_shape)
1071
+
1072
+ chunk_size = 2000
1073
+ pbar = tqdm(
1074
+ range(0, num_samples, chunk_size),
1075
+ desc="Caching",
1076
+ disable=not accelerator.is_main_process,
1077
+ )
1078
+
1079
+ for i in pbar:
1080
+ end = min(i + chunk_size, num_samples)
1081
+ batch = inp[i:end]
1082
+
1083
+ # Handle sparse/dense conversion
1084
+ if issparse(batch[0]) or hasattr(batch[0], "toarray"):
1085
+ data_chunk = np.stack(
1086
+ [x.toarray().astype(np.float32) for x in batch]
1087
+ )
1088
+ else:
1089
+ data_chunk = np.array(batch).astype(np.float32)
1090
+
1091
+ # Add channel dimension if needed (handles 1D, 2D, and 3D spatial data)
1092
+ # data_chunk shape: (batch, *spatial) -> need (batch, 1, *spatial)
1093
+ # full_shape is (N, 1, *spatial), so expected ndim = len(full_shape)
1094
+ if data_chunk.ndim == len(full_shape) - 1:
1095
+ # Missing channel dim: (batch, *spatial) -> (batch, 1, *spatial)
1096
+ data_chunk = np.expand_dims(data_chunk, 1)
1097
+
1098
+ fp[i:end] = data_chunk
1099
+
1100
+ # Periodic flush to disk
1101
+ if i % 10000 == 0:
1102
+ fp.flush()
1103
+
1104
+ fp.flush()
1105
+ del fp
1106
+ gc.collect()
1107
+
1108
+ # Train/Val split and scaler fitting
1109
+ indices = np.arange(num_samples)
1110
+ tr_idx, val_idx = train_test_split(
1111
+ indices, test_size=val_split, random_state=args.seed
1112
+ )
1113
+
1114
+ if not os.path.exists(SCALER_FILE):
1115
+ logger.info(" Fitting StandardScaler (training set only)...")
1116
+
1117
+ # Convert lazy datasets to numpy for reliable indexing
1118
+ # (h5py and _TransposedH5Dataset may not support fancy indexing)
1119
+ if hasattr(outp, "_dataset") or hasattr(outp, "file"):
1120
+ # Lazy h5py or _TransposedH5Dataset - load training subset
1121
+ outp_train = np.array([outp[i] for i in tr_idx])
1122
+ else:
1123
+ # Already numpy array
1124
+ outp_train = outp[tr_idx]
1125
+
1126
+ # Ensure 2D for StandardScaler: (N,) -> (N, 1)
1127
+ if outp_train.ndim == 1:
1128
+ outp_train = outp_train.reshape(-1, 1)
1129
+
1130
+ scaler = StandardScaler()
1131
+ scaler.fit(outp_train)
1132
+ with open(SCALER_FILE, "wb") as f:
1133
+ pickle.dump(scaler, f)
1134
+
1135
+ # Cleanup: close file handles BEFORE deleting references
1136
+ if "_lazy_handle" in dir() and _lazy_handle is not None:
1137
+ try:
1138
+ _lazy_handle.close()
1139
+ except Exception:
1140
+ pass
1141
+ del inp, outp
1142
+ gc.collect()
1143
+
1144
+ logger.info(" ✔ Cache creation complete, synchronizing ranks...")
1145
+ else:
1146
+ # NON-MAIN RANKS: Wait for cache creation
1147
+ # Log that we're waiting (helps with debugging)
1148
+ import time
1149
+
1150
+ wait_start = time.time()
1151
+ while not (
1152
+ os.path.exists(CACHE_FILE)
1153
+ and os.path.exists(SCALER_FILE)
1154
+ and os.path.exists(META_FILE)
1155
+ ):
1156
+ time.sleep(5) # Check every 5 seconds
1157
+ elapsed = time.time() - wait_start
1158
+ if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
1159
+ logger.info(
1160
+ f" [Rank {accelerator.process_index}] Waiting for cache creation... ({int(elapsed)}s)"
1161
+ )
1162
+ # Small delay to ensure files are fully written
1163
+ time.sleep(2)
1164
+
1165
+ # ==========================================================================
1166
+ # PHASE 2: SYNCHRONIZED LOADING (All Ranks)
1167
+ # ==========================================================================
1168
+ accelerator.wait_for_everyone()
1169
+
1170
+ # Load metadata
1171
+ with open(META_FILE, "rb") as f:
1172
+ meta = pickle.load(f)
1173
+ full_shape = meta["shape"]
1174
+ out_dim = meta["out_dim"]
1175
+
1176
+ # Load and validate scaler
1177
+ with open(SCALER_FILE, "rb") as f:
1178
+ scaler = pickle.load(f)
1179
+
1180
+ if not hasattr(scaler, "scale_") or scaler.scale_ is None:
1181
+ raise RuntimeError("CRITICAL: Scaler is not properly fitted (scale_ is None)")
1182
+
1183
+ # Load targets only (memory-efficient - avoids loading large input arrays)
1184
+ # This is critical for HPC environments with memory constraints during DDP
1185
+ outp = load_outputs_only(args.data_path)
1186
+
1187
+ # Ensure 2D for StandardScaler: (N,) -> (N, 1)
1188
+ if outp.ndim == 1:
1189
+ outp = outp.reshape(-1, 1)
1190
+
1191
+ y_scaled = scaler.transform(outp).astype(np.float32)
1192
+ y_tensor = torch.tensor(y_scaled)
1193
+
1194
+ # Regenerate indices (deterministic with same seed)
1195
+ indices = np.arange(full_shape[0])
1196
+ tr_idx, val_idx = train_test_split(
1197
+ indices, test_size=val_split, random_state=args.seed
1198
+ )
1199
+
1200
+ # Create datasets
1201
+ tr_ds = MemmapDataset(CACHE_FILE, y_tensor, full_shape, tr_idx)
1202
+ val_ds = MemmapDataset(CACHE_FILE, y_tensor, full_shape, val_idx)
1203
+
1204
+ # Create DataLoaders with thread-safe configuration
1205
+ loader_kwargs = {
1206
+ "batch_size": args.batch_size,
1207
+ "num_workers": args.workers,
1208
+ "pin_memory": True,
1209
+ "persistent_workers": (args.workers > 0),
1210
+ "prefetch_factor": 2 if args.workers > 0 else None,
1211
+ "worker_init_fn": memmap_worker_init_fn if args.workers > 0 else None,
1212
+ }
1213
+
1214
+ train_dl = DataLoader(tr_ds, shuffle=True, **loader_kwargs)
1215
+ val_dl = DataLoader(val_ds, shuffle=False, **loader_kwargs)
1216
+
1217
+ # Return spatial shape (H, W) for model initialization
1218
+ in_shape = full_shape[2:]
1219
+
1220
+ return train_dl, val_dl, scaler, in_shape, out_dim