braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ """
2
+ * Copyright (C) Cogitat, Ltd.
3
+ * Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
4
+ * Patent GB2609265 - Learnable filters for eeg classification
5
+ * https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
6
+ """
7
+
8
+ from functools import partial
9
+
10
+ import torch
11
+ from einops.layers.torch import Rearrange
12
+ from torch import nn
13
+
14
+ import braindecode.functional as F
15
+ from braindecode.models.base import EEGModuleMixin
16
+ from braindecode.modules import GeneralizedGaussianFilter
17
+
18
+ _eeg_miner_methods = ["mag", "corr", "plv"]
19
+
20
+
21
+ class EEGMiner(EEGModuleMixin, nn.Module):
22
+ """EEGMiner from Ludwig et al (2024) [eegminer]_.
23
+
24
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
25
+ :align: center
26
+ :alt: EEGMiner Architecture
27
+
28
+ EEGMiner is a neural network model for EEG signal classification using
29
+ learnable generalized Gaussian filters. The model leverages frequency domain
30
+ filtering and connectivity metrics or feature extraction, such as Phase Locking
31
+ Value (PLV) to extract meaningful features from EEG data, enabling effective
32
+ classification tasks.
33
+
34
+ The model has the following steps:
35
+
36
+ - **Generalized Gaussian** filters in the frequency domain to the input EEG signals.
37
+
38
+ - **Connectivity estimators** (corr, plv) or **Electrode-Wise Band Power** (mag), by default (plv).
39
+ - `'corr'`: Computes the correlation of the filtered signals.
40
+ - `'plv'`: Computes the phase locking value of the filtered signals.
41
+ - `'mag'`: Computes the magnitude of the filtered signals.
42
+
43
+ - **Feature Normalization**
44
+ - Apply batch normalization.
45
+
46
+ - **Final Layer**
47
+ - Feeds the batch-normalized features into a final linear layer for classification.
48
+
49
+ Depending on the selected method (`mag`, `corr`, or `plv`),
50
+ it computes the filtered signals' magnitude, correlation, or phase locking value.
51
+ These features are then normalized and passed through a batch normalization layer
52
+ before being fed into a final linear layer for classification.
53
+
54
+ The input to EEGMiner should be a three-dimensional tensor representing EEG signals:
55
+
56
+ ``(batch_size, n_channels, n_timesteps)``.
57
+
58
+ Notes
59
+ -----
60
+ EEGMiner incorporates learnable parameters for filter characteristics, allowing the
61
+ model to adaptively learn optimal frequency bands and phase delays for the classification task.
62
+ By default, using the PLV as a connectivity metric makes EEGMiner suitable for tasks requiring
63
+ the analysis of phase relationships between different EEG channels.
64
+
65
+ The model and the module have patent [eegminercode]_, and the code is CC BY-NC 4.0.
66
+
67
+ .. versionadded:: 0.9
68
+
69
+ Parameters
70
+ ----------
71
+ method : str, default="plv"
72
+ The method used for feature extraction. Options are:
73
+ - "mag": Electrode-Wise band power of the filtered signals.
74
+ - "corr": Correlation between filtered channels.
75
+ - "plv": Phase Locking Value connectivity metric.
76
+ filter_f_mean : list of float, default=[23.0, 23.0]
77
+ Mean frequencies for the generalized Gaussian filters.
78
+ filter_bandwidth : list of float, default=[44.0, 44.0]
79
+ Bandwidths for the generalized Gaussian filters.
80
+ filter_shape : list of float, default=[2.0, 2.0]
81
+ Shape parameters for the generalized Gaussian filters.
82
+ group_delay : tuple of float, default=(20.0, 20.0)
83
+ Group delay values for the filters in milliseconds.
84
+ clamp_f_mean : tuple of float, default=(1.0, 45.0)
85
+ Clamping range for the mean frequency parameters.
86
+
87
+ References
88
+ ----------
89
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
90
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
91
+ of brain activity with learnable filters. Journal of Neural Engineering,
92
+ 21(3), 036010.
93
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
94
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
95
+ of brain activity with learnable filters.
96
+ https://github.com/SMLudwig/EEGminer/.
97
+ Cogitat, Ltd. "Learnable filters for EEG classification."
98
+ Patent GB2609265.
99
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
100
+ """
101
+
102
+ def __init__(
103
+ self, # Signal related parameters
104
+ method: str = "plv",
105
+ n_chans=None,
106
+ n_outputs=None,
107
+ n_times=None,
108
+ chs_info=None,
109
+ input_window_seconds=None,
110
+ sfreq=None,
111
+ # model related
112
+ filter_f_mean=(23.0, 23.0),
113
+ filter_bandwidth=(44.0, 44.0),
114
+ filter_shape=(2.0, 2.0),
115
+ group_delay=(20.0, 20.0),
116
+ clamp_f_mean=(1.0, 45.0),
117
+ ):
118
+ super().__init__(
119
+ n_outputs=n_outputs,
120
+ n_chans=n_chans,
121
+ chs_info=chs_info,
122
+ n_times=n_times,
123
+ input_window_seconds=input_window_seconds,
124
+ sfreq=sfreq,
125
+ )
126
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
127
+
128
+ # Initialize filter parameters
129
+ self.filter_f_mean = filter_f_mean
130
+ self.filter_bandwidth = filter_bandwidth
131
+ self.filter_shape = filter_shape
132
+ self.n_filters = len(self.filter_f_mean)
133
+ self.group_delay = group_delay
134
+ self.clamp_f_mean = clamp_f_mean
135
+ self.method = method.lower()
136
+
137
+ if self.method not in _eeg_miner_methods:
138
+ raise ValueError(
139
+ f"The method {self.method} is not one of the valid options"
140
+ f" {_eeg_miner_methods}"
141
+ )
142
+
143
+ if self.method == "mag" or self.method == "corr":
144
+ inverse_fourier = True
145
+ in_channels = self.n_chans
146
+ out_channels = self.n_chans * self.n_filters
147
+ else:
148
+ inverse_fourier = False
149
+ in_channels = 1
150
+ out_channels = 1 * self.n_filters
151
+
152
+ # Generalized Gaussian Filter
153
+ self.filter = GeneralizedGaussianFilter(
154
+ in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ sequence_length=self.n_times,
157
+ sample_rate=self.sfreq,
158
+ f_mean=self.filter_f_mean,
159
+ bandwidth=self.filter_bandwidth,
160
+ shape=self.filter_shape,
161
+ affine_group_delay=False,
162
+ inverse_fourier=inverse_fourier,
163
+ group_delay=self.group_delay,
164
+ clamp_f_mean=self.clamp_f_mean,
165
+ )
166
+
167
+ # Forward method
168
+ if self.method == "mag":
169
+ self.method_forward = self._apply_mag_forward
170
+ self.n_features = self.n_chans * self.n_filters
171
+ self.ensure_dim = nn.Identity()
172
+ elif self.method == "corr":
173
+ self.method_forward = partial(
174
+ self._apply_corr_forward,
175
+ n_chans=self.n_chans,
176
+ n_filters=self.n_filters,
177
+ n_times=self.n_times,
178
+ )
179
+ self.n_features = self.n_filters * self.n_chans * (self.n_chans - 1) // 2
180
+ self.ensure_dim = nn.Identity()
181
+ elif self.method == "plv":
182
+ self.method_forward = partial(self._apply_plv, n_chans=self.n_chans)
183
+ self.ensure_dim = Rearrange("... d -> ... 1 d")
184
+ self.n_features = (self.n_filters * self.n_chans * (self.n_chans - 1)) // 2
185
+
186
+ self.flatten_layer = nn.Flatten()
187
+ # Classifier
188
+ self.batch_layer = nn.BatchNorm1d(self.n_features, affine=False)
189
+ self.final_layer = nn.Linear(self.n_features, self.n_outputs)
190
+ nn.init.zeros_(self.final_layer.bias)
191
+
192
+ def forward(self, x):
193
+ """x: (batch, electrodes, time)"""
194
+ batch = x.shape[0]
195
+ x = self.ensure_dim(x)
196
+ # Apply Gaussian filters in frequency domain
197
+ # x -> (batch, electrodes * filters, time)
198
+ x = self.filter(x)
199
+
200
+ x = self.method_forward(x=x, batch=batch)
201
+ # Classifier
202
+ # Note that the order of dimensions before flattening the feature vector is important
203
+ # for attributing feature weights during interpretation.
204
+ x = x.reshape(batch, self.n_features)
205
+ x = self.batch_layer(x)
206
+ x = self.final_layer(x)
207
+
208
+ return x
209
+
210
+ @staticmethod
211
+ def _apply_mag_forward(x, batch=None):
212
+ # Signal magnitude
213
+ x = x * x
214
+ x = x.mean(dim=-1)
215
+ x = torch.sqrt(x)
216
+ return x
217
+
218
+ @staticmethod
219
+ def _apply_corr_forward(
220
+ x, batch, n_chans, n_filters, n_times, epilson: float = 1e-6
221
+ ):
222
+ x = x.reshape(batch, n_chans, n_filters, n_times).transpose(-3, -2)
223
+ x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(
224
+ x.var(dim=-1, keepdim=True) + epilson
225
+ )
226
+ x = torch.matmul(x, x.transpose(-2, -1)) / x.shape[-1]
227
+ # Original tensor shape: [batch, n_filters, chans, chans]
228
+ x = x.permute(0, 2, 3, 1)
229
+ # New tensor shape: [batch, chans, chans, n_filters]
230
+ # move filter channels to the end
231
+ x = x.abs()
232
+
233
+ # Get upper triu of symmetric connectivity matrix
234
+ triu = torch.triu_indices(n_chans, n_chans, 1)
235
+ x = x[:, triu[0], triu[1], :]
236
+
237
+ return x
238
+
239
+ @staticmethod
240
+ def _apply_plv(x, n_chans, batch=None):
241
+ # Compute PLV connectivity
242
+ # x -> (batch, electrodes, electrodes, filters)
243
+ x = x.transpose(-4, -3) # swap electrodes and filters
244
+ # adjusting to compute the plv
245
+ x = F.plv_time(x, forward_fourier=False)
246
+ # batch, number of filters, connectivity matrix
247
+ # [batch, n_filters, chans, chans]
248
+ x = x.permute(0, 2, 3, 1)
249
+ # [batch, chans, chans, n_filters]
250
+
251
+ # Get upper triu of symmetric connectivity matrix
252
+ triu = torch.triu_indices(n_chans, n_chans, 1)
253
+ x = x[:, triu[0], triu[1], :]
254
+ return x