wavedl 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,9 @@ from torch.utils.data import DataLoader
36
36
  class CVDataset(torch.utils.data.Dataset):
37
37
  """Simple in-memory dataset for cross-validation."""
38
38
 
39
- def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
39
+ def __init__(
40
+ self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int | None = None
41
+ ):
40
42
  """
41
43
  Initialize CV dataset with explicit channel dimension handling.
42
44
 
@@ -51,6 +53,11 @@ class CVDataset(torch.utils.data.Dataset):
51
53
  - If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
52
54
  - If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
53
55
  - If expected_spatial_ndim is None: Use legacy ndim-based inference
56
+
57
+ Warning:
58
+ Legacy mode (expected_spatial_ndim=None) may misinterpret multichannel
59
+ 3D data as single-channel 4D data. Always pass expected_spatial_ndim
60
+ explicitly for 3D volumes with >1 channel.
54
61
  """
55
62
  if expected_spatial_ndim is not None:
56
63
  # Explicit mode: use expected_spatial_ndim to determine if channel exists
@@ -100,6 +107,7 @@ def train_fold(
100
107
  device: torch.device,
101
108
  epochs: int,
102
109
  patience: int,
110
+ grad_clip: float,
103
111
  scaler: StandardScaler,
104
112
  logger: logging.Logger,
105
113
  ) -> dict[str, Any]:
@@ -147,7 +155,8 @@ def train_fold(
147
155
  pred = model(x)
148
156
  loss = criterion(pred, y)
149
157
  loss.backward()
150
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
158
+ if grad_clip > 0:
159
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
151
160
  optimizer.step()
152
161
 
153
162
  # Per-batch LR scheduling (OneCycleLR)
@@ -289,6 +298,7 @@ def run_cross_validation(
289
298
  output_dir: str = "./cv_results",
290
299
  workers: int = 4,
291
300
  seed: int = 2025,
301
+ grad_clip: float = 1.0,
292
302
  logger: logging.Logger | None = None,
293
303
  ) -> dict[str, Any]:
294
304
  """
@@ -330,8 +340,10 @@ def run_cross_validation(
330
340
  )
331
341
  logger = logging.getLogger("CV-Trainer")
332
342
 
333
- # Set seeds
334
- np.random.seed(seed)
343
+ # Set seeds for reproducibility
344
+ # Note: sklearn KFold uses random_state parameter directly, not global numpy RNG
345
+ rng = np.random.default_rng(seed) # Local RNG for any numpy operations
346
+ _ = rng # Silence unused variable warning (available for future use)
335
347
  torch.manual_seed(seed)
336
348
  if torch.cuda.is_available():
337
349
  torch.cuda.manual_seed_all(seed)
@@ -444,6 +456,7 @@ def run_cross_validation(
444
456
  device=device,
445
457
  epochs=epochs,
446
458
  patience=patience,
459
+ grad_clip=grad_clip,
447
460
  scaler=scaler,
448
461
  logger=logger,
449
462
  )
wavedl/utils/data.py CHANGED
@@ -798,6 +798,128 @@ def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
798
798
  return source.load_outputs_only(path)
799
799
 
800
800
 
801
+ def _load_npz_for_test(
802
+ path: str,
803
+ input_keys: list[str],
804
+ output_keys: list[str],
805
+ explicit_input_key: str | None,
806
+ explicit_output_key: str | None,
807
+ ) -> tuple[np.ndarray, np.ndarray | None]:
808
+ """Load NPZ file for test/inference with key validation."""
809
+ with np.load(path, allow_pickle=False) as probe:
810
+ keys = list(probe.keys())
811
+
812
+ inp_key = DataSource._find_key(keys, input_keys)
813
+ out_key = DataSource._find_key(keys, output_keys)
814
+
815
+ # Strict validation for explicit keys
816
+ if explicit_input_key is not None and explicit_input_key not in keys:
817
+ raise KeyError(
818
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
819
+ )
820
+ if inp_key is None:
821
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
822
+ if explicit_output_key is not None and explicit_output_key not in keys:
823
+ raise KeyError(
824
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
825
+ )
826
+
827
+ data = NPZSource._load_and_copy(path, [inp_key] + ([out_key] if out_key else []))
828
+ inp = data[inp_key]
829
+ if inp.dtype == object:
830
+ inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
831
+ outp = data[out_key] if out_key else None
832
+ return inp, outp
833
+
834
+
835
+ def _load_hdf5_for_test(
836
+ path: str,
837
+ input_keys: list[str],
838
+ output_keys: list[str],
839
+ explicit_input_key: str | None,
840
+ explicit_output_key: str | None,
841
+ ) -> tuple[np.ndarray, np.ndarray | None]:
842
+ """Load HDF5 file for test/inference with key validation and OOM guard."""
843
+ with h5py.File(path, "r") as f:
844
+ keys = list(f.keys())
845
+ inp_key = DataSource._find_key(keys, input_keys)
846
+ out_key = DataSource._find_key(keys, output_keys)
847
+
848
+ # Strict validation
849
+ if explicit_input_key is not None and explicit_input_key not in keys:
850
+ raise KeyError(
851
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
852
+ )
853
+ if inp_key is None:
854
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
855
+ if explicit_output_key is not None and explicit_output_key not in keys:
856
+ raise KeyError(
857
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
858
+ )
859
+
860
+ # OOM guard
861
+ n_samples = f[inp_key].shape[0]
862
+ if n_samples > 100000:
863
+ raise ValueError(
864
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
865
+ f"everything into RAM which may cause OOM. For large inference "
866
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
867
+ )
868
+
869
+ inp = f[inp_key][:]
870
+ outp = f[out_key][:] if out_key else None
871
+ return inp, outp
872
+
873
+
874
+ def _load_mat_for_test(
875
+ path: str,
876
+ input_keys: list[str],
877
+ output_keys: list[str],
878
+ explicit_input_key: str | None,
879
+ explicit_output_key: str | None,
880
+ ) -> tuple[np.ndarray, np.ndarray | None]:
881
+ """Load MAT v7.3 file for test/inference with sparse support."""
882
+ mat_source = MATSource()
883
+ with h5py.File(path, "r") as f:
884
+ keys = list(f.keys())
885
+ inp_key = DataSource._find_key(keys, input_keys)
886
+ out_key = DataSource._find_key(keys, output_keys)
887
+
888
+ # Strict validation
889
+ if explicit_input_key is not None and explicit_input_key not in keys:
890
+ raise KeyError(
891
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
892
+ )
893
+ if inp_key is None:
894
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
895
+ if explicit_output_key is not None and explicit_output_key not in keys:
896
+ raise KeyError(
897
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
898
+ )
899
+
900
+ # OOM guard
901
+ n_samples = f[inp_key].shape[-1] # MAT is transposed
902
+ if n_samples > 100000:
903
+ raise ValueError(
904
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
905
+ f"everything into RAM which may cause OOM. For large inference "
906
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
907
+ )
908
+
909
+ inp = mat_source._load_dataset(f, inp_key)
910
+ outp = None
911
+ if out_key:
912
+ outp = mat_source._load_dataset(f, out_key)
913
+ # Handle MATLAB transpose
914
+ num_samples = inp.shape[0]
915
+ if outp.ndim == 2:
916
+ if (outp.shape[0] == 1 and outp.shape[1] == num_samples) or (
917
+ outp.shape[1] == 1 and outp.shape[0] != num_samples
918
+ ):
919
+ outp = outp.T
920
+ return inp, outp
921
+
922
+
801
923
  def load_test_data(
802
924
  path: str,
803
925
  format: str = "auto",
@@ -865,112 +987,17 @@ def load_test_data(
865
987
  # We detect keys first to ensure input_test/output_test are used when present
866
988
  try:
867
989
  if format == "npz":
868
- with np.load(path, allow_pickle=False) as probe:
869
- keys = list(probe.keys())
870
- inp_key = DataSource._find_key(keys, custom_input_keys)
871
- out_key = DataSource._find_key(keys, custom_output_keys)
872
- # Strict validation: if user explicitly specified input_key, it must exist exactly
873
- if input_key is not None and input_key not in keys:
874
- raise KeyError(
875
- f"Explicit --input_key '{input_key}' not found. "
876
- f"Available keys: {keys}"
877
- )
878
- if inp_key is None:
879
- raise KeyError(
880
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
881
- )
882
- # Strict validation: if user explicitly specified output_key, it must exist exactly
883
- if output_key is not None and output_key not in keys:
884
- raise KeyError(
885
- f"Explicit --output_key '{output_key}' not found. "
886
- f"Available keys: {keys}"
887
- )
888
- data = NPZSource._load_and_copy(
889
- path, [inp_key] + ([out_key] if out_key else [])
990
+ inp, outp = _load_npz_for_test(
991
+ path, custom_input_keys, custom_output_keys, input_key, output_key
890
992
  )
891
- inp = data[inp_key]
892
- if inp.dtype == object:
893
- inp = np.array(
894
- [x.toarray() if hasattr(x, "toarray") else x for x in inp]
895
- )
896
- outp = data[out_key] if out_key else None
897
993
  elif format == "hdf5":
898
- with h5py.File(path, "r") as f:
899
- keys = list(f.keys())
900
- inp_key = DataSource._find_key(keys, custom_input_keys)
901
- out_key = DataSource._find_key(keys, custom_output_keys)
902
- # Strict validation: if user explicitly specified input_key, it must exist exactly
903
- if input_key is not None and input_key not in keys:
904
- raise KeyError(
905
- f"Explicit --input_key '{input_key}' not found. "
906
- f"Available keys: {keys}"
907
- )
908
- if inp_key is None:
909
- raise KeyError(
910
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
911
- )
912
- # Strict validation: if user explicitly specified output_key, it must exist exactly
913
- if output_key is not None and output_key not in keys:
914
- raise KeyError(
915
- f"Explicit --output_key '{output_key}' not found. "
916
- f"Available keys: {keys}"
917
- )
918
- # OOM guard: warn if dataset is very large
919
- n_samples = f[inp_key].shape[0]
920
- if n_samples > 100000:
921
- raise ValueError(
922
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
923
- f"everything into RAM which may cause OOM. For large inference "
924
- f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
925
- )
926
- inp = f[inp_key][:]
927
- outp = f[out_key][:] if out_key else None
994
+ inp, outp = _load_hdf5_for_test(
995
+ path, custom_input_keys, custom_output_keys, input_key, output_key
996
+ )
928
997
  elif format == "mat":
929
- mat_source = MATSource()
930
- with h5py.File(path, "r") as f:
931
- keys = list(f.keys())
932
- inp_key = DataSource._find_key(keys, custom_input_keys)
933
- out_key = DataSource._find_key(keys, custom_output_keys)
934
- # Strict validation: if user explicitly specified input_key, it must exist exactly
935
- if input_key is not None and input_key not in keys:
936
- raise KeyError(
937
- f"Explicit --input_key '{input_key}' not found. "
938
- f"Available keys: {keys}"
939
- )
940
- if inp_key is None:
941
- raise KeyError(
942
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
943
- )
944
- # Strict validation: if user explicitly specified output_key, it must exist exactly
945
- if output_key is not None and output_key not in keys:
946
- raise KeyError(
947
- f"Explicit --output_key '{output_key}' not found. "
948
- f"Available keys: {keys}"
949
- )
950
- # OOM guard: warn if dataset is very large (MAT is transposed)
951
- n_samples = f[inp_key].shape[-1]
952
- if n_samples > 100000:
953
- raise ValueError(
954
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
955
- f"everything into RAM which may cause OOM. For large inference "
956
- f"sets, use a DataLoader with MATSource.load_mmap() instead."
957
- )
958
- inp = mat_source._load_dataset(f, inp_key)
959
- if out_key:
960
- outp = mat_source._load_dataset(f, out_key)
961
- # Handle transposed outputs from MATLAB
962
- # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
963
- # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
964
- num_samples = inp.shape[0]
965
- if outp.ndim == 2:
966
- if outp.shape[0] == 1 and outp.shape[1] == num_samples:
967
- # 1D vector: (1, N) → (N, 1)
968
- outp = outp.T
969
- elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
970
- # Single sample with multiple targets: (T, 1) → (1, T)
971
- outp = outp.T
972
- else:
973
- outp = None
998
+ inp, outp = _load_mat_for_test(
999
+ path, custom_input_keys, custom_output_keys, input_key, output_key
1000
+ )
974
1001
  else:
975
1002
  # Fallback to default source.load() for unknown formats
976
1003
  inp, outp = source.load(path)
@@ -993,80 +1020,19 @@ def load_test_data(
993
1020
  ) from e
994
1021
 
995
1022
  # Legitimate fallback: no explicit keys, outputs just not present
1023
+ # Re-call helpers with None for explicit keys (validation already done above)
996
1024
  if format == "npz":
997
- # First pass to find keys
998
- with np.load(path, allow_pickle=False) as probe:
999
- keys = list(probe.keys())
1000
- inp_key = DataSource._find_key(keys, custom_input_keys)
1001
- if inp_key is None:
1002
- raise KeyError(
1003
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
1004
- )
1005
- out_key = DataSource._find_key(keys, custom_output_keys)
1006
- keys_to_probe = [inp_key] + ([out_key] if out_key else [])
1007
- data = NPZSource._load_and_copy(path, keys_to_probe)
1008
- inp = data[inp_key]
1009
- if inp.dtype == object:
1010
- inp = np.array(
1011
- [x.toarray() if hasattr(x, "toarray") else x for x in inp]
1012
- )
1013
- outp = data[out_key] if out_key else None
1025
+ inp, outp = _load_npz_for_test(
1026
+ path, custom_input_keys, custom_output_keys, None, None
1027
+ )
1014
1028
  elif format == "hdf5":
1015
- # HDF5: input-only loading for inference
1016
- with h5py.File(path, "r") as f:
1017
- keys = list(f.keys())
1018
- inp_key = DataSource._find_key(keys, custom_input_keys)
1019
- if inp_key is None:
1020
- raise KeyError(
1021
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
1022
- )
1023
- # Check size - load_test_data is eager, large files should use DataLoader
1024
- n_samples = f[inp_key].shape[0]
1025
- if n_samples > 100000:
1026
- raise ValueError(
1027
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
1028
- f"everything into RAM which may cause OOM. For large inference "
1029
- f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
1030
- )
1031
- inp = f[inp_key][:]
1032
- out_key = DataSource._find_key(keys, custom_output_keys)
1033
- outp = f[out_key][:] if out_key else None
1029
+ inp, outp = _load_hdf5_for_test(
1030
+ path, custom_input_keys, custom_output_keys, None, None
1031
+ )
1034
1032
  elif format == "mat":
1035
- # MAT v7.3: input-only loading with proper sparse handling
1036
- mat_source = MATSource()
1037
- with h5py.File(path, "r") as f:
1038
- keys = list(f.keys())
1039
- inp_key = DataSource._find_key(keys, custom_input_keys)
1040
- if inp_key is None:
1041
- raise KeyError(
1042
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
1043
- )
1044
- # Check size - load_test_data is eager, large files should use DataLoader
1045
- n_samples = f[inp_key].shape[-1] # MAT is transposed
1046
- if n_samples > 100000:
1047
- raise ValueError(
1048
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
1049
- f"everything into RAM which may cause OOM. For large inference "
1050
- f"sets, use a DataLoader with MATSource.load_mmap() instead."
1051
- )
1052
- # Use _load_dataset for sparse support and proper transpose
1053
- inp = mat_source._load_dataset(f, inp_key)
1054
- out_key = DataSource._find_key(keys, custom_output_keys)
1055
- if out_key:
1056
- outp = mat_source._load_dataset(f, out_key)
1057
- # Handle transposed outputs from MATLAB
1058
- # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
1059
- # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
1060
- num_samples = inp.shape[0]
1061
- if outp.ndim == 2:
1062
- if outp.shape[0] == 1 and outp.shape[1] == num_samples:
1063
- # 1D vector: (1, N) → (N, 1)
1064
- outp = outp.T
1065
- elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
1066
- # Single sample with multiple targets: (T, 1) → (1, T)
1067
- outp = outp.T
1068
- else:
1069
- outp = None
1033
+ inp, outp = _load_mat_for_test(
1034
+ path, custom_input_keys, custom_output_keys, None, None
1035
+ )
1070
1036
  else:
1071
1037
  raise
1072
1038
 
wavedl/utils/metrics.py CHANGED
@@ -106,8 +106,20 @@ def configure_matplotlib_style():
106
106
  )
107
107
 
108
108
 
109
- # Apply style on import
110
- configure_matplotlib_style()
109
+ # Lazy style initialization flag
110
+ _style_configured = False
111
+
112
+
113
+ def _ensure_style_configured():
114
+ """Apply matplotlib style on first use (lazy initialization).
115
+
116
+ This avoids modifying global plt.rcParams at import time, which could
117
+ unexpectedly override user configurations.
118
+ """
119
+ global _style_configured
120
+ if not _style_configured:
121
+ configure_matplotlib_style()
122
+ _style_configured = True
111
123
 
112
124
 
113
125
  # ==============================================================================
@@ -115,11 +127,14 @@ configure_matplotlib_style()
115
127
  # ==============================================================================
116
128
  class MetricTracker:
117
129
  """
118
- Tracks running averages of metrics with thread-safe accumulation.
130
+ Tracks running averages of metrics during training.
119
131
 
120
132
  Useful for tracking loss, accuracy, or any scalar metric across batches.
121
133
  Handles division-by-zero safely by returning 0.0 when count is zero.
122
134
 
135
+ Note:
136
+ Not thread-safe. Intended for single-threaded use within training loops.
137
+
123
138
  Attributes:
124
139
  val: Most recent value
125
140
  avg: Running average
@@ -341,6 +356,7 @@ def plot_scientific_scatter(
341
356
  Returns:
342
357
  Matplotlib Figure object (can be saved or logged to WandB)
343
358
  """
359
+ _ensure_style_configured()
344
360
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
345
361
  y_true, y_pred, param_names, max_samples
346
362
  )
@@ -415,6 +431,7 @@ def plot_error_histogram(
415
431
  Returns:
416
432
  Matplotlib Figure object
417
433
  """
434
+ _ensure_style_configured()
418
435
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
419
436
  y_true, y_pred, param_names
420
437
  )
@@ -485,6 +502,7 @@ def plot_residuals(
485
502
  Returns:
486
503
  Matplotlib Figure object
487
504
  """
505
+ _ensure_style_configured()
488
506
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
489
507
  y_true, y_pred, param_names, max_samples
490
508
  )
@@ -552,6 +570,7 @@ def create_training_curves(
552
570
  Returns:
553
571
  Matplotlib Figure object
554
572
  """
573
+ _ensure_style_configured()
555
574
  epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
556
575
 
557
576
  fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
@@ -605,8 +624,8 @@ def create_training_curves(
605
624
  if not valid_data:
606
625
  return
607
626
  vmin, vmax = min(valid_data), max(valid_data)
608
- # Get decade range that covers data (ceil for min to avoid going too low)
609
- log_min = int(np.ceil(np.log10(vmin)))
627
+ # Get decade range that covers data (floor for min to include lowest data)
628
+ log_min = int(np.floor(np.log10(vmin)))
610
629
  log_max = int(np.ceil(np.log10(vmax)))
611
630
  # Generate ticks at each power of 10
612
631
  ticks = [10.0**i for i in range(log_min, log_max + 1)]
@@ -691,6 +710,7 @@ def plot_bland_altman(
691
710
  Returns:
692
711
  Matplotlib Figure object
693
712
  """
713
+ _ensure_style_configured()
694
714
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
695
715
  y_true, y_pred, param_names, max_samples
696
716
  )
@@ -764,6 +784,7 @@ def plot_qq(
764
784
  Returns:
765
785
  Matplotlib Figure object
766
786
  """
787
+ _ensure_style_configured()
767
788
  from scipy import stats
768
789
 
769
790
  num_params = y_true.shape[1] if y_true.ndim > 1 else 1
@@ -76,7 +76,7 @@ def get_scheduler(
76
76
  # Step/MultiStep parameters
77
77
  step_size: int = 30,
78
78
  milestones: list[int] | None = None,
79
- gamma: float = 0.1,
79
+ gamma: float = 0.99,
80
80
  # Linear warmup parameters
81
81
  warmup_epochs: int = 5,
82
82
  start_factor: float = 0.1,
@@ -100,7 +100,7 @@ def get_scheduler(
100
100
  pct_start: Percentage of cycle spent increasing LR (OneCycleLR)
101
101
  step_size: Period for StepLR
102
102
  milestones: Epochs to decay LR for MultiStepLR
103
- gamma: Decay factor for step/multistep/exponential
103
+ gamma: Decay factor (0.1 for step/multistep, 0.99 for exponential)
104
104
  warmup_epochs: Number of warmup epochs for linear_warmup
105
105
  start_factor: Starting LR factor for warmup (LR * start_factor)
106
106
  **kwargs: Additional arguments passed to scheduler
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.6.3
3
+ Version: 1.7.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -907,18 +907,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
907
907
  **Run HPO:**
908
908
 
909
909
  ```bash
910
- # Basic HPO (auto-detects GPUs for parallel trials)
911
- wavedl-hpo --data_path train.npz --models cnn --n_trials 100
910
+ # Basic HPO (50 trials, auto-detects GPUs)
911
+ wavedl-hpo --data_path train.npz --n_trials 50
912
912
 
913
- # Search multiple models
914
- wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
913
+ # Quick search (minimal search space, fastest)
914
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
915
915
 
916
- # Quick mode (fewer parameters, faster)
917
- wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
916
+ # Medium search (balanced between quick and full)
917
+ wavedl-hpo --data_path train.npz --n_trials 50 --medium
918
+
919
+ # Full search with specific models
920
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
921
+
922
+ # In-process mode (enables pruning, faster, single-GPU)
923
+ wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
918
924
  ```
919
925
 
920
926
  > [!TIP]
921
- > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
927
+ > **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
922
928
 
923
929
  **Train with best parameters**
924
930
 
@@ -940,10 +946,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
940
946
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
941
947
  | Batch size | 16, 32, 64, 128 | (always searched) |
942
948
 
943
- **Quick Mode** (`--quick`):
944
- - Uses minimal defaults: cnn + adamw + plateau + mse
945
- - Faster for testing your setup before running full search
946
- - You can still override any option with the flags above
949
+ **Search Presets:**
950
+
951
+ | Mode | Models | Optimizers | Schedulers | Use Case |
952
+ |------|--------|------------|------------|----------|
953
+ | Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
954
+ | `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
955
+ | `--quick` | cnn | adamw | plateau | Fast validation |
956
+
957
+ **Execution Modes:**
958
+
959
+ | Mode | Flag | Pruning | GPU Memory | Best For |
960
+ |------|------|---------|------------|----------|
961
+ | Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
962
+ | In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
963
+
964
+ > [!TIP]
965
+ > Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
947
966
 
948
967
  ---
949
968
 
@@ -954,7 +973,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
954
973
  | `--data_path` | (required) | Training data file |
955
974
  | `--models` | 3 defaults | Models to search (specify any number) |
956
975
  | `--n_trials` | `50` | Number of trials to run |
957
- | `--quick` | `False` | Use minimal defaults (faster) |
976
+ | `--quick` | `False` | Quick mode: minimal search space |
977
+ | `--medium` | `False` | Medium mode: balanced search space |
978
+ | `--inprocess` | `False` | Run trials in-process (enables pruning) |
958
979
  | `--optimizers` | all 6 | Optimizers to search |
959
980
  | `--schedulers` | all 8 | Schedulers to search |
960
981
  | `--losses` | all 6 | Losses to search |
@@ -963,7 +984,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
963
984
  | `--output` | `hpo_results.json` | Output file |
964
985
 
965
986
 
966
- > See [Available Models](#available-models) for all 38 architectures you can search.
987
+ > See [Available Models](#available-models) for all 69 architectures you can search.
967
988
 
968
989
  </details>
969
990
 
@@ -0,0 +1,46 @@
1
+ wavedl/__init__.py,sha256=Ol1M5mok2rnnUKDPBZLQDBDwKn7_LV9iTxds5obDeJk,1177
2
+ wavedl/hpo.py,sha256=TyEWubL-adQJyRmSv1M1SnJvH7_vTlXJATI1ElxXVUU,18991
3
+ wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
4
+ wavedl/test.py,sha256=_6i6F1KkO24cC0RtvxcwAGxDUbU6VV12efDWFGojkeE,38466
5
+ wavedl/train.py,sha256=U0-hHo4wba1dhZ2jjaCFURiS6aGHuPDGz0zPwPm2Kc0,72722
6
+ wavedl/models/__init__.py,sha256=hyR__h_D8PsUQCBSM5tj94yYK00uG8ABjEmj_RR8SGE,5719
7
+ wavedl/models/_pretrained_utils.py,sha256=br--ZrgqndYCO_iAeQOvUDg6ZxzGDyLZFeWu1Qj_DrI,14756
8
+ wavedl/models/_template.py,sha256=nEixVS8e82Tud08Uk8jkXtriGhk_WFbqSaGDq_Mj4ak,4684
9
+ wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
10
+ wavedl/models/caformer.py,sha256=ufPM-HzQ-qUZcXgnOulurY6jBUlMUzokC01whtPeVMg,7922
11
+ wavedl/models/cnn.py,sha256=dOmCrHGXd8Md8ixbJ_-An9t80tm36sVY84je2EDmnZA,8256
12
+ wavedl/models/convnext.py,sha256=GoLId2HClsOksuL3XLscEIytrmOBPGhO6UhGn04yDp4,13354
13
+ wavedl/models/convnext_v2.py,sha256=jPPXTZbQQ8zE9yGVWTNUaI5g1d0xIxBjrLuUHUKc5mM,14349
14
+ wavedl/models/densenet.py,sha256=V_caGd0wsG_Q3Q38I4MEgYmU0v4j8mDyvv7Rn3Bk7Ac,12667
15
+ wavedl/models/efficientnet.py,sha256=HWfhqSX57lC5Xug5TrQ3r-uFqkksoIKjmQ5Zr5njkEA,8264
16
+ wavedl/models/efficientnetv2.py,sha256=hVSnVId8T1rjqaKlckLqWFwvo2J-qASX7o9lMbXbP-s,10947
17
+ wavedl/models/efficientvit.py,sha256=KqFoZq9YHBMnTue6aMdPKgBOMczeBPryY_F6ip0hoEI,11630
18
+ wavedl/models/fastvit.py,sha256=S0SF0iC-9ZJrP-9YUTLPhMJMV-W9r2--V3hVAmSSVKI,7083
19
+ wavedl/models/mamba.py,sha256=2mqBxUKCLJNkRc87QzpsOj6hKzrEh5tchGSyYZCSUcQ,20031
20
+ wavedl/models/maxvit.py,sha256=I6TFGrLRcyMU-nU7u5VhOaXZWWdwmNJwHsMqbJh_g_o,7548
21
+ wavedl/models/mobilenetv3.py,sha256=LZxCg599kGP6-XI_l3PpT8jzh4oTAdWH3Y7GH097o28,10242
22
+ wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
23
+ wavedl/models/regnet.py,sha256=6Yjo2wZzdjK8VpOMagbCrHqmsfRmGkuiURmc-MesYvA,13777
24
+ wavedl/models/resnet.py,sha256=3i4zfE15qF4cd0qbTKX-Wdy2Kd0f4mLcdd316FAcVCo,16720
25
+ wavedl/models/resnet3d.py,sha256=NS6UBmvITO3NbdBNfe39bxViFqrIgeSXBnDgYi8QsC8,9247
26
+ wavedl/models/swin.py,sha256=39Gwn5hNEw3-tndc8qFFzV-VZ7pJMMKey2oZONAZ8MU,14980
27
+ wavedl/models/tcn.py,sha256=XzojpuMFG4lu_0oQHbQnkLAb7AnW-D7_6KoBlQDPLnQ,12367
28
+ wavedl/models/unet.py,sha256=oi7eBONSe0ALpJKsYda3jRGwu-LuSiFgNdURebnGGt0,7712
29
+ wavedl/models/unireplknet.py,sha256=sbiYcc2NeB0-_VAmeoe9Vi5hQzhYz03knG7o2Qk0WYE,14634
30
+ wavedl/models/vit.py,sha256=o-zWT2GBCTs9vD3jUFwlcwxK53XqEn_x4iPaRuEQe10,15219
31
+ wavedl/utils/__init__.py,sha256=CYqD3Bcwcub2mSrW05x8wvd2n1Co_3N9ajyKPyBswjo,4887
32
+ wavedl/utils/config.py,sha256=yAKuuhM-oxvHFXomkkek4IGihsVO5yZxc4b2noQ1amE,10523
33
+ wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
34
+ wavedl/utils/cross_validation.py,sha256=evp8EsHGJcxMHpfRdFSParDltUTyMhKQxUkcn5-osI4,18556
35
+ wavedl/utils/data.py,sha256=f-LLIiiv74iiKrR8TQ9oeKODF29_jzeUUp4iMBuj_H4,60875
36
+ wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
37
+ wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
38
+ wavedl/utils/metrics.py,sha256=tLzKG2zINyXit-KvYZSJg-1nG6rST54GH6k4ALonToU,40935
39
+ wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
40
+ wavedl/utils/schedulers.py,sha256=_aFTQ8kuvdZIxOoXPHQRu_N9XBuTVSjU6dmBbzH430o,7425
41
+ wavedl-1.7.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
42
+ wavedl-1.7.0.dist-info/METADATA,sha256=f5UMARudJtBdt5Pu3E18Rfk2sD4h77A2mt8Mh8_cJk4,48527
43
+ wavedl-1.7.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
44
+ wavedl-1.7.0.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
45
+ wavedl-1.7.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
46
+ wavedl-1.7.0.dist-info/RECORD,,