braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,147 @@
1
1
  from math import ceil
2
2
 
3
3
  import torch
4
+ from einops.layers.torch import Rearrange
4
5
  from torch import nn
5
6
  from torch.nn import init
6
- from torch.nn.utils import weight_norm
7
- from einops.layers.torch import Rearrange
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+
12
+
13
+ class TIDNet(EEGModuleMixin, nn.Module):
14
+ """Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
15
+
16
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
17
+ :align: center
18
+ :alt: TIDNet Architecture
19
+
20
+ See [TIDNet]_ for details.
21
+
22
+ Parameters
23
+ ----------
24
+ s_growth : int
25
+ DenseNet-style growth factor (added filters per DenseFilter)
26
+ t_filters : int
27
+ Number of temporal filters.
28
+ drop_prob : float
29
+ Dropout probability
30
+ pooling : int
31
+ Max temporal pooling (width and stride)
32
+ temp_layers : int
33
+ Number of temporal layers
34
+ spat_layers : int
35
+ Number of DenseFilters
36
+ temp_span : float
37
+ Percentage of n_times that defines the temporal filter length:
38
+ temp_len = ceil(temp_span * n_times)
39
+ e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
40
+ filter of length 75.
41
+ bottleneck : int
42
+ Bottleneck factor within Densefilter
43
+ summary : int
44
+ Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
45
+ automatically (n_times // pooling).
46
+ in_chans :
47
+ Alias for n_chans.
48
+ n_classes:
49
+ Alias for n_outputs.
50
+ input_window_samples :
51
+ Alias for n_times.
52
+ activation: nn.Module, default=nn.LeakyReLU
53
+ Activation function class to apply. Should be a PyTorch activation
54
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
55
+
56
+ Notes
57
+ -----
58
+ Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
59
+
60
+ References
61
+ ----------
62
+ .. [TIDNet] Kostas, D. & Rudzicz, F.
63
+ Thinker invariance: enabling deep neural networks for BCI across more
64
+ people.
65
+ J. Neural Eng. 17, 056008 (2020).
66
+ doi: 10.1088/1741-2552/abb7a7.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ n_chans=None,
72
+ n_outputs=None,
73
+ n_times=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ chs_info=None,
77
+ s_growth: int = 24,
78
+ t_filters: int = 32,
79
+ drop_prob: float = 0.4,
80
+ pooling: int = 15,
81
+ temp_layers: int = 2,
82
+ spat_layers: int = 2,
83
+ temp_span: float = 0.05,
84
+ bottleneck: int = 3,
85
+ summary: int = -1,
86
+ activation: nn.Module = nn.LeakyReLU,
87
+ ):
88
+ super().__init__(
89
+ n_outputs=n_outputs,
90
+ n_chans=n_chans,
91
+ n_times=n_times,
92
+ input_window_seconds=input_window_seconds,
93
+ sfreq=sfreq,
94
+ chs_info=chs_info,
95
+ )
96
+ del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
97
+
98
+ self.temp_len = ceil(temp_span * self.n_times)
99
+
100
+ self.dscnn = _TIDNetFeatures(
101
+ s_growth=s_growth,
102
+ t_filters=t_filters,
103
+ n_chans=self.n_chans,
104
+ n_times=self.n_times,
105
+ drop_prob=drop_prob,
106
+ pooling=pooling,
107
+ temp_layers=temp_layers,
108
+ spat_layers=spat_layers,
109
+ temp_span=temp_span,
110
+ bottleneck=bottleneck,
111
+ summary=summary,
112
+ activation=activation,
113
+ )
114
+
115
+ self._num_features = self.dscnn.num_features
116
+
117
+ self.flatten = nn.Flatten(start_dim=1)
118
+
119
+ self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
8
120
 
9
- from .modules import Ensure4d
10
- from .base import EEGModuleMixin, deprecated_args
121
+ def _create_classifier(self, incoming: int, n_outputs: int):
122
+ classifier = nn.Linear(incoming, n_outputs)
123
+ init.xavier_normal_(classifier.weight)
124
+ classifier.bias.data.zero_()
125
+ seq_clf = nn.Sequential(classifier, nn.Identity())
126
+
127
+ return seq_clf
128
+
129
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
130
+ """Forward pass.
131
+
132
+ Parameters
133
+ ----------
134
+ x: torch.Tensor
135
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
136
+ """
137
+
138
+ x = self.dscnn(x)
139
+ x = self.flatten(x)
140
+ return self.final_layer(x)
141
+
142
+ @property
143
+ def num_features(self):
144
+ return self._num_features
11
145
 
12
146
 
13
147
  class _BatchNormZG(nn.BatchNorm2d):
@@ -25,27 +159,49 @@ class _ConvBlock2D(nn.Module):
25
159
  Convolution, dropout, activation, batch-norm
26
160
  """
27
161
 
28
- def __init__(self, in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1,
29
- groups=1, drop_prob=0.5, batch_norm=True, activation=nn.LeakyReLU, residual=False):
162
+ def __init__(
163
+ self,
164
+ in_filters: int,
165
+ out_filters: int,
166
+ kernel: tuple[int, int],
167
+ stride: tuple[int, int] = (1, 1),
168
+ padding: int = 0,
169
+ dilation: int = 1,
170
+ groups: int = 1,
171
+ drop_prob: float = 0.5,
172
+ batch_norm: bool = True,
173
+ activation: type[nn.Module] = nn.LeakyReLU,
174
+ residual: bool = False,
175
+ ):
30
176
  super().__init__()
31
177
  self.kernel = kernel
32
178
  self.activation = activation()
33
179
  self.residual = residual
34
180
 
35
- self.conv = nn.Conv2d(in_filters, out_filters, kernel, stride=stride, padding=padding,
36
- dilation=dilation, groups=groups, bias=not batch_norm)
37
- self.dropout = nn.Dropout2d(p=drop_prob)
181
+ self.conv = nn.Conv2d(
182
+ in_filters,
183
+ out_filters,
184
+ kernel,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ bias=not batch_norm,
190
+ )
191
+ self.dropout = nn.Dropout2d(p=float(drop_prob))
38
192
  self.batch_norm = (
39
193
  _BatchNormZG(out_filters)
40
194
  if residual
41
195
  else nn.BatchNorm2d(out_filters)
42
196
  if batch_norm
43
- else lambda x: x
197
+ else nn.Identity()
44
198
  )
45
199
 
46
- def forward(self, input):
200
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
47
201
  res = input
48
- input = self.conv(input, )
202
+ input = self.conv(
203
+ input,
204
+ )
49
205
  input = self.dropout(input)
50
206
  input = self.activation(input)
51
207
  input = self.batch_norm(input)
@@ -53,12 +209,20 @@ class _ConvBlock2D(nn.Module):
53
209
 
54
210
 
55
211
  class _DenseFilter(nn.Module):
56
- def __init__(self, in_features, growth_rate, filter_len=5, drop_prob=0.5, bottleneck=2,
57
- activation=nn.LeakyReLU, dim=-2):
212
+ def __init__(
213
+ self,
214
+ in_features: int,
215
+ growth_rate: int,
216
+ filter_len: int = 5,
217
+ drop_prob: float = 0.5,
218
+ bottleneck: int = 2,
219
+ activation: type[nn.Module] = nn.LeakyReLU,
220
+ dim: int = -2,
221
+ ):
58
222
  super().__init__()
59
223
  dim = dim if dim > 0 else dim + 4
60
224
  if dim < 2 or dim > 3:
61
- raise ValueError('Only last two dimensions supported')
225
+ raise ValueError("Only last two dimensions supported")
62
226
  kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
63
227
 
64
228
  self.net = nn.Sequential(
@@ -67,29 +231,52 @@ class _DenseFilter(nn.Module):
67
231
  nn.Conv2d(in_features, bottleneck * growth_rate, 1),
68
232
  nn.BatchNorm2d(bottleneck * growth_rate),
69
233
  activation(),
70
- nn.Conv2d(bottleneck * growth_rate, growth_rate, kernel,
71
- padding=tuple((k // 2 for k in kernel))),
72
- nn.Dropout2d(drop_prob)
234
+ nn.Conv2d(
235
+ bottleneck * growth_rate,
236
+ growth_rate,
237
+ kernel,
238
+ padding=tuple((k // 2 for k in kernel)),
239
+ ),
240
+ nn.Dropout2d(p=float(drop_prob)),
73
241
  )
74
242
 
75
- def forward(self, x):
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
244
  return torch.cat((x, self.net(x)), dim=1)
77
245
 
78
246
 
79
247
  class _DenseSpatialFilter(nn.Module):
80
- def __init__(self, n_chans, growth, depth, in_ch=1, bottleneck=4, drop_prob=0.0,
81
- activation=nn.LeakyReLU, collapse=True):
248
+ def __init__(
249
+ self,
250
+ n_chans: int,
251
+ growth: int,
252
+ depth: int,
253
+ in_ch: int = 1,
254
+ bottleneck: int = 4,
255
+ drop_prob: float = 0.0,
256
+ activation: type[nn.Module] = nn.LeakyReLU,
257
+ collapse: bool = True,
258
+ ):
82
259
  super().__init__()
83
- self.net = nn.Sequential(*[
84
- _DenseFilter(in_ch + growth * d, growth, bottleneck=bottleneck, drop_prob=drop_prob,
85
- activation=activation) for d in range(depth)
86
- ])
260
+ self.net = nn.Sequential(
261
+ *[
262
+ _DenseFilter(
263
+ in_ch + growth * d,
264
+ growth,
265
+ bottleneck=bottleneck,
266
+ drop_prob=drop_prob,
267
+ activation=activation,
268
+ )
269
+ for d in range(depth)
270
+ ]
271
+ )
87
272
  n_filters = in_ch + growth * depth
88
273
  self.collapse = collapse
89
274
  if collapse:
90
- self.channel_collapse = _ConvBlock2D(n_filters, n_filters, (n_chans, 1), drop_prob=0)
275
+ self.channel_collapse = _ConvBlock2D(
276
+ n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
277
+ )
91
278
 
92
- def forward(self, x):
279
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
280
  if len(x.shape) < 4:
94
281
  x = x.unsqueeze(1).permute([0, 1, 3, 2])
95
282
  x = self.net(x)
@@ -99,8 +286,16 @@ class _DenseSpatialFilter(nn.Module):
99
286
 
100
287
 
101
288
  class _TemporalFilter(nn.Module):
102
- def __init__(self, n_chans, filters, depth, temp_len, drop_prob=0., activation=nn.LeakyReLU,
103
- residual='netwise'):
289
+ def __init__(
290
+ self,
291
+ n_chans: int,
292
+ filters: int,
293
+ depth: int,
294
+ temp_len: int,
295
+ drop_prob: float = 0.0,
296
+ activation: type[nn.Module] = nn.LeakyReLU,
297
+ residual: str = "netwise",
298
+ ):
104
299
  super().__init__()
105
300
  temp_len = temp_len + 1 - temp_len % 2
106
301
  self.residual_style = str(residual)
@@ -108,32 +303,54 @@ class _TemporalFilter(nn.Module):
108
303
 
109
304
  for i in range(depth):
110
305
  dil = depth - i
111
- conv = weight_norm(nn.Conv2d(n_chans if i == 0 else filters, filters,
112
- kernel_size=(1, temp_len), dilation=dil,
113
- padding=(0, dil * (temp_len - 1) // 2)))
114
- net.append(nn.Sequential(
115
- conv,
116
- activation(),
117
- nn.Dropout2d(drop_prob)
118
- ))
119
- if self.residual_style.lower() == 'netwise':
306
+ conv = weight_norm(
307
+ nn.Conv2d(
308
+ n_chans if i == 0 else filters,
309
+ filters,
310
+ kernel_size=(1, temp_len),
311
+ dilation=dil,
312
+ padding=(0, dil * (temp_len - 1) // 2),
313
+ )
314
+ )
315
+ net.append(
316
+ nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
317
+ )
318
+ if self.residual_style.lower() == "netwise":
120
319
  self.net = nn.Sequential(*net)
121
320
  self.residual = nn.Conv2d(n_chans, filters, (1, 1))
122
- elif residual.lower() == 'dense':
321
+ elif residual.lower() == "dense":
123
322
  self.net = net
124
323
 
125
- def forward(self, x):
126
- if self.residual_style.lower() == 'netwise':
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ style = self.residual_style.lower()
326
+ if style == "netwise":
127
327
  return self.net(x) + self.residual(x)
128
- elif self.residual_style.lower() == 'dense':
328
+ elif style == "dense":
129
329
  for layer in self.net:
130
330
  x = torch.cat((x, layer(x)), dim=1)
131
331
  return x
332
+ # TorchScript now knows this path always returns or errors
333
+ else:
334
+ # Use an assertion so TorchScript can compile it
335
+ assert False, f"Unsupported residual style: {self.residual_style}"
132
336
 
133
337
 
134
338
  class _TIDNetFeatures(nn.Module):
135
- def __init__(self, s_growth, t_filters, n_chans, n_times, drop_prob, pooling,
136
- temp_layers, spat_layers, temp_span, bottleneck, summary):
339
+ def __init__(
340
+ self,
341
+ s_growth: int,
342
+ t_filters: int,
343
+ n_chans: int,
344
+ n_times: int,
345
+ drop_prob: float,
346
+ pooling: int,
347
+ temp_layers: int,
348
+ spat_layers: int,
349
+ temp_span: float,
350
+ bottleneck: int,
351
+ summary: int,
352
+ activation: type[nn.Module] = nn.LeakyReLU,
353
+ ):
137
354
  super().__init__()
138
355
  self.n_chans = n_chans
139
356
  self.temp_len = ceil(temp_span * n_times)
@@ -141,17 +358,29 @@ class _TIDNetFeatures(nn.Module):
141
358
  self.temporal = nn.Sequential(
142
359
  Ensure4d(),
143
360
  Rearrange("batch C T 1 -> batch 1 C T"),
144
- _TemporalFilter(1, t_filters, depth=temp_layers, temp_len=self.temp_len),
361
+ _TemporalFilter(
362
+ 1,
363
+ t_filters,
364
+ depth=temp_layers,
365
+ temp_len=self.temp_len,
366
+ activation=activation,
367
+ ),
145
368
  nn.MaxPool2d((1, pooling)),
146
- nn.Dropout2d(drop_prob),
369
+ nn.Dropout2d(p=float(drop_prob)),
147
370
  )
148
371
  summary = n_times // pooling if summary == -1 else summary
149
372
 
150
- self.spatial = _DenseSpatialFilter(n_chans, s_growth, spat_layers, in_ch=t_filters,
151
- drop_prob=drop_prob, bottleneck=bottleneck)
373
+ self.spatial = _DenseSpatialFilter(
374
+ n_chans=n_chans,
375
+ growth=s_growth,
376
+ depth=spat_layers,
377
+ in_ch=t_filters,
378
+ drop_prob=drop_prob,
379
+ bottleneck=bottleneck,
380
+ activation=activation,
381
+ )
152
382
  self.extract_features = nn.Sequential(
153
- nn.AdaptiveAvgPool1d(int(summary)),
154
- nn.Flatten(start_dim=1)
383
+ nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
155
384
  )
156
385
 
157
386
  self._num_features = (t_filters + s_growth * spat_layers) * summary
@@ -160,123 +389,7 @@ class _TIDNetFeatures(nn.Module):
160
389
  def num_features(self):
161
390
  return self._num_features
162
391
 
163
- def forward(self, x):
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
393
  x = self.temporal(x)
165
394
  x = self.spatial(x)
166
395
  return self.extract_features(x)
167
-
168
-
169
- class TIDNet(EEGModuleMixin, nn.Module):
170
- """Thinker Invariance DenseNet model from Kostas et al 2020.
171
-
172
- See [TIDNet]_ for details.
173
-
174
- Parameters
175
- ----------
176
- s_growth : int
177
- DenseNet-style growth factor (added filters per DenseFilter)
178
- t_filters : int
179
- Number of temporal filters.
180
- drop_prob : float
181
- Dropout probability
182
- pooling : int
183
- Max temporal pooling (width and stride)
184
- temp_layers : int
185
- Number of temporal layers
186
- spat_layers : int
187
- Number of DenseFilters
188
- temp_span : float
189
- Percentage of n_times that defines the temporal filter length:
190
- temp_len = ceil(temp_span * n_times)
191
- e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
192
- filter of length 75.
193
- bottleneck : int
194
- Bottleneck factor within Densefilter
195
- summary : int
196
- Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
197
- automatically (n_times // pooling).
198
- in_chans :
199
- Alias for n_chans.
200
- n_classes:
201
- Alias for n_outputs.
202
- input_window_samples :
203
- Alias for n_times.
204
-
205
- Notes
206
- -----
207
- Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
208
-
209
- References
210
- ----------
211
- .. [TIDNet] Kostas, D. & Rudzicz, F.
212
- Thinker invariance: enabling deep neural networks for BCI across more
213
- people.
214
- J. Neural Eng. 17, 056008 (2020).
215
- doi: 10.1088/1741-2552/abb7a7.
216
- """
217
-
218
- def __init__(self, n_chans=None, n_outputs=None, n_times=None,
219
- in_chans=None, n_classes=None, input_window_samples=None,
220
- s_growth=24, t_filters=32, drop_prob=0.4, pooling=15,
221
- temp_layers=2, spat_layers=2, temp_span=0.05,
222
- bottleneck=3, summary=-1, add_log_softmax=True):
223
- n_chans, n_outputs, n_times = deprecated_args(
224
- self,
225
- ('in_chans', 'n_chans', in_chans, n_chans),
226
- ('n_classes', 'n_outputs', n_classes, n_outputs),
227
- ('input_window_samples', 'n_times', input_window_samples, n_times),
228
- )
229
- super().__init__(
230
- n_outputs=n_outputs,
231
- n_chans=n_chans,
232
- n_times=n_times,
233
- add_log_softmax=add_log_softmax,
234
- )
235
- del n_outputs, n_chans, n_times
236
- del in_chans, n_classes, input_window_samples
237
-
238
- self.mapping = {
239
- 'classify.1.weight': 'final_layer.0.weight',
240
- 'classify.1.bias': 'final_layer.0.bias'
241
- }
242
-
243
- self.temp_len = ceil(temp_span * self.n_times)
244
-
245
- self.dscnn = _TIDNetFeatures(s_growth=s_growth, t_filters=t_filters, n_chans=self.n_chans,
246
- n_times=self.n_times,
247
- drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
248
- spat_layers=spat_layers, temp_span=temp_span,
249
- bottleneck=bottleneck, summary=summary)
250
-
251
- self._num_features = self.dscnn.num_features
252
-
253
- self.flatten = nn.Flatten(start_dim=1)
254
-
255
- self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
256
-
257
- def _create_classifier(self, incoming, n_outputs):
258
- classifier = nn.Linear(incoming, n_outputs)
259
- init.xavier_normal_(classifier.weight)
260
- classifier.bias.data.zero_()
261
- seq_clf = nn.Sequential(
262
- classifier,
263
- nn.LogSoftmax(dim=-1) if self.add_log_softmax else nn.Identity())
264
-
265
- return seq_clf
266
-
267
- def forward(self, x):
268
- """Forward pass.
269
-
270
- Parameters
271
- ----------
272
- x: torch.Tensor
273
- Batch of EEG windows of shape (batch_size, n_channels, n_times).
274
- """
275
-
276
- x = self.dscnn(x)
277
- x = self.flatten(x)
278
- return self.final_layer(x)
279
-
280
- @property
281
- def num_features(self):
282
- return self._num_features