braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/tidnet.py
CHANGED
|
@@ -1,13 +1,147 @@
|
|
|
1
1
|
from math import ceil
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
4
5
|
from torch import nn
|
|
5
6
|
from torch.nn import init
|
|
6
|
-
from torch.nn.utils import weight_norm
|
|
7
|
-
|
|
7
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import Ensure4d
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TIDNet(EEGModuleMixin, nn.Module):
|
|
14
|
+
"""Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
|
|
15
|
+
|
|
16
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: TIDNet Architecture
|
|
19
|
+
|
|
20
|
+
See [TIDNet]_ for details.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
s_growth : int
|
|
25
|
+
DenseNet-style growth factor (added filters per DenseFilter)
|
|
26
|
+
t_filters : int
|
|
27
|
+
Number of temporal filters.
|
|
28
|
+
drop_prob : float
|
|
29
|
+
Dropout probability
|
|
30
|
+
pooling : int
|
|
31
|
+
Max temporal pooling (width and stride)
|
|
32
|
+
temp_layers : int
|
|
33
|
+
Number of temporal layers
|
|
34
|
+
spat_layers : int
|
|
35
|
+
Number of DenseFilters
|
|
36
|
+
temp_span : float
|
|
37
|
+
Percentage of n_times that defines the temporal filter length:
|
|
38
|
+
temp_len = ceil(temp_span * n_times)
|
|
39
|
+
e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
|
|
40
|
+
filter of length 75.
|
|
41
|
+
bottleneck : int
|
|
42
|
+
Bottleneck factor within Densefilter
|
|
43
|
+
summary : int
|
|
44
|
+
Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
|
|
45
|
+
automatically (n_times // pooling).
|
|
46
|
+
in_chans :
|
|
47
|
+
Alias for n_chans.
|
|
48
|
+
n_classes:
|
|
49
|
+
Alias for n_outputs.
|
|
50
|
+
input_window_samples :
|
|
51
|
+
Alias for n_times.
|
|
52
|
+
activation: nn.Module, default=nn.LeakyReLU
|
|
53
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
54
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
|
|
55
|
+
|
|
56
|
+
Notes
|
|
57
|
+
-----
|
|
58
|
+
Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
.. [TIDNet] Kostas, D. & Rudzicz, F.
|
|
63
|
+
Thinker invariance: enabling deep neural networks for BCI across more
|
|
64
|
+
people.
|
|
65
|
+
J. Neural Eng. 17, 056008 (2020).
|
|
66
|
+
doi: 10.1088/1741-2552/abb7a7.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
n_chans=None,
|
|
72
|
+
n_outputs=None,
|
|
73
|
+
n_times=None,
|
|
74
|
+
input_window_seconds=None,
|
|
75
|
+
sfreq=None,
|
|
76
|
+
chs_info=None,
|
|
77
|
+
s_growth: int = 24,
|
|
78
|
+
t_filters: int = 32,
|
|
79
|
+
drop_prob: float = 0.4,
|
|
80
|
+
pooling: int = 15,
|
|
81
|
+
temp_layers: int = 2,
|
|
82
|
+
spat_layers: int = 2,
|
|
83
|
+
temp_span: float = 0.05,
|
|
84
|
+
bottleneck: int = 3,
|
|
85
|
+
summary: int = -1,
|
|
86
|
+
activation: nn.Module = nn.LeakyReLU,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(
|
|
89
|
+
n_outputs=n_outputs,
|
|
90
|
+
n_chans=n_chans,
|
|
91
|
+
n_times=n_times,
|
|
92
|
+
input_window_seconds=input_window_seconds,
|
|
93
|
+
sfreq=sfreq,
|
|
94
|
+
chs_info=chs_info,
|
|
95
|
+
)
|
|
96
|
+
del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
|
|
97
|
+
|
|
98
|
+
self.temp_len = ceil(temp_span * self.n_times)
|
|
99
|
+
|
|
100
|
+
self.dscnn = _TIDNetFeatures(
|
|
101
|
+
s_growth=s_growth,
|
|
102
|
+
t_filters=t_filters,
|
|
103
|
+
n_chans=self.n_chans,
|
|
104
|
+
n_times=self.n_times,
|
|
105
|
+
drop_prob=drop_prob,
|
|
106
|
+
pooling=pooling,
|
|
107
|
+
temp_layers=temp_layers,
|
|
108
|
+
spat_layers=spat_layers,
|
|
109
|
+
temp_span=temp_span,
|
|
110
|
+
bottleneck=bottleneck,
|
|
111
|
+
summary=summary,
|
|
112
|
+
activation=activation,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._num_features = self.dscnn.num_features
|
|
116
|
+
|
|
117
|
+
self.flatten = nn.Flatten(start_dim=1)
|
|
118
|
+
|
|
119
|
+
self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
|
|
8
120
|
|
|
9
|
-
|
|
10
|
-
|
|
121
|
+
def _create_classifier(self, incoming: int, n_outputs: int):
|
|
122
|
+
classifier = nn.Linear(incoming, n_outputs)
|
|
123
|
+
init.xavier_normal_(classifier.weight)
|
|
124
|
+
classifier.bias.data.zero_()
|
|
125
|
+
seq_clf = nn.Sequential(classifier, nn.Identity())
|
|
126
|
+
|
|
127
|
+
return seq_clf
|
|
128
|
+
|
|
129
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
"""Forward pass.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
x: torch.Tensor
|
|
135
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
x = self.dscnn(x)
|
|
139
|
+
x = self.flatten(x)
|
|
140
|
+
return self.final_layer(x)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def num_features(self):
|
|
144
|
+
return self._num_features
|
|
11
145
|
|
|
12
146
|
|
|
13
147
|
class _BatchNormZG(nn.BatchNorm2d):
|
|
@@ -25,27 +159,49 @@ class _ConvBlock2D(nn.Module):
|
|
|
25
159
|
Convolution, dropout, activation, batch-norm
|
|
26
160
|
"""
|
|
27
161
|
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
in_filters: int,
|
|
165
|
+
out_filters: int,
|
|
166
|
+
kernel: tuple[int, int],
|
|
167
|
+
stride: tuple[int, int] = (1, 1),
|
|
168
|
+
padding: int = 0,
|
|
169
|
+
dilation: int = 1,
|
|
170
|
+
groups: int = 1,
|
|
171
|
+
drop_prob: float = 0.5,
|
|
172
|
+
batch_norm: bool = True,
|
|
173
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
174
|
+
residual: bool = False,
|
|
175
|
+
):
|
|
30
176
|
super().__init__()
|
|
31
177
|
self.kernel = kernel
|
|
32
178
|
self.activation = activation()
|
|
33
179
|
self.residual = residual
|
|
34
180
|
|
|
35
|
-
self.conv = nn.Conv2d(
|
|
36
|
-
|
|
37
|
-
|
|
181
|
+
self.conv = nn.Conv2d(
|
|
182
|
+
in_filters,
|
|
183
|
+
out_filters,
|
|
184
|
+
kernel,
|
|
185
|
+
stride=stride,
|
|
186
|
+
padding=padding,
|
|
187
|
+
dilation=dilation,
|
|
188
|
+
groups=groups,
|
|
189
|
+
bias=not batch_norm,
|
|
190
|
+
)
|
|
191
|
+
self.dropout = nn.Dropout2d(p=float(drop_prob))
|
|
38
192
|
self.batch_norm = (
|
|
39
193
|
_BatchNormZG(out_filters)
|
|
40
194
|
if residual
|
|
41
195
|
else nn.BatchNorm2d(out_filters)
|
|
42
196
|
if batch_norm
|
|
43
|
-
else
|
|
197
|
+
else nn.Identity()
|
|
44
198
|
)
|
|
45
199
|
|
|
46
|
-
def forward(self, input):
|
|
200
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
47
201
|
res = input
|
|
48
|
-
input = self.conv(
|
|
202
|
+
input = self.conv(
|
|
203
|
+
input,
|
|
204
|
+
)
|
|
49
205
|
input = self.dropout(input)
|
|
50
206
|
input = self.activation(input)
|
|
51
207
|
input = self.batch_norm(input)
|
|
@@ -53,12 +209,20 @@ class _ConvBlock2D(nn.Module):
|
|
|
53
209
|
|
|
54
210
|
|
|
55
211
|
class _DenseFilter(nn.Module):
|
|
56
|
-
def __init__(
|
|
57
|
-
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
in_features: int,
|
|
215
|
+
growth_rate: int,
|
|
216
|
+
filter_len: int = 5,
|
|
217
|
+
drop_prob: float = 0.5,
|
|
218
|
+
bottleneck: int = 2,
|
|
219
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
220
|
+
dim: int = -2,
|
|
221
|
+
):
|
|
58
222
|
super().__init__()
|
|
59
223
|
dim = dim if dim > 0 else dim + 4
|
|
60
224
|
if dim < 2 or dim > 3:
|
|
61
|
-
raise ValueError(
|
|
225
|
+
raise ValueError("Only last two dimensions supported")
|
|
62
226
|
kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
|
|
63
227
|
|
|
64
228
|
self.net = nn.Sequential(
|
|
@@ -67,29 +231,52 @@ class _DenseFilter(nn.Module):
|
|
|
67
231
|
nn.Conv2d(in_features, bottleneck * growth_rate, 1),
|
|
68
232
|
nn.BatchNorm2d(bottleneck * growth_rate),
|
|
69
233
|
activation(),
|
|
70
|
-
nn.Conv2d(
|
|
71
|
-
|
|
72
|
-
|
|
234
|
+
nn.Conv2d(
|
|
235
|
+
bottleneck * growth_rate,
|
|
236
|
+
growth_rate,
|
|
237
|
+
kernel,
|
|
238
|
+
padding=tuple((k // 2 for k in kernel)),
|
|
239
|
+
),
|
|
240
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
73
241
|
)
|
|
74
242
|
|
|
75
|
-
def forward(self, x):
|
|
243
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
76
244
|
return torch.cat((x, self.net(x)), dim=1)
|
|
77
245
|
|
|
78
246
|
|
|
79
247
|
class _DenseSpatialFilter(nn.Module):
|
|
80
|
-
def __init__(
|
|
81
|
-
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
n_chans: int,
|
|
251
|
+
growth: int,
|
|
252
|
+
depth: int,
|
|
253
|
+
in_ch: int = 1,
|
|
254
|
+
bottleneck: int = 4,
|
|
255
|
+
drop_prob: float = 0.0,
|
|
256
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
257
|
+
collapse: bool = True,
|
|
258
|
+
):
|
|
82
259
|
super().__init__()
|
|
83
|
-
self.net = nn.Sequential(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
260
|
+
self.net = nn.Sequential(
|
|
261
|
+
*[
|
|
262
|
+
_DenseFilter(
|
|
263
|
+
in_ch + growth * d,
|
|
264
|
+
growth,
|
|
265
|
+
bottleneck=bottleneck,
|
|
266
|
+
drop_prob=drop_prob,
|
|
267
|
+
activation=activation,
|
|
268
|
+
)
|
|
269
|
+
for d in range(depth)
|
|
270
|
+
]
|
|
271
|
+
)
|
|
87
272
|
n_filters = in_ch + growth * depth
|
|
88
273
|
self.collapse = collapse
|
|
89
274
|
if collapse:
|
|
90
|
-
self.channel_collapse = _ConvBlock2D(
|
|
275
|
+
self.channel_collapse = _ConvBlock2D(
|
|
276
|
+
n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
|
|
277
|
+
)
|
|
91
278
|
|
|
92
|
-
def forward(self, x):
|
|
279
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
93
280
|
if len(x.shape) < 4:
|
|
94
281
|
x = x.unsqueeze(1).permute([0, 1, 3, 2])
|
|
95
282
|
x = self.net(x)
|
|
@@ -99,8 +286,16 @@ class _DenseSpatialFilter(nn.Module):
|
|
|
99
286
|
|
|
100
287
|
|
|
101
288
|
class _TemporalFilter(nn.Module):
|
|
102
|
-
def __init__(
|
|
103
|
-
|
|
289
|
+
def __init__(
|
|
290
|
+
self,
|
|
291
|
+
n_chans: int,
|
|
292
|
+
filters: int,
|
|
293
|
+
depth: int,
|
|
294
|
+
temp_len: int,
|
|
295
|
+
drop_prob: float = 0.0,
|
|
296
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
297
|
+
residual: str = "netwise",
|
|
298
|
+
):
|
|
104
299
|
super().__init__()
|
|
105
300
|
temp_len = temp_len + 1 - temp_len % 2
|
|
106
301
|
self.residual_style = str(residual)
|
|
@@ -108,32 +303,54 @@ class _TemporalFilter(nn.Module):
|
|
|
108
303
|
|
|
109
304
|
for i in range(depth):
|
|
110
305
|
dil = depth - i
|
|
111
|
-
conv = weight_norm(
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
306
|
+
conv = weight_norm(
|
|
307
|
+
nn.Conv2d(
|
|
308
|
+
n_chans if i == 0 else filters,
|
|
309
|
+
filters,
|
|
310
|
+
kernel_size=(1, temp_len),
|
|
311
|
+
dilation=dil,
|
|
312
|
+
padding=(0, dil * (temp_len - 1) // 2),
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
net.append(
|
|
316
|
+
nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
|
|
317
|
+
)
|
|
318
|
+
if self.residual_style.lower() == "netwise":
|
|
120
319
|
self.net = nn.Sequential(*net)
|
|
121
320
|
self.residual = nn.Conv2d(n_chans, filters, (1, 1))
|
|
122
|
-
elif residual.lower() ==
|
|
321
|
+
elif residual.lower() == "dense":
|
|
123
322
|
self.net = net
|
|
124
323
|
|
|
125
|
-
def forward(self, x):
|
|
126
|
-
|
|
324
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
style = self.residual_style.lower()
|
|
326
|
+
if style == "netwise":
|
|
127
327
|
return self.net(x) + self.residual(x)
|
|
128
|
-
elif
|
|
328
|
+
elif style == "dense":
|
|
129
329
|
for layer in self.net:
|
|
130
330
|
x = torch.cat((x, layer(x)), dim=1)
|
|
131
331
|
return x
|
|
332
|
+
# TorchScript now knows this path always returns or errors
|
|
333
|
+
else:
|
|
334
|
+
# Use an assertion so TorchScript can compile it
|
|
335
|
+
assert False, f"Unsupported residual style: {self.residual_style}"
|
|
132
336
|
|
|
133
337
|
|
|
134
338
|
class _TIDNetFeatures(nn.Module):
|
|
135
|
-
def __init__(
|
|
136
|
-
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
s_growth: int,
|
|
342
|
+
t_filters: int,
|
|
343
|
+
n_chans: int,
|
|
344
|
+
n_times: int,
|
|
345
|
+
drop_prob: float,
|
|
346
|
+
pooling: int,
|
|
347
|
+
temp_layers: int,
|
|
348
|
+
spat_layers: int,
|
|
349
|
+
temp_span: float,
|
|
350
|
+
bottleneck: int,
|
|
351
|
+
summary: int,
|
|
352
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
353
|
+
):
|
|
137
354
|
super().__init__()
|
|
138
355
|
self.n_chans = n_chans
|
|
139
356
|
self.temp_len = ceil(temp_span * n_times)
|
|
@@ -141,17 +358,29 @@ class _TIDNetFeatures(nn.Module):
|
|
|
141
358
|
self.temporal = nn.Sequential(
|
|
142
359
|
Ensure4d(),
|
|
143
360
|
Rearrange("batch C T 1 -> batch 1 C T"),
|
|
144
|
-
_TemporalFilter(
|
|
361
|
+
_TemporalFilter(
|
|
362
|
+
1,
|
|
363
|
+
t_filters,
|
|
364
|
+
depth=temp_layers,
|
|
365
|
+
temp_len=self.temp_len,
|
|
366
|
+
activation=activation,
|
|
367
|
+
),
|
|
145
368
|
nn.MaxPool2d((1, pooling)),
|
|
146
|
-
nn.Dropout2d(drop_prob),
|
|
369
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
147
370
|
)
|
|
148
371
|
summary = n_times // pooling if summary == -1 else summary
|
|
149
372
|
|
|
150
|
-
self.spatial = _DenseSpatialFilter(
|
|
151
|
-
|
|
373
|
+
self.spatial = _DenseSpatialFilter(
|
|
374
|
+
n_chans=n_chans,
|
|
375
|
+
growth=s_growth,
|
|
376
|
+
depth=spat_layers,
|
|
377
|
+
in_ch=t_filters,
|
|
378
|
+
drop_prob=drop_prob,
|
|
379
|
+
bottleneck=bottleneck,
|
|
380
|
+
activation=activation,
|
|
381
|
+
)
|
|
152
382
|
self.extract_features = nn.Sequential(
|
|
153
|
-
nn.AdaptiveAvgPool1d(int(summary)),
|
|
154
|
-
nn.Flatten(start_dim=1)
|
|
383
|
+
nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
|
|
155
384
|
)
|
|
156
385
|
|
|
157
386
|
self._num_features = (t_filters + s_growth * spat_layers) * summary
|
|
@@ -160,123 +389,7 @@ class _TIDNetFeatures(nn.Module):
|
|
|
160
389
|
def num_features(self):
|
|
161
390
|
return self._num_features
|
|
162
391
|
|
|
163
|
-
def forward(self, x):
|
|
392
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
164
393
|
x = self.temporal(x)
|
|
165
394
|
x = self.spatial(x)
|
|
166
395
|
return self.extract_features(x)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
class TIDNet(EEGModuleMixin, nn.Module):
|
|
170
|
-
"""Thinker Invariance DenseNet model from Kostas et al 2020.
|
|
171
|
-
|
|
172
|
-
See [TIDNet]_ for details.
|
|
173
|
-
|
|
174
|
-
Parameters
|
|
175
|
-
----------
|
|
176
|
-
s_growth : int
|
|
177
|
-
DenseNet-style growth factor (added filters per DenseFilter)
|
|
178
|
-
t_filters : int
|
|
179
|
-
Number of temporal filters.
|
|
180
|
-
drop_prob : float
|
|
181
|
-
Dropout probability
|
|
182
|
-
pooling : int
|
|
183
|
-
Max temporal pooling (width and stride)
|
|
184
|
-
temp_layers : int
|
|
185
|
-
Number of temporal layers
|
|
186
|
-
spat_layers : int
|
|
187
|
-
Number of DenseFilters
|
|
188
|
-
temp_span : float
|
|
189
|
-
Percentage of n_times that defines the temporal filter length:
|
|
190
|
-
temp_len = ceil(temp_span * n_times)
|
|
191
|
-
e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
|
|
192
|
-
filter of length 75.
|
|
193
|
-
bottleneck : int
|
|
194
|
-
Bottleneck factor within Densefilter
|
|
195
|
-
summary : int
|
|
196
|
-
Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
|
|
197
|
-
automatically (n_times // pooling).
|
|
198
|
-
in_chans :
|
|
199
|
-
Alias for n_chans.
|
|
200
|
-
n_classes:
|
|
201
|
-
Alias for n_outputs.
|
|
202
|
-
input_window_samples :
|
|
203
|
-
Alias for n_times.
|
|
204
|
-
|
|
205
|
-
Notes
|
|
206
|
-
-----
|
|
207
|
-
Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
|
|
208
|
-
|
|
209
|
-
References
|
|
210
|
-
----------
|
|
211
|
-
.. [TIDNet] Kostas, D. & Rudzicz, F.
|
|
212
|
-
Thinker invariance: enabling deep neural networks for BCI across more
|
|
213
|
-
people.
|
|
214
|
-
J. Neural Eng. 17, 056008 (2020).
|
|
215
|
-
doi: 10.1088/1741-2552/abb7a7.
|
|
216
|
-
"""
|
|
217
|
-
|
|
218
|
-
def __init__(self, n_chans=None, n_outputs=None, n_times=None,
|
|
219
|
-
in_chans=None, n_classes=None, input_window_samples=None,
|
|
220
|
-
s_growth=24, t_filters=32, drop_prob=0.4, pooling=15,
|
|
221
|
-
temp_layers=2, spat_layers=2, temp_span=0.05,
|
|
222
|
-
bottleneck=3, summary=-1, add_log_softmax=True):
|
|
223
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
224
|
-
self,
|
|
225
|
-
('in_chans', 'n_chans', in_chans, n_chans),
|
|
226
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
227
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
228
|
-
)
|
|
229
|
-
super().__init__(
|
|
230
|
-
n_outputs=n_outputs,
|
|
231
|
-
n_chans=n_chans,
|
|
232
|
-
n_times=n_times,
|
|
233
|
-
add_log_softmax=add_log_softmax,
|
|
234
|
-
)
|
|
235
|
-
del n_outputs, n_chans, n_times
|
|
236
|
-
del in_chans, n_classes, input_window_samples
|
|
237
|
-
|
|
238
|
-
self.mapping = {
|
|
239
|
-
'classify.1.weight': 'final_layer.0.weight',
|
|
240
|
-
'classify.1.bias': 'final_layer.0.bias'
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
self.temp_len = ceil(temp_span * self.n_times)
|
|
244
|
-
|
|
245
|
-
self.dscnn = _TIDNetFeatures(s_growth=s_growth, t_filters=t_filters, n_chans=self.n_chans,
|
|
246
|
-
n_times=self.n_times,
|
|
247
|
-
drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
|
|
248
|
-
spat_layers=spat_layers, temp_span=temp_span,
|
|
249
|
-
bottleneck=bottleneck, summary=summary)
|
|
250
|
-
|
|
251
|
-
self._num_features = self.dscnn.num_features
|
|
252
|
-
|
|
253
|
-
self.flatten = nn.Flatten(start_dim=1)
|
|
254
|
-
|
|
255
|
-
self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
|
|
256
|
-
|
|
257
|
-
def _create_classifier(self, incoming, n_outputs):
|
|
258
|
-
classifier = nn.Linear(incoming, n_outputs)
|
|
259
|
-
init.xavier_normal_(classifier.weight)
|
|
260
|
-
classifier.bias.data.zero_()
|
|
261
|
-
seq_clf = nn.Sequential(
|
|
262
|
-
classifier,
|
|
263
|
-
nn.LogSoftmax(dim=-1) if self.add_log_softmax else nn.Identity())
|
|
264
|
-
|
|
265
|
-
return seq_clf
|
|
266
|
-
|
|
267
|
-
def forward(self, x):
|
|
268
|
-
"""Forward pass.
|
|
269
|
-
|
|
270
|
-
Parameters
|
|
271
|
-
----------
|
|
272
|
-
x: torch.Tensor
|
|
273
|
-
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
274
|
-
"""
|
|
275
|
-
|
|
276
|
-
x = self.dscnn(x)
|
|
277
|
-
x = self.flatten(x)
|
|
278
|
-
return self.final_layer(x)
|
|
279
|
-
|
|
280
|
-
@property
|
|
281
|
-
def num_features(self):
|
|
282
|
-
return self._num_features
|