braindecode 1.2.0.dev182094932__py3-none-any.whl → 1.3.0.dev173691341__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (34) hide show
  1. braindecode/datasets/experimental.py +218 -0
  2. braindecode/models/__init__.py +6 -8
  3. braindecode/models/atcnet.py +156 -16
  4. braindecode/models/attentionbasenet.py +151 -26
  5. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  6. braindecode/models/ctnet.py +1 -1
  7. braindecode/models/deep4.py +6 -2
  8. braindecode/models/deepsleepnet.py +118 -5
  9. braindecode/models/eegconformer.py +114 -15
  10. braindecode/models/eeginception_erp.py +76 -7
  11. braindecode/models/eeginception_mi.py +2 -0
  12. braindecode/models/eegnet.py +27 -190
  13. braindecode/models/eegnex.py +113 -6
  14. braindecode/models/eegsimpleconv.py +2 -0
  15. braindecode/models/eegtcnet.py +1 -1
  16. braindecode/models/sccnet.py +81 -8
  17. braindecode/models/shallow_fbcsp.py +2 -0
  18. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  19. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  20. braindecode/models/sparcnet.py +2 -0
  21. braindecode/models/summary.csv +39 -41
  22. braindecode/models/tidnet.py +2 -0
  23. braindecode/models/tsinception.py +15 -3
  24. braindecode/models/usleep.py +103 -9
  25. braindecode/models/util.py +5 -5
  26. braindecode/preprocessing/preprocess.py +20 -26
  27. braindecode/version.py +1 -1
  28. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/METADATA +7 -2
  29. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/RECORD +33 -33
  30. braindecode/models/eegresnet.py +0 -362
  31. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/WHEEL +0 -0
  32. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/licenses/LICENSE.txt +0 -0
  33. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/licenses/NOTICE.txt +0 -0
  34. {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev173691341.dist-info}/top_level.txt +0 -0
@@ -1,362 +0,0 @@
1
- # Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
2
- # Tonio Ball
3
- #
4
- # License: BSD-3
5
-
6
- import numpy as np
7
- import torch
8
- from einops.layers.torch import Rearrange
9
- from torch import nn
10
- from torch.nn import init
11
-
12
- from braindecode.models.base import EEGModuleMixin
13
- from braindecode.modules import (
14
- AvgPool2dWithConv,
15
- Ensure4d,
16
- SqueezeFinalOutput,
17
- )
18
-
19
-
20
- class EEGResNet(EEGModuleMixin, nn.Sequential):
21
- """EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
22
-
23
- .. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
24
- :align: center
25
- :alt: EEGResNet Architecture
26
-
27
- Model described in [Schirrmeister2017]_.
28
-
29
- Parameters
30
- ----------
31
- in_chans :
32
- Alias for ``n_chans``.
33
- n_classes :
34
- Alias for ``n_outputs``.
35
- input_window_samples :
36
- Alias for ``n_times``.
37
- activation: nn.Module, default=nn.ELU
38
- Activation function class to apply. Should be a PyTorch activation
39
- module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
40
-
41
- References
42
- ----------
43
- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
44
- L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
45
- & Ball, T. (2017). Deep learning with convolutional neural networks for ,
46
- EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
47
- Online: http://dx.doi.org/10.1002/hbm.23730
48
- """
49
-
50
- def __init__(
51
- self,
52
- n_chans=None,
53
- n_outputs=None,
54
- n_times=None,
55
- final_pool_length="auto",
56
- n_first_filters=20,
57
- n_layers_per_block=2,
58
- first_filter_length=3,
59
- activation=nn.ELU,
60
- split_first_layer=True,
61
- batch_norm_alpha=0.1,
62
- batch_norm_epsilon=1e-4,
63
- conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
64
- chs_info=None,
65
- input_window_seconds=None,
66
- sfreq=250,
67
- ):
68
- super().__init__(
69
- n_outputs=n_outputs,
70
- n_chans=n_chans,
71
- chs_info=chs_info,
72
- n_times=n_times,
73
- input_window_seconds=input_window_seconds,
74
- sfreq=sfreq,
75
- )
76
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
77
-
78
- if final_pool_length == "auto":
79
- assert self.n_times is not None
80
- assert first_filter_length % 2 == 1
81
- self.final_pool_length = final_pool_length
82
- self.n_first_filters = n_first_filters
83
- self.n_layers_per_block = n_layers_per_block
84
- self.first_filter_length = first_filter_length
85
- self.nonlinearity = activation
86
- self.split_first_layer = split_first_layer
87
- self.batch_norm_alpha = batch_norm_alpha
88
- self.batch_norm_epsilon = batch_norm_epsilon
89
- self.conv_weight_init_fn = conv_weight_init_fn
90
-
91
- self.mapping = {
92
- "conv_classifier.weight": "final_layer.conv_classifier.weight",
93
- "conv_classifier.bias": "final_layer.conv_classifier.bias",
94
- }
95
-
96
- self.add_module("ensuredims", Ensure4d())
97
- if self.split_first_layer:
98
- self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
99
- self.add_module(
100
- "conv_time",
101
- nn.Conv2d(
102
- 1,
103
- self.n_first_filters,
104
- (self.first_filter_length, 1),
105
- stride=1,
106
- padding=(self.first_filter_length // 2, 0),
107
- ),
108
- )
109
- self.add_module(
110
- "conv_spat",
111
- nn.Conv2d(
112
- self.n_first_filters,
113
- self.n_first_filters,
114
- (1, self.n_chans),
115
- stride=(1, 1),
116
- bias=False,
117
- ),
118
- )
119
- else:
120
- self.add_module(
121
- "conv_time",
122
- nn.Conv2d(
123
- self.n_chans,
124
- self.n_first_filters,
125
- (self.first_filter_length, 1),
126
- stride=(1, 1),
127
- padding=(self.first_filter_length // 2, 0),
128
- bias=False,
129
- ),
130
- )
131
- n_filters_conv = self.n_first_filters
132
- self.add_module(
133
- "bnorm",
134
- nn.BatchNorm2d(
135
- n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
136
- ),
137
- )
138
- self.add_module("conv_nonlin", self.nonlinearity())
139
- cur_dilation = np.array([1, 1])
140
- n_cur_filters = n_filters_conv
141
- i_block = 1
142
- for i_layer in range(self.n_layers_per_block):
143
- self.add_module(
144
- "res_{:d}_{:d}".format(i_block, i_layer),
145
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
146
- )
147
- i_block += 1
148
- cur_dilation[0] *= 2
149
- n_out_filters = int(2 * n_cur_filters)
150
- self.add_module(
151
- "res_{:d}_{:d}".format(i_block, 0),
152
- _ResidualBlock(
153
- n_cur_filters,
154
- n_out_filters,
155
- dilation=cur_dilation,
156
- ),
157
- )
158
- n_cur_filters = n_out_filters
159
- for i_layer in range(1, self.n_layers_per_block):
160
- self.add_module(
161
- "res_{:d}_{:d}".format(i_block, i_layer),
162
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
163
- )
164
-
165
- i_block += 1
166
- cur_dilation[0] *= 2
167
- n_out_filters = int(1.5 * n_cur_filters)
168
- self.add_module(
169
- "res_{:d}_{:d}".format(i_block, 0),
170
- _ResidualBlock(
171
- n_cur_filters,
172
- n_out_filters,
173
- dilation=cur_dilation,
174
- ),
175
- )
176
- n_cur_filters = n_out_filters
177
- for i_layer in range(1, self.n_layers_per_block):
178
- self.add_module(
179
- "res_{:d}_{:d}".format(i_block, i_layer),
180
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
181
- )
182
-
183
- i_block += 1
184
- cur_dilation[0] *= 2
185
- self.add_module(
186
- "res_{:d}_{:d}".format(i_block, 0),
187
- _ResidualBlock(
188
- n_cur_filters,
189
- n_cur_filters,
190
- dilation=cur_dilation,
191
- ),
192
- )
193
- for i_layer in range(1, self.n_layers_per_block):
194
- self.add_module(
195
- "res_{:d}_{:d}".format(i_block, i_layer),
196
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
197
- )
198
-
199
- i_block += 1
200
- cur_dilation[0] *= 2
201
- self.add_module(
202
- "res_{:d}_{:d}".format(i_block, 0),
203
- _ResidualBlock(
204
- n_cur_filters,
205
- n_cur_filters,
206
- dilation=cur_dilation,
207
- ),
208
- )
209
- for i_layer in range(1, self.n_layers_per_block):
210
- self.add_module(
211
- "res_{:d}_{:d}".format(i_block, i_layer),
212
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
213
- )
214
-
215
- i_block += 1
216
- cur_dilation[0] *= 2
217
- self.add_module(
218
- "res_{:d}_{:d}".format(i_block, 0),
219
- _ResidualBlock(
220
- n_cur_filters,
221
- n_cur_filters,
222
- dilation=cur_dilation,
223
- ),
224
- )
225
- for i_layer in range(1, self.n_layers_per_block):
226
- self.add_module(
227
- "res_{:d}_{:d}".format(i_block, i_layer),
228
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
229
- )
230
- i_block += 1
231
- cur_dilation[0] *= 2
232
- self.add_module(
233
- "res_{:d}_{:d}".format(i_block, 0),
234
- _ResidualBlock(
235
- n_cur_filters,
236
- n_cur_filters,
237
- dilation=cur_dilation,
238
- ),
239
- )
240
- for i_layer in range(1, self.n_layers_per_block):
241
- self.add_module(
242
- "res_{:d}_{:d}".format(i_block, i_layer),
243
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
244
- )
245
-
246
- self.eval()
247
- if self.final_pool_length == "auto":
248
- self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
249
- else:
250
- pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
251
- self.add_module(
252
- "mean_pool",
253
- AvgPool2dWithConv(
254
- (self.final_pool_length, 1), (1, 1), dilation=pool_dilation
255
- ),
256
- )
257
-
258
- # Incorporating classification module and subsequent ones in one final layer
259
- module = nn.Sequential()
260
-
261
- module.add_module(
262
- "conv_classifier",
263
- nn.Conv2d(
264
- n_cur_filters,
265
- self.n_outputs,
266
- (1, 1),
267
- bias=True,
268
- ),
269
- )
270
-
271
- module.add_module("squeeze", SqueezeFinalOutput())
272
-
273
- self.add_module("final_layer", module)
274
-
275
- # Initialize all weights
276
- self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
277
-
278
- # Start in train mode
279
- self.train()
280
-
281
- @staticmethod
282
- def _weights_init(module, conv_weight_init_fn):
283
- """
284
- initialize weights
285
- """
286
- classname = module.__class__.__name__
287
- if "Conv" in classname and classname != "AvgPool2dWithConv":
288
- conv_weight_init_fn(module.weight)
289
- if module.bias is not None:
290
- init.constant_(module.bias, 0)
291
- elif "BatchNorm" in classname:
292
- init.constant_(module.weight, 1)
293
- init.constant_(module.bias, 0)
294
-
295
-
296
- class _ResidualBlock(nn.Module):
297
- """
298
- create a residual learning building block with two stacked 3x3 convlayers as in paper
299
- """
300
-
301
- def __init__(
302
- self,
303
- in_filters,
304
- out_num_filters,
305
- dilation,
306
- filter_time_length=3,
307
- nonlinearity: nn.Module = nn.ELU,
308
- batch_norm_alpha=0.1,
309
- batch_norm_epsilon=1e-4,
310
- ):
311
- super(_ResidualBlock, self).__init__()
312
- time_padding = int((filter_time_length - 1) * dilation[0])
313
- assert time_padding % 2 == 0
314
- time_padding = int(time_padding // 2)
315
- dilation = (int(dilation[0]), int(dilation[1]))
316
- assert (out_num_filters - in_filters) % 2 == 0, (
317
- "Need even number of extra channels in order to be able to pad correctly"
318
- )
319
- self.n_pad_chans = out_num_filters - in_filters
320
-
321
- self.conv_1 = nn.Conv2d(
322
- in_filters,
323
- out_num_filters,
324
- (filter_time_length, 1),
325
- stride=(1, 1),
326
- dilation=dilation,
327
- padding=(time_padding, 0),
328
- )
329
- self.bn1 = nn.BatchNorm2d(
330
- out_num_filters,
331
- momentum=batch_norm_alpha,
332
- affine=True,
333
- eps=batch_norm_epsilon,
334
- )
335
- self.conv_2 = nn.Conv2d(
336
- out_num_filters,
337
- out_num_filters,
338
- (filter_time_length, 1),
339
- stride=(1, 1),
340
- dilation=dilation,
341
- padding=(time_padding, 0),
342
- )
343
- self.bn2 = nn.BatchNorm2d(
344
- out_num_filters,
345
- momentum=batch_norm_alpha,
346
- affine=True,
347
- eps=batch_norm_epsilon,
348
- )
349
- # also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
350
- # for resnet options as ilya used them
351
- self.nonlinearity = nonlinearity()
352
-
353
- def forward(self, x):
354
- stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
355
- stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
356
- if self.n_pad_chans != 0:
357
- zeros_for_padding = x.new_zeros(
358
- (x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
359
- )
360
- x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
361
- out = self.nonlinearity(x + stack_2)
362
- return out