braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,263 @@
1
+ """
2
+ Self-supervised learning samplers.
3
+ """
4
+
5
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Young Truong <dt.young112@gmail.com>
7
+ #
8
+ # License: BSD (3-clause)
9
+
10
+ import warnings
11
+
12
+ import numpy as np
13
+ import torch.distributed as dist
14
+
15
+ from . import DistributedRecordingSampler, RecordingSampler
16
+
17
+
18
+ class RelativePositioningSampler(RecordingSampler):
19
+ """Sample examples for the relative positioning task from [Banville2020]_.
20
+
21
+ Sample examples as tuples of two window indices, with a label indicating
22
+ whether the windows are close or far, as defined by tau_pos and tau_neg.
23
+
24
+ Parameters
25
+ ----------
26
+ metadata : pd.DataFrame
27
+ See RecordingSampler.
28
+ tau_pos : int
29
+ Size of the positive context, in samples. A positive pair contains two
30
+ windows x1 and x2 which are separated by at most `tau_pos` samples.
31
+ tau_neg : int
32
+ Size of the negative context, in samples. A negative pair contains two
33
+ windows x1 and x2 which are separated by at least `tau_neg` samples and
34
+ at most `tau_max` samples. Ignored if `same_rec_neg` is False.
35
+ n_examples : int
36
+ Number of pairs to extract.
37
+ tau_max : int | None
38
+ See `tau_neg`.
39
+ same_rec_neg : bool
40
+ If True, sample negative pairs from within the same recording. If
41
+ False, sample negative pairs from two different recordings.
42
+ random_state : None | np.RandomState | int
43
+ Random state.
44
+
45
+ References
46
+ ----------
47
+ .. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
48
+ & Gramfort, A. (2020). Uncovering the structure of clinical EEG
49
+ signals with self-supervised learning.
50
+ arXiv preprint arXiv:2007.16104.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ metadata,
56
+ tau_pos,
57
+ tau_neg,
58
+ n_examples,
59
+ tau_max=None,
60
+ same_rec_neg=True,
61
+ random_state=None,
62
+ ):
63
+ super().__init__(metadata, random_state=random_state)
64
+
65
+ self.tau_pos = tau_pos
66
+ self.tau_neg = tau_neg
67
+ self.tau_max = np.inf if tau_max is None else tau_max
68
+ self.n_examples = n_examples
69
+ self.same_rec_neg = same_rec_neg
70
+
71
+ if not same_rec_neg and self.n_recordings < 2:
72
+ raise ValueError(
73
+ "More than one recording must be available when "
74
+ "using across-recording negative sampling."
75
+ )
76
+
77
+ def _sample_pair(self):
78
+ """Sample a pair of two windows."""
79
+ # Sample first window
80
+ win_ind1, rec_ind1 = self.sample_window()
81
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
82
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
83
+
84
+ # Decide whether the pair will be positive or negative
85
+ pair_type = self.rng.binomial(1, 0.5)
86
+ win_ind2 = None
87
+ if pair_type == 0: # Negative example
88
+ if self.same_rec_neg:
89
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
90
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
91
+ )
92
+ else:
93
+ rec_ind2 = rec_ind1
94
+ while rec_ind2 == rec_ind1:
95
+ win_ind2, rec_ind2 = self.sample_window()
96
+ elif pair_type == 1: # Positive example
97
+ mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
98
+
99
+ if win_ind2 is None:
100
+ mask[ts == ts1] = False # same window cannot be sampled twice
101
+ if sum(mask) == 0:
102
+ raise NotImplementedError
103
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
104
+
105
+ return win_ind1, win_ind2, float(pair_type)
106
+
107
+ def presample(self):
108
+ """Presample examples.
109
+
110
+ Once presampled, the examples are the same from one epoch to another.
111
+ """
112
+ self.examples = [self._sample_pair() for _ in range(self.n_examples)]
113
+ return self
114
+
115
+ def __iter__(self):
116
+ """
117
+ Iterate over pairs.
118
+
119
+ Yields
120
+ ------
121
+ int
122
+ Position of the first window in the dataset.
123
+ int
124
+ Position of the second window in the dataset.
125
+ float
126
+ 0 for a negative pair, 1 for a positive pair.
127
+ """
128
+ for i in range(self.n_examples):
129
+ if hasattr(self, "examples"):
130
+ yield self.examples[i]
131
+ else:
132
+ yield self._sample_pair()
133
+
134
+ def __len__(self):
135
+ return self.n_examples
136
+
137
+
138
+ class DistributedRelativePositioningSampler(DistributedRecordingSampler):
139
+ """Sample examples for the relative positioning task from [Banville2020]_ in distributed mode.
140
+
141
+ Sample examples as tuples of two window indices, with a label indicating
142
+ whether the windows are close or far, as defined by tau_pos and tau_neg.
143
+
144
+ Parameters
145
+ ----------
146
+ metadata : pd.DataFrame
147
+ See RecordingSampler.
148
+ tau_pos : int
149
+ Size of the positive context, in samples. A positive pair contains two
150
+ windows x1 and x2 which are separated by at most `tau_pos` samples.
151
+ tau_neg : int
152
+ Size of the negative context, in samples. A negative pair contains two
153
+ windows x1 and x2 which are separated by at least `tau_neg` samples and
154
+ at most `tau_max` samples. Ignored if `same_rec_neg` is False.
155
+ n_examples : int
156
+ Number of pairs to extract.
157
+ tau_max : int | None
158
+ See `tau_neg`.
159
+ same_rec_neg : bool
160
+ If True, sample negative pairs from within the same recording. If
161
+ False, sample negative pairs from two different recordings.
162
+ random_state : None | np.RandomState | int
163
+ Random state.
164
+ kwargs: dict
165
+ Additional keyword arguments to pass to torch DistributedSampler.
166
+ See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
167
+
168
+ References
169
+ ----------
170
+ .. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A.,
171
+ & Gramfort, A. (2020). Uncovering the structure of clinical EEG
172
+ signals with self-supervised learning.
173
+ arXiv preprint arXiv:2007.16104.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ metadata,
179
+ tau_pos,
180
+ tau_neg,
181
+ n_examples,
182
+ tau_max=None,
183
+ same_rec_neg=True,
184
+ random_state=None,
185
+ **kwargs,
186
+ ):
187
+ super().__init__(metadata, random_state=random_state, **kwargs)
188
+ self.tau_pos = tau_pos
189
+ self.tau_neg = tau_neg
190
+ self.tau_max = np.inf if tau_max is None else tau_max
191
+ self.same_rec_neg = same_rec_neg
192
+
193
+ self.n_examples = n_examples * self.n_recordings // self.info.shape[0]
194
+ warnings.warn(
195
+ f"Rank {dist.get_rank()} - Number of datasets: {self.n_recordings}"
196
+ )
197
+ warnings.warn(f"Rank {dist.get_rank()} - Number of samples: {self.n_examples}")
198
+
199
+ if not same_rec_neg and self.n_recordings < 2:
200
+ raise ValueError(
201
+ "More than one recording must be available when "
202
+ "using across-recording negative sampling."
203
+ )
204
+
205
+ def _sample_pair(self):
206
+ """Sample a pair of two windows."""
207
+ # Sample first window
208
+ win_ind1, rec_ind1 = self.sample_window()
209
+ ts1 = self.metadata.iloc[win_ind1]["i_start_in_trial"]
210
+ ts = self.info.iloc[rec_ind1]["i_start_in_trial"]
211
+
212
+ # Decide whether the pair will be positive or negative
213
+ pair_type = self.rng.binomial(1, 0.5)
214
+ win_ind2 = None
215
+ if pair_type == 0: # Negative example
216
+ if self.same_rec_neg:
217
+ mask = ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | (
218
+ (ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)
219
+ )
220
+ else:
221
+ rec_ind2 = rec_ind1
222
+ while rec_ind2 == rec_ind1:
223
+ win_ind2, rec_ind2 = self.sample_window()
224
+ elif pair_type == 1: # Positive example
225
+ mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos)
226
+
227
+ if win_ind2 is None:
228
+ mask[ts == ts1] = False # same window cannot be sampled twice
229
+ if sum(mask) == 0:
230
+ raise NotImplementedError
231
+ win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]["index"][mask])
232
+
233
+ return win_ind1, win_ind2, float(pair_type)
234
+
235
+ def presample(self):
236
+ """Presample examples.
237
+
238
+ Once presampled, the examples are the same from one epoch to another.
239
+ """
240
+ self.examples = [self._sample_pair() for _ in range(self.n_examples)]
241
+ return self
242
+
243
+ def __iter__(self):
244
+ """
245
+ Iterate over pairs.
246
+
247
+ Yields
248
+ ------
249
+ int
250
+ Position of the first window in the dataset.
251
+ int
252
+ Position of the second window in the dataset.
253
+ float
254
+ 0 for a negative pair, 1 for a positive pair.
255
+ """
256
+ for i in range(self.n_examples):
257
+ if hasattr(self, "examples"):
258
+ yield self.examples[i]
259
+ else:
260
+ yield self._sample_pair()
261
+
262
+ def __len__(self):
263
+ return self.n_examples
@@ -0,0 +1,23 @@
1
+ """
2
+ Functionality for skorch-based training.
3
+ """
4
+
5
+ from .losses import CroppedLoss, TimeSeriesLoss, mixup_criterion
6
+ from .scoring import (
7
+ CroppedTimeSeriesEpochScoring,
8
+ CroppedTrialEpochScoring,
9
+ PostEpochTrainScoring,
10
+ predict_trials,
11
+ trial_preds_from_window_preds,
12
+ )
13
+
14
+ __all__ = [
15
+ "CroppedLoss",
16
+ "mixup_criterion",
17
+ "TimeSeriesLoss",
18
+ "CroppedTrialEpochScoring",
19
+ "PostEpochTrainScoring",
20
+ "CroppedTimeSeriesEpochScoring",
21
+ "trial_preds_from_window_preds",
22
+ "predict_trials",
23
+ ]
@@ -0,0 +1,23 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from skorch.callbacks import Callback
7
+
8
+
9
+ class MaxNormConstraintCallback(Callback):
10
+ def on_batch_end(self, net, training, *args, **kwargs):
11
+ if training:
12
+ model = net.module_
13
+ last_weight = None
14
+ for name, module in list(model.named_children()):
15
+ if hasattr(module, "weight") and (
16
+ not module.__class__.__name__.startswith("BatchNorm")
17
+ ):
18
+ module.weight.data = torch.renorm(
19
+ module.weight.data, 2, 0, maxnorm=2
20
+ )
21
+ last_weight = module.weight
22
+ if last_weight is not None:
23
+ last_weight.data = torch.renorm(last_weight.data, 2, 0, maxnorm=0.5)
@@ -0,0 +1,105 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Maciej Sliwowski <maciek.sliwowski@gmail.com>
3
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class CroppedLoss(nn.Module):
12
+ """Compute Loss after averaging predictions across time.
13
+ Assumes predictions are in shape:
14
+ n_batch size x n_classes x n_predictions (in time)"""
15
+
16
+ def __init__(self, loss_function):
17
+ super().__init__()
18
+ self.loss_function = loss_function
19
+
20
+ def forward(self, preds, targets):
21
+ """Forward pass.
22
+
23
+ Parameters
24
+ ----------
25
+ preds: torch.Tensor
26
+ Model's prediction with shape (batch_size, n_classes, n_times).
27
+ targets: torch.Tensor
28
+ Target labels with shape (batch_size, n_classes, n_times).
29
+ """
30
+ avg_preds = torch.mean(preds, dim=2)
31
+ avg_preds = avg_preds.squeeze(dim=1)
32
+ return self.loss_function(avg_preds, targets)
33
+
34
+
35
+ class TimeSeriesLoss(nn.Module):
36
+ """Compute Loss between timeseries targets and predictions.
37
+ Assumes predictions are in shape:
38
+ n_batch size x n_classes x n_predictions (in time)
39
+ Assumes targets are in shape:
40
+ n_batch size x n_classes x window_len (in time)
41
+ If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for
42
+ predictions valid corresponding to valid target values."""
43
+
44
+ def __init__(self, loss_function):
45
+ super().__init__()
46
+ self.loss_function = loss_function
47
+
48
+ def forward(self, preds, targets):
49
+ """Forward pass.
50
+
51
+ Parameters
52
+ ----------
53
+ preds: torch.Tensor
54
+ Model's prediction with shape (batch_size, n_classes, n_times).
55
+ targets: torch.Tensor
56
+ Target labels with shape (batch_size, n_classes, n_times).
57
+ """
58
+ n_preds = preds.shape[-1]
59
+ # slice the targets to fit preds shape
60
+ targets = targets[:, :, -n_preds:]
61
+ # create valid targets mask
62
+ mask = ~torch.isnan(targets)
63
+ # select valid targets that have a matching predictions
64
+ masked_targets = targets[mask]
65
+ masked_preds = preds[mask]
66
+ return self.loss_function(masked_preds, masked_targets)
67
+
68
+
69
+ def mixup_criterion(preds, target):
70
+ """Implements loss for Mixup for EEG data. See [1]_.
71
+
72
+ Implementation based on [2]_.
73
+
74
+ Parameters
75
+ ----------
76
+ preds : torch.Tensor
77
+ Predictions from the model.
78
+ target : torch.Tensor | list of torch.Tensor
79
+ For predictions without mixup, the targets as a tensor. If mixup has
80
+ been applied, a list containing the targets of the two mixed
81
+ samples and the mixing coefficients as tensors.
82
+
83
+ Returns
84
+ -------
85
+ loss : float
86
+ The loss value.
87
+
88
+ References
89
+ ----------
90
+ .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
91
+ mixup: Beyond Empirical Risk Minimization
92
+ Online: https://arxiv.org/abs/1710.09412
93
+ .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
94
+ """
95
+ if len(target) == 3:
96
+ # unpack target
97
+ y_a, y_b, lam = target
98
+ # compute loss per sample
99
+ loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction="none")
100
+ loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction="none")
101
+ # compute weighted mean
102
+ ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b)
103
+ return ret.mean()
104
+ else:
105
+ return torch.nn.functional.nll_loss(preds, target)