wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
wavedl/utils/data.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data Loading and Preprocessing Utilities
|
|
3
|
+
=========================================
|
|
4
|
+
|
|
5
|
+
Provides memory-efficient data loading for large-scale datasets with:
|
|
6
|
+
- Memory-mapped file support for datasets exceeding RAM
|
|
7
|
+
- DDP-safe data preparation with proper synchronization
|
|
8
|
+
- Thread-safe DataLoader worker initialization
|
|
9
|
+
- Multi-format support (NPZ, HDF5, MAT)
|
|
10
|
+
|
|
11
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
12
|
+
Version: 1.0.0
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import gc
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import pickle
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import h5py
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
from accelerate import Accelerator
|
|
27
|
+
from scipy.sparse import issparse
|
|
28
|
+
from sklearn.model_selection import train_test_split
|
|
29
|
+
from sklearn.preprocessing import StandardScaler
|
|
30
|
+
from torch.utils.data import DataLoader, Dataset
|
|
31
|
+
from tqdm.auto import tqdm
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Optional scipy.io for MATLAB files
|
|
35
|
+
try:
|
|
36
|
+
import scipy.io
|
|
37
|
+
|
|
38
|
+
HAS_SCIPY_IO = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
HAS_SCIPY_IO = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ==============================================================================
|
|
44
|
+
# DATA SOURCE ABSTRACTION
|
|
45
|
+
# ==============================================================================
|
|
46
|
+
|
|
47
|
+
# Supported key names for input/output arrays (priority order, pairwise aligned)
|
|
48
|
+
INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x"]
|
|
49
|
+
OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LazyDataHandle:
|
|
53
|
+
"""
|
|
54
|
+
Context manager wrapper for memory-mapped data handles.
|
|
55
|
+
|
|
56
|
+
Provides proper cleanup of file handles returned by load_mmap() methods.
|
|
57
|
+
Can be used either as a context manager (recommended) or with explicit close().
|
|
58
|
+
|
|
59
|
+
Usage:
|
|
60
|
+
# Context manager (recommended)
|
|
61
|
+
with source.load_mmap(path) as (inputs, outputs):
|
|
62
|
+
# Use inputs and outputs
|
|
63
|
+
pass # File automatically closed
|
|
64
|
+
|
|
65
|
+
# Manual cleanup
|
|
66
|
+
handle = source.load_mmap(path)
|
|
67
|
+
inputs, outputs = handle.inputs, handle.outputs
|
|
68
|
+
# ... use data ...
|
|
69
|
+
handle.close()
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
inputs: Input data array/dataset
|
|
73
|
+
outputs: Output data array/dataset
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, inputs, outputs, file_handle=None):
|
|
77
|
+
"""
|
|
78
|
+
Initialize the handle.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
inputs: Input array or lazy dataset
|
|
82
|
+
outputs: Output array or lazy dataset
|
|
83
|
+
file_handle: Optional file handle to close on cleanup
|
|
84
|
+
"""
|
|
85
|
+
self.inputs = inputs
|
|
86
|
+
self.outputs = outputs
|
|
87
|
+
self._file = file_handle
|
|
88
|
+
self._closed = False
|
|
89
|
+
|
|
90
|
+
def __enter__(self):
|
|
91
|
+
"""Return (inputs, outputs) tuple for unpacking."""
|
|
92
|
+
return self.inputs, self.outputs
|
|
93
|
+
|
|
94
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
95
|
+
"""Close file handle on context exit."""
|
|
96
|
+
self.close()
|
|
97
|
+
return False # Don't suppress exceptions
|
|
98
|
+
|
|
99
|
+
def close(self):
|
|
100
|
+
"""
|
|
101
|
+
Close the underlying file handle.
|
|
102
|
+
|
|
103
|
+
Safe to call multiple times.
|
|
104
|
+
"""
|
|
105
|
+
if self._closed:
|
|
106
|
+
return
|
|
107
|
+
self._closed = True
|
|
108
|
+
|
|
109
|
+
# Close the file handle if we have one
|
|
110
|
+
if self._file is not None:
|
|
111
|
+
try:
|
|
112
|
+
self._file.close()
|
|
113
|
+
except Exception:
|
|
114
|
+
pass
|
|
115
|
+
self._file = None
|
|
116
|
+
|
|
117
|
+
# Also close any _TransposedH5Dataset wrappers
|
|
118
|
+
for data in (self.inputs, self.outputs):
|
|
119
|
+
if hasattr(data, "close"):
|
|
120
|
+
try:
|
|
121
|
+
data.close()
|
|
122
|
+
except Exception:
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
def __repr__(self) -> str:
|
|
126
|
+
status = "closed" if self._closed else "open"
|
|
127
|
+
return f"LazyDataHandle(status={status})"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class DataSource(ABC):
|
|
131
|
+
"""
|
|
132
|
+
Abstract base class for data loaders supporting multiple file formats.
|
|
133
|
+
|
|
134
|
+
Subclasses must implement the `load()` method to return input/output arrays,
|
|
135
|
+
and optionally `load_outputs_only()` for memory-efficient target loading.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
@abstractmethod
|
|
139
|
+
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
140
|
+
"""
|
|
141
|
+
Load input and output arrays from a file.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
path: Path to the data file
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Tuple of (inputs, outputs) as numpy arrays
|
|
148
|
+
"""
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
@abstractmethod
|
|
152
|
+
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
153
|
+
"""
|
|
154
|
+
Load only output/target arrays from a file (memory-efficient).
|
|
155
|
+
|
|
156
|
+
This avoids loading large input arrays when only targets are needed,
|
|
157
|
+
which is critical for HPC environments with memory constraints.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path: Path to the data file
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Output/target array
|
|
164
|
+
"""
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def detect_format(path: str) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Auto-detect file format from extension.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
path: Path to data file
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Format string: 'npz', 'hdf5', or 'mat'
|
|
177
|
+
"""
|
|
178
|
+
ext = Path(path).suffix.lower()
|
|
179
|
+
format_map = {
|
|
180
|
+
".npz": "npz",
|
|
181
|
+
".h5": "hdf5",
|
|
182
|
+
".hdf5": "hdf5",
|
|
183
|
+
".mat": "mat",
|
|
184
|
+
}
|
|
185
|
+
if ext not in format_map:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Unsupported file extension: '{ext}'. "
|
|
188
|
+
f"Supported formats: .npz, .h5, .hdf5, .mat"
|
|
189
|
+
)
|
|
190
|
+
return format_map[ext]
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _find_key(available_keys: list[str], candidates: list[str]) -> str | None:
|
|
194
|
+
"""Find first matching key from candidates in available keys."""
|
|
195
|
+
for key in candidates:
|
|
196
|
+
if key in available_keys:
|
|
197
|
+
return key
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class NPZSource(DataSource):
|
|
202
|
+
"""Load data from NumPy .npz archives."""
|
|
203
|
+
|
|
204
|
+
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
205
|
+
data = np.load(path, allow_pickle=True)
|
|
206
|
+
keys = list(data.keys())
|
|
207
|
+
|
|
208
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
209
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
210
|
+
|
|
211
|
+
if input_key is None or output_key is None:
|
|
212
|
+
raise KeyError(
|
|
213
|
+
f"NPZ must contain input and output arrays. "
|
|
214
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
215
|
+
f"Found: {keys}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
inp = data[input_key]
|
|
219
|
+
outp = data[output_key]
|
|
220
|
+
|
|
221
|
+
# Handle object arrays (e.g., sparse matrices stored as objects)
|
|
222
|
+
if inp.dtype == object:
|
|
223
|
+
inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
|
|
224
|
+
|
|
225
|
+
return inp, outp
|
|
226
|
+
|
|
227
|
+
def load_mmap(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
228
|
+
"""
|
|
229
|
+
Load data using memory-mapped mode for zero-copy access.
|
|
230
|
+
|
|
231
|
+
This allows processing large datasets without loading them entirely
|
|
232
|
+
into RAM. Critical for HPC environments with memory constraints.
|
|
233
|
+
|
|
234
|
+
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
235
|
+
"""
|
|
236
|
+
data = np.load(path, allow_pickle=True, mmap_mode="r")
|
|
237
|
+
keys = list(data.keys())
|
|
238
|
+
|
|
239
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
240
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
241
|
+
|
|
242
|
+
if input_key is None or output_key is None:
|
|
243
|
+
raise KeyError(
|
|
244
|
+
f"NPZ must contain input and output arrays. "
|
|
245
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
246
|
+
f"Found: {keys}"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
inp = data[input_key]
|
|
250
|
+
outp = data[output_key]
|
|
251
|
+
|
|
252
|
+
return inp, outp
|
|
253
|
+
|
|
254
|
+
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
255
|
+
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
256
|
+
data = np.load(path, allow_pickle=True)
|
|
257
|
+
keys = list(data.keys())
|
|
258
|
+
|
|
259
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
260
|
+
if output_key is None:
|
|
261
|
+
raise KeyError(
|
|
262
|
+
f"NPZ must contain output array. "
|
|
263
|
+
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return data[output_key]
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class HDF5Source(DataSource):
|
|
270
|
+
"""Load data from HDF5 (.h5, .hdf5) files."""
|
|
271
|
+
|
|
272
|
+
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
273
|
+
with h5py.File(path, "r") as f:
|
|
274
|
+
keys = list(f.keys())
|
|
275
|
+
|
|
276
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
277
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
278
|
+
|
|
279
|
+
if input_key is None or output_key is None:
|
|
280
|
+
raise KeyError(
|
|
281
|
+
f"HDF5 must contain input and output datasets. "
|
|
282
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
283
|
+
f"Found: {keys}"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Load into memory (HDF5 datasets are lazy by default)
|
|
287
|
+
inp = f[input_key][:]
|
|
288
|
+
outp = f[output_key][:]
|
|
289
|
+
|
|
290
|
+
return inp, outp
|
|
291
|
+
|
|
292
|
+
def load_mmap(self, path: str) -> LazyDataHandle:
|
|
293
|
+
"""
|
|
294
|
+
Load HDF5 file with lazy/memory-mapped access.
|
|
295
|
+
|
|
296
|
+
Returns a LazyDataHandle that reads from disk on-demand,
|
|
297
|
+
avoiding loading the entire file into RAM.
|
|
298
|
+
|
|
299
|
+
Usage:
|
|
300
|
+
with source.load_mmap(path) as (inputs, outputs):
|
|
301
|
+
# Use inputs and outputs
|
|
302
|
+
pass # File automatically closed
|
|
303
|
+
"""
|
|
304
|
+
f = h5py.File(path, "r") # Keep file open for lazy access
|
|
305
|
+
keys = list(f.keys())
|
|
306
|
+
|
|
307
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
308
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
309
|
+
|
|
310
|
+
if input_key is None or output_key is None:
|
|
311
|
+
f.close()
|
|
312
|
+
raise KeyError(
|
|
313
|
+
f"HDF5 must contain input and output datasets. "
|
|
314
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
315
|
+
f"Found: {keys}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Return wrapped handle for proper cleanup
|
|
319
|
+
return LazyDataHandle(f[input_key], f[output_key], file_handle=f)
|
|
320
|
+
|
|
321
|
+
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
322
|
+
"""Load only targets from HDF5 (avoids loading large input arrays)."""
|
|
323
|
+
with h5py.File(path, "r") as f:
|
|
324
|
+
keys = list(f.keys())
|
|
325
|
+
|
|
326
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
327
|
+
if output_key is None:
|
|
328
|
+
raise KeyError(
|
|
329
|
+
f"HDF5 must contain output dataset. "
|
|
330
|
+
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
outp = f[output_key][:]
|
|
334
|
+
|
|
335
|
+
return outp
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class _TransposedH5Dataset:
|
|
339
|
+
"""
|
|
340
|
+
Lazy transpose wrapper for h5py datasets.
|
|
341
|
+
|
|
342
|
+
MATLAB stores arrays in column-major (Fortran) order, while Python/NumPy
|
|
343
|
+
expects row-major (C) order. This wrapper provides a transposed view
|
|
344
|
+
without loading the entire dataset into memory.
|
|
345
|
+
|
|
346
|
+
Supports:
|
|
347
|
+
- len(): Returns the transposed first dimension
|
|
348
|
+
- []: Returns slices with automatic transpose
|
|
349
|
+
- shape: Returns the transposed shape
|
|
350
|
+
- dtype: Returns the underlying dtype
|
|
351
|
+
|
|
352
|
+
This is critical for MATSource.load_mmap() to return consistent axis
|
|
353
|
+
ordering with the eager loader (MATSource.load()).
|
|
354
|
+
|
|
355
|
+
IMPORTANT: Holds a strong reference to the h5py.File to prevent
|
|
356
|
+
premature garbage collection while datasets are in use.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
def __init__(self, h5_dataset, file_handle=None):
|
|
360
|
+
"""
|
|
361
|
+
Args:
|
|
362
|
+
h5_dataset: The h5py dataset to wrap
|
|
363
|
+
file_handle: Optional h5py.File reference to keep alive
|
|
364
|
+
"""
|
|
365
|
+
self._dataset = h5_dataset
|
|
366
|
+
self._file = file_handle # Keep file alive to prevent GC
|
|
367
|
+
# Transpose shape: MATLAB (cols, rows, ...) -> Python (rows, cols, ...)
|
|
368
|
+
self.shape = tuple(reversed(h5_dataset.shape))
|
|
369
|
+
self.dtype = h5_dataset.dtype
|
|
370
|
+
|
|
371
|
+
# Precompute transpose axis order for efficiency
|
|
372
|
+
# For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0)
|
|
373
|
+
self._transpose_axes = tuple(range(len(h5_dataset.shape) - 1, -1, -1))
|
|
374
|
+
|
|
375
|
+
def __len__(self) -> int:
|
|
376
|
+
return self.shape[0]
|
|
377
|
+
|
|
378
|
+
def __getitem__(self, idx):
|
|
379
|
+
"""
|
|
380
|
+
Fetch data with automatic full transpose.
|
|
381
|
+
|
|
382
|
+
Handles integer indexing, slices, and fancy indexing.
|
|
383
|
+
All operations return data with fully reversed axes to match .T behavior.
|
|
384
|
+
"""
|
|
385
|
+
if isinstance(idx, int | np.integer):
|
|
386
|
+
# Single sample: index into last axis of h5py dataset (column-major)
|
|
387
|
+
# Result needs full transpose of remaining dimensions
|
|
388
|
+
data = self._dataset[..., idx]
|
|
389
|
+
if data.ndim == 0:
|
|
390
|
+
return data
|
|
391
|
+
elif data.ndim == 1:
|
|
392
|
+
return data # 1D doesn't need transpose
|
|
393
|
+
else:
|
|
394
|
+
# Full transpose: reverse all axes
|
|
395
|
+
return np.transpose(data)
|
|
396
|
+
|
|
397
|
+
elif isinstance(idx, slice):
|
|
398
|
+
# Slice indexing: fetch from last axis, then fully transpose
|
|
399
|
+
start, stop, step = idx.indices(self.shape[0])
|
|
400
|
+
data = self._dataset[..., start:stop:step]
|
|
401
|
+
|
|
402
|
+
# Handle special case: 1D result (e.g., row vector)
|
|
403
|
+
if data.ndim == 1:
|
|
404
|
+
return data
|
|
405
|
+
|
|
406
|
+
# Full transpose: reverse ALL axes (not just moveaxis)
|
|
407
|
+
# This matches the behavior of .T on a numpy array
|
|
408
|
+
return np.transpose(data, axes=self._transpose_axes)
|
|
409
|
+
|
|
410
|
+
elif isinstance(idx, list | np.ndarray):
|
|
411
|
+
# Fancy indexing: load samples one at a time (h5py limitation)
|
|
412
|
+
# This is slower but necessary for compatibility
|
|
413
|
+
samples = [self[i] for i in idx]
|
|
414
|
+
return np.stack(samples, axis=0)
|
|
415
|
+
|
|
416
|
+
else:
|
|
417
|
+
raise TypeError(f"Unsupported index type: {type(idx)}")
|
|
418
|
+
|
|
419
|
+
def close(self):
|
|
420
|
+
"""Close the underlying file handle if we own it."""
|
|
421
|
+
if self._file is not None:
|
|
422
|
+
try:
|
|
423
|
+
self._file.close()
|
|
424
|
+
except Exception:
|
|
425
|
+
pass
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class MATSource(DataSource):
|
|
429
|
+
"""
|
|
430
|
+
Load data from MATLAB .mat files (v7.3+ only, which uses HDF5 format).
|
|
431
|
+
|
|
432
|
+
Note: MAT v7.3 files are HDF5 files under the hood, so we use h5py for
|
|
433
|
+
memory-efficient lazy loading. Save with: save('file.mat', '-v7.3')
|
|
434
|
+
|
|
435
|
+
Supports MATLAB sparse matrices (automatically converted to dense).
|
|
436
|
+
|
|
437
|
+
For older MAT files (v5/v7), convert to NPZ or save with -v7.3 flag.
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
@staticmethod
|
|
441
|
+
def _is_sparse_dataset(dataset) -> bool:
|
|
442
|
+
"""Check if an HDF5 dataset/group represents a MATLAB sparse matrix."""
|
|
443
|
+
# MATLAB v7.3 stores sparse matrices as groups with 'data', 'ir', 'jc' keys
|
|
444
|
+
if hasattr(dataset, "keys"):
|
|
445
|
+
keys = set(dataset.keys())
|
|
446
|
+
return {"data", "ir", "jc"}.issubset(keys)
|
|
447
|
+
return False
|
|
448
|
+
|
|
449
|
+
@staticmethod
|
|
450
|
+
def _load_sparse_to_dense(group) -> np.ndarray:
|
|
451
|
+
"""Convert MATLAB sparse matrix (CSC format in HDF5) to dense numpy array."""
|
|
452
|
+
from scipy.sparse import csc_matrix
|
|
453
|
+
|
|
454
|
+
data = np.array(group["data"])
|
|
455
|
+
ir = np.array(group["ir"]) # row indices
|
|
456
|
+
jc = np.array(group["jc"]) # column pointers
|
|
457
|
+
|
|
458
|
+
# Get shape from MATLAB attributes or infer
|
|
459
|
+
if "MATLAB_sparse" in group.attrs:
|
|
460
|
+
nrows = group.attrs["MATLAB_sparse"]
|
|
461
|
+
else:
|
|
462
|
+
nrows = ir.max() + 1 if len(ir) > 0 else 0
|
|
463
|
+
ncols = len(jc) - 1
|
|
464
|
+
|
|
465
|
+
sparse_mat = csc_matrix((data, ir, jc), shape=(nrows, ncols))
|
|
466
|
+
return sparse_mat.toarray()
|
|
467
|
+
|
|
468
|
+
def _load_dataset(self, f, key: str) -> np.ndarray:
|
|
469
|
+
"""Load a dataset, handling sparse matrices automatically."""
|
|
470
|
+
dataset = f[key]
|
|
471
|
+
|
|
472
|
+
if self._is_sparse_dataset(dataset):
|
|
473
|
+
# Sparse matrix: convert to dense
|
|
474
|
+
arr = self._load_sparse_to_dense(dataset)
|
|
475
|
+
else:
|
|
476
|
+
# Regular dense array
|
|
477
|
+
arr = np.array(dataset)
|
|
478
|
+
|
|
479
|
+
# Transpose for MATLAB column-major -> Python row-major
|
|
480
|
+
return arr.T
|
|
481
|
+
|
|
482
|
+
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
483
|
+
"""Load MAT v7.3 file using h5py."""
|
|
484
|
+
try:
|
|
485
|
+
with h5py.File(path, "r") as f:
|
|
486
|
+
keys = list(f.keys())
|
|
487
|
+
|
|
488
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
489
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
490
|
+
|
|
491
|
+
if input_key is None or output_key is None:
|
|
492
|
+
raise KeyError(
|
|
493
|
+
f"MAT file must contain input and output arrays. "
|
|
494
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
495
|
+
f"Found: {keys}"
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Load with sparse matrix support
|
|
499
|
+
inp = self._load_dataset(f, input_key)
|
|
500
|
+
outp = self._load_dataset(f, output_key)
|
|
501
|
+
|
|
502
|
+
# Handle 1D outputs that become (1, N) after transpose
|
|
503
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
504
|
+
outp = outp.T
|
|
505
|
+
|
|
506
|
+
except OSError as e:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
f"Failed to load MAT file: {path}. "
|
|
509
|
+
f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
|
|
510
|
+
f"Original error: {e}"
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
return inp, outp
|
|
514
|
+
|
|
515
|
+
def load_mmap(self, path: str) -> LazyDataHandle:
|
|
516
|
+
"""
|
|
517
|
+
Load MAT v7.3 file with lazy/memory-mapped access.
|
|
518
|
+
|
|
519
|
+
Returns a LazyDataHandle that reads from disk on-demand,
|
|
520
|
+
avoiding loading the entire file into RAM.
|
|
521
|
+
|
|
522
|
+
Note: For sparse matrices, this will load and convert them.
|
|
523
|
+
For dense arrays, returns a transposed view wrapper for consistent axis ordering.
|
|
524
|
+
|
|
525
|
+
Usage:
|
|
526
|
+
with source.load_mmap(path) as (inputs, outputs):
|
|
527
|
+
# Use inputs and outputs
|
|
528
|
+
pass # File automatically closed
|
|
529
|
+
"""
|
|
530
|
+
try:
|
|
531
|
+
f = h5py.File(path, "r") # Keep file open for lazy access
|
|
532
|
+
keys = list(f.keys())
|
|
533
|
+
|
|
534
|
+
input_key = self._find_key(keys, INPUT_KEYS)
|
|
535
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
536
|
+
|
|
537
|
+
if input_key is None or output_key is None:
|
|
538
|
+
f.close()
|
|
539
|
+
raise KeyError(
|
|
540
|
+
f"MAT file must contain input and output arrays. "
|
|
541
|
+
f"Supported keys: {INPUT_KEYS} / {OUTPUT_KEYS}. "
|
|
542
|
+
f"Found: {keys}"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Check for sparse matrices - must load them eagerly
|
|
546
|
+
inp_dataset = f[input_key]
|
|
547
|
+
outp_dataset = f[output_key]
|
|
548
|
+
|
|
549
|
+
if self._is_sparse_dataset(inp_dataset):
|
|
550
|
+
inp = self._load_sparse_to_dense(inp_dataset).T
|
|
551
|
+
else:
|
|
552
|
+
# Wrap h5py dataset with transpose view for consistent axis order
|
|
553
|
+
# MATLAB stores column-major, Python expects row-major
|
|
554
|
+
# Pass file handle to keep it alive
|
|
555
|
+
inp = _TransposedH5Dataset(inp_dataset, file_handle=f)
|
|
556
|
+
|
|
557
|
+
if self._is_sparse_dataset(outp_dataset):
|
|
558
|
+
outp = self._load_sparse_to_dense(outp_dataset).T
|
|
559
|
+
else:
|
|
560
|
+
# Wrap h5py dataset with transpose view (shares same file handle)
|
|
561
|
+
outp = _TransposedH5Dataset(outp_dataset, file_handle=f)
|
|
562
|
+
|
|
563
|
+
# Return wrapped handle for proper cleanup
|
|
564
|
+
return LazyDataHandle(inp, outp, file_handle=f)
|
|
565
|
+
|
|
566
|
+
except OSError as e:
|
|
567
|
+
raise ValueError(
|
|
568
|
+
f"Failed to load MAT file: {path}. "
|
|
569
|
+
f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
|
|
570
|
+
f"Original error: {e}"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
574
|
+
"""Load only targets from MAT v7.3 file (avoids loading large input arrays)."""
|
|
575
|
+
try:
|
|
576
|
+
with h5py.File(path, "r") as f:
|
|
577
|
+
keys = list(f.keys())
|
|
578
|
+
|
|
579
|
+
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
580
|
+
if output_key is None:
|
|
581
|
+
raise KeyError(
|
|
582
|
+
f"MAT file must contain output array. "
|
|
583
|
+
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Load with sparse matrix support
|
|
587
|
+
outp = self._load_dataset(f, output_key)
|
|
588
|
+
|
|
589
|
+
# Handle 1D outputs
|
|
590
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
591
|
+
outp = outp.T
|
|
592
|
+
|
|
593
|
+
except OSError as e:
|
|
594
|
+
raise ValueError(
|
|
595
|
+
f"Failed to load MAT file: {path}. "
|
|
596
|
+
f"Ensure it's saved as v7.3: save('file.mat', '-v7.3'). "
|
|
597
|
+
f"Original error: {e}"
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return outp
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def get_data_source(format: str) -> DataSource:
|
|
604
|
+
"""
|
|
605
|
+
Factory function to get the appropriate DataSource for a format.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
format: One of 'npz', 'hdf5', 'mat'
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
DataSource instance
|
|
612
|
+
"""
|
|
613
|
+
sources = {
|
|
614
|
+
"npz": NPZSource,
|
|
615
|
+
"hdf5": HDF5Source,
|
|
616
|
+
"mat": MATSource,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
if format not in sources:
|
|
620
|
+
raise ValueError(
|
|
621
|
+
f"Unsupported format: {format}. Supported: {list(sources.keys())}"
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return sources[format]()
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def load_training_data(
|
|
628
|
+
path: str, format: str = "auto"
|
|
629
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
630
|
+
"""
|
|
631
|
+
Load training data from file with automatic format detection.
|
|
632
|
+
|
|
633
|
+
Supports:
|
|
634
|
+
- NPZ: NumPy compressed archives (.npz)
|
|
635
|
+
- HDF5: Hierarchical Data Format (.h5, .hdf5)
|
|
636
|
+
- MAT: MATLAB files (.mat)
|
|
637
|
+
|
|
638
|
+
Flexible key detection supports: input_train/X/data and output_train/y/labels.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
path: Path to data file
|
|
642
|
+
format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Tuple of (inputs, outputs) arrays
|
|
646
|
+
"""
|
|
647
|
+
if format == "auto":
|
|
648
|
+
format = DataSource.detect_format(path)
|
|
649
|
+
|
|
650
|
+
source = get_data_source(format)
|
|
651
|
+
return source.load(path)
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
|
|
655
|
+
"""
|
|
656
|
+
Load only output/target arrays from file (memory-efficient).
|
|
657
|
+
|
|
658
|
+
This function avoids loading large input arrays when only targets are needed,
|
|
659
|
+
which is critical for HPC environments with memory constraints during DDP.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
path: Path to data file
|
|
663
|
+
format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
Output/target array
|
|
667
|
+
"""
|
|
668
|
+
if format == "auto":
|
|
669
|
+
format = DataSource.detect_format(path)
|
|
670
|
+
|
|
671
|
+
source = get_data_source(format)
|
|
672
|
+
return source.load_outputs_only(path)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def load_test_data(
|
|
676
|
+
path: str,
|
|
677
|
+
format: str = "auto",
|
|
678
|
+
input_key: str | None = None,
|
|
679
|
+
output_key: str | None = None,
|
|
680
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
681
|
+
"""
|
|
682
|
+
Load test/inference data and return PyTorch tensors ready for model input.
|
|
683
|
+
|
|
684
|
+
This is the unified data loading function for inference. It:
|
|
685
|
+
- Auto-detects file format from extension
|
|
686
|
+
- Handles custom key names for non-standard datasets
|
|
687
|
+
- Adds channel dimension if missing (dimension-agnostic)
|
|
688
|
+
- Returns None for targets if not present in file
|
|
689
|
+
|
|
690
|
+
Supports any input dimensionality:
|
|
691
|
+
- 1D: (N, L) → (N, 1, L)
|
|
692
|
+
- 2D: (N, H, W) → (N, 1, H, W)
|
|
693
|
+
- 3D: (N, D, H, W) → (N, 1, D, H, W)
|
|
694
|
+
- Already has channel: (N, C, ...) → unchanged
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
path: Path to data file (NPZ, HDF5, or MAT v7.3)
|
|
698
|
+
format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
|
|
699
|
+
input_key: Custom key for input data (overrides auto-detection)
|
|
700
|
+
output_key: Custom key for output data (overrides auto-detection)
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
Tuple of:
|
|
704
|
+
- X: Input tensor with channel dimension (N, 1, *spatial_dims)
|
|
705
|
+
- y: Target tensor (N, T) or None if targets not present
|
|
706
|
+
|
|
707
|
+
Example:
|
|
708
|
+
>>> X, y = load_test_data("test_data.npz")
|
|
709
|
+
>>> X, y = load_test_data(
|
|
710
|
+
... "data.mat", input_key="waveforms", output_key="params"
|
|
711
|
+
... )
|
|
712
|
+
"""
|
|
713
|
+
if format == "auto":
|
|
714
|
+
format = DataSource.detect_format(path)
|
|
715
|
+
|
|
716
|
+
source = get_data_source(format)
|
|
717
|
+
|
|
718
|
+
# Build custom key lists if provided
|
|
719
|
+
if input_key:
|
|
720
|
+
custom_input_keys = [input_key] + INPUT_KEYS
|
|
721
|
+
else:
|
|
722
|
+
# Prioritize test keys for inference
|
|
723
|
+
custom_input_keys = ["input_test"] + [
|
|
724
|
+
k for k in INPUT_KEYS if k != "input_test"
|
|
725
|
+
]
|
|
726
|
+
|
|
727
|
+
if output_key:
|
|
728
|
+
custom_output_keys = [output_key] + OUTPUT_KEYS
|
|
729
|
+
else:
|
|
730
|
+
custom_output_keys = ["output_test"] + [
|
|
731
|
+
k for k in OUTPUT_KEYS if k != "output_test"
|
|
732
|
+
]
|
|
733
|
+
|
|
734
|
+
# Load data using appropriate source
|
|
735
|
+
try:
|
|
736
|
+
inp, outp = source.load(path)
|
|
737
|
+
except KeyError:
|
|
738
|
+
# Try with just inputs if outputs not found
|
|
739
|
+
if format == "npz":
|
|
740
|
+
data = np.load(path, allow_pickle=True)
|
|
741
|
+
keys = list(data.keys())
|
|
742
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
743
|
+
if inp_key is None:
|
|
744
|
+
raise KeyError(
|
|
745
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
746
|
+
)
|
|
747
|
+
inp = data[inp_key]
|
|
748
|
+
if inp.dtype == object:
|
|
749
|
+
inp = np.array(
|
|
750
|
+
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
751
|
+
)
|
|
752
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
753
|
+
outp = data[out_key] if out_key else None
|
|
754
|
+
else:
|
|
755
|
+
raise
|
|
756
|
+
|
|
757
|
+
# Handle sparse matrices
|
|
758
|
+
if issparse(inp):
|
|
759
|
+
inp = inp.toarray()
|
|
760
|
+
if outp is not None and issparse(outp):
|
|
761
|
+
outp = outp.toarray()
|
|
762
|
+
|
|
763
|
+
# Convert to tensors
|
|
764
|
+
X = torch.tensor(np.asarray(inp), dtype=torch.float32)
|
|
765
|
+
|
|
766
|
+
if outp is not None:
|
|
767
|
+
y = torch.tensor(np.asarray(outp), dtype=torch.float32)
|
|
768
|
+
# Normalize target shape: (N,) → (N, 1)
|
|
769
|
+
if y.ndim == 1:
|
|
770
|
+
y = y.unsqueeze(1)
|
|
771
|
+
else:
|
|
772
|
+
y = None
|
|
773
|
+
|
|
774
|
+
# Add channel dimension if needed (dimension-agnostic)
|
|
775
|
+
# X.ndim == 2: 1D data (N, L) → (N, 1, L)
|
|
776
|
+
# X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
|
|
777
|
+
# X.ndim == 4: Check if already has channel dim (C <= 16 heuristic)
|
|
778
|
+
if X.ndim == 2:
|
|
779
|
+
X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
|
|
780
|
+
elif X.ndim == 3:
|
|
781
|
+
X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
|
|
782
|
+
elif X.ndim == 4:
|
|
783
|
+
# Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
|
|
784
|
+
# Heuristic: if dim 1 is small (<=16), assume it's already a channel dim
|
|
785
|
+
if X.shape[1] > 16:
|
|
786
|
+
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
787
|
+
# X.ndim >= 5: assume channel dimension already exists
|
|
788
|
+
|
|
789
|
+
return X, y
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
# ==============================================================================
|
|
793
|
+
# WORKER INITIALIZATION
|
|
794
|
+
# ==============================================================================
|
|
795
|
+
def memmap_worker_init_fn(worker_id: int):
|
|
796
|
+
"""
|
|
797
|
+
Worker initialization function for proper memmap handling in multi-worker DataLoader.
|
|
798
|
+
|
|
799
|
+
Each DataLoader worker process runs this function after forking. It:
|
|
800
|
+
1. Resets the memmap file handle to None, forcing each worker to open its own
|
|
801
|
+
read-only handle (prevents file descriptor sharing issues and race conditions)
|
|
802
|
+
2. Seeds numpy's random state per worker to ensure statistical diversity in
|
|
803
|
+
random augmentations (prevents all workers from applying identical "random"
|
|
804
|
+
transformations to their batches)
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
worker_id: Worker index (0 to num_workers-1), provided by DataLoader
|
|
808
|
+
|
|
809
|
+
Usage:
|
|
810
|
+
DataLoader(dataset, num_workers=8, worker_init_fn=memmap_worker_init_fn)
|
|
811
|
+
"""
|
|
812
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
813
|
+
if worker_info is not None:
|
|
814
|
+
dataset = worker_info.dataset
|
|
815
|
+
# Force re-initialization of memmap in each worker
|
|
816
|
+
dataset.data = None
|
|
817
|
+
|
|
818
|
+
# Seed numpy RNG per worker using PyTorch's worker seed for reproducibility
|
|
819
|
+
# This ensures random augmentations (noise, shifts, etc.) are unique per worker
|
|
820
|
+
np.random.seed(worker_info.seed % (2**32 - 1))
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
# ==============================================================================
|
|
824
|
+
# MEMORY-MAPPED DATASET
|
|
825
|
+
# ==============================================================================
|
|
826
|
+
class MemmapDataset(Dataset):
|
|
827
|
+
"""
|
|
828
|
+
Zero-copy memory-mapped dataset for large-scale training.
|
|
829
|
+
|
|
830
|
+
Uses numpy memory mapping to load data directly from disk, allowing training
|
|
831
|
+
on datasets that exceed available RAM. The memmap is only opened when first
|
|
832
|
+
accessed (lazy initialization), and each DataLoader worker maintains its own
|
|
833
|
+
file handle for thread safety.
|
|
834
|
+
|
|
835
|
+
Args:
|
|
836
|
+
memmap_path: Path to the memory-mapped data file
|
|
837
|
+
targets: Pre-loaded target tensor (small enough to fit in memory)
|
|
838
|
+
shape: Full shape of the memmap array (N, C, H, W)
|
|
839
|
+
indices: Indices into the memmap for this split (train/val)
|
|
840
|
+
|
|
841
|
+
Thread Safety:
|
|
842
|
+
When using with DataLoader num_workers > 0, must use memmap_worker_init_fn
|
|
843
|
+
as the worker_init_fn to ensure each worker gets its own file handle.
|
|
844
|
+
|
|
845
|
+
Example:
|
|
846
|
+
dataset = MemmapDataset("cache.dat", y_tensor, (10000, 1, 500, 500), train_indices)
|
|
847
|
+
loader = DataLoader(dataset, num_workers=8, worker_init_fn=memmap_worker_init_fn)
|
|
848
|
+
"""
|
|
849
|
+
|
|
850
|
+
def __init__(
|
|
851
|
+
self,
|
|
852
|
+
memmap_path: str,
|
|
853
|
+
targets: torch.Tensor,
|
|
854
|
+
shape: tuple[int, ...],
|
|
855
|
+
indices: np.ndarray,
|
|
856
|
+
):
|
|
857
|
+
self.memmap_path = memmap_path
|
|
858
|
+
self.targets = targets
|
|
859
|
+
self.shape = shape
|
|
860
|
+
self.indices = indices
|
|
861
|
+
self.data: np.memmap | None = None # Lazy initialization
|
|
862
|
+
|
|
863
|
+
def __len__(self) -> int:
|
|
864
|
+
return len(self.indices)
|
|
865
|
+
|
|
866
|
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
867
|
+
if self.data is None:
|
|
868
|
+
# Mode 'r' = read-only, prevents accidental data modification
|
|
869
|
+
self.data = np.memmap(
|
|
870
|
+
self.memmap_path, dtype="float32", mode="r", shape=self.shape
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
real_idx = self.indices[idx]
|
|
874
|
+
|
|
875
|
+
# .copy() detaches from mmap buffer - essential for PyTorch pinned memory
|
|
876
|
+
x = torch.from_numpy(self.data[real_idx].copy()).contiguous()
|
|
877
|
+
y = self.targets[real_idx]
|
|
878
|
+
|
|
879
|
+
return x, y
|
|
880
|
+
|
|
881
|
+
def __repr__(self) -> str:
|
|
882
|
+
return (
|
|
883
|
+
f"MemmapDataset(path='{self.memmap_path}', "
|
|
884
|
+
f"samples={len(self)}, shape={self.shape})"
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
# ==============================================================================
|
|
889
|
+
# DATA PREPARATION
|
|
890
|
+
# ==============================================================================
|
|
891
|
+
def prepare_data(
|
|
892
|
+
args: Any,
|
|
893
|
+
logger: logging.Logger,
|
|
894
|
+
accelerator: Accelerator,
|
|
895
|
+
cache_dir: str = ".",
|
|
896
|
+
val_split: float = 0.2,
|
|
897
|
+
) -> tuple[DataLoader, DataLoader, StandardScaler, tuple[int, ...], int]:
|
|
898
|
+
"""
|
|
899
|
+
Prepare DataLoaders with DDP synchronization guarantees.
|
|
900
|
+
|
|
901
|
+
This function handles:
|
|
902
|
+
1. Loading raw data and creating memmap cache (Rank 0 only)
|
|
903
|
+
2. Fitting StandardScaler on training set only (no data leakage)
|
|
904
|
+
3. Synchronizing all ranks before proceeding
|
|
905
|
+
4. Creating thread-safe DataLoaders for DDP training
|
|
906
|
+
|
|
907
|
+
Supports any input dimensionality:
|
|
908
|
+
- 1D: (N, L) → returns in_shape = (L,)
|
|
909
|
+
- 2D: (N, H, W) → returns in_shape = (H, W)
|
|
910
|
+
- 3D: (N, D, H, W) → returns in_shape = (D, H, W)
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
args: Argument namespace with data_path, seed, batch_size, workers
|
|
914
|
+
logger: Logger instance for status messages
|
|
915
|
+
accelerator: Accelerator instance for DDP coordination
|
|
916
|
+
cache_dir: Directory for cache files (default: current directory)
|
|
917
|
+
val_split: Validation set fraction (default: 0.2)
|
|
918
|
+
|
|
919
|
+
Returns:
|
|
920
|
+
Tuple of:
|
|
921
|
+
- train_dl: Training DataLoader
|
|
922
|
+
- val_dl: Validation DataLoader
|
|
923
|
+
- scaler: Fitted StandardScaler (for inverse transforms)
|
|
924
|
+
- in_shape: Input spatial dimensions - (L,), (H, W), or (D, H, W)
|
|
925
|
+
- out_dim: Number of output targets
|
|
926
|
+
|
|
927
|
+
Cache Files Created:
|
|
928
|
+
- train_data_cache.dat: Memory-mapped input data
|
|
929
|
+
- scaler.pkl: Fitted StandardScaler
|
|
930
|
+
- data_metadata.pkl: Shape and dimension metadata
|
|
931
|
+
"""
|
|
932
|
+
CACHE_FILE = os.path.join(cache_dir, "train_data_cache.dat")
|
|
933
|
+
SCALER_FILE = os.path.join(cache_dir, "scaler.pkl")
|
|
934
|
+
META_FILE = os.path.join(cache_dir, "data_metadata.pkl")
|
|
935
|
+
|
|
936
|
+
# ==========================================================================
|
|
937
|
+
# PHASE 1: DATA GENERATION (Rank 0 Only)
|
|
938
|
+
# ==========================================================================
|
|
939
|
+
# Check cache existence and validity (data path must match)
|
|
940
|
+
cache_exists = (
|
|
941
|
+
os.path.exists(CACHE_FILE)
|
|
942
|
+
and os.path.exists(SCALER_FILE)
|
|
943
|
+
and os.path.exists(META_FILE)
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
# Validate cache matches current data_path (prevents stale cache corruption)
|
|
947
|
+
if cache_exists:
|
|
948
|
+
try:
|
|
949
|
+
with open(META_FILE, "rb") as f:
|
|
950
|
+
meta = pickle.load(f)
|
|
951
|
+
cached_data_path = meta.get("data_path", None)
|
|
952
|
+
if cached_data_path != os.path.abspath(args.data_path):
|
|
953
|
+
if accelerator.is_main_process:
|
|
954
|
+
logger.warning(
|
|
955
|
+
f"⚠️ Cache was created from different data file!\n"
|
|
956
|
+
f" Cached: {cached_data_path}\n"
|
|
957
|
+
f" Current: {os.path.abspath(args.data_path)}\n"
|
|
958
|
+
f" Invalidating cache and regenerating..."
|
|
959
|
+
)
|
|
960
|
+
cache_exists = False
|
|
961
|
+
except Exception:
|
|
962
|
+
cache_exists = False
|
|
963
|
+
|
|
964
|
+
if not cache_exists:
|
|
965
|
+
if accelerator.is_main_process:
|
|
966
|
+
# RANK 0: Create cache (can take a long time for large datasets)
|
|
967
|
+
# Other ranks will wait at the barrier below
|
|
968
|
+
|
|
969
|
+
# Detect format from extension
|
|
970
|
+
data_format = DataSource.detect_format(args.data_path)
|
|
971
|
+
logger.info(
|
|
972
|
+
f"⚡ [Rank 0] Initializing Data Processing from: {args.data_path} (format: {data_format})"
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Validate data file exists
|
|
976
|
+
if not os.path.exists(args.data_path):
|
|
977
|
+
raise FileNotFoundError(
|
|
978
|
+
f"CRITICAL: Data file not found: {args.data_path}"
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# Load raw data using memory-mapped mode for all formats
|
|
982
|
+
# This avoids loading the entire dataset into RAM at once
|
|
983
|
+
try:
|
|
984
|
+
if data_format == "npz":
|
|
985
|
+
source = NPZSource()
|
|
986
|
+
inp, outp = source.load_mmap(args.data_path)
|
|
987
|
+
elif data_format == "hdf5":
|
|
988
|
+
source = HDF5Source()
|
|
989
|
+
_lazy_handle = source.load_mmap(args.data_path)
|
|
990
|
+
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
991
|
+
elif data_format == "mat":
|
|
992
|
+
source = MATSource()
|
|
993
|
+
_lazy_handle = source.load_mmap(args.data_path)
|
|
994
|
+
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
995
|
+
else:
|
|
996
|
+
inp, outp = load_training_data(args.data_path, format=data_format)
|
|
997
|
+
logger.info(" Using memory-mapped loading (low memory mode)")
|
|
998
|
+
except Exception as e:
|
|
999
|
+
logger.error(f"Failed to load data file: {e}")
|
|
1000
|
+
raise
|
|
1001
|
+
|
|
1002
|
+
# Detect shape (handle sparse matrices) - DIMENSION AGNOSTIC
|
|
1003
|
+
num_samples = len(inp)
|
|
1004
|
+
|
|
1005
|
+
# Handle 1D targets: (N,) -> treat as single output
|
|
1006
|
+
if outp.ndim == 1:
|
|
1007
|
+
out_dim = 1
|
|
1008
|
+
else:
|
|
1009
|
+
out_dim = outp.shape[1]
|
|
1010
|
+
|
|
1011
|
+
sample_0 = inp[0]
|
|
1012
|
+
if issparse(sample_0) or hasattr(sample_0, "toarray"):
|
|
1013
|
+
sample_0 = sample_0.toarray()
|
|
1014
|
+
|
|
1015
|
+
# Detect dimensionality and validate single-channel assumption
|
|
1016
|
+
raw_shape = sample_0.shape
|
|
1017
|
+
|
|
1018
|
+
# Heuristic for ambiguous multi-channel detection:
|
|
1019
|
+
# Trigger when shape is EXACTLY 3D (could be C,H,W or D,H,W) with:
|
|
1020
|
+
# - First dim small (<=16) - looks like channels
|
|
1021
|
+
# - BOTH remaining dims large (>16) - confirming it's an image, not a tiny patch
|
|
1022
|
+
# This distinguishes (3, 256, 256) from (8, 1024) or (128, 128)
|
|
1023
|
+
# Use --single_channel to confirm shallow 3D volumes like (8, 128, 128)
|
|
1024
|
+
is_ambiguous_shape = (
|
|
1025
|
+
len(raw_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
|
|
1026
|
+
and raw_shape[0] <= 16 # First dim looks like channels
|
|
1027
|
+
and raw_shape[1] > 16
|
|
1028
|
+
and raw_shape[2] > 16 # Both spatial dims are large
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
# Check for user confirmation via --single_channel flag
|
|
1032
|
+
user_confirmed_single_channel = getattr(args, "single_channel", False)
|
|
1033
|
+
|
|
1034
|
+
if is_ambiguous_shape and not user_confirmed_single_channel:
|
|
1035
|
+
raise ValueError(
|
|
1036
|
+
f"Ambiguous input shape detected: sample shape {raw_shape}. "
|
|
1037
|
+
f"This could be either:\n"
|
|
1038
|
+
f" - Multi-channel 2D data (C={raw_shape[0]}, H={raw_shape[1]}, W={raw_shape[2]})\n"
|
|
1039
|
+
f" - Single-channel 3D volume (D={raw_shape[0]}, H={raw_shape[1]}, W={raw_shape[2]})\n\n"
|
|
1040
|
+
f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
|
|
1041
|
+
f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
spatial_shape = raw_shape
|
|
1045
|
+
full_shape = (
|
|
1046
|
+
num_samples,
|
|
1047
|
+
1,
|
|
1048
|
+
) + spatial_shape # Add channel dim: (N, 1, ...)
|
|
1049
|
+
|
|
1050
|
+
dim_names = {1: "1D (L)", 2: "2D (H, W)", 3: "3D (D, H, W)"}
|
|
1051
|
+
dim_type = dim_names.get(len(spatial_shape), f"{len(spatial_shape)}D")
|
|
1052
|
+
logger.info(
|
|
1053
|
+
f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
# Save metadata (including data path for cache validation)
|
|
1057
|
+
with open(META_FILE, "wb") as f:
|
|
1058
|
+
pickle.dump(
|
|
1059
|
+
{
|
|
1060
|
+
"shape": full_shape,
|
|
1061
|
+
"out_dim": out_dim,
|
|
1062
|
+
"data_path": os.path.abspath(args.data_path),
|
|
1063
|
+
},
|
|
1064
|
+
f,
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
# Create memmap cache
|
|
1068
|
+
if not os.path.exists(CACHE_FILE):
|
|
1069
|
+
logger.info(" Writing Memmap Cache (one-time operation)...")
|
|
1070
|
+
fp = np.memmap(CACHE_FILE, dtype="float32", mode="w+", shape=full_shape)
|
|
1071
|
+
|
|
1072
|
+
chunk_size = 2000
|
|
1073
|
+
pbar = tqdm(
|
|
1074
|
+
range(0, num_samples, chunk_size),
|
|
1075
|
+
desc="Caching",
|
|
1076
|
+
disable=not accelerator.is_main_process,
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
for i in pbar:
|
|
1080
|
+
end = min(i + chunk_size, num_samples)
|
|
1081
|
+
batch = inp[i:end]
|
|
1082
|
+
|
|
1083
|
+
# Handle sparse/dense conversion
|
|
1084
|
+
if issparse(batch[0]) or hasattr(batch[0], "toarray"):
|
|
1085
|
+
data_chunk = np.stack(
|
|
1086
|
+
[x.toarray().astype(np.float32) for x in batch]
|
|
1087
|
+
)
|
|
1088
|
+
else:
|
|
1089
|
+
data_chunk = np.array(batch).astype(np.float32)
|
|
1090
|
+
|
|
1091
|
+
# Add channel dimension if needed (handles 1D, 2D, and 3D spatial data)
|
|
1092
|
+
# data_chunk shape: (batch, *spatial) -> need (batch, 1, *spatial)
|
|
1093
|
+
# full_shape is (N, 1, *spatial), so expected ndim = len(full_shape)
|
|
1094
|
+
if data_chunk.ndim == len(full_shape) - 1:
|
|
1095
|
+
# Missing channel dim: (batch, *spatial) -> (batch, 1, *spatial)
|
|
1096
|
+
data_chunk = np.expand_dims(data_chunk, 1)
|
|
1097
|
+
|
|
1098
|
+
fp[i:end] = data_chunk
|
|
1099
|
+
|
|
1100
|
+
# Periodic flush to disk
|
|
1101
|
+
if i % 10000 == 0:
|
|
1102
|
+
fp.flush()
|
|
1103
|
+
|
|
1104
|
+
fp.flush()
|
|
1105
|
+
del fp
|
|
1106
|
+
gc.collect()
|
|
1107
|
+
|
|
1108
|
+
# Train/Val split and scaler fitting
|
|
1109
|
+
indices = np.arange(num_samples)
|
|
1110
|
+
tr_idx, val_idx = train_test_split(
|
|
1111
|
+
indices, test_size=val_split, random_state=args.seed
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
if not os.path.exists(SCALER_FILE):
|
|
1115
|
+
logger.info(" Fitting StandardScaler (training set only)...")
|
|
1116
|
+
|
|
1117
|
+
# Convert lazy datasets to numpy for reliable indexing
|
|
1118
|
+
# (h5py and _TransposedH5Dataset may not support fancy indexing)
|
|
1119
|
+
if hasattr(outp, "_dataset") or hasattr(outp, "file"):
|
|
1120
|
+
# Lazy h5py or _TransposedH5Dataset - load training subset
|
|
1121
|
+
outp_train = np.array([outp[i] for i in tr_idx])
|
|
1122
|
+
else:
|
|
1123
|
+
# Already numpy array
|
|
1124
|
+
outp_train = outp[tr_idx]
|
|
1125
|
+
|
|
1126
|
+
# Ensure 2D for StandardScaler: (N,) -> (N, 1)
|
|
1127
|
+
if outp_train.ndim == 1:
|
|
1128
|
+
outp_train = outp_train.reshape(-1, 1)
|
|
1129
|
+
|
|
1130
|
+
scaler = StandardScaler()
|
|
1131
|
+
scaler.fit(outp_train)
|
|
1132
|
+
with open(SCALER_FILE, "wb") as f:
|
|
1133
|
+
pickle.dump(scaler, f)
|
|
1134
|
+
|
|
1135
|
+
# Cleanup: close file handles BEFORE deleting references
|
|
1136
|
+
if "_lazy_handle" in dir() and _lazy_handle is not None:
|
|
1137
|
+
try:
|
|
1138
|
+
_lazy_handle.close()
|
|
1139
|
+
except Exception:
|
|
1140
|
+
pass
|
|
1141
|
+
del inp, outp
|
|
1142
|
+
gc.collect()
|
|
1143
|
+
|
|
1144
|
+
logger.info(" ✔ Cache creation complete, synchronizing ranks...")
|
|
1145
|
+
else:
|
|
1146
|
+
# NON-MAIN RANKS: Wait for cache creation
|
|
1147
|
+
# Log that we're waiting (helps with debugging)
|
|
1148
|
+
import time
|
|
1149
|
+
|
|
1150
|
+
wait_start = time.time()
|
|
1151
|
+
while not (
|
|
1152
|
+
os.path.exists(CACHE_FILE)
|
|
1153
|
+
and os.path.exists(SCALER_FILE)
|
|
1154
|
+
and os.path.exists(META_FILE)
|
|
1155
|
+
):
|
|
1156
|
+
time.sleep(5) # Check every 5 seconds
|
|
1157
|
+
elapsed = time.time() - wait_start
|
|
1158
|
+
if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
|
|
1159
|
+
logger.info(
|
|
1160
|
+
f" [Rank {accelerator.process_index}] Waiting for cache creation... ({int(elapsed)}s)"
|
|
1161
|
+
)
|
|
1162
|
+
# Small delay to ensure files are fully written
|
|
1163
|
+
time.sleep(2)
|
|
1164
|
+
|
|
1165
|
+
# ==========================================================================
|
|
1166
|
+
# PHASE 2: SYNCHRONIZED LOADING (All Ranks)
|
|
1167
|
+
# ==========================================================================
|
|
1168
|
+
accelerator.wait_for_everyone()
|
|
1169
|
+
|
|
1170
|
+
# Load metadata
|
|
1171
|
+
with open(META_FILE, "rb") as f:
|
|
1172
|
+
meta = pickle.load(f)
|
|
1173
|
+
full_shape = meta["shape"]
|
|
1174
|
+
out_dim = meta["out_dim"]
|
|
1175
|
+
|
|
1176
|
+
# Load and validate scaler
|
|
1177
|
+
with open(SCALER_FILE, "rb") as f:
|
|
1178
|
+
scaler = pickle.load(f)
|
|
1179
|
+
|
|
1180
|
+
if not hasattr(scaler, "scale_") or scaler.scale_ is None:
|
|
1181
|
+
raise RuntimeError("CRITICAL: Scaler is not properly fitted (scale_ is None)")
|
|
1182
|
+
|
|
1183
|
+
# Load targets only (memory-efficient - avoids loading large input arrays)
|
|
1184
|
+
# This is critical for HPC environments with memory constraints during DDP
|
|
1185
|
+
outp = load_outputs_only(args.data_path)
|
|
1186
|
+
|
|
1187
|
+
# Ensure 2D for StandardScaler: (N,) -> (N, 1)
|
|
1188
|
+
if outp.ndim == 1:
|
|
1189
|
+
outp = outp.reshape(-1, 1)
|
|
1190
|
+
|
|
1191
|
+
y_scaled = scaler.transform(outp).astype(np.float32)
|
|
1192
|
+
y_tensor = torch.tensor(y_scaled)
|
|
1193
|
+
|
|
1194
|
+
# Regenerate indices (deterministic with same seed)
|
|
1195
|
+
indices = np.arange(full_shape[0])
|
|
1196
|
+
tr_idx, val_idx = train_test_split(
|
|
1197
|
+
indices, test_size=val_split, random_state=args.seed
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# Create datasets
|
|
1201
|
+
tr_ds = MemmapDataset(CACHE_FILE, y_tensor, full_shape, tr_idx)
|
|
1202
|
+
val_ds = MemmapDataset(CACHE_FILE, y_tensor, full_shape, val_idx)
|
|
1203
|
+
|
|
1204
|
+
# Create DataLoaders with thread-safe configuration
|
|
1205
|
+
loader_kwargs = {
|
|
1206
|
+
"batch_size": args.batch_size,
|
|
1207
|
+
"num_workers": args.workers,
|
|
1208
|
+
"pin_memory": True,
|
|
1209
|
+
"persistent_workers": (args.workers > 0),
|
|
1210
|
+
"prefetch_factor": 2 if args.workers > 0 else None,
|
|
1211
|
+
"worker_init_fn": memmap_worker_init_fn if args.workers > 0 else None,
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
train_dl = DataLoader(tr_ds, shuffle=True, **loader_kwargs)
|
|
1215
|
+
val_dl = DataLoader(val_ds, shuffle=False, **loader_kwargs)
|
|
1216
|
+
|
|
1217
|
+
# Return spatial shape (H, W) for model initialization
|
|
1218
|
+
in_shape = full_shape[2:]
|
|
1219
|
+
|
|
1220
|
+
return train_dl, val_dl, scaler, in_shape, out_dim
|