wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/models/__init__.py
CHANGED
|
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
|
|
|
6
6
|
enabling dynamic model selection via command-line arguments.
|
|
7
7
|
|
|
8
8
|
**Dimensionality Coverage**:
|
|
9
|
-
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
|
|
10
|
-
- 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
|
|
11
|
-
EfficientNet, MobileNetV3, RegNet, Swin
|
|
12
|
-
|
|
9
|
+
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
|
|
10
|
+
- 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
|
|
11
|
+
EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
|
|
12
|
+
CAFormer, PoolFormer, Vision Mamba
|
|
13
|
+
- 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
|
|
13
14
|
|
|
14
15
|
Usage:
|
|
15
16
|
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
@@ -46,9 +47,19 @@ from .base import BaseModel
|
|
|
46
47
|
# Import model implementations (triggers registration via decorators)
|
|
47
48
|
from .cnn import CNN
|
|
48
49
|
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
50
|
+
|
|
51
|
+
# New models (v1.6+)
|
|
52
|
+
from .convnext_v2 import (
|
|
53
|
+
ConvNeXtV2Base,
|
|
54
|
+
ConvNeXtV2BaseLarge,
|
|
55
|
+
ConvNeXtV2Small,
|
|
56
|
+
ConvNeXtV2Tiny,
|
|
57
|
+
ConvNeXtV2TinyPretrained,
|
|
58
|
+
)
|
|
49
59
|
from .densenet import DenseNet121, DenseNet169
|
|
50
60
|
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
51
61
|
from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
|
|
62
|
+
from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
|
|
52
63
|
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
|
|
53
64
|
from .registry import (
|
|
54
65
|
MODEL_REGISTRY,
|
|
@@ -66,6 +77,17 @@ from .unet import UNetRegression
|
|
|
66
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
67
78
|
|
|
68
79
|
|
|
80
|
+
# Optional timm-based models (imported conditionally)
|
|
81
|
+
try:
|
|
82
|
+
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
83
|
+
from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
|
|
84
|
+
from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
|
|
85
|
+
|
|
86
|
+
_HAS_TIMM_MODELS = True
|
|
87
|
+
except ImportError:
|
|
88
|
+
_HAS_TIMM_MODELS = False
|
|
89
|
+
|
|
90
|
+
|
|
69
91
|
# Export public API (sorted alphabetically per RUF022)
|
|
70
92
|
# See module docstring for dimensionality support details
|
|
71
93
|
__all__ = [
|
|
@@ -77,6 +99,11 @@ __all__ = [
|
|
|
77
99
|
"ConvNeXtBase_",
|
|
78
100
|
"ConvNeXtSmall",
|
|
79
101
|
"ConvNeXtTiny",
|
|
102
|
+
"ConvNeXtV2Base",
|
|
103
|
+
"ConvNeXtV2BaseLarge",
|
|
104
|
+
"ConvNeXtV2Small",
|
|
105
|
+
"ConvNeXtV2Tiny",
|
|
106
|
+
"ConvNeXtV2TinyPretrained",
|
|
80
107
|
"DenseNet121",
|
|
81
108
|
"DenseNet169",
|
|
82
109
|
"EfficientNetB0",
|
|
@@ -85,6 +112,7 @@ __all__ = [
|
|
|
85
112
|
"EfficientNetV2L",
|
|
86
113
|
"EfficientNetV2M",
|
|
87
114
|
"EfficientNetV2S",
|
|
115
|
+
"Mamba1D",
|
|
88
116
|
"MobileNetV3Large",
|
|
89
117
|
"MobileNetV3Small",
|
|
90
118
|
"RegNetY1_6GF",
|
|
@@ -105,8 +133,28 @@ __all__ = [
|
|
|
105
133
|
"ViTBase_",
|
|
106
134
|
"ViTSmall",
|
|
107
135
|
"ViTTiny",
|
|
136
|
+
"VimBase",
|
|
137
|
+
"VimSmall",
|
|
138
|
+
"VimTiny",
|
|
108
139
|
"build_model",
|
|
109
140
|
"get_model",
|
|
110
141
|
"list_models",
|
|
111
142
|
"register_model",
|
|
112
143
|
]
|
|
144
|
+
|
|
145
|
+
# Add timm-based models to __all__ if available
|
|
146
|
+
if _HAS_TIMM_MODELS:
|
|
147
|
+
__all__.extend(
|
|
148
|
+
[
|
|
149
|
+
"CaFormerS18",
|
|
150
|
+
"CaFormerS36",
|
|
151
|
+
"FastViTS12",
|
|
152
|
+
"FastViTSA12",
|
|
153
|
+
"FastViTT8",
|
|
154
|
+
"FastViTT12",
|
|
155
|
+
"MaxViTBaseLarge",
|
|
156
|
+
"MaxViTSmall",
|
|
157
|
+
"MaxViTTiny",
|
|
158
|
+
"PoolFormerS12",
|
|
159
|
+
]
|
|
160
|
+
)
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared Utilities for Model Architectures
|
|
3
|
+
=========================================
|
|
4
|
+
|
|
5
|
+
Common components used across multiple models:
|
|
6
|
+
- GRN (Global Response Normalization) for ConvNeXt V2
|
|
7
|
+
- Dimension-agnostic layer factories
|
|
8
|
+
- Regression head builders
|
|
9
|
+
- Input channel adaptation for pretrained models
|
|
10
|
+
|
|
11
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =============================================================================
|
|
20
|
+
# DIMENSION-AGNOSTIC LAYER FACTORIES
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_conv_layer(dim: int) -> type[nn.Module]:
|
|
25
|
+
"""Get dimension-appropriate Conv class."""
|
|
26
|
+
layers = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
|
|
27
|
+
if dim not in layers:
|
|
28
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
29
|
+
return layers[dim]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_norm_layer(dim: int) -> type[nn.Module]:
|
|
33
|
+
"""Get dimension-appropriate BatchNorm class."""
|
|
34
|
+
layers = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d}
|
|
35
|
+
if dim not in layers:
|
|
36
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
37
|
+
return layers[dim]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_pool_layer(dim: int) -> type[nn.Module]:
|
|
41
|
+
"""Get dimension-appropriate AdaptiveAvgPool class."""
|
|
42
|
+
layers = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d, 3: nn.AdaptiveAvgPool3d}
|
|
43
|
+
if dim not in layers:
|
|
44
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
45
|
+
return layers[dim]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# GLOBAL RESPONSE NORMALIZATION (GRN) - ConvNeXt V2
|
|
50
|
+
# =============================================================================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GRN1d(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
Global Response Normalization for 1D inputs.
|
|
56
|
+
|
|
57
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
58
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
59
|
+
|
|
60
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1))
|
|
66
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1))
|
|
67
|
+
self.eps = eps
|
|
68
|
+
|
|
69
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
# x: (B, C, L)
|
|
71
|
+
Gx = torch.norm(x, p=2, dim=2, keepdim=True) # (B, C, 1)
|
|
72
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1)
|
|
73
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class GRN2d(nn.Module):
|
|
77
|
+
"""
|
|
78
|
+
Global Response Normalization for 2D inputs.
|
|
79
|
+
|
|
80
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
81
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
82
|
+
|
|
83
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
89
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
90
|
+
self.eps = eps
|
|
91
|
+
|
|
92
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
# x: (B, C, H, W)
|
|
94
|
+
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) # (B, C, 1, 1)
|
|
95
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1)
|
|
96
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class GRN3d(nn.Module):
|
|
100
|
+
"""
|
|
101
|
+
Global Response Normalization for 3D inputs.
|
|
102
|
+
|
|
103
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
104
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
105
|
+
|
|
106
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
|
|
112
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
|
|
113
|
+
self.eps = eps
|
|
114
|
+
|
|
115
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
|
+
# x: (B, C, D, H, W)
|
|
117
|
+
Gx = torch.norm(x, p=2, dim=(2, 3, 4), keepdim=True) # (B, C, 1, 1, 1)
|
|
118
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1, 1)
|
|
119
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_grn_layer(dim: int) -> type[nn.Module]:
|
|
123
|
+
"""Get dimension-appropriate GRN class."""
|
|
124
|
+
layers = {1: GRN1d, 2: GRN2d, 3: GRN3d}
|
|
125
|
+
if dim not in layers:
|
|
126
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
127
|
+
return layers[dim]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# =============================================================================
|
|
131
|
+
# LAYER NORMALIZATION (Channels Last for CNNs)
|
|
132
|
+
# =============================================================================
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class LayerNormNd(nn.Module):
|
|
136
|
+
"""
|
|
137
|
+
LayerNorm that works with channels-first tensors of any dimension.
|
|
138
|
+
Applies normalization over the channel dimension.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, normalized_shape: int, dim: int, eps: float = 1e-6):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
144
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
145
|
+
self.eps = eps
|
|
146
|
+
self.dim = dim
|
|
147
|
+
self.normalized_shape = (normalized_shape,)
|
|
148
|
+
|
|
149
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
# Move channels to last, apply LN, move back
|
|
151
|
+
if self.dim == 1:
|
|
152
|
+
# (B, C, L) -> (B, L, C) -> LN -> (B, C, L)
|
|
153
|
+
x = x.permute(0, 2, 1)
|
|
154
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
155
|
+
x = x.permute(0, 2, 1)
|
|
156
|
+
elif self.dim == 2:
|
|
157
|
+
# (B, C, H, W) -> (B, H, W, C) -> LN -> (B, C, H, W)
|
|
158
|
+
x = x.permute(0, 2, 3, 1)
|
|
159
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
160
|
+
x = x.permute(0, 3, 1, 2)
|
|
161
|
+
elif self.dim == 3:
|
|
162
|
+
# (B, C, D, H, W) -> (B, D, H, W, C) -> LN -> (B, C, D, H, W)
|
|
163
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
164
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
165
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# REGRESSION HEAD BUILDERS
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_regression_head(
|
|
175
|
+
in_features: int,
|
|
176
|
+
out_size: int,
|
|
177
|
+
dropout_rate: float = 0.3,
|
|
178
|
+
hidden_dim: int = 512,
|
|
179
|
+
) -> nn.Sequential:
|
|
180
|
+
"""
|
|
181
|
+
Build a standard regression head for pretrained models.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
in_features: Input feature dimension
|
|
185
|
+
out_size: Number of regression targets
|
|
186
|
+
dropout_rate: Dropout rate
|
|
187
|
+
hidden_dim: Hidden layer dimension
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
nn.Sequential regression head
|
|
191
|
+
"""
|
|
192
|
+
return nn.Sequential(
|
|
193
|
+
nn.Dropout(dropout_rate),
|
|
194
|
+
nn.Linear(in_features, hidden_dim),
|
|
195
|
+
nn.SiLU(inplace=True),
|
|
196
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
197
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
198
|
+
nn.SiLU(inplace=True),
|
|
199
|
+
nn.Linear(hidden_dim // 2, out_size),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def adapt_input_channels(
|
|
204
|
+
conv_layer: nn.Module,
|
|
205
|
+
new_in_channels: int = 1,
|
|
206
|
+
pretrained: bool = True,
|
|
207
|
+
) -> nn.Module:
|
|
208
|
+
"""
|
|
209
|
+
Adapt a convolutional layer for different input channels.
|
|
210
|
+
|
|
211
|
+
For pretrained models, averages RGB weights to grayscale.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
conv_layer: Original conv layer (expects 3 input channels)
|
|
215
|
+
new_in_channels: New number of input channels (default: 1)
|
|
216
|
+
pretrained: Whether to adapt pretrained weights
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
New conv layer with adapted input channels
|
|
220
|
+
"""
|
|
221
|
+
if isinstance(conv_layer, nn.Conv2d):
|
|
222
|
+
new_conv = nn.Conv2d(
|
|
223
|
+
new_in_channels,
|
|
224
|
+
conv_layer.out_channels,
|
|
225
|
+
kernel_size=conv_layer.kernel_size,
|
|
226
|
+
stride=conv_layer.stride,
|
|
227
|
+
padding=conv_layer.padding,
|
|
228
|
+
bias=conv_layer.bias is not None,
|
|
229
|
+
)
|
|
230
|
+
if pretrained and conv_layer.in_channels == 3:
|
|
231
|
+
with torch.no_grad():
|
|
232
|
+
# Average RGB weights
|
|
233
|
+
new_conv.weight.copy_(conv_layer.weight.mean(dim=1, keepdim=True))
|
|
234
|
+
if conv_layer.bias is not None:
|
|
235
|
+
new_conv.bias.copy_(conv_layer.bias)
|
|
236
|
+
return new_conv
|
|
237
|
+
else:
|
|
238
|
+
raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CaFormer: MetaFormer with Convolution and Attention
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
CaFormer implements the MetaFormer architecture using depthwise separable
|
|
6
|
+
convolutions in early stages and vanilla self-attention in later stages.
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- MetaFormer principle: architecture > token mixer
|
|
10
|
+
- Hybrid: Conv (early) + Attention (late)
|
|
11
|
+
- StarReLU activation for efficiency
|
|
12
|
+
- State-of-the-art on ImageNet (85.5%)
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- caformer_s18: 26M params
|
|
16
|
+
- caformer_s36: 39M params
|
|
17
|
+
- caformer_m36: 56M params
|
|
18
|
+
|
|
19
|
+
**Related Models**:
|
|
20
|
+
- PoolFormer: Uses pooling instead of attention
|
|
21
|
+
- ConvFormer: Uses only convolutions
|
|
22
|
+
|
|
23
|
+
**Requirements**:
|
|
24
|
+
- timm >= 0.9.0 (for CaFormer models)
|
|
25
|
+
|
|
26
|
+
Reference:
|
|
27
|
+
Yu, W., et al. (2023). MetaFormer Baselines for Vision.
|
|
28
|
+
TPAMI 2023. https://arxiv.org/abs/2210.13452
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
|
|
36
|
+
from wavedl.models._timm_utils import build_regression_head
|
|
37
|
+
from wavedl.models.base import BaseModel
|
|
38
|
+
from wavedl.models.registry import register_model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"CaFormerBase",
|
|
43
|
+
"CaFormerM36",
|
|
44
|
+
"CaFormerS18",
|
|
45
|
+
"CaFormerS36",
|
|
46
|
+
"PoolFormerS12",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# =============================================================================
|
|
51
|
+
# CAFORMER BASE CLASS
|
|
52
|
+
# =============================================================================
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class CaFormerBase(BaseModel):
|
|
56
|
+
"""
|
|
57
|
+
CaFormer base class wrapping timm implementation.
|
|
58
|
+
|
|
59
|
+
MetaFormer with conv (early) + attention (late) token mixing.
|
|
60
|
+
2D only.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
in_shape: tuple[int, int],
|
|
66
|
+
out_size: int,
|
|
67
|
+
model_name: str = "caformer_s18",
|
|
68
|
+
pretrained: bool = True,
|
|
69
|
+
freeze_backbone: bool = False,
|
|
70
|
+
dropout_rate: float = 0.3,
|
|
71
|
+
**kwargs,
|
|
72
|
+
):
|
|
73
|
+
super().__init__(in_shape, out_size)
|
|
74
|
+
|
|
75
|
+
if len(in_shape) != 2:
|
|
76
|
+
raise ValueError(f"CaFormer requires 2D input (H, W), got {len(in_shape)}D")
|
|
77
|
+
|
|
78
|
+
self.pretrained = pretrained
|
|
79
|
+
self.freeze_backbone = freeze_backbone
|
|
80
|
+
self.model_name = model_name
|
|
81
|
+
|
|
82
|
+
# Try to load from timm
|
|
83
|
+
try:
|
|
84
|
+
import timm
|
|
85
|
+
|
|
86
|
+
self.backbone = timm.create_model(
|
|
87
|
+
model_name,
|
|
88
|
+
pretrained=pretrained,
|
|
89
|
+
num_classes=0, # Remove classifier
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Get feature dimension
|
|
93
|
+
with torch.no_grad():
|
|
94
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
95
|
+
features = self.backbone(dummy)
|
|
96
|
+
in_features = features.shape[-1]
|
|
97
|
+
|
|
98
|
+
except ImportError:
|
|
99
|
+
raise ImportError(
|
|
100
|
+
"timm >= 0.9.0 is required for CaFormer. "
|
|
101
|
+
"Install with: pip install timm>=0.9.0"
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise RuntimeError(f"Failed to load CaFormer model '{model_name}': {e}")
|
|
105
|
+
|
|
106
|
+
# Adapt input channels (3 -> 1)
|
|
107
|
+
self._adapt_input_channels()
|
|
108
|
+
|
|
109
|
+
# Regression head
|
|
110
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
111
|
+
|
|
112
|
+
if freeze_backbone:
|
|
113
|
+
self._freeze_backbone()
|
|
114
|
+
|
|
115
|
+
def _adapt_input_channels(self):
|
|
116
|
+
"""Adapt first conv layer for single-channel input."""
|
|
117
|
+
# CaFormer uses stem for first layer
|
|
118
|
+
if hasattr(self.backbone, "stem"):
|
|
119
|
+
first_conv = None
|
|
120
|
+
# Find first conv in stem
|
|
121
|
+
for name, module in self.backbone.stem.named_modules():
|
|
122
|
+
if isinstance(module, nn.Conv2d):
|
|
123
|
+
first_conv = (name, module)
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
if first_conv is not None:
|
|
127
|
+
name, old_conv = first_conv
|
|
128
|
+
new_conv = self._make_new_conv(old_conv)
|
|
129
|
+
# Set the new conv (handle nested structure)
|
|
130
|
+
self._set_module(self.backbone.stem, name, new_conv)
|
|
131
|
+
|
|
132
|
+
def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
|
|
133
|
+
"""Create new conv layer with 1 input channel."""
|
|
134
|
+
new_conv = nn.Conv2d(
|
|
135
|
+
1,
|
|
136
|
+
old_conv.out_channels,
|
|
137
|
+
kernel_size=old_conv.kernel_size,
|
|
138
|
+
stride=old_conv.stride,
|
|
139
|
+
padding=old_conv.padding,
|
|
140
|
+
bias=old_conv.bias is not None,
|
|
141
|
+
)
|
|
142
|
+
if self.pretrained:
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
145
|
+
if old_conv.bias is not None:
|
|
146
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
147
|
+
return new_conv
|
|
148
|
+
|
|
149
|
+
def _set_module(self, parent: nn.Module, name: str, module: nn.Module):
|
|
150
|
+
"""Set a nested module by name."""
|
|
151
|
+
parts = name.split(".")
|
|
152
|
+
for part in parts[:-1]:
|
|
153
|
+
parent = getattr(parent, part)
|
|
154
|
+
setattr(parent, parts[-1], module)
|
|
155
|
+
|
|
156
|
+
def _freeze_backbone(self):
|
|
157
|
+
"""Freeze backbone parameters."""
|
|
158
|
+
for param in self.backbone.parameters():
|
|
159
|
+
param.requires_grad = False
|
|
160
|
+
|
|
161
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
features = self.backbone(x)
|
|
163
|
+
return self.head(features)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# =============================================================================
|
|
167
|
+
# REGISTERED VARIANTS
|
|
168
|
+
# =============================================================================
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@register_model("caformer_s18")
|
|
172
|
+
class CaFormerS18(CaFormerBase):
|
|
173
|
+
"""
|
|
174
|
+
CaFormer-S18: ~23.2M backbone parameters.
|
|
175
|
+
|
|
176
|
+
MetaFormer with conv + attention.
|
|
177
|
+
2D only.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
>>> model = CaFormerS18(in_shape=(224, 224), out_size=3)
|
|
181
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
182
|
+
>>> out = model(x) # (4, 3)
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
186
|
+
super().__init__(
|
|
187
|
+
in_shape=in_shape,
|
|
188
|
+
out_size=out_size,
|
|
189
|
+
model_name="caformer_s18",
|
|
190
|
+
**kwargs,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def __repr__(self) -> str:
|
|
194
|
+
return (
|
|
195
|
+
f"CaFormer_S18(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
196
|
+
f"pretrained={self.pretrained})"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@register_model("caformer_s36")
|
|
201
|
+
class CaFormerS36(CaFormerBase):
|
|
202
|
+
"""
|
|
203
|
+
CaFormer-S36: ~36.2M backbone parameters.
|
|
204
|
+
|
|
205
|
+
Deeper MetaFormer variant.
|
|
206
|
+
2D only.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
210
|
+
super().__init__(
|
|
211
|
+
in_shape=in_shape,
|
|
212
|
+
out_size=out_size,
|
|
213
|
+
model_name="caformer_s36",
|
|
214
|
+
**kwargs,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def __repr__(self) -> str:
|
|
218
|
+
return (
|
|
219
|
+
f"CaFormer_S36(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
220
|
+
f"pretrained={self.pretrained})"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@register_model("caformer_m36")
|
|
225
|
+
class CaFormerM36(CaFormerBase):
|
|
226
|
+
"""
|
|
227
|
+
CaFormer-M36: ~52.6M backbone parameters.
|
|
228
|
+
|
|
229
|
+
Medium-size MetaFormer variant.
|
|
230
|
+
2D only.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
234
|
+
super().__init__(
|
|
235
|
+
in_shape=in_shape,
|
|
236
|
+
out_size=out_size,
|
|
237
|
+
model_name="caformer_m36",
|
|
238
|
+
**kwargs,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def __repr__(self) -> str:
|
|
242
|
+
return (
|
|
243
|
+
f"CaFormer_M36(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
244
|
+
f"pretrained={self.pretrained})"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@register_model("poolformer_s12")
|
|
249
|
+
class PoolFormerS12(CaFormerBase):
|
|
250
|
+
"""
|
|
251
|
+
PoolFormer-S12: ~11.4M backbone parameters.
|
|
252
|
+
|
|
253
|
+
MetaFormer with simple pooling token mixer.
|
|
254
|
+
Proves that architecture matters more than complex attention.
|
|
255
|
+
2D only.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
259
|
+
super().__init__(
|
|
260
|
+
in_shape=in_shape,
|
|
261
|
+
out_size=out_size,
|
|
262
|
+
model_name="poolformer_s12",
|
|
263
|
+
**kwargs,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def __repr__(self) -> str:
|
|
267
|
+
return (
|
|
268
|
+
f"PoolFormer_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
269
|
+
f"pretrained={self.pretrained})"
|
|
270
|
+
)
|