wavedl 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +115 -9
- wavedl/models/_pretrained_utils.py +72 -0
- wavedl/models/_template.py +7 -6
- wavedl/models/cnn.py +20 -0
- wavedl/models/convnext.py +3 -70
- wavedl/models/convnext_v2.py +1 -18
- wavedl/models/mamba.py +126 -38
- wavedl/models/resnet3d.py +23 -5
- wavedl/models/unireplknet.py +1 -18
- wavedl/models/vit.py +18 -8
- wavedl/test.py +5 -23
- wavedl/train.py +492 -26
- wavedl/utils/__init__.py +49 -9
- wavedl/utils/config.py +6 -8
- wavedl/utils/cross_validation.py +17 -4
- wavedl/utils/data.py +140 -174
- wavedl/utils/metrics.py +26 -5
- wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/METADATA +35 -14
- wavedl-1.7.0.dist-info/RECORD +46 -0
- wavedl-1.6.3.dist-info/RECORD +0 -46
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/LICENSE +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/WHEEL +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/top_level.txt +0 -0
wavedl/utils/cross_validation.py
CHANGED
|
@@ -36,7 +36,9 @@ from torch.utils.data import DataLoader
|
|
|
36
36
|
class CVDataset(torch.utils.data.Dataset):
|
|
37
37
|
"""Simple in-memory dataset for cross-validation."""
|
|
38
38
|
|
|
39
|
-
def __init__(
|
|
39
|
+
def __init__(
|
|
40
|
+
self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int | None = None
|
|
41
|
+
):
|
|
40
42
|
"""
|
|
41
43
|
Initialize CV dataset with explicit channel dimension handling.
|
|
42
44
|
|
|
@@ -51,6 +53,11 @@ class CVDataset(torch.utils.data.Dataset):
|
|
|
51
53
|
- If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
|
|
52
54
|
- If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
|
|
53
55
|
- If expected_spatial_ndim is None: Use legacy ndim-based inference
|
|
56
|
+
|
|
57
|
+
Warning:
|
|
58
|
+
Legacy mode (expected_spatial_ndim=None) may misinterpret multichannel
|
|
59
|
+
3D data as single-channel 4D data. Always pass expected_spatial_ndim
|
|
60
|
+
explicitly for 3D volumes with >1 channel.
|
|
54
61
|
"""
|
|
55
62
|
if expected_spatial_ndim is not None:
|
|
56
63
|
# Explicit mode: use expected_spatial_ndim to determine if channel exists
|
|
@@ -100,6 +107,7 @@ def train_fold(
|
|
|
100
107
|
device: torch.device,
|
|
101
108
|
epochs: int,
|
|
102
109
|
patience: int,
|
|
110
|
+
grad_clip: float,
|
|
103
111
|
scaler: StandardScaler,
|
|
104
112
|
logger: logging.Logger,
|
|
105
113
|
) -> dict[str, Any]:
|
|
@@ -147,7 +155,8 @@ def train_fold(
|
|
|
147
155
|
pred = model(x)
|
|
148
156
|
loss = criterion(pred, y)
|
|
149
157
|
loss.backward()
|
|
150
|
-
|
|
158
|
+
if grad_clip > 0:
|
|
159
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
151
160
|
optimizer.step()
|
|
152
161
|
|
|
153
162
|
# Per-batch LR scheduling (OneCycleLR)
|
|
@@ -289,6 +298,7 @@ def run_cross_validation(
|
|
|
289
298
|
output_dir: str = "./cv_results",
|
|
290
299
|
workers: int = 4,
|
|
291
300
|
seed: int = 2025,
|
|
301
|
+
grad_clip: float = 1.0,
|
|
292
302
|
logger: logging.Logger | None = None,
|
|
293
303
|
) -> dict[str, Any]:
|
|
294
304
|
"""
|
|
@@ -330,8 +340,10 @@ def run_cross_validation(
|
|
|
330
340
|
)
|
|
331
341
|
logger = logging.getLogger("CV-Trainer")
|
|
332
342
|
|
|
333
|
-
# Set seeds
|
|
334
|
-
|
|
343
|
+
# Set seeds for reproducibility
|
|
344
|
+
# Note: sklearn KFold uses random_state parameter directly, not global numpy RNG
|
|
345
|
+
rng = np.random.default_rng(seed) # Local RNG for any numpy operations
|
|
346
|
+
_ = rng # Silence unused variable warning (available for future use)
|
|
335
347
|
torch.manual_seed(seed)
|
|
336
348
|
if torch.cuda.is_available():
|
|
337
349
|
torch.cuda.manual_seed_all(seed)
|
|
@@ -444,6 +456,7 @@ def run_cross_validation(
|
|
|
444
456
|
device=device,
|
|
445
457
|
epochs=epochs,
|
|
446
458
|
patience=patience,
|
|
459
|
+
grad_clip=grad_clip,
|
|
447
460
|
scaler=scaler,
|
|
448
461
|
logger=logger,
|
|
449
462
|
)
|
wavedl/utils/data.py
CHANGED
|
@@ -798,6 +798,128 @@ def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
|
|
|
798
798
|
return source.load_outputs_only(path)
|
|
799
799
|
|
|
800
800
|
|
|
801
|
+
def _load_npz_for_test(
|
|
802
|
+
path: str,
|
|
803
|
+
input_keys: list[str],
|
|
804
|
+
output_keys: list[str],
|
|
805
|
+
explicit_input_key: str | None,
|
|
806
|
+
explicit_output_key: str | None,
|
|
807
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
808
|
+
"""Load NPZ file for test/inference with key validation."""
|
|
809
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
810
|
+
keys = list(probe.keys())
|
|
811
|
+
|
|
812
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
813
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
814
|
+
|
|
815
|
+
# Strict validation for explicit keys
|
|
816
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
817
|
+
raise KeyError(
|
|
818
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
819
|
+
)
|
|
820
|
+
if inp_key is None:
|
|
821
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
822
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
823
|
+
raise KeyError(
|
|
824
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
data = NPZSource._load_and_copy(path, [inp_key] + ([out_key] if out_key else []))
|
|
828
|
+
inp = data[inp_key]
|
|
829
|
+
if inp.dtype == object:
|
|
830
|
+
inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
|
|
831
|
+
outp = data[out_key] if out_key else None
|
|
832
|
+
return inp, outp
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def _load_hdf5_for_test(
|
|
836
|
+
path: str,
|
|
837
|
+
input_keys: list[str],
|
|
838
|
+
output_keys: list[str],
|
|
839
|
+
explicit_input_key: str | None,
|
|
840
|
+
explicit_output_key: str | None,
|
|
841
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
842
|
+
"""Load HDF5 file for test/inference with key validation and OOM guard."""
|
|
843
|
+
with h5py.File(path, "r") as f:
|
|
844
|
+
keys = list(f.keys())
|
|
845
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
846
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
847
|
+
|
|
848
|
+
# Strict validation
|
|
849
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
850
|
+
raise KeyError(
|
|
851
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
852
|
+
)
|
|
853
|
+
if inp_key is None:
|
|
854
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
855
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
856
|
+
raise KeyError(
|
|
857
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# OOM guard
|
|
861
|
+
n_samples = f[inp_key].shape[0]
|
|
862
|
+
if n_samples > 100000:
|
|
863
|
+
raise ValueError(
|
|
864
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
865
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
866
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
inp = f[inp_key][:]
|
|
870
|
+
outp = f[out_key][:] if out_key else None
|
|
871
|
+
return inp, outp
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def _load_mat_for_test(
|
|
875
|
+
path: str,
|
|
876
|
+
input_keys: list[str],
|
|
877
|
+
output_keys: list[str],
|
|
878
|
+
explicit_input_key: str | None,
|
|
879
|
+
explicit_output_key: str | None,
|
|
880
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
881
|
+
"""Load MAT v7.3 file for test/inference with sparse support."""
|
|
882
|
+
mat_source = MATSource()
|
|
883
|
+
with h5py.File(path, "r") as f:
|
|
884
|
+
keys = list(f.keys())
|
|
885
|
+
inp_key = DataSource._find_key(keys, input_keys)
|
|
886
|
+
out_key = DataSource._find_key(keys, output_keys)
|
|
887
|
+
|
|
888
|
+
# Strict validation
|
|
889
|
+
if explicit_input_key is not None and explicit_input_key not in keys:
|
|
890
|
+
raise KeyError(
|
|
891
|
+
f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
|
|
892
|
+
)
|
|
893
|
+
if inp_key is None:
|
|
894
|
+
raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
|
|
895
|
+
if explicit_output_key is not None and explicit_output_key not in keys:
|
|
896
|
+
raise KeyError(
|
|
897
|
+
f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
# OOM guard
|
|
901
|
+
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
902
|
+
if n_samples > 100000:
|
|
903
|
+
raise ValueError(
|
|
904
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
905
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
906
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
910
|
+
outp = None
|
|
911
|
+
if out_key:
|
|
912
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
913
|
+
# Handle MATLAB transpose
|
|
914
|
+
num_samples = inp.shape[0]
|
|
915
|
+
if outp.ndim == 2:
|
|
916
|
+
if (outp.shape[0] == 1 and outp.shape[1] == num_samples) or (
|
|
917
|
+
outp.shape[1] == 1 and outp.shape[0] != num_samples
|
|
918
|
+
):
|
|
919
|
+
outp = outp.T
|
|
920
|
+
return inp, outp
|
|
921
|
+
|
|
922
|
+
|
|
801
923
|
def load_test_data(
|
|
802
924
|
path: str,
|
|
803
925
|
format: str = "auto",
|
|
@@ -865,112 +987,17 @@ def load_test_data(
|
|
|
865
987
|
# We detect keys first to ensure input_test/output_test are used when present
|
|
866
988
|
try:
|
|
867
989
|
if format == "npz":
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
871
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
872
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
873
|
-
if input_key is not None and input_key not in keys:
|
|
874
|
-
raise KeyError(
|
|
875
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
876
|
-
f"Available keys: {keys}"
|
|
877
|
-
)
|
|
878
|
-
if inp_key is None:
|
|
879
|
-
raise KeyError(
|
|
880
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
881
|
-
)
|
|
882
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
883
|
-
if output_key is not None and output_key not in keys:
|
|
884
|
-
raise KeyError(
|
|
885
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
886
|
-
f"Available keys: {keys}"
|
|
887
|
-
)
|
|
888
|
-
data = NPZSource._load_and_copy(
|
|
889
|
-
path, [inp_key] + ([out_key] if out_key else [])
|
|
990
|
+
inp, outp = _load_npz_for_test(
|
|
991
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
890
992
|
)
|
|
891
|
-
inp = data[inp_key]
|
|
892
|
-
if inp.dtype == object:
|
|
893
|
-
inp = np.array(
|
|
894
|
-
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
895
|
-
)
|
|
896
|
-
outp = data[out_key] if out_key else None
|
|
897
993
|
elif format == "hdf5":
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
902
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
903
|
-
if input_key is not None and input_key not in keys:
|
|
904
|
-
raise KeyError(
|
|
905
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
906
|
-
f"Available keys: {keys}"
|
|
907
|
-
)
|
|
908
|
-
if inp_key is None:
|
|
909
|
-
raise KeyError(
|
|
910
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
911
|
-
)
|
|
912
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
913
|
-
if output_key is not None and output_key not in keys:
|
|
914
|
-
raise KeyError(
|
|
915
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
916
|
-
f"Available keys: {keys}"
|
|
917
|
-
)
|
|
918
|
-
# OOM guard: warn if dataset is very large
|
|
919
|
-
n_samples = f[inp_key].shape[0]
|
|
920
|
-
if n_samples > 100000:
|
|
921
|
-
raise ValueError(
|
|
922
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
923
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
924
|
-
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
925
|
-
)
|
|
926
|
-
inp = f[inp_key][:]
|
|
927
|
-
outp = f[out_key][:] if out_key else None
|
|
994
|
+
inp, outp = _load_hdf5_for_test(
|
|
995
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
996
|
+
)
|
|
928
997
|
elif format == "mat":
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
933
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
934
|
-
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
935
|
-
if input_key is not None and input_key not in keys:
|
|
936
|
-
raise KeyError(
|
|
937
|
-
f"Explicit --input_key '{input_key}' not found. "
|
|
938
|
-
f"Available keys: {keys}"
|
|
939
|
-
)
|
|
940
|
-
if inp_key is None:
|
|
941
|
-
raise KeyError(
|
|
942
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
943
|
-
)
|
|
944
|
-
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
945
|
-
if output_key is not None and output_key not in keys:
|
|
946
|
-
raise KeyError(
|
|
947
|
-
f"Explicit --output_key '{output_key}' not found. "
|
|
948
|
-
f"Available keys: {keys}"
|
|
949
|
-
)
|
|
950
|
-
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
951
|
-
n_samples = f[inp_key].shape[-1]
|
|
952
|
-
if n_samples > 100000:
|
|
953
|
-
raise ValueError(
|
|
954
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
955
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
956
|
-
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
957
|
-
)
|
|
958
|
-
inp = mat_source._load_dataset(f, inp_key)
|
|
959
|
-
if out_key:
|
|
960
|
-
outp = mat_source._load_dataset(f, out_key)
|
|
961
|
-
# Handle transposed outputs from MATLAB
|
|
962
|
-
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
963
|
-
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
964
|
-
num_samples = inp.shape[0]
|
|
965
|
-
if outp.ndim == 2:
|
|
966
|
-
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
967
|
-
# 1D vector: (1, N) → (N, 1)
|
|
968
|
-
outp = outp.T
|
|
969
|
-
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
970
|
-
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
971
|
-
outp = outp.T
|
|
972
|
-
else:
|
|
973
|
-
outp = None
|
|
998
|
+
inp, outp = _load_mat_for_test(
|
|
999
|
+
path, custom_input_keys, custom_output_keys, input_key, output_key
|
|
1000
|
+
)
|
|
974
1001
|
else:
|
|
975
1002
|
# Fallback to default source.load() for unknown formats
|
|
976
1003
|
inp, outp = source.load(path)
|
|
@@ -993,80 +1020,19 @@ def load_test_data(
|
|
|
993
1020
|
) from e
|
|
994
1021
|
|
|
995
1022
|
# Legitimate fallback: no explicit keys, outputs just not present
|
|
1023
|
+
# Re-call helpers with None for explicit keys (validation already done above)
|
|
996
1024
|
if format == "npz":
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
1001
|
-
if inp_key is None:
|
|
1002
|
-
raise KeyError(
|
|
1003
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
1004
|
-
)
|
|
1005
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
1006
|
-
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
1007
|
-
data = NPZSource._load_and_copy(path, keys_to_probe)
|
|
1008
|
-
inp = data[inp_key]
|
|
1009
|
-
if inp.dtype == object:
|
|
1010
|
-
inp = np.array(
|
|
1011
|
-
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
1012
|
-
)
|
|
1013
|
-
outp = data[out_key] if out_key else None
|
|
1025
|
+
inp, outp = _load_npz_for_test(
|
|
1026
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1027
|
+
)
|
|
1014
1028
|
elif format == "hdf5":
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
1019
|
-
if inp_key is None:
|
|
1020
|
-
raise KeyError(
|
|
1021
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
1022
|
-
)
|
|
1023
|
-
# Check size - load_test_data is eager, large files should use DataLoader
|
|
1024
|
-
n_samples = f[inp_key].shape[0]
|
|
1025
|
-
if n_samples > 100000:
|
|
1026
|
-
raise ValueError(
|
|
1027
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
1028
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
1029
|
-
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
1030
|
-
)
|
|
1031
|
-
inp = f[inp_key][:]
|
|
1032
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
1033
|
-
outp = f[out_key][:] if out_key else None
|
|
1029
|
+
inp, outp = _load_hdf5_for_test(
|
|
1030
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1031
|
+
)
|
|
1034
1032
|
elif format == "mat":
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
keys = list(f.keys())
|
|
1039
|
-
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
1040
|
-
if inp_key is None:
|
|
1041
|
-
raise KeyError(
|
|
1042
|
-
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
1043
|
-
)
|
|
1044
|
-
# Check size - load_test_data is eager, large files should use DataLoader
|
|
1045
|
-
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
1046
|
-
if n_samples > 100000:
|
|
1047
|
-
raise ValueError(
|
|
1048
|
-
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
1049
|
-
f"everything into RAM which may cause OOM. For large inference "
|
|
1050
|
-
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
1051
|
-
)
|
|
1052
|
-
# Use _load_dataset for sparse support and proper transpose
|
|
1053
|
-
inp = mat_source._load_dataset(f, inp_key)
|
|
1054
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
1055
|
-
if out_key:
|
|
1056
|
-
outp = mat_source._load_dataset(f, out_key)
|
|
1057
|
-
# Handle transposed outputs from MATLAB
|
|
1058
|
-
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
1059
|
-
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
1060
|
-
num_samples = inp.shape[0]
|
|
1061
|
-
if outp.ndim == 2:
|
|
1062
|
-
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
1063
|
-
# 1D vector: (1, N) → (N, 1)
|
|
1064
|
-
outp = outp.T
|
|
1065
|
-
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
1066
|
-
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
1067
|
-
outp = outp.T
|
|
1068
|
-
else:
|
|
1069
|
-
outp = None
|
|
1033
|
+
inp, outp = _load_mat_for_test(
|
|
1034
|
+
path, custom_input_keys, custom_output_keys, None, None
|
|
1035
|
+
)
|
|
1070
1036
|
else:
|
|
1071
1037
|
raise
|
|
1072
1038
|
|
wavedl/utils/metrics.py
CHANGED
|
@@ -106,8 +106,20 @@ def configure_matplotlib_style():
|
|
|
106
106
|
)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
-
#
|
|
110
|
-
|
|
109
|
+
# Lazy style initialization flag
|
|
110
|
+
_style_configured = False
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _ensure_style_configured():
|
|
114
|
+
"""Apply matplotlib style on first use (lazy initialization).
|
|
115
|
+
|
|
116
|
+
This avoids modifying global plt.rcParams at import time, which could
|
|
117
|
+
unexpectedly override user configurations.
|
|
118
|
+
"""
|
|
119
|
+
global _style_configured
|
|
120
|
+
if not _style_configured:
|
|
121
|
+
configure_matplotlib_style()
|
|
122
|
+
_style_configured = True
|
|
111
123
|
|
|
112
124
|
|
|
113
125
|
# ==============================================================================
|
|
@@ -115,11 +127,14 @@ configure_matplotlib_style()
|
|
|
115
127
|
# ==============================================================================
|
|
116
128
|
class MetricTracker:
|
|
117
129
|
"""
|
|
118
|
-
Tracks running averages of metrics
|
|
130
|
+
Tracks running averages of metrics during training.
|
|
119
131
|
|
|
120
132
|
Useful for tracking loss, accuracy, or any scalar metric across batches.
|
|
121
133
|
Handles division-by-zero safely by returning 0.0 when count is zero.
|
|
122
134
|
|
|
135
|
+
Note:
|
|
136
|
+
Not thread-safe. Intended for single-threaded use within training loops.
|
|
137
|
+
|
|
123
138
|
Attributes:
|
|
124
139
|
val: Most recent value
|
|
125
140
|
avg: Running average
|
|
@@ -341,6 +356,7 @@ def plot_scientific_scatter(
|
|
|
341
356
|
Returns:
|
|
342
357
|
Matplotlib Figure object (can be saved or logged to WandB)
|
|
343
358
|
"""
|
|
359
|
+
_ensure_style_configured()
|
|
344
360
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
345
361
|
y_true, y_pred, param_names, max_samples
|
|
346
362
|
)
|
|
@@ -415,6 +431,7 @@ def plot_error_histogram(
|
|
|
415
431
|
Returns:
|
|
416
432
|
Matplotlib Figure object
|
|
417
433
|
"""
|
|
434
|
+
_ensure_style_configured()
|
|
418
435
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
419
436
|
y_true, y_pred, param_names
|
|
420
437
|
)
|
|
@@ -485,6 +502,7 @@ def plot_residuals(
|
|
|
485
502
|
Returns:
|
|
486
503
|
Matplotlib Figure object
|
|
487
504
|
"""
|
|
505
|
+
_ensure_style_configured()
|
|
488
506
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
489
507
|
y_true, y_pred, param_names, max_samples
|
|
490
508
|
)
|
|
@@ -552,6 +570,7 @@ def create_training_curves(
|
|
|
552
570
|
Returns:
|
|
553
571
|
Matplotlib Figure object
|
|
554
572
|
"""
|
|
573
|
+
_ensure_style_configured()
|
|
555
574
|
epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
|
|
556
575
|
|
|
557
576
|
fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
|
|
@@ -605,8 +624,8 @@ def create_training_curves(
|
|
|
605
624
|
if not valid_data:
|
|
606
625
|
return
|
|
607
626
|
vmin, vmax = min(valid_data), max(valid_data)
|
|
608
|
-
# Get decade range that covers data (
|
|
609
|
-
log_min = int(np.
|
|
627
|
+
# Get decade range that covers data (floor for min to include lowest data)
|
|
628
|
+
log_min = int(np.floor(np.log10(vmin)))
|
|
610
629
|
log_max = int(np.ceil(np.log10(vmax)))
|
|
611
630
|
# Generate ticks at each power of 10
|
|
612
631
|
ticks = [10.0**i for i in range(log_min, log_max + 1)]
|
|
@@ -691,6 +710,7 @@ def plot_bland_altman(
|
|
|
691
710
|
Returns:
|
|
692
711
|
Matplotlib Figure object
|
|
693
712
|
"""
|
|
713
|
+
_ensure_style_configured()
|
|
694
714
|
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
695
715
|
y_true, y_pred, param_names, max_samples
|
|
696
716
|
)
|
|
@@ -764,6 +784,7 @@ def plot_qq(
|
|
|
764
784
|
Returns:
|
|
765
785
|
Matplotlib Figure object
|
|
766
786
|
"""
|
|
787
|
+
_ensure_style_configured()
|
|
767
788
|
from scipy import stats
|
|
768
789
|
|
|
769
790
|
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
wavedl/utils/schedulers.py
CHANGED
|
@@ -76,7 +76,7 @@ def get_scheduler(
|
|
|
76
76
|
# Step/MultiStep parameters
|
|
77
77
|
step_size: int = 30,
|
|
78
78
|
milestones: list[int] | None = None,
|
|
79
|
-
gamma: float = 0.
|
|
79
|
+
gamma: float = 0.99,
|
|
80
80
|
# Linear warmup parameters
|
|
81
81
|
warmup_epochs: int = 5,
|
|
82
82
|
start_factor: float = 0.1,
|
|
@@ -100,7 +100,7 @@ def get_scheduler(
|
|
|
100
100
|
pct_start: Percentage of cycle spent increasing LR (OneCycleLR)
|
|
101
101
|
step_size: Period for StepLR
|
|
102
102
|
milestones: Epochs to decay LR for MultiStepLR
|
|
103
|
-
gamma: Decay factor for step/multistep
|
|
103
|
+
gamma: Decay factor (0.1 for step/multistep, 0.99 for exponential)
|
|
104
104
|
warmup_epochs: Number of warmup epochs for linear_warmup
|
|
105
105
|
start_factor: Starting LR factor for warmup (LR * start_factor)
|
|
106
106
|
**kwargs: Additional arguments passed to scheduler
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7.0
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -907,18 +907,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
907
907
|
**Run HPO:**
|
|
908
908
|
|
|
909
909
|
```bash
|
|
910
|
-
# Basic HPO (auto-detects GPUs
|
|
911
|
-
wavedl-hpo --data_path train.npz --
|
|
910
|
+
# Basic HPO (50 trials, auto-detects GPUs)
|
|
911
|
+
wavedl-hpo --data_path train.npz --n_trials 50
|
|
912
912
|
|
|
913
|
-
#
|
|
914
|
-
wavedl-hpo --data_path train.npz --
|
|
913
|
+
# Quick search (minimal search space, fastest)
|
|
914
|
+
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
915
915
|
|
|
916
|
-
#
|
|
917
|
-
wavedl-hpo --data_path train.npz --
|
|
916
|
+
# Medium search (balanced between quick and full)
|
|
917
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --medium
|
|
918
|
+
|
|
919
|
+
# Full search with specific models
|
|
920
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
921
|
+
|
|
922
|
+
# In-process mode (enables pruning, faster, single-GPU)
|
|
923
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
|
|
918
924
|
```
|
|
919
925
|
|
|
920
926
|
> [!TIP]
|
|
921
|
-
> **
|
|
927
|
+
> **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
|
|
922
928
|
|
|
923
929
|
**Train with best parameters**
|
|
924
930
|
|
|
@@ -940,10 +946,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
940
946
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
941
947
|
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
942
948
|
|
|
943
|
-
**
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
949
|
+
**Search Presets:**
|
|
950
|
+
|
|
951
|
+
| Mode | Models | Optimizers | Schedulers | Use Case |
|
|
952
|
+
|------|--------|------------|------------|----------|
|
|
953
|
+
| Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
|
|
954
|
+
| `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
|
|
955
|
+
| `--quick` | cnn | adamw | plateau | Fast validation |
|
|
956
|
+
|
|
957
|
+
**Execution Modes:**
|
|
958
|
+
|
|
959
|
+
| Mode | Flag | Pruning | GPU Memory | Best For |
|
|
960
|
+
|------|------|---------|------------|----------|
|
|
961
|
+
| Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
|
|
962
|
+
| In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
|
|
963
|
+
|
|
964
|
+
> [!TIP]
|
|
965
|
+
> Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
|
|
947
966
|
|
|
948
967
|
---
|
|
949
968
|
|
|
@@ -954,7 +973,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
954
973
|
| `--data_path` | (required) | Training data file |
|
|
955
974
|
| `--models` | 3 defaults | Models to search (specify any number) |
|
|
956
975
|
| `--n_trials` | `50` | Number of trials to run |
|
|
957
|
-
| `--quick` | `False` |
|
|
976
|
+
| `--quick` | `False` | Quick mode: minimal search space |
|
|
977
|
+
| `--medium` | `False` | Medium mode: balanced search space |
|
|
978
|
+
| `--inprocess` | `False` | Run trials in-process (enables pruning) |
|
|
958
979
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
959
980
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
960
981
|
| `--losses` | all 6 | Losses to search |
|
|
@@ -963,7 +984,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
963
984
|
| `--output` | `hpo_results.json` | Output file |
|
|
964
985
|
|
|
965
986
|
|
|
966
|
-
> See [Available Models](#available-models) for all
|
|
987
|
+
> See [Available Models](#available-models) for all 69 architectures you can search.
|
|
967
988
|
|
|
968
989
|
</details>
|
|
969
990
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
wavedl/__init__.py,sha256=Ol1M5mok2rnnUKDPBZLQDBDwKn7_LV9iTxds5obDeJk,1177
|
|
2
|
+
wavedl/hpo.py,sha256=TyEWubL-adQJyRmSv1M1SnJvH7_vTlXJATI1ElxXVUU,18991
|
|
3
|
+
wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
|
|
4
|
+
wavedl/test.py,sha256=_6i6F1KkO24cC0RtvxcwAGxDUbU6VV12efDWFGojkeE,38466
|
|
5
|
+
wavedl/train.py,sha256=U0-hHo4wba1dhZ2jjaCFURiS6aGHuPDGz0zPwPm2Kc0,72722
|
|
6
|
+
wavedl/models/__init__.py,sha256=hyR__h_D8PsUQCBSM5tj94yYK00uG8ABjEmj_RR8SGE,5719
|
|
7
|
+
wavedl/models/_pretrained_utils.py,sha256=br--ZrgqndYCO_iAeQOvUDg6ZxzGDyLZFeWu1Qj_DrI,14756
|
|
8
|
+
wavedl/models/_template.py,sha256=nEixVS8e82Tud08Uk8jkXtriGhk_WFbqSaGDq_Mj4ak,4684
|
|
9
|
+
wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
|
|
10
|
+
wavedl/models/caformer.py,sha256=ufPM-HzQ-qUZcXgnOulurY6jBUlMUzokC01whtPeVMg,7922
|
|
11
|
+
wavedl/models/cnn.py,sha256=dOmCrHGXd8Md8ixbJ_-An9t80tm36sVY84je2EDmnZA,8256
|
|
12
|
+
wavedl/models/convnext.py,sha256=GoLId2HClsOksuL3XLscEIytrmOBPGhO6UhGn04yDp4,13354
|
|
13
|
+
wavedl/models/convnext_v2.py,sha256=jPPXTZbQQ8zE9yGVWTNUaI5g1d0xIxBjrLuUHUKc5mM,14349
|
|
14
|
+
wavedl/models/densenet.py,sha256=V_caGd0wsG_Q3Q38I4MEgYmU0v4j8mDyvv7Rn3Bk7Ac,12667
|
|
15
|
+
wavedl/models/efficientnet.py,sha256=HWfhqSX57lC5Xug5TrQ3r-uFqkksoIKjmQ5Zr5njkEA,8264
|
|
16
|
+
wavedl/models/efficientnetv2.py,sha256=hVSnVId8T1rjqaKlckLqWFwvo2J-qASX7o9lMbXbP-s,10947
|
|
17
|
+
wavedl/models/efficientvit.py,sha256=KqFoZq9YHBMnTue6aMdPKgBOMczeBPryY_F6ip0hoEI,11630
|
|
18
|
+
wavedl/models/fastvit.py,sha256=S0SF0iC-9ZJrP-9YUTLPhMJMV-W9r2--V3hVAmSSVKI,7083
|
|
19
|
+
wavedl/models/mamba.py,sha256=2mqBxUKCLJNkRc87QzpsOj6hKzrEh5tchGSyYZCSUcQ,20031
|
|
20
|
+
wavedl/models/maxvit.py,sha256=I6TFGrLRcyMU-nU7u5VhOaXZWWdwmNJwHsMqbJh_g_o,7548
|
|
21
|
+
wavedl/models/mobilenetv3.py,sha256=LZxCg599kGP6-XI_l3PpT8jzh4oTAdWH3Y7GH097o28,10242
|
|
22
|
+
wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
|
|
23
|
+
wavedl/models/regnet.py,sha256=6Yjo2wZzdjK8VpOMagbCrHqmsfRmGkuiURmc-MesYvA,13777
|
|
24
|
+
wavedl/models/resnet.py,sha256=3i4zfE15qF4cd0qbTKX-Wdy2Kd0f4mLcdd316FAcVCo,16720
|
|
25
|
+
wavedl/models/resnet3d.py,sha256=NS6UBmvITO3NbdBNfe39bxViFqrIgeSXBnDgYi8QsC8,9247
|
|
26
|
+
wavedl/models/swin.py,sha256=39Gwn5hNEw3-tndc8qFFzV-VZ7pJMMKey2oZONAZ8MU,14980
|
|
27
|
+
wavedl/models/tcn.py,sha256=XzojpuMFG4lu_0oQHbQnkLAb7AnW-D7_6KoBlQDPLnQ,12367
|
|
28
|
+
wavedl/models/unet.py,sha256=oi7eBONSe0ALpJKsYda3jRGwu-LuSiFgNdURebnGGt0,7712
|
|
29
|
+
wavedl/models/unireplknet.py,sha256=sbiYcc2NeB0-_VAmeoe9Vi5hQzhYz03knG7o2Qk0WYE,14634
|
|
30
|
+
wavedl/models/vit.py,sha256=o-zWT2GBCTs9vD3jUFwlcwxK53XqEn_x4iPaRuEQe10,15219
|
|
31
|
+
wavedl/utils/__init__.py,sha256=CYqD3Bcwcub2mSrW05x8wvd2n1Co_3N9ajyKPyBswjo,4887
|
|
32
|
+
wavedl/utils/config.py,sha256=yAKuuhM-oxvHFXomkkek4IGihsVO5yZxc4b2noQ1amE,10523
|
|
33
|
+
wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
|
|
34
|
+
wavedl/utils/cross_validation.py,sha256=evp8EsHGJcxMHpfRdFSParDltUTyMhKQxUkcn5-osI4,18556
|
|
35
|
+
wavedl/utils/data.py,sha256=f-LLIiiv74iiKrR8TQ9oeKODF29_jzeUUp4iMBuj_H4,60875
|
|
36
|
+
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
37
|
+
wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
|
|
38
|
+
wavedl/utils/metrics.py,sha256=tLzKG2zINyXit-KvYZSJg-1nG6rST54GH6k4ALonToU,40935
|
|
39
|
+
wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
|
|
40
|
+
wavedl/utils/schedulers.py,sha256=_aFTQ8kuvdZIxOoXPHQRu_N9XBuTVSjU6dmBbzH430o,7425
|
|
41
|
+
wavedl-1.7.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
42
|
+
wavedl-1.7.0.dist-info/METADATA,sha256=f5UMARudJtBdt5Pu3E18Rfk2sD4h77A2mt8Mh8_cJk4,48527
|
|
43
|
+
wavedl-1.7.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
44
|
+
wavedl-1.7.0.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
|
|
45
|
+
wavedl-1.7.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
46
|
+
wavedl-1.7.0.dist-info/RECORD,,
|