wavedl 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/train.py CHANGED
@@ -41,30 +41,11 @@ from __future__ import annotations
41
41
  # Uses current working directory as fallback - works on HPC and local machines.
42
42
  import os
43
43
 
44
+ # Import and call HPC cache setup before any library imports
45
+ from wavedl.utils import setup_hpc_cache_dirs
44
46
 
45
- def _setup_cache_dir(env_var: str, subdir: str) -> None:
46
- """Set cache directory to CWD if home is not writable."""
47
- if env_var in os.environ:
48
- return # User already set, respect their choice
49
47
 
50
- # Check if home is writable
51
- home = os.path.expanduser("~")
52
- if os.access(home, os.W_OK):
53
- return # Home is writable, let library use defaults
54
-
55
- # Home not writable - use current working directory
56
- cache_path = os.path.join(os.getcwd(), f".{subdir}")
57
- os.makedirs(cache_path, exist_ok=True)
58
- os.environ[env_var] = cache_path
59
-
60
-
61
- # Configure cache directories (before any library imports)
62
- _setup_cache_dir("TORCH_HOME", "torch_cache")
63
- _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
64
- _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
65
- _setup_cache_dir("XDG_DATA_HOME", "local/share")
66
- _setup_cache_dir("XDG_STATE_HOME", "local/state")
67
- _setup_cache_dir("XDG_CACHE_HOME", "cache")
48
+ setup_hpc_cache_dirs()
68
49
 
69
50
 
70
51
  def _setup_per_rank_compile_cache() -> None:
@@ -109,7 +90,11 @@ import shutil
109
90
  import sys
110
91
  import time
111
92
  import warnings
112
- from typing import Any
93
+ from typing import TYPE_CHECKING, Any
94
+
95
+
96
+ if TYPE_CHECKING:
97
+ import optuna
113
98
 
114
99
 
115
100
  # Suppress Pydantic warnings from accelerate's internal Field() usage
@@ -155,9 +140,8 @@ except ImportError:
155
140
  # ==============================================================================
156
141
  # RUNTIME CONFIGURATION (post-import)
157
142
  # ==============================================================================
158
- # Configure matplotlib paths for HPC systems without writable home directories
159
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
160
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
143
+ # Note: matplotlib cache directory is already configured by setup_hpc_cache_dirs()
144
+ # called at module load time. No additional MPLCONFIGDIR setup needed here.
161
145
 
162
146
  # Suppress warnings from known-noisy libraries, but preserve legitimate warnings
163
147
  # from torch/numpy about NaN, dtype, and numerical issues.
@@ -483,6 +467,488 @@ def parse_args() -> argparse.Namespace:
483
467
  return args, parser # Returns (Namespace, ArgumentParser)
484
468
 
485
469
 
470
+ # ==============================================================================
471
+ # TRAINING HELPER FUNCTIONS
472
+ # ==============================================================================
473
+ def _run_train_epoch(
474
+ model,
475
+ train_dl,
476
+ optimizer,
477
+ criterion,
478
+ accelerator,
479
+ scheduler,
480
+ scheduler_step_per_batch: bool,
481
+ grad_clip: float,
482
+ ) -> tuple[float, float]:
483
+ """
484
+ Run one training epoch.
485
+
486
+ Args:
487
+ model: Model to train
488
+ train_dl: Training DataLoader
489
+ optimizer: Optimizer
490
+ criterion: Loss function (may be PhysicsConstrainedLoss)
491
+ accelerator: Accelerator instance
492
+ scheduler: LR scheduler
493
+ scheduler_step_per_batch: Whether to step scheduler per batch
494
+ grad_clip: Gradient clipping norm
495
+
496
+ Returns:
497
+ Tuple of (avg_train_loss, avg_grad_norm)
498
+ """
499
+ from wavedl.utils.constraints import PhysicsConstrainedLoss
500
+
501
+ model.train()
502
+ train_loss_sum = torch.tensor(0.0, device=accelerator.device)
503
+ train_samples = 0
504
+ grad_norm_tracker = MetricTracker()
505
+
506
+ for x, y in train_dl:
507
+ with accelerator.accumulate(model):
508
+ with accelerator.autocast():
509
+ pred = model(x)
510
+ if isinstance(criterion, PhysicsConstrainedLoss):
511
+ loss = criterion(pred, y, x)
512
+ else:
513
+ loss = criterion(pred, y)
514
+
515
+ accelerator.backward(loss)
516
+
517
+ if accelerator.sync_gradients:
518
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), grad_clip)
519
+ if grad_norm is not None:
520
+ grad_norm_tracker.update(grad_norm.item())
521
+
522
+ optimizer.step()
523
+ optimizer.zero_grad(set_to_none=True)
524
+
525
+ if scheduler_step_per_batch:
526
+ scheduler.step()
527
+
528
+ train_loss_sum += loss.detach() * x.size(0)
529
+ train_samples += x.size(0)
530
+
531
+ # Sync across GPUs
532
+ train_loss_scalar = train_loss_sum.item()
533
+ global_loss = accelerator.reduce(
534
+ torch.tensor([train_loss_scalar], device=accelerator.device),
535
+ reduction="sum",
536
+ ).item()
537
+ global_samples = accelerator.reduce(
538
+ torch.tensor([train_samples], device=accelerator.device),
539
+ reduction="sum",
540
+ ).item()
541
+
542
+ return global_loss / global_samples, grad_norm_tracker.avg
543
+
544
+
545
+ def _run_validation(
546
+ model,
547
+ val_dl,
548
+ criterion,
549
+ accelerator,
550
+ out_dim: int,
551
+ phys_scale: torch.Tensor,
552
+ ) -> tuple[float, np.ndarray, torch.Tensor, torch.Tensor]:
553
+ """
554
+ Run validation epoch.
555
+
556
+ Args:
557
+ model: Model in eval mode
558
+ val_dl: Validation DataLoader
559
+ criterion: Loss function
560
+ accelerator: Accelerator instance
561
+ out_dim: Number of output dimensions
562
+ phys_scale: Physical scale tensor for MAE computation
563
+
564
+ Returns:
565
+ Tuple of (avg_val_loss, avg_mae_per_param, gathered_preds, gathered_targets)
566
+ """
567
+ from wavedl.utils.constraints import PhysicsConstrainedLoss
568
+
569
+ model.eval()
570
+ val_loss_sum = torch.tensor(0.0, device=accelerator.device)
571
+ val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
572
+ val_samples = 0
573
+ local_preds, local_targets = [], []
574
+
575
+ with torch.inference_mode():
576
+ for x, y in val_dl:
577
+ with accelerator.autocast():
578
+ pred = model(x)
579
+ if isinstance(criterion, PhysicsConstrainedLoss):
580
+ loss = criterion(pred, y, x)
581
+ else:
582
+ loss = criterion(pred, y)
583
+
584
+ val_loss_sum += loss.detach() * x.size(0)
585
+ val_samples += x.size(0)
586
+ val_mae_sum += torch.abs((pred - y) * phys_scale).sum(dim=0)
587
+ local_preds.append(pred.detach().cpu())
588
+ local_targets.append(y.detach().cpu())
589
+
590
+ # Gather across GPUs
591
+ local_preds_cat = torch.cat(local_preds)
592
+ local_targets_cat = torch.cat(local_targets)
593
+
594
+ if accelerator.num_processes > 1:
595
+ gathered_preds = accelerator.gather_for_metrics(
596
+ local_preds_cat.to(accelerator.device)
597
+ ).cpu()
598
+ gathered_targets = accelerator.gather_for_metrics(
599
+ local_targets_cat.to(accelerator.device)
600
+ ).cpu()
601
+ else:
602
+ gathered_preds = local_preds_cat
603
+ gathered_targets = local_targets_cat
604
+
605
+ # Sync metrics
606
+ val_loss_scalar = val_loss_sum.item()
607
+ val_metrics = torch.cat(
608
+ [
609
+ torch.tensor([val_loss_scalar], device=accelerator.device),
610
+ val_mae_sum,
611
+ ]
612
+ )
613
+ val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
614
+ total_val_samples = accelerator.reduce(
615
+ torch.tensor([val_samples], device=accelerator.device),
616
+ reduction="sum",
617
+ ).item()
618
+
619
+ avg_val_loss = val_metrics_sync[0].item() / total_val_samples
620
+ avg_mae_per_param = (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
621
+
622
+ return avg_val_loss, avg_mae_per_param, gathered_preds, gathered_targets
623
+
624
+
625
+ def _save_best_checkpoint(
626
+ accelerator,
627
+ model,
628
+ args,
629
+ epoch: int,
630
+ best_val_loss: float,
631
+ in_shape: tuple,
632
+ out_dim: int,
633
+ scaler,
634
+ logger,
635
+ ) -> None:
636
+ """
637
+ Save best checkpoint with metadata.
638
+
639
+ Args:
640
+ accelerator: Accelerator instance
641
+ model: Model to save
642
+ args: Command-line arguments
643
+ epoch: Current epoch (0-indexed)
644
+ best_val_loss: Best validation loss
645
+ in_shape: Input shape
646
+ out_dim: Output dimension
647
+ scaler: StandardScaler for targets
648
+ logger: Logger instance
649
+ """
650
+ ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
651
+ with suppress_accelerate_logging():
652
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
653
+
654
+ if accelerator.is_main_process:
655
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
656
+ pickle.dump(
657
+ {
658
+ "epoch": epoch + 1,
659
+ "best_val_loss": best_val_loss,
660
+ "patience_ctr": 0,
661
+ "model_name": args.model,
662
+ "in_shape": in_shape,
663
+ "out_dim": out_dim,
664
+ },
665
+ f,
666
+ )
667
+
668
+ # Save standalone weights
669
+ try:
670
+ unwrapped = accelerator.unwrap_model(model)
671
+ except KeyError:
672
+ unwrapped = model.module if hasattr(model, "module") else model
673
+ if hasattr(unwrapped, "_orig_mod"):
674
+ unwrapped = unwrapped._orig_mod
675
+
676
+ torch.save(
677
+ unwrapped.state_dict(),
678
+ os.path.join(args.output_dir, "best_model_weights.pth"),
679
+ )
680
+
681
+ # Copy scaler for checkpoint portability
682
+ scaler_src = os.path.join(args.output_dir, "scaler.pkl")
683
+ scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
684
+ if os.path.exists(scaler_src):
685
+ shutil.copy2(scaler_src, scaler_dst)
686
+
687
+ logger.info(f" 💾 Best model saved (val_loss: {best_val_loss:.6f})")
688
+
689
+
690
+ # ==============================================================================
691
+ # IN-PROCESS HPO TRAINING FUNCTION
692
+ # ==============================================================================
693
+
694
+
695
+ def train_single_trial(
696
+ data_path: str,
697
+ model_name: str = "cnn",
698
+ lr: float = 1e-3,
699
+ batch_size: int = 32,
700
+ epochs: int = 50,
701
+ patience: int = 20,
702
+ optimizer_name: str = "adamw",
703
+ scheduler_name: str = "plateau",
704
+ loss_name: str = "mse",
705
+ weight_decay: float = 1e-4,
706
+ seed: int = 2025,
707
+ precision: str = "bf16",
708
+ workers: int = 0,
709
+ huber_delta: float = 1.0,
710
+ momentum: float = 0.9,
711
+ trial: optuna.trial.Trial | None = None,
712
+ verbose: bool = False,
713
+ ) -> dict:
714
+ """
715
+ Single-trial training function for in-process HPO.
716
+
717
+ This is a lightweight training loop designed for hyperparameter optimization
718
+ that supports Optuna pruning callbacks. Unlike `main()`, this avoids
719
+ Accelerator complexity for simpler single-GPU trials.
720
+
721
+ Args:
722
+ data_path: Path to training data (NPZ, HDF5, MAT)
723
+ model_name: Model architecture name (from registry)
724
+ lr: Learning rate
725
+ batch_size: Batch size
726
+ epochs: Maximum epochs
727
+ patience: Early stopping patience
728
+ optimizer_name: Optimizer name (from registry)
729
+ scheduler_name: Scheduler name (from registry)
730
+ loss_name: Loss function name (from registry)
731
+ weight_decay: Weight decay for optimizer
732
+ seed: Random seed
733
+ precision: Mixed precision mode ("bf16", "fp16", "no")
734
+ workers: DataLoader workers (0 for main process only)
735
+ huber_delta: Delta for Huber loss
736
+ momentum: Momentum for SGD optimizer
737
+ trial: Optuna trial for pruning callbacks (None for standalone use)
738
+ verbose: Print training progress
739
+
740
+ Returns:
741
+ dict with keys:
742
+ - best_val_loss: Best validation loss achieved
743
+ - epochs_trained: Number of epochs completed
744
+ - final_val_loss: Validation loss at last epoch
745
+ - pruned: Whether trial was pruned
746
+
747
+ Raises:
748
+ optuna.TrialPruned: If trial should be pruned (only when trial is provided)
749
+ """
750
+ import tempfile
751
+
752
+ import torch
753
+ from torch.utils.data import DataLoader, TensorDataset
754
+
755
+ from wavedl.models import build_model
756
+ from wavedl.utils import get_loss, get_optimizer, get_scheduler, prepare_data
757
+
758
+ # Set seed for reproducibility
759
+ torch.manual_seed(seed)
760
+ if torch.cuda.is_available():
761
+ torch.cuda.manual_seed_all(seed)
762
+
763
+ # Device setup
764
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
765
+
766
+ # Mixed precision setup
767
+ use_amp = precision != "no" and torch.cuda.is_available()
768
+ amp_dtype = torch.bfloat16 if precision == "bf16" else torch.float16
769
+ scaler = torch.amp.GradScaler("cuda", enabled=(use_amp and precision == "fp16"))
770
+
771
+ # Load and prepare data using temporary directory
772
+ with tempfile.TemporaryDirectory() as tmpdir:
773
+ # Create a minimal args-like object for prepare_data
774
+ class Args:
775
+ pass
776
+
777
+ args = Args()
778
+ args.data_path = data_path
779
+ args.batch_size = batch_size
780
+ args.workers = workers
781
+ args.val_size = 0.2
782
+ args.cache_validate = "fast"
783
+ args.single_channel = False
784
+ args.seed = seed # Required by prepare_data for train_test_split
785
+
786
+ # Create a dummy logger
787
+ class DummyLogger:
788
+ def info(self, msg):
789
+ if verbose:
790
+ print(msg)
791
+
792
+ def warning(self, msg):
793
+ if verbose:
794
+ print(f"WARNING: {msg}")
795
+
796
+ def error(self, msg):
797
+ print(f"ERROR: {msg}")
798
+
799
+ # Create a dummy accelerator for prepare_data compatibility
800
+ # Note: explicit device capture needed since class body scope differs from function scope
801
+ _device_for_accelerator = device
802
+
803
+ class DummyAccelerator:
804
+ is_main_process = True
805
+ device = _device_for_accelerator
806
+ num_processes = 1
807
+
808
+ @staticmethod
809
+ def wait_for_everyone():
810
+ pass # No-op for single-process
811
+
812
+ train_dl, val_dl, _target_scaler, in_shape, out_dim = prepare_data(
813
+ args, DummyLogger(), DummyAccelerator(), cache_dir=tmpdir
814
+ )
815
+
816
+ # Build model
817
+ model = build_model(model_name, in_shape=in_shape, out_size=out_dim)
818
+ model = model.to(device)
819
+
820
+ # Create optimizer
821
+ optimizer = get_optimizer(
822
+ name=optimizer_name,
823
+ params=model.get_optimizer_groups(lr, weight_decay),
824
+ lr=lr,
825
+ weight_decay=weight_decay,
826
+ momentum=momentum,
827
+ )
828
+
829
+ # Create loss function
830
+ criterion = get_loss(name=loss_name, delta=huber_delta)
831
+ criterion = criterion.to(device)
832
+
833
+ # Create scheduler
834
+ scheduler = get_scheduler(
835
+ name=scheduler_name,
836
+ optimizer=optimizer,
837
+ epochs=epochs,
838
+ steps_per_epoch=len(train_dl),
839
+ patience=patience // 2, # Use half patience for scheduler
840
+ )
841
+ scheduler_step_per_batch = scheduler_name == "onecycle"
842
+
843
+ # Training state
844
+ best_val_loss = float("inf")
845
+ patience_ctr = 0
846
+ epochs_trained = 0
847
+ final_val_loss = float("inf")
848
+
849
+ # Training loop
850
+ for epoch in range(epochs):
851
+ epochs_trained = epoch + 1
852
+
853
+ # === Training Phase ===
854
+ model.train()
855
+ train_loss_sum = 0.0
856
+ train_samples = 0
857
+
858
+ for x, y in train_dl:
859
+ x, y = x.to(device), y.to(device)
860
+
861
+ with torch.amp.autocast(
862
+ device_type="cuda", dtype=amp_dtype, enabled=use_amp
863
+ ):
864
+ pred = model(x)
865
+ loss = criterion(pred, y)
866
+
867
+ optimizer.zero_grad(set_to_none=True)
868
+
869
+ if use_amp and precision == "fp16":
870
+ scaler.scale(loss).backward()
871
+ scaler.step(optimizer)
872
+ scaler.update()
873
+ else:
874
+ loss.backward()
875
+ optimizer.step()
876
+
877
+ if scheduler_step_per_batch:
878
+ scheduler.step()
879
+
880
+ train_loss_sum += loss.item() * x.size(0)
881
+ train_samples += x.size(0)
882
+
883
+ avg_train_loss = train_loss_sum / train_samples
884
+
885
+ # === Validation Phase ===
886
+ model.eval()
887
+ val_loss_sum = 0.0
888
+ val_samples = 0
889
+
890
+ with torch.inference_mode():
891
+ for x, y in val_dl:
892
+ x, y = x.to(device), y.to(device)
893
+
894
+ with torch.amp.autocast(
895
+ device_type="cuda", dtype=amp_dtype, enabled=use_amp
896
+ ):
897
+ pred = model(x)
898
+ loss = criterion(pred, y)
899
+
900
+ val_loss_sum += loss.item() * x.size(0)
901
+ val_samples += x.size(0)
902
+
903
+ avg_val_loss = val_loss_sum / val_samples
904
+ final_val_loss = avg_val_loss
905
+
906
+ # === Optuna Integration ===
907
+ if trial is not None:
908
+ # Report intermediate result
909
+ trial.report(avg_val_loss, epoch)
910
+
911
+ # Check if trial should be pruned
912
+ if trial.should_prune():
913
+ return {
914
+ "best_val_loss": best_val_loss,
915
+ "epochs_trained": epochs_trained,
916
+ "final_val_loss": final_val_loss,
917
+ "pruned": True,
918
+ }
919
+
920
+ # === Early Stopping ===
921
+ if avg_val_loss < best_val_loss:
922
+ best_val_loss = avg_val_loss
923
+ patience_ctr = 0
924
+ else:
925
+ patience_ctr += 1
926
+ if patience_ctr >= patience:
927
+ if verbose:
928
+ print(f"Early stopping at epoch {epoch + 1}")
929
+ break
930
+
931
+ # === LR Scheduling ===
932
+ if not scheduler_step_per_batch:
933
+ if scheduler_name == "plateau":
934
+ scheduler.step(avg_val_loss)
935
+ else:
936
+ scheduler.step()
937
+
938
+ if verbose:
939
+ print(
940
+ f"Epoch {epoch + 1}/{epochs}: "
941
+ f"train_loss={avg_train_loss:.6f}, val_loss={avg_val_loss:.6f}"
942
+ )
943
+
944
+ return {
945
+ "best_val_loss": best_val_loss,
946
+ "epochs_trained": epochs_trained,
947
+ "final_val_loss": final_val_loss,
948
+ "pruned": False,
949
+ }
950
+
951
+
486
952
  # ==============================================================================
487
953
  # MAIN TRAINING FUNCTION
488
954
  # ==============================================================================
wavedl/utils/__init__.py CHANGED
@@ -8,25 +8,64 @@ Author: Ductho Le (ductho.le@outlook.com)
8
8
  Version: 1.0.0
9
9
  """
10
10
 
11
- from .config import (
11
+ import os
12
+
13
+
14
+ def setup_hpc_cache_dirs() -> None:
15
+ """
16
+ Configure cache directories for HPC environments with read-only home.
17
+
18
+ Auto-configures writable cache directories when home is not writable.
19
+ Uses current working directory as fallback - works on HPC and local machines.
20
+
21
+ Call this BEFORE importing libraries that use cache directories:
22
+ - torch (TORCH_HOME)
23
+ - matplotlib (MPLCONFIGDIR)
24
+ - fontconfig (FONTCONFIG_CACHE)
25
+
26
+ Example:
27
+ from wavedl.utils import setup_hpc_cache_dirs
28
+ setup_hpc_cache_dirs() # Must be before torch/matplotlib imports
29
+ """
30
+
31
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
32
+ if env_var in os.environ:
33
+ return # User already set, respect their choice
34
+ home = os.path.expanduser("~")
35
+ if os.access(home, os.W_OK):
36
+ return # Home is writable, let library use defaults
37
+ # Home not writable - use current working directory
38
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
39
+ os.makedirs(cache_path, exist_ok=True)
40
+ os.environ[env_var] = cache_path
41
+
42
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
43
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
44
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
45
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
46
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
47
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
48
+
49
+
50
+ from .config import ( # noqa: E402
12
51
  create_default_config,
13
52
  load_config,
14
53
  merge_config_with_args,
15
54
  save_config,
16
55
  validate_config,
17
56
  )
18
- from .constraints import (
57
+ from .constraints import ( # noqa: E402
19
58
  ExpressionConstraint,
20
59
  FileConstraint,
21
60
  PhysicsConstrainedLoss,
22
61
  build_constraints,
23
62
  )
24
- from .cross_validation import (
63
+ from .cross_validation import ( # noqa: E402
25
64
  CVDataset,
26
65
  run_cross_validation,
27
66
  train_fold,
28
67
  )
29
- from .data import (
68
+ from .data import ( # noqa: E402
30
69
  # Multi-format data loading
31
70
  DataSource,
32
71
  HDF5Source,
@@ -40,18 +79,18 @@ from .data import (
40
79
  memmap_worker_init_fn,
41
80
  prepare_data,
42
81
  )
43
- from .distributed import (
82
+ from .distributed import ( # noqa: E402
44
83
  broadcast_early_stop,
45
84
  broadcast_value,
46
85
  sync_tensor,
47
86
  )
48
- from .losses import (
87
+ from .losses import ( # noqa: E402
49
88
  LogCoshLoss,
50
89
  WeightedMSELoss,
51
90
  get_loss,
52
91
  list_losses,
53
92
  )
54
- from .metrics import (
93
+ from .metrics import ( # noqa: E402
55
94
  COLORS,
56
95
  FIGURE_DPI,
57
96
  FIGURE_WIDTH_CM,
@@ -76,12 +115,12 @@ from .metrics import (
76
115
  plot_residuals,
77
116
  plot_scientific_scatter,
78
117
  )
79
- from .optimizers import (
118
+ from .optimizers import ( # noqa: E402
80
119
  get_optimizer,
81
120
  get_optimizer_with_param_groups,
82
121
  list_optimizers,
83
122
  )
84
- from .schedulers import (
123
+ from .schedulers import ( # noqa: E402
85
124
  get_scheduler,
86
125
  get_scheduler_with_warmup,
87
126
  is_epoch_based,
@@ -156,6 +195,7 @@ __all__ = [
156
195
  # Cross-Validation
157
196
  "run_cross_validation",
158
197
  "save_config",
198
+ "setup_hpc_cache_dirs",
159
199
  "sync_tensor",
160
200
  "train_fold",
161
201
  "validate_config",
wavedl/utils/config.py CHANGED
@@ -116,13 +116,11 @@ def merge_config_with_args(
116
116
  """
117
117
  # Get parser defaults to detect which args were explicitly set by user
118
118
  if parser is not None:
119
- # Safe extraction: iterate actions instead of parse_args([])
120
- # This avoids failures if required arguments are added later
121
- defaults = {
122
- action.dest: action.default
123
- for action in parser._actions
124
- if action.dest != "help"
125
- }
119
+ # Use public API to extract defaults (avoids private _actions attribute)
120
+ defaults = {}
121
+ for action in parser._option_string_actions.values():
122
+ if action.dest != "help":
123
+ defaults[action.dest] = parser.get_default(action.dest)
126
124
  else:
127
125
  # Fallback: reconstruct defaults from known patterns
128
126
  # This works because argparse stores actual values, and we compare
@@ -233,7 +231,7 @@ def validate_config(
233
231
 
234
232
  # Validate numeric ranges
235
233
  numeric_checks = {
236
- "lr": (0, 1, "Learning rate should be between 0 and 1"),
234
+ "lr": (0, 10, "Learning rate should be between 0 and 10"),
237
235
  "epochs": (1, 100000, "Epochs should be positive"),
238
236
  "batch_size": (1, 10000, "Batch size should be positive"),
239
237
  "patience": (1, 1000, "Patience should be positive"),