wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/__init__.py CHANGED
@@ -80,8 +80,24 @@ from .vit import ViTBase_, ViTSmall, ViTTiny
80
80
  # Optional timm-based models (imported conditionally)
81
81
  try:
82
82
  from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
83
+ from .efficientvit import (
84
+ EfficientViTB0,
85
+ EfficientViTB1,
86
+ EfficientViTB2,
87
+ EfficientViTB3,
88
+ EfficientViTL1,
89
+ EfficientViTL2,
90
+ EfficientViTM0,
91
+ EfficientViTM1,
92
+ EfficientViTM2,
93
+ )
83
94
  from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
84
95
  from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
96
+ from .unireplknet import (
97
+ UniRepLKNetBaseLarge,
98
+ UniRepLKNetSmall,
99
+ UniRepLKNetTiny,
100
+ )
85
101
 
86
102
  _HAS_TIMM_MODELS = True
87
103
  except ImportError:
@@ -148,6 +164,15 @@ if _HAS_TIMM_MODELS:
148
164
  [
149
165
  "CaFormerS18",
150
166
  "CaFormerS36",
167
+ "EfficientViTB0",
168
+ "EfficientViTB1",
169
+ "EfficientViTB2",
170
+ "EfficientViTB3",
171
+ "EfficientViTL1",
172
+ "EfficientViTL2",
173
+ "EfficientViTM0",
174
+ "EfficientViTM1",
175
+ "EfficientViTM2",
151
176
  "FastViTS12",
152
177
  "FastViTSA12",
153
178
  "FastViTT8",
@@ -156,5 +181,8 @@ if _HAS_TIMM_MODELS:
156
181
  "MaxViTSmall",
157
182
  "MaxViTTiny",
158
183
  "PoolFormerS12",
184
+ "UniRepLKNetBaseLarge",
185
+ "UniRepLKNetSmall",
186
+ "UniRepLKNetTiny",
159
187
  ]
160
188
  )
@@ -236,3 +236,131 @@ def adapt_input_channels(
236
236
  return new_conv
237
237
  else:
238
238
  raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
239
+
240
+
241
+ def adapt_first_conv_for_single_channel(
242
+ module: nn.Module,
243
+ conv_path: str,
244
+ pretrained: bool = True,
245
+ ) -> None:
246
+ """
247
+ Adapt the first convolutional layer of a pretrained model for single-channel input.
248
+
249
+ This is a convenience function for torchvision-style models where the path
250
+ to the first conv layer is known. It modifies the model in-place.
251
+
252
+ For pretrained models, the RGB weights are averaged to create grayscale weights,
253
+ which provides a reasonable initialization for single-channel inputs.
254
+
255
+ Args:
256
+ module: The model or submodule containing the conv layer
257
+ conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
258
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
259
+
260
+ Example:
261
+ >>> # For torchvision ResNet
262
+ >>> adapt_first_conv_for_single_channel(
263
+ ... model.backbone, "conv1", pretrained=True
264
+ ... )
265
+ >>> # For torchvision ConvNeXt
266
+ >>> adapt_first_conv_for_single_channel(
267
+ ... model.backbone, "features.0.0", pretrained=True
268
+ ... )
269
+ >>> # For torchvision DenseNet
270
+ >>> adapt_first_conv_for_single_channel(
271
+ ... model.backbone, "features.conv0", pretrained=True
272
+ ... )
273
+ """
274
+ # Navigate to parent and get the conv layer
275
+ parts = conv_path.split(".")
276
+ parent = module
277
+ for part in parts[:-1]:
278
+ if part.isdigit():
279
+ parent = parent[int(part)]
280
+ else:
281
+ parent = getattr(parent, part)
282
+
283
+ # Get the final attribute name and the old conv
284
+ final_attr = parts[-1]
285
+ if final_attr.isdigit():
286
+ old_conv = parent[int(final_attr)]
287
+ else:
288
+ old_conv = getattr(parent, final_attr)
289
+
290
+ # Create and set the new conv
291
+ new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
292
+
293
+ if final_attr.isdigit():
294
+ parent[int(final_attr)] = new_conv
295
+ else:
296
+ setattr(parent, final_attr, new_conv)
297
+
298
+
299
+ def find_and_adapt_input_convs(
300
+ backbone: nn.Module,
301
+ pretrained: bool = True,
302
+ adapt_all: bool = False,
303
+ ) -> int:
304
+ """
305
+ Find and adapt Conv2d layers with 3 input channels for single-channel input.
306
+
307
+ This is useful for timm-style models where the exact path to the first
308
+ conv layer may vary or where multiple layers need adaptation.
309
+
310
+ Args:
311
+ backbone: The backbone model to adapt
312
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
313
+ adapt_all: If True, adapt all Conv2d layers with 3 input channels.
314
+ If False (default), only adapt the first one found.
315
+
316
+ Returns:
317
+ Number of layers adapted
318
+
319
+ Example:
320
+ >>> # For timm models (adapt first conv only)
321
+ >>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
322
+ >>> # For models with multiple input convs (e.g., FastViT)
323
+ >>> count = find_and_adapt_input_convs(
324
+ ... model.backbone, pretrained=True, adapt_all=True
325
+ ... )
326
+ """
327
+ adapted_count = 0
328
+
329
+ for name, module in backbone.named_modules():
330
+ if not hasattr(module, "in_channels") or module.in_channels != 3:
331
+ continue
332
+
333
+ # Check if this is a wrapper with inner .conv attribute
334
+ if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
335
+ old_conv = module.conv
336
+ module.conv = adapt_input_channels(
337
+ old_conv, new_in_channels=1, pretrained=pretrained
338
+ )
339
+ adapted_count += 1
340
+
341
+ elif isinstance(module, nn.Conv2d):
342
+ # Direct Conv2d - need to replace it in parent
343
+ parts = name.split(".")
344
+ parent = backbone
345
+ for part in parts[:-1]:
346
+ if part.isdigit():
347
+ parent = parent[int(part)]
348
+ else:
349
+ parent = getattr(parent, part)
350
+
351
+ child_name = parts[-1]
352
+ new_conv = adapt_input_channels(
353
+ module, new_in_channels=1, pretrained=pretrained
354
+ )
355
+
356
+ if child_name.isdigit():
357
+ parent[int(child_name)] = new_conv
358
+ else:
359
+ setattr(parent, child_name, new_conv)
360
+
361
+ adapted_count += 1
362
+
363
+ if not adapt_all and adapted_count > 0:
364
+ break
365
+
366
+ return adapted_count
wavedl/models/base.py CHANGED
@@ -15,6 +15,54 @@ import torch
15
15
  import torch.nn as nn
16
16
 
17
17
 
18
+ # =============================================================================
19
+ # TYPE ALIASES
20
+ # =============================================================================
21
+
22
+ # Spatial shape type aliases for model input dimensions
23
+ SpatialShape1D = tuple[int]
24
+ SpatialShape2D = tuple[int, int]
25
+ SpatialShape3D = tuple[int, int, int]
26
+ SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
27
+
28
+
29
+ # =============================================================================
30
+ # UTILITY FUNCTIONS
31
+ # =============================================================================
32
+
33
+
34
+ def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
35
+ """
36
+ Compute valid num_groups for GroupNorm that divides num_channels evenly.
37
+
38
+ GroupNorm requires num_channels to be divisible by num_groups. This function
39
+ finds the largest valid divisor up to preferred_groups.
40
+
41
+ Args:
42
+ num_channels: Number of channels to normalize (must be positive)
43
+ preferred_groups: Preferred number of groups (default: 32)
44
+
45
+ Returns:
46
+ Valid num_groups that satisfies num_channels % num_groups == 0
47
+
48
+ Example:
49
+ >>> compute_num_groups(64) # Returns 32
50
+ >>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
51
+ >>> compute_num_groups(7) # Returns 1 (prime number)
52
+ """
53
+ # Try preferred groups first, then common divisors
54
+ for groups in [preferred_groups, 16, 8, 4, 2, 1]:
55
+ if groups <= num_channels and num_channels % groups == 0:
56
+ return groups
57
+
58
+ # Fallback: find any valid divisor (always returns at least 1)
59
+ for groups in range(min(32, num_channels), 0, -1):
60
+ if num_channels % groups == 0:
61
+ return groups
62
+
63
+ return 1 # Always valid
64
+
65
+
18
66
  class BaseModel(nn.Module, ABC):
19
67
  """
20
68
  Abstract base class for all regression models.
wavedl/models/caformer.py CHANGED
@@ -33,7 +33,7 @@ Author: Ductho Le (ductho.le@outlook.com)
33
33
  import torch
34
34
  import torch.nn as nn
35
35
 
36
- from wavedl.models._timm_utils import build_regression_head
36
+ from wavedl.models._pretrained_utils import build_regression_head
37
37
  from wavedl.models.base import BaseModel
38
38
  from wavedl.models.registry import register_model
39
39
 
wavedl/models/cnn.py CHANGED
@@ -24,14 +24,10 @@ from typing import Any
24
24
  import torch
25
25
  import torch.nn as nn
26
26
 
27
- from wavedl.models.base import BaseModel
27
+ from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
28
28
  from wavedl.models.registry import register_model
29
29
 
30
30
 
31
- # Type alias for spatial shapes
32
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
33
-
34
-
35
31
  def _get_conv_layers(
36
32
  dim: int,
37
33
  ) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
@@ -163,27 +159,6 @@ class CNN(BaseModel):
163
159
  nn.Linear(64, out_size),
164
160
  )
165
161
 
166
- @staticmethod
167
- def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
168
- """
169
- Compute valid num_groups for GroupNorm that divides num_channels.
170
-
171
- Finds the largest divisor of num_channels that is <= target_groups,
172
- or falls back to 1 if no suitable divisor exists.
173
-
174
- Args:
175
- num_channels: Number of channels (must be positive)
176
- target_groups: Desired number of groups (default: 4)
177
-
178
- Returns:
179
- Valid num_groups that satisfies num_channels % num_groups == 0
180
- """
181
- # Try target_groups down to 1, return first valid divisor
182
- for g in range(min(target_groups, num_channels), 0, -1):
183
- if num_channels % g == 0:
184
- return g
185
- return 1 # Fallback (always valid)
186
-
187
162
  def _make_conv_block(
188
163
  self, in_channels: int, out_channels: int, dropout: float = 0.0
189
164
  ) -> nn.Sequential:
@@ -198,7 +173,7 @@ class CNN(BaseModel):
198
173
  Returns:
199
174
  Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
200
175
  """
201
- num_groups = self._compute_num_groups(out_channels, target_groups=4)
176
+ num_groups = compute_num_groups(out_channels, preferred_groups=4)
202
177
 
203
178
  layers = [
204
179
  self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
wavedl/models/convnext.py CHANGED
@@ -28,14 +28,10 @@ import torch
28
28
  import torch.nn as nn
29
29
  import torch.nn.functional as F
30
30
 
31
- from wavedl.models.base import BaseModel
31
+ from wavedl.models.base import BaseModel, SpatialShape
32
32
  from wavedl.models.registry import register_model
33
33
 
34
34
 
35
- # Type alias for spatial shapes
36
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
37
-
38
-
39
35
  def _get_conv_layer(dim: int) -> type[nn.Module]:
40
36
  """Get dimension-appropriate Conv class."""
41
37
  if dim == 1:
@@ -468,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
468
464
  )
469
465
 
470
466
  # Modify first conv for single-channel input
471
- old_conv = self.backbone.features[0][0]
472
- self.backbone.features[0][0] = nn.Conv2d(
473
- 1,
474
- old_conv.out_channels,
475
- kernel_size=old_conv.kernel_size,
476
- stride=old_conv.stride,
477
- padding=old_conv.padding,
478
- bias=old_conv.bias is not None,
467
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
468
+
469
+ adapt_first_conv_for_single_channel(
470
+ self.backbone, "features.0.0", pretrained=pretrained
479
471
  )
480
- if pretrained:
481
- with torch.no_grad():
482
- self.backbone.features[0][0].weight = nn.Parameter(
483
- old_conv.weight.mean(dim=1, keepdim=True)
484
- )
485
472
 
486
473
  if freeze_backbone:
487
474
  self._freeze_backbone()
@@ -31,20 +31,17 @@ from typing import Any
31
31
  import torch
32
32
  import torch.nn as nn
33
33
 
34
- from wavedl.models._timm_utils import (
34
+ from wavedl.models._pretrained_utils import (
35
35
  LayerNormNd,
36
36
  build_regression_head,
37
37
  get_conv_layer,
38
38
  get_grn_layer,
39
39
  get_pool_layer,
40
40
  )
41
- from wavedl.models.base import BaseModel
41
+ from wavedl.models.base import BaseModel, SpatialShape
42
42
  from wavedl.models.registry import register_model
43
43
 
44
44
 
45
- # Type alias for spatial shapes
46
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
47
-
48
45
  __all__ = [
49
46
  "ConvNeXtV2Base",
50
47
  "ConvNeXtV2BaseLarge",
@@ -469,24 +466,11 @@ class ConvNeXtV2TinyPretrained(BaseModel):
469
466
 
470
467
  def _adapt_input_channels(self):
471
468
  """Adapt first conv layer for single-channel input."""
472
- old_conv = self.backbone.features[0][0]
473
- new_conv = nn.Conv2d(
474
- 1,
475
- old_conv.out_channels,
476
- kernel_size=old_conv.kernel_size,
477
- stride=old_conv.stride,
478
- padding=old_conv.padding,
479
- bias=old_conv.bias is not None,
480
- )
469
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
481
470
 
482
- if self.pretrained:
483
- with torch.no_grad():
484
- # Average RGB weights for grayscale
485
- new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
486
- if old_conv.bias is not None:
487
- new_conv.bias.copy_(old_conv.bias)
488
-
489
- self.backbone.features[0][0] = new_conv
471
+ adapt_first_conv_for_single_channel(
472
+ self.backbone, "features.0.0", pretrained=self.pretrained
473
+ )
490
474
 
491
475
  def _freeze_backbone(self):
492
476
  """Freeze all backbone parameters except classifier."""
wavedl/models/densenet.py CHANGED
@@ -26,14 +26,10 @@ from typing import Any
26
26
  import torch
27
27
  import torch.nn as nn
28
28
 
29
- from wavedl.models.base import BaseModel
29
+ from wavedl.models.base import BaseModel, SpatialShape
30
30
  from wavedl.models.registry import register_model
31
31
 
32
32
 
33
- # Type alias for spatial shapes
34
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
35
-
36
-
37
33
  def _get_layers(dim: int):
38
34
  """Get dimension-appropriate layer classes."""
39
35
  if dim == 1:
@@ -374,20 +370,11 @@ class DenseNet121Pretrained(BaseModel):
374
370
  )
375
371
 
376
372
  # Modify first conv for single-channel input
377
- old_conv = self.backbone.features.conv0
378
- self.backbone.features.conv0 = nn.Conv2d(
379
- 1,
380
- old_conv.out_channels,
381
- kernel_size=old_conv.kernel_size,
382
- stride=old_conv.stride,
383
- padding=old_conv.padding,
384
- bias=False,
373
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
374
+
375
+ adapt_first_conv_for_single_channel(
376
+ self.backbone, "features.conv0", pretrained=pretrained
385
377
  )
386
- if pretrained:
387
- with torch.no_grad():
388
- self.backbone.features.conv0.weight = nn.Parameter(
389
- old_conv.weight.mean(dim=1, keepdim=True)
390
- )
391
378
 
392
379
  if freeze_backbone:
393
380
  self._freeze_backbone()