braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/eegitnet.py
CHANGED
|
@@ -2,94 +2,23 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD-3
|
|
4
4
|
import torch
|
|
5
|
-
from torch import nn
|
|
6
5
|
from einops.layers.torch import Rearrange
|
|
6
|
+
from torch import nn
|
|
7
7
|
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class _DepthwiseConv2d(torch.nn.Conv2d):
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
in_channels,
|
|
16
|
-
depth_multiplier=2,
|
|
17
|
-
kernel_size=3,
|
|
18
|
-
stride=1,
|
|
19
|
-
padding=0,
|
|
20
|
-
dilation=1,
|
|
21
|
-
bias=True,
|
|
22
|
-
padding_mode="zeros",
|
|
23
|
-
):
|
|
24
|
-
out_channels = in_channels * depth_multiplier
|
|
25
|
-
super().__init__(
|
|
26
|
-
in_channels=in_channels,
|
|
27
|
-
out_channels=out_channels,
|
|
28
|
-
kernel_size=kernel_size,
|
|
29
|
-
stride=stride,
|
|
30
|
-
padding=padding,
|
|
31
|
-
dilation=dilation,
|
|
32
|
-
groups=in_channels,
|
|
33
|
-
bias=bias,
|
|
34
|
-
padding_mode=padding_mode,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class _InceptionBlock(nn.Module):
|
|
39
|
-
def __init__(self, branches):
|
|
40
|
-
super().__init__()
|
|
41
|
-
self.branches = nn.ModuleList(branches)
|
|
42
|
-
|
|
43
|
-
def forward(self, x):
|
|
44
|
-
return torch.cat([branch(x) for branch in self.branches], 1)
|
|
45
|
-
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
46
10
|
|
|
47
|
-
class _TCBlock(nn.Module):
|
|
48
|
-
def __init__(self, in_ch, kernel_length, dialation, padding, drop_prob=0.4):
|
|
49
|
-
super().__init__()
|
|
50
|
-
self.pad = padding
|
|
51
|
-
self.tc1 = nn.Sequential(
|
|
52
|
-
_DepthwiseConv2d(
|
|
53
|
-
in_ch,
|
|
54
|
-
kernel_size=(1, kernel_length),
|
|
55
|
-
depth_multiplier=1,
|
|
56
|
-
dilation=(1, dialation),
|
|
57
|
-
bias=False,
|
|
58
|
-
padding="valid",
|
|
59
|
-
),
|
|
60
|
-
nn.BatchNorm2d(in_ch),
|
|
61
|
-
nn.ELU(),
|
|
62
|
-
nn.Dropout(drop_prob),
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
self.tc2 = nn.Sequential(
|
|
66
|
-
_DepthwiseConv2d(
|
|
67
|
-
in_ch,
|
|
68
|
-
kernel_size=(1, kernel_length),
|
|
69
|
-
depth_multiplier=1,
|
|
70
|
-
dilation=(1, dialation),
|
|
71
|
-
bias=False,
|
|
72
|
-
padding="valid",
|
|
73
|
-
),
|
|
74
|
-
nn.BatchNorm2d(in_ch),
|
|
75
|
-
nn.ELU(),
|
|
76
|
-
nn.Dropout(drop_prob),
|
|
77
|
-
)
|
|
78
11
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
|
|
82
|
-
x = nn.functional.pad(x, paddings)
|
|
83
|
-
x = self.tc1(x)
|
|
84
|
-
x = nn.functional.pad(x, paddings)
|
|
85
|
-
x = self.tc2(x) + residual
|
|
86
|
-
return x
|
|
12
|
+
class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
13
|
+
"""EEG-ITNet from Salami, et al (2022) [Salami2022]_
|
|
87
14
|
|
|
15
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: EEG-ITNet Architecture
|
|
88
18
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
Salami et. al 2022.
|
|
19
|
+
EEG-ITNet: An Explainable Inception Temporal
|
|
20
|
+
Convolutional Network for motor imagery classification from
|
|
21
|
+
Salami et al. 2022.
|
|
93
22
|
|
|
94
23
|
See [Salami2022]_ for details.
|
|
95
24
|
|
|
@@ -99,45 +28,62 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
99
28
|
----------
|
|
100
29
|
drop_prob: float
|
|
101
30
|
Dropout probability.
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
31
|
+
activation: nn.Module, default=nn.ELU
|
|
32
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
33
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
34
|
+
kernel_length : int, optional
|
|
35
|
+
Kernel length for inception branches. Determines the temporal receptive field.
|
|
36
|
+
Default is 16.
|
|
37
|
+
pool_kernel : int, optional
|
|
38
|
+
Pooling kernel size for the average pooling layer. Default is 4.
|
|
39
|
+
tcn_in_channel : int, optional
|
|
40
|
+
Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
|
|
41
|
+
tcn_kernel_size : int, optional
|
|
42
|
+
Kernel size for the TC blocks. Determines the temporal receptive field.
|
|
43
|
+
Default is 4.
|
|
44
|
+
tcn_padding : int, optional
|
|
45
|
+
Padding size for the TC blocks to maintain the input dimensions. Default is 3.
|
|
46
|
+
drop_prob : float, optional
|
|
47
|
+
Dropout probability applied after certain layers to prevent overfitting.
|
|
48
|
+
Default is 0.4.
|
|
49
|
+
tcn_dilatation : int, optional
|
|
50
|
+
Dilation rate for the first TC block. Subsequent blocks will have
|
|
51
|
+
dilation rates multiplied by powers of 2. Default is 1.
|
|
114
52
|
|
|
115
53
|
Notes
|
|
116
54
|
-----
|
|
117
55
|
This implementation is not guaranteed to be correct, has not been checked
|
|
118
56
|
by original authors, only reimplemented from the paper based on author implementation.
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
References
|
|
60
|
+
----------
|
|
61
|
+
.. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
|
|
62
|
+
An Explainable Inception Temporal Convolutional Network for motor
|
|
63
|
+
imagery classification," in IEEE Access,
|
|
64
|
+
doi: 10.1109/ACCESS.2022.3161489.
|
|
119
65
|
"""
|
|
120
66
|
|
|
121
67
|
def __init__(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
68
|
+
self,
|
|
69
|
+
# Braindecode parameters
|
|
70
|
+
n_outputs=None,
|
|
71
|
+
n_chans=None,
|
|
72
|
+
n_times=None,
|
|
73
|
+
chs_info=None,
|
|
74
|
+
input_window_seconds=None,
|
|
75
|
+
sfreq=None,
|
|
76
|
+
# Model parameters
|
|
77
|
+
n_filters_time: int = 2,
|
|
78
|
+
kernel_length: int = 16,
|
|
79
|
+
pool_kernel: int = 4,
|
|
80
|
+
tcn_in_channel: int = 14,
|
|
81
|
+
tcn_kernel_size: int = 4,
|
|
82
|
+
tcn_padding: int = 3,
|
|
83
|
+
drop_prob: float = 0.4,
|
|
84
|
+
tcn_dilatation: int = 1,
|
|
85
|
+
activation: nn.Module = nn.ELU,
|
|
134
86
|
):
|
|
135
|
-
n_outputs, n_chans, n_times, = deprecated_args(
|
|
136
|
-
self,
|
|
137
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
138
|
-
('in_channels', 'n_chans', in_channels, n_chans),
|
|
139
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
140
|
-
)
|
|
141
87
|
super().__init__(
|
|
142
88
|
n_outputs=n_outputs,
|
|
143
89
|
n_chans=n_chans,
|
|
@@ -145,88 +91,131 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
145
91
|
n_times=n_times,
|
|
146
92
|
input_window_seconds=input_window_seconds,
|
|
147
93
|
sfreq=sfreq,
|
|
148
|
-
add_log_softmax=add_log_softmax,
|
|
149
94
|
)
|
|
150
95
|
self.mapping = {
|
|
151
|
-
|
|
152
|
-
|
|
96
|
+
"classification.1.weight": "final_layer.clf.weight",
|
|
97
|
+
"classification.1.bias": "final_layer.clf.weight",
|
|
98
|
+
}
|
|
153
99
|
|
|
154
100
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
155
|
-
del n_classes, in_channels, input_window_samples
|
|
156
101
|
|
|
157
102
|
# ======== Handling EEG input ========================
|
|
158
103
|
self.add_module(
|
|
159
|
-
"input_preprocess",
|
|
160
|
-
|
|
161
|
-
"ba ch t 1 -> ba 1 ch t"))
|
|
104
|
+
"input_preprocess",
|
|
105
|
+
nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
|
|
162
106
|
)
|
|
163
107
|
# ======== Inception branches ========================
|
|
164
108
|
block11 = self._get_inception_branch(
|
|
165
|
-
in_channels=self.n_chans,
|
|
109
|
+
in_channels=self.n_chans,
|
|
110
|
+
out_channels=n_filters_time,
|
|
111
|
+
kernel_length=kernel_length,
|
|
112
|
+
activation=activation,
|
|
166
113
|
)
|
|
167
114
|
block12 = self._get_inception_branch(
|
|
168
|
-
in_channels=self.n_chans,
|
|
115
|
+
in_channels=self.n_chans,
|
|
116
|
+
out_channels=n_filters_time * 2,
|
|
117
|
+
kernel_length=kernel_length * 2,
|
|
118
|
+
activation=activation,
|
|
169
119
|
)
|
|
170
120
|
block13 = self._get_inception_branch(
|
|
171
|
-
in_channels=self.n_chans,
|
|
121
|
+
in_channels=self.n_chans,
|
|
122
|
+
out_channels=n_filters_time * 4,
|
|
123
|
+
kernel_length=n_filters_time * 4,
|
|
124
|
+
activation=activation,
|
|
125
|
+
)
|
|
126
|
+
self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
|
|
127
|
+
self.pool1 = self.add_module(
|
|
128
|
+
"pooling",
|
|
129
|
+
nn.Sequential(
|
|
130
|
+
nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
|
|
131
|
+
),
|
|
172
132
|
)
|
|
173
|
-
self.add_module("inception_block", _InceptionBlock((block11, block12, block13)))
|
|
174
|
-
self.pool1 = self.add_module("pooling", nn.Sequential(
|
|
175
|
-
nn.AvgPool2d(kernel_size=(1, 4)),
|
|
176
|
-
nn.Dropout(drop_prob)))
|
|
177
133
|
# =========== TC blocks =====================
|
|
178
134
|
self.add_module(
|
|
179
135
|
"TC_block1",
|
|
180
|
-
_TCBlock(
|
|
136
|
+
_TCBlock(
|
|
137
|
+
in_ch=tcn_in_channel,
|
|
138
|
+
kernel_length=tcn_kernel_size,
|
|
139
|
+
dilatation=tcn_dilatation,
|
|
140
|
+
padding=tcn_padding,
|
|
141
|
+
drop_prob=drop_prob,
|
|
142
|
+
activation=activation,
|
|
143
|
+
),
|
|
181
144
|
)
|
|
182
145
|
# ================================
|
|
183
146
|
self.add_module(
|
|
184
147
|
"TC_block2",
|
|
185
|
-
_TCBlock(
|
|
148
|
+
_TCBlock(
|
|
149
|
+
in_ch=tcn_in_channel,
|
|
150
|
+
kernel_length=tcn_kernel_size,
|
|
151
|
+
dilatation=tcn_dilatation * 2,
|
|
152
|
+
padding=tcn_padding * 2,
|
|
153
|
+
drop_prob=drop_prob,
|
|
154
|
+
activation=activation,
|
|
155
|
+
),
|
|
186
156
|
)
|
|
187
157
|
# ================================
|
|
188
158
|
self.add_module(
|
|
189
159
|
"TC_block3",
|
|
190
|
-
_TCBlock(
|
|
160
|
+
_TCBlock(
|
|
161
|
+
in_ch=tcn_in_channel,
|
|
162
|
+
kernel_length=tcn_kernel_size,
|
|
163
|
+
dilatation=tcn_dilatation * 4,
|
|
164
|
+
padding=tcn_padding * 4,
|
|
165
|
+
drop_prob=drop_prob,
|
|
166
|
+
activation=activation,
|
|
167
|
+
),
|
|
191
168
|
)
|
|
192
169
|
# ================================
|
|
193
170
|
self.add_module(
|
|
194
171
|
"TC_block4",
|
|
195
|
-
_TCBlock(
|
|
172
|
+
_TCBlock(
|
|
173
|
+
in_ch=tcn_in_channel,
|
|
174
|
+
kernel_length=tcn_kernel_size,
|
|
175
|
+
dilatation=tcn_dilatation * 8,
|
|
176
|
+
padding=tcn_padding * 8,
|
|
177
|
+
drop_prob=drop_prob,
|
|
178
|
+
activation=activation,
|
|
179
|
+
),
|
|
196
180
|
)
|
|
197
181
|
|
|
198
182
|
# ============= Dimensionality reduction ===================
|
|
199
|
-
self.add_module(
|
|
200
|
-
|
|
201
|
-
nn.
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
183
|
+
self.add_module(
|
|
184
|
+
"dim_reduction",
|
|
185
|
+
nn.Sequential(
|
|
186
|
+
nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
|
|
187
|
+
nn.BatchNorm2d(tcn_in_channel * 2),
|
|
188
|
+
activation(),
|
|
189
|
+
nn.AvgPool2d((1, tcn_kernel_size)),
|
|
190
|
+
nn.Dropout(drop_prob),
|
|
191
|
+
),
|
|
192
|
+
)
|
|
205
193
|
# ============== Classifier ==================
|
|
206
194
|
# Moved flatten to another layer
|
|
207
195
|
self.add_module("flatten", nn.Flatten())
|
|
208
196
|
|
|
209
|
-
|
|
210
|
-
module = nn.Sequential()
|
|
211
|
-
|
|
212
|
-
module.add_module("clf",
|
|
213
|
-
nn.Linear(int(int(self.n_times / 4) / 4) * 28, self.n_outputs))
|
|
197
|
+
num_features = self.get_output_shape()[-1]
|
|
214
198
|
|
|
215
|
-
|
|
216
|
-
module.add_module("out_fun", nn.LogSoftmax(dim=1))
|
|
217
|
-
else:
|
|
218
|
-
module.add_module("out_fun", nn.Identity())
|
|
219
|
-
|
|
220
|
-
self.add_module("final_layer", module)
|
|
199
|
+
self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
|
|
221
200
|
|
|
222
201
|
@staticmethod
|
|
223
|
-
def _get_inception_branch(
|
|
202
|
+
def _get_inception_branch(
|
|
203
|
+
in_channels,
|
|
204
|
+
out_channels,
|
|
205
|
+
kernel_length,
|
|
206
|
+
depth_multiplier=1,
|
|
207
|
+
activation: nn.Module = nn.ELU,
|
|
208
|
+
):
|
|
224
209
|
return nn.Sequential(
|
|
225
210
|
nn.Conv2d(
|
|
226
|
-
1,
|
|
211
|
+
1,
|
|
212
|
+
out_channels,
|
|
213
|
+
kernel_size=(1, kernel_length),
|
|
214
|
+
padding="same",
|
|
215
|
+
bias=False,
|
|
227
216
|
),
|
|
228
217
|
nn.BatchNorm2d(out_channels),
|
|
229
|
-
|
|
218
|
+
DepthwiseConv2d(
|
|
230
219
|
out_channels,
|
|
231
220
|
kernel_size=(in_channels, 1),
|
|
232
221
|
depth_multiplier=depth_multiplier,
|
|
@@ -234,4 +223,79 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
234
223
|
padding="valid",
|
|
235
224
|
),
|
|
236
225
|
nn.BatchNorm2d(out_channels),
|
|
237
|
-
|
|
226
|
+
activation(),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class _TCBlock(nn.Module):
|
|
231
|
+
"""
|
|
232
|
+
Temporal Convolutional (TC) block.
|
|
233
|
+
|
|
234
|
+
This module applies two depthwise separable convolutions with dilation and residual
|
|
235
|
+
connections, commonly used in temporal convolutional networks to capture long-range
|
|
236
|
+
dependencies in time-series data.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
in_ch : int
|
|
241
|
+
Number of input channels.
|
|
242
|
+
kernel_length : int
|
|
243
|
+
Length of the convolutional kernels.
|
|
244
|
+
dilatation : int
|
|
245
|
+
Dilatation rate for the convolutions.
|
|
246
|
+
padding : int
|
|
247
|
+
Amount of padding to add to the input.
|
|
248
|
+
drop_prob : float, optional
|
|
249
|
+
Dropout probability. Default is 0.4.
|
|
250
|
+
activation : nn.Module class, optional
|
|
251
|
+
Activation function class to use. Should be a PyTorch activation module class
|
|
252
|
+
like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def __init__(
|
|
256
|
+
self,
|
|
257
|
+
in_ch,
|
|
258
|
+
kernel_length,
|
|
259
|
+
dilatation,
|
|
260
|
+
padding,
|
|
261
|
+
drop_prob=0.4,
|
|
262
|
+
activation: nn.Module = nn.ELU,
|
|
263
|
+
):
|
|
264
|
+
super().__init__()
|
|
265
|
+
self.pad = padding
|
|
266
|
+
self.tc1 = nn.Sequential(
|
|
267
|
+
DepthwiseConv2d(
|
|
268
|
+
in_ch,
|
|
269
|
+
kernel_size=(1, kernel_length),
|
|
270
|
+
depth_multiplier=1,
|
|
271
|
+
dilation=(1, dilatation),
|
|
272
|
+
bias=False,
|
|
273
|
+
padding="valid",
|
|
274
|
+
),
|
|
275
|
+
nn.BatchNorm2d(in_ch),
|
|
276
|
+
activation(),
|
|
277
|
+
nn.Dropout(drop_prob),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
self.tc2 = nn.Sequential(
|
|
281
|
+
DepthwiseConv2d(
|
|
282
|
+
in_ch,
|
|
283
|
+
kernel_size=(1, kernel_length),
|
|
284
|
+
depth_multiplier=1,
|
|
285
|
+
dilation=(1, dilatation),
|
|
286
|
+
bias=False,
|
|
287
|
+
padding="valid",
|
|
288
|
+
),
|
|
289
|
+
nn.BatchNorm2d(in_ch),
|
|
290
|
+
activation(),
|
|
291
|
+
nn.Dropout(drop_prob),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def forward(self, x):
|
|
295
|
+
residual = x
|
|
296
|
+
paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
|
|
297
|
+
x = nn.functional.pad(x, paddings)
|
|
298
|
+
x = self.tc1(x)
|
|
299
|
+
x = nn.functional.pad(x, paddings)
|
|
300
|
+
x = self.tc2(x) + residual
|
|
301
|
+
return x
|