braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/regressor.py
CHANGED
|
@@ -11,8 +11,8 @@ import warnings
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from skorch.regressor import NeuralNetRegressor
|
|
13
13
|
|
|
14
|
-
from .training.scoring import predict_trials
|
|
15
14
|
from .eegneuralnet import _EEGNeuralNet
|
|
15
|
+
from .training.scoring import predict_trials
|
|
16
16
|
from .util import ThrowAwayIndexLoader, update_estimator_docstring
|
|
17
17
|
|
|
18
18
|
|
|
@@ -58,19 +58,28 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
58
58
|
""" # noqa: E501
|
|
59
59
|
__doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
|
|
60
60
|
|
|
61
|
-
def __init__(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
module,
|
|
64
|
+
*args,
|
|
65
|
+
cropped=False,
|
|
66
|
+
callbacks=None,
|
|
67
|
+
iterator_train__shuffle=True,
|
|
68
|
+
iterator_train__drop_last=True,
|
|
69
|
+
aggregate_predictions=True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
65
72
|
self.cropped = cropped
|
|
66
73
|
self.aggregate_predictions = aggregate_predictions
|
|
67
74
|
self._last_window_inds_ = None
|
|
68
|
-
super().__init__(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
super().__init__(
|
|
76
|
+
module,
|
|
77
|
+
*args,
|
|
78
|
+
callbacks=callbacks,
|
|
79
|
+
iterator_train__shuffle=iterator_train__shuffle,
|
|
80
|
+
iterator_train__drop_last=iterator_train__drop_last,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
74
83
|
|
|
75
84
|
def get_iterator(self, dataset, training=False, drop_index=True):
|
|
76
85
|
iterator = super().get_iterator(dataset, training=training)
|
|
@@ -155,7 +164,9 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
155
164
|
warnings.warn(
|
|
156
165
|
"This method was designed to predict trials in cropped mode. "
|
|
157
166
|
"Calling it when cropped is False will give the same result as "
|
|
158
|
-
"'.predict'.",
|
|
167
|
+
"'.predict'.",
|
|
168
|
+
UserWarning,
|
|
169
|
+
)
|
|
159
170
|
preds = self.predict(X)
|
|
160
171
|
if return_targets:
|
|
161
172
|
return preds, np.concatenate([X[i][1] for i in range(len(X))])
|
braindecode/samplers/__init__.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
|
-
"""Classes to sample examples.
|
|
2
|
-
"""
|
|
1
|
+
"""Classes to sample examples."""
|
|
3
2
|
|
|
4
|
-
from .base import
|
|
5
|
-
|
|
3
|
+
from .base import (
|
|
4
|
+
BalancedSequenceSampler,
|
|
5
|
+
DistributedRecordingSampler,
|
|
6
|
+
RecordingSampler,
|
|
7
|
+
SequenceSampler,
|
|
8
|
+
)
|
|
9
|
+
from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
|
|
6
10
|
|
|
7
|
-
__all__ = [
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RecordingSampler",
|
|
13
|
+
"SequenceSampler",
|
|
14
|
+
"BalancedSequenceSampler",
|
|
15
|
+
"RelativePositioningSampler",
|
|
16
|
+
"DistributedRecordingSampler",
|
|
17
|
+
"DistributedRelativePositioningSampler",
|
|
18
|
+
]
|
braindecode/samplers/base.py
CHANGED
|
@@ -4,17 +4,120 @@ Sampler classes.
|
|
|
4
4
|
|
|
5
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
6
|
# Theo Gnassounou <>
|
|
7
|
+
# Young Truong <dt.young112@gmail.com>
|
|
7
8
|
#
|
|
8
9
|
# License: BSD (3-clause)
|
|
9
10
|
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
10
13
|
import numpy as np
|
|
11
|
-
from torch.utils.data.sampler import Sampler
|
|
12
14
|
from sklearn.utils import check_random_state
|
|
15
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
16
|
+
from torch.utils.data.sampler import Sampler
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
class RecordingSampler(Sampler):
|
|
16
20
|
"""Base sampler simplifying sampling from recordings.
|
|
17
21
|
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
metadata : pd.DataFrame
|
|
25
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
26
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
27
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
28
|
+
`metadata.head()` might look like this:
|
|
29
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
30
|
+
| i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
|
|
31
|
+
+===================+=================+=================+========+==========+===========+=======+
|
|
32
|
+
| 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
|
|
33
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
34
|
+
| 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
|
|
35
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
36
|
+
| 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
|
|
37
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
38
|
+
| 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
|
|
39
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
40
|
+
| 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
|
|
41
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
42
|
+
|
|
43
|
+
random_state : np.RandomState | int | None
|
|
44
|
+
Random state.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
info : pd.DataFrame
|
|
49
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
50
|
+
and window indices information in an easily accessible structure for
|
|
51
|
+
quick sampling of windows.
|
|
52
|
+
n_recordings : int
|
|
53
|
+
Number of recordings available.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, metadata, random_state=None):
|
|
57
|
+
self.metadata = metadata
|
|
58
|
+
self.info = self._init_info(metadata)
|
|
59
|
+
self.rng = check_random_state(random_state)
|
|
60
|
+
|
|
61
|
+
def _init_info(self, metadata, required_keys=None):
|
|
62
|
+
"""Initialize ``info`` DataFrame.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
required_keys : list(str) | None
|
|
67
|
+
List of additional columns of the metadata DataFrame that we should
|
|
68
|
+
groupby when creating ``info``.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
See class attributes.
|
|
73
|
+
"""
|
|
74
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
75
|
+
if not keys:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"metadata must contain at least one of the following columns: "
|
|
78
|
+
"subject, session or run."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if required_keys is not None:
|
|
82
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
83
|
+
if len(missing_keys) > 0:
|
|
84
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
85
|
+
keys += required_keys
|
|
86
|
+
|
|
87
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
88
|
+
info = (
|
|
89
|
+
metadata.reset_index()
|
|
90
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
91
|
+
.agg(["unique"])
|
|
92
|
+
)
|
|
93
|
+
info.columns = info.columns.get_level_values(0)
|
|
94
|
+
|
|
95
|
+
return info
|
|
96
|
+
|
|
97
|
+
def sample_recording(self):
|
|
98
|
+
"""Return a random recording index."""
|
|
99
|
+
# XXX docstring missing
|
|
100
|
+
return self.rng.choice(self.n_recordings)
|
|
101
|
+
|
|
102
|
+
def sample_window(self, rec_ind=None):
|
|
103
|
+
"""Return a specific window."""
|
|
104
|
+
# XXX docstring missing
|
|
105
|
+
if rec_ind is None:
|
|
106
|
+
rec_ind = self.sample_recording()
|
|
107
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
108
|
+
return win_ind, rec_ind
|
|
109
|
+
|
|
110
|
+
def __iter__(self):
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def n_recordings(self):
|
|
115
|
+
return self.info.shape[0]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DistributedRecordingSampler(DistributedSampler):
|
|
119
|
+
"""Base sampler simplifying sampling from recordings in distributed setting.
|
|
120
|
+
|
|
18
121
|
Parameters
|
|
19
122
|
----------
|
|
20
123
|
metadata : pd.DataFrame
|
|
@@ -41,11 +144,22 @@ class RecordingSampler(Sampler):
|
|
|
41
144
|
quick sampling of windows.
|
|
42
145
|
n_recordings : int
|
|
43
146
|
Number of recordings available.
|
|
147
|
+
kwargs : dict
|
|
148
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
149
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
44
150
|
"""
|
|
45
|
-
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
metadata,
|
|
155
|
+
random_state=None,
|
|
156
|
+
**kwargs,
|
|
157
|
+
):
|
|
46
158
|
self.metadata = metadata
|
|
47
159
|
self.info = self._init_info(metadata)
|
|
48
160
|
self.rng = check_random_state(random_state)
|
|
161
|
+
# send information to DistributedSampler parent to handle data splitting among workers
|
|
162
|
+
super().__init__(self.info, seed=random_state, **kwargs)
|
|
49
163
|
|
|
50
164
|
def _init_info(self, metadata, required_keys=None):
|
|
51
165
|
"""Initialize ``info`` DataFrame.
|
|
@@ -60,50 +174,48 @@ class RecordingSampler(Sampler):
|
|
|
60
174
|
-------
|
|
61
175
|
See class attributes.
|
|
62
176
|
"""
|
|
63
|
-
keys = [k for k in [
|
|
64
|
-
if k in self.metadata.columns]
|
|
177
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
65
178
|
if not keys:
|
|
66
179
|
raise ValueError(
|
|
67
|
-
|
|
68
|
-
|
|
180
|
+
"metadata must contain at least one of the following columns: "
|
|
181
|
+
"subject, session or run."
|
|
182
|
+
)
|
|
69
183
|
|
|
70
184
|
if required_keys is not None:
|
|
71
|
-
missing_keys = [
|
|
72
|
-
k for k in required_keys if k not in self.metadata.columns]
|
|
185
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
73
186
|
if len(missing_keys) > 0:
|
|
74
|
-
raise ValueError(
|
|
75
|
-
f'Columns {missing_keys} were not found in metadata.')
|
|
187
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
76
188
|
keys += required_keys
|
|
77
189
|
|
|
78
|
-
metadata = metadata.reset_index().rename(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
[
|
|
190
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
191
|
+
info = (
|
|
192
|
+
metadata.reset_index()
|
|
193
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
194
|
+
.agg(["unique"])
|
|
195
|
+
)
|
|
82
196
|
info.columns = info.columns.get_level_values(0)
|
|
83
197
|
|
|
84
198
|
return info
|
|
85
199
|
|
|
86
200
|
def sample_recording(self):
|
|
87
201
|
"""Return a random recording index.
|
|
202
|
+
super().__iter__() contains indices of datasets specific to the current process
|
|
203
|
+
determined by the DistributedSampler
|
|
88
204
|
"""
|
|
89
205
|
# XXX docstring missing
|
|
90
|
-
return self.rng.choice(
|
|
206
|
+
return self.rng.choice(list(super().__iter__()))
|
|
91
207
|
|
|
92
208
|
def sample_window(self, rec_ind=None):
|
|
93
|
-
"""Return a specific window.
|
|
94
|
-
"""
|
|
209
|
+
"""Return a specific window."""
|
|
95
210
|
# XXX docstring missing
|
|
96
211
|
if rec_ind is None:
|
|
97
212
|
rec_ind = self.sample_recording()
|
|
98
|
-
win_ind = self.rng.choice(self.info.iloc[rec_ind][
|
|
213
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
99
214
|
return win_ind, rec_ind
|
|
100
215
|
|
|
101
|
-
def __iter__(self):
|
|
102
|
-
raise NotImplementedError
|
|
103
|
-
|
|
104
216
|
@property
|
|
105
217
|
def n_recordings(self):
|
|
106
|
-
return
|
|
218
|
+
return super().__len__()
|
|
107
219
|
|
|
108
220
|
|
|
109
221
|
class SequenceSampler(RecordingSampler):
|
|
@@ -131,8 +243,10 @@ class SequenceSampler(RecordingSampler):
|
|
|
131
243
|
Array of shape (n_sequences,) that indicates from which file each
|
|
132
244
|
sequence comes from. Useful e.g. to do self-ensembling.
|
|
133
245
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
|
|
249
|
+
):
|
|
136
250
|
super().__init__(metadata, random_state=random_state)
|
|
137
251
|
self.randomize = randomize
|
|
138
252
|
self.n_windows = n_windows
|
|
@@ -152,8 +266,11 @@ class SequenceSampler(RecordingSampler):
|
|
|
152
266
|
each sequence. Useful e.g. to do self-ensembling.
|
|
153
267
|
"""
|
|
154
268
|
end_offset = 1 - self.n_windows if self.n_windows > 1 else None
|
|
155
|
-
start_inds =
|
|
156
|
-
|
|
269
|
+
start_inds = (
|
|
270
|
+
self.info["index"]
|
|
271
|
+
.apply(lambda x: x[: end_offset : self.n_windows_stride])
|
|
272
|
+
.values
|
|
273
|
+
)
|
|
157
274
|
file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
|
|
158
275
|
return np.concatenate(start_inds), np.concatenate(file_ids)
|
|
159
276
|
|
|
@@ -200,12 +317,13 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
200
317
|
Med. 4, 72 (2021).
|
|
201
318
|
https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
|
|
202
319
|
"""
|
|
320
|
+
|
|
203
321
|
def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
|
|
204
322
|
super().__init__(metadata, random_state=random_state)
|
|
205
323
|
|
|
206
324
|
self.n_windows = n_windows
|
|
207
325
|
self.n_sequences = n_sequences
|
|
208
|
-
self.info_class = self._init_info(metadata, required_keys=[
|
|
326
|
+
self.info_class = self._init_info(metadata, required_keys=["target"])
|
|
209
327
|
|
|
210
328
|
def sample_class(self, rec_ind=None):
|
|
211
329
|
"""Return a random class.
|
|
@@ -225,8 +343,7 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
225
343
|
"""
|
|
226
344
|
if rec_ind is None:
|
|
227
345
|
rec_ind = self.sample_recording()
|
|
228
|
-
available_classes = self.info_class.loc[
|
|
229
|
-
self.info.iloc[rec_ind].name].index
|
|
346
|
+
available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
|
|
230
347
|
return self.rng.choice(available_classes), rec_ind
|
|
231
348
|
|
|
232
349
|
def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
|
|
@@ -257,15 +374,14 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
257
374
|
if class_ind is None:
|
|
258
375
|
class_ind, rec_ind = self.sample_class(rec_ind)
|
|
259
376
|
|
|
260
|
-
rec_inds = self.info.iloc[rec_ind][
|
|
377
|
+
rec_inds = self.info.iloc[rec_ind]["index"]
|
|
261
378
|
len_rec_inds = len(rec_inds)
|
|
262
379
|
|
|
263
380
|
row = self.info.iloc[rec_ind].name
|
|
264
381
|
if not isinstance(row, tuple):
|
|
265
382
|
# Theres's only one category, e.g. "subject"
|
|
266
383
|
row = tuple([row])
|
|
267
|
-
available_indices = self.info_class.loc[
|
|
268
|
-
row + tuple([class_ind]), 'index']
|
|
384
|
+
available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
|
|
269
385
|
win_ind = self.rng.choice(available_indices)
|
|
270
386
|
win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
|
|
271
387
|
|
braindecode/samplers/ssl.py
CHANGED
|
@@ -3,12 +3,16 @@ Self-supervised learning samplers.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Young Truong <dt.young112@gmail.com>
|
|
6
7
|
#
|
|
7
8
|
# License: BSD (3-clause)
|
|
8
9
|
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
9
12
|
import numpy as np
|
|
13
|
+
import torch.distributed as dist
|
|
10
14
|
|
|
11
|
-
from . import RecordingSampler
|
|
15
|
+
from . import DistributedRecordingSampler, RecordingSampler
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class RelativePositioningSampler(RecordingSampler):
|
|
@@ -45,8 +49,17 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
45
49
|
signals with self-supervised learning.
|
|
46
50
|
arXiv preprint arXiv:2007.16104.
|
|
47
51
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
metadata,
|
|
56
|
+
tau_pos,
|
|
57
|
+
tau_neg,
|
|
58
|
+
n_examples,
|
|
59
|
+
tau_max=None,
|
|
60
|
+
same_rec_neg=True,
|
|
61
|
+
random_state=None,
|
|
62
|
+
):
|
|
50
63
|
super().__init__(metadata, random_state=random_state)
|
|
51
64
|
|
|
52
65
|
self.tau_pos = tau_pos
|
|
@@ -56,25 +69,153 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
56
69
|
self.same_rec_neg = same_rec_neg
|
|
57
70
|
|
|
58
71
|
if not same_rec_neg and self.n_recordings < 2:
|
|
59
|
-
raise ValueError(
|
|
60
|
-
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"More than one recording must be available when "
|
|
74
|
+
"using across-recording negative sampling."
|
|
75
|
+
)
|
|
61
76
|
|
|
62
77
|
def _sample_pair(self):
|
|
63
|
-
"""Sample a pair of two windows.
|
|
78
|
+
"""Sample a pair of two windows."""
|
|
79
|
+
# Sample first window
|
|
80
|
+
win_ind1, rec_ind1 = self.sample_window()
|
|
81
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
82
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
83
|
+
|
|
84
|
+
# Decide whether the pair will be positive or negative
|
|
85
|
+
pair_type = self.rng.binomial(1, 0.5)
|
|
86
|
+
win_ind2 = None
|
|
87
|
+
if pair_type == 0: # Negative example
|
|
88
|
+
if self.same_rec_neg:
|
|
89
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
90
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
rec_ind2 = rec_ind1
|
|
94
|
+
while rec_ind2 == rec_ind1:
|
|
95
|
+
win_ind2, rec_ind2 = self.sample_window()
|
|
96
|
+
elif pair_type == 1: # Positive example
|
|
97
|
+
mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
|
|
98
|
+
|
|
99
|
+
if win_ind2 is None:
|
|
100
|
+
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
101
|
+
if sum(mask) == 0:
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
104
|
+
|
|
105
|
+
return win_ind1, win_ind2, float(pair_type)
|
|
106
|
+
|
|
107
|
+
def presample(self):
|
|
108
|
+
"""Presample examples.
|
|
109
|
+
|
|
110
|
+
Once presampled, the examples are the same from one epoch to another.
|
|
111
|
+
"""
|
|
112
|
+
self.examples = [self._sample_pair() for _ in range(self.n_examples)]
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __iter__(self):
|
|
116
|
+
"""
|
|
117
|
+
Iterate over pairs.
|
|
118
|
+
|
|
119
|
+
Yields
|
|
120
|
+
------
|
|
121
|
+
int
|
|
122
|
+
Position of the first window in the dataset.
|
|
123
|
+
int
|
|
124
|
+
Position of the second window in the dataset.
|
|
125
|
+
float
|
|
126
|
+
0 for a negative pair, 1 for a positive pair.
|
|
64
127
|
"""
|
|
128
|
+
for i in range(self.n_examples):
|
|
129
|
+
if hasattr(self, "examples"):
|
|
130
|
+
yield self.examples[i]
|
|
131
|
+
else:
|
|
132
|
+
yield self._sample_pair()
|
|
133
|
+
|
|
134
|
+
def __len__(self):
|
|
135
|
+
return self.n_examples
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class DistributedRelativePositioningSampler(DistributedRecordingSampler):
|
|
139
|
+
"""Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
|
|
140
|
+
|
|
141
|
+
Sample examples as tuples of two window indices, with a label indicating
|
|
142
|
+
whether the windows are close or far, as defined by tau_pos and tau_neg.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
metadata : pd.DataFrame
|
|
147
|
+
See RecordingSampler.
|
|
148
|
+
tau_pos : int
|
|
149
|
+
Size of the positive context, in samples. A positive pair contains two
|
|
150
|
+
windows x1 and x2 which are separated by at most `tau_pos` samples.
|
|
151
|
+
tau_neg : int
|
|
152
|
+
Size of the negative context, in samples. A negative pair contains two
|
|
153
|
+
windows x1 and x2 which are separated by at least `tau_neg` samples and
|
|
154
|
+
at most `tau_max` samples. Ignored if `same_rec_neg` is False.
|
|
155
|
+
n_examples : int
|
|
156
|
+
Number of pairs to extract.
|
|
157
|
+
tau_max : int | None
|
|
158
|
+
See `tau_neg`.
|
|
159
|
+
same_rec_neg : bool
|
|
160
|
+
If True, sample negative pairs from within the same recording. If
|
|
161
|
+
False, sample negative pairs from two different recordings.
|
|
162
|
+
random_state : None | np.RandomState | int
|
|
163
|
+
Random state.
|
|
164
|
+
kwargs: dict
|
|
165
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
166
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
167
|
+
|
|
168
|
+
References
|
|
169
|
+
----------
|
|
170
|
+
.. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
|
|
171
|
+
& Gramfort, A. (2020). Uncovering the structure of clinical EEG
|
|
172
|
+
signals with self-supervised learning.
|
|
173
|
+
arXiv preprint arXiv:2007.16104.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
metadata,
|
|
179
|
+
tau_pos,
|
|
180
|
+
tau_neg,
|
|
181
|
+
n_examples,
|
|
182
|
+
tau_max=None,
|
|
183
|
+
same_rec_neg=True,
|
|
184
|
+
random_state=None,
|
|
185
|
+
**kwargs,
|
|
186
|
+
):
|
|
187
|
+
super().__init__(metadata, random_state=random_state, **kwargs)
|
|
188
|
+
self.tau_pos = tau_pos
|
|
189
|
+
self.tau_neg = tau_neg
|
|
190
|
+
self.tau_max = np.inf if tau_max is None else tau_max
|
|
191
|
+
self.same_rec_neg = same_rec_neg
|
|
192
|
+
|
|
193
|
+
self.n_examples = n_examples // self.info.shape[0] * self.n_recordings
|
|
194
|
+
warnings.warn(
|
|
195
|
+
f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
|
|
196
|
+
)
|
|
197
|
+
warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
|
|
198
|
+
|
|
199
|
+
if not same_rec_neg and self.n_recordings < 2:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"More than one recording must be available when "
|
|
202
|
+
"using across-recording negative sampling."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _sample_pair(self):
|
|
206
|
+
"""Sample a pair of two windows."""
|
|
65
207
|
# Sample first window
|
|
66
208
|
win_ind1, rec_ind1 = self.sample_window()
|
|
67
|
-
ts1 = self.metadata.iloc[win_ind1][
|
|
68
|
-
ts = self.info.iloc[rec_ind1][
|
|
209
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
210
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
69
211
|
|
|
70
212
|
# Decide whether the pair will be positive or negative
|
|
71
213
|
pair_type = self.rng.binomial(1, 0.5)
|
|
72
214
|
win_ind2 = None
|
|
73
215
|
if pair_type == 0: # Negative example
|
|
74
216
|
if self.same_rec_neg:
|
|
75
|
-
mask = (
|
|
76
|
-
(
|
|
77
|
-
((ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max))
|
|
217
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
218
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
78
219
|
)
|
|
79
220
|
else:
|
|
80
221
|
rec_ind2 = rec_ind1
|
|
@@ -87,7 +228,7 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
87
228
|
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
88
229
|
if sum(mask) == 0:
|
|
89
230
|
raise NotImplementedError
|
|
90
|
-
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1][
|
|
231
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
91
232
|
|
|
92
233
|
return win_ind1, win_ind2, float(pair_type)
|
|
93
234
|
|
|
@@ -100,16 +241,20 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
100
241
|
return self
|
|
101
242
|
|
|
102
243
|
def __iter__(self):
|
|
103
|
-
"""
|
|
244
|
+
"""
|
|
245
|
+
Iterate over pairs.
|
|
104
246
|
|
|
105
247
|
Yields
|
|
106
248
|
------
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
249
|
+
int
|
|
250
|
+
Position of the first window in the dataset.
|
|
251
|
+
int
|
|
252
|
+
Position of the second window in the dataset.
|
|
253
|
+
float
|
|
254
|
+
0 for a negative pair, 1 for a positive pair.
|
|
110
255
|
"""
|
|
111
256
|
for i in range(self.n_examples):
|
|
112
|
-
if hasattr(self,
|
|
257
|
+
if hasattr(self, "examples"):
|
|
113
258
|
yield self.examples[i]
|
|
114
259
|
else:
|
|
115
260
|
yield self._sample_pair()
|
braindecode/training/__init__.py
CHANGED
|
@@ -2,14 +2,22 @@
|
|
|
2
2
|
Functionality for skorch-based training.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
|
|
6
|
+
from .scoring import (
|
|
7
|
+
CroppedTimeSeriesEpochScoring,
|
|
8
|
+
CroppedTrialEpochScoring,
|
|
9
|
+
PostEpochTrainScoring,
|
|
10
|
+
predict_trials,
|
|
11
|
+
trial_preds_from_window_preds,
|
|
12
|
+
)
|
|
5
13
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
__all__ = [
|
|
15
|
+
"CroppedLoss",
|
|
16
|
+
"mixup_criterion",
|
|
17
|
+
"TimeSeriesLoss",
|
|
18
|
+
"CroppedTrialEpochScoring",
|
|
19
|
+
"PostEpochTrainScoring",
|
|
20
|
+
"CroppedTimeSeriesEpochScoring",
|
|
21
|
+
"trial_preds_from_window_preds",
|
|
22
|
+
"predict_trials",
|
|
23
|
+
]
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
-
from skorch.callbacks import Callback
|
|
6
5
|
import torch
|
|
6
|
+
from skorch.callbacks import Callback
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class MaxNormConstraintCallback(Callback):
|
|
@@ -20,6 +20,4 @@ class MaxNormConstraintCallback(Callback):
|
|
|
20
20
|
)
|
|
21
21
|
last_weight = module.weight
|
|
22
22
|
if last_weight is not None:
|
|
23
|
-
last_weight.data = torch.renorm(
|
|
24
|
-
last_weight.data, 2, 0, maxnorm=0.5
|
|
25
|
-
)
|
|
23
|
+
last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
|
braindecode/training/losses.py
CHANGED
|
@@ -96,15 +96,10 @@ def mixup_criterion(preds, target):
|
|
|
96
96
|
# unpack target
|
|
97
97
|
y_a, y_b, lam = target
|
|
98
98
|
# compute loss per sample
|
|
99
|
-
loss_a = torch.nn.functional.nll_loss(preds,
|
|
100
|
-
|
|
101
|
-
reduction='none')
|
|
102
|
-
loss_b = torch.nn.functional.nll_loss(preds,
|
|
103
|
-
y_b,
|
|
104
|
-
reduction='none')
|
|
99
|
+
loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
|
|
100
|
+
loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
|
|
105
101
|
# compute weighted mean
|
|
106
102
|
ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
|
|
107
103
|
return ret.mean()
|
|
108
104
|
else:
|
|
109
|
-
return torch.nn.functional.nll_loss(preds,
|
|
110
|
-
target)
|
|
105
|
+
return torch.nn.functional.nll_loss(preds, target)
|