wavedl 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +115 -9
- wavedl/models/__init__.py +22 -0
- wavedl/models/_pretrained_utils.py +72 -0
- wavedl/models/_template.py +7 -6
- wavedl/models/cnn.py +20 -0
- wavedl/models/convnext.py +3 -70
- wavedl/models/convnext_v2.py +1 -18
- wavedl/models/mamba.py +126 -38
- wavedl/models/resnet3d.py +23 -5
- wavedl/models/unireplknet.py +1 -18
- wavedl/models/vit.py +18 -8
- wavedl/test.py +13 -23
- wavedl/train.py +494 -28
- wavedl/utils/__init__.py +49 -9
- wavedl/utils/config.py +6 -8
- wavedl/utils/cross_validation.py +17 -4
- wavedl/utils/data.py +176 -180
- wavedl/utils/metrics.py +26 -5
- wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/METADATA +37 -18
- wavedl-1.7.0.dist-info/RECORD +46 -0
- wavedl-1.6.2.dist-info/RECORD +0 -46
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/LICENSE +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/WHEEL +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/top_level.txt +0 -0
wavedl/utils/cross_validation.py
CHANGED
|
@@ -36,7 +36,9 @@ from torch.utils.data import DataLoader
|
|
|
36
36
|
class CVDataset(torch.utils.data.Dataset):
|
|
37
37
|
"""Simple in-memory dataset for cross-validation."""
|
|
38
38
|
|
|
39
|
-
def __init__(
|
|
39
|
+
def __init__(
|
|
40
|
+
self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int | None = None
|
|
41
|
+
):
|
|
40
42
|
"""
|
|
41
43
|
Initialize CV dataset with explicit channel dimension handling.
|
|
42
44
|
|
|
@@ -51,6 +53,11 @@ class CVDataset(torch.utils.data.Dataset):
|
|
|
51
53
|
- If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
|
|
52
54
|
- If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
|
|
53
55
|
- If expected_spatial_ndim is None: Use legacy ndim-based inference
|
|
56
|
+
|
|
57
|
+
Warning:
|
|
58
|
+
Legacy mode (expected_spatial_ndim=None) may misinterpret multichannel
|
|
59
|
+
3D data as single-channel 4D data. Always pass expected_spatial_ndim
|
|
60
|
+
explicitly for 3D volumes with >1 channel.
|
|
54
61
|
"""
|
|
55
62
|
if expected_spatial_ndim is not None:
|
|
56
63
|
# Explicit mode: use expected_spatial_ndim to determine if channel exists
|
|
@@ -100,6 +107,7 @@ def train_fold(
|
|
|
100
107
|
device: torch.device,
|
|
101
108
|
epochs: int,
|
|
102
109
|
patience: int,
|
|
110
|
+
grad_clip: float,
|
|
103
111
|
scaler: StandardScaler,
|
|
104
112
|
logger: logging.Logger,
|
|
105
113
|
) -> dict[str, Any]:
|
|
@@ -147,7 +155,8 @@ def train_fold(
|
|
|
147
155
|
pred = model(x)
|
|
148
156
|
loss = criterion(pred, y)
|
|
149
157
|
loss.backward()
|
|
150
|
-
|
|
158
|
+
if grad_clip > 0:
|
|
159
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
151
160
|
optimizer.step()
|
|
152
161
|
|
|
153
162
|
# Per-batch LR scheduling (OneCycleLR)
|
|
@@ -289,6 +298,7 @@ def run_cross_validation(
|
|
|
289
298
|
output_dir: str = "./cv_results",
|
|
290
299
|
workers: int = 4,
|
|
291
300
|
seed: int = 2025,
|
|
301
|
+
grad_clip: float = 1.0,
|
|
292
302
|
logger: logging.Logger | None = None,
|
|
293
303
|
) -> dict[str, Any]:
|
|
294
304
|
"""
|
|
@@ -330,8 +340,10 @@ def run_cross_validation(
|
|
|
330
340
|
)
|
|
331
341
|
logger = logging.getLogger("CV-Trainer")
|
|
332
342
|
|
|
333
|
-
# Set seeds
|
|
334
|
-
|
|
343
|
+
# Set seeds for reproducibility
|
|
344
|
+
# Note: sklearn KFold uses random_state parameter directly, not global numpy RNG
|
|
345
|
+
rng = np.random.default_rng(seed) # Local RNG for any numpy operations
|
|
346
|
+
_ = rng # Silence unused variable warning (available for future use)
|
|
335
347
|
torch.manual_seed(seed)
|
|
336
348
|
if torch.cuda.is_available():
|
|
337
349
|
torch.cuda.manual_seed_all(seed)
|
|
@@ -444,6 +456,7 @@ def run_cross_validation(
|
|
|
444
456
|
device=device,
|
|
445
457
|
epochs=epochs,
|
|
446
458
|
patience=patience,
|
|
459
|
+
grad_clip=grad_clip,
|
|
447
460
|
scaler=scaler,
|
|
448
461
|
logger=logger,
|
|
449
462
|
)
|
wavedl/utils/data.py
CHANGED
|
@@ -798,6 +798,128 @@ def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
|
|
|
798
798
|
return source.load_outputs_only(path)
|
|
799
799
|
|
|
800
800
|
|
|
801
|
+
def _load_npz_for_test(
|
|
802
|
+
path: str,
|
|
803
|
+
input_keys: list[str],
|
|
804
|
+
output_keys: list[str],
|
|
805
|
+
explicit_input_key: str | None,
|
|
806
|
+
explicit_output_key: str | None,
|
|
807
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
808
|
+
"""Load NPZ file for test/inference with key validation."""
|
|
809
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
810
|
+
keys = list(probe.keys())
|
|
811
|
+
|
|
812
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
813
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
814
|
+
|
|
815
|
+
# Strict validation for explicit keys
|
|
816
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
817
|
+
raise KeyError(
|
|
818
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
819
|
+
)
|
|
820
|
+
if inp_key is None:
|
|
821
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
822
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
823
|
+
raise KeyError(
|
|
824
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
data = NPZSource._load_and_copy(path, [inp_key] + ([out_key] if out_key else []))
|
|
828
|
+
inp = data[inp_key]
|
|
829
|
+
if inp.dtype == object:
|
|
830
|
+
inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
|
|
831
|
+
outp = data[out_key] if out_key else None
|
|
832
|
+
return inp, outp
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def _load_hdf5_for_test(
|
|
836
|
+
path: str,
|
|
837
|
+
input_keys: list[str],
|
|
838
|
+
output_keys: list[str],
|
|
839
|
+
explicit_input_key: str | None,
|
|
840
|
+
explicit_output_key: str | None,
|
|
841
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
842
|
+
"""Load HDF5 file for test/inference with key validation and OOM guard."""
|
|
843
|
+
with h5py.File(path, "r") as f:
|
|
844
|
+
keys = list(f.keys())
|
|
845
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
846
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
847
|
+
|
|
848
|
+
# Strict validation
|
|
849
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
850
|
+
raise KeyError(
|
|
851
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
852
|
+
)
|
|
853
|
+
if inp_key is None:
|
|
854
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
855
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
856
|
+
raise KeyError(
|
|
857
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# OOM guard
|
|
861
|
+
n_samples = f[inp_key].shape[0]
|
|
862
|
+
if n_samples > 100000:
|
|
863
|
+
raise ValueError(
|
|
864
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
865
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
866
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
inp = f[inp_key][:]
|
|
870
|
+
outp = f[out_key][:] if out_key else None
|
|
871
|
+
return inp, outp
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def _load_mat_for_test(
|
|
875
|
+
path: str,
|
|
876
|
+
input_keys: list[str],
|
|
877
|
+
output_keys: list[str],
|
|
878
|
+
explicit_input_key: str | None,
|
|
879
|
+
explicit_output_key: str | None,
|
|
880
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
881
|
+
"""Load MAT v7.3 file for test/inference with sparse support."""
|
|
882
|
+
mat_source = MATSource()
|
|
883
|
+
with h5py.File(path, "r") as f:
|
|
884
|
+
keys = list(f.keys())
|
|
885
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
886
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
887
|
+
|
|
888
|
+
# Strict validation
|
|
889
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
890
|
+
raise KeyError(
|
|
891
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
892
|
+
)
|
|
893
|
+
if inp_key is None:
|
|
894
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
895
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
896
|
+
raise KeyError(
|
|
897
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
# OOM guard
|
|
901
|
+
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
902
|
+
if n_samples > 100000:
|
|
903
|
+
raise ValueError(
|
|
904
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
905
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
906
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
910
|
+
outp = None
|
|
911
|
+
if out_key:
|
|
912
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
913
|
+
# Handle MATLAB transpose
|
|
914
|
+
num_samples = inp.shape[0]
|
|
915
|
+
if outp.ndim == 2:
|
|
916
|
+
if (outp.shape[0] == 1 and outp.shape[1] == num_samples) or (
|
|
917
|
+
outp.shape[1] == 1 and outp.shape[0] != num_samples
|
|
918
|
+
):
|
|
919
|
+
outp = outp.T
|
|
920
|
+
return inp, outp
|
|
921
|
+
|
|
922
|
+
|
|
801
923
|
def load_test_data(
|
|
802
924
|
path: str,
|
|
803
925
|
format: str = "auto",
|
|
@@ -865,112 +987,17 @@ def load_test_data(
|
|
|
865
987
|
# We detect keys first to ensure input_test/output_test are used when present
|
|
866
988
|
try:
|
|
867
989
|
if format == "npz":
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
871
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
872
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
873
|
-
if input_key is not None and input_key not in keys:
|
|
874
|
-
raise KeyError(
|
|
875
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
876
|
-
f"Available keys: {keys}"
|
|
877
|
-
)
|
|
878
|
-
if inp_key is None:
|
|
879
|
-
raise KeyError(
|
|
880
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
881
|
-
)
|
|
882
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
883
|
-
if output_key is not None and output_key not in keys:
|
|
884
|
-
raise KeyError(
|
|
885
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
886
|
-
f"Available keys: {keys}"
|
|
887
|
-
)
|
|
888
|
-
data = NPZSource._load_and_copy(
|
|
889
|
-
path, [inp_key] + ([out_key] if out_key else [])
|
|
990
|
+
inp, outp = _load_npz_for_test(
|
|
991
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
890
992
|
)
|
|
891
|
-
inp = data[inp_key]
|
|
892
|
-
if inp.dtype == object:
|
|
893
|
-
inp = np.array(
|
|
894
|
-
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
895
|
-
)
|
|
896
|
-
outp = data[out_key] if out_key else None
|
|
897
993
|
elif format == "hdf5":
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
902
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
903
|
-
if input_key is not None and input_key not in keys:
|
|
904
|
-
raise KeyError(
|
|
905
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
906
|
-
f"Available keys: {keys}"
|
|
907
|
-
)
|
|
908
|
-
if inp_key is None:
|
|
909
|
-
raise KeyError(
|
|
910
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
911
|
-
)
|
|
912
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
913
|
-
if output_key is not None and output_key not in keys:
|
|
914
|
-
raise KeyError(
|
|
915
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
916
|
-
f"Available keys: {keys}"
|
|
917
|
-
)
|
|
918
|
-
# OOM guard: warn if dataset is very large
|
|
919
|
-
n_samples = f[inp_key].shape[0]
|
|
920
|
-
if n_samples > 100000:
|
|
921
|
-
raise ValueError(
|
|
922
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
923
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
924
|
-
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
925
|
-
)
|
|
926
|
-
inp = f[inp_key][:]
|
|
927
|
-
outp = f[out_key][:] if out_key else None
|
|
994
|
+
inp, outp = _load_hdf5_for_test(
|
|
995
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
996
|
+
)
|
|
928
997
|
elif format == "mat":
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
933
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
934
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
935
|
-
if input_key is not None and input_key not in keys:
|
|
936
|
-
raise KeyError(
|
|
937
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
938
|
-
f"Available keys: {keys}"
|
|
939
|
-
)
|
|
940
|
-
if inp_key is None:
|
|
941
|
-
raise KeyError(
|
|
942
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
943
|
-
)
|
|
944
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
945
|
-
if output_key is not None and output_key not in keys:
|
|
946
|
-
raise KeyError(
|
|
947
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
948
|
-
f"Available keys: {keys}"
|
|
949
|
-
)
|
|
950
|
-
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
951
|
-
n_samples = f[inp_key].shape[-1]
|
|
952
|
-
if n_samples > 100000:
|
|
953
|
-
raise ValueError(
|
|
954
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
955
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
956
|
-
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
957
|
-
)
|
|
958
|
-
inp = mat_source._load_dataset(f, inp_key)
|
|
959
|
-
if out_key:
|
|
960
|
-
outp = mat_source._load_dataset(f, out_key)
|
|
961
|
-
# Handle transposed outputs from MATLAB
|
|
962
|
-
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
963
|
-
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
964
|
-
num_samples = inp.shape[0]
|
|
965
|
-
if outp.ndim == 2:
|
|
966
|
-
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
967
|
-
# 1D vector: (1, N) → (N, 1)
|
|
968
|
-
outp = outp.T
|
|
969
|
-
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
970
|
-
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
971
|
-
outp = outp.T
|
|
972
|
-
else:
|
|
973
|
-
outp = None
|
|
998
|
+
inp, outp = _load_mat_for_test(
|
|
999
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
1000
|
+
)
|
|
974
1001
|
else:
|
|
975
1002
|
# Fallback to default source.load() for unknown formats
|
|
976
1003
|
inp, outp = source.load(path)
|
|
@@ -984,81 +1011,28 @@ def load_test_data(
|
|
|
984
1011
|
f"Available keys depend on file format. Original error: {e}"
|
|
985
1012
|
) from e
|
|
986
1013
|
|
|
987
|
-
#
|
|
1014
|
+
# Also fail-fast if explicit input_key was provided but not found
|
|
1015
|
+
# This prevents silently loading a different tensor when user mistyped key
|
|
1016
|
+
if input_key is not None:
|
|
1017
|
+
raise KeyError(
|
|
1018
|
+
f"Explicit --input_key '{input_key}' not found in file. "
|
|
1019
|
+
f"Original error: {e}"
|
|
1020
|
+
) from e
|
|
1021
|
+
|
|
1022
|
+
# Legitimate fallback: no explicit keys, outputs just not present
|
|
1023
|
+
# Re-call helpers with None for explicit keys (validation already done above)
|
|
988
1024
|
if format == "npz":
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
993
|
-
if inp_key is None:
|
|
994
|
-
raise KeyError(
|
|
995
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
996
|
-
)
|
|
997
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
998
|
-
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
999
|
-
data = NPZSource._load_and_copy(path, keys_to_probe)
|
|
1000
|
-
inp = data[inp_key]
|
|
1001
|
-
if inp.dtype == object:
|
|
1002
|
-
inp = np.array(
|
|
1003
|
-
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
1004
|
-
)
|
|
1005
|
-
outp = data[out_key] if out_key else None
|
|
1025
|
+
inp, outp = _load_npz_for_test(
|
|
1026
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1027
|
+
)
|
|
1006
1028
|
elif format == "hdf5":
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
1011
|
-
if inp_key is None:
|
|
1012
|
-
raise KeyError(
|
|
1013
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
1014
|
-
)
|
|
1015
|
-
# Check size - load_test_data is eager, large files should use DataLoader
|
|
1016
|
-
n_samples = f[inp_key].shape[0]
|
|
1017
|
-
if n_samples > 100000:
|
|
1018
|
-
raise ValueError(
|
|
1019
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
1020
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
1021
|
-
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
1022
|
-
)
|
|
1023
|
-
inp = f[inp_key][:]
|
|
1024
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
1025
|
-
outp = f[out_key][:] if out_key else None
|
|
1029
|
+
inp, outp = _load_hdf5_for_test(
|
|
1030
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1031
|
+
)
|
|
1026
1032
|
elif format == "mat":
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
keys = list(f.keys())
|
|
1031
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
1032
|
-
if inp_key is None:
|
|
1033
|
-
raise KeyError(
|
|
1034
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
1035
|
-
)
|
|
1036
|
-
# Check size - load_test_data is eager, large files should use DataLoader
|
|
1037
|
-
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
1038
|
-
if n_samples > 100000:
|
|
1039
|
-
raise ValueError(
|
|
1040
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
1041
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
1042
|
-
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
1043
|
-
)
|
|
1044
|
-
# Use _load_dataset for sparse support and proper transpose
|
|
1045
|
-
inp = mat_source._load_dataset(f, inp_key)
|
|
1046
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
1047
|
-
if out_key:
|
|
1048
|
-
outp = mat_source._load_dataset(f, out_key)
|
|
1049
|
-
# Handle transposed outputs from MATLAB
|
|
1050
|
-
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
1051
|
-
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
1052
|
-
num_samples = inp.shape[0]
|
|
1053
|
-
if outp.ndim == 2:
|
|
1054
|
-
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
1055
|
-
# 1D vector: (1, N) → (N, 1)
|
|
1056
|
-
outp = outp.T
|
|
1057
|
-
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
1058
|
-
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
1059
|
-
outp = outp.T
|
|
1060
|
-
else:
|
|
1061
|
-
outp = None
|
|
1033
|
+
inp, outp = _load_mat_for_test(
|
|
1034
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1035
|
+
)
|
|
1062
1036
|
else:
|
|
1063
1037
|
raise
|
|
1064
1038
|
|
|
@@ -1524,21 +1498,43 @@ def prepare_data(
|
|
|
1524
1498
|
|
|
1525
1499
|
logger.info(" ✔ Cache creation complete, synchronizing ranks...")
|
|
1526
1500
|
else:
|
|
1527
|
-
# NON-MAIN RANKS: Wait for cache creation
|
|
1528
|
-
#
|
|
1501
|
+
# NON-MAIN RANKS: Wait for cache creation with timeout
|
|
1502
|
+
# Use monotonic clock (immune to system clock changes)
|
|
1529
1503
|
import time
|
|
1530
1504
|
|
|
1531
|
-
wait_start = time.
|
|
1505
|
+
wait_start = time.monotonic()
|
|
1506
|
+
|
|
1507
|
+
# Robust env parsing with guards for invalid/non-positive values
|
|
1508
|
+
DEFAULT_CACHE_TIMEOUT = 3600 # 1 hour default
|
|
1509
|
+
try:
|
|
1510
|
+
env_timeout = os.environ.get("WAVEDL_CACHE_TIMEOUT", "")
|
|
1511
|
+
CACHE_TIMEOUT = (
|
|
1512
|
+
int(env_timeout) if env_timeout else DEFAULT_CACHE_TIMEOUT
|
|
1513
|
+
)
|
|
1514
|
+
if CACHE_TIMEOUT <= 0:
|
|
1515
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1516
|
+
except ValueError:
|
|
1517
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1518
|
+
|
|
1532
1519
|
while not (
|
|
1533
1520
|
os.path.exists(CACHE_FILE)
|
|
1534
1521
|
and os.path.exists(SCALER_FILE)
|
|
1535
1522
|
and os.path.exists(META_FILE)
|
|
1536
1523
|
):
|
|
1537
1524
|
time.sleep(5) # Check every 5 seconds
|
|
1538
|
-
elapsed = time.
|
|
1525
|
+
elapsed = time.monotonic() - wait_start
|
|
1526
|
+
|
|
1527
|
+
if elapsed > CACHE_TIMEOUT:
|
|
1528
|
+
raise RuntimeError(
|
|
1529
|
+
f"[Rank {accelerator.process_index}] Timeout waiting for cache "
|
|
1530
|
+
f"files after {CACHE_TIMEOUT}s. Rank 0 may have failed during "
|
|
1531
|
+
f"cache generation. Check rank 0 logs for errors."
|
|
1532
|
+
)
|
|
1533
|
+
|
|
1539
1534
|
if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
|
|
1540
1535
|
logger.info(
|
|
1541
|
-
f" [Rank {accelerator.process_index}] Waiting for cache
|
|
1536
|
+
f" [Rank {accelerator.process_index}] Waiting for cache "
|
|
1537
|
+
f"creation... ({int(elapsed)}s / {CACHE_TIMEOUT}s max)"
|
|
1542
1538
|
)
|
|
1543
1539
|
# Small delay to ensure files are fully written
|
|
1544
1540
|
time.sleep(2)
|
wavedl/utils/metrics.py
CHANGED
|
@@ -106,8 +106,20 @@ def configure_matplotlib_style():
|
|
|
106
106
|
)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
-
#
|
|
110
|
-
|
|
109
|
+
# Lazy style initialization flag
|
|
110
|
+
_style_configured = False
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _ensure_style_configured():
|
|
114
|
+
"""Apply matplotlib style on first use (lazy initialization).
|
|
115
|
+
|
|
116
|
+
This avoids modifying global plt.rcParams at import time, which could
|
|
117
|
+
unexpectedly override user configurations.
|
|
118
|
+
"""
|
|
119
|
+
global _style_configured
|
|
120
|
+
if not _style_configured:
|
|
121
|
+
configure_matplotlib_style()
|
|
122
|
+
_style_configured = True
|
|
111
123
|
|
|
112
124
|
|
|
113
125
|
# ==============================================================================
|
|
@@ -115,11 +127,14 @@ configure_matplotlib_style()
|
|
|
115
127
|
# ==============================================================================
|
|
116
128
|
class MetricTracker:
|
|
117
129
|
"""
|
|
118
|
-
Tracks running averages of metrics
|
|
130
|
+
Tracks running averages of metrics during training.
|
|
119
131
|
|
|
120
132
|
Useful for tracking loss, accuracy, or any scalar metric across batches.
|
|
121
133
|
Handles division-by-zero safely by returning 0.0 when count is zero.
|
|
122
134
|
|
|
135
|
+
Note:
|
|
136
|
+
Not thread-safe. Intended for single-threaded use within training loops.
|
|
137
|
+
|
|
123
138
|
Attributes:
|
|
124
139
|
val: Most recent value
|
|
125
140
|
avg: Running average
|
|
@@ -341,6 +356,7 @@ def plot_scientific_scatter(
|
|
|
341
356
|
Returns:
|
|
342
357
|
Matplotlib Figure object (can be saved or logged to WandB)
|
|
343
358
|
"""
|
|
359
|
+
_ensure_style_configured()
|
|
344
360
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
345
361
|
y_true, y_pred, param_names, max_samples
|
|
346
362
|
)
|
|
@@ -415,6 +431,7 @@ def plot_error_histogram(
|
|
|
415
431
|
Returns:
|
|
416
432
|
Matplotlib Figure object
|
|
417
433
|
"""
|
|
434
|
+
_ensure_style_configured()
|
|
418
435
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
419
436
|
y_true, y_pred, param_names
|
|
420
437
|
)
|
|
@@ -485,6 +502,7 @@ def plot_residuals(
|
|
|
485
502
|
Returns:
|
|
486
503
|
Matplotlib Figure object
|
|
487
504
|
"""
|
|
505
|
+
_ensure_style_configured()
|
|
488
506
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
489
507
|
y_true, y_pred, param_names, max_samples
|
|
490
508
|
)
|
|
@@ -552,6 +570,7 @@ def create_training_curves(
|
|
|
552
570
|
Returns:
|
|
553
571
|
Matplotlib Figure object
|
|
554
572
|
"""
|
|
573
|
+
_ensure_style_configured()
|
|
555
574
|
epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
|
|
556
575
|
|
|
557
576
|
fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
|
|
@@ -605,8 +624,8 @@ def create_training_curves(
|
|
|
605
624
|
if not valid_data:
|
|
606
625
|
return
|
|
607
626
|
vmin, vmax = min(valid_data), max(valid_data)
|
|
608
|
-
# Get decade range that covers data (
|
|
609
|
-
log_min = int(np.
|
|
627
|
+
# Get decade range that covers data (floor for min to include lowest data)
|
|
628
|
+
log_min = int(np.floor(np.log10(vmin)))
|
|
610
629
|
log_max = int(np.ceil(np.log10(vmax)))
|
|
611
630
|
# Generate ticks at each power of 10
|
|
612
631
|
ticks = [10.0**i for i in range(log_min, log_max + 1)]
|
|
@@ -691,6 +710,7 @@ def plot_bland_altman(
|
|
|
691
710
|
Returns:
|
|
692
711
|
Matplotlib Figure object
|
|
693
712
|
"""
|
|
713
|
+
_ensure_style_configured()
|
|
694
714
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
695
715
|
y_true, y_pred, param_names, max_samples
|
|
696
716
|
)
|
|
@@ -764,6 +784,7 @@ def plot_qq(
|
|
|
764
784
|
Returns:
|
|
765
785
|
Matplotlib Figure object
|
|
766
786
|
"""
|
|
787
|
+
_ensure_style_configured()
|
|
767
788
|
from scipy import stats
|
|
768
789
|
|
|
769
790
|
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
wavedl/utils/schedulers.py
CHANGED
|
@@ -76,7 +76,7 @@ def get_scheduler(
|
|
|
76
76
|
# Step/MultiStep parameters
|
|
77
77
|
step_size: int = 30,
|
|
78
78
|
milestones: list[int] | None = None,
|
|
79
|
-
gamma: float = 0.
|
|
79
|
+
gamma: float = 0.99,
|
|
80
80
|
# Linear warmup parameters
|
|
81
81
|
warmup_epochs: int = 5,
|
|
82
82
|
start_factor: float = 0.1,
|
|
@@ -100,7 +100,7 @@ def get_scheduler(
|
|
|
100
100
|
pct_start: Percentage of cycle spent increasing LR (OneCycleLR)
|
|
101
101
|
step_size: Period for StepLR
|
|
102
102
|
milestones: Epochs to decay LR for MultiStepLR
|
|
103
|
-
gamma: Decay factor for step/multistep
|
|
103
|
+
gamma: Decay factor (0.1 for step/multistep, 0.99 for exponential)
|
|
104
104
|
warmup_epochs: Number of warmup epochs for linear_warmup
|
|
105
105
|
start_factor: Starting LR factor for warmup (LR * start_factor)
|
|
106
106
|
**kwargs: Additional arguments passed to scheduler
|