braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Self-supervised learning samplers.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Young Truong <dt.young112@gmail.com>
|
|
7
|
+
#
|
|
8
|
+
# License: BSD (3-clause)
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch.distributed as dist
|
|
14
|
+
|
|
15
|
+
from . import DistributedRecordingSampler, RecordingSampler
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RelativePositioningSampler(RecordingSampler):
|
|
19
|
+
"""Sample examples for the relative positioning task from [Banville2020]_.
|
|
20
|
+
|
|
21
|
+
Sample examples as tuples of two window indices, with a label indicating
|
|
22
|
+
whether the windows are close or far, as defined by tau_pos and tau_neg.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
metadata : pd.DataFrame
|
|
27
|
+
See RecordingSampler.
|
|
28
|
+
tau_pos : int
|
|
29
|
+
Size of the positive context, in samples. A positive pair contains two
|
|
30
|
+
windows x1 and x2 which are separated by at most `tau_pos` samples.
|
|
31
|
+
tau_neg : int
|
|
32
|
+
Size of the negative context, in samples. A negative pair contains two
|
|
33
|
+
windows x1 and x2 which are separated by at least `tau_neg` samples and
|
|
34
|
+
at most `tau_max` samples. Ignored if `same_rec_neg` is False.
|
|
35
|
+
n_examples : int
|
|
36
|
+
Number of pairs to extract.
|
|
37
|
+
tau_max : int | None
|
|
38
|
+
See `tau_neg`.
|
|
39
|
+
same_rec_neg : bool
|
|
40
|
+
If True, sample negative pairs from within the same recording. If
|
|
41
|
+
False, sample negative pairs from two different recordings.
|
|
42
|
+
random_state : None | np.RandomState | int
|
|
43
|
+
Random state.
|
|
44
|
+
|
|
45
|
+
References
|
|
46
|
+
----------
|
|
47
|
+
.. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
|
|
48
|
+
& Gramfort, A. (2020). Uncovering the structure of clinical EEG
|
|
49
|
+
signals with self-supervised learning.
|
|
50
|
+
arXiv preprint arXiv:2007.16104.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
metadata,
|
|
56
|
+
tau_pos,
|
|
57
|
+
tau_neg,
|
|
58
|
+
n_examples,
|
|
59
|
+
tau_max=None,
|
|
60
|
+
same_rec_neg=True,
|
|
61
|
+
random_state=None,
|
|
62
|
+
):
|
|
63
|
+
super().__init__(metadata, random_state=random_state)
|
|
64
|
+
|
|
65
|
+
self.tau_pos = tau_pos
|
|
66
|
+
self.tau_neg = tau_neg
|
|
67
|
+
self.tau_max = np.inf if tau_max is None else tau_max
|
|
68
|
+
self.n_examples = n_examples
|
|
69
|
+
self.same_rec_neg = same_rec_neg
|
|
70
|
+
|
|
71
|
+
if not same_rec_neg and self.n_recordings < 2:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"More than one recording must be available when "
|
|
74
|
+
"using across-recording negative sampling."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _sample_pair(self):
|
|
78
|
+
"""Sample a pair of two windows."""
|
|
79
|
+
# Sample first window
|
|
80
|
+
win_ind1, rec_ind1 = self.sample_window()
|
|
81
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
82
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
83
|
+
|
|
84
|
+
# Decide whether the pair will be positive or negative
|
|
85
|
+
pair_type = self.rng.binomial(1, 0.5)
|
|
86
|
+
win_ind2 = None
|
|
87
|
+
if pair_type == 0: # Negative example
|
|
88
|
+
if self.same_rec_neg:
|
|
89
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
90
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
rec_ind2 = rec_ind1
|
|
94
|
+
while rec_ind2 == rec_ind1:
|
|
95
|
+
win_ind2, rec_ind2 = self.sample_window()
|
|
96
|
+
elif pair_type == 1: # Positive example
|
|
97
|
+
mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
|
|
98
|
+
|
|
99
|
+
if win_ind2 is None:
|
|
100
|
+
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
101
|
+
if sum(mask) == 0:
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
104
|
+
|
|
105
|
+
return win_ind1, win_ind2, float(pair_type)
|
|
106
|
+
|
|
107
|
+
def presample(self):
|
|
108
|
+
"""Presample examples.
|
|
109
|
+
|
|
110
|
+
Once presampled, the examples are the same from one epoch to another.
|
|
111
|
+
"""
|
|
112
|
+
self.examples = [self._sample_pair() for _ in range(self.n_examples)]
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __iter__(self):
|
|
116
|
+
"""
|
|
117
|
+
Iterate over pairs.
|
|
118
|
+
|
|
119
|
+
Yields
|
|
120
|
+
------
|
|
121
|
+
int
|
|
122
|
+
Position of the first window in the dataset.
|
|
123
|
+
int
|
|
124
|
+
Position of the second window in the dataset.
|
|
125
|
+
float
|
|
126
|
+
0 for a negative pair, 1 for a positive pair.
|
|
127
|
+
"""
|
|
128
|
+
for i in range(self.n_examples):
|
|
129
|
+
if hasattr(self, "examples"):
|
|
130
|
+
yield self.examples[i]
|
|
131
|
+
else:
|
|
132
|
+
yield self._sample_pair()
|
|
133
|
+
|
|
134
|
+
def __len__(self):
|
|
135
|
+
return self.n_examples
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class DistributedRelativePositioningSampler(DistributedRecordingSampler):
|
|
139
|
+
"""Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
|
|
140
|
+
|
|
141
|
+
Sample examples as tuples of two window indices, with a label indicating
|
|
142
|
+
whether the windows are close or far, as defined by tau_pos and tau_neg.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
metadata : pd.DataFrame
|
|
147
|
+
See RecordingSampler.
|
|
148
|
+
tau_pos : int
|
|
149
|
+
Size of the positive context, in samples. A positive pair contains two
|
|
150
|
+
windows x1 and x2 which are separated by at most `tau_pos` samples.
|
|
151
|
+
tau_neg : int
|
|
152
|
+
Size of the negative context, in samples. A negative pair contains two
|
|
153
|
+
windows x1 and x2 which are separated by at least `tau_neg` samples and
|
|
154
|
+
at most `tau_max` samples. Ignored if `same_rec_neg` is False.
|
|
155
|
+
n_examples : int
|
|
156
|
+
Number of pairs to extract.
|
|
157
|
+
tau_max : int | None
|
|
158
|
+
See `tau_neg`.
|
|
159
|
+
same_rec_neg : bool
|
|
160
|
+
If True, sample negative pairs from within the same recording. If
|
|
161
|
+
False, sample negative pairs from two different recordings.
|
|
162
|
+
random_state : None | np.RandomState | int
|
|
163
|
+
Random state.
|
|
164
|
+
kwargs: dict
|
|
165
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
166
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
167
|
+
|
|
168
|
+
References
|
|
169
|
+
----------
|
|
170
|
+
.. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
|
|
171
|
+
& Gramfort, A. (2020). Uncovering the structure of clinical EEG
|
|
172
|
+
signals with self-supervised learning.
|
|
173
|
+
arXiv preprint arXiv:2007.16104.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
metadata,
|
|
179
|
+
tau_pos,
|
|
180
|
+
tau_neg,
|
|
181
|
+
n_examples,
|
|
182
|
+
tau_max=None,
|
|
183
|
+
same_rec_neg=True,
|
|
184
|
+
random_state=None,
|
|
185
|
+
**kwargs,
|
|
186
|
+
):
|
|
187
|
+
super().__init__(metadata, random_state=random_state, **kwargs)
|
|
188
|
+
self.tau_pos = tau_pos
|
|
189
|
+
self.tau_neg = tau_neg
|
|
190
|
+
self.tau_max = np.inf if tau_max is None else tau_max
|
|
191
|
+
self.same_rec_neg = same_rec_neg
|
|
192
|
+
|
|
193
|
+
self.n_examples = n_examples * self.n_recordings // self.info.shape[0]
|
|
194
|
+
warnings.warn(
|
|
195
|
+
f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
|
|
196
|
+
)
|
|
197
|
+
warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
|
|
198
|
+
|
|
199
|
+
if not same_rec_neg and self.n_recordings < 2:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"More than one recording must be available when "
|
|
202
|
+
"using across-recording negative sampling."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _sample_pair(self):
|
|
206
|
+
"""Sample a pair of two windows."""
|
|
207
|
+
# Sample first window
|
|
208
|
+
win_ind1, rec_ind1 = self.sample_window()
|
|
209
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
210
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
211
|
+
|
|
212
|
+
# Decide whether the pair will be positive or negative
|
|
213
|
+
pair_type = self.rng.binomial(1, 0.5)
|
|
214
|
+
win_ind2 = None
|
|
215
|
+
if pair_type == 0: # Negative example
|
|
216
|
+
if self.same_rec_neg:
|
|
217
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
218
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
rec_ind2 = rec_ind1
|
|
222
|
+
while rec_ind2 == rec_ind1:
|
|
223
|
+
win_ind2, rec_ind2 = self.sample_window()
|
|
224
|
+
elif pair_type == 1: # Positive example
|
|
225
|
+
mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
|
|
226
|
+
|
|
227
|
+
if win_ind2 is None:
|
|
228
|
+
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
229
|
+
if sum(mask) == 0:
|
|
230
|
+
raise NotImplementedError
|
|
231
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
232
|
+
|
|
233
|
+
return win_ind1, win_ind2, float(pair_type)
|
|
234
|
+
|
|
235
|
+
def presample(self):
|
|
236
|
+
"""Presample examples.
|
|
237
|
+
|
|
238
|
+
Once presampled, the examples are the same from one epoch to another.
|
|
239
|
+
"""
|
|
240
|
+
self.examples = [self._sample_pair() for _ in range(self.n_examples)]
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
def __iter__(self):
|
|
244
|
+
"""
|
|
245
|
+
Iterate over pairs.
|
|
246
|
+
|
|
247
|
+
Yields
|
|
248
|
+
------
|
|
249
|
+
int
|
|
250
|
+
Position of the first window in the dataset.
|
|
251
|
+
int
|
|
252
|
+
Position of the second window in the dataset.
|
|
253
|
+
float
|
|
254
|
+
0 for a negative pair, 1 for a positive pair.
|
|
255
|
+
"""
|
|
256
|
+
for i in range(self.n_examples):
|
|
257
|
+
if hasattr(self, "examples"):
|
|
258
|
+
yield self.examples[i]
|
|
259
|
+
else:
|
|
260
|
+
yield self._sample_pair()
|
|
261
|
+
|
|
262
|
+
def __len__(self):
|
|
263
|
+
return self.n_examples
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for skorch-based training.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
|
|
6
|
+
from .scoring import (
|
|
7
|
+
CroppedTimeSeriesEpochScoring,
|
|
8
|
+
CroppedTrialEpochScoring,
|
|
9
|
+
PostEpochTrainScoring,
|
|
10
|
+
predict_trials,
|
|
11
|
+
trial_preds_from_window_preds,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"CroppedLoss",
|
|
16
|
+
"mixup_criterion",
|
|
17
|
+
"TimeSeriesLoss",
|
|
18
|
+
"CroppedTrialEpochScoring",
|
|
19
|
+
"PostEpochTrainScoring",
|
|
20
|
+
"CroppedTimeSeriesEpochScoring",
|
|
21
|
+
"trial_preds_from_window_preds",
|
|
22
|
+
"predict_trials",
|
|
23
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from skorch.callbacks import Callback
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MaxNormConstraintCallback(Callback):
|
|
10
|
+
def on_batch_end(self, net, training, *args, **kwargs):
|
|
11
|
+
if training:
|
|
12
|
+
model = net.module_
|
|
13
|
+
last_weight = None
|
|
14
|
+
for name, module in list(model.named_children()):
|
|
15
|
+
if hasattr(module, "weight") and (
|
|
16
|
+
not module.__class__.__name__.startswith("BatchNorm")
|
|
17
|
+
):
|
|
18
|
+
module.weight.data = torch.renorm(
|
|
19
|
+
module.weight.data, 2, 0, maxnorm=2
|
|
20
|
+
)
|
|
21
|
+
last_weight = module.weight
|
|
22
|
+
if last_weight is not None:
|
|
23
|
+
last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
# Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
3
|
+
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
4
|
+
#
|
|
5
|
+
# License: BSD (3-clause)
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CroppedLoss(nn.Module):
|
|
12
|
+
"""Compute Loss after averaging predictions across time.
|
|
13
|
+
Assumes predictions are in shape:
|
|
14
|
+
n_batch size x n_classes x n_predictions (in time)"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, loss_function):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.loss_function = loss_function
|
|
19
|
+
|
|
20
|
+
def forward(self, preds, targets):
|
|
21
|
+
"""Forward pass.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
preds: torch.Tensor
|
|
26
|
+
Model's prediction with shape (batch_size, n_classes, n_times).
|
|
27
|
+
targets: torch.Tensor
|
|
28
|
+
Target labels with shape (batch_size, n_classes, n_times).
|
|
29
|
+
"""
|
|
30
|
+
avg_preds = torch.mean(preds, dim=2)
|
|
31
|
+
avg_preds = avg_preds.squeeze(dim=1)
|
|
32
|
+
return self.loss_function(avg_preds, targets)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TimeSeriesLoss(nn.Module):
|
|
36
|
+
"""Compute Loss between timeseries targets and predictions.
|
|
37
|
+
Assumes predictions are in shape:
|
|
38
|
+
n_batch size x n_classes x n_predictions (in time)
|
|
39
|
+
Assumes targets are in shape:
|
|
40
|
+
n_batch size x n_classes x window_len (in time)
|
|
41
|
+
If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for
|
|
42
|
+
predictions valid corresponding to valid target values."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, loss_function):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.loss_function = loss_function
|
|
47
|
+
|
|
48
|
+
def forward(self, preds, targets):
|
|
49
|
+
"""Forward pass.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
preds: torch.Tensor
|
|
54
|
+
Model's prediction with shape (batch_size, n_classes, n_times).
|
|
55
|
+
targets: torch.Tensor
|
|
56
|
+
Target labels with shape (batch_size, n_classes, n_times).
|
|
57
|
+
"""
|
|
58
|
+
n_preds = preds.shape[-1]
|
|
59
|
+
# slice the targets to fit preds shape
|
|
60
|
+
targets = targets[:, :, -n_preds:]
|
|
61
|
+
# create valid targets mask
|
|
62
|
+
mask = ~torch.isnan(targets)
|
|
63
|
+
# select valid targets that have a matching predictions
|
|
64
|
+
masked_targets = targets[mask]
|
|
65
|
+
masked_preds = preds[mask]
|
|
66
|
+
return self.loss_function(masked_preds, masked_targets)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def mixup_criterion(preds, target):
|
|
70
|
+
"""Implements loss for Mixup for EEG data. See [1]_.
|
|
71
|
+
|
|
72
|
+
Implementation based on [2]_.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
preds : torch.Tensor
|
|
77
|
+
Predictions from the model.
|
|
78
|
+
target : torch.Tensor | list of torch.Tensor
|
|
79
|
+
For predictions without mixup, the targets as a tensor. If mixup has
|
|
80
|
+
been applied, a list containing the targets of the two mixed
|
|
81
|
+
samples and the mixing coefficients as tensors.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
loss : float
|
|
86
|
+
The loss value.
|
|
87
|
+
|
|
88
|
+
References
|
|
89
|
+
----------
|
|
90
|
+
.. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
|
|
91
|
+
mixup: Beyond Empirical Risk Minimization
|
|
92
|
+
Online: https://arxiv.org/abs/1710.09412
|
|
93
|
+
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
94
|
+
"""
|
|
95
|
+
if len(target) == 3:
|
|
96
|
+
# unpack target
|
|
97
|
+
y_a, y_b, lam = target
|
|
98
|
+
# compute loss per sample
|
|
99
|
+
loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
|
|
100
|
+
loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
|
|
101
|
+
# compute weighted mean
|
|
102
|
+
ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
|
|
103
|
+
return ret.mean()
|
|
104
|
+
else:
|
|
105
|
+
return torch.nn.functional.nll_loss(preds, target)
|