braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,376 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from mne.utils import warn
7
+ from torch import nn
8
+ from torch.nn import init
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ AvgPool2dWithConv,
13
+ CombinedConv,
14
+ Ensure4d,
15
+ SqueezeFinalOutput,
16
+ )
17
+
18
+
19
+ class Deep4Net(EEGModuleMixin, nn.Sequential):
20
+ r"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
21
+
22
+ :bdg-success:`Convolution`
23
+
24
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
25
+ :align: center
26
+ :alt: Deep4Net Architecture
27
+ :width: 600px
28
+
29
+
30
+ Model described in [Schirrmeister2017]_.
31
+
32
+ Parameters
33
+ ----------
34
+ final_conv_length: int | str
35
+ Length of the final convolution layer.
36
+ If set to "auto", n_times must not be None. Default: "auto".
37
+ n_filters_time: int
38
+ Number of temporal filters.
39
+ n_filters_spat: int
40
+ Number of spatial filters.
41
+ filter_time_length: int
42
+ Length of the temporal filter in layer 1.
43
+ pool_time_length: int
44
+ Length of temporal pooling filter.
45
+ pool_time_stride: int
46
+ Length of stride between temporal pooling filters.
47
+ n_filters_2: int
48
+ Number of temporal filters in layer 2.
49
+ filter_length_2: int
50
+ Length of the temporal filter in layer 2.
51
+ n_filters_3: int
52
+ Number of temporal filters in layer 3.
53
+ filter_length_3: int
54
+ Length of the temporal filter in layer 3.
55
+ n_filters_4: int
56
+ Number of temporal filters in layer 4.
57
+ filter_length_4: int
58
+ Length of the temporal filter in layer 4.
59
+ activation_first_conv_nonlin: nn.Module, default is nn.ELU
60
+ Non-linear activation function to be used after convolution in layer 1.
61
+ first_pool_mode: str
62
+ Pooling mode in layer 1. "max" or "mean".
63
+ first_pool_nonlin: callable
64
+ Non-linear activation function to be used after pooling in layer 1.
65
+ activation_later_conv_nonlin: nn.Module, default is nn.ELU
66
+ Non-linear activation function to be used after convolution in later layers.
67
+ later_pool_mode: str
68
+ Pooling mode in later layers. "max" or "mean".
69
+ later_pool_nonlin: callable
70
+ Non-linear activation function to be used after pooling in later layers.
71
+ drop_prob: float
72
+ Dropout probability.
73
+ split_first_layer: bool
74
+ Split first layer into temporal and spatial layers (True) or just use temporal (False).
75
+ There would be no non-linearity between the split layers.
76
+ batch_norm: bool
77
+ Whether to use batch normalisation.
78
+ batch_norm_alpha: float
79
+ Momentum for BatchNorm2d.
80
+ stride_before_pool: bool
81
+ Stride before pooling.
82
+
83
+
84
+ References
85
+ ----------
86
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
87
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
88
+ & Ball, T. (2017).
89
+ Deep learning with convolutional neural networks for EEG decoding and
90
+ visualization.
91
+ Human Brain Mapping , Aug. 2017.
92
+ Online: http://dx.doi.org/10.1002/hbm.23730
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ n_chans=None,
98
+ n_outputs=None,
99
+ n_times=None,
100
+ final_conv_length="auto",
101
+ n_filters_time=25,
102
+ n_filters_spat=25,
103
+ filter_time_length=10,
104
+ pool_time_length=3,
105
+ pool_time_stride=3,
106
+ n_filters_2=50,
107
+ filter_length_2=10,
108
+ n_filters_3=100,
109
+ filter_length_3=10,
110
+ n_filters_4=200,
111
+ filter_length_4=10,
112
+ activation_first_conv_nonlin: type[nn.Module] = nn.ELU,
113
+ first_pool_mode="max",
114
+ first_pool_nonlin: type[nn.Module] = nn.Identity,
115
+ activation_later_conv_nonlin: type[nn.Module] = nn.ELU,
116
+ later_pool_mode="max",
117
+ later_pool_nonlin: type[nn.Module] = nn.Identity,
118
+ drop_prob=0.5,
119
+ split_first_layer=True,
120
+ batch_norm=True,
121
+ batch_norm_alpha=0.1,
122
+ stride_before_pool=False,
123
+ # Braindecode EEGModuleMixin parameters
124
+ chs_info=None,
125
+ input_window_seconds=None,
126
+ sfreq=None,
127
+ ):
128
+ super().__init__(
129
+ n_outputs=n_outputs,
130
+ n_chans=n_chans,
131
+ chs_info=chs_info,
132
+ n_times=n_times,
133
+ input_window_seconds=input_window_seconds,
134
+ sfreq=sfreq,
135
+ )
136
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
137
+
138
+ if final_conv_length == "auto":
139
+ assert self.n_times is not None
140
+ self.final_conv_length = final_conv_length
141
+ self.n_filters_time = n_filters_time
142
+ self.n_filters_spat = n_filters_spat
143
+ self.filter_time_length = filter_time_length
144
+ self.pool_time_length = pool_time_length
145
+ self.pool_time_stride = pool_time_stride
146
+ self.n_filters_2 = n_filters_2
147
+ self.filter_length_2 = filter_length_2
148
+ self.n_filters_3 = n_filters_3
149
+ self.filter_length_3 = filter_length_3
150
+ self.n_filters_4 = n_filters_4
151
+ self.filter_length_4 = filter_length_4
152
+ self.first_nonlin = activation_first_conv_nonlin
153
+ self.first_pool_mode = first_pool_mode
154
+ self.first_pool_nonlin = first_pool_nonlin
155
+ self.later_conv_nonlin = activation_later_conv_nonlin
156
+ self.later_pool_mode = later_pool_mode
157
+ self.later_pool_nonlin = later_pool_nonlin
158
+ self.drop_prob = drop_prob
159
+ self.split_first_layer = split_first_layer
160
+ self.batch_norm = batch_norm
161
+ self.batch_norm_alpha = batch_norm_alpha
162
+ self.stride_before_pool = stride_before_pool
163
+
164
+ min_n_times = self._get_min_n_times()
165
+ if self.n_times < min_n_times:
166
+ scaling_factor = self.n_times / min_n_times
167
+ warn(
168
+ f"n_times ({self.n_times}) is smaller than the minimum required "
169
+ f"({min_n_times}) for the current model parameters configuration. "
170
+ "Adjusting parameters to ensure compatibility."
171
+ "Reducing the kernel, pooling, and stride sizes accordingly."
172
+ "Scaling factor: {:.2f}".format(scaling_factor),
173
+ UserWarning,
174
+ )
175
+ # Calculate a scaling factor to adjust temporal parameters
176
+ # Apply the scaling factor to all temporal kernel and pooling sizes
177
+ self.filter_time_length = max(
178
+ 1, int(self.filter_time_length * scaling_factor)
179
+ )
180
+ self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
181
+ self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
182
+ self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
183
+ self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
184
+ self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
185
+ # For the load_state_dict
186
+ # When padronize all layers,
187
+ # add the old's parameters here
188
+ self.mapping = {
189
+ "conv_time.weight": "conv_time_spat.conv_time.weight",
190
+ "conv_spat.weight": "conv_time_spat.conv_spat.weight",
191
+ "conv_time.bias": "conv_time_spat.conv_time.bias",
192
+ "conv_spat.bias": "conv_time_spat.conv_spat.bias",
193
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
194
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
195
+ }
196
+
197
+ if self.stride_before_pool:
198
+ conv_stride = self.pool_time_stride
199
+ pool_stride = 1
200
+ else:
201
+ conv_stride = 1
202
+ pool_stride = self.pool_time_stride
203
+ self.add_module("ensuredims", Ensure4d())
204
+ pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
205
+ first_pool_class = pool_class_dict[self.first_pool_mode]
206
+ later_pool_class = pool_class_dict[self.later_pool_mode]
207
+ if self.split_first_layer:
208
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
209
+ self.add_module(
210
+ "conv_time_spat",
211
+ CombinedConv(
212
+ in_chans=self.n_chans,
213
+ n_filters_time=self.n_filters_time,
214
+ n_filters_spat=self.n_filters_spat,
215
+ filter_time_length=filter_time_length,
216
+ bias_time=True,
217
+ bias_spat=not self.batch_norm,
218
+ ),
219
+ )
220
+ n_filters_conv = self.n_filters_spat
221
+ else:
222
+ self.add_module(
223
+ "conv_time",
224
+ nn.Conv2d(
225
+ self.n_chans,
226
+ self.n_filters_time,
227
+ (self.filter_time_length, 1),
228
+ stride=(conv_stride, 1),
229
+ bias=not self.batch_norm,
230
+ ),
231
+ )
232
+ n_filters_conv = self.n_filters_time
233
+ if self.batch_norm:
234
+ self.add_module(
235
+ "bnorm",
236
+ nn.BatchNorm2d(
237
+ n_filters_conv,
238
+ momentum=self.batch_norm_alpha,
239
+ affine=True,
240
+ eps=1e-5,
241
+ ),
242
+ )
243
+ self.add_module("conv_nonlin", self.first_nonlin())
244
+ self.add_module(
245
+ "pool",
246
+ first_pool_class(
247
+ kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
248
+ ),
249
+ )
250
+ self.add_module("pool_nonlin", self.first_pool_nonlin())
251
+
252
+ def add_conv_pool_block(
253
+ model, n_filters_before, n_filters, filter_length, block_nr
254
+ ):
255
+ suffix = "_{:d}".format(block_nr)
256
+ self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
257
+ self.add_module(
258
+ "conv" + suffix,
259
+ nn.Conv2d(
260
+ n_filters_before,
261
+ n_filters,
262
+ (filter_length, 1),
263
+ stride=(conv_stride, 1),
264
+ bias=not self.batch_norm,
265
+ ),
266
+ )
267
+ if self.batch_norm:
268
+ self.add_module(
269
+ "bnorm" + suffix,
270
+ nn.BatchNorm2d(
271
+ n_filters,
272
+ momentum=self.batch_norm_alpha,
273
+ affine=True,
274
+ eps=1e-5,
275
+ ),
276
+ )
277
+ self.add_module("nonlin" + suffix, self.later_conv_nonlin())
278
+
279
+ self.add_module(
280
+ "pool" + suffix,
281
+ later_pool_class(
282
+ kernel_size=(self.pool_time_length, 1),
283
+ stride=(pool_stride, 1),
284
+ ),
285
+ )
286
+ self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
287
+
288
+ add_conv_pool_block(
289
+ self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
290
+ )
291
+ add_conv_pool_block(
292
+ self, self.n_filters_2, self.n_filters_3, self.filter_length_3, 3
293
+ )
294
+ add_conv_pool_block(
295
+ self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
296
+ )
297
+
298
+ self.eval()
299
+ if self.final_conv_length == "auto":
300
+ self.final_conv_length = self.get_output_shape()[2]
301
+
302
+ # Incorporating classification module and subsequent ones in one final layer
303
+ module = nn.Sequential()
304
+
305
+ module.add_module(
306
+ "conv_classifier",
307
+ nn.Conv2d(
308
+ self.n_filters_4,
309
+ self.n_outputs,
310
+ (self.final_conv_length, 1),
311
+ bias=True,
312
+ ),
313
+ )
314
+
315
+ module.add_module("squeeze", SqueezeFinalOutput())
316
+
317
+ self.add_module("final_layer", module)
318
+
319
+ # Initialization, xavier is same as in our paper...
320
+ # was default from lasagne
321
+ init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
322
+ # maybe no bias in case of no split layer and batch norm
323
+ if self.split_first_layer or (not self.batch_norm):
324
+ init.constant_(self.conv_time_spat.conv_time.bias, 0)
325
+ if self.split_first_layer:
326
+ init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
327
+ if not self.batch_norm:
328
+ init.constant_(self.conv_time_spat.conv_spat.bias, 0)
329
+ if self.batch_norm:
330
+ init.constant_(self.bnorm.weight, 1)
331
+ init.constant_(self.bnorm.bias, 0)
332
+ param_dict = dict(list(self.named_parameters()))
333
+ for block_nr in range(2, 5):
334
+ conv_weight = param_dict["conv_{:d}.weight".format(block_nr)]
335
+ init.xavier_uniform_(conv_weight, gain=1)
336
+ if not self.batch_norm:
337
+ conv_bias = param_dict["conv_{:d}.bias".format(block_nr)]
338
+ init.constant_(conv_bias, 0)
339
+ else:
340
+ bnorm_weight = param_dict["bnorm_{:d}.weight".format(block_nr)]
341
+ bnorm_bias = param_dict["bnorm_{:d}.bias".format(block_nr)]
342
+ init.constant_(bnorm_weight, 1)
343
+ init.constant_(bnorm_bias, 0)
344
+
345
+ init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
346
+ init.constant_(self.final_layer.conv_classifier.bias, 0)
347
+
348
+ self.train()
349
+
350
+ def _get_min_n_times(self) -> int:
351
+ """
352
+ Calculate the minimum number of time samples required for the model
353
+ to work with the given temporal parameters.
354
+ """
355
+ # Start with the minimum valid output length of the network (1)
356
+ min_len = 1
357
+
358
+ # List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
359
+ # Each tuple: (filter_length, pool_length, pool_stride)
360
+ block_params = [
361
+ (self.filter_length_4, self.pool_time_length, self.pool_time_stride),
362
+ (self.filter_length_3, self.pool_time_length, self.pool_time_stride),
363
+ (self.filter_length_2, self.pool_time_length, self.pool_time_stride),
364
+ (self.filter_time_length, self.pool_time_length, self.pool_time_stride),
365
+ ]
366
+
367
+ # Work backward from the last layer to the input
368
+ for filter_len, pool_len, pool_stride in block_params:
369
+ # Reverse the pooling operation
370
+ # L_in = stride * (L_out - 1) + kernel_size
371
+ min_len = pool_stride * (min_len - 1) + pool_len
372
+ # Reverse the convolution operation (assuming stride=1)
373
+ # L_in = L_out + kernel_size - 1
374
+ min_len = min_len + filter_len - 1
375
+
376
+ return min_len