braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,317 +0,0 @@
|
|
|
1
|
-
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
-
# Cedric Rommel <cedric.rommel@inria.fr>
|
|
3
|
-
#
|
|
4
|
-
# License: BSD (3-clause)
|
|
5
|
-
from warnings import warn
|
|
6
|
-
|
|
7
|
-
from numpy import prod
|
|
8
|
-
|
|
9
|
-
from torch import nn
|
|
10
|
-
from einops.layers.torch import Rearrange
|
|
11
|
-
from .modules import Ensure4d
|
|
12
|
-
from .eegnet import _glorot_weight_zero_bias
|
|
13
|
-
from .eegitnet import _InceptionBlock, _DepthwiseConv2d
|
|
14
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class EEGInception(EEGModuleMixin, nn.Sequential):
|
|
18
|
-
""" EEG Inception for ERP-based classification
|
|
19
|
-
|
|
20
|
-
--> DEPRECATED <--
|
|
21
|
-
THIS CLASS IS DEPRECATED AND WILL BE REMOVED IN THE RELEASE 0.9 OF
|
|
22
|
-
BRAINDECODE. PLEASE USE braindecode.models.EEGInceptionERP INSTEAD IN THE
|
|
23
|
-
FUTURE.
|
|
24
|
-
|
|
25
|
-
The code for the paper and this model is also available at [Santamaria2020]_
|
|
26
|
-
and an adaptation for PyTorch [2]_.
|
|
27
|
-
|
|
28
|
-
The model is strongly based on the original InceptionNet for an image. The main goal is
|
|
29
|
-
to extract features in parallel with different scales. The authors extracted three scales
|
|
30
|
-
proportional to the window sample size. The network had three parts:
|
|
31
|
-
1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
|
|
32
|
-
for classification.
|
|
33
|
-
|
|
34
|
-
One advantage of the EEG-Inception block is that it allows a network
|
|
35
|
-
to learn simultaneous components of low and high frequency associated with the signal.
|
|
36
|
-
The winners of BEETL Competition/NeurIps 2021 used parts of the model [beetl]_.
|
|
37
|
-
|
|
38
|
-
The model is fully described in [Santamaria2020]_.
|
|
39
|
-
|
|
40
|
-
Notes
|
|
41
|
-
-----
|
|
42
|
-
This implementation is not guaranteed to be correct, has not been checked
|
|
43
|
-
by original authors, only reimplemented from the paper based on [2]_.
|
|
44
|
-
|
|
45
|
-
Parameters
|
|
46
|
-
----------
|
|
47
|
-
drop_prob : float
|
|
48
|
-
Dropout rate inside all the network.
|
|
49
|
-
scales_time: list(int)
|
|
50
|
-
Windows for inception block, must be a list with proportional values of
|
|
51
|
-
the input_size_ms.
|
|
52
|
-
According to the authors: temporal scale (ms) of the convolutions
|
|
53
|
-
on each Inception module.
|
|
54
|
-
This parameter determines the kernel sizes of the filters.
|
|
55
|
-
n_filters : int
|
|
56
|
-
Initial number of convolutional filters. Set to 8 in [Santamaria2020]_.
|
|
57
|
-
activation: nn.Module
|
|
58
|
-
Activation function, default: ELU activation.
|
|
59
|
-
batch_norm_alpha: float
|
|
60
|
-
Momentum for BatchNorm2d.
|
|
61
|
-
depth_multiplier: int
|
|
62
|
-
Depth multiplier for the depthwise convolution.
|
|
63
|
-
pooling_sizes: list(int)
|
|
64
|
-
Pooling sizes for the inception block.
|
|
65
|
-
in_channels : int
|
|
66
|
-
Alias for n_chans.
|
|
67
|
-
n_classes : int
|
|
68
|
-
Alias for n_outputs.
|
|
69
|
-
input_window_samples : int
|
|
70
|
-
Alias for input_window_seconds.
|
|
71
|
-
|
|
72
|
-
References
|
|
73
|
-
----------
|
|
74
|
-
.. [Santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
|
|
75
|
-
Vaquerizo-Villar, F., & Hornero, R. (2020).
|
|
76
|
-
EEG-inception: A novel deep convolutional neural network for assistive
|
|
77
|
-
ERP-based brain-computer interfaces.
|
|
78
|
-
IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
|
|
79
|
-
Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
|
|
80
|
-
.. [2] Grifcc. Implementation of the EEGInception in torch (2022).
|
|
81
|
-
Online: https://github.com/Grifcc/EEG/tree/90e412a407c5242dfc953d5ffb490bdb32faf022
|
|
82
|
-
.. [beetl]_ Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
|
|
83
|
-
Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
|
|
84
|
-
Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
|
|
85
|
-
Lawhern, V.J., Śliwowski, M., Rouanne, V. & Tempczyk, P.. (2022).
|
|
86
|
-
2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &
|
|
87
|
-
Heterogeneous EEG Data Sets. <i>Proceedings of the NeurIPS 2021 Competitions and
|
|
88
|
-
Demonstrations Track</i>, in <i>Proceedings of Machine Learning Research</i>
|
|
89
|
-
176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
|
|
90
|
-
|
|
91
|
-
"""
|
|
92
|
-
|
|
93
|
-
def __init__(
|
|
94
|
-
self,
|
|
95
|
-
n_chans=None,
|
|
96
|
-
n_outputs=None,
|
|
97
|
-
n_times=1000,
|
|
98
|
-
sfreq=128,
|
|
99
|
-
drop_prob=0.5,
|
|
100
|
-
scales_samples_s=(0.5, 0.25, 0.125),
|
|
101
|
-
n_filters=8,
|
|
102
|
-
activation=nn.ELU(),
|
|
103
|
-
batch_norm_alpha=0.01,
|
|
104
|
-
depth_multiplier=2,
|
|
105
|
-
pooling_sizes=(4, 2, 2, 2),
|
|
106
|
-
chs_info=None,
|
|
107
|
-
input_window_seconds=None,
|
|
108
|
-
in_channels=None,
|
|
109
|
-
n_classes=None,
|
|
110
|
-
input_window_samples=None,
|
|
111
|
-
add_log_softmax=True,
|
|
112
|
-
):
|
|
113
|
-
n_chans, n_outputs, n_times, = deprecated_args(
|
|
114
|
-
self,
|
|
115
|
-
('in_channels', 'n_chans', in_channels, n_chans),
|
|
116
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
117
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
118
|
-
)
|
|
119
|
-
super().__init__(
|
|
120
|
-
n_outputs=n_outputs,
|
|
121
|
-
n_chans=n_chans,
|
|
122
|
-
chs_info=chs_info,
|
|
123
|
-
n_times=n_times,
|
|
124
|
-
input_window_seconds=input_window_seconds,
|
|
125
|
-
sfreq=sfreq,
|
|
126
|
-
add_log_softmax=add_log_softmax,
|
|
127
|
-
)
|
|
128
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
129
|
-
del in_channels, n_classes, input_window_samples
|
|
130
|
-
warn(
|
|
131
|
-
"The class EEGInception is deprecated and will be removed in the "
|
|
132
|
-
"release 0.9 of braindecode. Please use "
|
|
133
|
-
"braindecode.models.EEGInceptionERP instead in the future.",
|
|
134
|
-
DeprecationWarning
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
self.drop_prob = drop_prob
|
|
138
|
-
self.n_filters = n_filters
|
|
139
|
-
self.scales_samples_s = scales_samples_s
|
|
140
|
-
self.scales_samples = tuple(
|
|
141
|
-
int(size_s * self.sfreq) for size_s in self.scales_samples_s)
|
|
142
|
-
self.activation = activation
|
|
143
|
-
self.alpha_momentum = batch_norm_alpha
|
|
144
|
-
self.depth_multiplier = depth_multiplier
|
|
145
|
-
self.pooling_sizes = pooling_sizes
|
|
146
|
-
|
|
147
|
-
self.mapping = {
|
|
148
|
-
'classification.1.weight': 'final_layer.fc.weight',
|
|
149
|
-
'classification.1.bias': 'final_layer.fc.bias'}
|
|
150
|
-
|
|
151
|
-
self.add_module("ensuredims", Ensure4d())
|
|
152
|
-
|
|
153
|
-
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
|
|
154
|
-
|
|
155
|
-
# ======== Inception branches ========================
|
|
156
|
-
block11 = self._get_inception_branch_1(
|
|
157
|
-
in_channels=self.n_chans,
|
|
158
|
-
out_channels=self.n_filters,
|
|
159
|
-
kernel_length=self.scales_samples[0],
|
|
160
|
-
alpha_momentum=self.alpha_momentum,
|
|
161
|
-
activation=self.activation,
|
|
162
|
-
drop_prob=self.drop_prob,
|
|
163
|
-
depth_multiplier=self.depth_multiplier,
|
|
164
|
-
)
|
|
165
|
-
block12 = self._get_inception_branch_1(
|
|
166
|
-
in_channels=self.n_chans,
|
|
167
|
-
out_channels=self.n_filters,
|
|
168
|
-
kernel_length=self.scales_samples[1],
|
|
169
|
-
alpha_momentum=self.alpha_momentum,
|
|
170
|
-
activation=self.activation,
|
|
171
|
-
drop_prob=self.drop_prob,
|
|
172
|
-
depth_multiplier=self.depth_multiplier,
|
|
173
|
-
)
|
|
174
|
-
block13 = self._get_inception_branch_1(
|
|
175
|
-
in_channels=self.n_chans,
|
|
176
|
-
out_channels=self.n_filters,
|
|
177
|
-
kernel_length=self.scales_samples[2],
|
|
178
|
-
alpha_momentum=self.alpha_momentum,
|
|
179
|
-
activation=self.activation,
|
|
180
|
-
drop_prob=self.drop_prob,
|
|
181
|
-
depth_multiplier=self.depth_multiplier,
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
self.add_module("inception_block_1", _InceptionBlock((block11, block12, block13)))
|
|
185
|
-
|
|
186
|
-
self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
|
|
187
|
-
|
|
188
|
-
# ======== Inception branches ========================
|
|
189
|
-
n_concat_filters = len(self.scales_samples) * self.n_filters
|
|
190
|
-
n_concat_dw_filters = n_concat_filters * self.depth_multiplier
|
|
191
|
-
block21 = self._get_inception_branch_2(
|
|
192
|
-
in_channels=n_concat_dw_filters,
|
|
193
|
-
out_channels=self.n_filters,
|
|
194
|
-
kernel_length=self.scales_samples[0] // 4,
|
|
195
|
-
alpha_momentum=self.alpha_momentum,
|
|
196
|
-
activation=self.activation,
|
|
197
|
-
drop_prob=self.drop_prob
|
|
198
|
-
)
|
|
199
|
-
block22 = self._get_inception_branch_2(
|
|
200
|
-
in_channels=n_concat_dw_filters,
|
|
201
|
-
out_channels=self.n_filters,
|
|
202
|
-
kernel_length=self.scales_samples[1] // 4,
|
|
203
|
-
alpha_momentum=self.alpha_momentum,
|
|
204
|
-
activation=self.activation,
|
|
205
|
-
drop_prob=self.drop_prob
|
|
206
|
-
)
|
|
207
|
-
block23 = self._get_inception_branch_2(
|
|
208
|
-
in_channels=n_concat_dw_filters,
|
|
209
|
-
out_channels=self.n_filters,
|
|
210
|
-
kernel_length=self.scales_samples[2] // 4,
|
|
211
|
-
alpha_momentum=self.alpha_momentum,
|
|
212
|
-
activation=self.activation,
|
|
213
|
-
drop_prob=self.drop_prob
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
self.add_module(
|
|
217
|
-
"inception_block_2", _InceptionBlock((block21, block22, block23)))
|
|
218
|
-
|
|
219
|
-
self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
|
|
220
|
-
|
|
221
|
-
self.add_module("final_block", nn.Sequential(
|
|
222
|
-
nn.Conv2d(
|
|
223
|
-
n_concat_filters,
|
|
224
|
-
n_concat_filters // 2,
|
|
225
|
-
(1, 8),
|
|
226
|
-
padding="same",
|
|
227
|
-
bias=False
|
|
228
|
-
),
|
|
229
|
-
nn.BatchNorm2d(n_concat_filters // 2,
|
|
230
|
-
momentum=self.alpha_momentum),
|
|
231
|
-
activation,
|
|
232
|
-
nn.Dropout(self.drop_prob),
|
|
233
|
-
nn.AvgPool2d((1, self.pooling_sizes[2])),
|
|
234
|
-
|
|
235
|
-
nn.Conv2d(
|
|
236
|
-
n_concat_filters // 2,
|
|
237
|
-
n_concat_filters // 4,
|
|
238
|
-
(1, 4),
|
|
239
|
-
padding="same",
|
|
240
|
-
bias=False
|
|
241
|
-
),
|
|
242
|
-
nn.BatchNorm2d(n_concat_filters // 4,
|
|
243
|
-
momentum=self.alpha_momentum),
|
|
244
|
-
activation,
|
|
245
|
-
nn.Dropout(self.drop_prob),
|
|
246
|
-
nn.AvgPool2d((1, self.pooling_sizes[3])),
|
|
247
|
-
))
|
|
248
|
-
|
|
249
|
-
spatial_dim_last_layer = (
|
|
250
|
-
self.n_times // prod(self.pooling_sizes))
|
|
251
|
-
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
|
|
252
|
-
|
|
253
|
-
self.add_module("flat", nn.Flatten())
|
|
254
|
-
|
|
255
|
-
module = nn.Sequential()
|
|
256
|
-
|
|
257
|
-
module.add_module("fc",
|
|
258
|
-
nn.Linear(
|
|
259
|
-
spatial_dim_last_layer * n_channels_last_layer,
|
|
260
|
-
self.n_outputs
|
|
261
|
-
), )
|
|
262
|
-
|
|
263
|
-
if self.add_log_softmax:
|
|
264
|
-
module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
|
|
265
|
-
else:
|
|
266
|
-
module.add_module("identity", nn.Identity())
|
|
267
|
-
|
|
268
|
-
# The conv_classifier will be the final_layer and the other ones will be incorporated
|
|
269
|
-
self.add_module("final_layer", module)
|
|
270
|
-
|
|
271
|
-
_glorot_weight_zero_bias(self)
|
|
272
|
-
|
|
273
|
-
@staticmethod
|
|
274
|
-
def _get_inception_branch_1(in_channels, out_channels, kernel_length,
|
|
275
|
-
alpha_momentum, drop_prob, activation,
|
|
276
|
-
depth_multiplier):
|
|
277
|
-
return nn.Sequential(
|
|
278
|
-
nn.Conv2d(
|
|
279
|
-
1,
|
|
280
|
-
out_channels,
|
|
281
|
-
kernel_size=(1, kernel_length),
|
|
282
|
-
padding="same",
|
|
283
|
-
bias=True
|
|
284
|
-
),
|
|
285
|
-
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
286
|
-
activation,
|
|
287
|
-
nn.Dropout(drop_prob),
|
|
288
|
-
_DepthwiseConv2d(
|
|
289
|
-
out_channels,
|
|
290
|
-
kernel_size=(in_channels, 1),
|
|
291
|
-
depth_multiplier=depth_multiplier,
|
|
292
|
-
bias=False,
|
|
293
|
-
padding="valid",
|
|
294
|
-
),
|
|
295
|
-
nn.BatchNorm2d(
|
|
296
|
-
depth_multiplier * out_channels,
|
|
297
|
-
momentum=alpha_momentum
|
|
298
|
-
),
|
|
299
|
-
activation,
|
|
300
|
-
nn.Dropout(drop_prob),
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
@staticmethod
|
|
304
|
-
def _get_inception_branch_2(in_channels, out_channels, kernel_length,
|
|
305
|
-
alpha_momentum, drop_prob, activation):
|
|
306
|
-
return nn.Sequential(
|
|
307
|
-
nn.Conv2d(
|
|
308
|
-
in_channels,
|
|
309
|
-
out_channels,
|
|
310
|
-
kernel_size=(1, kernel_length),
|
|
311
|
-
padding="same",
|
|
312
|
-
bias=False
|
|
313
|
-
),
|
|
314
|
-
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
315
|
-
activation,
|
|
316
|
-
nn.Dropout(drop_prob),
|
|
317
|
-
)
|
braindecode/models/functions.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
-
#
|
|
3
|
-
# License: BSD (3-clause)
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def square(x):
|
|
9
|
-
return x * x
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def safe_log(x, eps=1e-6):
|
|
13
|
-
""" Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
|
|
14
|
-
return torch.log(torch.clamp(x, min=eps))
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def identity(x):
|
|
18
|
-
return x
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def squeeze_final_output(x):
|
|
22
|
-
"""Removes empty dimension at end and potentially removes empty time
|
|
23
|
-
dimension. It does not just use squeeze as we never want to remove
|
|
24
|
-
first dimension.
|
|
25
|
-
|
|
26
|
-
Returns
|
|
27
|
-
-------
|
|
28
|
-
x: torch.Tensor
|
|
29
|
-
squeezed tensor
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
assert x.size()[3] == 1
|
|
33
|
-
x = x[:, :, :, 0]
|
|
34
|
-
if x.size()[2] == 1:
|
|
35
|
-
x = x[:, :, 0]
|
|
36
|
-
return x
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def transpose_time_to_spat(x):
|
|
40
|
-
"""Swap time and spatial dimensions.
|
|
41
|
-
|
|
42
|
-
Returns
|
|
43
|
-
-------
|
|
44
|
-
x: torch.Tensor
|
|
45
|
-
tensor in which last and first dimensions are swapped
|
|
46
|
-
"""
|
|
47
|
-
return x.permute(0, 3, 2, 1)
|