braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/regressor.py CHANGED
@@ -11,8 +11,8 @@ import warnings
11
11
  import numpy as np
12
12
  from skorch.regressor import NeuralNetRegressor
13
13
 
14
- from .training.scoring import predict_trials
15
14
  from .eegneuralnet import _EEGNeuralNet
15
+ from .training.scoring import predict_trials
16
16
  from .util import ThrowAwayIndexLoader, update_estimator_docstring
17
17
 
18
18
 
@@ -58,19 +58,28 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
58
58
  """ # noqa: E501
59
59
  __doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
60
60
 
61
- def __init__(self, module, *args, cropped=False, callbacks=None,
62
- iterator_train__shuffle=True,
63
- iterator_train__drop_last=True,
64
- aggregate_predictions=True, **kwargs):
61
+ def __init__(
62
+ self,
63
+ module,
64
+ *args,
65
+ cropped=False,
66
+ callbacks=None,
67
+ iterator_train__shuffle=True,
68
+ iterator_train__drop_last=True,
69
+ aggregate_predictions=True,
70
+ **kwargs,
71
+ ):
65
72
  self.cropped = cropped
66
73
  self.aggregate_predictions = aggregate_predictions
67
74
  self._last_window_inds_ = None
68
- super().__init__(module,
69
- *args,
70
- callbacks=callbacks,
71
- iterator_train__shuffle=iterator_train__shuffle,
72
- iterator_train__drop_last=iterator_train__drop_last,
73
- **kwargs)
75
+ super().__init__(
76
+ module,
77
+ *args,
78
+ callbacks=callbacks,
79
+ iterator_train__shuffle=iterator_train__shuffle,
80
+ iterator_train__drop_last=iterator_train__drop_last,
81
+ **kwargs,
82
+ )
74
83
 
75
84
  def get_iterator(self, dataset, training=False, drop_index=True):
76
85
  iterator = super().get_iterator(dataset, training=training)
@@ -155,7 +164,9 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
155
164
  warnings.warn(
156
165
  "This method was designed to predict trials in cropped mode. "
157
166
  "Calling it when cropped is False will give the same result as "
158
- "'.predict'.", UserWarning)
167
+ "'.predict'.",
168
+ UserWarning,
169
+ )
159
170
  preds = self.predict(X)
160
171
  if return_targets:
161
172
  return preds, np.concatenate([X[i][1] for i in range(len(X))])
@@ -1,10 +1,18 @@
1
- """Classes to sample examples.
2
- """
1
+ """Classes to sample examples."""
3
2
 
4
- from .base import RecordingSampler, SequenceSampler, BalancedSequenceSampler
5
- from .ssl import RelativePositioningSampler
3
+ from .base import (
4
+ BalancedSequenceSampler,
5
+ DistributedRecordingSampler,
6
+ RecordingSampler,
7
+ SequenceSampler,
8
+ )
9
+ from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
6
10
 
7
- __all__ = ["RecordingSampler",
8
- "SequenceSampler",
9
- "BalancedSequenceSampler",
10
- "RelativePositioningSampler"]
11
+ __all__ = [
12
+ "RecordingSampler",
13
+ "SequenceSampler",
14
+ "BalancedSequenceSampler",
15
+ "RelativePositioningSampler",
16
+ "DistributedRecordingSampler",
17
+ "DistributedRelativePositioningSampler",
18
+ ]
@@ -4,17 +4,120 @@ Sampler classes.
4
4
 
5
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
6
  # Theo Gnassounou <>
7
+ # Young Truong <dt.young112@gmail.com>
7
8
  #
8
9
  # License: BSD (3-clause)
9
10
 
11
+ from typing import Optional
12
+
10
13
  import numpy as np
11
- from torch.utils.data.sampler import Sampler
12
14
  from sklearn.utils import check_random_state
15
+ from torch.utils.data.distributed import DistributedSampler
16
+ from torch.utils.data.sampler import Sampler
13
17
 
14
18
 
15
19
  class RecordingSampler(Sampler):
16
20
  """Base sampler simplifying sampling from recordings.
17
21
 
22
+ Parameters
23
+ ----------
24
+ metadata : pd.DataFrame
25
+ DataFrame with at least one of {subject, session, run} columns for each
26
+ window in the BaseConcatDataset to sample examples from. Normally
27
+ obtained with `BaseConcatDataset.get_metadata()`. For instance,
28
+ `metadata.head()` might look like this:
29
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
30
+ | i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
31
+ +===================+=================+=================+========+==========+===========+=======+
32
+ | 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
33
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
34
+ | 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
35
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
36
+ | 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
37
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
38
+ | 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
39
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
40
+ | 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
41
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
42
+
43
+ random_state : np.RandomState | int | None
44
+ Random state.
45
+
46
+ Attributes
47
+ ----------
48
+ info : pd.DataFrame
49
+ Series with MultiIndex index which contains the subject, session, run
50
+ and window indices information in an easily accessible structure for
51
+ quick sampling of windows.
52
+ n_recordings : int
53
+ Number of recordings available.
54
+ """
55
+
56
+ def __init__(self, metadata, random_state=None):
57
+ self.metadata = metadata
58
+ self.info = self._init_info(metadata)
59
+ self.rng = check_random_state(random_state)
60
+
61
+ def _init_info(self, metadata, required_keys=None):
62
+ """Initialize ``info`` DataFrame.
63
+
64
+ Parameters
65
+ ----------
66
+ required_keys : list(str) | None
67
+ List of additional columns of the metadata DataFrame that we should
68
+ groupby when creating ``info``.
69
+
70
+ Returns
71
+ -------
72
+ See class attributes.
73
+ """
74
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
75
+ if not keys:
76
+ raise ValueError(
77
+ "metadata must contain at least one of the following columns: "
78
+ "subject, session or run."
79
+ )
80
+
81
+ if required_keys is not None:
82
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
83
+ if len(missing_keys) > 0:
84
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
85
+ keys += required_keys
86
+
87
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
88
+ info = (
89
+ metadata.reset_index()
90
+ .groupby(keys)[["index", "i_start_in_trial"]]
91
+ .agg(["unique"])
92
+ )
93
+ info.columns = info.columns.get_level_values(0)
94
+
95
+ return info
96
+
97
+ def sample_recording(self):
98
+ """Return a random recording index."""
99
+ # XXX docstring missing
100
+ return self.rng.choice(self.n_recordings)
101
+
102
+ def sample_window(self, rec_ind=None):
103
+ """Return a specific window."""
104
+ # XXX docstring missing
105
+ if rec_ind is None:
106
+ rec_ind = self.sample_recording()
107
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
108
+ return win_ind, rec_ind
109
+
110
+ def __iter__(self):
111
+ raise NotImplementedError
112
+
113
+ @property
114
+ def n_recordings(self):
115
+ return self.info.shape[0]
116
+
117
+
118
+ class DistributedRecordingSampler(DistributedSampler):
119
+ """Base sampler simplifying sampling from recordings in distributed setting.
120
+
18
121
  Parameters
19
122
  ----------
20
123
  metadata : pd.DataFrame
@@ -41,11 +144,22 @@ class RecordingSampler(Sampler):
41
144
  quick sampling of windows.
42
145
  n_recordings : int
43
146
  Number of recordings available.
147
+ kwargs : dict
148
+ Additional keyword arguments to pass to torch DistributedSampler.
149
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
44
150
  """
45
- def __init__(self, metadata, random_state=None):
151
+
152
+ def __init__(
153
+ self,
154
+ metadata,
155
+ random_state=None,
156
+ **kwargs,
157
+ ):
46
158
  self.metadata = metadata
47
159
  self.info = self._init_info(metadata)
48
160
  self.rng = check_random_state(random_state)
161
+ # send information to DistributedSampler parent to handle data splitting among workers
162
+ super().__init__(self.info, seed=random_state, **kwargs)
49
163
 
50
164
  def _init_info(self, metadata, required_keys=None):
51
165
  """Initialize ``info`` DataFrame.
@@ -60,50 +174,48 @@ class RecordingSampler(Sampler):
60
174
  -------
61
175
  See class attributes.
62
176
  """
63
- keys = [k for k in ['subject', 'session', 'run']
64
- if k in self.metadata.columns]
177
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
65
178
  if not keys:
66
179
  raise ValueError(
67
- 'metadata must contain at least one of the following columns: '
68
- 'subject, session or run.')
180
+ "metadata must contain at least one of the following columns: "
181
+ "subject, session or run."
182
+ )
69
183
 
70
184
  if required_keys is not None:
71
- missing_keys = [
72
- k for k in required_keys if k not in self.metadata.columns]
185
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
73
186
  if len(missing_keys) > 0:
74
- raise ValueError(
75
- f'Columns {missing_keys} were not found in metadata.')
187
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
76
188
  keys += required_keys
77
189
 
78
- metadata = metadata.reset_index().rename(
79
- columns={'index': 'window_index'})
80
- info = metadata.reset_index().groupby(keys)[
81
- ['index', 'i_start_in_trial']].agg(['unique'])
190
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
191
+ info = (
192
+ metadata.reset_index()
193
+ .groupby(keys)[["index", "i_start_in_trial"]]
194
+ .agg(["unique"])
195
+ )
82
196
  info.columns = info.columns.get_level_values(0)
83
197
 
84
198
  return info
85
199
 
86
200
  def sample_recording(self):
87
201
  """Return a random recording index.
202
+ super().__iter__() contains indices of datasets specific to the current process
203
+ determined by the DistributedSampler
88
204
  """
89
205
  # XXX docstring missing
90
- return self.rng.choice(self.n_recordings)
206
+ return self.rng.choice(list(super().__iter__()))
91
207
 
92
208
  def sample_window(self, rec_ind=None):
93
- """Return a specific window.
94
- """
209
+ """Return a specific window."""
95
210
  # XXX docstring missing
96
211
  if rec_ind is None:
97
212
  rec_ind = self.sample_recording()
98
- win_ind = self.rng.choice(self.info.iloc[rec_ind]['index'])
213
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
99
214
  return win_ind, rec_ind
100
215
 
101
- def __iter__(self):
102
- raise NotImplementedError
103
-
104
216
  @property
105
217
  def n_recordings(self):
106
- return self.info.shape[0]
218
+ return super().__len__()
107
219
 
108
220
 
109
221
  class SequenceSampler(RecordingSampler):
@@ -131,8 +243,10 @@ class SequenceSampler(RecordingSampler):
131
243
  Array of shape (n_sequences,) that indicates from which file each
132
244
  sequence comes from. Useful e.g. to do self-ensembling.
133
245
  """
134
- def __init__(self, metadata, n_windows, n_windows_stride, randomize=False,
135
- random_state=None):
246
+
247
+ def __init__(
248
+ self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
249
+ ):
136
250
  super().__init__(metadata, random_state=random_state)
137
251
  self.randomize = randomize
138
252
  self.n_windows = n_windows
@@ -152,8 +266,11 @@ class SequenceSampler(RecordingSampler):
152
266
  each sequence. Useful e.g. to do self-ensembling.
153
267
  """
154
268
  end_offset = 1 - self.n_windows if self.n_windows > 1 else None
155
- start_inds = self.info['index'].apply(
156
- lambda x: x[:end_offset:self.n_windows_stride]).values
269
+ start_inds = (
270
+ self.info["index"]
271
+ .apply(lambda x: x[: end_offset : self.n_windows_stride])
272
+ .values
273
+ )
157
274
  file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
158
275
  return np.concatenate(start_inds), np.concatenate(file_ids)
159
276
 
@@ -200,12 +317,13 @@ class BalancedSequenceSampler(RecordingSampler):
200
317
  Med. 4, 72 (2021).
201
318
  https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
202
319
  """
320
+
203
321
  def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
204
322
  super().__init__(metadata, random_state=random_state)
205
323
 
206
324
  self.n_windows = n_windows
207
325
  self.n_sequences = n_sequences
208
- self.info_class = self._init_info(metadata, required_keys=['target'])
326
+ self.info_class = self._init_info(metadata, required_keys=["target"])
209
327
 
210
328
  def sample_class(self, rec_ind=None):
211
329
  """Return a random class.
@@ -225,8 +343,7 @@ class BalancedSequenceSampler(RecordingSampler):
225
343
  """
226
344
  if rec_ind is None:
227
345
  rec_ind = self.sample_recording()
228
- available_classes = self.info_class.loc[
229
- self.info.iloc[rec_ind].name].index
346
+ available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
230
347
  return self.rng.choice(available_classes), rec_ind
231
348
 
232
349
  def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
@@ -257,15 +374,14 @@ class BalancedSequenceSampler(RecordingSampler):
257
374
  if class_ind is None:
258
375
  class_ind, rec_ind = self.sample_class(rec_ind)
259
376
 
260
- rec_inds = self.info.iloc[rec_ind]['index']
377
+ rec_inds = self.info.iloc[rec_ind]["index"]
261
378
  len_rec_inds = len(rec_inds)
262
379
 
263
380
  row = self.info.iloc[rec_ind].name
264
381
  if not isinstance(row, tuple):
265
382
  # Theres's only one category, e.g. "subject"
266
383
  row = tuple([row])
267
- available_indices = self.info_class.loc[
268
- row + tuple([class_ind]), 'index']
384
+ available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
269
385
  win_ind = self.rng.choice(available_indices)
270
386
  win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
271
387
 
@@ -3,12 +3,16 @@ Self-supervised learning samplers.
3
3
  """
4
4
 
5
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Young Truong <dt.young112@gmail.com>
6
7
  #
7
8
  # License: BSD (3-clause)
8
9
 
10
+ import warnings
11
+
9
12
  import numpy as np
13
+ import torch.distributed as dist
10
14
 
11
- from . import RecordingSampler
15
+ from . import DistributedRecordingSampler, RecordingSampler
12
16
 
13
17
 
14
18
  class RelativePositioningSampler(RecordingSampler):
@@ -45,8 +49,17 @@ class RelativePositioningSampler(RecordingSampler):
45
49
  signals with self-supervised learning.
46
50
  arXiv preprint arXiv:2007.16104.
47
51
  """
48
- def __init__(self, metadata, tau_pos, tau_neg, n_examples, tau_max=None,
49
- same_rec_neg=True, random_state=None):
52
+
53
+ def __init__(
54
+ self,
55
+ metadata,
56
+ tau_pos,
57
+ tau_neg,
58
+ n_examples,
59
+ tau_max=None,
60
+ same_rec_neg=True,
61
+ random_state=None,
62
+ ):
50
63
  super().__init__(metadata, random_state=random_state)
51
64
 
52
65
  self.tau_pos = tau_pos
@@ -56,25 +69,153 @@ class RelativePositioningSampler(RecordingSampler):
56
69
  self.same_rec_neg = same_rec_neg
57
70
 
58
71
  if not same_rec_neg and self.n_recordings < 2:
59
- raise ValueError('More than one recording must be available when '
60
- 'using across-recording negative sampling.')
72
+ raise ValueError(
73
+ "More than one recording must be available when "
74
+ "using across-recording negative sampling."
75
+ )
61
76
 
62
77
  def _sample_pair(self):
63
- """Sample a pair of two windows.
78
+ """Sample a pair of two windows."""
79
+ # Sample first window
80
+ win_ind1, rec_ind1 = self.sample_window()
81
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
82
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
83
+
84
+ # Decide whether the pair will be positive or negative
85
+ pair_type = self.rng.binomial(1, 0.5)
86
+ win_ind2 = None
87
+ if pair_type == 0: # Negative example
88
+ if self.same_rec_neg:
89
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
90
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
91
+ )
92
+ else:
93
+ rec_ind2 = rec_ind1
94
+ while rec_ind2 == rec_ind1:
95
+ win_ind2, rec_ind2 = self.sample_window()
96
+ elif pair_type == 1: # Positive example
97
+ mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
98
+
99
+ if win_ind2 is None:
100
+ mask[ts == ts1] = False # same window cannot be sampled twice
101
+ if sum(mask) == 0:
102
+ raise NotImplementedError
103
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
104
+
105
+ return win_ind1, win_ind2, float(pair_type)
106
+
107
+ def presample(self):
108
+ """Presample examples.
109
+
110
+ Once presampled, the examples are the same from one epoch to another.
111
+ """
112
+ self.examples = [self._sample_pair() for _ in range(self.n_examples)]
113
+ return self
114
+
115
+ def __iter__(self):
116
+ """
117
+ Iterate over pairs.
118
+
119
+ Yields
120
+ ------
121
+ int
122
+ Position of the first window in the dataset.
123
+ int
124
+ Position of the second window in the dataset.
125
+ float
126
+ 0 for a negative pair, 1 for a positive pair.
64
127
  """
128
+ for i in range(self.n_examples):
129
+ if hasattr(self, "examples"):
130
+ yield self.examples[i]
131
+ else:
132
+ yield self._sample_pair()
133
+
134
+ def __len__(self):
135
+ return self.n_examples
136
+
137
+
138
+ class DistributedRelativePositioningSampler(DistributedRecordingSampler):
139
+ """Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
140
+
141
+ Sample examples as tuples of two window indices, with a label indicating
142
+ whether the windows are close or far, as defined by tau_pos and tau_neg.
143
+
144
+ Parameters
145
+ ----------
146
+ metadata : pd.DataFrame
147
+ See RecordingSampler.
148
+ tau_pos : int
149
+ Size of the positive context, in samples. A positive pair contains two
150
+ windows x1 and x2 which are separated by at most `tau_pos` samples.
151
+ tau_neg : int
152
+ Size of the negative context, in samples. A negative pair contains two
153
+ windows x1 and x2 which are separated by at least `tau_neg` samples and
154
+ at most `tau_max` samples. Ignored if `same_rec_neg` is False.
155
+ n_examples : int
156
+ Number of pairs to extract.
157
+ tau_max : int | None
158
+ See `tau_neg`.
159
+ same_rec_neg : bool
160
+ If True, sample negative pairs from within the same recording. If
161
+ False, sample negative pairs from two different recordings.
162
+ random_state : None | np.RandomState | int
163
+ Random state.
164
+ kwargs: dict
165
+ Additional keyword arguments to pass to torch DistributedSampler.
166
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
167
+
168
+ References
169
+ ----------
170
+ .. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
171
+ & Gramfort, A. (2020). Uncovering the structure of clinical EEG
172
+ signals with self-supervised learning.
173
+ arXiv preprint arXiv:2007.16104.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ metadata,
179
+ tau_pos,
180
+ tau_neg,
181
+ n_examples,
182
+ tau_max=None,
183
+ same_rec_neg=True,
184
+ random_state=None,
185
+ **kwargs,
186
+ ):
187
+ super().__init__(metadata, random_state=random_state, **kwargs)
188
+ self.tau_pos = tau_pos
189
+ self.tau_neg = tau_neg
190
+ self.tau_max = np.inf if tau_max is None else tau_max
191
+ self.same_rec_neg = same_rec_neg
192
+
193
+ self.n_examples = n_examples // self.info.shape[0] * self.n_recordings
194
+ warnings.warn(
195
+ f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
196
+ )
197
+ warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
198
+
199
+ if not same_rec_neg and self.n_recordings < 2:
200
+ raise ValueError(
201
+ "More than one recording must be available when "
202
+ "using across-recording negative sampling."
203
+ )
204
+
205
+ def _sample_pair(self):
206
+ """Sample a pair of two windows."""
65
207
  # Sample first window
66
208
  win_ind1, rec_ind1 = self.sample_window()
67
- ts1 = self.metadata.iloc[win_ind1]['i_start_in_trial']
68
- ts = self.info.iloc[rec_ind1]['i_start_in_trial']
209
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
210
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
69
211
 
70
212
  # Decide whether the pair will be positive or negative
71
213
  pair_type = self.rng.binomial(1, 0.5)
72
214
  win_ind2 = None
73
215
  if pair_type == 0: # Negative example
74
216
  if self.same_rec_neg:
75
- mask = (
76
- ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) |
77
- ((ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max))
217
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
218
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
78
219
  )
79
220
  else:
80
221
  rec_ind2 = rec_ind1
@@ -87,7 +228,7 @@ class RelativePositioningSampler(RecordingSampler):
87
228
  mask[ts == ts1] = False # same window cannot be sampled twice
88
229
  if sum(mask) == 0:
89
230
  raise NotImplementedError
90
- win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]['index'][mask])
231
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
91
232
 
92
233
  return win_ind1, win_ind2, float(pair_type)
93
234
 
@@ -100,16 +241,20 @@ class RelativePositioningSampler(RecordingSampler):
100
241
  return self
101
242
 
102
243
  def __iter__(self):
103
- """Iterate over pairs.
244
+ """
245
+ Iterate over pairs.
104
246
 
105
247
  Yields
106
248
  ------
107
- (int): position of the first window in the dataset.
108
- (int): position of the second window in the dataset.
109
- (float): 0 for negative pair, 1 for positive pair.
249
+ int
250
+ Position of the first window in the dataset.
251
+ int
252
+ Position of the second window in the dataset.
253
+ float
254
+ 0 for a negative pair, 1 for a positive pair.
110
255
  """
111
256
  for i in range(self.n_examples):
112
- if hasattr(self, 'examples'):
257
+ if hasattr(self, "examples"):
113
258
  yield self.examples[i]
114
259
  else:
115
260
  yield self._sample_pair()
@@ -2,14 +2,22 @@
2
2
  Functionality for skorch-based training.
3
3
  """
4
4
 
5
+ from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
6
+ from .scoring import (
7
+ CroppedTimeSeriesEpochScoring,
8
+ CroppedTrialEpochScoring,
9
+ PostEpochTrainScoring,
10
+ predict_trials,
11
+ trial_preds_from_window_preds,
12
+ )
5
13
 
6
- from .losses import CroppedLoss, mixup_criterion, TimeSeriesLoss
7
- from .scoring import (CroppedTrialEpochScoring, PostEpochTrainScoring,
8
- CroppedTimeSeriesEpochScoring, trial_preds_from_window_preds, predict_trials)
9
-
10
- __all__ = ["CroppedLoss", "mixup_criterion", "TimeSeriesLoss",
11
- "CroppedTrialEpochScoring",
12
- "PostEpochTrainScoring",
13
- "CroppedTimeSeriesEpochScoring",
14
- "trial_preds_from_window_preds",
15
- "predict_trials"]
14
+ __all__ = [
15
+ "CroppedLoss",
16
+ "mixup_criterion",
17
+ "TimeSeriesLoss",
18
+ "CroppedTrialEpochScoring",
19
+ "PostEpochTrainScoring",
20
+ "CroppedTimeSeriesEpochScoring",
21
+ "trial_preds_from_window_preds",
22
+ "predict_trials",
23
+ ]
@@ -2,8 +2,8 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
- from skorch.callbacks import Callback
6
5
  import torch
6
+ from skorch.callbacks import Callback
7
7
 
8
8
 
9
9
  class MaxNormConstraintCallback(Callback):
@@ -20,6 +20,4 @@ class MaxNormConstraintCallback(Callback):
20
20
  )
21
21
  last_weight = module.weight
22
22
  if last_weight is not None:
23
- last_weight.data = torch.renorm(
24
- last_weight.data, 2, 0, maxnorm=0.5
25
- )
23
+ last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
@@ -96,15 +96,10 @@ def mixup_criterion(preds, target):
96
96
  # unpack target
97
97
  y_a, y_b, lam = target
98
98
  # compute loss per sample
99
- loss_a = torch.nn.functional.nll_loss(preds,
100
- y_a,
101
- reduction='none')
102
- loss_b = torch.nn.functional.nll_loss(preds,
103
- y_b,
104
- reduction='none')
99
+ loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
100
+ loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
105
101
  # compute weighted mean
106
102
  ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
107
103
  return ret.mean()
108
104
  else:
109
- return torch.nn.functional.nll_loss(preds,
110
- target)
105
+ return torch.nn.functional.nll_loss(preds, target)