braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,654 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from mne.filter import _check_coefficients, create_filter
|
|
7
|
+
from mne.utils import warn
|
|
8
|
+
from torch import Tensor, from_numpy, nn
|
|
9
|
+
from torch.fft import fftfreq
|
|
10
|
+
from torchaudio.functional import fftconvolve, filtfilt
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FilterBankLayer(nn.Module):
|
|
14
|
+
"""Apply multiple band-pass filters to generate multiview signal representation.
|
|
15
|
+
|
|
16
|
+
This layer constructs a bank of signals filtered in specific bands for each channel.
|
|
17
|
+
It uses MNE's `create_filter` function to create the band-specific filters and
|
|
18
|
+
applies them to multi-channel time-series data. Each filter in the bank corresponds to a
|
|
19
|
+
specific frequency band and is applied to all channels of the input data. The filtering is
|
|
20
|
+
performed using FFT-based convolution via the ``torchaudio.functional`` if the method is FIR,
|
|
21
|
+
and ``torchaudio.functional`` if the method is IIR.
|
|
22
|
+
|
|
23
|
+
The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
|
|
24
|
+
spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
|
|
25
|
+
reference: *FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface*.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
n_chans : int
|
|
30
|
+
Number of channels in the input signal.
|
|
31
|
+
sfreq : int
|
|
32
|
+
Sampling frequency of the input signal in Hz.
|
|
33
|
+
band_filters : Optional[list[tuple[float, float]]] or int, default=None
|
|
34
|
+
List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines
|
|
35
|
+
the frequency range for one filter in the bank. If not provided, defaults
|
|
36
|
+
to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
|
|
37
|
+
method : str, default='fir'
|
|
38
|
+
``'fir'`` will use FIR filtering, ``'iir'`` will use IIR
|
|
39
|
+
forward-backward filtering (via :func:`~scipy.signal.filtfilt`).
|
|
40
|
+
For more details, please check the `MNE Preprocessing Tutorial <https://mne.tools/stable/auto_tutorials/preprocessing/25_background_filtering.html>`_.
|
|
41
|
+
filter_length : str | int
|
|
42
|
+
Length of the FIR filter to use (if applicable):
|
|
43
|
+
|
|
44
|
+
* **'auto' (default)**: The filter length is chosen based
|
|
45
|
+
on the size of the transition regions (6.6 times the reciprocal
|
|
46
|
+
of the shortest transition band for fir_window='hamming'
|
|
47
|
+
and fir_design="firwin2", and half that for "firwin").
|
|
48
|
+
* **str**: A human-readable time in
|
|
49
|
+
units of "s" or "ms" (e.g., "10s" or "5500ms") will be
|
|
50
|
+
converted to that number of samples if ``phase="zero"``, or
|
|
51
|
+
the shortest power-of-two length at least that duration for
|
|
52
|
+
``phase="zero-double"``.
|
|
53
|
+
* **int**: Specified length in samples. For fir_design="firwin",
|
|
54
|
+
this should not be used.
|
|
55
|
+
l_trans_bandwidth : Union[str, float, int], default='auto'
|
|
56
|
+
Width of the transition band at the low cut-off frequency in Hz
|
|
57
|
+
(high pass or cutoff 1 in bandpass). Can be "auto"
|
|
58
|
+
(default) to use a multiple of ``l_freq``::
|
|
59
|
+
|
|
60
|
+
min(max(l_freq * 0.25, 2), l_freq)
|
|
61
|
+
|
|
62
|
+
Only used for ``method='fir'``.
|
|
63
|
+
h_trans_bandwidth : Union[str, float, int], default='auto'
|
|
64
|
+
Width of the transition band at the high cut-off frequency in Hz
|
|
65
|
+
(low pass or cutoff 2 in bandpass). Can be "auto"
|
|
66
|
+
(default in 0.14) to use a multiple of ``h_freq``::
|
|
67
|
+
|
|
68
|
+
min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
|
|
69
|
+
|
|
70
|
+
Only used for ``method='fir'``.
|
|
71
|
+
phase : str, default='zero'
|
|
72
|
+
Phase of the filter.
|
|
73
|
+
When ``method='fir'``, symmetric linear-phase FIR filters are constructed
|
|
74
|
+
with the following behaviors when ``method="fir"``:
|
|
75
|
+
|
|
76
|
+
``"zero"`` (default)
|
|
77
|
+
The delay of this filter is compensated for, making it non-causal.
|
|
78
|
+
``"minimum"``
|
|
79
|
+
A minimum-phase filter will be constructed by decomposing the zero-phase filter
|
|
80
|
+
into a minimum-phase and all-pass systems, and then retaining only the
|
|
81
|
+
minimum-phase system (of the same length as the original zero-phase filter)
|
|
82
|
+
via :func:`scipy.signal.minimum_phase`.
|
|
83
|
+
``"zero-double"``
|
|
84
|
+
*This is a legacy option for compatibility with MNE <= 0.13.*
|
|
85
|
+
The filter is applied twice, once forward, and once backward
|
|
86
|
+
(also making it non-causal).
|
|
87
|
+
``"minimum-half"``
|
|
88
|
+
*This is a legacy option for compatibility with MNE <= 1.6.*
|
|
89
|
+
A minimum-phase filter will be reconstructed from the zero-phase filter with
|
|
90
|
+
half the length of the original filter.
|
|
91
|
+
|
|
92
|
+
When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'``
|
|
93
|
+
constructs and applies IIR filter twice, once forward, and once backward (making it
|
|
94
|
+
non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
|
|
95
|
+
the filter once in the forward (causal) direction using
|
|
96
|
+
:func:`~scipy.signal.lfilter`.
|
|
97
|
+
|
|
98
|
+
The behavior for ``phase="minimum"`` was fixed to use a filter of the requested
|
|
99
|
+
length and improved suppression.
|
|
100
|
+
iir_params : Optional[dict], default=None
|
|
101
|
+
Dictionary of parameters to use for IIR filtering.
|
|
102
|
+
If ``iir_params=None`` and ``method="iir"``, 4th order Butterworth will be used.
|
|
103
|
+
For more information, see :func:`mne.filter.construct_iir_filter`.
|
|
104
|
+
fir_window : str, default='hamming'
|
|
105
|
+
The window to use in FIR design, can be "hamming" (default),
|
|
106
|
+
"hann" (default in 0.13), or "blackman".
|
|
107
|
+
fir_design : str, default='firwin'
|
|
108
|
+
Can be "firwin" (default) to use :func:`scipy.signal.firwin`,
|
|
109
|
+
or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses
|
|
110
|
+
a time-domain design technique that generally gives improved
|
|
111
|
+
attenuation using fewer samples than "firwin2".
|
|
112
|
+
verbose: bool | str | int | None, default=True
|
|
113
|
+
Control verbosity of the logging output. If ``None``, use the default
|
|
114
|
+
verbosity level. See the func:`mne.verbose` for details.
|
|
115
|
+
Should only be passed as a keyword argument.
|
|
116
|
+
|
|
117
|
+
Examples
|
|
118
|
+
--------
|
|
119
|
+
>>> import torch
|
|
120
|
+
>>> from braindecode.modules import FilterBankLayer
|
|
121
|
+
>>> module = FilterBankLayer(
|
|
122
|
+
... n_chans=2,
|
|
123
|
+
... sfreq=128,
|
|
124
|
+
... band_filters=[(4.0, 8.0), (8.0, 12.0)],
|
|
125
|
+
... method="fir",
|
|
126
|
+
... verbose=False,
|
|
127
|
+
... )
|
|
128
|
+
>>> inputs = torch.randn(2, 2, 256)
|
|
129
|
+
>>> outputs = module(inputs)
|
|
130
|
+
>>> outputs.shape
|
|
131
|
+
torch.Size([2, 2, 2, 256])
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
n_chans: int,
|
|
137
|
+
sfreq: float,
|
|
138
|
+
band_filters: Optional[list[tuple[float, float]] | int] = None,
|
|
139
|
+
method: str = "fir",
|
|
140
|
+
filter_length: str | float | int = "auto",
|
|
141
|
+
l_trans_bandwidth: str | float | int = "auto",
|
|
142
|
+
h_trans_bandwidth: str | float | int = "auto",
|
|
143
|
+
phase: str = "zero",
|
|
144
|
+
iir_params: Optional[dict] = None,
|
|
145
|
+
fir_window: str = "hamming",
|
|
146
|
+
fir_design: str = "firwin",
|
|
147
|
+
verbose: bool = True,
|
|
148
|
+
):
|
|
149
|
+
super(FilterBankLayer, self).__init__()
|
|
150
|
+
|
|
151
|
+
# The first step here is to check the band_filters
|
|
152
|
+
# We accept as None values.
|
|
153
|
+
if band_filters is None:
|
|
154
|
+
"""
|
|
155
|
+
the filter bank is constructed using 9 filters with non-overlapping
|
|
156
|
+
frequency bands, each of 4Hz bandwidth, spanning from 4 to 40 Hz
|
|
157
|
+
(4-8, 8-12, …, 36-40 Hz)
|
|
158
|
+
|
|
159
|
+
Based on the reference: FBCNet: A Multi-view Convolutional Neural
|
|
160
|
+
Network for Brain-Computer Interface
|
|
161
|
+
"""
|
|
162
|
+
band_filters = [(low, low + 4) for low in range(4, 36 + 1, 4)]
|
|
163
|
+
# We accept as int.
|
|
164
|
+
if isinstance(band_filters, int):
|
|
165
|
+
warn(
|
|
166
|
+
"Creating the filter banks equally divided in the "
|
|
167
|
+
"interval 4Hz to 40Hz with almost equal bandwidths. "
|
|
168
|
+
"If you want a specific interval, "
|
|
169
|
+
"please specify 'band_filters' as a list of tuples.",
|
|
170
|
+
UserWarning,
|
|
171
|
+
)
|
|
172
|
+
start = 4.0
|
|
173
|
+
end = 40.0
|
|
174
|
+
|
|
175
|
+
total_band_width = end - start # 4 Hz to 40 Hz
|
|
176
|
+
|
|
177
|
+
band_width_calculated = total_band_width / band_filters
|
|
178
|
+
band_filters = [
|
|
179
|
+
(
|
|
180
|
+
float(start + i * band_width_calculated),
|
|
181
|
+
float(start + (i + 1) * band_width_calculated),
|
|
182
|
+
)
|
|
183
|
+
for i in range(band_filters)
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
if not isinstance(band_filters, list):
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"`band_filters` should be a list of tuples if you want to "
|
|
189
|
+
"use them this way."
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
if any(len(bands) != 2 for bands in band_filters):
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"The band_filters items should be splitable in 2 values."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
self.band_filters = band_filters
|
|
198
|
+
self.n_bands = len(band_filters)
|
|
199
|
+
self.phase = phase
|
|
200
|
+
self.method = method
|
|
201
|
+
self.n_chans = n_chans
|
|
202
|
+
|
|
203
|
+
self.method_iir = self.method == "iir"
|
|
204
|
+
|
|
205
|
+
# Prepare ParameterLists
|
|
206
|
+
self.fir_list = nn.ParameterList()
|
|
207
|
+
self.b_list = nn.ParameterList()
|
|
208
|
+
self.a_list = nn.ParameterList()
|
|
209
|
+
|
|
210
|
+
if self.method_iir:
|
|
211
|
+
if iir_params is None:
|
|
212
|
+
iir_params = dict(output="ba")
|
|
213
|
+
else:
|
|
214
|
+
if "output" in iir_params:
|
|
215
|
+
if iir_params["output"] == "sos":
|
|
216
|
+
warn(
|
|
217
|
+
"It is not possible to use second-order section filtering with Torch. Changing to filter ba",
|
|
218
|
+
UserWarning,
|
|
219
|
+
)
|
|
220
|
+
iir_params["output"] = "ba"
|
|
221
|
+
|
|
222
|
+
for l_freq, h_freq in band_filters:
|
|
223
|
+
filt = create_filter(
|
|
224
|
+
data=None,
|
|
225
|
+
sfreq=sfreq,
|
|
226
|
+
l_freq=float(l_freq),
|
|
227
|
+
h_freq=float(h_freq),
|
|
228
|
+
filter_length=filter_length,
|
|
229
|
+
l_trans_bandwidth=l_trans_bandwidth,
|
|
230
|
+
h_trans_bandwidth=h_trans_bandwidth,
|
|
231
|
+
method=self.method,
|
|
232
|
+
iir_params=iir_params,
|
|
233
|
+
phase=phase,
|
|
234
|
+
fir_window=fir_window,
|
|
235
|
+
fir_design=fir_design,
|
|
236
|
+
verbose=verbose,
|
|
237
|
+
)
|
|
238
|
+
if not self.method_iir:
|
|
239
|
+
# FIR filter
|
|
240
|
+
filt = from_numpy(filt).float()
|
|
241
|
+
self.fir_list.append(nn.Parameter(filt, requires_grad=False))
|
|
242
|
+
else:
|
|
243
|
+
a_coeffs = filt["a"]
|
|
244
|
+
b_coeffs = filt["b"]
|
|
245
|
+
|
|
246
|
+
_check_coefficients((b_coeffs, a_coeffs))
|
|
247
|
+
|
|
248
|
+
b = torch.tensor(b_coeffs, dtype=torch.float64)
|
|
249
|
+
a = torch.tensor(a_coeffs, dtype=torch.float64)
|
|
250
|
+
|
|
251
|
+
self.b_list.append(nn.Parameter(b, requires_grad=False))
|
|
252
|
+
self.a_list.append(nn.Parameter(a, requires_grad=False))
|
|
253
|
+
|
|
254
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
255
|
+
"""
|
|
256
|
+
Apply the filter bank to the input signal.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
x : torch.Tensor
|
|
261
|
+
Input tensor of shape (batch_size, n_chans, time_points).
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
torch.Tensor
|
|
266
|
+
Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
|
|
267
|
+
"""
|
|
268
|
+
outs = []
|
|
269
|
+
if self.method_iir:
|
|
270
|
+
for b, a in zip(self.b_list, self.a_list):
|
|
271
|
+
# Pass numerator and denominator directly
|
|
272
|
+
outs.append(self._apply_iir(x=x, b_coeffs=b, a_coeffs=a))
|
|
273
|
+
else:
|
|
274
|
+
for fir in self.fir_list:
|
|
275
|
+
# Pass FIR filter directly
|
|
276
|
+
outs.append(self._apply_fir(x=x, filt=fir, n_chans=self.n_chans))
|
|
277
|
+
|
|
278
|
+
return torch.cat(outs, dim=1)
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def _apply_fir(x, filt: Tensor, n_chans: int) -> Tensor:
|
|
282
|
+
"""
|
|
283
|
+
Apply an FIR filter to the input tensor.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
x : Tensor
|
|
288
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
289
|
+
filter : dict
|
|
290
|
+
Dictionary containing IIR filter coefficients.
|
|
291
|
+
- "b": Tensor of numerator coefficients.
|
|
292
|
+
n_chans: int
|
|
293
|
+
Number of channels
|
|
294
|
+
|
|
295
|
+
Returns
|
|
296
|
+
-------
|
|
297
|
+
Tensor
|
|
298
|
+
Filtered tensor of shape (batch_size, 1, n_chans, n_times).
|
|
299
|
+
"""
|
|
300
|
+
# Expand filter coefficients to match the number of channels
|
|
301
|
+
# Original 'b' shape: (filter_length,)
|
|
302
|
+
# After unsqueeze and repeat: (n_chans, filter_length)
|
|
303
|
+
# After final unsqueeze: (1, n_chans, filter_length)
|
|
304
|
+
filt_expanded = filt.to(x.device).unsqueeze(0).repeat(n_chans, 1).unsqueeze(0)
|
|
305
|
+
|
|
306
|
+
# Perform FFT-based convolution
|
|
307
|
+
# Input x shape: (batch_size, n_chans, n_times)
|
|
308
|
+
# filt_expanded shape: (1, n_chans, filter_length)
|
|
309
|
+
# After convolution: (batch_size, n_chans, n_times)
|
|
310
|
+
|
|
311
|
+
filtered = fftconvolve(
|
|
312
|
+
x, filt_expanded, mode="same"
|
|
313
|
+
) # Shape: (batch_size, nchans, time_points)
|
|
314
|
+
|
|
315
|
+
# Add a new dimension for the band
|
|
316
|
+
# Shape after unsqueeze: (batch_size, 1, n_chans, n_times)
|
|
317
|
+
filtered = filtered.unsqueeze(1)
|
|
318
|
+
# returning the filtered
|
|
319
|
+
return filtered
|
|
320
|
+
|
|
321
|
+
@staticmethod
|
|
322
|
+
def _apply_iir(x: Tensor, b_coeffs: Tensor, a_coeffs: Tensor) -> Tensor:
|
|
323
|
+
"""
|
|
324
|
+
Apply an IIR filter to the input tensor.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
x : Tensor
|
|
329
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
330
|
+
filter : dict
|
|
331
|
+
Dictionary containing IIR filter coefficients
|
|
332
|
+
|
|
333
|
+
- "b": Tensor of numerator coefficients.
|
|
334
|
+
- "a": Tensor of denominator coefficients.
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
Tensor
|
|
339
|
+
Filtered tensor of shape (batch_size, 1, n_chans, n_times).
|
|
340
|
+
"""
|
|
341
|
+
# Apply filtering using torchaudio's filtfilt
|
|
342
|
+
filtered = filtfilt(
|
|
343
|
+
x,
|
|
344
|
+
a_coeffs=a_coeffs.type_as(x).to(x.device),
|
|
345
|
+
b_coeffs=b_coeffs.type_as(x).to(x.device),
|
|
346
|
+
clamp=False,
|
|
347
|
+
)
|
|
348
|
+
# Rearrange dimensions to (batch_size, 1, n_chans, n_times)
|
|
349
|
+
return filtered.unsqueeze(1)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class GeneralizedGaussianFilter(nn.Module):
|
|
353
|
+
"""Generalized Gaussian Filter from Ludwig et al (2024) [eegminer]_.
|
|
354
|
+
|
|
355
|
+
Implements trainable temporal filters based on generalized Gaussian functions
|
|
356
|
+
in the frequency domain.
|
|
357
|
+
|
|
358
|
+
This module creates filters in the frequency domain using the generalized
|
|
359
|
+
Gaussian function, allowing for trainable center frequency (`f_mean`),
|
|
360
|
+
bandwidth (`bandwidth`), and shape (`shape`) parameters.
|
|
361
|
+
|
|
362
|
+
The filters are applied to the input signal in the frequency domain, and can
|
|
363
|
+
be optionally transformed back to the time domain using the inverse
|
|
364
|
+
Fourier transform.
|
|
365
|
+
|
|
366
|
+
The generalized Gaussian function in the frequency domain is defined as:
|
|
367
|
+
|
|
368
|
+
.. math::
|
|
369
|
+
|
|
370
|
+
F(x) = \\exp\\left( - \\left( \\frac{abs(x - \\mu)}{\\alpha} \\right)^{\\beta} \\right)
|
|
371
|
+
|
|
372
|
+
where:
|
|
373
|
+
- μ (mu) is the center frequency (`f_mean`).
|
|
374
|
+
|
|
375
|
+
- α (alpha) is the scale parameter, reparameterized in terms of the full width at half maximum (FWHM) `h` as:
|
|
376
|
+
|
|
377
|
+
.. math::
|
|
378
|
+
|
|
379
|
+
\\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
|
|
380
|
+
|
|
381
|
+
- β (beta) is the shape parameter (`shape`), controlling the shape of the filter.
|
|
382
|
+
|
|
383
|
+
The filters are constructed in the frequency domain to allow full control
|
|
384
|
+
over the magnitude and phase responses.
|
|
385
|
+
|
|
386
|
+
A linear phase response is used, with an optional trainable group delay (`group_delay`).
|
|
387
|
+
|
|
388
|
+
- Copyright (C) Cogitat, Ltd.
|
|
389
|
+
- Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
|
390
|
+
- Patent GB2609265 - Learnable filters for eeg classification
|
|
391
|
+
- https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
in_channels : int
|
|
396
|
+
Number of input channels.
|
|
397
|
+
out_channels : int
|
|
398
|
+
Number of output channels. Must be a multiple of `in_channels`.
|
|
399
|
+
sequence_length : int
|
|
400
|
+
Length of the input sequences (time steps).
|
|
401
|
+
sample_rate : float
|
|
402
|
+
Sampling rate of the input signals in Hz.
|
|
403
|
+
inverse_fourier : bool, optional
|
|
404
|
+
If True, applies the inverse Fourier transform to return to the time domain after filtering.
|
|
405
|
+
Default is True.
|
|
406
|
+
affine_group_delay : bool, optional
|
|
407
|
+
If True, makes the group delay parameter trainable. Default is False.
|
|
408
|
+
group_delay : tuple of float, optional
|
|
409
|
+
Initial group delay(s) in milliseconds for the filters. Default is (20.0,).
|
|
410
|
+
f_mean : tuple of float, optional
|
|
411
|
+
Initial center frequency (frequencies) of the filters in Hz. Default is (23.0,).
|
|
412
|
+
bandwidth : tuple of float, optional
|
|
413
|
+
Initial bandwidth(s) (full width at half maximum) of the filters in Hz. Default is (44.0,).
|
|
414
|
+
shape : tuple of float, optional
|
|
415
|
+
Initial shape parameter(s) of the generalized Gaussian filters. Must be >= 2.0. Default is (2.0,).
|
|
416
|
+
clamp_f_mean : tuple of float, optional
|
|
417
|
+
Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
|
|
418
|
+
Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
|
|
419
|
+
|
|
420
|
+
Examples
|
|
421
|
+
--------
|
|
422
|
+
>>> import torch
|
|
423
|
+
>>> from braindecode.modules import GeneralizedGaussianFilter
|
|
424
|
+
>>> module = GeneralizedGaussianFilter(
|
|
425
|
+
... in_channels=2,
|
|
426
|
+
... out_channels=2,
|
|
427
|
+
... sequence_length=256,
|
|
428
|
+
... sample_rate=128,
|
|
429
|
+
... inverse_fourier=True,
|
|
430
|
+
... )
|
|
431
|
+
>>> inputs = torch.randn(3, 2, 256)
|
|
432
|
+
>>> outputs = module(inputs)
|
|
433
|
+
>>> outputs.shape
|
|
434
|
+
torch.Size([3, 2, 256])
|
|
435
|
+
|
|
436
|
+
Notes
|
|
437
|
+
-----
|
|
438
|
+
The model and the module **have a patent** [eegminercode]_, and the **code is CC BY-NC 4.0**.
|
|
439
|
+
|
|
440
|
+
.. versionadded:: 0.9
|
|
441
|
+
|
|
442
|
+
References
|
|
443
|
+
----------
|
|
444
|
+
.. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
445
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
446
|
+
of brain activity with learnable filters. Journal of Neural Engineering,
|
|
447
|
+
21(3), 036010.
|
|
448
|
+
.. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
449
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
450
|
+
of brain activity with learnable filters.
|
|
451
|
+
https://github.com/SMLudwig/EEGminer/.
|
|
452
|
+
Cogitat, Ltd. "Learnable filters for EEG classification."
|
|
453
|
+
Patent GB2609265.
|
|
454
|
+
https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
in_channels,
|
|
460
|
+
out_channels,
|
|
461
|
+
sequence_length,
|
|
462
|
+
sample_rate,
|
|
463
|
+
inverse_fourier=True,
|
|
464
|
+
affine_group_delay=False,
|
|
465
|
+
group_delay=(20.0,),
|
|
466
|
+
f_mean=(23.0,),
|
|
467
|
+
bandwidth=(44.0,),
|
|
468
|
+
shape=(2.0,),
|
|
469
|
+
clamp_f_mean=(1.0, 45.0),
|
|
470
|
+
):
|
|
471
|
+
super(GeneralizedGaussianFilter, self).__init__()
|
|
472
|
+
self.in_channels = in_channels
|
|
473
|
+
self.out_channels = out_channels
|
|
474
|
+
self.sequence_length = sequence_length
|
|
475
|
+
self.sample_rate = sample_rate
|
|
476
|
+
self.inverse_fourier = inverse_fourier
|
|
477
|
+
self.affine_group_delay = affine_group_delay
|
|
478
|
+
self.clamp_f_mean = clamp_f_mean
|
|
479
|
+
if out_channels % in_channels != 0:
|
|
480
|
+
raise ValueError("out_channels has to be multiple of in_channels")
|
|
481
|
+
if len(f_mean) * in_channels != out_channels:
|
|
482
|
+
raise ValueError("len(f_mean) * in_channels must equal out_channels")
|
|
483
|
+
if len(bandwidth) * in_channels != out_channels:
|
|
484
|
+
raise ValueError("len(bandwidth) * in_channels must equal out_channels")
|
|
485
|
+
if len(shape) * in_channels != out_channels:
|
|
486
|
+
raise ValueError("len(shape) * in_channels must equal out_channels")
|
|
487
|
+
|
|
488
|
+
# Range from 0 to half sample rate, normalized
|
|
489
|
+
self.n_range = nn.Parameter(
|
|
490
|
+
torch.tensor(
|
|
491
|
+
list(
|
|
492
|
+
fftfreq(n=sequence_length, d=1 / sample_rate)[
|
|
493
|
+
: sequence_length // 2
|
|
494
|
+
]
|
|
495
|
+
)
|
|
496
|
+
+ [sample_rate / 2]
|
|
497
|
+
)
|
|
498
|
+
/ (sample_rate / 2),
|
|
499
|
+
requires_grad=False,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Trainable filter parameters
|
|
503
|
+
self.f_mean = nn.Parameter(
|
|
504
|
+
torch.tensor(f_mean * in_channels) / (sample_rate / 2), requires_grad=True
|
|
505
|
+
)
|
|
506
|
+
self.bandwidth = nn.Parameter(
|
|
507
|
+
torch.tensor(bandwidth * in_channels) / (sample_rate / 2),
|
|
508
|
+
requires_grad=True,
|
|
509
|
+
) # full width half maximum
|
|
510
|
+
self.shape = nn.Parameter(torch.tensor(shape * in_channels), requires_grad=True)
|
|
511
|
+
|
|
512
|
+
# Normalize group delay so that group_delay=1 corresponds to 1000ms
|
|
513
|
+
self.group_delay = nn.Parameter(
|
|
514
|
+
torch.tensor(group_delay * in_channels) / 1000,
|
|
515
|
+
requires_grad=affine_group_delay,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Construct filters from parameters
|
|
519
|
+
self.filters = self.construct_filters()
|
|
520
|
+
|
|
521
|
+
@staticmethod
|
|
522
|
+
def exponential_power(x, mean, fwhm, shape):
|
|
523
|
+
"""
|
|
524
|
+
Computes the generalized Gaussian function:
|
|
525
|
+
|
|
526
|
+
.. math::
|
|
527
|
+
|
|
528
|
+
F(x) = \\exp\\left( - \\left( \\frac{|x - \\mu|}{\\alpha} \\right)^{\\beta} \\right)
|
|
529
|
+
|
|
530
|
+
where:
|
|
531
|
+
|
|
532
|
+
- :math:`\\mu` is the mean (`mean`).
|
|
533
|
+
|
|
534
|
+
- :math:`\\alpha` is the scale parameter, reparameterized using the FWHM :math:`h` as:
|
|
535
|
+
|
|
536
|
+
.. math::
|
|
537
|
+
|
|
538
|
+
\\alpha = \\frac{h}{2 \\left( \\ln(2) \\right)^{1/\\beta}}
|
|
539
|
+
|
|
540
|
+
- :math:`\\beta` is the shape parameter (`shape`).
|
|
541
|
+
|
|
542
|
+
Parameters
|
|
543
|
+
----------
|
|
544
|
+
x : torch.Tensor
|
|
545
|
+
The input tensor representing frequencies, normalized between 0 and 1.
|
|
546
|
+
mean : torch.Tensor
|
|
547
|
+
The center frequency (`f_mean`), normalized between 0 and 1.
|
|
548
|
+
fwhm : torch.Tensor
|
|
549
|
+
The full width at half maximum (`bandwidth`), normalized between 0 and 1.
|
|
550
|
+
shape : torch.Tensor
|
|
551
|
+
The shape parameter (`shape`) of the generalized Gaussian.
|
|
552
|
+
|
|
553
|
+
Returns
|
|
554
|
+
-------
|
|
555
|
+
torch.Tensor
|
|
556
|
+
The computed generalized Gaussian function values at frequencies `x`.
|
|
557
|
+
|
|
558
|
+
"""
|
|
559
|
+
mean = mean.unsqueeze(1)
|
|
560
|
+
fwhm = fwhm.unsqueeze(1)
|
|
561
|
+
shape = shape.unsqueeze(1)
|
|
562
|
+
log2 = torch.log(torch.tensor(2.0, device=x.device, dtype=x.dtype))
|
|
563
|
+
scale = fwhm / (2 * log2 ** (1 / shape))
|
|
564
|
+
# Add small constant to difference between x and mean since grad of 0 ** shape is nan
|
|
565
|
+
return torch.exp(-((((x - mean).abs() + 1e-8) / scale) ** shape))
|
|
566
|
+
|
|
567
|
+
def construct_filters(self):
|
|
568
|
+
"""
|
|
569
|
+
Constructs the filters in the frequency domain based on current parameters.
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
torch.Tensor
|
|
574
|
+
The constructed filters with shape `(out_channels, freq_bins, 2)`.
|
|
575
|
+
|
|
576
|
+
"""
|
|
577
|
+
# Clamp parameters
|
|
578
|
+
self.f_mean.data = torch.clamp(
|
|
579
|
+
self.f_mean.data,
|
|
580
|
+
min=self.clamp_f_mean[0] / (self.sample_rate / 2),
|
|
581
|
+
max=self.clamp_f_mean[1] / (self.sample_rate / 2),
|
|
582
|
+
)
|
|
583
|
+
self.bandwidth.data = torch.clamp(
|
|
584
|
+
self.bandwidth.data, min=1.0 / (self.sample_rate / 2), max=1.0
|
|
585
|
+
)
|
|
586
|
+
self.shape.data = torch.clamp(self.shape.data, min=2.0, max=3.0)
|
|
587
|
+
|
|
588
|
+
# Create magnitude response with gain=1 -> (channels, freqs)
|
|
589
|
+
mag_response = self.exponential_power(
|
|
590
|
+
self.n_range, self.f_mean, self.bandwidth, self.shape * 8 - 14
|
|
591
|
+
)
|
|
592
|
+
mag_response = mag_response / mag_response.max(dim=-1, keepdim=True)[0]
|
|
593
|
+
|
|
594
|
+
# Create phase response, scaled so that normalized group_delay=1
|
|
595
|
+
# corresponds to group delay of 1000ms.
|
|
596
|
+
phase = torch.linspace(
|
|
597
|
+
0,
|
|
598
|
+
self.sample_rate,
|
|
599
|
+
self.sequence_length // 2 + 1,
|
|
600
|
+
device=mag_response.device,
|
|
601
|
+
dtype=mag_response.dtype,
|
|
602
|
+
)
|
|
603
|
+
phase = phase.expand(mag_response.shape[0], -1) # repeat for filter channels
|
|
604
|
+
pha_response = -self.group_delay.unsqueeze(-1) * phase * torch.pi
|
|
605
|
+
|
|
606
|
+
# Create real and imaginary parts of the filters
|
|
607
|
+
real = mag_response * torch.cos(pha_response)
|
|
608
|
+
imag = mag_response * torch.sin(pha_response)
|
|
609
|
+
|
|
610
|
+
# Stack real and imaginary parts to create filters
|
|
611
|
+
# -> (channels, freqs, 2)
|
|
612
|
+
filters = torch.stack((real, imag), dim=-1)
|
|
613
|
+
|
|
614
|
+
return filters
|
|
615
|
+
|
|
616
|
+
def forward(self, x):
|
|
617
|
+
"""
|
|
618
|
+
Applies the generalized Gaussian filters to the input signal.
|
|
619
|
+
|
|
620
|
+
Parameters
|
|
621
|
+
----------
|
|
622
|
+
x : torch.Tensor
|
|
623
|
+
Input tensor of shape `(..., in_channels, sequence_length)`.
|
|
624
|
+
|
|
625
|
+
Returns
|
|
626
|
+
-------
|
|
627
|
+
torch.Tensor
|
|
628
|
+
The filtered signal. If `inverse_fourier` is True, returns the signal in the time domain
|
|
629
|
+
with shape `(..., out_channels, sequence_length)`. Otherwise, returns the signal in the
|
|
630
|
+
frequency domain with shape `(..., out_channels, freq_bins, 2)`.
|
|
631
|
+
|
|
632
|
+
"""
|
|
633
|
+
# Construct filters from parameters
|
|
634
|
+
self.filters = self.construct_filters()
|
|
635
|
+
# Preserving the original dtype.
|
|
636
|
+
dtype = x.dtype
|
|
637
|
+
# Apply FFT -> (..., channels, freqs, 2)
|
|
638
|
+
x = torch.fft.rfft(x, dim=-1)
|
|
639
|
+
x = torch.view_as_real(x) # separate real and imag
|
|
640
|
+
|
|
641
|
+
# Repeat channels in case of multiple filters per channel
|
|
642
|
+
x = torch.repeat_interleave(x, self.out_channels // self.in_channels, dim=-3)
|
|
643
|
+
|
|
644
|
+
# Apply filters in the frequency domain
|
|
645
|
+
x = x * self.filters
|
|
646
|
+
|
|
647
|
+
# Apply inverse FFT if requested
|
|
648
|
+
if self.inverse_fourier:
|
|
649
|
+
x = torch.view_as_complex(x)
|
|
650
|
+
x = torch.fft.irfft(x, n=self.sequence_length, dim=-1)
|
|
651
|
+
|
|
652
|
+
x = x.to(dtype)
|
|
653
|
+
|
|
654
|
+
return x
|