braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ """
2
+ * Copyright (C) Cogitat, Ltd.
3
+ * Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
4
+ * Patent GB2609265 - Learnable filters for eeg classification
5
+ * https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
6
+ """
7
+
8
+ from functools import partial
9
+
10
+ import torch
11
+ from einops.layers.torch import Rearrange
12
+ from torch import nn
13
+
14
+ import braindecode.functional as F
15
+ from braindecode.models.base import EEGModuleMixin
16
+ from braindecode.modules import GeneralizedGaussianFilter
17
+
18
+ _eeg_miner_methods = ["mag", "corr", "plv"]
19
+
20
+
21
+ class EEGMiner(EEGModuleMixin, nn.Module):
22
+ r"""EEGMiner from Ludwig et al (2024) [eegminer]_.
23
+
24
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
25
+
26
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
27
+ :align: center
28
+ :alt: EEGMiner Architecture
29
+
30
+ EEGMiner is a neural network model for EEG signal classification using
31
+ learnable generalized Gaussian filters. The model leverages frequency domain
32
+ filtering and connectivity metrics or feature extraction, such as Phase Locking
33
+ Value (PLV) to extract meaningful features from EEG data, enabling effective
34
+ classification tasks.
35
+
36
+ The model has the following steps:
37
+
38
+ - **Generalized Gaussian** filters in the frequency domain to the input EEG signals.
39
+
40
+ - **Connectivity estimators** (corr, plv) or **Electrode-Wise Band Power** (mag), by default (plv).
41
+ - `'corr'`: Computes the correlation of the filtered signals.
42
+ - `'plv'`: Computes the phase locking value of the filtered signals.
43
+ - `'mag'`: Computes the magnitude of the filtered signals.
44
+
45
+ - **Feature Normalization**
46
+ - Apply batch normalization.
47
+
48
+ - **Final Layer**
49
+ - Feeds the batch-normalized features into a final linear layer for classification.
50
+
51
+ Depending on the selected method (`mag`, `corr`, or `plv`),
52
+ it computes the filtered signals' magnitude, correlation, or phase locking value.
53
+ These features are then normalized and passed through a batch normalization layer
54
+ before being fed into a final linear layer for classification.
55
+
56
+ The input to EEGMiner should be a three-dimensional tensor representing EEG signals:
57
+
58
+ ``(batch_size, n_channels, n_timesteps)``.
59
+
60
+ Notes
61
+ -----
62
+ EEGMiner incorporates learnable parameters for filter characteristics, allowing the
63
+ model to adaptively learn optimal frequency bands and phase delays for the classification task.
64
+ By default, using the PLV as a connectivity metric makes EEGMiner suitable for tasks requiring
65
+ the analysis of phase relationships between different EEG channels.
66
+
67
+ The model and the module have patent [eegminercode]_, and the code is CC BY-NC 4.0.
68
+
69
+ .. versionadded:: 0.9
70
+
71
+ Parameters
72
+ ----------
73
+ method : str, default="plv"
74
+ The method used for feature extraction. Options are:
75
+ - "mag": Electrode-Wise band power of the filtered signals.
76
+ - "corr": Correlation between filtered channels.
77
+ - "plv": Phase Locking Value connectivity metric.
78
+ filter_f_mean : list of float, default=[23.0, 23.0]
79
+ Mean frequencies for the generalized Gaussian filters.
80
+ filter_bandwidth : list of float, default=[44.0, 44.0]
81
+ Bandwidths for the generalized Gaussian filters.
82
+ filter_shape : list of float, default=[2.0, 2.0]
83
+ Shape parameters for the generalized Gaussian filters.
84
+ group_delay : tuple of float, default=(20.0, 20.0)
85
+ Group delay values for the filters in milliseconds.
86
+ clamp_f_mean : tuple of float, default=(1.0, 45.0)
87
+ Clamping range for the mean frequency parameters.
88
+
89
+ References
90
+ ----------
91
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
92
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
93
+ of brain activity with learnable filters. Journal of Neural Engineering,
94
+ 21(3), 036010.
95
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
96
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
97
+ of brain activity with learnable filters.
98
+ https://github.com/SMLudwig/EEGminer/.
99
+ Cogitat, Ltd. "Learnable filters for EEG classification."
100
+ Patent GB2609265.
101
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
102
+ """
103
+
104
+ def __init__(
105
+ self, # Signal related parameters
106
+ method: str = "plv",
107
+ n_chans=None,
108
+ n_outputs=None,
109
+ n_times=None,
110
+ chs_info=None,
111
+ input_window_seconds=None,
112
+ sfreq=None,
113
+ # model related
114
+ filter_f_mean=(23.0, 23.0),
115
+ filter_bandwidth=(44.0, 44.0),
116
+ filter_shape=(2.0, 2.0),
117
+ group_delay=(20.0, 20.0),
118
+ clamp_f_mean=(1.0, 45.0),
119
+ ):
120
+ super().__init__(
121
+ n_outputs=n_outputs,
122
+ n_chans=n_chans,
123
+ chs_info=chs_info,
124
+ n_times=n_times,
125
+ input_window_seconds=input_window_seconds,
126
+ sfreq=sfreq,
127
+ )
128
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
129
+
130
+ # Initialize filter parameters
131
+ self.filter_f_mean = filter_f_mean
132
+ self.filter_bandwidth = filter_bandwidth
133
+ self.filter_shape = filter_shape
134
+ self.n_filters = len(self.filter_f_mean)
135
+ self.group_delay = group_delay
136
+ self.clamp_f_mean = clamp_f_mean
137
+ self.method = method.lower()
138
+
139
+ if self.method not in _eeg_miner_methods:
140
+ raise ValueError(
141
+ f"The method {self.method} is not one of the valid options"
142
+ f" {_eeg_miner_methods}"
143
+ )
144
+
145
+ if self.method == "mag" or self.method == "corr":
146
+ inverse_fourier = True
147
+ in_channels = self.n_chans
148
+ out_channels = self.n_chans * self.n_filters
149
+ else:
150
+ inverse_fourier = False
151
+ in_channels = 1
152
+ out_channels = 1 * self.n_filters
153
+
154
+ # Generalized Gaussian Filter
155
+ self.filter = GeneralizedGaussianFilter(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ sequence_length=self.n_times,
159
+ sample_rate=self.sfreq,
160
+ f_mean=self.filter_f_mean,
161
+ bandwidth=self.filter_bandwidth,
162
+ shape=self.filter_shape,
163
+ affine_group_delay=False,
164
+ inverse_fourier=inverse_fourier,
165
+ group_delay=self.group_delay,
166
+ clamp_f_mean=self.clamp_f_mean,
167
+ )
168
+
169
+ # Forward method
170
+ if self.method == "mag":
171
+ self.method_forward = self._apply_mag_forward
172
+ self.n_features = self.n_chans * self.n_filters
173
+ self.ensure_dim = nn.Identity()
174
+ elif self.method == "corr":
175
+ self.method_forward = partial(
176
+ self._apply_corr_forward,
177
+ n_chans=self.n_chans,
178
+ n_filters=self.n_filters,
179
+ n_times=self.n_times,
180
+ )
181
+ self.n_features = self.n_filters * self.n_chans * (self.n_chans - 1) // 2
182
+ self.ensure_dim = nn.Identity()
183
+ elif self.method == "plv":
184
+ self.method_forward = partial(self._apply_plv, n_chans=self.n_chans)
185
+ self.ensure_dim = Rearrange("... d -> ... 1 d")
186
+ self.n_features = (self.n_filters * self.n_chans * (self.n_chans - 1)) // 2
187
+
188
+ self.flatten_layer = nn.Flatten()
189
+ # Classifier
190
+ self.batch_layer = nn.BatchNorm1d(self.n_features, affine=False)
191
+ self.final_layer = nn.Linear(self.n_features, self.n_outputs)
192
+ nn.init.zeros_(self.final_layer.bias)
193
+
194
+ def forward(self, x):
195
+ """x: (batch, electrodes, time)"""
196
+ batch = x.shape[0]
197
+ x = self.ensure_dim(x)
198
+ # Apply Gaussian filters in frequency domain
199
+ # x -> (batch, electrodes * filters, time)
200
+ x = self.filter(x)
201
+
202
+ x = self.method_forward(x=x, batch=batch)
203
+ # Classifier
204
+ # Note that the order of dimensions before flattening the feature vector is important
205
+ # for attributing feature weights during interpretation.
206
+ x = x.reshape(batch, self.n_features)
207
+ x = self.batch_layer(x)
208
+ x = self.final_layer(x)
209
+
210
+ return x
211
+
212
+ @staticmethod
213
+ def _apply_mag_forward(x, batch=None):
214
+ # Signal magnitude
215
+ x = x * x
216
+ x = x.mean(dim=-1)
217
+ x = torch.sqrt(x)
218
+ return x
219
+
220
+ @staticmethod
221
+ def _apply_corr_forward(
222
+ x, batch, n_chans, n_filters, n_times, epilson: float = 1e-6
223
+ ):
224
+ x = x.reshape(batch, n_chans, n_filters, n_times).transpose(-3, -2)
225
+ x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(
226
+ x.var(dim=-1, keepdim=True) + epilson
227
+ )
228
+ x = torch.matmul(x, x.transpose(-2, -1)) / x.shape[-1]
229
+ # Original tensor shape: [batch, n_filters, chans, chans]
230
+ x = x.permute(0, 2, 3, 1)
231
+ # New tensor shape: [batch, chans, chans, n_filters]
232
+ # move filter channels to the end
233
+ x = x.abs()
234
+
235
+ # Get upper triu of symmetric connectivity matrix
236
+ triu = torch.triu_indices(n_chans, n_chans, 1)
237
+ x = x[:, triu[0], triu[1], :]
238
+
239
+ return x
240
+
241
+ @staticmethod
242
+ def _apply_plv(x, n_chans, batch=None):
243
+ # Compute PLV connectivity
244
+ # x -> (batch, electrodes, electrodes, filters)
245
+ x = x.transpose(-4, -3) # swap electrodes and filters
246
+ # adjusting to compute the plv
247
+ x = F.plv_time(x, forward_fourier=False)
248
+ # batch, number of filters, connectivity matrix
249
+ # [batch, n_filters, chans, chans]
250
+ x = x.permute(0, 2, 3, 1)
251
+ # [batch, chans, chans, n_filters]
252
+
253
+ # Get upper triu of symmetric connectivity matrix
254
+ triu = torch.triu_indices(n_chans, n_chans, 1)
255
+ x = x[:, triu[0], triu[1], :]
256
+ return x
@@ -0,0 +1,359 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ from typing import Dict, Optional
7
+
8
+ from einops.layers.torch import Rearrange
9
+ from mne.utils import deprecated, warn
10
+ from torch import nn
11
+
12
+ from braindecode.functional import glorot_weight_zero_bias
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import (
15
+ Conv2dWithConstraint,
16
+ Ensure4d,
17
+ LinearWithConstraint,
18
+ SqueezeFinalOutput,
19
+ )
20
+
21
+
22
+ class EEGNet(EEGModuleMixin, nn.Sequential):
23
+ r"""EEGNet model from Lawhern et al (2018) [Lawhern2018]_.
24
+
25
+ :bdg-success:`Convolution`
26
+
27
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
28
+ :align: center
29
+ :alt: EEGNet Architecture
30
+ :width: 600px
31
+
32
+ .. rubric:: Architectural Overview
33
+
34
+ EEGNet is a compact convolutional network designed for EEG decoding with a pipeline that mirrors classical EEG processing:
35
+ - (i) learn temporal frequency-selective filters,
36
+ - (ii) learn spatial filters for those frequencies, and
37
+ - (iii) condense features with depthwise-separable convolutions before a lightweight classifier.
38
+
39
+ The architecture is deliberately small (temporal convolutional and spatial patterns) [Lawhern2018]_.
40
+
41
+ .. rubric:: Macro Components
42
+
43
+ - **Temporal convolution**
44
+ Temporal convolution applied per channel; learns ``F1`` kernels that act as data-driven band-pass filters.
45
+ - **Depthwise Spatial Filtering.**
46
+ Depthwise convolution spanning the channel dimension with ``groups = F1``,
47
+ yielding ``D`` spatial filters for each temporal filter (no cross-filter mixing).
48
+ - **Norm-Nonlinearity-Pooling (+ dropout).**
49
+ Batch normalization → ELU → temporal pooling, with dropout.
50
+ - **Depthwise-Separable Convolution Block.**
51
+ (a) depthwise temporal conv to refine temporal structure;
52
+ (b) pointwise 1x1 conv to mix feature maps into ``F2`` combinations.
53
+ - **Classifier Head.**
54
+ Lightweight 1x1 conv or dense layer (often with max-norm constraint).
55
+
56
+ .. rubric:: Convolutional Details
57
+
58
+ - **Temporal.** The initial temporal convs serve as a *learned filter bank*:
59
+ long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
60
+ Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
61
+
62
+ - **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
63
+ With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
64
+ typically regularized with max-norm.
65
+
66
+ - **Spectral.** No explicit Fourier/wavelet transform is used. Frequency structure
67
+ is captured implicitly by the temporal filter bank; later depthwise temporal kernels act as short-time integrators/refiners.
68
+
69
+ .. rubric:: Additional Comments
70
+
71
+ - **Filter-bank structure:** Parallel temporal kernels (``F1``) emulate classical filter banks; pairing them with frequency-specific spatial filters
72
+ yields features mappable to rhythms and topographies.
73
+ - **Depthwise & separable convs:** Parameter-efficient decomposition (depthwise + pointwise) retains power while limiting overfitting
74
+ [Chollet2017]_ and keeps temporal vs. mixing steps interpretable.
75
+ - **Regularization:** Batch norm, dropout, pooling, and optional max-norm on spatial kernels aid stability on small EEG datasets.
76
+ - The v4 means the version 4 at the arxiv paper [Lawhern2018]_.
77
+
78
+
79
+ Parameters
80
+ ----------
81
+ final_conv_length : int or "auto", default="auto"
82
+ Length of the final convolution layer. If "auto", it is set based on n_times.
83
+ pool_mode : {"mean", "max"}, default="mean"
84
+ Pooling method to use in pooling layers.
85
+ F1 : int, default=8
86
+ Number of temporal filters in the first convolutional layer.
87
+ D : int, default=2
88
+ Depth multiplier for the depthwise convolution.
89
+ F2 : int or None, default=None
90
+ Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
91
+ depthwise_kernel_length : int, default=16
92
+ Length of the depthwise convolution kernel in the separable convolution.
93
+ pool1_kernel_size : int, default=4
94
+ Kernel size of the first pooling layer.
95
+ pool2_kernel_size : int, default=8
96
+ Kernel size of the second pooling layer.
97
+ kernel_length : int, default=64
98
+ Length of the temporal convolution kernel.
99
+ conv_spatial_max_norm : float, default=1
100
+ Maximum norm constraint for the spatial (depthwise) convolution.
101
+ activation : nn.Module, default=nn.ELU
102
+ Non-linear activation function to be used in the layers.
103
+ batch_norm_momentum : float, default=0.01
104
+ Momentum for instance normalization in batch norm layers.
105
+ batch_norm_affine : bool, default=True
106
+ If True, batch norm has learnable affine parameters.
107
+ batch_norm_eps : float, default=1e-3
108
+ Epsilon for numeric stability in batch norm layers.
109
+ drop_prob : float, default=0.25
110
+ Dropout probability.
111
+ final_layer_with_constraint : bool, default=False
112
+ If ``False``, uses a convolution-based classification layer. If ``True``,
113
+ apply a flattened linear layer with constraint on the weights norm as the final classification step.
114
+ norm_rate : float, default=0.25
115
+ Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
116
+
117
+ References
118
+ ----------
119
+ .. [Lawhern2018] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
120
+ Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
121
+ neural network for EEG-based brain–computer interfaces. Journal of
122
+ neural engineering, 15(5), 056013.
123
+ .. [Chollet2017] Chollet, F., *Xception: Deep Learning with Depthwise Separable
124
+ Convolutions*, CVPR, 2017.
125
+
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ # signal's parameters
131
+ n_chans: Optional[int] = None,
132
+ n_outputs: Optional[int] = None,
133
+ n_times: Optional[int] = None,
134
+ # model's parameters
135
+ final_conv_length: str | int = "auto",
136
+ pool_mode: str = "mean",
137
+ F1: int = 8,
138
+ D: int = 2,
139
+ F2: Optional[int | None] = None,
140
+ kernel_length: int = 64,
141
+ *,
142
+ depthwise_kernel_length: int = 16,
143
+ pool1_kernel_size: int = 4,
144
+ pool2_kernel_size: int = 8,
145
+ conv_spatial_max_norm: int = 1,
146
+ activation: type[nn.Module] = nn.ELU,
147
+ batch_norm_momentum: float = 0.01,
148
+ batch_norm_affine: bool = True,
149
+ batch_norm_eps: float = 1e-3,
150
+ drop_prob: float = 0.25,
151
+ final_layer_with_constraint: bool = False,
152
+ norm_rate: float = 0.25,
153
+ # Other ways to construct the signal related parameters
154
+ chs_info: Optional[list[Dict]] = None,
155
+ input_window_seconds=None,
156
+ sfreq=None,
157
+ **kwargs,
158
+ ):
159
+ super().__init__(
160
+ n_outputs=n_outputs,
161
+ n_chans=n_chans,
162
+ chs_info=chs_info,
163
+ n_times=n_times,
164
+ input_window_seconds=input_window_seconds,
165
+ sfreq=sfreq,
166
+ )
167
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
168
+ if final_conv_length == "auto":
169
+ assert self.n_times is not None
170
+
171
+ if not final_layer_with_constraint:
172
+ warn(
173
+ "Parameter 'final_layer_with_constraint=False' is deprecated and will be "
174
+ "removed in a future release. Please use `final_layer_linear=True`.",
175
+ DeprecationWarning,
176
+ )
177
+
178
+ if "third_kernel_size" in kwargs:
179
+ warn(
180
+ "The parameter `third_kernel_size` is deprecated "
181
+ "and will be removed in a future version.",
182
+ )
183
+ unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
184
+ if unexpected_kwargs:
185
+ raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
186
+
187
+ self.final_conv_length = final_conv_length
188
+ self.pool_mode = pool_mode
189
+ self.F1 = F1
190
+ self.D = D
191
+
192
+ if F2 is None:
193
+ F2 = self.F1 * self.D
194
+ self.F2 = F2
195
+
196
+ self.kernel_length = kernel_length
197
+ self.depthwise_kernel_length = depthwise_kernel_length
198
+ self.pool1_kernel_size = pool1_kernel_size
199
+ self.pool2_kernel_size = pool2_kernel_size
200
+ self.drop_prob = drop_prob
201
+ self.activation = activation
202
+ self.batch_norm_momentum = batch_norm_momentum
203
+ self.batch_norm_affine = batch_norm_affine
204
+ self.batch_norm_eps = batch_norm_eps
205
+ self.conv_spatial_max_norm = conv_spatial_max_norm
206
+ self.norm_rate = norm_rate
207
+
208
+ # For the load_state_dict
209
+ # When padronize all layers,
210
+ # add the old's parameters here
211
+ self.mapping = {
212
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
213
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
214
+ }
215
+
216
+ pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
217
+ self.add_module("ensuredims", Ensure4d())
218
+
219
+ self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
220
+ self.add_module(
221
+ "conv_temporal",
222
+ nn.Conv2d(
223
+ 1,
224
+ self.F1,
225
+ (1, self.kernel_length),
226
+ bias=False,
227
+ padding=(0, self.kernel_length // 2),
228
+ ),
229
+ )
230
+ self.add_module(
231
+ "bnorm_temporal",
232
+ nn.BatchNorm2d(
233
+ self.F1,
234
+ momentum=self.batch_norm_momentum,
235
+ affine=self.batch_norm_affine,
236
+ eps=self.batch_norm_eps,
237
+ ),
238
+ )
239
+ self.add_module(
240
+ "conv_spatial",
241
+ Conv2dWithConstraint(
242
+ in_channels=self.F1,
243
+ out_channels=self.F1 * self.D,
244
+ kernel_size=(self.n_chans, 1),
245
+ max_norm=self.conv_spatial_max_norm,
246
+ bias=False,
247
+ groups=self.F1,
248
+ ),
249
+ )
250
+
251
+ self.add_module(
252
+ "bnorm_1",
253
+ nn.BatchNorm2d(
254
+ self.F1 * self.D,
255
+ momentum=self.batch_norm_momentum,
256
+ affine=self.batch_norm_affine,
257
+ eps=self.batch_norm_eps,
258
+ ),
259
+ )
260
+ self.add_module("elu_1", activation())
261
+
262
+ self.add_module(
263
+ "pool_1",
264
+ pool_class(
265
+ kernel_size=(1, self.pool1_kernel_size),
266
+ ),
267
+ )
268
+ self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
269
+
270
+ # https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
271
+ self.add_module(
272
+ "conv_separable_depth",
273
+ nn.Conv2d(
274
+ self.F1 * self.D,
275
+ self.F1 * self.D,
276
+ (1, self.depthwise_kernel_length),
277
+ bias=False,
278
+ groups=self.F1 * self.D,
279
+ padding=(0, self.depthwise_kernel_length // 2),
280
+ ),
281
+ )
282
+ self.add_module(
283
+ "conv_separable_point",
284
+ nn.Conv2d(
285
+ self.F1 * self.D,
286
+ self.F2,
287
+ kernel_size=(1, 1),
288
+ bias=False,
289
+ ),
290
+ )
291
+
292
+ self.add_module(
293
+ "bnorm_2",
294
+ nn.BatchNorm2d(
295
+ self.F2,
296
+ momentum=self.batch_norm_momentum,
297
+ affine=self.batch_norm_affine,
298
+ eps=self.batch_norm_eps,
299
+ ),
300
+ )
301
+ self.add_module("elu_2", self.activation())
302
+ self.add_module(
303
+ "pool_2",
304
+ pool_class(
305
+ kernel_size=(1, self.pool2_kernel_size),
306
+ ),
307
+ )
308
+ self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
309
+
310
+ output_shape = self.get_output_shape()
311
+ n_out_virtual_chans = output_shape[2]
312
+
313
+ if self.final_conv_length == "auto":
314
+ n_out_time = output_shape[3]
315
+ self.final_conv_length = n_out_time
316
+
317
+ # Incorporating classification module and subsequent ones in one final layer
318
+ module = nn.Sequential()
319
+ if not final_layer_with_constraint:
320
+ module.add_module(
321
+ "conv_classifier",
322
+ nn.Conv2d(
323
+ self.F2,
324
+ self.n_outputs,
325
+ (n_out_virtual_chans, self.final_conv_length),
326
+ bias=True,
327
+ ),
328
+ )
329
+
330
+ # Transpose back to the logic of braindecode,
331
+ # so time in third dimension (axis=2)
332
+ module.add_module(
333
+ "permute_back",
334
+ Rearrange("batch x y z -> batch x z y"),
335
+ )
336
+
337
+ module.add_module("squeeze", SqueezeFinalOutput())
338
+ else:
339
+ module.add_module("flatten", nn.Flatten())
340
+ module.add_module(
341
+ "linearconstraint",
342
+ LinearWithConstraint(
343
+ in_features=self.F2 * self.final_conv_length,
344
+ out_features=self.n_outputs,
345
+ max_norm=norm_rate,
346
+ ),
347
+ )
348
+ self.add_module("final_layer", module)
349
+
350
+ glorot_weight_zero_bias(self)
351
+
352
+
353
+ @deprecated(
354
+ "`EEGNetv4` was renamed to `EEGNet` in v1.12; this alias will be removed in v1.14."
355
+ )
356
+ class EEGNetv4(EEGNet):
357
+ r"""Deprecated alias for EEGNet."""
358
+
359
+ pass