braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
14
+
15
+
16
+ class EEGNeX(EEGModuleMixin, nn.Module):
17
+ """EEGNeX model from Chen et al. (2024) [eegnex]_.
18
+
19
+ .. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
20
+ :align: center
21
+ :alt: EEGNeX Architecture
22
+
23
+ Parameters
24
+ ----------
25
+ activation : nn.Module, optional
26
+ Activation function to use. Default is `nn.ELU`.
27
+ depth_multiplier : int, optional
28
+ Depth multiplier for the depthwise convolution. Default is 2.
29
+ filter_1 : int, optional
30
+ Number of filters in the first convolutional layer. Default is 8.
31
+ filter_2 : int, optional
32
+ Number of filters in the second convolutional layer. Default is 32.
33
+ drop_prob: float, optional
34
+ Dropout rate. Default is 0.5.
35
+ kernel_block_4 : tuple[int, int], optional
36
+ Kernel size for block 4. Default is (1, 16).
37
+ dilation_block_4 : tuple[int, int], optional
38
+ Dilation rate for block 4. Default is (1, 2).
39
+ avg_pool_block4 : tuple[int, int], optional
40
+ Pooling size for block 4. Default is (1, 4).
41
+ kernel_block_5 : tuple[int, int], optional
42
+ Kernel size for block 5. Default is (1, 16).
43
+ dilation_block_5 : tuple[int, int], optional
44
+ Dilation rate for block 5. Default is (1, 4).
45
+ avg_pool_block5 : tuple[int, int], optional
46
+ Pooling size for block 5. Default is (1, 8).
47
+
48
+ Notes
49
+ -----
50
+ This implementation is not guaranteed to be correct, has not been checked
51
+ by original authors, only reimplemented from the paper description and
52
+ source code in tensorflow [EEGNexCode]_.
53
+
54
+ References
55
+ ----------
56
+ .. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
57
+ Toward reliable signals decoding for electroencephalogram: A benchmark
58
+ study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
59
+ .. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
60
+ Toward reliable signals decoding for electroencephalogram: A benchmark
61
+ study to EEGNeX. https://github.com/chenxiachan/EEGNeX
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ # Signal related parameters
67
+ n_chans=None,
68
+ n_outputs=None,
69
+ n_times=None,
70
+ chs_info=None,
71
+ input_window_seconds=None,
72
+ sfreq=None,
73
+ # Model parameters
74
+ activation: nn.Module = nn.ELU,
75
+ depth_multiplier: int = 2,
76
+ filter_1: int = 8,
77
+ filter_2: int = 32,
78
+ drop_prob: float = 0.5,
79
+ kernel_block_1_2: int = 64,
80
+ kernel_block_4: int = 16,
81
+ dilation_block_4: int = 2,
82
+ avg_pool_block4: int = 4,
83
+ kernel_block_5: int = 16,
84
+ dilation_block_5: int = 4,
85
+ avg_pool_block5: int = 8,
86
+ max_norm_conv: float = 1.0,
87
+ max_norm_linear: float = 0.25,
88
+ ):
89
+ super().__init__(
90
+ n_outputs=n_outputs,
91
+ n_chans=n_chans,
92
+ chs_info=chs_info,
93
+ n_times=n_times,
94
+ input_window_seconds=input_window_seconds,
95
+ sfreq=sfreq,
96
+ )
97
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
98
+
99
+ self.depth_multiplier = depth_multiplier
100
+ self.filter_1 = filter_1
101
+ self.filter_2 = filter_2
102
+ self.filter_3 = self.filter_2 * self.depth_multiplier
103
+ self.drop_prob = drop_prob
104
+ self.activation = activation
105
+ self.kernel_block_1_2 = (1, kernel_block_1_2)
106
+ self.kernel_block_4 = (1, kernel_block_4)
107
+ self.dilation_block_4 = (1, dilation_block_4)
108
+ self.avg_pool_block4 = (1, avg_pool_block4)
109
+ self.kernel_block_5 = (1, kernel_block_5)
110
+ self.dilation_block_5 = (1, dilation_block_5)
111
+ self.avg_pool_block5 = (1, avg_pool_block5)
112
+
113
+ # final layers output
114
+ self.in_features = self._calculate_output_length()
115
+
116
+ # Following paper nomenclature
117
+ self.block_1 = nn.Sequential(
118
+ Rearrange("batch ch time -> batch 1 ch time"),
119
+ nn.Conv2d(
120
+ in_channels=1,
121
+ out_channels=self.filter_1,
122
+ kernel_size=self.kernel_block_1_2,
123
+ padding="same",
124
+ bias=False,
125
+ ),
126
+ nn.BatchNorm2d(num_features=self.filter_1),
127
+ )
128
+
129
+ self.block_2 = nn.Sequential(
130
+ nn.Conv2d(
131
+ in_channels=self.filter_1,
132
+ out_channels=self.filter_2,
133
+ kernel_size=self.kernel_block_1_2,
134
+ padding="same",
135
+ bias=False,
136
+ ),
137
+ nn.BatchNorm2d(num_features=self.filter_2),
138
+ )
139
+
140
+ self.block_3 = nn.Sequential(
141
+ Conv2dWithConstraint(
142
+ in_channels=self.filter_2,
143
+ out_channels=self.filter_3,
144
+ max_norm=max_norm_conv,
145
+ kernel_size=(self.n_chans, 1),
146
+ groups=self.filter_2,
147
+ bias=False,
148
+ ),
149
+ nn.BatchNorm2d(num_features=self.filter_3),
150
+ self.activation(),
151
+ nn.AvgPool2d(
152
+ kernel_size=self.avg_pool_block4,
153
+ padding=(0, 1),
154
+ ),
155
+ nn.Dropout(p=self.drop_prob),
156
+ )
157
+
158
+ self.block_4 = nn.Sequential(
159
+ nn.Conv2d(
160
+ in_channels=self.filter_3,
161
+ out_channels=self.filter_2,
162
+ kernel_size=self.kernel_block_4,
163
+ dilation=self.dilation_block_4,
164
+ padding="same",
165
+ bias=False,
166
+ ),
167
+ nn.BatchNorm2d(num_features=self.filter_2),
168
+ )
169
+
170
+ self.block_5 = nn.Sequential(
171
+ nn.Conv2d(
172
+ in_channels=self.filter_2,
173
+ out_channels=self.filter_1,
174
+ kernel_size=self.kernel_block_5,
175
+ dilation=self.dilation_block_5,
176
+ padding="same",
177
+ bias=False,
178
+ ),
179
+ nn.BatchNorm2d(num_features=self.filter_1),
180
+ self.activation(),
181
+ nn.AvgPool2d(
182
+ kernel_size=self.avg_pool_block5,
183
+ padding=(0, 1),
184
+ ),
185
+ nn.Dropout(p=self.drop_prob),
186
+ nn.Flatten(),
187
+ )
188
+
189
+ self.final_layer = LinearWithConstraint(
190
+ in_features=self.in_features,
191
+ out_features=self.n_outputs,
192
+ max_norm=max_norm_linear,
193
+ )
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """
197
+ Forward pass of the EEGNeX model.
198
+
199
+ Parameters
200
+ ----------
201
+ x : torch.Tensor
202
+ Input tensor of shape (batch_size, n_chans, n_times).
203
+
204
+ Returns
205
+ -------
206
+ torch.Tensor
207
+ Output tensor of shape (batch_size, n_outputs).
208
+ """
209
+ # x shape: (batch_size, n_chans, n_times)
210
+ x = self.block_1(x)
211
+ # (batch_size, n_filter, n_chans, n_times)
212
+ x = self.block_2(x)
213
+ # (batch_size, n_filter*4, n_chans, n_times)
214
+ x = self.block_3(x)
215
+ # (batch_size, 1, n_filter*8, n_times//4)
216
+ x = self.block_4(x)
217
+ # (batch_size, 1, n_filter*8, n_times//4)
218
+ x = self.block_5(x)
219
+ # (batch_size, n_filter*(n_times//32))
220
+ x = self.final_layer(x)
221
+
222
+ return x
223
+
224
+ def _calculate_output_length(self) -> int:
225
+ # Pooling kernel sizes for the time dimension
226
+ p4 = self.avg_pool_block4[1]
227
+ p5 = self.avg_pool_block5[1]
228
+
229
+ # Padding for the time dimension (assumed from padding=(0, 1))
230
+ pad4 = 1
231
+ pad5 = 1
232
+
233
+ # Stride is assumed to be equal to kernel size (p4 and p5)
234
+
235
+ # Calculate time dimension after block 3 pooling
236
+ # Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
237
+ T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
238
+
239
+ # Calculate time dimension after block 5 pooling
240
+ T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
241
+
242
+ # Calculate final flattened features (channels * 1 * time_dim)
243
+ # The spatial dimension is reduced to 1 after block 3's depthwise conv.
244
+ final_in_features = (
245
+ self.filter_1 * T5
246
+ ) # filter_1 is the number of channels before flatten
247
+ return final_in_features