braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn import init
|
|
7
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import Ensure4d
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TIDNet(EEGModuleMixin, nn.Module):
|
|
14
|
+
r"""Thinker Invariance DenseNet model from Kostas et al (2020) [TIDNet]_.
|
|
15
|
+
|
|
16
|
+
:bdg-success:`Convolution`
|
|
17
|
+
|
|
18
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: TIDNet Architecture
|
|
21
|
+
|
|
22
|
+
See [TIDNet]_ for details.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
s_growth : int
|
|
27
|
+
DenseNet-style growth factor (added filters per DenseFilter)
|
|
28
|
+
t_filters : int
|
|
29
|
+
Number of temporal filters.
|
|
30
|
+
drop_prob : float
|
|
31
|
+
Dropout probability
|
|
32
|
+
pooling : int
|
|
33
|
+
Max temporal pooling (width and stride)
|
|
34
|
+
temp_layers : int
|
|
35
|
+
Number of temporal layers
|
|
36
|
+
spat_layers : int
|
|
37
|
+
Number of DenseFilters
|
|
38
|
+
temp_span : float
|
|
39
|
+
Percentage of n_times that defines the temporal filter length:
|
|
40
|
+
temp_len = ceil(temp_span * n_times)
|
|
41
|
+
e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
|
|
42
|
+
filter of length 75.
|
|
43
|
+
bottleneck : int
|
|
44
|
+
Bottleneck factor within Densefilter
|
|
45
|
+
summary : int
|
|
46
|
+
Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
|
|
47
|
+
automatically (n_times // pooling).
|
|
48
|
+
in_chans :
|
|
49
|
+
Alias for n_chans.
|
|
50
|
+
n_classes:
|
|
51
|
+
Alias for n_outputs.
|
|
52
|
+
input_window_samples :
|
|
53
|
+
Alias for n_times.
|
|
54
|
+
activation: nn.Module, default=nn.LeakyReLU
|
|
55
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
56
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
|
|
57
|
+
|
|
58
|
+
Notes
|
|
59
|
+
-----
|
|
60
|
+
Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
|
|
61
|
+
|
|
62
|
+
References
|
|
63
|
+
----------
|
|
64
|
+
.. [TIDNet] Kostas, D. & Rudzicz, F.
|
|
65
|
+
Thinker invariance: enabling deep neural networks for BCI across more
|
|
66
|
+
people.
|
|
67
|
+
J. Neural Eng. 17, 056008 (2020).
|
|
68
|
+
doi: 10.1088/1741-2552/abb7a7.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
n_chans=None,
|
|
74
|
+
n_outputs=None,
|
|
75
|
+
n_times=None,
|
|
76
|
+
input_window_seconds=None,
|
|
77
|
+
sfreq=None,
|
|
78
|
+
chs_info=None,
|
|
79
|
+
s_growth: int = 24,
|
|
80
|
+
t_filters: int = 32,
|
|
81
|
+
drop_prob: float = 0.4,
|
|
82
|
+
pooling: int = 15,
|
|
83
|
+
temp_layers: int = 2,
|
|
84
|
+
spat_layers: int = 2,
|
|
85
|
+
temp_span: float = 0.05,
|
|
86
|
+
bottleneck: int = 3,
|
|
87
|
+
summary: int = -1,
|
|
88
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
89
|
+
):
|
|
90
|
+
super().__init__(
|
|
91
|
+
n_outputs=n_outputs,
|
|
92
|
+
n_chans=n_chans,
|
|
93
|
+
n_times=n_times,
|
|
94
|
+
input_window_seconds=input_window_seconds,
|
|
95
|
+
sfreq=sfreq,
|
|
96
|
+
chs_info=chs_info,
|
|
97
|
+
)
|
|
98
|
+
del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
|
|
99
|
+
|
|
100
|
+
self.temp_len = ceil(temp_span * self.n_times)
|
|
101
|
+
|
|
102
|
+
self.dscnn = _TIDNetFeatures(
|
|
103
|
+
s_growth=s_growth,
|
|
104
|
+
t_filters=t_filters,
|
|
105
|
+
n_chans=self.n_chans,
|
|
106
|
+
n_times=self.n_times,
|
|
107
|
+
drop_prob=drop_prob,
|
|
108
|
+
pooling=pooling,
|
|
109
|
+
temp_layers=temp_layers,
|
|
110
|
+
spat_layers=spat_layers,
|
|
111
|
+
temp_span=temp_span,
|
|
112
|
+
bottleneck=bottleneck,
|
|
113
|
+
summary=summary,
|
|
114
|
+
activation=activation,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self._num_features = self.dscnn.num_features
|
|
118
|
+
|
|
119
|
+
self.flatten = nn.Flatten(start_dim=1)
|
|
120
|
+
|
|
121
|
+
self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
|
|
122
|
+
|
|
123
|
+
def _create_classifier(self, incoming: int, n_outputs: int):
|
|
124
|
+
classifier = nn.Linear(incoming, n_outputs)
|
|
125
|
+
init.xavier_normal_(classifier.weight)
|
|
126
|
+
classifier.bias.data.zero_()
|
|
127
|
+
seq_clf = nn.Sequential(classifier, nn.Identity())
|
|
128
|
+
|
|
129
|
+
return seq_clf
|
|
130
|
+
|
|
131
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
|
+
"""Forward pass.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
x: torch.Tensor
|
|
137
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
x = self.dscnn(x)
|
|
141
|
+
x = self.flatten(x)
|
|
142
|
+
return self.final_layer(x)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def num_features(self):
|
|
146
|
+
return self._num_features
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class _BatchNormZG(nn.BatchNorm2d):
|
|
150
|
+
def reset_parameters(self):
|
|
151
|
+
if self.track_running_stats:
|
|
152
|
+
self.running_mean.zero_()
|
|
153
|
+
self.running_var.fill_(1)
|
|
154
|
+
if self.affine:
|
|
155
|
+
self.weight.data.zero_()
|
|
156
|
+
self.bias.data.zero_()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class _ConvBlock2D(nn.Module):
|
|
160
|
+
r"""Implements Convolution block with order:
|
|
161
|
+
Convolution, dropout, activation, batch-norm
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
in_filters: int,
|
|
167
|
+
out_filters: int,
|
|
168
|
+
kernel: tuple[int, int],
|
|
169
|
+
stride: tuple[int, int] = (1, 1),
|
|
170
|
+
padding: int = 0,
|
|
171
|
+
dilation: int = 1,
|
|
172
|
+
groups: int = 1,
|
|
173
|
+
drop_prob: float = 0.5,
|
|
174
|
+
batch_norm: bool = True,
|
|
175
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
176
|
+
residual: bool = False,
|
|
177
|
+
):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.kernel = kernel
|
|
180
|
+
self.activation = activation()
|
|
181
|
+
self.residual = residual
|
|
182
|
+
|
|
183
|
+
self.conv = nn.Conv2d(
|
|
184
|
+
in_filters,
|
|
185
|
+
out_filters,
|
|
186
|
+
kernel,
|
|
187
|
+
stride=stride,
|
|
188
|
+
padding=padding,
|
|
189
|
+
dilation=dilation,
|
|
190
|
+
groups=groups,
|
|
191
|
+
bias=not batch_norm,
|
|
192
|
+
)
|
|
193
|
+
self.dropout = nn.Dropout2d(p=float(drop_prob))
|
|
194
|
+
self.batch_norm = (
|
|
195
|
+
_BatchNormZG(out_filters)
|
|
196
|
+
if residual
|
|
197
|
+
else nn.BatchNorm2d(out_filters)
|
|
198
|
+
if batch_norm
|
|
199
|
+
else nn.Identity()
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
203
|
+
res = input
|
|
204
|
+
input = self.conv(
|
|
205
|
+
input,
|
|
206
|
+
)
|
|
207
|
+
input = self.dropout(input)
|
|
208
|
+
input = self.activation(input)
|
|
209
|
+
input = self.batch_norm(input)
|
|
210
|
+
return input + res if self.residual else input
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class _DenseFilter(nn.Module):
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
in_features: int,
|
|
217
|
+
growth_rate: int,
|
|
218
|
+
filter_len: int = 5,
|
|
219
|
+
drop_prob: float = 0.5,
|
|
220
|
+
bottleneck: int = 2,
|
|
221
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
222
|
+
dim: int = -2,
|
|
223
|
+
):
|
|
224
|
+
super().__init__()
|
|
225
|
+
dim = dim if dim > 0 else dim + 4
|
|
226
|
+
if dim < 2 or dim > 3:
|
|
227
|
+
raise ValueError("Only last two dimensions supported")
|
|
228
|
+
kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
|
|
229
|
+
|
|
230
|
+
self.net = nn.Sequential(
|
|
231
|
+
nn.BatchNorm2d(in_features),
|
|
232
|
+
activation(),
|
|
233
|
+
nn.Conv2d(in_features, bottleneck * growth_rate, 1),
|
|
234
|
+
nn.BatchNorm2d(bottleneck * growth_rate),
|
|
235
|
+
activation(),
|
|
236
|
+
nn.Conv2d(
|
|
237
|
+
bottleneck * growth_rate,
|
|
238
|
+
growth_rate,
|
|
239
|
+
kernel,
|
|
240
|
+
padding=tuple((k // 2 for k in kernel)),
|
|
241
|
+
),
|
|
242
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
246
|
+
return torch.cat((x, self.net(x)), dim=1)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class _DenseSpatialFilter(nn.Module):
|
|
250
|
+
def __init__(
|
|
251
|
+
self,
|
|
252
|
+
n_chans: int,
|
|
253
|
+
growth: int,
|
|
254
|
+
depth: int,
|
|
255
|
+
in_ch: int = 1,
|
|
256
|
+
bottleneck: int = 4,
|
|
257
|
+
drop_prob: float = 0.0,
|
|
258
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
259
|
+
collapse: bool = True,
|
|
260
|
+
):
|
|
261
|
+
super().__init__()
|
|
262
|
+
self.net = nn.Sequential(
|
|
263
|
+
*[
|
|
264
|
+
_DenseFilter(
|
|
265
|
+
in_ch + growth * d,
|
|
266
|
+
growth,
|
|
267
|
+
bottleneck=bottleneck,
|
|
268
|
+
drop_prob=drop_prob,
|
|
269
|
+
activation=activation,
|
|
270
|
+
)
|
|
271
|
+
for d in range(depth)
|
|
272
|
+
]
|
|
273
|
+
)
|
|
274
|
+
n_filters = in_ch + growth * depth
|
|
275
|
+
self.collapse = collapse
|
|
276
|
+
if collapse:
|
|
277
|
+
self.channel_collapse = _ConvBlock2D(
|
|
278
|
+
n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
282
|
+
if len(x.shape) < 4:
|
|
283
|
+
x = x.unsqueeze(1).permute([0, 1, 3, 2])
|
|
284
|
+
x = self.net(x)
|
|
285
|
+
if self.collapse:
|
|
286
|
+
return self.channel_collapse(x).squeeze(-2)
|
|
287
|
+
return x
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class _TemporalFilter(nn.Module):
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
n_chans: int,
|
|
294
|
+
filters: int,
|
|
295
|
+
depth: int,
|
|
296
|
+
temp_len: int,
|
|
297
|
+
drop_prob: float = 0.0,
|
|
298
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
299
|
+
residual: str = "netwise",
|
|
300
|
+
):
|
|
301
|
+
super().__init__()
|
|
302
|
+
temp_len = temp_len + 1 - temp_len % 2
|
|
303
|
+
self.residual_style = str(residual)
|
|
304
|
+
net = list()
|
|
305
|
+
|
|
306
|
+
for i in range(depth):
|
|
307
|
+
dil = depth - i
|
|
308
|
+
conv = weight_norm(
|
|
309
|
+
nn.Conv2d(
|
|
310
|
+
n_chans if i == 0 else filters,
|
|
311
|
+
filters,
|
|
312
|
+
kernel_size=(1, temp_len),
|
|
313
|
+
dilation=dil,
|
|
314
|
+
padding=(0, dil * (temp_len - 1) // 2),
|
|
315
|
+
)
|
|
316
|
+
)
|
|
317
|
+
net.append(
|
|
318
|
+
nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
|
|
319
|
+
)
|
|
320
|
+
if self.residual_style.lower() == "netwise":
|
|
321
|
+
self.net = nn.Sequential(*net)
|
|
322
|
+
self.residual = nn.Conv2d(n_chans, filters, (1, 1))
|
|
323
|
+
elif residual.lower() == "dense":
|
|
324
|
+
self.net = net
|
|
325
|
+
|
|
326
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
327
|
+
style = self.residual_style.lower()
|
|
328
|
+
if style == "netwise":
|
|
329
|
+
return self.net(x) + self.residual(x)
|
|
330
|
+
elif style == "dense":
|
|
331
|
+
for layer in self.net:
|
|
332
|
+
x = torch.cat((x, layer(x)), dim=1)
|
|
333
|
+
return x
|
|
334
|
+
# TorchScript now knows this path always returns or errors
|
|
335
|
+
else:
|
|
336
|
+
# Use an assertion so TorchScript can compile it
|
|
337
|
+
assert False, f"Unsupported residual style: {self.residual_style}"
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class _TIDNetFeatures(nn.Module):
|
|
341
|
+
def __init__(
|
|
342
|
+
self,
|
|
343
|
+
s_growth: int,
|
|
344
|
+
t_filters: int,
|
|
345
|
+
n_chans: int,
|
|
346
|
+
n_times: int,
|
|
347
|
+
drop_prob: float,
|
|
348
|
+
pooling: int,
|
|
349
|
+
temp_layers: int,
|
|
350
|
+
spat_layers: int,
|
|
351
|
+
temp_span: float,
|
|
352
|
+
bottleneck: int,
|
|
353
|
+
summary: int,
|
|
354
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
355
|
+
):
|
|
356
|
+
super().__init__()
|
|
357
|
+
self.n_chans = n_chans
|
|
358
|
+
self.temp_len = ceil(temp_span * n_times)
|
|
359
|
+
|
|
360
|
+
self.temporal = nn.Sequential(
|
|
361
|
+
Ensure4d(),
|
|
362
|
+
Rearrange("batch C T 1 -> batch 1 C T"),
|
|
363
|
+
_TemporalFilter(
|
|
364
|
+
1,
|
|
365
|
+
t_filters,
|
|
366
|
+
depth=temp_layers,
|
|
367
|
+
temp_len=self.temp_len,
|
|
368
|
+
activation=activation,
|
|
369
|
+
),
|
|
370
|
+
nn.MaxPool2d((1, pooling)),
|
|
371
|
+
nn.Dropout2d(p=float(drop_prob)),
|
|
372
|
+
)
|
|
373
|
+
summary = n_times // pooling if summary == -1 else summary
|
|
374
|
+
|
|
375
|
+
self.spatial = _DenseSpatialFilter(
|
|
376
|
+
n_chans=n_chans,
|
|
377
|
+
growth=s_growth,
|
|
378
|
+
depth=spat_layers,
|
|
379
|
+
in_ch=t_filters,
|
|
380
|
+
drop_prob=drop_prob,
|
|
381
|
+
bottleneck=bottleneck,
|
|
382
|
+
activation=activation,
|
|
383
|
+
)
|
|
384
|
+
self.extract_features = nn.Sequential(
|
|
385
|
+
nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self._num_features = (t_filters + s_growth * spat_layers) * summary
|
|
389
|
+
|
|
390
|
+
@property
|
|
391
|
+
def num_features(self):
|
|
392
|
+
return self._num_features
|
|
393
|
+
|
|
394
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
395
|
+
x = self.temporal(x)
|
|
396
|
+
x = self.spatial(x)
|
|
397
|
+
return self.extract_features(x)
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from einops.layers.torch import Rearrange
|
|
10
|
+
from mne.utils import deprecated, warn
|
|
11
|
+
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TSception(EEGModuleMixin, nn.Module):
|
|
16
|
+
r"""TSception model from Ding et al. (2020) from [ding2020]_.
|
|
17
|
+
|
|
18
|
+
:bdg-success:`Convolution`
|
|
19
|
+
|
|
20
|
+
TSception: A deep learning framework for emotion detection using EEG.
|
|
21
|
+
|
|
22
|
+
.. figure:: https://user-images.githubusercontent.com/58539144/74716976-80415e00-526a-11ea-9433-02ab2b753f6b.PNG
|
|
23
|
+
:align: center
|
|
24
|
+
:alt: TSception Architecture
|
|
25
|
+
|
|
26
|
+
The model consists of temporal and spatial convolutional layers
|
|
27
|
+
(Tception and Sception) designed to learn temporal and spatial features
|
|
28
|
+
from EEG data.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
number_filter_temp : int
|
|
33
|
+
Number of temporal convolutional filters.
|
|
34
|
+
number_filter_spat : int
|
|
35
|
+
Number of spatial convolutional filters.
|
|
36
|
+
hidden_size : int
|
|
37
|
+
Number of units in the hidden fully connected layer.
|
|
38
|
+
drop_prob : float
|
|
39
|
+
Dropout rate applied after the hidden layer.
|
|
40
|
+
activation : nn.Module, optional
|
|
41
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
42
|
+
module like ``nn.ReLU`` or ``nn.LeakyReLU``. Default is ``nn.LeakyReLU``.
|
|
43
|
+
pool_size : int, optional
|
|
44
|
+
Pooling size for the average pooling layers. Default is 8.
|
|
45
|
+
inception_windows : list[float], optional
|
|
46
|
+
List of window sizes (in seconds) for the inception modules.
|
|
47
|
+
Default is [0.5, 0.25, 0.125].
|
|
48
|
+
|
|
49
|
+
Notes
|
|
50
|
+
-----
|
|
51
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
52
|
+
by original authors. The modifications are minimal and the model is expected
|
|
53
|
+
to work as intended. the original code from [code2020]_.
|
|
54
|
+
|
|
55
|
+
References
|
|
56
|
+
----------
|
|
57
|
+
.. [ding2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
|
|
58
|
+
Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
|
|
59
|
+
for emotion detection using EEG. In 2020 international joint conference
|
|
60
|
+
on neural networks (IJCNN) (pp. 1-7). IEEE.
|
|
61
|
+
.. [code2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
|
|
62
|
+
Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
|
|
63
|
+
for emotion detection using EEG.
|
|
64
|
+
https://github.com/deepBrains/TSception/blob/master/Models.py
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
# Braindecode parameters
|
|
70
|
+
n_chans=None,
|
|
71
|
+
n_outputs=None,
|
|
72
|
+
input_window_seconds=None,
|
|
73
|
+
chs_info=None,
|
|
74
|
+
n_times=None,
|
|
75
|
+
sfreq=None,
|
|
76
|
+
# Model parameters
|
|
77
|
+
number_filter_temp: int = 9,
|
|
78
|
+
number_filter_spat: int = 6,
|
|
79
|
+
hidden_size: int = 128,
|
|
80
|
+
drop_prob: float = 0.5,
|
|
81
|
+
activation: type[nn.Module] = nn.LeakyReLU,
|
|
82
|
+
pool_size: int = 8,
|
|
83
|
+
inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
|
|
84
|
+
):
|
|
85
|
+
super().__init__(
|
|
86
|
+
n_outputs=n_outputs,
|
|
87
|
+
n_chans=n_chans,
|
|
88
|
+
chs_info=chs_info,
|
|
89
|
+
n_times=n_times,
|
|
90
|
+
input_window_seconds=input_window_seconds,
|
|
91
|
+
sfreq=sfreq,
|
|
92
|
+
)
|
|
93
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
94
|
+
|
|
95
|
+
self.activation = activation
|
|
96
|
+
self.pool_size = pool_size
|
|
97
|
+
self.inception_windows = inception_windows
|
|
98
|
+
self.number_filter_spat = number_filter_spat
|
|
99
|
+
self.number_filter_temp = number_filter_temp
|
|
100
|
+
self.drop_prob = drop_prob
|
|
101
|
+
|
|
102
|
+
### Layers
|
|
103
|
+
self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
|
|
104
|
+
if self.input_window_seconds < max(self.inception_windows):
|
|
105
|
+
inception_windows = (
|
|
106
|
+
self.input_window_seconds,
|
|
107
|
+
self.input_window_seconds / 2,
|
|
108
|
+
self.input_window_seconds / 4,
|
|
109
|
+
)
|
|
110
|
+
warning_msg = (
|
|
111
|
+
"Input window size is smaller than the maximum inception window size. "
|
|
112
|
+
"We are adjusting the input window size to match the maximum inception window size.\n"
|
|
113
|
+
f"Original input window size: {self.inception_windows}, \n"
|
|
114
|
+
f"Adjusted inception windows: {inception_windows}"
|
|
115
|
+
)
|
|
116
|
+
warn(warning_msg, UserWarning)
|
|
117
|
+
self.inception_windows = inception_windows
|
|
118
|
+
# Define temporal convolutional layers (Tception)
|
|
119
|
+
self.temporal_blocks = nn.ModuleList()
|
|
120
|
+
for window in self.inception_windows:
|
|
121
|
+
# 1. Calculate the temporal kernel size for this block
|
|
122
|
+
kernel_size_t = int(window * self.sfreq)
|
|
123
|
+
|
|
124
|
+
# 2. Calculate the output length of the convolution
|
|
125
|
+
conv_out_len = self.n_times - kernel_size_t + 1
|
|
126
|
+
|
|
127
|
+
# 3. Ensure the pooling size is not larger than the conv output
|
|
128
|
+
# and is at least 1.
|
|
129
|
+
dynamic_pool_size = max(1, min(self.pool_size, conv_out_len))
|
|
130
|
+
|
|
131
|
+
# 4. Create the block with the dynamic pooling size
|
|
132
|
+
block = self._conv_block(
|
|
133
|
+
in_channels=1,
|
|
134
|
+
out_channels=self.number_filter_temp,
|
|
135
|
+
kernel_size=(1, kernel_size_t),
|
|
136
|
+
stride=1,
|
|
137
|
+
pool_size=dynamic_pool_size, # Use the dynamic size
|
|
138
|
+
activation=self.activation,
|
|
139
|
+
)
|
|
140
|
+
self.temporal_blocks.append(block)
|
|
141
|
+
|
|
142
|
+
self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
|
|
143
|
+
|
|
144
|
+
# Define spatial convolutional layers (Sception)
|
|
145
|
+
|
|
146
|
+
pool_size_spat = self.pool_size // 4
|
|
147
|
+
|
|
148
|
+
self.spatial_block_1 = self._conv_block(
|
|
149
|
+
in_channels=self.number_filter_temp,
|
|
150
|
+
out_channels=self.number_filter_spat,
|
|
151
|
+
kernel_size=(self.n_chans, 1),
|
|
152
|
+
stride=1,
|
|
153
|
+
pool_size=pool_size_spat,
|
|
154
|
+
activation=self.activation,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
kernel_size_spat_2 = (max(1, self.n_chans // 2), 1)
|
|
158
|
+
|
|
159
|
+
self.spatial_block_2 = self._conv_block(
|
|
160
|
+
in_channels=self.number_filter_temp,
|
|
161
|
+
out_channels=self.number_filter_spat,
|
|
162
|
+
kernel_size=kernel_size_spat_2,
|
|
163
|
+
stride=kernel_size_spat_2,
|
|
164
|
+
pool_size=pool_size_spat,
|
|
165
|
+
activation=self.activation,
|
|
166
|
+
)
|
|
167
|
+
self.batch_spatial_lay = nn.BatchNorm2d(self.number_filter_spat)
|
|
168
|
+
|
|
169
|
+
# Calculate the size of the features after convolution and pooling layers
|
|
170
|
+
self.feature_size = self._calculate_feature_size()
|
|
171
|
+
# self.feature_size = self.number_filter_spat *
|
|
172
|
+
# Define the final classification layers
|
|
173
|
+
|
|
174
|
+
self.dense_layer = nn.Sequential(
|
|
175
|
+
nn.Linear(self.feature_size, hidden_size),
|
|
176
|
+
self.activation(),
|
|
177
|
+
nn.Dropout(self.drop_prob),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self.final_layer = nn.Linear(hidden_size, self.n_outputs)
|
|
181
|
+
|
|
182
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
183
|
+
"""
|
|
184
|
+
Forward pass of the TSception model.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
x : torch.Tensor
|
|
189
|
+
Input tensor of shape (batch_size, n_channels, n_times).
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
torch.Tensor
|
|
194
|
+
Output tensor of shape (batch_size, n_classes).
|
|
195
|
+
"""
|
|
196
|
+
# Temporal Convolution
|
|
197
|
+
# shape: (batch_size, n_channels, n_times)
|
|
198
|
+
x = self.ensuredim(x)
|
|
199
|
+
# shape: (batch_size, 1, n_channels, n_times)
|
|
200
|
+
|
|
201
|
+
t_features = [layer(x) for layer in self.temporal_blocks]
|
|
202
|
+
# shape: (batch_size, number_filter_temp, n_channels,
|
|
203
|
+
#
|
|
204
|
+
t_out = torch.cat(t_features, dim=-1)
|
|
205
|
+
|
|
206
|
+
t_out = self.batch_temporal_lay(t_out)
|
|
207
|
+
|
|
208
|
+
# Spatial Convolution
|
|
209
|
+
s_out1 = self.spatial_block_1(t_out)
|
|
210
|
+
s_out2 = self.spatial_block_2(t_out)
|
|
211
|
+
s_out = torch.cat((s_out1, s_out2), dim=2)
|
|
212
|
+
s_out = self.batch_spatial_lay(s_out)
|
|
213
|
+
|
|
214
|
+
# Flatten and apply final layers
|
|
215
|
+
s_out = s_out.view(s_out.size(0), -1)
|
|
216
|
+
output = self.dense_layer(s_out)
|
|
217
|
+
output = self.final_layer(output)
|
|
218
|
+
return output
|
|
219
|
+
|
|
220
|
+
def _calculate_feature_size(self) -> int:
|
|
221
|
+
"""
|
|
222
|
+
Calculates the size of the features after convolution and pooling layers.
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
int
|
|
227
|
+
Flattened size of the features after convolution and pooling layers.
|
|
228
|
+
"""
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
dummy_input = torch.ones(1, 1, self.n_chans, self.n_times)
|
|
231
|
+
t_features = [layer(dummy_input) for layer in self.temporal_blocks]
|
|
232
|
+
t_out = torch.cat(t_features, dim=-1)
|
|
233
|
+
t_out = self.batch_temporal_lay(t_out)
|
|
234
|
+
|
|
235
|
+
s_out1 = self.spatial_block_1(t_out)
|
|
236
|
+
s_out2 = self.spatial_block_2(t_out)
|
|
237
|
+
s_out = torch.cat((s_out1, s_out2), dim=2)
|
|
238
|
+
s_out = self.batch_spatial_lay(s_out)
|
|
239
|
+
|
|
240
|
+
feature_size = s_out.view(1, -1).size(1)
|
|
241
|
+
return feature_size
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _conv_block(
|
|
245
|
+
in_channels: int,
|
|
246
|
+
out_channels: int,
|
|
247
|
+
kernel_size: tuple,
|
|
248
|
+
stride: tuple[int, int] | int,
|
|
249
|
+
pool_size: int,
|
|
250
|
+
activation: nn.Module,
|
|
251
|
+
) -> nn.Sequential:
|
|
252
|
+
"""
|
|
253
|
+
Creates a convolutional block with Conv2d, activation, and AvgPool2d layers.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
in_channels : int
|
|
258
|
+
Number of input channels.
|
|
259
|
+
out_channels : int
|
|
260
|
+
Number of output channels.
|
|
261
|
+
kernel_size : tuple
|
|
262
|
+
Size of the convolutional kernel.
|
|
263
|
+
stride : int
|
|
264
|
+
Stride of the convolution.
|
|
265
|
+
pool_size : int
|
|
266
|
+
Size of the pooling kernel.
|
|
267
|
+
activation : nn.Module
|
|
268
|
+
Activation function class.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
nn.Sequential
|
|
273
|
+
A sequential container of the convolutional block.
|
|
274
|
+
"""
|
|
275
|
+
return nn.Sequential(
|
|
276
|
+
nn.Conv2d(
|
|
277
|
+
in_channels=in_channels,
|
|
278
|
+
out_channels=out_channels,
|
|
279
|
+
kernel_size=kernel_size,
|
|
280
|
+
stride=stride,
|
|
281
|
+
padding=0,
|
|
282
|
+
),
|
|
283
|
+
activation(),
|
|
284
|
+
nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_size)),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@deprecated(
|
|
289
|
+
"`TSceptionV1` was renamed to `TSception` in v1.12; "
|
|
290
|
+
"this alias will be removed in v1.14."
|
|
291
|
+
)
|
|
292
|
+
class TSceptionV1(TSception):
|
|
293
|
+
r"""Deprecated alias for TSception."""
|
|
294
|
+
|
|
295
|
+
pass
|