braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/regressor.py CHANGED
@@ -11,8 +11,8 @@ import warnings
11
11
  import numpy as np
12
12
  from skorch.regressor import NeuralNetRegressor
13
13
 
14
- from .training.scoring import predict_trials
15
14
  from .eegneuralnet import _EEGNeuralNet
15
+ from .training.scoring import predict_trials
16
16
  from .util import ThrowAwayIndexLoader, update_estimator_docstring
17
17
 
18
18
 
@@ -58,19 +58,28 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
58
58
  """ # noqa: E501
59
59
  __doc__ = update_estimator_docstring(NeuralNetRegressor, doc)
60
60
 
61
- def __init__(self, module, *args, cropped=False, callbacks=None,
62
- iterator_train__shuffle=True,
63
- iterator_train__drop_last=True,
64
- aggregate_predictions=True, **kwargs):
61
+ def __init__(
62
+ self,
63
+ module,
64
+ *args,
65
+ cropped=False,
66
+ callbacks=None,
67
+ iterator_train__shuffle=True,
68
+ iterator_train__drop_last=True,
69
+ aggregate_predictions=True,
70
+ **kwargs,
71
+ ):
65
72
  self.cropped = cropped
66
73
  self.aggregate_predictions = aggregate_predictions
67
74
  self._last_window_inds_ = None
68
- super().__init__(module,
69
- *args,
70
- callbacks=callbacks,
71
- iterator_train__shuffle=iterator_train__shuffle,
72
- iterator_train__drop_last=iterator_train__drop_last,
73
- **kwargs)
75
+ super().__init__(
76
+ module,
77
+ *args,
78
+ callbacks=callbacks,
79
+ iterator_train__shuffle=iterator_train__shuffle,
80
+ iterator_train__drop_last=iterator_train__drop_last,
81
+ **kwargs,
82
+ )
74
83
 
75
84
  def get_iterator(self, dataset, training=False, drop_index=True):
76
85
  iterator = super().get_iterator(dataset, training=training)
@@ -155,7 +164,9 @@ class EEGRegressor(_EEGNeuralNet, NeuralNetRegressor):
155
164
  warnings.warn(
156
165
  "This method was designed to predict trials in cropped mode. "
157
166
  "Calling it when cropped is False will give the same result as "
158
- "'.predict'.", UserWarning)
167
+ "'.predict'.",
168
+ UserWarning,
169
+ )
159
170
  preds = self.predict(X)
160
171
  if return_targets:
161
172
  return preds, np.concatenate([X[i][1] for i in range(len(X))])
@@ -1,10 +1,18 @@
1
- """Classes to sample examples.
2
- """
1
+ """Classes to sample examples."""
3
2
 
4
- from .base import RecordingSampler, SequenceSampler, BalancedSequenceSampler
5
- from .ssl import RelativePositioningSampler
3
+ from .base import (
4
+ BalancedSequenceSampler,
5
+ DistributedRecordingSampler,
6
+ RecordingSampler,
7
+ SequenceSampler,
8
+ )
9
+ from .ssl import DistributedRelativePositioningSampler, RelativePositioningSampler
6
10
 
7
- __all__ = ["RecordingSampler",
8
- "SequenceSampler",
9
- "BalancedSequenceSampler",
10
- "RelativePositioningSampler"]
11
+ __all__ = [
12
+ "RecordingSampler",
13
+ "SequenceSampler",
14
+ "BalancedSequenceSampler",
15
+ "RelativePositioningSampler",
16
+ "DistributedRecordingSampler",
17
+ "DistributedRelativePositioningSampler",
18
+ ]
@@ -4,17 +4,118 @@ Sampler classes.
4
4
 
5
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
6
  # Theo Gnassounou <>
7
+ # Young Truong <dt.young112@gmail.com>
7
8
  #
8
9
  # License: BSD (3-clause)
9
10
 
10
11
  import numpy as np
11
- from torch.utils.data.sampler import Sampler
12
12
  from sklearn.utils import check_random_state
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.utils.data.sampler import Sampler
13
15
 
14
16
 
15
17
  class RecordingSampler(Sampler):
16
18
  """Base sampler simplifying sampling from recordings.
17
19
 
20
+ Parameters
21
+ ----------
22
+ metadata : pd.DataFrame
23
+ DataFrame with at least one of {subject, session, run} columns for each
24
+ window in the BaseConcatDataset to sample examples from. Normally
25
+ obtained with `BaseConcatDataset.get_metadata()`. For instance,
26
+ `metadata.head()` might look like this:
27
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
28
+ | i_window_in_trial | i_start_in_trial| i_stop_in_trial | target | subject | session | run |
29
+ +===================+=================+=================+========+==========+===========+=======+
30
+ | 0 | 0 | 500 | -1 | 4 | session_T | run_0 |
31
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
32
+ | 1 | 500 | 1000 | -1 | 4 | session_T | run_0 |
33
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
34
+ | 2 | 1000 | 1500 | -1 | 4 | session_T | run_0 |
35
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
36
+ | 3 | 1500 | 2000 | -1 | 4 | session_T | run_0 |
37
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
38
+ | 4 | 2000 | 2500 | -1 | 4 | session_T | run_0 |
39
+ +-------------------+-----------------+-----------------+--------+----------+-----------+-------+
40
+
41
+ random_state : np.RandomState | int | None
42
+ Random state.
43
+
44
+ Attributes
45
+ ----------
46
+ info : pd.DataFrame
47
+ Series with MultiIndex index which contains the subject, session, run
48
+ and window indices information in an easily accessible structure for
49
+ quick sampling of windows.
50
+ n_recordings : int
51
+ Number of recordings available.
52
+ """
53
+
54
+ def __init__(self, metadata, random_state=None):
55
+ self.metadata = metadata
56
+ self.info = self._init_info(metadata)
57
+ self.rng = check_random_state(random_state)
58
+
59
+ def _init_info(self, metadata, required_keys=None):
60
+ """Initialize ``info`` DataFrame.
61
+
62
+ Parameters
63
+ ----------
64
+ required_keys : list(str) | None
65
+ List of additional columns of the metadata DataFrame that we should
66
+ groupby when creating ``info``.
67
+
68
+ Returns
69
+ -------
70
+ See class attributes.
71
+ """
72
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
73
+ if not keys:
74
+ raise ValueError(
75
+ "metadata must contain at least one of the following columns: "
76
+ "subject, session or run."
77
+ )
78
+
79
+ if required_keys is not None:
80
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
81
+ if len(missing_keys) > 0:
82
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
83
+ keys += required_keys
84
+
85
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
86
+ info = (
87
+ metadata.reset_index()
88
+ .groupby(keys)[["index", "i_start_in_trial"]]
89
+ .agg(["unique"])
90
+ )
91
+ info.columns = info.columns.get_level_values(0)
92
+
93
+ return info
94
+
95
+ def sample_recording(self):
96
+ """Return a random recording index."""
97
+ # XXX docstring missing
98
+ return self.rng.choice(self.n_recordings)
99
+
100
+ def sample_window(self, rec_ind=None):
101
+ """Return a specific window."""
102
+ # XXX docstring missing
103
+ if rec_ind is None:
104
+ rec_ind = self.sample_recording()
105
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
106
+ return win_ind, rec_ind
107
+
108
+ def __iter__(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def n_recordings(self):
113
+ return self.info.shape[0]
114
+
115
+
116
+ class DistributedRecordingSampler(DistributedSampler):
117
+ """Base sampler simplifying sampling from recordings in distributed setting.
118
+
18
119
  Parameters
19
120
  ----------
20
121
  metadata : pd.DataFrame
@@ -41,11 +142,22 @@ class RecordingSampler(Sampler):
41
142
  quick sampling of windows.
42
143
  n_recordings : int
43
144
  Number of recordings available.
145
+ kwargs : dict
146
+ Additional keyword arguments to pass to torch DistributedSampler.
147
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
44
148
  """
45
- def __init__(self, metadata, random_state=None):
149
+
150
+ def __init__(
151
+ self,
152
+ metadata,
153
+ random_state=None,
154
+ **kwargs,
155
+ ):
46
156
  self.metadata = metadata
47
157
  self.info = self._init_info(metadata)
48
158
  self.rng = check_random_state(random_state)
159
+ # send information to DistributedSampler parent to handle data splitting among workers
160
+ super().__init__(self.info, seed=random_state, **kwargs)
49
161
 
50
162
  def _init_info(self, metadata, required_keys=None):
51
163
  """Initialize ``info`` DataFrame.
@@ -60,50 +172,48 @@ class RecordingSampler(Sampler):
60
172
  -------
61
173
  See class attributes.
62
174
  """
63
- keys = [k for k in ['subject', 'session', 'run']
64
- if k in self.metadata.columns]
175
+ keys = [k for k in ["subject", "session", "run"] if k in self.metadata.columns]
65
176
  if not keys:
66
177
  raise ValueError(
67
- 'metadata must contain at least one of the following columns: '
68
- 'subject, session or run.')
178
+ "metadata must contain at least one of the following columns: "
179
+ "subject, session or run."
180
+ )
69
181
 
70
182
  if required_keys is not None:
71
- missing_keys = [
72
- k for k in required_keys if k not in self.metadata.columns]
183
+ missing_keys = [k for k in required_keys if k not in self.metadata.columns]
73
184
  if len(missing_keys) > 0:
74
- raise ValueError(
75
- f'Columns {missing_keys} were not found in metadata.')
185
+ raise ValueError(f"Columns {missing_keys} were not found in metadata.")
76
186
  keys += required_keys
77
187
 
78
- metadata = metadata.reset_index().rename(
79
- columns={'index': 'window_index'})
80
- info = metadata.reset_index().groupby(keys)[
81
- ['index', 'i_start_in_trial']].agg(['unique'])
188
+ metadata = metadata.reset_index().rename(columns={"index": "window_index"})
189
+ info = (
190
+ metadata.reset_index()
191
+ .groupby(keys)[["index", "i_start_in_trial"]]
192
+ .agg(["unique"])
193
+ )
82
194
  info.columns = info.columns.get_level_values(0)
83
195
 
84
196
  return info
85
197
 
86
198
  def sample_recording(self):
87
199
  """Return a random recording index.
200
+ super().__iter__() contains indices of datasets specific to the current process
201
+ determined by the DistributedSampler
88
202
  """
89
203
  # XXX docstring missing
90
- return self.rng.choice(self.n_recordings)
204
+ return self.rng.choice(list(super().__iter__()))
91
205
 
92
206
  def sample_window(self, rec_ind=None):
93
- """Return a specific window.
94
- """
207
+ """Return a specific window."""
95
208
  # XXX docstring missing
96
209
  if rec_ind is None:
97
210
  rec_ind = self.sample_recording()
98
- win_ind = self.rng.choice(self.info.iloc[rec_ind]['index'])
211
+ win_ind = self.rng.choice(self.info.iloc[rec_ind]["index"])
99
212
  return win_ind, rec_ind
100
213
 
101
- def __iter__(self):
102
- raise NotImplementedError
103
-
104
214
  @property
105
215
  def n_recordings(self):
106
- return self.info.shape[0]
216
+ return super().__len__()
107
217
 
108
218
 
109
219
  class SequenceSampler(RecordingSampler):
@@ -131,8 +241,10 @@ class SequenceSampler(RecordingSampler):
131
241
  Array of shape (n_sequences,) that indicates from which file each
132
242
  sequence comes from. Useful e.g. to do self-ensembling.
133
243
  """
134
- def __init__(self, metadata, n_windows, n_windows_stride, randomize=False,
135
- random_state=None):
244
+
245
+ def __init__(
246
+ self, metadata, n_windows, n_windows_stride, randomize=False, random_state=None
247
+ ):
136
248
  super().__init__(metadata, random_state=random_state)
137
249
  self.randomize = randomize
138
250
  self.n_windows = n_windows
@@ -152,8 +264,11 @@ class SequenceSampler(RecordingSampler):
152
264
  each sequence. Useful e.g. to do self-ensembling.
153
265
  """
154
266
  end_offset = 1 - self.n_windows if self.n_windows > 1 else None
155
- start_inds = self.info['index'].apply(
156
- lambda x: x[:end_offset:self.n_windows_stride]).values
267
+ start_inds = (
268
+ self.info["index"]
269
+ .apply(lambda x: x[: end_offset : self.n_windows_stride])
270
+ .values
271
+ )
157
272
  file_ids = [[i] * len(inds) for i, inds in enumerate(start_inds)]
158
273
  return np.concatenate(start_inds), np.concatenate(file_ids)
159
274
 
@@ -200,12 +315,13 @@ class BalancedSequenceSampler(RecordingSampler):
200
315
  Med. 4, 72 (2021).
201
316
  https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
202
317
  """
318
+
203
319
  def __init__(self, metadata, n_windows, n_sequences=10, random_state=None):
204
320
  super().__init__(metadata, random_state=random_state)
205
321
 
206
322
  self.n_windows = n_windows
207
323
  self.n_sequences = n_sequences
208
- self.info_class = self._init_info(metadata, required_keys=['target'])
324
+ self.info_class = self._init_info(metadata, required_keys=["target"])
209
325
 
210
326
  def sample_class(self, rec_ind=None):
211
327
  """Return a random class.
@@ -225,8 +341,7 @@ class BalancedSequenceSampler(RecordingSampler):
225
341
  """
226
342
  if rec_ind is None:
227
343
  rec_ind = self.sample_recording()
228
- available_classes = self.info_class.loc[
229
- self.info.iloc[rec_ind].name].index
344
+ available_classes = self.info_class.loc[self.info.iloc[rec_ind].name].index
230
345
  return self.rng.choice(available_classes), rec_ind
231
346
 
232
347
  def _sample_seq_start_ind(self, rec_ind=None, class_ind=None):
@@ -257,15 +372,14 @@ class BalancedSequenceSampler(RecordingSampler):
257
372
  if class_ind is None:
258
373
  class_ind, rec_ind = self.sample_class(rec_ind)
259
374
 
260
- rec_inds = self.info.iloc[rec_ind]['index']
375
+ rec_inds = self.info.iloc[rec_ind]["index"]
261
376
  len_rec_inds = len(rec_inds)
262
377
 
263
378
  row = self.info.iloc[rec_ind].name
264
379
  if not isinstance(row, tuple):
265
380
  # Theres's only one category, e.g. "subject"
266
381
  row = tuple([row])
267
- available_indices = self.info_class.loc[
268
- row + tuple([class_ind]), 'index']
382
+ available_indices = self.info_class.loc[row + tuple([class_ind]), "index"]
269
383
  win_ind = self.rng.choice(available_indices)
270
384
  win_ind_in_rec = np.where(rec_inds == win_ind)[0][0]
271
385
 
@@ -3,12 +3,16 @@ Self-supervised learning samplers.
3
3
  """
4
4
 
5
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Young Truong <dt.young112@gmail.com>
6
7
  #
7
8
  # License: BSD (3-clause)
8
9
 
10
+ import warnings
11
+
9
12
  import numpy as np
13
+ import torch.distributed as dist
10
14
 
11
- from . import RecordingSampler
15
+ from . import DistributedRecordingSampler, RecordingSampler
12
16
 
13
17
 
14
18
  class RelativePositioningSampler(RecordingSampler):
@@ -45,8 +49,17 @@ class RelativePositioningSampler(RecordingSampler):
45
49
  signals with self-supervised learning.
46
50
  arXiv preprint arXiv:2007.16104.
47
51
  """
48
- def __init__(self, metadata, tau_pos, tau_neg, n_examples, tau_max=None,
49
- same_rec_neg=True, random_state=None):
52
+
53
+ def __init__(
54
+ self,
55
+ metadata,
56
+ tau_pos,
57
+ tau_neg,
58
+ n_examples,
59
+ tau_max=None,
60
+ same_rec_neg=True,
61
+ random_state=None,
62
+ ):
50
63
  super().__init__(metadata, random_state=random_state)
51
64
 
52
65
  self.tau_pos = tau_pos
@@ -56,25 +69,153 @@ class RelativePositioningSampler(RecordingSampler):
56
69
  self.same_rec_neg = same_rec_neg
57
70
 
58
71
  if not same_rec_neg and self.n_recordings < 2:
59
- raise ValueError('More than one recording must be available when '
60
- 'using across-recording negative sampling.')
72
+ raise ValueError(
73
+ "More than one recording must be available when "
74
+ "using across-recording negative sampling."
75
+ )
61
76
 
62
77
  def _sample_pair(self):
63
- """Sample a pair of two windows.
78
+ """Sample a pair of two windows."""
79
+ # Sample first window
80
+ win_ind1, rec_ind1 = self.sample_window()
81
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
82
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
83
+
84
+ # Decide whether the pair will be positive or negative
85
+ pair_type = self.rng.binomial(1, 0.5)
86
+ win_ind2 = None
87
+ if pair_type == 0: # Negative example
88
+ if self.same_rec_neg:
89
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
90
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
91
+ )
92
+ else:
93
+ rec_ind2 = rec_ind1
94
+ while rec_ind2 == rec_ind1:
95
+ win_ind2, rec_ind2 = self.sample_window()
96
+ elif pair_type == 1: # Positive example
97
+ mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
98
+
99
+ if win_ind2 is None:
100
+ mask[ts == ts1] = False # same window cannot be sampled twice
101
+ if sum(mask) == 0:
102
+ raise NotImplementedError
103
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
104
+
105
+ return win_ind1, win_ind2, float(pair_type)
106
+
107
+ def presample(self):
108
+ """Presample examples.
109
+
110
+ Once presampled, the examples are the same from one epoch to another.
111
+ """
112
+ self.examples = [self._sample_pair() for _ in range(self.n_examples)]
113
+ return self
114
+
115
+ def __iter__(self):
116
+ """
117
+ Iterate over pairs.
118
+
119
+ Yields
120
+ ------
121
+ int
122
+ Position of the first window in the dataset.
123
+ int
124
+ Position of the second window in the dataset.
125
+ float
126
+ 0 for a negative pair, 1 for a positive pair.
64
127
  """
128
+ for i in range(self.n_examples):
129
+ if hasattr(self, "examples"):
130
+ yield self.examples[i]
131
+ else:
132
+ yield self._sample_pair()
133
+
134
+ def __len__(self):
135
+ return self.n_examples
136
+
137
+
138
+ class DistributedRelativePositioningSampler(DistributedRecordingSampler):
139
+ """Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
140
+
141
+ Sample examples as tuples of two window indices, with a label indicating
142
+ whether the windows are close or far, as defined by tau_pos and tau_neg.
143
+
144
+ Parameters
145
+ ----------
146
+ metadata : pd.DataFrame
147
+ See RecordingSampler.
148
+ tau_pos : int
149
+ Size of the positive context, in samples. A positive pair contains two
150
+ windows x1 and x2 which are separated by at most `tau_pos` samples.
151
+ tau_neg : int
152
+ Size of the negative context, in samples. A negative pair contains two
153
+ windows x1 and x2 which are separated by at least `tau_neg` samples and
154
+ at most `tau_max` samples. Ignored if `same_rec_neg` is False.
155
+ n_examples : int
156
+ Number of pairs to extract.
157
+ tau_max : int | None
158
+ See `tau_neg`.
159
+ same_rec_neg : bool
160
+ If True, sample negative pairs from within the same recording. If
161
+ False, sample negative pairs from two different recordings.
162
+ random_state : None | np.RandomState | int
163
+ Random state.
164
+ kwargs: dict
165
+ Additional keyword arguments to pass to torch DistributedSampler.
166
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
167
+
168
+ References
169
+ ----------
170
+ .. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
171
+ & Gramfort, A. (2020). Uncovering the structure of clinical EEG
172
+ signals with self-supervised learning.
173
+ arXiv preprint arXiv:2007.16104.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ metadata,
179
+ tau_pos,
180
+ tau_neg,
181
+ n_examples,
182
+ tau_max=None,
183
+ same_rec_neg=True,
184
+ random_state=None,
185
+ **kwargs,
186
+ ):
187
+ super().__init__(metadata, random_state=random_state, **kwargs)
188
+ self.tau_pos = tau_pos
189
+ self.tau_neg = tau_neg
190
+ self.tau_max = np.inf if tau_max is None else tau_max
191
+ self.same_rec_neg = same_rec_neg
192
+
193
+ self.n_examples = n_examples // self.info.shape[0] * self.n_recordings
194
+ warnings.warn(
195
+ f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
196
+ )
197
+ warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
198
+
199
+ if not same_rec_neg and self.n_recordings < 2:
200
+ raise ValueError(
201
+ "More than one recording must be available when "
202
+ "using across-recording negative sampling."
203
+ )
204
+
205
+ def _sample_pair(self):
206
+ """Sample a pair of two windows."""
65
207
  # Sample first window
66
208
  win_ind1, rec_ind1 = self.sample_window()
67
- ts1 = self.metadata.iloc[win_ind1]['i_start_in_trial']
68
- ts = self.info.iloc[rec_ind1]['i_start_in_trial']
209
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
210
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
69
211
 
70
212
  # Decide whether the pair will be positive or negative
71
213
  pair_type = self.rng.binomial(1, 0.5)
72
214
  win_ind2 = None
73
215
  if pair_type == 0: # Negative example
74
216
  if self.same_rec_neg:
75
- mask = (
76
- ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) |
77
- ((ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max))
217
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
218
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
78
219
  )
79
220
  else:
80
221
  rec_ind2 = rec_ind1
@@ -87,7 +228,7 @@ class RelativePositioningSampler(RecordingSampler):
87
228
  mask[ts == ts1] = False # same window cannot be sampled twice
88
229
  if sum(mask) == 0:
89
230
  raise NotImplementedError
90
- win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]['index'][mask])
231
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
91
232
 
92
233
  return win_ind1, win_ind2, float(pair_type)
93
234
 
@@ -100,16 +241,20 @@ class RelativePositioningSampler(RecordingSampler):
100
241
  return self
101
242
 
102
243
  def __iter__(self):
103
- """Iterate over pairs.
244
+ """
245
+ Iterate over pairs.
104
246
 
105
247
  Yields
106
248
  ------
107
- (int): position of the first window in the dataset.
108
- (int): position of the second window in the dataset.
109
- (float): 0 for negative pair, 1 for positive pair.
249
+ int
250
+ Position of the first window in the dataset.
251
+ int
252
+ Position of the second window in the dataset.
253
+ float
254
+ 0 for a negative pair, 1 for a positive pair.
110
255
  """
111
256
  for i in range(self.n_examples):
112
- if hasattr(self, 'examples'):
257
+ if hasattr(self, "examples"):
113
258
  yield self.examples[i]
114
259
  else:
115
260
  yield self._sample_pair()
@@ -2,14 +2,22 @@
2
2
  Functionality for skorch-based training.
3
3
  """
4
4
 
5
+ from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
6
+ from .scoring import (
7
+ CroppedTimeSeriesEpochScoring,
8
+ CroppedTrialEpochScoring,
9
+ PostEpochTrainScoring,
10
+ predict_trials,
11
+ trial_preds_from_window_preds,
12
+ )
5
13
 
6
- from .losses import CroppedLoss, mixup_criterion, TimeSeriesLoss
7
- from .scoring import (CroppedTrialEpochScoring, PostEpochTrainScoring,
8
- CroppedTimeSeriesEpochScoring, trial_preds_from_window_preds, predict_trials)
9
-
10
- __all__ = ["CroppedLoss", "mixup_criterion", "TimeSeriesLoss",
11
- "CroppedTrialEpochScoring",
12
- "PostEpochTrainScoring",
13
- "CroppedTimeSeriesEpochScoring",
14
- "trial_preds_from_window_preds",
15
- "predict_trials"]
14
+ __all__ = [
15
+ "CroppedLoss",
16
+ "mixup_criterion",
17
+ "TimeSeriesLoss",
18
+ "CroppedTrialEpochScoring",
19
+ "PostEpochTrainScoring",
20
+ "CroppedTimeSeriesEpochScoring",
21
+ "trial_preds_from_window_preds",
22
+ "predict_trials",
23
+ ]
@@ -2,8 +2,8 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
- from skorch.callbacks import Callback
6
5
  import torch
6
+ from skorch.callbacks import Callback
7
7
 
8
8
 
9
9
  class MaxNormConstraintCallback(Callback):
@@ -20,6 +20,4 @@ class MaxNormConstraintCallback(Callback):
20
20
  )
21
21
  last_weight = module.weight
22
22
  if last_weight is not None:
23
- last_weight.data = torch.renorm(
24
- last_weight.data, 2, 0, maxnorm=0.5
25
- )
23
+ last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
@@ -96,15 +96,10 @@ def mixup_criterion(preds, target):
96
96
  # unpack target
97
97
  y_a, y_b, lam = target
98
98
  # compute loss per sample
99
- loss_a = torch.nn.functional.nll_loss(preds,
100
- y_a,
101
- reduction='none')
102
- loss_b = torch.nn.functional.nll_loss(preds,
103
- y_b,
104
- reduction='none')
99
+ loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
100
+ loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
105
101
  # compute weighted mean
106
102
  ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
107
103
  return ret.mean()
108
104
  else:
109
- return torch.nn.functional.nll_loss(preds,
110
- target)
105
+ return torch.nn.functional.nll_loss(preds, target)