wavedl 1.5.5__py3-none-any.whl → 1.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/efficientnet.py +24 -7
- wavedl/models/efficientnetv2.py +29 -6
- wavedl/models/mobilenetv3.py +31 -8
- wavedl/models/regnet.py +29 -6
- wavedl/models/swin.py +38 -6
- wavedl/models/tcn.py +22 -2
- wavedl/models/vit.py +85 -25
- wavedl/test.py +7 -3
- wavedl/train.py +79 -18
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +130 -39
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/METADATA +37 -27
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/RECORD +19 -19
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/LICENSE +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/WHEEL +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.7.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/models/efficientnet.py
CHANGED
|
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
|
|
|
110
110
|
self._freeze_backbone()
|
|
111
111
|
|
|
112
112
|
def _adapt_input_channels(self):
|
|
113
|
-
"""Modify first conv to
|
|
114
|
-
|
|
115
|
-
|
|
113
|
+
"""Modify first conv to accept single-channel input.
|
|
114
|
+
|
|
115
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
116
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
117
|
+
weights as the mean of the pretrained RGB filters.
|
|
118
|
+
"""
|
|
119
|
+
# EfficientNet stem conv is at: features[0][0]
|
|
120
|
+
old_conv = self.backbone.features[0][0]
|
|
121
|
+
new_conv = nn.Conv2d(
|
|
122
|
+
1, # Single channel input
|
|
123
|
+
old_conv.out_channels,
|
|
124
|
+
kernel_size=old_conv.kernel_size,
|
|
125
|
+
stride=old_conv.stride,
|
|
126
|
+
padding=old_conv.padding,
|
|
127
|
+
dilation=old_conv.dilation,
|
|
128
|
+
groups=old_conv.groups,
|
|
129
|
+
padding_mode=old_conv.padding_mode,
|
|
130
|
+
bias=old_conv.bias is not None,
|
|
131
|
+
)
|
|
132
|
+
if self.pretrained:
|
|
133
|
+
# Initialize with mean of pretrained RGB weights
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
136
|
+
self.backbone.features[0][0] = new_conv
|
|
116
137
|
|
|
117
138
|
def _freeze_backbone(self):
|
|
118
139
|
"""Freeze all backbone parameters except the classifier."""
|
|
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
|
|
|
130
151
|
Returns:
|
|
131
152
|
Output tensor of shape (B, out_size)
|
|
132
153
|
"""
|
|
133
|
-
# Expand single channel to 3 channels for pretrained weights
|
|
134
|
-
if x.size(1) == 1:
|
|
135
|
-
x = x.expand(-1, 3, -1, -1)
|
|
136
|
-
|
|
137
154
|
return self.backbone(x)
|
|
138
155
|
|
|
139
156
|
@classmethod
|
wavedl/models/efficientnetv2.py
CHANGED
|
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
|
|
|
129
129
|
nn.Linear(regression_hidden // 2, out_size),
|
|
130
130
|
)
|
|
131
131
|
|
|
132
|
-
#
|
|
132
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
133
|
+
self._adapt_input_channels()
|
|
134
|
+
|
|
135
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
133
136
|
if freeze_backbone:
|
|
134
137
|
self._freeze_backbone()
|
|
135
138
|
|
|
139
|
+
def _adapt_input_channels(self):
|
|
140
|
+
"""Modify first conv to accept single-channel input.
|
|
141
|
+
|
|
142
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
143
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
144
|
+
weights as the mean of the pretrained RGB filters.
|
|
145
|
+
"""
|
|
146
|
+
old_conv = self.backbone.features[0][0]
|
|
147
|
+
new_conv = nn.Conv2d(
|
|
148
|
+
1, # Single channel input
|
|
149
|
+
old_conv.out_channels,
|
|
150
|
+
kernel_size=old_conv.kernel_size,
|
|
151
|
+
stride=old_conv.stride,
|
|
152
|
+
padding=old_conv.padding,
|
|
153
|
+
dilation=old_conv.dilation,
|
|
154
|
+
groups=old_conv.groups,
|
|
155
|
+
padding_mode=old_conv.padding_mode,
|
|
156
|
+
bias=old_conv.bias is not None,
|
|
157
|
+
)
|
|
158
|
+
if self.pretrained:
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
161
|
+
self.backbone.features[0][0] = new_conv
|
|
162
|
+
|
|
136
163
|
def _freeze_backbone(self):
|
|
137
164
|
"""Freeze all backbone parameters except the classifier."""
|
|
138
165
|
for name, param in self.backbone.named_parameters():
|
|
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
|
|
|
144
171
|
Forward pass.
|
|
145
172
|
|
|
146
173
|
Args:
|
|
147
|
-
x: Input tensor of shape (B,
|
|
174
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
148
175
|
|
|
149
176
|
Returns:
|
|
150
177
|
Output tensor of shape (B, out_size)
|
|
151
178
|
"""
|
|
152
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
153
|
-
if x.size(1) == 1:
|
|
154
|
-
x = x.expand(-1, 3, -1, -1)
|
|
155
|
-
|
|
156
179
|
return self.backbone(x)
|
|
157
180
|
|
|
158
181
|
@classmethod
|
wavedl/models/mobilenetv3.py
CHANGED
|
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
|
|
|
136
136
|
nn.Linear(regression_hidden, out_size),
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
#
|
|
139
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
140
|
+
self._adapt_input_channels()
|
|
141
|
+
|
|
142
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
140
143
|
if freeze_backbone:
|
|
141
144
|
self._freeze_backbone()
|
|
142
145
|
|
|
146
|
+
def _adapt_input_channels(self):
|
|
147
|
+
"""Modify first conv to accept single-channel input.
|
|
148
|
+
|
|
149
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
150
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
151
|
+
weights as the mean of the pretrained RGB filters.
|
|
152
|
+
"""
|
|
153
|
+
old_conv = self.backbone.features[0][0]
|
|
154
|
+
new_conv = nn.Conv2d(
|
|
155
|
+
1, # Single channel input
|
|
156
|
+
old_conv.out_channels,
|
|
157
|
+
kernel_size=old_conv.kernel_size,
|
|
158
|
+
stride=old_conv.stride,
|
|
159
|
+
padding=old_conv.padding,
|
|
160
|
+
dilation=old_conv.dilation,
|
|
161
|
+
groups=old_conv.groups,
|
|
162
|
+
padding_mode=old_conv.padding_mode,
|
|
163
|
+
bias=old_conv.bias is not None,
|
|
164
|
+
)
|
|
165
|
+
if self.pretrained:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
168
|
+
self.backbone.features[0][0] = new_conv
|
|
169
|
+
|
|
143
170
|
def _freeze_backbone(self):
|
|
144
171
|
"""Freeze all backbone parameters except the classifier."""
|
|
145
172
|
for name, param in self.backbone.named_parameters():
|
|
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
|
|
|
151
178
|
Forward pass.
|
|
152
179
|
|
|
153
180
|
Args:
|
|
154
|
-
x: Input tensor of shape (B,
|
|
181
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
155
182
|
|
|
156
183
|
Returns:
|
|
157
184
|
Output tensor of shape (B, out_size)
|
|
158
185
|
"""
|
|
159
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
160
|
-
if x.size(1) == 1:
|
|
161
|
-
x = x.expand(-1, 3, -1, -1)
|
|
162
|
-
|
|
163
186
|
return self.backbone(x)
|
|
164
187
|
|
|
165
188
|
@classmethod
|
|
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
194
217
|
|
|
195
218
|
Performance (approximate):
|
|
196
219
|
- CPU inference: ~6ms (single core)
|
|
197
|
-
- Parameters:
|
|
220
|
+
- Parameters: ~1.1M
|
|
198
221
|
- MAdds: 56M
|
|
199
222
|
|
|
200
223
|
Args:
|
|
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
241
264
|
|
|
242
265
|
Performance (approximate):
|
|
243
266
|
- CPU inference: ~20ms (single core)
|
|
244
|
-
- Parameters:
|
|
267
|
+
- Parameters: ~3.2M
|
|
245
268
|
- MAdds: 219M
|
|
246
269
|
|
|
247
270
|
Args:
|
wavedl/models/regnet.py
CHANGED
|
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
|
|
|
140
140
|
nn.Linear(regression_hidden, out_size),
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
#
|
|
143
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
144
|
+
self._adapt_input_channels()
|
|
145
|
+
|
|
146
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
144
147
|
if freeze_backbone:
|
|
145
148
|
self._freeze_backbone()
|
|
146
149
|
|
|
150
|
+
def _adapt_input_channels(self):
|
|
151
|
+
"""Modify first conv to accept single-channel input.
|
|
152
|
+
|
|
153
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
154
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
155
|
+
weights as the mean of the pretrained RGB filters.
|
|
156
|
+
"""
|
|
157
|
+
old_conv = self.backbone.stem[0]
|
|
158
|
+
new_conv = nn.Conv2d(
|
|
159
|
+
1, # Single channel input
|
|
160
|
+
old_conv.out_channels,
|
|
161
|
+
kernel_size=old_conv.kernel_size,
|
|
162
|
+
stride=old_conv.stride,
|
|
163
|
+
padding=old_conv.padding,
|
|
164
|
+
dilation=old_conv.dilation,
|
|
165
|
+
groups=old_conv.groups,
|
|
166
|
+
padding_mode=old_conv.padding_mode,
|
|
167
|
+
bias=old_conv.bias is not None,
|
|
168
|
+
)
|
|
169
|
+
if self.pretrained:
|
|
170
|
+
with torch.no_grad():
|
|
171
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
172
|
+
self.backbone.stem[0] = new_conv
|
|
173
|
+
|
|
147
174
|
def _freeze_backbone(self):
|
|
148
175
|
"""Freeze all backbone parameters except the fc layer."""
|
|
149
176
|
for name, param in self.backbone.named_parameters():
|
|
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
|
|
|
155
182
|
Forward pass.
|
|
156
183
|
|
|
157
184
|
Args:
|
|
158
|
-
x: Input tensor of shape (B,
|
|
185
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
159
186
|
|
|
160
187
|
Returns:
|
|
161
188
|
Output tensor of shape (B, out_size)
|
|
162
189
|
"""
|
|
163
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
164
|
-
if x.size(1) == 1:
|
|
165
|
-
x = x.expand(-1, 3, -1, -1)
|
|
166
|
-
|
|
167
190
|
return self.backbone(x)
|
|
168
191
|
|
|
169
192
|
@classmethod
|
wavedl/models/swin.py
CHANGED
|
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
|
|
|
141
141
|
nn.Linear(regression_hidden // 2, out_size),
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
-
#
|
|
144
|
+
# Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
|
|
145
|
+
self._adapt_input_channels()
|
|
146
|
+
|
|
147
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
145
148
|
if freeze_backbone:
|
|
146
149
|
self._freeze_backbone()
|
|
147
150
|
|
|
151
|
+
def _adapt_input_channels(self):
|
|
152
|
+
"""Modify patch embedding conv to accept single-channel input.
|
|
153
|
+
|
|
154
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
155
|
+
we replace the patch embedding conv with a 1-channel version and
|
|
156
|
+
initialize weights as the mean of the pretrained RGB filters.
|
|
157
|
+
"""
|
|
158
|
+
# Swin's patch embedding is at features[0][0]
|
|
159
|
+
try:
|
|
160
|
+
old_conv = self.backbone.features[0][0]
|
|
161
|
+
except (IndexError, AttributeError, TypeError) as e:
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
f"Swin patch embed structure changed in this torchvision version. "
|
|
164
|
+
f"Cannot adapt input channels. Error: {e}"
|
|
165
|
+
) from e
|
|
166
|
+
new_conv = nn.Conv2d(
|
|
167
|
+
1, # Single channel input
|
|
168
|
+
old_conv.out_channels,
|
|
169
|
+
kernel_size=old_conv.kernel_size,
|
|
170
|
+
stride=old_conv.stride,
|
|
171
|
+
padding=old_conv.padding,
|
|
172
|
+
dilation=old_conv.dilation,
|
|
173
|
+
groups=old_conv.groups,
|
|
174
|
+
padding_mode=old_conv.padding_mode,
|
|
175
|
+
bias=old_conv.bias is not None,
|
|
176
|
+
)
|
|
177
|
+
if self.pretrained:
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
180
|
+
if old_conv.bias is not None:
|
|
181
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
182
|
+
self.backbone.features[0][0] = new_conv
|
|
183
|
+
|
|
148
184
|
def _freeze_backbone(self):
|
|
149
185
|
"""Freeze all backbone parameters except the head."""
|
|
150
186
|
for name, param in self.backbone.named_parameters():
|
|
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
|
|
|
156
192
|
Forward pass.
|
|
157
193
|
|
|
158
194
|
Args:
|
|
159
|
-
x: Input tensor of shape (B,
|
|
195
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
160
196
|
|
|
161
197
|
Returns:
|
|
162
198
|
Output tensor of shape (B, out_size)
|
|
163
199
|
"""
|
|
164
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
165
|
-
if x.size(1) == 1:
|
|
166
|
-
x = x.expand(-1, 3, -1, -1)
|
|
167
|
-
|
|
168
200
|
return self.backbone(x)
|
|
169
201
|
|
|
170
202
|
@classmethod
|
wavedl/models/tcn.py
CHANGED
|
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
|
|
|
45
45
|
from wavedl.models.registry import register_model
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def _find_group_count(channels: int, max_groups: int = 8) -> int:
|
|
49
|
+
"""
|
|
50
|
+
Find largest valid group count for GroupNorm.
|
|
51
|
+
|
|
52
|
+
GroupNorm requires channels to be divisible by num_groups.
|
|
53
|
+
This finds the largest divisor up to max_groups.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
channels: Number of channels
|
|
57
|
+
max_groups: Maximum group count to consider (default: 8)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Largest valid group count (always >= 1)
|
|
61
|
+
"""
|
|
62
|
+
for g in range(min(max_groups, channels), 0, -1):
|
|
63
|
+
if channels % g == 0:
|
|
64
|
+
return g
|
|
65
|
+
return 1
|
|
66
|
+
|
|
67
|
+
|
|
48
68
|
class CausalConv1d(nn.Module):
|
|
49
69
|
"""
|
|
50
70
|
Causal 1D convolution with dilation.
|
|
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
|
|
|
101
121
|
|
|
102
122
|
# First causal convolution
|
|
103
123
|
self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
|
104
|
-
self.norm1 = nn.GroupNorm(
|
|
124
|
+
self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
105
125
|
self.act1 = nn.GELU()
|
|
106
126
|
self.dropout1 = nn.Dropout(dropout)
|
|
107
127
|
|
|
108
128
|
# Second causal convolution
|
|
109
129
|
self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
|
110
|
-
self.norm2 = nn.GroupNorm(
|
|
130
|
+
self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
111
131
|
self.act2 = nn.GELU()
|
|
112
132
|
self.dropout2 = nn.Dropout(dropout)
|
|
113
133
|
|
wavedl/models/vit.py
CHANGED
|
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
|
|
|
42
42
|
Supports 1D and 2D inputs:
|
|
43
43
|
- 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
|
|
44
44
|
- 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
48
|
+
patch_size: Size of each patch
|
|
49
|
+
embed_dim: Embedding dimension
|
|
50
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead of
|
|
51
|
+
dropping edge pixels. Important for NDE/QUS applications where edge
|
|
52
|
+
effects matter. Default: False (original behavior with warning).
|
|
45
53
|
"""
|
|
46
54
|
|
|
47
|
-
def __init__(
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
in_shape: SpatialShape,
|
|
58
|
+
patch_size: int,
|
|
59
|
+
embed_dim: int,
|
|
60
|
+
pad_if_needed: bool = False,
|
|
61
|
+
):
|
|
48
62
|
super().__init__()
|
|
49
63
|
|
|
50
64
|
self.dim = len(in_shape)
|
|
51
65
|
self.patch_size = patch_size
|
|
52
66
|
self.embed_dim = embed_dim
|
|
67
|
+
self.pad_if_needed = pad_if_needed
|
|
68
|
+
self._padding = None # Will be set if padding is needed
|
|
53
69
|
|
|
54
70
|
if self.dim == 1:
|
|
55
71
|
# 1D: segment patches
|
|
56
72
|
L = in_shape[0]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
remainder = L % patch_size
|
|
74
|
+
if remainder != 0:
|
|
75
|
+
if pad_if_needed:
|
|
76
|
+
# Pad to next multiple of patch_size
|
|
77
|
+
pad_amount = patch_size - remainder
|
|
78
|
+
self._padding = (0, pad_amount) # (left, right)
|
|
79
|
+
L_padded = L + pad_amount
|
|
80
|
+
self.num_patches = L_padded // patch_size
|
|
81
|
+
else:
|
|
82
|
+
import warnings
|
|
83
|
+
|
|
84
|
+
warnings.warn(
|
|
85
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
86
|
+
f"Last {remainder} elements will be dropped. "
|
|
87
|
+
f"Consider using pad_if_needed=True or padding input to "
|
|
88
|
+
f"{((L // patch_size) + 1) * patch_size}.",
|
|
89
|
+
UserWarning,
|
|
90
|
+
stacklevel=2,
|
|
91
|
+
)
|
|
92
|
+
self.num_patches = L // patch_size
|
|
93
|
+
else:
|
|
94
|
+
self.num_patches = L // patch_size
|
|
68
95
|
self.proj = nn.Conv1d(
|
|
69
96
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
70
97
|
)
|
|
71
98
|
elif self.dim == 2:
|
|
72
99
|
# 2D: grid patches
|
|
73
100
|
H, W = in_shape
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
101
|
+
h_rem, w_rem = H % patch_size, W % patch_size
|
|
102
|
+
if h_rem != 0 or w_rem != 0:
|
|
103
|
+
if pad_if_needed:
|
|
104
|
+
# Pad to next multiple of patch_size
|
|
105
|
+
h_pad = (patch_size - h_rem) % patch_size
|
|
106
|
+
w_pad = (patch_size - w_rem) % patch_size
|
|
107
|
+
# Padding format: (left, right, top, bottom)
|
|
108
|
+
self._padding = (0, w_pad, 0, h_pad)
|
|
109
|
+
H_padded, W_padded = H + h_pad, W + w_pad
|
|
110
|
+
self.num_patches = (H_padded // patch_size) * (
|
|
111
|
+
W_padded // patch_size
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
import warnings
|
|
115
|
+
|
|
116
|
+
warnings.warn(
|
|
117
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
118
|
+
f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
|
|
119
|
+
f"Consider using pad_if_needed=True or padding to "
|
|
120
|
+
f"({((H // patch_size) + 1) * patch_size}, "
|
|
121
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
122
|
+
UserWarning,
|
|
123
|
+
stacklevel=2,
|
|
124
|
+
)
|
|
125
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
126
|
+
else:
|
|
127
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
86
128
|
self.proj = nn.Conv2d(
|
|
87
129
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
88
130
|
)
|
|
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
|
|
|
97
139
|
Returns:
|
|
98
140
|
Patch embeddings (B, num_patches, embed_dim)
|
|
99
141
|
"""
|
|
142
|
+
# Apply padding if configured
|
|
143
|
+
if self._padding is not None:
|
|
144
|
+
x = nn.functional.pad(x, self._padding, mode="constant", value=0)
|
|
145
|
+
|
|
100
146
|
x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
|
|
101
147
|
x = x.flatten(2) # (B, embed_dim, num_patches)
|
|
102
148
|
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
|
|
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
|
|
|
185
231
|
3. Transformer encoder blocks
|
|
186
232
|
4. Extract CLS token
|
|
187
233
|
5. Regression head
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
237
|
+
out_size: Number of regression targets
|
|
238
|
+
patch_size: Size of each patch (default: 16)
|
|
239
|
+
embed_dim: Embedding dimension (default: 768)
|
|
240
|
+
depth: Number of transformer blocks (default: 12)
|
|
241
|
+
num_heads: Number of attention heads (default: 12)
|
|
242
|
+
mlp_ratio: MLP hidden dim multiplier (default: 4.0)
|
|
243
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
244
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead
|
|
245
|
+
of dropping edge pixels. Important for NDE/QUS applications.
|
|
188
246
|
"""
|
|
189
247
|
|
|
190
248
|
def __init__(
|
|
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
|
|
|
197
255
|
num_heads: int = 12,
|
|
198
256
|
mlp_ratio: float = 4.0,
|
|
199
257
|
dropout_rate: float = 0.1,
|
|
258
|
+
pad_if_needed: bool = False,
|
|
200
259
|
**kwargs,
|
|
201
260
|
):
|
|
202
261
|
super().__init__(in_shape, out_size)
|
|
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
|
|
|
207
266
|
self.num_heads = num_heads
|
|
208
267
|
self.dropout_rate = dropout_rate
|
|
209
268
|
self.dim = len(in_shape)
|
|
269
|
+
self.pad_if_needed = pad_if_needed
|
|
210
270
|
|
|
211
271
|
# Patch embedding
|
|
212
|
-
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
|
|
272
|
+
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
|
|
213
273
|
num_patches = self.patch_embed.num_patches
|
|
214
274
|
|
|
215
275
|
# Learnable CLS token and position embeddings
|
wavedl/test.py
CHANGED
|
@@ -311,7 +311,7 @@ def load_data_for_inference(
|
|
|
311
311
|
# ==============================================================================
|
|
312
312
|
def load_checkpoint(
|
|
313
313
|
checkpoint_dir: str,
|
|
314
|
-
in_shape: tuple[int,
|
|
314
|
+
in_shape: tuple[int, ...],
|
|
315
315
|
out_size: int,
|
|
316
316
|
model_name: str | None = None,
|
|
317
317
|
) -> tuple[nn.Module, any]:
|
|
@@ -320,7 +320,7 @@ def load_checkpoint(
|
|
|
320
320
|
|
|
321
321
|
Args:
|
|
322
322
|
checkpoint_dir: Path to checkpoint directory
|
|
323
|
-
in_shape: Input
|
|
323
|
+
in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
|
|
324
324
|
out_size: Number of output parameters
|
|
325
325
|
model_name: Model architecture name (auto-detect if None)
|
|
326
326
|
|
|
@@ -376,7 +376,11 @@ def load_checkpoint(
|
|
|
376
376
|
)
|
|
377
377
|
|
|
378
378
|
logging.info(f" Building model: {model_name}")
|
|
379
|
-
|
|
379
|
+
# Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
|
|
380
|
+
# so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
|
|
381
|
+
model = build_model(
|
|
382
|
+
model_name, in_shape=in_shape, out_size=out_size, pretrained=False
|
|
383
|
+
)
|
|
380
384
|
|
|
381
385
|
# Load weights (check multiple formats in order of preference)
|
|
382
386
|
weight_path = None
|