wavedl 1.5.6__tar.gz → 1.5.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.6/src/wavedl.egg-info → wavedl-1.5.7}/PKG-INFO +16 -7
  2. {wavedl-1.5.6 → wavedl-1.5.7}/README.md +15 -6
  3. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/efficientnet.py +24 -7
  5. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/efficientnetv2.py +29 -6
  6. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/mobilenetv3.py +31 -8
  7. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/regnet.py +29 -6
  8. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/swin.py +38 -6
  9. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/tcn.py +22 -2
  10. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/test.py +7 -3
  11. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/train.py +41 -12
  12. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/constraints.py +11 -5
  13. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/data.py +82 -13
  14. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/metrics.py +287 -326
  15. {wavedl-1.5.6 → wavedl-1.5.7/src/wavedl.egg-info}/PKG-INFO +16 -7
  16. {wavedl-1.5.6 → wavedl-1.5.7}/LICENSE +0 -0
  17. {wavedl-1.5.6 → wavedl-1.5.7}/pyproject.toml +0 -0
  18. {wavedl-1.5.6 → wavedl-1.5.7}/setup.cfg +0 -0
  19. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/hpc.py +0 -0
  20. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/hpo.py +0 -0
  21. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/__init__.py +0 -0
  22. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/_template.py +0 -0
  23. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/base.py +0 -0
  24. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/cnn.py +0 -0
  25. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/convnext.py +0 -0
  26. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/densenet.py +0 -0
  27. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/registry.py +0 -0
  28. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/resnet.py +0 -0
  29. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/resnet3d.py +0 -0
  30. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/unet.py +0 -0
  31. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/vit.py +0 -0
  32. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/__init__.py +0 -0
  33. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/config.py +0 -0
  34. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/optimizers.py +0 -0
  38. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/schedulers.py +0 -0
  39. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/SOURCES.txt +0 -0
  40. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.6
3
+ Version: 1.5.7
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -388,7 +388,7 @@ WaveDL/
388
388
  ├── configs/ # YAML config templates
389
389
  ├── examples/ # Ready-to-run examples
390
390
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (731 tests)
391
+ ├── unit_tests/ # Pytest test suite (903 tests)
392
392
 
393
393
  ├── pyproject.toml # Package config, dependencies
394
394
  ├── CHANGELOG.md # Version history
@@ -1035,12 +1035,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1035
1035
 
1036
1036
  | Parameter | Unit | Description |
1037
1037
  |-----------|------|-------------|
1038
- | *h* | mm | Plate thickness |
1039
- | √(*E*/ρ) | km/s | Square root of Young's modulus over density |
1040
- | *ν* | — | Poisson's ratio |
1038
+ | $h$ | mm | Plate thickness |
1039
+ | $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
1040
+ | $\nu$ | — | Poisson's ratio |
1041
1041
 
1042
1042
  > [!NOTE]
1043
- > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
1043
+ > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
1044
+ "*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
1045
+
1046
+ **Sample Dispersion Data:**
1047
+
1048
+ <p align="center">
1049
+ <img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
1050
+ <em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
1051
+ </p>
1044
1052
 
1045
1053
  **Try it yourself:**
1046
1054
 
@@ -1061,7 +1069,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
1061
1069
  | File | Description |
1062
1070
  |------|-------------|
1063
1071
  | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1064
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1072
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
1073
+ | `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
1065
1074
  | `model.onnx` | ONNX export with embedded de-normalization |
1066
1075
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1067
1076
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -342,7 +342,7 @@ WaveDL/
342
342
  ├── configs/ # YAML config templates
343
343
  ├── examples/ # Ready-to-run examples
344
344
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (731 tests)
345
+ ├── unit_tests/ # Pytest test suite (903 tests)
346
346
 
347
347
  ├── pyproject.toml # Package config, dependencies
348
348
  ├── CHANGELOG.md # Version history
@@ -989,12 +989,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
989
989
 
990
990
  | Parameter | Unit | Description |
991
991
  |-----------|------|-------------|
992
- | *h* | mm | Plate thickness |
993
- | √(*E*/ρ) | km/s | Square root of Young's modulus over density |
994
- | *ν* | — | Poisson's ratio |
992
+ | $h$ | mm | Plate thickness |
993
+ | $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
994
+ | $\nu$ | — | Poisson's ratio |
995
995
 
996
996
  > [!NOTE]
997
- > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
997
+ > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
998
+ "*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
999
+
1000
+ **Sample Dispersion Data:**
1001
+
1002
+ <p align="center">
1003
+ <img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
1004
+ <em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
1005
+ </p>
998
1006
 
999
1007
  **Try it yourself:**
1000
1008
 
@@ -1015,7 +1023,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
1015
1023
  | File | Description |
1016
1024
  |------|-------------|
1017
1025
  | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1018
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1026
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
1027
+ | `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
1019
1028
  | `model.onnx` | ONNX export with embedded de-normalization |
1020
1029
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1021
1030
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.6"
21
+ __version__ = "1.5.7"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
110
110
  self._freeze_backbone()
111
111
 
112
112
  def _adapt_input_channels(self):
113
- """Modify first conv to handle single-channel input by expanding to 3ch."""
114
- # We'll handle this in forward by repeating channels
115
- pass
113
+ """Modify first conv to accept single-channel input.
114
+
115
+ Instead of expanding 1→3 channels in forward (which triples memory),
116
+ we replace the first conv layer with a 1-channel version and initialize
117
+ weights as the mean of the pretrained RGB filters.
118
+ """
119
+ # EfficientNet stem conv is at: features[0][0]
120
+ old_conv = self.backbone.features[0][0]
121
+ new_conv = nn.Conv2d(
122
+ 1, # Single channel input
123
+ old_conv.out_channels,
124
+ kernel_size=old_conv.kernel_size,
125
+ stride=old_conv.stride,
126
+ padding=old_conv.padding,
127
+ dilation=old_conv.dilation,
128
+ groups=old_conv.groups,
129
+ padding_mode=old_conv.padding_mode,
130
+ bias=old_conv.bias is not None,
131
+ )
132
+ if self.pretrained:
133
+ # Initialize with mean of pretrained RGB weights
134
+ with torch.no_grad():
135
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
136
+ self.backbone.features[0][0] = new_conv
116
137
 
117
138
  def _freeze_backbone(self):
118
139
  """Freeze all backbone parameters except the classifier."""
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
130
151
  Returns:
131
152
  Output tensor of shape (B, out_size)
132
153
  """
133
- # Expand single channel to 3 channels for pretrained weights
134
- if x.size(1) == 1:
135
- x = x.expand(-1, 3, -1, -1)
136
-
137
154
  return self.backbone(x)
138
155
 
139
156
  @classmethod
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
129
129
  nn.Linear(regression_hidden // 2, out_size),
130
130
  )
131
131
 
132
- # Optionally freeze backbone for fine-tuning
132
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
133
+ self._adapt_input_channels()
134
+
135
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
133
136
  if freeze_backbone:
134
137
  self._freeze_backbone()
135
138
 
139
+ def _adapt_input_channels(self):
140
+ """Modify first conv to accept single-channel input.
141
+
142
+ Instead of expanding 1→3 channels in forward (which triples memory),
143
+ we replace the first conv layer with a 1-channel version and initialize
144
+ weights as the mean of the pretrained RGB filters.
145
+ """
146
+ old_conv = self.backbone.features[0][0]
147
+ new_conv = nn.Conv2d(
148
+ 1, # Single channel input
149
+ old_conv.out_channels,
150
+ kernel_size=old_conv.kernel_size,
151
+ stride=old_conv.stride,
152
+ padding=old_conv.padding,
153
+ dilation=old_conv.dilation,
154
+ groups=old_conv.groups,
155
+ padding_mode=old_conv.padding_mode,
156
+ bias=old_conv.bias is not None,
157
+ )
158
+ if self.pretrained:
159
+ with torch.no_grad():
160
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
161
+ self.backbone.features[0][0] = new_conv
162
+
136
163
  def _freeze_backbone(self):
137
164
  """Freeze all backbone parameters except the classifier."""
138
165
  for name, param in self.backbone.named_parameters():
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
144
171
  Forward pass.
145
172
 
146
173
  Args:
147
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
174
+ x: Input tensor of shape (B, 1, H, W)
148
175
 
149
176
  Returns:
150
177
  Output tensor of shape (B, out_size)
151
178
  """
152
- # Expand single channel to 3 channels for pretrained weights compatibility
153
- if x.size(1) == 1:
154
- x = x.expand(-1, 3, -1, -1)
155
-
156
179
  return self.backbone(x)
157
180
 
158
181
  @classmethod
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
136
136
  nn.Linear(regression_hidden, out_size),
137
137
  )
138
138
 
139
- # Optionally freeze backbone for fine-tuning
139
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
140
+ self._adapt_input_channels()
141
+
142
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
140
143
  if freeze_backbone:
141
144
  self._freeze_backbone()
142
145
 
146
+ def _adapt_input_channels(self):
147
+ """Modify first conv to accept single-channel input.
148
+
149
+ Instead of expanding 1→3 channels in forward (which triples memory),
150
+ we replace the first conv layer with a 1-channel version and initialize
151
+ weights as the mean of the pretrained RGB filters.
152
+ """
153
+ old_conv = self.backbone.features[0][0]
154
+ new_conv = nn.Conv2d(
155
+ 1, # Single channel input
156
+ old_conv.out_channels,
157
+ kernel_size=old_conv.kernel_size,
158
+ stride=old_conv.stride,
159
+ padding=old_conv.padding,
160
+ dilation=old_conv.dilation,
161
+ groups=old_conv.groups,
162
+ padding_mode=old_conv.padding_mode,
163
+ bias=old_conv.bias is not None,
164
+ )
165
+ if self.pretrained:
166
+ with torch.no_grad():
167
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
168
+ self.backbone.features[0][0] = new_conv
169
+
143
170
  def _freeze_backbone(self):
144
171
  """Freeze all backbone parameters except the classifier."""
145
172
  for name, param in self.backbone.named_parameters():
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
151
178
  Forward pass.
152
179
 
153
180
  Args:
154
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
181
+ x: Input tensor of shape (B, 1, H, W)
155
182
 
156
183
  Returns:
157
184
  Output tensor of shape (B, out_size)
158
185
  """
159
- # Expand single channel to 3 channels for pretrained weights compatibility
160
- if x.size(1) == 1:
161
- x = x.expand(-1, 3, -1, -1)
162
-
163
186
  return self.backbone(x)
164
187
 
165
188
  @classmethod
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
194
217
 
195
218
  Performance (approximate):
196
219
  - CPU inference: ~6ms (single core)
197
- - Parameters: 2.5M
220
+ - Parameters: ~1.1M
198
221
  - MAdds: 56M
199
222
 
200
223
  Args:
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
241
264
 
242
265
  Performance (approximate):
243
266
  - CPU inference: ~20ms (single core)
244
- - Parameters: 5.4M
267
+ - Parameters: ~3.2M
245
268
  - MAdds: 219M
246
269
 
247
270
  Args:
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
140
140
  nn.Linear(regression_hidden, out_size),
141
141
  )
142
142
 
143
- # Optionally freeze backbone for fine-tuning
143
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
144
+ self._adapt_input_channels()
145
+
146
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
144
147
  if freeze_backbone:
145
148
  self._freeze_backbone()
146
149
 
150
+ def _adapt_input_channels(self):
151
+ """Modify first conv to accept single-channel input.
152
+
153
+ Instead of expanding 1→3 channels in forward (which triples memory),
154
+ we replace the first conv layer with a 1-channel version and initialize
155
+ weights as the mean of the pretrained RGB filters.
156
+ """
157
+ old_conv = self.backbone.stem[0]
158
+ new_conv = nn.Conv2d(
159
+ 1, # Single channel input
160
+ old_conv.out_channels,
161
+ kernel_size=old_conv.kernel_size,
162
+ stride=old_conv.stride,
163
+ padding=old_conv.padding,
164
+ dilation=old_conv.dilation,
165
+ groups=old_conv.groups,
166
+ padding_mode=old_conv.padding_mode,
167
+ bias=old_conv.bias is not None,
168
+ )
169
+ if self.pretrained:
170
+ with torch.no_grad():
171
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
+ self.backbone.stem[0] = new_conv
173
+
147
174
  def _freeze_backbone(self):
148
175
  """Freeze all backbone parameters except the fc layer."""
149
176
  for name, param in self.backbone.named_parameters():
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
155
182
  Forward pass.
156
183
 
157
184
  Args:
158
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
185
+ x: Input tensor of shape (B, 1, H, W)
159
186
 
160
187
  Returns:
161
188
  Output tensor of shape (B, out_size)
162
189
  """
163
- # Expand single channel to 3 channels for pretrained weights compatibility
164
- if x.size(1) == 1:
165
- x = x.expand(-1, 3, -1, -1)
166
-
167
190
  return self.backbone(x)
168
191
 
169
192
  @classmethod
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
141
141
  nn.Linear(regression_hidden // 2, out_size),
142
142
  )
143
143
 
144
- # Optionally freeze backbone for fine-tuning
144
+ # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
+ self._adapt_input_channels()
146
+
147
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
145
148
  if freeze_backbone:
146
149
  self._freeze_backbone()
147
150
 
151
+ def _adapt_input_channels(self):
152
+ """Modify patch embedding conv to accept single-channel input.
153
+
154
+ Instead of expanding 1→3 channels in forward (which triples memory),
155
+ we replace the patch embedding conv with a 1-channel version and
156
+ initialize weights as the mean of the pretrained RGB filters.
157
+ """
158
+ # Swin's patch embedding is at features[0][0]
159
+ try:
160
+ old_conv = self.backbone.features[0][0]
161
+ except (IndexError, AttributeError, TypeError) as e:
162
+ raise RuntimeError(
163
+ f"Swin patch embed structure changed in this torchvision version. "
164
+ f"Cannot adapt input channels. Error: {e}"
165
+ ) from e
166
+ new_conv = nn.Conv2d(
167
+ 1, # Single channel input
168
+ old_conv.out_channels,
169
+ kernel_size=old_conv.kernel_size,
170
+ stride=old_conv.stride,
171
+ padding=old_conv.padding,
172
+ dilation=old_conv.dilation,
173
+ groups=old_conv.groups,
174
+ padding_mode=old_conv.padding_mode,
175
+ bias=old_conv.bias is not None,
176
+ )
177
+ if self.pretrained:
178
+ with torch.no_grad():
179
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
+ if old_conv.bias is not None:
181
+ new_conv.bias.copy_(old_conv.bias)
182
+ self.backbone.features[0][0] = new_conv
183
+
148
184
  def _freeze_backbone(self):
149
185
  """Freeze all backbone parameters except the head."""
150
186
  for name, param in self.backbone.named_parameters():
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
156
192
  Forward pass.
157
193
 
158
194
  Args:
159
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
195
+ x: Input tensor of shape (B, 1, H, W)
160
196
 
161
197
  Returns:
162
198
  Output tensor of shape (B, out_size)
163
199
  """
164
- # Expand single channel to 3 channels for pretrained weights compatibility
165
- if x.size(1) == 1:
166
- x = x.expand(-1, 3, -1, -1)
167
-
168
200
  return self.backbone(x)
169
201
 
170
202
  @classmethod
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
45
45
  from wavedl.models.registry import register_model
46
46
 
47
47
 
48
+ def _find_group_count(channels: int, max_groups: int = 8) -> int:
49
+ """
50
+ Find largest valid group count for GroupNorm.
51
+
52
+ GroupNorm requires channels to be divisible by num_groups.
53
+ This finds the largest divisor up to max_groups.
54
+
55
+ Args:
56
+ channels: Number of channels
57
+ max_groups: Maximum group count to consider (default: 8)
58
+
59
+ Returns:
60
+ Largest valid group count (always >= 1)
61
+ """
62
+ for g in range(min(max_groups, channels), 0, -1):
63
+ if channels % g == 0:
64
+ return g
65
+ return 1
66
+
67
+
48
68
  class CausalConv1d(nn.Module):
49
69
  """
50
70
  Causal 1D convolution with dilation.
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
101
121
 
102
122
  # First causal convolution
103
123
  self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
104
- self.norm1 = nn.GroupNorm(min(8, out_channels), out_channels)
124
+ self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
105
125
  self.act1 = nn.GELU()
106
126
  self.dropout1 = nn.Dropout(dropout)
107
127
 
108
128
  # Second causal convolution
109
129
  self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
110
- self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
130
+ self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
111
131
  self.act2 = nn.GELU()
112
132
  self.dropout2 = nn.Dropout(dropout)
113
133
 
@@ -311,7 +311,7 @@ def load_data_for_inference(
311
311
  # ==============================================================================
312
312
  def load_checkpoint(
313
313
  checkpoint_dir: str,
314
- in_shape: tuple[int, int],
314
+ in_shape: tuple[int, ...],
315
315
  out_size: int,
316
316
  model_name: str | None = None,
317
317
  ) -> tuple[nn.Module, any]:
@@ -320,7 +320,7 @@ def load_checkpoint(
320
320
 
321
321
  Args:
322
322
  checkpoint_dir: Path to checkpoint directory
323
- in_shape: Input image shape (H, W)
323
+ in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
324
324
  out_size: Number of output parameters
325
325
  model_name: Model architecture name (auto-detect if None)
326
326
 
@@ -376,7 +376,11 @@ def load_checkpoint(
376
376
  )
377
377
 
378
378
  logging.info(f" Building model: {model_name}")
379
- model = build_model(model_name, in_shape=in_shape, out_size=out_size)
379
+ # Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
380
+ # so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
381
+ model = build_model(
382
+ model_name, in_shape=in_shape, out_size=out_size, pretrained=False
383
+ )
380
384
 
381
385
  # Load weights (check multiple formats in order of preference)
382
386
  weight_path = None
@@ -240,15 +240,9 @@ def parse_args() -> argparse.Namespace:
240
240
  )
241
241
  parser.add_argument(
242
242
  "--pretrained",
243
- action="store_true",
243
+ action=argparse.BooleanOptionalAction,
244
244
  default=True,
245
- help="Use pretrained weights (default: True)",
246
- )
247
- parser.add_argument(
248
- "--no_pretrained",
249
- dest="pretrained",
250
- action="store_false",
251
- help="Train from scratch without pretrained weights",
245
+ help="Use pretrained weights (default: True). Use --no-pretrained to train from scratch.",
252
246
  )
253
247
 
254
248
  # Configuration File
@@ -367,6 +361,18 @@ def parse_args() -> argparse.Namespace:
367
361
  help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
368
362
  )
369
363
  parser.add_argument("--seed", type=int, default=2025, help="Random seed")
364
+ parser.add_argument(
365
+ "--deterministic",
366
+ action="store_true",
367
+ help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
368
+ )
369
+ parser.add_argument(
370
+ "--cache_validate",
371
+ type=str,
372
+ default="sha256",
373
+ choices=["sha256", "fast", "size"],
374
+ help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
375
+ )
370
376
  parser.add_argument(
371
377
  "--single_channel",
372
378
  action="store_true",
@@ -512,11 +518,23 @@ def main():
512
518
  # Import as regular module
513
519
  importlib.import_module(module_name)
514
520
  print(f"✓ Imported module: {module_name}")
515
- except ImportError as e:
521
+ except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
516
522
  print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
517
- print(
518
- " Make sure the module is in your Python path or current directory."
519
- )
523
+ if isinstance(e, FileNotFoundError):
524
+ print(" File does not exist. Check the path.", file=sys.stderr)
525
+ elif isinstance(e, SyntaxError):
526
+ print(
527
+ f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
528
+ )
529
+ elif isinstance(e, PermissionError):
530
+ print(
531
+ " Permission denied. Check file permissions.", file=sys.stderr
532
+ )
533
+ else:
534
+ print(
535
+ " Make sure the module is in your Python path or current directory.",
536
+ file=sys.stderr,
537
+ )
520
538
  sys.exit(1)
521
539
 
522
540
  # Handle --list_models flag
@@ -648,6 +666,17 @@ def main():
648
666
  )
649
667
  set_seed(args.seed)
650
668
 
669
+ # Deterministic mode for scientific reproducibility
670
+ # Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
671
+ if args.deterministic:
672
+ torch.backends.cudnn.benchmark = False
673
+ torch.backends.cudnn.deterministic = True
674
+ torch.backends.cuda.matmul.allow_tf32 = False
675
+ torch.backends.cudnn.allow_tf32 = False
676
+ torch.use_deterministic_algorithms(True, warn_only=True)
677
+ if accelerator.is_main_process:
678
+ print("🔒 Deterministic mode enabled (slower but reproducible)")
679
+
651
680
  # Configure logging (rank 0 only prints to console)
652
681
  logging.basicConfig(
653
682
  level=logging.INFO if accelerator.is_main_process else logging.ERROR,
@@ -207,22 +207,28 @@ class ExpressionConstraint(nn.Module):
207
207
  # Parse indices from the slice
208
208
  indices = self._parse_subscript_indices(node.slice)
209
209
 
210
+ # Auto-squeeze channel dimension for single-channel inputs
211
+ # This allows x[i,j] syntax for (B, 1, H, W) inputs instead of x[c,i,j]
212
+ inputs_for_indexing = inputs
213
+ if inputs.ndim >= 3 and inputs.shape[1] == 1:
214
+ inputs_for_indexing = inputs.squeeze(1) # (B, 1, H, W) → (B, H, W)
215
+
210
216
  # Validate dimensions match
211
217
  # inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
212
- input_ndim = inputs.ndim - 1 # Exclude batch dimension
218
+ input_ndim = inputs_for_indexing.ndim - 1 # Exclude batch dimension
213
219
  if len(indices) != input_ndim:
214
220
  raise ValueError(
215
- f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
221
+ f"Input has {input_ndim}D shape (after channel squeeze), but got {len(indices)} indices. "
216
222
  f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
217
223
  )
218
224
 
219
225
  # Extract the value at the specified indices (for entire batch)
220
226
  if len(indices) == 1:
221
- return inputs[:, indices[0]]
227
+ return inputs_for_indexing[:, indices[0]]
222
228
  elif len(indices) == 2:
223
- return inputs[:, indices[0], indices[1]]
229
+ return inputs_for_indexing[:, indices[0], indices[1]]
224
230
  elif len(indices) == 3:
225
- return inputs[:, indices[0], indices[1], indices[2]]
231
+ return inputs_for_indexing[:, indices[0], indices[1], indices[2]]
226
232
  else:
227
233
  raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
228
234
  elif isinstance(node, ast.Expression):