wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,251 @@
1
+ """
2
+ MaxViT: Multi-Axis Vision Transformer
3
+ ======================================
4
+
5
+ MaxViT combines local and global attention with O(n) complexity using
6
+ multi-axis attention: block attention (local) + grid attention (global sparse).
7
+
8
+ **Key Features**:
9
+ - Multi-axis attention for both local and global context
10
+ - Hybrid design with MBConv + attention
11
+ - Linear O(n) complexity
12
+ - Hierarchical multi-scale features
13
+
14
+ **Variants**:
15
+ - maxvit_tiny: 31M params
16
+ - maxvit_small: 69M params
17
+ - maxvit_base: 120M params
18
+
19
+ **Requirements**:
20
+ - timm (for pretrained models and architecture)
21
+ - torchvision (fallback, limited support)
22
+
23
+ Reference:
24
+ Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
25
+ ECCV 2022. https://arxiv.org/abs/2204.01697
26
+
27
+ Author: Ductho Le (ductho.le@outlook.com)
28
+ """
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ from wavedl.models._timm_utils import build_regression_head
34
+ from wavedl.models.base import BaseModel
35
+ from wavedl.models.registry import register_model
36
+
37
+
38
+ __all__ = [
39
+ "MaxViTBase",
40
+ "MaxViTBaseLarge",
41
+ "MaxViTSmall",
42
+ "MaxViTTiny",
43
+ ]
44
+
45
+
46
+ # =============================================================================
47
+ # MAXVIT BASE CLASS
48
+ # =============================================================================
49
+
50
+
51
+ class MaxViTBase(BaseModel):
52
+ """
53
+ MaxViT base class wrapping timm implementation.
54
+
55
+ Multi-axis attention with local block and global grid attention.
56
+ 2D only due to attention structure.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ in_shape: tuple[int, int],
62
+ out_size: int,
63
+ model_name: str = "maxvit_tiny_tf_224",
64
+ pretrained: bool = True,
65
+ freeze_backbone: bool = False,
66
+ dropout_rate: float = 0.3,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(in_shape, out_size)
70
+
71
+ if len(in_shape) != 2:
72
+ raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
73
+
74
+ self.pretrained = pretrained
75
+ self.freeze_backbone = freeze_backbone
76
+ self.model_name = model_name
77
+
78
+ # Try to load from timm
79
+ try:
80
+ import timm
81
+
82
+ self.backbone = timm.create_model(
83
+ model_name,
84
+ pretrained=pretrained,
85
+ num_classes=0, # Remove classifier
86
+ )
87
+
88
+ # Get feature dimension
89
+ with torch.no_grad():
90
+ dummy = torch.zeros(1, 3, *in_shape)
91
+ features = self.backbone(dummy)
92
+ in_features = features.shape[-1]
93
+
94
+ except ImportError:
95
+ raise ImportError(
96
+ "timm is required for MaxViT. Install with: pip install timm"
97
+ )
98
+ except Exception as e:
99
+ raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
100
+
101
+ # Adapt input channels (3 -> 1)
102
+ self._adapt_input_channels()
103
+
104
+ # Regression head
105
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
106
+
107
+ if freeze_backbone:
108
+ self._freeze_backbone()
109
+
110
+ def _adapt_input_channels(self):
111
+ """Adapt first conv layer for single-channel input."""
112
+ # MaxViT uses stem.conv1 (Conv2dSame from timm)
113
+ adapted = False
114
+
115
+ # Find the first Conv2d with 3 input channels
116
+ for name, module in self.backbone.named_modules():
117
+ if hasattr(module, "in_channels") and module.in_channels == 3:
118
+ # Get parent and child names
119
+ parts = name.split(".")
120
+ parent = self.backbone
121
+ for part in parts[:-1]:
122
+ parent = getattr(parent, part)
123
+ child_name = parts[-1]
124
+
125
+ # Create new conv with 1 input channel
126
+ new_conv = self._make_new_conv(module)
127
+ setattr(parent, child_name, new_conv)
128
+ adapted = True
129
+ break
130
+
131
+ if not adapted:
132
+ import warnings
133
+
134
+ warnings.warn(
135
+ "Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
136
+ )
137
+
138
+ def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
139
+ """Create new conv layer with 1 input channel."""
140
+ # Handle both Conv2d and Conv2dSame from timm
141
+ type(old_conv)
142
+
143
+ # Get common parameters
144
+ kwargs = {
145
+ "out_channels": old_conv.out_channels,
146
+ "kernel_size": old_conv.kernel_size,
147
+ "stride": old_conv.stride,
148
+ "padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
149
+ "bias": old_conv.bias is not None,
150
+ }
151
+
152
+ # Create new conv (use regular Conv2d for simplicity)
153
+ new_conv = nn.Conv2d(1, **kwargs)
154
+
155
+ if self.pretrained:
156
+ with torch.no_grad():
157
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
158
+ if old_conv.bias is not None:
159
+ new_conv.bias.copy_(old_conv.bias)
160
+ return new_conv
161
+
162
+ def _freeze_backbone(self):
163
+ """Freeze backbone parameters."""
164
+ for param in self.backbone.parameters():
165
+ param.requires_grad = False
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ features = self.backbone(x)
169
+ return self.head(features)
170
+
171
+
172
+ # =============================================================================
173
+ # REGISTERED VARIANTS
174
+ # =============================================================================
175
+
176
+
177
+ @register_model("maxvit_tiny")
178
+ class MaxViTTiny(MaxViTBase):
179
+ """
180
+ MaxViT Tiny: ~30.1M backbone parameters.
181
+
182
+ Multi-axis attention with local+global context.
183
+ 2D only.
184
+
185
+ Example:
186
+ >>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
187
+ >>> x = torch.randn(4, 1, 224, 224)
188
+ >>> out = model(x) # (4, 3)
189
+ """
190
+
191
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
192
+ super().__init__(
193
+ in_shape=in_shape,
194
+ out_size=out_size,
195
+ model_name="maxvit_tiny_tf_224",
196
+ **kwargs,
197
+ )
198
+
199
+ def __repr__(self) -> str:
200
+ return (
201
+ f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
202
+ f"pretrained={self.pretrained})"
203
+ )
204
+
205
+
206
+ @register_model("maxvit_small")
207
+ class MaxViTSmall(MaxViTBase):
208
+ """
209
+ MaxViT Small: ~67.6M backbone parameters.
210
+
211
+ Multi-axis attention with local+global context.
212
+ 2D only.
213
+ """
214
+
215
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
216
+ super().__init__(
217
+ in_shape=in_shape,
218
+ out_size=out_size,
219
+ model_name="maxvit_small_tf_224",
220
+ **kwargs,
221
+ )
222
+
223
+ def __repr__(self) -> str:
224
+ return (
225
+ f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
226
+ f"pretrained={self.pretrained})"
227
+ )
228
+
229
+
230
+ @register_model("maxvit_base")
231
+ class MaxViTBaseLarge(MaxViTBase):
232
+ """
233
+ MaxViT Base: ~118.1M backbone parameters.
234
+
235
+ Multi-axis attention with local+global context.
236
+ 2D only.
237
+ """
238
+
239
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
240
+ super().__init__(
241
+ in_shape=in_shape,
242
+ out_size=out_size,
243
+ model_name="maxvit_base_tf_224",
244
+ **kwargs,
245
+ )
246
+
247
+ def __repr__(self) -> str:
248
+ return (
249
+ f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
250
+ f"pretrained={self.pretrained})"
251
+ )
@@ -13,8 +13,8 @@ optimization to achieve excellent accuracy with minimal computational cost.
13
13
  - Designed for real-time inference on CPUs and edge devices
14
14
 
15
15
  **Variants**:
16
- - mobilenet_v3_small: Ultra-lightweight (~1.1M params) - Edge/embedded
17
- - mobilenet_v3_large: Balanced (~3.2M params) - Mobile deployment
16
+ - mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
17
+ - mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
18
18
 
19
19
  **Use Cases**:
20
20
  - Real-time structural health monitoring on embedded systems
@@ -206,7 +206,7 @@ class MobileNetV3Small(MobileNetV3Base):
206
206
  """
207
207
  MobileNetV3-Small: Ultra-lightweight for edge deployment.
208
208
 
209
- ~1.1M parameters. Designed for the most constrained environments.
209
+ ~0.9M backbone parameters. Designed for the most constrained environments.
210
210
  Achieves ~67% ImageNet accuracy with minimal compute.
211
211
 
212
212
  Recommended for:
@@ -217,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
217
217
 
218
218
  Performance (approximate):
219
219
  - CPU inference: ~6ms (single core)
220
- - Parameters: ~1.1M
220
+ - Parameters: ~0.9M backbone
221
221
  - MAdds: 56M
222
222
 
223
223
  Args:
@@ -253,7 +253,7 @@ class MobileNetV3Large(MobileNetV3Base):
253
253
  """
254
254
  MobileNetV3-Large: Balanced efficiency and accuracy.
255
255
 
256
- ~3.2M parameters. Best trade-off for mobile/portable deployment.
256
+ ~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
257
257
  Achieves ~75% ImageNet accuracy with efficient inference.
258
258
 
259
259
  Recommended for:
@@ -264,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
264
264
 
265
265
  Performance (approximate):
266
266
  - CPU inference: ~20ms (single core)
267
- - Parameters: ~3.2M
267
+ - Parameters: ~3.0M backbone
268
268
  - MAdds: 219M
269
269
 
270
270
  Args:
wavedl/models/regnet.py CHANGED
@@ -13,11 +13,11 @@ Models scale smoothly from mobile to server deployments.
13
13
  - Optional Squeeze-and-Excitation (SE) attention
14
14
 
15
15
  **Variants** (RegNetY includes SE attention):
16
- - regnet_y_400mf: Ultra-light (~4.0M params, 0.4 GFLOPs)
17
- - regnet_y_800mf: Light (~5.8M params, 0.8 GFLOPs)
18
- - regnet_y_1_6gf: Medium (~10.5M params, 1.6 GFLOPs) - Recommended
19
- - regnet_y_3_2gf: Large (~18.3M params, 3.2 GFLOPs)
20
- - regnet_y_8gf: Very large (~37.9M params, 8.0 GFLOPs)
16
+ - regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
17
+ - regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
18
+ - regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
19
+ - regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
20
+ - regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
21
21
 
22
22
  **When to Use RegNet**:
23
23
  - When you need predictable performance at a given compute budget
@@ -210,7 +210,7 @@ class RegNetY400MF(RegNetBase):
210
210
  """
211
211
  RegNetY-400MF: Ultra-lightweight for constrained environments.
212
212
 
213
- ~4.0M parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
213
+ ~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
214
214
 
215
215
  Recommended for:
216
216
  - Edge deployment with moderate accuracy needs
@@ -250,7 +250,7 @@ class RegNetY800MF(RegNetBase):
250
250
  """
251
251
  RegNetY-800MF: Light variant with good accuracy.
252
252
 
253
- ~6.4M parameters, 0.8 GFLOPs. Good balance for mobile deployment.
253
+ ~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
254
254
 
255
255
  Recommended for:
256
256
  - Mobile/portable devices
@@ -290,7 +290,7 @@ class RegNetY1_6GF(RegNetBase):
290
290
  """
291
291
  RegNetY-1.6GF: Recommended default for balanced performance.
292
292
 
293
- ~11.2M parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
293
+ ~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
294
294
  Comparable to ResNet50 but more efficient.
295
295
 
296
296
  Recommended for:
@@ -331,7 +331,7 @@ class RegNetY3_2GF(RegNetBase):
331
331
  """
332
332
  RegNetY-3.2GF: Higher accuracy for demanding tasks.
333
333
 
334
- ~19.4M parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
334
+ ~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
335
335
 
336
336
  Recommended for:
337
337
  - Larger datasets requiring more capacity
@@ -371,7 +371,7 @@ class RegNetY8GF(RegNetBase):
371
371
  """
372
372
  RegNetY-8GF: High capacity for large-scale tasks.
373
373
 
374
- ~39.2M parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
374
+ ~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
375
375
 
376
376
  Recommended for:
377
377
  - Very large datasets (>50k samples)
wavedl/models/resnet.py CHANGED
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - resnet18: Lightweight, fast training (~11M params)
15
- - resnet34: Balanced capacity (~21M params)
16
- - resnet50: Higher capacity with bottleneck blocks (~25M params)
14
+ - resnet18: Lightweight, fast training (~11.2M backbone params)
15
+ - resnet34: Balanced capacity (~21.3M backbone params)
16
+ - resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
17
17
 
18
18
  References:
19
19
  He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
@@ -534,7 +534,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
534
534
  """
535
535
  ResNet-18 with ImageNet pretrained weights (2D only).
536
536
 
537
- ~11M parameters. Good for: Transfer learning, faster convergence.
537
+ ~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
538
538
 
539
539
  Args:
540
540
  in_shape: (H, W) image dimensions
@@ -563,7 +563,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
563
563
  """
564
564
  ResNet-50 with ImageNet pretrained weights (2D only).
565
565
 
566
- ~25M parameters. Good for: High accuracy with transfer learning.
566
+ ~23.5M backbone parameters. Good for: High accuracy with transfer learning.
567
567
 
568
568
  Args:
569
569
  in_shape: (H, W) image dimensions
wavedl/models/resnet3d.py CHANGED
@@ -179,7 +179,7 @@ class ResNet3D18(ResNet3DBase):
179
179
  """
180
180
  ResNet3D-18: Lightweight 3D ResNet for volumetric data.
181
181
 
182
- ~33M parameters. Uses 3D convolutions throughout for true volumetric processing.
182
+ ~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
183
183
  Pretrained on Kinetics-400 (video action recognition).
184
184
 
185
185
  Recommended for:
@@ -221,7 +221,7 @@ class MC3_18(ResNet3DBase):
221
221
  """
222
222
  MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
223
223
 
224
- ~11M parameters. More efficient than pure 3D ResNet while maintaining
224
+ ~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
225
225
  good spatiotemporal modeling. Uses 3D convolutions in early layers
226
226
  and 2D convolutions in later layers.
227
227
 
wavedl/models/swin.py CHANGED
@@ -304,7 +304,7 @@ class SwinTiny(SwinTransformerBase):
304
304
  """
305
305
  Swin-T (Tiny): Efficient default for most wave-based tasks.
306
306
 
307
- ~28M parameters. Good balance of accuracy and computational cost.
307
+ ~27.5M backbone parameters. Good balance of accuracy and computational cost.
308
308
  Outperforms ResNet50 while being more efficient.
309
309
 
310
310
  Recommended for:
@@ -353,7 +353,7 @@ class SwinSmall(SwinTransformerBase):
353
353
  """
354
354
  Swin-S (Small): Higher accuracy with moderate compute.
355
355
 
356
- ~50M parameters. Better accuracy than Swin-T for larger datasets.
356
+ ~48.8M backbone parameters. Better accuracy than Swin-T for larger datasets.
357
357
 
358
358
  Recommended for:
359
359
  - Larger datasets (>20k samples)
@@ -400,7 +400,7 @@ class SwinBase(SwinTransformerBase):
400
400
  """
401
401
  Swin-B (Base): Maximum accuracy for large-scale tasks.
402
402
 
403
- ~88M parameters. Best accuracy but requires more compute and data.
403
+ ~86.7M backbone parameters. Best accuracy but requires more compute and data.
404
404
 
405
405
  Recommended for:
406
406
  - Very large datasets (>50k samples)
wavedl/models/tcn.py CHANGED
@@ -296,7 +296,7 @@ class TCN(TCNBase):
296
296
  """
297
297
  TCN: Standard Temporal Convolutional Network.
298
298
 
299
- ~7.0M parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
299
+ ~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
300
300
  Receptive field: 511 samples with kernel_size=3.
301
301
 
302
302
  Recommended for:
@@ -338,7 +338,7 @@ class TCNSmall(TCNBase):
338
338
  """
339
339
  TCN-Small: Lightweight variant for quick experiments.
340
340
 
341
- ~1.0M parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
341
+ ~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
342
342
  Receptive field: 127 samples with kernel_size=3.
343
343
 
344
344
  Recommended for:
@@ -376,7 +376,7 @@ class TCNLarge(TCNBase):
376
376
  """
377
377
  TCN-Large: High-capacity variant for complex patterns.
378
378
 
379
- ~10.2M parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
379
+ ~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
380
380
  Receptive field: 2047 samples with kernel_size=3.
381
381
 
382
382
  Recommended for:
wavedl/models/unet.py CHANGED
@@ -119,7 +119,7 @@ class UNetRegression(BaseModel):
119
119
  Uses U-Net encoder-decoder architecture with skip connections,
120
120
  then applies global pooling for standard vector regression output.
121
121
 
122
- ~31.1M parameters (2D). Good for leveraging multi-scale features
122
+ ~31.0M backbone parameters (2D). Good for leveraging multi-scale features
123
123
  and skip connections for regression tasks.
124
124
 
125
125
  Args:
wavedl/models/vit.py CHANGED
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
10
10
  - 2D: Images/spectrograms → patches are grid squares
11
11
 
12
12
  **Variants**:
13
- - vit_tiny: Smallest (~5.7M params, embed_dim=192, depth=12, heads=3)
14
- - vit_small: Light (~22M params, embed_dim=384, depth=12, heads=6)
15
- - vit_base: Standard (~86M params, embed_dim=768, depth=12, heads=12)
13
+ - vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
14
+ - vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
15
+ - vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
16
16
 
17
17
  References:
18
18
  Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
365
365
  """
366
366
  ViT-Tiny: Smallest Vision Transformer variant.
367
367
 
368
- ~5.7M parameters. Good for: Quick experiments, smaller datasets.
368
+ ~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
369
369
 
370
370
  Args:
371
371
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
398
398
  """
399
399
  ViT-Small: Light Vision Transformer variant.
400
400
 
401
- ~22M parameters. Good for: Balanced performance.
401
+ ~21.4M backbone parameters. Good for: Balanced performance.
402
402
 
403
403
  Args:
404
404
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
429
429
  """
430
430
  ViT-Base: Standard Vision Transformer variant.
431
431
 
432
- ~86M parameters. Good for: High accuracy, larger datasets.
432
+ ~85.3M backbone parameters. Good for: High accuracy, larger datasets.
433
433
 
434
434
  Args:
435
435
  in_shape: (L,) for 1D or (H, W) for 2D
wavedl/train.py CHANGED
@@ -239,11 +239,12 @@ def parse_args() -> argparse.Namespace:
239
239
  help="Python modules to import before training (for custom models)",
240
240
  )
241
241
  parser.add_argument(
242
- "--pretrained",
243
- action=argparse.BooleanOptionalAction,
244
- default=True,
245
- help="Use pretrained weights (default: True). Use --no-pretrained to train from scratch.",
242
+ "--no_pretrained",
243
+ dest="pretrained",
244
+ action="store_false",
245
+ help="Train from scratch without pretrained weights (default: use pretrained)",
246
246
  )
247
+ parser.set_defaults(pretrained=True)
247
248
 
248
249
  # Configuration File
249
250
  parser.add_argument(
@@ -1028,12 +1029,14 @@ def main():
1028
1029
 
1029
1030
  for x, y in pbar:
1030
1031
  with accelerator.accumulate(model):
1031
- pred = model(x)
1032
- # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1033
- if isinstance(criterion, PhysicsConstrainedLoss):
1034
- loss = criterion(pred, y, x)
1035
- else:
1036
- loss = criterion(pred, y)
1032
+ # Use mixed precision for forward pass (respects --precision flag)
1033
+ with accelerator.autocast():
1034
+ pred = model(x)
1035
+ # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1036
+ if isinstance(criterion, PhysicsConstrainedLoss):
1037
+ loss = criterion(pred, y, x)
1038
+ else:
1039
+ loss = criterion(pred, y)
1037
1040
 
1038
1041
  accelerator.backward(loss)
1039
1042
 
@@ -1082,12 +1085,14 @@ def main():
1082
1085
 
1083
1086
  with torch.inference_mode():
1084
1087
  for x, y in val_dl:
1085
- pred = model(x)
1086
- # Pass inputs for input-dependent constraints
1087
- if isinstance(criterion, PhysicsConstrainedLoss):
1088
- loss = criterion(pred, y, x)
1089
- else:
1090
- loss = criterion(pred, y)
1088
+ # Use mixed precision for validation (consistent with training)
1089
+ with accelerator.autocast():
1090
+ pred = model(x)
1091
+ # Pass inputs for input-dependent constraints
1092
+ if isinstance(criterion, PhysicsConstrainedLoss):
1093
+ loss = criterion(pred, y, x)
1094
+ else:
1095
+ loss = criterion(pred, y)
1091
1096
 
1092
1097
  val_loss_sum += loss.detach() * x.size(0)
1093
1098
  val_samples += x.size(0)
wavedl/utils/data.py CHANGED
@@ -474,9 +474,18 @@ class _TransposedH5Dataset:
474
474
  self.shape = tuple(reversed(h5_dataset.shape))
475
475
  self.dtype = h5_dataset.dtype
476
476
 
477
- # Precompute transpose axis order for efficiency
478
- # For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0)
479
- self._transpose_axes = tuple(range(len(h5_dataset.shape) - 1, -1, -1))
477
+ @property
478
+ def ndim(self) -> int:
479
+ """Number of dimensions (derived from shape for numpy compatibility)."""
480
+ return len(self.shape)
481
+
482
+ @property
483
+ def _transpose_axes(self) -> tuple[int, ...]:
484
+ """Transpose axis order for reversing dimensions.
485
+
486
+ For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0).
487
+ """
488
+ return tuple(range(len(self._dataset.shape) - 1, -1, -1))
480
489
 
481
490
  def __len__(self) -> int:
482
491
  return self.shape[0]
@@ -965,8 +974,17 @@ def load_test_data(
965
974
  else:
966
975
  # Fallback to default source.load() for unknown formats
967
976
  inp, outp = source.load(path)
968
- except KeyError:
969
- # Try with just inputs if outputs not found (inference-only mode)
977
+ except KeyError as e:
978
+ # IMPORTANT: Only fall back to inference-only mode if outputs are
979
+ # genuinely missing (auto-detection failed). If user explicitly
980
+ # provided --output_key, they expect it to exist - don't silently drop.
981
+ if output_key is not None:
982
+ raise KeyError(
983
+ f"Explicit --output_key '{output_key}' not found in file. "
984
+ f"Available keys depend on file format. Original error: {e}"
985
+ ) from e
986
+
987
+ # Legitimate fallback: no explicit output_key, outputs just not present
970
988
  if format == "npz":
971
989
  # First pass to find keys
972
990
  with np.load(path, allow_pickle=False) as probe:
@@ -1083,11 +1101,26 @@ def load_test_data(
1083
1101
  raise ValueError(
1084
1102
  f"Input appears to be channels-last format: {tuple(X.shape)}. "
1085
1103
  "WaveDL expects channels-first (N, C, H, W). "
1086
- "Convert your data using: X = X.permute(0, 3, 1, 2)"
1104
+ "Convert your data using: X = X.permute(0, 3, 1, 2). "
1105
+ "If this is actually a 3D volume with small depth, "
1106
+ "use --input_channels 1 to add a channel dimension."
1087
1107
  )
1088
1108
  elif X.shape[1] > 16:
1089
1109
  # Heuristic fallback: large dim 1 suggests 3D volume needing channel
1090
1110
  X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
1111
+ else:
1112
+ # Ambiguous case: shallow 3D volume (D <= 16) or multi-channel 2D
1113
+ # Default to treating as multi-channel 2D (no modification needed)
1114
+ # Log a warning so users know about the --input_channels option
1115
+ import warnings
1116
+
1117
+ warnings.warn(
1118
+ f"Ambiguous 4D input shape: {tuple(X.shape)}. "
1119
+ f"Assuming {X.shape[1]} channels (multi-channel 2D). "
1120
+ f"For 3D volumes with depth={X.shape[1]}, use --input_channels 1.",
1121
+ UserWarning,
1122
+ stacklevel=2,
1123
+ )
1091
1124
  # X.ndim >= 5: assume channel dimension already exists
1092
1125
 
1093
1126
  return X, y