wavedl 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +115 -9
- wavedl/models/_pretrained_utils.py +72 -0
- wavedl/models/_template.py +7 -6
- wavedl/models/cnn.py +20 -0
- wavedl/models/convnext.py +3 -70
- wavedl/models/convnext_v2.py +1 -18
- wavedl/models/mamba.py +126 -38
- wavedl/models/resnet3d.py +23 -5
- wavedl/models/unireplknet.py +1 -18
- wavedl/models/vit.py +18 -8
- wavedl/test.py +5 -23
- wavedl/train.py +492 -26
- wavedl/utils/__init__.py +49 -9
- wavedl/utils/config.py +6 -8
- wavedl/utils/cross_validation.py +17 -4
- wavedl/utils/data.py +140 -174
- wavedl/utils/metrics.py +26 -5
- wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/METADATA +35 -14
- wavedl-1.7.0.dist-info/RECORD +46 -0
- wavedl-1.6.3.dist-info/RECORD +0 -46
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/LICENSE +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/WHEEL +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/top_level.txt +0 -0
wavedl/train.py
CHANGED
|
@@ -41,30 +41,11 @@ from __future__ import annotations
|
|
|
41
41
|
# Uses current working directory as fallback - works on HPC and local machines.
|
|
42
42
|
import os
|
|
43
43
|
|
|
44
|
+
# Import and call HPC cache setup before any library imports
|
|
45
|
+
from wavedl.utils import setup_hpc_cache_dirs
|
|
44
46
|
|
|
45
|
-
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
46
|
-
"""Set cache directory to CWD if home is not writable."""
|
|
47
|
-
if env_var in os.environ:
|
|
48
|
-
return # User already set, respect their choice
|
|
49
47
|
|
|
50
|
-
|
|
51
|
-
home = os.path.expanduser("~")
|
|
52
|
-
if os.access(home, os.W_OK):
|
|
53
|
-
return # Home is writable, let library use defaults
|
|
54
|
-
|
|
55
|
-
# Home not writable - use current working directory
|
|
56
|
-
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
57
|
-
os.makedirs(cache_path, exist_ok=True)
|
|
58
|
-
os.environ[env_var] = cache_path
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# Configure cache directories (before any library imports)
|
|
62
|
-
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
63
|
-
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
64
|
-
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
65
|
-
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
66
|
-
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
67
|
-
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
48
|
+
setup_hpc_cache_dirs()
|
|
68
49
|
|
|
69
50
|
|
|
70
51
|
def _setup_per_rank_compile_cache() -> None:
|
|
@@ -109,7 +90,11 @@ import shutil
|
|
|
109
90
|
import sys
|
|
110
91
|
import time
|
|
111
92
|
import warnings
|
|
112
|
-
from typing import Any
|
|
93
|
+
from typing import TYPE_CHECKING, Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if TYPE_CHECKING:
|
|
97
|
+
import optuna
|
|
113
98
|
|
|
114
99
|
|
|
115
100
|
# Suppress Pydantic warnings from accelerate's internal Field() usage
|
|
@@ -155,9 +140,8 @@ except ImportError:
|
|
|
155
140
|
# ==============================================================================
|
|
156
141
|
# RUNTIME CONFIGURATION (post-import)
|
|
157
142
|
# ==============================================================================
|
|
158
|
-
#
|
|
159
|
-
|
|
160
|
-
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
143
|
+
# Note: matplotlib cache directory is already configured by setup_hpc_cache_dirs()
|
|
144
|
+
# called at module load time. No additional MPLCONFIGDIR setup needed here.
|
|
161
145
|
|
|
162
146
|
# Suppress warnings from known-noisy libraries, but preserve legitimate warnings
|
|
163
147
|
# from torch/numpy about NaN, dtype, and numerical issues.
|
|
@@ -483,6 +467,488 @@ def parse_args() -> argparse.Namespace:
|
|
|
483
467
|
return args, parser # Returns (Namespace, ArgumentParser)
|
|
484
468
|
|
|
485
469
|
|
|
470
|
+
# ==============================================================================
|
|
471
|
+
# TRAINING HELPER FUNCTIONS
|
|
472
|
+
# ==============================================================================
|
|
473
|
+
def _run_train_epoch(
|
|
474
|
+
model,
|
|
475
|
+
train_dl,
|
|
476
|
+
optimizer,
|
|
477
|
+
criterion,
|
|
478
|
+
accelerator,
|
|
479
|
+
scheduler,
|
|
480
|
+
scheduler_step_per_batch: bool,
|
|
481
|
+
grad_clip: float,
|
|
482
|
+
) -> tuple[float, float]:
|
|
483
|
+
"""
|
|
484
|
+
Run one training epoch.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
model: Model to train
|
|
488
|
+
train_dl: Training DataLoader
|
|
489
|
+
optimizer: Optimizer
|
|
490
|
+
criterion: Loss function (may be PhysicsConstrainedLoss)
|
|
491
|
+
accelerator: Accelerator instance
|
|
492
|
+
scheduler: LR scheduler
|
|
493
|
+
scheduler_step_per_batch: Whether to step scheduler per batch
|
|
494
|
+
grad_clip: Gradient clipping norm
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Tuple of (avg_train_loss, avg_grad_norm)
|
|
498
|
+
"""
|
|
499
|
+
from wavedl.utils.constraints import PhysicsConstrainedLoss
|
|
500
|
+
|
|
501
|
+
model.train()
|
|
502
|
+
train_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
503
|
+
train_samples = 0
|
|
504
|
+
grad_norm_tracker = MetricTracker()
|
|
505
|
+
|
|
506
|
+
for x, y in train_dl:
|
|
507
|
+
with accelerator.accumulate(model):
|
|
508
|
+
with accelerator.autocast():
|
|
509
|
+
pred = model(x)
|
|
510
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
511
|
+
loss = criterion(pred, y, x)
|
|
512
|
+
else:
|
|
513
|
+
loss = criterion(pred, y)
|
|
514
|
+
|
|
515
|
+
accelerator.backward(loss)
|
|
516
|
+
|
|
517
|
+
if accelerator.sync_gradients:
|
|
518
|
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), grad_clip)
|
|
519
|
+
if grad_norm is not None:
|
|
520
|
+
grad_norm_tracker.update(grad_norm.item())
|
|
521
|
+
|
|
522
|
+
optimizer.step()
|
|
523
|
+
optimizer.zero_grad(set_to_none=True)
|
|
524
|
+
|
|
525
|
+
if scheduler_step_per_batch:
|
|
526
|
+
scheduler.step()
|
|
527
|
+
|
|
528
|
+
train_loss_sum += loss.detach() * x.size(0)
|
|
529
|
+
train_samples += x.size(0)
|
|
530
|
+
|
|
531
|
+
# Sync across GPUs
|
|
532
|
+
train_loss_scalar = train_loss_sum.item()
|
|
533
|
+
global_loss = accelerator.reduce(
|
|
534
|
+
torch.tensor([train_loss_scalar], device=accelerator.device),
|
|
535
|
+
reduction="sum",
|
|
536
|
+
).item()
|
|
537
|
+
global_samples = accelerator.reduce(
|
|
538
|
+
torch.tensor([train_samples], device=accelerator.device),
|
|
539
|
+
reduction="sum",
|
|
540
|
+
).item()
|
|
541
|
+
|
|
542
|
+
return global_loss / global_samples, grad_norm_tracker.avg
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _run_validation(
|
|
546
|
+
model,
|
|
547
|
+
val_dl,
|
|
548
|
+
criterion,
|
|
549
|
+
accelerator,
|
|
550
|
+
out_dim: int,
|
|
551
|
+
phys_scale: torch.Tensor,
|
|
552
|
+
) -> tuple[float, np.ndarray, torch.Tensor, torch.Tensor]:
|
|
553
|
+
"""
|
|
554
|
+
Run validation epoch.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
model: Model in eval mode
|
|
558
|
+
val_dl: Validation DataLoader
|
|
559
|
+
criterion: Loss function
|
|
560
|
+
accelerator: Accelerator instance
|
|
561
|
+
out_dim: Number of output dimensions
|
|
562
|
+
phys_scale: Physical scale tensor for MAE computation
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
Tuple of (avg_val_loss, avg_mae_per_param, gathered_preds, gathered_targets)
|
|
566
|
+
"""
|
|
567
|
+
from wavedl.utils.constraints import PhysicsConstrainedLoss
|
|
568
|
+
|
|
569
|
+
model.eval()
|
|
570
|
+
val_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
571
|
+
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
572
|
+
val_samples = 0
|
|
573
|
+
local_preds, local_targets = [], []
|
|
574
|
+
|
|
575
|
+
with torch.inference_mode():
|
|
576
|
+
for x, y in val_dl:
|
|
577
|
+
with accelerator.autocast():
|
|
578
|
+
pred = model(x)
|
|
579
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
580
|
+
loss = criterion(pred, y, x)
|
|
581
|
+
else:
|
|
582
|
+
loss = criterion(pred, y)
|
|
583
|
+
|
|
584
|
+
val_loss_sum += loss.detach() * x.size(0)
|
|
585
|
+
val_samples += x.size(0)
|
|
586
|
+
val_mae_sum += torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
587
|
+
local_preds.append(pred.detach().cpu())
|
|
588
|
+
local_targets.append(y.detach().cpu())
|
|
589
|
+
|
|
590
|
+
# Gather across GPUs
|
|
591
|
+
local_preds_cat = torch.cat(local_preds)
|
|
592
|
+
local_targets_cat = torch.cat(local_targets)
|
|
593
|
+
|
|
594
|
+
if accelerator.num_processes > 1:
|
|
595
|
+
gathered_preds = accelerator.gather_for_metrics(
|
|
596
|
+
local_preds_cat.to(accelerator.device)
|
|
597
|
+
).cpu()
|
|
598
|
+
gathered_targets = accelerator.gather_for_metrics(
|
|
599
|
+
local_targets_cat.to(accelerator.device)
|
|
600
|
+
).cpu()
|
|
601
|
+
else:
|
|
602
|
+
gathered_preds = local_preds_cat
|
|
603
|
+
gathered_targets = local_targets_cat
|
|
604
|
+
|
|
605
|
+
# Sync metrics
|
|
606
|
+
val_loss_scalar = val_loss_sum.item()
|
|
607
|
+
val_metrics = torch.cat(
|
|
608
|
+
[
|
|
609
|
+
torch.tensor([val_loss_scalar], device=accelerator.device),
|
|
610
|
+
val_mae_sum,
|
|
611
|
+
]
|
|
612
|
+
)
|
|
613
|
+
val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
|
|
614
|
+
total_val_samples = accelerator.reduce(
|
|
615
|
+
torch.tensor([val_samples], device=accelerator.device),
|
|
616
|
+
reduction="sum",
|
|
617
|
+
).item()
|
|
618
|
+
|
|
619
|
+
avg_val_loss = val_metrics_sync[0].item() / total_val_samples
|
|
620
|
+
avg_mae_per_param = (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
|
|
621
|
+
|
|
622
|
+
return avg_val_loss, avg_mae_per_param, gathered_preds, gathered_targets
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def _save_best_checkpoint(
|
|
626
|
+
accelerator,
|
|
627
|
+
model,
|
|
628
|
+
args,
|
|
629
|
+
epoch: int,
|
|
630
|
+
best_val_loss: float,
|
|
631
|
+
in_shape: tuple,
|
|
632
|
+
out_dim: int,
|
|
633
|
+
scaler,
|
|
634
|
+
logger,
|
|
635
|
+
) -> None:
|
|
636
|
+
"""
|
|
637
|
+
Save best checkpoint with metadata.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
accelerator: Accelerator instance
|
|
641
|
+
model: Model to save
|
|
642
|
+
args: Command-line arguments
|
|
643
|
+
epoch: Current epoch (0-indexed)
|
|
644
|
+
best_val_loss: Best validation loss
|
|
645
|
+
in_shape: Input shape
|
|
646
|
+
out_dim: Output dimension
|
|
647
|
+
scaler: StandardScaler for targets
|
|
648
|
+
logger: Logger instance
|
|
649
|
+
"""
|
|
650
|
+
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
651
|
+
with suppress_accelerate_logging():
|
|
652
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
653
|
+
|
|
654
|
+
if accelerator.is_main_process:
|
|
655
|
+
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
656
|
+
pickle.dump(
|
|
657
|
+
{
|
|
658
|
+
"epoch": epoch + 1,
|
|
659
|
+
"best_val_loss": best_val_loss,
|
|
660
|
+
"patience_ctr": 0,
|
|
661
|
+
"model_name": args.model,
|
|
662
|
+
"in_shape": in_shape,
|
|
663
|
+
"out_dim": out_dim,
|
|
664
|
+
},
|
|
665
|
+
f,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Save standalone weights
|
|
669
|
+
try:
|
|
670
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
671
|
+
except KeyError:
|
|
672
|
+
unwrapped = model.module if hasattr(model, "module") else model
|
|
673
|
+
if hasattr(unwrapped, "_orig_mod"):
|
|
674
|
+
unwrapped = unwrapped._orig_mod
|
|
675
|
+
|
|
676
|
+
torch.save(
|
|
677
|
+
unwrapped.state_dict(),
|
|
678
|
+
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Copy scaler for checkpoint portability
|
|
682
|
+
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
683
|
+
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
684
|
+
if os.path.exists(scaler_src):
|
|
685
|
+
shutil.copy2(scaler_src, scaler_dst)
|
|
686
|
+
|
|
687
|
+
logger.info(f" 💾 Best model saved (val_loss: {best_val_loss:.6f})")
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
# ==============================================================================
|
|
691
|
+
# IN-PROCESS HPO TRAINING FUNCTION
|
|
692
|
+
# ==============================================================================
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def train_single_trial(
|
|
696
|
+
data_path: str,
|
|
697
|
+
model_name: str = "cnn",
|
|
698
|
+
lr: float = 1e-3,
|
|
699
|
+
batch_size: int = 32,
|
|
700
|
+
epochs: int = 50,
|
|
701
|
+
patience: int = 20,
|
|
702
|
+
optimizer_name: str = "adamw",
|
|
703
|
+
scheduler_name: str = "plateau",
|
|
704
|
+
loss_name: str = "mse",
|
|
705
|
+
weight_decay: float = 1e-4,
|
|
706
|
+
seed: int = 2025,
|
|
707
|
+
precision: str = "bf16",
|
|
708
|
+
workers: int = 0,
|
|
709
|
+
huber_delta: float = 1.0,
|
|
710
|
+
momentum: float = 0.9,
|
|
711
|
+
trial: optuna.trial.Trial | None = None,
|
|
712
|
+
verbose: bool = False,
|
|
713
|
+
) -> dict:
|
|
714
|
+
"""
|
|
715
|
+
Single-trial training function for in-process HPO.
|
|
716
|
+
|
|
717
|
+
This is a lightweight training loop designed for hyperparameter optimization
|
|
718
|
+
that supports Optuna pruning callbacks. Unlike `main()`, this avoids
|
|
719
|
+
Accelerator complexity for simpler single-GPU trials.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
data_path: Path to training data (NPZ, HDF5, MAT)
|
|
723
|
+
model_name: Model architecture name (from registry)
|
|
724
|
+
lr: Learning rate
|
|
725
|
+
batch_size: Batch size
|
|
726
|
+
epochs: Maximum epochs
|
|
727
|
+
patience: Early stopping patience
|
|
728
|
+
optimizer_name: Optimizer name (from registry)
|
|
729
|
+
scheduler_name: Scheduler name (from registry)
|
|
730
|
+
loss_name: Loss function name (from registry)
|
|
731
|
+
weight_decay: Weight decay for optimizer
|
|
732
|
+
seed: Random seed
|
|
733
|
+
precision: Mixed precision mode ("bf16", "fp16", "no")
|
|
734
|
+
workers: DataLoader workers (0 for main process only)
|
|
735
|
+
huber_delta: Delta for Huber loss
|
|
736
|
+
momentum: Momentum for SGD optimizer
|
|
737
|
+
trial: Optuna trial for pruning callbacks (None for standalone use)
|
|
738
|
+
verbose: Print training progress
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
dict with keys:
|
|
742
|
+
- best_val_loss: Best validation loss achieved
|
|
743
|
+
- epochs_trained: Number of epochs completed
|
|
744
|
+
- final_val_loss: Validation loss at last epoch
|
|
745
|
+
- pruned: Whether trial was pruned
|
|
746
|
+
|
|
747
|
+
Raises:
|
|
748
|
+
optuna.TrialPruned: If trial should be pruned (only when trial is provided)
|
|
749
|
+
"""
|
|
750
|
+
import tempfile
|
|
751
|
+
|
|
752
|
+
import torch
|
|
753
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
754
|
+
|
|
755
|
+
from wavedl.models import build_model
|
|
756
|
+
from wavedl.utils import get_loss, get_optimizer, get_scheduler, prepare_data
|
|
757
|
+
|
|
758
|
+
# Set seed for reproducibility
|
|
759
|
+
torch.manual_seed(seed)
|
|
760
|
+
if torch.cuda.is_available():
|
|
761
|
+
torch.cuda.manual_seed_all(seed)
|
|
762
|
+
|
|
763
|
+
# Device setup
|
|
764
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
765
|
+
|
|
766
|
+
# Mixed precision setup
|
|
767
|
+
use_amp = precision != "no" and torch.cuda.is_available()
|
|
768
|
+
amp_dtype = torch.bfloat16 if precision == "bf16" else torch.float16
|
|
769
|
+
scaler = torch.amp.GradScaler("cuda", enabled=(use_amp and precision == "fp16"))
|
|
770
|
+
|
|
771
|
+
# Load and prepare data using temporary directory
|
|
772
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
773
|
+
# Create a minimal args-like object for prepare_data
|
|
774
|
+
class Args:
|
|
775
|
+
pass
|
|
776
|
+
|
|
777
|
+
args = Args()
|
|
778
|
+
args.data_path = data_path
|
|
779
|
+
args.batch_size = batch_size
|
|
780
|
+
args.workers = workers
|
|
781
|
+
args.val_size = 0.2
|
|
782
|
+
args.cache_validate = "fast"
|
|
783
|
+
args.single_channel = False
|
|
784
|
+
args.seed = seed # Required by prepare_data for train_test_split
|
|
785
|
+
|
|
786
|
+
# Create a dummy logger
|
|
787
|
+
class DummyLogger:
|
|
788
|
+
def info(self, msg):
|
|
789
|
+
if verbose:
|
|
790
|
+
print(msg)
|
|
791
|
+
|
|
792
|
+
def warning(self, msg):
|
|
793
|
+
if verbose:
|
|
794
|
+
print(f"WARNING: {msg}")
|
|
795
|
+
|
|
796
|
+
def error(self, msg):
|
|
797
|
+
print(f"ERROR: {msg}")
|
|
798
|
+
|
|
799
|
+
# Create a dummy accelerator for prepare_data compatibility
|
|
800
|
+
# Note: explicit device capture needed since class body scope differs from function scope
|
|
801
|
+
_device_for_accelerator = device
|
|
802
|
+
|
|
803
|
+
class DummyAccelerator:
|
|
804
|
+
is_main_process = True
|
|
805
|
+
device = _device_for_accelerator
|
|
806
|
+
num_processes = 1
|
|
807
|
+
|
|
808
|
+
@staticmethod
|
|
809
|
+
def wait_for_everyone():
|
|
810
|
+
pass # No-op for single-process
|
|
811
|
+
|
|
812
|
+
train_dl, val_dl, _target_scaler, in_shape, out_dim = prepare_data(
|
|
813
|
+
args, DummyLogger(), DummyAccelerator(), cache_dir=tmpdir
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Build model
|
|
817
|
+
model = build_model(model_name, in_shape=in_shape, out_size=out_dim)
|
|
818
|
+
model = model.to(device)
|
|
819
|
+
|
|
820
|
+
# Create optimizer
|
|
821
|
+
optimizer = get_optimizer(
|
|
822
|
+
name=optimizer_name,
|
|
823
|
+
params=model.get_optimizer_groups(lr, weight_decay),
|
|
824
|
+
lr=lr,
|
|
825
|
+
weight_decay=weight_decay,
|
|
826
|
+
momentum=momentum,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# Create loss function
|
|
830
|
+
criterion = get_loss(name=loss_name, delta=huber_delta)
|
|
831
|
+
criterion = criterion.to(device)
|
|
832
|
+
|
|
833
|
+
# Create scheduler
|
|
834
|
+
scheduler = get_scheduler(
|
|
835
|
+
name=scheduler_name,
|
|
836
|
+
optimizer=optimizer,
|
|
837
|
+
epochs=epochs,
|
|
838
|
+
steps_per_epoch=len(train_dl),
|
|
839
|
+
patience=patience // 2, # Use half patience for scheduler
|
|
840
|
+
)
|
|
841
|
+
scheduler_step_per_batch = scheduler_name == "onecycle"
|
|
842
|
+
|
|
843
|
+
# Training state
|
|
844
|
+
best_val_loss = float("inf")
|
|
845
|
+
patience_ctr = 0
|
|
846
|
+
epochs_trained = 0
|
|
847
|
+
final_val_loss = float("inf")
|
|
848
|
+
|
|
849
|
+
# Training loop
|
|
850
|
+
for epoch in range(epochs):
|
|
851
|
+
epochs_trained = epoch + 1
|
|
852
|
+
|
|
853
|
+
# === Training Phase ===
|
|
854
|
+
model.train()
|
|
855
|
+
train_loss_sum = 0.0
|
|
856
|
+
train_samples = 0
|
|
857
|
+
|
|
858
|
+
for x, y in train_dl:
|
|
859
|
+
x, y = x.to(device), y.to(device)
|
|
860
|
+
|
|
861
|
+
with torch.amp.autocast(
|
|
862
|
+
device_type="cuda", dtype=amp_dtype, enabled=use_amp
|
|
863
|
+
):
|
|
864
|
+
pred = model(x)
|
|
865
|
+
loss = criterion(pred, y)
|
|
866
|
+
|
|
867
|
+
optimizer.zero_grad(set_to_none=True)
|
|
868
|
+
|
|
869
|
+
if use_amp and precision == "fp16":
|
|
870
|
+
scaler.scale(loss).backward()
|
|
871
|
+
scaler.step(optimizer)
|
|
872
|
+
scaler.update()
|
|
873
|
+
else:
|
|
874
|
+
loss.backward()
|
|
875
|
+
optimizer.step()
|
|
876
|
+
|
|
877
|
+
if scheduler_step_per_batch:
|
|
878
|
+
scheduler.step()
|
|
879
|
+
|
|
880
|
+
train_loss_sum += loss.item() * x.size(0)
|
|
881
|
+
train_samples += x.size(0)
|
|
882
|
+
|
|
883
|
+
avg_train_loss = train_loss_sum / train_samples
|
|
884
|
+
|
|
885
|
+
# === Validation Phase ===
|
|
886
|
+
model.eval()
|
|
887
|
+
val_loss_sum = 0.0
|
|
888
|
+
val_samples = 0
|
|
889
|
+
|
|
890
|
+
with torch.inference_mode():
|
|
891
|
+
for x, y in val_dl:
|
|
892
|
+
x, y = x.to(device), y.to(device)
|
|
893
|
+
|
|
894
|
+
with torch.amp.autocast(
|
|
895
|
+
device_type="cuda", dtype=amp_dtype, enabled=use_amp
|
|
896
|
+
):
|
|
897
|
+
pred = model(x)
|
|
898
|
+
loss = criterion(pred, y)
|
|
899
|
+
|
|
900
|
+
val_loss_sum += loss.item() * x.size(0)
|
|
901
|
+
val_samples += x.size(0)
|
|
902
|
+
|
|
903
|
+
avg_val_loss = val_loss_sum / val_samples
|
|
904
|
+
final_val_loss = avg_val_loss
|
|
905
|
+
|
|
906
|
+
# === Optuna Integration ===
|
|
907
|
+
if trial is not None:
|
|
908
|
+
# Report intermediate result
|
|
909
|
+
trial.report(avg_val_loss, epoch)
|
|
910
|
+
|
|
911
|
+
# Check if trial should be pruned
|
|
912
|
+
if trial.should_prune():
|
|
913
|
+
return {
|
|
914
|
+
"best_val_loss": best_val_loss,
|
|
915
|
+
"epochs_trained": epochs_trained,
|
|
916
|
+
"final_val_loss": final_val_loss,
|
|
917
|
+
"pruned": True,
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
# === Early Stopping ===
|
|
921
|
+
if avg_val_loss < best_val_loss:
|
|
922
|
+
best_val_loss = avg_val_loss
|
|
923
|
+
patience_ctr = 0
|
|
924
|
+
else:
|
|
925
|
+
patience_ctr += 1
|
|
926
|
+
if patience_ctr >= patience:
|
|
927
|
+
if verbose:
|
|
928
|
+
print(f"Early stopping at epoch {epoch + 1}")
|
|
929
|
+
break
|
|
930
|
+
|
|
931
|
+
# === LR Scheduling ===
|
|
932
|
+
if not scheduler_step_per_batch:
|
|
933
|
+
if scheduler_name == "plateau":
|
|
934
|
+
scheduler.step(avg_val_loss)
|
|
935
|
+
else:
|
|
936
|
+
scheduler.step()
|
|
937
|
+
|
|
938
|
+
if verbose:
|
|
939
|
+
print(
|
|
940
|
+
f"Epoch {epoch + 1}/{epochs}: "
|
|
941
|
+
f"train_loss={avg_train_loss:.6f}, val_loss={avg_val_loss:.6f}"
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
return {
|
|
945
|
+
"best_val_loss": best_val_loss,
|
|
946
|
+
"epochs_trained": epochs_trained,
|
|
947
|
+
"final_val_loss": final_val_loss,
|
|
948
|
+
"pruned": False,
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
|
|
486
952
|
# ==============================================================================
|
|
487
953
|
# MAIN TRAINING FUNCTION
|
|
488
954
|
# ==============================================================================
|
wavedl/utils/__init__.py
CHANGED
|
@@ -8,25 +8,64 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
8
8
|
Version: 1.0.0
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def setup_hpc_cache_dirs() -> None:
|
|
15
|
+
"""
|
|
16
|
+
Configure cache directories for HPC environments with read-only home.
|
|
17
|
+
|
|
18
|
+
Auto-configures writable cache directories when home is not writable.
|
|
19
|
+
Uses current working directory as fallback - works on HPC and local machines.
|
|
20
|
+
|
|
21
|
+
Call this BEFORE importing libraries that use cache directories:
|
|
22
|
+
- torch (TORCH_HOME)
|
|
23
|
+
- matplotlib (MPLCONFIGDIR)
|
|
24
|
+
- fontconfig (FONTCONFIG_CACHE)
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
from wavedl.utils import setup_hpc_cache_dirs
|
|
28
|
+
setup_hpc_cache_dirs() # Must be before torch/matplotlib imports
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
32
|
+
if env_var in os.environ:
|
|
33
|
+
return # User already set, respect their choice
|
|
34
|
+
home = os.path.expanduser("~")
|
|
35
|
+
if os.access(home, os.W_OK):
|
|
36
|
+
return # Home is writable, let library use defaults
|
|
37
|
+
# Home not writable - use current working directory
|
|
38
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
39
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
40
|
+
os.environ[env_var] = cache_path
|
|
41
|
+
|
|
42
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
43
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
44
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
45
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
46
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
47
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
from .config import ( # noqa: E402
|
|
12
51
|
create_default_config,
|
|
13
52
|
load_config,
|
|
14
53
|
merge_config_with_args,
|
|
15
54
|
save_config,
|
|
16
55
|
validate_config,
|
|
17
56
|
)
|
|
18
|
-
from .constraints import (
|
|
57
|
+
from .constraints import ( # noqa: E402
|
|
19
58
|
ExpressionConstraint,
|
|
20
59
|
FileConstraint,
|
|
21
60
|
PhysicsConstrainedLoss,
|
|
22
61
|
build_constraints,
|
|
23
62
|
)
|
|
24
|
-
from .cross_validation import (
|
|
63
|
+
from .cross_validation import ( # noqa: E402
|
|
25
64
|
CVDataset,
|
|
26
65
|
run_cross_validation,
|
|
27
66
|
train_fold,
|
|
28
67
|
)
|
|
29
|
-
from .data import (
|
|
68
|
+
from .data import ( # noqa: E402
|
|
30
69
|
# Multi-format data loading
|
|
31
70
|
DataSource,
|
|
32
71
|
HDF5Source,
|
|
@@ -40,18 +79,18 @@ from .data import (
|
|
|
40
79
|
memmap_worker_init_fn,
|
|
41
80
|
prepare_data,
|
|
42
81
|
)
|
|
43
|
-
from .distributed import (
|
|
82
|
+
from .distributed import ( # noqa: E402
|
|
44
83
|
broadcast_early_stop,
|
|
45
84
|
broadcast_value,
|
|
46
85
|
sync_tensor,
|
|
47
86
|
)
|
|
48
|
-
from .losses import (
|
|
87
|
+
from .losses import ( # noqa: E402
|
|
49
88
|
LogCoshLoss,
|
|
50
89
|
WeightedMSELoss,
|
|
51
90
|
get_loss,
|
|
52
91
|
list_losses,
|
|
53
92
|
)
|
|
54
|
-
from .metrics import (
|
|
93
|
+
from .metrics import ( # noqa: E402
|
|
55
94
|
COLORS,
|
|
56
95
|
FIGURE_DPI,
|
|
57
96
|
FIGURE_WIDTH_CM,
|
|
@@ -76,12 +115,12 @@ from .metrics import (
|
|
|
76
115
|
plot_residuals,
|
|
77
116
|
plot_scientific_scatter,
|
|
78
117
|
)
|
|
79
|
-
from .optimizers import (
|
|
118
|
+
from .optimizers import ( # noqa: E402
|
|
80
119
|
get_optimizer,
|
|
81
120
|
get_optimizer_with_param_groups,
|
|
82
121
|
list_optimizers,
|
|
83
122
|
)
|
|
84
|
-
from .schedulers import (
|
|
123
|
+
from .schedulers import ( # noqa: E402
|
|
85
124
|
get_scheduler,
|
|
86
125
|
get_scheduler_with_warmup,
|
|
87
126
|
is_epoch_based,
|
|
@@ -156,6 +195,7 @@ __all__ = [
|
|
|
156
195
|
# Cross-Validation
|
|
157
196
|
"run_cross_validation",
|
|
158
197
|
"save_config",
|
|
198
|
+
"setup_hpc_cache_dirs",
|
|
159
199
|
"sync_tensor",
|
|
160
200
|
"train_fold",
|
|
161
201
|
"validate_config",
|
wavedl/utils/config.py
CHANGED
|
@@ -116,13 +116,11 @@ def merge_config_with_args(
|
|
|
116
116
|
"""
|
|
117
117
|
# Get parser defaults to detect which args were explicitly set by user
|
|
118
118
|
if parser is not None:
|
|
119
|
-
#
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
action.dest:
|
|
123
|
-
|
|
124
|
-
if action.dest != "help"
|
|
125
|
-
}
|
|
119
|
+
# Use public API to extract defaults (avoids private _actions attribute)
|
|
120
|
+
defaults = {}
|
|
121
|
+
for action in parser._option_string_actions.values():
|
|
122
|
+
if action.dest != "help":
|
|
123
|
+
defaults[action.dest] = parser.get_default(action.dest)
|
|
126
124
|
else:
|
|
127
125
|
# Fallback: reconstruct defaults from known patterns
|
|
128
126
|
# This works because argparse stores actual values, and we compare
|
|
@@ -233,7 +231,7 @@ def validate_config(
|
|
|
233
231
|
|
|
234
232
|
# Validate numeric ranges
|
|
235
233
|
numeric_checks = {
|
|
236
|
-
"lr": (0,
|
|
234
|
+
"lr": (0, 10, "Learning rate should be between 0 and 10"),
|
|
237
235
|
"epochs": (1, 100000, "Epochs should be positive"),
|
|
238
236
|
"batch_size": (1, 10000, "Batch size should be positive"),
|
|
239
237
|
"patience": (1, 1000, "Patience should be positive"),
|