wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/__init__.py
CHANGED
|
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
|
|
|
6
6
|
enabling dynamic model selection via command-line arguments.
|
|
7
7
|
|
|
8
8
|
**Dimensionality Coverage**:
|
|
9
|
-
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
|
|
10
|
-
- 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
|
|
11
|
-
EfficientNet, MobileNetV3, RegNet, Swin
|
|
12
|
-
|
|
9
|
+
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
|
|
10
|
+
- 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
|
|
11
|
+
EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
|
|
12
|
+
CAFormer, PoolFormer, Vision Mamba
|
|
13
|
+
- 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
|
|
13
14
|
|
|
14
15
|
Usage:
|
|
15
16
|
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
@@ -46,9 +47,19 @@ from .base import BaseModel
|
|
|
46
47
|
# Import model implementations (triggers registration via decorators)
|
|
47
48
|
from .cnn import CNN
|
|
48
49
|
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
50
|
+
|
|
51
|
+
# New models (v1.6+)
|
|
52
|
+
from .convnext_v2 import (
|
|
53
|
+
ConvNeXtV2Base,
|
|
54
|
+
ConvNeXtV2BaseLarge,
|
|
55
|
+
ConvNeXtV2Small,
|
|
56
|
+
ConvNeXtV2Tiny,
|
|
57
|
+
ConvNeXtV2TinyPretrained,
|
|
58
|
+
)
|
|
49
59
|
from .densenet import DenseNet121, DenseNet169
|
|
50
60
|
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
51
61
|
from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
|
|
62
|
+
from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
|
|
52
63
|
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
|
|
53
64
|
from .registry import (
|
|
54
65
|
MODEL_REGISTRY,
|
|
@@ -66,6 +77,33 @@ from .unet import UNetRegression
|
|
|
66
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
67
78
|
|
|
68
79
|
|
|
80
|
+
# Optional timm-based models (imported conditionally)
|
|
81
|
+
try:
|
|
82
|
+
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
83
|
+
from .efficientvit import (
|
|
84
|
+
EfficientViTB0,
|
|
85
|
+
EfficientViTB1,
|
|
86
|
+
EfficientViTB2,
|
|
87
|
+
EfficientViTB3,
|
|
88
|
+
EfficientViTL1,
|
|
89
|
+
EfficientViTL2,
|
|
90
|
+
EfficientViTM0,
|
|
91
|
+
EfficientViTM1,
|
|
92
|
+
EfficientViTM2,
|
|
93
|
+
)
|
|
94
|
+
from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
|
|
95
|
+
from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
|
|
96
|
+
from .unireplknet import (
|
|
97
|
+
UniRepLKNetBaseLarge,
|
|
98
|
+
UniRepLKNetSmall,
|
|
99
|
+
UniRepLKNetTiny,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
_HAS_TIMM_MODELS = True
|
|
103
|
+
except ImportError:
|
|
104
|
+
_HAS_TIMM_MODELS = False
|
|
105
|
+
|
|
106
|
+
|
|
69
107
|
# Export public API (sorted alphabetically per RUF022)
|
|
70
108
|
# See module docstring for dimensionality support details
|
|
71
109
|
__all__ = [
|
|
@@ -77,6 +115,11 @@ __all__ = [
|
|
|
77
115
|
"ConvNeXtBase_",
|
|
78
116
|
"ConvNeXtSmall",
|
|
79
117
|
"ConvNeXtTiny",
|
|
118
|
+
"ConvNeXtV2Base",
|
|
119
|
+
"ConvNeXtV2BaseLarge",
|
|
120
|
+
"ConvNeXtV2Small",
|
|
121
|
+
"ConvNeXtV2Tiny",
|
|
122
|
+
"ConvNeXtV2TinyPretrained",
|
|
80
123
|
"DenseNet121",
|
|
81
124
|
"DenseNet169",
|
|
82
125
|
"EfficientNetB0",
|
|
@@ -85,6 +128,7 @@ __all__ = [
|
|
|
85
128
|
"EfficientNetV2L",
|
|
86
129
|
"EfficientNetV2M",
|
|
87
130
|
"EfficientNetV2S",
|
|
131
|
+
"Mamba1D",
|
|
88
132
|
"MobileNetV3Large",
|
|
89
133
|
"MobileNetV3Small",
|
|
90
134
|
"RegNetY1_6GF",
|
|
@@ -105,8 +149,40 @@ __all__ = [
|
|
|
105
149
|
"ViTBase_",
|
|
106
150
|
"ViTSmall",
|
|
107
151
|
"ViTTiny",
|
|
152
|
+
"VimBase",
|
|
153
|
+
"VimSmall",
|
|
154
|
+
"VimTiny",
|
|
108
155
|
"build_model",
|
|
109
156
|
"get_model",
|
|
110
157
|
"list_models",
|
|
111
158
|
"register_model",
|
|
112
159
|
]
|
|
160
|
+
|
|
161
|
+
# Add timm-based models to __all__ if available
|
|
162
|
+
if _HAS_TIMM_MODELS:
|
|
163
|
+
__all__.extend(
|
|
164
|
+
[
|
|
165
|
+
"CaFormerS18",
|
|
166
|
+
"CaFormerS36",
|
|
167
|
+
"EfficientViTB0",
|
|
168
|
+
"EfficientViTB1",
|
|
169
|
+
"EfficientViTB2",
|
|
170
|
+
"EfficientViTB3",
|
|
171
|
+
"EfficientViTL1",
|
|
172
|
+
"EfficientViTL2",
|
|
173
|
+
"EfficientViTM0",
|
|
174
|
+
"EfficientViTM1",
|
|
175
|
+
"EfficientViTM2",
|
|
176
|
+
"FastViTS12",
|
|
177
|
+
"FastViTSA12",
|
|
178
|
+
"FastViTT8",
|
|
179
|
+
"FastViTT12",
|
|
180
|
+
"MaxViTBaseLarge",
|
|
181
|
+
"MaxViTSmall",
|
|
182
|
+
"MaxViTTiny",
|
|
183
|
+
"PoolFormerS12",
|
|
184
|
+
"UniRepLKNetBaseLarge",
|
|
185
|
+
"UniRepLKNetSmall",
|
|
186
|
+
"UniRepLKNetTiny",
|
|
187
|
+
]
|
|
188
|
+
)
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared Utilities for Model Architectures
|
|
3
|
+
=========================================
|
|
4
|
+
|
|
5
|
+
Common components used across multiple models:
|
|
6
|
+
- GRN (Global Response Normalization) for ConvNeXt V2
|
|
7
|
+
- Dimension-agnostic layer factories
|
|
8
|
+
- Regression head builders
|
|
9
|
+
- Input channel adaptation for pretrained models
|
|
10
|
+
|
|
11
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =============================================================================
|
|
20
|
+
# DIMENSION-AGNOSTIC LAYER FACTORIES
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_conv_layer(dim: int) -> type[nn.Module]:
|
|
25
|
+
"""Get dimension-appropriate Conv class."""
|
|
26
|
+
layers = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
|
|
27
|
+
if dim not in layers:
|
|
28
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
29
|
+
return layers[dim]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_norm_layer(dim: int) -> type[nn.Module]:
|
|
33
|
+
"""Get dimension-appropriate BatchNorm class."""
|
|
34
|
+
layers = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d}
|
|
35
|
+
if dim not in layers:
|
|
36
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
37
|
+
return layers[dim]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_pool_layer(dim: int) -> type[nn.Module]:
|
|
41
|
+
"""Get dimension-appropriate AdaptiveAvgPool class."""
|
|
42
|
+
layers = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d, 3: nn.AdaptiveAvgPool3d}
|
|
43
|
+
if dim not in layers:
|
|
44
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
45
|
+
return layers[dim]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# GLOBAL RESPONSE NORMALIZATION (GRN) - ConvNeXt V2
|
|
50
|
+
# =============================================================================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GRN1d(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
Global Response Normalization for 1D inputs.
|
|
56
|
+
|
|
57
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
58
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
59
|
+
|
|
60
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1))
|
|
66
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1))
|
|
67
|
+
self.eps = eps
|
|
68
|
+
|
|
69
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
# x: (B, C, L)
|
|
71
|
+
Gx = torch.norm(x, p=2, dim=2, keepdim=True) # (B, C, 1)
|
|
72
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1)
|
|
73
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class GRN2d(nn.Module):
|
|
77
|
+
"""
|
|
78
|
+
Global Response Normalization for 2D inputs.
|
|
79
|
+
|
|
80
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
81
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
82
|
+
|
|
83
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
89
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
90
|
+
self.eps = eps
|
|
91
|
+
|
|
92
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
# x: (B, C, H, W)
|
|
94
|
+
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) # (B, C, 1, 1)
|
|
95
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1)
|
|
96
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class GRN3d(nn.Module):
|
|
100
|
+
"""
|
|
101
|
+
Global Response Normalization for 3D inputs.
|
|
102
|
+
|
|
103
|
+
GRN enhances inter-channel feature competition and promotes diversity.
|
|
104
|
+
Replaces LayerScale in ConvNeXt V2.
|
|
105
|
+
|
|
106
|
+
Reference: ConvNeXt V2 (CVPR 2023)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
|
|
112
|
+
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
|
|
113
|
+
self.eps = eps
|
|
114
|
+
|
|
115
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
|
+
# x: (B, C, D, H, W)
|
|
117
|
+
Gx = torch.norm(x, p=2, dim=(2, 3, 4), keepdim=True) # (B, C, 1, 1, 1)
|
|
118
|
+
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1, 1)
|
|
119
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_grn_layer(dim: int) -> type[nn.Module]:
|
|
123
|
+
"""Get dimension-appropriate GRN class."""
|
|
124
|
+
layers = {1: GRN1d, 2: GRN2d, 3: GRN3d}
|
|
125
|
+
if dim not in layers:
|
|
126
|
+
raise ValueError(f"Unsupported dimension: {dim}")
|
|
127
|
+
return layers[dim]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# =============================================================================
|
|
131
|
+
# LAYER NORMALIZATION (Channels Last for CNNs)
|
|
132
|
+
# =============================================================================
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class LayerNormNd(nn.Module):
|
|
136
|
+
"""
|
|
137
|
+
LayerNorm that works with channels-first tensors of any dimension.
|
|
138
|
+
Applies normalization over the channel dimension.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, normalized_shape: int, dim: int, eps: float = 1e-6):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
144
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
145
|
+
self.eps = eps
|
|
146
|
+
self.dim = dim
|
|
147
|
+
self.normalized_shape = (normalized_shape,)
|
|
148
|
+
|
|
149
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
# Move channels to last, apply LN, move back
|
|
151
|
+
if self.dim == 1:
|
|
152
|
+
# (B, C, L) -> (B, L, C) -> LN -> (B, C, L)
|
|
153
|
+
x = x.permute(0, 2, 1)
|
|
154
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
155
|
+
x = x.permute(0, 2, 1)
|
|
156
|
+
elif self.dim == 2:
|
|
157
|
+
# (B, C, H, W) -> (B, H, W, C) -> LN -> (B, C, H, W)
|
|
158
|
+
x = x.permute(0, 2, 3, 1)
|
|
159
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
160
|
+
x = x.permute(0, 3, 1, 2)
|
|
161
|
+
elif self.dim == 3:
|
|
162
|
+
# (B, C, D, H, W) -> (B, D, H, W, C) -> LN -> (B, C, D, H, W)
|
|
163
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
164
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
165
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# REGRESSION HEAD BUILDERS
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_regression_head(
|
|
175
|
+
in_features: int,
|
|
176
|
+
out_size: int,
|
|
177
|
+
dropout_rate: float = 0.3,
|
|
178
|
+
hidden_dim: int = 512,
|
|
179
|
+
) -> nn.Sequential:
|
|
180
|
+
"""
|
|
181
|
+
Build a standard regression head for pretrained models.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
in_features: Input feature dimension
|
|
185
|
+
out_size: Number of regression targets
|
|
186
|
+
dropout_rate: Dropout rate
|
|
187
|
+
hidden_dim: Hidden layer dimension
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
nn.Sequential regression head
|
|
191
|
+
"""
|
|
192
|
+
return nn.Sequential(
|
|
193
|
+
nn.Dropout(dropout_rate),
|
|
194
|
+
nn.Linear(in_features, hidden_dim),
|
|
195
|
+
nn.SiLU(inplace=True),
|
|
196
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
197
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
198
|
+
nn.SiLU(inplace=True),
|
|
199
|
+
nn.Linear(hidden_dim // 2, out_size),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def adapt_input_channels(
|
|
204
|
+
conv_layer: nn.Module,
|
|
205
|
+
new_in_channels: int = 1,
|
|
206
|
+
pretrained: bool = True,
|
|
207
|
+
) -> nn.Module:
|
|
208
|
+
"""
|
|
209
|
+
Adapt a convolutional layer for different input channels.
|
|
210
|
+
|
|
211
|
+
For pretrained models, averages RGB weights to grayscale.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
conv_layer: Original conv layer (expects 3 input channels)
|
|
215
|
+
new_in_channels: New number of input channels (default: 1)
|
|
216
|
+
pretrained: Whether to adapt pretrained weights
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
New conv layer with adapted input channels
|
|
220
|
+
"""
|
|
221
|
+
if isinstance(conv_layer, nn.Conv2d):
|
|
222
|
+
new_conv = nn.Conv2d(
|
|
223
|
+
new_in_channels,
|
|
224
|
+
conv_layer.out_channels,
|
|
225
|
+
kernel_size=conv_layer.kernel_size,
|
|
226
|
+
stride=conv_layer.stride,
|
|
227
|
+
padding=conv_layer.padding,
|
|
228
|
+
bias=conv_layer.bias is not None,
|
|
229
|
+
)
|
|
230
|
+
if pretrained and conv_layer.in_channels == 3:
|
|
231
|
+
with torch.no_grad():
|
|
232
|
+
# Average RGB weights
|
|
233
|
+
new_conv.weight.copy_(conv_layer.weight.mean(dim=1, keepdim=True))
|
|
234
|
+
if conv_layer.bias is not None:
|
|
235
|
+
new_conv.bias.copy_(conv_layer.bias)
|
|
236
|
+
return new_conv
|
|
237
|
+
else:
|
|
238
|
+
raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def adapt_first_conv_for_single_channel(
|
|
242
|
+
module: nn.Module,
|
|
243
|
+
conv_path: str,
|
|
244
|
+
pretrained: bool = True,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Adapt the first convolutional layer of a pretrained model for single-channel input.
|
|
248
|
+
|
|
249
|
+
This is a convenience function for torchvision-style models where the path
|
|
250
|
+
to the first conv layer is known. It modifies the model in-place.
|
|
251
|
+
|
|
252
|
+
For pretrained models, the RGB weights are averaged to create grayscale weights,
|
|
253
|
+
which provides a reasonable initialization for single-channel inputs.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
module: The model or submodule containing the conv layer
|
|
257
|
+
conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
|
|
258
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> # For torchvision ResNet
|
|
262
|
+
>>> adapt_first_conv_for_single_channel(
|
|
263
|
+
... model.backbone, "conv1", pretrained=True
|
|
264
|
+
... )
|
|
265
|
+
>>> # For torchvision ConvNeXt
|
|
266
|
+
>>> adapt_first_conv_for_single_channel(
|
|
267
|
+
... model.backbone, "features.0.0", pretrained=True
|
|
268
|
+
... )
|
|
269
|
+
>>> # For torchvision DenseNet
|
|
270
|
+
>>> adapt_first_conv_for_single_channel(
|
|
271
|
+
... model.backbone, "features.conv0", pretrained=True
|
|
272
|
+
... )
|
|
273
|
+
"""
|
|
274
|
+
# Navigate to parent and get the conv layer
|
|
275
|
+
parts = conv_path.split(".")
|
|
276
|
+
parent = module
|
|
277
|
+
for part in parts[:-1]:
|
|
278
|
+
if part.isdigit():
|
|
279
|
+
parent = parent[int(part)]
|
|
280
|
+
else:
|
|
281
|
+
parent = getattr(parent, part)
|
|
282
|
+
|
|
283
|
+
# Get the final attribute name and the old conv
|
|
284
|
+
final_attr = parts[-1]
|
|
285
|
+
if final_attr.isdigit():
|
|
286
|
+
old_conv = parent[int(final_attr)]
|
|
287
|
+
else:
|
|
288
|
+
old_conv = getattr(parent, final_attr)
|
|
289
|
+
|
|
290
|
+
# Create and set the new conv
|
|
291
|
+
new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
|
|
292
|
+
|
|
293
|
+
if final_attr.isdigit():
|
|
294
|
+
parent[int(final_attr)] = new_conv
|
|
295
|
+
else:
|
|
296
|
+
setattr(parent, final_attr, new_conv)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def find_and_adapt_input_convs(
|
|
300
|
+
backbone: nn.Module,
|
|
301
|
+
pretrained: bool = True,
|
|
302
|
+
adapt_all: bool = False,
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Find and adapt Conv2d layers with 3 input channels for single-channel input.
|
|
306
|
+
|
|
307
|
+
This is useful for timm-style models where the exact path to the first
|
|
308
|
+
conv layer may vary or where multiple layers need adaptation.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
backbone: The backbone model to adapt
|
|
312
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
313
|
+
adapt_all: If True, adapt all Conv2d layers with 3 input channels.
|
|
314
|
+
If False (default), only adapt the first one found.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Number of layers adapted
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> # For timm models (adapt first conv only)
|
|
321
|
+
>>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
|
|
322
|
+
>>> # For models with multiple input convs (e.g., FastViT)
|
|
323
|
+
>>> count = find_and_adapt_input_convs(
|
|
324
|
+
... model.backbone, pretrained=True, adapt_all=True
|
|
325
|
+
... )
|
|
326
|
+
"""
|
|
327
|
+
adapted_count = 0
|
|
328
|
+
|
|
329
|
+
for name, module in backbone.named_modules():
|
|
330
|
+
if not hasattr(module, "in_channels") or module.in_channels != 3:
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# Check if this is a wrapper with inner .conv attribute
|
|
334
|
+
if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
|
|
335
|
+
old_conv = module.conv
|
|
336
|
+
module.conv = adapt_input_channels(
|
|
337
|
+
old_conv, new_in_channels=1, pretrained=pretrained
|
|
338
|
+
)
|
|
339
|
+
adapted_count += 1
|
|
340
|
+
|
|
341
|
+
elif isinstance(module, nn.Conv2d):
|
|
342
|
+
# Direct Conv2d - need to replace it in parent
|
|
343
|
+
parts = name.split(".")
|
|
344
|
+
parent = backbone
|
|
345
|
+
for part in parts[:-1]:
|
|
346
|
+
if part.isdigit():
|
|
347
|
+
parent = parent[int(part)]
|
|
348
|
+
else:
|
|
349
|
+
parent = getattr(parent, part)
|
|
350
|
+
|
|
351
|
+
child_name = parts[-1]
|
|
352
|
+
new_conv = adapt_input_channels(
|
|
353
|
+
module, new_in_channels=1, pretrained=pretrained
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if child_name.isdigit():
|
|
357
|
+
parent[int(child_name)] = new_conv
|
|
358
|
+
else:
|
|
359
|
+
setattr(parent, child_name, new_conv)
|
|
360
|
+
|
|
361
|
+
adapted_count += 1
|
|
362
|
+
|
|
363
|
+
if not adapt_all and adapted_count > 0:
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
return adapted_count
|
wavedl/models/base.py
CHANGED
|
@@ -15,6 +15,54 @@ import torch
|
|
|
15
15
|
import torch.nn as nn
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# TYPE ALIASES
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
# Spatial shape type aliases for model input dimensions
|
|
23
|
+
SpatialShape1D = tuple[int]
|
|
24
|
+
SpatialShape2D = tuple[int, int]
|
|
25
|
+
SpatialShape3D = tuple[int, int, int]
|
|
26
|
+
SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# UTILITY FUNCTIONS
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
35
|
+
"""
|
|
36
|
+
Compute valid num_groups for GroupNorm that divides num_channels evenly.
|
|
37
|
+
|
|
38
|
+
GroupNorm requires num_channels to be divisible by num_groups. This function
|
|
39
|
+
finds the largest valid divisor up to preferred_groups.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
num_channels: Number of channels to normalize (must be positive)
|
|
43
|
+
preferred_groups: Preferred number of groups (default: 32)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> compute_num_groups(64) # Returns 32
|
|
50
|
+
>>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
|
|
51
|
+
>>> compute_num_groups(7) # Returns 1 (prime number)
|
|
52
|
+
"""
|
|
53
|
+
# Try preferred groups first, then common divisors
|
|
54
|
+
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
55
|
+
if groups <= num_channels and num_channels % groups == 0:
|
|
56
|
+
return groups
|
|
57
|
+
|
|
58
|
+
# Fallback: find any valid divisor (always returns at least 1)
|
|
59
|
+
for groups in range(min(32, num_channels), 0, -1):
|
|
60
|
+
if num_channels % groups == 0:
|
|
61
|
+
return groups
|
|
62
|
+
|
|
63
|
+
return 1 # Always valid
|
|
64
|
+
|
|
65
|
+
|
|
18
66
|
class BaseModel(nn.Module, ABC):
|
|
19
67
|
"""
|
|
20
68
|
Abstract base class for all regression models.
|