braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/eegitnet.py
CHANGED
|
@@ -1,95 +1,23 @@
|
|
|
1
1
|
# Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD-3
|
|
4
|
-
import torch
|
|
5
|
-
from torch import nn
|
|
6
4
|
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
7
6
|
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class _DepthwiseConv2d(torch.nn.Conv2d):
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
in_channels,
|
|
16
|
-
depth_multiplier=2,
|
|
17
|
-
kernel_size=3,
|
|
18
|
-
stride=1,
|
|
19
|
-
padding=0,
|
|
20
|
-
dilation=1,
|
|
21
|
-
bias=True,
|
|
22
|
-
padding_mode="zeros",
|
|
23
|
-
):
|
|
24
|
-
out_channels = in_channels * depth_multiplier
|
|
25
|
-
super().__init__(
|
|
26
|
-
in_channels=in_channels,
|
|
27
|
-
out_channels=out_channels,
|
|
28
|
-
kernel_size=kernel_size,
|
|
29
|
-
stride=stride,
|
|
30
|
-
padding=padding,
|
|
31
|
-
dilation=dilation,
|
|
32
|
-
groups=in_channels,
|
|
33
|
-
bias=bias,
|
|
34
|
-
padding_mode=padding_mode,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class _InceptionBlock(nn.Module):
|
|
39
|
-
def __init__(self, branches):
|
|
40
|
-
super().__init__()
|
|
41
|
-
self.branches = nn.ModuleList(branches)
|
|
42
|
-
|
|
43
|
-
def forward(self, x):
|
|
44
|
-
return torch.cat([branch(x) for branch in self.branches], 1)
|
|
45
|
-
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
46
9
|
|
|
47
|
-
class _TCBlock(nn.Module):
|
|
48
|
-
def __init__(self, in_ch, kernel_length, dialation, padding, drop_prob=0.4):
|
|
49
|
-
super().__init__()
|
|
50
|
-
self.pad = padding
|
|
51
|
-
self.tc1 = nn.Sequential(
|
|
52
|
-
_DepthwiseConv2d(
|
|
53
|
-
in_ch,
|
|
54
|
-
kernel_size=(1, kernel_length),
|
|
55
|
-
depth_multiplier=1,
|
|
56
|
-
dilation=(1, dialation),
|
|
57
|
-
bias=False,
|
|
58
|
-
padding="valid",
|
|
59
|
-
),
|
|
60
|
-
nn.BatchNorm2d(in_ch),
|
|
61
|
-
nn.ELU(),
|
|
62
|
-
nn.Dropout(drop_prob),
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
self.tc2 = nn.Sequential(
|
|
66
|
-
_DepthwiseConv2d(
|
|
67
|
-
in_ch,
|
|
68
|
-
kernel_size=(1, kernel_length),
|
|
69
|
-
depth_multiplier=1,
|
|
70
|
-
dilation=(1, dialation),
|
|
71
|
-
bias=False,
|
|
72
|
-
padding="valid",
|
|
73
|
-
),
|
|
74
|
-
nn.BatchNorm2d(in_ch),
|
|
75
|
-
nn.ELU(),
|
|
76
|
-
nn.Dropout(drop_prob),
|
|
77
|
-
)
|
|
78
10
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
|
|
82
|
-
x = nn.functional.pad(x, paddings)
|
|
83
|
-
x = self.tc1(x)
|
|
84
|
-
x = nn.functional.pad(x, paddings)
|
|
85
|
-
x = self.tc2(x) + residual
|
|
86
|
-
return x
|
|
11
|
+
class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
12
|
+
"""EEG-ITNet from Salami, et al (2022) [Salami2022]_
|
|
87
13
|
|
|
14
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
|
|
15
|
+
:align: center
|
|
16
|
+
:alt: EEG-ITNet Architecture
|
|
88
17
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
Salami et. al 2022.
|
|
18
|
+
EEG-ITNet: An Explainable Inception Temporal
|
|
19
|
+
Convolutional Network for motor imagery classification from
|
|
20
|
+
Salami et al. 2022.
|
|
93
21
|
|
|
94
22
|
See [Salami2022]_ for details.
|
|
95
23
|
|
|
@@ -99,45 +27,62 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
99
27
|
----------
|
|
100
28
|
drop_prob: float
|
|
101
29
|
Dropout probability.
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
30
|
+
activation: nn.Module, default=nn.ELU
|
|
31
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
32
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
33
|
+
kernel_length : int, optional
|
|
34
|
+
Kernel length for inception branches. Determines the temporal receptive field.
|
|
35
|
+
Default is 16.
|
|
36
|
+
pool_kernel : int, optional
|
|
37
|
+
Pooling kernel size for the average pooling layer. Default is 4.
|
|
38
|
+
tcn_in_channel : int, optional
|
|
39
|
+
Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
|
|
40
|
+
tcn_kernel_size : int, optional
|
|
41
|
+
Kernel size for the TC blocks. Determines the temporal receptive field.
|
|
42
|
+
Default is 4.
|
|
43
|
+
tcn_padding : int, optional
|
|
44
|
+
Padding size for the TC blocks to maintain the input dimensions. Default is 3.
|
|
45
|
+
drop_prob : float, optional
|
|
46
|
+
Dropout probability applied after certain layers to prevent overfitting.
|
|
47
|
+
Default is 0.4.
|
|
48
|
+
tcn_dilatation : int, optional
|
|
49
|
+
Dilation rate for the first TC block. Subsequent blocks will have
|
|
50
|
+
dilation rates multiplied by powers of 2. Default is 1.
|
|
114
51
|
|
|
115
52
|
Notes
|
|
116
53
|
-----
|
|
117
54
|
This implementation is not guaranteed to be correct, has not been checked
|
|
118
55
|
by original authors, only reimplemented from the paper based on author implementation.
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
References
|
|
59
|
+
----------
|
|
60
|
+
.. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
|
|
61
|
+
An Explainable Inception Temporal Convolutional Network for motor
|
|
62
|
+
imagery classification," in IEEE Access,
|
|
63
|
+
doi: 10.1109/ACCESS.2022.3161489.
|
|
119
64
|
"""
|
|
120
65
|
|
|
121
66
|
def __init__(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
67
|
+
self,
|
|
68
|
+
# Braindecode parameters
|
|
69
|
+
n_outputs=None,
|
|
70
|
+
n_chans=None,
|
|
71
|
+
n_times=None,
|
|
72
|
+
chs_info=None,
|
|
73
|
+
input_window_seconds=None,
|
|
74
|
+
sfreq=None,
|
|
75
|
+
# Model parameters
|
|
76
|
+
n_filters_time: int = 2,
|
|
77
|
+
kernel_length: int = 16,
|
|
78
|
+
pool_kernel: int = 4,
|
|
79
|
+
tcn_in_channel: int = 14,
|
|
80
|
+
tcn_kernel_size: int = 4,
|
|
81
|
+
tcn_padding: int = 3,
|
|
82
|
+
drop_prob: float = 0.4,
|
|
83
|
+
tcn_dilatation: int = 1,
|
|
84
|
+
activation: nn.Module = nn.ELU,
|
|
134
85
|
):
|
|
135
|
-
n_outputs, n_chans, n_times, = deprecated_args(
|
|
136
|
-
self,
|
|
137
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
138
|
-
('in_channels', 'n_chans', in_channels, n_chans),
|
|
139
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
140
|
-
)
|
|
141
86
|
super().__init__(
|
|
142
87
|
n_outputs=n_outputs,
|
|
143
88
|
n_chans=n_chans,
|
|
@@ -145,88 +90,131 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
145
90
|
n_times=n_times,
|
|
146
91
|
input_window_seconds=input_window_seconds,
|
|
147
92
|
sfreq=sfreq,
|
|
148
|
-
add_log_softmax=add_log_softmax,
|
|
149
93
|
)
|
|
150
94
|
self.mapping = {
|
|
151
|
-
|
|
152
|
-
|
|
95
|
+
"classification.1.weight": "final_layer.clf.weight",
|
|
96
|
+
"classification.1.bias": "final_layer.clf.weight",
|
|
97
|
+
}
|
|
153
98
|
|
|
154
99
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
155
|
-
del n_classes, in_channels, input_window_samples
|
|
156
100
|
|
|
157
101
|
# ======== Handling EEG input ========================
|
|
158
102
|
self.add_module(
|
|
159
|
-
"input_preprocess",
|
|
160
|
-
|
|
161
|
-
"ba ch t 1 -> ba 1 ch t"))
|
|
103
|
+
"input_preprocess",
|
|
104
|
+
nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
|
|
162
105
|
)
|
|
163
106
|
# ======== Inception branches ========================
|
|
164
107
|
block11 = self._get_inception_branch(
|
|
165
|
-
in_channels=self.n_chans,
|
|
108
|
+
in_channels=self.n_chans,
|
|
109
|
+
out_channels=n_filters_time,
|
|
110
|
+
kernel_length=kernel_length,
|
|
111
|
+
activation=activation,
|
|
166
112
|
)
|
|
167
113
|
block12 = self._get_inception_branch(
|
|
168
|
-
in_channels=self.n_chans,
|
|
114
|
+
in_channels=self.n_chans,
|
|
115
|
+
out_channels=n_filters_time * 2,
|
|
116
|
+
kernel_length=kernel_length * 2,
|
|
117
|
+
activation=activation,
|
|
169
118
|
)
|
|
170
119
|
block13 = self._get_inception_branch(
|
|
171
|
-
in_channels=self.n_chans,
|
|
120
|
+
in_channels=self.n_chans,
|
|
121
|
+
out_channels=n_filters_time * 4,
|
|
122
|
+
kernel_length=n_filters_time * 4,
|
|
123
|
+
activation=activation,
|
|
124
|
+
)
|
|
125
|
+
self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
|
|
126
|
+
self.pool1 = self.add_module(
|
|
127
|
+
"pooling",
|
|
128
|
+
nn.Sequential(
|
|
129
|
+
nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
|
|
130
|
+
),
|
|
172
131
|
)
|
|
173
|
-
self.add_module("inception_block", _InceptionBlock((block11, block12, block13)))
|
|
174
|
-
self.pool1 = self.add_module("pooling", nn.Sequential(
|
|
175
|
-
nn.AvgPool2d(kernel_size=(1, 4)),
|
|
176
|
-
nn.Dropout(drop_prob)))
|
|
177
132
|
# =========== TC blocks =====================
|
|
178
133
|
self.add_module(
|
|
179
134
|
"TC_block1",
|
|
180
|
-
_TCBlock(
|
|
135
|
+
_TCBlock(
|
|
136
|
+
in_ch=tcn_in_channel,
|
|
137
|
+
kernel_length=tcn_kernel_size,
|
|
138
|
+
dilatation=tcn_dilatation,
|
|
139
|
+
padding=tcn_padding,
|
|
140
|
+
drop_prob=drop_prob,
|
|
141
|
+
activation=activation,
|
|
142
|
+
),
|
|
181
143
|
)
|
|
182
144
|
# ================================
|
|
183
145
|
self.add_module(
|
|
184
146
|
"TC_block2",
|
|
185
|
-
_TCBlock(
|
|
147
|
+
_TCBlock(
|
|
148
|
+
in_ch=tcn_in_channel,
|
|
149
|
+
kernel_length=tcn_kernel_size,
|
|
150
|
+
dilatation=tcn_dilatation * 2,
|
|
151
|
+
padding=tcn_padding * 2,
|
|
152
|
+
drop_prob=drop_prob,
|
|
153
|
+
activation=activation,
|
|
154
|
+
),
|
|
186
155
|
)
|
|
187
156
|
# ================================
|
|
188
157
|
self.add_module(
|
|
189
158
|
"TC_block3",
|
|
190
|
-
_TCBlock(
|
|
159
|
+
_TCBlock(
|
|
160
|
+
in_ch=tcn_in_channel,
|
|
161
|
+
kernel_length=tcn_kernel_size,
|
|
162
|
+
dilatation=tcn_dilatation * 4,
|
|
163
|
+
padding=tcn_padding * 4,
|
|
164
|
+
drop_prob=drop_prob,
|
|
165
|
+
activation=activation,
|
|
166
|
+
),
|
|
191
167
|
)
|
|
192
168
|
# ================================
|
|
193
169
|
self.add_module(
|
|
194
170
|
"TC_block4",
|
|
195
|
-
_TCBlock(
|
|
171
|
+
_TCBlock(
|
|
172
|
+
in_ch=tcn_in_channel,
|
|
173
|
+
kernel_length=tcn_kernel_size,
|
|
174
|
+
dilatation=tcn_dilatation * 8,
|
|
175
|
+
padding=tcn_padding * 8,
|
|
176
|
+
drop_prob=drop_prob,
|
|
177
|
+
activation=activation,
|
|
178
|
+
),
|
|
196
179
|
)
|
|
197
180
|
|
|
198
181
|
# ============= Dimensionality reduction ===================
|
|
199
|
-
self.add_module(
|
|
200
|
-
|
|
201
|
-
nn.
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
182
|
+
self.add_module(
|
|
183
|
+
"dim_reduction",
|
|
184
|
+
nn.Sequential(
|
|
185
|
+
nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
|
|
186
|
+
nn.BatchNorm2d(tcn_in_channel * 2),
|
|
187
|
+
activation(),
|
|
188
|
+
nn.AvgPool2d((1, tcn_kernel_size)),
|
|
189
|
+
nn.Dropout(drop_prob),
|
|
190
|
+
),
|
|
191
|
+
)
|
|
205
192
|
# ============== Classifier ==================
|
|
206
193
|
# Moved flatten to another layer
|
|
207
194
|
self.add_module("flatten", nn.Flatten())
|
|
208
195
|
|
|
209
|
-
|
|
210
|
-
module = nn.Sequential()
|
|
211
|
-
|
|
212
|
-
module.add_module("clf",
|
|
213
|
-
nn.Linear(int(int(self.n_times / 4) / 4) * 28, self.n_outputs))
|
|
196
|
+
num_features = self.get_output_shape()[-1]
|
|
214
197
|
|
|
215
|
-
|
|
216
|
-
module.add_module("out_fun", nn.LogSoftmax(dim=1))
|
|
217
|
-
else:
|
|
218
|
-
module.add_module("out_fun", nn.Identity())
|
|
219
|
-
|
|
220
|
-
self.add_module("final_layer", module)
|
|
198
|
+
self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
|
|
221
199
|
|
|
222
200
|
@staticmethod
|
|
223
|
-
def _get_inception_branch(
|
|
201
|
+
def _get_inception_branch(
|
|
202
|
+
in_channels,
|
|
203
|
+
out_channels,
|
|
204
|
+
kernel_length,
|
|
205
|
+
depth_multiplier=1,
|
|
206
|
+
activation: nn.Module = nn.ELU,
|
|
207
|
+
):
|
|
224
208
|
return nn.Sequential(
|
|
225
209
|
nn.Conv2d(
|
|
226
|
-
1,
|
|
210
|
+
1,
|
|
211
|
+
out_channels,
|
|
212
|
+
kernel_size=(1, kernel_length),
|
|
213
|
+
padding="same",
|
|
214
|
+
bias=False,
|
|
227
215
|
),
|
|
228
216
|
nn.BatchNorm2d(out_channels),
|
|
229
|
-
|
|
217
|
+
DepthwiseConv2d(
|
|
230
218
|
out_channels,
|
|
231
219
|
kernel_size=(in_channels, 1),
|
|
232
220
|
depth_multiplier=depth_multiplier,
|
|
@@ -234,4 +222,79 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
|
234
222
|
padding="valid",
|
|
235
223
|
),
|
|
236
224
|
nn.BatchNorm2d(out_channels),
|
|
237
|
-
|
|
225
|
+
activation(),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class _TCBlock(nn.Module):
|
|
230
|
+
"""
|
|
231
|
+
Temporal Convolutional (TC) block.
|
|
232
|
+
|
|
233
|
+
This module applies two depthwise separable convolutions with dilation and residual
|
|
234
|
+
connections, commonly used in temporal convolutional networks to capture long-range
|
|
235
|
+
dependencies in time-series data.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
in_ch : int
|
|
240
|
+
Number of input channels.
|
|
241
|
+
kernel_length : int
|
|
242
|
+
Length of the convolutional kernels.
|
|
243
|
+
dilatation : int
|
|
244
|
+
Dilatation rate for the convolutions.
|
|
245
|
+
padding : int
|
|
246
|
+
Amount of padding to add to the input.
|
|
247
|
+
drop_prob : float, optional
|
|
248
|
+
Dropout probability. Default is 0.4.
|
|
249
|
+
activation : nn.Module class, optional
|
|
250
|
+
Activation function class to use. Should be a PyTorch activation module class
|
|
251
|
+
like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
in_ch,
|
|
257
|
+
kernel_length,
|
|
258
|
+
dilatation,
|
|
259
|
+
padding,
|
|
260
|
+
drop_prob=0.4,
|
|
261
|
+
activation: nn.Module = nn.ELU,
|
|
262
|
+
):
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.pad = padding
|
|
265
|
+
self.tc1 = nn.Sequential(
|
|
266
|
+
DepthwiseConv2d(
|
|
267
|
+
in_ch,
|
|
268
|
+
kernel_size=(1, kernel_length),
|
|
269
|
+
depth_multiplier=1,
|
|
270
|
+
dilation=(1, dilatation),
|
|
271
|
+
bias=False,
|
|
272
|
+
padding="valid",
|
|
273
|
+
),
|
|
274
|
+
nn.BatchNorm2d(in_ch),
|
|
275
|
+
activation(),
|
|
276
|
+
nn.Dropout(drop_prob),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
self.tc2 = nn.Sequential(
|
|
280
|
+
DepthwiseConv2d(
|
|
281
|
+
in_ch,
|
|
282
|
+
kernel_size=(1, kernel_length),
|
|
283
|
+
depth_multiplier=1,
|
|
284
|
+
dilation=(1, dilatation),
|
|
285
|
+
bias=False,
|
|
286
|
+
padding="valid",
|
|
287
|
+
),
|
|
288
|
+
nn.BatchNorm2d(in_ch),
|
|
289
|
+
activation(),
|
|
290
|
+
nn.Dropout(drop_prob),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def forward(self, x):
|
|
294
|
+
residual = x
|
|
295
|
+
paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
|
|
296
|
+
x = nn.functional.pad(x, paddings)
|
|
297
|
+
x = self.tc1(x)
|
|
298
|
+
x = nn.functional.pad(x, paddings)
|
|
299
|
+
x = self.tc2(x) + residual
|
|
300
|
+
return x
|