wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,398 @@
1
+ """
2
+ EfficientViT: Memory-Efficient Vision Transformer with Cascaded Group Attention
3
+ ================================================================================
4
+
5
+ EfficientViT (MIT) achieves state-of-the-art speed-accuracy trade-off by using
6
+ cascaded group attention (CGA) which reduces computational redundancy in
7
+ multi-head self-attention while maintaining model capability.
8
+
9
+ **Key Features**:
10
+ - Cascaded Group Attention (CGA): Linear complexity attention
11
+ - Memory-efficient design for edge deployment
12
+ - Faster than Swin Transformer with similar accuracy
13
+ - Excellent for real-time NDE applications
14
+
15
+ **Variants**:
16
+ - efficientvit_m0: 2.3M params (mobile, fastest)
17
+ - efficientvit_m1: 2.9M params (mobile)
18
+ - efficientvit_m2: 4.2M params (mobile)
19
+ - efficientvit_b0: 3.4M params (balanced)
20
+ - efficientvit_b1: 9.1M params (balanced)
21
+ - efficientvit_b2: 24M params (balanced)
22
+ - efficientvit_b3: 49M params (balanced)
23
+ - efficientvit_l1: 53M params (large)
24
+ - efficientvit_l2: 64M params (large)
25
+
26
+ **Requirements**:
27
+ - timm >= 0.9.0 (for EfficientViT models)
28
+
29
+ Reference:
30
+ Liu, X., et al. (2023). EfficientViT: Memory Efficient Vision Transformer
31
+ with Cascaded Group Attention. CVPR 2023.
32
+ https://arxiv.org/abs/2305.07027
33
+
34
+ Author: Ductho Le (ductho.le@outlook.com)
35
+ """
36
+
37
+ import torch
38
+
39
+ from wavedl.models._pretrained_utils import build_regression_head
40
+ from wavedl.models.base import BaseModel
41
+ from wavedl.models.registry import register_model
42
+
43
+
44
+ __all__ = [
45
+ "EfficientViTB0",
46
+ "EfficientViTB1",
47
+ "EfficientViTB2",
48
+ "EfficientViTB3",
49
+ "EfficientViTBase",
50
+ "EfficientViTL1",
51
+ "EfficientViTL2",
52
+ "EfficientViTM0",
53
+ "EfficientViTM1",
54
+ "EfficientViTM2",
55
+ ]
56
+
57
+
58
+ # =============================================================================
59
+ # EFFICIENTVIT BASE CLASS
60
+ # =============================================================================
61
+
62
+
63
+ class EfficientViTBase(BaseModel):
64
+ """
65
+ EfficientViT base class wrapping timm implementation.
66
+
67
+ Uses Cascaded Group Attention for efficient multi-head attention with
68
+ linear complexity. 2D only due to attention structure.
69
+
70
+ Args:
71
+ in_shape: (H, W) input shape (2D only)
72
+ out_size: Number of regression targets
73
+ model_name: timm model name
74
+ pretrained: Whether to load pretrained weights
75
+ freeze_backbone: Whether to freeze backbone for fine-tuning
76
+ dropout_rate: Dropout rate for regression head
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ in_shape: tuple[int, int],
82
+ out_size: int,
83
+ model_name: str = "efficientvit_b0",
84
+ pretrained: bool = True,
85
+ freeze_backbone: bool = False,
86
+ dropout_rate: float = 0.3,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(in_shape, out_size)
90
+
91
+ if len(in_shape) != 2:
92
+ raise ValueError(
93
+ f"EfficientViT requires 2D input (H, W), got {len(in_shape)}D"
94
+ )
95
+
96
+ self.pretrained = pretrained
97
+ self.freeze_backbone = freeze_backbone
98
+ self.model_name = model_name
99
+
100
+ # Load from timm
101
+ try:
102
+ import timm
103
+
104
+ self.backbone = timm.create_model(
105
+ model_name,
106
+ pretrained=pretrained,
107
+ num_classes=0, # Remove classifier
108
+ )
109
+
110
+ # Get feature dimension
111
+ with torch.no_grad():
112
+ dummy = torch.zeros(1, 3, *in_shape)
113
+ features = self.backbone(dummy)
114
+ in_features = features.shape[-1]
115
+
116
+ except ImportError:
117
+ raise ImportError(
118
+ "timm >= 0.9.0 is required for EfficientViT. "
119
+ "Install with: pip install timm>=0.9.0"
120
+ )
121
+ except Exception as e:
122
+ raise RuntimeError(f"Failed to load EfficientViT model '{model_name}': {e}")
123
+
124
+ # Adapt input channels (3 -> 1)
125
+ self._adapt_input_channels()
126
+
127
+ # Regression head
128
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
129
+
130
+ if freeze_backbone:
131
+ self._freeze_backbone()
132
+
133
+ def _adapt_input_channels(self):
134
+ """Adapt first conv layer for single-channel input."""
135
+ from wavedl.models._pretrained_utils import find_and_adapt_input_convs
136
+
137
+ adapted_count = find_and_adapt_input_convs(
138
+ self.backbone, pretrained=self.pretrained, adapt_all=False
139
+ )
140
+
141
+ if adapted_count == 0:
142
+ import warnings
143
+
144
+ warnings.warn(
145
+ "Could not adapt EfficientViT input channels. Model may fail.",
146
+ stacklevel=2,
147
+ )
148
+
149
+ def _freeze_backbone(self):
150
+ """Freeze backbone parameters."""
151
+ for param in self.backbone.parameters():
152
+ param.requires_grad = False
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ features = self.backbone(x)
156
+ return self.head(features)
157
+
158
+
159
+ # =============================================================================
160
+ # MOBILE VARIANTS (Ultra-lightweight)
161
+ # =============================================================================
162
+
163
+
164
+ @register_model("efficientvit_m0")
165
+ class EfficientViTM0(EfficientViTBase):
166
+ """
167
+ EfficientViT-M0: ~2.2M backbone parameters (fastest mobile variant).
168
+
169
+ Cascaded group attention for efficient inference.
170
+ Ideal for edge deployment and real-time NDE applications.
171
+ 2D only.
172
+
173
+ Example:
174
+ >>> model = EfficientViTM0(in_shape=(224, 224), out_size=3)
175
+ >>> x = torch.randn(4, 1, 224, 224)
176
+ >>> out = model(x) # (4, 3)
177
+ """
178
+
179
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
180
+ super().__init__(
181
+ in_shape=in_shape,
182
+ out_size=out_size,
183
+ model_name="efficientvit_m0",
184
+ **kwargs,
185
+ )
186
+
187
+ def __repr__(self) -> str:
188
+ return (
189
+ f"EfficientViT_M0(in_shape={self.in_shape}, out_size={self.out_size}, "
190
+ f"pretrained={self.pretrained})"
191
+ )
192
+
193
+
194
+ @register_model("efficientvit_m1")
195
+ class EfficientViTM1(EfficientViTBase):
196
+ """
197
+ EfficientViT-M1: ~2.6M backbone parameters.
198
+
199
+ Slightly larger mobile variant with better accuracy.
200
+ 2D only.
201
+ """
202
+
203
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
204
+ super().__init__(
205
+ in_shape=in_shape,
206
+ out_size=out_size,
207
+ model_name="efficientvit_m1",
208
+ **kwargs,
209
+ )
210
+
211
+ def __repr__(self) -> str:
212
+ return (
213
+ f"EfficientViT_M1(in_shape={self.in_shape}, out_size={self.out_size}, "
214
+ f"pretrained={self.pretrained})"
215
+ )
216
+
217
+
218
+ @register_model("efficientvit_m2")
219
+ class EfficientViTM2(EfficientViTBase):
220
+ """
221
+ EfficientViT-M2: ~3.8M backbone parameters.
222
+
223
+ Largest mobile variant, best accuracy among M-series.
224
+ 2D only.
225
+ """
226
+
227
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
228
+ super().__init__(
229
+ in_shape=in_shape,
230
+ out_size=out_size,
231
+ model_name="efficientvit_m2",
232
+ **kwargs,
233
+ )
234
+
235
+ def __repr__(self) -> str:
236
+ return (
237
+ f"EfficientViT_M2(in_shape={self.in_shape}, out_size={self.out_size}, "
238
+ f"pretrained={self.pretrained})"
239
+ )
240
+
241
+
242
+ # =============================================================================
243
+ # BALANCED VARIANTS (B-series)
244
+ # =============================================================================
245
+
246
+
247
+ @register_model("efficientvit_b0")
248
+ class EfficientViTB0(EfficientViTBase):
249
+ """
250
+ EfficientViT-B0: ~2.1M backbone parameters.
251
+
252
+ Smallest balanced variant. Good accuracy-speed trade-off.
253
+ 2D only.
254
+
255
+ Example:
256
+ >>> model = EfficientViTB0(in_shape=(224, 224), out_size=3)
257
+ >>> x = torch.randn(4, 1, 224, 224)
258
+ >>> out = model(x) # (4, 3)
259
+ """
260
+
261
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
262
+ super().__init__(
263
+ in_shape=in_shape,
264
+ out_size=out_size,
265
+ model_name="efficientvit_b0",
266
+ **kwargs,
267
+ )
268
+
269
+ def __repr__(self) -> str:
270
+ return (
271
+ f"EfficientViT_B0(in_shape={self.in_shape}, out_size={self.out_size}, "
272
+ f"pretrained={self.pretrained})"
273
+ )
274
+
275
+
276
+ @register_model("efficientvit_b1")
277
+ class EfficientViTB1(EfficientViTBase):
278
+ """
279
+ EfficientViT-B1: ~7.5M backbone parameters.
280
+
281
+ Medium balanced variant with improved capacity.
282
+ 2D only.
283
+ """
284
+
285
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
286
+ super().__init__(
287
+ in_shape=in_shape,
288
+ out_size=out_size,
289
+ model_name="efficientvit_b1",
290
+ **kwargs,
291
+ )
292
+
293
+ def __repr__(self) -> str:
294
+ return (
295
+ f"EfficientViT_B1(in_shape={self.in_shape}, out_size={self.out_size}, "
296
+ f"pretrained={self.pretrained})"
297
+ )
298
+
299
+
300
+ @register_model("efficientvit_b2")
301
+ class EfficientViTB2(EfficientViTBase):
302
+ """
303
+ EfficientViT-B2: ~21.8M backbone parameters.
304
+
305
+ Larger balanced variant for complex patterns.
306
+ 2D only.
307
+ """
308
+
309
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
310
+ super().__init__(
311
+ in_shape=in_shape,
312
+ out_size=out_size,
313
+ model_name="efficientvit_b2",
314
+ **kwargs,
315
+ )
316
+
317
+ def __repr__(self) -> str:
318
+ return (
319
+ f"EfficientViT_B2(in_shape={self.in_shape}, out_size={self.out_size}, "
320
+ f"pretrained={self.pretrained})"
321
+ )
322
+
323
+
324
+ @register_model("efficientvit_b3")
325
+ class EfficientViTB3(EfficientViTBase):
326
+ """
327
+ EfficientViT-B3: ~46.1M backbone parameters.
328
+
329
+ Largest balanced variant, highest accuracy in B-series.
330
+ 2D only.
331
+ """
332
+
333
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
334
+ super().__init__(
335
+ in_shape=in_shape,
336
+ out_size=out_size,
337
+ model_name="efficientvit_b3",
338
+ **kwargs,
339
+ )
340
+
341
+ def __repr__(self) -> str:
342
+ return (
343
+ f"EfficientViT_B3(in_shape={self.in_shape}, out_size={self.out_size}, "
344
+ f"pretrained={self.pretrained})"
345
+ )
346
+
347
+
348
+ # =============================================================================
349
+ # LARGE VARIANTS (L-series)
350
+ # =============================================================================
351
+
352
+
353
+ @register_model("efficientvit_l1")
354
+ class EfficientViTL1(EfficientViTBase):
355
+ """
356
+ EfficientViT-L1: ~49.5M backbone parameters.
357
+
358
+ Large variant for maximum accuracy.
359
+ 2D only.
360
+ """
361
+
362
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
363
+ super().__init__(
364
+ in_shape=in_shape,
365
+ out_size=out_size,
366
+ model_name="efficientvit_l1",
367
+ **kwargs,
368
+ )
369
+
370
+ def __repr__(self) -> str:
371
+ return (
372
+ f"EfficientViT_L1(in_shape={self.in_shape}, out_size={self.out_size}, "
373
+ f"pretrained={self.pretrained})"
374
+ )
375
+
376
+
377
+ @register_model("efficientvit_l2")
378
+ class EfficientViTL2(EfficientViTBase):
379
+ """
380
+ EfficientViT-L2: ~60.5M backbone parameters.
381
+
382
+ Largest variant, best accuracy.
383
+ 2D only.
384
+ """
385
+
386
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
387
+ super().__init__(
388
+ in_shape=in_shape,
389
+ out_size=out_size,
390
+ model_name="efficientvit_l2",
391
+ **kwargs,
392
+ )
393
+
394
+ def __repr__(self) -> str:
395
+ return (
396
+ f"EfficientViT_L2(in_shape={self.in_shape}, out_size={self.out_size}, "
397
+ f"pretrained={self.pretrained})"
398
+ )
@@ -0,0 +1,252 @@
1
+ """
2
+ FastViT: A Fast Hybrid Vision Transformer
3
+ ==========================================
4
+
5
+ FastViT from Apple uses RepMixer for efficient token mixing with structural
6
+ reparameterization - train with skip connections, deploy without.
7
+
8
+ **Key Features**:
9
+ - RepMixer: Reparameterizable token mixing
10
+ - Train-time overparameterization
11
+ - Faster than EfficientNet/ConvNeXt on mobile
12
+ - CoreML compatible
13
+
14
+ **Variants**:
15
+ - fastvit_t8: 4M params (fastest)
16
+ - fastvit_t12: 7M params
17
+ - fastvit_s12: 9M params
18
+ - fastvit_sa12: 21M params (with attention)
19
+
20
+ **Requirements**:
21
+ - timm >= 0.9.0 (for FastViT models)
22
+
23
+ Reference:
24
+ Vasu, P.K.A., et al. (2023). FastViT: A Fast Hybrid Vision Transformer
25
+ using Structural Reparameterization. ICCV 2023.
26
+ https://arxiv.org/abs/2303.14189
27
+
28
+ Author: Ductho Le (ductho.le@outlook.com)
29
+ """
30
+
31
+ import torch
32
+
33
+ from wavedl.models._pretrained_utils import build_regression_head
34
+ from wavedl.models.base import BaseModel
35
+ from wavedl.models.registry import register_model
36
+
37
+
38
+ __all__ = [
39
+ "FastViTBase",
40
+ "FastViTS12",
41
+ "FastViTSA12",
42
+ "FastViTT8",
43
+ "FastViTT12",
44
+ ]
45
+
46
+
47
+ # =============================================================================
48
+ # FASTVIT BASE CLASS
49
+ # =============================================================================
50
+
51
+
52
+ class FastViTBase(BaseModel):
53
+ """
54
+ FastViT base class wrapping timm implementation.
55
+
56
+ Uses RepMixer for efficient token mixing with reparameterization.
57
+ 2D only.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_shape: tuple[int, int],
63
+ out_size: int,
64
+ model_name: str = "fastvit_t8",
65
+ pretrained: bool = True,
66
+ freeze_backbone: bool = False,
67
+ dropout_rate: float = 0.3,
68
+ **kwargs,
69
+ ):
70
+ super().__init__(in_shape, out_size)
71
+
72
+ if len(in_shape) != 2:
73
+ raise ValueError(f"FastViT requires 2D input (H, W), got {len(in_shape)}D")
74
+
75
+ self.pretrained = pretrained
76
+ self.freeze_backbone = freeze_backbone
77
+ self.model_name = model_name
78
+
79
+ # Try to load from timm
80
+ try:
81
+ import timm
82
+
83
+ self.backbone = timm.create_model(
84
+ model_name,
85
+ pretrained=pretrained,
86
+ num_classes=0, # Remove classifier
87
+ )
88
+
89
+ # Get feature dimension
90
+ with torch.no_grad():
91
+ dummy = torch.zeros(1, 3, *in_shape)
92
+ features = self.backbone(dummy)
93
+ in_features = features.shape[-1]
94
+
95
+ except ImportError:
96
+ raise ImportError(
97
+ "timm >= 0.9.0 is required for FastViT. "
98
+ "Install with: pip install timm>=0.9.0"
99
+ )
100
+ except Exception as e:
101
+ raise RuntimeError(f"Failed to load FastViT model '{model_name}': {e}")
102
+
103
+ # Adapt input channels (3 -> 1)
104
+ self._adapt_input_channels()
105
+
106
+ # Regression head
107
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
108
+
109
+ if freeze_backbone:
110
+ self._freeze_backbone()
111
+
112
+ def _adapt_input_channels(self):
113
+ """Adapt all conv layers with 3 input channels for single-channel input."""
114
+ # FastViT may have multiple modules with 3 input channels (e.g., conv_kxk, conv_scale)
115
+ # We need to adapt all of them
116
+ from wavedl.models._pretrained_utils import find_and_adapt_input_convs
117
+
118
+ adapted_count = find_and_adapt_input_convs(
119
+ self.backbone, pretrained=self.pretrained, adapt_all=True
120
+ )
121
+
122
+ if adapted_count == 0:
123
+ import warnings
124
+
125
+ warnings.warn(
126
+ "Could not adapt FastViT input channels. Model may fail.", stacklevel=2
127
+ )
128
+
129
+ def _freeze_backbone(self):
130
+ """Freeze backbone parameters."""
131
+ for param in self.backbone.parameters():
132
+ param.requires_grad = False
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ features = self.backbone(x)
136
+ return self.head(features)
137
+
138
+ def reparameterize(self):
139
+ """
140
+ Reparameterize model for inference.
141
+
142
+ Fuses RepMixer blocks for faster inference.
143
+ Call this before deployment.
144
+ """
145
+ if hasattr(self.backbone, "reparameterize"):
146
+ self.backbone.reparameterize()
147
+
148
+
149
+ # =============================================================================
150
+ # REGISTERED VARIANTS
151
+ # =============================================================================
152
+
153
+
154
+ @register_model("fastvit_t8")
155
+ class FastViTT8(FastViTBase):
156
+ """
157
+ FastViT-T8: ~3.3M backbone parameters (fastest variant).
158
+
159
+ Optimized for mobile and edge deployment.
160
+ 2D only.
161
+
162
+ Example:
163
+ >>> model = FastViTT8(in_shape=(224, 224), out_size=3)
164
+ >>> x = torch.randn(4, 1, 224, 224)
165
+ >>> out = model(x) # (4, 3)
166
+ """
167
+
168
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
169
+ super().__init__(
170
+ in_shape=in_shape,
171
+ out_size=out_size,
172
+ model_name="fastvit_t8",
173
+ **kwargs,
174
+ )
175
+
176
+ def __repr__(self) -> str:
177
+ return (
178
+ f"FastViT_T8(in_shape={self.in_shape}, out_size={self.out_size}, "
179
+ f"pretrained={self.pretrained})"
180
+ )
181
+
182
+
183
+ @register_model("fastvit_t12")
184
+ class FastViTT12(FastViTBase):
185
+ """
186
+ FastViT-T12: ~6.5M backbone parameters.
187
+
188
+ Balanced speed and accuracy.
189
+ 2D only.
190
+ """
191
+
192
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
193
+ super().__init__(
194
+ in_shape=in_shape,
195
+ out_size=out_size,
196
+ model_name="fastvit_t12",
197
+ **kwargs,
198
+ )
199
+
200
+ def __repr__(self) -> str:
201
+ return (
202
+ f"FastViT_T12(in_shape={self.in_shape}, out_size={self.out_size}, "
203
+ f"pretrained={self.pretrained})"
204
+ )
205
+
206
+
207
+ @register_model("fastvit_s12")
208
+ class FastViTS12(FastViTBase):
209
+ """
210
+ FastViT-S12: ~8.5M backbone parameters.
211
+
212
+ Slightly larger for better accuracy.
213
+ 2D only.
214
+ """
215
+
216
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
217
+ super().__init__(
218
+ in_shape=in_shape,
219
+ out_size=out_size,
220
+ model_name="fastvit_s12",
221
+ **kwargs,
222
+ )
223
+
224
+ def __repr__(self) -> str:
225
+ return (
226
+ f"FastViT_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
227
+ f"pretrained={self.pretrained})"
228
+ )
229
+
230
+
231
+ @register_model("fastvit_sa12")
232
+ class FastViTSA12(FastViTBase):
233
+ """
234
+ FastViT-SA12: ~10.6M backbone parameters.
235
+
236
+ With self-attention for better accuracy at the cost of speed.
237
+ 2D only.
238
+ """
239
+
240
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
241
+ super().__init__(
242
+ in_shape=in_shape,
243
+ out_size=out_size,
244
+ model_name="fastvit_sa12",
245
+ **kwargs,
246
+ )
247
+
248
+ def __repr__(self) -> str:
249
+ return (
250
+ f"FastViT_SA12(in_shape={self.in_shape}, out_size={self.out_size}, "
251
+ f"pretrained={self.pretrained})"
252
+ )