wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/unet.py ADDED
@@ -0,0 +1,304 @@
1
+ """
2
+ U-Net: Encoder-Decoder Architecture for Regression
3
+ ====================================================
4
+
5
+ A dimension-agnostic U-Net implementation for tasks requiring either:
6
+ - Spatial output (e.g., velocity field prediction)
7
+ - Vector output (global pooling → regression head)
8
+
9
+ **Dimensionality Support**:
10
+ - 1D: Waveforms, signals (N, 1, L) → Conv1d
11
+ - 2D: Images, spectrograms (N, 1, H, W) → Conv2d
12
+ - 3D: Volumetric data (N, 1, D, H, W) → Conv3d
13
+
14
+ **Variants**:
15
+ - unet: Full encoder-decoder with spatial output capability
16
+ - unet_regression: U-Net with global pooling for vector regression
17
+
18
+ Author: Ductho Le (ductho.le@outlook.com)
19
+ Version: 1.0.0
20
+ """
21
+
22
+ from typing import Any
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+
27
+ from wavedl.models.base import BaseModel
28
+ from wavedl.models.registry import register_model
29
+
30
+
31
+ # Type alias for spatial shapes
32
+ SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
33
+
34
+
35
+ def _get_layers(dim: int):
36
+ """Get dimension-appropriate layer classes."""
37
+ if dim == 1:
38
+ return nn.Conv1d, nn.ConvTranspose1d, nn.MaxPool1d, nn.AdaptiveAvgPool1d
39
+ elif dim == 2:
40
+ return nn.Conv2d, nn.ConvTranspose2d, nn.MaxPool2d, nn.AdaptiveAvgPool2d
41
+ elif dim == 3:
42
+ return nn.Conv3d, nn.ConvTranspose3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d
43
+ else:
44
+ raise ValueError(f"Unsupported dimensionality: {dim}D")
45
+
46
+
47
+ class DoubleConv(nn.Module):
48
+ """Double convolution block: Conv-GN-ReLU-Conv-GN-ReLU."""
49
+
50
+ def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
51
+ super().__init__()
52
+ Conv = _get_layers(dim)[0]
53
+
54
+ num_groups = min(32, out_channels)
55
+ while out_channels % num_groups != 0 and num_groups > 1:
56
+ num_groups -= 1
57
+
58
+ self.double_conv = nn.Sequential(
59
+ Conv(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
60
+ nn.GroupNorm(num_groups, out_channels),
61
+ nn.ReLU(inplace=True),
62
+ Conv(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
63
+ nn.GroupNorm(num_groups, out_channels),
64
+ nn.ReLU(inplace=True),
65
+ )
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ return self.double_conv(x)
69
+
70
+
71
+ class Down(nn.Module):
72
+ """Downscaling with maxpool then double conv."""
73
+
74
+ def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
75
+ super().__init__()
76
+ _, _, MaxPool, _ = _get_layers(dim)
77
+
78
+ self.maxpool_conv = nn.Sequential(
79
+ MaxPool(2), DoubleConv(in_channels, out_channels, dim)
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ return self.maxpool_conv(x)
84
+
85
+
86
+ class Up(nn.Module):
87
+ """Upscaling then double conv."""
88
+
89
+ def __init__(self, in_channels: int, out_channels: int, dim: int = 2):
90
+ super().__init__()
91
+ _, ConvTranspose, _, _ = _get_layers(dim)
92
+
93
+ # in_channels comes from previous layer
94
+ # After upconv: in_channels // 2
95
+ # After concat with skip (out_channels): in_channels // 2 + out_channels = in_channels
96
+ # Then DoubleConv: in_channels -> out_channels
97
+ self.up = ConvTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
98
+ self.conv = DoubleConv(in_channels, out_channels, dim)
99
+
100
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
101
+ x1 = self.up(x1)
102
+
103
+ # Handle size mismatch (pad x1 to match x2)
104
+ if x1.shape[2:] != x2.shape[2:]:
105
+ diff = [x2.size(i + 2) - x1.size(i + 2) for i in range(len(x1.shape) - 2)]
106
+ # Pad x1 to match x2
107
+ pad = []
108
+ for d in reversed(diff):
109
+ pad.extend([d // 2, d - d // 2])
110
+ x1 = nn.functional.pad(x1, pad)
111
+
112
+ x = torch.cat([x2, x1], dim=1)
113
+ return self.conv(x)
114
+
115
+
116
+ class UNetBase(BaseModel):
117
+ """
118
+ Base U-Net class for regression tasks.
119
+
120
+ Standard U-Net architecture:
121
+ - Encoder path with downsampling
122
+ - Decoder path with upsampling and skip connections
123
+ - Optional spatial or vector output
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ in_shape: SpatialShape,
129
+ out_size: int,
130
+ base_channels: int = 64,
131
+ depth: int = 4,
132
+ dropout_rate: float = 0.1,
133
+ spatial_output: bool = False,
134
+ **kwargs,
135
+ ):
136
+ super().__init__(in_shape, out_size)
137
+
138
+ self.dim = len(in_shape)
139
+ self.base_channels = base_channels
140
+ self.depth = depth
141
+ self.dropout_rate = dropout_rate
142
+ self.spatial_output = spatial_output
143
+
144
+ Conv, _, _, AdaptivePool = _get_layers(self.dim)
145
+
146
+ # Channel progression: 64 -> 128 -> 256 -> 512 (for depth=4)
147
+ # features[i] = base_channels * 2^i
148
+ features = [base_channels * (2**i) for i in range(depth + 1)]
149
+
150
+ # Initial double conv (1 -> features[0])
151
+ self.inc = DoubleConv(1, features[0], self.dim)
152
+
153
+ # Encoder (down path)
154
+ self.downs = nn.ModuleList()
155
+ for i in range(depth):
156
+ self.downs.append(Down(features[i], features[i + 1], self.dim))
157
+
158
+ # Decoder (up path)
159
+ self.ups = nn.ModuleList()
160
+ for i in range(depth):
161
+ # Input: features[depth - i], Skip: features[depth - 1 - i], Output: features[depth - 1 - i]
162
+ self.ups.append(Up(features[depth - i], features[depth - 1 - i], self.dim))
163
+
164
+ if spatial_output:
165
+ # Spatial output: 1x1 conv to out_size channels
166
+ self.outc = Conv(features[0], out_size, kernel_size=1)
167
+ else:
168
+ # Vector output: global pooling + regression head
169
+ self.global_pool = AdaptivePool(1)
170
+ self.head = nn.Sequential(
171
+ nn.Dropout(dropout_rate),
172
+ nn.Linear(features[0], 256),
173
+ nn.ReLU(inplace=True),
174
+ nn.Dropout(dropout_rate),
175
+ nn.Linear(256, out_size),
176
+ )
177
+
178
+ self._init_weights()
179
+
180
+ def _init_weights(self):
181
+ """Initialize weights."""
182
+ for m in self.modules():
183
+ if isinstance(
184
+ m,
185
+ (
186
+ nn.Conv1d,
187
+ nn.Conv2d,
188
+ nn.Conv3d,
189
+ nn.ConvTranspose1d,
190
+ nn.ConvTranspose2d,
191
+ nn.ConvTranspose3d,
192
+ ),
193
+ ):
194
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
195
+ if m.bias is not None:
196
+ nn.init.constant_(m.bias, 0)
197
+ elif isinstance(m, nn.GroupNorm):
198
+ nn.init.constant_(m.weight, 1)
199
+ nn.init.constant_(m.bias, 0)
200
+ elif isinstance(m, nn.Linear):
201
+ nn.init.normal_(m.weight, 0, 0.01)
202
+ if m.bias is not None:
203
+ nn.init.constant_(m.bias, 0)
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ """Forward pass."""
207
+ # Encoder path - collect skip connections
208
+ x1 = self.inc(x)
209
+
210
+ skips = [x1]
211
+ x = x1
212
+ for down in self.downs:
213
+ x = down(x)
214
+ skips.append(x)
215
+
216
+ # Remove last (bottleneck output, not a skip)
217
+ skips = skips[:-1]
218
+
219
+ # Decoder path - use skips in reverse order
220
+ for up, skip in zip(self.ups, reversed(skips)):
221
+ x = up(x, skip)
222
+
223
+ if self.spatial_output:
224
+ return self.outc(x)
225
+ else:
226
+ x = self.global_pool(x)
227
+ x = x.flatten(1)
228
+ return self.head(x)
229
+
230
+ @classmethod
231
+ def get_default_config(cls) -> dict[str, Any]:
232
+ """Return default configuration."""
233
+ return {"base_channels": 64, "depth": 4, "dropout_rate": 0.1}
234
+
235
+
236
+ # =============================================================================
237
+ # REGISTERED MODEL VARIANTS
238
+ # =============================================================================
239
+
240
+
241
+ @register_model("unet")
242
+ class UNet(UNetBase):
243
+ """
244
+ U-Net with spatial output capability.
245
+
246
+ Good for: Pixel/voxel-wise regression (velocity fields, spatial maps).
247
+
248
+ Note: For spatial output, out_size is the number of output channels.
249
+ Output shape: (B, out_size, *spatial_dims) for spatial_output=True.
250
+
251
+ Args:
252
+ in_shape: (L,), (H, W), or (D, H, W)
253
+ out_size: Number of output channels for spatial output
254
+ base_channels: Base channel count (default: 64)
255
+ depth: Number of encoder/decoder levels (default: 4)
256
+ spatial_output: If True, output spatial map; if False, output vector
257
+ dropout_rate: Dropout rate (default: 0.1)
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ in_shape: SpatialShape,
263
+ out_size: int,
264
+ spatial_output: bool = True,
265
+ **kwargs,
266
+ ):
267
+ super().__init__(
268
+ in_shape=in_shape,
269
+ out_size=out_size,
270
+ spatial_output=spatial_output,
271
+ **kwargs,
272
+ )
273
+
274
+ def __repr__(self) -> str:
275
+ mode = "spatial" if self.spatial_output else "vector"
276
+ return f"UNet({self.dim}D, {mode}, in_shape={self.in_shape}, out_size={self.out_size})"
277
+
278
+
279
+ @register_model("unet_regression")
280
+ class UNetRegression(UNetBase):
281
+ """
282
+ U-Net for vector regression output.
283
+
284
+ Uses U-Net encoder-decoder but applies global pooling at the end
285
+ for standard vector regression output.
286
+
287
+ Good for: Leveraging U-Net features (multi-scale, skip connections)
288
+ for standard regression tasks.
289
+
290
+ Args:
291
+ in_shape: (L,), (H, W), or (D, H, W)
292
+ out_size: Number of regression targets
293
+ base_channels: Base channel count (default: 64)
294
+ depth: Number of encoder/decoder levels (default: 4)
295
+ dropout_rate: Dropout rate (default: 0.1)
296
+ """
297
+
298
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
299
+ super().__init__(
300
+ in_shape=in_shape, out_size=out_size, spatial_output=False, **kwargs
301
+ )
302
+
303
+ def __repr__(self) -> str:
304
+ return f"UNet_Regression({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
wavedl/models/vit.py ADDED
@@ -0,0 +1,372 @@
1
+ """
2
+ Vision Transformer (ViT): Transformer-Based Architecture for Regression
3
+ ========================================================================
4
+
5
+ A flexible Vision Transformer implementation for regression tasks.
6
+ Supports both 1D (signals) and 2D (images) inputs via configurable patch embedding.
7
+
8
+ **Dimensionality Support**:
9
+ - 1D: Waveforms/signals → patches are segments
10
+ - 2D: Images/spectrograms → patches are grid squares
11
+
12
+ **Variants**:
13
+ - vit_tiny: Smallest (embed_dim=192, depth=12, heads=3)
14
+ - vit_small: Light (embed_dim=384, depth=12, heads=6)
15
+ - vit_base: Standard (embed_dim=768, depth=12, heads=12)
16
+
17
+ Author: Ductho Le (ductho.le@outlook.com)
18
+ Version: 1.0.0
19
+ """
20
+
21
+ from typing import Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from wavedl.models.base import BaseModel
27
+ from wavedl.models.registry import register_model
28
+
29
+
30
+ # Type alias for spatial shapes
31
+ SpatialShape = tuple[int] | tuple[int, int]
32
+
33
+
34
+ class PatchEmbed(nn.Module):
35
+ """
36
+ Patch Embedding module that converts input into sequence of patch embeddings.
37
+
38
+ Supports 1D and 2D inputs:
39
+ - 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
40
+ - 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
41
+ """
42
+
43
+ def __init__(self, in_shape: SpatialShape, patch_size: int, embed_dim: int):
44
+ super().__init__()
45
+
46
+ self.dim = len(in_shape)
47
+ self.patch_size = patch_size
48
+ self.embed_dim = embed_dim
49
+
50
+ if self.dim == 1:
51
+ # 1D: segment patches
52
+ L = in_shape[0]
53
+ self.num_patches = L // patch_size
54
+ self.proj = nn.Conv1d(
55
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
56
+ )
57
+ elif self.dim == 2:
58
+ # 2D: grid patches
59
+ H, W = in_shape
60
+ self.num_patches = (H // patch_size) * (W // patch_size)
61
+ self.proj = nn.Conv2d(
62
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
63
+ )
64
+ else:
65
+ raise ValueError(f"ViT supports 1D and 2D inputs, got {self.dim}D")
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ Args:
70
+ x: Input tensor (B, 1, ..spatial..)
71
+
72
+ Returns:
73
+ Patch embeddings (B, num_patches, embed_dim)
74
+ """
75
+ x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
76
+ x = x.flatten(2) # (B, embed_dim, num_patches)
77
+ x = x.transpose(1, 2) # (B, num_patches, embed_dim)
78
+ return x
79
+
80
+
81
+ class MultiHeadAttention(nn.Module):
82
+ """Multi-head self-attention mechanism."""
83
+
84
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
85
+ super().__init__()
86
+ self.num_heads = num_heads
87
+ self.head_dim = embed_dim // num_heads
88
+ self.scale = self.head_dim**-0.5
89
+
90
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
91
+ self.proj = nn.Linear(embed_dim, embed_dim)
92
+ self.dropout = nn.Dropout(dropout)
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B, N, C = x.shape
96
+
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
98
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
99
+ q, k, v = qkv[0], qkv[1], qkv[2]
100
+
101
+ attn = (q @ k.transpose(-2, -1)) * self.scale
102
+ attn = attn.softmax(dim=-1)
103
+ attn = self.dropout(attn)
104
+
105
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
106
+ x = self.proj(x)
107
+ x = self.dropout(x)
108
+ return x
109
+
110
+
111
+ class MLP(nn.Module):
112
+ """MLP block with GELU activation."""
113
+
114
+ def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
115
+ super().__init__()
116
+ hidden_dim = int(embed_dim * mlp_ratio)
117
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
118
+ self.act = nn.GELU()
119
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
120
+ self.dropout = nn.Dropout(dropout)
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ x = self.fc1(x)
124
+ x = self.act(x)
125
+ x = self.dropout(x)
126
+ x = self.fc2(x)
127
+ x = self.dropout(x)
128
+ return x
129
+
130
+
131
+ class TransformerBlock(nn.Module):
132
+ """Transformer encoder block with pre-norm architecture."""
133
+
134
+ def __init__(
135
+ self,
136
+ embed_dim: int,
137
+ num_heads: int,
138
+ mlp_ratio: float = 4.0,
139
+ dropout: float = 0.0,
140
+ ):
141
+ super().__init__()
142
+ self.norm1 = nn.LayerNorm(embed_dim)
143
+ self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
144
+ self.norm2 = nn.LayerNorm(embed_dim)
145
+ self.mlp = MLP(embed_dim, mlp_ratio, dropout)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ x = x + self.attn(self.norm1(x))
149
+ x = x + self.mlp(self.norm2(x))
150
+ return x
151
+
152
+
153
+ class ViTBase(BaseModel):
154
+ """
155
+ Vision Transformer base class for regression.
156
+
157
+ Architecture:
158
+ 1. Patch embedding
159
+ 2. Add learnable position embeddings + CLS token
160
+ 3. Transformer encoder blocks
161
+ 4. Extract CLS token
162
+ 5. Regression head
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ in_shape: SpatialShape,
168
+ out_size: int,
169
+ patch_size: int = 16,
170
+ embed_dim: int = 768,
171
+ depth: int = 12,
172
+ num_heads: int = 12,
173
+ mlp_ratio: float = 4.0,
174
+ dropout_rate: float = 0.1,
175
+ **kwargs,
176
+ ):
177
+ super().__init__(in_shape, out_size)
178
+
179
+ self.patch_size = patch_size
180
+ self.embed_dim = embed_dim
181
+ self.depth = depth
182
+ self.num_heads = num_heads
183
+ self.dropout_rate = dropout_rate
184
+ self.dim = len(in_shape)
185
+
186
+ # Patch embedding
187
+ self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
188
+ num_patches = self.patch_embed.num_patches
189
+
190
+ # Learnable CLS token and position embeddings
191
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
192
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
193
+ self.dropout = nn.Dropout(dropout_rate)
194
+
195
+ # Transformer encoder
196
+ self.blocks = nn.ModuleList(
197
+ [
198
+ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate)
199
+ for _ in range(depth)
200
+ ]
201
+ )
202
+
203
+ self.norm = nn.LayerNorm(embed_dim)
204
+
205
+ # Regression head
206
+ self.head = nn.Sequential(
207
+ nn.Dropout(dropout_rate),
208
+ nn.Linear(embed_dim, 256),
209
+ nn.GELU(),
210
+ nn.Dropout(dropout_rate),
211
+ nn.Linear(256, out_size),
212
+ )
213
+
214
+ # Initialize weights
215
+ self._init_weights()
216
+
217
+ def _init_weights(self):
218
+ """Initialize weights."""
219
+ # Initialize position embeddings
220
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
221
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
222
+
223
+ # Initialize linear layers
224
+ for m in self.modules():
225
+ if isinstance(m, nn.Linear):
226
+ nn.init.trunc_normal_(m.weight, std=0.02)
227
+ if m.bias is not None:
228
+ nn.init.constant_(m.bias, 0)
229
+ elif isinstance(m, nn.LayerNorm):
230
+ nn.init.constant_(m.weight, 1)
231
+ nn.init.constant_(m.bias, 0)
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ """
235
+ Forward pass.
236
+
237
+ Args:
238
+ x: Input tensor (B, 1, ..spatial..)
239
+
240
+ Returns:
241
+ Regression output (B, out_size)
242
+ """
243
+ B = x.size(0)
244
+
245
+ # Patch embedding
246
+ x = self.patch_embed(x) # (B, num_patches, embed_dim)
247
+
248
+ # Prepend CLS token
249
+ cls_tokens = self.cls_token.expand(B, -1, -1)
250
+ x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, embed_dim)
251
+
252
+ # Add position embeddings
253
+ x = x + self.pos_embed
254
+ x = self.dropout(x)
255
+
256
+ # Transformer blocks
257
+ for block in self.blocks:
258
+ x = block(x)
259
+
260
+ x = self.norm(x)
261
+
262
+ # Extract CLS token for regression
263
+ cls_output = x[:, 0]
264
+
265
+ return self.head(cls_output)
266
+
267
+ @classmethod
268
+ def get_default_config(cls) -> dict[str, Any]:
269
+ """Return default configuration."""
270
+ return {"patch_size": 16, "dropout_rate": 0.1}
271
+
272
+
273
+ # =============================================================================
274
+ # REGISTERED MODEL VARIANTS
275
+ # =============================================================================
276
+
277
+
278
+ @register_model("vit_tiny")
279
+ class ViTTiny(ViTBase):
280
+ """
281
+ ViT-Tiny: Smallest Vision Transformer variant.
282
+
283
+ ~5.7M parameters. Good for: Quick experiments, smaller datasets.
284
+
285
+ Args:
286
+ in_shape: (L,) for 1D or (H, W) for 2D
287
+ out_size: Number of regression targets
288
+ patch_size: Patch size (default: 16)
289
+ dropout_rate: Dropout rate (default: 0.1)
290
+ """
291
+
292
+ def __init__(
293
+ self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
294
+ ):
295
+ super().__init__(
296
+ in_shape=in_shape,
297
+ out_size=out_size,
298
+ patch_size=patch_size,
299
+ embed_dim=192,
300
+ depth=12,
301
+ num_heads=3,
302
+ **kwargs,
303
+ )
304
+
305
+ def __repr__(self) -> str:
306
+ return (
307
+ f"ViT_Tiny({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
308
+ )
309
+
310
+
311
+ @register_model("vit_small")
312
+ class ViTSmall(ViTBase):
313
+ """
314
+ ViT-Small: Light Vision Transformer variant.
315
+
316
+ ~22M parameters. Good for: Balanced performance.
317
+
318
+ Args:
319
+ in_shape: (L,) for 1D or (H, W) for 2D
320
+ out_size: Number of regression targets
321
+ patch_size: Patch size (default: 16)
322
+ dropout_rate: Dropout rate (default: 0.1)
323
+ """
324
+
325
+ def __init__(
326
+ self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
327
+ ):
328
+ super().__init__(
329
+ in_shape=in_shape,
330
+ out_size=out_size,
331
+ patch_size=patch_size,
332
+ embed_dim=384,
333
+ depth=12,
334
+ num_heads=6,
335
+ **kwargs,
336
+ )
337
+
338
+ def __repr__(self) -> str:
339
+ return f"ViT_Small({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
340
+
341
+
342
+ @register_model("vit_base")
343
+ class ViTBase_(ViTBase):
344
+ """
345
+ ViT-Base: Standard Vision Transformer variant.
346
+
347
+ ~86M parameters. Good for: High accuracy, larger datasets.
348
+
349
+ Args:
350
+ in_shape: (L,) for 1D or (H, W) for 2D
351
+ out_size: Number of regression targets
352
+ patch_size: Patch size (default: 16)
353
+ dropout_rate: Dropout rate (default: 0.1)
354
+ """
355
+
356
+ def __init__(
357
+ self, in_shape: SpatialShape, out_size: int, patch_size: int = 16, **kwargs
358
+ ):
359
+ super().__init__(
360
+ in_shape=in_shape,
361
+ out_size=out_size,
362
+ patch_size=patch_size,
363
+ embed_dim=768,
364
+ depth=12,
365
+ num_heads=12,
366
+ **kwargs,
367
+ )
368
+
369
+ def __repr__(self) -> str:
370
+ return (
371
+ f"ViT_Base({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
372
+ )