wavedl 1.4.6__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +9 -1
- wavedl/models/vit.py +21 -0
- wavedl/test.py +28 -5
- wavedl/train.py +122 -15
- wavedl/utils/__init__.py +11 -0
- wavedl/utils/config.py +10 -0
- wavedl/utils/constraints.py +470 -0
- wavedl/utils/cross_validation.py +12 -2
- wavedl/utils/data.py +26 -7
- wavedl/utils/metrics.py +49 -2
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/METADATA +122 -19
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/RECORD +17 -16
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/LICENSE +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/WHEEL +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -89,7 +89,8 @@ def create_objective(args):
|
|
|
89
89
|
# Suggest hyperparameters
|
|
90
90
|
model = trial.suggest_categorical("model", models)
|
|
91
91
|
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
-
|
|
92
|
+
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
+
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
93
94
|
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
94
95
|
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
95
96
|
loss = trial.suggest_categorical("loss", losses)
|
|
@@ -317,6 +318,13 @@ Examples:
|
|
|
317
318
|
default=None,
|
|
318
319
|
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
319
320
|
)
|
|
321
|
+
parser.add_argument(
|
|
322
|
+
"--batch_sizes",
|
|
323
|
+
type=int,
|
|
324
|
+
nargs="+",
|
|
325
|
+
default=None,
|
|
326
|
+
help="Batch sizes to search (default: 16 32 64 128)",
|
|
327
|
+
)
|
|
320
328
|
|
|
321
329
|
# Training settings for each trial
|
|
322
330
|
parser.add_argument(
|
wavedl/models/vit.py
CHANGED
|
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
|
|
|
54
54
|
if self.dim == 1:
|
|
55
55
|
# 1D: segment patches
|
|
56
56
|
L = in_shape[0]
|
|
57
|
+
if L % patch_size != 0:
|
|
58
|
+
import warnings
|
|
59
|
+
|
|
60
|
+
warnings.warn(
|
|
61
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
62
|
+
f"Last {L % patch_size} elements will be dropped. "
|
|
63
|
+
f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
|
|
64
|
+
UserWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
|
57
67
|
self.num_patches = L // patch_size
|
|
58
68
|
self.proj = nn.Conv1d(
|
|
59
69
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
|
|
|
61
71
|
elif self.dim == 2:
|
|
62
72
|
# 2D: grid patches
|
|
63
73
|
H, W = in_shape
|
|
74
|
+
if H % patch_size != 0 or W % patch_size != 0:
|
|
75
|
+
import warnings
|
|
76
|
+
|
|
77
|
+
warnings.warn(
|
|
78
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
79
|
+
f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
|
|
80
|
+
f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
|
|
81
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
82
|
+
UserWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
64
85
|
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
65
86
|
self.proj = nn.Conv2d(
|
|
66
87
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
wavedl/test.py
CHANGED
|
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
|
|
|
166
166
|
default=None,
|
|
167
167
|
help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
|
|
168
168
|
)
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--input_channels",
|
|
171
|
+
type=int,
|
|
172
|
+
default=None,
|
|
173
|
+
help="Explicit number of input channels. Bypasses auto-detection heuristics "
|
|
174
|
+
"for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
|
|
175
|
+
)
|
|
169
176
|
|
|
170
177
|
# Inference options
|
|
171
178
|
parser.add_argument(
|
|
@@ -235,6 +242,7 @@ def load_data_for_inference(
|
|
|
235
242
|
format: str = "auto",
|
|
236
243
|
input_key: str | None = None,
|
|
237
244
|
output_key: str | None = None,
|
|
245
|
+
input_channels: int | None = None,
|
|
238
246
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
239
247
|
"""
|
|
240
248
|
Load test data for inference using the unified data loading pipeline.
|
|
@@ -278,7 +286,11 @@ def load_data_for_inference(
|
|
|
278
286
|
|
|
279
287
|
# Use the unified loader from utils.data
|
|
280
288
|
X, y = load_test_data(
|
|
281
|
-
file_path,
|
|
289
|
+
file_path,
|
|
290
|
+
format=format,
|
|
291
|
+
input_key=input_key,
|
|
292
|
+
output_key=output_key,
|
|
293
|
+
input_channels=input_channels,
|
|
282
294
|
)
|
|
283
295
|
|
|
284
296
|
# Log results
|
|
@@ -452,7 +464,12 @@ def run_inference(
|
|
|
452
464
|
predictions: Numpy array (N, out_size) - still in normalized space
|
|
453
465
|
"""
|
|
454
466
|
if device is None:
|
|
455
|
-
|
|
467
|
+
if torch.cuda.is_available():
|
|
468
|
+
device = torch.device("cuda")
|
|
469
|
+
elif torch.backends.mps.is_available():
|
|
470
|
+
device = torch.device("mps")
|
|
471
|
+
else:
|
|
472
|
+
device = torch.device("cpu")
|
|
456
473
|
|
|
457
474
|
model = model.to(device)
|
|
458
475
|
model.eval()
|
|
@@ -463,7 +480,7 @@ def run_inference(
|
|
|
463
480
|
batch_size=batch_size,
|
|
464
481
|
shuffle=False,
|
|
465
482
|
num_workers=num_workers,
|
|
466
|
-
pin_memory=device.type
|
|
483
|
+
pin_memory=device.type in ("cuda", "mps"),
|
|
467
484
|
)
|
|
468
485
|
|
|
469
486
|
predictions = []
|
|
@@ -919,8 +936,13 @@ def main():
|
|
|
919
936
|
)
|
|
920
937
|
logger = logging.getLogger("Tester")
|
|
921
938
|
|
|
922
|
-
# Device
|
|
923
|
-
|
|
939
|
+
# Device (CUDA > MPS > CPU)
|
|
940
|
+
if torch.cuda.is_available():
|
|
941
|
+
device = torch.device("cuda")
|
|
942
|
+
elif torch.backends.mps.is_available():
|
|
943
|
+
device = torch.device("mps")
|
|
944
|
+
else:
|
|
945
|
+
device = torch.device("cpu")
|
|
924
946
|
logger.info(f"Using device: {device}")
|
|
925
947
|
|
|
926
948
|
# Load test data
|
|
@@ -929,6 +951,7 @@ def main():
|
|
|
929
951
|
format=args.format,
|
|
930
952
|
input_key=args.input_key,
|
|
931
953
|
output_key=args.output_key,
|
|
954
|
+
input_channels=args.input_channels,
|
|
932
955
|
)
|
|
933
956
|
in_shape = tuple(X_test.shape[2:])
|
|
934
957
|
|
wavedl/train.py
CHANGED
|
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
|
|
|
375
375
|
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
376
376
|
)
|
|
377
377
|
|
|
378
|
+
# Physical Constraints
|
|
379
|
+
parser.add_argument(
|
|
380
|
+
"--constraint",
|
|
381
|
+
type=str,
|
|
382
|
+
nargs="+",
|
|
383
|
+
default=[],
|
|
384
|
+
help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
parser.add_argument(
|
|
388
|
+
"--constraint_file",
|
|
389
|
+
type=str,
|
|
390
|
+
default=None,
|
|
391
|
+
help="Python file with constraint(pred, inputs) function",
|
|
392
|
+
)
|
|
393
|
+
parser.add_argument(
|
|
394
|
+
"--constraint_weight",
|
|
395
|
+
type=float,
|
|
396
|
+
nargs="+",
|
|
397
|
+
default=[0.1],
|
|
398
|
+
help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
|
|
399
|
+
)
|
|
400
|
+
parser.add_argument(
|
|
401
|
+
"--constraint_reduction",
|
|
402
|
+
type=str,
|
|
403
|
+
default="mse",
|
|
404
|
+
choices=["mse", "mae"],
|
|
405
|
+
help="Reduction mode for constraint penalties",
|
|
406
|
+
)
|
|
407
|
+
|
|
378
408
|
# Logging
|
|
379
409
|
parser.add_argument(
|
|
380
410
|
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
@@ -553,7 +583,7 @@ def main():
|
|
|
553
583
|
return
|
|
554
584
|
|
|
555
585
|
# ==========================================================================
|
|
556
|
-
#
|
|
586
|
+
# SYSTEM INITIALIZATION
|
|
557
587
|
# ==========================================================================
|
|
558
588
|
# Initialize Accelerator for DDP and mixed precision
|
|
559
589
|
accelerator = Accelerator(
|
|
@@ -609,7 +639,7 @@ def main():
|
|
|
609
639
|
)
|
|
610
640
|
|
|
611
641
|
# ==========================================================================
|
|
612
|
-
#
|
|
642
|
+
# DATA & MODEL LOADING
|
|
613
643
|
# ==========================================================================
|
|
614
644
|
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
615
645
|
args, logger, accelerator, cache_dir=args.output_dir
|
|
@@ -663,7 +693,7 @@ def main():
|
|
|
663
693
|
)
|
|
664
694
|
|
|
665
695
|
# ==========================================================================
|
|
666
|
-
#
|
|
696
|
+
# OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
667
697
|
# ==========================================================================
|
|
668
698
|
# Parse comma-separated arguments with validation
|
|
669
699
|
try:
|
|
@@ -707,6 +737,43 @@ def main():
|
|
|
707
737
|
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
708
738
|
criterion = criterion.to(accelerator.device)
|
|
709
739
|
|
|
740
|
+
# ==========================================================================
|
|
741
|
+
# PHYSICAL CONSTRAINTS INTEGRATION
|
|
742
|
+
# ==========================================================================
|
|
743
|
+
from wavedl.utils.constraints import (
|
|
744
|
+
PhysicsConstrainedLoss,
|
|
745
|
+
build_constraints,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# Build soft constraints
|
|
749
|
+
soft_constraints = build_constraints(
|
|
750
|
+
expressions=args.constraint,
|
|
751
|
+
file_path=args.constraint_file,
|
|
752
|
+
reduction=args.constraint_reduction,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
|
|
756
|
+
if soft_constraints:
|
|
757
|
+
# Pass output scaler so constraints can be evaluated in physical space
|
|
758
|
+
output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
|
|
759
|
+
output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
|
|
760
|
+
criterion = PhysicsConstrainedLoss(
|
|
761
|
+
criterion,
|
|
762
|
+
soft_constraints,
|
|
763
|
+
weights=args.constraint_weight,
|
|
764
|
+
output_mean=output_mean,
|
|
765
|
+
output_std=output_std,
|
|
766
|
+
)
|
|
767
|
+
if accelerator.is_main_process:
|
|
768
|
+
logger.info(
|
|
769
|
+
f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
|
|
770
|
+
f"with weight(s) {args.constraint_weight}"
|
|
771
|
+
)
|
|
772
|
+
if output_mean is not None:
|
|
773
|
+
logger.info(
|
|
774
|
+
" 📐 Constraints evaluated in physical space (denormalized)"
|
|
775
|
+
)
|
|
776
|
+
|
|
710
777
|
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
711
778
|
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
712
779
|
|
|
@@ -762,7 +829,7 @@ def main():
|
|
|
762
829
|
)
|
|
763
830
|
|
|
764
831
|
# ==========================================================================
|
|
765
|
-
#
|
|
832
|
+
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
766
833
|
# ==========================================================================
|
|
767
834
|
start_epoch = 0
|
|
768
835
|
best_val_loss = float("inf")
|
|
@@ -818,7 +885,7 @@ def main():
|
|
|
818
885
|
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
819
886
|
|
|
820
887
|
# ==========================================================================
|
|
821
|
-
#
|
|
888
|
+
# PHYSICAL METRIC SETUP
|
|
822
889
|
# ==========================================================================
|
|
823
890
|
# Physical MAE = normalized MAE * scaler.scale_
|
|
824
891
|
phys_scale = torch.tensor(
|
|
@@ -826,7 +893,7 @@ def main():
|
|
|
826
893
|
)
|
|
827
894
|
|
|
828
895
|
# ==========================================================================
|
|
829
|
-
#
|
|
896
|
+
# TRAINING LOOP
|
|
830
897
|
# ==========================================================================
|
|
831
898
|
# Dynamic console header
|
|
832
899
|
if accelerator.is_main_process:
|
|
@@ -864,7 +931,11 @@ def main():
|
|
|
864
931
|
for x, y in pbar:
|
|
865
932
|
with accelerator.accumulate(model):
|
|
866
933
|
pred = model(x)
|
|
867
|
-
|
|
934
|
+
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
935
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
936
|
+
loss = criterion(pred, y, x)
|
|
937
|
+
else:
|
|
938
|
+
loss = criterion(pred, y)
|
|
868
939
|
|
|
869
940
|
accelerator.backward(loss)
|
|
870
941
|
|
|
@@ -914,7 +985,11 @@ def main():
|
|
|
914
985
|
with torch.inference_mode():
|
|
915
986
|
for x, y in val_dl:
|
|
916
987
|
pred = model(x)
|
|
917
|
-
|
|
988
|
+
# Pass inputs for input-dependent constraints
|
|
989
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
990
|
+
loss = criterion(pred, y, x)
|
|
991
|
+
else:
|
|
992
|
+
loss = criterion(pred, y)
|
|
918
993
|
|
|
919
994
|
val_loss_sum += loss.detach() * x.size(0)
|
|
920
995
|
val_samples += x.size(0)
|
|
@@ -931,13 +1006,45 @@ def main():
|
|
|
931
1006
|
cpu_preds = torch.cat(local_preds)
|
|
932
1007
|
cpu_targets = torch.cat(local_targets)
|
|
933
1008
|
|
|
934
|
-
# Gather predictions and targets
|
|
935
|
-
#
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
1009
|
+
# Gather predictions and targets to rank 0 only (memory-efficient)
|
|
1010
|
+
# Avoids duplicating full validation set on every GPU
|
|
1011
|
+
if torch.distributed.is_initialized():
|
|
1012
|
+
# DDP mode: gather only to rank 0
|
|
1013
|
+
# NCCL backend requires CUDA tensors for collective ops
|
|
1014
|
+
gpu_preds = cpu_preds.to(accelerator.device)
|
|
1015
|
+
gpu_targets = cpu_targets.to(accelerator.device)
|
|
1016
|
+
|
|
1017
|
+
if accelerator.is_main_process:
|
|
1018
|
+
# Rank 0: allocate gather buffers on GPU
|
|
1019
|
+
all_preds_list = [
|
|
1020
|
+
torch.zeros_like(gpu_preds)
|
|
1021
|
+
for _ in range(accelerator.num_processes)
|
|
1022
|
+
]
|
|
1023
|
+
all_targets_list = [
|
|
1024
|
+
torch.zeros_like(gpu_targets)
|
|
1025
|
+
for _ in range(accelerator.num_processes)
|
|
1026
|
+
]
|
|
1027
|
+
torch.distributed.gather(
|
|
1028
|
+
gpu_preds, gather_list=all_preds_list, dst=0
|
|
1029
|
+
)
|
|
1030
|
+
torch.distributed.gather(
|
|
1031
|
+
gpu_targets, gather_list=all_targets_list, dst=0
|
|
1032
|
+
)
|
|
1033
|
+
# Move back to CPU for metric computation
|
|
1034
|
+
gathered = [
|
|
1035
|
+
(
|
|
1036
|
+
torch.cat(all_preds_list).cpu(),
|
|
1037
|
+
torch.cat(all_targets_list).cpu(),
|
|
1038
|
+
)
|
|
1039
|
+
]
|
|
1040
|
+
else:
|
|
1041
|
+
# Other ranks: send to rank 0, don't allocate gather buffers
|
|
1042
|
+
torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
|
|
1043
|
+
torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
|
|
1044
|
+
gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
|
|
1045
|
+
else:
|
|
1046
|
+
# Single-GPU mode: no gathering needed
|
|
1047
|
+
gathered = [(cpu_preds, cpu_targets)]
|
|
941
1048
|
|
|
942
1049
|
# Synchronize validation metrics (scalars only - efficient)
|
|
943
1050
|
val_loss_scalar = val_loss_sum.item()
|
wavedl/utils/__init__.py
CHANGED
|
@@ -15,6 +15,12 @@ from .config import (
|
|
|
15
15
|
save_config,
|
|
16
16
|
validate_config,
|
|
17
17
|
)
|
|
18
|
+
from .constraints import (
|
|
19
|
+
ExpressionConstraint,
|
|
20
|
+
FileConstraint,
|
|
21
|
+
PhysicsConstrainedLoss,
|
|
22
|
+
build_constraints,
|
|
23
|
+
)
|
|
18
24
|
from .cross_validation import (
|
|
19
25
|
CVDataset,
|
|
20
26
|
run_cross_validation,
|
|
@@ -91,8 +97,11 @@ __all__ = [
|
|
|
91
97
|
"FIGURE_WIDTH_INCH",
|
|
92
98
|
"FONT_SIZE_TEXT",
|
|
93
99
|
"FONT_SIZE_TICKS",
|
|
100
|
+
# Constraints
|
|
94
101
|
"CVDataset",
|
|
95
102
|
"DataSource",
|
|
103
|
+
"ExpressionConstraint",
|
|
104
|
+
"FileConstraint",
|
|
96
105
|
"HDF5Source",
|
|
97
106
|
"LogCoshLoss",
|
|
98
107
|
"MATSource",
|
|
@@ -101,10 +110,12 @@ __all__ = [
|
|
|
101
110
|
# Metrics
|
|
102
111
|
"MetricTracker",
|
|
103
112
|
"NPZSource",
|
|
113
|
+
"PhysicsConstrainedLoss",
|
|
104
114
|
"WeightedMSELoss",
|
|
105
115
|
# Distributed
|
|
106
116
|
"broadcast_early_stop",
|
|
107
117
|
"broadcast_value",
|
|
118
|
+
"build_constraints",
|
|
108
119
|
"calc_pearson",
|
|
109
120
|
"calc_per_target_r2",
|
|
110
121
|
"configure_matplotlib_style",
|
wavedl/utils/config.py
CHANGED
|
@@ -306,6 +306,16 @@ def validate_config(
|
|
|
306
306
|
# Config
|
|
307
307
|
"config",
|
|
308
308
|
"list_models",
|
|
309
|
+
# Physical Constraints
|
|
310
|
+
"constraint",
|
|
311
|
+
"bounds",
|
|
312
|
+
"constraint_file",
|
|
313
|
+
"constraint_weight",
|
|
314
|
+
"constraint_reduction",
|
|
315
|
+
"positive",
|
|
316
|
+
"output_bounds",
|
|
317
|
+
"output_transform",
|
|
318
|
+
"output_formula",
|
|
309
319
|
# Metadata (internal)
|
|
310
320
|
"_metadata",
|
|
311
321
|
}
|