wavedl 1.4.6__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.6"
21
+ __version__ = "1.5.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpo.py CHANGED
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -317,6 +318,13 @@ Examples:
317
318
  default=None,
318
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
319
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
320
328
 
321
329
  # Training settings for each trial
322
330
  parser.add_argument(
wavedl/models/vit.py CHANGED
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
54
54
  if self.dim == 1:
55
55
  # 1D: segment patches
56
56
  L = in_shape[0]
57
+ if L % patch_size != 0:
58
+ import warnings
59
+
60
+ warnings.warn(
61
+ f"Input length {L} not divisible by patch_size {patch_size}. "
62
+ f"Last {L % patch_size} elements will be dropped. "
63
+ f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
+ UserWarning,
65
+ stacklevel=2,
66
+ )
57
67
  self.num_patches = L // patch_size
58
68
  self.proj = nn.Conv1d(
59
69
  1, embed_dim, kernel_size=patch_size, stride=patch_size
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
61
71
  elif self.dim == 2:
62
72
  # 2D: grid patches
63
73
  H, W = in_shape
74
+ if H % patch_size != 0 or W % patch_size != 0:
75
+ import warnings
76
+
77
+ warnings.warn(
78
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
+ f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
+ f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
+ f"{((W // patch_size) + 1) * patch_size}).",
82
+ UserWarning,
83
+ stacklevel=2,
84
+ )
64
85
  self.num_patches = (H // patch_size) * (W // patch_size)
65
86
  self.proj = nn.Conv2d(
66
87
  1, embed_dim, kernel_size=patch_size, stride=patch_size
wavedl/test.py CHANGED
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
166
166
  default=None,
167
167
  help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
168
168
  )
169
+ parser.add_argument(
170
+ "--input_channels",
171
+ type=int,
172
+ default=None,
173
+ help="Explicit number of input channels. Bypasses auto-detection heuristics "
174
+ "for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
175
+ )
169
176
 
170
177
  # Inference options
171
178
  parser.add_argument(
@@ -235,6 +242,7 @@ def load_data_for_inference(
235
242
  format: str = "auto",
236
243
  input_key: str | None = None,
237
244
  output_key: str | None = None,
245
+ input_channels: int | None = None,
238
246
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
239
247
  """
240
248
  Load test data for inference using the unified data loading pipeline.
@@ -278,7 +286,11 @@ def load_data_for_inference(
278
286
 
279
287
  # Use the unified loader from utils.data
280
288
  X, y = load_test_data(
281
- file_path, format=format, input_key=input_key, output_key=output_key
289
+ file_path,
290
+ format=format,
291
+ input_key=input_key,
292
+ output_key=output_key,
293
+ input_channels=input_channels,
282
294
  )
283
295
 
284
296
  # Log results
@@ -452,7 +464,12 @@ def run_inference(
452
464
  predictions: Numpy array (N, out_size) - still in normalized space
453
465
  """
454
466
  if device is None:
455
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ if torch.cuda.is_available():
468
+ device = torch.device("cuda")
469
+ elif torch.backends.mps.is_available():
470
+ device = torch.device("mps")
471
+ else:
472
+ device = torch.device("cpu")
456
473
 
457
474
  model = model.to(device)
458
475
  model.eval()
@@ -463,7 +480,7 @@ def run_inference(
463
480
  batch_size=batch_size,
464
481
  shuffle=False,
465
482
  num_workers=num_workers,
466
- pin_memory=device.type == "cuda",
483
+ pin_memory=device.type in ("cuda", "mps"),
467
484
  )
468
485
 
469
486
  predictions = []
@@ -919,8 +936,13 @@ def main():
919
936
  )
920
937
  logger = logging.getLogger("Tester")
921
938
 
922
- # Device
923
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
939
+ # Device (CUDA > MPS > CPU)
940
+ if torch.cuda.is_available():
941
+ device = torch.device("cuda")
942
+ elif torch.backends.mps.is_available():
943
+ device = torch.device("mps")
944
+ else:
945
+ device = torch.device("cpu")
924
946
  logger.info(f"Using device: {device}")
925
947
 
926
948
  # Load test data
@@ -929,6 +951,7 @@ def main():
929
951
  format=args.format,
930
952
  input_key=args.input_key,
931
953
  output_key=args.output_key,
954
+ input_channels=args.input_channels,
932
955
  )
933
956
  in_shape = tuple(X_test.shape[2:])
934
957
 
wavedl/train.py CHANGED
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
375
375
  help=argparse.SUPPRESS, # Hidden: use --precision instead
376
376
  )
377
377
 
378
+ # Physical Constraints
379
+ parser.add_argument(
380
+ "--constraint",
381
+ type=str,
382
+ nargs="+",
383
+ default=[],
384
+ help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--constraint_file",
389
+ type=str,
390
+ default=None,
391
+ help="Python file with constraint(pred, inputs) function",
392
+ )
393
+ parser.add_argument(
394
+ "--constraint_weight",
395
+ type=float,
396
+ nargs="+",
397
+ default=[0.1],
398
+ help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
399
+ )
400
+ parser.add_argument(
401
+ "--constraint_reduction",
402
+ type=str,
403
+ default="mse",
404
+ choices=["mse", "mae"],
405
+ help="Reduction mode for constraint penalties",
406
+ )
407
+
378
408
  # Logging
379
409
  parser.add_argument(
380
410
  "--wandb", action="store_true", help="Enable Weights & Biases logging"
@@ -553,7 +583,7 @@ def main():
553
583
  return
554
584
 
555
585
  # ==========================================================================
556
- # 1. SYSTEM INITIALIZATION
586
+ # SYSTEM INITIALIZATION
557
587
  # ==========================================================================
558
588
  # Initialize Accelerator for DDP and mixed precision
559
589
  accelerator = Accelerator(
@@ -609,7 +639,7 @@ def main():
609
639
  )
610
640
 
611
641
  # ==========================================================================
612
- # 2. DATA & MODEL LOADING
642
+ # DATA & MODEL LOADING
613
643
  # ==========================================================================
614
644
  train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
615
645
  args, logger, accelerator, cache_dir=args.output_dir
@@ -663,7 +693,7 @@ def main():
663
693
  )
664
694
 
665
695
  # ==========================================================================
666
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
696
+ # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
667
697
  # ==========================================================================
668
698
  # Parse comma-separated arguments with validation
669
699
  try:
@@ -707,6 +737,43 @@ def main():
707
737
  # Move criterion to device (important for WeightedMSELoss buffer)
708
738
  criterion = criterion.to(accelerator.device)
709
739
 
740
+ # ==========================================================================
741
+ # PHYSICAL CONSTRAINTS INTEGRATION
742
+ # ==========================================================================
743
+ from wavedl.utils.constraints import (
744
+ PhysicsConstrainedLoss,
745
+ build_constraints,
746
+ )
747
+
748
+ # Build soft constraints
749
+ soft_constraints = build_constraints(
750
+ expressions=args.constraint,
751
+ file_path=args.constraint_file,
752
+ reduction=args.constraint_reduction,
753
+ )
754
+
755
+ # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
756
+ if soft_constraints:
757
+ # Pass output scaler so constraints can be evaluated in physical space
758
+ output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
759
+ output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
760
+ criterion = PhysicsConstrainedLoss(
761
+ criterion,
762
+ soft_constraints,
763
+ weights=args.constraint_weight,
764
+ output_mean=output_mean,
765
+ output_std=output_std,
766
+ )
767
+ if accelerator.is_main_process:
768
+ logger.info(
769
+ f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
770
+ f"with weight(s) {args.constraint_weight}"
771
+ )
772
+ if output_mean is not None:
773
+ logger.info(
774
+ " 📐 Constraints evaluated in physical space (denormalized)"
775
+ )
776
+
710
777
  # Track if scheduler should step per batch (OneCycleLR) or per epoch
711
778
  scheduler_step_per_batch = not is_epoch_based(args.scheduler)
712
779
 
@@ -762,7 +829,7 @@ def main():
762
829
  )
763
830
 
764
831
  # ==========================================================================
765
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
832
+ # AUTO-RESUME / RESUME FROM CHECKPOINT
766
833
  # ==========================================================================
767
834
  start_epoch = 0
768
835
  best_val_loss = float("inf")
@@ -818,7 +885,7 @@ def main():
818
885
  raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
819
886
 
820
887
  # ==========================================================================
821
- # 4. PHYSICAL METRIC SETUP
888
+ # PHYSICAL METRIC SETUP
822
889
  # ==========================================================================
823
890
  # Physical MAE = normalized MAE * scaler.scale_
824
891
  phys_scale = torch.tensor(
@@ -826,7 +893,7 @@ def main():
826
893
  )
827
894
 
828
895
  # ==========================================================================
829
- # 5. TRAINING LOOP
896
+ # TRAINING LOOP
830
897
  # ==========================================================================
831
898
  # Dynamic console header
832
899
  if accelerator.is_main_process:
@@ -864,7 +931,11 @@ def main():
864
931
  for x, y in pbar:
865
932
  with accelerator.accumulate(model):
866
933
  pred = model(x)
867
- loss = criterion(pred, y)
934
+ # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
935
+ if isinstance(criterion, PhysicsConstrainedLoss):
936
+ loss = criterion(pred, y, x)
937
+ else:
938
+ loss = criterion(pred, y)
868
939
 
869
940
  accelerator.backward(loss)
870
941
 
@@ -914,7 +985,11 @@ def main():
914
985
  with torch.inference_mode():
915
986
  for x, y in val_dl:
916
987
  pred = model(x)
917
- loss = criterion(pred, y)
988
+ # Pass inputs for input-dependent constraints
989
+ if isinstance(criterion, PhysicsConstrainedLoss):
990
+ loss = criterion(pred, y, x)
991
+ else:
992
+ loss = criterion(pred, y)
918
993
 
919
994
  val_loss_sum += loss.detach() * x.size(0)
920
995
  val_samples += x.size(0)
@@ -931,13 +1006,45 @@ def main():
931
1006
  cpu_preds = torch.cat(local_preds)
932
1007
  cpu_targets = torch.cat(local_targets)
933
1008
 
934
- # Gather predictions and targets across all ranks
935
- # Use accelerator.gather (works with all accelerate versions)
936
- gpu_preds = cpu_preds.to(accelerator.device)
937
- gpu_targets = cpu_targets.to(accelerator.device)
938
- all_preds_gathered = accelerator.gather(gpu_preds).cpu()
939
- all_targets_gathered = accelerator.gather(gpu_targets).cpu()
940
- gathered = [(all_preds_gathered, all_targets_gathered)]
1009
+ # Gather predictions and targets to rank 0 only (memory-efficient)
1010
+ # Avoids duplicating full validation set on every GPU
1011
+ if torch.distributed.is_initialized():
1012
+ # DDP mode: gather only to rank 0
1013
+ # NCCL backend requires CUDA tensors for collective ops
1014
+ gpu_preds = cpu_preds.to(accelerator.device)
1015
+ gpu_targets = cpu_targets.to(accelerator.device)
1016
+
1017
+ if accelerator.is_main_process:
1018
+ # Rank 0: allocate gather buffers on GPU
1019
+ all_preds_list = [
1020
+ torch.zeros_like(gpu_preds)
1021
+ for _ in range(accelerator.num_processes)
1022
+ ]
1023
+ all_targets_list = [
1024
+ torch.zeros_like(gpu_targets)
1025
+ for _ in range(accelerator.num_processes)
1026
+ ]
1027
+ torch.distributed.gather(
1028
+ gpu_preds, gather_list=all_preds_list, dst=0
1029
+ )
1030
+ torch.distributed.gather(
1031
+ gpu_targets, gather_list=all_targets_list, dst=0
1032
+ )
1033
+ # Move back to CPU for metric computation
1034
+ gathered = [
1035
+ (
1036
+ torch.cat(all_preds_list).cpu(),
1037
+ torch.cat(all_targets_list).cpu(),
1038
+ )
1039
+ ]
1040
+ else:
1041
+ # Other ranks: send to rank 0, don't allocate gather buffers
1042
+ torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
1043
+ torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
1044
+ gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
1045
+ else:
1046
+ # Single-GPU mode: no gathering needed
1047
+ gathered = [(cpu_preds, cpu_targets)]
941
1048
 
942
1049
  # Synchronize validation metrics (scalars only - efficient)
943
1050
  val_loss_scalar = val_loss_sum.item()
wavedl/utils/__init__.py CHANGED
@@ -15,6 +15,12 @@ from .config import (
15
15
  save_config,
16
16
  validate_config,
17
17
  )
18
+ from .constraints import (
19
+ ExpressionConstraint,
20
+ FileConstraint,
21
+ PhysicsConstrainedLoss,
22
+ build_constraints,
23
+ )
18
24
  from .cross_validation import (
19
25
  CVDataset,
20
26
  run_cross_validation,
@@ -91,8 +97,11 @@ __all__ = [
91
97
  "FIGURE_WIDTH_INCH",
92
98
  "FONT_SIZE_TEXT",
93
99
  "FONT_SIZE_TICKS",
100
+ # Constraints
94
101
  "CVDataset",
95
102
  "DataSource",
103
+ "ExpressionConstraint",
104
+ "FileConstraint",
96
105
  "HDF5Source",
97
106
  "LogCoshLoss",
98
107
  "MATSource",
@@ -101,10 +110,12 @@ __all__ = [
101
110
  # Metrics
102
111
  "MetricTracker",
103
112
  "NPZSource",
113
+ "PhysicsConstrainedLoss",
104
114
  "WeightedMSELoss",
105
115
  # Distributed
106
116
  "broadcast_early_stop",
107
117
  "broadcast_value",
118
+ "build_constraints",
108
119
  "calc_pearson",
109
120
  "calc_per_target_r2",
110
121
  "configure_matplotlib_style",
wavedl/utils/config.py CHANGED
@@ -306,6 +306,16 @@ def validate_config(
306
306
  # Config
307
307
  "config",
308
308
  "list_models",
309
+ # Physical Constraints
310
+ "constraint",
311
+ "bounds",
312
+ "constraint_file",
313
+ "constraint_weight",
314
+ "constraint_reduction",
315
+ "positive",
316
+ "output_bounds",
317
+ "output_transform",
318
+ "output_formula",
309
319
  # Metadata (internal)
310
320
  "_metadata",
311
321
  }