braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,234 @@
1
+ # Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ # Lukas Gemein <l.gemein@gmail.com>
4
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ # Pierre Guetschel <pierre.guetschel@gmail.com>
6
+ #
7
+ # License: BSD (3-clause)
8
+
9
+ import warnings
10
+
11
+ import numpy as np
12
+ from skorch.regressor import NeuralNetRegressor
13
+
14
+ from .eegneuralnet import _EEGNeuralNet
15
+ from .training.scoring import predict_trials
16
+ from .util import ThrowAwayIndexLoader, update_estimator_docstring
17
+
18
+
19
+ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
20
+ doc = """Regressor that calls loss function directly.
21
+
22
+ Parameters
23
+ ----------
24
+ module: str or torch Module (class or instance)
25
+ Either the name of one of the braindecode models (see
26
+ :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.
27
+ When passing directly a torch module, uninstantiated class should be preferred,
28
+ although instantiated modules will also work.
29
+
30
+ cropped: bool (default=False)
31
+ Defines whether torch model passed to this class is cropped or not.
32
+ Currently used for callbacks definition.
33
+
34
+ callbacks: None or list of strings or list of Callback instances (default=None)
35
+ More callbacks, in addition to those returned by
36
+ ``get_default_callbacks``. Each callback should inherit from
37
+ :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a
38
+ list of strings specifying `sklearn` scoring functions (for scoring
39
+ functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)
40
+ or a list of callbacks where the callback names are inferred from the
41
+ class name. Name conflicts are resolved by appending a count suffix
42
+ starting with 1, e.g. ``EpochScoring_1``. Alternatively,
43
+ a tuple ``(name, callback)`` can be passed, where ``name``
44
+ should be unique. Callbacks may or may not be instantiated.
45
+ The callback name can be used to set parameters on specific
46
+ callbacks (e.g., for the callback with name ``'print_log'``, use
47
+ ``net.set_params(callbacks__print_log__keys_ignored=['epoch',
48
+ 'train_loss'])``).
49
+
50
+ iterator_train__shuffle: bool (default=True)
51
+ Defines whether train dataset will be shuffled. As skorch does not
52
+ shuffle the train dataset by default this one overwrites this option.
53
+
54
+ aggregate_predictions: bool (default=True)
55
+ Whether to average cropped predictions to obtain window predictions. Used only in the
56
+ cropped mode.
57
+
58
+ """ # noqa: E501
59
+ __doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
60
+
61
+ def __init__(
62
+ self,
63
+ module,
64
+ *args,
65
+ cropped=False,
66
+ callbacks=None,
67
+ iterator_train__shuffle=True,
68
+ iterator_train__drop_last=True,
69
+ aggregate_predictions=True,
70
+ **kwargs,
71
+ ):
72
+ self.cropped = cropped
73
+ self.aggregate_predictions = aggregate_predictions
74
+ self._last_window_inds_ = None
75
+ super().__init__(
76
+ module,
77
+ *args,
78
+ callbacks=callbacks,
79
+ iterator_train__shuffle=iterator_train__shuffle,
80
+ iterator_train__drop_last=iterator_train__drop_last,
81
+ **kwargs,
82
+ )
83
+
84
+ def get_iterator(self, dataset, training=False, drop_index=True):
85
+ iterator = super().get_iterator(dataset, training=training)
86
+ if drop_index:
87
+ return ThrowAwayIndexLoader(self, iterator, is_regression=True)
88
+ else:
89
+ return iterator
90
+
91
+ def predict_proba(self, X):
92
+ """Return the output of the module's forward method as a numpy.
93
+
94
+ array. In case of cropped decoding returns averaged values for
95
+ each trial.
96
+
97
+ If the module's forward method returns multiple outputs as a
98
+ tuple, it is assumed that the first output contains the
99
+ relevant information and the other values are ignored.
100
+ If all values are relevant or module's output for each crop
101
+ is needed, consider using :func:`~skorch.NeuralNet.forward`
102
+ instead.
103
+
104
+ Parameters
105
+ ----------
106
+ X : input data, compatible with skorch.dataset.Dataset
107
+ By default, you should be able to pass:
108
+
109
+ * numpy arrays
110
+ * torch tensors
111
+ * pandas DataFrame or Series
112
+ * scipy sparse CSR matrices
113
+ * a dictionary of the former three
114
+ * a list/tuple of the former three
115
+ * a Dataset
116
+
117
+ If this doesn't work with your data, you have to pass a
118
+ ``Dataset`` that can deal with the data.
119
+
120
+ Returns
121
+ -------
122
+ y_proba : numpy ndarray
123
+
124
+ Warnings
125
+ --------
126
+ Regressors predict regression targets, so output of this method
127
+ can't be interpreted as probabilities. We advise you to use
128
+ `predict` method instead of `predict_proba`.
129
+ """
130
+ y_pred = super().predict_proba(X)
131
+ # Normally, we have to average the predictions across crops/timesteps
132
+ # to get one prediction per window/trial
133
+ # Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
134
+ # However, when predictions are computed outside of CroppedTrialEpochScoring
135
+ # we have to average predictions, hence the check if len(y_pred.shape) == 3
136
+ if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
137
+ return y_pred.mean(-1)
138
+ else:
139
+ return y_pred
140
+
141
+ def predict_trials(self, X, return_targets=True):
142
+ """Create trialwise predictions and optionally also return trialwise.
143
+
144
+ labels from cropped dataset.
145
+
146
+ Parameters
147
+ ----------
148
+ X : braindecode.datasets.BaseConcatDataset
149
+ A braindecode dataset to be predicted.
150
+ return_targets : bool
151
+ If True, additionally returns the trial targets.
152
+
153
+ Returns
154
+ -------
155
+ trial_predictions : np.ndarray
156
+ 3-dimensional array (n_trials x n_classes x n_predictions), where
157
+ the number of predictions depend on the chosen window size and the
158
+ receptive field of the network.
159
+ trial_labels : np.ndarray
160
+ 2-dimensional array (n_trials x n_targets) where the number of
161
+ targets depends on the decoding paradigm and can be either a single
162
+ value, multiple values, or a sequence.
163
+ """
164
+ if not self.cropped:
165
+ warnings.warn(
166
+ "This method was designed to predict trials in cropped mode. "
167
+ "Calling it when cropped is False will give the same result as "
168
+ "'.predict'.",
169
+ UserWarning,
170
+ )
171
+ preds = self.predict(X)
172
+ if return_targets:
173
+ return preds, np.concatenate([X[i][1] for i in range(len(X))])
174
+ return preds
175
+ return predict_trials(
176
+ module=self.module,
177
+ dataset=X,
178
+ return_targets=return_targets,
179
+ batch_size=self.batch_size,
180
+ num_workers=self.get_iterator(X, training=False).loader.num_workers,
181
+ )
182
+
183
+ def fit(self, X, y=None, **kwargs):
184
+ """Initialize and fit the module.
185
+
186
+ If the module was already initialized, by calling fit, the
187
+ module will be re-initialized (unless ``warm_start`` is True).
188
+ If possible, signal-related parameters are inferred from the
189
+ data and passed to the module at initialisation.
190
+ Depending on the type of input passed, the following parameters
191
+ are inferred:
192
+
193
+ * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
194
+ ``sfreq``, ``input_window_seconds``
195
+ * numpy array: ``n_times``, ``n_chans``, ``n_outputs``
196
+ * WindowsDataset with ``targets_from='metadata'``
197
+ (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
198
+ * other Dataset: ``n_times``, ``n_chans``
199
+ * other types: no parameters are inferred.
200
+
201
+ Parameters
202
+ ----------
203
+ X : input data, compatible with skorch.dataset.Dataset
204
+ By default, you should be able to pass:
205
+
206
+ * mne.Epochs
207
+ * numpy arrays
208
+ * torch tensors
209
+ * pandas DataFrame or Series
210
+ * scipy sparse CSR matrices
211
+ * a dictionary of the former three
212
+ * a list/tuple of the former three
213
+ * a Dataset
214
+
215
+ If this doesn't work with your data, you have to pass a
216
+ ``Dataset`` that can deal with the data.
217
+
218
+ y : target data, compatible with skorch.dataset.Dataset
219
+ The same data types as for ``X`` are supported. If your X is
220
+ a Dataset that contains the target, ``y`` may be set to
221
+ None.
222
+
223
+ **fit_params : dict
224
+ Additional parameters passed to the ``forward`` method of
225
+ the module and to the ``self.train_split`` call.
226
+ """
227
+ if y is not None:
228
+ if y.ndim == 1:
229
+ y = np.array(y).reshape(-1, 1)
230
+ super().fit(X=X, y=y, **kwargs)
231
+
232
+ @property
233
+ def mode(self):
234
+ return "regression"
@@ -0,0 +1,18 @@
1
+ """Classes to sample examples."""
2
+
3
+ from .base import (
4
+ BalancedSequenceSampler,
5
+ DistributedRecordingSampler,
6
+ RecordingSampler,
7
+ SequenceSampler,
8
+ )
9
+ from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
10
+
11
+ __all__ = [
12
+ "RecordingSampler",
13
+ "SequenceSampler",
14
+ "BalancedSequenceSampler",
15
+ "RelativePositioningSampler",
16
+ "DistributedRecordingSampler",
17
+ "DistributedRelativePositioningSampler",
18
+ ]
@@ -0,0 +1,399 @@
1
+ """
2
+ Sampler classes.
3
+ """
4
+
5
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Theo Gnassounou <>
7
+ # Young Truong <dt.young112@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ import numpy as np
12
+ from sklearn.utils import check_random_state
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.utils.data.sampler import Sampler
15
+
16
+
17
+ class RecordingSampler(Sampler):
18
+ """Base sampler simplifying sampling from recordings.
19
+
20
+ Parameters
21
+ ----------
22
+ metadata : pd.DataFrame
23
+ DataFrame with at least one of {subject, session, run} columns for each
24
+ window in the BaseConcatDataset to sample examples from. Normally
25
+ obtained with `BaseConcatDataset.get_metadata()`. For instance,
26
+ `metadata.head()` might look like this:
27
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
28
+ | i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
29
+ +===================+=================+=================+========+==========+===========+=======+
30
+ | 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
31
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
32
+ | 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
33
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
34
+ | 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
35
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
36
+ | 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
37
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
38
+ | 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
39
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
40
+
41
+ random_state : np.RandomState | int | None
42
+ Random state.
43
+
44
+ Attributes
45
+ ----------
46
+ info : pd.DataFrame
47
+ Series with MultiIndex index which contains the subject, session, run
48
+ and window indices information in an easily accessible structure for
49
+ quick sampling of windows.
50
+ n_recordings : int
51
+ Number of recordings available.
52
+ """
53
+
54
+ def __init__(self, metadata, random_state=None):
55
+ self.metadata = metadata
56
+ self.info = self._init_info(metadata)
57
+ self.rng = check_random_state(random_state)
58
+
59
+ def _init_info(self, metadata, required_keys=None):
60
+ """Initialize ``info`` DataFrame.
61
+
62
+ Parameters
63
+ ----------
64
+ required_keys : list(str) | None
65
+ List of additional columns of the metadata DataFrame that we should
66
+ groupby when creating ``info``.
67
+
68
+ Returns
69
+ -------
70
+ See class attributes.
71
+ """
72
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
73
+ if not keys:
74
+ raise ValueError(
75
+ "metadata must contain at least one of the following columns: "
76
+ "subject, session or run."
77
+ )
78
+
79
+ if required_keys is not None:
80
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
81
+ if len(missing_keys) > 0:
82
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
83
+ keys += required_keys
84
+
85
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
86
+ info = (
87
+ metadata.reset_index()
88
+ .groupby(keys)[["index", "i_start_in_trial"]]
89
+ .agg(["unique"])
90
+ )
91
+ info.columns = info.columns.get_level_values(0)
92
+
93
+ return info
94
+
95
+ def sample_recording(self):
96
+ """Return a random recording index."""
97
+ # XXX docstring missing
98
+ return self.rng.choice(self.n_recordings)
99
+
100
+ def sample_window(self, rec_ind=None):
101
+ """Return a specific window."""
102
+ # XXX docstring missing
103
+ if rec_ind is None:
104
+ rec_ind = self.sample_recording()
105
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
106
+ return win_ind, rec_ind
107
+
108
+ def __iter__(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def n_recordings(self):
113
+ return self.info.shape[0]
114
+
115
+
116
+ class DistributedRecordingSampler(DistributedSampler):
117
+ """Base sampler simplifying sampling from recordings in distributed setting.
118
+
119
+ Parameters
120
+ ----------
121
+ metadata : pd.DataFrame
122
+ DataFrame with at least one of {subject, session, run} columns for each
123
+ window in the BaseConcatDataset to sample examples from. Normally
124
+ obtained with `BaseConcatDataset.get_metadata()`. For instance,
125
+ `metadata.head()` might look like this::
126
+
127
+ i_window_in_trial i_start_in_trial i_stop_in_trial target subject session run
128
+ 0 0 0 500 -1 4 session_T run_0
129
+ 1 1 500 1000 -1 4 session_T run_0
130
+ 2 2 1000 1500 -1 4 session_T run_0
131
+ 3 3 1500 2000 -1 4 session_T run_0
132
+ 4 4 2000 2500 -1 4 session_T run_0
133
+
134
+ random_state : np.RandomState | int | None
135
+ Random state.
136
+
137
+ Attributes
138
+ ----------
139
+ info : pd.DataFrame
140
+ Series with MultiIndex index which contains the subject, session, run
141
+ and window indices information in an easily accessible structure for
142
+ quick sampling of windows.
143
+ n_recordings : int
144
+ Number of recordings available.
145
+ kwargs : dict
146
+ Additional keyword arguments to pass to torch DistributedSampler.
147
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ metadata,
153
+ random_state=None,
154
+ **kwargs,
155
+ ):
156
+ self.metadata = metadata
157
+ self.info = self._init_info(metadata)
158
+ self.rng = check_random_state(random_state)
159
+ # send information to DistributedSampler parent to handle data splitting among workers
160
+ super().__init__(self.info, seed=random_state, **kwargs)
161
+
162
+ def _init_info(self, metadata, required_keys=None):
163
+ """Initialize ``info`` DataFrame.
164
+
165
+ Parameters
166
+ ----------
167
+ required_keys : list(str) | None
168
+ List of additional columns of the metadata DataFrame that we should
169
+ groupby when creating ``info``.
170
+
171
+ Returns
172
+ -------
173
+ See class attributes.
174
+ """
175
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
176
+ if not keys:
177
+ raise ValueError(
178
+ "metadata must contain at least one of the following columns: "
179
+ "subject, session or run."
180
+ )
181
+
182
+ if required_keys is not None:
183
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
184
+ if len(missing_keys) > 0:
185
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
186
+ keys += required_keys
187
+
188
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
189
+ info = (
190
+ metadata.reset_index()
191
+ .groupby(keys)[["index", "i_start_in_trial"]]
192
+ .agg(["unique"])
193
+ )
194
+ info.columns = info.columns.get_level_values(0)
195
+
196
+ return info
197
+
198
+ def sample_recording(self):
199
+ """Return a random recording index.
200
+ super().__iter__() contains indices of datasets specific to the current process
201
+ determined by the DistributedSampler
202
+ """
203
+ # XXX docstring missing
204
+ return self.rng.choice(list(super().__iter__()))
205
+
206
+ def sample_window(self, rec_ind=None):
207
+ """Return a specific window."""
208
+ # XXX docstring missing
209
+ if rec_ind is None:
210
+ rec_ind = self.sample_recording()
211
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
212
+ return win_ind, rec_ind
213
+
214
+ @property
215
+ def n_recordings(self):
216
+ return super().__len__()
217
+
218
+
219
+ class SequenceSampler(RecordingSampler):
220
+ """Sample sequences of consecutive windows.
221
+
222
+ Parameters
223
+ ----------
224
+ metadata : pd.DataFrame
225
+ See RecordingSampler.
226
+ n_windows : int
227
+ Number of consecutive windows in a sequence.
228
+ n_windows_stride : int
229
+ Number of windows between two consecutive sequences.
230
+ random : bool
231
+ If True, sample sequences randomly. If False, sample sequences in
232
+ order.
233
+ random_state : np.random.RandomState | int | None
234
+ Random state.
235
+
236
+ Attributes
237
+ ----------
238
+ info : pd.DataFrame
239
+ See RecordingSampler.
240
+ file_ids : np.ndarray of ints
241
+ Array of shape (n_sequences,) that indicates from which file each
242
+ sequence comes from. Useful e.g. to do self-ensembling.
243
+ """
244
+
245
+ def __init__(
246
+ self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
247
+ ):
248
+ super().__init__(metadata, random_state=random_state)
249
+ self.randomize = randomize
250
+ self.n_windows = n_windows
251
+ self.n_windows_stride = n_windows_stride
252
+ self.start_inds, self.file_ids = self._compute_seq_start_inds()
253
+
254
+ def _compute_seq_start_inds(self):
255
+ """Compute sequence start indices.
256
+
257
+ Returns
258
+ -------
259
+ np.ndarray :
260
+ Array of shape (n_sequences,) containing the indices of the first
261
+ windows of possible sequences.
262
+ np.ndarray :
263
+ Array of shape (n_sequences,) containing the unique file number of
264
+ each sequence. Useful e.g. to do self-ensembling.
265
+ """
266
+ end_offset = 1 - self.n_windows if self.n_windows > 1 else None
267
+ start_inds = (
268
+ self.info["index"]
269
+ .apply(lambda x: x[: end_offset : self.n_windows_stride])
270
+ .values
271
+ )
272
+ file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
273
+ return np.concatenate(start_inds), np.concatenate(file_ids)
274
+
275
+ def __len__(self):
276
+ return len(self.start_inds)
277
+
278
+ def __iter__(self):
279
+ if self.randomize:
280
+ start_inds = self.start_inds.copy()
281
+ self.rng.shuffle(start_inds)
282
+ for start_ind in start_inds:
283
+ yield tuple(range(start_ind, start_ind + self.n_windows))
284
+ else:
285
+ for start_ind in self.start_inds:
286
+ yield tuple(range(start_ind, start_ind + self.n_windows))
287
+
288
+
289
+ class BalancedSequenceSampler(RecordingSampler):
290
+ """Balanced sampling of sequences of consecutive windows with categorical
291
+ targets.
292
+
293
+ Balanced sampling of sequences inspired by the approach of [Perslev2021]_:
294
+ 1. Uniformly sample a recording out of the available ones.
295
+ 2. Uniformly sample one of the classes.
296
+ 3. Sample a window of the corresponding class in the selected recording.
297
+ 4. Extract a sequence of windows around the sampled window.
298
+
299
+ Parameters
300
+ ----------
301
+ metadata : pd.DataFrame
302
+ See RecordingSampler.
303
+ Must contain a column `target` with categorical targets.
304
+ n_windows : int
305
+ Number of consecutive windows in a sequence.
306
+ n_sequences : int
307
+ Number of sequences to sample.
308
+ random_state : np.random.RandomState | int | None
309
+ Random state.
310
+
311
+ References
312
+ ----------
313
+ .. [Perslev2021] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ,
314
+ Igel C. U-Sleep: resilient high-frequency sleep staging. npj Digit.
315
+ Med. 4, 72 (2021).
316
+ https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
317
+ """
318
+
319
+ def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
320
+ super().__init__(metadata, random_state=random_state)
321
+
322
+ self.n_windows = n_windows
323
+ self.n_sequences = n_sequences
324
+ self.info_class = self._init_info(metadata, required_keys=["target"])
325
+
326
+ def sample_class(self, rec_ind=None):
327
+ """Return a random class.
328
+
329
+ Parameters
330
+ ----------
331
+ rec_ind : int | None
332
+ Index to the recording to sample from. If None, the recording will
333
+ be uniformly sampled across available recordings.
334
+
335
+ Returns
336
+ -------
337
+ int
338
+ Sampled class.
339
+ int
340
+ Index to the recording the class was sampled from.
341
+ """
342
+ if rec_ind is None:
343
+ rec_ind = self.sample_recording()
344
+ available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
345
+ return self.rng.choice(available_classes), rec_ind
346
+
347
+ def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
348
+ """Sample a sequence and return its start index.
349
+
350
+ Sample a window associated with a random recording and a random class
351
+ and randomly sample a sequence with it inside. The function returns the
352
+ index of the beginning of the sequence.
353
+
354
+ Parameters
355
+ ----------
356
+ rec_ind : int | None
357
+ Index to the recording to sample from. If None, the recording will
358
+ be uniformly sampled across available recordings.
359
+ class_ind : int | None
360
+ If provided as int, sample a window of the corresponding class. If
361
+ None, the class will be uniformly sampled across available classes.
362
+
363
+ Returns
364
+ -------
365
+ int
366
+ Index of the first window of the sequence.
367
+ int
368
+ Corresponding recording index.
369
+ int
370
+ Class of the sampled window.
371
+ """
372
+ if class_ind is None:
373
+ class_ind, rec_ind = self.sample_class(rec_ind)
374
+
375
+ rec_inds = self.info.iloc[rec_ind]["index"]
376
+ len_rec_inds = len(rec_inds)
377
+
378
+ row = self.info.iloc[rec_ind].name
379
+ if not isinstance(row, tuple):
380
+ # Theres's only one category, e.g. "subject"
381
+ row = tuple([row])
382
+ available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
383
+ win_ind = self.rng.choice(available_indices)
384
+ win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
385
+
386
+ # Minimum and maximum start indices in the sequence
387
+ min_pos = max(0, win_ind_in_rec - self.n_windows + 1)
388
+ max_pos = min(len_rec_inds - self.n_windows, win_ind_in_rec)
389
+ start_ind = rec_inds[self.rng.randint(min_pos, max_pos + 1)]
390
+
391
+ return start_ind, rec_ind, class_ind
392
+
393
+ def __len__(self):
394
+ return self.n_sequences
395
+
396
+ def __iter__(self):
397
+ for _ in range(self.n_sequences):
398
+ start_ind, _, _ = self._sample_seq_start_ind()
399
+ yield tuple(range(start_ind, start_ind + self.n_windows))