braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
braindecode/regressor.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
2
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
4
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
5
|
+
# Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
6
|
+
#
|
|
7
|
+
# License: BSD (3-clause)
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from skorch.regressor import NeuralNetRegressor
|
|
13
|
+
|
|
14
|
+
from .eegneuralnet import _EEGNeuralNet
|
|
15
|
+
from .training.scoring import predict_trials
|
|
16
|
+
from .util import ThrowAwayIndexLoader, update_estimator_docstring
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
|
|
20
|
+
doc = """Regressor that calls loss function directly.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
module: str or torch Module (class or instance)
|
|
25
|
+
Either the name of one of the braindecode models (see
|
|
26
|
+
:obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.
|
|
27
|
+
When passing directly a torch module, uninstantiated class should be preferred,
|
|
28
|
+
although instantiated modules will also work.
|
|
29
|
+
|
|
30
|
+
cropped: bool (default=False)
|
|
31
|
+
Defines whether torch model passed to this class is cropped or not.
|
|
32
|
+
Currently used for callbacks definition.
|
|
33
|
+
|
|
34
|
+
callbacks: None or list of strings or list of Callback instances (default=None)
|
|
35
|
+
More callbacks, in addition to those returned by
|
|
36
|
+
``get_default_callbacks``. Each callback should inherit from
|
|
37
|
+
:class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a
|
|
38
|
+
list of strings specifying `sklearn` scoring functions (for scoring
|
|
39
|
+
functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)
|
|
40
|
+
or a list of callbacks where the callback names are inferred from the
|
|
41
|
+
class name. Name conflicts are resolved by appending a count suffix
|
|
42
|
+
starting with 1, e.g. ``EpochScoring_1``. Alternatively,
|
|
43
|
+
a tuple ``(name, callback)`` can be passed, where ``name``
|
|
44
|
+
should be unique. Callbacks may or may not be instantiated.
|
|
45
|
+
The callback name can be used to set parameters on specific
|
|
46
|
+
callbacks (e.g., for the callback with name ``'print_log'``, use
|
|
47
|
+
``net.set_params(callbacks__print_log__keys_ignored=['epoch',
|
|
48
|
+
'train_loss'])``).
|
|
49
|
+
|
|
50
|
+
iterator_train__shuffle: bool (default=True)
|
|
51
|
+
Defines whether train dataset will be shuffled. As skorch does not
|
|
52
|
+
shuffle the train dataset by default this one overwrites this option.
|
|
53
|
+
|
|
54
|
+
aggregate_predictions: bool (default=True)
|
|
55
|
+
Whether to average cropped predictions to obtain window predictions. Used only in the
|
|
56
|
+
cropped mode.
|
|
57
|
+
|
|
58
|
+
""" # noqa: E501
|
|
59
|
+
__doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
module,
|
|
64
|
+
*args,
|
|
65
|
+
cropped=False,
|
|
66
|
+
callbacks=None,
|
|
67
|
+
iterator_train__shuffle=True,
|
|
68
|
+
iterator_train__drop_last=True,
|
|
69
|
+
aggregate_predictions=True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
self.cropped = cropped
|
|
73
|
+
self.aggregate_predictions = aggregate_predictions
|
|
74
|
+
self._last_window_inds_ = None
|
|
75
|
+
super().__init__(
|
|
76
|
+
module,
|
|
77
|
+
*args,
|
|
78
|
+
callbacks=callbacks,
|
|
79
|
+
iterator_train__shuffle=iterator_train__shuffle,
|
|
80
|
+
iterator_train__drop_last=iterator_train__drop_last,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def get_iterator(self, dataset, training=False, drop_index=True):
|
|
85
|
+
iterator = super().get_iterator(dataset, training=training)
|
|
86
|
+
if drop_index:
|
|
87
|
+
return ThrowAwayIndexLoader(self, iterator, is_regression=True)
|
|
88
|
+
else:
|
|
89
|
+
return iterator
|
|
90
|
+
|
|
91
|
+
def predict_proba(self, X):
|
|
92
|
+
"""Return the output of the module's forward method as a numpy.
|
|
93
|
+
|
|
94
|
+
array. In case of cropped decoding returns averaged values for
|
|
95
|
+
each trial.
|
|
96
|
+
|
|
97
|
+
If the module's forward method returns multiple outputs as a
|
|
98
|
+
tuple, it is assumed that the first output contains the
|
|
99
|
+
relevant information and the other values are ignored.
|
|
100
|
+
If all values are relevant or module's output for each crop
|
|
101
|
+
is needed, consider using :func:`~skorch.NeuralNet.forward`
|
|
102
|
+
instead.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
107
|
+
By default, you should be able to pass:
|
|
108
|
+
|
|
109
|
+
* numpy arrays
|
|
110
|
+
* torch tensors
|
|
111
|
+
* pandas DataFrame or Series
|
|
112
|
+
* scipy sparse CSR matrices
|
|
113
|
+
* a dictionary of the former three
|
|
114
|
+
* a list/tuple of the former three
|
|
115
|
+
* a Dataset
|
|
116
|
+
|
|
117
|
+
If this doesn't work with your data, you have to pass a
|
|
118
|
+
``Dataset`` that can deal with the data.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
y_proba : numpy ndarray
|
|
123
|
+
|
|
124
|
+
Warnings
|
|
125
|
+
--------
|
|
126
|
+
Regressors predict regression targets, so output of this method
|
|
127
|
+
can't be interpreted as probabilities. We advise you to use
|
|
128
|
+
`predict` method instead of `predict_proba`.
|
|
129
|
+
"""
|
|
130
|
+
y_pred = super().predict_proba(X)
|
|
131
|
+
# Normally, we have to average the predictions across crops/timesteps
|
|
132
|
+
# to get one prediction per window/trial
|
|
133
|
+
# Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
|
|
134
|
+
# However, when predictions are computed outside of CroppedTrialEpochScoring
|
|
135
|
+
# we have to average predictions, hence the check if len(y_pred.shape) == 3
|
|
136
|
+
if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
|
|
137
|
+
return y_pred.mean(-1)
|
|
138
|
+
else:
|
|
139
|
+
return y_pred
|
|
140
|
+
|
|
141
|
+
def predict_trials(self, X, return_targets=True):
|
|
142
|
+
"""Create trialwise predictions and optionally also return trialwise.
|
|
143
|
+
|
|
144
|
+
labels from cropped dataset.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
X : braindecode.datasets.BaseConcatDataset
|
|
149
|
+
A braindecode dataset to be predicted.
|
|
150
|
+
return_targets : bool
|
|
151
|
+
If True, additionally returns the trial targets.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
trial_predictions : np.ndarray
|
|
156
|
+
3-dimensional array (n_trials x n_classes x n_predictions), where
|
|
157
|
+
the number of predictions depend on the chosen window size and the
|
|
158
|
+
receptive field of the network.
|
|
159
|
+
trial_labels : np.ndarray
|
|
160
|
+
2-dimensional array (n_trials x n_targets) where the number of
|
|
161
|
+
targets depends on the decoding paradigm and can be either a single
|
|
162
|
+
value, multiple values, or a sequence.
|
|
163
|
+
"""
|
|
164
|
+
if not self.cropped:
|
|
165
|
+
warnings.warn(
|
|
166
|
+
"This method was designed to predict trials in cropped mode. "
|
|
167
|
+
"Calling it when cropped is False will give the same result as "
|
|
168
|
+
"'.predict'.",
|
|
169
|
+
UserWarning,
|
|
170
|
+
)
|
|
171
|
+
preds = self.predict(X)
|
|
172
|
+
if return_targets:
|
|
173
|
+
return preds, np.concatenate([X[i][1] for i in range(len(X))])
|
|
174
|
+
return preds
|
|
175
|
+
return predict_trials(
|
|
176
|
+
module=self.module,
|
|
177
|
+
dataset=X,
|
|
178
|
+
return_targets=return_targets,
|
|
179
|
+
batch_size=self.batch_size,
|
|
180
|
+
num_workers=self.get_iterator(X, training=False).loader.num_workers,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def fit(self, X, y=None, **kwargs):
|
|
184
|
+
"""Initialize and fit the module.
|
|
185
|
+
|
|
186
|
+
If the module was already initialized, by calling fit, the
|
|
187
|
+
module will be re-initialized (unless ``warm_start`` is True).
|
|
188
|
+
If possible, signal-related parameters are inferred from the
|
|
189
|
+
data and passed to the module at initialisation.
|
|
190
|
+
Depending on the type of input passed, the following parameters
|
|
191
|
+
are inferred:
|
|
192
|
+
|
|
193
|
+
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
194
|
+
``sfreq``, ``input_window_seconds``
|
|
195
|
+
* numpy array: ``n_times``, ``n_chans``, ``n_outputs``
|
|
196
|
+
* WindowsDataset with ``targets_from='metadata'``
|
|
197
|
+
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
198
|
+
* other Dataset: ``n_times``, ``n_chans``
|
|
199
|
+
* other types: no parameters are inferred.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
204
|
+
By default, you should be able to pass:
|
|
205
|
+
|
|
206
|
+
* mne.Epochs
|
|
207
|
+
* numpy arrays
|
|
208
|
+
* torch tensors
|
|
209
|
+
* pandas DataFrame or Series
|
|
210
|
+
* scipy sparse CSR matrices
|
|
211
|
+
* a dictionary of the former three
|
|
212
|
+
* a list/tuple of the former three
|
|
213
|
+
* a Dataset
|
|
214
|
+
|
|
215
|
+
If this doesn't work with your data, you have to pass a
|
|
216
|
+
``Dataset`` that can deal with the data.
|
|
217
|
+
|
|
218
|
+
y : target data, compatible with skorch.dataset.Dataset
|
|
219
|
+
The same data types as for ``X`` are supported. If your X is
|
|
220
|
+
a Dataset that contains the target, ``y`` may be set to
|
|
221
|
+
None.
|
|
222
|
+
|
|
223
|
+
**fit_params : dict
|
|
224
|
+
Additional parameters passed to the ``forward`` method of
|
|
225
|
+
the module and to the ``self.train_split`` call.
|
|
226
|
+
"""
|
|
227
|
+
if y is not None:
|
|
228
|
+
if y.ndim == 1:
|
|
229
|
+
y = np.array(y).reshape(-1, 1)
|
|
230
|
+
super().fit(X=X, y=y, **kwargs)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def mode(self):
|
|
234
|
+
return "regression"
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Classes to sample examples."""
|
|
2
|
+
|
|
3
|
+
from .base import (
|
|
4
|
+
BalancedSequenceSampler,
|
|
5
|
+
DistributedRecordingSampler,
|
|
6
|
+
RecordingSampler,
|
|
7
|
+
SequenceSampler,
|
|
8
|
+
)
|
|
9
|
+
from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RecordingSampler",
|
|
13
|
+
"SequenceSampler",
|
|
14
|
+
"BalancedSequenceSampler",
|
|
15
|
+
"RelativePositioningSampler",
|
|
16
|
+
"DistributedRecordingSampler",
|
|
17
|
+
"DistributedRelativePositioningSampler",
|
|
18
|
+
]
|
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sampler classes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Theo Gnassounou <>
|
|
7
|
+
# Young Truong <dt.young112@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from sklearn.utils import check_random_state
|
|
13
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
14
|
+
from torch.utils.data.sampler import Sampler
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RecordingSampler(Sampler):
|
|
18
|
+
"""Base sampler simplifying sampling from recordings.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
metadata : pd.DataFrame
|
|
23
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
24
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
25
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
26
|
+
`metadata.head()` might look like this:
|
|
27
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
28
|
+
| i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
|
|
29
|
+
+===================+=================+=================+========+==========+===========+=======+
|
|
30
|
+
| 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
|
|
31
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
32
|
+
| 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
|
|
33
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
34
|
+
| 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
|
|
35
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
36
|
+
| 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
|
|
37
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
38
|
+
| 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
|
|
39
|
+
+-------------------+-----------------+-----------------+--------+----------+-----------+-------+
|
|
40
|
+
|
|
41
|
+
random_state : np.RandomState | int | None
|
|
42
|
+
Random state.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
info : pd.DataFrame
|
|
47
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
48
|
+
and window indices information in an easily accessible structure for
|
|
49
|
+
quick sampling of windows.
|
|
50
|
+
n_recordings : int
|
|
51
|
+
Number of recordings available.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, metadata, random_state=None):
|
|
55
|
+
self.metadata = metadata
|
|
56
|
+
self.info = self._init_info(metadata)
|
|
57
|
+
self.rng = check_random_state(random_state)
|
|
58
|
+
|
|
59
|
+
def _init_info(self, metadata, required_keys=None):
|
|
60
|
+
"""Initialize ``info`` DataFrame.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
required_keys : list(str) | None
|
|
65
|
+
List of additional columns of the metadata DataFrame that we should
|
|
66
|
+
groupby when creating ``info``.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
See class attributes.
|
|
71
|
+
"""
|
|
72
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
73
|
+
if not keys:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"metadata must contain at least one of the following columns: "
|
|
76
|
+
"subject, session or run."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if required_keys is not None:
|
|
80
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
81
|
+
if len(missing_keys) > 0:
|
|
82
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
83
|
+
keys += required_keys
|
|
84
|
+
|
|
85
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
86
|
+
info = (
|
|
87
|
+
metadata.reset_index()
|
|
88
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
89
|
+
.agg(["unique"])
|
|
90
|
+
)
|
|
91
|
+
info.columns = info.columns.get_level_values(0)
|
|
92
|
+
|
|
93
|
+
return info
|
|
94
|
+
|
|
95
|
+
def sample_recording(self):
|
|
96
|
+
"""Return a random recording index."""
|
|
97
|
+
# XXX docstring missing
|
|
98
|
+
return self.rng.choice(self.n_recordings)
|
|
99
|
+
|
|
100
|
+
def sample_window(self, rec_ind=None):
|
|
101
|
+
"""Return a specific window."""
|
|
102
|
+
# XXX docstring missing
|
|
103
|
+
if rec_ind is None:
|
|
104
|
+
rec_ind = self.sample_recording()
|
|
105
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
106
|
+
return win_ind, rec_ind
|
|
107
|
+
|
|
108
|
+
def __iter__(self):
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def n_recordings(self):
|
|
113
|
+
return self.info.shape[0]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class DistributedRecordingSampler(DistributedSampler):
|
|
117
|
+
"""Base sampler simplifying sampling from recordings in distributed setting.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
metadata : pd.DataFrame
|
|
122
|
+
DataFrame with at least one of {subject, session, run} columns for each
|
|
123
|
+
window in the BaseConcatDataset to sample examples from. Normally
|
|
124
|
+
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
125
|
+
`metadata.head()` might look like this::
|
|
126
|
+
|
|
127
|
+
i_window_in_trial i_start_in_trial i_stop_in_trial target subject session run
|
|
128
|
+
0 0 0 500 -1 4 session_T run_0
|
|
129
|
+
1 1 500 1000 -1 4 session_T run_0
|
|
130
|
+
2 2 1000 1500 -1 4 session_T run_0
|
|
131
|
+
3 3 1500 2000 -1 4 session_T run_0
|
|
132
|
+
4 4 2000 2500 -1 4 session_T run_0
|
|
133
|
+
|
|
134
|
+
random_state : np.RandomState | int | None
|
|
135
|
+
Random state.
|
|
136
|
+
|
|
137
|
+
Attributes
|
|
138
|
+
----------
|
|
139
|
+
info : pd.DataFrame
|
|
140
|
+
Series with MultiIndex index which contains the subject, session, run
|
|
141
|
+
and window indices information in an easily accessible structure for
|
|
142
|
+
quick sampling of windows.
|
|
143
|
+
n_recordings : int
|
|
144
|
+
Number of recordings available.
|
|
145
|
+
kwargs : dict
|
|
146
|
+
Additional keyword arguments to pass to torch DistributedSampler.
|
|
147
|
+
See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
metadata,
|
|
153
|
+
random_state=None,
|
|
154
|
+
**kwargs,
|
|
155
|
+
):
|
|
156
|
+
self.metadata = metadata
|
|
157
|
+
self.info = self._init_info(metadata)
|
|
158
|
+
self.rng = check_random_state(random_state)
|
|
159
|
+
# send information to DistributedSampler parent to handle data splitting among workers
|
|
160
|
+
super().__init__(self.info, seed=random_state, **kwargs)
|
|
161
|
+
|
|
162
|
+
def _init_info(self, metadata, required_keys=None):
|
|
163
|
+
"""Initialize ``info`` DataFrame.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
required_keys : list(str) | None
|
|
168
|
+
List of additional columns of the metadata DataFrame that we should
|
|
169
|
+
groupby when creating ``info``.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
See class attributes.
|
|
174
|
+
"""
|
|
175
|
+
keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
|
|
176
|
+
if not keys:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"metadata must contain at least one of the following columns: "
|
|
179
|
+
"subject, session or run."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if required_keys is not None:
|
|
183
|
+
missing_keys = [k for k in required_keys if k not in self.metadata.columns]
|
|
184
|
+
if len(missing_keys) > 0:
|
|
185
|
+
raise ValueError(f"Columns {missing_keys} were not found in metadata.")
|
|
186
|
+
keys += required_keys
|
|
187
|
+
|
|
188
|
+
metadata = metadata.reset_index().rename(columns={"index": "window_index"})
|
|
189
|
+
info = (
|
|
190
|
+
metadata.reset_index()
|
|
191
|
+
.groupby(keys)[["index", "i_start_in_trial"]]
|
|
192
|
+
.agg(["unique"])
|
|
193
|
+
)
|
|
194
|
+
info.columns = info.columns.get_level_values(0)
|
|
195
|
+
|
|
196
|
+
return info
|
|
197
|
+
|
|
198
|
+
def sample_recording(self):
|
|
199
|
+
"""Return a random recording index.
|
|
200
|
+
super().__iter__() contains indices of datasets specific to the current process
|
|
201
|
+
determined by the DistributedSampler
|
|
202
|
+
"""
|
|
203
|
+
# XXX docstring missing
|
|
204
|
+
return self.rng.choice(list(super().__iter__()))
|
|
205
|
+
|
|
206
|
+
def sample_window(self, rec_ind=None):
|
|
207
|
+
"""Return a specific window."""
|
|
208
|
+
# XXX docstring missing
|
|
209
|
+
if rec_ind is None:
|
|
210
|
+
rec_ind = self.sample_recording()
|
|
211
|
+
win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
|
|
212
|
+
return win_ind, rec_ind
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def n_recordings(self):
|
|
216
|
+
return super().__len__()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class SequenceSampler(RecordingSampler):
|
|
220
|
+
"""Sample sequences of consecutive windows.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
metadata : pd.DataFrame
|
|
225
|
+
See RecordingSampler.
|
|
226
|
+
n_windows : int
|
|
227
|
+
Number of consecutive windows in a sequence.
|
|
228
|
+
n_windows_stride : int
|
|
229
|
+
Number of windows between two consecutive sequences.
|
|
230
|
+
random : bool
|
|
231
|
+
If True, sample sequences randomly. If False, sample sequences in
|
|
232
|
+
order.
|
|
233
|
+
random_state : np.random.RandomState | int | None
|
|
234
|
+
Random state.
|
|
235
|
+
|
|
236
|
+
Attributes
|
|
237
|
+
----------
|
|
238
|
+
info : pd.DataFrame
|
|
239
|
+
See RecordingSampler.
|
|
240
|
+
file_ids : np.ndarray of ints
|
|
241
|
+
Array of shape (n_sequences,) that indicates from which file each
|
|
242
|
+
sequence comes from. Useful e.g. to do self-ensembling.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def __init__(
|
|
246
|
+
self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
|
|
247
|
+
):
|
|
248
|
+
super().__init__(metadata, random_state=random_state)
|
|
249
|
+
self.randomize = randomize
|
|
250
|
+
self.n_windows = n_windows
|
|
251
|
+
self.n_windows_stride = n_windows_stride
|
|
252
|
+
self.start_inds, self.file_ids = self._compute_seq_start_inds()
|
|
253
|
+
|
|
254
|
+
def _compute_seq_start_inds(self):
|
|
255
|
+
"""Compute sequence start indices.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
np.ndarray :
|
|
260
|
+
Array of shape (n_sequences,) containing the indices of the first
|
|
261
|
+
windows of possible sequences.
|
|
262
|
+
np.ndarray :
|
|
263
|
+
Array of shape (n_sequences,) containing the unique file number of
|
|
264
|
+
each sequence. Useful e.g. to do self-ensembling.
|
|
265
|
+
"""
|
|
266
|
+
end_offset = 1 - self.n_windows if self.n_windows > 1 else None
|
|
267
|
+
start_inds = (
|
|
268
|
+
self.info["index"]
|
|
269
|
+
.apply(lambda x: x[: end_offset : self.n_windows_stride])
|
|
270
|
+
.values
|
|
271
|
+
)
|
|
272
|
+
file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
|
|
273
|
+
return np.concatenate(start_inds), np.concatenate(file_ids)
|
|
274
|
+
|
|
275
|
+
def __len__(self):
|
|
276
|
+
return len(self.start_inds)
|
|
277
|
+
|
|
278
|
+
def __iter__(self):
|
|
279
|
+
if self.randomize:
|
|
280
|
+
start_inds = self.start_inds.copy()
|
|
281
|
+
self.rng.shuffle(start_inds)
|
|
282
|
+
for start_ind in start_inds:
|
|
283
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|
|
284
|
+
else:
|
|
285
|
+
for start_ind in self.start_inds:
|
|
286
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class BalancedSequenceSampler(RecordingSampler):
|
|
290
|
+
"""Balanced sampling of sequences of consecutive windows with categorical
|
|
291
|
+
targets.
|
|
292
|
+
|
|
293
|
+
Balanced sampling of sequences inspired by the approach of [Perslev2021]_:
|
|
294
|
+
1. Uniformly sample a recording out of the available ones.
|
|
295
|
+
2. Uniformly sample one of the classes.
|
|
296
|
+
3. Sample a window of the corresponding class in the selected recording.
|
|
297
|
+
4. Extract a sequence of windows around the sampled window.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
metadata : pd.DataFrame
|
|
302
|
+
See RecordingSampler.
|
|
303
|
+
Must contain a column `target` with categorical targets.
|
|
304
|
+
n_windows : int
|
|
305
|
+
Number of consecutive windows in a sequence.
|
|
306
|
+
n_sequences : int
|
|
307
|
+
Number of sequences to sample.
|
|
308
|
+
random_state : np.random.RandomState | int | None
|
|
309
|
+
Random state.
|
|
310
|
+
|
|
311
|
+
References
|
|
312
|
+
----------
|
|
313
|
+
.. [Perslev2021] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ,
|
|
314
|
+
Igel C. U-Sleep: resilient high-frequency sleep staging. npj Digit.
|
|
315
|
+
Med. 4, 72 (2021).
|
|
316
|
+
https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
|
|
320
|
+
super().__init__(metadata, random_state=random_state)
|
|
321
|
+
|
|
322
|
+
self.n_windows = n_windows
|
|
323
|
+
self.n_sequences = n_sequences
|
|
324
|
+
self.info_class = self._init_info(metadata, required_keys=["target"])
|
|
325
|
+
|
|
326
|
+
def sample_class(self, rec_ind=None):
|
|
327
|
+
"""Return a random class.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
rec_ind : int | None
|
|
332
|
+
Index to the recording to sample from. If None, the recording will
|
|
333
|
+
be uniformly sampled across available recordings.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
int
|
|
338
|
+
Sampled class.
|
|
339
|
+
int
|
|
340
|
+
Index to the recording the class was sampled from.
|
|
341
|
+
"""
|
|
342
|
+
if rec_ind is None:
|
|
343
|
+
rec_ind = self.sample_recording()
|
|
344
|
+
available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
|
|
345
|
+
return self.rng.choice(available_classes), rec_ind
|
|
346
|
+
|
|
347
|
+
def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
|
|
348
|
+
"""Sample a sequence and return its start index.
|
|
349
|
+
|
|
350
|
+
Sample a window associated with a random recording and a random class
|
|
351
|
+
and randomly sample a sequence with it inside. The function returns the
|
|
352
|
+
index of the beginning of the sequence.
|
|
353
|
+
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
rec_ind : int | None
|
|
357
|
+
Index to the recording to sample from. If None, the recording will
|
|
358
|
+
be uniformly sampled across available recordings.
|
|
359
|
+
class_ind : int | None
|
|
360
|
+
If provided as int, sample a window of the corresponding class. If
|
|
361
|
+
None, the class will be uniformly sampled across available classes.
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
int
|
|
366
|
+
Index of the first window of the sequence.
|
|
367
|
+
int
|
|
368
|
+
Corresponding recording index.
|
|
369
|
+
int
|
|
370
|
+
Class of the sampled window.
|
|
371
|
+
"""
|
|
372
|
+
if class_ind is None:
|
|
373
|
+
class_ind, rec_ind = self.sample_class(rec_ind)
|
|
374
|
+
|
|
375
|
+
rec_inds = self.info.iloc[rec_ind]["index"]
|
|
376
|
+
len_rec_inds = len(rec_inds)
|
|
377
|
+
|
|
378
|
+
row = self.info.iloc[rec_ind].name
|
|
379
|
+
if not isinstance(row, tuple):
|
|
380
|
+
# Theres's only one category, e.g. "subject"
|
|
381
|
+
row = tuple([row])
|
|
382
|
+
available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
|
|
383
|
+
win_ind = self.rng.choice(available_indices)
|
|
384
|
+
win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
|
|
385
|
+
|
|
386
|
+
# Minimum and maximum start indices in the sequence
|
|
387
|
+
min_pos = max(0, win_ind_in_rec - self.n_windows + 1)
|
|
388
|
+
max_pos = min(len_rec_inds - self.n_windows, win_ind_in_rec)
|
|
389
|
+
start_ind = rec_inds[self.rng.randint(min_pos, max_pos + 1)]
|
|
390
|
+
|
|
391
|
+
return start_ind, rec_ind, class_ind
|
|
392
|
+
|
|
393
|
+
def __len__(self):
|
|
394
|
+
return self.n_sequences
|
|
395
|
+
|
|
396
|
+
def __iter__(self):
|
|
397
|
+
for _ in range(self.n_sequences):
|
|
398
|
+
start_ind, _, _ = self._sample_seq_start_ind()
|
|
399
|
+
yield tuple(range(start_ind, start_ind + self.n_windows))
|