braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/regressor.py
CHANGED
|
@@ -11,8 +11,8 @@ import warnings
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from skorch.regressor import NeuralNetRegressor
|
|
13
13
|
|
|
14
|
-
from .training.scoring import predict_trials
|
|
15
14
|
from .eegneuralnet import _EEGNeuralNet
|
|
15
|
+
from .training.scoring import predict_trials
|
|
16
16
|
from .util import ThrowAwayIndexLoader, update_estimator_docstring
|
|
17
17
|
|
|
18
18
|
|
|
@@ -58,19 +58,28 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
58
58
|
""" # noqa: E501
|
|
59
59
|
__doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
|
|
60
60
|
|
|
61
|
-
def __init__(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
module,
|
|
64
|
+
*args,
|
|
65
|
+
cropped=False,
|
|
66
|
+
callbacks=None,
|
|
67
|
+
iterator_train__shuffle=True,
|
|
68
|
+
iterator_train__drop_last=True,
|
|
69
|
+
aggregate_predictions=True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
65
72
|
self.cropped = cropped
|
|
66
73
|
self.aggregate_predictions = aggregate_predictions
|
|
67
74
|
self._last_window_inds_ = None
|
|
68
|
-
super().__init__(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
super().__init__(
|
|
76
|
+
module,
|
|
77
|
+
*args,
|
|
78
|
+
callbacks=callbacks,
|
|
79
|
+
iterator_train__shuffle=iterator_train__shuffle,
|
|
80
|
+
iterator_train__drop_last=iterator_train__drop_last,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
74
83
|
|
|
75
84
|
def get_iterator(self, dataset, training=False, drop_index=True):
|
|
76
85
|
iterator = super().get_iterator(dataset, training=training)
|
|
@@ -155,7 +164,9 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
155
164
|
warnings.warn(
|
|
156
165
|
"This method was designed to predict trials in cropped mode. "
|
|
157
166
|
"Calling it when cropped is False will give the same result as "
|
|
158
|
-
"'.predict'.",
|
|
167
|
+
"'.predict'.",
|
|
168
|
+
UserWarning,
|
|
169
|
+
)
|
|
159
170
|
preds = self.predict(X)
|
|
160
171
|
if return_targets:
|
|
161
172
|
return preds, np.concatenate([X[i][1] for i in range(len(X))])
|
braindecode/samplers/__init__.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
|
-
"""Classes to sample examples.
|
|
2
|
-
"""
|
|
1
|
+
"""Classes to sample examples."""
|
|
3
2
|
|
|
4
|
-
from .base import
|
|
5
|
-
|
|
3
|
+
from .base import (
|
|
4
|
+
BalancedSequenceSampler,
|
|
5
|
+
DistributedRecordingSampler,
|
|
6
|
+
RecordingSampler,
|
|
7
|
+
SequenceSampler,
|
|
8
|
+
)
|
|
9
|
+
from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
|
|
6
10
|
|
|
7
|
-
__all__ = [
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RecordingSampler",
|
|
13
|
+
"SequenceSampler",
|
|
14
|
+
"BalancedSequenceSampler",
|
|
15
|
+
"RelativePositioningSampler",
|
|
16
|
+
"DistributedRecordingSampler",
|
|
17
|
+
"DistributedRelativePositioningSampler",
|
|
18
|
+
]
|
braindecode/samplers/base.py
CHANGED
|
@@ -4,17 +4,118 @@ Sampler classes.
|
|
|
4
4
|
|
|
5
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
6
|
# Theo Gnassounou <>
|
|
7
|
+
# Young Truong <dt.young112@gmail.com>
|
|
7
8
|
#
|
|
8
9
|
# License: BSD (3-clause)
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
11
|
-
from torch.utils.data.sampler import Sampler
|
|
12
12
|
from sklearn.utils import check_random_state
|
|
13
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
14
|
+
from torch.utils.data.sampler import Sampler
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class RecordingSampler(Sampler):
|
|
16
18
|
"""Base sampler simplifying sampling from recordings.
|
|
17
19
|
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
metadata : pd.DataFrame
|
|
23
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
24
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
25
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
26
|
+
`metadata.head()` might look like this:
|
|
27
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
28
|
+
| i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
|
|
29
|
+
+===================+=================+=================+========+==========+===========+=======+
|
|
30
|
+
| 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
|
|
31
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
32
|
+
| 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
|
|
33
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
34
|
+
| 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
|
|
35
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
36
|
+
| 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
|
|
37
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
38
|
+
| 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
|
|
39
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
40
|
+
|
|
41
|
+
random_state : np.RandomState | int | None
|
|
42
|
+
Random state.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
info : pd.DataFrame
|
|
47
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
48
|
+
and window indices information in an easily accessible structure for
|
|
49
|
+
quick sampling of windows.
|
|
50
|
+
n_recordings : int
|
|
51
|
+
Number of recordings available.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, metadata, random_state=None):
|
|
55
|
+
self.metadata = metadata
|
|
56
|
+
self.info = self._init_info(metadata)
|
|
57
|
+
self.rng = check_random_state(random_state)
|
|
58
|
+
|
|
59
|
+
def _init_info(self, metadata, required_keys=None):
|
|
60
|
+
"""Initialize ``info`` DataFrame.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
required_keys : list(str) | None
|
|
65
|
+
List of additional columns of the metadata DataFrame that we should
|
|
66
|
+
groupby when creating ``info``.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
See class attributes.
|
|
71
|
+
"""
|
|
72
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
73
|
+
if not keys:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"metadata must contain at least one of the following columns: "
|
|
76
|
+
"subject, session or run."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if required_keys is not None:
|
|
80
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
81
|
+
if len(missing_keys) > 0:
|
|
82
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
83
|
+
keys += required_keys
|
|
84
|
+
|
|
85
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
86
|
+
info = (
|
|
87
|
+
metadata.reset_index()
|
|
88
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
89
|
+
.agg(["unique"])
|
|
90
|
+
)
|
|
91
|
+
info.columns = info.columns.get_level_values(0)
|
|
92
|
+
|
|
93
|
+
return info
|
|
94
|
+
|
|
95
|
+
def sample_recording(self):
|
|
96
|
+
"""Return a random recording index."""
|
|
97
|
+
# XXX docstring missing
|
|
98
|
+
return self.rng.choice(self.n_recordings)
|
|
99
|
+
|
|
100
|
+
def sample_window(self, rec_ind=None):
|
|
101
|
+
"""Return a specific window."""
|
|
102
|
+
# XXX docstring missing
|
|
103
|
+
if rec_ind is None:
|
|
104
|
+
rec_ind = self.sample_recording()
|
|
105
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
106
|
+
return win_ind, rec_ind
|
|
107
|
+
|
|
108
|
+
def __iter__(self):
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def n_recordings(self):
|
|
113
|
+
return self.info.shape[0]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class DistributedRecordingSampler(DistributedSampler):
|
|
117
|
+
"""Base sampler simplifying sampling from recordings in distributed setting.
|
|
118
|
+
|
|
18
119
|
Parameters
|
|
19
120
|
----------
|
|
20
121
|
metadata : pd.DataFrame
|
|
@@ -41,11 +142,22 @@ class RecordingSampler(Sampler):
|
|
|
41
142
|
quick sampling of windows.
|
|
42
143
|
n_recordings : int
|
|
43
144
|
Number of recordings available.
|
|
145
|
+
kwargs : dict
|
|
146
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
147
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
44
148
|
"""
|
|
45
|
-
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
metadata,
|
|
153
|
+
random_state=None,
|
|
154
|
+
**kwargs,
|
|
155
|
+
):
|
|
46
156
|
self.metadata = metadata
|
|
47
157
|
self.info = self._init_info(metadata)
|
|
48
158
|
self.rng = check_random_state(random_state)
|
|
159
|
+
# send information to DistributedSampler parent to handle data splitting among workers
|
|
160
|
+
super().__init__(self.info, seed=random_state, **kwargs)
|
|
49
161
|
|
|
50
162
|
def _init_info(self, metadata, required_keys=None):
|
|
51
163
|
"""Initialize ``info`` DataFrame.
|
|
@@ -60,50 +172,48 @@ class RecordingSampler(Sampler):
|
|
|
60
172
|
-------
|
|
61
173
|
See class attributes.
|
|
62
174
|
"""
|
|
63
|
-
keys = [k for k in [
|
|
64
|
-
if k in self.metadata.columns]
|
|
175
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
65
176
|
if not keys:
|
|
66
177
|
raise ValueError(
|
|
67
|
-
|
|
68
|
-
|
|
178
|
+
"metadata must contain at least one of the following columns: "
|
|
179
|
+
"subject, session or run."
|
|
180
|
+
)
|
|
69
181
|
|
|
70
182
|
if required_keys is not None:
|
|
71
|
-
missing_keys = [
|
|
72
|
-
k for k in required_keys if k not in self.metadata.columns]
|
|
183
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
73
184
|
if len(missing_keys) > 0:
|
|
74
|
-
raise ValueError(
|
|
75
|
-
f'Columns {missing_keys} were not found in metadata.')
|
|
185
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
76
186
|
keys += required_keys
|
|
77
187
|
|
|
78
|
-
metadata = metadata.reset_index().rename(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
[
|
|
188
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
189
|
+
info = (
|
|
190
|
+
metadata.reset_index()
|
|
191
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
192
|
+
.agg(["unique"])
|
|
193
|
+
)
|
|
82
194
|
info.columns = info.columns.get_level_values(0)
|
|
83
195
|
|
|
84
196
|
return info
|
|
85
197
|
|
|
86
198
|
def sample_recording(self):
|
|
87
199
|
"""Return a random recording index.
|
|
200
|
+
super().__iter__() contains indices of datasets specific to the current process
|
|
201
|
+
determined by the DistributedSampler
|
|
88
202
|
"""
|
|
89
203
|
# XXX docstring missing
|
|
90
|
-
return self.rng.choice(
|
|
204
|
+
return self.rng.choice(list(super().__iter__()))
|
|
91
205
|
|
|
92
206
|
def sample_window(self, rec_ind=None):
|
|
93
|
-
"""Return a specific window.
|
|
94
|
-
"""
|
|
207
|
+
"""Return a specific window."""
|
|
95
208
|
# XXX docstring missing
|
|
96
209
|
if rec_ind is None:
|
|
97
210
|
rec_ind = self.sample_recording()
|
|
98
|
-
win_ind = self.rng.choice(self.info.iloc[rec_ind][
|
|
211
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
99
212
|
return win_ind, rec_ind
|
|
100
213
|
|
|
101
|
-
def __iter__(self):
|
|
102
|
-
raise NotImplementedError
|
|
103
|
-
|
|
104
214
|
@property
|
|
105
215
|
def n_recordings(self):
|
|
106
|
-
return
|
|
216
|
+
return super().__len__()
|
|
107
217
|
|
|
108
218
|
|
|
109
219
|
class SequenceSampler(RecordingSampler):
|
|
@@ -131,8 +241,10 @@ class SequenceSampler(RecordingSampler):
|
|
|
131
241
|
Array of shape (n_sequences,) that indicates from which file each
|
|
132
242
|
sequence comes from. Useful e.g. to do self-ensembling.
|
|
133
243
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
244
|
+
|
|
245
|
+
def __init__(
|
|
246
|
+
self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
|
|
247
|
+
):
|
|
136
248
|
super().__init__(metadata, random_state=random_state)
|
|
137
249
|
self.randomize = randomize
|
|
138
250
|
self.n_windows = n_windows
|
|
@@ -152,8 +264,11 @@ class SequenceSampler(RecordingSampler):
|
|
|
152
264
|
each sequence. Useful e.g. to do self-ensembling.
|
|
153
265
|
"""
|
|
154
266
|
end_offset = 1 - self.n_windows if self.n_windows > 1 else None
|
|
155
|
-
start_inds =
|
|
156
|
-
|
|
267
|
+
start_inds = (
|
|
268
|
+
self.info["index"]
|
|
269
|
+
.apply(lambda x: x[: end_offset : self.n_windows_stride])
|
|
270
|
+
.values
|
|
271
|
+
)
|
|
157
272
|
file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
|
|
158
273
|
return np.concatenate(start_inds), np.concatenate(file_ids)
|
|
159
274
|
|
|
@@ -200,12 +315,13 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
200
315
|
Med. 4, 72 (2021).
|
|
201
316
|
https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
|
|
202
317
|
"""
|
|
318
|
+
|
|
203
319
|
def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
|
|
204
320
|
super().__init__(metadata, random_state=random_state)
|
|
205
321
|
|
|
206
322
|
self.n_windows = n_windows
|
|
207
323
|
self.n_sequences = n_sequences
|
|
208
|
-
self.info_class = self._init_info(metadata, required_keys=[
|
|
324
|
+
self.info_class = self._init_info(metadata, required_keys=["target"])
|
|
209
325
|
|
|
210
326
|
def sample_class(self, rec_ind=None):
|
|
211
327
|
"""Return a random class.
|
|
@@ -225,8 +341,7 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
225
341
|
"""
|
|
226
342
|
if rec_ind is None:
|
|
227
343
|
rec_ind = self.sample_recording()
|
|
228
|
-
available_classes = self.info_class.loc[
|
|
229
|
-
self.info.iloc[rec_ind].name].index
|
|
344
|
+
available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
|
|
230
345
|
return self.rng.choice(available_classes), rec_ind
|
|
231
346
|
|
|
232
347
|
def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
|
|
@@ -257,15 +372,14 @@ class BalancedSequenceSampler(RecordingSampler):
|
|
|
257
372
|
if class_ind is None:
|
|
258
373
|
class_ind, rec_ind = self.sample_class(rec_ind)
|
|
259
374
|
|
|
260
|
-
rec_inds = self.info.iloc[rec_ind][
|
|
375
|
+
rec_inds = self.info.iloc[rec_ind]["index"]
|
|
261
376
|
len_rec_inds = len(rec_inds)
|
|
262
377
|
|
|
263
378
|
row = self.info.iloc[rec_ind].name
|
|
264
379
|
if not isinstance(row, tuple):
|
|
265
380
|
# Theres's only one category, e.g. "subject"
|
|
266
381
|
row = tuple([row])
|
|
267
|
-
available_indices = self.info_class.loc[
|
|
268
|
-
row + tuple([class_ind]), 'index']
|
|
382
|
+
available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
|
|
269
383
|
win_ind = self.rng.choice(available_indices)
|
|
270
384
|
win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
|
|
271
385
|
|
braindecode/samplers/ssl.py
CHANGED
|
@@ -3,12 +3,16 @@ Self-supervised learning samplers.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Young Truong <dt.young112@gmail.com>
|
|
6
7
|
#
|
|
7
8
|
# License: BSD (3-clause)
|
|
8
9
|
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
9
12
|
import numpy as np
|
|
13
|
+
import torch.distributed as dist
|
|
10
14
|
|
|
11
|
-
from . import RecordingSampler
|
|
15
|
+
from . import DistributedRecordingSampler, RecordingSampler
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class RelativePositioningSampler(RecordingSampler):
|
|
@@ -45,8 +49,17 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
45
49
|
signals with self-supervised learning.
|
|
46
50
|
arXiv preprint arXiv:2007.16104.
|
|
47
51
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
metadata,
|
|
56
|
+
tau_pos,
|
|
57
|
+
tau_neg,
|
|
58
|
+
n_examples,
|
|
59
|
+
tau_max=None,
|
|
60
|
+
same_rec_neg=True,
|
|
61
|
+
random_state=None,
|
|
62
|
+
):
|
|
50
63
|
super().__init__(metadata, random_state=random_state)
|
|
51
64
|
|
|
52
65
|
self.tau_pos = tau_pos
|
|
@@ -56,25 +69,153 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
56
69
|
self.same_rec_neg = same_rec_neg
|
|
57
70
|
|
|
58
71
|
if not same_rec_neg and self.n_recordings < 2:
|
|
59
|
-
raise ValueError(
|
|
60
|
-
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"More than one recording must be available when "
|
|
74
|
+
"using across-recording negative sampling."
|
|
75
|
+
)
|
|
61
76
|
|
|
62
77
|
def _sample_pair(self):
|
|
63
|
-
"""Sample a pair of two windows.
|
|
78
|
+
"""Sample a pair of two windows."""
|
|
79
|
+
# Sample first window
|
|
80
|
+
win_ind1, rec_ind1 = self.sample_window()
|
|
81
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
82
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
83
|
+
|
|
84
|
+
# Decide whether the pair will be positive or negative
|
|
85
|
+
pair_type = self.rng.binomial(1, 0.5)
|
|
86
|
+
win_ind2 = None
|
|
87
|
+
if pair_type == 0: # Negative example
|
|
88
|
+
if self.same_rec_neg:
|
|
89
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
90
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
rec_ind2 = rec_ind1
|
|
94
|
+
while rec_ind2 == rec_ind1:
|
|
95
|
+
win_ind2, rec_ind2 = self.sample_window()
|
|
96
|
+
elif pair_type == 1: # Positive example
|
|
97
|
+
mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
|
|
98
|
+
|
|
99
|
+
if win_ind2 is None:
|
|
100
|
+
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
101
|
+
if sum(mask) == 0:
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
104
|
+
|
|
105
|
+
return win_ind1, win_ind2, float(pair_type)
|
|
106
|
+
|
|
107
|
+
def presample(self):
|
|
108
|
+
"""Presample examples.
|
|
109
|
+
|
|
110
|
+
Once presampled, the examples are the same from one epoch to another.
|
|
111
|
+
"""
|
|
112
|
+
self.examples = [self._sample_pair() for _ in range(self.n_examples)]
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __iter__(self):
|
|
116
|
+
"""
|
|
117
|
+
Iterate over pairs.
|
|
118
|
+
|
|
119
|
+
Yields
|
|
120
|
+
------
|
|
121
|
+
int
|
|
122
|
+
Position of the first window in the dataset.
|
|
123
|
+
int
|
|
124
|
+
Position of the second window in the dataset.
|
|
125
|
+
float
|
|
126
|
+
0 for a negative pair, 1 for a positive pair.
|
|
64
127
|
"""
|
|
128
|
+
for i in range(self.n_examples):
|
|
129
|
+
if hasattr(self, "examples"):
|
|
130
|
+
yield self.examples[i]
|
|
131
|
+
else:
|
|
132
|
+
yield self._sample_pair()
|
|
133
|
+
|
|
134
|
+
def __len__(self):
|
|
135
|
+
return self.n_examples
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class DistributedRelativePositioningSampler(DistributedRecordingSampler):
|
|
139
|
+
"""Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
|
|
140
|
+
|
|
141
|
+
Sample examples as tuples of two window indices, with a label indicating
|
|
142
|
+
whether the windows are close or far, as defined by tau_pos and tau_neg.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
metadata : pd.DataFrame
|
|
147
|
+
See RecordingSampler.
|
|
148
|
+
tau_pos : int
|
|
149
|
+
Size of the positive context, in samples. A positive pair contains two
|
|
150
|
+
windows x1 and x2 which are separated by at most `tau_pos` samples.
|
|
151
|
+
tau_neg : int
|
|
152
|
+
Size of the negative context, in samples. A negative pair contains two
|
|
153
|
+
windows x1 and x2 which are separated by at least `tau_neg` samples and
|
|
154
|
+
at most `tau_max` samples. Ignored if `same_rec_neg` is False.
|
|
155
|
+
n_examples : int
|
|
156
|
+
Number of pairs to extract.
|
|
157
|
+
tau_max : int | None
|
|
158
|
+
See `tau_neg`.
|
|
159
|
+
same_rec_neg : bool
|
|
160
|
+
If True, sample negative pairs from within the same recording. If
|
|
161
|
+
False, sample negative pairs from two different recordings.
|
|
162
|
+
random_state : None | np.RandomState | int
|
|
163
|
+
Random state.
|
|
164
|
+
kwargs: dict
|
|
165
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
166
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
167
|
+
|
|
168
|
+
References
|
|
169
|
+
----------
|
|
170
|
+
.. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
|
|
171
|
+
& Gramfort, A. (2020). Uncovering the structure of clinical EEG
|
|
172
|
+
signals with self-supervised learning.
|
|
173
|
+
arXiv preprint arXiv:2007.16104.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
metadata,
|
|
179
|
+
tau_pos,
|
|
180
|
+
tau_neg,
|
|
181
|
+
n_examples,
|
|
182
|
+
tau_max=None,
|
|
183
|
+
same_rec_neg=True,
|
|
184
|
+
random_state=None,
|
|
185
|
+
**kwargs,
|
|
186
|
+
):
|
|
187
|
+
super().__init__(metadata, random_state=random_state, **kwargs)
|
|
188
|
+
self.tau_pos = tau_pos
|
|
189
|
+
self.tau_neg = tau_neg
|
|
190
|
+
self.tau_max = np.inf if tau_max is None else tau_max
|
|
191
|
+
self.same_rec_neg = same_rec_neg
|
|
192
|
+
|
|
193
|
+
self.n_examples = n_examples // self.info.shape[0] * self.n_recordings
|
|
194
|
+
warnings.warn(
|
|
195
|
+
f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
|
|
196
|
+
)
|
|
197
|
+
warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
|
|
198
|
+
|
|
199
|
+
if not same_rec_neg and self.n_recordings < 2:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"More than one recording must be available when "
|
|
202
|
+
"using across-recording negative sampling."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _sample_pair(self):
|
|
206
|
+
"""Sample a pair of two windows."""
|
|
65
207
|
# Sample first window
|
|
66
208
|
win_ind1, rec_ind1 = self.sample_window()
|
|
67
|
-
ts1 = self.metadata.iloc[win_ind1][
|
|
68
|
-
ts = self.info.iloc[rec_ind1][
|
|
209
|
+
ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
|
|
210
|
+
ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
|
|
69
211
|
|
|
70
212
|
# Decide whether the pair will be positive or negative
|
|
71
213
|
pair_type = self.rng.binomial(1, 0.5)
|
|
72
214
|
win_ind2 = None
|
|
73
215
|
if pair_type == 0: # Negative example
|
|
74
216
|
if self.same_rec_neg:
|
|
75
|
-
mask = (
|
|
76
|
-
(
|
|
77
|
-
((ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max))
|
|
217
|
+
mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
|
|
218
|
+
(ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
|
|
78
219
|
)
|
|
79
220
|
else:
|
|
80
221
|
rec_ind2 = rec_ind1
|
|
@@ -87,7 +228,7 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
87
228
|
mask[ts == ts1] = False # same window cannot be sampled twice
|
|
88
229
|
if sum(mask) == 0:
|
|
89
230
|
raise NotImplementedError
|
|
90
|
-
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1][
|
|
231
|
+
win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
|
|
91
232
|
|
|
92
233
|
return win_ind1, win_ind2, float(pair_type)
|
|
93
234
|
|
|
@@ -100,16 +241,20 @@ class RelativePositioningSampler(RecordingSampler):
|
|
|
100
241
|
return self
|
|
101
242
|
|
|
102
243
|
def __iter__(self):
|
|
103
|
-
"""
|
|
244
|
+
"""
|
|
245
|
+
Iterate over pairs.
|
|
104
246
|
|
|
105
247
|
Yields
|
|
106
248
|
------
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
249
|
+
int
|
|
250
|
+
Position of the first window in the dataset.
|
|
251
|
+
int
|
|
252
|
+
Position of the second window in the dataset.
|
|
253
|
+
float
|
|
254
|
+
0 for a negative pair, 1 for a positive pair.
|
|
110
255
|
"""
|
|
111
256
|
for i in range(self.n_examples):
|
|
112
|
-
if hasattr(self,
|
|
257
|
+
if hasattr(self, "examples"):
|
|
113
258
|
yield self.examples[i]
|
|
114
259
|
else:
|
|
115
260
|
yield self._sample_pair()
|
braindecode/training/__init__.py
CHANGED
|
@@ -2,14 +2,22 @@
|
|
|
2
2
|
Functionality for skorch-based training.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
|
|
6
|
+
from .scoring import (
|
|
7
|
+
CroppedTimeSeriesEpochScoring,
|
|
8
|
+
CroppedTrialEpochScoring,
|
|
9
|
+
PostEpochTrainScoring,
|
|
10
|
+
predict_trials,
|
|
11
|
+
trial_preds_from_window_preds,
|
|
12
|
+
)
|
|
5
13
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
__all__ = [
|
|
15
|
+
"CroppedLoss",
|
|
16
|
+
"mixup_criterion",
|
|
17
|
+
"TimeSeriesLoss",
|
|
18
|
+
"CroppedTrialEpochScoring",
|
|
19
|
+
"PostEpochTrainScoring",
|
|
20
|
+
"CroppedTimeSeriesEpochScoring",
|
|
21
|
+
"trial_preds_from_window_preds",
|
|
22
|
+
"predict_trials",
|
|
23
|
+
]
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
-
from skorch.callbacks import Callback
|
|
6
5
|
import torch
|
|
6
|
+
from skorch.callbacks import Callback
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class MaxNormConstraintCallback(Callback):
|
|
@@ -20,6 +20,4 @@ class MaxNormConstraintCallback(Callback):
|
|
|
20
20
|
)
|
|
21
21
|
last_weight = module.weight
|
|
22
22
|
if last_weight is not None:
|
|
23
|
-
last_weight.data = torch.renorm(
|
|
24
|
-
last_weight.data, 2, 0, maxnorm=0.5
|
|
25
|
-
)
|
|
23
|
+
last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
|
braindecode/training/losses.py
CHANGED
|
@@ -96,15 +96,10 @@ def mixup_criterion(preds, target):
|
|
|
96
96
|
# unpack target
|
|
97
97
|
y_a, y_b, lam = target
|
|
98
98
|
# compute loss per sample
|
|
99
|
-
loss_a = torch.nn.functional.nll_loss(preds,
|
|
100
|
-
|
|
101
|
-
reduction='none')
|
|
102
|
-
loss_b = torch.nn.functional.nll_loss(preds,
|
|
103
|
-
y_b,
|
|
104
|
-
reduction='none')
|
|
99
|
+
loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
|
|
100
|
+
loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
|
|
105
101
|
# compute weighted mean
|
|
106
102
|
ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
|
|
107
103
|
return ret.mean()
|
|
108
104
|
else:
|
|
109
|
-
return torch.nn.functional.nll_loss(preds,
|
|
110
|
-
target)
|
|
105
|
+
return torch.nn.functional.nll_loss(preds, target)
|