wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,285 @@
1
+ """
2
+ FastViT: A Fast Hybrid Vision Transformer
3
+ ==========================================
4
+
5
+ FastViT from Apple uses RepMixer for efficient token mixing with structural
6
+ reparameterization - train with skip connections, deploy without.
7
+
8
+ **Key Features**:
9
+ - RepMixer: Reparameterizable token mixing
10
+ - Train-time overparameterization
11
+ - Faster than EfficientNet/ConvNeXt on mobile
12
+ - CoreML compatible
13
+
14
+ **Variants**:
15
+ - fastvit_t8: 4M params (fastest)
16
+ - fastvit_t12: 7M params
17
+ - fastvit_s12: 9M params
18
+ - fastvit_sa12: 21M params (with attention)
19
+
20
+ **Requirements**:
21
+ - timm >= 0.9.0 (for FastViT models)
22
+
23
+ Reference:
24
+ Vasu, P.K.A., et al. (2023). FastViT: A Fast Hybrid Vision Transformer
25
+ using Structural Reparameterization. ICCV 2023.
26
+ https://arxiv.org/abs/2303.14189
27
+
28
+ Author: Ductho Le (ductho.le@outlook.com)
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ from wavedl.models._timm_utils import build_regression_head
35
+ from wavedl.models.base import BaseModel
36
+ from wavedl.models.registry import register_model
37
+
38
+
39
+ __all__ = [
40
+ "FastViTBase",
41
+ "FastViTS12",
42
+ "FastViTSA12",
43
+ "FastViTT8",
44
+ "FastViTT12",
45
+ ]
46
+
47
+
48
+ # =============================================================================
49
+ # FASTVIT BASE CLASS
50
+ # =============================================================================
51
+
52
+
53
+ class FastViTBase(BaseModel):
54
+ """
55
+ FastViT base class wrapping timm implementation.
56
+
57
+ Uses RepMixer for efficient token mixing with reparameterization.
58
+ 2D only.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_shape: tuple[int, int],
64
+ out_size: int,
65
+ model_name: str = "fastvit_t8",
66
+ pretrained: bool = True,
67
+ freeze_backbone: bool = False,
68
+ dropout_rate: float = 0.3,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(in_shape, out_size)
72
+
73
+ if len(in_shape) != 2:
74
+ raise ValueError(f"FastViT requires 2D input (H, W), got {len(in_shape)}D")
75
+
76
+ self.pretrained = pretrained
77
+ self.freeze_backbone = freeze_backbone
78
+ self.model_name = model_name
79
+
80
+ # Try to load from timm
81
+ try:
82
+ import timm
83
+
84
+ self.backbone = timm.create_model(
85
+ model_name,
86
+ pretrained=pretrained,
87
+ num_classes=0, # Remove classifier
88
+ )
89
+
90
+ # Get feature dimension
91
+ with torch.no_grad():
92
+ dummy = torch.zeros(1, 3, *in_shape)
93
+ features = self.backbone(dummy)
94
+ in_features = features.shape[-1]
95
+
96
+ except ImportError:
97
+ raise ImportError(
98
+ "timm >= 0.9.0 is required for FastViT. "
99
+ "Install with: pip install timm>=0.9.0"
100
+ )
101
+ except Exception as e:
102
+ raise RuntimeError(f"Failed to load FastViT model '{model_name}': {e}")
103
+
104
+ # Adapt input channels (3 -> 1)
105
+ self._adapt_input_channels()
106
+
107
+ # Regression head
108
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
109
+
110
+ if freeze_backbone:
111
+ self._freeze_backbone()
112
+
113
+ def _adapt_input_channels(self):
114
+ """Adapt all conv layers with 3 input channels for single-channel input."""
115
+ # FastViT may have multiple modules with 3 input channels (e.g., conv_kxk, conv_scale)
116
+ # We need to adapt all of them
117
+ adapted_count = 0
118
+
119
+ for name, module in self.backbone.named_modules():
120
+ if hasattr(module, "in_channels") and module.in_channels == 3:
121
+ # Check if this is a wrapper (e.g., ConvNormAct) with inner .conv
122
+ if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
123
+ # Adapt the inner conv layer
124
+ old_conv = module.conv
125
+ module.conv = self._make_new_conv(old_conv)
126
+ adapted_count += 1
127
+ elif isinstance(module, nn.Conv2d):
128
+ # Direct Conv2d - replace it
129
+ parts = name.split(".")
130
+ parent = self.backbone
131
+ for part in parts[:-1]:
132
+ parent = getattr(parent, part)
133
+ child_name = parts[-1]
134
+ new_conv = self._make_new_conv(module)
135
+ setattr(parent, child_name, new_conv)
136
+ adapted_count += 1
137
+
138
+ if adapted_count == 0:
139
+ import warnings
140
+
141
+ warnings.warn(
142
+ "Could not adapt FastViT input channels. Model may fail.", stacklevel=2
143
+ )
144
+
145
+ def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
146
+ """Create new conv layer with 1 input channel."""
147
+ new_conv = nn.Conv2d(
148
+ 1,
149
+ old_conv.out_channels,
150
+ kernel_size=old_conv.kernel_size,
151
+ stride=old_conv.stride,
152
+ padding=old_conv.padding,
153
+ bias=old_conv.bias is not None,
154
+ )
155
+ if self.pretrained:
156
+ with torch.no_grad():
157
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
158
+ if old_conv.bias is not None:
159
+ new_conv.bias.copy_(old_conv.bias)
160
+ return new_conv
161
+
162
+ def _freeze_backbone(self):
163
+ """Freeze backbone parameters."""
164
+ for param in self.backbone.parameters():
165
+ param.requires_grad = False
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ features = self.backbone(x)
169
+ return self.head(features)
170
+
171
+ def reparameterize(self):
172
+ """
173
+ Reparameterize model for inference.
174
+
175
+ Fuses RepMixer blocks for faster inference.
176
+ Call this before deployment.
177
+ """
178
+ if hasattr(self.backbone, "reparameterize"):
179
+ self.backbone.reparameterize()
180
+
181
+
182
+ # =============================================================================
183
+ # REGISTERED VARIANTS
184
+ # =============================================================================
185
+
186
+
187
+ @register_model("fastvit_t8")
188
+ class FastViTT8(FastViTBase):
189
+ """
190
+ FastViT-T8: ~3.3M backbone parameters (fastest variant).
191
+
192
+ Optimized for mobile and edge deployment.
193
+ 2D only.
194
+
195
+ Example:
196
+ >>> model = FastViTT8(in_shape=(224, 224), out_size=3)
197
+ >>> x = torch.randn(4, 1, 224, 224)
198
+ >>> out = model(x) # (4, 3)
199
+ """
200
+
201
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
202
+ super().__init__(
203
+ in_shape=in_shape,
204
+ out_size=out_size,
205
+ model_name="fastvit_t8",
206
+ **kwargs,
207
+ )
208
+
209
+ def __repr__(self) -> str:
210
+ return (
211
+ f"FastViT_T8(in_shape={self.in_shape}, out_size={self.out_size}, "
212
+ f"pretrained={self.pretrained})"
213
+ )
214
+
215
+
216
+ @register_model("fastvit_t12")
217
+ class FastViTT12(FastViTBase):
218
+ """
219
+ FastViT-T12: ~6.5M backbone parameters.
220
+
221
+ Balanced speed and accuracy.
222
+ 2D only.
223
+ """
224
+
225
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
226
+ super().__init__(
227
+ in_shape=in_shape,
228
+ out_size=out_size,
229
+ model_name="fastvit_t12",
230
+ **kwargs,
231
+ )
232
+
233
+ def __repr__(self) -> str:
234
+ return (
235
+ f"FastViT_T12(in_shape={self.in_shape}, out_size={self.out_size}, "
236
+ f"pretrained={self.pretrained})"
237
+ )
238
+
239
+
240
+ @register_model("fastvit_s12")
241
+ class FastViTS12(FastViTBase):
242
+ """
243
+ FastViT-S12: ~8.5M backbone parameters.
244
+
245
+ Slightly larger for better accuracy.
246
+ 2D only.
247
+ """
248
+
249
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
250
+ super().__init__(
251
+ in_shape=in_shape,
252
+ out_size=out_size,
253
+ model_name="fastvit_s12",
254
+ **kwargs,
255
+ )
256
+
257
+ def __repr__(self) -> str:
258
+ return (
259
+ f"FastViT_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
260
+ f"pretrained={self.pretrained})"
261
+ )
262
+
263
+
264
+ @register_model("fastvit_sa12")
265
+ class FastViTSA12(FastViTBase):
266
+ """
267
+ FastViT-SA12: ~10.6M backbone parameters.
268
+
269
+ With self-attention for better accuracy at the cost of speed.
270
+ 2D only.
271
+ """
272
+
273
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
274
+ super().__init__(
275
+ in_shape=in_shape,
276
+ out_size=out_size,
277
+ model_name="fastvit_sa12",
278
+ **kwargs,
279
+ )
280
+
281
+ def __repr__(self) -> str:
282
+ return (
283
+ f"FastViT_SA12(in_shape={self.in_shape}, out_size={self.out_size}, "
284
+ f"pretrained={self.pretrained})"
285
+ )