wavedl 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/tcn.py ADDED
@@ -0,0 +1,389 @@
1
+ """
2
+ Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
3
+ =================================================================================
4
+
5
+ A dedicated 1D architecture using dilated causal convolutions to capture
6
+ long-range temporal dependencies in waveforms and time-series data.
7
+ Provides exponentially growing receptive field with linear parameter growth.
8
+
9
+ **Key Features**:
10
+ - Dilated convolutions: Exponentially growing receptive field
11
+ - Causal padding: No information leakage from future
12
+ - Residual connections: Stable gradient flow
13
+ - Weight normalization: Faster convergence
14
+
15
+ **Variants**:
16
+ - tcn: Standard TCN with configurable depth and channels
17
+ - tcn_small: Lightweight variant for quick experiments
18
+ - tcn_large: Higher capacity for complex patterns
19
+
20
+ **Receptive Field Calculation**:
21
+ RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
22
+ With default settings (kernel=3, 8 layers, dilation=2^i):
23
+ RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
24
+
25
+ **Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
26
+
27
+ References:
28
+ Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
29
+ Generic Convolutional and Recurrent Networks for Sequence Modeling.
30
+ arXiv:1803.01271. https://arxiv.org/abs/1803.01271
31
+
32
+ van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
33
+ arXiv:1609.03499. https://arxiv.org/abs/1609.03499
34
+
35
+ Author: Ductho Le (ductho.le@outlook.com)
36
+ """
37
+
38
+ from typing import Any
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+
44
+ from wavedl.models.base import BaseModel
45
+ from wavedl.models.registry import register_model
46
+
47
+
48
+ class CausalConv1d(nn.Module):
49
+ """
50
+ Causal 1D convolution with dilation.
51
+
52
+ Ensures output at time t only depends on inputs at times <= t.
53
+ Uses left-side padding to achieve causal behavior.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ in_channels: int,
59
+ out_channels: int,
60
+ kernel_size: int,
61
+ dilation: int = 1,
62
+ ):
63
+ super().__init__()
64
+ self.kernel_size = kernel_size
65
+ self.dilation = dilation
66
+ # Causal padding: only pad on the left
67
+ self.padding = (kernel_size - 1) * dilation
68
+
69
+ self.conv = nn.Conv1d(
70
+ in_channels,
71
+ out_channels,
72
+ kernel_size,
73
+ dilation=dilation,
74
+ padding=0, # We handle padding manually for causality
75
+ )
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ # Pad on the left only (causal)
79
+ x = F.pad(x, (self.padding, 0))
80
+ return self.conv(x)
81
+
82
+
83
+ class TemporalBlock(nn.Module):
84
+ """
85
+ Temporal block with two causal dilated convolutions and residual connection.
86
+
87
+ Architecture:
88
+ Input → CausalConv → LayerNorm → GELU → Dropout →
89
+ CausalConv → LayerNorm → GELU → Dropout → (+Input) → Output
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ in_channels: int,
95
+ out_channels: int,
96
+ kernel_size: int,
97
+ dilation: int,
98
+ dropout: float = 0.1,
99
+ ):
100
+ super().__init__()
101
+
102
+ # First causal convolution
103
+ self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
104
+ self.norm1 = nn.GroupNorm(min(8, out_channels), out_channels)
105
+ self.act1 = nn.GELU()
106
+ self.dropout1 = nn.Dropout(dropout)
107
+
108
+ # Second causal convolution
109
+ self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
110
+ self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
111
+ self.act2 = nn.GELU()
112
+ self.dropout2 = nn.Dropout(dropout)
113
+
114
+ # Residual connection (1x1 conv if channels change)
115
+ self.downsample = (
116
+ nn.Conv1d(in_channels, out_channels, 1)
117
+ if in_channels != out_channels
118
+ else nn.Identity()
119
+ )
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ residual = self.downsample(x)
123
+
124
+ # First conv block
125
+ out = self.conv1(x)
126
+ out = self.norm1(out)
127
+ out = self.act1(out)
128
+ out = self.dropout1(out)
129
+
130
+ # Second conv block
131
+ out = self.conv2(out)
132
+ out = self.norm2(out)
133
+ out = self.act2(out)
134
+ out = self.dropout2(out)
135
+
136
+ return out + residual
137
+
138
+
139
+ class TCNBase(BaseModel):
140
+ """
141
+ Base Temporal Convolutional Network for 1D regression.
142
+
143
+ Architecture:
144
+ 1. Input projection (optional channel expansion)
145
+ 2. Stack of temporal blocks with exponentially increasing dilation
146
+ 3. Global average pooling
147
+ 4. Regression head
148
+
149
+ The receptive field grows exponentially with depth:
150
+ RF = 1 + (kernel_size - 1) * sum(2^i for i in range(num_layers))
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ in_shape: tuple[int],
156
+ out_size: int,
157
+ num_channels: list[int],
158
+ kernel_size: int = 3,
159
+ dropout_rate: float = 0.1,
160
+ **kwargs,
161
+ ):
162
+ """
163
+ Initialize TCN for regression.
164
+
165
+ Args:
166
+ in_shape: (L,) input signal length
167
+ out_size: Number of regression output targets
168
+ num_channels: List of channel sizes for each temporal block
169
+ kernel_size: Convolution kernel size (default: 3)
170
+ dropout_rate: Dropout rate (default: 0.1)
171
+ """
172
+ super().__init__(in_shape, out_size)
173
+
174
+ if len(in_shape) != 1:
175
+ raise ValueError(
176
+ f"TCN requires 1D input (L,), got {len(in_shape)}D. "
177
+ "For 2D/3D data, use ResNet, EfficientNet, or Swin."
178
+ )
179
+
180
+ self.num_channels = num_channels
181
+ self.kernel_size = kernel_size
182
+ self.dropout_rate = dropout_rate
183
+
184
+ # Build temporal blocks with exponentially increasing dilation
185
+ layers = []
186
+ num_levels = len(num_channels)
187
+
188
+ for i in range(num_levels):
189
+ dilation = 2**i
190
+ in_ch = 1 if i == 0 else num_channels[i - 1]
191
+ out_ch = num_channels[i]
192
+ layers.append(
193
+ TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout_rate)
194
+ )
195
+
196
+ self.network = nn.Sequential(*layers)
197
+
198
+ # Global pooling
199
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
200
+
201
+ # Regression head
202
+ final_channels = num_channels[-1]
203
+ self.head = nn.Sequential(
204
+ nn.Dropout(dropout_rate),
205
+ nn.Linear(final_channels, 256),
206
+ nn.GELU(),
207
+ nn.Dropout(dropout_rate * 0.5),
208
+ nn.Linear(256, 128),
209
+ nn.GELU(),
210
+ nn.Linear(128, out_size),
211
+ )
212
+
213
+ # Calculate and store receptive field
214
+ self.receptive_field = self._compute_receptive_field()
215
+
216
+ # Initialize weights
217
+ self._init_weights()
218
+
219
+ def _compute_receptive_field(self) -> int:
220
+ """Compute the receptive field of the network."""
221
+ rf = 1
222
+ for i in range(len(self.num_channels)):
223
+ dilation = 2**i
224
+ # Each temporal block has 2 convolutions
225
+ rf += 2 * (self.kernel_size - 1) * dilation
226
+ return rf
227
+
228
+ def _init_weights(self):
229
+ """Initialize weights using Kaiming initialization."""
230
+ for m in self.modules():
231
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
232
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
233
+ if m.bias is not None:
234
+ nn.init.constant_(m.bias, 0)
235
+ elif isinstance(m, nn.GroupNorm):
236
+ nn.init.constant_(m.weight, 1)
237
+ nn.init.constant_(m.bias, 0)
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ """
241
+ Forward pass.
242
+
243
+ Args:
244
+ x: Input tensor of shape (B, 1, L)
245
+
246
+ Returns:
247
+ Output tensor of shape (B, out_size)
248
+ """
249
+ # Temporal blocks
250
+ x = self.network(x)
251
+
252
+ # Global pooling
253
+ x = self.global_pool(x)
254
+ x = x.flatten(1)
255
+
256
+ # Regression head
257
+ return self.head(x)
258
+
259
+ @classmethod
260
+ def get_default_config(cls) -> dict[str, Any]:
261
+ """Return default configuration for TCN."""
262
+ return {
263
+ "num_channels": [64, 128, 256, 256, 512, 512, 512, 512],
264
+ "kernel_size": 3,
265
+ "dropout_rate": 0.1,
266
+ }
267
+
268
+
269
+ # =============================================================================
270
+ # REGISTERED MODEL VARIANTS
271
+ # =============================================================================
272
+
273
+
274
+ @register_model("tcn")
275
+ class TCN(TCNBase):
276
+ """
277
+ TCN: Standard Temporal Convolutional Network.
278
+
279
+ ~7.0M parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
280
+ Receptive field: 511 samples with kernel_size=3.
281
+
282
+ Recommended for:
283
+ - Ultrasonic A-scan processing
284
+ - Acoustic emission signals
285
+ - Seismic waveform analysis
286
+ - Any 1D time-series regression
287
+
288
+ Args:
289
+ in_shape: (L,) input signal length
290
+ out_size: Number of regression targets
291
+ kernel_size: Convolution kernel size (default: 3)
292
+ dropout_rate: Dropout rate (default: 0.1)
293
+
294
+ Example:
295
+ >>> model = TCN(in_shape=(4096,), out_size=3)
296
+ >>> x = torch.randn(4, 1, 4096)
297
+ >>> out = model(x) # (4, 3)
298
+ """
299
+
300
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
301
+ # Default: 8 layers, 64→512 channels
302
+ num_channels = kwargs.pop(
303
+ "num_channels", [64, 128, 256, 256, 512, 512, 512, 512]
304
+ )
305
+ super().__init__(
306
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
307
+ )
308
+
309
+ def __repr__(self) -> str:
310
+ return (
311
+ f"TCN(in_shape={self.in_shape}, out={self.out_size}, "
312
+ f"RF={self.receptive_field})"
313
+ )
314
+
315
+
316
+ @register_model("tcn_small")
317
+ class TCNSmall(TCNBase):
318
+ """
319
+ TCN-Small: Lightweight variant for quick experiments.
320
+
321
+ ~1.0M parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
322
+ Receptive field: 127 samples with kernel_size=3.
323
+
324
+ Recommended for:
325
+ - Quick prototyping
326
+ - Smaller datasets
327
+ - Real-time inference on edge devices
328
+
329
+ Args:
330
+ in_shape: (L,) input signal length
331
+ out_size: Number of regression targets
332
+ kernel_size: Convolution kernel size (default: 3)
333
+ dropout_rate: Dropout rate (default: 0.1)
334
+
335
+ Example:
336
+ >>> model = TCNSmall(in_shape=(1024,), out_size=3)
337
+ >>> x = torch.randn(4, 1, 1024)
338
+ >>> out = model(x) # (4, 3)
339
+ """
340
+
341
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
342
+ num_channels = [32, 64, 128, 128, 256, 256]
343
+ super().__init__(
344
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
345
+ )
346
+
347
+ def __repr__(self) -> str:
348
+ return (
349
+ f"TCN_Small(in_shape={self.in_shape}, out={self.out_size}, "
350
+ f"RF={self.receptive_field})"
351
+ )
352
+
353
+
354
+ @register_model("tcn_large")
355
+ class TCNLarge(TCNBase):
356
+ """
357
+ TCN-Large: High-capacity variant for complex patterns.
358
+
359
+ ~10.2M parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
360
+ Receptive field: 2047 samples with kernel_size=3.
361
+
362
+ Recommended for:
363
+ - Long sequences (>4096 samples)
364
+ - Complex temporal patterns
365
+ - Large datasets with sufficient compute
366
+
367
+ Args:
368
+ in_shape: (L,) input signal length
369
+ out_size: Number of regression targets
370
+ kernel_size: Convolution kernel size (default: 3)
371
+ dropout_rate: Dropout rate (default: 0.1)
372
+
373
+ Example:
374
+ >>> model = TCNLarge(in_shape=(8192,), out_size=3)
375
+ >>> x = torch.randn(4, 1, 8192)
376
+ >>> out = model(x) # (4, 3)
377
+ """
378
+
379
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
380
+ num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
381
+ super().__init__(
382
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
383
+ )
384
+
385
+ def __repr__(self) -> str:
386
+ return (
387
+ f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
388
+ f"RF={self.receptive_field})"
389
+ )
wavedl/models/unet.py CHANGED
@@ -1,10 +1,10 @@
1
1
  """
2
- U-Net: Encoder-Decoder Architecture for Regression
3
- ====================================================
2
+ U-Net Regression: Encoder-Decoder Architecture for Vector Regression
3
+ =====================================================================
4
4
 
5
- A dimension-agnostic U-Net implementation for tasks requiring either:
6
- - Spatial output (e.g., velocity field prediction)
7
- - Vector output (global pooling → regression head)
5
+ A dimension-agnostic U-Net implementation adapted for vector regression output.
6
+ Uses encoder-decoder architecture with skip connections, then applies global
7
+ pooling to produce a regression vector.
8
8
 
9
9
  **Dimensionality Support**:
10
10
  - 1D: Waveforms, signals (N, 1, L) → Conv1d
@@ -12,11 +12,9 @@ A dimension-agnostic U-Net implementation for tasks requiring either:
12
12
  - 3D: Volumetric data (N, 1, D, H, W) → Conv3d
13
13
 
14
14
  **Variants**:
15
- - unet: Full encoder-decoder with spatial output capability
16
15
  - unet_regression: U-Net with global pooling for vector regression
17
16
 
18
17
  Author: Ductho Le (ductho.le@outlook.com)
19
- Version: 1.0.0
20
18
  """
21
19
 
22
20
  from typing import Any
@@ -90,10 +88,6 @@ class Up(nn.Module):
90
88
  super().__init__()
91
89
  _, ConvTranspose, _, _ = _get_layers(dim)
92
90
 
93
- # in_channels comes from previous layer
94
- # After upconv: in_channels // 2
95
- # After concat with skip (out_channels): in_channels // 2 + out_channels = in_channels
96
- # Then DoubleConv: in_channels -> out_channels
97
91
  self.up = ConvTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
98
92
  self.conv = DoubleConv(in_channels, out_channels, dim)
99
93
 
@@ -103,7 +97,6 @@ class Up(nn.Module):
103
97
  # Handle size mismatch (pad x1 to match x2)
104
98
  if x1.shape[2:] != x2.shape[2:]:
105
99
  diff = [x2.size(i + 2) - x1.size(i + 2) for i in range(len(x1.shape) - 2)]
106
- # Pad x1 to match x2
107
100
  pad = []
108
101
  for d in reversed(diff):
109
102
  pad.extend([d // 2, d - d // 2])
@@ -113,14 +106,33 @@ class Up(nn.Module):
113
106
  return self.conv(x)
114
107
 
115
108
 
116
- class UNetBase(BaseModel):
109
+ # =============================================================================
110
+ # REGISTERED MODEL
111
+ # =============================================================================
112
+
113
+
114
+ @register_model("unet_regression")
115
+ class UNetRegression(BaseModel):
117
116
  """
118
- Base U-Net class for regression tasks.
117
+ U-Net for vector regression output.
118
+
119
+ Uses U-Net encoder-decoder architecture with skip connections,
120
+ then applies global pooling for standard vector regression output.
121
+
122
+ ~31.1M parameters (2D). Good for leveraging multi-scale features
123
+ and skip connections for regression tasks.
124
+
125
+ Args:
126
+ in_shape: (L,), (H, W), or (D, H, W)
127
+ out_size: Number of regression targets
128
+ base_channels: Base channel count (default: 64)
129
+ depth: Number of encoder/decoder levels (default: 4)
130
+ dropout_rate: Dropout rate (default: 0.1)
119
131
 
120
- Standard U-Net architecture:
121
- - Encoder path with downsampling
122
- - Decoder path with upsampling and skip connections
123
- - Optional spatial or vector output
132
+ Example:
133
+ >>> model = UNetRegression(in_shape=(224, 224), out_size=3)
134
+ >>> x = torch.randn(4, 1, 224, 224)
135
+ >>> out = model(x) # (4, 3)
124
136
  """
125
137
 
126
138
  def __init__(
@@ -130,7 +142,6 @@ class UNetBase(BaseModel):
130
142
  base_channels: int = 64,
131
143
  depth: int = 4,
132
144
  dropout_rate: float = 0.1,
133
- spatial_output: bool = False,
134
145
  **kwargs,
135
146
  ):
136
147
  super().__init__(in_shape, out_size)
@@ -139,12 +150,10 @@ class UNetBase(BaseModel):
139
150
  self.base_channels = base_channels
140
151
  self.depth = depth
141
152
  self.dropout_rate = dropout_rate
142
- self.spatial_output = spatial_output
143
153
 
144
- Conv, _, _, AdaptivePool = _get_layers(self.dim)
154
+ _, _, _, AdaptivePool = _get_layers(self.dim)
145
155
 
146
156
  # Channel progression: 64 -> 128 -> 256 -> 512 (for depth=4)
147
- # features[i] = base_channels * 2^i
148
157
  features = [base_channels * (2**i) for i in range(depth + 1)]
149
158
 
150
159
  # Initial double conv (1 -> features[0])
@@ -158,22 +167,17 @@ class UNetBase(BaseModel):
158
167
  # Decoder (up path)
159
168
  self.ups = nn.ModuleList()
160
169
  for i in range(depth):
161
- # Input: features[depth - i], Skip: features[depth - 1 - i], Output: features[depth - 1 - i]
162
170
  self.ups.append(Up(features[depth - i], features[depth - 1 - i], self.dim))
163
171
 
164
- if spatial_output:
165
- # Spatial output: 1x1 conv to out_size channels
166
- self.outc = Conv(features[0], out_size, kernel_size=1)
167
- else:
168
- # Vector output: global pooling + regression head
169
- self.global_pool = AdaptivePool(1)
170
- self.head = nn.Sequential(
171
- nn.Dropout(dropout_rate),
172
- nn.Linear(features[0], 256),
173
- nn.ReLU(inplace=True),
174
- nn.Dropout(dropout_rate),
175
- nn.Linear(256, out_size),
176
- )
172
+ # Vector output: global pooling + regression head
173
+ self.global_pool = AdaptivePool(1)
174
+ self.head = nn.Sequential(
175
+ nn.Dropout(dropout_rate),
176
+ nn.Linear(features[0], 256),
177
+ nn.ReLU(inplace=True),
178
+ nn.Dropout(dropout_rate),
179
+ nn.Linear(256, out_size),
180
+ )
177
181
 
178
182
  self._init_weights()
179
183
 
@@ -220,85 +224,15 @@ class UNetBase(BaseModel):
220
224
  for up, skip in zip(self.ups, reversed(skips)):
221
225
  x = up(x, skip)
222
226
 
223
- if self.spatial_output:
224
- return self.outc(x)
225
- else:
226
- x = self.global_pool(x)
227
- x = x.flatten(1)
228
- return self.head(x)
227
+ # Global pooling + regression head
228
+ x = self.global_pool(x)
229
+ x = x.flatten(1)
230
+ return self.head(x)
229
231
 
230
232
  @classmethod
231
233
  def get_default_config(cls) -> dict[str, Any]:
232
234
  """Return default configuration."""
233
235
  return {"base_channels": 64, "depth": 4, "dropout_rate": 0.1}
234
236
 
235
-
236
- # =============================================================================
237
- # REGISTERED MODEL VARIANTS
238
- # =============================================================================
239
-
240
-
241
- @register_model("unet")
242
- class UNet(UNetBase):
243
- """
244
- U-Net with spatial output capability.
245
-
246
- Good for: Pixel/voxel-wise regression (velocity fields, spatial maps).
247
-
248
- Note: For spatial output, out_size is the number of output channels.
249
- Output shape: (B, out_size, *spatial_dims) for spatial_output=True.
250
-
251
- Args:
252
- in_shape: (L,), (H, W), or (D, H, W)
253
- out_size: Number of output channels for spatial output
254
- base_channels: Base channel count (default: 64)
255
- depth: Number of encoder/decoder levels (default: 4)
256
- spatial_output: If True, output spatial map; if False, output vector
257
- dropout_rate: Dropout rate (default: 0.1)
258
- """
259
-
260
- def __init__(
261
- self,
262
- in_shape: SpatialShape,
263
- out_size: int,
264
- spatial_output: bool = True,
265
- **kwargs,
266
- ):
267
- super().__init__(
268
- in_shape=in_shape,
269
- out_size=out_size,
270
- spatial_output=spatial_output,
271
- **kwargs,
272
- )
273
-
274
- def __repr__(self) -> str:
275
- mode = "spatial" if self.spatial_output else "vector"
276
- return f"UNet({self.dim}D, {mode}, in_shape={self.in_shape}, out_size={self.out_size})"
277
-
278
-
279
- @register_model("unet_regression")
280
- class UNetRegression(UNetBase):
281
- """
282
- U-Net for vector regression output.
283
-
284
- Uses U-Net encoder-decoder but applies global pooling at the end
285
- for standard vector regression output.
286
-
287
- Good for: Leveraging U-Net features (multi-scale, skip connections)
288
- for standard regression tasks.
289
-
290
- Args:
291
- in_shape: (L,), (H, W), or (D, H, W)
292
- out_size: Number of regression targets
293
- base_channels: Base channel count (default: 64)
294
- depth: Number of encoder/decoder levels (default: 4)
295
- dropout_rate: Dropout rate (default: 0.1)
296
- """
297
-
298
- def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
299
- super().__init__(
300
- in_shape=in_shape, out_size=out_size, spatial_output=False, **kwargs
301
- )
302
-
303
237
  def __repr__(self) -> str:
304
238
  return f"UNet_Regression({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
wavedl/models/vit.py CHANGED
@@ -10,12 +10,16 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
10
10
  - 2D: Images/spectrograms → patches are grid squares
11
11
 
12
12
  **Variants**:
13
- - vit_tiny: Smallest (embed_dim=192, depth=12, heads=3)
14
- - vit_small: Light (embed_dim=384, depth=12, heads=6)
15
- - vit_base: Standard (embed_dim=768, depth=12, heads=12)
13
+ - vit_tiny: Smallest (~5.7M params, embed_dim=192, depth=12, heads=3)
14
+ - vit_small: Light (~22M params, embed_dim=384, depth=12, heads=6)
15
+ - vit_base: Standard (~86M params, embed_dim=768, depth=12, heads=12)
16
+
17
+ References:
18
+ Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
19
+ Transformers for Image Recognition at Scale. ICLR 2021.
20
+ https://arxiv.org/abs/2010.11929
16
21
 
17
22
  Author: Ductho Le (ductho.le@outlook.com)
18
- Version: 1.0.0
19
23
  """
20
24
 
21
25
  from typing import Any