wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.6"
21
+ __version__ = "1.6.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/models/__init__.py CHANGED
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
6
6
  enabling dynamic model selection via command-line arguments.
7
7
 
8
8
  **Dimensionality Coverage**:
9
- - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
10
- - 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
11
- EfficientNet, MobileNetV3, RegNet, Swin
12
- - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, DenseNet
9
+ - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
10
+ - 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
11
+ EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
12
+ CAFormer, PoolFormer, Vision Mamba
13
+ - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
13
14
 
14
15
  Usage:
15
16
  from wavedl.models import get_model, list_models, MODEL_REGISTRY
@@ -46,9 +47,19 @@ from .base import BaseModel
46
47
  # Import model implementations (triggers registration via decorators)
47
48
  from .cnn import CNN
48
49
  from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
50
+
51
+ # New models (v1.6+)
52
+ from .convnext_v2 import (
53
+ ConvNeXtV2Base,
54
+ ConvNeXtV2BaseLarge,
55
+ ConvNeXtV2Small,
56
+ ConvNeXtV2Tiny,
57
+ ConvNeXtV2TinyPretrained,
58
+ )
49
59
  from .densenet import DenseNet121, DenseNet169
50
60
  from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
51
61
  from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
62
+ from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
52
63
  from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
53
64
  from .registry import (
54
65
  MODEL_REGISTRY,
@@ -66,6 +77,17 @@ from .unet import UNetRegression
66
77
  from .vit import ViTBase_, ViTSmall, ViTTiny
67
78
 
68
79
 
80
+ # Optional timm-based models (imported conditionally)
81
+ try:
82
+ from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
83
+ from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
84
+ from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
85
+
86
+ _HAS_TIMM_MODELS = True
87
+ except ImportError:
88
+ _HAS_TIMM_MODELS = False
89
+
90
+
69
91
  # Export public API (sorted alphabetically per RUF022)
70
92
  # See module docstring for dimensionality support details
71
93
  __all__ = [
@@ -77,6 +99,11 @@ __all__ = [
77
99
  "ConvNeXtBase_",
78
100
  "ConvNeXtSmall",
79
101
  "ConvNeXtTiny",
102
+ "ConvNeXtV2Base",
103
+ "ConvNeXtV2BaseLarge",
104
+ "ConvNeXtV2Small",
105
+ "ConvNeXtV2Tiny",
106
+ "ConvNeXtV2TinyPretrained",
80
107
  "DenseNet121",
81
108
  "DenseNet169",
82
109
  "EfficientNetB0",
@@ -85,6 +112,7 @@ __all__ = [
85
112
  "EfficientNetV2L",
86
113
  "EfficientNetV2M",
87
114
  "EfficientNetV2S",
115
+ "Mamba1D",
88
116
  "MobileNetV3Large",
89
117
  "MobileNetV3Small",
90
118
  "RegNetY1_6GF",
@@ -105,8 +133,28 @@ __all__ = [
105
133
  "ViTBase_",
106
134
  "ViTSmall",
107
135
  "ViTTiny",
136
+ "VimBase",
137
+ "VimSmall",
138
+ "VimTiny",
108
139
  "build_model",
109
140
  "get_model",
110
141
  "list_models",
111
142
  "register_model",
112
143
  ]
144
+
145
+ # Add timm-based models to __all__ if available
146
+ if _HAS_TIMM_MODELS:
147
+ __all__.extend(
148
+ [
149
+ "CaFormerS18",
150
+ "CaFormerS36",
151
+ "FastViTS12",
152
+ "FastViTSA12",
153
+ "FastViTT8",
154
+ "FastViTT12",
155
+ "MaxViTBaseLarge",
156
+ "MaxViTSmall",
157
+ "MaxViTTiny",
158
+ "PoolFormerS12",
159
+ ]
160
+ )
@@ -0,0 +1,238 @@
1
+ """
2
+ Shared Utilities for Model Architectures
3
+ =========================================
4
+
5
+ Common components used across multiple models:
6
+ - GRN (Global Response Normalization) for ConvNeXt V2
7
+ - Dimension-agnostic layer factories
8
+ - Regression head builders
9
+ - Input channel adaptation for pretrained models
10
+
11
+ Author: Ductho Le (ductho.le@outlook.com)
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ # =============================================================================
20
+ # DIMENSION-AGNOSTIC LAYER FACTORIES
21
+ # =============================================================================
22
+
23
+
24
+ def get_conv_layer(dim: int) -> type[nn.Module]:
25
+ """Get dimension-appropriate Conv class."""
26
+ layers = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
27
+ if dim not in layers:
28
+ raise ValueError(f"Unsupported dimension: {dim}")
29
+ return layers[dim]
30
+
31
+
32
+ def get_norm_layer(dim: int) -> type[nn.Module]:
33
+ """Get dimension-appropriate BatchNorm class."""
34
+ layers = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d}
35
+ if dim not in layers:
36
+ raise ValueError(f"Unsupported dimension: {dim}")
37
+ return layers[dim]
38
+
39
+
40
+ def get_pool_layer(dim: int) -> type[nn.Module]:
41
+ """Get dimension-appropriate AdaptiveAvgPool class."""
42
+ layers = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d, 3: nn.AdaptiveAvgPool3d}
43
+ if dim not in layers:
44
+ raise ValueError(f"Unsupported dimension: {dim}")
45
+ return layers[dim]
46
+
47
+
48
+ # =============================================================================
49
+ # GLOBAL RESPONSE NORMALIZATION (GRN) - ConvNeXt V2
50
+ # =============================================================================
51
+
52
+
53
+ class GRN1d(nn.Module):
54
+ """
55
+ Global Response Normalization for 1D inputs.
56
+
57
+ GRN enhances inter-channel feature competition and promotes diversity.
58
+ Replaces LayerScale in ConvNeXt V2.
59
+
60
+ Reference: ConvNeXt V2 (CVPR 2023)
61
+ """
62
+
63
+ def __init__(self, dim: int, eps: float = 1e-6):
64
+ super().__init__()
65
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1))
66
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1))
67
+ self.eps = eps
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ # x: (B, C, L)
71
+ Gx = torch.norm(x, p=2, dim=2, keepdim=True) # (B, C, 1)
72
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1)
73
+ return self.gamma * (x * Nx) + self.beta + x
74
+
75
+
76
+ class GRN2d(nn.Module):
77
+ """
78
+ Global Response Normalization for 2D inputs.
79
+
80
+ GRN enhances inter-channel feature competition and promotes diversity.
81
+ Replaces LayerScale in ConvNeXt V2.
82
+
83
+ Reference: ConvNeXt V2 (CVPR 2023)
84
+ """
85
+
86
+ def __init__(self, dim: int, eps: float = 1e-6):
87
+ super().__init__()
88
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
89
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
90
+ self.eps = eps
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ # x: (B, C, H, W)
94
+ Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) # (B, C, 1, 1)
95
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1)
96
+ return self.gamma * (x * Nx) + self.beta + x
97
+
98
+
99
+ class GRN3d(nn.Module):
100
+ """
101
+ Global Response Normalization for 3D inputs.
102
+
103
+ GRN enhances inter-channel feature competition and promotes diversity.
104
+ Replaces LayerScale in ConvNeXt V2.
105
+
106
+ Reference: ConvNeXt V2 (CVPR 2023)
107
+ """
108
+
109
+ def __init__(self, dim: int, eps: float = 1e-6):
110
+ super().__init__()
111
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
112
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
113
+ self.eps = eps
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ # x: (B, C, D, H, W)
117
+ Gx = torch.norm(x, p=2, dim=(2, 3, 4), keepdim=True) # (B, C, 1, 1, 1)
118
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1, 1)
119
+ return self.gamma * (x * Nx) + self.beta + x
120
+
121
+
122
+ def get_grn_layer(dim: int) -> type[nn.Module]:
123
+ """Get dimension-appropriate GRN class."""
124
+ layers = {1: GRN1d, 2: GRN2d, 3: GRN3d}
125
+ if dim not in layers:
126
+ raise ValueError(f"Unsupported dimension: {dim}")
127
+ return layers[dim]
128
+
129
+
130
+ # =============================================================================
131
+ # LAYER NORMALIZATION (Channels Last for CNNs)
132
+ # =============================================================================
133
+
134
+
135
+ class LayerNormNd(nn.Module):
136
+ """
137
+ LayerNorm that works with channels-first tensors of any dimension.
138
+ Applies normalization over the channel dimension.
139
+ """
140
+
141
+ def __init__(self, normalized_shape: int, dim: int, eps: float = 1e-6):
142
+ super().__init__()
143
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
144
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
145
+ self.eps = eps
146
+ self.dim = dim
147
+ self.normalized_shape = (normalized_shape,)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ # Move channels to last, apply LN, move back
151
+ if self.dim == 1:
152
+ # (B, C, L) -> (B, L, C) -> LN -> (B, C, L)
153
+ x = x.permute(0, 2, 1)
154
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
155
+ x = x.permute(0, 2, 1)
156
+ elif self.dim == 2:
157
+ # (B, C, H, W) -> (B, H, W, C) -> LN -> (B, C, H, W)
158
+ x = x.permute(0, 2, 3, 1)
159
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
160
+ x = x.permute(0, 3, 1, 2)
161
+ elif self.dim == 3:
162
+ # (B, C, D, H, W) -> (B, D, H, W, C) -> LN -> (B, C, D, H, W)
163
+ x = x.permute(0, 2, 3, 4, 1)
164
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
165
+ x = x.permute(0, 4, 1, 2, 3)
166
+ return x
167
+
168
+
169
+ # =============================================================================
170
+ # REGRESSION HEAD BUILDERS
171
+ # =============================================================================
172
+
173
+
174
+ def build_regression_head(
175
+ in_features: int,
176
+ out_size: int,
177
+ dropout_rate: float = 0.3,
178
+ hidden_dim: int = 512,
179
+ ) -> nn.Sequential:
180
+ """
181
+ Build a standard regression head for pretrained models.
182
+
183
+ Args:
184
+ in_features: Input feature dimension
185
+ out_size: Number of regression targets
186
+ dropout_rate: Dropout rate
187
+ hidden_dim: Hidden layer dimension
188
+
189
+ Returns:
190
+ nn.Sequential regression head
191
+ """
192
+ return nn.Sequential(
193
+ nn.Dropout(dropout_rate),
194
+ nn.Linear(in_features, hidden_dim),
195
+ nn.SiLU(inplace=True),
196
+ nn.Dropout(dropout_rate * 0.5),
197
+ nn.Linear(hidden_dim, hidden_dim // 2),
198
+ nn.SiLU(inplace=True),
199
+ nn.Linear(hidden_dim // 2, out_size),
200
+ )
201
+
202
+
203
+ def adapt_input_channels(
204
+ conv_layer: nn.Module,
205
+ new_in_channels: int = 1,
206
+ pretrained: bool = True,
207
+ ) -> nn.Module:
208
+ """
209
+ Adapt a convolutional layer for different input channels.
210
+
211
+ For pretrained models, averages RGB weights to grayscale.
212
+
213
+ Args:
214
+ conv_layer: Original conv layer (expects 3 input channels)
215
+ new_in_channels: New number of input channels (default: 1)
216
+ pretrained: Whether to adapt pretrained weights
217
+
218
+ Returns:
219
+ New conv layer with adapted input channels
220
+ """
221
+ if isinstance(conv_layer, nn.Conv2d):
222
+ new_conv = nn.Conv2d(
223
+ new_in_channels,
224
+ conv_layer.out_channels,
225
+ kernel_size=conv_layer.kernel_size,
226
+ stride=conv_layer.stride,
227
+ padding=conv_layer.padding,
228
+ bias=conv_layer.bias is not None,
229
+ )
230
+ if pretrained and conv_layer.in_channels == 3:
231
+ with torch.no_grad():
232
+ # Average RGB weights
233
+ new_conv.weight.copy_(conv_layer.weight.mean(dim=1, keepdim=True))
234
+ if conv_layer.bias is not None:
235
+ new_conv.bias.copy_(conv_layer.bias)
236
+ return new_conv
237
+ else:
238
+ raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
@@ -0,0 +1,270 @@
1
+ """
2
+ CaFormer: MetaFormer with Convolution and Attention
3
+ ====================================================
4
+
5
+ CaFormer implements the MetaFormer architecture using depthwise separable
6
+ convolutions in early stages and vanilla self-attention in later stages.
7
+
8
+ **Key Features**:
9
+ - MetaFormer principle: architecture > token mixer
10
+ - Hybrid: Conv (early) + Attention (late)
11
+ - StarReLU activation for efficiency
12
+ - State-of-the-art on ImageNet (85.5%)
13
+
14
+ **Variants**:
15
+ - caformer_s18: 26M params
16
+ - caformer_s36: 39M params
17
+ - caformer_m36: 56M params
18
+
19
+ **Related Models**:
20
+ - PoolFormer: Uses pooling instead of attention
21
+ - ConvFormer: Uses only convolutions
22
+
23
+ **Requirements**:
24
+ - timm >= 0.9.0 (for CaFormer models)
25
+
26
+ Reference:
27
+ Yu, W., et al. (2023). MetaFormer Baselines for Vision.
28
+ TPAMI 2023. https://arxiv.org/abs/2210.13452
29
+
30
+ Author: Ductho Le (ductho.le@outlook.com)
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+
36
+ from wavedl.models._timm_utils import build_regression_head
37
+ from wavedl.models.base import BaseModel
38
+ from wavedl.models.registry import register_model
39
+
40
+
41
+ __all__ = [
42
+ "CaFormerBase",
43
+ "CaFormerM36",
44
+ "CaFormerS18",
45
+ "CaFormerS36",
46
+ "PoolFormerS12",
47
+ ]
48
+
49
+
50
+ # =============================================================================
51
+ # CAFORMER BASE CLASS
52
+ # =============================================================================
53
+
54
+
55
+ class CaFormerBase(BaseModel):
56
+ """
57
+ CaFormer base class wrapping timm implementation.
58
+
59
+ MetaFormer with conv (early) + attention (late) token mixing.
60
+ 2D only.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ in_shape: tuple[int, int],
66
+ out_size: int,
67
+ model_name: str = "caformer_s18",
68
+ pretrained: bool = True,
69
+ freeze_backbone: bool = False,
70
+ dropout_rate: float = 0.3,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(in_shape, out_size)
74
+
75
+ if len(in_shape) != 2:
76
+ raise ValueError(f"CaFormer requires 2D input (H, W), got {len(in_shape)}D")
77
+
78
+ self.pretrained = pretrained
79
+ self.freeze_backbone = freeze_backbone
80
+ self.model_name = model_name
81
+
82
+ # Try to load from timm
83
+ try:
84
+ import timm
85
+
86
+ self.backbone = timm.create_model(
87
+ model_name,
88
+ pretrained=pretrained,
89
+ num_classes=0, # Remove classifier
90
+ )
91
+
92
+ # Get feature dimension
93
+ with torch.no_grad():
94
+ dummy = torch.zeros(1, 3, *in_shape)
95
+ features = self.backbone(dummy)
96
+ in_features = features.shape[-1]
97
+
98
+ except ImportError:
99
+ raise ImportError(
100
+ "timm >= 0.9.0 is required for CaFormer. "
101
+ "Install with: pip install timm>=0.9.0"
102
+ )
103
+ except Exception as e:
104
+ raise RuntimeError(f"Failed to load CaFormer model '{model_name}': {e}")
105
+
106
+ # Adapt input channels (3 -> 1)
107
+ self._adapt_input_channels()
108
+
109
+ # Regression head
110
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
111
+
112
+ if freeze_backbone:
113
+ self._freeze_backbone()
114
+
115
+ def _adapt_input_channels(self):
116
+ """Adapt first conv layer for single-channel input."""
117
+ # CaFormer uses stem for first layer
118
+ if hasattr(self.backbone, "stem"):
119
+ first_conv = None
120
+ # Find first conv in stem
121
+ for name, module in self.backbone.stem.named_modules():
122
+ if isinstance(module, nn.Conv2d):
123
+ first_conv = (name, module)
124
+ break
125
+
126
+ if first_conv is not None:
127
+ name, old_conv = first_conv
128
+ new_conv = self._make_new_conv(old_conv)
129
+ # Set the new conv (handle nested structure)
130
+ self._set_module(self.backbone.stem, name, new_conv)
131
+
132
+ def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
133
+ """Create new conv layer with 1 input channel."""
134
+ new_conv = nn.Conv2d(
135
+ 1,
136
+ old_conv.out_channels,
137
+ kernel_size=old_conv.kernel_size,
138
+ stride=old_conv.stride,
139
+ padding=old_conv.padding,
140
+ bias=old_conv.bias is not None,
141
+ )
142
+ if self.pretrained:
143
+ with torch.no_grad():
144
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
145
+ if old_conv.bias is not None:
146
+ new_conv.bias.copy_(old_conv.bias)
147
+ return new_conv
148
+
149
+ def _set_module(self, parent: nn.Module, name: str, module: nn.Module):
150
+ """Set a nested module by name."""
151
+ parts = name.split(".")
152
+ for part in parts[:-1]:
153
+ parent = getattr(parent, part)
154
+ setattr(parent, parts[-1], module)
155
+
156
+ def _freeze_backbone(self):
157
+ """Freeze backbone parameters."""
158
+ for param in self.backbone.parameters():
159
+ param.requires_grad = False
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ features = self.backbone(x)
163
+ return self.head(features)
164
+
165
+
166
+ # =============================================================================
167
+ # REGISTERED VARIANTS
168
+ # =============================================================================
169
+
170
+
171
+ @register_model("caformer_s18")
172
+ class CaFormerS18(CaFormerBase):
173
+ """
174
+ CaFormer-S18: ~23.2M backbone parameters.
175
+
176
+ MetaFormer with conv + attention.
177
+ 2D only.
178
+
179
+ Example:
180
+ >>> model = CaFormerS18(in_shape=(224, 224), out_size=3)
181
+ >>> x = torch.randn(4, 1, 224, 224)
182
+ >>> out = model(x) # (4, 3)
183
+ """
184
+
185
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
186
+ super().__init__(
187
+ in_shape=in_shape,
188
+ out_size=out_size,
189
+ model_name="caformer_s18",
190
+ **kwargs,
191
+ )
192
+
193
+ def __repr__(self) -> str:
194
+ return (
195
+ f"CaFormer_S18(in_shape={self.in_shape}, out_size={self.out_size}, "
196
+ f"pretrained={self.pretrained})"
197
+ )
198
+
199
+
200
+ @register_model("caformer_s36")
201
+ class CaFormerS36(CaFormerBase):
202
+ """
203
+ CaFormer-S36: ~36.2M backbone parameters.
204
+
205
+ Deeper MetaFormer variant.
206
+ 2D only.
207
+ """
208
+
209
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
210
+ super().__init__(
211
+ in_shape=in_shape,
212
+ out_size=out_size,
213
+ model_name="caformer_s36",
214
+ **kwargs,
215
+ )
216
+
217
+ def __repr__(self) -> str:
218
+ return (
219
+ f"CaFormer_S36(in_shape={self.in_shape}, out_size={self.out_size}, "
220
+ f"pretrained={self.pretrained})"
221
+ )
222
+
223
+
224
+ @register_model("caformer_m36")
225
+ class CaFormerM36(CaFormerBase):
226
+ """
227
+ CaFormer-M36: ~52.6M backbone parameters.
228
+
229
+ Medium-size MetaFormer variant.
230
+ 2D only.
231
+ """
232
+
233
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
234
+ super().__init__(
235
+ in_shape=in_shape,
236
+ out_size=out_size,
237
+ model_name="caformer_m36",
238
+ **kwargs,
239
+ )
240
+
241
+ def __repr__(self) -> str:
242
+ return (
243
+ f"CaFormer_M36(in_shape={self.in_shape}, out_size={self.out_size}, "
244
+ f"pretrained={self.pretrained})"
245
+ )
246
+
247
+
248
+ @register_model("poolformer_s12")
249
+ class PoolFormerS12(CaFormerBase):
250
+ """
251
+ PoolFormer-S12: ~11.4M backbone parameters.
252
+
253
+ MetaFormer with simple pooling token mixer.
254
+ Proves that architecture matters more than complex attention.
255
+ 2D only.
256
+ """
257
+
258
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
259
+ super().__init__(
260
+ in_shape=in_shape,
261
+ out_size=out_size,
262
+ model_name="poolformer_s12",
263
+ **kwargs,
264
+ )
265
+
266
+ def __repr__(self) -> str:
267
+ return (
268
+ f"PoolFormer_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
269
+ f"pretrained={self.pretrained})"
270
+ )