wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/tcn.py CHANGED
@@ -1,409 +1,393 @@
1
- """
2
- Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
3
- =================================================================================
4
-
5
- A dedicated 1D architecture using dilated causal convolutions to capture
6
- long-range temporal dependencies in waveforms and time-series data.
7
- Provides exponentially growing receptive field with linear parameter growth.
8
-
9
- **Key Features**:
10
- - Dilated convolutions: Exponentially growing receptive field
11
- - Causal padding: No information leakage from future
12
- - Residual connections: Stable gradient flow
13
- - Weight normalization: Faster convergence
14
-
15
- **Variants**:
16
- - tcn: Standard TCN with configurable depth and channels
17
- - tcn_small: Lightweight variant for quick experiments
18
- - tcn_large: Higher capacity for complex patterns
19
-
20
- **Receptive Field Calculation**:
21
- RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
22
- With default settings (kernel=3, 8 layers, dilation=2^i):
23
- RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
24
-
25
- **Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
26
-
27
- References:
28
- Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
29
- Generic Convolutional and Recurrent Networks for Sequence Modeling.
30
- arXiv:1803.01271. https://arxiv.org/abs/1803.01271
31
-
32
- van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
33
- arXiv:1609.03499. https://arxiv.org/abs/1609.03499
34
-
35
- Author: Ductho Le (ductho.le@outlook.com)
36
- """
37
-
38
- from typing import Any
39
-
40
- import torch
41
- import torch.nn as nn
42
- import torch.nn.functional as F
43
-
44
- from wavedl.models.base import BaseModel
45
- from wavedl.models.registry import register_model
46
-
47
-
48
- def _find_group_count(channels: int, max_groups: int = 8) -> int:
49
- """
50
- Find largest valid group count for GroupNorm.
51
-
52
- GroupNorm requires channels to be divisible by num_groups.
53
- This finds the largest divisor up to max_groups.
54
-
55
- Args:
56
- channels: Number of channels
57
- max_groups: Maximum group count to consider (default: 8)
58
-
59
- Returns:
60
- Largest valid group count (always >= 1)
61
- """
62
- for g in range(min(max_groups, channels), 0, -1):
63
- if channels % g == 0:
64
- return g
65
- return 1
66
-
67
-
68
- class CausalConv1d(nn.Module):
69
- """
70
- Causal 1D convolution with dilation.
71
-
72
- Ensures output at time t only depends on inputs at times <= t.
73
- Uses left-side padding to achieve causal behavior.
74
- """
75
-
76
- def __init__(
77
- self,
78
- in_channels: int,
79
- out_channels: int,
80
- kernel_size: int,
81
- dilation: int = 1,
82
- ):
83
- super().__init__()
84
- self.kernel_size = kernel_size
85
- self.dilation = dilation
86
- # Causal padding: only pad on the left
87
- self.padding = (kernel_size - 1) * dilation
88
-
89
- self.conv = nn.Conv1d(
90
- in_channels,
91
- out_channels,
92
- kernel_size,
93
- dilation=dilation,
94
- padding=0, # We handle padding manually for causality
95
- )
96
-
97
- def forward(self, x: torch.Tensor) -> torch.Tensor:
98
- # Pad on the left only (causal)
99
- x = F.pad(x, (self.padding, 0))
100
- return self.conv(x)
101
-
102
-
103
- class TemporalBlock(nn.Module):
104
- """
105
- Temporal block with two causal dilated convolutions and residual connection.
106
-
107
- Architecture:
108
- Input CausalConv → LayerNorm → GELU → Dropout
109
- CausalConv → LayerNorm → GELU → Dropout → (+Input) → Output
110
- """
111
-
112
- def __init__(
113
- self,
114
- in_channels: int,
115
- out_channels: int,
116
- kernel_size: int,
117
- dilation: int,
118
- dropout: float = 0.1,
119
- ):
120
- super().__init__()
121
-
122
- # First causal convolution
123
- self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
124
- self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
125
- self.act1 = nn.GELU()
126
- self.dropout1 = nn.Dropout(dropout)
127
-
128
- # Second causal convolution
129
- self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
130
- self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
131
- self.act2 = nn.GELU()
132
- self.dropout2 = nn.Dropout(dropout)
133
-
134
- # Residual connection (1x1 conv if channels change)
135
- self.downsample = (
136
- nn.Conv1d(in_channels, out_channels, 1)
137
- if in_channels != out_channels
138
- else nn.Identity()
139
- )
140
-
141
- def forward(self, x: torch.Tensor) -> torch.Tensor:
142
- residual = self.downsample(x)
143
-
144
- # First conv block
145
- out = self.conv1(x)
146
- out = self.norm1(out)
147
- out = self.act1(out)
148
- out = self.dropout1(out)
149
-
150
- # Second conv block
151
- out = self.conv2(out)
152
- out = self.norm2(out)
153
- out = self.act2(out)
154
- out = self.dropout2(out)
155
-
156
- return out + residual
157
-
158
-
159
- class TCNBase(BaseModel):
160
- """
161
- Base Temporal Convolutional Network for 1D regression.
162
-
163
- Architecture:
164
- 1. Input projection (optional channel expansion)
165
- 2. Stack of temporal blocks with exponentially increasing dilation
166
- 3. Global average pooling
167
- 4. Regression head
168
-
169
- The receptive field grows exponentially with depth:
170
- RF = 1 + (kernel_size - 1) * sum(2^i for i in range(num_layers))
171
- """
172
-
173
- def __init__(
174
- self,
175
- in_shape: tuple[int],
176
- out_size: int,
177
- num_channels: list[int],
178
- kernel_size: int = 3,
179
- dropout_rate: float = 0.1,
180
- **kwargs,
181
- ):
182
- """
183
- Initialize TCN for regression.
184
-
185
- Args:
186
- in_shape: (L,) input signal length
187
- out_size: Number of regression output targets
188
- num_channels: List of channel sizes for each temporal block
189
- kernel_size: Convolution kernel size (default: 3)
190
- dropout_rate: Dropout rate (default: 0.1)
191
- """
192
- super().__init__(in_shape, out_size)
193
-
194
- if len(in_shape) != 1:
195
- raise ValueError(
196
- f"TCN requires 1D input (L,), got {len(in_shape)}D. "
197
- "For 2D/3D data, use ResNet, EfficientNet, or Swin."
198
- )
199
-
200
- self.num_channels = num_channels
201
- self.kernel_size = kernel_size
202
- self.dropout_rate = dropout_rate
203
-
204
- # Build temporal blocks with exponentially increasing dilation
205
- layers = []
206
- num_levels = len(num_channels)
207
-
208
- for i in range(num_levels):
209
- dilation = 2**i
210
- in_ch = 1 if i == 0 else num_channels[i - 1]
211
- out_ch = num_channels[i]
212
- layers.append(
213
- TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout_rate)
214
- )
215
-
216
- self.network = nn.Sequential(*layers)
217
-
218
- # Global pooling
219
- self.global_pool = nn.AdaptiveAvgPool1d(1)
220
-
221
- # Regression head
222
- final_channels = num_channels[-1]
223
- self.head = nn.Sequential(
224
- nn.Dropout(dropout_rate),
225
- nn.Linear(final_channels, 256),
226
- nn.GELU(),
227
- nn.Dropout(dropout_rate * 0.5),
228
- nn.Linear(256, 128),
229
- nn.GELU(),
230
- nn.Linear(128, out_size),
231
- )
232
-
233
- # Calculate and store receptive field
234
- self.receptive_field = self._compute_receptive_field()
235
-
236
- # Initialize weights
237
- self._init_weights()
238
-
239
- def _compute_receptive_field(self) -> int:
240
- """Compute the receptive field of the network."""
241
- rf = 1
242
- for i in range(len(self.num_channels)):
243
- dilation = 2**i
244
- # Each temporal block has 2 convolutions
245
- rf += 2 * (self.kernel_size - 1) * dilation
246
- return rf
247
-
248
- def _init_weights(self):
249
- """Initialize weights using Kaiming initialization."""
250
- for m in self.modules():
251
- if isinstance(m, (nn.Conv1d, nn.Linear)):
252
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
253
- if m.bias is not None:
254
- nn.init.constant_(m.bias, 0)
255
- elif isinstance(m, nn.GroupNorm):
256
- nn.init.constant_(m.weight, 1)
257
- nn.init.constant_(m.bias, 0)
258
-
259
- def forward(self, x: torch.Tensor) -> torch.Tensor:
260
- """
261
- Forward pass.
262
-
263
- Args:
264
- x: Input tensor of shape (B, 1, L)
265
-
266
- Returns:
267
- Output tensor of shape (B, out_size)
268
- """
269
- # Temporal blocks
270
- x = self.network(x)
271
-
272
- # Global pooling
273
- x = self.global_pool(x)
274
- x = x.flatten(1)
275
-
276
- # Regression head
277
- return self.head(x)
278
-
279
- @classmethod
280
- def get_default_config(cls) -> dict[str, Any]:
281
- """Return default configuration for TCN."""
282
- return {
283
- "num_channels": [64, 128, 256, 256, 512, 512, 512, 512],
284
- "kernel_size": 3,
285
- "dropout_rate": 0.1,
286
- }
287
-
288
-
289
- # =============================================================================
290
- # REGISTERED MODEL VARIANTS
291
- # =============================================================================
292
-
293
-
294
- @register_model("tcn")
295
- class TCN(TCNBase):
296
- """
297
- TCN: Standard Temporal Convolutional Network.
298
-
299
- ~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
300
- Receptive field: 511 samples with kernel_size=3.
301
-
302
- Recommended for:
303
- - Ultrasonic A-scan processing
304
- - Acoustic emission signals
305
- - Seismic waveform analysis
306
- - Any 1D time-series regression
307
-
308
- Args:
309
- in_shape: (L,) input signal length
310
- out_size: Number of regression targets
311
- kernel_size: Convolution kernel size (default: 3)
312
- dropout_rate: Dropout rate (default: 0.1)
313
-
314
- Example:
315
- >>> model = TCN(in_shape=(4096,), out_size=3)
316
- >>> x = torch.randn(4, 1, 4096)
317
- >>> out = model(x) # (4, 3)
318
- """
319
-
320
- def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
321
- # Default: 8 layers, 64→512 channels
322
- num_channels = kwargs.pop(
323
- "num_channels", [64, 128, 256, 256, 512, 512, 512, 512]
324
- )
325
- super().__init__(
326
- in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
327
- )
328
-
329
- def __repr__(self) -> str:
330
- return (
331
- f"TCN(in_shape={self.in_shape}, out={self.out_size}, "
332
- f"RF={self.receptive_field})"
333
- )
334
-
335
-
336
- @register_model("tcn_small")
337
- class TCNSmall(TCNBase):
338
- """
339
- TCN-Small: Lightweight variant for quick experiments.
340
-
341
- ~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
342
- Receptive field: 127 samples with kernel_size=3.
343
-
344
- Recommended for:
345
- - Quick prototyping
346
- - Smaller datasets
347
- - Real-time inference on edge devices
348
-
349
- Args:
350
- in_shape: (L,) input signal length
351
- out_size: Number of regression targets
352
- kernel_size: Convolution kernel size (default: 3)
353
- dropout_rate: Dropout rate (default: 0.1)
354
-
355
- Example:
356
- >>> model = TCNSmall(in_shape=(1024,), out_size=3)
357
- >>> x = torch.randn(4, 1, 1024)
358
- >>> out = model(x) # (4, 3)
359
- """
360
-
361
- def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
362
- num_channels = [32, 64, 128, 128, 256, 256]
363
- super().__init__(
364
- in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
365
- )
366
-
367
- def __repr__(self) -> str:
368
- return (
369
- f"TCN_Small(in_shape={self.in_shape}, out={self.out_size}, "
370
- f"RF={self.receptive_field})"
371
- )
372
-
373
-
374
- @register_model("tcn_large")
375
- class TCNLarge(TCNBase):
376
- """
377
- TCN-Large: High-capacity variant for complex patterns.
378
-
379
- ~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
380
- Receptive field: 2047 samples with kernel_size=3.
381
-
382
- Recommended for:
383
- - Long sequences (>4096 samples)
384
- - Complex temporal patterns
385
- - Large datasets with sufficient compute
386
-
387
- Args:
388
- in_shape: (L,) input signal length
389
- out_size: Number of regression targets
390
- kernel_size: Convolution kernel size (default: 3)
391
- dropout_rate: Dropout rate (default: 0.1)
392
-
393
- Example:
394
- >>> model = TCNLarge(in_shape=(8192,), out_size=3)
395
- >>> x = torch.randn(4, 1, 8192)
396
- >>> out = model(x) # (4, 3)
397
- """
398
-
399
- def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
400
- num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
401
- super().__init__(
402
- in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
403
- )
404
-
405
- def __repr__(self) -> str:
406
- return (
407
- f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
408
- f"RF={self.receptive_field})"
409
- )
1
+ """
2
+ Temporal Convolutional Network (TCN): Dilated Causal Convolutions for 1D Signals
3
+ =================================================================================
4
+
5
+ A dedicated 1D architecture using dilated causal convolutions to capture
6
+ long-range temporal dependencies in waveforms and time-series data.
7
+ Provides exponentially growing receptive field with linear parameter growth.
8
+
9
+ **Key Features**:
10
+ - Dilated convolutions: Exponentially growing receptive field
11
+ - Causal padding: No information leakage from future
12
+ - Residual connections: Stable gradient flow
13
+ - Weight normalization: Faster convergence
14
+
15
+ **Variants**:
16
+ - tcn: Standard TCN with configurable depth and channels
17
+ - tcn_small: Lightweight variant for quick experiments
18
+ - tcn_large: Higher capacity for complex patterns
19
+
20
+ **Receptive Field Calculation**:
21
+ RF = 1 + (kernel_size - 1) * sum(dilation[i] for i in layers)
22
+ With default settings (kernel=3, 8 layers, dilation=2^i):
23
+ RF = 1 + 2 * (1+2+4+8+16+32+64+128) = 511 samples
24
+
25
+ **Note**: TCN is 1D-only. For 2D/3D data, use ResNet, EfficientNet, or Swin.
26
+
27
+ References:
28
+ Bai, S., Kolter, J.Z., & Koltun, V. (2018). An Empirical Evaluation of
29
+ Generic Convolutional and Recurrent Networks for Sequence Modeling.
30
+ arXiv:1803.01271. https://arxiv.org/abs/1803.01271
31
+
32
+ van den Oord, A., et al. (2016). WaveNet: A Generative Model for Raw Audio.
33
+ arXiv:1609.03499. https://arxiv.org/abs/1609.03499
34
+
35
+ Author: Ductho Le (ductho.le@outlook.com)
36
+ """
37
+
38
+ from typing import Any
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+
44
+ from wavedl.models.base import BaseModel, compute_num_groups
45
+ from wavedl.models.registry import register_model
46
+
47
+
48
+ class CausalConv1d(nn.Module):
49
+ """
50
+ Causal 1D convolution with dilation.
51
+
52
+ Ensures output at time t only depends on inputs at times <= t.
53
+ Uses left-side padding to achieve causal behavior.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ in_channels: int,
59
+ out_channels: int,
60
+ kernel_size: int,
61
+ dilation: int = 1,
62
+ ):
63
+ super().__init__()
64
+ self.kernel_size = kernel_size
65
+ self.dilation = dilation
66
+ # Causal padding: only pad on the left
67
+ self.padding = (kernel_size - 1) * dilation
68
+
69
+ self.conv = nn.Conv1d(
70
+ in_channels,
71
+ out_channels,
72
+ kernel_size,
73
+ dilation=dilation,
74
+ padding=0, # We handle padding manually for causality
75
+ )
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ # Pad on the left only (causal)
79
+ x = F.pad(x, (self.padding, 0))
80
+ return self.conv(x)
81
+
82
+
83
+ class TemporalBlock(nn.Module):
84
+ """
85
+ Temporal block with two causal dilated convolutions and residual connection.
86
+
87
+ Architecture:
88
+ Input → CausalConv → LayerNorm → GELU → Dropout →
89
+ CausalConv LayerNorm → GELU → Dropout → (+Input) → Output
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ in_channels: int,
95
+ out_channels: int,
96
+ kernel_size: int,
97
+ dilation: int,
98
+ dropout: float = 0.1,
99
+ ):
100
+ super().__init__()
101
+
102
+ # First causal convolution
103
+ self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
104
+ self.norm1 = nn.GroupNorm(
105
+ compute_num_groups(out_channels, preferred_groups=8), out_channels
106
+ )
107
+ self.act1 = nn.GELU()
108
+ self.dropout1 = nn.Dropout(dropout)
109
+
110
+ # Second causal convolution
111
+ self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
112
+ self.norm2 = nn.GroupNorm(
113
+ compute_num_groups(out_channels, preferred_groups=8), out_channels
114
+ )
115
+ self.act2 = nn.GELU()
116
+ self.dropout2 = nn.Dropout(dropout)
117
+
118
+ # Residual connection (1x1 conv if channels change)
119
+ self.downsample = (
120
+ nn.Conv1d(in_channels, out_channels, 1)
121
+ if in_channels != out_channels
122
+ else nn.Identity()
123
+ )
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ residual = self.downsample(x)
127
+
128
+ # First conv block
129
+ out = self.conv1(x)
130
+ out = self.norm1(out)
131
+ out = self.act1(out)
132
+ out = self.dropout1(out)
133
+
134
+ # Second conv block
135
+ out = self.conv2(out)
136
+ out = self.norm2(out)
137
+ out = self.act2(out)
138
+ out = self.dropout2(out)
139
+
140
+ return out + residual
141
+
142
+
143
+ class TCNBase(BaseModel):
144
+ """
145
+ Base Temporal Convolutional Network for 1D regression.
146
+
147
+ Architecture:
148
+ 1. Input projection (optional channel expansion)
149
+ 2. Stack of temporal blocks with exponentially increasing dilation
150
+ 3. Global average pooling
151
+ 4. Regression head
152
+
153
+ The receptive field grows exponentially with depth:
154
+ RF = 1 + (kernel_size - 1) * sum(2^i for i in range(num_layers))
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ in_shape: tuple[int],
160
+ out_size: int,
161
+ num_channels: list[int],
162
+ kernel_size: int = 3,
163
+ dropout_rate: float = 0.1,
164
+ **kwargs,
165
+ ):
166
+ """
167
+ Initialize TCN for regression.
168
+
169
+ Args:
170
+ in_shape: (L,) input signal length
171
+ out_size: Number of regression output targets
172
+ num_channels: List of channel sizes for each temporal block
173
+ kernel_size: Convolution kernel size (default: 3)
174
+ dropout_rate: Dropout rate (default: 0.1)
175
+ """
176
+ super().__init__(in_shape, out_size)
177
+
178
+ if len(in_shape) != 1:
179
+ raise ValueError(
180
+ f"TCN requires 1D input (L,), got {len(in_shape)}D. "
181
+ "For 2D/3D data, use ResNet, EfficientNet, or Swin."
182
+ )
183
+
184
+ self.num_channels = num_channels
185
+ self.kernel_size = kernel_size
186
+ self.dropout_rate = dropout_rate
187
+
188
+ # Build temporal blocks with exponentially increasing dilation
189
+ layers = []
190
+ num_levels = len(num_channels)
191
+
192
+ for i in range(num_levels):
193
+ dilation = 2**i
194
+ in_ch = 1 if i == 0 else num_channels[i - 1]
195
+ out_ch = num_channels[i]
196
+ layers.append(
197
+ TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout_rate)
198
+ )
199
+
200
+ self.network = nn.Sequential(*layers)
201
+
202
+ # Global pooling
203
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
204
+
205
+ # Regression head
206
+ final_channels = num_channels[-1]
207
+ self.head = nn.Sequential(
208
+ nn.Dropout(dropout_rate),
209
+ nn.Linear(final_channels, 256),
210
+ nn.GELU(),
211
+ nn.Dropout(dropout_rate * 0.5),
212
+ nn.Linear(256, 128),
213
+ nn.GELU(),
214
+ nn.Linear(128, out_size),
215
+ )
216
+
217
+ # Calculate and store receptive field
218
+ self.receptive_field = self._compute_receptive_field()
219
+
220
+ # Initialize weights
221
+ self._init_weights()
222
+
223
+ def _compute_receptive_field(self) -> int:
224
+ """Compute the receptive field of the network."""
225
+ rf = 1
226
+ for i in range(len(self.num_channels)):
227
+ dilation = 2**i
228
+ # Each temporal block has 2 convolutions
229
+ rf += 2 * (self.kernel_size - 1) * dilation
230
+ return rf
231
+
232
+ def _init_weights(self):
233
+ """Initialize weights using Kaiming initialization."""
234
+ for m in self.modules():
235
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
236
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
237
+ if m.bias is not None:
238
+ nn.init.constant_(m.bias, 0)
239
+ elif isinstance(m, nn.GroupNorm):
240
+ nn.init.constant_(m.weight, 1)
241
+ nn.init.constant_(m.bias, 0)
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ """
245
+ Forward pass.
246
+
247
+ Args:
248
+ x: Input tensor of shape (B, 1, L)
249
+
250
+ Returns:
251
+ Output tensor of shape (B, out_size)
252
+ """
253
+ # Temporal blocks
254
+ x = self.network(x)
255
+
256
+ # Global pooling
257
+ x = self.global_pool(x)
258
+ x = x.flatten(1)
259
+
260
+ # Regression head
261
+ return self.head(x)
262
+
263
+ @classmethod
264
+ def get_default_config(cls) -> dict[str, Any]:
265
+ """Return default configuration for TCN."""
266
+ return {
267
+ "num_channels": [64, 128, 256, 256, 512, 512, 512, 512],
268
+ "kernel_size": 3,
269
+ "dropout_rate": 0.1,
270
+ }
271
+
272
+
273
+ # =============================================================================
274
+ # REGISTERED MODEL VARIANTS
275
+ # =============================================================================
276
+
277
+
278
+ @register_model("tcn")
279
+ class TCN(TCNBase):
280
+ """
281
+ TCN: Standard Temporal Convolutional Network.
282
+
283
+ ~6.9M backbone parameters. 8 temporal blocks with channels [64128256256512512512512].
284
+ Receptive field: 511 samples with kernel_size=3.
285
+
286
+ Recommended for:
287
+ - Ultrasonic A-scan processing
288
+ - Acoustic emission signals
289
+ - Seismic waveform analysis
290
+ - Any 1D time-series regression
291
+
292
+ Args:
293
+ in_shape: (L,) input signal length
294
+ out_size: Number of regression targets
295
+ kernel_size: Convolution kernel size (default: 3)
296
+ dropout_rate: Dropout rate (default: 0.1)
297
+
298
+ Example:
299
+ >>> model = TCN(in_shape=(4096,), out_size=3)
300
+ >>> x = torch.randn(4, 1, 4096)
301
+ >>> out = model(x) # (4, 3)
302
+ """
303
+
304
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
305
+ # Default: 8 layers, 64→512 channels
306
+ num_channels = kwargs.pop(
307
+ "num_channels", [64, 128, 256, 256, 512, 512, 512, 512]
308
+ )
309
+ super().__init__(
310
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
311
+ )
312
+
313
+ def __repr__(self) -> str:
314
+ return (
315
+ f"TCN(in_shape={self.in_shape}, out={self.out_size}, "
316
+ f"RF={self.receptive_field})"
317
+ )
318
+
319
+
320
+ @register_model("tcn_small")
321
+ class TCNSmall(TCNBase):
322
+ """
323
+ TCN-Small: Lightweight variant for quick experiments.
324
+
325
+ ~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
326
+ Receptive field: 127 samples with kernel_size=3.
327
+
328
+ Recommended for:
329
+ - Quick prototyping
330
+ - Smaller datasets
331
+ - Real-time inference on edge devices
332
+
333
+ Args:
334
+ in_shape: (L,) input signal length
335
+ out_size: Number of regression targets
336
+ kernel_size: Convolution kernel size (default: 3)
337
+ dropout_rate: Dropout rate (default: 0.1)
338
+
339
+ Example:
340
+ >>> model = TCNSmall(in_shape=(1024,), out_size=3)
341
+ >>> x = torch.randn(4, 1, 1024)
342
+ >>> out = model(x) # (4, 3)
343
+ """
344
+
345
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
346
+ num_channels = [32, 64, 128, 128, 256, 256]
347
+ super().__init__(
348
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
349
+ )
350
+
351
+ def __repr__(self) -> str:
352
+ return (
353
+ f"TCN_Small(in_shape={self.in_shape}, out={self.out_size}, "
354
+ f"RF={self.receptive_field})"
355
+ )
356
+
357
+
358
+ @register_model("tcn_large")
359
+ class TCNLarge(TCNBase):
360
+ """
361
+ TCN-Large: High-capacity variant for complex patterns.
362
+
363
+ ~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
364
+ Receptive field: 2047 samples with kernel_size=3.
365
+
366
+ Recommended for:
367
+ - Long sequences (>4096 samples)
368
+ - Complex temporal patterns
369
+ - Large datasets with sufficient compute
370
+
371
+ Args:
372
+ in_shape: (L,) input signal length
373
+ out_size: Number of regression targets
374
+ kernel_size: Convolution kernel size (default: 3)
375
+ dropout_rate: Dropout rate (default: 0.1)
376
+
377
+ Example:
378
+ >>> model = TCNLarge(in_shape=(8192,), out_size=3)
379
+ >>> x = torch.randn(4, 1, 8192)
380
+ >>> out = model(x) # (4, 3)
381
+ """
382
+
383
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
384
+ num_channels = [64, 128, 256, 256, 512, 512, 512, 512, 512, 512]
385
+ super().__init__(
386
+ in_shape=in_shape, out_size=out_size, num_channels=num_channels, **kwargs
387
+ )
388
+
389
+ def __repr__(self) -> str:
390
+ return (
391
+ f"TCN_Large(in_shape={self.in_shape}, out={self.out_size}, "
392
+ f"RF={self.receptive_field})"
393
+ )