braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.functional import drop_path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Ensure4d(nn.Module):
|
|
14
|
+
"""Ensure the input tensor has 4 dimensions.
|
|
15
|
+
|
|
16
|
+
This is a small utility layer that repeatedly adds a singleton dimension at
|
|
17
|
+
the end until the input has shape ``(batch, channels, time, 1)``.
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
>>> import torch
|
|
22
|
+
>>> from braindecode.modules import Ensure4d
|
|
23
|
+
>>> module = Ensure4d()
|
|
24
|
+
>>> outputs = module(torch.randn(2, 3, 10))
|
|
25
|
+
>>> outputs.shape
|
|
26
|
+
torch.Size([2, 3, 10, 1])
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
while len(x.shape) < 4:
|
|
31
|
+
x = x.unsqueeze(-1)
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Chomp1d(nn.Module):
|
|
36
|
+
"""Remove samples from the end of a sequence.
|
|
37
|
+
|
|
38
|
+
Examples
|
|
39
|
+
--------
|
|
40
|
+
>>> import torch
|
|
41
|
+
>>> from braindecode.modules import Chomp1d
|
|
42
|
+
>>> module = Chomp1d(chomp_size=5)
|
|
43
|
+
>>> inputs = torch.randn(2, 3, 20)
|
|
44
|
+
>>> outputs = module(inputs)
|
|
45
|
+
>>> outputs.shape
|
|
46
|
+
torch.Size([2, 3, 15])
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, chomp_size):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.chomp_size = chomp_size
|
|
52
|
+
|
|
53
|
+
def extra_repr(self):
|
|
54
|
+
return "chomp_size={}".format(self.chomp_size)
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
return x[:, :, : -self.chomp_size].contiguous()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TimeDistributed(nn.Module):
|
|
61
|
+
"""Apply module on multiple windows.
|
|
62
|
+
|
|
63
|
+
Apply the provided module on a sequence of windows and return their
|
|
64
|
+
concatenation.
|
|
65
|
+
Useful with sequence-to-prediction models (e.g. sleep stager which must map
|
|
66
|
+
a sequence of consecutive windows to the label of the middle window in the
|
|
67
|
+
sequence).
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
module : nn.Module
|
|
72
|
+
Module to be applied to the input windows. Must accept an input of
|
|
73
|
+
shape (batch_size, n_channels, n_times).
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
>>> import torch
|
|
78
|
+
>>> from torch import nn
|
|
79
|
+
>>> from braindecode.modules import TimeDistributed
|
|
80
|
+
>>> module = TimeDistributed(nn.Conv1d(3, 4, kernel_size=3, padding=1))
|
|
81
|
+
>>> inputs = torch.randn(2, 5, 3, 20)
|
|
82
|
+
>>> outputs = module(inputs)
|
|
83
|
+
>>> outputs.shape
|
|
84
|
+
torch.Size([2, 5, 4])
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, module):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.module = module
|
|
90
|
+
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
"""
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
x : torch.Tensor
|
|
96
|
+
Sequence of windows, of shape (batch_size, seq_len, n_channels,
|
|
97
|
+
n_times).
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
torch.Tensor
|
|
102
|
+
Shape (batch_size, seq_len, output_size).
|
|
103
|
+
"""
|
|
104
|
+
b, s, c, t = x.shape
|
|
105
|
+
out = self.module(x.view(b * s, c, t))
|
|
106
|
+
return out.view(b, s, -1)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class DropPath(nn.Module):
|
|
110
|
+
"""Drop paths, also known as Stochastic Depth, per sample.
|
|
111
|
+
|
|
112
|
+
When applied in main path of residual blocks.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
drop_prob: float (default=None)
|
|
117
|
+
Drop path probability (should be in range 0-1).
|
|
118
|
+
|
|
119
|
+
Notes
|
|
120
|
+
-----
|
|
121
|
+
Code copied and modified from VISSL facebookresearch:
|
|
122
|
+
https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
|
|
123
|
+
|
|
124
|
+
All rights reserved.
|
|
125
|
+
|
|
126
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
127
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
128
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
129
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
130
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
131
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
132
|
+
SOFTWARE.
|
|
133
|
+
|
|
134
|
+
Examples
|
|
135
|
+
--------
|
|
136
|
+
>>> import torch
|
|
137
|
+
>>> from braindecode.modules import DropPath
|
|
138
|
+
>>> module = DropPath(drop_prob=0.2)
|
|
139
|
+
>>> module.train()
|
|
140
|
+
>>> inputs = torch.randn(2, 3, 10)
|
|
141
|
+
>>> outputs = module(inputs)
|
|
142
|
+
>>> outputs.shape
|
|
143
|
+
torch.Size([2, 3, 10])
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, drop_prob=None):
|
|
147
|
+
super(DropPath, self).__init__()
|
|
148
|
+
self.drop_prob = drop_prob
|
|
149
|
+
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
return drop_path(x, self.drop_prob, self.training)
|
|
152
|
+
|
|
153
|
+
# Utility function to print DropPath module
|
|
154
|
+
def extra_repr(self) -> str:
|
|
155
|
+
return f"p={self.drop_prob}"
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SqueezeFinalOutput(nn.Module):
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
Removes empty dimension at end and potentially removes empty time
|
|
162
|
+
dimension. It does not just use squeeze as we never want to remove
|
|
163
|
+
first dimension.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
x: torch.Tensor
|
|
168
|
+
squeezed tensor
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self):
|
|
172
|
+
super().__init__()
|
|
173
|
+
|
|
174
|
+
self.squeeze = Rearrange("b c t 1 -> b c t")
|
|
175
|
+
|
|
176
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
177
|
+
# 1) drop feature dim
|
|
178
|
+
x = self.squeeze(x)
|
|
179
|
+
# 2) drop time dim if singleton
|
|
180
|
+
if x.shape[-1] == 1:
|
|
181
|
+
x = x.squeeze(-1)
|
|
182
|
+
return x
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class SubjectLayers(nn.Module):
|
|
186
|
+
"""Per-subject linear transformation layer.
|
|
187
|
+
|
|
188
|
+
Applies subject-specific linear transformations to the input. Each subject
|
|
189
|
+
owns an independent weight matrix, enabling personalized feature
|
|
190
|
+
processing.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
in_channels: int,
|
|
196
|
+
out_channels: int,
|
|
197
|
+
n_subjects: int,
|
|
198
|
+
init_id: bool = False,
|
|
199
|
+
):
|
|
200
|
+
super().__init__()
|
|
201
|
+
self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
|
|
202
|
+
if init_id:
|
|
203
|
+
if in_channels != out_channels:
|
|
204
|
+
raise AssertionError("init_id requires in_channels == out_channels")
|
|
205
|
+
self.weights.data[:] = torch.eye(in_channels)[None]
|
|
206
|
+
self.weights.data *= 1 / (in_channels**0.5)
|
|
207
|
+
|
|
208
|
+
def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor:
|
|
209
|
+
"""Apply the subject-specific linear transforms."""
|
|
210
|
+
_, C, D = self.weights.shape
|
|
211
|
+
weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
|
|
212
|
+
return torch.einsum("bct,bcd->bdt", x, weights)
|
|
213
|
+
|
|
214
|
+
def __repr__(self) -> str:
|
|
215
|
+
S, C, D = self.weights.shape
|
|
216
|
+
return f"SubjectLayers({C}, {D}, {S})"
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
from torch.nn.utils.parametrize import register_parametrization
|
|
3
|
+
|
|
4
|
+
from braindecode.modules.parametrization import MaxNorm, MaxNormParametrize
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MaxNormLinear(nn.Linear):
|
|
8
|
+
"""Linear layer with MaxNorm constraining on weights.
|
|
9
|
+
|
|
10
|
+
Equivalent of Keras tf.keras.Dense(..., kernel_constraint=max_norm())
|
|
11
|
+
[1]_ and [2]_. Implemented as advised in [3]_.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
in_features: int
|
|
16
|
+
Size of each input sample.
|
|
17
|
+
out_features: int
|
|
18
|
+
Size of each output sample.
|
|
19
|
+
bias: bool, optional
|
|
20
|
+
If set to ``False``, the layer will not learn an additive bias.
|
|
21
|
+
Default: ``True``.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> import torch
|
|
26
|
+
>>> from braindecode.modules import MaxNormLinear
|
|
27
|
+
>>> module = MaxNormLinear(10, 5, max_norm_val=2)
|
|
28
|
+
>>> inputs = torch.randn(2, 10)
|
|
29
|
+
>>> outputs = module(inputs)
|
|
30
|
+
>>> outputs.shape
|
|
31
|
+
torch.Size([2, 5])
|
|
32
|
+
|
|
33
|
+
References
|
|
34
|
+
----------
|
|
35
|
+
.. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
|
|
36
|
+
.. [2] https://www.tensorflow.org/api_docs/python/tf/keras/constraints/
|
|
37
|
+
MaxNorm
|
|
38
|
+
.. [3] https://discuss.pytorch.org/t/how-to-correctly-implement-in-place-
|
|
39
|
+
max-norm-constraint/96769
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self, in_features, out_features, bias=True, max_norm_val=2, eps=1e-5, **kwargs
|
|
44
|
+
):
|
|
45
|
+
super().__init__(
|
|
46
|
+
in_features=in_features, out_features=out_features, bias=bias, **kwargs
|
|
47
|
+
)
|
|
48
|
+
self._max_norm_val = max_norm_val
|
|
49
|
+
self._eps = eps
|
|
50
|
+
register_parametrization(self, "weight", MaxNorm(self._max_norm_val, self._eps))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LinearWithConstraint(nn.Linear):
|
|
54
|
+
"""Linear layer with max-norm constraint on the weights.
|
|
55
|
+
|
|
56
|
+
Examples
|
|
57
|
+
--------
|
|
58
|
+
>>> import torch
|
|
59
|
+
>>> from braindecode.modules import LinearWithConstraint
|
|
60
|
+
>>> module = LinearWithConstraint(10, 5, max_norm=1.0)
|
|
61
|
+
>>> inputs = torch.randn(2, 10)
|
|
62
|
+
>>> outputs = module(inputs)
|
|
63
|
+
>>> outputs.shape
|
|
64
|
+
torch.Size([2, 5])
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, *args, max_norm=1.0, **kwargs):
|
|
68
|
+
super(LinearWithConstraint, self).__init__(*args, **kwargs)
|
|
69
|
+
self.max_norm = max_norm
|
|
70
|
+
register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MaxNorm(nn.Module):
|
|
6
|
+
def __init__(self, max_norm_val=2.0, eps=1e-5):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.max_norm_val = max_norm_val
|
|
9
|
+
self.eps = eps
|
|
10
|
+
|
|
11
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
norm = X.norm(2, dim=0, keepdim=True)
|
|
13
|
+
denom = norm.clamp(min=self.max_norm_val / 2)
|
|
14
|
+
number = denom.clamp(max=self.max_norm_val)
|
|
15
|
+
return X * (number / (denom + self.eps))
|
|
16
|
+
|
|
17
|
+
def right_inverse(self, X: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
# Assuming the forward scales X by a factor s,
|
|
19
|
+
# the right inverse would scale it back by 1/s.
|
|
20
|
+
norm = X.norm(2, dim=0, keepdim=True)
|
|
21
|
+
denom = norm.clamp(min=self.max_norm_val / 2)
|
|
22
|
+
number = denom.clamp(max=self.max_norm_val)
|
|
23
|
+
scale = number / (denom + self.eps)
|
|
24
|
+
return X / scale
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MaxNormParametrize(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Enforce a max‑norm constraint on the rows of a weight tensor via parametrization.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, max_norm: float = 1.0):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.max_norm = max_norm
|
|
35
|
+
|
|
36
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
# Renormalize each "row" (dim=0 slice) to have at most self.max_norm L2-norm
|
|
38
|
+
return X.renorm(p=2, dim=0, maxnorm=self.max_norm)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Callable, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StatLayer(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Generic layer to compute a statistical function along a specified dimension.
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
stat_fn : Callable
|
|
16
|
+
A function like torch.mean, torch.std, etc.
|
|
17
|
+
dim : int
|
|
18
|
+
Dimension along which to apply the function.
|
|
19
|
+
keepdim : bool, default=True
|
|
20
|
+
Whether to keep the reduced dimension.
|
|
21
|
+
clamp_range : tuple(float, float), optional
|
|
22
|
+
Used only for functions requiring clamping (e.g., log variance).
|
|
23
|
+
apply_log : bool, default=False
|
|
24
|
+
Whether to apply log after computation (used for LogVarLayer).
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
>>> import torch
|
|
29
|
+
>>> from braindecode.modules import StatLayer
|
|
30
|
+
>>> module = StatLayer(stat_fn=torch.mean, dim=-1, keepdim=True)
|
|
31
|
+
>>> inputs = torch.randn(2, 3, 10)
|
|
32
|
+
>>> outputs = module(inputs)
|
|
33
|
+
>>> outputs.shape
|
|
34
|
+
torch.Size([2, 3, 1])
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
stat_fn: Callable[..., torch.Tensor],
|
|
40
|
+
dim: int,
|
|
41
|
+
keepdim: bool = True,
|
|
42
|
+
clamp_range: Optional[tuple[float, float]] = None,
|
|
43
|
+
apply_log: bool = False,
|
|
44
|
+
) -> None:
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.stat_fn = stat_fn
|
|
47
|
+
self.dim = dim
|
|
48
|
+
self.keepdim = keepdim
|
|
49
|
+
self.clamp_range = clamp_range
|
|
50
|
+
self.apply_log = apply_log
|
|
51
|
+
|
|
52
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
out = self.stat_fn(x, dim=self.dim, keepdim=self.keepdim)
|
|
54
|
+
if self.clamp_range is not None:
|
|
55
|
+
out = torch.clamp(out, min=self.clamp_range[0], max=self.clamp_range[1])
|
|
56
|
+
if self.apply_log:
|
|
57
|
+
out = torch.log(out)
|
|
58
|
+
return out
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# make things more simple
|
|
62
|
+
def _max_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
|
|
63
|
+
return x.max(dim=dim, keepdim=keepdim)[0]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _power_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
|
|
67
|
+
# compute mean of squared values along `dim`
|
|
68
|
+
return torch.mean(x**2, dim=dim, keepdim=keepdim)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
MeanLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.mean)
|
|
72
|
+
MaxLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, _max_fn)
|
|
73
|
+
VarLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.var)
|
|
74
|
+
StdLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.std)
|
|
75
|
+
LogVarLayer: Callable[[int, bool], StatLayer] = partial(
|
|
76
|
+
StatLayer,
|
|
77
|
+
torch.var,
|
|
78
|
+
clamp_range=(1e-6, 1e6),
|
|
79
|
+
apply_log=True,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
LogPowerLayer: Callable[[int, bool], StatLayer] = partial(
|
|
83
|
+
StatLayer,
|
|
84
|
+
_power_fn,
|
|
85
|
+
clamp_range=(1e-4, 1e4),
|
|
86
|
+
apply_log=True,
|
|
87
|
+
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
# Hubert Banville <hubert.jbanville@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.special import log_softmax
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _pad_shift_array(x, stride=1):
|
|
11
|
+
"""Zero-pad and shift rows of a 3D array.
|
|
12
|
+
|
|
13
|
+
E.g., used to align predictions of corresponding windows in
|
|
14
|
+
sequence-to-sequence models.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
x : np.ndarray
|
|
19
|
+
Array of shape (n_rows, n_classes, n_windows).
|
|
20
|
+
stride : int
|
|
21
|
+
Number of non-overlapping elements between two consecutive sequences.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
np.ndarray :
|
|
26
|
+
Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
|
|
27
|
+
where each row is obtained by zero-padding the corresponding row in
|
|
28
|
+
``x`` before and after in the last dimension.
|
|
29
|
+
"""
|
|
30
|
+
if x.ndim != 3:
|
|
31
|
+
raise NotImplementedError(
|
|
32
|
+
f"x must be of shape (n_rows, n_classes, n_windows), got {x.shape}"
|
|
33
|
+
)
|
|
34
|
+
x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
|
|
35
|
+
orig_strides = x_padded.strides
|
|
36
|
+
new_strides = (
|
|
37
|
+
orig_strides[0] - stride * orig_strides[2],
|
|
38
|
+
orig_strides[1],
|
|
39
|
+
orig_strides[2],
|
|
40
|
+
)
|
|
41
|
+
return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def aggregate_probas(logits, n_windows_stride=1):
|
|
45
|
+
"""Aggregate predicted probabilities with self-ensembling.
|
|
46
|
+
|
|
47
|
+
Aggregate window-wise predicted probabilities obtained on overlapping
|
|
48
|
+
sequences of windows using multiplicative voting as described in
|
|
49
|
+
[Phan2018]_.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
logits : np.ndarray
|
|
54
|
+
Array of shape (n_sequences, n_classes, n_windows) containing the
|
|
55
|
+
logits (i.e. the raw unnormalized scores for each class) for each
|
|
56
|
+
window of each sequence.
|
|
57
|
+
n_windows_stride : int
|
|
58
|
+
Number of windows between two consecutive sequences. Default is 1
|
|
59
|
+
(maximally overlapping sequences).
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
np.ndarray :
|
|
64
|
+
Array of shape ((n_rows - 1) * stride + n_windows, n_classes)
|
|
65
|
+
containing the aggregated predicted probabilities for each window
|
|
66
|
+
contained in the input sequences.
|
|
67
|
+
|
|
68
|
+
References
|
|
69
|
+
----------
|
|
70
|
+
.. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
|
|
71
|
+
De Vos, M. (2018). Joint classification and prediction CNN framework
|
|
72
|
+
for automatic sleep stage classification. IEEE Transactions on
|
|
73
|
+
Biomedical Engineering, 66(5), 1285-1296.
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
>>> import numpy as np
|
|
78
|
+
>>> from braindecode.modules import aggregate_probas
|
|
79
|
+
>>> logits = np.random.randn(3, 4, 5) # (n_sequences, n_classes, n_windows)
|
|
80
|
+
>>> probas = aggregate_probas(logits, n_windows_stride=1)
|
|
81
|
+
>>> probas.shape
|
|
82
|
+
(7, 4)
|
|
83
|
+
"""
|
|
84
|
+
log_probas = log_softmax(logits, axis=1)
|
|
85
|
+
return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Expression(nn.Module):
|
|
6
|
+
"""Compute given expression on forward pass.
|
|
7
|
+
|
|
8
|
+
Parameters
|
|
9
|
+
----------
|
|
10
|
+
expression_fn : callable
|
|
11
|
+
Should accept variable number of objects of type
|
|
12
|
+
`torch.autograd.Variable` to compute its output.
|
|
13
|
+
|
|
14
|
+
Examples
|
|
15
|
+
--------
|
|
16
|
+
>>> import torch
|
|
17
|
+
>>> from braindecode.modules import Expression
|
|
18
|
+
>>> module = Expression(lambda x: x**2)
|
|
19
|
+
>>> inputs = torch.randn(2, 3)
|
|
20
|
+
>>> outputs = module(inputs)
|
|
21
|
+
>>> outputs.shape
|
|
22
|
+
torch.Size([2, 3])
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, expression_fn):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.expression_fn = expression_fn
|
|
28
|
+
|
|
29
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
return self.expression_fn(x)
|
|
31
|
+
|
|
32
|
+
def __repr__(self):
|
|
33
|
+
if hasattr(self.expression_fn, "func") and hasattr(
|
|
34
|
+
self.expression_fn, "kwargs"
|
|
35
|
+
):
|
|
36
|
+
expression_str = "{:s} {:s}".format(
|
|
37
|
+
self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
|
|
38
|
+
)
|
|
39
|
+
elif hasattr(self.expression_fn, "__name__"):
|
|
40
|
+
expression_str = self.expression_fn.__name__
|
|
41
|
+
else:
|
|
42
|
+
expression_str = repr(self.expression_fn)
|
|
43
|
+
return self.__class__.__name__ + "(expression=%s) " % expression_str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class IntermediateOutputWrapper(nn.Module):
|
|
47
|
+
"""Wraps network model such that outputs of intermediate layers can be returned.
|
|
48
|
+
forward() returns list of intermediate activations in a network during forward pass.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
to_select : list
|
|
53
|
+
list of module names for which activation should be returned
|
|
54
|
+
model : model object
|
|
55
|
+
network model
|
|
56
|
+
|
|
57
|
+
Examples
|
|
58
|
+
--------
|
|
59
|
+
>>> model = Deep4Net()
|
|
60
|
+
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
|
|
61
|
+
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
|
|
62
|
+
|
|
63
|
+
>>> import torch
|
|
64
|
+
>>> base = torch.nn.Sequential(torch.nn.Linear(10, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2))
|
|
65
|
+
>>> wrapped = IntermediateOutputWrapper(to_select=["0", "2"], model=base)
|
|
66
|
+
>>> outputs = wrapped(torch.randn(4, 10))
|
|
67
|
+
>>> len(outputs)
|
|
68
|
+
2
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, to_select, model):
|
|
72
|
+
if not len(list(model.children())) == len(list(model.named_children())):
|
|
73
|
+
raise Exception("All modules in model need to have names!")
|
|
74
|
+
|
|
75
|
+
super().__init__()
|
|
76
|
+
|
|
77
|
+
modules_list = model.named_children()
|
|
78
|
+
for key, module in modules_list:
|
|
79
|
+
self.add_module(key, module)
|
|
80
|
+
self._modules[key].load_state_dict(module.state_dict())
|
|
81
|
+
self._to_select = to_select
|
|
82
|
+
|
|
83
|
+
def forward(self, x):
|
|
84
|
+
# Call modules individually and append activation to output if module is in to_select
|
|
85
|
+
o = []
|
|
86
|
+
for name, module in self._modules.items():
|
|
87
|
+
x = module(x)
|
|
88
|
+
if name in self._to_select:
|
|
89
|
+
o.append(x)
|
|
90
|
+
return o
|