braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,628 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from mne.filter import _check_coefficients, create_filter
7
+ from mne.utils import warn
8
+ from torch import Tensor, from_numpy, nn
9
+ from torch.fft import fftfreq
10
+ from torchaudio.functional import fftconvolve, filtfilt
11
+
12
+
13
+ class FilterBankLayer(nn.Module):
14
+ """Apply multiple band-pass filters to generate multiview signal representation.
15
+
16
+ This layer constructs a bank of signals filtered in specific bands for each channel.
17
+ It uses MNE's `create_filter` function to create the band-specific filters and
18
+ applies them to multi-channel time-series data. Each filter in the bank corresponds to a
19
+ specific frequency band and is applied to all channels of the input data. The filtering is
20
+ performed using FFT-based convolution via the `fftconvolve` function from
21
+ :func:`torchaudio.functional if the method is FIR, and `filtfilt` function from
22
+ :func:`torchaudio.functional if the method is IIR.
23
+
24
+ The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
25
+ spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
26
+ reference: *FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface*.
27
+
28
+ Parameters
29
+ ----------
30
+ n_chans : int
31
+ Number of channels in the input signal.
32
+ sfreq : int
33
+ Sampling frequency of the input signal in Hz.
34
+ band_filters : Optional[list[tuple[float, float]]] or int, default=None
35
+ List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines
36
+ the frequency range for one filter in the bank. If not provided, defaults
37
+ to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
38
+ method : str, default='fir'
39
+ ``'fir'`` will use FIR filtering, ``'iir'`` will use IIR
40
+ forward-backward filtering (via :func:`~scipy.signal.filtfilt`).
41
+ For more details, please check the `MNE Preprocessing Tutorial <https://mne.tools/stable/auto_tutorials/preprocessing/25_background_filtering.html>`_.
42
+ filter_length : str | int
43
+ Length of the FIR filter to use (if applicable):
44
+
45
+ * **'auto' (default)**: The filter length is chosen based
46
+ on the size of the transition regions (6.6 times the reciprocal
47
+ of the shortest transition band for fir_window='hamming'
48
+ and fir_design="firwin2", and half that for "firwin").
49
+ * **str**: A human-readable time in
50
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
51
+ converted to that number of samples if ``phase="zero"``, or
52
+ the shortest power-of-two length at least that duration for
53
+ ``phase="zero-double"``.
54
+ * **int**: Specified length in samples. For fir_design="firwin",
55
+ this should not be used.
56
+ l_trans_bandwidth : Union[str, float, int], default='auto'
57
+ Width of the transition band at the low cut-off frequency in Hz
58
+ (high pass or cutoff 1 in bandpass). Can be "auto"
59
+ (default) to use a multiple of ``l_freq``::
60
+
61
+ min(max(l_freq * 0.25, 2), l_freq)
62
+
63
+ Only used for ``method='fir'``.
64
+ h_trans_bandwidth : Union[str, float, int], default='auto'
65
+ Width of the transition band at the high cut-off frequency in Hz
66
+ (low pass or cutoff 2 in bandpass). Can be "auto"
67
+ (default in 0.14) to use a multiple of ``h_freq``::
68
+
69
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
70
+
71
+ Only used for ``method='fir'``.
72
+ phase : str, default='zero'
73
+ Phase of the filter.
74
+ When ``method='fir'``, symmetric linear-phase FIR filters are constructed
75
+ with the following behaviors when ``method="fir"``:
76
+
77
+ ``"zero"`` (default)
78
+ The delay of this filter is compensated for, making it non-causal.
79
+ ``"minimum"``
80
+ A minimum-phase filter will be constructed by decomposing the zero-phase filter
81
+ into a minimum-phase and all-pass systems, and then retaining only the
82
+ minimum-phase system (of the same length as the original zero-phase filter)
83
+ via :func:`scipy.signal.minimum_phase`.
84
+ ``"zero-double"``
85
+ *This is a legacy option for compatibility with MNE <= 0.13.*
86
+ The filter is applied twice, once forward, and once backward
87
+ (also making it non-causal).
88
+ ``"minimum-half"``
89
+ *This is a legacy option for compatibility with MNE <= 1.6.*
90
+ A minimum-phase filter will be reconstructed from the zero-phase filter with
91
+ half the length of the original filter.
92
+
93
+ When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'``
94
+ constructs and applies IIR filter twice, once forward, and once backward (making it
95
+ non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
96
+ the filter once in the forward (causal) direction using
97
+ :func:`~scipy.signal.lfilter`.
98
+
99
+ The behavior for ``phase="minimum"`` was fixed to use a filter of the requested
100
+ length and improved suppression.
101
+ iir_params : Optional[dict], default=None
102
+ Dictionary of parameters to use for IIR filtering.
103
+ If ``iir_params=None`` and ``method="iir"``, 4th order Butterworth will be used.
104
+ For more information, see :func:`mne.filter.construct_iir_filter`.
105
+ fir_window : str, default='hamming'
106
+ The window to use in FIR design, can be "hamming" (default),
107
+ "hann" (default in 0.13), or "blackman".
108
+ fir_design : str, default='firwin'
109
+ Can be "firwin" (default) to use :func:`scipy.signal.firwin`,
110
+ or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses
111
+ a time-domain design technique that generally gives improved
112
+ attenuation using fewer samples than "firwin2".
113
+ pad : str, default='reflect_limited'
114
+ The type of padding to use. Supports all func:`numpy.pad()` mode options.
115
+ Can also be "reflect_limited", which pads with a reflected version of
116
+ each vector mirrored on the first and last values of the vector,
117
+ followed by zeros. Only used for ``method='fir'``.
118
+ verbose: bool | str | int | None, default=True
119
+ Control verbosity of the logging output. If ``None``, use the default
120
+ verbosity level. See the func:`mne.verbose` for details.
121
+ Should only be passed as a keyword argument.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ n_chans: int,
127
+ sfreq: float,
128
+ band_filters: Optional[list[tuple[float, float]] | int] = None,
129
+ method: str = "fir",
130
+ filter_length: str | float | int = "auto",
131
+ l_trans_bandwidth: str | float | int = "auto",
132
+ h_trans_bandwidth: str | float | int = "auto",
133
+ phase: str = "zero",
134
+ iir_params: Optional[dict] = None,
135
+ fir_window: str = "hamming",
136
+ fir_design: str = "firwin",
137
+ verbose: bool = True,
138
+ ):
139
+ super(FilterBankLayer, self).__init__()
140
+
141
+ # The first step here is to check the band_filters
142
+ # We accept as None values.
143
+ if band_filters is None:
144
+ """
145
+ the filter bank is constructed using 9 filters with non-overlapping
146
+ frequency bands, each of 4Hz bandwidth, spanning from 4 to 40 Hz
147
+ (4-8, 8-12, …, 36-40 Hz)
148
+
149
+ Based on the reference: FBCNet: A Multi-view Convolutional Neural
150
+ Network for Brain-Computer Interface
151
+ """
152
+ band_filters = [(low, low + 4) for low in range(4, 36 + 1, 4)]
153
+ # We accept as int.
154
+ if isinstance(band_filters, int):
155
+ warn(
156
+ "Creating the filter banks equally divided in the "
157
+ "interval 4Hz to 40Hz with almost equal bandwidths. "
158
+ "If you want a specific interval, "
159
+ "please specify 'band_filters' as a list of tuples.",
160
+ UserWarning,
161
+ )
162
+ start = 4.0
163
+ end = 40.0
164
+
165
+ total_band_width = end - start # 4 Hz to 40 Hz
166
+
167
+ band_width_calculated = total_band_width / band_filters
168
+ band_filters = [
169
+ (
170
+ float(start + i * band_width_calculated),
171
+ float(start + (i + 1) * band_width_calculated),
172
+ )
173
+ for i in range(band_filters)
174
+ ]
175
+
176
+ if not isinstance(band_filters, list):
177
+ raise ValueError(
178
+ "`band_filters` should be a list of tuples if you want to "
179
+ "use them this way."
180
+ )
181
+ else:
182
+ if any(len(bands) != 2 for bands in band_filters):
183
+ raise ValueError(
184
+ "The band_filters items should be splitable in 2 values."
185
+ )
186
+
187
+ # and we accepted as
188
+ self.band_filters = band_filters
189
+ self.n_bands = len(band_filters)
190
+ self.phase = phase
191
+ self.method = method
192
+ self.n_chans = n_chans
193
+
194
+ self.method_iir = self.method == "iir"
195
+
196
+ # Prepare ParameterLists
197
+ self.fir_list = nn.ParameterList()
198
+ self.b_list = nn.ParameterList()
199
+ self.a_list = nn.ParameterList()
200
+
201
+ if self.method_iir:
202
+ if iir_params is None:
203
+ iir_params = dict(output="ba")
204
+ else:
205
+ if "output" in iir_params:
206
+ if iir_params["output"] == "sos":
207
+ warn(
208
+ "It is not possible to use second-order section filtering with Torch. Changing to filter ba",
209
+ UserWarning,
210
+ )
211
+ iir_params["output"] = "ba"
212
+
213
+ for l_freq, h_freq in band_filters:
214
+ filt = create_filter(
215
+ data=None,
216
+ sfreq=sfreq,
217
+ l_freq=float(l_freq),
218
+ h_freq=float(h_freq),
219
+ filter_length=filter_length,
220
+ l_trans_bandwidth=l_trans_bandwidth,
221
+ h_trans_bandwidth=h_trans_bandwidth,
222
+ method=self.method,
223
+ iir_params=iir_params,
224
+ phase=phase,
225
+ fir_window=fir_window,
226
+ fir_design=fir_design,
227
+ verbose=verbose,
228
+ )
229
+ if not self.method_iir:
230
+ # FIR filter
231
+ filt = from_numpy(filt).float()
232
+ self.fir_list.append(nn.Parameter(filt, requires_grad=False))
233
+ else:
234
+ a_coeffs = filt["a"]
235
+ b_coeffs = filt["b"]
236
+
237
+ _check_coefficients((b_coeffs, a_coeffs))
238
+
239
+ b = torch.tensor(b_coeffs, dtype=torch.float64)
240
+ a = torch.tensor(a_coeffs, dtype=torch.float64)
241
+
242
+ self.b_list.append(nn.Parameter(b, requires_grad=False))
243
+ self.a_list.append(nn.Parameter(a, requires_grad=False))
244
+
245
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
246
+ """
247
+ Apply the filter bank to the input signal.
248
+
249
+ Parameters
250
+ ----------
251
+ x : torch.Tensor
252
+ Input tensor of shape (batch_size, n_chans, time_points).
253
+
254
+ Returns
255
+ -------
256
+ torch.Tensor
257
+ Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
258
+ """
259
+ outs = []
260
+ if self.method_iir:
261
+ for b, a in zip(self.b_list, self.a_list):
262
+ # Pass numerator and denominator directly
263
+ outs.append(self._apply_iir(x=x, b_coeffs=b, a_coeffs=a))
264
+ else:
265
+ for fir in self.fir_list:
266
+ # Pass FIR filter directly
267
+ outs.append(self._apply_fir(x=x, filt=fir, n_chans=self.n_chans))
268
+
269
+ return torch.cat(outs, dim=1)
270
+
271
+ @staticmethod
272
+ def _apply_fir(x, filt: Tensor, n_chans: int) -> Tensor:
273
+ """
274
+ Apply an FIR filter to the input tensor.
275
+
276
+ Parameters
277
+ ----------
278
+ x : Tensor
279
+ Input tensor of shape (batch_size, n_chans, n_times).
280
+ filter : dict
281
+ Dictionary containing IIR filter coefficients.
282
+ - "b": Tensor of numerator coefficients.
283
+ n_chans: int
284
+ Number of channels
285
+
286
+ Returns
287
+ -------
288
+ Tensor
289
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
290
+ """
291
+ # Expand filter coefficients to match the number of channels
292
+ # Original 'b' shape: (filter_length,)
293
+ # After unsqueeze and repeat: (n_chans, filter_length)
294
+ # After final unsqueeze: (1, n_chans, filter_length)
295
+ filt_expanded = filt.to(x.device).unsqueeze(0).repeat(n_chans, 1).unsqueeze(0)
296
+
297
+ # Perform FFT-based convolution
298
+ # Input x shape: (batch_size, n_chans, n_times)
299
+ # filt_expanded shape: (1, n_chans, filter_length)
300
+ # After convolution: (batch_size, n_chans, n_times)
301
+
302
+ filtered = fftconvolve(
303
+ x, filt_expanded, mode="same"
304
+ ) # Shape: (batch_size, nchans, time_points)
305
+
306
+ # Add a new dimension for the band
307
+ # Shape after unsqueeze: (batch_size, 1, n_chans, n_times)
308
+ filtered = filtered.unsqueeze(1)
309
+ # returning the filtered
310
+ return filtered
311
+
312
+ @staticmethod
313
+ def _apply_iir(x: Tensor, b_coeffs: Tensor, a_coeffs: Tensor) -> Tensor:
314
+ """
315
+ Apply an IIR filter to the input tensor.
316
+
317
+ Parameters
318
+ ----------
319
+ x : Tensor
320
+ Input tensor of shape (batch_size, n_chans, n_times).
321
+ filter : dict
322
+ Dictionary containing IIR filter coefficients
323
+
324
+ - "b": Tensor of numerator coefficients.
325
+ - "a": Tensor of denominator coefficients.
326
+
327
+ Returns
328
+ -------
329
+ Tensor
330
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
331
+ """
332
+ # Apply filtering using torchaudio's filtfilt
333
+ filtered = filtfilt(
334
+ x,
335
+ a_coeffs=a_coeffs.type_as(x).to(x.device),
336
+ b_coeffs=b_coeffs.type_as(x).to(x.device),
337
+ clamp=False,
338
+ )
339
+ # Rearrange dimensions to (batch_size, 1, n_chans, n_times)
340
+ return filtered.unsqueeze(1)
341
+
342
+
343
+ class GeneralizedGaussianFilter(nn.Module):
344
+ """Generalized Gaussian Filter from Ludwig et al (2024) [eegminer]_.
345
+
346
+ Implements trainable temporal filters based on generalized Gaussian functions
347
+ in the frequency domain.
348
+
349
+ This module creates filters in the frequency domain using the generalized
350
+ Gaussian function, allowing for trainable center frequency (`f_mean`),
351
+ bandwidth (`bandwidth`), and shape (`shape`) parameters.
352
+
353
+ The filters are applied to the input signal in the frequency domain, and can
354
+ be optionally transformed back to the time domain using the inverse
355
+ Fourier transform.
356
+
357
+ The generalized Gaussian function in the frequency domain is defined as:
358
+
359
+ .. math::
360
+
361
+ F(x) = \\exp\\left( - \\left( \\frac{abs(x - \\mu)}{\\alpha} \\right)^{\\beta} \\right)
362
+
363
+ where:
364
+ - μ (mu) is the center frequency (`f_mean`).
365
+
366
+ - α (alpha) is the scale parameter, reparameterized in terms of the full width at half maximum (FWHM) `h` as:
367
+
368
+ .. math::
369
+
370
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
371
+
372
+ - β (beta) is the shape parameter (`shape`), controlling the shape of the filter.
373
+
374
+ The filters are constructed in the frequency domain to allow full control
375
+ over the magnitude and phase responses.
376
+
377
+ A linear phase response is used, with an optional trainable group delay (`group_delay`).
378
+
379
+ - Copyright (C) Cogitat, Ltd.
380
+ - Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
381
+ - Patent GB2609265 - Learnable filters for eeg classification
382
+ - https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
383
+
384
+ Parameters
385
+ ----------
386
+ in_channels : int
387
+ Number of input channels.
388
+ out_channels : int
389
+ Number of output channels. Must be a multiple of `in_channels`.
390
+ sequence_length : int
391
+ Length of the input sequences (time steps).
392
+ sample_rate : float
393
+ Sampling rate of the input signals in Hz.
394
+ inverse_fourier : bool, optional
395
+ If True, applies the inverse Fourier transform to return to the time domain after filtering.
396
+ Default is True.
397
+ affine_group_delay : bool, optional
398
+ If True, makes the group delay parameter trainable. Default is False.
399
+ group_delay : tuple of float, optional
400
+ Initial group delay(s) in milliseconds for the filters. Default is (20.0,).
401
+ f_mean : tuple of float, optional
402
+ Initial center frequency (frequencies) of the filters in Hz. Default is (23.0,).
403
+ bandwidth : tuple of float, optional
404
+ Initial bandwidth(s) (full width at half maximum) of the filters in Hz. Default is (44.0,).
405
+ shape : tuple of float, optional
406
+ Initial shape parameter(s) of the generalized Gaussian filters. Must be >= 2.0. Default is (2.0,).
407
+ clamp_f_mean : tuple of float, optional
408
+ Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
409
+ Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
410
+
411
+
412
+ Notes
413
+ -----
414
+ The model and the module **have a patent** [eegminercode]_, and the **code is CC BY-NC 4.0**.
415
+
416
+ .. versionadded:: 0.9
417
+
418
+ References
419
+ ----------
420
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
421
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
422
+ of brain activity with learnable filters. Journal of Neural Engineering,
423
+ 21(3), 036010.
424
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
425
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
426
+ of brain activity with learnable filters.
427
+ https://github.com/SMLudwig/EEGminer/.
428
+ Cogitat, Ltd. "Learnable filters for EEG classification."
429
+ Patent GB2609265.
430
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ in_channels,
436
+ out_channels,
437
+ sequence_length,
438
+ sample_rate,
439
+ inverse_fourier=True,
440
+ affine_group_delay=False,
441
+ group_delay=(20.0,),
442
+ f_mean=(23.0,),
443
+ bandwidth=(44.0,),
444
+ shape=(2.0,),
445
+ clamp_f_mean=(1.0, 45.0),
446
+ ):
447
+ super(GeneralizedGaussianFilter, self).__init__()
448
+ self.in_channels = in_channels
449
+ self.out_channels = out_channels
450
+ self.sequence_length = sequence_length
451
+ self.sample_rate = sample_rate
452
+ self.inverse_fourier = inverse_fourier
453
+ self.affine_group_delay = affine_group_delay
454
+ self.clamp_f_mean = clamp_f_mean
455
+ assert out_channels % in_channels == 0, (
456
+ "out_channels has to be multiple of in_channels"
457
+ )
458
+ assert len(f_mean) * in_channels == out_channels
459
+ assert len(bandwidth) * in_channels == out_channels
460
+ assert len(shape) * in_channels == out_channels
461
+
462
+ # Range from 0 to half sample rate, normalized
463
+ self.n_range = nn.Parameter(
464
+ torch.tensor(
465
+ list(
466
+ fftfreq(n=sequence_length, d=1 / sample_rate)[
467
+ : sequence_length // 2
468
+ ]
469
+ )
470
+ + [sample_rate / 2]
471
+ )
472
+ / (sample_rate / 2),
473
+ requires_grad=False,
474
+ )
475
+
476
+ # Trainable filter parameters
477
+ self.f_mean = nn.Parameter(
478
+ torch.tensor(f_mean * in_channels) / (sample_rate / 2), requires_grad=True
479
+ )
480
+ self.bandwidth = nn.Parameter(
481
+ torch.tensor(bandwidth * in_channels) / (sample_rate / 2),
482
+ requires_grad=True,
483
+ ) # full width half maximum
484
+ self.shape = nn.Parameter(torch.tensor(shape * in_channels), requires_grad=True)
485
+
486
+ # Normalize group delay so that group_delay=1 corresponds to 1000ms
487
+ self.group_delay = nn.Parameter(
488
+ torch.tensor(group_delay * in_channels) / 1000,
489
+ requires_grad=affine_group_delay,
490
+ )
491
+
492
+ # Construct filters from parameters
493
+ self.filters = self.construct_filters()
494
+
495
+ @staticmethod
496
+ def exponential_power(x, mean, fwhm, shape):
497
+ """
498
+ Computes the generalized Gaussian function:
499
+
500
+ .. math::
501
+
502
+ F(x) = \\exp\\left( - \\left( \\frac{|x - \\mu|}{\\alpha} \\right)^{\\beta} \\right)
503
+
504
+ where:
505
+
506
+ - :math:`\\mu` is the mean (`mean`).
507
+
508
+ - :math:`\\alpha` is the scale parameter, reparameterized using the FWHM :math:`h` as:
509
+
510
+ .. math::
511
+
512
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
513
+
514
+ - :math:`\\beta` is the shape parameter (`shape`).
515
+
516
+ Parameters
517
+ ----------
518
+ x : torch.Tensor
519
+ The input tensor representing frequencies, normalized between 0 and 1.
520
+ mean : torch.Tensor
521
+ The center frequency (`f_mean`), normalized between 0 and 1.
522
+ fwhm : torch.Tensor
523
+ The full width at half maximum (`bandwidth`), normalized between 0 and 1.
524
+ shape : torch.Tensor
525
+ The shape parameter (`shape`) of the generalized Gaussian.
526
+
527
+ Returns
528
+ -------
529
+ torch.Tensor
530
+ The computed generalized Gaussian function values at frequencies `x`.
531
+
532
+ """
533
+ mean = mean.unsqueeze(1)
534
+ fwhm = fwhm.unsqueeze(1)
535
+ shape = shape.unsqueeze(1)
536
+ log2 = torch.log(torch.tensor(2.0, device=x.device, dtype=x.dtype))
537
+ scale = fwhm / (2 * log2 ** (1 / shape))
538
+ # Add small constant to difference between x and mean since grad of 0 ** shape is nan
539
+ return torch.exp(-((((x - mean).abs() + 1e-8) / scale) ** shape))
540
+
541
+ def construct_filters(self):
542
+ """
543
+ Constructs the filters in the frequency domain based on current parameters.
544
+
545
+ Returns
546
+ -------
547
+ torch.Tensor
548
+ The constructed filters with shape `(out_channels, freq_bins, 2)`.
549
+
550
+ """
551
+ # Clamp parameters
552
+ self.f_mean.data = torch.clamp(
553
+ self.f_mean.data,
554
+ min=self.clamp_f_mean[0] / (self.sample_rate / 2),
555
+ max=self.clamp_f_mean[1] / (self.sample_rate / 2),
556
+ )
557
+ self.bandwidth.data = torch.clamp(
558
+ self.bandwidth.data, min=1.0 / (self.sample_rate / 2), max=1.0
559
+ )
560
+ self.shape.data = torch.clamp(self.shape.data, min=2.0, max=3.0)
561
+
562
+ # Create magnitude response with gain=1 -> (channels, freqs)
563
+ mag_response = self.exponential_power(
564
+ self.n_range, self.f_mean, self.bandwidth, self.shape * 8 - 14
565
+ )
566
+ mag_response = mag_response / mag_response.max(dim=-1, keepdim=True)[0]
567
+
568
+ # Create phase response, scaled so that normalized group_delay=1
569
+ # corresponds to group delay of 1000ms.
570
+ phase = torch.linspace(
571
+ 0,
572
+ self.sample_rate,
573
+ self.sequence_length // 2 + 1,
574
+ device=mag_response.device,
575
+ dtype=mag_response.dtype,
576
+ )
577
+ phase = phase.expand(mag_response.shape[0], -1) # repeat for filter channels
578
+ pha_response = -self.group_delay.unsqueeze(-1) * phase * torch.pi
579
+
580
+ # Create real and imaginary parts of the filters
581
+ real = mag_response * torch.cos(pha_response)
582
+ imag = mag_response * torch.sin(pha_response)
583
+
584
+ # Stack real and imaginary parts to create filters
585
+ # -> (channels, freqs, 2)
586
+ filters = torch.stack((real, imag), dim=-1)
587
+
588
+ return filters
589
+
590
+ def forward(self, x):
591
+ """
592
+ Applies the generalized Gaussian filters to the input signal.
593
+
594
+ Parameters
595
+ ----------
596
+ x : torch.Tensor
597
+ Input tensor of shape `(..., in_channels, sequence_length)`.
598
+
599
+ Returns
600
+ -------
601
+ torch.Tensor
602
+ The filtered signal. If `inverse_fourier` is True, returns the signal in the time domain
603
+ with shape `(..., out_channels, sequence_length)`. Otherwise, returns the signal in the
604
+ frequency domain with shape `(..., out_channels, freq_bins, 2)`.
605
+
606
+ """
607
+ # Construct filters from parameters
608
+ self.filters = self.construct_filters()
609
+ # Preserving the original dtype.
610
+ dtype = x.dtype
611
+ # Apply FFT -> (..., channels, freqs, 2)
612
+ x = torch.fft.rfft(x, dim=-1)
613
+ x = torch.view_as_real(x) # separate real and imag
614
+
615
+ # Repeat channels in case of multiple filters per channel
616
+ x = torch.repeat_interleave(x, self.out_channels // self.in_channels, dim=-3)
617
+
618
+ # Apply filters in the frequency domain
619
+ x = x * self.filters
620
+
621
+ # Apply inverse FFT if requested
622
+ if self.inverse_fourier:
623
+ x = torch.view_as_complex(x)
624
+ x = torch.fft.irfft(x, n=self.sequence_length, dim=-1)
625
+
626
+ x = x.to(dtype)
627
+
628
+ return x