wavedl 1.4.6__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +9 -1
- wavedl/train.py +73 -6
- wavedl/utils/__init__.py +11 -0
- wavedl/utils/config.py +10 -0
- wavedl/utils/constraints.py +470 -0
- wavedl/utils/metrics.py +49 -2
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/METADATA +115 -19
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/RECORD +13 -12
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/LICENSE +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/WHEEL +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.0.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -89,7 +89,8 @@ def create_objective(args):
|
|
|
89
89
|
# Suggest hyperparameters
|
|
90
90
|
model = trial.suggest_categorical("model", models)
|
|
91
91
|
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
-
|
|
92
|
+
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
+
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
93
94
|
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
94
95
|
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
95
96
|
loss = trial.suggest_categorical("loss", losses)
|
|
@@ -317,6 +318,13 @@ Examples:
|
|
|
317
318
|
default=None,
|
|
318
319
|
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
319
320
|
)
|
|
321
|
+
parser.add_argument(
|
|
322
|
+
"--batch_sizes",
|
|
323
|
+
type=int,
|
|
324
|
+
nargs="+",
|
|
325
|
+
default=None,
|
|
326
|
+
help="Batch sizes to search (default: 16 32 64 128)",
|
|
327
|
+
)
|
|
320
328
|
|
|
321
329
|
# Training settings for each trial
|
|
322
330
|
parser.add_argument(
|
wavedl/train.py
CHANGED
|
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
|
|
|
375
375
|
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
376
376
|
)
|
|
377
377
|
|
|
378
|
+
# Physical Constraints
|
|
379
|
+
parser.add_argument(
|
|
380
|
+
"--constraint",
|
|
381
|
+
type=str,
|
|
382
|
+
nargs="+",
|
|
383
|
+
default=[],
|
|
384
|
+
help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
parser.add_argument(
|
|
388
|
+
"--constraint_file",
|
|
389
|
+
type=str,
|
|
390
|
+
default=None,
|
|
391
|
+
help="Python file with constraint(pred, inputs) function",
|
|
392
|
+
)
|
|
393
|
+
parser.add_argument(
|
|
394
|
+
"--constraint_weight",
|
|
395
|
+
type=float,
|
|
396
|
+
nargs="+",
|
|
397
|
+
default=[0.1],
|
|
398
|
+
help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
|
|
399
|
+
)
|
|
400
|
+
parser.add_argument(
|
|
401
|
+
"--constraint_reduction",
|
|
402
|
+
type=str,
|
|
403
|
+
default="mse",
|
|
404
|
+
choices=["mse", "mae"],
|
|
405
|
+
help="Reduction mode for constraint penalties",
|
|
406
|
+
)
|
|
407
|
+
|
|
378
408
|
# Logging
|
|
379
409
|
parser.add_argument(
|
|
380
410
|
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
@@ -553,7 +583,7 @@ def main():
|
|
|
553
583
|
return
|
|
554
584
|
|
|
555
585
|
# ==========================================================================
|
|
556
|
-
#
|
|
586
|
+
# SYSTEM INITIALIZATION
|
|
557
587
|
# ==========================================================================
|
|
558
588
|
# Initialize Accelerator for DDP and mixed precision
|
|
559
589
|
accelerator = Accelerator(
|
|
@@ -609,7 +639,7 @@ def main():
|
|
|
609
639
|
)
|
|
610
640
|
|
|
611
641
|
# ==========================================================================
|
|
612
|
-
#
|
|
642
|
+
# DATA & MODEL LOADING
|
|
613
643
|
# ==========================================================================
|
|
614
644
|
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
615
645
|
args, logger, accelerator, cache_dir=args.output_dir
|
|
@@ -663,7 +693,7 @@ def main():
|
|
|
663
693
|
)
|
|
664
694
|
|
|
665
695
|
# ==========================================================================
|
|
666
|
-
#
|
|
696
|
+
# OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
667
697
|
# ==========================================================================
|
|
668
698
|
# Parse comma-separated arguments with validation
|
|
669
699
|
try:
|
|
@@ -707,6 +737,43 @@ def main():
|
|
|
707
737
|
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
708
738
|
criterion = criterion.to(accelerator.device)
|
|
709
739
|
|
|
740
|
+
# ==========================================================================
|
|
741
|
+
# PHYSICAL CONSTRAINTS INTEGRATION
|
|
742
|
+
# ==========================================================================
|
|
743
|
+
from wavedl.utils.constraints import (
|
|
744
|
+
PhysicsConstrainedLoss,
|
|
745
|
+
build_constraints,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# Build soft constraints
|
|
749
|
+
soft_constraints = build_constraints(
|
|
750
|
+
expressions=args.constraint,
|
|
751
|
+
file_path=args.constraint_file,
|
|
752
|
+
reduction=args.constraint_reduction,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
|
|
756
|
+
if soft_constraints:
|
|
757
|
+
# Pass output scaler so constraints can be evaluated in physical space
|
|
758
|
+
output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
|
|
759
|
+
output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
|
|
760
|
+
criterion = PhysicsConstrainedLoss(
|
|
761
|
+
criterion,
|
|
762
|
+
soft_constraints,
|
|
763
|
+
weights=args.constraint_weight,
|
|
764
|
+
output_mean=output_mean,
|
|
765
|
+
output_std=output_std,
|
|
766
|
+
)
|
|
767
|
+
if accelerator.is_main_process:
|
|
768
|
+
logger.info(
|
|
769
|
+
f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
|
|
770
|
+
f"with weight(s) {args.constraint_weight}"
|
|
771
|
+
)
|
|
772
|
+
if output_mean is not None:
|
|
773
|
+
logger.info(
|
|
774
|
+
" 📐 Constraints evaluated in physical space (denormalized)"
|
|
775
|
+
)
|
|
776
|
+
|
|
710
777
|
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
711
778
|
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
712
779
|
|
|
@@ -762,7 +829,7 @@ def main():
|
|
|
762
829
|
)
|
|
763
830
|
|
|
764
831
|
# ==========================================================================
|
|
765
|
-
#
|
|
832
|
+
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
766
833
|
# ==========================================================================
|
|
767
834
|
start_epoch = 0
|
|
768
835
|
best_val_loss = float("inf")
|
|
@@ -818,7 +885,7 @@ def main():
|
|
|
818
885
|
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
819
886
|
|
|
820
887
|
# ==========================================================================
|
|
821
|
-
#
|
|
888
|
+
# PHYSICAL METRIC SETUP
|
|
822
889
|
# ==========================================================================
|
|
823
890
|
# Physical MAE = normalized MAE * scaler.scale_
|
|
824
891
|
phys_scale = torch.tensor(
|
|
@@ -826,7 +893,7 @@ def main():
|
|
|
826
893
|
)
|
|
827
894
|
|
|
828
895
|
# ==========================================================================
|
|
829
|
-
#
|
|
896
|
+
# TRAINING LOOP
|
|
830
897
|
# ==========================================================================
|
|
831
898
|
# Dynamic console header
|
|
832
899
|
if accelerator.is_main_process:
|
wavedl/utils/__init__.py
CHANGED
|
@@ -15,6 +15,12 @@ from .config import (
|
|
|
15
15
|
save_config,
|
|
16
16
|
validate_config,
|
|
17
17
|
)
|
|
18
|
+
from .constraints import (
|
|
19
|
+
ExpressionConstraint,
|
|
20
|
+
FileConstraint,
|
|
21
|
+
PhysicsConstrainedLoss,
|
|
22
|
+
build_constraints,
|
|
23
|
+
)
|
|
18
24
|
from .cross_validation import (
|
|
19
25
|
CVDataset,
|
|
20
26
|
run_cross_validation,
|
|
@@ -91,8 +97,11 @@ __all__ = [
|
|
|
91
97
|
"FIGURE_WIDTH_INCH",
|
|
92
98
|
"FONT_SIZE_TEXT",
|
|
93
99
|
"FONT_SIZE_TICKS",
|
|
100
|
+
# Constraints
|
|
94
101
|
"CVDataset",
|
|
95
102
|
"DataSource",
|
|
103
|
+
"ExpressionConstraint",
|
|
104
|
+
"FileConstraint",
|
|
96
105
|
"HDF5Source",
|
|
97
106
|
"LogCoshLoss",
|
|
98
107
|
"MATSource",
|
|
@@ -101,10 +110,12 @@ __all__ = [
|
|
|
101
110
|
# Metrics
|
|
102
111
|
"MetricTracker",
|
|
103
112
|
"NPZSource",
|
|
113
|
+
"PhysicsConstrainedLoss",
|
|
104
114
|
"WeightedMSELoss",
|
|
105
115
|
# Distributed
|
|
106
116
|
"broadcast_early_stop",
|
|
107
117
|
"broadcast_value",
|
|
118
|
+
"build_constraints",
|
|
108
119
|
"calc_pearson",
|
|
109
120
|
"calc_per_target_r2",
|
|
110
121
|
"configure_matplotlib_style",
|
wavedl/utils/config.py
CHANGED
|
@@ -306,6 +306,16 @@ def validate_config(
|
|
|
306
306
|
# Config
|
|
307
307
|
"config",
|
|
308
308
|
"list_models",
|
|
309
|
+
# Physical Constraints
|
|
310
|
+
"constraint",
|
|
311
|
+
"bounds",
|
|
312
|
+
"constraint_file",
|
|
313
|
+
"constraint_weight",
|
|
314
|
+
"constraint_reduction",
|
|
315
|
+
"positive",
|
|
316
|
+
"output_bounds",
|
|
317
|
+
"output_transform",
|
|
318
|
+
"output_formula",
|
|
309
319
|
# Metadata (internal)
|
|
310
320
|
"_metadata",
|
|
311
321
|
}
|
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Physical Constraints for Training
|
|
3
|
+
=================================
|
|
4
|
+
|
|
5
|
+
Soft constraint enforcement via penalty-based loss terms.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Expression constraints
|
|
9
|
+
wavedl-train --constraint "y0 > 0" --constraint_weight 0.1
|
|
10
|
+
|
|
11
|
+
# Complex constraints via Python file
|
|
12
|
+
wavedl-train --constraint_file my_constraint.py
|
|
13
|
+
|
|
14
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
15
|
+
Version: 2.0.0
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import ast
|
|
21
|
+
import importlib.util
|
|
22
|
+
import sys
|
|
23
|
+
from typing import TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from collections.abc import Callable
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ==============================================================================
|
|
35
|
+
# SAFE EXPRESSION PARSING
|
|
36
|
+
# ==============================================================================
|
|
37
|
+
SAFE_FUNCTIONS: dict[str, Callable] = {
|
|
38
|
+
"sin": torch.sin,
|
|
39
|
+
"cos": torch.cos,
|
|
40
|
+
"tan": torch.tan,
|
|
41
|
+
"exp": torch.exp,
|
|
42
|
+
"log": torch.log,
|
|
43
|
+
"sqrt": torch.sqrt,
|
|
44
|
+
"abs": torch.abs,
|
|
45
|
+
"relu": F.relu,
|
|
46
|
+
"sigmoid": torch.sigmoid,
|
|
47
|
+
"softplus": F.softplus,
|
|
48
|
+
"tanh": torch.tanh,
|
|
49
|
+
"min": torch.minimum,
|
|
50
|
+
"max": torch.maximum,
|
|
51
|
+
"pow": torch.pow,
|
|
52
|
+
"clamp": torch.clamp,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
INPUT_AGGREGATES: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
|
|
56
|
+
"x_mean": lambda x: x.mean(dim=tuple(range(1, x.ndim))),
|
|
57
|
+
"x_sum": lambda x: x.sum(dim=tuple(range(1, x.ndim))),
|
|
58
|
+
"x_max": lambda x: x.amax(dim=tuple(range(1, x.ndim))),
|
|
59
|
+
"x_min": lambda x: x.amin(dim=tuple(range(1, x.ndim))),
|
|
60
|
+
"x_std": lambda x: x.std(dim=tuple(range(1, x.ndim))),
|
|
61
|
+
"x_energy": lambda x: (x**2).sum(dim=tuple(range(1, x.ndim))),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ==============================================================================
|
|
66
|
+
# SOFT CONSTRAINTS
|
|
67
|
+
# ==============================================================================
|
|
68
|
+
class ExpressionConstraint(nn.Module):
|
|
69
|
+
"""
|
|
70
|
+
Soft constraint via string expression.
|
|
71
|
+
|
|
72
|
+
Parses mathematical expressions using Python's AST for safe evaluation.
|
|
73
|
+
Supports output variables (y0, y1, ...), input aggregates (x_mean, ...),
|
|
74
|
+
and whitelisted math functions.
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
>>> constraint = ExpressionConstraint("y0 - y1 * y2")
|
|
78
|
+
>>> penalty = constraint(predictions, inputs)
|
|
79
|
+
|
|
80
|
+
>>> constraint = ExpressionConstraint("sin(y0) + cos(y1)")
|
|
81
|
+
>>> penalty = constraint(predictions, inputs)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, expression: str, reduction: str = "mse"):
|
|
85
|
+
"""
|
|
86
|
+
Args:
|
|
87
|
+
expression: Mathematical expression to evaluate (should equal 0)
|
|
88
|
+
reduction: How to reduce violations - 'mse' or 'mae'
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.expression = expression
|
|
92
|
+
self.reduction = reduction
|
|
93
|
+
self._tree = ast.parse(expression, mode="eval")
|
|
94
|
+
self._validate(self._tree)
|
|
95
|
+
|
|
96
|
+
def _validate(self, tree: ast.Expression) -> None:
|
|
97
|
+
"""Validate that expression only uses safe functions."""
|
|
98
|
+
for node in ast.walk(tree):
|
|
99
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
100
|
+
if node.func.id not in SAFE_FUNCTIONS:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Unsafe function '{node.func.id}' in constraint. "
|
|
103
|
+
f"Allowed: {list(SAFE_FUNCTIONS.keys())}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _eval(
|
|
107
|
+
self, node: ast.AST, pred: torch.Tensor, inputs: torch.Tensor | None
|
|
108
|
+
) -> torch.Tensor:
|
|
109
|
+
"""Recursively evaluate AST node."""
|
|
110
|
+
if isinstance(node, ast.Constant):
|
|
111
|
+
return torch.tensor(node.value, device=pred.device, dtype=pred.dtype)
|
|
112
|
+
elif isinstance(node, ast.Name):
|
|
113
|
+
name = node.id
|
|
114
|
+
# Output variable: y0, y1, ...
|
|
115
|
+
if name.startswith("y") and name[1:].isdigit():
|
|
116
|
+
idx = int(name[1:])
|
|
117
|
+
if idx >= pred.shape[1]:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Output index {idx} out of range. "
|
|
120
|
+
f"Model has {pred.shape[1]} outputs."
|
|
121
|
+
)
|
|
122
|
+
return pred[:, idx]
|
|
123
|
+
# Input aggregate: x_mean, x_sum, ...
|
|
124
|
+
elif name in INPUT_AGGREGATES:
|
|
125
|
+
if inputs is None:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Constraint uses '{name}' but inputs not provided."
|
|
128
|
+
)
|
|
129
|
+
return INPUT_AGGREGATES[name](inputs)
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Unknown variable '{name}'. "
|
|
133
|
+
f"Use y0, y1, ... for outputs or {list(INPUT_AGGREGATES.keys())} for inputs."
|
|
134
|
+
)
|
|
135
|
+
elif isinstance(node, ast.BinOp):
|
|
136
|
+
left = self._eval(node.left, pred, inputs)
|
|
137
|
+
right = self._eval(node.right, pred, inputs)
|
|
138
|
+
ops = {
|
|
139
|
+
ast.Add: torch.add,
|
|
140
|
+
ast.Sub: torch.sub,
|
|
141
|
+
ast.Mult: torch.mul,
|
|
142
|
+
ast.Div: torch.div,
|
|
143
|
+
ast.Pow: torch.pow,
|
|
144
|
+
ast.Mod: torch.remainder,
|
|
145
|
+
}
|
|
146
|
+
if type(node.op) not in ops:
|
|
147
|
+
raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
|
|
148
|
+
return ops[type(node.op)](left, right)
|
|
149
|
+
elif isinstance(node, ast.UnaryOp):
|
|
150
|
+
operand = self._eval(node.operand, pred, inputs)
|
|
151
|
+
if isinstance(node.op, ast.USub):
|
|
152
|
+
return -operand
|
|
153
|
+
elif isinstance(node.op, ast.UAdd):
|
|
154
|
+
return operand
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Unsupported unary operator: {type(node.op).__name__}"
|
|
158
|
+
)
|
|
159
|
+
elif isinstance(node, ast.Call):
|
|
160
|
+
if not isinstance(node.func, ast.Name):
|
|
161
|
+
raise ValueError("Only direct function calls supported (e.g., sin(x))")
|
|
162
|
+
func_name = node.func.id
|
|
163
|
+
if func_name not in SAFE_FUNCTIONS:
|
|
164
|
+
raise ValueError(f"Unsafe function: {func_name}")
|
|
165
|
+
args = [self._eval(arg, pred, inputs) for arg in node.args]
|
|
166
|
+
return SAFE_FUNCTIONS[func_name](*args)
|
|
167
|
+
elif isinstance(node, ast.Compare):
|
|
168
|
+
# Comparison operators: y0 > 0, y0 < 1, y0 >= 0, y0 <= 1
|
|
169
|
+
# Returns penalty (violation amount) when constraint is not satisfied
|
|
170
|
+
if len(node.ops) != 1 or len(node.comparators) != 1:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Only single comparisons supported (e.g., 'y0 > 0', not 'y0 > 0 > y1')"
|
|
173
|
+
)
|
|
174
|
+
left = self._eval(node.left, pred, inputs)
|
|
175
|
+
right = self._eval(node.comparators[0], pred, inputs)
|
|
176
|
+
op = node.ops[0]
|
|
177
|
+
|
|
178
|
+
# Return violation amount (0 if satisfied, positive if violated)
|
|
179
|
+
if isinstance(
|
|
180
|
+
op, (ast.Gt, ast.GtE)
|
|
181
|
+
): # y0 > right → penalize if y0 <= right
|
|
182
|
+
return F.relu(right - left)
|
|
183
|
+
elif isinstance(
|
|
184
|
+
op, (ast.Lt, ast.LtE)
|
|
185
|
+
): # y0 < right → penalize if y0 >= right
|
|
186
|
+
return F.relu(left - right)
|
|
187
|
+
elif isinstance(op, ast.Eq): # y0 == right → penalize difference
|
|
188
|
+
return torch.abs(left - right)
|
|
189
|
+
elif isinstance(op, ast.NotEq): # y0 != right → not useful as constraint
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"'!=' is not a valid constraint. Use '==' for equality constraints."
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Unsupported comparison operator: {type(op).__name__}"
|
|
196
|
+
)
|
|
197
|
+
elif isinstance(node, ast.Subscript):
|
|
198
|
+
# Input indexing: x[0], x[0,5], x[0,5,10]
|
|
199
|
+
if not isinstance(node.value, ast.Name) or node.value.id != "x":
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Subscript indexing only supported for 'x' (inputs). "
|
|
202
|
+
"Use x[i], x[i,j], or x[i,j,k]."
|
|
203
|
+
)
|
|
204
|
+
if inputs is None:
|
|
205
|
+
raise ValueError("Constraint uses 'x[...]' but inputs not provided.")
|
|
206
|
+
|
|
207
|
+
# Parse indices from the slice
|
|
208
|
+
indices = self._parse_subscript_indices(node.slice)
|
|
209
|
+
|
|
210
|
+
# Validate dimensions match
|
|
211
|
+
# inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
|
|
212
|
+
input_ndim = inputs.ndim - 1 # Exclude batch dimension
|
|
213
|
+
if len(indices) != input_ndim:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
|
|
216
|
+
f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Extract the value at the specified indices (for entire batch)
|
|
220
|
+
if len(indices) == 1:
|
|
221
|
+
return inputs[:, indices[0]]
|
|
222
|
+
elif len(indices) == 2:
|
|
223
|
+
return inputs[:, indices[0], indices[1]]
|
|
224
|
+
elif len(indices) == 3:
|
|
225
|
+
return inputs[:, indices[0], indices[1], indices[2]]
|
|
226
|
+
else:
|
|
227
|
+
raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
|
|
228
|
+
elif isinstance(node, ast.Expression):
|
|
229
|
+
return self._eval(node.body, pred, inputs)
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(f"Unsupported AST node type: {type(node).__name__}")
|
|
232
|
+
|
|
233
|
+
def _parse_subscript_indices(self, slice_node: ast.AST) -> list[int]:
|
|
234
|
+
"""Parse subscript indices from AST slice node."""
|
|
235
|
+
if isinstance(slice_node, ast.Constant):
|
|
236
|
+
# Single index: x[0]
|
|
237
|
+
return [int(slice_node.value)]
|
|
238
|
+
elif isinstance(slice_node, ast.Tuple):
|
|
239
|
+
# Multiple indices: x[0,5] or x[0,5,10]
|
|
240
|
+
indices = []
|
|
241
|
+
for elt in slice_node.elts:
|
|
242
|
+
if not isinstance(elt, ast.Constant):
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Only constant indices supported in x[...]. "
|
|
245
|
+
"Use x[0,5] not x[i,j]."
|
|
246
|
+
)
|
|
247
|
+
indices.append(int(elt.value))
|
|
248
|
+
return indices
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Unsupported subscript type: {type(slice_node).__name__}. "
|
|
252
|
+
"Use x[0], x[0,5], or x[0,5,10]."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def forward(
|
|
256
|
+
self, pred: torch.Tensor, inputs: torch.Tensor | None = None
|
|
257
|
+
) -> torch.Tensor:
|
|
258
|
+
"""
|
|
259
|
+
Compute constraint violation penalty.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
pred: Model predictions of shape (N, num_outputs)
|
|
263
|
+
inputs: Model inputs of shape (N, ...) for input-dependent constraints
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Scalar penalty value
|
|
267
|
+
"""
|
|
268
|
+
violation = self._eval(self._tree, pred, inputs)
|
|
269
|
+
if self.reduction == "mse":
|
|
270
|
+
return (violation**2).mean()
|
|
271
|
+
else: # mae
|
|
272
|
+
return violation.abs().mean()
|
|
273
|
+
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
return (
|
|
276
|
+
f"ExpressionConstraint('{self.expression}', reduction='{self.reduction}')"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class FileConstraint(nn.Module):
|
|
281
|
+
"""
|
|
282
|
+
Load constraint function from Python file.
|
|
283
|
+
|
|
284
|
+
The file must define a function `constraint(pred, inputs=None)` that
|
|
285
|
+
returns per-sample violation values.
|
|
286
|
+
|
|
287
|
+
Example file (my_constraint.py):
|
|
288
|
+
import torch
|
|
289
|
+
|
|
290
|
+
def constraint(pred, inputs=None):
|
|
291
|
+
# Monotonicity: y0 < y1 < y2
|
|
292
|
+
diffs = pred[:, 1:] - pred[:, :-1]
|
|
293
|
+
return torch.relu(-diffs).sum(dim=1)
|
|
294
|
+
|
|
295
|
+
Usage:
|
|
296
|
+
>>> constraint = FileConstraint("my_constraint.py")
|
|
297
|
+
>>> penalty = constraint(predictions, inputs)
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, file_path: str, reduction: str = "mse"):
|
|
301
|
+
"""
|
|
302
|
+
Args:
|
|
303
|
+
file_path: Path to Python file containing constraint function
|
|
304
|
+
reduction: How to reduce violations - 'mse' or 'mae'
|
|
305
|
+
"""
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.file_path = file_path
|
|
308
|
+
self.reduction = reduction
|
|
309
|
+
|
|
310
|
+
# Load module from file
|
|
311
|
+
spec = importlib.util.spec_from_file_location("constraint_module", file_path)
|
|
312
|
+
if spec is None or spec.loader is None:
|
|
313
|
+
raise ValueError(f"Could not load constraint file: {file_path}")
|
|
314
|
+
|
|
315
|
+
module = importlib.util.module_from_spec(spec)
|
|
316
|
+
sys.modules["constraint_module"] = module
|
|
317
|
+
spec.loader.exec_module(module)
|
|
318
|
+
|
|
319
|
+
if not hasattr(module, "constraint"):
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"Constraint file must define 'constraint(pred, inputs)' function: {file_path}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
self._constraint_fn = module.constraint
|
|
325
|
+
|
|
326
|
+
def forward(
|
|
327
|
+
self, pred: torch.Tensor, inputs: torch.Tensor | None = None
|
|
328
|
+
) -> torch.Tensor:
|
|
329
|
+
"""Evaluate constraint from loaded function."""
|
|
330
|
+
violation = self._constraint_fn(pred, inputs)
|
|
331
|
+
if self.reduction == "mse":
|
|
332
|
+
return (violation**2).mean()
|
|
333
|
+
else:
|
|
334
|
+
return violation.abs().mean()
|
|
335
|
+
|
|
336
|
+
def __repr__(self) -> str:
|
|
337
|
+
return f"FileConstraint('{self.file_path}')"
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# ==============================================================================
|
|
341
|
+
# COMBINED LOSS WRAPPER
|
|
342
|
+
# ==============================================================================
|
|
343
|
+
class PhysicsConstrainedLoss(nn.Module):
|
|
344
|
+
"""
|
|
345
|
+
Combine base loss with constraint penalties.
|
|
346
|
+
|
|
347
|
+
Total Loss = Base Loss + Σ(weight_i × constraint_i)
|
|
348
|
+
|
|
349
|
+
Constraints are evaluated in **physical space** (denormalized) while
|
|
350
|
+
the base loss is computed in normalized space for stable training.
|
|
351
|
+
|
|
352
|
+
Example:
|
|
353
|
+
>>> base_loss = nn.MSELoss()
|
|
354
|
+
>>> constraints = [ExpressionConstraint("y0 - y1*y2")]
|
|
355
|
+
>>> criterion = PhysicsConstrainedLoss(
|
|
356
|
+
... base_loss,
|
|
357
|
+
... constraints,
|
|
358
|
+
... weights=[0.1],
|
|
359
|
+
... output_mean=[10, 5, 50],
|
|
360
|
+
... output_std=[2, 1, 10],
|
|
361
|
+
... )
|
|
362
|
+
>>> loss = criterion(pred, target, inputs)
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def __init__(
|
|
366
|
+
self,
|
|
367
|
+
base_loss: nn.Module,
|
|
368
|
+
constraints: list[nn.Module] | None = None,
|
|
369
|
+
weights: list[float] | None = None,
|
|
370
|
+
output_mean: torch.Tensor | list[float] | None = None,
|
|
371
|
+
output_std: torch.Tensor | list[float] | None = None,
|
|
372
|
+
):
|
|
373
|
+
"""
|
|
374
|
+
Args:
|
|
375
|
+
base_loss: Base loss function (e.g., MSELoss)
|
|
376
|
+
constraints: List of constraint modules
|
|
377
|
+
weights: Weight for each constraint. If shorter than constraints,
|
|
378
|
+
last weight is repeated.
|
|
379
|
+
output_mean: Mean of each output (for denormalization). Shape: (num_outputs,)
|
|
380
|
+
output_std: Std of each output (for denormalization). Shape: (num_outputs,)
|
|
381
|
+
"""
|
|
382
|
+
super().__init__()
|
|
383
|
+
self.base_loss = base_loss
|
|
384
|
+
self.constraints = nn.ModuleList(constraints or [])
|
|
385
|
+
self.weights = weights or [0.1]
|
|
386
|
+
|
|
387
|
+
# Store scaler as buffers (moves to correct device automatically)
|
|
388
|
+
if output_mean is not None:
|
|
389
|
+
if not isinstance(output_mean, torch.Tensor):
|
|
390
|
+
output_mean = torch.tensor(output_mean, dtype=torch.float32)
|
|
391
|
+
self.register_buffer("output_mean", output_mean)
|
|
392
|
+
else:
|
|
393
|
+
self.register_buffer("output_mean", None)
|
|
394
|
+
|
|
395
|
+
if output_std is not None:
|
|
396
|
+
if not isinstance(output_std, torch.Tensor):
|
|
397
|
+
output_std = torch.tensor(output_std, dtype=torch.float32)
|
|
398
|
+
self.register_buffer("output_std", output_std)
|
|
399
|
+
else:
|
|
400
|
+
self.register_buffer("output_std", None)
|
|
401
|
+
|
|
402
|
+
def _denormalize(self, pred: torch.Tensor) -> torch.Tensor:
|
|
403
|
+
"""Convert normalized predictions to physical values."""
|
|
404
|
+
if self.output_mean is None or self.output_std is None:
|
|
405
|
+
return pred
|
|
406
|
+
return pred * self.output_std + self.output_mean
|
|
407
|
+
|
|
408
|
+
def forward(
|
|
409
|
+
self,
|
|
410
|
+
pred: torch.Tensor,
|
|
411
|
+
target: torch.Tensor,
|
|
412
|
+
inputs: torch.Tensor | None = None,
|
|
413
|
+
) -> torch.Tensor:
|
|
414
|
+
"""
|
|
415
|
+
Compute combined loss.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
pred: Model predictions (normalized)
|
|
419
|
+
target: Ground truth targets (normalized)
|
|
420
|
+
inputs: Model inputs (for input-dependent constraints)
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
Combined loss value
|
|
424
|
+
"""
|
|
425
|
+
# Base loss in normalized space (stable gradients)
|
|
426
|
+
loss = self.base_loss(pred, target)
|
|
427
|
+
|
|
428
|
+
# Denormalize for constraint evaluation (physical units)
|
|
429
|
+
pred_physical = self._denormalize(pred)
|
|
430
|
+
|
|
431
|
+
for i, constraint in enumerate(self.constraints):
|
|
432
|
+
weight = self.weights[i] if i < len(self.weights) else self.weights[-1]
|
|
433
|
+
penalty = constraint(pred_physical, inputs)
|
|
434
|
+
loss = loss + weight * penalty
|
|
435
|
+
|
|
436
|
+
return loss
|
|
437
|
+
|
|
438
|
+
def __repr__(self) -> str:
|
|
439
|
+
has_scaler = self.output_mean is not None
|
|
440
|
+
return f"PhysicsConstrainedLoss(base={self.base_loss}, constraints={len(self.constraints)}, denormalize={has_scaler})"
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# ==============================================================================
|
|
444
|
+
# FACTORY FUNCTIONS
|
|
445
|
+
# ==============================================================================
|
|
446
|
+
def build_constraints(
|
|
447
|
+
expressions: list[str] | None = None,
|
|
448
|
+
file_path: str | None = None,
|
|
449
|
+
reduction: str = "mse",
|
|
450
|
+
) -> list[nn.Module]:
|
|
451
|
+
"""
|
|
452
|
+
Build soft constraint modules from CLI arguments.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
expressions: Expression constraints (e.g., ["y0 - y1*y2", "y0 > 0"])
|
|
456
|
+
file_path: Path to Python constraint file
|
|
457
|
+
reduction: Reduction mode for penalties
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
List of constraint modules
|
|
461
|
+
"""
|
|
462
|
+
constraints: list[nn.Module] = []
|
|
463
|
+
|
|
464
|
+
for expr in expressions or []:
|
|
465
|
+
constraints.append(ExpressionConstraint(expr, reduction))
|
|
466
|
+
|
|
467
|
+
if file_path:
|
|
468
|
+
constraints.append(FileConstraint(file_path, reduction))
|
|
469
|
+
|
|
470
|
+
return constraints
|
wavedl/utils/metrics.py
CHANGED
|
@@ -560,11 +560,56 @@ def create_training_curves(
|
|
|
560
560
|
)
|
|
561
561
|
lines.append(line)
|
|
562
562
|
|
|
563
|
+
def set_lr_ticks(ax: plt.Axes, data: list[float], n_ticks: int = 4) -> None:
|
|
564
|
+
"""Set n uniformly spaced ticks on LR axis with 10^n format labels."""
|
|
565
|
+
valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
|
|
566
|
+
if not valid_data:
|
|
567
|
+
return
|
|
568
|
+
vmin, vmax = min(valid_data), max(valid_data)
|
|
569
|
+
# Snap to clean decade boundaries
|
|
570
|
+
log_min = np.floor(np.log10(vmin))
|
|
571
|
+
log_max = np.ceil(np.log10(vmax))
|
|
572
|
+
# Generate n uniformly spaced ticks as powers of 10
|
|
573
|
+
log_ticks = np.linspace(log_min, log_max, n_ticks)
|
|
574
|
+
# Round to nearest integer power of 10 for clean numbers
|
|
575
|
+
log_ticks = np.round(log_ticks)
|
|
576
|
+
ticks = 10.0**log_ticks
|
|
577
|
+
# Remove duplicates while preserving order
|
|
578
|
+
ticks = list(dict.fromkeys(ticks))
|
|
579
|
+
ax.set_yticks(ticks)
|
|
580
|
+
# Format all tick labels as 10^n
|
|
581
|
+
labels = [f"$10^{{{int(np.log10(t))}}}$" for t in ticks]
|
|
582
|
+
ax.set_yticklabels(labels)
|
|
583
|
+
ax.minorticks_off()
|
|
584
|
+
|
|
585
|
+
def set_loss_ticks(ax: plt.Axes, data: list[float]) -> None:
|
|
586
|
+
"""Set ticks at powers of 10 that cover the data range."""
|
|
587
|
+
valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
|
|
588
|
+
if not valid_data:
|
|
589
|
+
return
|
|
590
|
+
vmin, vmax = min(valid_data), max(valid_data)
|
|
591
|
+
# Get decade range that covers data (ceil for min to avoid going too low)
|
|
592
|
+
log_min = int(np.ceil(np.log10(vmin)))
|
|
593
|
+
log_max = int(np.ceil(np.log10(vmax)))
|
|
594
|
+
# Generate ticks at each power of 10
|
|
595
|
+
ticks = [10.0**i for i in range(log_min, log_max + 1)]
|
|
596
|
+
ax.set_yticks(ticks)
|
|
597
|
+
# Format labels as 10^n
|
|
598
|
+
labels = [f"$10^{{{i}}}$" for i in range(log_min, log_max + 1)]
|
|
599
|
+
ax.set_yticklabels(labels)
|
|
600
|
+
ax.minorticks_off()
|
|
601
|
+
|
|
563
602
|
ax1.set_xlabel("Epoch")
|
|
564
603
|
ax1.set_ylabel("Loss")
|
|
565
604
|
ax1.set_yscale("log") # Log scale for loss
|
|
566
605
|
ax1.grid(True, alpha=0.3)
|
|
567
606
|
|
|
607
|
+
# Collect all loss values and set clean power of 10 ticks
|
|
608
|
+
all_loss_values = []
|
|
609
|
+
for metric in metrics:
|
|
610
|
+
all_loss_values.extend([h.get(metric, np.nan) for h in history])
|
|
611
|
+
set_loss_ticks(ax1, all_loss_values)
|
|
612
|
+
|
|
568
613
|
# Check if learning rate data exists
|
|
569
614
|
has_lr = show_lr and any("lr" in h for h in history)
|
|
570
615
|
|
|
@@ -581,9 +626,11 @@ def create_training_curves(
|
|
|
581
626
|
alpha=0.7,
|
|
582
627
|
label="Learning Rate",
|
|
583
628
|
)
|
|
584
|
-
ax2.set_ylabel("Learning Rate"
|
|
585
|
-
ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
|
|
629
|
+
ax2.set_ylabel("Learning Rate")
|
|
586
630
|
ax2.set_yscale("log") # Log scale for LR
|
|
631
|
+
set_lr_ticks(ax2, lr_values, n_ticks=4)
|
|
632
|
+
# Ensure right spine (axis line) is visible
|
|
633
|
+
ax2.spines["right"].set_visible(True)
|
|
587
634
|
lines.append(line_lr)
|
|
588
635
|
|
|
589
636
|
# Combined legend
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5.0
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
|
|
|
113
113
|
</td>
|
|
114
114
|
<td width="50%" valign="top">
|
|
115
115
|
|
|
116
|
-
**🧠
|
|
116
|
+
**🧠 Models? We've Got Options**
|
|
117
117
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
```
|
|
123
|
-
Design your model. Register with one line.
|
|
118
|
+
38 architectures, ready to go:
|
|
119
|
+
- CNNs, ResNets, ViTs, EfficientNets...
|
|
120
|
+
- All adapted for regression
|
|
121
|
+
- [Add your own](#adding-custom-models) in one line
|
|
124
122
|
|
|
125
123
|
</td>
|
|
126
124
|
</tr>
|
|
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
|
|
|
137
135
|
</td>
|
|
138
136
|
<td width="50%" valign="top">
|
|
139
137
|
|
|
140
|
-
|
|
138
|
+
**🔬 Physics-Constrained Training**
|
|
141
139
|
|
|
142
|
-
|
|
143
|
-
-
|
|
144
|
-
-
|
|
145
|
-
-
|
|
140
|
+
Make your model respect the laws:
|
|
141
|
+
- Enforce bounds, positivity, equations
|
|
142
|
+
- Simple expression syntax or Python
|
|
143
|
+
- [Custom constraints](#physical-constraints) for various laws
|
|
146
144
|
|
|
147
145
|
</td>
|
|
148
146
|
</tr>
|
|
@@ -383,7 +381,7 @@ WaveDL/
|
|
|
383
381
|
├── configs/ # YAML config templates
|
|
384
382
|
├── examples/ # Ready-to-run examples
|
|
385
383
|
├── notebooks/ # Jupyter notebooks
|
|
386
|
-
├── unit_tests/ # Pytest test suite (
|
|
384
|
+
├── unit_tests/ # Pytest test suite (725 tests)
|
|
387
385
|
│
|
|
388
386
|
├── pyproject.toml # Package config, dependencies
|
|
389
387
|
├── CHANGELOG.md # Version history
|
|
@@ -727,6 +725,104 @@ seed: 2025
|
|
|
727
725
|
|
|
728
726
|
</details>
|
|
729
727
|
|
|
728
|
+
<details>
|
|
729
|
+
<summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
|
|
730
|
+
|
|
731
|
+
Add penalty terms to the loss function to enforce physical laws:
|
|
732
|
+
|
|
733
|
+
```
|
|
734
|
+
Total Loss = Data Loss + weight × penalty(violation)
|
|
735
|
+
```
|
|
736
|
+
|
|
737
|
+
### Expression Constraints
|
|
738
|
+
|
|
739
|
+
```bash
|
|
740
|
+
# Positivity
|
|
741
|
+
--constraint "y0 > 0"
|
|
742
|
+
|
|
743
|
+
# Bounds
|
|
744
|
+
--constraint "y0 >= 0" "y0 <= 1"
|
|
745
|
+
|
|
746
|
+
# Equations (penalize deviations from zero)
|
|
747
|
+
--constraint "y2 - y0 * y1"
|
|
748
|
+
|
|
749
|
+
# Input-dependent constraints
|
|
750
|
+
--constraint "y0 - 2*x[0]"
|
|
751
|
+
|
|
752
|
+
# Multiple constraints with different weights
|
|
753
|
+
--constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
|
|
754
|
+
```
|
|
755
|
+
|
|
756
|
+
### Custom Python Constraints
|
|
757
|
+
|
|
758
|
+
For complex physics (matrix operations, implicit equations):
|
|
759
|
+
|
|
760
|
+
```python
|
|
761
|
+
# my_constraint.py
|
|
762
|
+
import torch
|
|
763
|
+
|
|
764
|
+
def constraint(pred, inputs=None):
|
|
765
|
+
"""
|
|
766
|
+
Args:
|
|
767
|
+
pred: (batch, num_outputs)
|
|
768
|
+
inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
|
|
769
|
+
Returns:
|
|
770
|
+
(batch,) — violation per sample (0 = satisfied)
|
|
771
|
+
"""
|
|
772
|
+
# Outputs (same for all data types)
|
|
773
|
+
y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
|
|
774
|
+
|
|
775
|
+
# Inputs — Tabular: (batch, features)
|
|
776
|
+
# x0 = inputs[:, 0] # Feature 0
|
|
777
|
+
# x_sum = inputs.sum(dim=1) # Sum all features
|
|
778
|
+
|
|
779
|
+
# Inputs — Images: (batch, C, H, W)
|
|
780
|
+
# pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
|
|
781
|
+
# img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
|
|
782
|
+
|
|
783
|
+
# Inputs — 3D Volumes: (batch, C, D, H, W)
|
|
784
|
+
# voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
|
|
785
|
+
|
|
786
|
+
# Example constraints:
|
|
787
|
+
# return y2 - y0 * y1 # Wave equation
|
|
788
|
+
# return y0 - 2 * inputs[:, 0] # Output = 2×input
|
|
789
|
+
# return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
|
|
790
|
+
|
|
791
|
+
return y0 - y1 * y2
|
|
792
|
+
```
|
|
793
|
+
|
|
794
|
+
```bash
|
|
795
|
+
--constraint_file my_constraint.py --constraint_weight 1.0
|
|
796
|
+
```
|
|
797
|
+
|
|
798
|
+
---
|
|
799
|
+
|
|
800
|
+
### Reference
|
|
801
|
+
|
|
802
|
+
| Argument | Default | Description |
|
|
803
|
+
|----------|---------|-------------|
|
|
804
|
+
| `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
|
|
805
|
+
| `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
|
|
806
|
+
| `--constraint_weight` | `0.1` | Penalty weight(s) |
|
|
807
|
+
| `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
|
|
808
|
+
|
|
809
|
+
#### Expression Syntax
|
|
810
|
+
|
|
811
|
+
| Variable | Meaning |
|
|
812
|
+
|----------|---------|
|
|
813
|
+
| `y0`, `y1`, ... | Model outputs |
|
|
814
|
+
| `x[0]`, `x[1]`, ... | Input values (1D tabular) |
|
|
815
|
+
| `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
|
|
816
|
+
| `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
|
|
817
|
+
|
|
818
|
+
**Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
|
|
819
|
+
|
|
820
|
+
**Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
|
|
821
|
+
|
|
822
|
+
</details>
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
|
|
730
826
|
<details>
|
|
731
827
|
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
732
828
|
|
|
@@ -766,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
766
862
|
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
767
863
|
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
768
864
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
769
|
-
| Batch size |
|
|
865
|
+
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
770
866
|
|
|
771
867
|
**Quick Mode** (`--quick`):
|
|
772
868
|
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
@@ -938,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
938
1034
|
```bash
|
|
939
1035
|
# Run inference on the example data
|
|
940
1036
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
941
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1037
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
942
1038
|
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
943
1039
|
|
|
944
1040
|
# Export to ONNX (already included as model.onnx)
|
|
945
1041
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
946
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1042
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
947
1043
|
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
948
1044
|
```
|
|
949
1045
|
|
|
@@ -952,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
952
1048
|
| File | Description |
|
|
953
1049
|
|------|-------------|
|
|
954
1050
|
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
955
|
-
| `
|
|
1051
|
+
| `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
956
1052
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
957
1053
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
958
1054
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -963,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
963
1059
|
|
|
964
1060
|
<p align="center">
|
|
965
1061
|
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
966
|
-
<em>Training and validation loss over
|
|
1062
|
+
<em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
|
|
967
1063
|
</p>
|
|
968
1064
|
|
|
969
1065
|
**Inference Results:**
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=jeYx6dZ0_UD4yXl-I4_Aa63820RzS5-IxJP_KUni8Pw,1177
|
|
2
2
|
wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
|
|
3
|
-
wavedl/hpo.py,sha256=
|
|
3
|
+
wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
|
|
4
4
|
wavedl/test.py,sha256=oWGSSC7178loqOxwti-oDXUVogOqbwHL__GfoXSE5Ss,37846
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=Xy0-W9Kxqdc7WaAyt_HDSZwhZKrfjVbZd2KzgSid6rQ,52457
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -20,18 +20,19 @@ wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
|
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
22
|
wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
|
|
23
|
-
wavedl/utils/__init__.py,sha256=
|
|
24
|
-
wavedl/utils/config.py,sha256=
|
|
23
|
+
wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
24
|
+
wavedl/utils/config.py,sha256=jGW-K7AYB6zrD2BfVm2XPnSY9rbfL_EkM4bwxhBLuwM,10859
|
|
25
|
+
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
25
26
|
wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
|
|
26
27
|
wavedl/utils/data.py,sha256=_OaWvU5oFVJW0NwM5WyDD0Kb1hy5MgvJIFpzvJGux9w,48214
|
|
27
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
28
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
29
|
-
wavedl/utils/metrics.py,sha256=
|
|
30
|
+
wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
|
|
30
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.
|
|
33
|
-
wavedl-1.
|
|
34
|
-
wavedl-1.
|
|
35
|
-
wavedl-1.
|
|
36
|
-
wavedl-1.
|
|
37
|
-
wavedl-1.
|
|
33
|
+
wavedl-1.5.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.0.dist-info/METADATA,sha256=jlQPfPLtwIXzWHKwgN8xQjzco_xB7L3j7p_dl3TJA8E,45224
|
|
35
|
+
wavedl-1.5.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.0.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|