wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
wavedl/models/unet.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""
|
|
2
|
+
U-Net: Encoder-Decoder Architecture for Regression
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
A dimension-agnostic U-Net implementation for tasks requiring either:
|
|
6
|
+
- Spatial output (e.g., velocity field prediction)
|
|
7
|
+
- Vector output (global pooling → regression head)
|
|
8
|
+
|
|
9
|
+
**Dimensionality Support**:
|
|
10
|
+
- 1D: Waveforms, signals (N, 1, L) → Conv1d
|
|
11
|
+
- 2D: Images, spectrograms (N, 1, H, W) → Conv2d
|
|
12
|
+
- 3D: Volumetric data (N, 1, D, H, W) → Conv3d
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- unet: Full encoder-decoder with spatial output capability
|
|
16
|
+
- unet_regression: U-Net with global pooling for vector regression
|
|
17
|
+
|
|
18
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
+
Version: 1.0.0
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
|
|
27
|
+
from wavedl.models.base import BaseModel
|
|
28
|
+
from wavedl.models.registry import register_model
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Type alias for spatial shapes
|
|
32
|
+
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_layers(dim: int):
|
|
36
|
+
"""Get dimension-appropriate layer classes."""
|
|
37
|
+
if dim == 1:
|
|
38
|
+
return nn.Conv1d, nn.ConvTranspose1d, nn.MaxPool1d, nn.AdaptiveAvgPool1d
|
|
39
|
+
elif dim == 2:
|
|
40
|
+
return nn.Conv2d, nn.ConvTranspose2d, nn.MaxPool2d, nn.AdaptiveAvgPool2d
|
|
41
|
+
elif dim == 3:
|
|
42
|
+
return nn.Conv3d, nn.ConvTranspose3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unsupported dimensionality: {dim}D")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DoubleConv(nn.Module):
|
|
48
|
+
"""Double convolution block: Conv-GN-ReLU-Conv-GN-ReLU."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
|
|
51
|
+
super().__init__()
|
|
52
|
+
Conv = _get_layers(dim)[0]
|
|
53
|
+
|
|
54
|
+
num_groups = min(32, out_channels)
|
|
55
|
+
while out_channels % num_groups != 0 and num_groups > 1:
|
|
56
|
+
num_groups -= 1
|
|
57
|
+
|
|
58
|
+
self.double_conv = nn.Sequential(
|
|
59
|
+
Conv(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
|
60
|
+
nn.GroupNorm(num_groups, out_channels),
|
|
61
|
+
nn.ReLU(inplace=True),
|
|
62
|
+
Conv(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
|
63
|
+
nn.GroupNorm(num_groups, out_channels),
|
|
64
|
+
nn.ReLU(inplace=True),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
return self.double_conv(x)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Down(nn.Module):
|
|
72
|
+
"""Downscaling with maxpool then double conv."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
|
|
75
|
+
super().__init__()
|
|
76
|
+
_, _, MaxPool, _ = _get_layers(dim)
|
|
77
|
+
|
|
78
|
+
self.maxpool_conv = nn.Sequential(
|
|
79
|
+
MaxPool(2), DoubleConv(in_channels, out_channels, dim)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
return self.maxpool_conv(x)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Up(nn.Module):
|
|
87
|
+
"""Upscaling then double conv."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
|
|
90
|
+
super().__init__()
|
|
91
|
+
_, ConvTranspose, _, _ = _get_layers(dim)
|
|
92
|
+
|
|
93
|
+
# in_channels comes from previous layer
|
|
94
|
+
# After upconv: in_channels // 2
|
|
95
|
+
# After concat with skip (out_channels): in_channels // 2 + out_channels = in_channels
|
|
96
|
+
# Then DoubleConv: in_channels -> out_channels
|
|
97
|
+
self.up = ConvTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
|
98
|
+
self.conv = DoubleConv(in_channels, out_channels, dim)
|
|
99
|
+
|
|
100
|
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
x1 = self.up(x1)
|
|
102
|
+
|
|
103
|
+
# Handle size mismatch (pad x1 to match x2)
|
|
104
|
+
if x1.shape[2:] != x2.shape[2:]:
|
|
105
|
+
diff = [x2.size(i + 2) - x1.size(i + 2) for i in range(len(x1.shape) - 2)]
|
|
106
|
+
# Pad x1 to match x2
|
|
107
|
+
pad = []
|
|
108
|
+
for d in reversed(diff):
|
|
109
|
+
pad.extend([d // 2, d - d // 2])
|
|
110
|
+
x1 = nn.functional.pad(x1, pad)
|
|
111
|
+
|
|
112
|
+
x = torch.cat([x2, x1], dim=1)
|
|
113
|
+
return self.conv(x)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class UNetBase(BaseModel):
|
|
117
|
+
"""
|
|
118
|
+
Base U-Net class for regression tasks.
|
|
119
|
+
|
|
120
|
+
Standard U-Net architecture:
|
|
121
|
+
- Encoder path with downsampling
|
|
122
|
+
- Decoder path with upsampling and skip connections
|
|
123
|
+
- Optional spatial or vector output
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
in_shape: SpatialShape,
|
|
129
|
+
out_size: int,
|
|
130
|
+
base_channels: int = 64,
|
|
131
|
+
depth: int = 4,
|
|
132
|
+
dropout_rate: float = 0.1,
|
|
133
|
+
spatial_output: bool = False,
|
|
134
|
+
**kwargs,
|
|
135
|
+
):
|
|
136
|
+
super().__init__(in_shape, out_size)
|
|
137
|
+
|
|
138
|
+
self.dim = len(in_shape)
|
|
139
|
+
self.base_channels = base_channels
|
|
140
|
+
self.depth = depth
|
|
141
|
+
self.dropout_rate = dropout_rate
|
|
142
|
+
self.spatial_output = spatial_output
|
|
143
|
+
|
|
144
|
+
Conv, _, _, AdaptivePool = _get_layers(self.dim)
|
|
145
|
+
|
|
146
|
+
# Channel progression: 64 -> 128 -> 256 -> 512 (for depth=4)
|
|
147
|
+
# features[i] = base_channels * 2^i
|
|
148
|
+
features = [base_channels * (2**i) for i in range(depth + 1)]
|
|
149
|
+
|
|
150
|
+
# Initial double conv (1 -> features[0])
|
|
151
|
+
self.inc = DoubleConv(1, features[0], self.dim)
|
|
152
|
+
|
|
153
|
+
# Encoder (down path)
|
|
154
|
+
self.downs = nn.ModuleList()
|
|
155
|
+
for i in range(depth):
|
|
156
|
+
self.downs.append(Down(features[i], features[i + 1], self.dim))
|
|
157
|
+
|
|
158
|
+
# Decoder (up path)
|
|
159
|
+
self.ups = nn.ModuleList()
|
|
160
|
+
for i in range(depth):
|
|
161
|
+
# Input: features[depth - i], Skip: features[depth - 1 - i], Output: features[depth - 1 - i]
|
|
162
|
+
self.ups.append(Up(features[depth - i], features[depth - 1 - i], self.dim))
|
|
163
|
+
|
|
164
|
+
if spatial_output:
|
|
165
|
+
# Spatial output: 1x1 conv to out_size channels
|
|
166
|
+
self.outc = Conv(features[0], out_size, kernel_size=1)
|
|
167
|
+
else:
|
|
168
|
+
# Vector output: global pooling + regression head
|
|
169
|
+
self.global_pool = AdaptivePool(1)
|
|
170
|
+
self.head = nn.Sequential(
|
|
171
|
+
nn.Dropout(dropout_rate),
|
|
172
|
+
nn.Linear(features[0], 256),
|
|
173
|
+
nn.ReLU(inplace=True),
|
|
174
|
+
nn.Dropout(dropout_rate),
|
|
175
|
+
nn.Linear(256, out_size),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self._init_weights()
|
|
179
|
+
|
|
180
|
+
def _init_weights(self):
|
|
181
|
+
"""Initialize weights."""
|
|
182
|
+
for m in self.modules():
|
|
183
|
+
if isinstance(
|
|
184
|
+
m,
|
|
185
|
+
(
|
|
186
|
+
nn.Conv1d,
|
|
187
|
+
nn.Conv2d,
|
|
188
|
+
nn.Conv3d,
|
|
189
|
+
nn.ConvTranspose1d,
|
|
190
|
+
nn.ConvTranspose2d,
|
|
191
|
+
nn.ConvTranspose3d,
|
|
192
|
+
),
|
|
193
|
+
):
|
|
194
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
195
|
+
if m.bias is not None:
|
|
196
|
+
nn.init.constant_(m.bias, 0)
|
|
197
|
+
elif isinstance(m, nn.GroupNorm):
|
|
198
|
+
nn.init.constant_(m.weight, 1)
|
|
199
|
+
nn.init.constant_(m.bias, 0)
|
|
200
|
+
elif isinstance(m, nn.Linear):
|
|
201
|
+
nn.init.normal_(m.weight, 0, 0.01)
|
|
202
|
+
if m.bias is not None:
|
|
203
|
+
nn.init.constant_(m.bias, 0)
|
|
204
|
+
|
|
205
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
206
|
+
"""Forward pass."""
|
|
207
|
+
# Encoder path - collect skip connections
|
|
208
|
+
x1 = self.inc(x)
|
|
209
|
+
|
|
210
|
+
skips = [x1]
|
|
211
|
+
x = x1
|
|
212
|
+
for down in self.downs:
|
|
213
|
+
x = down(x)
|
|
214
|
+
skips.append(x)
|
|
215
|
+
|
|
216
|
+
# Remove last (bottleneck output, not a skip)
|
|
217
|
+
skips = skips[:-1]
|
|
218
|
+
|
|
219
|
+
# Decoder path - use skips in reverse order
|
|
220
|
+
for up, skip in zip(self.ups, reversed(skips)):
|
|
221
|
+
x = up(x, skip)
|
|
222
|
+
|
|
223
|
+
if self.spatial_output:
|
|
224
|
+
return self.outc(x)
|
|
225
|
+
else:
|
|
226
|
+
x = self.global_pool(x)
|
|
227
|
+
x = x.flatten(1)
|
|
228
|
+
return self.head(x)
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
232
|
+
"""Return default configuration."""
|
|
233
|
+
return {"base_channels": 64, "depth": 4, "dropout_rate": 0.1}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# =============================================================================
|
|
237
|
+
# REGISTERED MODEL VARIANTS
|
|
238
|
+
# =============================================================================
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@register_model("unet")
|
|
242
|
+
class UNet(UNetBase):
|
|
243
|
+
"""
|
|
244
|
+
U-Net with spatial output capability.
|
|
245
|
+
|
|
246
|
+
Good for: Pixel/voxel-wise regression (velocity fields, spatial maps).
|
|
247
|
+
|
|
248
|
+
Note: For spatial output, out_size is the number of output channels.
|
|
249
|
+
Output shape: (B, out_size, *spatial_dims) for spatial_output=True.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
in_shape: (L,), (H, W), or (D, H, W)
|
|
253
|
+
out_size: Number of output channels for spatial output
|
|
254
|
+
base_channels: Base channel count (default: 64)
|
|
255
|
+
depth: Number of encoder/decoder levels (default: 4)
|
|
256
|
+
spatial_output: If True, output spatial map; if False, output vector
|
|
257
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
in_shape: SpatialShape,
|
|
263
|
+
out_size: int,
|
|
264
|
+
spatial_output: bool = True,
|
|
265
|
+
**kwargs,
|
|
266
|
+
):
|
|
267
|
+
super().__init__(
|
|
268
|
+
in_shape=in_shape,
|
|
269
|
+
out_size=out_size,
|
|
270
|
+
spatial_output=spatial_output,
|
|
271
|
+
**kwargs,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
mode = "spatial" if self.spatial_output else "vector"
|
|
276
|
+
return f"UNet({self.dim}D, {mode}, in_shape={self.in_shape}, out_size={self.out_size})"
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@register_model("unet_regression")
|
|
280
|
+
class UNetRegression(UNetBase):
|
|
281
|
+
"""
|
|
282
|
+
U-Net for vector regression output.
|
|
283
|
+
|
|
284
|
+
Uses U-Net encoder-decoder but applies global pooling at the end
|
|
285
|
+
for standard vector regression output.
|
|
286
|
+
|
|
287
|
+
Good for: Leveraging U-Net features (multi-scale, skip connections)
|
|
288
|
+
for standard regression tasks.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
in_shape: (L,), (H, W), or (D, H, W)
|
|
292
|
+
out_size: Number of regression targets
|
|
293
|
+
base_channels: Base channel count (default: 64)
|
|
294
|
+
depth: Number of encoder/decoder levels (default: 4)
|
|
295
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
299
|
+
super().__init__(
|
|
300
|
+
in_shape=in_shape, out_size=out_size, spatial_output=False, **kwargs
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def __repr__(self) -> str:
|
|
304
|
+
return f"UNet_Regression({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
|
wavedl/models/vit.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vision Transformer (ViT): Transformer-Based Architecture for Regression
|
|
3
|
+
========================================================================
|
|
4
|
+
|
|
5
|
+
A flexible Vision Transformer implementation for regression tasks.
|
|
6
|
+
Supports both 1D (signals) and 2D (images) inputs via configurable patch embedding.
|
|
7
|
+
|
|
8
|
+
**Dimensionality Support**:
|
|
9
|
+
- 1D: Waveforms/signals → patches are segments
|
|
10
|
+
- 2D: Images/spectrograms → patches are grid squares
|
|
11
|
+
|
|
12
|
+
**Variants**:
|
|
13
|
+
- vit_tiny: Smallest (embed_dim=192, depth=12, heads=3)
|
|
14
|
+
- vit_small: Light (embed_dim=384, depth=12, heads=6)
|
|
15
|
+
- vit_base: Standard (embed_dim=768, depth=12, heads=12)
|
|
16
|
+
|
|
17
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
18
|
+
Version: 1.0.0
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
from wavedl.models.base import BaseModel
|
|
27
|
+
from wavedl.models.registry import register_model
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Type alias for spatial shapes
|
|
31
|
+
SpatialShape = tuple[int] | tuple[int, int]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PatchEmbed(nn.Module):
|
|
35
|
+
"""
|
|
36
|
+
Patch Embedding module that converts input into sequence of patch embeddings.
|
|
37
|
+
|
|
38
|
+
Supports 1D and 2D inputs:
|
|
39
|
+
- 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
|
|
40
|
+
- 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, in_shape: SpatialShape, patch_size: int, embed_dim: int):
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
self.dim = len(in_shape)
|
|
47
|
+
self.patch_size = patch_size
|
|
48
|
+
self.embed_dim = embed_dim
|
|
49
|
+
|
|
50
|
+
if self.dim == 1:
|
|
51
|
+
# 1D: segment patches
|
|
52
|
+
L = in_shape[0]
|
|
53
|
+
self.num_patches = L // patch_size
|
|
54
|
+
self.proj = nn.Conv1d(
|
|
55
|
+
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
56
|
+
)
|
|
57
|
+
elif self.dim == 2:
|
|
58
|
+
# 2D: grid patches
|
|
59
|
+
H, W = in_shape
|
|
60
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
61
|
+
self.proj = nn.Conv2d(
|
|
62
|
+
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"ViT supports 1D and 2D inputs, got {self.dim}D")
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
"""
|
|
69
|
+
Args:
|
|
70
|
+
x: Input tensor (B, 1, ..spatial..)
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Patch embeddings (B, num_patches, embed_dim)
|
|
74
|
+
"""
|
|
75
|
+
x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
|
|
76
|
+
x = x.flatten(2) # (B, embed_dim, num_patches)
|
|
77
|
+
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
|
|
78
|
+
return x
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MultiHeadAttention(nn.Module):
|
|
82
|
+
"""Multi-head self-attention mechanism."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.num_heads = num_heads
|
|
87
|
+
self.head_dim = embed_dim // num_heads
|
|
88
|
+
self.scale = self.head_dim**-0.5
|
|
89
|
+
|
|
90
|
+
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
|
|
91
|
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
92
|
+
self.dropout = nn.Dropout(dropout)
|
|
93
|
+
|
|
94
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
95
|
+
B, N, C = x.shape
|
|
96
|
+
|
|
97
|
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
98
|
+
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
|
|
99
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
100
|
+
|
|
101
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
102
|
+
attn = attn.softmax(dim=-1)
|
|
103
|
+
attn = self.dropout(attn)
|
|
104
|
+
|
|
105
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
106
|
+
x = self.proj(x)
|
|
107
|
+
x = self.dropout(x)
|
|
108
|
+
return x
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MLP(nn.Module):
|
|
112
|
+
"""MLP block with GELU activation."""
|
|
113
|
+
|
|
114
|
+
def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
|
|
115
|
+
super().__init__()
|
|
116
|
+
hidden_dim = int(embed_dim * mlp_ratio)
|
|
117
|
+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
|
118
|
+
self.act = nn.GELU()
|
|
119
|
+
self.fc2 = nn.Linear(hidden_dim, embed_dim)
|
|
120
|
+
self.dropout = nn.Dropout(dropout)
|
|
121
|
+
|
|
122
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
123
|
+
x = self.fc1(x)
|
|
124
|
+
x = self.act(x)
|
|
125
|
+
x = self.dropout(x)
|
|
126
|
+
x = self.fc2(x)
|
|
127
|
+
x = self.dropout(x)
|
|
128
|
+
return x
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TransformerBlock(nn.Module):
|
|
132
|
+
"""Transformer encoder block with pre-norm architecture."""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
embed_dim: int,
|
|
137
|
+
num_heads: int,
|
|
138
|
+
mlp_ratio: float = 4.0,
|
|
139
|
+
dropout: float = 0.0,
|
|
140
|
+
):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.norm1 = nn.LayerNorm(embed_dim)
|
|
143
|
+
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
|
|
144
|
+
self.norm2 = nn.LayerNorm(embed_dim)
|
|
145
|
+
self.mlp = MLP(embed_dim, mlp_ratio, dropout)
|
|
146
|
+
|
|
147
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
x = x + self.attn(self.norm1(x))
|
|
149
|
+
x = x + self.mlp(self.norm2(x))
|
|
150
|
+
return x
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ViTBase(BaseModel):
|
|
154
|
+
"""
|
|
155
|
+
Vision Transformer base class for regression.
|
|
156
|
+
|
|
157
|
+
Architecture:
|
|
158
|
+
1. Patch embedding
|
|
159
|
+
2. Add learnable position embeddings + CLS token
|
|
160
|
+
3. Transformer encoder blocks
|
|
161
|
+
4. Extract CLS token
|
|
162
|
+
5. Regression head
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
in_shape: SpatialShape,
|
|
168
|
+
out_size: int,
|
|
169
|
+
patch_size: int = 16,
|
|
170
|
+
embed_dim: int = 768,
|
|
171
|
+
depth: int = 12,
|
|
172
|
+
num_heads: int = 12,
|
|
173
|
+
mlp_ratio: float = 4.0,
|
|
174
|
+
dropout_rate: float = 0.1,
|
|
175
|
+
**kwargs,
|
|
176
|
+
):
|
|
177
|
+
super().__init__(in_shape, out_size)
|
|
178
|
+
|
|
179
|
+
self.patch_size = patch_size
|
|
180
|
+
self.embed_dim = embed_dim
|
|
181
|
+
self.depth = depth
|
|
182
|
+
self.num_heads = num_heads
|
|
183
|
+
self.dropout_rate = dropout_rate
|
|
184
|
+
self.dim = len(in_shape)
|
|
185
|
+
|
|
186
|
+
# Patch embedding
|
|
187
|
+
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
|
|
188
|
+
num_patches = self.patch_embed.num_patches
|
|
189
|
+
|
|
190
|
+
# Learnable CLS token and position embeddings
|
|
191
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
192
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
193
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
194
|
+
|
|
195
|
+
# Transformer encoder
|
|
196
|
+
self.blocks = nn.ModuleList(
|
|
197
|
+
[
|
|
198
|
+
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate)
|
|
199
|
+
for _ in range(depth)
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.norm = nn.LayerNorm(embed_dim)
|
|
204
|
+
|
|
205
|
+
# Regression head
|
|
206
|
+
self.head = nn.Sequential(
|
|
207
|
+
nn.Dropout(dropout_rate),
|
|
208
|
+
nn.Linear(embed_dim, 256),
|
|
209
|
+
nn.GELU(),
|
|
210
|
+
nn.Dropout(dropout_rate),
|
|
211
|
+
nn.Linear(256, out_size),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Initialize weights
|
|
215
|
+
self._init_weights()
|
|
216
|
+
|
|
217
|
+
def _init_weights(self):
|
|
218
|
+
"""Initialize weights."""
|
|
219
|
+
# Initialize position embeddings
|
|
220
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
221
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
222
|
+
|
|
223
|
+
# Initialize linear layers
|
|
224
|
+
for m in self.modules():
|
|
225
|
+
if isinstance(m, nn.Linear):
|
|
226
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
227
|
+
if m.bias is not None:
|
|
228
|
+
nn.init.constant_(m.bias, 0)
|
|
229
|
+
elif isinstance(m, nn.LayerNorm):
|
|
230
|
+
nn.init.constant_(m.weight, 1)
|
|
231
|
+
nn.init.constant_(m.bias, 0)
|
|
232
|
+
|
|
233
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
234
|
+
"""
|
|
235
|
+
Forward pass.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
x: Input tensor (B, 1, ..spatial..)
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Regression output (B, out_size)
|
|
242
|
+
"""
|
|
243
|
+
B = x.size(0)
|
|
244
|
+
|
|
245
|
+
# Patch embedding
|
|
246
|
+
x = self.patch_embed(x) # (B, num_patches, embed_dim)
|
|
247
|
+
|
|
248
|
+
# Prepend CLS token
|
|
249
|
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
250
|
+
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, embed_dim)
|
|
251
|
+
|
|
252
|
+
# Add position embeddings
|
|
253
|
+
x = x + self.pos_embed
|
|
254
|
+
x = self.dropout(x)
|
|
255
|
+
|
|
256
|
+
# Transformer blocks
|
|
257
|
+
for block in self.blocks:
|
|
258
|
+
x = block(x)
|
|
259
|
+
|
|
260
|
+
x = self.norm(x)
|
|
261
|
+
|
|
262
|
+
# Extract CLS token for regression
|
|
263
|
+
cls_output = x[:, 0]
|
|
264
|
+
|
|
265
|
+
return self.head(cls_output)
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
269
|
+
"""Return default configuration."""
|
|
270
|
+
return {"patch_size": 16, "dropout_rate": 0.1}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# =============================================================================
|
|
274
|
+
# REGISTERED MODEL VARIANTS
|
|
275
|
+
# =============================================================================
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@register_model("vit_tiny")
|
|
279
|
+
class ViTTiny(ViTBase):
|
|
280
|
+
"""
|
|
281
|
+
ViT-Tiny: Smallest Vision Transformer variant.
|
|
282
|
+
|
|
283
|
+
~5.7M parameters. Good for: Quick experiments, smaller datasets.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
in_shape: (L,) for 1D or (H, W) for 2D
|
|
287
|
+
out_size: Number of regression targets
|
|
288
|
+
patch_size: Patch size (default: 16)
|
|
289
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(
|
|
293
|
+
self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
|
|
294
|
+
):
|
|
295
|
+
super().__init__(
|
|
296
|
+
in_shape=in_shape,
|
|
297
|
+
out_size=out_size,
|
|
298
|
+
patch_size=patch_size,
|
|
299
|
+
embed_dim=192,
|
|
300
|
+
depth=12,
|
|
301
|
+
num_heads=3,
|
|
302
|
+
**kwargs,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def __repr__(self) -> str:
|
|
306
|
+
return (
|
|
307
|
+
f"ViT_Tiny({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@register_model("vit_small")
|
|
312
|
+
class ViTSmall(ViTBase):
|
|
313
|
+
"""
|
|
314
|
+
ViT-Small: Light Vision Transformer variant.
|
|
315
|
+
|
|
316
|
+
~22M parameters. Good for: Balanced performance.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
in_shape: (L,) for 1D or (H, W) for 2D
|
|
320
|
+
out_size: Number of regression targets
|
|
321
|
+
patch_size: Patch size (default: 16)
|
|
322
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def __init__(
|
|
326
|
+
self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
|
|
327
|
+
):
|
|
328
|
+
super().__init__(
|
|
329
|
+
in_shape=in_shape,
|
|
330
|
+
out_size=out_size,
|
|
331
|
+
patch_size=patch_size,
|
|
332
|
+
embed_dim=384,
|
|
333
|
+
depth=12,
|
|
334
|
+
num_heads=6,
|
|
335
|
+
**kwargs,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def __repr__(self) -> str:
|
|
339
|
+
return f"ViT_Small({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@register_model("vit_base")
|
|
343
|
+
class ViTBase_(ViTBase):
|
|
344
|
+
"""
|
|
345
|
+
ViT-Base: Standard Vision Transformer variant.
|
|
346
|
+
|
|
347
|
+
~86M parameters. Good for: High accuracy, larger datasets.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
in_shape: (L,) for 1D or (H, W) for 2D
|
|
351
|
+
out_size: Number of regression targets
|
|
352
|
+
patch_size: Patch size (default: 16)
|
|
353
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(
|
|
357
|
+
self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
|
|
358
|
+
):
|
|
359
|
+
super().__init__(
|
|
360
|
+
in_shape=in_shape,
|
|
361
|
+
out_size=out_size,
|
|
362
|
+
patch_size=patch_size,
|
|
363
|
+
embed_dim=768,
|
|
364
|
+
depth=12,
|
|
365
|
+
num_heads=12,
|
|
366
|
+
**kwargs,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def __repr__(self) -> str:
|
|
370
|
+
return (
|
|
371
|
+
f"ViT_Base({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
|
|
372
|
+
)
|