braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,654 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from mne.filter import _check_coefficients, create_filter
7
+ from mne.utils import warn
8
+ from torch import Tensor, from_numpy, nn
9
+ from torch.fft import fftfreq
10
+ from torchaudio.functional import fftconvolve, filtfilt
11
+
12
+
13
+ class FilterBankLayer(nn.Module):
14
+ """Apply multiple band-pass filters to generate multiview signal representation.
15
+
16
+ This layer constructs a bank of signals filtered in specific bands for each channel.
17
+ It uses MNE's `create_filter` function to create the band-specific filters and
18
+ applies them to multi-channel time-series data. Each filter in the bank corresponds to a
19
+ specific frequency band and is applied to all channels of the input data. The filtering is
20
+ performed using FFT-based convolution via the ``torchaudio.functional`` if the method is FIR,
21
+ and ``torchaudio.functional`` if the method is IIR.
22
+
23
+ The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
24
+ spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
25
+ reference: *FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface*.
26
+
27
+ Parameters
28
+ ----------
29
+ n_chans : int
30
+ Number of channels in the input signal.
31
+ sfreq : int
32
+ Sampling frequency of the input signal in Hz.
33
+ band_filters : Optional[list[tuple[float, float]]] or int, default=None
34
+ List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines
35
+ the frequency range for one filter in the bank. If not provided, defaults
36
+ to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
37
+ method : str, default='fir'
38
+ ``'fir'`` will use FIR filtering, ``'iir'`` will use IIR
39
+ forward-backward filtering (via :func:`~scipy.signal.filtfilt`).
40
+ For more details, please check the `MNE Preprocessing Tutorial <https://mne.tools/stable/auto_tutorials/preprocessing/25_background_filtering.html>`_.
41
+ filter_length : str | int
42
+ Length of the FIR filter to use (if applicable):
43
+
44
+ * **'auto' (default)**: The filter length is chosen based
45
+ on the size of the transition regions (6.6 times the reciprocal
46
+ of the shortest transition band for fir_window='hamming'
47
+ and fir_design="firwin2", and half that for "firwin").
48
+ * **str**: A human-readable time in
49
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
50
+ converted to that number of samples if ``phase="zero"``, or
51
+ the shortest power-of-two length at least that duration for
52
+ ``phase="zero-double"``.
53
+ * **int**: Specified length in samples. For fir_design="firwin",
54
+ this should not be used.
55
+ l_trans_bandwidth : Union[str, float, int], default='auto'
56
+ Width of the transition band at the low cut-off frequency in Hz
57
+ (high pass or cutoff 1 in bandpass). Can be "auto"
58
+ (default) to use a multiple of ``l_freq``::
59
+
60
+ min(max(l_freq * 0.25, 2), l_freq)
61
+
62
+ Only used for ``method='fir'``.
63
+ h_trans_bandwidth : Union[str, float, int], default='auto'
64
+ Width of the transition band at the high cut-off frequency in Hz
65
+ (low pass or cutoff 2 in bandpass). Can be "auto"
66
+ (default in 0.14) to use a multiple of ``h_freq``::
67
+
68
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
69
+
70
+ Only used for ``method='fir'``.
71
+ phase : str, default='zero'
72
+ Phase of the filter.
73
+ When ``method='fir'``, symmetric linear-phase FIR filters are constructed
74
+ with the following behaviors when ``method="fir"``:
75
+
76
+ ``"zero"`` (default)
77
+ The delay of this filter is compensated for, making it non-causal.
78
+ ``"minimum"``
79
+ A minimum-phase filter will be constructed by decomposing the zero-phase filter
80
+ into a minimum-phase and all-pass systems, and then retaining only the
81
+ minimum-phase system (of the same length as the original zero-phase filter)
82
+ via :func:`scipy.signal.minimum_phase`.
83
+ ``"zero-double"``
84
+ *This is a legacy option for compatibility with MNE <= 0.13.*
85
+ The filter is applied twice, once forward, and once backward
86
+ (also making it non-causal).
87
+ ``"minimum-half"``
88
+ *This is a legacy option for compatibility with MNE <= 1.6.*
89
+ A minimum-phase filter will be reconstructed from the zero-phase filter with
90
+ half the length of the original filter.
91
+
92
+ When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'``
93
+ constructs and applies IIR filter twice, once forward, and once backward (making it
94
+ non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
95
+ the filter once in the forward (causal) direction using
96
+ :func:`~scipy.signal.lfilter`.
97
+
98
+ The behavior for ``phase="minimum"`` was fixed to use a filter of the requested
99
+ length and improved suppression.
100
+ iir_params : Optional[dict], default=None
101
+ Dictionary of parameters to use for IIR filtering.
102
+ If ``iir_params=None`` and ``method="iir"``, 4th order Butterworth will be used.
103
+ For more information, see :func:`mne.filter.construct_iir_filter`.
104
+ fir_window : str, default='hamming'
105
+ The window to use in FIR design, can be "hamming" (default),
106
+ "hann" (default in 0.13), or "blackman".
107
+ fir_design : str, default='firwin'
108
+ Can be "firwin" (default) to use :func:`scipy.signal.firwin`,
109
+ or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses
110
+ a time-domain design technique that generally gives improved
111
+ attenuation using fewer samples than "firwin2".
112
+ verbose: bool | str | int | None, default=True
113
+ Control verbosity of the logging output. If ``None``, use the default
114
+ verbosity level. See the func:`mne.verbose` for details.
115
+ Should only be passed as a keyword argument.
116
+
117
+ Examples
118
+ --------
119
+ >>> import torch
120
+ >>> from braindecode.modules import FilterBankLayer
121
+ >>> module = FilterBankLayer(
122
+ ... n_chans=2,
123
+ ... sfreq=128,
124
+ ... band_filters=[(4.0, 8.0), (8.0, 12.0)],
125
+ ... method="fir",
126
+ ... verbose=False,
127
+ ... )
128
+ >>> inputs = torch.randn(2, 2, 256)
129
+ >>> outputs = module(inputs)
130
+ >>> outputs.shape
131
+ torch.Size([2, 2, 2, 256])
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ n_chans: int,
137
+ sfreq: float,
138
+ band_filters: Optional[list[tuple[float, float]] | int] = None,
139
+ method: str = "fir",
140
+ filter_length: str | float | int = "auto",
141
+ l_trans_bandwidth: str | float | int = "auto",
142
+ h_trans_bandwidth: str | float | int = "auto",
143
+ phase: str = "zero",
144
+ iir_params: Optional[dict] = None,
145
+ fir_window: str = "hamming",
146
+ fir_design: str = "firwin",
147
+ verbose: bool = True,
148
+ ):
149
+ super(FilterBankLayer, self).__init__()
150
+
151
+ # The first step here is to check the band_filters
152
+ # We accept as None values.
153
+ if band_filters is None:
154
+ """
155
+ the filter bank is constructed using 9 filters with non-overlapping
156
+ frequency bands, each of 4Hz bandwidth, spanning from 4 to 40 Hz
157
+ (4-8, 8-12, …, 36-40 Hz)
158
+
159
+ Based on the reference: FBCNet: A Multi-view Convolutional Neural
160
+ Network for Brain-Computer Interface
161
+ """
162
+ band_filters = [(low, low + 4) for low in range(4, 36 + 1, 4)]
163
+ # We accept as int.
164
+ if isinstance(band_filters, int):
165
+ warn(
166
+ "Creating the filter banks equally divided in the "
167
+ "interval 4Hz to 40Hz with almost equal bandwidths. "
168
+ "If you want a specific interval, "
169
+ "please specify 'band_filters' as a list of tuples.",
170
+ UserWarning,
171
+ )
172
+ start = 4.0
173
+ end = 40.0
174
+
175
+ total_band_width = end - start # 4 Hz to 40 Hz
176
+
177
+ band_width_calculated = total_band_width / band_filters
178
+ band_filters = [
179
+ (
180
+ float(start + i * band_width_calculated),
181
+ float(start + (i + 1) * band_width_calculated),
182
+ )
183
+ for i in range(band_filters)
184
+ ]
185
+
186
+ if not isinstance(band_filters, list):
187
+ raise ValueError(
188
+ "`band_filters` should be a list of tuples if you want to "
189
+ "use them this way."
190
+ )
191
+ else:
192
+ if any(len(bands) != 2 for bands in band_filters):
193
+ raise ValueError(
194
+ "The band_filters items should be splitable in 2 values."
195
+ )
196
+
197
+ self.band_filters = band_filters
198
+ self.n_bands = len(band_filters)
199
+ self.phase = phase
200
+ self.method = method
201
+ self.n_chans = n_chans
202
+
203
+ self.method_iir = self.method == "iir"
204
+
205
+ # Prepare ParameterLists
206
+ self.fir_list = nn.ParameterList()
207
+ self.b_list = nn.ParameterList()
208
+ self.a_list = nn.ParameterList()
209
+
210
+ if self.method_iir:
211
+ if iir_params is None:
212
+ iir_params = dict(output="ba")
213
+ else:
214
+ if "output" in iir_params:
215
+ if iir_params["output"] == "sos":
216
+ warn(
217
+ "It is not possible to use second-order section filtering with Torch. Changing to filter ba",
218
+ UserWarning,
219
+ )
220
+ iir_params["output"] = "ba"
221
+
222
+ for l_freq, h_freq in band_filters:
223
+ filt = create_filter(
224
+ data=None,
225
+ sfreq=sfreq,
226
+ l_freq=float(l_freq),
227
+ h_freq=float(h_freq),
228
+ filter_length=filter_length,
229
+ l_trans_bandwidth=l_trans_bandwidth,
230
+ h_trans_bandwidth=h_trans_bandwidth,
231
+ method=self.method,
232
+ iir_params=iir_params,
233
+ phase=phase,
234
+ fir_window=fir_window,
235
+ fir_design=fir_design,
236
+ verbose=verbose,
237
+ )
238
+ if not self.method_iir:
239
+ # FIR filter
240
+ filt = from_numpy(filt).float()
241
+ self.fir_list.append(nn.Parameter(filt, requires_grad=False))
242
+ else:
243
+ a_coeffs = filt["a"]
244
+ b_coeffs = filt["b"]
245
+
246
+ _check_coefficients((b_coeffs, a_coeffs))
247
+
248
+ b = torch.tensor(b_coeffs, dtype=torch.float64)
249
+ a = torch.tensor(a_coeffs, dtype=torch.float64)
250
+
251
+ self.b_list.append(nn.Parameter(b, requires_grad=False))
252
+ self.a_list.append(nn.Parameter(a, requires_grad=False))
253
+
254
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
255
+ """
256
+ Apply the filter bank to the input signal.
257
+
258
+ Parameters
259
+ ----------
260
+ x : torch.Tensor
261
+ Input tensor of shape (batch_size, n_chans, time_points).
262
+
263
+ Returns
264
+ -------
265
+ torch.Tensor
266
+ Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
267
+ """
268
+ outs = []
269
+ if self.method_iir:
270
+ for b, a in zip(self.b_list, self.a_list):
271
+ # Pass numerator and denominator directly
272
+ outs.append(self._apply_iir(x=x, b_coeffs=b, a_coeffs=a))
273
+ else:
274
+ for fir in self.fir_list:
275
+ # Pass FIR filter directly
276
+ outs.append(self._apply_fir(x=x, filt=fir, n_chans=self.n_chans))
277
+
278
+ return torch.cat(outs, dim=1)
279
+
280
+ @staticmethod
281
+ def _apply_fir(x, filt: Tensor, n_chans: int) -> Tensor:
282
+ """
283
+ Apply an FIR filter to the input tensor.
284
+
285
+ Parameters
286
+ ----------
287
+ x : Tensor
288
+ Input tensor of shape (batch_size, n_chans, n_times).
289
+ filter : dict
290
+ Dictionary containing IIR filter coefficients.
291
+ - "b": Tensor of numerator coefficients.
292
+ n_chans: int
293
+ Number of channels
294
+
295
+ Returns
296
+ -------
297
+ Tensor
298
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
299
+ """
300
+ # Expand filter coefficients to match the number of channels
301
+ # Original 'b' shape: (filter_length,)
302
+ # After unsqueeze and repeat: (n_chans, filter_length)
303
+ # After final unsqueeze: (1, n_chans, filter_length)
304
+ filt_expanded = filt.to(x.device).unsqueeze(0).repeat(n_chans, 1).unsqueeze(0)
305
+
306
+ # Perform FFT-based convolution
307
+ # Input x shape: (batch_size, n_chans, n_times)
308
+ # filt_expanded shape: (1, n_chans, filter_length)
309
+ # After convolution: (batch_size, n_chans, n_times)
310
+
311
+ filtered = fftconvolve(
312
+ x, filt_expanded, mode="same"
313
+ ) # Shape: (batch_size, nchans, time_points)
314
+
315
+ # Add a new dimension for the band
316
+ # Shape after unsqueeze: (batch_size, 1, n_chans, n_times)
317
+ filtered = filtered.unsqueeze(1)
318
+ # returning the filtered
319
+ return filtered
320
+
321
+ @staticmethod
322
+ def _apply_iir(x: Tensor, b_coeffs: Tensor, a_coeffs: Tensor) -> Tensor:
323
+ """
324
+ Apply an IIR filter to the input tensor.
325
+
326
+ Parameters
327
+ ----------
328
+ x : Tensor
329
+ Input tensor of shape (batch_size, n_chans, n_times).
330
+ filter : dict
331
+ Dictionary containing IIR filter coefficients
332
+
333
+ - "b": Tensor of numerator coefficients.
334
+ - "a": Tensor of denominator coefficients.
335
+
336
+ Returns
337
+ -------
338
+ Tensor
339
+ Filtered tensor of shape (batch_size, 1, n_chans, n_times).
340
+ """
341
+ # Apply filtering using torchaudio's filtfilt
342
+ filtered = filtfilt(
343
+ x,
344
+ a_coeffs=a_coeffs.type_as(x).to(x.device),
345
+ b_coeffs=b_coeffs.type_as(x).to(x.device),
346
+ clamp=False,
347
+ )
348
+ # Rearrange dimensions to (batch_size, 1, n_chans, n_times)
349
+ return filtered.unsqueeze(1)
350
+
351
+
352
+ class GeneralizedGaussianFilter(nn.Module):
353
+ """Generalized Gaussian Filter from Ludwig et al (2024) [eegminer]_.
354
+
355
+ Implements trainable temporal filters based on generalized Gaussian functions
356
+ in the frequency domain.
357
+
358
+ This module creates filters in the frequency domain using the generalized
359
+ Gaussian function, allowing for trainable center frequency (`f_mean`),
360
+ bandwidth (`bandwidth`), and shape (`shape`) parameters.
361
+
362
+ The filters are applied to the input signal in the frequency domain, and can
363
+ be optionally transformed back to the time domain using the inverse
364
+ Fourier transform.
365
+
366
+ The generalized Gaussian function in the frequency domain is defined as:
367
+
368
+ .. math::
369
+
370
+ F(x) = \\exp\\left( - \\left( \\frac{abs(x - \\mu)}{\\alpha} \\right)^{\\beta} \\right)
371
+
372
+ where:
373
+ - μ (mu) is the center frequency (`f_mean`).
374
+
375
+ - α (alpha) is the scale parameter, reparameterized in terms of the full width at half maximum (FWHM) `h` as:
376
+
377
+ .. math::
378
+
379
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
380
+
381
+ - β (beta) is the shape parameter (`shape`), controlling the shape of the filter.
382
+
383
+ The filters are constructed in the frequency domain to allow full control
384
+ over the magnitude and phase responses.
385
+
386
+ A linear phase response is used, with an optional trainable group delay (`group_delay`).
387
+
388
+ - Copyright (C) Cogitat, Ltd.
389
+ - Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
390
+ - Patent GB2609265 - Learnable filters for eeg classification
391
+ - https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
392
+
393
+ Parameters
394
+ ----------
395
+ in_channels : int
396
+ Number of input channels.
397
+ out_channels : int
398
+ Number of output channels. Must be a multiple of `in_channels`.
399
+ sequence_length : int
400
+ Length of the input sequences (time steps).
401
+ sample_rate : float
402
+ Sampling rate of the input signals in Hz.
403
+ inverse_fourier : bool, optional
404
+ If True, applies the inverse Fourier transform to return to the time domain after filtering.
405
+ Default is True.
406
+ affine_group_delay : bool, optional
407
+ If True, makes the group delay parameter trainable. Default is False.
408
+ group_delay : tuple of float, optional
409
+ Initial group delay(s) in milliseconds for the filters. Default is (20.0,).
410
+ f_mean : tuple of float, optional
411
+ Initial center frequency (frequencies) of the filters in Hz. Default is (23.0,).
412
+ bandwidth : tuple of float, optional
413
+ Initial bandwidth(s) (full width at half maximum) of the filters in Hz. Default is (44.0,).
414
+ shape : tuple of float, optional
415
+ Initial shape parameter(s) of the generalized Gaussian filters. Must be >= 2.0. Default is (2.0,).
416
+ clamp_f_mean : tuple of float, optional
417
+ Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
418
+ Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
419
+
420
+ Examples
421
+ --------
422
+ >>> import torch
423
+ >>> from braindecode.modules import GeneralizedGaussianFilter
424
+ >>> module = GeneralizedGaussianFilter(
425
+ ... in_channels=2,
426
+ ... out_channels=2,
427
+ ... sequence_length=256,
428
+ ... sample_rate=128,
429
+ ... inverse_fourier=True,
430
+ ... )
431
+ >>> inputs = torch.randn(3, 2, 256)
432
+ >>> outputs = module(inputs)
433
+ >>> outputs.shape
434
+ torch.Size([3, 2, 256])
435
+
436
+ Notes
437
+ -----
438
+ The model and the module **have a patent** [eegminercode]_, and the **code is CC BY-NC 4.0**.
439
+
440
+ .. versionadded:: 0.9
441
+
442
+ References
443
+ ----------
444
+ .. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
445
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
446
+ of brain activity with learnable filters. Journal of Neural Engineering,
447
+ 21(3), 036010.
448
+ .. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
449
+ Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
450
+ of brain activity with learnable filters.
451
+ https://github.com/SMLudwig/EEGminer/.
452
+ Cogitat, Ltd. "Learnable filters for EEG classification."
453
+ Patent GB2609265.
454
+ https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ in_channels,
460
+ out_channels,
461
+ sequence_length,
462
+ sample_rate,
463
+ inverse_fourier=True,
464
+ affine_group_delay=False,
465
+ group_delay=(20.0,),
466
+ f_mean=(23.0,),
467
+ bandwidth=(44.0,),
468
+ shape=(2.0,),
469
+ clamp_f_mean=(1.0, 45.0),
470
+ ):
471
+ super(GeneralizedGaussianFilter, self).__init__()
472
+ self.in_channels = in_channels
473
+ self.out_channels = out_channels
474
+ self.sequence_length = sequence_length
475
+ self.sample_rate = sample_rate
476
+ self.inverse_fourier = inverse_fourier
477
+ self.affine_group_delay = affine_group_delay
478
+ self.clamp_f_mean = clamp_f_mean
479
+ if out_channels % in_channels != 0:
480
+ raise ValueError("out_channels has to be multiple of in_channels")
481
+ if len(f_mean) * in_channels != out_channels:
482
+ raise ValueError("len(f_mean) * in_channels must equal out_channels")
483
+ if len(bandwidth) * in_channels != out_channels:
484
+ raise ValueError("len(bandwidth) * in_channels must equal out_channels")
485
+ if len(shape) * in_channels != out_channels:
486
+ raise ValueError("len(shape) * in_channels must equal out_channels")
487
+
488
+ # Range from 0 to half sample rate, normalized
489
+ self.n_range = nn.Parameter(
490
+ torch.tensor(
491
+ list(
492
+ fftfreq(n=sequence_length, d=1 / sample_rate)[
493
+ : sequence_length // 2
494
+ ]
495
+ )
496
+ + [sample_rate / 2]
497
+ )
498
+ / (sample_rate / 2),
499
+ requires_grad=False,
500
+ )
501
+
502
+ # Trainable filter parameters
503
+ self.f_mean = nn.Parameter(
504
+ torch.tensor(f_mean * in_channels) / (sample_rate / 2), requires_grad=True
505
+ )
506
+ self.bandwidth = nn.Parameter(
507
+ torch.tensor(bandwidth * in_channels) / (sample_rate / 2),
508
+ requires_grad=True,
509
+ ) # full width half maximum
510
+ self.shape = nn.Parameter(torch.tensor(shape * in_channels), requires_grad=True)
511
+
512
+ # Normalize group delay so that group_delay=1 corresponds to 1000ms
513
+ self.group_delay = nn.Parameter(
514
+ torch.tensor(group_delay * in_channels) / 1000,
515
+ requires_grad=affine_group_delay,
516
+ )
517
+
518
+ # Construct filters from parameters
519
+ self.filters = self.construct_filters()
520
+
521
+ @staticmethod
522
+ def exponential_power(x, mean, fwhm, shape):
523
+ """
524
+ Computes the generalized Gaussian function:
525
+
526
+ .. math::
527
+
528
+ F(x) = \\exp\\left( - \\left( \\frac{|x - \\mu|}{\\alpha} \\right)^{\\beta} \\right)
529
+
530
+ where:
531
+
532
+ - :math:`\\mu` is the mean (`mean`).
533
+
534
+ - :math:`\\alpha` is the scale parameter, reparameterized using the FWHM :math:`h` as:
535
+
536
+ .. math::
537
+
538
+ \\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
539
+
540
+ - :math:`\\beta` is the shape parameter (`shape`).
541
+
542
+ Parameters
543
+ ----------
544
+ x : torch.Tensor
545
+ The input tensor representing frequencies, normalized between 0 and 1.
546
+ mean : torch.Tensor
547
+ The center frequency (`f_mean`), normalized between 0 and 1.
548
+ fwhm : torch.Tensor
549
+ The full width at half maximum (`bandwidth`), normalized between 0 and 1.
550
+ shape : torch.Tensor
551
+ The shape parameter (`shape`) of the generalized Gaussian.
552
+
553
+ Returns
554
+ -------
555
+ torch.Tensor
556
+ The computed generalized Gaussian function values at frequencies `x`.
557
+
558
+ """
559
+ mean = mean.unsqueeze(1)
560
+ fwhm = fwhm.unsqueeze(1)
561
+ shape = shape.unsqueeze(1)
562
+ log2 = torch.log(torch.tensor(2.0, device=x.device, dtype=x.dtype))
563
+ scale = fwhm / (2 * log2 ** (1 / shape))
564
+ # Add small constant to difference between x and mean since grad of 0 ** shape is nan
565
+ return torch.exp(-((((x - mean).abs() + 1e-8) / scale) ** shape))
566
+
567
+ def construct_filters(self):
568
+ """
569
+ Constructs the filters in the frequency domain based on current parameters.
570
+
571
+ Returns
572
+ -------
573
+ torch.Tensor
574
+ The constructed filters with shape `(out_channels, freq_bins, 2)`.
575
+
576
+ """
577
+ # Clamp parameters
578
+ self.f_mean.data = torch.clamp(
579
+ self.f_mean.data,
580
+ min=self.clamp_f_mean[0] / (self.sample_rate / 2),
581
+ max=self.clamp_f_mean[1] / (self.sample_rate / 2),
582
+ )
583
+ self.bandwidth.data = torch.clamp(
584
+ self.bandwidth.data, min=1.0 / (self.sample_rate / 2), max=1.0
585
+ )
586
+ self.shape.data = torch.clamp(self.shape.data, min=2.0, max=3.0)
587
+
588
+ # Create magnitude response with gain=1 -> (channels, freqs)
589
+ mag_response = self.exponential_power(
590
+ self.n_range, self.f_mean, self.bandwidth, self.shape * 8 - 14
591
+ )
592
+ mag_response = mag_response / mag_response.max(dim=-1, keepdim=True)[0]
593
+
594
+ # Create phase response, scaled so that normalized group_delay=1
595
+ # corresponds to group delay of 1000ms.
596
+ phase = torch.linspace(
597
+ 0,
598
+ self.sample_rate,
599
+ self.sequence_length // 2 + 1,
600
+ device=mag_response.device,
601
+ dtype=mag_response.dtype,
602
+ )
603
+ phase = phase.expand(mag_response.shape[0], -1) # repeat for filter channels
604
+ pha_response = -self.group_delay.unsqueeze(-1) * phase * torch.pi
605
+
606
+ # Create real and imaginary parts of the filters
607
+ real = mag_response * torch.cos(pha_response)
608
+ imag = mag_response * torch.sin(pha_response)
609
+
610
+ # Stack real and imaginary parts to create filters
611
+ # -> (channels, freqs, 2)
612
+ filters = torch.stack((real, imag), dim=-1)
613
+
614
+ return filters
615
+
616
+ def forward(self, x):
617
+ """
618
+ Applies the generalized Gaussian filters to the input signal.
619
+
620
+ Parameters
621
+ ----------
622
+ x : torch.Tensor
623
+ Input tensor of shape `(..., in_channels, sequence_length)`.
624
+
625
+ Returns
626
+ -------
627
+ torch.Tensor
628
+ The filtered signal. If `inverse_fourier` is True, returns the signal in the time domain
629
+ with shape `(..., out_channels, sequence_length)`. Otherwise, returns the signal in the
630
+ frequency domain with shape `(..., out_channels, freq_bins, 2)`.
631
+
632
+ """
633
+ # Construct filters from parameters
634
+ self.filters = self.construct_filters()
635
+ # Preserving the original dtype.
636
+ dtype = x.dtype
637
+ # Apply FFT -> (..., channels, freqs, 2)
638
+ x = torch.fft.rfft(x, dim=-1)
639
+ x = torch.view_as_real(x) # separate real and imag
640
+
641
+ # Repeat channels in case of multiple filters per channel
642
+ x = torch.repeat_interleave(x, self.out_channels // self.in_channels, dim=-3)
643
+
644
+ # Apply filters in the frequency domain
645
+ x = x * self.filters
646
+
647
+ # Apply inverse FFT if requested
648
+ if self.inverse_fourier:
649
+ x = torch.view_as_complex(x)
650
+ x = torch.fft.irfft(x, n=self.sequence_length, dim=-1)
651
+
652
+ x = x.to(dtype)
653
+
654
+ return x