wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/unet.py CHANGED
@@ -22,14 +22,10 @@ from typing import Any
22
22
  import torch
23
23
  import torch.nn as nn
24
24
 
25
- from wavedl.models.base import BaseModel
25
+ from wavedl.models.base import BaseModel, SpatialShape
26
26
  from wavedl.models.registry import register_model
27
27
 
28
28
 
29
- # Type alias for spatial shapes
30
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
31
-
32
-
33
29
  def _get_layers(dim: int):
34
30
  """Get dimension-appropriate layer classes."""
35
31
  if dim == 1:
@@ -119,7 +115,7 @@ class UNetRegression(BaseModel):
119
115
  Uses U-Net encoder-decoder architecture with skip connections,
120
116
  then applies global pooling for standard vector regression output.
121
117
 
122
- ~31.1M parameters (2D). Good for leveraging multi-scale features
118
+ ~31.0M backbone parameters (2D). Good for leveraging multi-scale features
123
119
  and skip connections for regression tasks.
124
120
 
125
121
  Args:
@@ -0,0 +1,491 @@
1
+ """
2
+ UniRepLKNet: Universal Large-Kernel ConvNet for Regression
3
+ ===========================================================
4
+
5
+ A dimension-agnostic implementation of UniRepLKNet featuring ultra-large kernels
6
+ (up to 31x31) for capturing long-range dependencies. Particularly effective for
7
+ wave-based problems where spatial correlations span large distances.
8
+
9
+ **Key Features**:
10
+ - Large kernels (13x13 to 31x31) via efficient decomposition
11
+ - Dilated small kernel reparam for efficient training
12
+ - SE (Squeeze-and-Excitation) attention
13
+ - GRN (Global Response Normalization) from ConvNeXt V2
14
+ - Dimension-agnostic: supports 1D, 2D, 3D inputs
15
+
16
+ **Variants**:
17
+ - unireplknet_tiny: 31M params, depths [3,3,18,3], dims [80,160,320,640]
18
+ - unireplknet_small: 56M params, depths [3,3,27,3], dims [96,192,384,768]
19
+ - unireplknet_base: 97M params, depths [3,3,27,3], dims [128,256,512,1024]
20
+
21
+ **Why Large Kernels for Wave Problems**:
22
+ - Dispersion curves: Long-range frequency-wavenumber correlations
23
+ - B-scans: Defect signatures span many pixels
24
+ - Time-series: Capture multiple wave periods without deep stacking
25
+
26
+ Reference:
27
+ Ding, X., et al. (2024). UniRepLKNet: A Universal Perception Large-Kernel
28
+ ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition.
29
+ CVPR 2024. https://arxiv.org/abs/2311.15599
30
+
31
+ Author: Ductho Le (ductho.le@outlook.com)
32
+ """
33
+
34
+ from typing import Any
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+
39
+ from wavedl.models._pretrained_utils import (
40
+ LayerNormNd,
41
+ get_conv_layer,
42
+ get_grn_layer,
43
+ get_pool_layer,
44
+ )
45
+ from wavedl.models.base import BaseModel, SpatialShape
46
+ from wavedl.models.registry import register_model
47
+
48
+
49
+ __all__ = [
50
+ "UniRepLKNetBase",
51
+ "UniRepLKNetBaseLarge",
52
+ "UniRepLKNetSmall",
53
+ "UniRepLKNetTiny",
54
+ ]
55
+
56
+
57
+ # =============================================================================
58
+ # LARGE KERNEL CONVOLUTION BLOCK
59
+ # =============================================================================
60
+
61
+
62
+ class LargeKernelConv(nn.Module):
63
+ """
64
+ Large kernel depthwise convolution.
65
+
66
+ Implements efficient large kernel convolutions following UniRepLKNet.
67
+ Uses a single large depthwise conv for simplicity and reliability.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ channels: int,
73
+ kernel_size: int,
74
+ dim: int = 2,
75
+ ):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.kernel_size = kernel_size
79
+
80
+ Conv = get_conv_layer(dim)
81
+ padding = kernel_size // 2
82
+
83
+ # Large kernel depthwise conv
84
+ self.conv = Conv(
85
+ channels,
86
+ channels,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=channels,
90
+ bias=False,
91
+ )
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ return self.conv(x)
95
+
96
+
97
+ class SEBlock(nn.Module):
98
+ """
99
+ Squeeze-and-Excitation block for channel attention.
100
+
101
+ Adaptively recalibrates channel-wise feature responses by explicitly
102
+ modeling interdependencies between channels.
103
+ """
104
+
105
+ def __init__(self, channels: int, reduction: int = 4):
106
+ super().__init__()
107
+ reduced = max(channels // reduction, 8)
108
+ self.fc1 = nn.Linear(channels, reduced, bias=False)
109
+ self.fc2 = nn.Linear(reduced, channels, bias=False)
110
+ self.act = nn.GELU()
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ # Global average pooling
114
+ if x.ndim == 3: # 1D: (B, C, L)
115
+ gap = x.mean(dim=2)
116
+ elif x.ndim == 4: # 2D: (B, C, H, W)
117
+ gap = x.mean(dim=(2, 3))
118
+ else: # 3D: (B, C, D, H, W)
119
+ gap = x.mean(dim=(2, 3, 4))
120
+
121
+ # FC layers
122
+ scale = self.act(self.fc1(gap))
123
+ scale = torch.sigmoid(self.fc2(scale))
124
+
125
+ # Reshape for broadcasting
126
+ if x.ndim == 3:
127
+ scale = scale.unsqueeze(-1)
128
+ elif x.ndim == 4:
129
+ scale = scale.unsqueeze(-1).unsqueeze(-1)
130
+ else:
131
+ scale = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
132
+
133
+ return x * scale
134
+
135
+
136
+ class DropPath(nn.Module):
137
+ """Stochastic Depth (drop path) regularization."""
138
+
139
+ def __init__(self, drop_prob: float = 0.0):
140
+ super().__init__()
141
+ self.drop_prob = drop_prob
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ if self.drop_prob == 0.0 or not self.training:
145
+ return x
146
+
147
+ keep_prob = 1 - self.drop_prob
148
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
149
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
150
+ random_tensor.floor_()
151
+ return x.div(keep_prob) * random_tensor
152
+
153
+
154
+ # =============================================================================
155
+ # UNIREPLKNET BLOCK
156
+ # =============================================================================
157
+
158
+
159
+ class UniRepLKNetBlock(nn.Module):
160
+ """
161
+ UniRepLKNet block with large kernel convolution, SE attention, and GRN.
162
+
163
+ Architecture:
164
+ Input → LargeKernelConv → LayerNorm → SE → Linear → GELU → GRN → Linear → Residual
165
+
166
+ This combines the effective receptive field of large kernels with the
167
+ feature recalibration of SE attention and the inter-channel competition
168
+ of GRN from ConvNeXt V2.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ spatial_dim: int,
175
+ kernel_size: int = 13,
176
+ drop_path: float = 0.0,
177
+ mlp_ratio: float = 4.0,
178
+ ):
179
+ super().__init__()
180
+ self.spatial_dim = spatial_dim
181
+
182
+ GRN = get_grn_layer(spatial_dim)
183
+
184
+ # Large kernel depthwise conv
185
+ self.large_kernel = LargeKernelConv(
186
+ dim, kernel_size=kernel_size, dim=spatial_dim
187
+ )
188
+
189
+ # Layer norm (applied in channels-last format)
190
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
191
+
192
+ # SE attention
193
+ self.se = SEBlock(dim)
194
+
195
+ # MLP with expansion
196
+ hidden_dim = int(dim * mlp_ratio)
197
+ self.pwconv1 = nn.Linear(dim, hidden_dim)
198
+ self.act = nn.GELU()
199
+ self.grn = GRN(hidden_dim)
200
+ self.pwconv2 = nn.Linear(hidden_dim, dim)
201
+
202
+ # Stochastic depth
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
204
+
205
+ def _to_channels_last(self, x: torch.Tensor) -> torch.Tensor:
206
+ """Convert from channels-first to channels-last."""
207
+ if self.spatial_dim == 1:
208
+ return x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
209
+ elif self.spatial_dim == 2:
210
+ return x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
211
+ else:
212
+ return x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
213
+
214
+ def _to_channels_first(self, x: torch.Tensor) -> torch.Tensor:
215
+ """Convert from channels-last to channels-first."""
216
+ if self.spatial_dim == 1:
217
+ return x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
218
+ elif self.spatial_dim == 2:
219
+ return x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
220
+ else:
221
+ return x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ residual = x
225
+
226
+ # Large kernel conv (channels-first)
227
+ x = self.large_kernel(x)
228
+
229
+ # SE attention (channels-first)
230
+ x = self.se(x)
231
+
232
+ # LayerNorm + MLP (channels-last)
233
+ x = self._to_channels_last(x)
234
+ x = self.norm(x)
235
+ x = self.pwconv1(x)
236
+ x = self.act(x)
237
+
238
+ # GRN (channels-first)
239
+ x = self._to_channels_first(x)
240
+ x = self.grn(x)
241
+
242
+ # Final projection (channels-last)
243
+ x = self._to_channels_last(x)
244
+ x = self.pwconv2(x)
245
+ x = self._to_channels_first(x)
246
+
247
+ # Residual + drop path
248
+ x = residual + self.drop_path(x)
249
+ return x
250
+
251
+
252
+ # =============================================================================
253
+ # UNIREPLKNET BASE CLASS
254
+ # =============================================================================
255
+
256
+
257
+ class UniRepLKNetBase(BaseModel):
258
+ """
259
+ UniRepLKNet base class for regression.
260
+
261
+ Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
262
+ Features large kernels for capturing long-range dependencies in wave data.
263
+
264
+ Architecture:
265
+ 1. Stem: 4x downsampling conv
266
+ 2. 4 stages with UniRepLKNet blocks
267
+ 3. Downsampling between stages
268
+ 4. Global pooling + regression head
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ in_shape: SpatialShape,
274
+ out_size: int,
275
+ depths: list[int],
276
+ dims: list[int],
277
+ kernel_sizes: list[int] | None = None,
278
+ drop_path_rate: float = 0.1,
279
+ dropout_rate: float = 0.3,
280
+ **kwargs,
281
+ ):
282
+ super().__init__(in_shape, out_size)
283
+
284
+ self.dim = len(in_shape)
285
+ self.depths = depths
286
+ self.dims = dims
287
+
288
+ # Default kernel sizes: larger in early stages, smaller in later stages
289
+ # Early stages: large receptive field for low-level features
290
+ # Later stages: smaller kernels sufficient for high-level features
291
+ if kernel_sizes is None:
292
+ kernel_sizes = [31, 29, 17, 13]
293
+
294
+ Conv = get_conv_layer(self.dim)
295
+ Pool = get_pool_layer(self.dim)
296
+
297
+ # Stem: aggressive 4x downsampling (like ConvNeXt)
298
+ self.stem = nn.Sequential(
299
+ Conv(1, dims[0], kernel_size=4, stride=4),
300
+ LayerNormNd(dims[0], self.dim),
301
+ )
302
+
303
+ # Stochastic depth decay
304
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
305
+
306
+ # Build stages
307
+ self.stages = nn.ModuleList()
308
+ self.downsamples = nn.ModuleList()
309
+ cur = 0
310
+
311
+ for i in range(len(depths)):
312
+ # Adjust kernel size for 1D (can use larger kernels)
313
+ # Ensure kernel size is always odd for proper same-padding
314
+ kernel_size = kernel_sizes[i]
315
+ if self.dim == 1:
316
+ kernel_size = min(kernel_size * 2 - 1, 63) # Keep odd for 1D
317
+
318
+ stage = nn.Sequential(
319
+ *[
320
+ UniRepLKNetBlock(
321
+ dim=dims[i],
322
+ spatial_dim=self.dim,
323
+ kernel_size=kernel_size,
324
+ drop_path=dp_rates[cur + j],
325
+ )
326
+ for j in range(depths[i])
327
+ ]
328
+ )
329
+ self.stages.append(stage)
330
+ cur += depths[i]
331
+
332
+ # Downsample between stages (except after last)
333
+ if i < len(depths) - 1:
334
+ downsample = nn.Sequential(
335
+ LayerNormNd(dims[i], self.dim),
336
+ Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
337
+ )
338
+ self.downsamples.append(downsample)
339
+
340
+ # Global pooling and head
341
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
342
+ self.global_pool = Pool(1)
343
+ self.head = nn.Sequential(
344
+ nn.Dropout(dropout_rate),
345
+ nn.Linear(dims[-1], dims[-1] // 2),
346
+ nn.GELU(),
347
+ nn.Dropout(dropout_rate * 0.5),
348
+ nn.Linear(dims[-1] // 2, out_size),
349
+ )
350
+
351
+ # Initialize weights
352
+ self._init_weights()
353
+
354
+ def _init_weights(self):
355
+ """Initialize weights with truncated normal."""
356
+ for m in self.modules():
357
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
358
+ nn.init.trunc_normal_(m.weight, std=0.02)
359
+ if m.bias is not None:
360
+ nn.init.zeros_(m.bias)
361
+ elif isinstance(m, nn.LayerNorm):
362
+ nn.init.ones_(m.weight)
363
+ nn.init.zeros_(m.bias)
364
+
365
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
366
+ """
367
+ Forward pass.
368
+
369
+ Args:
370
+ x: Input tensor (B, 1, *in_shape)
371
+
372
+ Returns:
373
+ Output tensor (B, out_size)
374
+ """
375
+ x = self.stem(x)
376
+
377
+ for i, stage in enumerate(self.stages):
378
+ x = stage(x)
379
+ if i < len(self.downsamples):
380
+ x = self.downsamples[i](x)
381
+
382
+ # Global pooling
383
+ x = self.global_pool(x)
384
+ x = x.flatten(1)
385
+
386
+ # Final norm and head
387
+ x = self.norm(x)
388
+ x = self.head(x)
389
+
390
+ return x
391
+
392
+ @classmethod
393
+ def get_default_config(cls) -> dict[str, Any]:
394
+ return {
395
+ "depths": [3, 3, 18, 3],
396
+ "dims": [80, 160, 320, 640],
397
+ "kernel_sizes": [31, 29, 17, 13],
398
+ "drop_path_rate": 0.1,
399
+ "dropout_rate": 0.3,
400
+ }
401
+
402
+
403
+ # =============================================================================
404
+ # REGISTERED VARIANTS
405
+ # =============================================================================
406
+
407
+
408
+ @register_model("unireplknet_tiny")
409
+ class UniRepLKNetTiny(UniRepLKNetBase):
410
+ """
411
+ UniRepLKNet Tiny: ~30.8M backbone parameters.
412
+
413
+ Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
414
+ Depths [3,3,18,3], Dims [80,160,320,640].
415
+ Supports 1D, 2D, 3D inputs.
416
+
417
+ Example:
418
+ >>> model = UniRepLKNetTiny(in_shape=(256, 256), out_size=3)
419
+ >>> x = torch.randn(4, 1, 256, 256)
420
+ >>> out = model(x) # (4, 3)
421
+ """
422
+
423
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
424
+ super().__init__(
425
+ in_shape=in_shape,
426
+ out_size=out_size,
427
+ depths=[3, 3, 18, 3],
428
+ dims=[80, 160, 320, 640],
429
+ kernel_sizes=[31, 29, 17, 13],
430
+ **kwargs,
431
+ )
432
+
433
+ def __repr__(self) -> str:
434
+ return (
435
+ f"UniRepLKNet_Tiny({self.dim}D, in_shape={self.in_shape}, "
436
+ f"out_size={self.out_size})"
437
+ )
438
+
439
+
440
+ @register_model("unireplknet_small")
441
+ class UniRepLKNetSmall(UniRepLKNetBase):
442
+ """
443
+ UniRepLKNet Small: ~56.0M backbone parameters.
444
+
445
+ Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
446
+ Depths [3,3,27,3], Dims [96,192,384,768].
447
+ Supports 1D, 2D, 3D inputs.
448
+ """
449
+
450
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
451
+ super().__init__(
452
+ in_shape=in_shape,
453
+ out_size=out_size,
454
+ depths=[3, 3, 27, 3],
455
+ dims=[96, 192, 384, 768],
456
+ kernel_sizes=[31, 29, 17, 13],
457
+ **kwargs,
458
+ )
459
+
460
+ def __repr__(self) -> str:
461
+ return (
462
+ f"UniRepLKNet_Small({self.dim}D, in_shape={self.in_shape}, "
463
+ f"out_size={self.out_size})"
464
+ )
465
+
466
+
467
+ @register_model("unireplknet_base")
468
+ class UniRepLKNetBaseLarge(UniRepLKNetBase):
469
+ """
470
+ UniRepLKNet Base: ~97.6M backbone parameters.
471
+
472
+ Large kernels [31, 29, 17, 13] for capturing long-range wave patterns.
473
+ Depths [3,3,27,3], Dims [128,256,512,1024].
474
+ Supports 1D, 2D, 3D inputs.
475
+ """
476
+
477
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
478
+ super().__init__(
479
+ in_shape=in_shape,
480
+ out_size=out_size,
481
+ depths=[3, 3, 27, 3],
482
+ dims=[128, 256, 512, 1024],
483
+ kernel_sizes=[31, 29, 17, 13],
484
+ **kwargs,
485
+ )
486
+
487
+ def __repr__(self) -> str:
488
+ return (
489
+ f"UniRepLKNet_Base({self.dim}D, in_shape={self.in_shape}, "
490
+ f"out_size={self.out_size})"
491
+ )
wavedl/models/vit.py CHANGED
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
10
10
  - 2D: Images/spectrograms → patches are grid squares
11
11
 
12
12
  **Variants**:
13
- - vit_tiny: Smallest (~5.7M params, embed_dim=192, depth=12, heads=3)
14
- - vit_small: Light (~22M params, embed_dim=384, depth=12, heads=6)
15
- - vit_base: Standard (~86M params, embed_dim=768, depth=12, heads=12)
13
+ - vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
14
+ - vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
15
+ - vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
16
16
 
17
17
  References:
18
18
  Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
@@ -27,12 +27,12 @@ from typing import Any
27
27
  import torch
28
28
  import torch.nn as nn
29
29
 
30
- from wavedl.models.base import BaseModel
30
+ from wavedl.models.base import BaseModel, SpatialShape1D, SpatialShape2D
31
31
  from wavedl.models.registry import register_model
32
32
 
33
33
 
34
- # Type alias for spatial shapes
35
- SpatialShape = tuple[int] | tuple[int, int]
34
+ # ViT supports 1D and 2D only
35
+ SpatialShape = SpatialShape1D | SpatialShape2D
36
36
 
37
37
 
38
38
  class PatchEmbed(nn.Module):
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
365
365
  """
366
366
  ViT-Tiny: Smallest Vision Transformer variant.
367
367
 
368
- ~5.7M parameters. Good for: Quick experiments, smaller datasets.
368
+ ~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
369
369
 
370
370
  Args:
371
371
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
398
398
  """
399
399
  ViT-Small: Light Vision Transformer variant.
400
400
 
401
- ~22M parameters. Good for: Balanced performance.
401
+ ~21.4M backbone parameters. Good for: Balanced performance.
402
402
 
403
403
  Args:
404
404
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
429
429
  """
430
430
  ViT-Base: Standard Vision Transformer variant.
431
431
 
432
- ~86M parameters. Good for: High accuracy, larger datasets.
432
+ ~85.3M backbone parameters. Good for: High accuracy, larger datasets.
433
433
 
434
434
  Args:
435
435
  in_shape: (L,) for 1D or (H, W) for 2D