wavedl 1.4.6__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,470 @@
1
+ """
2
+ Physical Constraints for Training
3
+ =================================
4
+
5
+ Soft constraint enforcement via penalty-based loss terms.
6
+
7
+ Usage:
8
+ # Expression constraints
9
+ wavedl-train --constraint "y0 > 0" --constraint_weight 0.1
10
+
11
+ # Complex constraints via Python file
12
+ wavedl-train --constraint_file my_constraint.py
13
+
14
+ Author: Ductho Le (ductho.le@outlook.com)
15
+ Version: 2.0.0
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import importlib.util
22
+ import sys
23
+ from typing import TYPE_CHECKING
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from collections.abc import Callable
32
+
33
+
34
+ # ==============================================================================
35
+ # SAFE EXPRESSION PARSING
36
+ # ==============================================================================
37
+ SAFE_FUNCTIONS: dict[str, Callable] = {
38
+ "sin": torch.sin,
39
+ "cos": torch.cos,
40
+ "tan": torch.tan,
41
+ "exp": torch.exp,
42
+ "log": torch.log,
43
+ "sqrt": torch.sqrt,
44
+ "abs": torch.abs,
45
+ "relu": F.relu,
46
+ "sigmoid": torch.sigmoid,
47
+ "softplus": F.softplus,
48
+ "tanh": torch.tanh,
49
+ "min": torch.minimum,
50
+ "max": torch.maximum,
51
+ "pow": torch.pow,
52
+ "clamp": torch.clamp,
53
+ }
54
+
55
+ INPUT_AGGREGATES: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
56
+ "x_mean": lambda x: x.mean(dim=tuple(range(1, x.ndim))),
57
+ "x_sum": lambda x: x.sum(dim=tuple(range(1, x.ndim))),
58
+ "x_max": lambda x: x.amax(dim=tuple(range(1, x.ndim))),
59
+ "x_min": lambda x: x.amin(dim=tuple(range(1, x.ndim))),
60
+ "x_std": lambda x: x.std(dim=tuple(range(1, x.ndim))),
61
+ "x_energy": lambda x: (x**2).sum(dim=tuple(range(1, x.ndim))),
62
+ }
63
+
64
+
65
+ # ==============================================================================
66
+ # SOFT CONSTRAINTS
67
+ # ==============================================================================
68
+ class ExpressionConstraint(nn.Module):
69
+ """
70
+ Soft constraint via string expression.
71
+
72
+ Parses mathematical expressions using Python's AST for safe evaluation.
73
+ Supports output variables (y0, y1, ...), input aggregates (x_mean, ...),
74
+ and whitelisted math functions.
75
+
76
+ Example:
77
+ >>> constraint = ExpressionConstraint("y0 - y1 * y2")
78
+ >>> penalty = constraint(predictions, inputs)
79
+
80
+ >>> constraint = ExpressionConstraint("sin(y0) + cos(y1)")
81
+ >>> penalty = constraint(predictions, inputs)
82
+ """
83
+
84
+ def __init__(self, expression: str, reduction: str = "mse"):
85
+ """
86
+ Args:
87
+ expression: Mathematical expression to evaluate (should equal 0)
88
+ reduction: How to reduce violations - 'mse' or 'mae'
89
+ """
90
+ super().__init__()
91
+ self.expression = expression
92
+ self.reduction = reduction
93
+ self._tree = ast.parse(expression, mode="eval")
94
+ self._validate(self._tree)
95
+
96
+ def _validate(self, tree: ast.Expression) -> None:
97
+ """Validate that expression only uses safe functions."""
98
+ for node in ast.walk(tree):
99
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
100
+ if node.func.id not in SAFE_FUNCTIONS:
101
+ raise ValueError(
102
+ f"Unsafe function '{node.func.id}' in constraint. "
103
+ f"Allowed: {list(SAFE_FUNCTIONS.keys())}"
104
+ )
105
+
106
+ def _eval(
107
+ self, node: ast.AST, pred: torch.Tensor, inputs: torch.Tensor | None
108
+ ) -> torch.Tensor:
109
+ """Recursively evaluate AST node."""
110
+ if isinstance(node, ast.Constant):
111
+ return torch.tensor(node.value, device=pred.device, dtype=pred.dtype)
112
+ elif isinstance(node, ast.Name):
113
+ name = node.id
114
+ # Output variable: y0, y1, ...
115
+ if name.startswith("y") and name[1:].isdigit():
116
+ idx = int(name[1:])
117
+ if idx >= pred.shape[1]:
118
+ raise ValueError(
119
+ f"Output index {idx} out of range. "
120
+ f"Model has {pred.shape[1]} outputs."
121
+ )
122
+ return pred[:, idx]
123
+ # Input aggregate: x_mean, x_sum, ...
124
+ elif name in INPUT_AGGREGATES:
125
+ if inputs is None:
126
+ raise ValueError(
127
+ f"Constraint uses '{name}' but inputs not provided."
128
+ )
129
+ return INPUT_AGGREGATES[name](inputs)
130
+ else:
131
+ raise ValueError(
132
+ f"Unknown variable '{name}'. "
133
+ f"Use y0, y1, ... for outputs or {list(INPUT_AGGREGATES.keys())} for inputs."
134
+ )
135
+ elif isinstance(node, ast.BinOp):
136
+ left = self._eval(node.left, pred, inputs)
137
+ right = self._eval(node.right, pred, inputs)
138
+ ops = {
139
+ ast.Add: torch.add,
140
+ ast.Sub: torch.sub,
141
+ ast.Mult: torch.mul,
142
+ ast.Div: torch.div,
143
+ ast.Pow: torch.pow,
144
+ ast.Mod: torch.remainder,
145
+ }
146
+ if type(node.op) not in ops:
147
+ raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
148
+ return ops[type(node.op)](left, right)
149
+ elif isinstance(node, ast.UnaryOp):
150
+ operand = self._eval(node.operand, pred, inputs)
151
+ if isinstance(node.op, ast.USub):
152
+ return -operand
153
+ elif isinstance(node.op, ast.UAdd):
154
+ return operand
155
+ else:
156
+ raise ValueError(
157
+ f"Unsupported unary operator: {type(node.op).__name__}"
158
+ )
159
+ elif isinstance(node, ast.Call):
160
+ if not isinstance(node.func, ast.Name):
161
+ raise ValueError("Only direct function calls supported (e.g., sin(x))")
162
+ func_name = node.func.id
163
+ if func_name not in SAFE_FUNCTIONS:
164
+ raise ValueError(f"Unsafe function: {func_name}")
165
+ args = [self._eval(arg, pred, inputs) for arg in node.args]
166
+ return SAFE_FUNCTIONS[func_name](*args)
167
+ elif isinstance(node, ast.Compare):
168
+ # Comparison operators: y0 > 0, y0 < 1, y0 >= 0, y0 <= 1
169
+ # Returns penalty (violation amount) when constraint is not satisfied
170
+ if len(node.ops) != 1 or len(node.comparators) != 1:
171
+ raise ValueError(
172
+ "Only single comparisons supported (e.g., 'y0 > 0', not 'y0 > 0 > y1')"
173
+ )
174
+ left = self._eval(node.left, pred, inputs)
175
+ right = self._eval(node.comparators[0], pred, inputs)
176
+ op = node.ops[0]
177
+
178
+ # Return violation amount (0 if satisfied, positive if violated)
179
+ if isinstance(
180
+ op, (ast.Gt, ast.GtE)
181
+ ): # y0 > right → penalize if y0 <= right
182
+ return F.relu(right - left)
183
+ elif isinstance(
184
+ op, (ast.Lt, ast.LtE)
185
+ ): # y0 < right → penalize if y0 >= right
186
+ return F.relu(left - right)
187
+ elif isinstance(op, ast.Eq): # y0 == right → penalize difference
188
+ return torch.abs(left - right)
189
+ elif isinstance(op, ast.NotEq): # y0 != right → not useful as constraint
190
+ raise ValueError(
191
+ "'!=' is not a valid constraint. Use '==' for equality constraints."
192
+ )
193
+ else:
194
+ raise ValueError(
195
+ f"Unsupported comparison operator: {type(op).__name__}"
196
+ )
197
+ elif isinstance(node, ast.Subscript):
198
+ # Input indexing: x[0], x[0,5], x[0,5,10]
199
+ if not isinstance(node.value, ast.Name) or node.value.id != "x":
200
+ raise ValueError(
201
+ "Subscript indexing only supported for 'x' (inputs). "
202
+ "Use x[i], x[i,j], or x[i,j,k]."
203
+ )
204
+ if inputs is None:
205
+ raise ValueError("Constraint uses 'x[...]' but inputs not provided.")
206
+
207
+ # Parse indices from the slice
208
+ indices = self._parse_subscript_indices(node.slice)
209
+
210
+ # Validate dimensions match
211
+ # inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
212
+ input_ndim = inputs.ndim - 1 # Exclude batch dimension
213
+ if len(indices) != input_ndim:
214
+ raise ValueError(
215
+ f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
216
+ f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
217
+ )
218
+
219
+ # Extract the value at the specified indices (for entire batch)
220
+ if len(indices) == 1:
221
+ return inputs[:, indices[0]]
222
+ elif len(indices) == 2:
223
+ return inputs[:, indices[0], indices[1]]
224
+ elif len(indices) == 3:
225
+ return inputs[:, indices[0], indices[1], indices[2]]
226
+ else:
227
+ raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
228
+ elif isinstance(node, ast.Expression):
229
+ return self._eval(node.body, pred, inputs)
230
+ else:
231
+ raise ValueError(f"Unsupported AST node type: {type(node).__name__}")
232
+
233
+ def _parse_subscript_indices(self, slice_node: ast.AST) -> list[int]:
234
+ """Parse subscript indices from AST slice node."""
235
+ if isinstance(slice_node, ast.Constant):
236
+ # Single index: x[0]
237
+ return [int(slice_node.value)]
238
+ elif isinstance(slice_node, ast.Tuple):
239
+ # Multiple indices: x[0,5] or x[0,5,10]
240
+ indices = []
241
+ for elt in slice_node.elts:
242
+ if not isinstance(elt, ast.Constant):
243
+ raise ValueError(
244
+ "Only constant indices supported in x[...]. "
245
+ "Use x[0,5] not x[i,j]."
246
+ )
247
+ indices.append(int(elt.value))
248
+ return indices
249
+ else:
250
+ raise ValueError(
251
+ f"Unsupported subscript type: {type(slice_node).__name__}. "
252
+ "Use x[0], x[0,5], or x[0,5,10]."
253
+ )
254
+
255
+ def forward(
256
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
257
+ ) -> torch.Tensor:
258
+ """
259
+ Compute constraint violation penalty.
260
+
261
+ Args:
262
+ pred: Model predictions of shape (N, num_outputs)
263
+ inputs: Model inputs of shape (N, ...) for input-dependent constraints
264
+
265
+ Returns:
266
+ Scalar penalty value
267
+ """
268
+ violation = self._eval(self._tree, pred, inputs)
269
+ if self.reduction == "mse":
270
+ return (violation**2).mean()
271
+ else: # mae
272
+ return violation.abs().mean()
273
+
274
+ def __repr__(self) -> str:
275
+ return (
276
+ f"ExpressionConstraint('{self.expression}', reduction='{self.reduction}')"
277
+ )
278
+
279
+
280
+ class FileConstraint(nn.Module):
281
+ """
282
+ Load constraint function from Python file.
283
+
284
+ The file must define a function `constraint(pred, inputs=None)` that
285
+ returns per-sample violation values.
286
+
287
+ Example file (my_constraint.py):
288
+ import torch
289
+
290
+ def constraint(pred, inputs=None):
291
+ # Monotonicity: y0 < y1 < y2
292
+ diffs = pred[:, 1:] - pred[:, :-1]
293
+ return torch.relu(-diffs).sum(dim=1)
294
+
295
+ Usage:
296
+ >>> constraint = FileConstraint("my_constraint.py")
297
+ >>> penalty = constraint(predictions, inputs)
298
+ """
299
+
300
+ def __init__(self, file_path: str, reduction: str = "mse"):
301
+ """
302
+ Args:
303
+ file_path: Path to Python file containing constraint function
304
+ reduction: How to reduce violations - 'mse' or 'mae'
305
+ """
306
+ super().__init__()
307
+ self.file_path = file_path
308
+ self.reduction = reduction
309
+
310
+ # Load module from file
311
+ spec = importlib.util.spec_from_file_location("constraint_module", file_path)
312
+ if spec is None or spec.loader is None:
313
+ raise ValueError(f"Could not load constraint file: {file_path}")
314
+
315
+ module = importlib.util.module_from_spec(spec)
316
+ sys.modules["constraint_module"] = module
317
+ spec.loader.exec_module(module)
318
+
319
+ if not hasattr(module, "constraint"):
320
+ raise ValueError(
321
+ f"Constraint file must define 'constraint(pred, inputs)' function: {file_path}"
322
+ )
323
+
324
+ self._constraint_fn = module.constraint
325
+
326
+ def forward(
327
+ self, pred: torch.Tensor, inputs: torch.Tensor | None = None
328
+ ) -> torch.Tensor:
329
+ """Evaluate constraint from loaded function."""
330
+ violation = self._constraint_fn(pred, inputs)
331
+ if self.reduction == "mse":
332
+ return (violation**2).mean()
333
+ else:
334
+ return violation.abs().mean()
335
+
336
+ def __repr__(self) -> str:
337
+ return f"FileConstraint('{self.file_path}')"
338
+
339
+
340
+ # ==============================================================================
341
+ # COMBINED LOSS WRAPPER
342
+ # ==============================================================================
343
+ class PhysicsConstrainedLoss(nn.Module):
344
+ """
345
+ Combine base loss with constraint penalties.
346
+
347
+ Total Loss = Base Loss + Σ(weight_i × constraint_i)
348
+
349
+ Constraints are evaluated in **physical space** (denormalized) while
350
+ the base loss is computed in normalized space for stable training.
351
+
352
+ Example:
353
+ >>> base_loss = nn.MSELoss()
354
+ >>> constraints = [ExpressionConstraint("y0 - y1*y2")]
355
+ >>> criterion = PhysicsConstrainedLoss(
356
+ ... base_loss,
357
+ ... constraints,
358
+ ... weights=[0.1],
359
+ ... output_mean=[10, 5, 50],
360
+ ... output_std=[2, 1, 10],
361
+ ... )
362
+ >>> loss = criterion(pred, target, inputs)
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ base_loss: nn.Module,
368
+ constraints: list[nn.Module] | None = None,
369
+ weights: list[float] | None = None,
370
+ output_mean: torch.Tensor | list[float] | None = None,
371
+ output_std: torch.Tensor | list[float] | None = None,
372
+ ):
373
+ """
374
+ Args:
375
+ base_loss: Base loss function (e.g., MSELoss)
376
+ constraints: List of constraint modules
377
+ weights: Weight for each constraint. If shorter than constraints,
378
+ last weight is repeated.
379
+ output_mean: Mean of each output (for denormalization). Shape: (num_outputs,)
380
+ output_std: Std of each output (for denormalization). Shape: (num_outputs,)
381
+ """
382
+ super().__init__()
383
+ self.base_loss = base_loss
384
+ self.constraints = nn.ModuleList(constraints or [])
385
+ self.weights = weights or [0.1]
386
+
387
+ # Store scaler as buffers (moves to correct device automatically)
388
+ if output_mean is not None:
389
+ if not isinstance(output_mean, torch.Tensor):
390
+ output_mean = torch.tensor(output_mean, dtype=torch.float32)
391
+ self.register_buffer("output_mean", output_mean)
392
+ else:
393
+ self.register_buffer("output_mean", None)
394
+
395
+ if output_std is not None:
396
+ if not isinstance(output_std, torch.Tensor):
397
+ output_std = torch.tensor(output_std, dtype=torch.float32)
398
+ self.register_buffer("output_std", output_std)
399
+ else:
400
+ self.register_buffer("output_std", None)
401
+
402
+ def _denormalize(self, pred: torch.Tensor) -> torch.Tensor:
403
+ """Convert normalized predictions to physical values."""
404
+ if self.output_mean is None or self.output_std is None:
405
+ return pred
406
+ return pred * self.output_std + self.output_mean
407
+
408
+ def forward(
409
+ self,
410
+ pred: torch.Tensor,
411
+ target: torch.Tensor,
412
+ inputs: torch.Tensor | None = None,
413
+ ) -> torch.Tensor:
414
+ """
415
+ Compute combined loss.
416
+
417
+ Args:
418
+ pred: Model predictions (normalized)
419
+ target: Ground truth targets (normalized)
420
+ inputs: Model inputs (for input-dependent constraints)
421
+
422
+ Returns:
423
+ Combined loss value
424
+ """
425
+ # Base loss in normalized space (stable gradients)
426
+ loss = self.base_loss(pred, target)
427
+
428
+ # Denormalize for constraint evaluation (physical units)
429
+ pred_physical = self._denormalize(pred)
430
+
431
+ for i, constraint in enumerate(self.constraints):
432
+ weight = self.weights[i] if i < len(self.weights) else self.weights[-1]
433
+ penalty = constraint(pred_physical, inputs)
434
+ loss = loss + weight * penalty
435
+
436
+ return loss
437
+
438
+ def __repr__(self) -> str:
439
+ has_scaler = self.output_mean is not None
440
+ return f"PhysicsConstrainedLoss(base={self.base_loss}, constraints={len(self.constraints)}, denormalize={has_scaler})"
441
+
442
+
443
+ # ==============================================================================
444
+ # FACTORY FUNCTIONS
445
+ # ==============================================================================
446
+ def build_constraints(
447
+ expressions: list[str] | None = None,
448
+ file_path: str | None = None,
449
+ reduction: str = "mse",
450
+ ) -> list[nn.Module]:
451
+ """
452
+ Build soft constraint modules from CLI arguments.
453
+
454
+ Args:
455
+ expressions: Expression constraints (e.g., ["y0 - y1*y2", "y0 > 0"])
456
+ file_path: Path to Python constraint file
457
+ reduction: Reduction mode for penalties
458
+
459
+ Returns:
460
+ List of constraint modules
461
+ """
462
+ constraints: list[nn.Module] = []
463
+
464
+ for expr in expressions or []:
465
+ constraints.append(ExpressionConstraint(expr, reduction))
466
+
467
+ if file_path:
468
+ constraints.append(FileConstraint(file_path, reduction))
469
+
470
+ return constraints
@@ -128,6 +128,12 @@ def train_fold(
128
128
  best_state = None
129
129
  history = []
130
130
 
131
+ # Determine if scheduler steps per batch (OneCycleLR) or per epoch
132
+ # Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
133
+ from torch.optim.lr_scheduler import OneCycleLR
134
+
135
+ step_per_batch = isinstance(scheduler, OneCycleLR)
136
+
131
137
  for epoch in range(epochs):
132
138
  # Training
133
139
  model.train()
@@ -144,6 +150,10 @@ def train_fold(
144
150
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
145
151
  optimizer.step()
146
152
 
153
+ # Per-batch LR scheduling (OneCycleLR)
154
+ if step_per_batch:
155
+ scheduler.step()
156
+
147
157
  train_loss += loss.item() * x.size(0)
148
158
  train_samples += x.size(0)
149
159
 
@@ -186,8 +196,8 @@ def train_fold(
186
196
  }
187
197
  )
188
198
 
189
- # LR scheduling
190
- if hasattr(scheduler, "step"):
199
+ # LR scheduling (epoch-based only, not for per-batch schedulers)
200
+ if not step_per_batch and hasattr(scheduler, "step"):
191
201
  if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
192
202
  scheduler.step(avg_val_loss)
193
203
  else:
wavedl/utils/data.py CHANGED
@@ -201,8 +201,18 @@ class DataSource(ABC):
201
201
  class NPZSource(DataSource):
202
202
  """Load data from NumPy .npz archives."""
203
203
 
204
+ @staticmethod
205
+ def _safe_load(path: str, mmap_mode: str | None = None):
206
+ """Load NPZ with pickle only if needed (sparse matrix support)."""
207
+ try:
208
+ return np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
209
+ except ValueError:
210
+ # Fallback for sparse matrices stored as object arrays
211
+ return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
212
+
204
213
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
205
- data = np.load(path, allow_pickle=True)
214
+ """Load NPZ file (pickle enabled only for sparse matrices)."""
215
+ data = self._safe_load(path)
206
216
  keys = list(data.keys())
207
217
 
208
218
  input_key = self._find_key(keys, INPUT_KEYS)
@@ -233,7 +243,7 @@ class NPZSource(DataSource):
233
243
 
234
244
  Note: Returns memory-mapped arrays - do NOT modify them.
235
245
  """
236
- data = np.load(path, allow_pickle=True, mmap_mode="r")
246
+ data = self._safe_load(path, mmap_mode="r")
237
247
  keys = list(data.keys())
238
248
 
239
249
  input_key = self._find_key(keys, INPUT_KEYS)
@@ -253,7 +263,7 @@ class NPZSource(DataSource):
253
263
 
254
264
  def load_outputs_only(self, path: str) -> np.ndarray:
255
265
  """Load only targets from NPZ (avoids loading large input arrays)."""
256
- data = np.load(path, allow_pickle=True)
266
+ data = self._safe_load(path)
257
267
  keys = list(data.keys())
258
268
 
259
269
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -677,6 +687,7 @@ def load_test_data(
677
687
  format: str = "auto",
678
688
  input_key: str | None = None,
679
689
  output_key: str | None = None,
690
+ input_channels: int | None = None,
680
691
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
681
692
  """
682
693
  Load test/inference data and return PyTorch tensors ready for model input.
@@ -698,6 +709,9 @@ def load_test_data(
698
709
  format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
699
710
  input_key: Custom key for input data (overrides auto-detection)
700
711
  output_key: Custom key for output data (overrides auto-detection)
712
+ input_channels: Explicit number of input channels. If provided, bypasses
713
+ the heuristic for 4D data. Use input_channels=1 for 3D volumes that
714
+ look like multi-channel 2D (e.g., depth ≤16).
701
715
 
702
716
  Returns:
703
717
  Tuple of:
@@ -737,7 +751,7 @@ def load_test_data(
737
751
  except KeyError:
738
752
  # Try with just inputs if outputs not found (inference-only mode)
739
753
  if format == "npz":
740
- data = np.load(path, allow_pickle=True)
754
+ data = NPZSource._safe_load(path)
741
755
  keys = list(data.keys())
742
756
  inp_key = DataSource._find_key(keys, custom_input_keys)
743
757
  if inp_key is None:
@@ -822,15 +836,20 @@ def load_test_data(
822
836
  # Add channel dimension if needed (dimension-agnostic)
823
837
  # X.ndim == 2: 1D data (N, L) → (N, 1, L)
824
838
  # X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
825
- # X.ndim == 4: Check if already has channel dim (C <= 16 heuristic)
839
+ # X.ndim == 4: Check if already has channel dim
826
840
  if X.ndim == 2:
827
841
  X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
828
842
  elif X.ndim == 3:
829
843
  X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
830
844
  elif X.ndim == 4:
831
845
  # Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
832
- # Heuristic: if dim 1 is small (<=16), assume it's already a channel dim
833
- if X.shape[1] > 16:
846
+ if input_channels is not None:
847
+ # Explicit override: user specifies channel count
848
+ if input_channels == 1:
849
+ X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
850
+ # else: already has channels, leave as-is
851
+ elif X.shape[1] > 16:
852
+ # Heuristic fallback: large dim 1 suggests 3D volume needing channel
834
853
  X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
835
854
  # X.ndim >= 5: assume channel dimension already exists
836
855
 
wavedl/utils/metrics.py CHANGED
@@ -560,11 +560,56 @@ def create_training_curves(
560
560
  )
561
561
  lines.append(line)
562
562
 
563
+ def set_lr_ticks(ax: plt.Axes, data: list[float], n_ticks: int = 4) -> None:
564
+ """Set n uniformly spaced ticks on LR axis with 10^n format labels."""
565
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
566
+ if not valid_data:
567
+ return
568
+ vmin, vmax = min(valid_data), max(valid_data)
569
+ # Snap to clean decade boundaries
570
+ log_min = np.floor(np.log10(vmin))
571
+ log_max = np.ceil(np.log10(vmax))
572
+ # Generate n uniformly spaced ticks as powers of 10
573
+ log_ticks = np.linspace(log_min, log_max, n_ticks)
574
+ # Round to nearest integer power of 10 for clean numbers
575
+ log_ticks = np.round(log_ticks)
576
+ ticks = 10.0**log_ticks
577
+ # Remove duplicates while preserving order
578
+ ticks = list(dict.fromkeys(ticks))
579
+ ax.set_yticks(ticks)
580
+ # Format all tick labels as 10^n
581
+ labels = [f"$10^{{{int(np.log10(t))}}}$" for t in ticks]
582
+ ax.set_yticklabels(labels)
583
+ ax.minorticks_off()
584
+
585
+ def set_loss_ticks(ax: plt.Axes, data: list[float]) -> None:
586
+ """Set ticks at powers of 10 that cover the data range."""
587
+ valid_data = [v for v in data if v is not None and not np.isnan(v) and v > 0]
588
+ if not valid_data:
589
+ return
590
+ vmin, vmax = min(valid_data), max(valid_data)
591
+ # Get decade range that covers data (ceil for min to avoid going too low)
592
+ log_min = int(np.ceil(np.log10(vmin)))
593
+ log_max = int(np.ceil(np.log10(vmax)))
594
+ # Generate ticks at each power of 10
595
+ ticks = [10.0**i for i in range(log_min, log_max + 1)]
596
+ ax.set_yticks(ticks)
597
+ # Format labels as 10^n
598
+ labels = [f"$10^{{{i}}}$" for i in range(log_min, log_max + 1)]
599
+ ax.set_yticklabels(labels)
600
+ ax.minorticks_off()
601
+
563
602
  ax1.set_xlabel("Epoch")
564
603
  ax1.set_ylabel("Loss")
565
604
  ax1.set_yscale("log") # Log scale for loss
566
605
  ax1.grid(True, alpha=0.3)
567
606
 
607
+ # Collect all loss values and set clean power of 10 ticks
608
+ all_loss_values = []
609
+ for metric in metrics:
610
+ all_loss_values.extend([h.get(metric, np.nan) for h in history])
611
+ set_loss_ticks(ax1, all_loss_values)
612
+
568
613
  # Check if learning rate data exists
569
614
  has_lr = show_lr and any("lr" in h for h in history)
570
615
 
@@ -581,9 +626,11 @@ def create_training_curves(
581
626
  alpha=0.7,
582
627
  label="Learning Rate",
583
628
  )
584
- ax2.set_ylabel("Learning Rate", color=COLORS["neutral"])
585
- ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
629
+ ax2.set_ylabel("Learning Rate")
586
630
  ax2.set_yscale("log") # Log scale for LR
631
+ set_lr_ticks(ax2, lr_values, n_ticks=4)
632
+ # Ensure right spine (axis line) is visible
633
+ ax2.spines["right"].set_visible(True)
587
634
  lines.append(line_lr)
588
635
 
589
636
  # Combined legend