wavedl 1.5.6__tar.gz → 1.5.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.5.6/src/wavedl.egg-info → wavedl-1.5.7}/PKG-INFO +16 -7
- {wavedl-1.5.6 → wavedl-1.5.7}/README.md +15 -6
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/__init__.py +1 -1
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/efficientnet.py +24 -7
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/efficientnetv2.py +29 -6
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/mobilenetv3.py +31 -8
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/regnet.py +29 -6
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/swin.py +38 -6
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/tcn.py +22 -2
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/test.py +7 -3
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/train.py +41 -12
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/constraints.py +11 -5
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/data.py +82 -13
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6 → wavedl-1.5.7/src/wavedl.egg-info}/PKG-INFO +16 -7
- {wavedl-1.5.6 → wavedl-1.5.7}/LICENSE +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/pyproject.toml +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/setup.cfg +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/hpc.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/hpo.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/base.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/config.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.5.7}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.7
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -388,7 +388,7 @@ WaveDL/
|
|
|
388
388
|
├── configs/ # YAML config templates
|
|
389
389
|
├── examples/ # Ready-to-run examples
|
|
390
390
|
├── notebooks/ # Jupyter notebooks
|
|
391
|
-
├── unit_tests/ # Pytest test suite (
|
|
391
|
+
├── unit_tests/ # Pytest test suite (903 tests)
|
|
392
392
|
│
|
|
393
393
|
├── pyproject.toml # Package config, dependencies
|
|
394
394
|
├── CHANGELOG.md # Version history
|
|
@@ -1035,12 +1035,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1035
1035
|
|
|
1036
1036
|
| Parameter | Unit | Description |
|
|
1037
1037
|
|-----------|------|-------------|
|
|
1038
|
-
|
|
|
1039
|
-
|
|
|
1040
|
-
|
|
|
1038
|
+
| $h$ | mm | Plate thickness |
|
|
1039
|
+
| $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
|
|
1040
|
+
| $\nu$ | — | Poisson's ratio |
|
|
1041
1041
|
|
|
1042
1042
|
> [!NOTE]
|
|
1043
|
-
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"
|
|
1043
|
+
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
|
|
1044
|
+
"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
|
|
1045
|
+
|
|
1046
|
+
**Sample Dispersion Data:**
|
|
1047
|
+
|
|
1048
|
+
<p align="center">
|
|
1049
|
+
<img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
|
|
1050
|
+
<em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
|
|
1051
|
+
</p>
|
|
1044
1052
|
|
|
1045
1053
|
**Try it yourself:**
|
|
1046
1054
|
|
|
@@ -1061,7 +1069,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
|
|
|
1061
1069
|
| File | Description |
|
|
1062
1070
|
|------|-------------|
|
|
1063
1071
|
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1064
|
-
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves →
|
|
1072
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
|
|
1073
|
+
| `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
|
|
1065
1074
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1066
1075
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1067
1076
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -342,7 +342,7 @@ WaveDL/
|
|
|
342
342
|
├── configs/ # YAML config templates
|
|
343
343
|
├── examples/ # Ready-to-run examples
|
|
344
344
|
├── notebooks/ # Jupyter notebooks
|
|
345
|
-
├── unit_tests/ # Pytest test suite (
|
|
345
|
+
├── unit_tests/ # Pytest test suite (903 tests)
|
|
346
346
|
│
|
|
347
347
|
├── pyproject.toml # Package config, dependencies
|
|
348
348
|
├── CHANGELOG.md # Version history
|
|
@@ -989,12 +989,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
989
989
|
|
|
990
990
|
| Parameter | Unit | Description |
|
|
991
991
|
|-----------|------|-------------|
|
|
992
|
-
|
|
|
993
|
-
|
|
|
994
|
-
|
|
|
992
|
+
| $h$ | mm | Plate thickness |
|
|
993
|
+
| $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
|
|
994
|
+
| $\nu$ | — | Poisson's ratio |
|
|
995
995
|
|
|
996
996
|
> [!NOTE]
|
|
997
|
-
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"
|
|
997
|
+
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
|
|
998
|
+
"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
|
|
999
|
+
|
|
1000
|
+
**Sample Dispersion Data:**
|
|
1001
|
+
|
|
1002
|
+
<p align="center">
|
|
1003
|
+
<img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
|
|
1004
|
+
<em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
|
|
1005
|
+
</p>
|
|
998
1006
|
|
|
999
1007
|
**Try it yourself:**
|
|
1000
1008
|
|
|
@@ -1015,7 +1023,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
|
|
|
1015
1023
|
| File | Description |
|
|
1016
1024
|
|------|-------------|
|
|
1017
1025
|
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1018
|
-
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves →
|
|
1026
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
|
|
1027
|
+
| `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
|
|
1019
1028
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1020
1029
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1021
1030
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
|
|
|
110
110
|
self._freeze_backbone()
|
|
111
111
|
|
|
112
112
|
def _adapt_input_channels(self):
|
|
113
|
-
"""Modify first conv to
|
|
114
|
-
|
|
115
|
-
|
|
113
|
+
"""Modify first conv to accept single-channel input.
|
|
114
|
+
|
|
115
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
116
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
117
|
+
weights as the mean of the pretrained RGB filters.
|
|
118
|
+
"""
|
|
119
|
+
# EfficientNet stem conv is at: features[0][0]
|
|
120
|
+
old_conv = self.backbone.features[0][0]
|
|
121
|
+
new_conv = nn.Conv2d(
|
|
122
|
+
1, # Single channel input
|
|
123
|
+
old_conv.out_channels,
|
|
124
|
+
kernel_size=old_conv.kernel_size,
|
|
125
|
+
stride=old_conv.stride,
|
|
126
|
+
padding=old_conv.padding,
|
|
127
|
+
dilation=old_conv.dilation,
|
|
128
|
+
groups=old_conv.groups,
|
|
129
|
+
padding_mode=old_conv.padding_mode,
|
|
130
|
+
bias=old_conv.bias is not None,
|
|
131
|
+
)
|
|
132
|
+
if self.pretrained:
|
|
133
|
+
# Initialize with mean of pretrained RGB weights
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
136
|
+
self.backbone.features[0][0] = new_conv
|
|
116
137
|
|
|
117
138
|
def _freeze_backbone(self):
|
|
118
139
|
"""Freeze all backbone parameters except the classifier."""
|
|
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
|
|
|
130
151
|
Returns:
|
|
131
152
|
Output tensor of shape (B, out_size)
|
|
132
153
|
"""
|
|
133
|
-
# Expand single channel to 3 channels for pretrained weights
|
|
134
|
-
if x.size(1) == 1:
|
|
135
|
-
x = x.expand(-1, 3, -1, -1)
|
|
136
|
-
|
|
137
154
|
return self.backbone(x)
|
|
138
155
|
|
|
139
156
|
@classmethod
|
|
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
|
|
|
129
129
|
nn.Linear(regression_hidden // 2, out_size),
|
|
130
130
|
)
|
|
131
131
|
|
|
132
|
-
#
|
|
132
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
133
|
+
self._adapt_input_channels()
|
|
134
|
+
|
|
135
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
133
136
|
if freeze_backbone:
|
|
134
137
|
self._freeze_backbone()
|
|
135
138
|
|
|
139
|
+
def _adapt_input_channels(self):
|
|
140
|
+
"""Modify first conv to accept single-channel input.
|
|
141
|
+
|
|
142
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
143
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
144
|
+
weights as the mean of the pretrained RGB filters.
|
|
145
|
+
"""
|
|
146
|
+
old_conv = self.backbone.features[0][0]
|
|
147
|
+
new_conv = nn.Conv2d(
|
|
148
|
+
1, # Single channel input
|
|
149
|
+
old_conv.out_channels,
|
|
150
|
+
kernel_size=old_conv.kernel_size,
|
|
151
|
+
stride=old_conv.stride,
|
|
152
|
+
padding=old_conv.padding,
|
|
153
|
+
dilation=old_conv.dilation,
|
|
154
|
+
groups=old_conv.groups,
|
|
155
|
+
padding_mode=old_conv.padding_mode,
|
|
156
|
+
bias=old_conv.bias is not None,
|
|
157
|
+
)
|
|
158
|
+
if self.pretrained:
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
161
|
+
self.backbone.features[0][0] = new_conv
|
|
162
|
+
|
|
136
163
|
def _freeze_backbone(self):
|
|
137
164
|
"""Freeze all backbone parameters except the classifier."""
|
|
138
165
|
for name, param in self.backbone.named_parameters():
|
|
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
|
|
|
144
171
|
Forward pass.
|
|
145
172
|
|
|
146
173
|
Args:
|
|
147
|
-
x: Input tensor of shape (B,
|
|
174
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
148
175
|
|
|
149
176
|
Returns:
|
|
150
177
|
Output tensor of shape (B, out_size)
|
|
151
178
|
"""
|
|
152
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
153
|
-
if x.size(1) == 1:
|
|
154
|
-
x = x.expand(-1, 3, -1, -1)
|
|
155
|
-
|
|
156
179
|
return self.backbone(x)
|
|
157
180
|
|
|
158
181
|
@classmethod
|
|
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
|
|
|
136
136
|
nn.Linear(regression_hidden, out_size),
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
#
|
|
139
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
140
|
+
self._adapt_input_channels()
|
|
141
|
+
|
|
142
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
140
143
|
if freeze_backbone:
|
|
141
144
|
self._freeze_backbone()
|
|
142
145
|
|
|
146
|
+
def _adapt_input_channels(self):
|
|
147
|
+
"""Modify first conv to accept single-channel input.
|
|
148
|
+
|
|
149
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
150
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
151
|
+
weights as the mean of the pretrained RGB filters.
|
|
152
|
+
"""
|
|
153
|
+
old_conv = self.backbone.features[0][0]
|
|
154
|
+
new_conv = nn.Conv2d(
|
|
155
|
+
1, # Single channel input
|
|
156
|
+
old_conv.out_channels,
|
|
157
|
+
kernel_size=old_conv.kernel_size,
|
|
158
|
+
stride=old_conv.stride,
|
|
159
|
+
padding=old_conv.padding,
|
|
160
|
+
dilation=old_conv.dilation,
|
|
161
|
+
groups=old_conv.groups,
|
|
162
|
+
padding_mode=old_conv.padding_mode,
|
|
163
|
+
bias=old_conv.bias is not None,
|
|
164
|
+
)
|
|
165
|
+
if self.pretrained:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
168
|
+
self.backbone.features[0][0] = new_conv
|
|
169
|
+
|
|
143
170
|
def _freeze_backbone(self):
|
|
144
171
|
"""Freeze all backbone parameters except the classifier."""
|
|
145
172
|
for name, param in self.backbone.named_parameters():
|
|
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
|
|
|
151
178
|
Forward pass.
|
|
152
179
|
|
|
153
180
|
Args:
|
|
154
|
-
x: Input tensor of shape (B,
|
|
181
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
155
182
|
|
|
156
183
|
Returns:
|
|
157
184
|
Output tensor of shape (B, out_size)
|
|
158
185
|
"""
|
|
159
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
160
|
-
if x.size(1) == 1:
|
|
161
|
-
x = x.expand(-1, 3, -1, -1)
|
|
162
|
-
|
|
163
186
|
return self.backbone(x)
|
|
164
187
|
|
|
165
188
|
@classmethod
|
|
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
194
217
|
|
|
195
218
|
Performance (approximate):
|
|
196
219
|
- CPU inference: ~6ms (single core)
|
|
197
|
-
- Parameters:
|
|
220
|
+
- Parameters: ~1.1M
|
|
198
221
|
- MAdds: 56M
|
|
199
222
|
|
|
200
223
|
Args:
|
|
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
241
264
|
|
|
242
265
|
Performance (approximate):
|
|
243
266
|
- CPU inference: ~20ms (single core)
|
|
244
|
-
- Parameters:
|
|
267
|
+
- Parameters: ~3.2M
|
|
245
268
|
- MAdds: 219M
|
|
246
269
|
|
|
247
270
|
Args:
|
|
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
|
|
|
140
140
|
nn.Linear(regression_hidden, out_size),
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
#
|
|
143
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
144
|
+
self._adapt_input_channels()
|
|
145
|
+
|
|
146
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
144
147
|
if freeze_backbone:
|
|
145
148
|
self._freeze_backbone()
|
|
146
149
|
|
|
150
|
+
def _adapt_input_channels(self):
|
|
151
|
+
"""Modify first conv to accept single-channel input.
|
|
152
|
+
|
|
153
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
154
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
155
|
+
weights as the mean of the pretrained RGB filters.
|
|
156
|
+
"""
|
|
157
|
+
old_conv = self.backbone.stem[0]
|
|
158
|
+
new_conv = nn.Conv2d(
|
|
159
|
+
1, # Single channel input
|
|
160
|
+
old_conv.out_channels,
|
|
161
|
+
kernel_size=old_conv.kernel_size,
|
|
162
|
+
stride=old_conv.stride,
|
|
163
|
+
padding=old_conv.padding,
|
|
164
|
+
dilation=old_conv.dilation,
|
|
165
|
+
groups=old_conv.groups,
|
|
166
|
+
padding_mode=old_conv.padding_mode,
|
|
167
|
+
bias=old_conv.bias is not None,
|
|
168
|
+
)
|
|
169
|
+
if self.pretrained:
|
|
170
|
+
with torch.no_grad():
|
|
171
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
172
|
+
self.backbone.stem[0] = new_conv
|
|
173
|
+
|
|
147
174
|
def _freeze_backbone(self):
|
|
148
175
|
"""Freeze all backbone parameters except the fc layer."""
|
|
149
176
|
for name, param in self.backbone.named_parameters():
|
|
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
|
|
|
155
182
|
Forward pass.
|
|
156
183
|
|
|
157
184
|
Args:
|
|
158
|
-
x: Input tensor of shape (B,
|
|
185
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
159
186
|
|
|
160
187
|
Returns:
|
|
161
188
|
Output tensor of shape (B, out_size)
|
|
162
189
|
"""
|
|
163
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
164
|
-
if x.size(1) == 1:
|
|
165
|
-
x = x.expand(-1, 3, -1, -1)
|
|
166
|
-
|
|
167
190
|
return self.backbone(x)
|
|
168
191
|
|
|
169
192
|
@classmethod
|
|
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
|
|
|
141
141
|
nn.Linear(regression_hidden // 2, out_size),
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
-
#
|
|
144
|
+
# Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
|
|
145
|
+
self._adapt_input_channels()
|
|
146
|
+
|
|
147
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
145
148
|
if freeze_backbone:
|
|
146
149
|
self._freeze_backbone()
|
|
147
150
|
|
|
151
|
+
def _adapt_input_channels(self):
|
|
152
|
+
"""Modify patch embedding conv to accept single-channel input.
|
|
153
|
+
|
|
154
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
155
|
+
we replace the patch embedding conv with a 1-channel version and
|
|
156
|
+
initialize weights as the mean of the pretrained RGB filters.
|
|
157
|
+
"""
|
|
158
|
+
# Swin's patch embedding is at features[0][0]
|
|
159
|
+
try:
|
|
160
|
+
old_conv = self.backbone.features[0][0]
|
|
161
|
+
except (IndexError, AttributeError, TypeError) as e:
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
f"Swin patch embed structure changed in this torchvision version. "
|
|
164
|
+
f"Cannot adapt input channels. Error: {e}"
|
|
165
|
+
) from e
|
|
166
|
+
new_conv = nn.Conv2d(
|
|
167
|
+
1, # Single channel input
|
|
168
|
+
old_conv.out_channels,
|
|
169
|
+
kernel_size=old_conv.kernel_size,
|
|
170
|
+
stride=old_conv.stride,
|
|
171
|
+
padding=old_conv.padding,
|
|
172
|
+
dilation=old_conv.dilation,
|
|
173
|
+
groups=old_conv.groups,
|
|
174
|
+
padding_mode=old_conv.padding_mode,
|
|
175
|
+
bias=old_conv.bias is not None,
|
|
176
|
+
)
|
|
177
|
+
if self.pretrained:
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
180
|
+
if old_conv.bias is not None:
|
|
181
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
182
|
+
self.backbone.features[0][0] = new_conv
|
|
183
|
+
|
|
148
184
|
def _freeze_backbone(self):
|
|
149
185
|
"""Freeze all backbone parameters except the head."""
|
|
150
186
|
for name, param in self.backbone.named_parameters():
|
|
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
|
|
|
156
192
|
Forward pass.
|
|
157
193
|
|
|
158
194
|
Args:
|
|
159
|
-
x: Input tensor of shape (B,
|
|
195
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
160
196
|
|
|
161
197
|
Returns:
|
|
162
198
|
Output tensor of shape (B, out_size)
|
|
163
199
|
"""
|
|
164
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
165
|
-
if x.size(1) == 1:
|
|
166
|
-
x = x.expand(-1, 3, -1, -1)
|
|
167
|
-
|
|
168
200
|
return self.backbone(x)
|
|
169
201
|
|
|
170
202
|
@classmethod
|
|
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
|
|
|
45
45
|
from wavedl.models.registry import register_model
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def _find_group_count(channels: int, max_groups: int = 8) -> int:
|
|
49
|
+
"""
|
|
50
|
+
Find largest valid group count for GroupNorm.
|
|
51
|
+
|
|
52
|
+
GroupNorm requires channels to be divisible by num_groups.
|
|
53
|
+
This finds the largest divisor up to max_groups.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
channels: Number of channels
|
|
57
|
+
max_groups: Maximum group count to consider (default: 8)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Largest valid group count (always >= 1)
|
|
61
|
+
"""
|
|
62
|
+
for g in range(min(max_groups, channels), 0, -1):
|
|
63
|
+
if channels % g == 0:
|
|
64
|
+
return g
|
|
65
|
+
return 1
|
|
66
|
+
|
|
67
|
+
|
|
48
68
|
class CausalConv1d(nn.Module):
|
|
49
69
|
"""
|
|
50
70
|
Causal 1D convolution with dilation.
|
|
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
|
|
|
101
121
|
|
|
102
122
|
# First causal convolution
|
|
103
123
|
self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
|
104
|
-
self.norm1 = nn.GroupNorm(
|
|
124
|
+
self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
105
125
|
self.act1 = nn.GELU()
|
|
106
126
|
self.dropout1 = nn.Dropout(dropout)
|
|
107
127
|
|
|
108
128
|
# Second causal convolution
|
|
109
129
|
self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
|
110
|
-
self.norm2 = nn.GroupNorm(
|
|
130
|
+
self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
111
131
|
self.act2 = nn.GELU()
|
|
112
132
|
self.dropout2 = nn.Dropout(dropout)
|
|
113
133
|
|
|
@@ -311,7 +311,7 @@ def load_data_for_inference(
|
|
|
311
311
|
# ==============================================================================
|
|
312
312
|
def load_checkpoint(
|
|
313
313
|
checkpoint_dir: str,
|
|
314
|
-
in_shape: tuple[int,
|
|
314
|
+
in_shape: tuple[int, ...],
|
|
315
315
|
out_size: int,
|
|
316
316
|
model_name: str | None = None,
|
|
317
317
|
) -> tuple[nn.Module, any]:
|
|
@@ -320,7 +320,7 @@ def load_checkpoint(
|
|
|
320
320
|
|
|
321
321
|
Args:
|
|
322
322
|
checkpoint_dir: Path to checkpoint directory
|
|
323
|
-
in_shape: Input
|
|
323
|
+
in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
|
|
324
324
|
out_size: Number of output parameters
|
|
325
325
|
model_name: Model architecture name (auto-detect if None)
|
|
326
326
|
|
|
@@ -376,7 +376,11 @@ def load_checkpoint(
|
|
|
376
376
|
)
|
|
377
377
|
|
|
378
378
|
logging.info(f" Building model: {model_name}")
|
|
379
|
-
|
|
379
|
+
# Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
|
|
380
|
+
# so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
|
|
381
|
+
model = build_model(
|
|
382
|
+
model_name, in_shape=in_shape, out_size=out_size, pretrained=False
|
|
383
|
+
)
|
|
380
384
|
|
|
381
385
|
# Load weights (check multiple formats in order of preference)
|
|
382
386
|
weight_path = None
|
|
@@ -240,15 +240,9 @@ def parse_args() -> argparse.Namespace:
|
|
|
240
240
|
)
|
|
241
241
|
parser.add_argument(
|
|
242
242
|
"--pretrained",
|
|
243
|
-
action=
|
|
243
|
+
action=argparse.BooleanOptionalAction,
|
|
244
244
|
default=True,
|
|
245
|
-
help="Use pretrained weights (default: True)",
|
|
246
|
-
)
|
|
247
|
-
parser.add_argument(
|
|
248
|
-
"--no_pretrained",
|
|
249
|
-
dest="pretrained",
|
|
250
|
-
action="store_false",
|
|
251
|
-
help="Train from scratch without pretrained weights",
|
|
245
|
+
help="Use pretrained weights (default: True). Use --no-pretrained to train from scratch.",
|
|
252
246
|
)
|
|
253
247
|
|
|
254
248
|
# Configuration File
|
|
@@ -367,6 +361,18 @@ def parse_args() -> argparse.Namespace:
|
|
|
367
361
|
help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
|
|
368
362
|
)
|
|
369
363
|
parser.add_argument("--seed", type=int, default=2025, help="Random seed")
|
|
364
|
+
parser.add_argument(
|
|
365
|
+
"--deterministic",
|
|
366
|
+
action="store_true",
|
|
367
|
+
help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
|
|
368
|
+
)
|
|
369
|
+
parser.add_argument(
|
|
370
|
+
"--cache_validate",
|
|
371
|
+
type=str,
|
|
372
|
+
default="sha256",
|
|
373
|
+
choices=["sha256", "fast", "size"],
|
|
374
|
+
help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
|
|
375
|
+
)
|
|
370
376
|
parser.add_argument(
|
|
371
377
|
"--single_channel",
|
|
372
378
|
action="store_true",
|
|
@@ -512,11 +518,23 @@ def main():
|
|
|
512
518
|
# Import as regular module
|
|
513
519
|
importlib.import_module(module_name)
|
|
514
520
|
print(f"✓ Imported module: {module_name}")
|
|
515
|
-
except ImportError as e:
|
|
521
|
+
except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
|
|
516
522
|
print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
|
|
517
|
-
|
|
518
|
-
"
|
|
519
|
-
)
|
|
523
|
+
if isinstance(e, FileNotFoundError):
|
|
524
|
+
print(" File does not exist. Check the path.", file=sys.stderr)
|
|
525
|
+
elif isinstance(e, SyntaxError):
|
|
526
|
+
print(
|
|
527
|
+
f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
|
|
528
|
+
)
|
|
529
|
+
elif isinstance(e, PermissionError):
|
|
530
|
+
print(
|
|
531
|
+
" Permission denied. Check file permissions.", file=sys.stderr
|
|
532
|
+
)
|
|
533
|
+
else:
|
|
534
|
+
print(
|
|
535
|
+
" Make sure the module is in your Python path or current directory.",
|
|
536
|
+
file=sys.stderr,
|
|
537
|
+
)
|
|
520
538
|
sys.exit(1)
|
|
521
539
|
|
|
522
540
|
# Handle --list_models flag
|
|
@@ -648,6 +666,17 @@ def main():
|
|
|
648
666
|
)
|
|
649
667
|
set_seed(args.seed)
|
|
650
668
|
|
|
669
|
+
# Deterministic mode for scientific reproducibility
|
|
670
|
+
# Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
|
|
671
|
+
if args.deterministic:
|
|
672
|
+
torch.backends.cudnn.benchmark = False
|
|
673
|
+
torch.backends.cudnn.deterministic = True
|
|
674
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
675
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
676
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
677
|
+
if accelerator.is_main_process:
|
|
678
|
+
print("🔒 Deterministic mode enabled (slower but reproducible)")
|
|
679
|
+
|
|
651
680
|
# Configure logging (rank 0 only prints to console)
|
|
652
681
|
logging.basicConfig(
|
|
653
682
|
level=logging.INFO if accelerator.is_main_process else logging.ERROR,
|
|
@@ -207,22 +207,28 @@ class ExpressionConstraint(nn.Module):
|
|
|
207
207
|
# Parse indices from the slice
|
|
208
208
|
indices = self._parse_subscript_indices(node.slice)
|
|
209
209
|
|
|
210
|
+
# Auto-squeeze channel dimension for single-channel inputs
|
|
211
|
+
# This allows x[i,j] syntax for (B, 1, H, W) inputs instead of x[c,i,j]
|
|
212
|
+
inputs_for_indexing = inputs
|
|
213
|
+
if inputs.ndim >= 3 and inputs.shape[1] == 1:
|
|
214
|
+
inputs_for_indexing = inputs.squeeze(1) # (B, 1, H, W) → (B, H, W)
|
|
215
|
+
|
|
210
216
|
# Validate dimensions match
|
|
211
217
|
# inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
|
|
212
|
-
input_ndim =
|
|
218
|
+
input_ndim = inputs_for_indexing.ndim - 1 # Exclude batch dimension
|
|
213
219
|
if len(indices) != input_ndim:
|
|
214
220
|
raise ValueError(
|
|
215
|
-
f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
|
|
221
|
+
f"Input has {input_ndim}D shape (after channel squeeze), but got {len(indices)} indices. "
|
|
216
222
|
f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
|
|
217
223
|
)
|
|
218
224
|
|
|
219
225
|
# Extract the value at the specified indices (for entire batch)
|
|
220
226
|
if len(indices) == 1:
|
|
221
|
-
return
|
|
227
|
+
return inputs_for_indexing[:, indices[0]]
|
|
222
228
|
elif len(indices) == 2:
|
|
223
|
-
return
|
|
229
|
+
return inputs_for_indexing[:, indices[0], indices[1]]
|
|
224
230
|
elif len(indices) == 3:
|
|
225
|
-
return
|
|
231
|
+
return inputs_for_indexing[:, indices[0], indices[1], indices[2]]
|
|
226
232
|
else:
|
|
227
233
|
raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
|
|
228
234
|
elif isinstance(node, ast.Expression):
|