wavedl 1.4.5__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.5"
21
+ __version__ = "1.5.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpc.py CHANGED
@@ -174,7 +174,9 @@ Environment Variables:
174
174
  return args, remaining
175
175
 
176
176
 
177
- def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
177
+ def print_summary(
178
+ exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
179
+ ) -> None:
178
180
  """Print post-training summary and instructions."""
179
181
  print()
180
182
  print("=" * 40)
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
183
185
  print("✅ Training completed successfully!")
184
186
  print("=" * 40)
185
187
 
186
- if wandb_mode == "offline":
188
+ # Only show WandB sync instructions if user enabled wandb
189
+ if wandb_enabled and wandb_mode == "offline":
187
190
  print()
188
191
  print("📊 WandB Sync Instructions:")
189
192
  print(" From the login node, run:")
@@ -237,6 +240,10 @@ def main() -> int:
237
240
  f"--dynamo_backend={args.dynamo_backend}",
238
241
  ]
239
242
 
243
+ # Explicitly set multi_gpu to suppress accelerate auto-detection warning
244
+ if num_gpus > 1:
245
+ cmd.append("--multi_gpu")
246
+
240
247
  # Add multi-node networking args if specified (required for some clusters)
241
248
  if args.main_process_ip:
242
249
  cmd.append(f"--main_process_ip={args.main_process_ip}")
@@ -263,8 +270,10 @@ def main() -> int:
263
270
  exit_code = 130
264
271
 
265
272
  # Print summary
273
+ wandb_enabled = "--wandb" in train_args
266
274
  print_summary(
267
275
  exit_code,
276
+ wandb_enabled,
268
277
  os.environ.get("WANDB_MODE", "offline"),
269
278
  os.environ.get("WANDB_DIR", "/tmp/wandb"),
270
279
  )
wavedl/hpo.py CHANGED
@@ -31,7 +31,7 @@ try:
31
31
  import optuna
32
32
  from optuna.trial import TrialState
33
33
  except ImportError:
34
- print("Error: Optuna not installed. Run: pip install -e '.[hpo]'")
34
+ print("Error: Optuna not installed. Run: pip install wavedl")
35
35
  sys.exit(1)
36
36
 
37
37
 
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -147,6 +148,32 @@ def create_objective(args):
147
148
  cmd.extend(["--output_dir", tmpdir])
148
149
  history_file = Path(tmpdir) / "training_history.csv"
149
150
 
151
+ # GPU isolation for parallel trials: assign each trial to a specific GPU
152
+ # This prevents multiple trials from competing for all GPUs
153
+ env = None
154
+ if args.n_jobs > 1:
155
+ import os
156
+
157
+ # Detect available GPUs
158
+ n_gpus = 1
159
+ try:
160
+ import subprocess as sp
161
+
162
+ result_gpu = sp.run(
163
+ ["nvidia-smi", "--list-gpus"],
164
+ capture_output=True,
165
+ text=True,
166
+ )
167
+ if result_gpu.returncode == 0:
168
+ n_gpus = len(result_gpu.stdout.strip().split("\n"))
169
+ except Exception:
170
+ pass
171
+
172
+ # Assign trial to a specific GPU (round-robin)
173
+ gpu_id = trial.number % n_gpus
174
+ env = os.environ.copy()
175
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
+
150
177
  # Run training
151
178
  try:
152
179
  result = subprocess.run(
@@ -155,6 +182,7 @@ def create_objective(args):
155
182
  text=True,
156
183
  timeout=args.timeout,
157
184
  cwd=Path(__file__).parent,
185
+ env=env,
158
186
  )
159
187
 
160
188
  # Read best val_loss from training_history.csv (reliable machine-readable)
@@ -248,7 +276,10 @@ Examples:
248
276
  "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
249
277
  )
250
278
  parser.add_argument(
251
- "--n_jobs", type=int, default=1, help="Parallel trials (default: 1)"
279
+ "--n_jobs",
280
+ type=int,
281
+ default=-1,
282
+ help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
252
283
  )
253
284
  parser.add_argument(
254
285
  "--quick",
@@ -287,6 +318,13 @@ Examples:
287
318
  default=None,
288
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
289
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
290
328
 
291
329
  # Training settings for each trial
292
330
  parser.add_argument(
@@ -315,11 +353,30 @@ Examples:
315
353
 
316
354
  args = parser.parse_args()
317
355
 
356
+ # Convert to absolute path (child processes may run in different cwd)
357
+ args.data_path = str(Path(args.data_path).resolve())
358
+
318
359
  # Validate data path
319
360
  if not Path(args.data_path).exists():
320
361
  print(f"Error: Data file not found: {args.data_path}")
321
362
  sys.exit(1)
322
363
 
364
+ # Auto-detect GPUs for n_jobs if not specified
365
+ if args.n_jobs == -1:
366
+ try:
367
+ result_gpu = subprocess.run(
368
+ ["nvidia-smi", "--list-gpus"],
369
+ capture_output=True,
370
+ text=True,
371
+ )
372
+ if result_gpu.returncode == 0:
373
+ args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
374
+ else:
375
+ args.n_jobs = 1
376
+ except Exception:
377
+ args.n_jobs = 1
378
+ print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
379
+
323
380
  # Create study
324
381
  print("=" * 60)
325
382
  print("WaveDL Hyperparameter Optimization")
wavedl/test.py CHANGED
@@ -366,13 +366,19 @@ def load_checkpoint(
366
366
  logging.info(f" Building model: {model_name}")
367
367
  model = build_model(model_name, in_shape=in_shape, out_size=out_size)
368
368
 
369
- # Load weights (prefer safetensors)
370
- weight_path = checkpoint_dir / "model.safetensors"
371
- if not weight_path.exists():
372
- weight_path = checkpoint_dir / "pytorch_model.bin"
373
-
374
- if not weight_path.exists():
375
- raise FileNotFoundError(f"No model weights found in {checkpoint_dir}")
369
+ # Load weights (check multiple formats in order of preference)
370
+ weight_path = None
371
+ for fname in ["model.safetensors", "model.bin", "pytorch_model.bin"]:
372
+ candidate = checkpoint_dir / fname
373
+ if candidate.exists():
374
+ weight_path = candidate
375
+ break
376
+
377
+ if weight_path is None:
378
+ raise FileNotFoundError(
379
+ f"No model weights found in {checkpoint_dir}. "
380
+ f"Expected one of: model.safetensors, model.bin, pytorch_model.bin"
381
+ )
376
382
 
377
383
  if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
378
384
  state_dict = load_safetensors(str(weight_path))
wavedl/train.py CHANGED
@@ -148,6 +148,24 @@ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
148
148
  torch.backends.cudnn.benchmark = True
149
149
 
150
150
 
151
+ # ==============================================================================
152
+ # LOGGING UTILITIES
153
+ # ==============================================================================
154
+ from contextlib import contextmanager
155
+
156
+
157
+ @contextmanager
158
+ def suppress_accelerate_logging():
159
+ """Temporarily suppress accelerate's verbose checkpoint save messages."""
160
+ accelerate_logger = logging.getLogger("accelerate.checkpointing")
161
+ original_level = accelerate_logger.level
162
+ accelerate_logger.setLevel(logging.WARNING)
163
+ try:
164
+ yield
165
+ finally:
166
+ accelerate_logger.setLevel(original_level)
167
+
168
+
151
169
  # ==============================================================================
152
170
  # ARGUMENT PARSING
153
171
  # ==============================================================================
@@ -357,6 +375,36 @@ def parse_args() -> argparse.Namespace:
357
375
  help=argparse.SUPPRESS, # Hidden: use --precision instead
358
376
  )
359
377
 
378
+ # Physical Constraints
379
+ parser.add_argument(
380
+ "--constraint",
381
+ type=str,
382
+ nargs="+",
383
+ default=[],
384
+ help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--constraint_file",
389
+ type=str,
390
+ default=None,
391
+ help="Python file with constraint(pred, inputs) function",
392
+ )
393
+ parser.add_argument(
394
+ "--constraint_weight",
395
+ type=float,
396
+ nargs="+",
397
+ default=[0.1],
398
+ help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
399
+ )
400
+ parser.add_argument(
401
+ "--constraint_reduction",
402
+ type=str,
403
+ default="mse",
404
+ choices=["mse", "mae"],
405
+ help="Reduction mode for constraint penalties",
406
+ )
407
+
360
408
  # Logging
361
409
  parser.add_argument(
362
410
  "--wandb", action="store_true", help="Enable Weights & Biases logging"
@@ -535,7 +583,7 @@ def main():
535
583
  return
536
584
 
537
585
  # ==========================================================================
538
- # 1. SYSTEM INITIALIZATION
586
+ # SYSTEM INITIALIZATION
539
587
  # ==========================================================================
540
588
  # Initialize Accelerator for DDP and mixed precision
541
589
  accelerator = Accelerator(
@@ -591,7 +639,7 @@ def main():
591
639
  )
592
640
 
593
641
  # ==========================================================================
594
- # 2. DATA & MODEL LOADING
642
+ # DATA & MODEL LOADING
595
643
  # ==========================================================================
596
644
  train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
597
645
  args, logger, accelerator, cache_dir=args.output_dir
@@ -645,7 +693,7 @@ def main():
645
693
  )
646
694
 
647
695
  # ==========================================================================
648
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
696
+ # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
649
697
  # ==========================================================================
650
698
  # Parse comma-separated arguments with validation
651
699
  try:
@@ -689,6 +737,43 @@ def main():
689
737
  # Move criterion to device (important for WeightedMSELoss buffer)
690
738
  criterion = criterion.to(accelerator.device)
691
739
 
740
+ # ==========================================================================
741
+ # PHYSICAL CONSTRAINTS INTEGRATION
742
+ # ==========================================================================
743
+ from wavedl.utils.constraints import (
744
+ PhysicsConstrainedLoss,
745
+ build_constraints,
746
+ )
747
+
748
+ # Build soft constraints
749
+ soft_constraints = build_constraints(
750
+ expressions=args.constraint,
751
+ file_path=args.constraint_file,
752
+ reduction=args.constraint_reduction,
753
+ )
754
+
755
+ # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
756
+ if soft_constraints:
757
+ # Pass output scaler so constraints can be evaluated in physical space
758
+ output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
759
+ output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
760
+ criterion = PhysicsConstrainedLoss(
761
+ criterion,
762
+ soft_constraints,
763
+ weights=args.constraint_weight,
764
+ output_mean=output_mean,
765
+ output_std=output_std,
766
+ )
767
+ if accelerator.is_main_process:
768
+ logger.info(
769
+ f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
770
+ f"with weight(s) {args.constraint_weight}"
771
+ )
772
+ if output_mean is not None:
773
+ logger.info(
774
+ " 📐 Constraints evaluated in physical space (denormalized)"
775
+ )
776
+
692
777
  # Track if scheduler should step per batch (OneCycleLR) or per epoch
693
778
  scheduler_step_per_batch = not is_epoch_based(args.scheduler)
694
779
 
@@ -744,7 +829,7 @@ def main():
744
829
  )
745
830
 
746
831
  # ==========================================================================
747
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
832
+ # AUTO-RESUME / RESUME FROM CHECKPOINT
748
833
  # ==========================================================================
749
834
  start_epoch = 0
750
835
  best_val_loss = float("inf")
@@ -800,7 +885,7 @@ def main():
800
885
  raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
801
886
 
802
887
  # ==========================================================================
803
- # 4. PHYSICAL METRIC SETUP
888
+ # PHYSICAL METRIC SETUP
804
889
  # ==========================================================================
805
890
  # Physical MAE = normalized MAE * scaler.scale_
806
891
  phys_scale = torch.tensor(
@@ -808,7 +893,7 @@ def main():
808
893
  )
809
894
 
810
895
  # ==========================================================================
811
- # 5. TRAINING LOOP
896
+ # TRAINING LOOP
812
897
  # ==========================================================================
813
898
  # Dynamic console header
814
899
  if accelerator.is_main_process:
@@ -1033,7 +1118,8 @@ def main():
1033
1118
  # Step 3: Save checkpoint with all ranks participating
1034
1119
  if is_best_epoch:
1035
1120
  ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
1036
- accelerator.save_state(ckpt_dir) # All ranks must call this
1121
+ with suppress_accelerate_logging():
1122
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1037
1123
 
1038
1124
  # Step 4: Rank 0 handles metadata and updates tracking variables
1039
1125
  if accelerator.is_main_process:
@@ -1096,7 +1182,8 @@ def main():
1096
1182
  if periodic_checkpoint_needed:
1097
1183
  ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1098
1184
  ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1099
- accelerator.save_state(ckpt_dir) # All ranks participate
1185
+ with suppress_accelerate_logging():
1186
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1100
1187
 
1101
1188
  if accelerator.is_main_process:
1102
1189
  with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
@@ -1147,7 +1234,11 @@ def main():
1147
1234
 
1148
1235
  except KeyboardInterrupt:
1149
1236
  logger.warning("Training interrupted. Saving emergency checkpoint...")
1150
- accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1237
+ with suppress_accelerate_logging():
1238
+ accelerator.save_state(
1239
+ os.path.join(args.output_dir, "interrupted_checkpoint"),
1240
+ safe_serialization=False,
1241
+ )
1151
1242
 
1152
1243
  except Exception as e:
1153
1244
  logger.error(f"Critical error: {e}", exc_info=True)
wavedl/utils/__init__.py CHANGED
@@ -15,6 +15,12 @@ from .config import (
15
15
  save_config,
16
16
  validate_config,
17
17
  )
18
+ from .constraints import (
19
+ ExpressionConstraint,
20
+ FileConstraint,
21
+ PhysicsConstrainedLoss,
22
+ build_constraints,
23
+ )
18
24
  from .cross_validation import (
19
25
  CVDataset,
20
26
  run_cross_validation,
@@ -91,8 +97,11 @@ __all__ = [
91
97
  "FIGURE_WIDTH_INCH",
92
98
  "FONT_SIZE_TEXT",
93
99
  "FONT_SIZE_TICKS",
100
+ # Constraints
94
101
  "CVDataset",
95
102
  "DataSource",
103
+ "ExpressionConstraint",
104
+ "FileConstraint",
96
105
  "HDF5Source",
97
106
  "LogCoshLoss",
98
107
  "MATSource",
@@ -101,10 +110,12 @@ __all__ = [
101
110
  # Metrics
102
111
  "MetricTracker",
103
112
  "NPZSource",
113
+ "PhysicsConstrainedLoss",
104
114
  "WeightedMSELoss",
105
115
  # Distributed
106
116
  "broadcast_early_stop",
107
117
  "broadcast_value",
118
+ "build_constraints",
108
119
  "calc_pearson",
109
120
  "calc_per_target_r2",
110
121
  "configure_matplotlib_style",
wavedl/utils/config.py CHANGED
@@ -306,6 +306,16 @@ def validate_config(
306
306
  # Config
307
307
  "config",
308
308
  "list_models",
309
+ # Physical Constraints
310
+ "constraint",
311
+ "bounds",
312
+ "constraint_file",
313
+ "constraint_weight",
314
+ "constraint_reduction",
315
+ "positive",
316
+ "output_bounds",
317
+ "output_transform",
318
+ "output_formula",
309
319
  # Metadata (internal)
310
320
  "_metadata",
311
321
  }
@@ -0,0 +1,470 @@
1
+ """
2
+ Physical Constraints for Training
3
+ =================================
4
+
5
+ Soft constraint enforcement via penalty-based loss terms.
6
+
7
+ Usage:
8
+ # Expression constraints
9
+ wavedl-train --constraint "y0 > 0" --constraint_weight 0.1
10
+
11
+ # Complex constraints via Python file
12
+ wavedl-train --constraint_file my_constraint.py
13
+
14
+ Author: Ductho Le (ductho.le@outlook.com)
15
+ Version: 2.0.0
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import importlib.util
22
+ import sys
23
+ from typing import TYPE_CHECKING
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from collections.abc import Callable
32
+
33
+
34
+ # ==============================================================================
35
+ # SAFE EXPRESSION PARSING
36
+ # ==============================================================================
37
+ SAFE_FUNCTIONS: dict[str, Callable] = {
38
+ "sin": torch.sin,
39
+ "cos": torch.cos,
40
+ "tan": torch.tan,
41
+ "exp": torch.exp,
42
+ "log": torch.log,
43
+ "sqrt": torch.sqrt,
44
+ "abs": torch.abs,
45
+ "relu": F.relu,
46
+ "sigmoid": torch.sigmoid,
47
+ "softplus": F.softplus,
48
+ "tanh": torch.tanh,
49
+ "min": torch.minimum,
50
+ "max": torch.maximum,
51
+ "pow": torch.pow,
52
+ "clamp": torch.clamp,
53
+ }
54
+
55
+ INPUT_AGGREGATES: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
56
+ "x_mean": lambda x: x.mean(dim=tuple(range(1, x.ndim))),
57
+ "x_sum": lambda x: x.sum(dim=tuple(range(1, x.ndim))),
58
+ "x_max": lambda x: x.amax(dim=tuple(range(1, x.ndim))),
59
+ "x_min": lambda x: x.amin(dim=tuple(range(1, x.ndim))),
60
+ "x_std": lambda x: x.std(dim=tuple(range(1, x.ndim))),
61
+ "x_energy": lambda x: (x**2).sum(dim=tuple(range(1, x.ndim))),
62
+ }
63
+
64
+
65
+ # ==============================================================================
66
+ # SOFT CONSTRAINTS
67
+ # ==============================================================================
68
+ class ExpressionConstraint(nn.Module):
69
+ """
70
+ Soft constraint via string expression.
71
+
72
+ Parses mathematical expressions using Python's AST for safe evaluation.
73
+ Supports output variables (y0, y1, ...), input aggregates (x_mean, ...),
74
+ and whitelisted math functions.
75
+
76
+ Example:
77
+ >>> constraint = ExpressionConstraint("y0 - y1 * y2")
78
+ >>> penalty = constraint(predictions, inputs)
79
+
80
+ >>> constraint = ExpressionConstraint("sin(y0) + cos(y1)")
81
+ >>> penalty = constraint(predictions, inputs)
82
+ """
83
+
84
+ def __init__(self, expression: str, reduction: str = "mse"):
85
+ """
86
+ Args:
87
+ expression: Mathematical expression to evaluate (should equal 0)
88
+ reduction: How to reduce violations - 'mse' or 'mae'
89
+ """
90
+ super().__init__()
91
+ self.expression = expression
92
+ self.reduction = reduction
93
+ self._tree = ast.parse(expression, mode="eval")
94
+ self._validate(self._tree)
95
+
96
+ def _validate(self, tree: ast.Expression) -> None:
97
+ """Validate that expression only uses safe functions."""
98
+ for node in ast.walk(tree):
99
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
100
+ if node.func.id not in SAFE_FUNCTIONS:
101
+ raise ValueError(
102
+ f"Unsafe function '{node.func.id}' in constraint. "
103
+ f"Allowed: {list(SAFE_FUNCTIONS.keys())}"
104
+ )
105
+
106
+ def _eval(
107
+ self, node: ast.AST, pred: torch.Tensor, inputs: torch.Tensor | None
108
+ ) -> torch.Tensor:
109
+ """Recursively evaluate AST node."""
110
+ if isinstance(node, ast.Constant):
111
+ return torch.tensor(node.value, device=pred.device, dtype=pred.dtype)
112
+ elif isinstance(node, ast.Name):
113
+ name = node.id
114
+ # Output variable: y0, y1, ...
115
+ if name.startswith("y") and name[1:].isdigit():
116
+ idx = int(name[1:])
117
+ if idx >= pred.shape[1]:
118
+ raise ValueError(
119
+ f"Output index {idx} out of range. "
120
+ f"Model has {pred.shape[1]} outputs."
121
+ )
122
+ return pred[:, idx]
123
+ # Input aggregate: x_mean, x_sum, ...
124
+ elif name in INPUT_AGGREGATES:
125
+ if inputs is None:
126
+ raise ValueError(
127
+ f"Constraint uses '{name}' but inputs not provided."
128
+ )
129
+ return INPUT_AGGREGATES[name](inputs)
130
+ else:
131
+ raise ValueError(
132
+ f"Unknown variable '{name}'. "
133
+ f"Use y0, y1, ... for outputs or {list(INPUT_AGGREGATES.keys())} for inputs."
134
+ )
135
+ elif isinstance(node, ast.BinOp):
136
+ left = self._eval(node.left, pred, inputs)
137
+ right = self._eval(node.right, pred, inputs)
138
+ ops = {
139
+ ast.Add: torch.add,
140
+ ast.Sub: torch.sub,
141
+ ast.Mult: torch.mul,
142
+ ast.Div: torch.div,
143
+ ast.Pow: torch.pow,
144
+ ast.Mod: torch.remainder,
145
+ }
146
+ if type(node.op) not in ops:
147
+ raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
148
+ return ops[type(node.op)](left, right)
149
+ elif isinstance(node, ast.UnaryOp):
150
+ operand = self._eval(node.operand, pred, inputs)
151
+ if isinstance(node.op, ast.USub):
152
+ return -operand
153
+ elif isinstance(node.op, ast.UAdd):
154
+ return operand
155
+ else:
156
+ raise ValueError(
157
+ f"Unsupported unary operator: {type(node.op).__name__}"
158
+ )
159
+ elif isinstance(node, ast.Call):
160
+ if not isinstance(node.func, ast.Name):
161
+ raise ValueError("Only direct function calls supported (e.g., sin(x))")
162
+ func_name = node.func.id
163
+ if func_name not in SAFE_FUNCTIONS:
164
+ raise ValueError(f"Unsafe function: {func_name}")
165
+ args = [self._eval(arg, pred, inputs) for arg in node.args]
166
+ return SAFE_FUNCTIONS[func_name](*args)
167
+ elif isinstance(node, ast.Compare):
168
+ # Comparison operators: y0 > 0, y0 < 1, y0 >= 0, y0 <= 1
169
+ # Returns penalty (violation amount) when constraint is not satisfied
170
+ if len(node.ops) != 1 or len(node.comparators) != 1:
171
+ raise ValueError(
172
+ "Only single comparisons supported (e.g., 'y0 > 0', not 'y0 > 0 > y1')"
173
+ )
174
+ left = self._eval(node.left, pred, inputs)
175
+ right = self._eval(node.comparators[0], pred, inputs)
176
+ op = node.ops[0]
177
+
178
+ # Return violation amount (0 if satisfied, positive if violated)
179
+ if isinstance(
180
+ op, (ast.Gt, ast.GtE)
181
+ ): # y0 > right → penalize if y0 <= right
182
+ return F.relu(right - left)
183
+ elif isinstance(
184
+ op, (ast.Lt, ast.LtE)
185
+ ): # y0 < right → penalize if y0 >= right
186
+ return F.relu(left - right)
187
+ elif isinstance(op, ast.Eq): # y0 == right → penalize difference
188
+ return torch.abs(left - right)
189
+ elif isinstance(op, ast.NotEq): # y0 != right → not useful as constraint
190
+ raise ValueError(
191
+ "'!=' is not a valid constraint. Use '==' for equality constraints."
192
+ )
193
+ else:
194
+ raise ValueError(
195
+ f"Unsupported comparison operator: {type(op).__name__}"
196
+ )
197
+ elif isinstance(node, ast.Subscript):
198
+ # Input indexing: x[0], x[0,5], x[0,5,10]
199
+ if not isinstance(node.value, ast.Name) or node.value.id != "x":
200
+ raise ValueError(
201
+ "Subscript indexing only supported for 'x' (inputs). "
202
+ "Use x[i], x[i,j], or x[i,j,k]."
203
+ )
204
+ if inputs is None:
205
+ raise ValueError("Constraint uses 'x[...]' but inputs not provided.")
206
+
207
+ # Parse indices from the slice
208
+ indices = self._parse_subscript_indices(node.slice)
209
+
210
+ # Validate dimensions match
211
+ # inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
212
+ input_ndim = inputs.ndim - 1 # Exclude batch dimension
213
+ if len(indices) != input_ndim:
214
+ raise ValueError(
215
+ f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
216
+ f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
217
+ )
218
+
219
+ # Extract the value at the specified indices (for entire batch)
220
+ if len(indices) == 1:
221
+ return inputs[:, indices[0]]
222
+ elif len(indices) == 2:
223
+ return inputs[:, indices[0], indices[1]]
224
+ elif len(indices) == 3:
225
+ return inputs[:, indices[0], indices[1], indices[2]]
226
+ else:
227
+ raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
228
+ elif isinstance(node, ast.Expression):
229
+ return self._eval(node.body, pred, inputs)
230
+ else:
231
+ raise ValueError(f"Unsupported AST node type: {type(node).__name__}")
232
+
233
+ def _parse_subscript_indices(self, slice_node: ast.AST) -> list[int]:
234
+ """Parse subscript indices from AST slice node."""
235
+ if isinstance(slice_node, ast.Constant):
236
+ # Single index: x[0]
237
+ return [int(slice_node.value)]
238
+ elif isinstance(slice_node, ast.Tuple):
239
+ # Multiple indices: x[0,5] or x[0,5,10]
240
+ indices = []
241
+ for elt in slice_node.elts:
242
+ if not isinstance(elt, ast.Constant):
243
+ raise ValueError(
244
+ "Only constant indices supported in x[...]. "
245
+ "Use x[0,5] not x[i,j]."
246
+ )
247
+ indices.append(int(elt.value))
248
+ return indices
249
+ else:
250
+ raise ValueError(
251
+ f"Unsupported subscript type: {type(slice_node).__name__}. "
252
+ "Use x[0], x[0,5], or x[0,5,10]."
253
+ )
254
+
255
+ def forward(
256
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
257
+ ) -> torch.Tensor:
258
+ """
259
+ Compute constraint violation penalty.
260
+
261
+ Args:
262
+ pred: Model predictions of shape (N, num_outputs)
263
+ inputs: Model inputs of shape (N, ...) for input-dependent constraints
264
+
265
+ Returns:
266
+ Scalar penalty value
267
+ """
268
+ violation = self._eval(self._tree, pred, inputs)
269
+ if self.reduction == "mse":
270
+ return (violation**2).mean()
271
+ else: # mae
272
+ return violation.abs().mean()
273
+
274
+ def __repr__(self) -> str:
275
+ return (
276
+ f"ExpressionConstraint('{self.expression}', reduction='{self.reduction}')"
277
+ )
278
+
279
+
280
+ class FileConstraint(nn.Module):
281
+ """
282
+ Load constraint function from Python file.
283
+
284
+ The file must define a function `constraint(pred, inputs=None)` that
285
+ returns per-sample violation values.
286
+
287
+ Example file (my_constraint.py):
288
+ import torch
289
+
290
+ def constraint(pred, inputs=None):
291
+ # Monotonicity: y0 < y1 < y2
292
+ diffs = pred[:, 1:] - pred[:, :-1]
293
+ return torch.relu(-diffs).sum(dim=1)
294
+
295
+ Usage:
296
+ >>> constraint = FileConstraint("my_constraint.py")
297
+ >>> penalty = constraint(predictions, inputs)
298
+ """
299
+
300
+ def __init__(self, file_path: str, reduction: str = "mse"):
301
+ """
302
+ Args:
303
+ file_path: Path to Python file containing constraint function
304
+ reduction: How to reduce violations - 'mse' or 'mae'
305
+ """
306
+ super().__init__()
307
+ self.file_path = file_path
308
+ self.reduction = reduction
309
+
310
+ # Load module from file
311
+ spec = importlib.util.spec_from_file_location("constraint_module", file_path)
312
+ if spec is None or spec.loader is None:
313
+ raise ValueError(f"Could not load constraint file: {file_path}")
314
+
315
+ module = importlib.util.module_from_spec(spec)
316
+ sys.modules["constraint_module"] = module
317
+ spec.loader.exec_module(module)
318
+
319
+ if not hasattr(module, "constraint"):
320
+ raise ValueError(
321
+ f"Constraint file must define 'constraint(pred, inputs)' function: {file_path}"
322
+ )
323
+
324
+ self._constraint_fn = module.constraint
325
+
326
+ def forward(
327
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
328
+ ) -> torch.Tensor:
329
+ """Evaluate constraint from loaded function."""
330
+ violation = self._constraint_fn(pred, inputs)
331
+ if self.reduction == "mse":
332
+ return (violation**2).mean()
333
+ else:
334
+ return violation.abs().mean()
335
+
336
+ def __repr__(self) -> str:
337
+ return f"FileConstraint('{self.file_path}')"
338
+
339
+
340
+ # ==============================================================================
341
+ # COMBINED LOSS WRAPPER
342
+ # ==============================================================================
343
+ class PhysicsConstrainedLoss(nn.Module):
344
+ """
345
+ Combine base loss with constraint penalties.
346
+
347
+ Total Loss = Base Loss + Σ(weight_i × constraint_i)
348
+
349
+ Constraints are evaluated in **physical space** (denormalized) while
350
+ the base loss is computed in normalized space for stable training.
351
+
352
+ Example:
353
+ >>> base_loss = nn.MSELoss()
354
+ >>> constraints = [ExpressionConstraint("y0 - y1*y2")]
355
+ >>> criterion = PhysicsConstrainedLoss(
356
+ ... base_loss,
357
+ ... constraints,
358
+ ... weights=[0.1],
359
+ ... output_mean=[10, 5, 50],
360
+ ... output_std=[2, 1, 10],
361
+ ... )
362
+ >>> loss = criterion(pred, target, inputs)
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ base_loss: nn.Module,
368
+ constraints: list[nn.Module] | None = None,
369
+ weights: list[float] | None = None,
370
+ output_mean: torch.Tensor | list[float] | None = None,
371
+ output_std: torch.Tensor | list[float] | None = None,
372
+ ):
373
+ """
374
+ Args:
375
+ base_loss: Base loss function (e.g., MSELoss)
376
+ constraints: List of constraint modules
377
+ weights: Weight for each constraint. If shorter than constraints,
378
+ last weight is repeated.
379
+ output_mean: Mean of each output (for denormalization). Shape: (num_outputs,)
380
+ output_std: Std of each output (for denormalization). Shape: (num_outputs,)
381
+ """
382
+ super().__init__()
383
+ self.base_loss = base_loss
384
+ self.constraints = nn.ModuleList(constraints or [])
385
+ self.weights = weights or [0.1]
386
+
387
+ # Store scaler as buffers (moves to correct device automatically)
388
+ if output_mean is not None:
389
+ if not isinstance(output_mean, torch.Tensor):
390
+ output_mean = torch.tensor(output_mean, dtype=torch.float32)
391
+ self.register_buffer("output_mean", output_mean)
392
+ else:
393
+ self.register_buffer("output_mean", None)
394
+
395
+ if output_std is not None:
396
+ if not isinstance(output_std, torch.Tensor):
397
+ output_std = torch.tensor(output_std, dtype=torch.float32)
398
+ self.register_buffer("output_std", output_std)
399
+ else:
400
+ self.register_buffer("output_std", None)
401
+
402
+ def _denormalize(self, pred: torch.Tensor) -> torch.Tensor:
403
+ """Convert normalized predictions to physical values."""
404
+ if self.output_mean is None or self.output_std is None:
405
+ return pred
406
+ return pred * self.output_std + self.output_mean
407
+
408
+ def forward(
409
+ self,
410
+ pred: torch.Tensor,
411
+ target: torch.Tensor,
412
+ inputs: torch.Tensor | None = None,
413
+ ) -> torch.Tensor:
414
+ """
415
+ Compute combined loss.
416
+
417
+ Args:
418
+ pred: Model predictions (normalized)
419
+ target: Ground truth targets (normalized)
420
+ inputs: Model inputs (for input-dependent constraints)
421
+
422
+ Returns:
423
+ Combined loss value
424
+ """
425
+ # Base loss in normalized space (stable gradients)
426
+ loss = self.base_loss(pred, target)
427
+
428
+ # Denormalize for constraint evaluation (physical units)
429
+ pred_physical = self._denormalize(pred)
430
+
431
+ for i, constraint in enumerate(self.constraints):
432
+ weight = self.weights[i] if i < len(self.weights) else self.weights[-1]
433
+ penalty = constraint(pred_physical, inputs)
434
+ loss = loss + weight * penalty
435
+
436
+ return loss
437
+
438
+ def __repr__(self) -> str:
439
+ has_scaler = self.output_mean is not None
440
+ return f"PhysicsConstrainedLoss(base={self.base_loss}, constraints={len(self.constraints)}, denormalize={has_scaler})"
441
+
442
+
443
+ # ==============================================================================
444
+ # FACTORY FUNCTIONS
445
+ # ==============================================================================
446
+ def build_constraints(
447
+ expressions: list[str] | None = None,
448
+ file_path: str | None = None,
449
+ reduction: str = "mse",
450
+ ) -> list[nn.Module]:
451
+ """
452
+ Build soft constraint modules from CLI arguments.
453
+
454
+ Args:
455
+ expressions: Expression constraints (e.g., ["y0 - y1*y2", "y0 > 0"])
456
+ file_path: Path to Python constraint file
457
+ reduction: Reduction mode for penalties
458
+
459
+ Returns:
460
+ List of constraint modules
461
+ """
462
+ constraints: list[nn.Module] = []
463
+
464
+ for expr in expressions or []:
465
+ constraints.append(ExpressionConstraint(expr, reduction))
466
+
467
+ if file_path:
468
+ constraints.append(FileConstraint(file_path, reduction))
469
+
470
+ return constraints
wavedl/utils/metrics.py CHANGED
@@ -560,11 +560,56 @@ def create_training_curves(
560
560
  )
561
561
  lines.append(line)
562
562
 
563
+ def set_lr_ticks(ax: plt.Axes, data: list[float], n_ticks: int = 4) -> None:
564
+ """Set n uniformly spaced ticks on LR axis with 10^n format labels."""
565
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
566
+ if not valid_data:
567
+ return
568
+ vmin, vmax = min(valid_data), max(valid_data)
569
+ # Snap to clean decade boundaries
570
+ log_min = np.floor(np.log10(vmin))
571
+ log_max = np.ceil(np.log10(vmax))
572
+ # Generate n uniformly spaced ticks as powers of 10
573
+ log_ticks = np.linspace(log_min, log_max, n_ticks)
574
+ # Round to nearest integer power of 10 for clean numbers
575
+ log_ticks = np.round(log_ticks)
576
+ ticks = 10.0**log_ticks
577
+ # Remove duplicates while preserving order
578
+ ticks = list(dict.fromkeys(ticks))
579
+ ax.set_yticks(ticks)
580
+ # Format all tick labels as 10^n
581
+ labels = [f"$10^{{{int(np.log10(t))}}}$" for t in ticks]
582
+ ax.set_yticklabels(labels)
583
+ ax.minorticks_off()
584
+
585
+ def set_loss_ticks(ax: plt.Axes, data: list[float]) -> None:
586
+ """Set ticks at powers of 10 that cover the data range."""
587
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
588
+ if not valid_data:
589
+ return
590
+ vmin, vmax = min(valid_data), max(valid_data)
591
+ # Get decade range that covers data (ceil for min to avoid going too low)
592
+ log_min = int(np.ceil(np.log10(vmin)))
593
+ log_max = int(np.ceil(np.log10(vmax)))
594
+ # Generate ticks at each power of 10
595
+ ticks = [10.0**i for i in range(log_min, log_max + 1)]
596
+ ax.set_yticks(ticks)
597
+ # Format labels as 10^n
598
+ labels = [f"$10^{{{i}}}$" for i in range(log_min, log_max + 1)]
599
+ ax.set_yticklabels(labels)
600
+ ax.minorticks_off()
601
+
563
602
  ax1.set_xlabel("Epoch")
564
603
  ax1.set_ylabel("Loss")
565
604
  ax1.set_yscale("log") # Log scale for loss
566
605
  ax1.grid(True, alpha=0.3)
567
606
 
607
+ # Collect all loss values and set clean power of 10 ticks
608
+ all_loss_values = []
609
+ for metric in metrics:
610
+ all_loss_values.extend([h.get(metric, np.nan) for h in history])
611
+ set_loss_ticks(ax1, all_loss_values)
612
+
568
613
  # Check if learning rate data exists
569
614
  has_lr = show_lr and any("lr" in h for h in history)
570
615
 
@@ -581,9 +626,11 @@ def create_training_curves(
581
626
  alpha=0.7,
582
627
  label="Learning Rate",
583
628
  )
584
- ax2.set_ylabel("Learning Rate", color=COLORS["neutral"])
585
- ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
629
+ ax2.set_ylabel("Learning Rate")
586
630
  ax2.set_yscale("log") # Log scale for LR
631
+ set_lr_ticks(ax2, lr_values, n_ticks=4)
632
+ # Ensure right spine (axis line) is visible
633
+ ax2.spines["right"].set_visible(True)
587
634
  lines.append(line_lr)
588
635
 
589
636
  # Combined legend
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.5
3
+ Version: 1.5.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
113
113
  </td>
114
114
  <td width="50%" valign="top">
115
115
 
116
- **🧠 One-Line Model Registration**
116
+ **🧠 Models? We've Got Options**
117
117
 
118
- Plug in any architecture:
119
- ```python
120
- @register_model("my_net")
121
- class MyNet(BaseModel): ...
122
- ```
123
- Design your model. Register with one line.
118
+ 38 architectures, ready to go:
119
+ - CNNs, ResNets, ViTs, EfficientNets...
120
+ - All adapted for regression
121
+ - [Add your own](#adding-custom-models) in one line
124
122
 
125
123
  </td>
126
124
  </tr>
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
137
135
  </td>
138
136
  <td width="50%" valign="top">
139
137
 
140
- **📊 Publish-Ready Output**
138
+ **🔬 Physics-Constrained Training**
141
139
 
142
- Results go straight to your paper:
143
- - 11 diagnostic plots with LaTeX styling
144
- - Multi-format export (PNG, PDF, SVG, ...)
145
- - MAE in physical units per parameter
140
+ Make your model respect the laws:
141
+ - Enforce bounds, positivity, equations
142
+ - Simple expression syntax or Python
143
+ - [Custom constraints](#physical-constraints) for various laws
146
144
 
147
145
  </td>
148
146
  </tr>
@@ -383,7 +381,7 @@ WaveDL/
383
381
  ├── configs/ # YAML config templates
384
382
  ├── examples/ # Ready-to-run examples
385
383
  ├── notebooks/ # Jupyter notebooks
386
- ├── unit_tests/ # Pytest test suite (704 tests)
384
+ ├── unit_tests/ # Pytest test suite (725 tests)
387
385
 
388
386
  ├── pyproject.toml # Package config, dependencies
389
387
  ├── CHANGELOG.md # Version history
@@ -727,6 +725,104 @@ seed: 2025
727
725
 
728
726
  </details>
729
727
 
728
+ <details>
729
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
730
+
731
+ Add penalty terms to the loss function to enforce physical laws:
732
+
733
+ ```
734
+ Total Loss = Data Loss + weight × penalty(violation)
735
+ ```
736
+
737
+ ### Expression Constraints
738
+
739
+ ```bash
740
+ # Positivity
741
+ --constraint "y0 > 0"
742
+
743
+ # Bounds
744
+ --constraint "y0 >= 0" "y0 <= 1"
745
+
746
+ # Equations (penalize deviations from zero)
747
+ --constraint "y2 - y0 * y1"
748
+
749
+ # Input-dependent constraints
750
+ --constraint "y0 - 2*x[0]"
751
+
752
+ # Multiple constraints with different weights
753
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
754
+ ```
755
+
756
+ ### Custom Python Constraints
757
+
758
+ For complex physics (matrix operations, implicit equations):
759
+
760
+ ```python
761
+ # my_constraint.py
762
+ import torch
763
+
764
+ def constraint(pred, inputs=None):
765
+ """
766
+ Args:
767
+ pred: (batch, num_outputs)
768
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
769
+ Returns:
770
+ (batch,) — violation per sample (0 = satisfied)
771
+ """
772
+ # Outputs (same for all data types)
773
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
774
+
775
+ # Inputs — Tabular: (batch, features)
776
+ # x0 = inputs[:, 0] # Feature 0
777
+ # x_sum = inputs.sum(dim=1) # Sum all features
778
+
779
+ # Inputs — Images: (batch, C, H, W)
780
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
781
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
782
+
783
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
784
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
785
+
786
+ # Example constraints:
787
+ # return y2 - y0 * y1 # Wave equation
788
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
789
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
790
+
791
+ return y0 - y1 * y2
792
+ ```
793
+
794
+ ```bash
795
+ --constraint_file my_constraint.py --constraint_weight 1.0
796
+ ```
797
+
798
+ ---
799
+
800
+ ### Reference
801
+
802
+ | Argument | Default | Description |
803
+ |----------|---------|-------------|
804
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
805
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
806
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
807
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
808
+
809
+ #### Expression Syntax
810
+
811
+ | Variable | Meaning |
812
+ |----------|---------|
813
+ | `y0`, `y1`, ... | Model outputs |
814
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
815
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
816
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
817
+
818
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
819
+
820
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
821
+
822
+ </details>
823
+
824
+
825
+
730
826
  <details>
731
827
  <summary><b>Hyperparameter Search (HPO)</b></summary>
732
828
 
@@ -734,18 +830,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
734
830
 
735
831
  **Run HPO:**
736
832
 
737
- You specify which models to search and how many trials to run:
738
833
  ```bash
739
- # Search 3 models with 100 trials
740
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
834
+ # Basic HPO (auto-detects GPUs for parallel trials)
835
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
741
836
 
742
- # Search 1 model (faster)
743
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
837
+ # Search multiple models
838
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
744
839
 
745
- # Search all your candidate models
746
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
840
+ # Quick mode (fewer parameters, faster)
841
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
747
842
  ```
748
843
 
844
+ > [!TIP]
845
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
846
+
749
847
  **Train with best parameters**
750
848
 
751
849
  After HPO completes, it prints the optimal command:
@@ -764,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
764
862
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
765
863
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
766
864
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
767
- | Batch size | 64, 128, 256, 512 | (always searched) |
865
+ | Batch size | 16, 32, 64, 128 | (always searched) |
768
866
 
769
867
  **Quick Mode** (`--quick`):
770
868
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -784,7 +882,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
784
882
  | `--optimizers` | all 6 | Optimizers to search |
785
883
  | `--schedulers` | all 8 | Schedulers to search |
786
884
  | `--losses` | all 6 | Losses to search |
787
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
885
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
788
886
  | `--max_epochs` | `50` | Max epochs per trial |
789
887
  | `--output` | `hpo_results.json` | Output file |
790
888
 
@@ -936,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
936
1034
  ```bash
937
1035
  # Run inference on the example data
938
1036
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
939
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1037
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
940
1038
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
941
1039
 
942
1040
  # Export to ONNX (already included as model.onnx)
943
1041
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
944
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1042
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
945
1043
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
946
1044
  ```
947
1045
 
@@ -950,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
950
1048
  | File | Description |
951
1049
  |------|-------------|
952
1050
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
953
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1051
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
954
1052
  | `model.onnx` | ONNX export with embedded de-normalization |
955
1053
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
956
1054
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -961,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
961
1059
 
962
1060
  <p align="center">
963
1061
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
964
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1062
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
965
1063
  </p>
966
1064
 
967
1065
  **Inference Results:**
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=2ro7SYQ3wCmq-ejiAm5sd6BeXf6sZgixC9U2vS7Ckbs,1177
2
- wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
3
- wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
- wavedl/test.py,sha256=81al6vQBDAJ3CpSEtxZn6xzR1c4-jo28R7tX_84KROc,37642
5
- wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
1
+ wavedl/__init__.py,sha256=jeYx6dZ0_UD4yXl-I4_Aa63820RzS5-IxJP_KUni8Pw,1177
2
+ wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
3
+ wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
4
+ wavedl/test.py,sha256=oWGSSC7178loqOxwti-oDXUVogOqbwHL__GfoXSE5Ss,37846
5
+ wavedl/train.py,sha256=Xy0-W9Kxqdc7WaAyt_HDSZwhZKrfjVbZd2KzgSid6rQ,52457
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -20,18 +20,19 @@ wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
20
20
  wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
21
21
  wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
22
22
  wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
23
- wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
24
- wavedl/utils/config.py,sha256=fMoucikIQHn85mVhGMa7TnXTuFDcEEPjfXk2EjbkJR0,10591
23
+ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
+ wavedl/utils/config.py,sha256=jGW-K7AYB6zrD2BfVm2XPnSY9rbfL_EkM4bwxhBLuwM,10859
25
+ wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
25
26
  wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
26
27
  wavedl/utils/data.py,sha256=_OaWvU5oFVJW0NwM5WyDD0Kb1hy5MgvJIFpzvJGux9w,48214
27
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
28
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
- wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
+ wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
30
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.5.dist-info/METADATA,sha256=4ltxFDaqPqh4XUAW_K8nkFmvqBzPcL2cxmghH11GMWg,42191
34
- wavedl-1.4.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.5.dist-info/RECORD,,
33
+ wavedl-1.5.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.0.dist-info/METADATA,sha256=jlQPfPLtwIXzWHKwgN8xQjzco_xB7L3j7p_dl3TJA8E,45224
35
+ wavedl-1.5.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.0.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.0.dist-info/RECORD,,
File without changes