wavedl 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,9 @@ from torch.utils.data import DataLoader
36
36
  class CVDataset(torch.utils.data.Dataset):
37
37
  """Simple in-memory dataset for cross-validation."""
38
38
 
39
- def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
39
+ def __init__(
40
+ self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int | None = None
41
+ ):
40
42
  """
41
43
  Initialize CV dataset with explicit channel dimension handling.
42
44
 
@@ -51,6 +53,11 @@ class CVDataset(torch.utils.data.Dataset):
51
53
  - If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
52
54
  - If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
53
55
  - If expected_spatial_ndim is None: Use legacy ndim-based inference
56
+
57
+ Warning:
58
+ Legacy mode (expected_spatial_ndim=None) may misinterpret multichannel
59
+ 3D data as single-channel 4D data. Always pass expected_spatial_ndim
60
+ explicitly for 3D volumes with >1 channel.
54
61
  """
55
62
  if expected_spatial_ndim is not None:
56
63
  # Explicit mode: use expected_spatial_ndim to determine if channel exists
@@ -100,6 +107,7 @@ def train_fold(
100
107
  device: torch.device,
101
108
  epochs: int,
102
109
  patience: int,
110
+ grad_clip: float,
103
111
  scaler: StandardScaler,
104
112
  logger: logging.Logger,
105
113
  ) -> dict[str, Any]:
@@ -147,7 +155,8 @@ def train_fold(
147
155
  pred = model(x)
148
156
  loss = criterion(pred, y)
149
157
  loss.backward()
150
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
158
+ if grad_clip > 0:
159
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
151
160
  optimizer.step()
152
161
 
153
162
  # Per-batch LR scheduling (OneCycleLR)
@@ -289,6 +298,7 @@ def run_cross_validation(
289
298
  output_dir: str = "./cv_results",
290
299
  workers: int = 4,
291
300
  seed: int = 2025,
301
+ grad_clip: float = 1.0,
292
302
  logger: logging.Logger | None = None,
293
303
  ) -> dict[str, Any]:
294
304
  """
@@ -330,8 +340,10 @@ def run_cross_validation(
330
340
  )
331
341
  logger = logging.getLogger("CV-Trainer")
332
342
 
333
- # Set seeds
334
- np.random.seed(seed)
343
+ # Set seeds for reproducibility
344
+ # Note: sklearn KFold uses random_state parameter directly, not global numpy RNG
345
+ rng = np.random.default_rng(seed) # Local RNG for any numpy operations
346
+ _ = rng # Silence unused variable warning (available for future use)
335
347
  torch.manual_seed(seed)
336
348
  if torch.cuda.is_available():
337
349
  torch.cuda.manual_seed_all(seed)
@@ -444,6 +456,7 @@ def run_cross_validation(
444
456
  device=device,
445
457
  epochs=epochs,
446
458
  patience=patience,
459
+ grad_clip=grad_clip,
447
460
  scaler=scaler,
448
461
  logger=logger,
449
462
  )
wavedl/utils/data.py CHANGED
@@ -798,6 +798,128 @@ def load_outputs_only(path: str, format: str = "auto") -> np.ndarray:
798
798
  return source.load_outputs_only(path)
799
799
 
800
800
 
801
+ def _load_npz_for_test(
802
+ path: str,
803
+ input_keys: list[str],
804
+ output_keys: list[str],
805
+ explicit_input_key: str | None,
806
+ explicit_output_key: str | None,
807
+ ) -> tuple[np.ndarray, np.ndarray | None]:
808
+ """Load NPZ file for test/inference with key validation."""
809
+ with np.load(path, allow_pickle=False) as probe:
810
+ keys = list(probe.keys())
811
+
812
+ inp_key = DataSource._find_key(keys, input_keys)
813
+ out_key = DataSource._find_key(keys, output_keys)
814
+
815
+ # Strict validation for explicit keys
816
+ if explicit_input_key is not None and explicit_input_key not in keys:
817
+ raise KeyError(
818
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
819
+ )
820
+ if inp_key is None:
821
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
822
+ if explicit_output_key is not None and explicit_output_key not in keys:
823
+ raise KeyError(
824
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
825
+ )
826
+
827
+ data = NPZSource._load_and_copy(path, [inp_key] + ([out_key] if out_key else []))
828
+ inp = data[inp_key]
829
+ if inp.dtype == object:
830
+ inp = np.array([x.toarray() if hasattr(x, "toarray") else x for x in inp])
831
+ outp = data[out_key] if out_key else None
832
+ return inp, outp
833
+
834
+
835
+ def _load_hdf5_for_test(
836
+ path: str,
837
+ input_keys: list[str],
838
+ output_keys: list[str],
839
+ explicit_input_key: str | None,
840
+ explicit_output_key: str | None,
841
+ ) -> tuple[np.ndarray, np.ndarray | None]:
842
+ """Load HDF5 file for test/inference with key validation and OOM guard."""
843
+ with h5py.File(path, "r") as f:
844
+ keys = list(f.keys())
845
+ inp_key = DataSource._find_key(keys, input_keys)
846
+ out_key = DataSource._find_key(keys, output_keys)
847
+
848
+ # Strict validation
849
+ if explicit_input_key is not None and explicit_input_key not in keys:
850
+ raise KeyError(
851
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
852
+ )
853
+ if inp_key is None:
854
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
855
+ if explicit_output_key is not None and explicit_output_key not in keys:
856
+ raise KeyError(
857
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
858
+ )
859
+
860
+ # OOM guard
861
+ n_samples = f[inp_key].shape[0]
862
+ if n_samples > 100000:
863
+ raise ValueError(
864
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
865
+ f"everything into RAM which may cause OOM. For large inference "
866
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
867
+ )
868
+
869
+ inp = f[inp_key][:]
870
+ outp = f[out_key][:] if out_key else None
871
+ return inp, outp
872
+
873
+
874
+ def _load_mat_for_test(
875
+ path: str,
876
+ input_keys: list[str],
877
+ output_keys: list[str],
878
+ explicit_input_key: str | None,
879
+ explicit_output_key: str | None,
880
+ ) -> tuple[np.ndarray, np.ndarray | None]:
881
+ """Load MAT v7.3 file for test/inference with sparse support."""
882
+ mat_source = MATSource()
883
+ with h5py.File(path, "r") as f:
884
+ keys = list(f.keys())
885
+ inp_key = DataSource._find_key(keys, input_keys)
886
+ out_key = DataSource._find_key(keys, output_keys)
887
+
888
+ # Strict validation
889
+ if explicit_input_key is not None and explicit_input_key not in keys:
890
+ raise KeyError(
891
+ f"Explicit --input_key '{explicit_input_key}' not found. Available: {keys}"
892
+ )
893
+ if inp_key is None:
894
+ raise KeyError(f"Input key not found. Tried: {input_keys}. Found: {keys}")
895
+ if explicit_output_key is not None and explicit_output_key not in keys:
896
+ raise KeyError(
897
+ f"Explicit --output_key '{explicit_output_key}' not found. Available: {keys}"
898
+ )
899
+
900
+ # OOM guard
901
+ n_samples = f[inp_key].shape[-1] # MAT is transposed
902
+ if n_samples > 100000:
903
+ raise ValueError(
904
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
905
+ f"everything into RAM which may cause OOM. For large inference "
906
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
907
+ )
908
+
909
+ inp = mat_source._load_dataset(f, inp_key)
910
+ outp = None
911
+ if out_key:
912
+ outp = mat_source._load_dataset(f, out_key)
913
+ # Handle MATLAB transpose
914
+ num_samples = inp.shape[0]
915
+ if outp.ndim == 2:
916
+ if (outp.shape[0] == 1 and outp.shape[1] == num_samples) or (
917
+ outp.shape[1] == 1 and outp.shape[0] != num_samples
918
+ ):
919
+ outp = outp.T
920
+ return inp, outp
921
+
922
+
801
923
  def load_test_data(
802
924
  path: str,
803
925
  format: str = "auto",
@@ -865,112 +987,17 @@ def load_test_data(
865
987
  # We detect keys first to ensure input_test/output_test are used when present
866
988
  try:
867
989
  if format == "npz":
868
- with np.load(path, allow_pickle=False) as probe:
869
- keys = list(probe.keys())
870
- inp_key = DataSource._find_key(keys, custom_input_keys)
871
- out_key = DataSource._find_key(keys, custom_output_keys)
872
- # Strict validation: if user explicitly specified input_key, it must exist exactly
873
- if input_key is not None and input_key not in keys:
874
- raise KeyError(
875
- f"Explicit --input_key '{input_key}' not found. "
876
- f"Available keys: {keys}"
877
- )
878
- if inp_key is None:
879
- raise KeyError(
880
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
881
- )
882
- # Strict validation: if user explicitly specified output_key, it must exist exactly
883
- if output_key is not None and output_key not in keys:
884
- raise KeyError(
885
- f"Explicit --output_key '{output_key}' not found. "
886
- f"Available keys: {keys}"
887
- )
888
- data = NPZSource._load_and_copy(
889
- path, [inp_key] + ([out_key] if out_key else [])
990
+ inp, outp = _load_npz_for_test(
991
+ path, custom_input_keys, custom_output_keys, input_key, output_key
890
992
  )
891
- inp = data[inp_key]
892
- if inp.dtype == object:
893
- inp = np.array(
894
- [x.toarray() if hasattr(x, "toarray") else x for x in inp]
895
- )
896
- outp = data[out_key] if out_key else None
897
993
  elif format == "hdf5":
898
- with h5py.File(path, "r") as f:
899
- keys = list(f.keys())
900
- inp_key = DataSource._find_key(keys, custom_input_keys)
901
- out_key = DataSource._find_key(keys, custom_output_keys)
902
- # Strict validation: if user explicitly specified input_key, it must exist exactly
903
- if input_key is not None and input_key not in keys:
904
- raise KeyError(
905
- f"Explicit --input_key '{input_key}' not found. "
906
- f"Available keys: {keys}"
907
- )
908
- if inp_key is None:
909
- raise KeyError(
910
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
911
- )
912
- # Strict validation: if user explicitly specified output_key, it must exist exactly
913
- if output_key is not None and output_key not in keys:
914
- raise KeyError(
915
- f"Explicit --output_key '{output_key}' not found. "
916
- f"Available keys: {keys}"
917
- )
918
- # OOM guard: warn if dataset is very large
919
- n_samples = f[inp_key].shape[0]
920
- if n_samples > 100000:
921
- raise ValueError(
922
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
923
- f"everything into RAM which may cause OOM. For large inference "
924
- f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
925
- )
926
- inp = f[inp_key][:]
927
- outp = f[out_key][:] if out_key else None
994
+ inp, outp = _load_hdf5_for_test(
995
+ path, custom_input_keys, custom_output_keys, input_key, output_key
996
+ )
928
997
  elif format == "mat":
929
- mat_source = MATSource()
930
- with h5py.File(path, "r") as f:
931
- keys = list(f.keys())
932
- inp_key = DataSource._find_key(keys, custom_input_keys)
933
- out_key = DataSource._find_key(keys, custom_output_keys)
934
- # Strict validation: if user explicitly specified input_key, it must exist exactly
935
- if input_key is not None and input_key not in keys:
936
- raise KeyError(
937
- f"Explicit --input_key '{input_key}' not found. "
938
- f"Available keys: {keys}"
939
- )
940
- if inp_key is None:
941
- raise KeyError(
942
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
943
- )
944
- # Strict validation: if user explicitly specified output_key, it must exist exactly
945
- if output_key is not None and output_key not in keys:
946
- raise KeyError(
947
- f"Explicit --output_key '{output_key}' not found. "
948
- f"Available keys: {keys}"
949
- )
950
- # OOM guard: warn if dataset is very large (MAT is transposed)
951
- n_samples = f[inp_key].shape[-1]
952
- if n_samples > 100000:
953
- raise ValueError(
954
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
955
- f"everything into RAM which may cause OOM. For large inference "
956
- f"sets, use a DataLoader with MATSource.load_mmap() instead."
957
- )
958
- inp = mat_source._load_dataset(f, inp_key)
959
- if out_key:
960
- outp = mat_source._load_dataset(f, out_key)
961
- # Handle transposed outputs from MATLAB
962
- # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
963
- # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
964
- num_samples = inp.shape[0]
965
- if outp.ndim == 2:
966
- if outp.shape[0] == 1 and outp.shape[1] == num_samples:
967
- # 1D vector: (1, N) → (N, 1)
968
- outp = outp.T
969
- elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
970
- # Single sample with multiple targets: (T, 1) → (1, T)
971
- outp = outp.T
972
- else:
973
- outp = None
998
+ inp, outp = _load_mat_for_test(
999
+ path, custom_input_keys, custom_output_keys, input_key, output_key
1000
+ )
974
1001
  else:
975
1002
  # Fallback to default source.load() for unknown formats
976
1003
  inp, outp = source.load(path)
@@ -984,81 +1011,28 @@ def load_test_data(
984
1011
  f"Available keys depend on file format. Original error: {e}"
985
1012
  ) from e
986
1013
 
987
- # Legitimate fallback: no explicit output_key, outputs just not present
1014
+ # Also fail-fast if explicit input_key was provided but not found
1015
+ # This prevents silently loading a different tensor when user mistyped key
1016
+ if input_key is not None:
1017
+ raise KeyError(
1018
+ f"Explicit --input_key '{input_key}' not found in file. "
1019
+ f"Original error: {e}"
1020
+ ) from e
1021
+
1022
+ # Legitimate fallback: no explicit keys, outputs just not present
1023
+ # Re-call helpers with None for explicit keys (validation already done above)
988
1024
  if format == "npz":
989
- # First pass to find keys
990
- with np.load(path, allow_pickle=False) as probe:
991
- keys = list(probe.keys())
992
- inp_key = DataSource._find_key(keys, custom_input_keys)
993
- if inp_key is None:
994
- raise KeyError(
995
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
996
- )
997
- out_key = DataSource._find_key(keys, custom_output_keys)
998
- keys_to_probe = [inp_key] + ([out_key] if out_key else [])
999
- data = NPZSource._load_and_copy(path, keys_to_probe)
1000
- inp = data[inp_key]
1001
- if inp.dtype == object:
1002
- inp = np.array(
1003
- [x.toarray() if hasattr(x, "toarray") else x for x in inp]
1004
- )
1005
- outp = data[out_key] if out_key else None
1025
+ inp, outp = _load_npz_for_test(
1026
+ path, custom_input_keys, custom_output_keys, None, None
1027
+ )
1006
1028
  elif format == "hdf5":
1007
- # HDF5: input-only loading for inference
1008
- with h5py.File(path, "r") as f:
1009
- keys = list(f.keys())
1010
- inp_key = DataSource._find_key(keys, custom_input_keys)
1011
- if inp_key is None:
1012
- raise KeyError(
1013
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
1014
- )
1015
- # Check size - load_test_data is eager, large files should use DataLoader
1016
- n_samples = f[inp_key].shape[0]
1017
- if n_samples > 100000:
1018
- raise ValueError(
1019
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
1020
- f"everything into RAM which may cause OOM. For large inference "
1021
- f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
1022
- )
1023
- inp = f[inp_key][:]
1024
- out_key = DataSource._find_key(keys, custom_output_keys)
1025
- outp = f[out_key][:] if out_key else None
1029
+ inp, outp = _load_hdf5_for_test(
1030
+ path, custom_input_keys, custom_output_keys, None, None
1031
+ )
1026
1032
  elif format == "mat":
1027
- # MAT v7.3: input-only loading with proper sparse handling
1028
- mat_source = MATSource()
1029
- with h5py.File(path, "r") as f:
1030
- keys = list(f.keys())
1031
- inp_key = DataSource._find_key(keys, custom_input_keys)
1032
- if inp_key is None:
1033
- raise KeyError(
1034
- f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
1035
- )
1036
- # Check size - load_test_data is eager, large files should use DataLoader
1037
- n_samples = f[inp_key].shape[-1] # MAT is transposed
1038
- if n_samples > 100000:
1039
- raise ValueError(
1040
- f"Dataset has {n_samples:,} samples. load_test_data() loads "
1041
- f"everything into RAM which may cause OOM. For large inference "
1042
- f"sets, use a DataLoader with MATSource.load_mmap() instead."
1043
- )
1044
- # Use _load_dataset for sparse support and proper transpose
1045
- inp = mat_source._load_dataset(f, inp_key)
1046
- out_key = DataSource._find_key(keys, custom_output_keys)
1047
- if out_key:
1048
- outp = mat_source._load_dataset(f, out_key)
1049
- # Handle transposed outputs from MATLAB
1050
- # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
1051
- # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
1052
- num_samples = inp.shape[0]
1053
- if outp.ndim == 2:
1054
- if outp.shape[0] == 1 and outp.shape[1] == num_samples:
1055
- # 1D vector: (1, N) → (N, 1)
1056
- outp = outp.T
1057
- elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
1058
- # Single sample with multiple targets: (T, 1) → (1, T)
1059
- outp = outp.T
1060
- else:
1061
- outp = None
1033
+ inp, outp = _load_mat_for_test(
1034
+ path, custom_input_keys, custom_output_keys, None, None
1035
+ )
1062
1036
  else:
1063
1037
  raise
1064
1038
 
@@ -1524,21 +1498,43 @@ def prepare_data(
1524
1498
 
1525
1499
  logger.info(" ✔ Cache creation complete, synchronizing ranks...")
1526
1500
  else:
1527
- # NON-MAIN RANKS: Wait for cache creation
1528
- # Log that we're waiting (helps with debugging)
1501
+ # NON-MAIN RANKS: Wait for cache creation with timeout
1502
+ # Use monotonic clock (immune to system clock changes)
1529
1503
  import time
1530
1504
 
1531
- wait_start = time.time()
1505
+ wait_start = time.monotonic()
1506
+
1507
+ # Robust env parsing with guards for invalid/non-positive values
1508
+ DEFAULT_CACHE_TIMEOUT = 3600 # 1 hour default
1509
+ try:
1510
+ env_timeout = os.environ.get("WAVEDL_CACHE_TIMEOUT", "")
1511
+ CACHE_TIMEOUT = (
1512
+ int(env_timeout) if env_timeout else DEFAULT_CACHE_TIMEOUT
1513
+ )
1514
+ if CACHE_TIMEOUT <= 0:
1515
+ CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
1516
+ except ValueError:
1517
+ CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
1518
+
1532
1519
  while not (
1533
1520
  os.path.exists(CACHE_FILE)
1534
1521
  and os.path.exists(SCALER_FILE)
1535
1522
  and os.path.exists(META_FILE)
1536
1523
  ):
1537
1524
  time.sleep(5) # Check every 5 seconds
1538
- elapsed = time.time() - wait_start
1525
+ elapsed = time.monotonic() - wait_start
1526
+
1527
+ if elapsed > CACHE_TIMEOUT:
1528
+ raise RuntimeError(
1529
+ f"[Rank {accelerator.process_index}] Timeout waiting for cache "
1530
+ f"files after {CACHE_TIMEOUT}s. Rank 0 may have failed during "
1531
+ f"cache generation. Check rank 0 logs for errors."
1532
+ )
1533
+
1539
1534
  if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
1540
1535
  logger.info(
1541
- f" [Rank {accelerator.process_index}] Waiting for cache creation... ({int(elapsed)}s)"
1536
+ f" [Rank {accelerator.process_index}] Waiting for cache "
1537
+ f"creation... ({int(elapsed)}s / {CACHE_TIMEOUT}s max)"
1542
1538
  )
1543
1539
  # Small delay to ensure files are fully written
1544
1540
  time.sleep(2)
wavedl/utils/metrics.py CHANGED
@@ -106,8 +106,20 @@ def configure_matplotlib_style():
106
106
  )
107
107
 
108
108
 
109
- # Apply style on import
110
- configure_matplotlib_style()
109
+ # Lazy style initialization flag
110
+ _style_configured = False
111
+
112
+
113
+ def _ensure_style_configured():
114
+ """Apply matplotlib style on first use (lazy initialization).
115
+
116
+ This avoids modifying global plt.rcParams at import time, which could
117
+ unexpectedly override user configurations.
118
+ """
119
+ global _style_configured
120
+ if not _style_configured:
121
+ configure_matplotlib_style()
122
+ _style_configured = True
111
123
 
112
124
 
113
125
  # ==============================================================================
@@ -115,11 +127,14 @@ configure_matplotlib_style()
115
127
  # ==============================================================================
116
128
  class MetricTracker:
117
129
  """
118
- Tracks running averages of metrics with thread-safe accumulation.
130
+ Tracks running averages of metrics during training.
119
131
 
120
132
  Useful for tracking loss, accuracy, or any scalar metric across batches.
121
133
  Handles division-by-zero safely by returning 0.0 when count is zero.
122
134
 
135
+ Note:
136
+ Not thread-safe. Intended for single-threaded use within training loops.
137
+
123
138
  Attributes:
124
139
  val: Most recent value
125
140
  avg: Running average
@@ -341,6 +356,7 @@ def plot_scientific_scatter(
341
356
  Returns:
342
357
  Matplotlib Figure object (can be saved or logged to WandB)
343
358
  """
359
+ _ensure_style_configured()
344
360
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
345
361
  y_true, y_pred, param_names, max_samples
346
362
  )
@@ -415,6 +431,7 @@ def plot_error_histogram(
415
431
  Returns:
416
432
  Matplotlib Figure object
417
433
  """
434
+ _ensure_style_configured()
418
435
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
419
436
  y_true, y_pred, param_names
420
437
  )
@@ -485,6 +502,7 @@ def plot_residuals(
485
502
  Returns:
486
503
  Matplotlib Figure object
487
504
  """
505
+ _ensure_style_configured()
488
506
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
489
507
  y_true, y_pred, param_names, max_samples
490
508
  )
@@ -552,6 +570,7 @@ def create_training_curves(
552
570
  Returns:
553
571
  Matplotlib Figure object
554
572
  """
573
+ _ensure_style_configured()
555
574
  epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
556
575
 
557
576
  fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
@@ -605,8 +624,8 @@ def create_training_curves(
605
624
  if not valid_data:
606
625
  return
607
626
  vmin, vmax = min(valid_data), max(valid_data)
608
- # Get decade range that covers data (ceil for min to avoid going too low)
609
- log_min = int(np.ceil(np.log10(vmin)))
627
+ # Get decade range that covers data (floor for min to include lowest data)
628
+ log_min = int(np.floor(np.log10(vmin)))
610
629
  log_max = int(np.ceil(np.log10(vmax)))
611
630
  # Generate ticks at each power of 10
612
631
  ticks = [10.0**i for i in range(log_min, log_max + 1)]
@@ -691,6 +710,7 @@ def plot_bland_altman(
691
710
  Returns:
692
711
  Matplotlib Figure object
693
712
  """
713
+ _ensure_style_configured()
694
714
  y_true, y_pred, param_names, num_params = _prepare_plot_data(
695
715
  y_true, y_pred, param_names, max_samples
696
716
  )
@@ -764,6 +784,7 @@ def plot_qq(
764
784
  Returns:
765
785
  Matplotlib Figure object
766
786
  """
787
+ _ensure_style_configured()
767
788
  from scipy import stats
768
789
 
769
790
  num_params = y_true.shape[1] if y_true.ndim > 1 else 1
@@ -76,7 +76,7 @@ def get_scheduler(
76
76
  # Step/MultiStep parameters
77
77
  step_size: int = 30,
78
78
  milestones: list[int] | None = None,
79
- gamma: float = 0.1,
79
+ gamma: float = 0.99,
80
80
  # Linear warmup parameters
81
81
  warmup_epochs: int = 5,
82
82
  start_factor: float = 0.1,
@@ -100,7 +100,7 @@ def get_scheduler(
100
100
  pct_start: Percentage of cycle spent increasing LR (OneCycleLR)
101
101
  step_size: Period for StepLR
102
102
  milestones: Epochs to decay LR for MultiStepLR
103
- gamma: Decay factor for step/multistep/exponential
103
+ gamma: Decay factor (0.1 for step/multistep, 0.99 for exponential)
104
104
  warmup_epochs: Number of warmup epochs for linear_warmup
105
105
  start_factor: Starting LR factor for warmup (LR * start_factor)
106
106
  **kwargs: Additional arguments passed to scheduler