wavedl 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +28 -26
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +0 -1
- wavedl/models/base.py +0 -1
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1113 -1117
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/METADATA +111 -93
- wavedl-1.4.0.dist-info/RECORD +37 -0
- wavedl-1.3.0.dist-info/RECORD +0 -31
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/LICENSE +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/WHEEL +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/top_level.txt +0 -0
wavedl/models/tcn.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
|
|
3
|
+
=================================================================================
|
|
4
|
+
|
|
5
|
+
A dedicated 1D architecture using dilated causal convolutions to capture
|
|
6
|
+
long-range temporal dependencies in waveforms and time-series data.
|
|
7
|
+
Provides exponentially growing receptive field with linear parameter growth.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Dilated convolutions: Exponentially growing receptive field
|
|
11
|
+
- Causal padding: No information leakage from future
|
|
12
|
+
- Residual connections: Stable gradient flow
|
|
13
|
+
- Weight normalization: Faster convergence
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- tcn: Standard TCN with configurable depth and channels
|
|
17
|
+
- tcn_small: Lightweight variant for quick experiments
|
|
18
|
+
- tcn_large: Higher capacity for complex patterns
|
|
19
|
+
|
|
20
|
+
**Receptive Field Calculation**:
|
|
21
|
+
RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
|
|
22
|
+
With default settings (kernel=3, 8 layers, dilation=2^i):
|
|
23
|
+
RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
|
|
24
|
+
|
|
25
|
+
**Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
|
|
26
|
+
|
|
27
|
+
References:
|
|
28
|
+
Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
|
|
29
|
+
Generic Convolutional and Recurrent Networks for Sequence Modeling.
|
|
30
|
+
arXiv:1803.01271. https://arxiv.org/abs/1803.01271
|
|
31
|
+
|
|
32
|
+
van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
|
|
33
|
+
arXiv:1609.03499. https://arxiv.org/abs/1609.03499
|
|
34
|
+
|
|
35
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
import torch
|
|
41
|
+
import torch.nn as nn
|
|
42
|
+
import torch.nn.functional as F
|
|
43
|
+
|
|
44
|
+
from wavedl.models.base import BaseModel
|
|
45
|
+
from wavedl.models.registry import register_model
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CausalConv1d(nn.Module):
|
|
49
|
+
"""
|
|
50
|
+
Causal 1D convolution with dilation.
|
|
51
|
+
|
|
52
|
+
Ensures output at time t only depends on inputs at times <= t.
|
|
53
|
+
Uses left-side padding to achieve causal behavior.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
in_channels: int,
|
|
59
|
+
out_channels: int,
|
|
60
|
+
kernel_size: int,
|
|
61
|
+
dilation: int = 1,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.kernel_size = kernel_size
|
|
65
|
+
self.dilation = dilation
|
|
66
|
+
# Causal padding: only pad on the left
|
|
67
|
+
self.padding = (kernel_size - 1) * dilation
|
|
68
|
+
|
|
69
|
+
self.conv = nn.Conv1d(
|
|
70
|
+
in_channels,
|
|
71
|
+
out_channels,
|
|
72
|
+
kernel_size,
|
|
73
|
+
dilation=dilation,
|
|
74
|
+
padding=0, # We handle padding manually for causality
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
# Pad on the left only (causal)
|
|
79
|
+
x = F.pad(x, (self.padding, 0))
|
|
80
|
+
return self.conv(x)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TemporalBlock(nn.Module):
|
|
84
|
+
"""
|
|
85
|
+
Temporal block with two causal dilated convolutions and residual connection.
|
|
86
|
+
|
|
87
|
+
Architecture:
|
|
88
|
+
Input → CausalConv → LayerNorm → GELU → Dropout →
|
|
89
|
+
CausalConv → LayerNorm → GELU → Dropout → (+Input) → Output
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
in_channels: int,
|
|
95
|
+
out_channels: int,
|
|
96
|
+
kernel_size: int,
|
|
97
|
+
dilation: int,
|
|
98
|
+
dropout: float = 0.1,
|
|
99
|
+
):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
# First causal convolution
|
|
103
|
+
self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
|
104
|
+
self.norm1 = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
105
|
+
self.act1 = nn.GELU()
|
|
106
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
107
|
+
|
|
108
|
+
# Second causal convolution
|
|
109
|
+
self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
|
110
|
+
self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
|
|
111
|
+
self.act2 = nn.GELU()
|
|
112
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
113
|
+
|
|
114
|
+
# Residual connection (1x1 conv if channels change)
|
|
115
|
+
self.downsample = (
|
|
116
|
+
nn.Conv1d(in_channels, out_channels, 1)
|
|
117
|
+
if in_channels != out_channels
|
|
118
|
+
else nn.Identity()
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
122
|
+
residual = self.downsample(x)
|
|
123
|
+
|
|
124
|
+
# First conv block
|
|
125
|
+
out = self.conv1(x)
|
|
126
|
+
out = self.norm1(out)
|
|
127
|
+
out = self.act1(out)
|
|
128
|
+
out = self.dropout1(out)
|
|
129
|
+
|
|
130
|
+
# Second conv block
|
|
131
|
+
out = self.conv2(out)
|
|
132
|
+
out = self.norm2(out)
|
|
133
|
+
out = self.act2(out)
|
|
134
|
+
out = self.dropout2(out)
|
|
135
|
+
|
|
136
|
+
return out + residual
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class TCNBase(BaseModel):
|
|
140
|
+
"""
|
|
141
|
+
Base Temporal Convolutional Network for 1D regression.
|
|
142
|
+
|
|
143
|
+
Architecture:
|
|
144
|
+
1. Input projection (optional channel expansion)
|
|
145
|
+
2. Stack of temporal blocks with exponentially increasing dilation
|
|
146
|
+
3. Global average pooling
|
|
147
|
+
4. Regression head
|
|
148
|
+
|
|
149
|
+
The receptive field grows exponentially with depth:
|
|
150
|
+
RF = 1 + (kernel_size - 1) * sum(2^i for i in range(num_layers))
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
in_shape: tuple[int],
|
|
156
|
+
out_size: int,
|
|
157
|
+
num_channels: list[int],
|
|
158
|
+
kernel_size: int = 3,
|
|
159
|
+
dropout_rate: float = 0.1,
|
|
160
|
+
**kwargs,
|
|
161
|
+
):
|
|
162
|
+
"""
|
|
163
|
+
Initialize TCN for regression.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
in_shape: (L,) input signal length
|
|
167
|
+
out_size: Number of regression output targets
|
|
168
|
+
num_channels: List of channel sizes for each temporal block
|
|
169
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
170
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
171
|
+
"""
|
|
172
|
+
super().__init__(in_shape, out_size)
|
|
173
|
+
|
|
174
|
+
if len(in_shape) != 1:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"TCN requires 1D input (L,), got {len(in_shape)}D. "
|
|
177
|
+
"For 2D/3D data, use ResNet, EfficientNet, or Swin."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self.num_channels = num_channels
|
|
181
|
+
self.kernel_size = kernel_size
|
|
182
|
+
self.dropout_rate = dropout_rate
|
|
183
|
+
|
|
184
|
+
# Build temporal blocks with exponentially increasing dilation
|
|
185
|
+
layers = []
|
|
186
|
+
num_levels = len(num_channels)
|
|
187
|
+
|
|
188
|
+
for i in range(num_levels):
|
|
189
|
+
dilation = 2**i
|
|
190
|
+
in_ch = 1 if i == 0 else num_channels[i - 1]
|
|
191
|
+
out_ch = num_channels[i]
|
|
192
|
+
layers.append(
|
|
193
|
+
TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout_rate)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.network = nn.Sequential(*layers)
|
|
197
|
+
|
|
198
|
+
# Global pooling
|
|
199
|
+
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
|
200
|
+
|
|
201
|
+
# Regression head
|
|
202
|
+
final_channels = num_channels[-1]
|
|
203
|
+
self.head = nn.Sequential(
|
|
204
|
+
nn.Dropout(dropout_rate),
|
|
205
|
+
nn.Linear(final_channels, 256),
|
|
206
|
+
nn.GELU(),
|
|
207
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
208
|
+
nn.Linear(256, 128),
|
|
209
|
+
nn.GELU(),
|
|
210
|
+
nn.Linear(128, out_size),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Calculate and store receptive field
|
|
214
|
+
self.receptive_field = self._compute_receptive_field()
|
|
215
|
+
|
|
216
|
+
# Initialize weights
|
|
217
|
+
self._init_weights()
|
|
218
|
+
|
|
219
|
+
def _compute_receptive_field(self) -> int:
|
|
220
|
+
"""Compute the receptive field of the network."""
|
|
221
|
+
rf = 1
|
|
222
|
+
for i in range(len(self.num_channels)):
|
|
223
|
+
dilation = 2**i
|
|
224
|
+
# Each temporal block has 2 convolutions
|
|
225
|
+
rf += 2 * (self.kernel_size - 1) * dilation
|
|
226
|
+
return rf
|
|
227
|
+
|
|
228
|
+
def _init_weights(self):
|
|
229
|
+
"""Initialize weights using Kaiming initialization."""
|
|
230
|
+
for m in self.modules():
|
|
231
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
232
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
233
|
+
if m.bias is not None:
|
|
234
|
+
nn.init.constant_(m.bias, 0)
|
|
235
|
+
elif isinstance(m, nn.GroupNorm):
|
|
236
|
+
nn.init.constant_(m.weight, 1)
|
|
237
|
+
nn.init.constant_(m.bias, 0)
|
|
238
|
+
|
|
239
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
240
|
+
"""
|
|
241
|
+
Forward pass.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x: Input tensor of shape (B, 1, L)
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Output tensor of shape (B, out_size)
|
|
248
|
+
"""
|
|
249
|
+
# Temporal blocks
|
|
250
|
+
x = self.network(x)
|
|
251
|
+
|
|
252
|
+
# Global pooling
|
|
253
|
+
x = self.global_pool(x)
|
|
254
|
+
x = x.flatten(1)
|
|
255
|
+
|
|
256
|
+
# Regression head
|
|
257
|
+
return self.head(x)
|
|
258
|
+
|
|
259
|
+
@classmethod
|
|
260
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
261
|
+
"""Return default configuration for TCN."""
|
|
262
|
+
return {
|
|
263
|
+
"num_channels": [64, 128, 256, 256, 512, 512, 512, 512],
|
|
264
|
+
"kernel_size": 3,
|
|
265
|
+
"dropout_rate": 0.1,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# =============================================================================
|
|
270
|
+
# REGISTERED MODEL VARIANTS
|
|
271
|
+
# =============================================================================
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@register_model("tcn")
|
|
275
|
+
class TCN(TCNBase):
|
|
276
|
+
"""
|
|
277
|
+
TCN: Standard Temporal Convolutional Network.
|
|
278
|
+
|
|
279
|
+
~7.0M parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
|
|
280
|
+
Receptive field: 511 samples with kernel_size=3.
|
|
281
|
+
|
|
282
|
+
Recommended for:
|
|
283
|
+
- Ultrasonic A-scan processing
|
|
284
|
+
- Acoustic emission signals
|
|
285
|
+
- Seismic waveform analysis
|
|
286
|
+
- Any 1D time-series regression
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
in_shape: (L,) input signal length
|
|
290
|
+
out_size: Number of regression targets
|
|
291
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
292
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
>>> model = TCN(in_shape=(4096,), out_size=3)
|
|
296
|
+
>>> x = torch.randn(4, 1, 4096)
|
|
297
|
+
>>> out = model(x) # (4, 3)
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
301
|
+
# Default: 8 layers, 64→512 channels
|
|
302
|
+
num_channels = kwargs.pop(
|
|
303
|
+
"num_channels", [64, 128, 256, 256, 512, 512, 512, 512]
|
|
304
|
+
)
|
|
305
|
+
super().__init__(
|
|
306
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def __repr__(self) -> str:
|
|
310
|
+
return (
|
|
311
|
+
f"TCN(in_shape={self.in_shape}, out={self.out_size}, "
|
|
312
|
+
f"RF={self.receptive_field})"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@register_model("tcn_small")
|
|
317
|
+
class TCNSmall(TCNBase):
|
|
318
|
+
"""
|
|
319
|
+
TCN-Small: Lightweight variant for quick experiments.
|
|
320
|
+
|
|
321
|
+
~1.0M parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
|
|
322
|
+
Receptive field: 127 samples with kernel_size=3.
|
|
323
|
+
|
|
324
|
+
Recommended for:
|
|
325
|
+
- Quick prototyping
|
|
326
|
+
- Smaller datasets
|
|
327
|
+
- Real-time inference on edge devices
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
in_shape: (L,) input signal length
|
|
331
|
+
out_size: Number of regression targets
|
|
332
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
333
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
334
|
+
|
|
335
|
+
Example:
|
|
336
|
+
>>> model = TCNSmall(in_shape=(1024,), out_size=3)
|
|
337
|
+
>>> x = torch.randn(4, 1, 1024)
|
|
338
|
+
>>> out = model(x) # (4, 3)
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
342
|
+
num_channels = [32, 64, 128, 128, 256, 256]
|
|
343
|
+
super().__init__(
|
|
344
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def __repr__(self) -> str:
|
|
348
|
+
return (
|
|
349
|
+
f"TCN_Small(in_shape={self.in_shape}, out={self.out_size}, "
|
|
350
|
+
f"RF={self.receptive_field})"
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
@register_model("tcn_large")
|
|
355
|
+
class TCNLarge(TCNBase):
|
|
356
|
+
"""
|
|
357
|
+
TCN-Large: High-capacity variant for complex patterns.
|
|
358
|
+
|
|
359
|
+
~10.2M parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
|
|
360
|
+
Receptive field: 2047 samples with kernel_size=3.
|
|
361
|
+
|
|
362
|
+
Recommended for:
|
|
363
|
+
- Long sequences (>4096 samples)
|
|
364
|
+
- Complex temporal patterns
|
|
365
|
+
- Large datasets with sufficient compute
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
in_shape: (L,) input signal length
|
|
369
|
+
out_size: Number of regression targets
|
|
370
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
371
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
372
|
+
|
|
373
|
+
Example:
|
|
374
|
+
>>> model = TCNLarge(in_shape=(8192,), out_size=3)
|
|
375
|
+
>>> x = torch.randn(4, 1, 8192)
|
|
376
|
+
>>> out = model(x) # (4, 3)
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
380
|
+
num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
|
|
381
|
+
super().__init__(
|
|
382
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
def __repr__(self) -> str:
|
|
386
|
+
return (
|
|
387
|
+
f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
|
|
388
|
+
f"RF={self.receptive_field})"
|
|
389
|
+
)
|
wavedl/models/unet.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""
|
|
2
|
-
U-Net: Encoder-Decoder Architecture for Regression
|
|
3
|
-
|
|
2
|
+
U-Net Regression: Encoder-Decoder Architecture for Vector Regression
|
|
3
|
+
=====================================================================
|
|
4
4
|
|
|
5
|
-
A dimension-agnostic U-Net implementation for
|
|
6
|
-
-
|
|
7
|
-
|
|
5
|
+
A dimension-agnostic U-Net implementation adapted for vector regression output.
|
|
6
|
+
Uses encoder-decoder architecture with skip connections, then applies global
|
|
7
|
+
pooling to produce a regression vector.
|
|
8
8
|
|
|
9
9
|
**Dimensionality Support**:
|
|
10
10
|
- 1D: Waveforms, signals (N, 1, L) → Conv1d
|
|
@@ -12,11 +12,9 @@ A dimension-agnostic U-Net implementation for tasks requiring either:
|
|
|
12
12
|
- 3D: Volumetric data (N, 1, D, H, W) → Conv3d
|
|
13
13
|
|
|
14
14
|
**Variants**:
|
|
15
|
-
- unet: Full encoder-decoder with spatial output capability
|
|
16
15
|
- unet_regression: U-Net with global pooling for vector regression
|
|
17
16
|
|
|
18
17
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
-
Version: 1.0.0
|
|
20
18
|
"""
|
|
21
19
|
|
|
22
20
|
from typing import Any
|
|
@@ -90,10 +88,6 @@ class Up(nn.Module):
|
|
|
90
88
|
super().__init__()
|
|
91
89
|
_, ConvTranspose, _, _ = _get_layers(dim)
|
|
92
90
|
|
|
93
|
-
# in_channels comes from previous layer
|
|
94
|
-
# After upconv: in_channels // 2
|
|
95
|
-
# After concat with skip (out_channels): in_channels // 2 + out_channels = in_channels
|
|
96
|
-
# Then DoubleConv: in_channels -> out_channels
|
|
97
91
|
self.up = ConvTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
|
98
92
|
self.conv = DoubleConv(in_channels, out_channels, dim)
|
|
99
93
|
|
|
@@ -103,7 +97,6 @@ class Up(nn.Module):
|
|
|
103
97
|
# Handle size mismatch (pad x1 to match x2)
|
|
104
98
|
if x1.shape[2:] != x2.shape[2:]:
|
|
105
99
|
diff = [x2.size(i + 2) - x1.size(i + 2) for i in range(len(x1.shape) - 2)]
|
|
106
|
-
# Pad x1 to match x2
|
|
107
100
|
pad = []
|
|
108
101
|
for d in reversed(diff):
|
|
109
102
|
pad.extend([d // 2, d - d // 2])
|
|
@@ -113,14 +106,33 @@ class Up(nn.Module):
|
|
|
113
106
|
return self.conv(x)
|
|
114
107
|
|
|
115
108
|
|
|
116
|
-
|
|
109
|
+
# =============================================================================
|
|
110
|
+
# REGISTERED MODEL
|
|
111
|
+
# =============================================================================
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@register_model("unet_regression")
|
|
115
|
+
class UNetRegression(BaseModel):
|
|
117
116
|
"""
|
|
118
|
-
|
|
117
|
+
U-Net for vector regression output.
|
|
118
|
+
|
|
119
|
+
Uses U-Net encoder-decoder architecture with skip connections,
|
|
120
|
+
then applies global pooling for standard vector regression output.
|
|
121
|
+
|
|
122
|
+
~31.1M parameters (2D). Good for leveraging multi-scale features
|
|
123
|
+
and skip connections for regression tasks.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
in_shape: (L,), (H, W), or (D, H, W)
|
|
127
|
+
out_size: Number of regression targets
|
|
128
|
+
base_channels: Base channel count (default: 64)
|
|
129
|
+
depth: Number of encoder/decoder levels (default: 4)
|
|
130
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
119
131
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
132
|
+
Example:
|
|
133
|
+
>>> model = UNetRegression(in_shape=(224, 224), out_size=3)
|
|
134
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
135
|
+
>>> out = model(x) # (4, 3)
|
|
124
136
|
"""
|
|
125
137
|
|
|
126
138
|
def __init__(
|
|
@@ -130,7 +142,6 @@ class UNetBase(BaseModel):
|
|
|
130
142
|
base_channels: int = 64,
|
|
131
143
|
depth: int = 4,
|
|
132
144
|
dropout_rate: float = 0.1,
|
|
133
|
-
spatial_output: bool = False,
|
|
134
145
|
**kwargs,
|
|
135
146
|
):
|
|
136
147
|
super().__init__(in_shape, out_size)
|
|
@@ -139,12 +150,10 @@ class UNetBase(BaseModel):
|
|
|
139
150
|
self.base_channels = base_channels
|
|
140
151
|
self.depth = depth
|
|
141
152
|
self.dropout_rate = dropout_rate
|
|
142
|
-
self.spatial_output = spatial_output
|
|
143
153
|
|
|
144
|
-
|
|
154
|
+
_, _, _, AdaptivePool = _get_layers(self.dim)
|
|
145
155
|
|
|
146
156
|
# Channel progression: 64 -> 128 -> 256 -> 512 (for depth=4)
|
|
147
|
-
# features[i] = base_channels * 2^i
|
|
148
157
|
features = [base_channels * (2**i) for i in range(depth + 1)]
|
|
149
158
|
|
|
150
159
|
# Initial double conv (1 -> features[0])
|
|
@@ -158,22 +167,17 @@ class UNetBase(BaseModel):
|
|
|
158
167
|
# Decoder (up path)
|
|
159
168
|
self.ups = nn.ModuleList()
|
|
160
169
|
for i in range(depth):
|
|
161
|
-
# Input: features[depth - i], Skip: features[depth - 1 - i], Output: features[depth - 1 - i]
|
|
162
170
|
self.ups.append(Up(features[depth - i], features[depth - 1 - i], self.dim))
|
|
163
171
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
nn.ReLU(inplace=True),
|
|
174
|
-
nn.Dropout(dropout_rate),
|
|
175
|
-
nn.Linear(256, out_size),
|
|
176
|
-
)
|
|
172
|
+
# Vector output: global pooling + regression head
|
|
173
|
+
self.global_pool = AdaptivePool(1)
|
|
174
|
+
self.head = nn.Sequential(
|
|
175
|
+
nn.Dropout(dropout_rate),
|
|
176
|
+
nn.Linear(features[0], 256),
|
|
177
|
+
nn.ReLU(inplace=True),
|
|
178
|
+
nn.Dropout(dropout_rate),
|
|
179
|
+
nn.Linear(256, out_size),
|
|
180
|
+
)
|
|
177
181
|
|
|
178
182
|
self._init_weights()
|
|
179
183
|
|
|
@@ -220,85 +224,15 @@ class UNetBase(BaseModel):
|
|
|
220
224
|
for up, skip in zip(self.ups, reversed(skips)):
|
|
221
225
|
x = up(x, skip)
|
|
222
226
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
x = x.flatten(1)
|
|
228
|
-
return self.head(x)
|
|
227
|
+
# Global pooling + regression head
|
|
228
|
+
x = self.global_pool(x)
|
|
229
|
+
x = x.flatten(1)
|
|
230
|
+
return self.head(x)
|
|
229
231
|
|
|
230
232
|
@classmethod
|
|
231
233
|
def get_default_config(cls) -> dict[str, Any]:
|
|
232
234
|
"""Return default configuration."""
|
|
233
235
|
return {"base_channels": 64, "depth": 4, "dropout_rate": 0.1}
|
|
234
236
|
|
|
235
|
-
|
|
236
|
-
# =============================================================================
|
|
237
|
-
# REGISTERED MODEL VARIANTS
|
|
238
|
-
# =============================================================================
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
@register_model("unet")
|
|
242
|
-
class UNet(UNetBase):
|
|
243
|
-
"""
|
|
244
|
-
U-Net with spatial output capability.
|
|
245
|
-
|
|
246
|
-
Good for: Pixel/voxel-wise regression (velocity fields, spatial maps).
|
|
247
|
-
|
|
248
|
-
Note: For spatial output, out_size is the number of output channels.
|
|
249
|
-
Output shape: (B, out_size, *spatial_dims) for spatial_output=True.
|
|
250
|
-
|
|
251
|
-
Args:
|
|
252
|
-
in_shape: (L,), (H, W), or (D, H, W)
|
|
253
|
-
out_size: Number of output channels for spatial output
|
|
254
|
-
base_channels: Base channel count (default: 64)
|
|
255
|
-
depth: Number of encoder/decoder levels (default: 4)
|
|
256
|
-
spatial_output: If True, output spatial map; if False, output vector
|
|
257
|
-
dropout_rate: Dropout rate (default: 0.1)
|
|
258
|
-
"""
|
|
259
|
-
|
|
260
|
-
def __init__(
|
|
261
|
-
self,
|
|
262
|
-
in_shape: SpatialShape,
|
|
263
|
-
out_size: int,
|
|
264
|
-
spatial_output: bool = True,
|
|
265
|
-
**kwargs,
|
|
266
|
-
):
|
|
267
|
-
super().__init__(
|
|
268
|
-
in_shape=in_shape,
|
|
269
|
-
out_size=out_size,
|
|
270
|
-
spatial_output=spatial_output,
|
|
271
|
-
**kwargs,
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
def __repr__(self) -> str:
|
|
275
|
-
mode = "spatial" if self.spatial_output else "vector"
|
|
276
|
-
return f"UNet({self.dim}D, {mode}, in_shape={self.in_shape}, out_size={self.out_size})"
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
@register_model("unet_regression")
|
|
280
|
-
class UNetRegression(UNetBase):
|
|
281
|
-
"""
|
|
282
|
-
U-Net for vector regression output.
|
|
283
|
-
|
|
284
|
-
Uses U-Net encoder-decoder but applies global pooling at the end
|
|
285
|
-
for standard vector regression output.
|
|
286
|
-
|
|
287
|
-
Good for: Leveraging U-Net features (multi-scale, skip connections)
|
|
288
|
-
for standard regression tasks.
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
in_shape: (L,), (H, W), or (D, H, W)
|
|
292
|
-
out_size: Number of regression targets
|
|
293
|
-
base_channels: Base channel count (default: 64)
|
|
294
|
-
depth: Number of encoder/decoder levels (default: 4)
|
|
295
|
-
dropout_rate: Dropout rate (default: 0.1)
|
|
296
|
-
"""
|
|
297
|
-
|
|
298
|
-
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
299
|
-
super().__init__(
|
|
300
|
-
in_shape=in_shape, out_size=out_size, spatial_output=False, **kwargs
|
|
301
|
-
)
|
|
302
|
-
|
|
303
237
|
def __repr__(self) -> str:
|
|
304
238
|
return f"UNet_Regression({self.dim}D, in_shape={self.in_shape}, out_size={self.out_size})"
|
wavedl/models/vit.py
CHANGED
|
@@ -10,12 +10,16 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
|
|
|
10
10
|
- 2D: Images/spectrograms → patches are grid squares
|
|
11
11
|
|
|
12
12
|
**Variants**:
|
|
13
|
-
- vit_tiny: Smallest (embed_dim=192, depth=12, heads=3)
|
|
14
|
-
- vit_small: Light (embed_dim=384, depth=12, heads=6)
|
|
15
|
-
- vit_base: Standard (embed_dim=768, depth=12, heads=12)
|
|
13
|
+
- vit_tiny: Smallest (~5.7M params, embed_dim=192, depth=12, heads=3)
|
|
14
|
+
- vit_small: Light (~22M params, embed_dim=384, depth=12, heads=6)
|
|
15
|
+
- vit_base: Standard (~86M params, embed_dim=768, depth=12, heads=12)
|
|
16
|
+
|
|
17
|
+
References:
|
|
18
|
+
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
|
|
19
|
+
Transformers for Image Recognition at Scale. ICLR 2021.
|
|
20
|
+
https://arxiv.org/abs/2010.11929
|
|
16
21
|
|
|
17
22
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
18
|
-
Version: 1.0.0
|
|
19
23
|
"""
|
|
20
24
|
|
|
21
25
|
from typing import Any
|