braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/regressor.py
CHANGED
|
@@ -11,8 +11,8 @@ import warnings
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from skorch.regressor import NeuralNetRegressor
|
|
13
13
|
|
|
14
|
-
from .training.scoring import predict_trials
|
|
15
14
|
from .eegneuralnet import _EEGNeuralNet
|
|
15
|
+
from .training.scoring import predict_trials
|
|
16
16
|
from .util import ThrowAwayIndexLoader, update_estimator_docstring
|
|
17
17
|
|
|
18
18
|
|
|
@@ -58,19 +58,28 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
58
58
|
""" # noqa: E501
|
|
59
59
|
__doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
|
|
60
60
|
|
|
61
|
-
def __init__(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
module,
|
|
64
|
+
*args,
|
|
65
|
+
cropped=False,
|
|
66
|
+
callbacks=None,
|
|
67
|
+
iterator_train__shuffle=True,
|
|
68
|
+
iterator_train__drop_last=True,
|
|
69
|
+
aggregate_predictions=True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
65
72
|
self.cropped = cropped
|
|
66
73
|
self.aggregate_predictions = aggregate_predictions
|
|
67
74
|
self._last_window_inds_ = None
|
|
68
|
-
super().__init__(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
super().__init__(
|
|
76
|
+
module,
|
|
77
|
+
*args,
|
|
78
|
+
callbacks=callbacks,
|
|
79
|
+
iterator_train__shuffle=iterator_train__shuffle,
|
|
80
|
+
iterator_train__drop_last=iterator_train__drop_last,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
74
83
|
|
|
75
84
|
def get_iterator(self, dataset, training=False, drop_index=True):
|
|
76
85
|
iterator = super().get_iterator(dataset, training=training)
|
|
@@ -155,7 +164,9 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
|
155
164
|
warnings.warn(
|
|
156
165
|
"This method was designed to predict trials in cropped mode. "
|
|
157
166
|
"Calling it when cropped is False will give the same result as "
|
|
158
|
-
"'.predict'.",
|
|
167
|
+
"'.predict'.",
|
|
168
|
+
UserWarning,
|
|
169
|
+
)
|
|
159
170
|
preds = self.predict(X)
|
|
160
171
|
if return_targets:
|
|
161
172
|
return preds, np.concatenate([X[i][1] for i in range(len(X))])
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Classes to sample examples."""
|
|
2
|
+
|
|
3
|
+
from .base import (
|
|
4
|
+
BalancedSequenceSampler,
|
|
5
|
+
DistributedRecordingSampler,
|
|
6
|
+
RecordingSampler,
|
|
7
|
+
SequenceSampler,
|
|
8
|
+
)
|
|
9
|
+
from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RecordingSampler",
|
|
13
|
+
"SequenceSampler",
|
|
14
|
+
"BalancedSequenceSampler",
|
|
15
|
+
"RelativePositioningSampler",
|
|
16
|
+
"DistributedRecordingSampler",
|
|
17
|
+
"DistributedRelativePositioningSampler",
|
|
18
|
+
]
|
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sampler classes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Theo Gnassounou <>
|
|
7
|
+
# Young Truong <dt.young112@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from sklearn.utils import check_random_state
|
|
15
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
16
|
+
from torch.utils.data.sampler import Sampler
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RecordingSampler(Sampler):
|
|
20
|
+
"""Base sampler simplifying sampling from recordings.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
metadata : pd.DataFrame
|
|
25
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
26
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
27
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
28
|
+
`metadata.head()` might look like this:
|
|
29
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
30
|
+
| i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
|
|
31
|
+
+===================+=================+=================+========+==========+===========+=======+
|
|
32
|
+
| 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
|
|
33
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
34
|
+
| 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
|
|
35
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
36
|
+
| 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
|
|
37
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
38
|
+
| 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
|
|
39
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
40
|
+
| 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
|
|
41
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
42
|
+
|
|
43
|
+
random_state : np.RandomState | int | None
|
|
44
|
+
Random state.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
info : pd.DataFrame
|
|
49
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
50
|
+
and window indices information in an easily accessible structure for
|
|
51
|
+
quick sampling of windows.
|
|
52
|
+
n_recordings : int
|
|
53
|
+
Number of recordings available.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, metadata, random_state=None):
|
|
57
|
+
self.metadata = metadata
|
|
58
|
+
self.info = self._init_info(metadata)
|
|
59
|
+
self.rng = check_random_state(random_state)
|
|
60
|
+
|
|
61
|
+
def _init_info(self, metadata, required_keys=None):
|
|
62
|
+
"""Initialize ``info`` DataFrame.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
required_keys : list(str) | None
|
|
67
|
+
List of additional columns of the metadata DataFrame that we should
|
|
68
|
+
groupby when creating ``info``.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
See class attributes.
|
|
73
|
+
"""
|
|
74
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
75
|
+
if not keys:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"metadata must contain at least one of the following columns: "
|
|
78
|
+
"subject, session or run."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if required_keys is not None:
|
|
82
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
83
|
+
if len(missing_keys) > 0:
|
|
84
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
85
|
+
keys += required_keys
|
|
86
|
+
|
|
87
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
88
|
+
info = (
|
|
89
|
+
metadata.reset_index()
|
|
90
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
91
|
+
.agg(["unique"])
|
|
92
|
+
)
|
|
93
|
+
info.columns = info.columns.get_level_values(0)
|
|
94
|
+
|
|
95
|
+
return info
|
|
96
|
+
|
|
97
|
+
def sample_recording(self):
|
|
98
|
+
"""Return a random recording index."""
|
|
99
|
+
# XXX docstring missing
|
|
100
|
+
return self.rng.choice(self.n_recordings)
|
|
101
|
+
|
|
102
|
+
def sample_window(self, rec_ind=None):
|
|
103
|
+
"""Return a specific window."""
|
|
104
|
+
# XXX docstring missing
|
|
105
|
+
if rec_ind is None:
|
|
106
|
+
rec_ind = self.sample_recording()
|
|
107
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
108
|
+
return win_ind, rec_ind
|
|
109
|
+
|
|
110
|
+
def __iter__(self):
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def n_recordings(self):
|
|
115
|
+
return self.info.shape[0]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DistributedRecordingSampler(DistributedSampler):
|
|
119
|
+
"""Base sampler simplifying sampling from recordings in distributed setting.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
metadata : pd.DataFrame
|
|
124
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
125
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
126
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
127
|
+
`metadata.head()` might look like this:
|
|
128
|
+
|
|
129
|
+
i_window_in_trial i_start_in_trial i_stop_in_trial target subject session run
|
|
130
|
+
0 0 0 500 -1 4 session_T run_0
|
|
131
|
+
1 1 500 1000 -1 4 session_T run_0
|
|
132
|
+
2 2 1000 1500 -1 4 session_T run_0
|
|
133
|
+
3 3 1500 2000 -1 4 session_T run_0
|
|
134
|
+
4 4 2000 2500 -1 4 session_T run_0
|
|
135
|
+
|
|
136
|
+
random_state : np.RandomState | int | None
|
|
137
|
+
Random state.
|
|
138
|
+
|
|
139
|
+
Attributes
|
|
140
|
+
----------
|
|
141
|
+
info : pd.DataFrame
|
|
142
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
143
|
+
and window indices information in an easily accessible structure for
|
|
144
|
+
quick sampling of windows.
|
|
145
|
+
n_recordings : int
|
|
146
|
+
Number of recordings available.
|
|
147
|
+
kwargs : dict
|
|
148
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
149
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
metadata,
|
|
155
|
+
random_state=None,
|
|
156
|
+
**kwargs,
|
|
157
|
+
):
|
|
158
|
+
self.metadata = metadata
|
|
159
|
+
self.info = self._init_info(metadata)
|
|
160
|
+
self.rng = check_random_state(random_state)
|
|
161
|
+
# send information to DistributedSampler parent to handle data splitting among workers
|
|
162
|
+
super().__init__(self.info, seed=random_state, **kwargs)
|
|
163
|
+
|
|
164
|
+
def _init_info(self, metadata, required_keys=None):
|
|
165
|
+
"""Initialize ``info`` DataFrame.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
required_keys : list(str) | None
|
|
170
|
+
List of additional columns of the metadata DataFrame that we should
|
|
171
|
+
groupby when creating ``info``.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
See class attributes.
|
|
176
|
+
"""
|
|
177
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
178
|
+
if not keys:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"metadata must contain at least one of the following columns: "
|
|
181
|
+
"subject, session or run."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if required_keys is not None:
|
|
185
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
186
|
+
if len(missing_keys) > 0:
|
|
187
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
188
|
+
keys += required_keys
|
|
189
|
+
|
|
190
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
191
|
+
info = (
|
|
192
|
+
metadata.reset_index()
|
|
193
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
194
|
+
.agg(["unique"])
|
|
195
|
+
)
|
|
196
|
+
info.columns = info.columns.get_level_values(0)
|
|
197
|
+
|
|
198
|
+
return info
|
|
199
|
+
|
|
200
|
+
def sample_recording(self):
|
|
201
|
+
"""Return a random recording index.
|
|
202
|
+
super().__iter__() contains indices of datasets specific to the current process
|
|
203
|
+
determined by the DistributedSampler
|
|
204
|
+
"""
|
|
205
|
+
# XXX docstring missing
|
|
206
|
+
return self.rng.choice(list(super().__iter__()))
|
|
207
|
+
|
|
208
|
+
def sample_window(self, rec_ind=None):
|
|
209
|
+
"""Return a specific window."""
|
|
210
|
+
# XXX docstring missing
|
|
211
|
+
if rec_ind is None:
|
|
212
|
+
rec_ind = self.sample_recording()
|
|
213
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
214
|
+
return win_ind, rec_ind
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def n_recordings(self):
|
|
218
|
+
return super().__len__()
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class SequenceSampler(RecordingSampler):
|
|
222
|
+
"""Sample sequences of consecutive windows.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
metadata : pd.DataFrame
|
|
227
|
+
See RecordingSampler.
|
|
228
|
+
n_windows : int
|
|
229
|
+
Number of consecutive windows in a sequence.
|
|
230
|
+
n_windows_stride : int
|
|
231
|
+
Number of windows between two consecutive sequences.
|
|
232
|
+
random : bool
|
|
233
|
+
If True, sample sequences randomly. If False, sample sequences in
|
|
234
|
+
order.
|
|
235
|
+
random_state : np.random.RandomState | int | None
|
|
236
|
+
Random state.
|
|
237
|
+
|
|
238
|
+
Attributes
|
|
239
|
+
----------
|
|
240
|
+
info : pd.DataFrame
|
|
241
|
+
See RecordingSampler.
|
|
242
|
+
file_ids : np.ndarray of ints
|
|
243
|
+
Array of shape (n_sequences,) that indicates from which file each
|
|
244
|
+
sequence comes from. Useful e.g. to do self-ensembling.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
|
|
249
|
+
):
|
|
250
|
+
super().__init__(metadata, random_state=random_state)
|
|
251
|
+
self.randomize = randomize
|
|
252
|
+
self.n_windows = n_windows
|
|
253
|
+
self.n_windows_stride = n_windows_stride
|
|
254
|
+
self.start_inds, self.file_ids = self._compute_seq_start_inds()
|
|
255
|
+
|
|
256
|
+
def _compute_seq_start_inds(self):
|
|
257
|
+
"""Compute sequence start indices.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
np.ndarray :
|
|
262
|
+
Array of shape (n_sequences,) containing the indices of the first
|
|
263
|
+
windows of possible sequences.
|
|
264
|
+
np.ndarray :
|
|
265
|
+
Array of shape (n_sequences,) containing the unique file number of
|
|
266
|
+
each sequence. Useful e.g. to do self-ensembling.
|
|
267
|
+
"""
|
|
268
|
+
end_offset = 1 - self.n_windows if self.n_windows > 1 else None
|
|
269
|
+
start_inds = (
|
|
270
|
+
self.info["index"]
|
|
271
|
+
.apply(lambda x: x[: end_offset : self.n_windows_stride])
|
|
272
|
+
.values
|
|
273
|
+
)
|
|
274
|
+
file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
|
|
275
|
+
return np.concatenate(start_inds), np.concatenate(file_ids)
|
|
276
|
+
|
|
277
|
+
def __len__(self):
|
|
278
|
+
return len(self.start_inds)
|
|
279
|
+
|
|
280
|
+
def __iter__(self):
|
|
281
|
+
if self.randomize:
|
|
282
|
+
start_inds = self.start_inds.copy()
|
|
283
|
+
self.rng.shuffle(start_inds)
|
|
284
|
+
for start_ind in start_inds:
|
|
285
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|
|
286
|
+
else:
|
|
287
|
+
for start_ind in self.start_inds:
|
|
288
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class BalancedSequenceSampler(RecordingSampler):
|
|
292
|
+
"""Balanced sampling of sequences of consecutive windows with categorical
|
|
293
|
+
targets.
|
|
294
|
+
|
|
295
|
+
Balanced sampling of sequences inspired by the approach of [Perslev2021]_:
|
|
296
|
+
1. Uniformly sample a recording out of the available ones.
|
|
297
|
+
2. Uniformly sample one of the classes.
|
|
298
|
+
3. Sample a window of the corresponding class in the selected recording.
|
|
299
|
+
4. Extract a sequence of windows around the sampled window.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
metadata : pd.DataFrame
|
|
304
|
+
See RecordingSampler.
|
|
305
|
+
Must contain a column `target` with categorical targets.
|
|
306
|
+
n_windows : int
|
|
307
|
+
Number of consecutive windows in a sequence.
|
|
308
|
+
n_sequences : int
|
|
309
|
+
Number of sequences to sample.
|
|
310
|
+
random_state : np.random.RandomState | int | None
|
|
311
|
+
Random state.
|
|
312
|
+
|
|
313
|
+
References
|
|
314
|
+
----------
|
|
315
|
+
.. [Perslev2021] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ,
|
|
316
|
+
Igel C. U-Sleep: resilient high-frequency sleep staging. npj Digit.
|
|
317
|
+
Med. 4, 72 (2021).
|
|
318
|
+
https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
|
|
322
|
+
super().__init__(metadata, random_state=random_state)
|
|
323
|
+
|
|
324
|
+
self.n_windows = n_windows
|
|
325
|
+
self.n_sequences = n_sequences
|
|
326
|
+
self.info_class = self._init_info(metadata, required_keys=["target"])
|
|
327
|
+
|
|
328
|
+
def sample_class(self, rec_ind=None):
|
|
329
|
+
"""Return a random class.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
rec_ind : int | None
|
|
334
|
+
Index to the recording to sample from. If None, the recording will
|
|
335
|
+
be uniformly sampled across available recordings.
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
int
|
|
340
|
+
Sampled class.
|
|
341
|
+
int
|
|
342
|
+
Index to the recording the class was sampled from.
|
|
343
|
+
"""
|
|
344
|
+
if rec_ind is None:
|
|
345
|
+
rec_ind = self.sample_recording()
|
|
346
|
+
available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
|
|
347
|
+
return self.rng.choice(available_classes), rec_ind
|
|
348
|
+
|
|
349
|
+
def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
|
|
350
|
+
"""Sample a sequence and return its start index.
|
|
351
|
+
|
|
352
|
+
Sample a window associated with a random recording and a random class
|
|
353
|
+
and randomly sample a sequence with it inside. The function returns the
|
|
354
|
+
index of the beginning of the sequence.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
rec_ind : int | None
|
|
359
|
+
Index to the recording to sample from. If None, the recording will
|
|
360
|
+
be uniformly sampled across available recordings.
|
|
361
|
+
class_ind : int | None
|
|
362
|
+
If provided as int, sample a window of the corresponding class. If
|
|
363
|
+
None, the class will be uniformly sampled across available classes.
|
|
364
|
+
|
|
365
|
+
Returns
|
|
366
|
+
-------
|
|
367
|
+
int
|
|
368
|
+
Index of the first window of the sequence.
|
|
369
|
+
int
|
|
370
|
+
Corresponding recording index.
|
|
371
|
+
int
|
|
372
|
+
Class of the sampled window.
|
|
373
|
+
"""
|
|
374
|
+
if class_ind is None:
|
|
375
|
+
class_ind, rec_ind = self.sample_class(rec_ind)
|
|
376
|
+
|
|
377
|
+
rec_inds = self.info.iloc[rec_ind]["index"]
|
|
378
|
+
len_rec_inds = len(rec_inds)
|
|
379
|
+
|
|
380
|
+
row = self.info.iloc[rec_ind].name
|
|
381
|
+
if not isinstance(row, tuple):
|
|
382
|
+
# Theres's only one category, e.g. "subject"
|
|
383
|
+
row = tuple([row])
|
|
384
|
+
available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
|
|
385
|
+
win_ind = self.rng.choice(available_indices)
|
|
386
|
+
win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
|
|
387
|
+
|
|
388
|
+
# Minimum and maximum start indices in the sequence
|
|
389
|
+
min_pos = max(0, win_ind_in_rec - self.n_windows + 1)
|
|
390
|
+
max_pos = min(len_rec_inds - self.n_windows, win_ind_in_rec)
|
|
391
|
+
start_ind = rec_inds[self.rng.randint(min_pos, max_pos + 1)]
|
|
392
|
+
|
|
393
|
+
return start_ind, rec_ind, class_ind
|
|
394
|
+
|
|
395
|
+
def __len__(self):
|
|
396
|
+
return self.n_sequences
|
|
397
|
+
|
|
398
|
+
def __iter__(self):
|
|
399
|
+
for _ in range(self.n_sequences):
|
|
400
|
+
start_ind, _, _ = self._sample_seq_start_ind()
|
|
401
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|