wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/maxvit.py CHANGED
@@ -28,9 +28,9 @@ Author: Ductho Le (ductho.le@outlook.com)
28
28
  """
29
29
 
30
30
  import torch
31
- import torch.nn as nn
31
+ import torch.nn.functional as F
32
32
 
33
- from wavedl.models._timm_utils import build_regression_head
33
+ from wavedl.models._pretrained_utils import build_regression_head
34
34
  from wavedl.models.base import BaseModel
35
35
  from wavedl.models.registry import register_model
36
36
 
@@ -54,8 +54,16 @@ class MaxViTBase(BaseModel):
54
54
 
55
55
  Multi-axis attention with local block and global grid attention.
56
56
  2D only due to attention structure.
57
+
58
+ Note:
59
+ MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
60
+ This implementation automatically resizes inputs to the nearest compatible size.
57
61
  """
58
62
 
63
+ # MaxViT stem downsamples by 4x, then requires divisibility by 7 (window size)
64
+ # So original input must be divisible by 4 * 7 = 28
65
+ _DIVISOR = 28
66
+
59
67
  def __init__(
60
68
  self,
61
69
  in_shape: tuple[int, int],
@@ -75,6 +83,9 @@ class MaxViTBase(BaseModel):
75
83
  self.freeze_backbone = freeze_backbone
76
84
  self.model_name = model_name
77
85
 
86
+ # Compute compatible input size for MaxViT attention windows
87
+ self._target_size = self._compute_compatible_size(in_shape)
88
+
78
89
  # Try to load from timm
79
90
  try:
80
91
  import timm
@@ -85,9 +96,9 @@ class MaxViTBase(BaseModel):
85
96
  num_classes=0, # Remove classifier
86
97
  )
87
98
 
88
- # Get feature dimension
99
+ # Get feature dimension using compatible size
89
100
  with torch.no_grad():
90
- dummy = torch.zeros(1, 3, *in_shape)
101
+ dummy = torch.zeros(1, 3, *self._target_size)
91
102
  features = self.backbone(dummy)
92
103
  in_features = features.shape[-1]
93
104
 
@@ -109,62 +120,54 @@ class MaxViTBase(BaseModel):
109
120
 
110
121
  def _adapt_input_channels(self):
111
122
  """Adapt first conv layer for single-channel input."""
112
- # MaxViT uses stem.conv1 (Conv2dSame from timm)
113
- adapted = False
114
-
115
- # Find the first Conv2d with 3 input channels
116
- for name, module in self.backbone.named_modules():
117
- if hasattr(module, "in_channels") and module.in_channels == 3:
118
- # Get parent and child names
119
- parts = name.split(".")
120
- parent = self.backbone
121
- for part in parts[:-1]:
122
- parent = getattr(parent, part)
123
- child_name = parts[-1]
124
-
125
- # Create new conv with 1 input channel
126
- new_conv = self._make_new_conv(module)
127
- setattr(parent, child_name, new_conv)
128
- adapted = True
129
- break
130
-
131
- if not adapted:
123
+ from wavedl.models._pretrained_utils import find_and_adapt_input_convs
124
+
125
+ adapted_count = find_and_adapt_input_convs(
126
+ self.backbone, pretrained=self.pretrained, adapt_all=False
127
+ )
128
+
129
+ if adapted_count == 0:
132
130
  import warnings
133
131
 
134
132
  warnings.warn(
135
133
  "Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
136
134
  )
137
135
 
138
- def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
139
- """Create new conv layer with 1 input channel."""
140
- # Handle both Conv2d and Conv2dSame from timm
141
- type(old_conv)
142
-
143
- # Get common parameters
144
- kwargs = {
145
- "out_channels": old_conv.out_channels,
146
- "kernel_size": old_conv.kernel_size,
147
- "stride": old_conv.stride,
148
- "padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
149
- "bias": old_conv.bias is not None,
150
- }
151
-
152
- # Create new conv (use regular Conv2d for simplicity)
153
- new_conv = nn.Conv2d(1, **kwargs)
154
-
155
- if self.pretrained:
156
- with torch.no_grad():
157
- new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
158
- if old_conv.bias is not None:
159
- new_conv.bias.copy_(old_conv.bias)
160
- return new_conv
161
-
162
136
  def _freeze_backbone(self):
163
137
  """Freeze backbone parameters."""
164
138
  for param in self.backbone.parameters():
165
139
  param.requires_grad = False
166
140
 
141
+ def _compute_compatible_size(self, in_shape: tuple[int, int]) -> tuple[int, int]:
142
+ """
143
+ Compute the nearest input size compatible with MaxViT attention windows.
144
+
145
+ MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
146
+ This rounds up to the nearest compatible size.
147
+
148
+ Args:
149
+ in_shape: Original (H, W) input shape
150
+
151
+ Returns:
152
+ Compatible (H, W) shape divisible by 28
153
+ """
154
+ import math
155
+
156
+ h, w = in_shape
157
+ target_h = math.ceil(h / self._DIVISOR) * self._DIVISOR
158
+ target_w = math.ceil(w / self._DIVISOR) * self._DIVISOR
159
+ return (target_h, target_w)
160
+
167
161
  def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ # Resize input to compatible size if needed
163
+ _, _, h, w = x.shape
164
+ if (h, w) != self._target_size:
165
+ x = F.interpolate(
166
+ x,
167
+ size=self._target_size,
168
+ mode="bilinear",
169
+ align_corners=False,
170
+ )
168
171
  features = self.backbone(x)
169
172
  return self.head(features)
170
173