braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import ConstantPad2d
|
|
8
|
+
|
|
9
|
+
from braindecode.models.deep4 import Deep4Net
|
|
10
|
+
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HybridNet(nn.Module):
|
|
14
|
+
r"""Hybrid ConvNet model from Schirrmeister, R T et al (2017) [Schirrmeister2017]_.
|
|
15
|
+
|
|
16
|
+
See [Schirrmeister2017]_ for details.
|
|
17
|
+
|
|
18
|
+
References
|
|
19
|
+
----------
|
|
20
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
21
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
22
|
+
& Ball, T. (2017).
|
|
23
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
24
|
+
visualization.
|
|
25
|
+
Human Brain Mapping , Aug. 2017.
|
|
26
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
n_chans=None,
|
|
32
|
+
n_outputs=None,
|
|
33
|
+
n_times=None,
|
|
34
|
+
input_window_seconds=None,
|
|
35
|
+
sfreq=None,
|
|
36
|
+
chs_info=None,
|
|
37
|
+
activation: type[nn.Module] = nn.ELU,
|
|
38
|
+
drop_prob: float = 0.5,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.mapping = {
|
|
42
|
+
"final_conv.weight": "final_layer.weight",
|
|
43
|
+
"final_conv.bias": "final_layer.bias",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
deep_model = Deep4Net(
|
|
47
|
+
n_chans=n_chans,
|
|
48
|
+
n_outputs=n_outputs,
|
|
49
|
+
n_filters_time=20,
|
|
50
|
+
n_filters_spat=30,
|
|
51
|
+
n_filters_2=40,
|
|
52
|
+
n_filters_3=50,
|
|
53
|
+
n_filters_4=60,
|
|
54
|
+
n_times=n_times,
|
|
55
|
+
input_window_seconds=input_window_seconds,
|
|
56
|
+
sfreq=sfreq,
|
|
57
|
+
chs_info=chs_info,
|
|
58
|
+
final_conv_length=2,
|
|
59
|
+
activation_first_conv_nonlin=activation,
|
|
60
|
+
activation_later_conv_nonlin=activation,
|
|
61
|
+
drop_prob=drop_prob,
|
|
62
|
+
)
|
|
63
|
+
shallow_model = ShallowFBCSPNet(
|
|
64
|
+
n_chans=n_chans,
|
|
65
|
+
n_outputs=n_outputs,
|
|
66
|
+
n_times=n_times,
|
|
67
|
+
input_window_seconds=input_window_seconds,
|
|
68
|
+
sfreq=sfreq,
|
|
69
|
+
chs_info=chs_info,
|
|
70
|
+
n_filters_time=30,
|
|
71
|
+
n_filters_spat=40,
|
|
72
|
+
filter_time_length=28,
|
|
73
|
+
final_conv_length=29,
|
|
74
|
+
drop_prob=drop_prob,
|
|
75
|
+
)
|
|
76
|
+
new_conv_layer = nn.Conv2d(
|
|
77
|
+
deep_model.final_layer.conv_classifier.in_channels,
|
|
78
|
+
60,
|
|
79
|
+
kernel_size=deep_model.final_layer.conv_classifier.kernel_size,
|
|
80
|
+
stride=deep_model.final_layer.conv_classifier.stride,
|
|
81
|
+
)
|
|
82
|
+
deep_model.final_layer = new_conv_layer
|
|
83
|
+
|
|
84
|
+
new_conv_layer = nn.Conv2d(
|
|
85
|
+
shallow_model.final_layer.conv_classifier.in_channels,
|
|
86
|
+
40,
|
|
87
|
+
kernel_size=shallow_model.final_layer.conv_classifier.kernel_size,
|
|
88
|
+
stride=shallow_model.final_layer.conv_classifier.stride,
|
|
89
|
+
)
|
|
90
|
+
shallow_model.final_layer = new_conv_layer
|
|
91
|
+
|
|
92
|
+
deep_model.to_dense_prediction_model()
|
|
93
|
+
shallow_model.to_dense_prediction_model()
|
|
94
|
+
self.reduced_deep_model = deep_model
|
|
95
|
+
self.reduced_shallow_model = shallow_model
|
|
96
|
+
|
|
97
|
+
self.final_layer = nn.Sequential(
|
|
98
|
+
nn.Conv2d(100, n_outputs, kernel_size=(1, 1), stride=1),
|
|
99
|
+
nn.Identity(),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def forward(self, x):
|
|
103
|
+
"""Forward pass.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
x: torch.Tensor
|
|
108
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
109
|
+
"""
|
|
110
|
+
deep_out = self.reduced_deep_model(x)
|
|
111
|
+
shallow_out = self.reduced_shallow_model(x)
|
|
112
|
+
|
|
113
|
+
n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
|
|
114
|
+
|
|
115
|
+
if n_diff_deep_shallow < 0:
|
|
116
|
+
deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(deep_out)
|
|
117
|
+
elif n_diff_deep_shallow > 0:
|
|
118
|
+
shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(shallow_out)
|
|
119
|
+
|
|
120
|
+
merged_out = torch.cat((deep_out, shallow_out), dim=1)
|
|
121
|
+
|
|
122
|
+
output = self.final_layer(merged_out)
|
|
123
|
+
|
|
124
|
+
squeezed = output.squeeze(3)
|
|
125
|
+
|
|
126
|
+
return squeezed
|
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""IFNet Neural Network.
|
|
2
|
+
|
|
3
|
+
Authors: Jiaheng Wang
|
|
4
|
+
Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
|
|
5
|
+
License: MIT (https://github.com/Jiaheng-Wang/IFNet/blob/main/LICENSE)
|
|
6
|
+
|
|
7
|
+
J. Wang, L. Yao and Y. Wang, "IFNet: An Interactive Frequency Convolutional
|
|
8
|
+
Neural Network for Enhancing Motor Imagery Decoding from EEG," in IEEE
|
|
9
|
+
Transactions on Neural Systems and Rehabilitation Engineering,
|
|
10
|
+
doi: 10.1109/TNSRE.2023.3257319.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Sequence
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from einops.layers.torch import Rearrange
|
|
19
|
+
from mne.utils import warn
|
|
20
|
+
from torch import nn
|
|
21
|
+
from torch.nn.init import trunc_normal_
|
|
22
|
+
|
|
23
|
+
from braindecode.models.base import EEGModuleMixin
|
|
24
|
+
from braindecode.modules import (
|
|
25
|
+
FilterBankLayer,
|
|
26
|
+
LinearWithConstraint,
|
|
27
|
+
LogPowerLayer,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class IFNet(EEGModuleMixin, nn.Module):
|
|
32
|
+
r"""IFNetV2 from Wang J et al (2023) [ifnet]_.
|
|
33
|
+
|
|
34
|
+
:bdg-success:`Convolution` :bdg-primary:`Filterbank`
|
|
35
|
+
|
|
36
|
+
.. figure:: https://raw.githubusercontent.com/Jiaheng-Wang/IFNet/main/IFNet.png
|
|
37
|
+
:align: center
|
|
38
|
+
:alt: IFNetV2 Architecture
|
|
39
|
+
|
|
40
|
+
Overview of the Interactive Frequency Convolutional Neural Network architecture.
|
|
41
|
+
|
|
42
|
+
IFNetV2 is designed to effectively capture spectro-spatial-temporal
|
|
43
|
+
features for motor imagery decoding from EEG data. The model consists of
|
|
44
|
+
three stages: Spectro-Spatial Feature Representation, Cross-Frequency
|
|
45
|
+
Interactions, and Classification.
|
|
46
|
+
|
|
47
|
+
- **Spectro-Spatial Feature Representation**: The raw EEG signals are
|
|
48
|
+
filtered into two characteristic frequency bands: low (4-16 Hz) and
|
|
49
|
+
high (16-40 Hz), covering the most relevant motor imagery bands.
|
|
50
|
+
Spectro-spatial features are then extracted through 1D point-wise
|
|
51
|
+
spatial convolution followed by temporal convolution.
|
|
52
|
+
|
|
53
|
+
- **Cross-Frequency Interactions**: The extracted spectro-spatial
|
|
54
|
+
features from each frequency band are combined through an element-wise
|
|
55
|
+
summation operation, which enhances feature representation while
|
|
56
|
+
preserving distinct characteristics.
|
|
57
|
+
|
|
58
|
+
- **Classification**: The aggregated spectro-spatial features are further
|
|
59
|
+
reduced through temporal average pooling and passed through a fully
|
|
60
|
+
connected layer followed by a softmax operation to generate output
|
|
61
|
+
probabilities for each class.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
66
|
+
by original authors, only reimplemented from the paper description and
|
|
67
|
+
Torch source code [ifnetv2code]_. Version 2 is present only in the repository,
|
|
68
|
+
and the main difference is one pooling layer, describe at the TABLE VII
|
|
69
|
+
from the paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10070810
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
bands : list[tuple[int, int]] or int or None, default=[[4, 16], (16, 40)]
|
|
75
|
+
Frequency bands for filtering.
|
|
76
|
+
out_planes : int, default=64
|
|
77
|
+
Number of output feature dimensions.
|
|
78
|
+
kernel_sizes : tuple of int, default=(63, 31)
|
|
79
|
+
List of kernel sizes for temporal convolutions.
|
|
80
|
+
patch_size : int, default=125
|
|
81
|
+
Size of the patches for temporal segmentation.
|
|
82
|
+
drop_prob : float, default=0.5
|
|
83
|
+
Dropout probability.
|
|
84
|
+
activation : nn.Module, default=nn.GELU
|
|
85
|
+
Activation function after the InterFrequency Layer.
|
|
86
|
+
verbose : bool, default=False
|
|
87
|
+
Verbose to control the filtering layer
|
|
88
|
+
filter_parameters : dict, default={}
|
|
89
|
+
Additional parameters for the filter bank layer.
|
|
90
|
+
|
|
91
|
+
References
|
|
92
|
+
----------
|
|
93
|
+
.. [ifnet] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
|
|
94
|
+
frequency convolutional neural network for enhancing motor imagery
|
|
95
|
+
decoding from EEG. IEEE Transactions on Neural Systems and
|
|
96
|
+
Rehabilitation Engineering, 31, 1900-1911.
|
|
97
|
+
.. [ifnetv2code] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
|
|
98
|
+
frequency convolutional neural network for enhancing motor imagery
|
|
99
|
+
decoding from EEG.
|
|
100
|
+
https://github.com/Jiaheng-Wang/IFNet
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
# Braindecode parameters
|
|
106
|
+
n_chans=None,
|
|
107
|
+
n_outputs=None,
|
|
108
|
+
n_times=None,
|
|
109
|
+
chs_info=None,
|
|
110
|
+
input_window_seconds=None,
|
|
111
|
+
sfreq=None,
|
|
112
|
+
# Model-specific parameters
|
|
113
|
+
bands: list[tuple[float, float]] | int | None = [(4.0, 16.0), (16, 40)],
|
|
114
|
+
n_filters_spat: int = 64,
|
|
115
|
+
kernel_sizes: tuple[int, int] = (63, 31),
|
|
116
|
+
stride_factor: int = 8,
|
|
117
|
+
drop_prob: float = 0.5,
|
|
118
|
+
linear_max_norm: float = 0.5,
|
|
119
|
+
activation: type[nn.Module] = nn.GELU,
|
|
120
|
+
verbose: bool = False,
|
|
121
|
+
filter_parameters: Optional[dict] = None,
|
|
122
|
+
):
|
|
123
|
+
super().__init__(
|
|
124
|
+
n_chans=n_chans,
|
|
125
|
+
n_outputs=n_outputs,
|
|
126
|
+
chs_info=chs_info,
|
|
127
|
+
n_times=n_times,
|
|
128
|
+
input_window_seconds=input_window_seconds,
|
|
129
|
+
sfreq=sfreq,
|
|
130
|
+
)
|
|
131
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
132
|
+
|
|
133
|
+
self.bands = bands
|
|
134
|
+
self.n_filters_spat = n_filters_spat
|
|
135
|
+
self.stride_factor = stride_factor
|
|
136
|
+
self.kernel_sizes = kernel_sizes
|
|
137
|
+
self.verbose = verbose
|
|
138
|
+
self.filter_parameters = filter_parameters
|
|
139
|
+
self.drop_prob = drop_prob
|
|
140
|
+
self.activation = activation
|
|
141
|
+
self.linear_max_norm = linear_max_norm
|
|
142
|
+
self.filter_parameters = filter_parameters or {}
|
|
143
|
+
|
|
144
|
+
# Layers
|
|
145
|
+
# Following paper nomenclature
|
|
146
|
+
self.spectral_filtering = FilterBankLayer(
|
|
147
|
+
n_chans=self.n_chans,
|
|
148
|
+
sfreq=self.sfreq,
|
|
149
|
+
band_filters=self.bands,
|
|
150
|
+
verbose=verbose,
|
|
151
|
+
**self.filter_parameters,
|
|
152
|
+
)
|
|
153
|
+
# As we have an internal process to create the bands,
|
|
154
|
+
# we get the values from the filterbank
|
|
155
|
+
self.n_bands = self.spectral_filtering.n_bands
|
|
156
|
+
|
|
157
|
+
# My interpretation from the TABLE VII IFNet Architecture from the
|
|
158
|
+
# paper.
|
|
159
|
+
self.ensuredim = Rearrange(
|
|
160
|
+
"batch nbands chans time -> batch (nbands chans) time"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# SpatioTemporal Feature Block
|
|
164
|
+
self.feature_block = _SpatioTemporalFeatureBlock(
|
|
165
|
+
in_channels=self.n_chans * self.n_bands,
|
|
166
|
+
out_channels=self.n_filters_spat,
|
|
167
|
+
kernel_sizes=self.kernel_sizes,
|
|
168
|
+
stride_factor=self.stride_factor,
|
|
169
|
+
n_bands=self.n_bands,
|
|
170
|
+
drop_prob=self.drop_prob,
|
|
171
|
+
activation=self.activation,
|
|
172
|
+
n_times=self.n_times,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Final classification layer
|
|
176
|
+
self.final_layer = LinearWithConstraint(
|
|
177
|
+
in_features=self.n_filters_spat * stride_factor,
|
|
178
|
+
out_features=self.n_outputs,
|
|
179
|
+
max_norm=self.linear_max_norm,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
self.flatten = Rearrange("batch ... -> batch (...)")
|
|
183
|
+
|
|
184
|
+
# Initialize parameters
|
|
185
|
+
self._initialize_weights(self)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _initialize_weights(m):
|
|
189
|
+
"""Initializes weights of the network.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
m : nn.Module
|
|
194
|
+
Module to initialize.
|
|
195
|
+
"""
|
|
196
|
+
if isinstance(m, nn.Linear):
|
|
197
|
+
trunc_normal_(m.weight, std=0.01)
|
|
198
|
+
if m.bias is not None:
|
|
199
|
+
nn.init.constant_(m.bias, 0)
|
|
200
|
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d)):
|
|
201
|
+
if m.weight is not None:
|
|
202
|
+
nn.init.constant_(m.weight, 1.0)
|
|
203
|
+
if m.bias is not None:
|
|
204
|
+
nn.init.constant_(m.bias, 0)
|
|
205
|
+
elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
|
|
206
|
+
trunc_normal_(m.weight, std=0.01)
|
|
207
|
+
if m.bias is not None:
|
|
208
|
+
nn.init.constant_(m.bias, 0)
|
|
209
|
+
|
|
210
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
211
|
+
"""Forward pass of IFNet.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
x : torch.Tensor
|
|
216
|
+
Input tensor with shape (batch_size, n_chans, n_times).
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
torch.Tensor
|
|
221
|
+
Output tensor with shape (batch_size, n_outputs).
|
|
222
|
+
"""
|
|
223
|
+
# Pass through the spectral filtering layer
|
|
224
|
+
x = self.spectral_filtering(x)
|
|
225
|
+
# x is now of shape (batch_size, n_bands, n_chans, n_times)
|
|
226
|
+
x = self.ensuredim(x)
|
|
227
|
+
# x is now of shape (batch_size, n_bands * n_chans, n_times)
|
|
228
|
+
|
|
229
|
+
# Pass through the feature block
|
|
230
|
+
x = self.feature_block(x)
|
|
231
|
+
|
|
232
|
+
# Flatten and pass through the final layer
|
|
233
|
+
x = self.flatten(x)
|
|
234
|
+
|
|
235
|
+
x = self.final_layer(x)
|
|
236
|
+
|
|
237
|
+
return x
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class _InterFrequencyModule(nn.Module):
|
|
241
|
+
r"""Module that combines outputs from different frequency bands."""
|
|
242
|
+
|
|
243
|
+
def __init__(self, activation: type[nn.Module] = nn.GELU):
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
activation: nn.Module
|
|
249
|
+
Activation function for the InterFrequency Module
|
|
250
|
+
|
|
251
|
+
"""
|
|
252
|
+
super().__init__()
|
|
253
|
+
|
|
254
|
+
self.activation = activation()
|
|
255
|
+
|
|
256
|
+
def forward(self, x_list: list[torch.Tensor]) -> torch.Tensor:
|
|
257
|
+
"""Forward pass.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
x_list : list of torch.Tensor
|
|
262
|
+
List of tensors to be combined.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
torch.Tensor
|
|
267
|
+
Combined tensor after applying GELU activation.
|
|
268
|
+
"""
|
|
269
|
+
x = torch.stack(x_list, dim=0).sum(dim=0)
|
|
270
|
+
x = self.activation(x)
|
|
271
|
+
return x
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class _SpatioTemporalFeatureBlock(nn.Module):
|
|
275
|
+
r"""SpatioTemporal Feature Block consisting of spatial and temporal convolutions."""
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
n_times: int,
|
|
280
|
+
in_channels: int,
|
|
281
|
+
out_channels: int = 64,
|
|
282
|
+
kernel_sizes: Sequence[int] = [63, 31],
|
|
283
|
+
stride_factor: int = 8,
|
|
284
|
+
n_bands: int = 2,
|
|
285
|
+
drop_prob: float = 0.5,
|
|
286
|
+
activation: type[nn.Module] = nn.GELU,
|
|
287
|
+
dim: int = 3,
|
|
288
|
+
):
|
|
289
|
+
"""
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
in_channels : int
|
|
293
|
+
Number of input channels.
|
|
294
|
+
out_channels : int, default=64
|
|
295
|
+
Number of output channels.
|
|
296
|
+
kernel_sizes : list of int, default=[63, 31]
|
|
297
|
+
List of kernel sizes for temporal convolutions.
|
|
298
|
+
stride_factor : int, default=4
|
|
299
|
+
Stride factor for temporal segmentation.
|
|
300
|
+
n_bands : int, default=2
|
|
301
|
+
Number of frequency bands or groups.
|
|
302
|
+
drop_prob : float, default=0.5
|
|
303
|
+
Dropout probability.
|
|
304
|
+
activation: nn.Module, default=nn.GELU
|
|
305
|
+
Activation function after the InterFrequency Layer
|
|
306
|
+
dim: int, default=3
|
|
307
|
+
Internal dimensional to apply the LogPowerLayer
|
|
308
|
+
"""
|
|
309
|
+
super().__init__()
|
|
310
|
+
self.in_channels = in_channels
|
|
311
|
+
self.out_channels = out_channels
|
|
312
|
+
self.n_bands = n_bands
|
|
313
|
+
self.stride_factor = stride_factor
|
|
314
|
+
self.drop_prob = drop_prob
|
|
315
|
+
self.activation = activation
|
|
316
|
+
self.dim = dim
|
|
317
|
+
self.n_times = n_times
|
|
318
|
+
self.kernel_sizes = kernel_sizes
|
|
319
|
+
|
|
320
|
+
if self.n_bands != len(self.kernel_sizes):
|
|
321
|
+
warn(
|
|
322
|
+
f"Got {self.n_bands} bands, different from {len(self.kernel_sizes)} amount of "
|
|
323
|
+
"kernels to build the temporal convolution, we will apply "
|
|
324
|
+
"min(n_bands, len(self.kernel_size) to apply the convolution.",
|
|
325
|
+
UserWarning,
|
|
326
|
+
)
|
|
327
|
+
if self.n_bands > len(self.kernel_sizes):
|
|
328
|
+
self.n_bands = len(self.kernel_sizes)
|
|
329
|
+
warn(
|
|
330
|
+
f"Reducing number of bands to {len(self.kernel_sizes)} to match the number of kernels.",
|
|
331
|
+
UserWarning,
|
|
332
|
+
)
|
|
333
|
+
elif self.n_bands < len(self.kernel_sizes):
|
|
334
|
+
self.kernel_sizes = self.kernel_sizes[: self.n_bands]
|
|
335
|
+
warn(
|
|
336
|
+
f"Reducing number of kernels to {self.n_bands} to match the number of bands.",
|
|
337
|
+
UserWarning,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if self.n_times % self.stride_factor != 0:
|
|
341
|
+
warn(
|
|
342
|
+
f"Time dimension ({self.n_times}) is not divisible by"
|
|
343
|
+
f" stride_factor ({self.stride_factor}). Input will be padded.",
|
|
344
|
+
UserWarning,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
out_channels_spatial = self.out_channels * self.n_bands
|
|
348
|
+
|
|
349
|
+
# Spatial convolution
|
|
350
|
+
self.spatial_conv = nn.Conv1d(
|
|
351
|
+
in_channels=self.in_channels,
|
|
352
|
+
out_channels=out_channels_spatial,
|
|
353
|
+
kernel_size=1,
|
|
354
|
+
groups=self.n_bands,
|
|
355
|
+
bias=False,
|
|
356
|
+
)
|
|
357
|
+
self.spatial_bn = nn.BatchNorm1d(out_channels_spatial)
|
|
358
|
+
|
|
359
|
+
self.unpack_bands = nn.Unflatten(
|
|
360
|
+
dim=1, unflattened_size=(self.n_bands, self.out_channels)
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Temporal convolutions for each radix
|
|
364
|
+
self.temporal_convs = nn.ModuleList()
|
|
365
|
+
for kernel_size in self.kernel_sizes:
|
|
366
|
+
self.temporal_convs.append(
|
|
367
|
+
nn.Sequential(
|
|
368
|
+
nn.Conv1d(
|
|
369
|
+
in_channels=self.out_channels,
|
|
370
|
+
out_channels=self.out_channels,
|
|
371
|
+
kernel_size=kernel_size,
|
|
372
|
+
padding=kernel_size // 2,
|
|
373
|
+
groups=self.out_channels,
|
|
374
|
+
bias=False,
|
|
375
|
+
),
|
|
376
|
+
nn.BatchNorm1d(self.out_channels),
|
|
377
|
+
)
|
|
378
|
+
)
|
|
379
|
+
# Inter-frequency module
|
|
380
|
+
self.inter_frequency = _InterFrequencyModule(activation=self.activation)
|
|
381
|
+
|
|
382
|
+
if self.n_times % self.stride_factor != 0:
|
|
383
|
+
self.padding_size = stride_factor - (self.n_times % stride_factor)
|
|
384
|
+
self.n_times_padded = self.n_times + self.padding_size
|
|
385
|
+
self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
|
|
386
|
+
else:
|
|
387
|
+
self.padding_layer = nn.Identity()
|
|
388
|
+
self.n_times_padded = self.n_times
|
|
389
|
+
|
|
390
|
+
# Log-Power layer
|
|
391
|
+
self.log_power = LogPowerLayer(dim=self.dim) # type: ignore
|
|
392
|
+
# Dropout
|
|
393
|
+
self.dropout_layer = nn.Dropout(self.drop_prob)
|
|
394
|
+
|
|
395
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
396
|
+
"""Forward pass.
|
|
397
|
+
|
|
398
|
+
Parameters
|
|
399
|
+
----------
|
|
400
|
+
x : torch.Tensor
|
|
401
|
+
Input tensor of shape (batch_size, in_channels, n_times).
|
|
402
|
+
|
|
403
|
+
Returns
|
|
404
|
+
-------
|
|
405
|
+
torch.Tensor
|
|
406
|
+
Output tensor after processing.
|
|
407
|
+
"""
|
|
408
|
+
batch_size, _, _ = x.shape
|
|
409
|
+
|
|
410
|
+
# Spatial convolution
|
|
411
|
+
x = self.spatial_conv(x)
|
|
412
|
+
|
|
413
|
+
x = self.spatial_bn(x)
|
|
414
|
+
|
|
415
|
+
# Split the output by bands for each frequency
|
|
416
|
+
x_split = self.unpack_bands(x)
|
|
417
|
+
|
|
418
|
+
x_t = []
|
|
419
|
+
for idx, conv in enumerate(self.temporal_convs):
|
|
420
|
+
x_t.append(conv(x_split[::, idx]))
|
|
421
|
+
|
|
422
|
+
# Inter-frequency interaction
|
|
423
|
+
x = self.inter_frequency(x_t)
|
|
424
|
+
|
|
425
|
+
# Reshape for temporal segmentation
|
|
426
|
+
x = self.padding_layer(x)
|
|
427
|
+
# x is now of shape (batch_size, ..., n_times_padded)
|
|
428
|
+
|
|
429
|
+
# Reshape for log-power computation
|
|
430
|
+
x = x.view(
|
|
431
|
+
batch_size,
|
|
432
|
+
self.out_channels,
|
|
433
|
+
self.stride_factor,
|
|
434
|
+
self.n_times_padded // self.stride_factor,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Log-Power layer
|
|
438
|
+
x = self.log_power(x)
|
|
439
|
+
|
|
440
|
+
# Dropout
|
|
441
|
+
x = self.dropout_layer(x)
|
|
442
|
+
|
|
443
|
+
return x
|