wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/{hpc.py → launcher.py} +135 -61
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1427 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
- wavedl-1.6.2.dist-info/RECORD +46 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
wavedl/models/tcn.py
CHANGED
|
@@ -1,409 +1,393 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
|
|
3
|
-
=================================================================================
|
|
4
|
-
|
|
5
|
-
A dedicated 1D architecture using dilated causal convolutions to capture
|
|
6
|
-
long-range temporal dependencies in waveforms and time-series data.
|
|
7
|
-
Provides exponentially growing receptive field with linear parameter growth.
|
|
8
|
-
|
|
9
|
-
**Key Features**:
|
|
10
|
-
- Dilated convolutions: Exponentially growing receptive field
|
|
11
|
-
- Causal padding: No information leakage from future
|
|
12
|
-
- Residual connections: Stable gradient flow
|
|
13
|
-
- Weight normalization: Faster convergence
|
|
14
|
-
|
|
15
|
-
**Variants**:
|
|
16
|
-
- tcn: Standard TCN with configurable depth and channels
|
|
17
|
-
- tcn_small: Lightweight variant for quick experiments
|
|
18
|
-
- tcn_large: Higher capacity for complex patterns
|
|
19
|
-
|
|
20
|
-
**Receptive Field Calculation**:
|
|
21
|
-
RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
|
|
22
|
-
With default settings (kernel=3, 8 layers, dilation=2^i):
|
|
23
|
-
RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
|
|
24
|
-
|
|
25
|
-
**Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
|
|
26
|
-
|
|
27
|
-
References:
|
|
28
|
-
Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
|
|
29
|
-
Generic Convolutional and Recurrent Networks for Sequence Modeling.
|
|
30
|
-
arXiv:1803.01271. https://arxiv.org/abs/1803.01271
|
|
31
|
-
|
|
32
|
-
van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
|
|
33
|
-
arXiv:1609.03499. https://arxiv.org/abs/1609.03499
|
|
34
|
-
|
|
35
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
from typing import Any
|
|
39
|
-
|
|
40
|
-
import torch
|
|
41
|
-
import torch.nn as nn
|
|
42
|
-
import torch.nn.functional as F
|
|
43
|
-
|
|
44
|
-
from wavedl.models.base import BaseModel
|
|
45
|
-
from wavedl.models.registry import register_model
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
#
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
out_size
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
self.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
>>> model = TCNLarge(in_shape=(8192,), out_size=3)
|
|
395
|
-
>>> x = torch.randn(4, 1, 8192)
|
|
396
|
-
>>> out = model(x) # (4, 3)
|
|
397
|
-
"""
|
|
398
|
-
|
|
399
|
-
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
400
|
-
num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
|
|
401
|
-
super().__init__(
|
|
402
|
-
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
def __repr__(self) -> str:
|
|
406
|
-
return (
|
|
407
|
-
f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
|
|
408
|
-
f"RF={self.receptive_field})"
|
|
409
|
-
)
|
|
1
|
+
"""
|
|
2
|
+
Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
|
|
3
|
+
=================================================================================
|
|
4
|
+
|
|
5
|
+
A dedicated 1D architecture using dilated causal convolutions to capture
|
|
6
|
+
long-range temporal dependencies in waveforms and time-series data.
|
|
7
|
+
Provides exponentially growing receptive field with linear parameter growth.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Dilated convolutions: Exponentially growing receptive field
|
|
11
|
+
- Causal padding: No information leakage from future
|
|
12
|
+
- Residual connections: Stable gradient flow
|
|
13
|
+
- Weight normalization: Faster convergence
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- tcn: Standard TCN with configurable depth and channels
|
|
17
|
+
- tcn_small: Lightweight variant for quick experiments
|
|
18
|
+
- tcn_large: Higher capacity for complex patterns
|
|
19
|
+
|
|
20
|
+
**Receptive Field Calculation**:
|
|
21
|
+
RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
|
|
22
|
+
With default settings (kernel=3, 8 layers, dilation=2^i):
|
|
23
|
+
RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
|
|
24
|
+
|
|
25
|
+
**Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
|
|
26
|
+
|
|
27
|
+
References:
|
|
28
|
+
Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
|
|
29
|
+
Generic Convolutional and Recurrent Networks for Sequence Modeling.
|
|
30
|
+
arXiv:1803.01271. https://arxiv.org/abs/1803.01271
|
|
31
|
+
|
|
32
|
+
van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
|
|
33
|
+
arXiv:1609.03499. https://arxiv.org/abs/1609.03499
|
|
34
|
+
|
|
35
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
import torch
|
|
41
|
+
import torch.nn as nn
|
|
42
|
+
import torch.nn.functional as F
|
|
43
|
+
|
|
44
|
+
from wavedl.models.base import BaseModel, compute_num_groups
|
|
45
|
+
from wavedl.models.registry import register_model
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CausalConv1d(nn.Module):
|
|
49
|
+
"""
|
|
50
|
+
Causal 1D convolution with dilation.
|
|
51
|
+
|
|
52
|
+
Ensures output at time t only depends on inputs at times <= t.
|
|
53
|
+
Uses left-side padding to achieve causal behavior.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
in_channels: int,
|
|
59
|
+
out_channels: int,
|
|
60
|
+
kernel_size: int,
|
|
61
|
+
dilation: int = 1,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.kernel_size = kernel_size
|
|
65
|
+
self.dilation = dilation
|
|
66
|
+
# Causal padding: only pad on the left
|
|
67
|
+
self.padding = (kernel_size - 1) * dilation
|
|
68
|
+
|
|
69
|
+
self.conv = nn.Conv1d(
|
|
70
|
+
in_channels,
|
|
71
|
+
out_channels,
|
|
72
|
+
kernel_size,
|
|
73
|
+
dilation=dilation,
|
|
74
|
+
padding=0, # We handle padding manually for causality
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
# Pad on the left only (causal)
|
|
79
|
+
x = F.pad(x, (self.padding, 0))
|
|
80
|
+
return self.conv(x)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TemporalBlock(nn.Module):
|
|
84
|
+
"""
|
|
85
|
+
Temporal block with two causal dilated convolutions and residual connection.
|
|
86
|
+
|
|
87
|
+
Architecture:
|
|
88
|
+
Input → CausalConv → LayerNorm → GELU → Dropout →
|
|
89
|
+
CausalConv → LayerNorm → GELU → Dropout → (+Input) → Output
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
in_channels: int,
|
|
95
|
+
out_channels: int,
|
|
96
|
+
kernel_size: int,
|
|
97
|
+
dilation: int,
|
|
98
|
+
dropout: float = 0.1,
|
|
99
|
+
):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
# First causal convolution
|
|
103
|
+
self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
|
104
|
+
self.norm1 = nn.GroupNorm(
|
|
105
|
+
compute_num_groups(out_channels, preferred_groups=8), out_channels
|
|
106
|
+
)
|
|
107
|
+
self.act1 = nn.GELU()
|
|
108
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
109
|
+
|
|
110
|
+
# Second causal convolution
|
|
111
|
+
self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
|
112
|
+
self.norm2 = nn.GroupNorm(
|
|
113
|
+
compute_num_groups(out_channels, preferred_groups=8), out_channels
|
|
114
|
+
)
|
|
115
|
+
self.act2 = nn.GELU()
|
|
116
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
117
|
+
|
|
118
|
+
# Residual connection (1x1 conv if channels change)
|
|
119
|
+
self.downsample = (
|
|
120
|
+
nn.Conv1d(in_channels, out_channels, 1)
|
|
121
|
+
if in_channels != out_channels
|
|
122
|
+
else nn.Identity()
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
residual = self.downsample(x)
|
|
127
|
+
|
|
128
|
+
# First conv block
|
|
129
|
+
out = self.conv1(x)
|
|
130
|
+
out = self.norm1(out)
|
|
131
|
+
out = self.act1(out)
|
|
132
|
+
out = self.dropout1(out)
|
|
133
|
+
|
|
134
|
+
# Second conv block
|
|
135
|
+
out = self.conv2(out)
|
|
136
|
+
out = self.norm2(out)
|
|
137
|
+
out = self.act2(out)
|
|
138
|
+
out = self.dropout2(out)
|
|
139
|
+
|
|
140
|
+
return out + residual
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class TCNBase(BaseModel):
|
|
144
|
+
"""
|
|
145
|
+
Base Temporal Convolutional Network for 1D regression.
|
|
146
|
+
|
|
147
|
+
Architecture:
|
|
148
|
+
1. Input projection (optional channel expansion)
|
|
149
|
+
2. Stack of temporal blocks with exponentially increasing dilation
|
|
150
|
+
3. Global average pooling
|
|
151
|
+
4. Regression head
|
|
152
|
+
|
|
153
|
+
The receptive field grows exponentially with depth:
|
|
154
|
+
RF = 1 + (kernel_size - 1) * sum(2^i for i in range(num_layers))
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
in_shape: tuple[int],
|
|
160
|
+
out_size: int,
|
|
161
|
+
num_channels: list[int],
|
|
162
|
+
kernel_size: int = 3,
|
|
163
|
+
dropout_rate: float = 0.1,
|
|
164
|
+
**kwargs,
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Initialize TCN for regression.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
in_shape: (L,) input signal length
|
|
171
|
+
out_size: Number of regression output targets
|
|
172
|
+
num_channels: List of channel sizes for each temporal block
|
|
173
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
174
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
175
|
+
"""
|
|
176
|
+
super().__init__(in_shape, out_size)
|
|
177
|
+
|
|
178
|
+
if len(in_shape) != 1:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"TCN requires 1D input (L,), got {len(in_shape)}D. "
|
|
181
|
+
"For 2D/3D data, use ResNet, EfficientNet, or Swin."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.num_channels = num_channels
|
|
185
|
+
self.kernel_size = kernel_size
|
|
186
|
+
self.dropout_rate = dropout_rate
|
|
187
|
+
|
|
188
|
+
# Build temporal blocks with exponentially increasing dilation
|
|
189
|
+
layers = []
|
|
190
|
+
num_levels = len(num_channels)
|
|
191
|
+
|
|
192
|
+
for i in range(num_levels):
|
|
193
|
+
dilation = 2**i
|
|
194
|
+
in_ch = 1 if i == 0 else num_channels[i - 1]
|
|
195
|
+
out_ch = num_channels[i]
|
|
196
|
+
layers.append(
|
|
197
|
+
TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout_rate)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
self.network = nn.Sequential(*layers)
|
|
201
|
+
|
|
202
|
+
# Global pooling
|
|
203
|
+
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
|
204
|
+
|
|
205
|
+
# Regression head
|
|
206
|
+
final_channels = num_channels[-1]
|
|
207
|
+
self.head = nn.Sequential(
|
|
208
|
+
nn.Dropout(dropout_rate),
|
|
209
|
+
nn.Linear(final_channels, 256),
|
|
210
|
+
nn.GELU(),
|
|
211
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
212
|
+
nn.Linear(256, 128),
|
|
213
|
+
nn.GELU(),
|
|
214
|
+
nn.Linear(128, out_size),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Calculate and store receptive field
|
|
218
|
+
self.receptive_field = self._compute_receptive_field()
|
|
219
|
+
|
|
220
|
+
# Initialize weights
|
|
221
|
+
self._init_weights()
|
|
222
|
+
|
|
223
|
+
def _compute_receptive_field(self) -> int:
|
|
224
|
+
"""Compute the receptive field of the network."""
|
|
225
|
+
rf = 1
|
|
226
|
+
for i in range(len(self.num_channels)):
|
|
227
|
+
dilation = 2**i
|
|
228
|
+
# Each temporal block has 2 convolutions
|
|
229
|
+
rf += 2 * (self.kernel_size - 1) * dilation
|
|
230
|
+
return rf
|
|
231
|
+
|
|
232
|
+
def _init_weights(self):
|
|
233
|
+
"""Initialize weights using Kaiming initialization."""
|
|
234
|
+
for m in self.modules():
|
|
235
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
236
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
237
|
+
if m.bias is not None:
|
|
238
|
+
nn.init.constant_(m.bias, 0)
|
|
239
|
+
elif isinstance(m, nn.GroupNorm):
|
|
240
|
+
nn.init.constant_(m.weight, 1)
|
|
241
|
+
nn.init.constant_(m.bias, 0)
|
|
242
|
+
|
|
243
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
244
|
+
"""
|
|
245
|
+
Forward pass.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
x: Input tensor of shape (B, 1, L)
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Output tensor of shape (B, out_size)
|
|
252
|
+
"""
|
|
253
|
+
# Temporal blocks
|
|
254
|
+
x = self.network(x)
|
|
255
|
+
|
|
256
|
+
# Global pooling
|
|
257
|
+
x = self.global_pool(x)
|
|
258
|
+
x = x.flatten(1)
|
|
259
|
+
|
|
260
|
+
# Regression head
|
|
261
|
+
return self.head(x)
|
|
262
|
+
|
|
263
|
+
@classmethod
|
|
264
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
265
|
+
"""Return default configuration for TCN."""
|
|
266
|
+
return {
|
|
267
|
+
"num_channels": [64, 128, 256, 256, 512, 512, 512, 512],
|
|
268
|
+
"kernel_size": 3,
|
|
269
|
+
"dropout_rate": 0.1,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# =============================================================================
|
|
274
|
+
# REGISTERED MODEL VARIANTS
|
|
275
|
+
# =============================================================================
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@register_model("tcn")
|
|
279
|
+
class TCN(TCNBase):
|
|
280
|
+
"""
|
|
281
|
+
TCN: Standard Temporal Convolutional Network.
|
|
282
|
+
|
|
283
|
+
~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
|
|
284
|
+
Receptive field: 511 samples with kernel_size=3.
|
|
285
|
+
|
|
286
|
+
Recommended for:
|
|
287
|
+
- Ultrasonic A-scan processing
|
|
288
|
+
- Acoustic emission signals
|
|
289
|
+
- Seismic waveform analysis
|
|
290
|
+
- Any 1D time-series regression
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
in_shape: (L,) input signal length
|
|
294
|
+
out_size: Number of regression targets
|
|
295
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
296
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
297
|
+
|
|
298
|
+
Example:
|
|
299
|
+
>>> model = TCN(in_shape=(4096,), out_size=3)
|
|
300
|
+
>>> x = torch.randn(4, 1, 4096)
|
|
301
|
+
>>> out = model(x) # (4, 3)
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
305
|
+
# Default: 8 layers, 64→512 channels
|
|
306
|
+
num_channels = kwargs.pop(
|
|
307
|
+
"num_channels", [64, 128, 256, 256, 512, 512, 512, 512]
|
|
308
|
+
)
|
|
309
|
+
super().__init__(
|
|
310
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def __repr__(self) -> str:
|
|
314
|
+
return (
|
|
315
|
+
f"TCN(in_shape={self.in_shape}, out={self.out_size}, "
|
|
316
|
+
f"RF={self.receptive_field})"
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@register_model("tcn_small")
|
|
321
|
+
class TCNSmall(TCNBase):
|
|
322
|
+
"""
|
|
323
|
+
TCN-Small: Lightweight variant for quick experiments.
|
|
324
|
+
|
|
325
|
+
~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
|
|
326
|
+
Receptive field: 127 samples with kernel_size=3.
|
|
327
|
+
|
|
328
|
+
Recommended for:
|
|
329
|
+
- Quick prototyping
|
|
330
|
+
- Smaller datasets
|
|
331
|
+
- Real-time inference on edge devices
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
in_shape: (L,) input signal length
|
|
335
|
+
out_size: Number of regression targets
|
|
336
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
337
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
>>> model = TCNSmall(in_shape=(1024,), out_size=3)
|
|
341
|
+
>>> x = torch.randn(4, 1, 1024)
|
|
342
|
+
>>> out = model(x) # (4, 3)
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
346
|
+
num_channels = [32, 64, 128, 128, 256, 256]
|
|
347
|
+
super().__init__(
|
|
348
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def __repr__(self) -> str:
|
|
352
|
+
return (
|
|
353
|
+
f"TCN_Small(in_shape={self.in_shape}, out={self.out_size}, "
|
|
354
|
+
f"RF={self.receptive_field})"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@register_model("tcn_large")
|
|
359
|
+
class TCNLarge(TCNBase):
|
|
360
|
+
"""
|
|
361
|
+
TCN-Large: High-capacity variant for complex patterns.
|
|
362
|
+
|
|
363
|
+
~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
|
|
364
|
+
Receptive field: 2047 samples with kernel_size=3.
|
|
365
|
+
|
|
366
|
+
Recommended for:
|
|
367
|
+
- Long sequences (>4096 samples)
|
|
368
|
+
- Complex temporal patterns
|
|
369
|
+
- Large datasets with sufficient compute
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
in_shape: (L,) input signal length
|
|
373
|
+
out_size: Number of regression targets
|
|
374
|
+
kernel_size: Convolution kernel size (default: 3)
|
|
375
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
376
|
+
|
|
377
|
+
Example:
|
|
378
|
+
>>> model = TCNLarge(in_shape=(8192,), out_size=3)
|
|
379
|
+
>>> x = torch.randn(4, 1, 8192)
|
|
380
|
+
>>> out = model(x) # (4, 3)
|
|
381
|
+
"""
|
|
382
|
+
|
|
383
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
384
|
+
num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
|
|
385
|
+
super().__init__(
|
|
386
|
+
in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def __repr__(self) -> str:
|
|
390
|
+
return (
|
|
391
|
+
f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
|
|
392
|
+
f"RF={self.receptive_field})"
|
|
393
|
+
)
|