wavedl 1.4.6__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +9 -1
- wavedl/models/vit.py +21 -0
- wavedl/test.py +28 -5
- wavedl/train.py +122 -15
- wavedl/utils/__init__.py +11 -0
- wavedl/utils/config.py +10 -0
- wavedl/utils/constraints.py +470 -0
- wavedl/utils/cross_validation.py +12 -2
- wavedl/utils/data.py +26 -7
- wavedl/utils/metrics.py +49 -2
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/METADATA +122 -19
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/RECORD +17 -16
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/LICENSE +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/WHEEL +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.6.dist-info → wavedl-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Physical Constraints for Training
|
|
3
|
+
=================================
|
|
4
|
+
|
|
5
|
+
Soft constraint enforcement via penalty-based loss terms.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Expression constraints
|
|
9
|
+
wavedl-train --constraint "y0 > 0" --constraint_weight 0.1
|
|
10
|
+
|
|
11
|
+
# Complex constraints via Python file
|
|
12
|
+
wavedl-train --constraint_file my_constraint.py
|
|
13
|
+
|
|
14
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
15
|
+
Version: 2.0.0
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import ast
|
|
21
|
+
import importlib.util
|
|
22
|
+
import sys
|
|
23
|
+
from typing import TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from collections.abc import Callable
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ==============================================================================
|
|
35
|
+
# SAFE EXPRESSION PARSING
|
|
36
|
+
# ==============================================================================
|
|
37
|
+
SAFE_FUNCTIONS: dict[str, Callable] = {
|
|
38
|
+
"sin": torch.sin,
|
|
39
|
+
"cos": torch.cos,
|
|
40
|
+
"tan": torch.tan,
|
|
41
|
+
"exp": torch.exp,
|
|
42
|
+
"log": torch.log,
|
|
43
|
+
"sqrt": torch.sqrt,
|
|
44
|
+
"abs": torch.abs,
|
|
45
|
+
"relu": F.relu,
|
|
46
|
+
"sigmoid": torch.sigmoid,
|
|
47
|
+
"softplus": F.softplus,
|
|
48
|
+
"tanh": torch.tanh,
|
|
49
|
+
"min": torch.minimum,
|
|
50
|
+
"max": torch.maximum,
|
|
51
|
+
"pow": torch.pow,
|
|
52
|
+
"clamp": torch.clamp,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
INPUT_AGGREGATES: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
|
|
56
|
+
"x_mean": lambda x: x.mean(dim=tuple(range(1, x.ndim))),
|
|
57
|
+
"x_sum": lambda x: x.sum(dim=tuple(range(1, x.ndim))),
|
|
58
|
+
"x_max": lambda x: x.amax(dim=tuple(range(1, x.ndim))),
|
|
59
|
+
"x_min": lambda x: x.amin(dim=tuple(range(1, x.ndim))),
|
|
60
|
+
"x_std": lambda x: x.std(dim=tuple(range(1, x.ndim))),
|
|
61
|
+
"x_energy": lambda x: (x**2).sum(dim=tuple(range(1, x.ndim))),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ==============================================================================
|
|
66
|
+
# SOFT CONSTRAINTS
|
|
67
|
+
# ==============================================================================
|
|
68
|
+
class ExpressionConstraint(nn.Module):
|
|
69
|
+
"""
|
|
70
|
+
Soft constraint via string expression.
|
|
71
|
+
|
|
72
|
+
Parses mathematical expressions using Python's AST for safe evaluation.
|
|
73
|
+
Supports output variables (y0, y1, ...), input aggregates (x_mean, ...),
|
|
74
|
+
and whitelisted math functions.
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
>>> constraint = ExpressionConstraint("y0 - y1 * y2")
|
|
78
|
+
>>> penalty = constraint(predictions, inputs)
|
|
79
|
+
|
|
80
|
+
>>> constraint = ExpressionConstraint("sin(y0) + cos(y1)")
|
|
81
|
+
>>> penalty = constraint(predictions, inputs)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, expression: str, reduction: str = "mse"):
|
|
85
|
+
"""
|
|
86
|
+
Args:
|
|
87
|
+
expression: Mathematical expression to evaluate (should equal 0)
|
|
88
|
+
reduction: How to reduce violations - 'mse' or 'mae'
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.expression = expression
|
|
92
|
+
self.reduction = reduction
|
|
93
|
+
self._tree = ast.parse(expression, mode="eval")
|
|
94
|
+
self._validate(self._tree)
|
|
95
|
+
|
|
96
|
+
def _validate(self, tree: ast.Expression) -> None:
|
|
97
|
+
"""Validate that expression only uses safe functions."""
|
|
98
|
+
for node in ast.walk(tree):
|
|
99
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
100
|
+
if node.func.id not in SAFE_FUNCTIONS:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Unsafe function '{node.func.id}' in constraint. "
|
|
103
|
+
f"Allowed: {list(SAFE_FUNCTIONS.keys())}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _eval(
|
|
107
|
+
self, node: ast.AST, pred: torch.Tensor, inputs: torch.Tensor | None
|
|
108
|
+
) -> torch.Tensor:
|
|
109
|
+
"""Recursively evaluate AST node."""
|
|
110
|
+
if isinstance(node, ast.Constant):
|
|
111
|
+
return torch.tensor(node.value, device=pred.device, dtype=pred.dtype)
|
|
112
|
+
elif isinstance(node, ast.Name):
|
|
113
|
+
name = node.id
|
|
114
|
+
# Output variable: y0, y1, ...
|
|
115
|
+
if name.startswith("y") and name[1:].isdigit():
|
|
116
|
+
idx = int(name[1:])
|
|
117
|
+
if idx >= pred.shape[1]:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Output index {idx} out of range. "
|
|
120
|
+
f"Model has {pred.shape[1]} outputs."
|
|
121
|
+
)
|
|
122
|
+
return pred[:, idx]
|
|
123
|
+
# Input aggregate: x_mean, x_sum, ...
|
|
124
|
+
elif name in INPUT_AGGREGATES:
|
|
125
|
+
if inputs is None:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Constraint uses '{name}' but inputs not provided."
|
|
128
|
+
)
|
|
129
|
+
return INPUT_AGGREGATES[name](inputs)
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Unknown variable '{name}'. "
|
|
133
|
+
f"Use y0, y1, ... for outputs or {list(INPUT_AGGREGATES.keys())} for inputs."
|
|
134
|
+
)
|
|
135
|
+
elif isinstance(node, ast.BinOp):
|
|
136
|
+
left = self._eval(node.left, pred, inputs)
|
|
137
|
+
right = self._eval(node.right, pred, inputs)
|
|
138
|
+
ops = {
|
|
139
|
+
ast.Add: torch.add,
|
|
140
|
+
ast.Sub: torch.sub,
|
|
141
|
+
ast.Mult: torch.mul,
|
|
142
|
+
ast.Div: torch.div,
|
|
143
|
+
ast.Pow: torch.pow,
|
|
144
|
+
ast.Mod: torch.remainder,
|
|
145
|
+
}
|
|
146
|
+
if type(node.op) not in ops:
|
|
147
|
+
raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
|
|
148
|
+
return ops[type(node.op)](left, right)
|
|
149
|
+
elif isinstance(node, ast.UnaryOp):
|
|
150
|
+
operand = self._eval(node.operand, pred, inputs)
|
|
151
|
+
if isinstance(node.op, ast.USub):
|
|
152
|
+
return -operand
|
|
153
|
+
elif isinstance(node.op, ast.UAdd):
|
|
154
|
+
return operand
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Unsupported unary operator: {type(node.op).__name__}"
|
|
158
|
+
)
|
|
159
|
+
elif isinstance(node, ast.Call):
|
|
160
|
+
if not isinstance(node.func, ast.Name):
|
|
161
|
+
raise ValueError("Only direct function calls supported (e.g., sin(x))")
|
|
162
|
+
func_name = node.func.id
|
|
163
|
+
if func_name not in SAFE_FUNCTIONS:
|
|
164
|
+
raise ValueError(f"Unsafe function: {func_name}")
|
|
165
|
+
args = [self._eval(arg, pred, inputs) for arg in node.args]
|
|
166
|
+
return SAFE_FUNCTIONS[func_name](*args)
|
|
167
|
+
elif isinstance(node, ast.Compare):
|
|
168
|
+
# Comparison operators: y0 > 0, y0 < 1, y0 >= 0, y0 <= 1
|
|
169
|
+
# Returns penalty (violation amount) when constraint is not satisfied
|
|
170
|
+
if len(node.ops) != 1 or len(node.comparators) != 1:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Only single comparisons supported (e.g., 'y0 > 0', not 'y0 > 0 > y1')"
|
|
173
|
+
)
|
|
174
|
+
left = self._eval(node.left, pred, inputs)
|
|
175
|
+
right = self._eval(node.comparators[0], pred, inputs)
|
|
176
|
+
op = node.ops[0]
|
|
177
|
+
|
|
178
|
+
# Return violation amount (0 if satisfied, positive if violated)
|
|
179
|
+
if isinstance(
|
|
180
|
+
op, (ast.Gt, ast.GtE)
|
|
181
|
+
): # y0 > right → penalize if y0 <= right
|
|
182
|
+
return F.relu(right - left)
|
|
183
|
+
elif isinstance(
|
|
184
|
+
op, (ast.Lt, ast.LtE)
|
|
185
|
+
): # y0 < right → penalize if y0 >= right
|
|
186
|
+
return F.relu(left - right)
|
|
187
|
+
elif isinstance(op, ast.Eq): # y0 == right → penalize difference
|
|
188
|
+
return torch.abs(left - right)
|
|
189
|
+
elif isinstance(op, ast.NotEq): # y0 != right → not useful as constraint
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"'!=' is not a valid constraint. Use '==' for equality constraints."
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Unsupported comparison operator: {type(op).__name__}"
|
|
196
|
+
)
|
|
197
|
+
elif isinstance(node, ast.Subscript):
|
|
198
|
+
# Input indexing: x[0], x[0,5], x[0,5,10]
|
|
199
|
+
if not isinstance(node.value, ast.Name) or node.value.id != "x":
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Subscript indexing only supported for 'x' (inputs). "
|
|
202
|
+
"Use x[i], x[i,j], or x[i,j,k]."
|
|
203
|
+
)
|
|
204
|
+
if inputs is None:
|
|
205
|
+
raise ValueError("Constraint uses 'x[...]' but inputs not provided.")
|
|
206
|
+
|
|
207
|
+
# Parse indices from the slice
|
|
208
|
+
indices = self._parse_subscript_indices(node.slice)
|
|
209
|
+
|
|
210
|
+
# Validate dimensions match
|
|
211
|
+
# inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
|
|
212
|
+
input_ndim = inputs.ndim - 1 # Exclude batch dimension
|
|
213
|
+
if len(indices) != input_ndim:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
|
|
216
|
+
f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Extract the value at the specified indices (for entire batch)
|
|
220
|
+
if len(indices) == 1:
|
|
221
|
+
return inputs[:, indices[0]]
|
|
222
|
+
elif len(indices) == 2:
|
|
223
|
+
return inputs[:, indices[0], indices[1]]
|
|
224
|
+
elif len(indices) == 3:
|
|
225
|
+
return inputs[:, indices[0], indices[1], indices[2]]
|
|
226
|
+
else:
|
|
227
|
+
raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
|
|
228
|
+
elif isinstance(node, ast.Expression):
|
|
229
|
+
return self._eval(node.body, pred, inputs)
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(f"Unsupported AST node type: {type(node).__name__}")
|
|
232
|
+
|
|
233
|
+
def _parse_subscript_indices(self, slice_node: ast.AST) -> list[int]:
|
|
234
|
+
"""Parse subscript indices from AST slice node."""
|
|
235
|
+
if isinstance(slice_node, ast.Constant):
|
|
236
|
+
# Single index: x[0]
|
|
237
|
+
return [int(slice_node.value)]
|
|
238
|
+
elif isinstance(slice_node, ast.Tuple):
|
|
239
|
+
# Multiple indices: x[0,5] or x[0,5,10]
|
|
240
|
+
indices = []
|
|
241
|
+
for elt in slice_node.elts:
|
|
242
|
+
if not isinstance(elt, ast.Constant):
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Only constant indices supported in x[...]. "
|
|
245
|
+
"Use x[0,5] not x[i,j]."
|
|
246
|
+
)
|
|
247
|
+
indices.append(int(elt.value))
|
|
248
|
+
return indices
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Unsupported subscript type: {type(slice_node).__name__}. "
|
|
252
|
+
"Use x[0], x[0,5], or x[0,5,10]."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def forward(
|
|
256
|
+
self, pred: torch.Tensor, inputs: torch.Tensor | None = None
|
|
257
|
+
) -> torch.Tensor:
|
|
258
|
+
"""
|
|
259
|
+
Compute constraint violation penalty.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
pred: Model predictions of shape (N, num_outputs)
|
|
263
|
+
inputs: Model inputs of shape (N, ...) for input-dependent constraints
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Scalar penalty value
|
|
267
|
+
"""
|
|
268
|
+
violation = self._eval(self._tree, pred, inputs)
|
|
269
|
+
if self.reduction == "mse":
|
|
270
|
+
return (violation**2).mean()
|
|
271
|
+
else: # mae
|
|
272
|
+
return violation.abs().mean()
|
|
273
|
+
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
return (
|
|
276
|
+
f"ExpressionConstraint('{self.expression}', reduction='{self.reduction}')"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class FileConstraint(nn.Module):
|
|
281
|
+
"""
|
|
282
|
+
Load constraint function from Python file.
|
|
283
|
+
|
|
284
|
+
The file must define a function `constraint(pred, inputs=None)` that
|
|
285
|
+
returns per-sample violation values.
|
|
286
|
+
|
|
287
|
+
Example file (my_constraint.py):
|
|
288
|
+
import torch
|
|
289
|
+
|
|
290
|
+
def constraint(pred, inputs=None):
|
|
291
|
+
# Monotonicity: y0 < y1 < y2
|
|
292
|
+
diffs = pred[:, 1:] - pred[:, :-1]
|
|
293
|
+
return torch.relu(-diffs).sum(dim=1)
|
|
294
|
+
|
|
295
|
+
Usage:
|
|
296
|
+
>>> constraint = FileConstraint("my_constraint.py")
|
|
297
|
+
>>> penalty = constraint(predictions, inputs)
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, file_path: str, reduction: str = "mse"):
|
|
301
|
+
"""
|
|
302
|
+
Args:
|
|
303
|
+
file_path: Path to Python file containing constraint function
|
|
304
|
+
reduction: How to reduce violations - 'mse' or 'mae'
|
|
305
|
+
"""
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.file_path = file_path
|
|
308
|
+
self.reduction = reduction
|
|
309
|
+
|
|
310
|
+
# Load module from file
|
|
311
|
+
spec = importlib.util.spec_from_file_location("constraint_module", file_path)
|
|
312
|
+
if spec is None or spec.loader is None:
|
|
313
|
+
raise ValueError(f"Could not load constraint file: {file_path}")
|
|
314
|
+
|
|
315
|
+
module = importlib.util.module_from_spec(spec)
|
|
316
|
+
sys.modules["constraint_module"] = module
|
|
317
|
+
spec.loader.exec_module(module)
|
|
318
|
+
|
|
319
|
+
if not hasattr(module, "constraint"):
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"Constraint file must define 'constraint(pred, inputs)' function: {file_path}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
self._constraint_fn = module.constraint
|
|
325
|
+
|
|
326
|
+
def forward(
|
|
327
|
+
self, pred: torch.Tensor, inputs: torch.Tensor | None = None
|
|
328
|
+
) -> torch.Tensor:
|
|
329
|
+
"""Evaluate constraint from loaded function."""
|
|
330
|
+
violation = self._constraint_fn(pred, inputs)
|
|
331
|
+
if self.reduction == "mse":
|
|
332
|
+
return (violation**2).mean()
|
|
333
|
+
else:
|
|
334
|
+
return violation.abs().mean()
|
|
335
|
+
|
|
336
|
+
def __repr__(self) -> str:
|
|
337
|
+
return f"FileConstraint('{self.file_path}')"
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# ==============================================================================
|
|
341
|
+
# COMBINED LOSS WRAPPER
|
|
342
|
+
# ==============================================================================
|
|
343
|
+
class PhysicsConstrainedLoss(nn.Module):
|
|
344
|
+
"""
|
|
345
|
+
Combine base loss with constraint penalties.
|
|
346
|
+
|
|
347
|
+
Total Loss = Base Loss + Σ(weight_i × constraint_i)
|
|
348
|
+
|
|
349
|
+
Constraints are evaluated in **physical space** (denormalized) while
|
|
350
|
+
the base loss is computed in normalized space for stable training.
|
|
351
|
+
|
|
352
|
+
Example:
|
|
353
|
+
>>> base_loss = nn.MSELoss()
|
|
354
|
+
>>> constraints = [ExpressionConstraint("y0 - y1*y2")]
|
|
355
|
+
>>> criterion = PhysicsConstrainedLoss(
|
|
356
|
+
... base_loss,
|
|
357
|
+
... constraints,
|
|
358
|
+
... weights=[0.1],
|
|
359
|
+
... output_mean=[10, 5, 50],
|
|
360
|
+
... output_std=[2, 1, 10],
|
|
361
|
+
... )
|
|
362
|
+
>>> loss = criterion(pred, target, inputs)
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def __init__(
|
|
366
|
+
self,
|
|
367
|
+
base_loss: nn.Module,
|
|
368
|
+
constraints: list[nn.Module] | None = None,
|
|
369
|
+
weights: list[float] | None = None,
|
|
370
|
+
output_mean: torch.Tensor | list[float] | None = None,
|
|
371
|
+
output_std: torch.Tensor | list[float] | None = None,
|
|
372
|
+
):
|
|
373
|
+
"""
|
|
374
|
+
Args:
|
|
375
|
+
base_loss: Base loss function (e.g., MSELoss)
|
|
376
|
+
constraints: List of constraint modules
|
|
377
|
+
weights: Weight for each constraint. If shorter than constraints,
|
|
378
|
+
last weight is repeated.
|
|
379
|
+
output_mean: Mean of each output (for denormalization). Shape: (num_outputs,)
|
|
380
|
+
output_std: Std of each output (for denormalization). Shape: (num_outputs,)
|
|
381
|
+
"""
|
|
382
|
+
super().__init__()
|
|
383
|
+
self.base_loss = base_loss
|
|
384
|
+
self.constraints = nn.ModuleList(constraints or [])
|
|
385
|
+
self.weights = weights or [0.1]
|
|
386
|
+
|
|
387
|
+
# Store scaler as buffers (moves to correct device automatically)
|
|
388
|
+
if output_mean is not None:
|
|
389
|
+
if not isinstance(output_mean, torch.Tensor):
|
|
390
|
+
output_mean = torch.tensor(output_mean, dtype=torch.float32)
|
|
391
|
+
self.register_buffer("output_mean", output_mean)
|
|
392
|
+
else:
|
|
393
|
+
self.register_buffer("output_mean", None)
|
|
394
|
+
|
|
395
|
+
if output_std is not None:
|
|
396
|
+
if not isinstance(output_std, torch.Tensor):
|
|
397
|
+
output_std = torch.tensor(output_std, dtype=torch.float32)
|
|
398
|
+
self.register_buffer("output_std", output_std)
|
|
399
|
+
else:
|
|
400
|
+
self.register_buffer("output_std", None)
|
|
401
|
+
|
|
402
|
+
def _denormalize(self, pred: torch.Tensor) -> torch.Tensor:
|
|
403
|
+
"""Convert normalized predictions to physical values."""
|
|
404
|
+
if self.output_mean is None or self.output_std is None:
|
|
405
|
+
return pred
|
|
406
|
+
return pred * self.output_std + self.output_mean
|
|
407
|
+
|
|
408
|
+
def forward(
|
|
409
|
+
self,
|
|
410
|
+
pred: torch.Tensor,
|
|
411
|
+
target: torch.Tensor,
|
|
412
|
+
inputs: torch.Tensor | None = None,
|
|
413
|
+
) -> torch.Tensor:
|
|
414
|
+
"""
|
|
415
|
+
Compute combined loss.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
pred: Model predictions (normalized)
|
|
419
|
+
target: Ground truth targets (normalized)
|
|
420
|
+
inputs: Model inputs (for input-dependent constraints)
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
Combined loss value
|
|
424
|
+
"""
|
|
425
|
+
# Base loss in normalized space (stable gradients)
|
|
426
|
+
loss = self.base_loss(pred, target)
|
|
427
|
+
|
|
428
|
+
# Denormalize for constraint evaluation (physical units)
|
|
429
|
+
pred_physical = self._denormalize(pred)
|
|
430
|
+
|
|
431
|
+
for i, constraint in enumerate(self.constraints):
|
|
432
|
+
weight = self.weights[i] if i < len(self.weights) else self.weights[-1]
|
|
433
|
+
penalty = constraint(pred_physical, inputs)
|
|
434
|
+
loss = loss + weight * penalty
|
|
435
|
+
|
|
436
|
+
return loss
|
|
437
|
+
|
|
438
|
+
def __repr__(self) -> str:
|
|
439
|
+
has_scaler = self.output_mean is not None
|
|
440
|
+
return f"PhysicsConstrainedLoss(base={self.base_loss}, constraints={len(self.constraints)}, denormalize={has_scaler})"
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# ==============================================================================
|
|
444
|
+
# FACTORY FUNCTIONS
|
|
445
|
+
# ==============================================================================
|
|
446
|
+
def build_constraints(
|
|
447
|
+
expressions: list[str] | None = None,
|
|
448
|
+
file_path: str | None = None,
|
|
449
|
+
reduction: str = "mse",
|
|
450
|
+
) -> list[nn.Module]:
|
|
451
|
+
"""
|
|
452
|
+
Build soft constraint modules from CLI arguments.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
expressions: Expression constraints (e.g., ["y0 - y1*y2", "y0 > 0"])
|
|
456
|
+
file_path: Path to Python constraint file
|
|
457
|
+
reduction: Reduction mode for penalties
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
List of constraint modules
|
|
461
|
+
"""
|
|
462
|
+
constraints: list[nn.Module] = []
|
|
463
|
+
|
|
464
|
+
for expr in expressions or []:
|
|
465
|
+
constraints.append(ExpressionConstraint(expr, reduction))
|
|
466
|
+
|
|
467
|
+
if file_path:
|
|
468
|
+
constraints.append(FileConstraint(file_path, reduction))
|
|
469
|
+
|
|
470
|
+
return constraints
|
wavedl/utils/cross_validation.py
CHANGED
|
@@ -128,6 +128,12 @@ def train_fold(
|
|
|
128
128
|
best_state = None
|
|
129
129
|
history = []
|
|
130
130
|
|
|
131
|
+
# Determine if scheduler steps per batch (OneCycleLR) or per epoch
|
|
132
|
+
# Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
|
|
133
|
+
from torch.optim.lr_scheduler import OneCycleLR
|
|
134
|
+
|
|
135
|
+
step_per_batch = isinstance(scheduler, OneCycleLR)
|
|
136
|
+
|
|
131
137
|
for epoch in range(epochs):
|
|
132
138
|
# Training
|
|
133
139
|
model.train()
|
|
@@ -144,6 +150,10 @@ def train_fold(
|
|
|
144
150
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
145
151
|
optimizer.step()
|
|
146
152
|
|
|
153
|
+
# Per-batch LR scheduling (OneCycleLR)
|
|
154
|
+
if step_per_batch:
|
|
155
|
+
scheduler.step()
|
|
156
|
+
|
|
147
157
|
train_loss += loss.item() * x.size(0)
|
|
148
158
|
train_samples += x.size(0)
|
|
149
159
|
|
|
@@ -186,8 +196,8 @@ def train_fold(
|
|
|
186
196
|
}
|
|
187
197
|
)
|
|
188
198
|
|
|
189
|
-
# LR scheduling
|
|
190
|
-
if hasattr(scheduler, "step"):
|
|
199
|
+
# LR scheduling (epoch-based only, not for per-batch schedulers)
|
|
200
|
+
if not step_per_batch and hasattr(scheduler, "step"):
|
|
191
201
|
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
192
202
|
scheduler.step(avg_val_loss)
|
|
193
203
|
else:
|
wavedl/utils/data.py
CHANGED
|
@@ -201,8 +201,18 @@ class DataSource(ABC):
|
|
|
201
201
|
class NPZSource(DataSource):
|
|
202
202
|
"""Load data from NumPy .npz archives."""
|
|
203
203
|
|
|
204
|
+
@staticmethod
|
|
205
|
+
def _safe_load(path: str, mmap_mode: str | None = None):
|
|
206
|
+
"""Load NPZ with pickle only if needed (sparse matrix support)."""
|
|
207
|
+
try:
|
|
208
|
+
return np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
|
|
209
|
+
except ValueError:
|
|
210
|
+
# Fallback for sparse matrices stored as object arrays
|
|
211
|
+
return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
|
|
212
|
+
|
|
204
213
|
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
205
|
-
|
|
214
|
+
"""Load NPZ file (pickle enabled only for sparse matrices)."""
|
|
215
|
+
data = self._safe_load(path)
|
|
206
216
|
keys = list(data.keys())
|
|
207
217
|
|
|
208
218
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
@@ -233,7 +243,7 @@ class NPZSource(DataSource):
|
|
|
233
243
|
|
|
234
244
|
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
235
245
|
"""
|
|
236
|
-
data =
|
|
246
|
+
data = self._safe_load(path, mmap_mode="r")
|
|
237
247
|
keys = list(data.keys())
|
|
238
248
|
|
|
239
249
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
@@ -253,7 +263,7 @@ class NPZSource(DataSource):
|
|
|
253
263
|
|
|
254
264
|
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
255
265
|
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
256
|
-
data =
|
|
266
|
+
data = self._safe_load(path)
|
|
257
267
|
keys = list(data.keys())
|
|
258
268
|
|
|
259
269
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
@@ -677,6 +687,7 @@ def load_test_data(
|
|
|
677
687
|
format: str = "auto",
|
|
678
688
|
input_key: str | None = None,
|
|
679
689
|
output_key: str | None = None,
|
|
690
|
+
input_channels: int | None = None,
|
|
680
691
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
681
692
|
"""
|
|
682
693
|
Load test/inference data and return PyTorch tensors ready for model input.
|
|
@@ -698,6 +709,9 @@ def load_test_data(
|
|
|
698
709
|
format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
|
|
699
710
|
input_key: Custom key for input data (overrides auto-detection)
|
|
700
711
|
output_key: Custom key for output data (overrides auto-detection)
|
|
712
|
+
input_channels: Explicit number of input channels. If provided, bypasses
|
|
713
|
+
the heuristic for 4D data. Use input_channels=1 for 3D volumes that
|
|
714
|
+
look like multi-channel 2D (e.g., depth ≤16).
|
|
701
715
|
|
|
702
716
|
Returns:
|
|
703
717
|
Tuple of:
|
|
@@ -737,7 +751,7 @@ def load_test_data(
|
|
|
737
751
|
except KeyError:
|
|
738
752
|
# Try with just inputs if outputs not found (inference-only mode)
|
|
739
753
|
if format == "npz":
|
|
740
|
-
data =
|
|
754
|
+
data = NPZSource._safe_load(path)
|
|
741
755
|
keys = list(data.keys())
|
|
742
756
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
743
757
|
if inp_key is None:
|
|
@@ -822,15 +836,20 @@ def load_test_data(
|
|
|
822
836
|
# Add channel dimension if needed (dimension-agnostic)
|
|
823
837
|
# X.ndim == 2: 1D data (N, L) → (N, 1, L)
|
|
824
838
|
# X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
|
|
825
|
-
# X.ndim == 4: Check if already has channel dim
|
|
839
|
+
# X.ndim == 4: Check if already has channel dim
|
|
826
840
|
if X.ndim == 2:
|
|
827
841
|
X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
|
|
828
842
|
elif X.ndim == 3:
|
|
829
843
|
X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
|
|
830
844
|
elif X.ndim == 4:
|
|
831
845
|
# Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
|
|
832
|
-
|
|
833
|
-
|
|
846
|
+
if input_channels is not None:
|
|
847
|
+
# Explicit override: user specifies channel count
|
|
848
|
+
if input_channels == 1:
|
|
849
|
+
X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
|
|
850
|
+
# else: already has channels, leave as-is
|
|
851
|
+
elif X.shape[1] > 16:
|
|
852
|
+
# Heuristic fallback: large dim 1 suggests 3D volume needing channel
|
|
834
853
|
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
835
854
|
# X.ndim >= 5: assume channel dimension already exists
|
|
836
855
|
|
wavedl/utils/metrics.py
CHANGED
|
@@ -560,11 +560,56 @@ def create_training_curves(
|
|
|
560
560
|
)
|
|
561
561
|
lines.append(line)
|
|
562
562
|
|
|
563
|
+
def set_lr_ticks(ax: plt.Axes, data: list[float], n_ticks: int = 4) -> None:
|
|
564
|
+
"""Set n uniformly spaced ticks on LR axis with 10^n format labels."""
|
|
565
|
+
valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
|
|
566
|
+
if not valid_data:
|
|
567
|
+
return
|
|
568
|
+
vmin, vmax = min(valid_data), max(valid_data)
|
|
569
|
+
# Snap to clean decade boundaries
|
|
570
|
+
log_min = np.floor(np.log10(vmin))
|
|
571
|
+
log_max = np.ceil(np.log10(vmax))
|
|
572
|
+
# Generate n uniformly spaced ticks as powers of 10
|
|
573
|
+
log_ticks = np.linspace(log_min, log_max, n_ticks)
|
|
574
|
+
# Round to nearest integer power of 10 for clean numbers
|
|
575
|
+
log_ticks = np.round(log_ticks)
|
|
576
|
+
ticks = 10.0**log_ticks
|
|
577
|
+
# Remove duplicates while preserving order
|
|
578
|
+
ticks = list(dict.fromkeys(ticks))
|
|
579
|
+
ax.set_yticks(ticks)
|
|
580
|
+
# Format all tick labels as 10^n
|
|
581
|
+
labels = [f"$10^{{{int(np.log10(t))}}}$" for t in ticks]
|
|
582
|
+
ax.set_yticklabels(labels)
|
|
583
|
+
ax.minorticks_off()
|
|
584
|
+
|
|
585
|
+
def set_loss_ticks(ax: plt.Axes, data: list[float]) -> None:
|
|
586
|
+
"""Set ticks at powers of 10 that cover the data range."""
|
|
587
|
+
valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
|
|
588
|
+
if not valid_data:
|
|
589
|
+
return
|
|
590
|
+
vmin, vmax = min(valid_data), max(valid_data)
|
|
591
|
+
# Get decade range that covers data (ceil for min to avoid going too low)
|
|
592
|
+
log_min = int(np.ceil(np.log10(vmin)))
|
|
593
|
+
log_max = int(np.ceil(np.log10(vmax)))
|
|
594
|
+
# Generate ticks at each power of 10
|
|
595
|
+
ticks = [10.0**i for i in range(log_min, log_max + 1)]
|
|
596
|
+
ax.set_yticks(ticks)
|
|
597
|
+
# Format labels as 10^n
|
|
598
|
+
labels = [f"$10^{{{i}}}$" for i in range(log_min, log_max + 1)]
|
|
599
|
+
ax.set_yticklabels(labels)
|
|
600
|
+
ax.minorticks_off()
|
|
601
|
+
|
|
563
602
|
ax1.set_xlabel("Epoch")
|
|
564
603
|
ax1.set_ylabel("Loss")
|
|
565
604
|
ax1.set_yscale("log") # Log scale for loss
|
|
566
605
|
ax1.grid(True, alpha=0.3)
|
|
567
606
|
|
|
607
|
+
# Collect all loss values and set clean power of 10 ticks
|
|
608
|
+
all_loss_values = []
|
|
609
|
+
for metric in metrics:
|
|
610
|
+
all_loss_values.extend([h.get(metric, np.nan) for h in history])
|
|
611
|
+
set_loss_ticks(ax1, all_loss_values)
|
|
612
|
+
|
|
568
613
|
# Check if learning rate data exists
|
|
569
614
|
has_lr = show_lr and any("lr" in h for h in history)
|
|
570
615
|
|
|
@@ -581,9 +626,11 @@ def create_training_curves(
|
|
|
581
626
|
alpha=0.7,
|
|
582
627
|
label="Learning Rate",
|
|
583
628
|
)
|
|
584
|
-
ax2.set_ylabel("Learning Rate"
|
|
585
|
-
ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
|
|
629
|
+
ax2.set_ylabel("Learning Rate")
|
|
586
630
|
ax2.set_yscale("log") # Log scale for LR
|
|
631
|
+
set_lr_ticks(ax2, lr_values, n_ticks=4)
|
|
632
|
+
# Ensure right spine (axis line) is visible
|
|
633
|
+
ax2.spines["right"].set_visible(True)
|
|
587
634
|
lines.append(line_lr)
|
|
588
635
|
|
|
589
636
|
# Combined legend
|