braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/usleep.py
CHANGED
|
@@ -3,122 +3,33 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
9
|
from torch import nn
|
|
9
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def _crop_tensors_to_match(x1, x2, axis=-1):
|
|
13
|
-
"""Crops two tensors to their lowest-common-dimension along an axis."""
|
|
14
|
-
dim_cropped = min(x1.shape[axis], x2.shape[axis])
|
|
15
|
-
|
|
16
|
-
x1_cropped = torch.index_select(
|
|
17
|
-
x1, dim=axis,
|
|
18
|
-
index=torch.arange(dim_cropped).to(device=x1.device)
|
|
19
|
-
)
|
|
20
|
-
x2_cropped = torch.index_select(
|
|
21
|
-
x2, dim=axis,
|
|
22
|
-
index=torch.arange(dim_cropped).to(device=x1.device)
|
|
23
|
-
)
|
|
24
|
-
return x1_cropped, x2_cropped
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class _EncoderBlock(nn.Module):
|
|
28
|
-
"""Encoding block for a timeseries x of shape (B, C, T)."""
|
|
29
|
-
|
|
30
|
-
def __init__(self,
|
|
31
|
-
in_channels=2,
|
|
32
|
-
out_channels=2,
|
|
33
|
-
kernel_size=9,
|
|
34
|
-
downsample=2):
|
|
35
|
-
super().__init__()
|
|
36
|
-
self.in_channels = in_channels
|
|
37
|
-
self.out_channels = out_channels
|
|
38
|
-
self.kernel_size = kernel_size
|
|
39
|
-
self.downsample = downsample
|
|
40
|
-
|
|
41
|
-
self.block_prepool = nn.Sequential(
|
|
42
|
-
nn.Conv1d(in_channels=in_channels,
|
|
43
|
-
out_channels=out_channels,
|
|
44
|
-
kernel_size=kernel_size,
|
|
45
|
-
padding='same'),
|
|
46
|
-
nn.ELU(),
|
|
47
|
-
nn.BatchNorm1d(num_features=out_channels),
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
self.pad = nn.ConstantPad1d(padding=1, value=0)
|
|
51
|
-
self.maxpool = nn.MaxPool1d(
|
|
52
|
-
kernel_size=self.downsample, stride=self.downsample)
|
|
53
|
-
|
|
54
|
-
def forward(self, x):
|
|
55
|
-
x = self.block_prepool(x)
|
|
56
|
-
residual = x
|
|
57
|
-
if x.shape[-1] % 2:
|
|
58
|
-
x = self.pad(x)
|
|
59
|
-
x = self.maxpool(x)
|
|
60
|
-
return x, residual
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class _DecoderBlock(nn.Module):
|
|
64
|
-
"""Decoding block for a timeseries x of shape (B, C, T)."""
|
|
65
|
-
|
|
66
|
-
def __init__(self,
|
|
67
|
-
in_channels=2,
|
|
68
|
-
out_channels=2,
|
|
69
|
-
kernel_size=9,
|
|
70
|
-
upsample=2,
|
|
71
|
-
with_skip_connection=True):
|
|
72
|
-
super().__init__()
|
|
73
|
-
self.in_channels = in_channels
|
|
74
|
-
self.out_channels = out_channels
|
|
75
|
-
self.kernel_size = kernel_size
|
|
76
|
-
self.upsample = upsample
|
|
77
|
-
self.with_skip_connection = with_skip_connection
|
|
78
|
-
|
|
79
|
-
self.block_preskip = nn.Sequential(
|
|
80
|
-
nn.Upsample(scale_factor=upsample),
|
|
81
|
-
nn.Conv1d(in_channels=in_channels,
|
|
82
|
-
out_channels=out_channels,
|
|
83
|
-
kernel_size=2,
|
|
84
|
-
padding='same'),
|
|
85
|
-
nn.ELU(),
|
|
86
|
-
nn.BatchNorm1d(num_features=out_channels),
|
|
87
|
-
)
|
|
88
|
-
self.block_postskip = nn.Sequential(
|
|
89
|
-
nn.Conv1d(
|
|
90
|
-
in_channels=(
|
|
91
|
-
2 * out_channels if with_skip_connection else out_channels),
|
|
92
|
-
out_channels=out_channels,
|
|
93
|
-
kernel_size=kernel_size,
|
|
94
|
-
padding='same'),
|
|
95
|
-
nn.ELU(),
|
|
96
|
-
nn.BatchNorm1d(num_features=out_channels),
|
|
97
|
-
)
|
|
98
10
|
|
|
99
|
-
|
|
100
|
-
x = self.block_preskip(x)
|
|
101
|
-
if self.with_skip_connection:
|
|
102
|
-
x, residual = _crop_tensors_to_match(x, residual, axis=-1) # in case of mismatch
|
|
103
|
-
x = torch.cat([x, residual], axis=1) # (B, 2 * C, T)
|
|
104
|
-
x = self.block_postskip(x)
|
|
105
|
-
return x
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
106
12
|
|
|
107
13
|
|
|
108
14
|
class USleep(EEGModuleMixin, nn.Module):
|
|
109
|
-
"""
|
|
15
|
+
"""
|
|
16
|
+
Sleep staging architecture from Perslev et al. (2021) [1]_.
|
|
17
|
+
|
|
18
|
+
.. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41746-021-00440-5/MediaObjects/41746_2021_440_Fig2_HTML.png
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: USleep Architecture
|
|
110
21
|
|
|
111
22
|
U-Net (autoencoder with skip connections) feature-extractor for sleep
|
|
112
23
|
staging described in [1]_.
|
|
113
24
|
|
|
114
25
|
For the encoder ('down'):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
26
|
+
- the temporal dimension shrinks (via maxpooling in the time-domain)
|
|
27
|
+
- the spatial dimension expands (via more conv1d filters in the time-domain)
|
|
28
|
+
|
|
118
29
|
For the decoder ('up'):
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
30
|
+
- the temporal dimension expands (via upsampling in the time-domain)
|
|
31
|
+
- the spatial dimension shrinks (via fewer conv1d filters in the time-domain)
|
|
32
|
+
|
|
122
33
|
Both do so at exponential rates.
|
|
123
34
|
|
|
124
35
|
Parameters
|
|
@@ -128,12 +39,12 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
128
39
|
sfreq : float
|
|
129
40
|
EEG sampling frequency. Set to 128 in [1]_.
|
|
130
41
|
depth : int
|
|
131
|
-
Number of conv blocks in encoding layer (number of 2x2 max pools)
|
|
132
|
-
Note: each block
|
|
42
|
+
Number of conv blocks in encoding layer (number of 2x2 max pools).
|
|
43
|
+
Note: each block halves the spatial dimensions of the features.
|
|
133
44
|
n_time_filters : int
|
|
134
45
|
Initial number of convolutional filters. Set to 5 in [1]_.
|
|
135
46
|
complexity_factor : float
|
|
136
|
-
Multiplicative factor for number of channels at each layer of the U-Net.
|
|
47
|
+
Multiplicative factor for the number of channels at each layer of the U-Net.
|
|
137
48
|
Set to 2 in [1]_.
|
|
138
49
|
with_skip_connection : bool
|
|
139
50
|
If True, use skip connections in decoder blocks.
|
|
@@ -147,47 +58,35 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
147
58
|
ensure_odd_conv_size : bool
|
|
148
59
|
If True and the size of the convolutional kernel is an even number, one
|
|
149
60
|
will be added to it to ensure it is odd, so that the decoder blocks can
|
|
150
|
-
work. This can
|
|
61
|
+
work. This can be useful when using different sampling rates from 128
|
|
151
62
|
or 100 Hz.
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
Alias for n_outputs.
|
|
156
|
-
input_size_s : float
|
|
157
|
-
Alias for input_window_seconds.
|
|
63
|
+
activation : nn.Module, default=nn.ELU
|
|
64
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
65
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
158
66
|
|
|
159
67
|
References
|
|
160
68
|
----------
|
|
161
69
|
.. [1] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ, Igel C.
|
|
162
|
-
|
|
163
|
-
|
|
70
|
+
U-Sleep: resilient high-frequency sleep staging. *npj Digit. Med.* 4, 72 (2021).
|
|
71
|
+
https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
|
|
164
72
|
"""
|
|
165
73
|
|
|
166
74
|
def __init__(
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
n_classes=None,
|
|
182
|
-
input_size_s=None,
|
|
183
|
-
add_log_softmax=False,
|
|
75
|
+
self,
|
|
76
|
+
n_chans=None,
|
|
77
|
+
sfreq=None,
|
|
78
|
+
depth=12,
|
|
79
|
+
n_time_filters=5,
|
|
80
|
+
complexity_factor=1.67,
|
|
81
|
+
with_skip_connection=True,
|
|
82
|
+
n_outputs=5,
|
|
83
|
+
input_window_seconds=None,
|
|
84
|
+
time_conv_size_s=9 / 128,
|
|
85
|
+
ensure_odd_conv_size=False,
|
|
86
|
+
activation: nn.Module = nn.ELU,
|
|
87
|
+
chs_info=None,
|
|
88
|
+
n_times=None,
|
|
184
89
|
):
|
|
185
|
-
n_chans, n_outputs, input_window_seconds = deprecated_args(
|
|
186
|
-
self,
|
|
187
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
188
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
189
|
-
("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
|
|
190
|
-
)
|
|
191
90
|
super().__init__(
|
|
192
91
|
n_outputs=n_outputs,
|
|
193
92
|
n_chans=n_chans,
|
|
@@ -195,16 +94,14 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
195
94
|
n_times=n_times,
|
|
196
95
|
input_window_seconds=input_window_seconds,
|
|
197
96
|
sfreq=sfreq,
|
|
198
|
-
add_log_softmax=add_log_softmax,
|
|
199
97
|
)
|
|
200
98
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
201
|
-
del in_chans, n_classes, input_size_s
|
|
202
99
|
|
|
203
100
|
self.mapping = {
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
101
|
+
"clf.3.weight": "final_layer.0.weight",
|
|
102
|
+
"clf.3.bias": "final_layer.0.bias",
|
|
103
|
+
"clf.5.weight": "final_layer.2.weight",
|
|
104
|
+
"clf.5.bias": "final_layer.2.bias",
|
|
208
105
|
}
|
|
209
106
|
|
|
210
107
|
max_pool_size = 2 # Hardcoded to avoid dimensional errors
|
|
@@ -214,8 +111,9 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
214
111
|
time_conv_size += 1
|
|
215
112
|
else:
|
|
216
113
|
raise ValueError(
|
|
217
|
-
|
|
218
|
-
|
|
114
|
+
"time_conv_size must be an odd number to accommodate the "
|
|
115
|
+
"upsampling step in the decoder blocks."
|
|
116
|
+
)
|
|
219
117
|
|
|
220
118
|
channels = [self.n_chans]
|
|
221
119
|
n_filters = n_time_filters
|
|
@@ -225,38 +123,42 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
225
123
|
self.channels = channels
|
|
226
124
|
|
|
227
125
|
# Instantiate encoder
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
126
|
+
self.encoder_blocks = nn.ModuleList(
|
|
127
|
+
_EncoderBlock(
|
|
128
|
+
in_channels=channels[idx],
|
|
129
|
+
out_channels=channels[idx + 1],
|
|
130
|
+
kernel_size=time_conv_size,
|
|
131
|
+
downsample=max_pool_size,
|
|
132
|
+
activation=activation,
|
|
133
|
+
)
|
|
134
|
+
for idx in range(depth)
|
|
135
|
+
)
|
|
237
136
|
|
|
238
137
|
# Instantiate bottom (channels increase, temporal dim stays the same)
|
|
239
138
|
self.bottom = nn.Sequential(
|
|
240
|
-
nn.Conv1d(
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
139
|
+
nn.Conv1d(
|
|
140
|
+
in_channels=channels[-2],
|
|
141
|
+
out_channels=channels[-1],
|
|
142
|
+
kernel_size=time_conv_size,
|
|
143
|
+
padding=(time_conv_size - 1) // 2,
|
|
144
|
+
), # preserves dimension
|
|
145
|
+
activation(),
|
|
245
146
|
nn.BatchNorm1d(num_features=channels[-1]),
|
|
246
147
|
)
|
|
247
148
|
|
|
248
149
|
# Instantiate decoder
|
|
249
|
-
decoder = list()
|
|
250
150
|
channels_reverse = channels[::-1]
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
151
|
+
self.decoder_blocks = nn.ModuleList(
|
|
152
|
+
_DecoderBlock(
|
|
153
|
+
in_channels=channels_reverse[idx],
|
|
154
|
+
out_channels=channels_reverse[idx + 1],
|
|
155
|
+
kernel_size=time_conv_size,
|
|
156
|
+
upsample=max_pool_size,
|
|
157
|
+
with_skip_connection=with_skip_connection,
|
|
158
|
+
activation=activation,
|
|
159
|
+
)
|
|
160
|
+
for idx in range(depth)
|
|
161
|
+
)
|
|
260
162
|
|
|
261
163
|
# The temporal dimension remains unchanged
|
|
262
164
|
# (except through the AvgPooling which collapses it to 1)
|
|
@@ -282,7 +184,7 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
282
184
|
stride=1,
|
|
283
185
|
padding=0,
|
|
284
186
|
), # output is (B, n_classes, S)
|
|
285
|
-
|
|
187
|
+
activation(),
|
|
286
188
|
nn.Conv1d(
|
|
287
189
|
in_channels=self.n_outputs,
|
|
288
190
|
out_channels=self.n_outputs,
|
|
@@ -290,11 +192,11 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
290
192
|
stride=1,
|
|
291
193
|
padding=0,
|
|
292
194
|
),
|
|
293
|
-
nn.
|
|
195
|
+
nn.Identity(),
|
|
294
196
|
# output is (B, n_classes, S)
|
|
295
197
|
)
|
|
296
198
|
|
|
297
|
-
def forward(self, x):
|
|
199
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
298
200
|
"""If input x has shape (B, S, C, T), return y_pred of shape (B, n_classes, S).
|
|
299
201
|
If input x has shape (B, C, T), return y_pred of shape (B, n_classes).
|
|
300
202
|
"""
|
|
@@ -305,7 +207,7 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
305
207
|
|
|
306
208
|
# encoder
|
|
307
209
|
residuals = []
|
|
308
|
-
for down in self.
|
|
210
|
+
for down in self.encoder_blocks:
|
|
309
211
|
x, res = down(x)
|
|
310
212
|
residuals.append(res)
|
|
311
213
|
|
|
@@ -313,9 +215,11 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
313
215
|
x = self.bottom(x)
|
|
314
216
|
|
|
315
217
|
# decoder
|
|
316
|
-
|
|
317
|
-
for
|
|
318
|
-
|
|
218
|
+
num_blocks = len(self.decoder_blocks) # statically known
|
|
219
|
+
for idx, dec in enumerate(self.decoder_blocks):
|
|
220
|
+
# pick the matching residual in reverse order
|
|
221
|
+
res = residuals[num_blocks - 1 - idx]
|
|
222
|
+
x = dec(x, res)
|
|
319
223
|
|
|
320
224
|
# classifier
|
|
321
225
|
x = self.clf(x)
|
|
@@ -325,3 +229,112 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
325
229
|
y_pred = y_pred[:, :, 0]
|
|
326
230
|
|
|
327
231
|
return y_pred
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class _EncoderBlock(nn.Module):
|
|
235
|
+
"""Encoding block for a timeseries x of shape (B, C, T)."""
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
in_channels=2,
|
|
240
|
+
out_channels=2,
|
|
241
|
+
kernel_size=9,
|
|
242
|
+
downsample=2,
|
|
243
|
+
activation: nn.Module = nn.ELU,
|
|
244
|
+
):
|
|
245
|
+
super().__init__()
|
|
246
|
+
self.in_channels = in_channels
|
|
247
|
+
self.out_channels = out_channels
|
|
248
|
+
self.kernel_size = kernel_size
|
|
249
|
+
self.downsample = downsample
|
|
250
|
+
|
|
251
|
+
self.block_prepool = nn.Sequential(
|
|
252
|
+
nn.Conv1d(
|
|
253
|
+
in_channels=in_channels,
|
|
254
|
+
out_channels=out_channels,
|
|
255
|
+
kernel_size=kernel_size,
|
|
256
|
+
padding="same",
|
|
257
|
+
),
|
|
258
|
+
activation(),
|
|
259
|
+
nn.BatchNorm1d(num_features=out_channels),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
self.pad = nn.ConstantPad1d(padding=1, value=0.0)
|
|
263
|
+
self.maxpool = nn.MaxPool1d(kernel_size=self.downsample, stride=self.downsample)
|
|
264
|
+
|
|
265
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
266
|
+
x = self.block_prepool(x)
|
|
267
|
+
residual = x
|
|
268
|
+
if x.shape[-1] % 2:
|
|
269
|
+
x = self.pad(x)
|
|
270
|
+
x = self.maxpool(x)
|
|
271
|
+
return x, residual
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class _DecoderBlock(nn.Module):
|
|
275
|
+
"""Decoding block for a timeseries x of shape (B, C, T)."""
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
in_channels=2,
|
|
280
|
+
out_channels=2,
|
|
281
|
+
kernel_size=9,
|
|
282
|
+
upsample=2,
|
|
283
|
+
with_skip_connection=True,
|
|
284
|
+
activation: nn.Module = nn.ELU,
|
|
285
|
+
):
|
|
286
|
+
super().__init__()
|
|
287
|
+
self.in_channels = in_channels
|
|
288
|
+
self.out_channels = out_channels
|
|
289
|
+
self.kernel_size = kernel_size
|
|
290
|
+
self.upsample = upsample
|
|
291
|
+
self.with_skip_connection = with_skip_connection
|
|
292
|
+
|
|
293
|
+
self.block_preskip = nn.Sequential(
|
|
294
|
+
nn.Upsample(scale_factor=upsample),
|
|
295
|
+
nn.Conv1d(
|
|
296
|
+
in_channels=in_channels,
|
|
297
|
+
out_channels=out_channels,
|
|
298
|
+
kernel_size=2,
|
|
299
|
+
padding="same",
|
|
300
|
+
),
|
|
301
|
+
activation(),
|
|
302
|
+
nn.BatchNorm1d(num_features=out_channels),
|
|
303
|
+
)
|
|
304
|
+
self.block_postskip = nn.Sequential(
|
|
305
|
+
nn.Conv1d(
|
|
306
|
+
in_channels=(
|
|
307
|
+
2 * out_channels if with_skip_connection else out_channels
|
|
308
|
+
),
|
|
309
|
+
out_channels=out_channels,
|
|
310
|
+
kernel_size=kernel_size,
|
|
311
|
+
padding="same",
|
|
312
|
+
),
|
|
313
|
+
activation(),
|
|
314
|
+
nn.BatchNorm1d(num_features=out_channels),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
|
318
|
+
x = self.block_preskip(x)
|
|
319
|
+
if self.with_skip_connection:
|
|
320
|
+
x, residual = self._crop_tensors_to_match(
|
|
321
|
+
x, residual, axis=-1
|
|
322
|
+
) # in case of mismatch
|
|
323
|
+
x = torch.cat([x, residual], dim=1) # (B, 2 * C, T)
|
|
324
|
+
x = self.block_postskip(x)
|
|
325
|
+
return x
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def _crop_tensors_to_match(
|
|
329
|
+
x1: torch.Tensor, x2: torch.Tensor, axis: int = -1
|
|
330
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
331
|
+
"""Crops two tensors to their lowest-common-dimension along an axis."""
|
|
332
|
+
dim_cropped = min(x1.shape[axis], x2.shape[axis])
|
|
333
|
+
|
|
334
|
+
x1_cropped = torch.index_select(
|
|
335
|
+
x1, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
|
|
336
|
+
)
|
|
337
|
+
x2_cropped = torch.index_select(
|
|
338
|
+
x2, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
|
|
339
|
+
)
|
|
340
|
+
return x1_cropped, x2_cropped
|