braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Authors: Patryk Chrabaszcz
|
|
2
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD-3
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import init
|
|
8
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import Chomp1d, Ensure4d, Expression, SqueezeFinalOutput
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BDTCN(EEGModuleMixin, nn.Module):
|
|
15
|
+
"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
|
|
16
|
+
|
|
17
|
+
.. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
|
|
18
|
+
:align: center
|
|
19
|
+
:alt: Braindecode TCN Architecture
|
|
20
|
+
|
|
21
|
+
See [gemein2020]_ for details.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
n_filters: int
|
|
26
|
+
number of output filters of each convolution
|
|
27
|
+
n_blocks: int
|
|
28
|
+
number of temporal blocks in the network
|
|
29
|
+
kernel_size: int
|
|
30
|
+
kernel size of the convolutions
|
|
31
|
+
drop_prob: float
|
|
32
|
+
dropout probability
|
|
33
|
+
activation: nn.Module, default=nn.ReLU
|
|
34
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
35
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
36
|
+
|
|
37
|
+
References
|
|
38
|
+
----------
|
|
39
|
+
.. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
|
|
40
|
+
Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
|
|
41
|
+
diagnostics of EEG pathology. NeuroImage, 220, 117021.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
# Braindecode parameters
|
|
47
|
+
n_chans=None,
|
|
48
|
+
n_outputs=None,
|
|
49
|
+
chs_info=None,
|
|
50
|
+
n_times=None,
|
|
51
|
+
sfreq=None,
|
|
52
|
+
input_window_seconds=None,
|
|
53
|
+
# Model's parameters
|
|
54
|
+
n_blocks=3,
|
|
55
|
+
n_filters=30,
|
|
56
|
+
kernel_size=5,
|
|
57
|
+
drop_prob=0.5,
|
|
58
|
+
activation: nn.Module = nn.ReLU,
|
|
59
|
+
):
|
|
60
|
+
super().__init__(
|
|
61
|
+
n_outputs=n_outputs,
|
|
62
|
+
n_chans=n_chans,
|
|
63
|
+
chs_info=chs_info,
|
|
64
|
+
n_times=n_times,
|
|
65
|
+
input_window_seconds=input_window_seconds,
|
|
66
|
+
sfreq=sfreq,
|
|
67
|
+
)
|
|
68
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
69
|
+
|
|
70
|
+
self.base_tcn = TCN(
|
|
71
|
+
n_chans=self.n_chans,
|
|
72
|
+
n_outputs=self.n_outputs,
|
|
73
|
+
n_blocks=n_blocks,
|
|
74
|
+
n_filters=n_filters,
|
|
75
|
+
kernel_size=kernel_size,
|
|
76
|
+
drop_prob=drop_prob,
|
|
77
|
+
activation=activation,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.final_layer = torch.nn.Sequential(
|
|
81
|
+
torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
x = self.base_tcn(x)
|
|
86
|
+
x = self.final_layer(x)
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TCN(nn.Module):
|
|
91
|
+
"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
|
|
92
|
+
|
|
93
|
+
See [Bai2018]_ for details.
|
|
94
|
+
|
|
95
|
+
Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
n_filters: int
|
|
100
|
+
number of output filters of each convolution
|
|
101
|
+
n_blocks: int
|
|
102
|
+
number of temporal blocks in the network
|
|
103
|
+
kernel_size: int
|
|
104
|
+
kernel size of the convolutions
|
|
105
|
+
drop_prob: float
|
|
106
|
+
dropout probability
|
|
107
|
+
activation: nn.Module, default=nn.ReLU
|
|
108
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
109
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
110
|
+
|
|
111
|
+
References
|
|
112
|
+
----------
|
|
113
|
+
.. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
|
|
114
|
+
An empirical evaluation of generic convolutional and recurrent networks
|
|
115
|
+
for sequence modeling.
|
|
116
|
+
arXiv preprint arXiv:1803.01271.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
n_chans=None,
|
|
122
|
+
n_outputs=None,
|
|
123
|
+
n_blocks=3,
|
|
124
|
+
n_filters=30,
|
|
125
|
+
kernel_size=5,
|
|
126
|
+
drop_prob=0.5,
|
|
127
|
+
activation: nn.Module = nn.ReLU,
|
|
128
|
+
):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.mapping = {
|
|
131
|
+
"fc.weight": "final_layer.fc.weight",
|
|
132
|
+
"fc.bias": "final_layer.fc.bias",
|
|
133
|
+
}
|
|
134
|
+
self.ensuredims = Ensure4d()
|
|
135
|
+
t_blocks = nn.Sequential()
|
|
136
|
+
for i in range(n_blocks):
|
|
137
|
+
n_inputs = n_chans if i == 0 else n_filters
|
|
138
|
+
dilation_size = 2**i
|
|
139
|
+
t_blocks.add_module(
|
|
140
|
+
"temporal_block_{:d}".format(i),
|
|
141
|
+
_TemporalBlock(
|
|
142
|
+
n_inputs=n_inputs,
|
|
143
|
+
n_outputs=n_filters,
|
|
144
|
+
kernel_size=kernel_size,
|
|
145
|
+
stride=1,
|
|
146
|
+
dilation=dilation_size,
|
|
147
|
+
padding=(kernel_size - 1) * dilation_size,
|
|
148
|
+
drop_prob=drop_prob,
|
|
149
|
+
activation=activation,
|
|
150
|
+
),
|
|
151
|
+
)
|
|
152
|
+
self.temporal_blocks = t_blocks
|
|
153
|
+
|
|
154
|
+
self.final_layer = _FinalLayer(
|
|
155
|
+
in_features=n_filters,
|
|
156
|
+
out_features=n_outputs,
|
|
157
|
+
)
|
|
158
|
+
self.min_len = 1
|
|
159
|
+
for i in range(n_blocks):
|
|
160
|
+
dilation = 2**i
|
|
161
|
+
self.min_len += 2 * (kernel_size - 1) * dilation
|
|
162
|
+
|
|
163
|
+
# start in eval mode
|
|
164
|
+
self.train()
|
|
165
|
+
|
|
166
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
167
|
+
"""Forward pass.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
x: torch.Tensor
|
|
172
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
173
|
+
"""
|
|
174
|
+
x = self.ensuredims(x)
|
|
175
|
+
# x is in format: B x C x T x 1
|
|
176
|
+
(batch_size, _, time_size, _) = x.size()
|
|
177
|
+
assert time_size >= self.min_len
|
|
178
|
+
# remove empty trailing dimension
|
|
179
|
+
x = x.squeeze(3)
|
|
180
|
+
x = self.temporal_blocks(x)
|
|
181
|
+
# Convert to: B x T x C
|
|
182
|
+
x = x.transpose(1, 2).contiguous()
|
|
183
|
+
|
|
184
|
+
out = self.final_layer(x, batch_size, time_size, self.min_len)
|
|
185
|
+
|
|
186
|
+
return out
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class _FinalLayer(nn.Module):
|
|
190
|
+
def __init__(self, in_features, out_features):
|
|
191
|
+
super().__init__()
|
|
192
|
+
|
|
193
|
+
self.fc = nn.Linear(in_features=in_features, out_features=out_features)
|
|
194
|
+
|
|
195
|
+
self.out_fun = nn.Identity()
|
|
196
|
+
|
|
197
|
+
self.squeeze = SqueezeFinalOutput()
|
|
198
|
+
|
|
199
|
+
def forward(
|
|
200
|
+
self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
|
|
201
|
+
) -> torch.Tensor:
|
|
202
|
+
fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
|
|
203
|
+
fc_out = self.out_fun(fc_out)
|
|
204
|
+
fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
|
|
205
|
+
|
|
206
|
+
out_size = 1 + max(0, time_size - min_len)
|
|
207
|
+
out = fc_out[:, -out_size:, :].transpose(1, 2)
|
|
208
|
+
# re-add 4th dimension for compatibility with braindecode
|
|
209
|
+
return self.squeeze(out[:, :, :, None])
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class _TemporalBlock(nn.Module):
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
n_inputs,
|
|
216
|
+
n_outputs,
|
|
217
|
+
kernel_size,
|
|
218
|
+
stride,
|
|
219
|
+
dilation,
|
|
220
|
+
padding,
|
|
221
|
+
drop_prob,
|
|
222
|
+
activation: nn.Module = nn.ReLU,
|
|
223
|
+
):
|
|
224
|
+
super().__init__()
|
|
225
|
+
self.conv1 = weight_norm(
|
|
226
|
+
nn.Conv1d(
|
|
227
|
+
n_inputs,
|
|
228
|
+
n_outputs,
|
|
229
|
+
kernel_size,
|
|
230
|
+
stride=stride,
|
|
231
|
+
padding=padding,
|
|
232
|
+
dilation=dilation,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
self.chomp1 = Chomp1d(padding)
|
|
236
|
+
self.relu1 = activation()
|
|
237
|
+
self.dropout1 = nn.Dropout2d(drop_prob)
|
|
238
|
+
|
|
239
|
+
self.conv2 = weight_norm(
|
|
240
|
+
nn.Conv1d(
|
|
241
|
+
n_outputs,
|
|
242
|
+
n_outputs,
|
|
243
|
+
kernel_size,
|
|
244
|
+
stride=stride,
|
|
245
|
+
padding=padding,
|
|
246
|
+
dilation=dilation,
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
self.chomp2 = Chomp1d(padding)
|
|
250
|
+
self.relu2 = activation()
|
|
251
|
+
self.dropout2 = nn.Dropout2d(drop_prob)
|
|
252
|
+
|
|
253
|
+
self.downsample = (
|
|
254
|
+
nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
|
255
|
+
)
|
|
256
|
+
self.relu = activation()
|
|
257
|
+
|
|
258
|
+
init.normal_(self.conv1.weight, 0, 0.01)
|
|
259
|
+
init.normal_(self.conv2.weight, 0, 0.01)
|
|
260
|
+
if self.downsample is not None:
|
|
261
|
+
init.normal_(self.downsample.weight, 0, 0.01)
|
|
262
|
+
|
|
263
|
+
def forward(self, x):
|
|
264
|
+
out = self.conv1(x)
|
|
265
|
+
out = self.chomp1(out)
|
|
266
|
+
out = self.relu1(out)
|
|
267
|
+
out = self.dropout1(out)
|
|
268
|
+
out = self.conv2(out)
|
|
269
|
+
out = self.chomp2(out)
|
|
270
|
+
out = self.relu2(out)
|
|
271
|
+
out = self.dropout2(out)
|
|
272
|
+
res = x if self.downsample is None else self.downsample(x)
|
|
273
|
+
return self.relu(out + res)
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn import init
|
|
7
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import Ensure4d
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TIDNet(EEGModuleMixin, nn.Module):
|
|
14
|
+
"""Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
|
|
15
|
+
|
|
16
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: TIDNet Architecture
|
|
19
|
+
|
|
20
|
+
See [TIDNet]_ for details.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
s_growth : int
|
|
25
|
+
DenseNet-style growth factor (added filters per DenseFilter)
|
|
26
|
+
t_filters : int
|
|
27
|
+
Number of temporal filters.
|
|
28
|
+
drop_prob : float
|
|
29
|
+
Dropout probability
|
|
30
|
+
pooling : int
|
|
31
|
+
Max temporal pooling (width and stride)
|
|
32
|
+
temp_layers : int
|
|
33
|
+
Number of temporal layers
|
|
34
|
+
spat_layers : int
|
|
35
|
+
Number of DenseFilters
|
|
36
|
+
temp_span : float
|
|
37
|
+
Percentage of n_times that defines the temporal filter length:
|
|
38
|
+
temp_len = ceil(temp_span * n_times)
|
|
39
|
+
e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
|
|
40
|
+
filter of length 75.
|
|
41
|
+
bottleneck : int
|
|
42
|
+
Bottleneck factor within Densefilter
|
|
43
|
+
summary : int
|
|
44
|
+
Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
|
|
45
|
+
automatically (n_times // pooling).
|
|
46
|
+
in_chans :
|
|
47
|
+
Alias for n_chans.
|
|
48
|
+
n_classes:
|
|
49
|
+
Alias for n_outputs.
|
|
50
|
+
input_window_samples :
|
|
51
|
+
Alias for n_times.
|
|
52
|
+
activation: nn.Module, default=nn.LeakyReLU
|
|
53
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
54
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
|
|
55
|
+
|
|
56
|
+
Notes
|
|
57
|
+
-----
|
|
58
|
+
Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
.. [TIDNet] Kostas, D. & Rudzicz, F.
|
|
63
|
+
Thinker invariance: enabling deep neural networks for BCI across more
|
|
64
|
+
people.
|
|
65
|
+
J. Neural Eng. 17, 056008 (2020).
|
|
66
|
+
doi: 10.1088/1741-2552/abb7a7.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
n_chans=None,
|
|
72
|
+
n_outputs=None,
|
|
73
|
+
n_times=None,
|
|
74
|
+
input_window_seconds=None,
|
|
75
|
+
sfreq=None,
|
|
76
|
+
chs_info=None,
|
|
77
|
+
s_growth: int = 24,
|
|
78
|
+
t_filters: int = 32,
|
|
79
|
+
drop_prob: float = 0.4,
|
|
80
|
+
pooling: int = 15,
|
|
81
|
+
temp_layers: int = 2,
|
|
82
|
+
spat_layers: int = 2,
|
|
83
|
+
temp_span: float = 0.05,
|
|
84
|
+
bottleneck: int = 3,
|
|
85
|
+
summary: int = -1,
|
|
86
|
+
activation: nn.Module = nn.LeakyReLU,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(
|
|
89
|
+
n_outputs=n_outputs,
|
|
90
|
+
n_chans=n_chans,
|
|
91
|
+
n_times=n_times,
|
|
92
|
+
input_window_seconds=input_window_seconds,
|
|
93
|
+
sfreq=sfreq,
|
|
94
|
+
chs_info=chs_info,
|
|
95
|
+
)
|
|
96
|
+
del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
|
|
97
|
+
|
|
98
|
+
self.temp_len = ceil(temp_span * self.n_times)
|
|
99
|
+
|
|
100
|
+
self.dscnn = _TIDNetFeatures(
|
|
101
|
+
s_growth=s_growth,
|
|
102
|
+
t_filters=t_filters,
|
|
103
|
+
n_chans=self.n_chans,
|
|
104
|
+
n_times=self.n_times,
|
|
105
|
+
drop_prob=drop_prob,
|
|
106
|
+
pooling=pooling,
|
|
107
|
+
temp_layers=temp_layers,
|
|
108
|
+
spat_layers=spat_layers,
|
|
109
|
+
temp_span=temp_span,
|
|
110
|
+
bottleneck=bottleneck,
|
|
111
|
+
summary=summary,
|
|
112
|
+
activation=activation,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._num_features = self.dscnn.num_features
|
|
116
|
+
|
|
117
|
+
self.flatten = nn.Flatten(start_dim=1)
|
|
118
|
+
|
|
119
|
+
self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
|
|
120
|
+
|
|
121
|
+
def _create_classifier(self, incoming: int, n_outputs: int):
|
|
122
|
+
classifier = nn.Linear(incoming, n_outputs)
|
|
123
|
+
init.xavier_normal_(classifier.weight)
|
|
124
|
+
classifier.bias.data.zero_()
|
|
125
|
+
seq_clf = nn.Sequential(classifier, nn.Identity())
|
|
126
|
+
|
|
127
|
+
return seq_clf
|
|
128
|
+
|
|
129
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
"""Forward pass.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
x: torch.Tensor
|
|
135
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
x = self.dscnn(x)
|
|
139
|
+
x = self.flatten(x)
|
|
140
|
+
return self.final_layer(x)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def num_features(self):
|
|
144
|
+
return self._num_features
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class _BatchNormZG(nn.BatchNorm2d):
|
|
148
|
+
def reset_parameters(self):
|
|
149
|
+
if self.track_running_stats:
|
|
150
|
+
self.running_mean.zero_()
|
|
151
|
+
self.running_var.fill_(1)
|
|
152
|
+
if self.affine:
|
|
153
|
+
self.weight.data.zero_()
|
|
154
|
+
self.bias.data.zero_()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class _ConvBlock2D(nn.Module):
|
|
158
|
+
"""Implements Convolution block with order:
|
|
159
|
+
Convolution, dropout, activation, batch-norm
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
in_filters: int,
|
|
165
|
+
out_filters: int,
|
|
166
|
+
kernel: tuple[int, int],
|
|
167
|
+
stride: tuple[int, int] = (1, 1),
|
|
168
|
+
padding: int = 0,
|
|
169
|
+
dilation: int = 1,
|
|
170
|
+
groups: int = 1,
|
|
171
|
+
drop_prob: float = 0.5,
|
|
172
|
+
batch_norm: bool = True,
|
|
173
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
174
|
+
residual: bool = False,
|
|
175
|
+
):
|
|
176
|
+
super().__init__()
|
|
177
|
+
self.kernel = kernel
|
|
178
|
+
self.activation = activation()
|
|
179
|
+
self.residual = residual
|
|
180
|
+
|
|
181
|
+
self.conv = nn.Conv2d(
|
|
182
|
+
in_filters,
|
|
183
|
+
out_filters,
|
|
184
|
+
kernel,
|
|
185
|
+
stride=stride,
|
|
186
|
+
padding=padding,
|
|
187
|
+
dilation=dilation,
|
|
188
|
+
groups=groups,
|
|
189
|
+
bias=not batch_norm,
|
|
190
|
+
)
|
|
191
|
+
self.dropout = nn.Dropout2d(p=float(drop_prob))
|
|
192
|
+
self.batch_norm = (
|
|
193
|
+
_BatchNormZG(out_filters)
|
|
194
|
+
if residual
|
|
195
|
+
else nn.BatchNorm2d(out_filters)
|
|
196
|
+
if batch_norm
|
|
197
|
+
else nn.Identity()
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
res = input
|
|
202
|
+
input = self.conv(
|
|
203
|
+
input,
|
|
204
|
+
)
|
|
205
|
+
input = self.dropout(input)
|
|
206
|
+
input = self.activation(input)
|
|
207
|
+
input = self.batch_norm(input)
|
|
208
|
+
return input + res if self.residual else input
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class _DenseFilter(nn.Module):
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
in_features: int,
|
|
215
|
+
growth_rate: int,
|
|
216
|
+
filter_len: int = 5,
|
|
217
|
+
drop_prob: float = 0.5,
|
|
218
|
+
bottleneck: int = 2,
|
|
219
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
220
|
+
dim: int = -2,
|
|
221
|
+
):
|
|
222
|
+
super().__init__()
|
|
223
|
+
dim = dim if dim > 0 else dim + 4
|
|
224
|
+
if dim < 2 or dim > 3:
|
|
225
|
+
raise ValueError("Only last two dimensions supported")
|
|
226
|
+
kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
|
|
227
|
+
|
|
228
|
+
self.net = nn.Sequential(
|
|
229
|
+
nn.BatchNorm2d(in_features),
|
|
230
|
+
activation(),
|
|
231
|
+
nn.Conv2d(in_features, bottleneck * growth_rate, 1),
|
|
232
|
+
nn.BatchNorm2d(bottleneck * growth_rate),
|
|
233
|
+
activation(),
|
|
234
|
+
nn.Conv2d(
|
|
235
|
+
bottleneck * growth_rate,
|
|
236
|
+
growth_rate,
|
|
237
|
+
kernel,
|
|
238
|
+
padding=tuple((k // 2 for k in kernel)),
|
|
239
|
+
),
|
|
240
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
244
|
+
return torch.cat((x, self.net(x)), dim=1)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class _DenseSpatialFilter(nn.Module):
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
n_chans: int,
|
|
251
|
+
growth: int,
|
|
252
|
+
depth: int,
|
|
253
|
+
in_ch: int = 1,
|
|
254
|
+
bottleneck: int = 4,
|
|
255
|
+
drop_prob: float = 0.0,
|
|
256
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
257
|
+
collapse: bool = True,
|
|
258
|
+
):
|
|
259
|
+
super().__init__()
|
|
260
|
+
self.net = nn.Sequential(
|
|
261
|
+
*[
|
|
262
|
+
_DenseFilter(
|
|
263
|
+
in_ch + growth * d,
|
|
264
|
+
growth,
|
|
265
|
+
bottleneck=bottleneck,
|
|
266
|
+
drop_prob=drop_prob,
|
|
267
|
+
activation=activation,
|
|
268
|
+
)
|
|
269
|
+
for d in range(depth)
|
|
270
|
+
]
|
|
271
|
+
)
|
|
272
|
+
n_filters = in_ch + growth * depth
|
|
273
|
+
self.collapse = collapse
|
|
274
|
+
if collapse:
|
|
275
|
+
self.channel_collapse = _ConvBlock2D(
|
|
276
|
+
n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
280
|
+
if len(x.shape) < 4:
|
|
281
|
+
x = x.unsqueeze(1).permute([0, 1, 3, 2])
|
|
282
|
+
x = self.net(x)
|
|
283
|
+
if self.collapse:
|
|
284
|
+
return self.channel_collapse(x).squeeze(-2)
|
|
285
|
+
return x
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _TemporalFilter(nn.Module):
|
|
289
|
+
def __init__(
|
|
290
|
+
self,
|
|
291
|
+
n_chans: int,
|
|
292
|
+
filters: int,
|
|
293
|
+
depth: int,
|
|
294
|
+
temp_len: int,
|
|
295
|
+
drop_prob: float = 0.0,
|
|
296
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
297
|
+
residual: str = "netwise",
|
|
298
|
+
):
|
|
299
|
+
super().__init__()
|
|
300
|
+
temp_len = temp_len + 1 - temp_len % 2
|
|
301
|
+
self.residual_style = str(residual)
|
|
302
|
+
net = list()
|
|
303
|
+
|
|
304
|
+
for i in range(depth):
|
|
305
|
+
dil = depth - i
|
|
306
|
+
conv = weight_norm(
|
|
307
|
+
nn.Conv2d(
|
|
308
|
+
n_chans if i == 0 else filters,
|
|
309
|
+
filters,
|
|
310
|
+
kernel_size=(1, temp_len),
|
|
311
|
+
dilation=dil,
|
|
312
|
+
padding=(0, dil * (temp_len - 1) // 2),
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
net.append(
|
|
316
|
+
nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
|
|
317
|
+
)
|
|
318
|
+
if self.residual_style.lower() == "netwise":
|
|
319
|
+
self.net = nn.Sequential(*net)
|
|
320
|
+
self.residual = nn.Conv2d(n_chans, filters, (1, 1))
|
|
321
|
+
elif residual.lower() == "dense":
|
|
322
|
+
self.net = net
|
|
323
|
+
|
|
324
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
style = self.residual_style.lower()
|
|
326
|
+
if style == "netwise":
|
|
327
|
+
return self.net(x) + self.residual(x)
|
|
328
|
+
elif style == "dense":
|
|
329
|
+
for layer in self.net:
|
|
330
|
+
x = torch.cat((x, layer(x)), dim=1)
|
|
331
|
+
return x
|
|
332
|
+
# TorchScript now knows this path always returns or errors
|
|
333
|
+
else:
|
|
334
|
+
# Use an assertion so TorchScript can compile it
|
|
335
|
+
assert False, f"Unsupported residual style: {self.residual_style}"
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class _TIDNetFeatures(nn.Module):
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
s_growth: int,
|
|
342
|
+
t_filters: int,
|
|
343
|
+
n_chans: int,
|
|
344
|
+
n_times: int,
|
|
345
|
+
drop_prob: float,
|
|
346
|
+
pooling: int,
|
|
347
|
+
temp_layers: int,
|
|
348
|
+
spat_layers: int,
|
|
349
|
+
temp_span: float,
|
|
350
|
+
bottleneck: int,
|
|
351
|
+
summary: int,
|
|
352
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
353
|
+
):
|
|
354
|
+
super().__init__()
|
|
355
|
+
self.n_chans = n_chans
|
|
356
|
+
self.temp_len = ceil(temp_span * n_times)
|
|
357
|
+
|
|
358
|
+
self.temporal = nn.Sequential(
|
|
359
|
+
Ensure4d(),
|
|
360
|
+
Rearrange("batch C T 1 -> batch 1 C T"),
|
|
361
|
+
_TemporalFilter(
|
|
362
|
+
1,
|
|
363
|
+
t_filters,
|
|
364
|
+
depth=temp_layers,
|
|
365
|
+
temp_len=self.temp_len,
|
|
366
|
+
activation=activation,
|
|
367
|
+
),
|
|
368
|
+
nn.MaxPool2d((1, pooling)),
|
|
369
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
370
|
+
)
|
|
371
|
+
summary = n_times // pooling if summary == -1 else summary
|
|
372
|
+
|
|
373
|
+
self.spatial = _DenseSpatialFilter(
|
|
374
|
+
n_chans=n_chans,
|
|
375
|
+
growth=s_growth,
|
|
376
|
+
depth=spat_layers,
|
|
377
|
+
in_ch=t_filters,
|
|
378
|
+
drop_prob=drop_prob,
|
|
379
|
+
bottleneck=bottleneck,
|
|
380
|
+
activation=activation,
|
|
381
|
+
)
|
|
382
|
+
self.extract_features = nn.Sequential(
|
|
383
|
+
nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
self._num_features = (t_filters + s_growth * spat_layers) * summary
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def num_features(self):
|
|
390
|
+
return self._num_features
|
|
391
|
+
|
|
392
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
393
|
+
x = self.temporal(x)
|
|
394
|
+
x = self.spatial(x)
|
|
395
|
+
return self.extract_features(x)
|