wavedl 1.5.5__py3-none-any.whl → 1.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.5"
21
+ __version__ = "1.5.7"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
110
110
  self._freeze_backbone()
111
111
 
112
112
  def _adapt_input_channels(self):
113
- """Modify first conv to handle single-channel input by expanding to 3ch."""
114
- # We'll handle this in forward by repeating channels
115
- pass
113
+ """Modify first conv to accept single-channel input.
114
+
115
+ Instead of expanding 1→3 channels in forward (which triples memory),
116
+ we replace the first conv layer with a 1-channel version and initialize
117
+ weights as the mean of the pretrained RGB filters.
118
+ """
119
+ # EfficientNet stem conv is at: features[0][0]
120
+ old_conv = self.backbone.features[0][0]
121
+ new_conv = nn.Conv2d(
122
+ 1, # Single channel input
123
+ old_conv.out_channels,
124
+ kernel_size=old_conv.kernel_size,
125
+ stride=old_conv.stride,
126
+ padding=old_conv.padding,
127
+ dilation=old_conv.dilation,
128
+ groups=old_conv.groups,
129
+ padding_mode=old_conv.padding_mode,
130
+ bias=old_conv.bias is not None,
131
+ )
132
+ if self.pretrained:
133
+ # Initialize with mean of pretrained RGB weights
134
+ with torch.no_grad():
135
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
136
+ self.backbone.features[0][0] = new_conv
116
137
 
117
138
  def _freeze_backbone(self):
118
139
  """Freeze all backbone parameters except the classifier."""
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
130
151
  Returns:
131
152
  Output tensor of shape (B, out_size)
132
153
  """
133
- # Expand single channel to 3 channels for pretrained weights
134
- if x.size(1) == 1:
135
- x = x.expand(-1, 3, -1, -1)
136
-
137
154
  return self.backbone(x)
138
155
 
139
156
  @classmethod
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
129
129
  nn.Linear(regression_hidden // 2, out_size),
130
130
  )
131
131
 
132
- # Optionally freeze backbone for fine-tuning
132
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
133
+ self._adapt_input_channels()
134
+
135
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
133
136
  if freeze_backbone:
134
137
  self._freeze_backbone()
135
138
 
139
+ def _adapt_input_channels(self):
140
+ """Modify first conv to accept single-channel input.
141
+
142
+ Instead of expanding 1→3 channels in forward (which triples memory),
143
+ we replace the first conv layer with a 1-channel version and initialize
144
+ weights as the mean of the pretrained RGB filters.
145
+ """
146
+ old_conv = self.backbone.features[0][0]
147
+ new_conv = nn.Conv2d(
148
+ 1, # Single channel input
149
+ old_conv.out_channels,
150
+ kernel_size=old_conv.kernel_size,
151
+ stride=old_conv.stride,
152
+ padding=old_conv.padding,
153
+ dilation=old_conv.dilation,
154
+ groups=old_conv.groups,
155
+ padding_mode=old_conv.padding_mode,
156
+ bias=old_conv.bias is not None,
157
+ )
158
+ if self.pretrained:
159
+ with torch.no_grad():
160
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
161
+ self.backbone.features[0][0] = new_conv
162
+
136
163
  def _freeze_backbone(self):
137
164
  """Freeze all backbone parameters except the classifier."""
138
165
  for name, param in self.backbone.named_parameters():
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
144
171
  Forward pass.
145
172
 
146
173
  Args:
147
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
174
+ x: Input tensor of shape (B, 1, H, W)
148
175
 
149
176
  Returns:
150
177
  Output tensor of shape (B, out_size)
151
178
  """
152
- # Expand single channel to 3 channels for pretrained weights compatibility
153
- if x.size(1) == 1:
154
- x = x.expand(-1, 3, -1, -1)
155
-
156
179
  return self.backbone(x)
157
180
 
158
181
  @classmethod
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
136
136
  nn.Linear(regression_hidden, out_size),
137
137
  )
138
138
 
139
- # Optionally freeze backbone for fine-tuning
139
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
140
+ self._adapt_input_channels()
141
+
142
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
140
143
  if freeze_backbone:
141
144
  self._freeze_backbone()
142
145
 
146
+ def _adapt_input_channels(self):
147
+ """Modify first conv to accept single-channel input.
148
+
149
+ Instead of expanding 1→3 channels in forward (which triples memory),
150
+ we replace the first conv layer with a 1-channel version and initialize
151
+ weights as the mean of the pretrained RGB filters.
152
+ """
153
+ old_conv = self.backbone.features[0][0]
154
+ new_conv = nn.Conv2d(
155
+ 1, # Single channel input
156
+ old_conv.out_channels,
157
+ kernel_size=old_conv.kernel_size,
158
+ stride=old_conv.stride,
159
+ padding=old_conv.padding,
160
+ dilation=old_conv.dilation,
161
+ groups=old_conv.groups,
162
+ padding_mode=old_conv.padding_mode,
163
+ bias=old_conv.bias is not None,
164
+ )
165
+ if self.pretrained:
166
+ with torch.no_grad():
167
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
168
+ self.backbone.features[0][0] = new_conv
169
+
143
170
  def _freeze_backbone(self):
144
171
  """Freeze all backbone parameters except the classifier."""
145
172
  for name, param in self.backbone.named_parameters():
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
151
178
  Forward pass.
152
179
 
153
180
  Args:
154
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
181
+ x: Input tensor of shape (B, 1, H, W)
155
182
 
156
183
  Returns:
157
184
  Output tensor of shape (B, out_size)
158
185
  """
159
- # Expand single channel to 3 channels for pretrained weights compatibility
160
- if x.size(1) == 1:
161
- x = x.expand(-1, 3, -1, -1)
162
-
163
186
  return self.backbone(x)
164
187
 
165
188
  @classmethod
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
194
217
 
195
218
  Performance (approximate):
196
219
  - CPU inference: ~6ms (single core)
197
- - Parameters: 2.5M
220
+ - Parameters: ~1.1M
198
221
  - MAdds: 56M
199
222
 
200
223
  Args:
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
241
264
 
242
265
  Performance (approximate):
243
266
  - CPU inference: ~20ms (single core)
244
- - Parameters: 5.4M
267
+ - Parameters: ~3.2M
245
268
  - MAdds: 219M
246
269
 
247
270
  Args:
wavedl/models/regnet.py CHANGED
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
140
140
  nn.Linear(regression_hidden, out_size),
141
141
  )
142
142
 
143
- # Optionally freeze backbone for fine-tuning
143
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
144
+ self._adapt_input_channels()
145
+
146
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
144
147
  if freeze_backbone:
145
148
  self._freeze_backbone()
146
149
 
150
+ def _adapt_input_channels(self):
151
+ """Modify first conv to accept single-channel input.
152
+
153
+ Instead of expanding 1→3 channels in forward (which triples memory),
154
+ we replace the first conv layer with a 1-channel version and initialize
155
+ weights as the mean of the pretrained RGB filters.
156
+ """
157
+ old_conv = self.backbone.stem[0]
158
+ new_conv = nn.Conv2d(
159
+ 1, # Single channel input
160
+ old_conv.out_channels,
161
+ kernel_size=old_conv.kernel_size,
162
+ stride=old_conv.stride,
163
+ padding=old_conv.padding,
164
+ dilation=old_conv.dilation,
165
+ groups=old_conv.groups,
166
+ padding_mode=old_conv.padding_mode,
167
+ bias=old_conv.bias is not None,
168
+ )
169
+ if self.pretrained:
170
+ with torch.no_grad():
171
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
+ self.backbone.stem[0] = new_conv
173
+
147
174
  def _freeze_backbone(self):
148
175
  """Freeze all backbone parameters except the fc layer."""
149
176
  for name, param in self.backbone.named_parameters():
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
155
182
  Forward pass.
156
183
 
157
184
  Args:
158
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
185
+ x: Input tensor of shape (B, 1, H, W)
159
186
 
160
187
  Returns:
161
188
  Output tensor of shape (B, out_size)
162
189
  """
163
- # Expand single channel to 3 channels for pretrained weights compatibility
164
- if x.size(1) == 1:
165
- x = x.expand(-1, 3, -1, -1)
166
-
167
190
  return self.backbone(x)
168
191
 
169
192
  @classmethod
wavedl/models/swin.py CHANGED
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
141
141
  nn.Linear(regression_hidden // 2, out_size),
142
142
  )
143
143
 
144
- # Optionally freeze backbone for fine-tuning
144
+ # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
+ self._adapt_input_channels()
146
+
147
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
145
148
  if freeze_backbone:
146
149
  self._freeze_backbone()
147
150
 
151
+ def _adapt_input_channels(self):
152
+ """Modify patch embedding conv to accept single-channel input.
153
+
154
+ Instead of expanding 1→3 channels in forward (which triples memory),
155
+ we replace the patch embedding conv with a 1-channel version and
156
+ initialize weights as the mean of the pretrained RGB filters.
157
+ """
158
+ # Swin's patch embedding is at features[0][0]
159
+ try:
160
+ old_conv = self.backbone.features[0][0]
161
+ except (IndexError, AttributeError, TypeError) as e:
162
+ raise RuntimeError(
163
+ f"Swin patch embed structure changed in this torchvision version. "
164
+ f"Cannot adapt input channels. Error: {e}"
165
+ ) from e
166
+ new_conv = nn.Conv2d(
167
+ 1, # Single channel input
168
+ old_conv.out_channels,
169
+ kernel_size=old_conv.kernel_size,
170
+ stride=old_conv.stride,
171
+ padding=old_conv.padding,
172
+ dilation=old_conv.dilation,
173
+ groups=old_conv.groups,
174
+ padding_mode=old_conv.padding_mode,
175
+ bias=old_conv.bias is not None,
176
+ )
177
+ if self.pretrained:
178
+ with torch.no_grad():
179
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
+ if old_conv.bias is not None:
181
+ new_conv.bias.copy_(old_conv.bias)
182
+ self.backbone.features[0][0] = new_conv
183
+
148
184
  def _freeze_backbone(self):
149
185
  """Freeze all backbone parameters except the head."""
150
186
  for name, param in self.backbone.named_parameters():
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
156
192
  Forward pass.
157
193
 
158
194
  Args:
159
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
195
+ x: Input tensor of shape (B, 1, H, W)
160
196
 
161
197
  Returns:
162
198
  Output tensor of shape (B, out_size)
163
199
  """
164
- # Expand single channel to 3 channels for pretrained weights compatibility
165
- if x.size(1) == 1:
166
- x = x.expand(-1, 3, -1, -1)
167
-
168
200
  return self.backbone(x)
169
201
 
170
202
  @classmethod
wavedl/models/tcn.py CHANGED
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
45
45
  from wavedl.models.registry import register_model
46
46
 
47
47
 
48
+ def _find_group_count(channels: int, max_groups: int = 8) -> int:
49
+ """
50
+ Find largest valid group count for GroupNorm.
51
+
52
+ GroupNorm requires channels to be divisible by num_groups.
53
+ This finds the largest divisor up to max_groups.
54
+
55
+ Args:
56
+ channels: Number of channels
57
+ max_groups: Maximum group count to consider (default: 8)
58
+
59
+ Returns:
60
+ Largest valid group count (always >= 1)
61
+ """
62
+ for g in range(min(max_groups, channels), 0, -1):
63
+ if channels % g == 0:
64
+ return g
65
+ return 1
66
+
67
+
48
68
  class CausalConv1d(nn.Module):
49
69
  """
50
70
  Causal 1D convolution with dilation.
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
101
121
 
102
122
  # First causal convolution
103
123
  self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
104
- self.norm1 = nn.GroupNorm(min(8, out_channels), out_channels)
124
+ self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
105
125
  self.act1 = nn.GELU()
106
126
  self.dropout1 = nn.Dropout(dropout)
107
127
 
108
128
  # Second causal convolution
109
129
  self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
110
- self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
130
+ self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
111
131
  self.act2 = nn.GELU()
112
132
  self.dropout2 = nn.Dropout(dropout)
113
133
 
wavedl/models/vit.py CHANGED
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
42
42
  Supports 1D and 2D inputs:
43
43
  - 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
44
44
  - 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
45
+
46
+ Args:
47
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
48
+ patch_size: Size of each patch
49
+ embed_dim: Embedding dimension
50
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead of
51
+ dropping edge pixels. Important for NDE/QUS applications where edge
52
+ effects matter. Default: False (original behavior with warning).
45
53
  """
46
54
 
47
- def __init__(self, in_shape: SpatialShape, patch_size: int, embed_dim: int):
55
+ def __init__(
56
+ self,
57
+ in_shape: SpatialShape,
58
+ patch_size: int,
59
+ embed_dim: int,
60
+ pad_if_needed: bool = False,
61
+ ):
48
62
  super().__init__()
49
63
 
50
64
  self.dim = len(in_shape)
51
65
  self.patch_size = patch_size
52
66
  self.embed_dim = embed_dim
67
+ self.pad_if_needed = pad_if_needed
68
+ self._padding = None # Will be set if padding is needed
53
69
 
54
70
  if self.dim == 1:
55
71
  # 1D: segment patches
56
72
  L = in_shape[0]
57
- if L % patch_size != 0:
58
- import warnings
59
-
60
- warnings.warn(
61
- f"Input length {L} not divisible by patch_size {patch_size}. "
62
- f"Last {L % patch_size} elements will be dropped. "
63
- f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
- UserWarning,
65
- stacklevel=2,
66
- )
67
- self.num_patches = L // patch_size
73
+ remainder = L % patch_size
74
+ if remainder != 0:
75
+ if pad_if_needed:
76
+ # Pad to next multiple of patch_size
77
+ pad_amount = patch_size - remainder
78
+ self._padding = (0, pad_amount) # (left, right)
79
+ L_padded = L + pad_amount
80
+ self.num_patches = L_padded // patch_size
81
+ else:
82
+ import warnings
83
+
84
+ warnings.warn(
85
+ f"Input length {L} not divisible by patch_size {patch_size}. "
86
+ f"Last {remainder} elements will be dropped. "
87
+ f"Consider using pad_if_needed=True or padding input to "
88
+ f"{((L // patch_size) + 1) * patch_size}.",
89
+ UserWarning,
90
+ stacklevel=2,
91
+ )
92
+ self.num_patches = L // patch_size
93
+ else:
94
+ self.num_patches = L // patch_size
68
95
  self.proj = nn.Conv1d(
69
96
  1, embed_dim, kernel_size=patch_size, stride=patch_size
70
97
  )
71
98
  elif self.dim == 2:
72
99
  # 2D: grid patches
73
100
  H, W = in_shape
74
- if H % patch_size != 0 or W % patch_size != 0:
75
- import warnings
76
-
77
- warnings.warn(
78
- f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
- f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
- f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
- f"{((W // patch_size) + 1) * patch_size}).",
82
- UserWarning,
83
- stacklevel=2,
84
- )
85
- self.num_patches = (H // patch_size) * (W // patch_size)
101
+ h_rem, w_rem = H % patch_size, W % patch_size
102
+ if h_rem != 0 or w_rem != 0:
103
+ if pad_if_needed:
104
+ # Pad to next multiple of patch_size
105
+ h_pad = (patch_size - h_rem) % patch_size
106
+ w_pad = (patch_size - w_rem) % patch_size
107
+ # Padding format: (left, right, top, bottom)
108
+ self._padding = (0, w_pad, 0, h_pad)
109
+ H_padded, W_padded = H + h_pad, W + w_pad
110
+ self.num_patches = (H_padded // patch_size) * (
111
+ W_padded // patch_size
112
+ )
113
+ else:
114
+ import warnings
115
+
116
+ warnings.warn(
117
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
118
+ f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
119
+ f"Consider using pad_if_needed=True or padding to "
120
+ f"({((H // patch_size) + 1) * patch_size}, "
121
+ f"{((W // patch_size) + 1) * patch_size}).",
122
+ UserWarning,
123
+ stacklevel=2,
124
+ )
125
+ self.num_patches = (H // patch_size) * (W // patch_size)
126
+ else:
127
+ self.num_patches = (H // patch_size) * (W // patch_size)
86
128
  self.proj = nn.Conv2d(
87
129
  1, embed_dim, kernel_size=patch_size, stride=patch_size
88
130
  )
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
97
139
  Returns:
98
140
  Patch embeddings (B, num_patches, embed_dim)
99
141
  """
142
+ # Apply padding if configured
143
+ if self._padding is not None:
144
+ x = nn.functional.pad(x, self._padding, mode="constant", value=0)
145
+
100
146
  x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
101
147
  x = x.flatten(2) # (B, embed_dim, num_patches)
102
148
  x = x.transpose(1, 2) # (B, num_patches, embed_dim)
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
185
231
  3. Transformer encoder blocks
186
232
  4. Extract CLS token
187
233
  5. Regression head
234
+
235
+ Args:
236
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
237
+ out_size: Number of regression targets
238
+ patch_size: Size of each patch (default: 16)
239
+ embed_dim: Embedding dimension (default: 768)
240
+ depth: Number of transformer blocks (default: 12)
241
+ num_heads: Number of attention heads (default: 12)
242
+ mlp_ratio: MLP hidden dim multiplier (default: 4.0)
243
+ dropout_rate: Dropout rate (default: 0.1)
244
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead
245
+ of dropping edge pixels. Important for NDE/QUS applications.
188
246
  """
189
247
 
190
248
  def __init__(
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
197
255
  num_heads: int = 12,
198
256
  mlp_ratio: float = 4.0,
199
257
  dropout_rate: float = 0.1,
258
+ pad_if_needed: bool = False,
200
259
  **kwargs,
201
260
  ):
202
261
  super().__init__(in_shape, out_size)
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
207
266
  self.num_heads = num_heads
208
267
  self.dropout_rate = dropout_rate
209
268
  self.dim = len(in_shape)
269
+ self.pad_if_needed = pad_if_needed
210
270
 
211
271
  # Patch embedding
212
- self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
272
+ self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
213
273
  num_patches = self.patch_embed.num_patches
214
274
 
215
275
  # Learnable CLS token and position embeddings
wavedl/test.py CHANGED
@@ -311,7 +311,7 @@ def load_data_for_inference(
311
311
  # ==============================================================================
312
312
  def load_checkpoint(
313
313
  checkpoint_dir: str,
314
- in_shape: tuple[int, int],
314
+ in_shape: tuple[int, ...],
315
315
  out_size: int,
316
316
  model_name: str | None = None,
317
317
  ) -> tuple[nn.Module, any]:
@@ -320,7 +320,7 @@ def load_checkpoint(
320
320
 
321
321
  Args:
322
322
  checkpoint_dir: Path to checkpoint directory
323
- in_shape: Input image shape (H, W)
323
+ in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
324
324
  out_size: Number of output parameters
325
325
  model_name: Model architecture name (auto-detect if None)
326
326
 
@@ -376,7 +376,11 @@ def load_checkpoint(
376
376
  )
377
377
 
378
378
  logging.info(f" Building model: {model_name}")
379
- model = build_model(model_name, in_shape=in_shape, out_size=out_size)
379
+ # Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
380
+ # so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
381
+ model = build_model(
382
+ model_name, in_shape=in_shape, out_size=out_size, pretrained=False
383
+ )
380
384
 
381
385
  # Load weights (check multiple formats in order of preference)
382
386
  weight_path = None