braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,255 @@
1
+ """
2
+ * Copyright (C) Cogitat, Ltd.
3
+ * Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
4
+ * Patent GB2609265 - Learnable filters for eeg classification
5
+ * https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
6
+ """
7
+
8
+ from functools import partial
9
+
10
+ import torch
11
+ from einops.layers.torch import Rearrange
12
+ from torch import nn
13
+ from torch.fft import fftfreq
14
+
15
+ import braindecode.functional as F
16
+ from braindecode.models.base import EEGModuleMixin
17
+ from braindecode.modules import GeneralizedGaussianFilter
18
+
19
+ _eeg_miner_methods = ["mag", "corr", "plv"]
20
+
21
+
22
+ class EEGMiner(EEGModuleMixin, nn.Module):
23
+ """EEGMiner from Ludwig et al (2024) [eegminer]_.
24
+
25
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
26
+ :align: center
27
+ :alt: EEGMiner Architecture
28
+
29
+ EEGMiner is a neural network model for EEG signal classification using
30
+ learnable generalized Gaussian filters. The model leverages frequency domain
31
+ filtering and connectivity metrics or feature extraction, such as Phase Locking
32
+ Value (PLV) to extract meaningful features from EEG data, enabling effective
33
+ classification tasks.
34
+
35
+ The model has the following steps:
36
+
37
+ - **Generalized Gaussian** filters in the frequency domain to the input EEG signals.
38
+
39
+ - **Connectivity estimators** (corr, plv) or **Electrode-Wise Band Power** (mag), by default (plv).
40
+ - `'corr'`: Computes the correlation of the filtered signals.
41
+ - `'plv'`: Computes the phase locking value of the filtered signals.
42
+ - `'mag'`: Computes the magnitude of the filtered signals.
43
+
44
+ - **Feature Normalization**
45
+ - Apply batch normalization.
46
+
47
+ - **Final Layer**
48
+ - Feeds the batch-normalized features into a final linear layer for classification.
49
+
50
+ Depending on the selected method (`mag`, `corr`, or `plv`),
51
+ it computes the filtered signals' magnitude, correlation, or phase locking value.
52
+ These features are then normalized and passed through a batch normalization layer
53
+ before being fed into a final linear layer for classification.
54
+
55
+ The input to EEGMiner should be a three-dimensional tensor representing EEG signals:
56
+
57
+ ``(batch_size, n_channels, n_timesteps)``.
58
+
59
+ Notes
60
+ -----
61
+ EEGMiner incorporates learnable parameters for filter characteristics, allowing the
62
+ model to adaptively learn optimal frequency bands and phase delays for the classification task.
63
+ By default, using the PLV as a connectivity metric makes EEGMiner suitable for tasks requiring
64
+ the analysis of phase relationships between different EEG channels.
65
+
66
+ The model and the module have patent [eegminercode]_, and the code is CC BY-NC 4.0.
67
+
68
+ .. versionadded:: 0.9
69
+
70
+ Parameters
71
+ ----------
72
+ method : str, default="plv"
73
+ The method used for feature extraction. Options are:
74
+ - "mag": Electrode-Wise band power of the filtered signals.
75
+ - "corr": Correlation between filtered channels.
76
+ - "plv": Phase Locking Value connectivity metric.
77
+ filter_f_mean : list of float, default=[23.0, 23.0]
78
+ Mean frequencies for the generalized Gaussian filters.
79
+ filter_bandwidth : list of float, default=[44.0, 44.0]
80
+ Bandwidths for the generalized Gaussian filters.
81
+ filter_shape : list of float, default=[2.0, 2.0]
82
+ Shape parameters for the generalized Gaussian filters.
83
+ group_delay : tuple of float, default=(20.0, 20.0)
84
+ Group delay values for the filters in milliseconds.
85
+ clamp_f_mean : tuple of float, default=(1.0, 45.0)
86
+ Clamping range for the mean frequency parameters.
87
+
88
+ References
89
+ ----------
90
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
91
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
92
+ of brain activity with learnable filters. Journal of Neural Engineering,
93
+ 21(3), 036010.
94
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
95
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
96
+ of brain activity with learnable filters.
97
+ https://github.com/SMLudwig/EEGminer/.
98
+ Cogitat, Ltd. "Learnable filters for EEG classification."
99
+ Patent GB2609265.
100
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
101
+ """
102
+
103
+ def __init__(
104
+ self, # Signal related parameters
105
+ method: str = "plv",
106
+ n_chans=None,
107
+ n_outputs=None,
108
+ n_times=None,
109
+ chs_info=None,
110
+ input_window_seconds=None,
111
+ sfreq=None,
112
+ # model related
113
+ filter_f_mean=(23.0, 23.0),
114
+ filter_bandwidth=(44.0, 44.0),
115
+ filter_shape=(2.0, 2.0),
116
+ group_delay=(20.0, 20.0),
117
+ clamp_f_mean=(1.0, 45.0),
118
+ ):
119
+ super().__init__(
120
+ n_outputs=n_outputs,
121
+ n_chans=n_chans,
122
+ chs_info=chs_info,
123
+ n_times=n_times,
124
+ input_window_seconds=input_window_seconds,
125
+ sfreq=sfreq,
126
+ )
127
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
128
+
129
+ # Initialize filter parameters
130
+ self.filter_f_mean = filter_f_mean
131
+ self.filter_bandwidth = filter_bandwidth
132
+ self.filter_shape = filter_shape
133
+ self.n_filters = len(self.filter_f_mean)
134
+ self.group_delay = group_delay
135
+ self.clamp_f_mean = clamp_f_mean
136
+ self.method = method.lower()
137
+
138
+ if self.method not in _eeg_miner_methods:
139
+ raise ValueError(
140
+ f"The method {self.method} is not one of the valid options"
141
+ f" {_eeg_miner_methods}"
142
+ )
143
+
144
+ if self.method == "mag" or self.method == "corr":
145
+ inverse_fourier = True
146
+ in_channels = self.n_chans
147
+ out_channels = self.n_chans * self.n_filters
148
+ else:
149
+ inverse_fourier = False
150
+ in_channels = 1
151
+ out_channels = 1 * self.n_filters
152
+
153
+ # Generalized Gaussian Filter
154
+ self.filter = GeneralizedGaussianFilter(
155
+ in_channels=in_channels,
156
+ out_channels=out_channels,
157
+ sequence_length=self.n_times,
158
+ sample_rate=self.sfreq,
159
+ f_mean=self.filter_f_mean,
160
+ bandwidth=self.filter_bandwidth,
161
+ shape=self.filter_shape,
162
+ affine_group_delay=False,
163
+ inverse_fourier=inverse_fourier,
164
+ group_delay=self.group_delay,
165
+ clamp_f_mean=self.clamp_f_mean,
166
+ )
167
+
168
+ # Forward method
169
+ if self.method == "mag":
170
+ self.method_forward = self._apply_mag_forward
171
+ self.n_features = self.n_chans * self.n_filters
172
+ self.ensure_dim = nn.Identity()
173
+ elif self.method == "corr":
174
+ self.method_forward = partial(
175
+ self._apply_corr_forward,
176
+ n_chans=self.n_chans,
177
+ n_filters=self.n_filters,
178
+ n_times=self.n_times,
179
+ )
180
+ self.n_features = self.n_filters * self.n_chans * (self.n_chans - 1) // 2
181
+ self.ensure_dim = nn.Identity()
182
+ elif self.method == "plv":
183
+ self.method_forward = partial(self._apply_plv, n_chans=self.n_chans)
184
+ self.ensure_dim = Rearrange("... d -> ... 1 d")
185
+ self.n_features = (self.n_filters * self.n_chans * (self.n_chans - 1)) // 2
186
+
187
+ self.flatten_layer = nn.Flatten()
188
+ # Classifier
189
+ self.batch_layer = nn.BatchNorm1d(self.n_features, affine=False)
190
+ self.final_layer = nn.Linear(self.n_features, self.n_outputs)
191
+ nn.init.zeros_(self.final_layer.bias)
192
+
193
+ def forward(self, x):
194
+ """x: (batch, electrodes, time)"""
195
+ batch = x.shape[0]
196
+ x = self.ensure_dim(x)
197
+ # Apply Gaussian filters in frequency domain
198
+ # x -> (batch, electrodes * filters, time)
199
+ x = self.filter(x)
200
+
201
+ x = self.method_forward(x=x, batch=batch)
202
+ # Classifier
203
+ # Note that the order of dimensions before flattening the feature vector is important
204
+ # for attributing feature weights during interpretation.
205
+ x = x.reshape(batch, self.n_features)
206
+ x = self.batch_layer(x)
207
+ x = self.final_layer(x)
208
+
209
+ return x
210
+
211
+ @staticmethod
212
+ def _apply_mag_forward(x, batch=None):
213
+ # Signal magnitude
214
+ x = x * x
215
+ x = x.mean(dim=-1)
216
+ x = torch.sqrt(x)
217
+ return x
218
+
219
+ @staticmethod
220
+ def _apply_corr_forward(
221
+ x, batch, n_chans, n_filters, n_times, epilson: float = 1e-6
222
+ ):
223
+ x = x.reshape(batch, n_chans, n_filters, n_times).transpose(-3, -2)
224
+ x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(
225
+ x.var(dim=-1, keepdim=True) + epilson
226
+ )
227
+ x = torch.matmul(x, x.transpose(-2, -1)) / x.shape[-1]
228
+ # Original tensor shape: [batch, n_filters, chans, chans]
229
+ x = x.permute(0, 2, 3, 1)
230
+ # New tensor shape: [batch, chans, chans, n_filters]
231
+ # move filter channels to the end
232
+ x = x.abs()
233
+
234
+ # Get upper triu of symmetric connectivity matrix
235
+ triu = torch.triu_indices(n_chans, n_chans, 1)
236
+ x = x[:, triu[0], triu[1], :]
237
+
238
+ return x
239
+
240
+ @staticmethod
241
+ def _apply_plv(x, n_chans, batch=None):
242
+ # Compute PLV connectivity
243
+ # x -> (batch, electrodes, electrodes, filters)
244
+ x = x.transpose(-4, -3) # swap electrodes and filters
245
+ # adjusting to compute the plv
246
+ x = F.plv_time(x, forward_fourier=False)
247
+ # batch, number of filters, connectivity matrix
248
+ # [batch, n_filters, chans, chans]
249
+ x = x.permute(0, 2, 3, 1)
250
+ # [batch, chans, chans, n_filters]
251
+
252
+ # Get upper triu of symmetric connectivity matrix
253
+ triu = torch.triu_indices(n_chans, n_chans, 1)
254
+ x = x[:, triu[0], triu[1], :]
255
+ return x