braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,258 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops.layers.torch import Rearrange
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+
13
+
14
+ class TSceptionV1(EEGModuleMixin, nn.Module):
15
+ """TSception model from Ding et al. (2020) from [ding2020]_.
16
+
17
+ TSception: A deep learning framework for emotion detection using EEG.
18
+
19
+ .. figure:: https://user-images.githubusercontent.com/58539144/74716976-80415e00-526a-11ea-9433-02ab2b753f6b.PNG
20
+ :align: center
21
+ :alt: TSceptionV1 Architecture
22
+
23
+ The model consists of temporal and spatial convolutional layers
24
+ (Tception and Sception) designed to learn temporal and spatial features
25
+ from EEG data.
26
+
27
+ Parameters
28
+ ----------
29
+ number_filter_temp : int
30
+ Number of temporal convolutional filters.
31
+ number_filter_spat : int
32
+ Number of spatial convolutional filters.
33
+ hidden_size : int
34
+ Number of units in the hidden fully connected layer.
35
+ drop_prob : float
36
+ Dropout rate applied after the hidden layer.
37
+ activation : nn.Module, optional
38
+ Activation function class to apply. Should be a PyTorch activation
39
+ module like ``nn.ReLU`` or ``nn.LeakyReLU``. Default is ``nn.LeakyReLU``.
40
+ pool_size : int, optional
41
+ Pooling size for the average pooling layers. Default is 8.
42
+ inception_windows : list[float], optional
43
+ List of window sizes (in seconds) for the inception modules.
44
+ Default is [0.5, 0.25, 0.125].
45
+
46
+ Notes
47
+ -----
48
+ This implementation is not guaranteed to be correct, has not been checked
49
+ by original authors. The modifications are minimal and the model is expected
50
+ to work as intended. the original code from [code2020]_.
51
+
52
+ References
53
+ ----------
54
+ .. [ding2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
55
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
56
+ for emotion detection using EEG. In 2020 international joint conference
57
+ on neural networks (IJCNN) (pp. 1-7). IEEE.
58
+ .. [code2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
59
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
60
+ for emotion detection using EEG.
61
+ https://github.com/deepBrains/TSception/blob/master/Models.py
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ # Braindecode parameters
67
+ n_chans=None,
68
+ n_outputs=None,
69
+ input_window_seconds=None,
70
+ chs_info=None,
71
+ n_times=None,
72
+ sfreq=None,
73
+ # Model parameters
74
+ number_filter_temp: int = 9,
75
+ number_filter_spat: int = 6,
76
+ hidden_size: int = 128,
77
+ drop_prob: float = 0.5,
78
+ activation: nn.Module = nn.LeakyReLU,
79
+ pool_size: int = 8,
80
+ inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
81
+ ):
82
+ super().__init__(
83
+ n_outputs=n_outputs,
84
+ n_chans=n_chans,
85
+ chs_info=chs_info,
86
+ n_times=n_times,
87
+ input_window_seconds=input_window_seconds,
88
+ sfreq=sfreq,
89
+ )
90
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
91
+
92
+ self.activation = activation
93
+ self.pool_size = pool_size
94
+ self.inception_windows = inception_windows
95
+ self.number_filter_spat = number_filter_spat
96
+ self.number_filter_temp = number_filter_temp
97
+ self.drop_prob = drop_prob
98
+
99
+ ### Layers
100
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
101
+ # Define temporal convolutional layers (Tception)
102
+ self.temporal_blocks = nn.ModuleList(
103
+ [
104
+ self._conv_block(
105
+ in_channels=1,
106
+ out_channels=number_filter_temp,
107
+ kernel_size=(1, int(window * self.sfreq)),
108
+ stride=1,
109
+ pool_size=self.pool_size,
110
+ activation=self.activation,
111
+ )
112
+ for window in self.inception_windows
113
+ ]
114
+ )
115
+ self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
116
+
117
+ # Define spatial convolutional layers (Sception)
118
+
119
+ pool_size_spat = self.pool_size // 4
120
+
121
+ self.spatial_block_1 = self._conv_block(
122
+ in_channels=self.number_filter_temp,
123
+ out_channels=self.number_filter_spat,
124
+ kernel_size=(self.n_chans, 1),
125
+ stride=1,
126
+ pool_size=pool_size_spat,
127
+ activation=self.activation,
128
+ )
129
+
130
+ kernel_size_spat_2 = (max(1, self.n_chans // 2), 1)
131
+
132
+ self.spatial_block_2 = self._conv_block(
133
+ in_channels=self.number_filter_temp,
134
+ out_channels=self.number_filter_spat,
135
+ kernel_size=kernel_size_spat_2,
136
+ stride=kernel_size_spat_2,
137
+ pool_size=pool_size_spat,
138
+ activation=self.activation,
139
+ )
140
+ self.batch_spatial_lay = nn.BatchNorm2d(self.number_filter_spat)
141
+
142
+ # Calculate the size of the features after convolution and pooling layers
143
+ self.feature_size = self._calculate_feature_size()
144
+ # self.feature_size = self.number_filter_spat *
145
+ # Define the final classification layers
146
+
147
+ self.dense_layer = nn.Sequential(
148
+ nn.Linear(self.feature_size, hidden_size),
149
+ self.activation(),
150
+ nn.Dropout(self.drop_prob),
151
+ )
152
+
153
+ self.final_layer = nn.Linear(hidden_size, self.n_outputs)
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ """
157
+ Forward pass of the TSception model.
158
+
159
+ Parameters
160
+ ----------
161
+ x : torch.Tensor
162
+ Input tensor of shape (batch_size, n_channels, n_times).
163
+
164
+ Returns
165
+ -------
166
+ torch.Tensor
167
+ Output tensor of shape (batch_size, n_classes).
168
+ """
169
+ # Temporal Convolution
170
+ # shape: (batch_size, n_channels, n_times)
171
+ x = self.ensuredim(x)
172
+ # shape: (batch_size, 1, n_channels, n_times)
173
+
174
+ t_features = [layer(x) for layer in self.temporal_blocks]
175
+ # shape: (batch_size, number_filter_temp, n_channels,
176
+ #
177
+ t_out = torch.cat(t_features, dim=-1)
178
+
179
+ t_out = self.batch_temporal_lay(t_out)
180
+
181
+ # Spatial Convolution
182
+ s_out1 = self.spatial_block_1(t_out)
183
+ s_out2 = self.spatial_block_2(t_out)
184
+ s_out = torch.cat((s_out1, s_out2), dim=2)
185
+ s_out = self.batch_spatial_lay(s_out)
186
+
187
+ # Flatten and apply final layers
188
+ s_out = s_out.view(s_out.size(0), -1)
189
+ output = self.dense_layer(s_out)
190
+ output = self.final_layer(output)
191
+ return output
192
+
193
+ def _calculate_feature_size(self) -> int:
194
+ """
195
+ Calculates the size of the features after convolution and pooling layers.
196
+
197
+ Returns
198
+ -------
199
+ int
200
+ Flattened size of the features after convolution and pooling layers.
201
+ """
202
+ with torch.no_grad():
203
+ dummy_input = torch.ones(1, 1, self.n_chans, self.n_times)
204
+ t_features = [layer(dummy_input) for layer in self.temporal_blocks]
205
+ t_out = torch.cat(t_features, dim=-1)
206
+ t_out = self.batch_temporal_lay(t_out)
207
+
208
+ s_out1 = self.spatial_block_1(t_out)
209
+ s_out2 = self.spatial_block_2(t_out)
210
+ s_out = torch.cat((s_out1, s_out2), dim=2)
211
+ s_out = self.batch_spatial_lay(s_out)
212
+
213
+ feature_size = s_out.view(1, -1).size(1)
214
+ return feature_size
215
+
216
+ @staticmethod
217
+ def _conv_block(
218
+ in_channels: int,
219
+ out_channels: int,
220
+ kernel_size: tuple,
221
+ stride: tuple[int, int] | int,
222
+ pool_size: int,
223
+ activation: nn.Module,
224
+ ) -> nn.Sequential:
225
+ """
226
+ Creates a convolutional block with Conv2d, activation, and AvgPool2d layers.
227
+
228
+ Parameters
229
+ ----------
230
+ in_channels : int
231
+ Number of input channels.
232
+ out_channels : int
233
+ Number of output channels.
234
+ kernel_size : tuple
235
+ Size of the convolutional kernel.
236
+ stride : int
237
+ Stride of the convolution.
238
+ pool_size : int
239
+ Size of the pooling kernel.
240
+ activation : nn.Module
241
+ Activation function class.
242
+
243
+ Returns
244
+ -------
245
+ nn.Sequential
246
+ A sequential container of the convolutional block.
247
+ """
248
+ return nn.Sequential(
249
+ nn.Conv2d(
250
+ in_channels=in_channels,
251
+ out_channels=out_channels,
252
+ kernel_size=kernel_size,
253
+ stride=stride,
254
+ padding=0,
255
+ ),
256
+ activation(),
257
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_size)),
258
+ )