braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class InceptionBlock(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Inception block module.
|
|
8
|
+
|
|
9
|
+
This module applies multiple convolutional branches to the input and concatenates
|
|
10
|
+
their outputs along the channel dimension. Each branch can have a different
|
|
11
|
+
configuration, allowing the model to capture multi-scale features.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
branches : list of nn.Module
|
|
16
|
+
List of convolutional branches to apply to the input.
|
|
17
|
+
|
|
18
|
+
Examples
|
|
19
|
+
--------
|
|
20
|
+
>>> import torch
|
|
21
|
+
>>> from torch import nn
|
|
22
|
+
>>> from braindecode.modules import InceptionBlock
|
|
23
|
+
>>> block = InceptionBlock(
|
|
24
|
+
... [
|
|
25
|
+
... nn.Conv1d(3, 4, kernel_size=1),
|
|
26
|
+
... nn.Conv1d(3, 4, kernel_size=3, padding=1),
|
|
27
|
+
... ]
|
|
28
|
+
... )
|
|
29
|
+
>>> inputs = torch.randn(2, 3, 100)
|
|
30
|
+
>>> outputs = block(inputs)
|
|
31
|
+
>>> outputs.shape
|
|
32
|
+
torch.Size([2, 8, 100])
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, branches):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.branches = nn.ModuleList(branches)
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
return torch.cat([branch(x) for branch in self.branches], 1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MLP(nn.Sequential):
|
|
44
|
+
r"""Multilayer Perceptron (MLP) with GELU activation and optional dropout.
|
|
45
|
+
|
|
46
|
+
Also known as fully connected feedforward network, an MLP is a sequence of
|
|
47
|
+
non-linear parametric functions
|
|
48
|
+
|
|
49
|
+
.. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}),
|
|
50
|
+
|
|
51
|
+
over feature vectors :math:`h_i`, with the input and output feature vectors
|
|
52
|
+
:math:`x = h_0` and :math:`y = h_L`, respectively. The non-linear functions
|
|
53
|
+
:math:`a_i` are called activation functions. The trainable parameters of an
|
|
54
|
+
MLP are its weights and biases :math:`\\phi = \{W_i, b_i | i = 1, \dots, L\}`.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
in_features: int
|
|
59
|
+
Number of input features.
|
|
60
|
+
hidden_features: Sequential[int] (default=None)
|
|
61
|
+
Number of hidden features, if None, set to in_features.
|
|
62
|
+
You can increase the size of MLP just passing more int in the
|
|
63
|
+
hidden features vector. The model size increase follow the
|
|
64
|
+
rule 2n (hidden layers)+2 (in and out layers)
|
|
65
|
+
out_features: int (default=None)
|
|
66
|
+
Number of output features, if None, set to in_features.
|
|
67
|
+
act_layer: nn.GELU (default)
|
|
68
|
+
The activation function constructor. If ``None``, use
|
|
69
|
+
:class:`torch.nn.GELU` instead.
|
|
70
|
+
drop: float (default=0.0)
|
|
71
|
+
Dropout rate.
|
|
72
|
+
normalize: bool (default=False)
|
|
73
|
+
Whether to apply layer normalization.
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
>>> import torch
|
|
78
|
+
>>> from braindecode.modules import MLP
|
|
79
|
+
>>> module = MLP(in_features=32, hidden_features=(64,), out_features=16)
|
|
80
|
+
>>> inputs = torch.randn(2, 10, 32)
|
|
81
|
+
>>> outputs = module(inputs)
|
|
82
|
+
>>> outputs.shape
|
|
83
|
+
torch.Size([2, 10, 16])
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
in_features: int,
|
|
89
|
+
hidden_features=None,
|
|
90
|
+
out_features=None,
|
|
91
|
+
activation=nn.GELU,
|
|
92
|
+
drop=0.0,
|
|
93
|
+
normalize=False,
|
|
94
|
+
):
|
|
95
|
+
self.normalization = nn.LayerNorm if normalize else lambda: None
|
|
96
|
+
self.in_features = in_features
|
|
97
|
+
self.out_features = out_features or self.in_features
|
|
98
|
+
if hidden_features:
|
|
99
|
+
self.hidden_features = hidden_features
|
|
100
|
+
else:
|
|
101
|
+
self.hidden_features = (self.in_features, self.in_features)
|
|
102
|
+
self.activation = activation
|
|
103
|
+
|
|
104
|
+
layers = []
|
|
105
|
+
|
|
106
|
+
for before, after in zip(
|
|
107
|
+
(self.in_features, *self.hidden_features),
|
|
108
|
+
(*self.hidden_features, self.out_features),
|
|
109
|
+
):
|
|
110
|
+
layers.extend(
|
|
111
|
+
[
|
|
112
|
+
nn.Linear(in_features=before, out_features=after),
|
|
113
|
+
self.activation(),
|
|
114
|
+
self.normalization(),
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
layers = layers[:-2]
|
|
119
|
+
layers.append(nn.Dropout(p=drop))
|
|
120
|
+
|
|
121
|
+
# Cleaning if we are not using the normalization layer
|
|
122
|
+
layers = list(filter(lambda layer: layer is not None, layers))
|
|
123
|
+
|
|
124
|
+
super().__init__(*layers)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class FeedForwardBlock(nn.Sequential):
|
|
128
|
+
"""Feedforward network block.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
emb_size : int
|
|
133
|
+
Embedding dimension.
|
|
134
|
+
expansion : int
|
|
135
|
+
Expansion factor for the hidden layer size.
|
|
136
|
+
drop_p : float
|
|
137
|
+
Dropout probability.
|
|
138
|
+
activation : type[nn.Module], default=nn.GELU
|
|
139
|
+
Activation function constructor.
|
|
140
|
+
|
|
141
|
+
Examples
|
|
142
|
+
--------
|
|
143
|
+
>>> import torch
|
|
144
|
+
>>> from braindecode.modules import FeedForwardBlock
|
|
145
|
+
>>> module = FeedForwardBlock(emb_size=32, expansion=2, drop_p=0.1)
|
|
146
|
+
>>> inputs = torch.randn(2, 10, 32)
|
|
147
|
+
>>> outputs = module(inputs)
|
|
148
|
+
>>> outputs.shape
|
|
149
|
+
torch.Size([2, 10, 32])
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self, emb_size, expansion, drop_p, activation: type[nn.Module] = nn.GELU
|
|
154
|
+
):
|
|
155
|
+
super().__init__(
|
|
156
|
+
nn.Linear(emb_size, expansion * emb_size),
|
|
157
|
+
activation(),
|
|
158
|
+
nn.Dropout(drop_p),
|
|
159
|
+
nn.Linear(expansion * emb_size, emb_size),
|
|
160
|
+
)
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
from torch.nn.utils.parametrize import register_parametrization
|
|
8
|
+
|
|
9
|
+
from braindecode.util import np_to_th
|
|
10
|
+
|
|
11
|
+
from .parametrization import MaxNormParametrize
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AvgPool2dWithConv(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Compute average pooling using a convolution, to have the dilation parameter.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
kernel_size: (int,int)
|
|
21
|
+
Size of the pooling region.
|
|
22
|
+
stride: (int,int)
|
|
23
|
+
Stride of the pooling operation.
|
|
24
|
+
dilation: int or (int,int)
|
|
25
|
+
Dilation applied to the pooling filter.
|
|
26
|
+
padding: int or (int,int)
|
|
27
|
+
Padding applied before the pooling operation.
|
|
28
|
+
|
|
29
|
+
Examples
|
|
30
|
+
--------
|
|
31
|
+
>>> import torch
|
|
32
|
+
>>> from braindecode.modules import AvgPool2dWithConv
|
|
33
|
+
>>> module = AvgPool2dWithConv(kernel_size=(1, 4), stride=(1, 4))
|
|
34
|
+
>>> inputs = torch.randn(2, 4, 1, 16)
|
|
35
|
+
>>> outputs = module(inputs)
|
|
36
|
+
>>> outputs.shape
|
|
37
|
+
torch.Size([2, 4, 1, 4])
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, kernel_size, stride, dilation=1, padding=0):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.kernel_size = kernel_size
|
|
43
|
+
self.stride = stride
|
|
44
|
+
self.dilation = dilation
|
|
45
|
+
self.padding = padding
|
|
46
|
+
# don't name them "weights" to
|
|
47
|
+
# make sure these are not accidentally used by some procedure
|
|
48
|
+
# that initializes parameters or something
|
|
49
|
+
self._pool_weights = None
|
|
50
|
+
|
|
51
|
+
def forward(self, x):
|
|
52
|
+
# Create weights for the convolution on demand:
|
|
53
|
+
# size or type of x changed...
|
|
54
|
+
in_channels = x.size()[1]
|
|
55
|
+
weight_shape = (
|
|
56
|
+
in_channels,
|
|
57
|
+
1,
|
|
58
|
+
self.kernel_size[0],
|
|
59
|
+
self.kernel_size[1],
|
|
60
|
+
)
|
|
61
|
+
if self._pool_weights is None or (
|
|
62
|
+
(tuple(self._pool_weights.size()) != tuple(weight_shape))
|
|
63
|
+
or (self._pool_weights.is_cuda != x.is_cuda)
|
|
64
|
+
or (self._pool_weights.data.type() != x.data.type())
|
|
65
|
+
):
|
|
66
|
+
n_pool = np.prod(self.kernel_size)
|
|
67
|
+
weights = np_to_th(np.ones(weight_shape, dtype=np.float32) / float(n_pool))
|
|
68
|
+
weights = weights.type_as(x)
|
|
69
|
+
if x.is_cuda:
|
|
70
|
+
weights = weights.cuda()
|
|
71
|
+
self._pool_weights = weights
|
|
72
|
+
|
|
73
|
+
pooled = F.conv2d(
|
|
74
|
+
x,
|
|
75
|
+
self._pool_weights,
|
|
76
|
+
bias=None,
|
|
77
|
+
stride=self.stride,
|
|
78
|
+
dilation=self.dilation,
|
|
79
|
+
padding=self.padding,
|
|
80
|
+
groups=in_channels,
|
|
81
|
+
)
|
|
82
|
+
return pooled
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Conv2dWithConstraint(nn.Conv2d):
|
|
86
|
+
"""2D convolution with max-norm constraint on the weights.
|
|
87
|
+
|
|
88
|
+
Examples
|
|
89
|
+
--------
|
|
90
|
+
>>> import torch
|
|
91
|
+
>>> from braindecode.modules import Conv2dWithConstraint
|
|
92
|
+
>>> module = Conv2dWithConstraint(4, 8, kernel_size=(1, 3), padding=(0, 1), bias=False)
|
|
93
|
+
>>> inputs = torch.randn(2, 4, 1, 64)
|
|
94
|
+
>>> outputs = module(inputs)
|
|
95
|
+
>>> outputs.shape
|
|
96
|
+
torch.Size([2, 8, 1, 64])
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self, *args, max_norm=1, **kwargs):
|
|
100
|
+
super().__init__(*args, **kwargs)
|
|
101
|
+
self.max_norm = max_norm
|
|
102
|
+
# initialize the weights
|
|
103
|
+
nn.init.xavier_uniform_(self.weight, gain=1)
|
|
104
|
+
register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class CombinedConv(nn.Module):
|
|
108
|
+
"""Merged convolutional layer for temporal and spatial convs in Deep4/ShallowFBCSP
|
|
109
|
+
|
|
110
|
+
Numerically equivalent to the separate sequential approach, but this should be faster.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
in_chans : int
|
|
115
|
+
Number of EEG input channels.
|
|
116
|
+
n_filters_time: int
|
|
117
|
+
Number of temporal filters.
|
|
118
|
+
filter_time_length: int
|
|
119
|
+
Length of the temporal filter.
|
|
120
|
+
n_filters_spat: int
|
|
121
|
+
Number of spatial filters.
|
|
122
|
+
bias_time: bool
|
|
123
|
+
Whether to use bias in the temporal conv
|
|
124
|
+
bias_spat: bool
|
|
125
|
+
Whether to use bias in the spatial conv
|
|
126
|
+
|
|
127
|
+
Examples
|
|
128
|
+
--------
|
|
129
|
+
>>> import torch
|
|
130
|
+
>>> from braindecode.modules import CombinedConv
|
|
131
|
+
>>> module = CombinedConv(in_chans=8, n_filters_time=4, n_filters_spat=4, filter_time_length=5)
|
|
132
|
+
>>> inputs = torch.randn(2, 1, 100, 8)
|
|
133
|
+
>>> outputs = module(inputs)
|
|
134
|
+
>>> outputs.shape
|
|
135
|
+
torch.Size([2, 4, 96, 1])
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
in_chans,
|
|
142
|
+
n_filters_time=40,
|
|
143
|
+
n_filters_spat=40,
|
|
144
|
+
filter_time_length=25,
|
|
145
|
+
bias_time=True,
|
|
146
|
+
bias_spat=True,
|
|
147
|
+
):
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.bias_time = bias_time
|
|
150
|
+
self.bias_spat = bias_spat
|
|
151
|
+
self.conv_time = nn.Conv2d(
|
|
152
|
+
1, n_filters_time, (filter_time_length, 1), bias=bias_time, stride=1
|
|
153
|
+
)
|
|
154
|
+
self.conv_spat = nn.Conv2d(
|
|
155
|
+
n_filters_time, n_filters_spat, (1, in_chans), bias=bias_spat, stride=1
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
159
|
+
# Merge time and spat weights
|
|
160
|
+
combined_weight = (
|
|
161
|
+
(self.conv_time.weight * self.conv_spat.weight.permute(1, 0, 2, 3))
|
|
162
|
+
.sum(0)
|
|
163
|
+
.unsqueeze(1)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
bias = None
|
|
167
|
+
calculated_bias: Optional[torch.Tensor] = None
|
|
168
|
+
|
|
169
|
+
# Calculate bias terms
|
|
170
|
+
if self.bias_time:
|
|
171
|
+
time_bias = self.conv_time.bias
|
|
172
|
+
if time_bias is None:
|
|
173
|
+
raise RuntimeError("conv_time.bias is None despite bias_time=True")
|
|
174
|
+
calculated_bias = (
|
|
175
|
+
self.conv_spat.weight.squeeze()
|
|
176
|
+
.sum(-1)
|
|
177
|
+
.mm(time_bias.unsqueeze(-1))
|
|
178
|
+
.squeeze()
|
|
179
|
+
)
|
|
180
|
+
if self.bias_spat:
|
|
181
|
+
spat_bias = self.conv_spat.bias
|
|
182
|
+
if spat_bias is None:
|
|
183
|
+
raise RuntimeError("conv_spat.bias is None despite bias_spat=True")
|
|
184
|
+
if calculated_bias is None:
|
|
185
|
+
calculated_bias = spat_bias
|
|
186
|
+
else:
|
|
187
|
+
calculated_bias = calculated_bias + spat_bias
|
|
188
|
+
|
|
189
|
+
bias = calculated_bias
|
|
190
|
+
|
|
191
|
+
return F.conv2d(x, weight=combined_weight, bias=bias, stride=(1, 1))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class CausalConv1d(nn.Conv1d):
|
|
195
|
+
"""Causal 1-dimensional convolution
|
|
196
|
+
|
|
197
|
+
Code modified from [1]_ and [2]_.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
in_channels : int
|
|
202
|
+
Input channels.
|
|
203
|
+
out_channels : int
|
|
204
|
+
Output channels (number of filters).
|
|
205
|
+
kernel_size : int
|
|
206
|
+
Kernel size.
|
|
207
|
+
dilation : int, optional
|
|
208
|
+
Dilation (number of elements to skip within kernel multiplication).
|
|
209
|
+
Default to 1.
|
|
210
|
+
**kwargs :
|
|
211
|
+
Other keyword arguments to pass to torch.nn.Conv1d, except for
|
|
212
|
+
`padding`!!
|
|
213
|
+
|
|
214
|
+
References
|
|
215
|
+
----------
|
|
216
|
+
.. [1] https://discuss.pytorch.org/t/causal-convolution/3456/4
|
|
217
|
+
.. [2] https://gist.github.com/paultsw/7a9d6e3ce7b70e9e2c61bc9287addefc
|
|
218
|
+
|
|
219
|
+
Examples
|
|
220
|
+
--------
|
|
221
|
+
>>> import torch
|
|
222
|
+
>>> from braindecode.modules import CausalConv1d
|
|
223
|
+
>>> module = CausalConv1d(in_channels=4, out_channels=8, kernel_size=5, dilation=2)
|
|
224
|
+
>>> inputs = torch.randn(2, 4, 128)
|
|
225
|
+
>>> outputs = module(inputs)
|
|
226
|
+
>>> outputs.shape
|
|
227
|
+
torch.Size([2, 8, 128])
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
in_channels,
|
|
233
|
+
out_channels,
|
|
234
|
+
kernel_size,
|
|
235
|
+
dilation=1,
|
|
236
|
+
**kwargs,
|
|
237
|
+
):
|
|
238
|
+
if "padding" in kwargs:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
"The padding parameter is controlled internally by "
|
|
241
|
+
f"{type(self).__name__} class. You should not try to override this"
|
|
242
|
+
" parameter."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
super().__init__(
|
|
246
|
+
in_channels=in_channels,
|
|
247
|
+
out_channels=out_channels,
|
|
248
|
+
kernel_size=kernel_size,
|
|
249
|
+
dilation=dilation,
|
|
250
|
+
padding=(kernel_size - 1) * dilation,
|
|
251
|
+
**kwargs,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def forward(self, X):
|
|
255
|
+
out = F.conv1d(
|
|
256
|
+
X,
|
|
257
|
+
self.weight,
|
|
258
|
+
self.bias,
|
|
259
|
+
stride=self.stride,
|
|
260
|
+
padding=self.padding,
|
|
261
|
+
dilation=self.dilation,
|
|
262
|
+
groups=self.groups,
|
|
263
|
+
)
|
|
264
|
+
return out[..., : -self.padding[0]]
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class DepthwiseConv2d(torch.nn.Conv2d):
|
|
268
|
+
"""
|
|
269
|
+
Depthwise convolution layer.
|
|
270
|
+
|
|
271
|
+
This class implements a depthwise convolution, where each input channel is
|
|
272
|
+
convolved separately with its own filter (channel multiplier), effectively
|
|
273
|
+
performing a spatial convolution independently over each channel.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
in_channels : int
|
|
278
|
+
Number of channels in the input tensor.
|
|
279
|
+
depth_multiplier : int, optional
|
|
280
|
+
Multiplier for the number of output channels. The total number of
|
|
281
|
+
output channels will be `in_channels * depth_multiplier`. Default is 2.
|
|
282
|
+
kernel_size : int or tuple, optional
|
|
283
|
+
Size of the convolutional kernel. Default is 3.
|
|
284
|
+
stride : int or tuple, optional
|
|
285
|
+
Stride of the convolution. Default is 1.
|
|
286
|
+
padding : int or tuple, optional
|
|
287
|
+
Padding added to both sides of the input. Default is 0.
|
|
288
|
+
dilation : int or tuple, optional
|
|
289
|
+
Spacing between kernel elements. Default is 1.
|
|
290
|
+
bias : bool, optional
|
|
291
|
+
If True, adds a learnable bias to the output. Default is True.
|
|
292
|
+
padding_mode : str, optional
|
|
293
|
+
Padding mode to use. Options are 'zeros', 'reflect', 'replicate', or
|
|
294
|
+
'circular'.
|
|
295
|
+
Default is 'zeros'.
|
|
296
|
+
|
|
297
|
+
Examples
|
|
298
|
+
--------
|
|
299
|
+
>>> import torch
|
|
300
|
+
>>> from braindecode.modules import DepthwiseConv2d
|
|
301
|
+
>>> module = DepthwiseConv2d(in_channels=4, depth_multiplier=2, kernel_size=3, padding=1)
|
|
302
|
+
>>> inputs = torch.randn(2, 4, 1, 64)
|
|
303
|
+
>>> outputs = module(inputs)
|
|
304
|
+
>>> outputs.shape
|
|
305
|
+
torch.Size([2, 8, 1, 64])
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self,
|
|
310
|
+
in_channels,
|
|
311
|
+
depth_multiplier=2,
|
|
312
|
+
kernel_size=3,
|
|
313
|
+
stride=1,
|
|
314
|
+
padding=0,
|
|
315
|
+
dilation=1,
|
|
316
|
+
bias=True,
|
|
317
|
+
padding_mode="zeros",
|
|
318
|
+
):
|
|
319
|
+
out_channels = in_channels * depth_multiplier
|
|
320
|
+
super().__init__(
|
|
321
|
+
in_channels=in_channels,
|
|
322
|
+
out_channels=out_channels,
|
|
323
|
+
kernel_size=kernel_size,
|
|
324
|
+
stride=stride,
|
|
325
|
+
padding=padding,
|
|
326
|
+
dilation=dilation,
|
|
327
|
+
groups=in_channels,
|
|
328
|
+
bias=bias,
|
|
329
|
+
padding_mode=padding_mode,
|
|
330
|
+
)
|