wavedl 1.4.6__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.6"
21
+ __version__ = "1.5.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpo.py CHANGED
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -317,6 +318,13 @@ Examples:
317
318
  default=None,
318
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
319
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
320
328
 
321
329
  # Training settings for each trial
322
330
  parser.add_argument(
wavedl/train.py CHANGED
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
375
375
  help=argparse.SUPPRESS, # Hidden: use --precision instead
376
376
  )
377
377
 
378
+ # Physical Constraints
379
+ parser.add_argument(
380
+ "--constraint",
381
+ type=str,
382
+ nargs="+",
383
+ default=[],
384
+ help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--constraint_file",
389
+ type=str,
390
+ default=None,
391
+ help="Python file with constraint(pred, inputs) function",
392
+ )
393
+ parser.add_argument(
394
+ "--constraint_weight",
395
+ type=float,
396
+ nargs="+",
397
+ default=[0.1],
398
+ help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
399
+ )
400
+ parser.add_argument(
401
+ "--constraint_reduction",
402
+ type=str,
403
+ default="mse",
404
+ choices=["mse", "mae"],
405
+ help="Reduction mode for constraint penalties",
406
+ )
407
+
378
408
  # Logging
379
409
  parser.add_argument(
380
410
  "--wandb", action="store_true", help="Enable Weights & Biases logging"
@@ -553,7 +583,7 @@ def main():
553
583
  return
554
584
 
555
585
  # ==========================================================================
556
- # 1. SYSTEM INITIALIZATION
586
+ # SYSTEM INITIALIZATION
557
587
  # ==========================================================================
558
588
  # Initialize Accelerator for DDP and mixed precision
559
589
  accelerator = Accelerator(
@@ -609,7 +639,7 @@ def main():
609
639
  )
610
640
 
611
641
  # ==========================================================================
612
- # 2. DATA & MODEL LOADING
642
+ # DATA & MODEL LOADING
613
643
  # ==========================================================================
614
644
  train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
615
645
  args, logger, accelerator, cache_dir=args.output_dir
@@ -663,7 +693,7 @@ def main():
663
693
  )
664
694
 
665
695
  # ==========================================================================
666
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
696
+ # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
667
697
  # ==========================================================================
668
698
  # Parse comma-separated arguments with validation
669
699
  try:
@@ -707,6 +737,43 @@ def main():
707
737
  # Move criterion to device (important for WeightedMSELoss buffer)
708
738
  criterion = criterion.to(accelerator.device)
709
739
 
740
+ # ==========================================================================
741
+ # PHYSICAL CONSTRAINTS INTEGRATION
742
+ # ==========================================================================
743
+ from wavedl.utils.constraints import (
744
+ PhysicsConstrainedLoss,
745
+ build_constraints,
746
+ )
747
+
748
+ # Build soft constraints
749
+ soft_constraints = build_constraints(
750
+ expressions=args.constraint,
751
+ file_path=args.constraint_file,
752
+ reduction=args.constraint_reduction,
753
+ )
754
+
755
+ # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
756
+ if soft_constraints:
757
+ # Pass output scaler so constraints can be evaluated in physical space
758
+ output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
759
+ output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
760
+ criterion = PhysicsConstrainedLoss(
761
+ criterion,
762
+ soft_constraints,
763
+ weights=args.constraint_weight,
764
+ output_mean=output_mean,
765
+ output_std=output_std,
766
+ )
767
+ if accelerator.is_main_process:
768
+ logger.info(
769
+ f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
770
+ f"with weight(s) {args.constraint_weight}"
771
+ )
772
+ if output_mean is not None:
773
+ logger.info(
774
+ " 📐 Constraints evaluated in physical space (denormalized)"
775
+ )
776
+
710
777
  # Track if scheduler should step per batch (OneCycleLR) or per epoch
711
778
  scheduler_step_per_batch = not is_epoch_based(args.scheduler)
712
779
 
@@ -762,7 +829,7 @@ def main():
762
829
  )
763
830
 
764
831
  # ==========================================================================
765
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
832
+ # AUTO-RESUME / RESUME FROM CHECKPOINT
766
833
  # ==========================================================================
767
834
  start_epoch = 0
768
835
  best_val_loss = float("inf")
@@ -818,7 +885,7 @@ def main():
818
885
  raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
819
886
 
820
887
  # ==========================================================================
821
- # 4. PHYSICAL METRIC SETUP
888
+ # PHYSICAL METRIC SETUP
822
889
  # ==========================================================================
823
890
  # Physical MAE = normalized MAE * scaler.scale_
824
891
  phys_scale = torch.tensor(
@@ -826,7 +893,7 @@ def main():
826
893
  )
827
894
 
828
895
  # ==========================================================================
829
- # 5. TRAINING LOOP
896
+ # TRAINING LOOP
830
897
  # ==========================================================================
831
898
  # Dynamic console header
832
899
  if accelerator.is_main_process:
wavedl/utils/__init__.py CHANGED
@@ -15,6 +15,12 @@ from .config import (
15
15
  save_config,
16
16
  validate_config,
17
17
  )
18
+ from .constraints import (
19
+ ExpressionConstraint,
20
+ FileConstraint,
21
+ PhysicsConstrainedLoss,
22
+ build_constraints,
23
+ )
18
24
  from .cross_validation import (
19
25
  CVDataset,
20
26
  run_cross_validation,
@@ -91,8 +97,11 @@ __all__ = [
91
97
  "FIGURE_WIDTH_INCH",
92
98
  "FONT_SIZE_TEXT",
93
99
  "FONT_SIZE_TICKS",
100
+ # Constraints
94
101
  "CVDataset",
95
102
  "DataSource",
103
+ "ExpressionConstraint",
104
+ "FileConstraint",
96
105
  "HDF5Source",
97
106
  "LogCoshLoss",
98
107
  "MATSource",
@@ -101,10 +110,12 @@ __all__ = [
101
110
  # Metrics
102
111
  "MetricTracker",
103
112
  "NPZSource",
113
+ "PhysicsConstrainedLoss",
104
114
  "WeightedMSELoss",
105
115
  # Distributed
106
116
  "broadcast_early_stop",
107
117
  "broadcast_value",
118
+ "build_constraints",
108
119
  "calc_pearson",
109
120
  "calc_per_target_r2",
110
121
  "configure_matplotlib_style",
wavedl/utils/config.py CHANGED
@@ -306,6 +306,16 @@ def validate_config(
306
306
  # Config
307
307
  "config",
308
308
  "list_models",
309
+ # Physical Constraints
310
+ "constraint",
311
+ "bounds",
312
+ "constraint_file",
313
+ "constraint_weight",
314
+ "constraint_reduction",
315
+ "positive",
316
+ "output_bounds",
317
+ "output_transform",
318
+ "output_formula",
309
319
  # Metadata (internal)
310
320
  "_metadata",
311
321
  }
@@ -0,0 +1,470 @@
1
+ """
2
+ Physical Constraints for Training
3
+ =================================
4
+
5
+ Soft constraint enforcement via penalty-based loss terms.
6
+
7
+ Usage:
8
+ # Expression constraints
9
+ wavedl-train --constraint "y0 > 0" --constraint_weight 0.1
10
+
11
+ # Complex constraints via Python file
12
+ wavedl-train --constraint_file my_constraint.py
13
+
14
+ Author: Ductho Le (ductho.le@outlook.com)
15
+ Version: 2.0.0
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import importlib.util
22
+ import sys
23
+ from typing import TYPE_CHECKING
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from collections.abc import Callable
32
+
33
+
34
+ # ==============================================================================
35
+ # SAFE EXPRESSION PARSING
36
+ # ==============================================================================
37
+ SAFE_FUNCTIONS: dict[str, Callable] = {
38
+ "sin": torch.sin,
39
+ "cos": torch.cos,
40
+ "tan": torch.tan,
41
+ "exp": torch.exp,
42
+ "log": torch.log,
43
+ "sqrt": torch.sqrt,
44
+ "abs": torch.abs,
45
+ "relu": F.relu,
46
+ "sigmoid": torch.sigmoid,
47
+ "softplus": F.softplus,
48
+ "tanh": torch.tanh,
49
+ "min": torch.minimum,
50
+ "max": torch.maximum,
51
+ "pow": torch.pow,
52
+ "clamp": torch.clamp,
53
+ }
54
+
55
+ INPUT_AGGREGATES: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
56
+ "x_mean": lambda x: x.mean(dim=tuple(range(1, x.ndim))),
57
+ "x_sum": lambda x: x.sum(dim=tuple(range(1, x.ndim))),
58
+ "x_max": lambda x: x.amax(dim=tuple(range(1, x.ndim))),
59
+ "x_min": lambda x: x.amin(dim=tuple(range(1, x.ndim))),
60
+ "x_std": lambda x: x.std(dim=tuple(range(1, x.ndim))),
61
+ "x_energy": lambda x: (x**2).sum(dim=tuple(range(1, x.ndim))),
62
+ }
63
+
64
+
65
+ # ==============================================================================
66
+ # SOFT CONSTRAINTS
67
+ # ==============================================================================
68
+ class ExpressionConstraint(nn.Module):
69
+ """
70
+ Soft constraint via string expression.
71
+
72
+ Parses mathematical expressions using Python's AST for safe evaluation.
73
+ Supports output variables (y0, y1, ...), input aggregates (x_mean, ...),
74
+ and whitelisted math functions.
75
+
76
+ Example:
77
+ >>> constraint = ExpressionConstraint("y0 - y1 * y2")
78
+ >>> penalty = constraint(predictions, inputs)
79
+
80
+ >>> constraint = ExpressionConstraint("sin(y0) + cos(y1)")
81
+ >>> penalty = constraint(predictions, inputs)
82
+ """
83
+
84
+ def __init__(self, expression: str, reduction: str = "mse"):
85
+ """
86
+ Args:
87
+ expression: Mathematical expression to evaluate (should equal 0)
88
+ reduction: How to reduce violations - 'mse' or 'mae'
89
+ """
90
+ super().__init__()
91
+ self.expression = expression
92
+ self.reduction = reduction
93
+ self._tree = ast.parse(expression, mode="eval")
94
+ self._validate(self._tree)
95
+
96
+ def _validate(self, tree: ast.Expression) -> None:
97
+ """Validate that expression only uses safe functions."""
98
+ for node in ast.walk(tree):
99
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
100
+ if node.func.id not in SAFE_FUNCTIONS:
101
+ raise ValueError(
102
+ f"Unsafe function '{node.func.id}' in constraint. "
103
+ f"Allowed: {list(SAFE_FUNCTIONS.keys())}"
104
+ )
105
+
106
+ def _eval(
107
+ self, node: ast.AST, pred: torch.Tensor, inputs: torch.Tensor | None
108
+ ) -> torch.Tensor:
109
+ """Recursively evaluate AST node."""
110
+ if isinstance(node, ast.Constant):
111
+ return torch.tensor(node.value, device=pred.device, dtype=pred.dtype)
112
+ elif isinstance(node, ast.Name):
113
+ name = node.id
114
+ # Output variable: y0, y1, ...
115
+ if name.startswith("y") and name[1:].isdigit():
116
+ idx = int(name[1:])
117
+ if idx >= pred.shape[1]:
118
+ raise ValueError(
119
+ f"Output index {idx} out of range. "
120
+ f"Model has {pred.shape[1]} outputs."
121
+ )
122
+ return pred[:, idx]
123
+ # Input aggregate: x_mean, x_sum, ...
124
+ elif name in INPUT_AGGREGATES:
125
+ if inputs is None:
126
+ raise ValueError(
127
+ f"Constraint uses '{name}' but inputs not provided."
128
+ )
129
+ return INPUT_AGGREGATES[name](inputs)
130
+ else:
131
+ raise ValueError(
132
+ f"Unknown variable '{name}'. "
133
+ f"Use y0, y1, ... for outputs or {list(INPUT_AGGREGATES.keys())} for inputs."
134
+ )
135
+ elif isinstance(node, ast.BinOp):
136
+ left = self._eval(node.left, pred, inputs)
137
+ right = self._eval(node.right, pred, inputs)
138
+ ops = {
139
+ ast.Add: torch.add,
140
+ ast.Sub: torch.sub,
141
+ ast.Mult: torch.mul,
142
+ ast.Div: torch.div,
143
+ ast.Pow: torch.pow,
144
+ ast.Mod: torch.remainder,
145
+ }
146
+ if type(node.op) not in ops:
147
+ raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
148
+ return ops[type(node.op)](left, right)
149
+ elif isinstance(node, ast.UnaryOp):
150
+ operand = self._eval(node.operand, pred, inputs)
151
+ if isinstance(node.op, ast.USub):
152
+ return -operand
153
+ elif isinstance(node.op, ast.UAdd):
154
+ return operand
155
+ else:
156
+ raise ValueError(
157
+ f"Unsupported unary operator: {type(node.op).__name__}"
158
+ )
159
+ elif isinstance(node, ast.Call):
160
+ if not isinstance(node.func, ast.Name):
161
+ raise ValueError("Only direct function calls supported (e.g., sin(x))")
162
+ func_name = node.func.id
163
+ if func_name not in SAFE_FUNCTIONS:
164
+ raise ValueError(f"Unsafe function: {func_name}")
165
+ args = [self._eval(arg, pred, inputs) for arg in node.args]
166
+ return SAFE_FUNCTIONS[func_name](*args)
167
+ elif isinstance(node, ast.Compare):
168
+ # Comparison operators: y0 > 0, y0 < 1, y0 >= 0, y0 <= 1
169
+ # Returns penalty (violation amount) when constraint is not satisfied
170
+ if len(node.ops) != 1 or len(node.comparators) != 1:
171
+ raise ValueError(
172
+ "Only single comparisons supported (e.g., 'y0 > 0', not 'y0 > 0 > y1')"
173
+ )
174
+ left = self._eval(node.left, pred, inputs)
175
+ right = self._eval(node.comparators[0], pred, inputs)
176
+ op = node.ops[0]
177
+
178
+ # Return violation amount (0 if satisfied, positive if violated)
179
+ if isinstance(
180
+ op, (ast.Gt, ast.GtE)
181
+ ): # y0 > right → penalize if y0 <= right
182
+ return F.relu(right - left)
183
+ elif isinstance(
184
+ op, (ast.Lt, ast.LtE)
185
+ ): # y0 < right → penalize if y0 >= right
186
+ return F.relu(left - right)
187
+ elif isinstance(op, ast.Eq): # y0 == right → penalize difference
188
+ return torch.abs(left - right)
189
+ elif isinstance(op, ast.NotEq): # y0 != right → not useful as constraint
190
+ raise ValueError(
191
+ "'!=' is not a valid constraint. Use '==' for equality constraints."
192
+ )
193
+ else:
194
+ raise ValueError(
195
+ f"Unsupported comparison operator: {type(op).__name__}"
196
+ )
197
+ elif isinstance(node, ast.Subscript):
198
+ # Input indexing: x[0], x[0,5], x[0,5,10]
199
+ if not isinstance(node.value, ast.Name) or node.value.id != "x":
200
+ raise ValueError(
201
+ "Subscript indexing only supported for 'x' (inputs). "
202
+ "Use x[i], x[i,j], or x[i,j,k]."
203
+ )
204
+ if inputs is None:
205
+ raise ValueError("Constraint uses 'x[...]' but inputs not provided.")
206
+
207
+ # Parse indices from the slice
208
+ indices = self._parse_subscript_indices(node.slice)
209
+
210
+ # Validate dimensions match
211
+ # inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
212
+ input_ndim = inputs.ndim - 1 # Exclude batch dimension
213
+ if len(indices) != input_ndim:
214
+ raise ValueError(
215
+ f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
216
+ f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
217
+ )
218
+
219
+ # Extract the value at the specified indices (for entire batch)
220
+ if len(indices) == 1:
221
+ return inputs[:, indices[0]]
222
+ elif len(indices) == 2:
223
+ return inputs[:, indices[0], indices[1]]
224
+ elif len(indices) == 3:
225
+ return inputs[:, indices[0], indices[1], indices[2]]
226
+ else:
227
+ raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
228
+ elif isinstance(node, ast.Expression):
229
+ return self._eval(node.body, pred, inputs)
230
+ else:
231
+ raise ValueError(f"Unsupported AST node type: {type(node).__name__}")
232
+
233
+ def _parse_subscript_indices(self, slice_node: ast.AST) -> list[int]:
234
+ """Parse subscript indices from AST slice node."""
235
+ if isinstance(slice_node, ast.Constant):
236
+ # Single index: x[0]
237
+ return [int(slice_node.value)]
238
+ elif isinstance(slice_node, ast.Tuple):
239
+ # Multiple indices: x[0,5] or x[0,5,10]
240
+ indices = []
241
+ for elt in slice_node.elts:
242
+ if not isinstance(elt, ast.Constant):
243
+ raise ValueError(
244
+ "Only constant indices supported in x[...]. "
245
+ "Use x[0,5] not x[i,j]."
246
+ )
247
+ indices.append(int(elt.value))
248
+ return indices
249
+ else:
250
+ raise ValueError(
251
+ f"Unsupported subscript type: {type(slice_node).__name__}. "
252
+ "Use x[0], x[0,5], or x[0,5,10]."
253
+ )
254
+
255
+ def forward(
256
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
257
+ ) -> torch.Tensor:
258
+ """
259
+ Compute constraint violation penalty.
260
+
261
+ Args:
262
+ pred: Model predictions of shape (N, num_outputs)
263
+ inputs: Model inputs of shape (N, ...) for input-dependent constraints
264
+
265
+ Returns:
266
+ Scalar penalty value
267
+ """
268
+ violation = self._eval(self._tree, pred, inputs)
269
+ if self.reduction == "mse":
270
+ return (violation**2).mean()
271
+ else: # mae
272
+ return violation.abs().mean()
273
+
274
+ def __repr__(self) -> str:
275
+ return (
276
+ f"ExpressionConstraint('{self.expression}', reduction='{self.reduction}')"
277
+ )
278
+
279
+
280
+ class FileConstraint(nn.Module):
281
+ """
282
+ Load constraint function from Python file.
283
+
284
+ The file must define a function `constraint(pred, inputs=None)` that
285
+ returns per-sample violation values.
286
+
287
+ Example file (my_constraint.py):
288
+ import torch
289
+
290
+ def constraint(pred, inputs=None):
291
+ # Monotonicity: y0 < y1 < y2
292
+ diffs = pred[:, 1:] - pred[:, :-1]
293
+ return torch.relu(-diffs).sum(dim=1)
294
+
295
+ Usage:
296
+ >>> constraint = FileConstraint("my_constraint.py")
297
+ >>> penalty = constraint(predictions, inputs)
298
+ """
299
+
300
+ def __init__(self, file_path: str, reduction: str = "mse"):
301
+ """
302
+ Args:
303
+ file_path: Path to Python file containing constraint function
304
+ reduction: How to reduce violations - 'mse' or 'mae'
305
+ """
306
+ super().__init__()
307
+ self.file_path = file_path
308
+ self.reduction = reduction
309
+
310
+ # Load module from file
311
+ spec = importlib.util.spec_from_file_location("constraint_module", file_path)
312
+ if spec is None or spec.loader is None:
313
+ raise ValueError(f"Could not load constraint file: {file_path}")
314
+
315
+ module = importlib.util.module_from_spec(spec)
316
+ sys.modules["constraint_module"] = module
317
+ spec.loader.exec_module(module)
318
+
319
+ if not hasattr(module, "constraint"):
320
+ raise ValueError(
321
+ f"Constraint file must define 'constraint(pred, inputs)' function: {file_path}"
322
+ )
323
+
324
+ self._constraint_fn = module.constraint
325
+
326
+ def forward(
327
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
328
+ ) -> torch.Tensor:
329
+ """Evaluate constraint from loaded function."""
330
+ violation = self._constraint_fn(pred, inputs)
331
+ if self.reduction == "mse":
332
+ return (violation**2).mean()
333
+ else:
334
+ return violation.abs().mean()
335
+
336
+ def __repr__(self) -> str:
337
+ return f"FileConstraint('{self.file_path}')"
338
+
339
+
340
+ # ==============================================================================
341
+ # COMBINED LOSS WRAPPER
342
+ # ==============================================================================
343
+ class PhysicsConstrainedLoss(nn.Module):
344
+ """
345
+ Combine base loss with constraint penalties.
346
+
347
+ Total Loss = Base Loss + Σ(weight_i × constraint_i)
348
+
349
+ Constraints are evaluated in **physical space** (denormalized) while
350
+ the base loss is computed in normalized space for stable training.
351
+
352
+ Example:
353
+ >>> base_loss = nn.MSELoss()
354
+ >>> constraints = [ExpressionConstraint("y0 - y1*y2")]
355
+ >>> criterion = PhysicsConstrainedLoss(
356
+ ... base_loss,
357
+ ... constraints,
358
+ ... weights=[0.1],
359
+ ... output_mean=[10, 5, 50],
360
+ ... output_std=[2, 1, 10],
361
+ ... )
362
+ >>> loss = criterion(pred, target, inputs)
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ base_loss: nn.Module,
368
+ constraints: list[nn.Module] | None = None,
369
+ weights: list[float] | None = None,
370
+ output_mean: torch.Tensor | list[float] | None = None,
371
+ output_std: torch.Tensor | list[float] | None = None,
372
+ ):
373
+ """
374
+ Args:
375
+ base_loss: Base loss function (e.g., MSELoss)
376
+ constraints: List of constraint modules
377
+ weights: Weight for each constraint. If shorter than constraints,
378
+ last weight is repeated.
379
+ output_mean: Mean of each output (for denormalization). Shape: (num_outputs,)
380
+ output_std: Std of each output (for denormalization). Shape: (num_outputs,)
381
+ """
382
+ super().__init__()
383
+ self.base_loss = base_loss
384
+ self.constraints = nn.ModuleList(constraints or [])
385
+ self.weights = weights or [0.1]
386
+
387
+ # Store scaler as buffers (moves to correct device automatically)
388
+ if output_mean is not None:
389
+ if not isinstance(output_mean, torch.Tensor):
390
+ output_mean = torch.tensor(output_mean, dtype=torch.float32)
391
+ self.register_buffer("output_mean", output_mean)
392
+ else:
393
+ self.register_buffer("output_mean", None)
394
+
395
+ if output_std is not None:
396
+ if not isinstance(output_std, torch.Tensor):
397
+ output_std = torch.tensor(output_std, dtype=torch.float32)
398
+ self.register_buffer("output_std", output_std)
399
+ else:
400
+ self.register_buffer("output_std", None)
401
+
402
+ def _denormalize(self, pred: torch.Tensor) -> torch.Tensor:
403
+ """Convert normalized predictions to physical values."""
404
+ if self.output_mean is None or self.output_std is None:
405
+ return pred
406
+ return pred * self.output_std + self.output_mean
407
+
408
+ def forward(
409
+ self,
410
+ pred: torch.Tensor,
411
+ target: torch.Tensor,
412
+ inputs: torch.Tensor | None = None,
413
+ ) -> torch.Tensor:
414
+ """
415
+ Compute combined loss.
416
+
417
+ Args:
418
+ pred: Model predictions (normalized)
419
+ target: Ground truth targets (normalized)
420
+ inputs: Model inputs (for input-dependent constraints)
421
+
422
+ Returns:
423
+ Combined loss value
424
+ """
425
+ # Base loss in normalized space (stable gradients)
426
+ loss = self.base_loss(pred, target)
427
+
428
+ # Denormalize for constraint evaluation (physical units)
429
+ pred_physical = self._denormalize(pred)
430
+
431
+ for i, constraint in enumerate(self.constraints):
432
+ weight = self.weights[i] if i < len(self.weights) else self.weights[-1]
433
+ penalty = constraint(pred_physical, inputs)
434
+ loss = loss + weight * penalty
435
+
436
+ return loss
437
+
438
+ def __repr__(self) -> str:
439
+ has_scaler = self.output_mean is not None
440
+ return f"PhysicsConstrainedLoss(base={self.base_loss}, constraints={len(self.constraints)}, denormalize={has_scaler})"
441
+
442
+
443
+ # ==============================================================================
444
+ # FACTORY FUNCTIONS
445
+ # ==============================================================================
446
+ def build_constraints(
447
+ expressions: list[str] | None = None,
448
+ file_path: str | None = None,
449
+ reduction: str = "mse",
450
+ ) -> list[nn.Module]:
451
+ """
452
+ Build soft constraint modules from CLI arguments.
453
+
454
+ Args:
455
+ expressions: Expression constraints (e.g., ["y0 - y1*y2", "y0 > 0"])
456
+ file_path: Path to Python constraint file
457
+ reduction: Reduction mode for penalties
458
+
459
+ Returns:
460
+ List of constraint modules
461
+ """
462
+ constraints: list[nn.Module] = []
463
+
464
+ for expr in expressions or []:
465
+ constraints.append(ExpressionConstraint(expr, reduction))
466
+
467
+ if file_path:
468
+ constraints.append(FileConstraint(file_path, reduction))
469
+
470
+ return constraints
wavedl/utils/metrics.py CHANGED
@@ -560,11 +560,56 @@ def create_training_curves(
560
560
  )
561
561
  lines.append(line)
562
562
 
563
+ def set_lr_ticks(ax: plt.Axes, data: list[float], n_ticks: int = 4) -> None:
564
+ """Set n uniformly spaced ticks on LR axis with 10^n format labels."""
565
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
566
+ if not valid_data:
567
+ return
568
+ vmin, vmax = min(valid_data), max(valid_data)
569
+ # Snap to clean decade boundaries
570
+ log_min = np.floor(np.log10(vmin))
571
+ log_max = np.ceil(np.log10(vmax))
572
+ # Generate n uniformly spaced ticks as powers of 10
573
+ log_ticks = np.linspace(log_min, log_max, n_ticks)
574
+ # Round to nearest integer power of 10 for clean numbers
575
+ log_ticks = np.round(log_ticks)
576
+ ticks = 10.0**log_ticks
577
+ # Remove duplicates while preserving order
578
+ ticks = list(dict.fromkeys(ticks))
579
+ ax.set_yticks(ticks)
580
+ # Format all tick labels as 10^n
581
+ labels = [f"$10^{{{int(np.log10(t))}}}$" for t in ticks]
582
+ ax.set_yticklabels(labels)
583
+ ax.minorticks_off()
584
+
585
+ def set_loss_ticks(ax: plt.Axes, data: list[float]) -> None:
586
+ """Set ticks at powers of 10 that cover the data range."""
587
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
588
+ if not valid_data:
589
+ return
590
+ vmin, vmax = min(valid_data), max(valid_data)
591
+ # Get decade range that covers data (ceil for min to avoid going too low)
592
+ log_min = int(np.ceil(np.log10(vmin)))
593
+ log_max = int(np.ceil(np.log10(vmax)))
594
+ # Generate ticks at each power of 10
595
+ ticks = [10.0**i for i in range(log_min, log_max + 1)]
596
+ ax.set_yticks(ticks)
597
+ # Format labels as 10^n
598
+ labels = [f"$10^{{{i}}}$" for i in range(log_min, log_max + 1)]
599
+ ax.set_yticklabels(labels)
600
+ ax.minorticks_off()
601
+
563
602
  ax1.set_xlabel("Epoch")
564
603
  ax1.set_ylabel("Loss")
565
604
  ax1.set_yscale("log") # Log scale for loss
566
605
  ax1.grid(True, alpha=0.3)
567
606
 
607
+ # Collect all loss values and set clean power of 10 ticks
608
+ all_loss_values = []
609
+ for metric in metrics:
610
+ all_loss_values.extend([h.get(metric, np.nan) for h in history])
611
+ set_loss_ticks(ax1, all_loss_values)
612
+
568
613
  # Check if learning rate data exists
569
614
  has_lr = show_lr and any("lr" in h for h in history)
570
615
 
@@ -581,9 +626,11 @@ def create_training_curves(
581
626
  alpha=0.7,
582
627
  label="Learning Rate",
583
628
  )
584
- ax2.set_ylabel("Learning Rate", color=COLORS["neutral"])
585
- ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
629
+ ax2.set_ylabel("Learning Rate")
586
630
  ax2.set_yscale("log") # Log scale for LR
631
+ set_lr_ticks(ax2, lr_values, n_ticks=4)
632
+ # Ensure right spine (axis line) is visible
633
+ ax2.spines["right"].set_visible(True)
587
634
  lines.append(line_lr)
588
635
 
589
636
  # Combined legend
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.6
3
+ Version: 1.5.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
113
113
  </td>
114
114
  <td width="50%" valign="top">
115
115
 
116
- **🧠 One-Line Model Registration**
116
+ **🧠 Models? We've Got Options**
117
117
 
118
- Plug in any architecture:
119
- ```python
120
- @register_model("my_net")
121
- class MyNet(BaseModel): ...
122
- ```
123
- Design your model. Register with one line.
118
+ 38 architectures, ready to go:
119
+ - CNNs, ResNets, ViTs, EfficientNets...
120
+ - All adapted for regression
121
+ - [Add your own](#adding-custom-models) in one line
124
122
 
125
123
  </td>
126
124
  </tr>
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
137
135
  </td>
138
136
  <td width="50%" valign="top">
139
137
 
140
- **📊 Publish-Ready Output**
138
+ **🔬 Physics-Constrained Training**
141
139
 
142
- Results go straight to your paper:
143
- - 11 diagnostic plots with LaTeX styling
144
- - Multi-format export (PNG, PDF, SVG, ...)
145
- - MAE in physical units per parameter
140
+ Make your model respect the laws:
141
+ - Enforce bounds, positivity, equations
142
+ - Simple expression syntax or Python
143
+ - [Custom constraints](#physical-constraints) for various laws
146
144
 
147
145
  </td>
148
146
  </tr>
@@ -383,7 +381,7 @@ WaveDL/
383
381
  ├── configs/ # YAML config templates
384
382
  ├── examples/ # Ready-to-run examples
385
383
  ├── notebooks/ # Jupyter notebooks
386
- ├── unit_tests/ # Pytest test suite (704 tests)
384
+ ├── unit_tests/ # Pytest test suite (725 tests)
387
385
 
388
386
  ├── pyproject.toml # Package config, dependencies
389
387
  ├── CHANGELOG.md # Version history
@@ -727,6 +725,104 @@ seed: 2025
727
725
 
728
726
  </details>
729
727
 
728
+ <details>
729
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
730
+
731
+ Add penalty terms to the loss function to enforce physical laws:
732
+
733
+ ```
734
+ Total Loss = Data Loss + weight × penalty(violation)
735
+ ```
736
+
737
+ ### Expression Constraints
738
+
739
+ ```bash
740
+ # Positivity
741
+ --constraint "y0 > 0"
742
+
743
+ # Bounds
744
+ --constraint "y0 >= 0" "y0 <= 1"
745
+
746
+ # Equations (penalize deviations from zero)
747
+ --constraint "y2 - y0 * y1"
748
+
749
+ # Input-dependent constraints
750
+ --constraint "y0 - 2*x[0]"
751
+
752
+ # Multiple constraints with different weights
753
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
754
+ ```
755
+
756
+ ### Custom Python Constraints
757
+
758
+ For complex physics (matrix operations, implicit equations):
759
+
760
+ ```python
761
+ # my_constraint.py
762
+ import torch
763
+
764
+ def constraint(pred, inputs=None):
765
+ """
766
+ Args:
767
+ pred: (batch, num_outputs)
768
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
769
+ Returns:
770
+ (batch,) — violation per sample (0 = satisfied)
771
+ """
772
+ # Outputs (same for all data types)
773
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
774
+
775
+ # Inputs — Tabular: (batch, features)
776
+ # x0 = inputs[:, 0] # Feature 0
777
+ # x_sum = inputs.sum(dim=1) # Sum all features
778
+
779
+ # Inputs — Images: (batch, C, H, W)
780
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
781
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
782
+
783
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
784
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
785
+
786
+ # Example constraints:
787
+ # return y2 - y0 * y1 # Wave equation
788
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
789
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
790
+
791
+ return y0 - y1 * y2
792
+ ```
793
+
794
+ ```bash
795
+ --constraint_file my_constraint.py --constraint_weight 1.0
796
+ ```
797
+
798
+ ---
799
+
800
+ ### Reference
801
+
802
+ | Argument | Default | Description |
803
+ |----------|---------|-------------|
804
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
805
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
806
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
807
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
808
+
809
+ #### Expression Syntax
810
+
811
+ | Variable | Meaning |
812
+ |----------|---------|
813
+ | `y0`, `y1`, ... | Model outputs |
814
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
815
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
816
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
817
+
818
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
819
+
820
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
821
+
822
+ </details>
823
+
824
+
825
+
730
826
  <details>
731
827
  <summary><b>Hyperparameter Search (HPO)</b></summary>
732
828
 
@@ -766,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
766
862
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
767
863
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
768
864
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
769
- | Batch size | 64, 128, 256, 512 | (always searched) |
865
+ | Batch size | 16, 32, 64, 128 | (always searched) |
770
866
 
771
867
  **Quick Mode** (`--quick`):
772
868
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -938,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
938
1034
  ```bash
939
1035
  # Run inference on the example data
940
1036
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
941
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1037
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
942
1038
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
943
1039
 
944
1040
  # Export to ONNX (already included as model.onnx)
945
1041
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
946
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1042
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
947
1043
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
948
1044
  ```
949
1045
 
@@ -952,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
952
1048
  | File | Description |
953
1049
  |------|-------------|
954
1050
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
955
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1051
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
956
1052
  | `model.onnx` | ONNX export with embedded de-normalization |
957
1053
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
958
1054
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -963,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
963
1059
 
964
1060
  <p align="center">
965
1061
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
966
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1062
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
967
1063
  </p>
968
1064
 
969
1065
  **Inference Results:**
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=ItdZLt3f7sbtAMgiwUtGwwG5Cko4tPLugC_OVhfHMno,1177
1
+ wavedl/__init__.py,sha256=jeYx6dZ0_UD4yXl-I4_Aa63820RzS5-IxJP_KUni8Pw,1177
2
2
  wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
3
- wavedl/hpo.py,sha256=JQvwPgiVHj3sB9Wombn1QO4ammpuo0QAMpRee0LjkuI,14731
3
+ wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
4
4
  wavedl/test.py,sha256=oWGSSC7178loqOxwti-oDXUVogOqbwHL__GfoXSE5Ss,37846
5
- wavedl/train.py,sha256=9l4aVW1Jd1Sq6yBr8BOoVIKUYmxASDO8XK6BqEkLLWs,50151
5
+ wavedl/train.py,sha256=Xy0-W9Kxqdc7WaAyt_HDSZwhZKrfjVbZd2KzgSid6rQ,52457
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -20,18 +20,19 @@ wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
20
20
  wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
21
21
  wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
22
22
  wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
23
- wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
24
- wavedl/utils/config.py,sha256=fMoucikIQHn85mVhGMa7TnXTuFDcEEPjfXk2EjbkJR0,10591
23
+ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
+ wavedl/utils/config.py,sha256=jGW-K7AYB6zrD2BfVm2XPnSY9rbfL_EkM4bwxhBLuwM,10859
25
+ wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
25
26
  wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
26
27
  wavedl/utils/data.py,sha256=_OaWvU5oFVJW0NwM5WyDD0Kb1hy5MgvJIFpzvJGux9w,48214
27
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
28
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
- wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
+ wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
30
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.6.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.6.dist-info/METADATA,sha256=Hnot8ui2oksCz2UXhj3FHd_Z9MtoP8MJyiMzC6eWq5s,42453
34
- wavedl-1.4.6.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.6.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.6.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.6.dist-info/RECORD,,
33
+ wavedl-1.5.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.0.dist-info/METADATA,sha256=jlQPfPLtwIXzWHKwgN8xQjzco_xB7L3j7p_dl3TJA8E,45224
35
+ wavedl-1.5.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.0.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.0.dist-info/RECORD,,
File without changes