braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops.layers.torch import Rearrange
10
+ from mne.utils import warn
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+
14
+
15
+ class TSceptionV1(EEGModuleMixin, nn.Module):
16
+ """TSception model from Ding et al. (2020) from [ding2020]_.
17
+
18
+ TSception: A deep learning framework for emotion detection using EEG.
19
+
20
+ .. figure:: https://user-images.githubusercontent.com/58539144/74716976-80415e00-526a-11ea-9433-02ab2b753f6b.PNG
21
+ :align: center
22
+ :alt: TSceptionV1 Architecture
23
+
24
+ The model consists of temporal and spatial convolutional layers
25
+ (Tception and Sception) designed to learn temporal and spatial features
26
+ from EEG data.
27
+
28
+ Parameters
29
+ ----------
30
+ number_filter_temp : int
31
+ Number of temporal convolutional filters.
32
+ number_filter_spat : int
33
+ Number of spatial convolutional filters.
34
+ hidden_size : int
35
+ Number of units in the hidden fully connected layer.
36
+ drop_prob : float
37
+ Dropout rate applied after the hidden layer.
38
+ activation : nn.Module, optional
39
+ Activation function class to apply. Should be a PyTorch activation
40
+ module like ``nn.ReLU`` or ``nn.LeakyReLU``. Default is ``nn.LeakyReLU``.
41
+ pool_size : int, optional
42
+ Pooling size for the average pooling layers. Default is 8.
43
+ inception_windows : list[float], optional
44
+ List of window sizes (in seconds) for the inception modules.
45
+ Default is [0.5, 0.25, 0.125].
46
+
47
+ Notes
48
+ -----
49
+ This implementation is not guaranteed to be correct, has not been checked
50
+ by original authors. The modifications are minimal and the model is expected
51
+ to work as intended. the original code from [code2020]_.
52
+
53
+ References
54
+ ----------
55
+ .. [ding2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
56
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
57
+ for emotion detection using EEG. In 2020 international joint conference
58
+ on neural networks (IJCNN) (pp. 1-7). IEEE.
59
+ .. [code2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
60
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
61
+ for emotion detection using EEG.
62
+ https://github.com/deepBrains/TSception/blob/master/Models.py
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ # Braindecode parameters
68
+ n_chans=None,
69
+ n_outputs=None,
70
+ input_window_seconds=None,
71
+ chs_info=None,
72
+ n_times=None,
73
+ sfreq=None,
74
+ # Model parameters
75
+ number_filter_temp: int = 9,
76
+ number_filter_spat: int = 6,
77
+ hidden_size: int = 128,
78
+ drop_prob: float = 0.5,
79
+ activation: nn.Module = nn.LeakyReLU,
80
+ pool_size: int = 8,
81
+ inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
82
+ ):
83
+ super().__init__(
84
+ n_outputs=n_outputs,
85
+ n_chans=n_chans,
86
+ chs_info=chs_info,
87
+ n_times=n_times,
88
+ input_window_seconds=input_window_seconds,
89
+ sfreq=sfreq,
90
+ )
91
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
92
+
93
+ self.activation = activation
94
+ self.pool_size = pool_size
95
+ self.inception_windows = inception_windows
96
+ self.number_filter_spat = number_filter_spat
97
+ self.number_filter_temp = number_filter_temp
98
+ self.drop_prob = drop_prob
99
+
100
+ ### Layers
101
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
102
+ if self.input_window_seconds < max(self.inception_windows):
103
+ inception_windows = (
104
+ self.input_window_seconds,
105
+ self.input_window_seconds / 2,
106
+ self.input_window_seconds / 4,
107
+ )
108
+ warning_msg = (
109
+ "Input window size is smaller than the maximum inception window size. "
110
+ "We are adjusting the input window size to match the maximum inception window size.\n"
111
+ f"Original input window size: {self.inception_windows}, \n"
112
+ f"Adjusted inception windows: {inception_windows}"
113
+ )
114
+ warn(warning_msg, UserWarning)
115
+ self.inception_windows = inception_windows
116
+ # Define temporal convolutional layers (Tception)
117
+ self.temporal_blocks = nn.ModuleList()
118
+ for window in self.inception_windows:
119
+ # 1. Calculate the temporal kernel size for this block
120
+ kernel_size_t = int(window * self.sfreq)
121
+
122
+ # 2. Calculate the output length of the convolution
123
+ conv_out_len = self.n_times - kernel_size_t + 1
124
+
125
+ # 3. Ensure the pooling size is not larger than the conv output
126
+ # and is at least 1.
127
+ dynamic_pool_size = max(1, min(self.pool_size, conv_out_len))
128
+
129
+ # 4. Create the block with the dynamic pooling size
130
+ block = self._conv_block(
131
+ in_channels=1,
132
+ out_channels=self.number_filter_temp,
133
+ kernel_size=(1, kernel_size_t),
134
+ stride=1,
135
+ pool_size=dynamic_pool_size, # Use the dynamic size
136
+ activation=self.activation,
137
+ )
138
+ self.temporal_blocks.append(block)
139
+
140
+ self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
141
+
142
+ # Define spatial convolutional layers (Sception)
143
+
144
+ pool_size_spat = self.pool_size // 4
145
+
146
+ self.spatial_block_1 = self._conv_block(
147
+ in_channels=self.number_filter_temp,
148
+ out_channels=self.number_filter_spat,
149
+ kernel_size=(self.n_chans, 1),
150
+ stride=1,
151
+ pool_size=pool_size_spat,
152
+ activation=self.activation,
153
+ )
154
+
155
+ kernel_size_spat_2 = (max(1, self.n_chans // 2), 1)
156
+
157
+ self.spatial_block_2 = self._conv_block(
158
+ in_channels=self.number_filter_temp,
159
+ out_channels=self.number_filter_spat,
160
+ kernel_size=kernel_size_spat_2,
161
+ stride=kernel_size_spat_2,
162
+ pool_size=pool_size_spat,
163
+ activation=self.activation,
164
+ )
165
+ self.batch_spatial_lay = nn.BatchNorm2d(self.number_filter_spat)
166
+
167
+ # Calculate the size of the features after convolution and pooling layers
168
+ self.feature_size = self._calculate_feature_size()
169
+ # self.feature_size = self.number_filter_spat *
170
+ # Define the final classification layers
171
+
172
+ self.dense_layer = nn.Sequential(
173
+ nn.Linear(self.feature_size, hidden_size),
174
+ self.activation(),
175
+ nn.Dropout(self.drop_prob),
176
+ )
177
+
178
+ self.final_layer = nn.Linear(hidden_size, self.n_outputs)
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ """
182
+ Forward pass of the TSception model.
183
+
184
+ Parameters
185
+ ----------
186
+ x : torch.Tensor
187
+ Input tensor of shape (batch_size, n_channels, n_times).
188
+
189
+ Returns
190
+ -------
191
+ torch.Tensor
192
+ Output tensor of shape (batch_size, n_classes).
193
+ """
194
+ # Temporal Convolution
195
+ # shape: (batch_size, n_channels, n_times)
196
+ x = self.ensuredim(x)
197
+ # shape: (batch_size, 1, n_channels, n_times)
198
+
199
+ t_features = [layer(x) for layer in self.temporal_blocks]
200
+ # shape: (batch_size, number_filter_temp, n_channels,
201
+ #
202
+ t_out = torch.cat(t_features, dim=-1)
203
+
204
+ t_out = self.batch_temporal_lay(t_out)
205
+
206
+ # Spatial Convolution
207
+ s_out1 = self.spatial_block_1(t_out)
208
+ s_out2 = self.spatial_block_2(t_out)
209
+ s_out = torch.cat((s_out1, s_out2), dim=2)
210
+ s_out = self.batch_spatial_lay(s_out)
211
+
212
+ # Flatten and apply final layers
213
+ s_out = s_out.view(s_out.size(0), -1)
214
+ output = self.dense_layer(s_out)
215
+ output = self.final_layer(output)
216
+ return output
217
+
218
+ def _calculate_feature_size(self) -> int:
219
+ """
220
+ Calculates the size of the features after convolution and pooling layers.
221
+
222
+ Returns
223
+ -------
224
+ int
225
+ Flattened size of the features after convolution and pooling layers.
226
+ """
227
+ with torch.no_grad():
228
+ dummy_input = torch.ones(1, 1, self.n_chans, self.n_times)
229
+ t_features = [layer(dummy_input) for layer in self.temporal_blocks]
230
+ t_out = torch.cat(t_features, dim=-1)
231
+ t_out = self.batch_temporal_lay(t_out)
232
+
233
+ s_out1 = self.spatial_block_1(t_out)
234
+ s_out2 = self.spatial_block_2(t_out)
235
+ s_out = torch.cat((s_out1, s_out2), dim=2)
236
+ s_out = self.batch_spatial_lay(s_out)
237
+
238
+ feature_size = s_out.view(1, -1).size(1)
239
+ return feature_size
240
+
241
+ @staticmethod
242
+ def _conv_block(
243
+ in_channels: int,
244
+ out_channels: int,
245
+ kernel_size: tuple,
246
+ stride: tuple[int, int] | int,
247
+ pool_size: int,
248
+ activation: nn.Module,
249
+ ) -> nn.Sequential:
250
+ """
251
+ Creates a convolutional block with Conv2d, activation, and AvgPool2d layers.
252
+
253
+ Parameters
254
+ ----------
255
+ in_channels : int
256
+ Number of input channels.
257
+ out_channels : int
258
+ Number of output channels.
259
+ kernel_size : tuple
260
+ Size of the convolutional kernel.
261
+ stride : int
262
+ Stride of the convolution.
263
+ pool_size : int
264
+ Size of the pooling kernel.
265
+ activation : nn.Module
266
+ Activation function class.
267
+
268
+ Returns
269
+ -------
270
+ nn.Sequential
271
+ A sequential container of the convolutional block.
272
+ """
273
+ return nn.Sequential(
274
+ nn.Conv2d(
275
+ in_channels=in_channels,
276
+ out_channels=out_channels,
277
+ kernel_size=kernel_size,
278
+ stride=stride,
279
+ padding=0,
280
+ ),
281
+ activation(),
282
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_size)),
283
+ )