braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,632 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from mne.filter import _check_coefficients, create_filter
9
+ from mne.utils import warn
10
+ from torch import Tensor, from_numpy, nn
11
+ from torch.fft import fftfreq
12
+ from torchaudio.functional import fftconvolve, filtfilt
13
+
14
+ import braindecode.functional as F
15
+
16
+
17
+ class FilterBankLayer(nn.Module):
18
+ """Apply multiple band-pass filters to generate multiview signal representation.
19
+
20
+ This layer constructs a bank of signals filtered in specific bands for each channel.
21
+ It uses MNE's `create_filter` function to create the band-specific filters and
22
+ applies them to multi-channel time-series data. Each filter in the bank corresponds to a
23
+ specific frequency band and is applied to all channels of the input data. The filtering is
24
+ performed using FFT-based convolution via the `fftconvolve` function from
25
+ :func:`torchaudio.functional if the method is FIR, and `filtfilt` function from
26
+ :func:`torchaudio.functional if the method is IIR.
27
+
28
+ The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
29
+ spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
30
+ reference: *FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface*.
31
+
32
+ Parameters
33
+ ----------
34
+ n_chans : int
35
+ Number of channels in the input signal.
36
+ sfreq : int
37
+ Sampling frequency of the input signal in Hz.
38
+ band_filters : Optional[list[tuple[float, float]]] or int, default=None
39
+ List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines
40
+ the frequency range for one filter in the bank. If not provided, defaults
41
+ to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
42
+ method : str, default='fir'
43
+ ``'fir'`` will use FIR filtering, ``'iir'`` will use IIR
44
+ forward-backward filtering (via :func:`~scipy.signal.filtfilt`).
45
+ For more details, please check the `MNE Preprocessing Tutorial <https://mne.tools/stable/auto_tutorials/preprocessing/25_background_filtering.html>`_.
46
+ filter_length : str | int
47
+ Length of the FIR filter to use (if applicable):
48
+
49
+ * **'auto' (default)**: The filter length is chosen based
50
+ on the size of the transition regions (6.6 times the reciprocal
51
+ of the shortest transition band for fir_window='hamming'
52
+ and fir_design="firwin2", and half that for "firwin").
53
+ * **str**: A human-readable time in
54
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
55
+ converted to that number of samples if ``phase="zero"``, or
56
+ the shortest power-of-two length at least that duration for
57
+ ``phase="zero-double"``.
58
+ * **int**: Specified length in samples. For fir_design="firwin",
59
+ this should not be used.
60
+ l_trans_bandwidth : Union[str, float, int], default='auto'
61
+ Width of the transition band at the low cut-off frequency in Hz
62
+ (high pass or cutoff 1 in bandpass). Can be "auto"
63
+ (default) to use a multiple of ``l_freq``::
64
+
65
+ min(max(l_freq * 0.25, 2), l_freq)
66
+
67
+ Only used for ``method='fir'``.
68
+ h_trans_bandwidth : Union[str, float, int], default='auto'
69
+ Width of the transition band at the high cut-off frequency in Hz
70
+ (low pass or cutoff 2 in bandpass). Can be "auto"
71
+ (default in 0.14) to use a multiple of ``h_freq``::
72
+
73
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
74
+
75
+ Only used for ``method='fir'``.
76
+ phase : str, default='zero'
77
+ Phase of the filter.
78
+ When ``method='fir'``, symmetric linear-phase FIR filters are constructed
79
+ with the following behaviors when ``method="fir"``:
80
+
81
+ ``"zero"`` (default)
82
+ The delay of this filter is compensated for, making it non-causal.
83
+ ``"minimum"``
84
+ A minimum-phase filter will be constructed by decomposing the zero-phase filter
85
+ into a minimum-phase and all-pass systems, and then retaining only the
86
+ minimum-phase system (of the same length as the original zero-phase filter)
87
+ via :func:`scipy.signal.minimum_phase`.
88
+ ``"zero-double"``
89
+ *This is a legacy option for compatibility with MNE <= 0.13.*
90
+ The filter is applied twice, once forward, and once backward
91
+ (also making it non-causal).
92
+ ``"minimum-half"``
93
+ *This is a legacy option for compatibility with MNE <= 1.6.*
94
+ A minimum-phase filter will be reconstructed from the zero-phase filter with
95
+ half the length of the original filter.
96
+
97
+ When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'``
98
+ constructs and applies IIR filter twice, once forward, and once backward (making it
99
+ non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
100
+ the filter once in the forward (causal) direction using
101
+ :func:`~scipy.signal.lfilter`.
102
+
103
+ The behavior for ``phase="minimum"`` was fixed to use a filter of the requested
104
+ length and improved suppression.
105
+ iir_params : Optional[dict], default=None
106
+ Dictionary of parameters to use for IIR filtering.
107
+ If ``iir_params=None`` and ``method="iir"``, 4th order Butterworth will be used.
108
+ For more information, see :func:`mne.filter.construct_iir_filter`.
109
+ fir_window : str, default='hamming'
110
+ The window to use in FIR design, can be "hamming" (default),
111
+ "hann" (default in 0.13), or "blackman".
112
+ fir_design : str, default='firwin'
113
+ Can be "firwin" (default) to use :func:`scipy.signal.firwin`,
114
+ or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses
115
+ a time-domain design technique that generally gives improved
116
+ attenuation using fewer samples than "firwin2".
117
+ pad : str, default='reflect_limited'
118
+ The type of padding to use. Supports all func:`numpy.pad()` mode options.
119
+ Can also be "reflect_limited", which pads with a reflected version of
120
+ each vector mirrored on the first and last values of the vector,
121
+ followed by zeros. Only used for ``method='fir'``.
122
+ verbose: bool | str | int | None, default=True
123
+ Control verbosity of the logging output. If ``None``, use the default
124
+ verbosity level. See the func:`mne.verbose` for details.
125
+ Should only be passed as a keyword argument.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ n_chans: int,
131
+ sfreq: float,
132
+ band_filters: Optional[list[tuple[float, float]] | int] = None,
133
+ method: str = "fir",
134
+ filter_length: str | float | int = "auto",
135
+ l_trans_bandwidth: str | float | int = "auto",
136
+ h_trans_bandwidth: str | float | int = "auto",
137
+ phase: str = "zero",
138
+ iir_params: Optional[dict] = None,
139
+ fir_window: str = "hamming",
140
+ fir_design: str = "firwin",
141
+ verbose: bool = True,
142
+ ):
143
+ super(FilterBankLayer, self).__init__()
144
+
145
+ # The first step here is to check the band_filters
146
+ # We accept as None values.
147
+ if band_filters is None:
148
+ """
149
+ the filter bank is constructed using 9 filters with non-overlapping
150
+ frequency bands, each of 4Hz bandwidth, spanning from 4 to 40 Hz
151
+ (4-8, 8-12, …, 36-40 Hz)
152
+
153
+ Based on the reference: FBCNet: A Multi-view Convolutional Neural
154
+ Network for Brain-Computer Interface
155
+ """
156
+ band_filters = [(low, low + 4) for low in range(4, 36 + 1, 4)]
157
+ # We accept as int.
158
+ if isinstance(band_filters, int):
159
+ warn(
160
+ "Creating the filter banks equally divided in the "
161
+ "interval 4Hz to 40Hz with almost equal bandwidths. "
162
+ "If you want a specific interval, "
163
+ "please specify 'band_filters' as a list of tuples.",
164
+ UserWarning,
165
+ )
166
+ start = 4.0
167
+ end = 40.0
168
+
169
+ total_band_width = end - start # 4 Hz to 40 Hz
170
+
171
+ band_width_calculated = total_band_width / band_filters
172
+ band_filters = [
173
+ (
174
+ float(start + i * band_width_calculated),
175
+ float(start + (i + 1) * band_width_calculated),
176
+ )
177
+ for i in range(band_filters)
178
+ ]
179
+
180
+ if not isinstance(band_filters, list):
181
+ raise ValueError(
182
+ "`band_filters` should be a list of tuples if you want to "
183
+ "use them this way."
184
+ )
185
+ else:
186
+ if any(len(bands) != 2 for bands in band_filters):
187
+ raise ValueError(
188
+ "The band_filters items should be splitable in 2 values."
189
+ )
190
+
191
+ # and we accepted as
192
+ self.band_filters = band_filters
193
+ self.n_bands = len(band_filters)
194
+ self.phase = phase
195
+ self.method = method
196
+ self.n_chans = n_chans
197
+
198
+ self.method_iir = self.method == "iir"
199
+
200
+ # Prepare ParameterLists
201
+ self.fir_list = nn.ParameterList()
202
+ self.b_list = nn.ParameterList()
203
+ self.a_list = nn.ParameterList()
204
+
205
+ if self.method_iir:
206
+ if iir_params is None:
207
+ iir_params = dict(output="ba")
208
+ else:
209
+ if "output" in iir_params:
210
+ if iir_params["output"] == "sos":
211
+ warn(
212
+ "It is not possible to use second-order section filtering with Torch. Changing to filter ba",
213
+ UserWarning,
214
+ )
215
+ iir_params["output"] = "ba"
216
+
217
+ for l_freq, h_freq in band_filters:
218
+ filt = create_filter(
219
+ data=None,
220
+ sfreq=sfreq,
221
+ l_freq=float(l_freq),
222
+ h_freq=float(h_freq),
223
+ filter_length=filter_length,
224
+ l_trans_bandwidth=l_trans_bandwidth,
225
+ h_trans_bandwidth=h_trans_bandwidth,
226
+ method=self.method,
227
+ iir_params=iir_params,
228
+ phase=phase,
229
+ fir_window=fir_window,
230
+ fir_design=fir_design,
231
+ verbose=verbose,
232
+ )
233
+ if not self.method_iir:
234
+ # FIR filter
235
+ filt = from_numpy(filt).float()
236
+ self.fir_list.append(nn.Parameter(filt, requires_grad=False))
237
+ else:
238
+ a_coeffs = filt["a"]
239
+ b_coeffs = filt["b"]
240
+
241
+ _check_coefficients((b_coeffs, a_coeffs))
242
+
243
+ b = torch.tensor(b_coeffs, dtype=torch.float64)
244
+ a = torch.tensor(a_coeffs, dtype=torch.float64)
245
+
246
+ self.b_list.append(nn.Parameter(b, requires_grad=False))
247
+ self.a_list.append(nn.Parameter(a, requires_grad=False))
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ """
251
+ Apply the filter bank to the input signal.
252
+
253
+ Parameters
254
+ ----------
255
+ x : torch.Tensor
256
+ Input tensor of shape (batch_size, n_chans, time_points).
257
+
258
+ Returns
259
+ -------
260
+ torch.Tensor
261
+ Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
262
+ """
263
+ outs = []
264
+ if self.method_iir:
265
+ for b, a in zip(self.b_list, self.a_list):
266
+ # Pass numerator and denominator directly
267
+ outs.append(self._apply_iir(x=x, b_coeffs=b, a_coeffs=a))
268
+ else:
269
+ for fir in self.fir_list:
270
+ # Pass FIR filter directly
271
+ outs.append(self._apply_fir(x=x, filt=fir, n_chans=self.n_chans))
272
+
273
+ return torch.cat(outs, dim=1)
274
+
275
+ @staticmethod
276
+ def _apply_fir(x, filt: Tensor, n_chans: int) -> Tensor:
277
+ """
278
+ Apply an FIR filter to the input tensor.
279
+
280
+ Parameters
281
+ ----------
282
+ x : Tensor
283
+ Input tensor of shape (batch_size, n_chans, n_times).
284
+ filter : dict
285
+ Dictionary containing IIR filter coefficients.
286
+ - "b": Tensor of numerator coefficients.
287
+ n_chans: int
288
+ Number of channels
289
+
290
+ Returns
291
+ -------
292
+ Tensor
293
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
294
+ """
295
+ # Expand filter coefficients to match the number of channels
296
+ # Original 'b' shape: (filter_length,)
297
+ # After unsqueeze and repeat: (n_chans, filter_length)
298
+ # After final unsqueeze: (1, n_chans, filter_length)
299
+ filt_expanded = filt.to(x.device).unsqueeze(0).repeat(n_chans, 1).unsqueeze(0)
300
+
301
+ # Perform FFT-based convolution
302
+ # Input x shape: (batch_size, n_chans, n_times)
303
+ # filt_expanded shape: (1, n_chans, filter_length)
304
+ # After convolution: (batch_size, n_chans, n_times)
305
+
306
+ filtered = fftconvolve(
307
+ x, filt_expanded, mode="same"
308
+ ) # Shape: (batch_size, nchans, time_points)
309
+
310
+ # Add a new dimension for the band
311
+ # Shape after unsqueeze: (batch_size, 1, n_chans, n_times)
312
+ filtered = filtered.unsqueeze(1)
313
+ # returning the filtered
314
+ return filtered
315
+
316
+ @staticmethod
317
+ def _apply_iir(x: Tensor, b_coeffs: Tensor, a_coeffs: Tensor) -> Tensor:
318
+ """
319
+ Apply an IIR filter to the input tensor.
320
+
321
+ Parameters
322
+ ----------
323
+ x : Tensor
324
+ Input tensor of shape (batch_size, n_chans, n_times).
325
+ filter : dict
326
+ Dictionary containing IIR filter coefficients
327
+
328
+ - "b": Tensor of numerator coefficients.
329
+ - "a": Tensor of denominator coefficients.
330
+
331
+ Returns
332
+ -------
333
+ Tensor
334
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
335
+ """
336
+ # Apply filtering using torchaudio's filtfilt
337
+ filtered = filtfilt(
338
+ x,
339
+ a_coeffs=a_coeffs.type_as(x).to(x.device),
340
+ b_coeffs=b_coeffs.type_as(x).to(x.device),
341
+ clamp=False,
342
+ )
343
+ # Rearrange dimensions to (batch_size, 1, n_chans, n_times)
344
+ return filtered.unsqueeze(1)
345
+
346
+
347
+ class GeneralizedGaussianFilter(nn.Module):
348
+ """Generalized Gaussian Filter from Ludwig et al (2024) [eegminer]_.
349
+
350
+ Implements trainable temporal filters based on generalized Gaussian functions
351
+ in the frequency domain.
352
+
353
+ This module creates filters in the frequency domain using the generalized
354
+ Gaussian function, allowing for trainable center frequency (`f_mean`),
355
+ bandwidth (`bandwidth`), and shape (`shape`) parameters.
356
+
357
+ The filters are applied to the input signal in the frequency domain, and can
358
+ be optionally transformed back to the time domain using the inverse
359
+ Fourier transform.
360
+
361
+ The generalized Gaussian function in the frequency domain is defined as:
362
+
363
+ .. math::
364
+
365
+ F(x) = \\exp\\left( - \\left( \\frac{abs(x - \\mu)}{\\alpha} \\right)^{\\beta} \\right)
366
+
367
+ where:
368
+ - μ (mu) is the center frequency (`f_mean`).
369
+
370
+ - α (alpha) is the scale parameter, reparameterized in terms of the full width at half maximum (FWHM) `h` as:
371
+
372
+ .. math::
373
+
374
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
375
+
376
+ - β (beta) is the shape parameter (`shape`), controlling the shape of the filter.
377
+
378
+ The filters are constructed in the frequency domain to allow full control
379
+ over the magnitude and phase responses.
380
+
381
+ A linear phase response is used, with an optional trainable group delay (`group_delay`).
382
+
383
+ - Copyright (C) Cogitat, Ltd.
384
+ - Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
385
+ - Patent GB2609265 - Learnable filters for eeg classification
386
+ - https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
387
+
388
+ Parameters
389
+ ----------
390
+ in_channels : int
391
+ Number of input channels.
392
+ out_channels : int
393
+ Number of output channels. Must be a multiple of `in_channels`.
394
+ sequence_length : int
395
+ Length of the input sequences (time steps).
396
+ sample_rate : float
397
+ Sampling rate of the input signals in Hz.
398
+ inverse_fourier : bool, optional
399
+ If True, applies the inverse Fourier transform to return to the time domain after filtering.
400
+ Default is True.
401
+ affine_group_delay : bool, optional
402
+ If True, makes the group delay parameter trainable. Default is False.
403
+ group_delay : tuple of float, optional
404
+ Initial group delay(s) in milliseconds for the filters. Default is (20.0,).
405
+ f_mean : tuple of float, optional
406
+ Initial center frequency (frequencies) of the filters in Hz. Default is (23.0,).
407
+ bandwidth : tuple of float, optional
408
+ Initial bandwidth(s) (full width at half maximum) of the filters in Hz. Default is (44.0,).
409
+ shape : tuple of float, optional
410
+ Initial shape parameter(s) of the generalized Gaussian filters. Must be >= 2.0. Default is (2.0,).
411
+ clamp_f_mean : tuple of float, optional
412
+ Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
413
+ Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
414
+
415
+
416
+ Notes
417
+ -----
418
+ The model and the module **have a patent** [eegminercode]_, and the **code is CC BY-NC 4.0**.
419
+
420
+ .. versionadded:: 0.9
421
+
422
+ References
423
+ ----------
424
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
425
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
426
+ of brain activity with learnable filters. Journal of Neural Engineering,
427
+ 21(3), 036010.
428
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
429
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
430
+ of brain activity with learnable filters.
431
+ https://github.com/SMLudwig/EEGminer/.
432
+ Cogitat, Ltd. "Learnable filters for EEG classification."
433
+ Patent GB2609265.
434
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ in_channels,
440
+ out_channels,
441
+ sequence_length,
442
+ sample_rate,
443
+ inverse_fourier=True,
444
+ affine_group_delay=False,
445
+ group_delay=(20.0,),
446
+ f_mean=(23.0,),
447
+ bandwidth=(44.0,),
448
+ shape=(2.0,),
449
+ clamp_f_mean=(1.0, 45.0),
450
+ ):
451
+ super(GeneralizedGaussianFilter, self).__init__()
452
+ self.in_channels = in_channels
453
+ self.out_channels = out_channels
454
+ self.sequence_length = sequence_length
455
+ self.sample_rate = sample_rate
456
+ self.inverse_fourier = inverse_fourier
457
+ self.affine_group_delay = affine_group_delay
458
+ self.clamp_f_mean = clamp_f_mean
459
+ assert out_channels % in_channels == 0, (
460
+ "out_channels has to be multiple of in_channels"
461
+ )
462
+ assert len(f_mean) * in_channels == out_channels
463
+ assert len(bandwidth) * in_channels == out_channels
464
+ assert len(shape) * in_channels == out_channels
465
+
466
+ # Range from 0 to half sample rate, normalized
467
+ self.n_range = nn.Parameter(
468
+ torch.tensor(
469
+ list(
470
+ fftfreq(n=sequence_length, d=1 / sample_rate)[
471
+ : sequence_length // 2
472
+ ]
473
+ )
474
+ + [sample_rate / 2]
475
+ )
476
+ / (sample_rate / 2),
477
+ requires_grad=False,
478
+ )
479
+
480
+ # Trainable filter parameters
481
+ self.f_mean = nn.Parameter(
482
+ torch.tensor(f_mean * in_channels) / (sample_rate / 2), requires_grad=True
483
+ )
484
+ self.bandwidth = nn.Parameter(
485
+ torch.tensor(bandwidth * in_channels) / (sample_rate / 2),
486
+ requires_grad=True,
487
+ ) # full width half maximum
488
+ self.shape = nn.Parameter(torch.tensor(shape * in_channels), requires_grad=True)
489
+
490
+ # Normalize group delay so that group_delay=1 corresponds to 1000ms
491
+ self.group_delay = nn.Parameter(
492
+ torch.tensor(group_delay * in_channels) / 1000,
493
+ requires_grad=affine_group_delay,
494
+ )
495
+
496
+ # Construct filters from parameters
497
+ self.filters = self.construct_filters()
498
+
499
+ @staticmethod
500
+ def exponential_power(x, mean, fwhm, shape):
501
+ """
502
+ Computes the generalized Gaussian function:
503
+
504
+ .. math::
505
+
506
+ F(x) = \\exp\\left( - \\left( \\frac{|x - \\mu|}{\\alpha} \\right)^{\\beta} \\right)
507
+
508
+ where:
509
+
510
+ - :math:`\\mu` is the mean (`mean`).
511
+
512
+ - :math:`\\alpha` is the scale parameter, reparameterized using the FWHM :math:`h` as:
513
+
514
+ .. math::
515
+
516
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
517
+
518
+ - :math:`\\beta` is the shape parameter (`shape`).
519
+
520
+ Parameters
521
+ ----------
522
+ x : torch.Tensor
523
+ The input tensor representing frequencies, normalized between 0 and 1.
524
+ mean : torch.Tensor
525
+ The center frequency (`f_mean`), normalized between 0 and 1.
526
+ fwhm : torch.Tensor
527
+ The full width at half maximum (`bandwidth`), normalized between 0 and 1.
528
+ shape : torch.Tensor
529
+ The shape parameter (`shape`) of the generalized Gaussian.
530
+
531
+ Returns
532
+ -------
533
+ torch.Tensor
534
+ The computed generalized Gaussian function values at frequencies `x`.
535
+
536
+ """
537
+ mean = mean.unsqueeze(1)
538
+ fwhm = fwhm.unsqueeze(1)
539
+ shape = shape.unsqueeze(1)
540
+ log2 = torch.log(torch.tensor(2.0, device=x.device, dtype=x.dtype))
541
+ scale = fwhm / (2 * log2 ** (1 / shape))
542
+ # Add small constant to difference between x and mean since grad of 0 ** shape is nan
543
+ return torch.exp(-((((x - mean).abs() + 1e-8) / scale) ** shape))
544
+
545
+ def construct_filters(self):
546
+ """
547
+ Constructs the filters in the frequency domain based on current parameters.
548
+
549
+ Returns
550
+ -------
551
+ torch.Tensor
552
+ The constructed filters with shape `(out_channels, freq_bins, 2)`.
553
+
554
+ """
555
+ # Clamp parameters
556
+ self.f_mean.data = torch.clamp(
557
+ self.f_mean.data,
558
+ min=self.clamp_f_mean[0] / (self.sample_rate / 2),
559
+ max=self.clamp_f_mean[1] / (self.sample_rate / 2),
560
+ )
561
+ self.bandwidth.data = torch.clamp(
562
+ self.bandwidth.data, min=1.0 / (self.sample_rate / 2), max=1.0
563
+ )
564
+ self.shape.data = torch.clamp(self.shape.data, min=2.0, max=3.0)
565
+
566
+ # Create magnitude response with gain=1 -> (channels, freqs)
567
+ mag_response = self.exponential_power(
568
+ self.n_range, self.f_mean, self.bandwidth, self.shape * 8 - 14
569
+ )
570
+ mag_response = mag_response / mag_response.max(dim=-1, keepdim=True)[0]
571
+
572
+ # Create phase response, scaled so that normalized group_delay=1
573
+ # corresponds to group delay of 1000ms.
574
+ phase = torch.linspace(
575
+ 0,
576
+ self.sample_rate,
577
+ self.sequence_length // 2 + 1,
578
+ device=mag_response.device,
579
+ dtype=mag_response.dtype,
580
+ )
581
+ phase = phase.expand(mag_response.shape[0], -1) # repeat for filter channels
582
+ pha_response = -self.group_delay.unsqueeze(-1) * phase * torch.pi
583
+
584
+ # Create real and imaginary parts of the filters
585
+ real = mag_response * torch.cos(pha_response)
586
+ imag = mag_response * torch.sin(pha_response)
587
+
588
+ # Stack real and imaginary parts to create filters
589
+ # -> (channels, freqs, 2)
590
+ filters = torch.stack((real, imag), dim=-1)
591
+
592
+ return filters
593
+
594
+ def forward(self, x):
595
+ """
596
+ Applies the generalized Gaussian filters to the input signal.
597
+
598
+ Parameters
599
+ ----------
600
+ x : torch.Tensor
601
+ Input tensor of shape `(..., in_channels, sequence_length)`.
602
+
603
+ Returns
604
+ -------
605
+ torch.Tensor
606
+ The filtered signal. If `inverse_fourier` is True, returns the signal in the time domain
607
+ with shape `(..., out_channels, sequence_length)`. Otherwise, returns the signal in the
608
+ frequency domain with shape `(..., out_channels, freq_bins, 2)`.
609
+
610
+ """
611
+ # Construct filters from parameters
612
+ self.filters = self.construct_filters()
613
+ # Preserving the original dtype.
614
+ dtype = x.dtype
615
+ # Apply FFT -> (..., channels, freqs, 2)
616
+ x = torch.fft.rfft(x, dim=-1)
617
+ x = torch.view_as_real(x) # separate real and imag
618
+
619
+ # Repeat channels in case of multiple filters per channel
620
+ x = torch.repeat_interleave(x, self.out_channels // self.in_channels, dim=-3)
621
+
622
+ # Apply filters in the frequency domain
623
+ x = x * self.filters
624
+
625
+ # Apply inverse FFT if requested
626
+ if self.inverse_fourier:
627
+ x = torch.view_as_complex(x)
628
+ x = torch.fft.irfft(x, n=self.sequence_length, dim=-1)
629
+
630
+ x = x.to(dtype)
631
+
632
+ return x