braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
1
+ """IFNet Neural Network.
2
+
3
+ Authors: Jiaheng Wang
4
+ Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
5
+ License: MIT (https://github.com/Jiaheng-Wang/IFNet/blob/main/LICENSE)
6
+
7
+ J. Wang, L. Yao and Y. Wang, "IFNet: An Interactive Frequency Convolutional
8
+ Neural Network for Enhancing Motor Imagery Decoding from EEG," in IEEE
9
+ Transactions on Neural Systems and Rehabilitation Engineering,
10
+ doi: 10.1109/TNSRE.2023.3257319.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Optional, Sequence
16
+
17
+ import torch
18
+ from einops.layers.torch import Rearrange
19
+ from mne.utils import warn
20
+ from torch import nn
21
+ from torch.nn.init import trunc_normal_
22
+
23
+ from braindecode.models.base import EEGModuleMixin
24
+ from braindecode.modules import (
25
+ FilterBankLayer,
26
+ LinearWithConstraint,
27
+ LogPowerLayer,
28
+ )
29
+
30
+
31
+ class IFNet(EEGModuleMixin, nn.Module):
32
+ """IFNetV2 from Wang J et al (2023) [ifnet]_.
33
+
34
+ .. figure:: https://raw.githubusercontent.com/Jiaheng-Wang/IFNet/main/IFNet.png
35
+ :align: center
36
+ :alt: IFNetV2 Architecture
37
+
38
+ Overview of the Interactive Frequency Convolutional Neural Network architecture.
39
+
40
+ IFNetV2 is designed to effectively capture spectro-spatial-temporal
41
+ features for motor imagery decoding from EEG data. The model consists of
42
+ three stages: Spectro-Spatial Feature Representation, Cross-Frequency
43
+ Interactions, and Classification.
44
+
45
+ - **Spectro-Spatial Feature Representation**: The raw EEG signals are
46
+ filtered into two characteristic frequency bands: low (4-16 Hz) and
47
+ high (16-40 Hz), covering the most relevant motor imagery bands.
48
+ Spectro-spatial features are then extracted through 1D point-wise
49
+ spatial convolution followed by temporal convolution.
50
+
51
+ - **Cross-Frequency Interactions**: The extracted spectro-spatial
52
+ features from each frequency band are combined through an element-wise
53
+ summation operation, which enhances feature representation while
54
+ preserving distinct characteristics.
55
+
56
+ - **Classification**: The aggregated spectro-spatial features are further
57
+ reduced through temporal average pooling and passed through a fully
58
+ connected layer followed by a softmax operation to generate output
59
+ probabilities for each class.
60
+
61
+ Notes
62
+ -----
63
+ This implementation is not guaranteed to be correct, has not been checked
64
+ by original authors, only reimplemented from the paper description and
65
+ Torch source code [ifnetv2code]_. Version 2 is present only in the repository,
66
+ and the main difference is one pooling layer, describe at the TABLE VII
67
+ from the paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10070810
68
+
69
+
70
+ Parameters
71
+ ----------
72
+ bands : list[tuple[int, int]] or int or None, default=[[4, 16], (16, 40)]
73
+ Frequency bands for filtering.
74
+ out_planes : int, default=64
75
+ Number of output feature dimensions.
76
+ kernel_sizes : tuple of int, default=(63, 31)
77
+ List of kernel sizes for temporal convolutions.
78
+ patch_size : int, default=125
79
+ Size of the patches for temporal segmentation.
80
+ drop_prob : float, default=0.5
81
+ Dropout probability.
82
+ activation : nn.Module, default=nn.GELU
83
+ Activation function after the InterFrequency Layer.
84
+ verbose : bool, default=False
85
+ Verbose to control the filtering layer
86
+ filter_parameters : dict, default={}
87
+ Additional parameters for the filter bank layer.
88
+
89
+ References
90
+ ----------
91
+ .. [ifnet] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
92
+ frequency convolutional neural network for enhancing motor imagery
93
+ decoding from EEG. IEEE Transactions on Neural Systems and
94
+ Rehabilitation Engineering, 31, 1900-1911.
95
+ .. [ifnetv2code] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
96
+ frequency convolutional neural network for enhancing motor imagery
97
+ decoding from EEG.
98
+ https://github.com/Jiaheng-Wang/IFNet
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ # Braindecode parameters
104
+ n_chans=None,
105
+ n_outputs=None,
106
+ n_times=None,
107
+ chs_info=None,
108
+ input_window_seconds=None,
109
+ sfreq=None,
110
+ # Model-specific parameters
111
+ bands: list[tuple[float, float]] | int | None = [(4.0, 16.0), (16, 40)],
112
+ n_filters_spat: int = 64,
113
+ kernel_sizes: tuple[int, int] = (63, 31),
114
+ stride_factor: int = 8,
115
+ drop_prob: float = 0.5,
116
+ linear_max_norm: float = 0.5,
117
+ activation: type[nn.Module] = nn.GELU,
118
+ verbose: bool = False,
119
+ filter_parameters: Optional[dict] = None,
120
+ ):
121
+ super().__init__(
122
+ n_chans=n_chans,
123
+ n_outputs=n_outputs,
124
+ chs_info=chs_info,
125
+ n_times=n_times,
126
+ input_window_seconds=input_window_seconds,
127
+ sfreq=sfreq,
128
+ )
129
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
130
+
131
+ self.bands = bands
132
+ self.n_filters_spat = n_filters_spat
133
+ self.stride_factor = stride_factor
134
+ self.kernel_sizes = kernel_sizes
135
+ self.verbose = verbose
136
+ self.filter_parameters = filter_parameters
137
+ self.drop_prob = drop_prob
138
+ self.activation = activation
139
+ self.linear_max_norm = linear_max_norm
140
+ self.filter_parameters = filter_parameters or {}
141
+
142
+ # Layers
143
+ # Following paper nomenclature
144
+ self.spectral_filtering = FilterBankLayer(
145
+ n_chans=self.n_chans,
146
+ sfreq=self.sfreq,
147
+ band_filters=self.bands,
148
+ verbose=verbose,
149
+ **self.filter_parameters,
150
+ )
151
+ # As we have an internal process to create the bands,
152
+ # we get the values from the filterbank
153
+ self.n_bands = self.spectral_filtering.n_bands
154
+
155
+ # My interpretation from the TABLE VII IFNet Architecture from the
156
+ # paper.
157
+ self.ensuredim = Rearrange(
158
+ "batch nbands chans time -> batch (nbands chans) time"
159
+ )
160
+
161
+ # SpatioTemporal Feature Block
162
+ self.feature_block = _SpatioTemporalFeatureBlock(
163
+ in_channels=self.n_chans * self.n_bands,
164
+ out_channels=self.n_filters_spat,
165
+ kernel_sizes=self.kernel_sizes,
166
+ stride_factor=self.stride_factor,
167
+ n_bands=self.n_bands,
168
+ drop_prob=self.drop_prob,
169
+ activation=self.activation,
170
+ n_times=self.n_times,
171
+ )
172
+
173
+ # Final classification layer
174
+ self.final_layer = LinearWithConstraint(
175
+ in_features=self.n_filters_spat * stride_factor,
176
+ out_features=self.n_outputs,
177
+ max_norm=self.linear_max_norm,
178
+ )
179
+
180
+ self.flatten = Rearrange("batch ... -> batch (...)")
181
+
182
+ # Initialize parameters
183
+ self._initialize_weights(self)
184
+
185
+ @staticmethod
186
+ def _initialize_weights(m):
187
+ """Initializes weights of the network.
188
+
189
+ Parameters
190
+ ----------
191
+ m : nn.Module
192
+ Module to initialize.
193
+ """
194
+ if isinstance(m, nn.Linear):
195
+ trunc_normal_(m.weight, std=0.01)
196
+ if m.bias is not None:
197
+ nn.init.constant_(m.bias, 0)
198
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d)):
199
+ if m.weight is not None:
200
+ nn.init.constant_(m.weight, 1.0)
201
+ if m.bias is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
204
+ trunc_normal_(m.weight, std=0.01)
205
+ if m.bias is not None:
206
+ nn.init.constant_(m.bias, 0)
207
+
208
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
209
+ """Forward pass of IFNet.
210
+
211
+ Parameters
212
+ ----------
213
+ x : torch.Tensor
214
+ Input tensor with shape (batch_size, n_chans, n_times).
215
+
216
+ Returns
217
+ -------
218
+ torch.Tensor
219
+ Output tensor with shape (batch_size, n_outputs).
220
+ """
221
+ # Pass through the spectral filtering layer
222
+ x = self.spectral_filtering(x)
223
+ # x is now of shape (batch_size, n_bands, n_chans, n_times)
224
+ x = self.ensuredim(x)
225
+ # x is now of shape (batch_size, n_bands * n_chans, n_times)
226
+
227
+ # Pass through the feature block
228
+ x = self.feature_block(x)
229
+
230
+ # Flatten and pass through the final layer
231
+ x = self.flatten(x)
232
+
233
+ x = self.final_layer(x)
234
+
235
+ return x
236
+
237
+
238
+ class _InterFrequencyModule(nn.Module):
239
+ """Module that combines outputs from different frequency bands."""
240
+
241
+ def __init__(self, activation: nn.Module = nn.GELU):
242
+ """
243
+
244
+ Parameters
245
+ ----------
246
+ activation: nn.Module
247
+ Activation function for the InterFrequency Module
248
+
249
+ """
250
+ super().__init__()
251
+
252
+ self.activation = activation()
253
+
254
+ def forward(self, x_list: list[torch.Tensor]) -> torch.Tensor:
255
+ """Forward pass.
256
+
257
+ Parameters
258
+ ----------
259
+ x_list : list of torch.Tensor
260
+ List of tensors to be combined.
261
+
262
+ Returns
263
+ -------
264
+ torch.Tensor
265
+ Combined tensor after applying GELU activation.
266
+ """
267
+ x = torch.stack(x_list, dim=0).sum(dim=0)
268
+ x = self.activation(x)
269
+ return x
270
+
271
+
272
+ class _SpatioTemporalFeatureBlock(nn.Module):
273
+ """SpatioTemporal Feature Block consisting of spatial and temporal convolutions."""
274
+
275
+ def __init__(
276
+ self,
277
+ n_times: int,
278
+ in_channels: int,
279
+ out_channels: int = 64,
280
+ kernel_sizes: Sequence[int] = [63, 31],
281
+ stride_factor: int = 8,
282
+ n_bands: int = 2,
283
+ drop_prob: float = 0.5,
284
+ activation: nn.Module = nn.GELU,
285
+ dim: int = 3,
286
+ ):
287
+ """
288
+ Parameters
289
+ ----------
290
+ in_channels : int
291
+ Number of input channels.
292
+ out_channels : int, default=64
293
+ Number of output channels.
294
+ kernel_sizes : list of int, default=[63, 31]
295
+ List of kernel sizes for temporal convolutions.
296
+ stride_factor : int, default=4
297
+ Stride factor for temporal segmentation.
298
+ n_bands : int, default=2
299
+ Number of frequency bands or groups.
300
+ drop_prob : float, default=0.5
301
+ Dropout probability.
302
+ activation: nn.Module, default=nn.GELU
303
+ Activation function after the InterFrequency Layer
304
+ dim: int, default=3
305
+ Internal dimensional to apply the LogPowerLayer
306
+ """
307
+ super().__init__()
308
+ self.in_channels = in_channels
309
+ self.out_channels = out_channels
310
+ self.n_bands = n_bands
311
+ self.stride_factor = stride_factor
312
+ self.drop_prob = drop_prob
313
+ self.activation = activation
314
+ self.dim = dim
315
+ self.n_times = n_times
316
+ self.kernel_sizes = kernel_sizes
317
+
318
+ if self.n_bands != len(self.kernel_sizes):
319
+ warn(
320
+ f"Got {self.n_bands} bands, different from {len(self.kernel_sizes)} amount of "
321
+ "kernels to build the temporal convolution, we will apply "
322
+ "min(n_bands, len(self.kernel_size) to apply the convolution.",
323
+ UserWarning,
324
+ )
325
+ if self.n_bands > len(self.kernel_sizes):
326
+ self.n_bands = len(self.kernel_sizes)
327
+ warn(
328
+ f"Reducing number of bands to {len(self.kernel_sizes)} to match the number of kernels.",
329
+ UserWarning,
330
+ )
331
+ elif self.n_bands < len(self.kernel_sizes):
332
+ self.kernel_sizes = self.kernel_sizes[: self.n_bands]
333
+ warn(
334
+ f"Reducing number of kernels to {self.n_bands} to match the number of bands.",
335
+ UserWarning,
336
+ )
337
+
338
+ if self.n_times % self.stride_factor != 0:
339
+ warn(
340
+ f"Time dimension ({self.n_times}) is not divisible by"
341
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
342
+ UserWarning,
343
+ )
344
+
345
+ out_channels_spatial = self.out_channels * self.n_bands
346
+
347
+ # Spatial convolution
348
+ self.spatial_conv = nn.Conv1d(
349
+ in_channels=self.in_channels,
350
+ out_channels=out_channels_spatial,
351
+ kernel_size=1,
352
+ groups=self.n_bands,
353
+ bias=False,
354
+ )
355
+ self.spatial_bn = nn.BatchNorm1d(out_channels_spatial)
356
+
357
+ self.unpack_bands = nn.Unflatten(
358
+ dim=1, unflattened_size=(self.n_bands, self.out_channels)
359
+ )
360
+
361
+ # Temporal convolutions for each radix
362
+ self.temporal_convs = nn.ModuleList()
363
+ for kernel_size in self.kernel_sizes:
364
+ self.temporal_convs.append(
365
+ nn.Sequential(
366
+ nn.Conv1d(
367
+ in_channels=self.out_channels,
368
+ out_channels=self.out_channels,
369
+ kernel_size=kernel_size,
370
+ padding=kernel_size // 2,
371
+ groups=self.out_channels,
372
+ bias=False,
373
+ ),
374
+ nn.BatchNorm1d(self.out_channels),
375
+ )
376
+ )
377
+ # Inter-frequency module
378
+ self.inter_frequency = _InterFrequencyModule(activation=self.activation)
379
+
380
+ if self.n_times % self.stride_factor != 0:
381
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
382
+ self.n_times_padded = self.n_times + self.padding_size
383
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
384
+ else:
385
+ self.padding_layer = nn.Identity()
386
+ self.n_times_padded = self.n_times
387
+
388
+ # Log-Power layer
389
+ self.log_power = LogPowerLayer(dim=self.dim) # type: ignore
390
+ # Dropout
391
+ self.dropout_layer = nn.Dropout(self.drop_prob)
392
+
393
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
394
+ """Forward pass.
395
+
396
+ Parameters
397
+ ----------
398
+ x : torch.Tensor
399
+ Input tensor of shape (batch_size, in_channels, n_times).
400
+
401
+ Returns
402
+ -------
403
+ torch.Tensor
404
+ Output tensor after processing.
405
+ """
406
+ batch_size, _, _ = x.shape
407
+
408
+ # Spatial convolution
409
+ x = self.spatial_conv(x)
410
+
411
+ x = self.spatial_bn(x)
412
+
413
+ # Split the output by bands for each frequency
414
+ x_split = self.unpack_bands(x)
415
+
416
+ x_t = []
417
+ for idx, conv in enumerate(self.temporal_convs):
418
+ x_t.append(conv(x_split[::, idx]))
419
+
420
+ # Inter-frequency interaction
421
+ x = self.inter_frequency(x_t)
422
+
423
+ # Reshape for temporal segmentation
424
+ x = self.padding_layer(x)
425
+ # x is now of shape (batch_size, ..., n_times_padded)
426
+
427
+ # Reshape for log-power computation
428
+ x = x.view(
429
+ batch_size,
430
+ self.out_channels,
431
+ self.stride_factor,
432
+ self.n_times_padded // self.stride_factor,
433
+ )
434
+
435
+ # Log-Power layer
436
+ x = self.log_power(x)
437
+
438
+ # Dropout
439
+ x = self.dropout_layer(x)
440
+
441
+ return x