wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/unet.py
CHANGED
|
@@ -22,14 +22,10 @@ from typing import Any
|
|
|
22
22
|
import torch
|
|
23
23
|
import torch.nn as nn
|
|
24
24
|
|
|
25
|
-
from wavedl.models.base import BaseModel
|
|
25
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
26
26
|
from wavedl.models.registry import register_model
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
# Type alias for spatial shapes
|
|
30
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
31
|
-
|
|
32
|
-
|
|
33
29
|
def _get_layers(dim: int):
|
|
34
30
|
"""Get dimension-appropriate layer classes."""
|
|
35
31
|
if dim == 1:
|
|
@@ -119,7 +115,7 @@ class UNetRegression(BaseModel):
|
|
|
119
115
|
Uses U-Net encoder-decoder architecture with skip connections,
|
|
120
116
|
then applies global pooling for standard vector regression output.
|
|
121
117
|
|
|
122
|
-
~31.
|
|
118
|
+
~31.0M backbone parameters (2D). Good for leveraging multi-scale features
|
|
123
119
|
and skip connections for regression tasks.
|
|
124
120
|
|
|
125
121
|
Args:
|
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UniRepLKNet: Universal Large-Kernel ConvNet for Regression
|
|
3
|
+
===========================================================
|
|
4
|
+
|
|
5
|
+
A dimension-agnostic implementation of UniRepLKNet featuring ultra-large kernels
|
|
6
|
+
(up to 31x31) for capturing long-range dependencies. Particularly effective for
|
|
7
|
+
wave-based problems where spatial correlations span large distances.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Large kernels (13x13 to 31x31) via efficient decomposition
|
|
11
|
+
- Dilated small kernel reparam for efficient training
|
|
12
|
+
- SE (Squeeze-and-Excitation) attention
|
|
13
|
+
- GRN (Global Response Normalization) from ConvNeXt V2
|
|
14
|
+
- Dimension-agnostic: supports 1D, 2D, 3D inputs
|
|
15
|
+
|
|
16
|
+
**Variants**:
|
|
17
|
+
- unireplknet_tiny: 31M params, depths [3,3,18,3], dims [80,160,320,640]
|
|
18
|
+
- unireplknet_small: 56M params, depths [3,3,27,3], dims [96,192,384,768]
|
|
19
|
+
- unireplknet_base: 97M params, depths [3,3,27,3], dims [128,256,512,1024]
|
|
20
|
+
|
|
21
|
+
**Why Large Kernels for Wave Problems**:
|
|
22
|
+
- Dispersion curves: Long-range frequency-wavenumber correlations
|
|
23
|
+
- B-scans: Defect signatures span many pixels
|
|
24
|
+
- Time-series: Capture multiple wave periods without deep stacking
|
|
25
|
+
|
|
26
|
+
Reference:
|
|
27
|
+
Ding, X., et al. (2024). UniRepLKNet: A Universal Perception Large-Kernel
|
|
28
|
+
ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition.
|
|
29
|
+
CVPR 2024. https://arxiv.org/abs/2311.15599
|
|
30
|
+
|
|
31
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from typing import Any
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import torch.nn as nn
|
|
38
|
+
|
|
39
|
+
from wavedl.models._pretrained_utils import (
|
|
40
|
+
LayerNormNd,
|
|
41
|
+
get_conv_layer,
|
|
42
|
+
get_grn_layer,
|
|
43
|
+
get_pool_layer,
|
|
44
|
+
)
|
|
45
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
46
|
+
from wavedl.models.registry import register_model
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
"UniRepLKNetBase",
|
|
51
|
+
"UniRepLKNetBaseLarge",
|
|
52
|
+
"UniRepLKNetSmall",
|
|
53
|
+
"UniRepLKNetTiny",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# =============================================================================
|
|
58
|
+
# LARGE KERNEL CONVOLUTION BLOCK
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class LargeKernelConv(nn.Module):
|
|
63
|
+
"""
|
|
64
|
+
Large kernel depthwise convolution.
|
|
65
|
+
|
|
66
|
+
Implements efficient large kernel convolutions following UniRepLKNet.
|
|
67
|
+
Uses a single large depthwise conv for simplicity and reliability.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
channels: int,
|
|
73
|
+
kernel_size: int,
|
|
74
|
+
dim: int = 2,
|
|
75
|
+
):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.dim = dim
|
|
78
|
+
self.kernel_size = kernel_size
|
|
79
|
+
|
|
80
|
+
Conv = get_conv_layer(dim)
|
|
81
|
+
padding = kernel_size // 2
|
|
82
|
+
|
|
83
|
+
# Large kernel depthwise conv
|
|
84
|
+
self.conv = Conv(
|
|
85
|
+
channels,
|
|
86
|
+
channels,
|
|
87
|
+
kernel_size=kernel_size,
|
|
88
|
+
padding=padding,
|
|
89
|
+
groups=channels,
|
|
90
|
+
bias=False,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
return self.conv(x)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SEBlock(nn.Module):
|
|
98
|
+
"""
|
|
99
|
+
Squeeze-and-Excitation block for channel attention.
|
|
100
|
+
|
|
101
|
+
Adaptively recalibrates channel-wise feature responses by explicitly
|
|
102
|
+
modeling interdependencies between channels.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, channels: int, reduction: int = 4):
|
|
106
|
+
super().__init__()
|
|
107
|
+
reduced = max(channels // reduction, 8)
|
|
108
|
+
self.fc1 = nn.Linear(channels, reduced, bias=False)
|
|
109
|
+
self.fc2 = nn.Linear(reduced, channels, bias=False)
|
|
110
|
+
self.act = nn.GELU()
|
|
111
|
+
|
|
112
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
113
|
+
# Global average pooling
|
|
114
|
+
if x.ndim == 3: # 1D: (B, C, L)
|
|
115
|
+
gap = x.mean(dim=2)
|
|
116
|
+
elif x.ndim == 4: # 2D: (B, C, H, W)
|
|
117
|
+
gap = x.mean(dim=(2, 3))
|
|
118
|
+
else: # 3D: (B, C, D, H, W)
|
|
119
|
+
gap = x.mean(dim=(2, 3, 4))
|
|
120
|
+
|
|
121
|
+
# FC layers
|
|
122
|
+
scale = self.act(self.fc1(gap))
|
|
123
|
+
scale = torch.sigmoid(self.fc2(scale))
|
|
124
|
+
|
|
125
|
+
# Reshape for broadcasting
|
|
126
|
+
if x.ndim == 3:
|
|
127
|
+
scale = scale.unsqueeze(-1)
|
|
128
|
+
elif x.ndim == 4:
|
|
129
|
+
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
|
130
|
+
else:
|
|
131
|
+
scale = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
|
132
|
+
|
|
133
|
+
return x * scale
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class DropPath(nn.Module):
|
|
137
|
+
"""Stochastic Depth (drop path) regularization."""
|
|
138
|
+
|
|
139
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.drop_prob = drop_prob
|
|
142
|
+
|
|
143
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
145
|
+
return x
|
|
146
|
+
|
|
147
|
+
keep_prob = 1 - self.drop_prob
|
|
148
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
149
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
150
|
+
random_tensor.floor_()
|
|
151
|
+
return x.div(keep_prob) * random_tensor
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# =============================================================================
|
|
155
|
+
# UNIREPLKNET BLOCK
|
|
156
|
+
# =============================================================================
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class UniRepLKNetBlock(nn.Module):
|
|
160
|
+
"""
|
|
161
|
+
UniRepLKNet block with large kernel convolution, SE attention, and GRN.
|
|
162
|
+
|
|
163
|
+
Architecture:
|
|
164
|
+
Input → LargeKernelConv → LayerNorm → SE → Linear → GELU → GRN → Linear → Residual
|
|
165
|
+
|
|
166
|
+
This combines the effective receptive field of large kernels with the
|
|
167
|
+
feature recalibration of SE attention and the inter-channel competition
|
|
168
|
+
of GRN from ConvNeXt V2.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
dim: int,
|
|
174
|
+
spatial_dim: int,
|
|
175
|
+
kernel_size: int = 13,
|
|
176
|
+
drop_path: float = 0.0,
|
|
177
|
+
mlp_ratio: float = 4.0,
|
|
178
|
+
):
|
|
179
|
+
super().__init__()
|
|
180
|
+
self.spatial_dim = spatial_dim
|
|
181
|
+
|
|
182
|
+
GRN = get_grn_layer(spatial_dim)
|
|
183
|
+
|
|
184
|
+
# Large kernel depthwise conv
|
|
185
|
+
self.large_kernel = LargeKernelConv(
|
|
186
|
+
dim, kernel_size=kernel_size, dim=spatial_dim
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Layer norm (applied in channels-last format)
|
|
190
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
191
|
+
|
|
192
|
+
# SE attention
|
|
193
|
+
self.se = SEBlock(dim)
|
|
194
|
+
|
|
195
|
+
# MLP with expansion
|
|
196
|
+
hidden_dim = int(dim * mlp_ratio)
|
|
197
|
+
self.pwconv1 = nn.Linear(dim, hidden_dim)
|
|
198
|
+
self.act = nn.GELU()
|
|
199
|
+
self.grn = GRN(hidden_dim)
|
|
200
|
+
self.pwconv2 = nn.Linear(hidden_dim, dim)
|
|
201
|
+
|
|
202
|
+
# Stochastic depth
|
|
203
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
204
|
+
|
|
205
|
+
def _to_channels_last(self, x: torch.Tensor) -> torch.Tensor:
|
|
206
|
+
"""Convert from channels-first to channels-last."""
|
|
207
|
+
if self.spatial_dim == 1:
|
|
208
|
+
return x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
|
|
209
|
+
elif self.spatial_dim == 2:
|
|
210
|
+
return x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
|
|
211
|
+
else:
|
|
212
|
+
return x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
|
|
213
|
+
|
|
214
|
+
def _to_channels_first(self, x: torch.Tensor) -> torch.Tensor:
|
|
215
|
+
"""Convert from channels-last to channels-first."""
|
|
216
|
+
if self.spatial_dim == 1:
|
|
217
|
+
return x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
|
|
218
|
+
elif self.spatial_dim == 2:
|
|
219
|
+
return x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
|
|
220
|
+
else:
|
|
221
|
+
return x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
|
|
222
|
+
|
|
223
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
224
|
+
residual = x
|
|
225
|
+
|
|
226
|
+
# Large kernel conv (channels-first)
|
|
227
|
+
x = self.large_kernel(x)
|
|
228
|
+
|
|
229
|
+
# SE attention (channels-first)
|
|
230
|
+
x = self.se(x)
|
|
231
|
+
|
|
232
|
+
# LayerNorm + MLP (channels-last)
|
|
233
|
+
x = self._to_channels_last(x)
|
|
234
|
+
x = self.norm(x)
|
|
235
|
+
x = self.pwconv1(x)
|
|
236
|
+
x = self.act(x)
|
|
237
|
+
|
|
238
|
+
# GRN (channels-first)
|
|
239
|
+
x = self._to_channels_first(x)
|
|
240
|
+
x = self.grn(x)
|
|
241
|
+
|
|
242
|
+
# Final projection (channels-last)
|
|
243
|
+
x = self._to_channels_last(x)
|
|
244
|
+
x = self.pwconv2(x)
|
|
245
|
+
x = self._to_channels_first(x)
|
|
246
|
+
|
|
247
|
+
# Residual + drop path
|
|
248
|
+
x = residual + self.drop_path(x)
|
|
249
|
+
return x
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# =============================================================================
|
|
253
|
+
# UNIREPLKNET BASE CLASS
|
|
254
|
+
# =============================================================================
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class UniRepLKNetBase(BaseModel):
|
|
258
|
+
"""
|
|
259
|
+
UniRepLKNet base class for regression.
|
|
260
|
+
|
|
261
|
+
Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
|
|
262
|
+
Features large kernels for capturing long-range dependencies in wave data.
|
|
263
|
+
|
|
264
|
+
Architecture:
|
|
265
|
+
1. Stem: 4x downsampling conv
|
|
266
|
+
2. 4 stages with UniRepLKNet blocks
|
|
267
|
+
3. Downsampling between stages
|
|
268
|
+
4. Global pooling + regression head
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
in_shape: SpatialShape,
|
|
274
|
+
out_size: int,
|
|
275
|
+
depths: list[int],
|
|
276
|
+
dims: list[int],
|
|
277
|
+
kernel_sizes: list[int] | None = None,
|
|
278
|
+
drop_path_rate: float = 0.1,
|
|
279
|
+
dropout_rate: float = 0.3,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
super().__init__(in_shape, out_size)
|
|
283
|
+
|
|
284
|
+
self.dim = len(in_shape)
|
|
285
|
+
self.depths = depths
|
|
286
|
+
self.dims = dims
|
|
287
|
+
|
|
288
|
+
# Default kernel sizes: larger in early stages, smaller in later stages
|
|
289
|
+
# Early stages: large receptive field for low-level features
|
|
290
|
+
# Later stages: smaller kernels sufficient for high-level features
|
|
291
|
+
if kernel_sizes is None:
|
|
292
|
+
kernel_sizes = [31, 29, 17, 13]
|
|
293
|
+
|
|
294
|
+
Conv = get_conv_layer(self.dim)
|
|
295
|
+
Pool = get_pool_layer(self.dim)
|
|
296
|
+
|
|
297
|
+
# Stem: aggressive 4x downsampling (like ConvNeXt)
|
|
298
|
+
self.stem = nn.Sequential(
|
|
299
|
+
Conv(1, dims[0], kernel_size=4, stride=4),
|
|
300
|
+
LayerNormNd(dims[0], self.dim),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Stochastic depth decay
|
|
304
|
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
305
|
+
|
|
306
|
+
# Build stages
|
|
307
|
+
self.stages = nn.ModuleList()
|
|
308
|
+
self.downsamples = nn.ModuleList()
|
|
309
|
+
cur = 0
|
|
310
|
+
|
|
311
|
+
for i in range(len(depths)):
|
|
312
|
+
# Adjust kernel size for 1D (can use larger kernels)
|
|
313
|
+
# Ensure kernel size is always odd for proper same-padding
|
|
314
|
+
kernel_size = kernel_sizes[i]
|
|
315
|
+
if self.dim == 1:
|
|
316
|
+
kernel_size = min(kernel_size * 2 - 1, 63) # Keep odd for 1D
|
|
317
|
+
|
|
318
|
+
stage = nn.Sequential(
|
|
319
|
+
*[
|
|
320
|
+
UniRepLKNetBlock(
|
|
321
|
+
dim=dims[i],
|
|
322
|
+
spatial_dim=self.dim,
|
|
323
|
+
kernel_size=kernel_size,
|
|
324
|
+
drop_path=dp_rates[cur + j],
|
|
325
|
+
)
|
|
326
|
+
for j in range(depths[i])
|
|
327
|
+
]
|
|
328
|
+
)
|
|
329
|
+
self.stages.append(stage)
|
|
330
|
+
cur += depths[i]
|
|
331
|
+
|
|
332
|
+
# Downsample between stages (except after last)
|
|
333
|
+
if i < len(depths) - 1:
|
|
334
|
+
downsample = nn.Sequential(
|
|
335
|
+
LayerNormNd(dims[i], self.dim),
|
|
336
|
+
Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
|
337
|
+
)
|
|
338
|
+
self.downsamples.append(downsample)
|
|
339
|
+
|
|
340
|
+
# Global pooling and head
|
|
341
|
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
|
|
342
|
+
self.global_pool = Pool(1)
|
|
343
|
+
self.head = nn.Sequential(
|
|
344
|
+
nn.Dropout(dropout_rate),
|
|
345
|
+
nn.Linear(dims[-1], dims[-1] // 2),
|
|
346
|
+
nn.GELU(),
|
|
347
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
348
|
+
nn.Linear(dims[-1] // 2, out_size),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Initialize weights
|
|
352
|
+
self._init_weights()
|
|
353
|
+
|
|
354
|
+
def _init_weights(self):
|
|
355
|
+
"""Initialize weights with truncated normal."""
|
|
356
|
+
for m in self.modules():
|
|
357
|
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
|
358
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
359
|
+
if m.bias is not None:
|
|
360
|
+
nn.init.zeros_(m.bias)
|
|
361
|
+
elif isinstance(m, nn.LayerNorm):
|
|
362
|
+
nn.init.ones_(m.weight)
|
|
363
|
+
nn.init.zeros_(m.bias)
|
|
364
|
+
|
|
365
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
366
|
+
"""
|
|
367
|
+
Forward pass.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
x: Input tensor (B, 1, *in_shape)
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Output tensor (B, out_size)
|
|
374
|
+
"""
|
|
375
|
+
x = self.stem(x)
|
|
376
|
+
|
|
377
|
+
for i, stage in enumerate(self.stages):
|
|
378
|
+
x = stage(x)
|
|
379
|
+
if i < len(self.downsamples):
|
|
380
|
+
x = self.downsamples[i](x)
|
|
381
|
+
|
|
382
|
+
# Global pooling
|
|
383
|
+
x = self.global_pool(x)
|
|
384
|
+
x = x.flatten(1)
|
|
385
|
+
|
|
386
|
+
# Final norm and head
|
|
387
|
+
x = self.norm(x)
|
|
388
|
+
x = self.head(x)
|
|
389
|
+
|
|
390
|
+
return x
|
|
391
|
+
|
|
392
|
+
@classmethod
|
|
393
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
394
|
+
return {
|
|
395
|
+
"depths": [3, 3, 18, 3],
|
|
396
|
+
"dims": [80, 160, 320, 640],
|
|
397
|
+
"kernel_sizes": [31, 29, 17, 13],
|
|
398
|
+
"drop_path_rate": 0.1,
|
|
399
|
+
"dropout_rate": 0.3,
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
# =============================================================================
|
|
404
|
+
# REGISTERED VARIANTS
|
|
405
|
+
# =============================================================================
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@register_model("unireplknet_tiny")
|
|
409
|
+
class UniRepLKNetTiny(UniRepLKNetBase):
|
|
410
|
+
"""
|
|
411
|
+
UniRepLKNet Tiny: ~30.8M backbone parameters.
|
|
412
|
+
|
|
413
|
+
Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
|
|
414
|
+
Depths [3,3,18,3], Dims [80,160,320,640].
|
|
415
|
+
Supports 1D, 2D, 3D inputs.
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
>>> model = UniRepLKNetTiny(in_shape=(256, 256), out_size=3)
|
|
419
|
+
>>> x = torch.randn(4, 1, 256, 256)
|
|
420
|
+
>>> out = model(x) # (4, 3)
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
424
|
+
super().__init__(
|
|
425
|
+
in_shape=in_shape,
|
|
426
|
+
out_size=out_size,
|
|
427
|
+
depths=[3, 3, 18, 3],
|
|
428
|
+
dims=[80, 160, 320, 640],
|
|
429
|
+
kernel_sizes=[31, 29, 17, 13],
|
|
430
|
+
**kwargs,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def __repr__(self) -> str:
|
|
434
|
+
return (
|
|
435
|
+
f"UniRepLKNet_Tiny({self.dim}D, in_shape={self.in_shape}, "
|
|
436
|
+
f"out_size={self.out_size})"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@register_model("unireplknet_small")
|
|
441
|
+
class UniRepLKNetSmall(UniRepLKNetBase):
|
|
442
|
+
"""
|
|
443
|
+
UniRepLKNet Small: ~56.0M backbone parameters.
|
|
444
|
+
|
|
445
|
+
Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
|
|
446
|
+
Depths [3,3,27,3], Dims [96,192,384,768].
|
|
447
|
+
Supports 1D, 2D, 3D inputs.
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
451
|
+
super().__init__(
|
|
452
|
+
in_shape=in_shape,
|
|
453
|
+
out_size=out_size,
|
|
454
|
+
depths=[3, 3, 27, 3],
|
|
455
|
+
dims=[96, 192, 384, 768],
|
|
456
|
+
kernel_sizes=[31, 29, 17, 13],
|
|
457
|
+
**kwargs,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
def __repr__(self) -> str:
|
|
461
|
+
return (
|
|
462
|
+
f"UniRepLKNet_Small({self.dim}D, in_shape={self.in_shape}, "
|
|
463
|
+
f"out_size={self.out_size})"
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@register_model("unireplknet_base")
|
|
468
|
+
class UniRepLKNetBaseLarge(UniRepLKNetBase):
|
|
469
|
+
"""
|
|
470
|
+
UniRepLKNet Base: ~97.6M backbone parameters.
|
|
471
|
+
|
|
472
|
+
Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
|
|
473
|
+
Depths [3,3,27,3], Dims [128,256,512,1024].
|
|
474
|
+
Supports 1D, 2D, 3D inputs.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
478
|
+
super().__init__(
|
|
479
|
+
in_shape=in_shape,
|
|
480
|
+
out_size=out_size,
|
|
481
|
+
depths=[3, 3, 27, 3],
|
|
482
|
+
dims=[128, 256, 512, 1024],
|
|
483
|
+
kernel_sizes=[31, 29, 17, 13],
|
|
484
|
+
**kwargs,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def __repr__(self) -> str:
|
|
488
|
+
return (
|
|
489
|
+
f"UniRepLKNet_Base({self.dim}D, in_shape={self.in_shape}, "
|
|
490
|
+
f"out_size={self.out_size})"
|
|
491
|
+
)
|
wavedl/models/vit.py
CHANGED
|
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
|
|
|
10
10
|
- 2D: Images/spectrograms → patches are grid squares
|
|
11
11
|
|
|
12
12
|
**Variants**:
|
|
13
|
-
- vit_tiny: Smallest (~5.
|
|
14
|
-
- vit_small: Light (~
|
|
15
|
-
- vit_base: Standard (~
|
|
13
|
+
- vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
|
|
14
|
+
- vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
|
|
15
|
+
- vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
|
|
16
16
|
|
|
17
17
|
References:
|
|
18
18
|
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
|
|
@@ -27,12 +27,12 @@ from typing import Any
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
|
|
30
|
-
from wavedl.models.base import BaseModel
|
|
30
|
+
from wavedl.models.base import BaseModel, SpatialShape1D, SpatialShape2D
|
|
31
31
|
from wavedl.models.registry import register_model
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
#
|
|
35
|
-
SpatialShape =
|
|
34
|
+
# ViT supports 1D and 2D only
|
|
35
|
+
SpatialShape = SpatialShape1D | SpatialShape2D
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class PatchEmbed(nn.Module):
|
|
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
|
|
|
365
365
|
"""
|
|
366
366
|
ViT-Tiny: Smallest Vision Transformer variant.
|
|
367
367
|
|
|
368
|
-
~5.
|
|
368
|
+
~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
|
|
369
369
|
|
|
370
370
|
Args:
|
|
371
371
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
|
|
|
398
398
|
"""
|
|
399
399
|
ViT-Small: Light Vision Transformer variant.
|
|
400
400
|
|
|
401
|
-
~
|
|
401
|
+
~21.4M backbone parameters. Good for: Balanced performance.
|
|
402
402
|
|
|
403
403
|
Args:
|
|
404
404
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
|
|
|
429
429
|
"""
|
|
430
430
|
ViT-Base: Standard Vision Transformer variant.
|
|
431
431
|
|
|
432
|
-
~
|
|
432
|
+
~85.3M backbone parameters. Good for: High accuracy, larger datasets.
|
|
433
433
|
|
|
434
434
|
Args:
|
|
435
435
|
in_shape: (L,) for 1D or (H, W) for 2D
|