braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -5,32 +5,36 @@
5
5
 
6
6
 
7
7
  import abc
8
- import logging
9
8
  import inspect
9
+ import logging
10
10
 
11
11
  import mne
12
12
  import numpy as np
13
13
  import torch
14
- from skorch import NeuralNet
15
14
  from sklearn.metrics import get_scorer
15
+ from skorch import NeuralNet
16
16
  from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
17
- from skorch.utils import noop, to_numpy, train_loss_score, valid_loss_score, is_dataset
17
+ from skorch.helper import SliceDataset
18
+ from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
18
19
 
19
- from .training.scoring import (CroppedTimeSeriesEpochScoring,
20
- CroppedTrialEpochScoring, PostEpochTrainScoring)
21
- from .models.util import models_dict
22
20
  from .datasets.base import BaseConcatDataset, WindowsDataset
21
+ from .models.util import models_dict
22
+ from .training.scoring import (
23
+ CroppedTimeSeriesEpochScoring,
24
+ CroppedTrialEpochScoring,
25
+ PostEpochTrainScoring,
26
+ )
23
27
 
24
28
  log = logging.getLogger(__name__)
25
29
 
26
30
 
27
- def _get_model(model):
28
- ''' Returns the corresponding class in case the model passed is a string. '''
31
+ def _get_model(model: str):
32
+ """Returns the corresponding class in case the model passed is a string."""
29
33
  if isinstance(model, str):
30
34
  if model in models_dict:
31
35
  model = models_dict[model]
32
36
  else:
33
- raise ValueError(f'Unknown model name {model!r}.')
37
+ raise ValueError(f"Unknown model name {model!r}.")
34
38
  return model
35
39
 
36
40
 
@@ -50,7 +54,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
50
54
  will be left as is.
51
55
 
52
56
  """
53
- kwargs = self.get_params_for('module')
57
+ kwargs = self.get_params_for("module")
54
58
  module = _get_model(self.module)
55
59
  module = self.initialized_instance(module, kwargs)
56
60
  # pylint: disable=attribute-defined-outside-init
@@ -61,7 +65,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
61
65
  # Here we parse the callbacks supplied as strings,
62
66
  # e.g. 'accuracy', to the callbacks skorch expects
63
67
  for name, cb, named_by_user in super()._yield_callbacks():
64
- if name == 'str':
68
+ if name == "str":
65
69
  train_cb, valid_cb = self._parse_str_callback(cb)
66
70
  yield train_cb
67
71
  if self.train_split is not None:
@@ -72,15 +76,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
72
76
  def _parse_str_callback(self, cb_supplied_name):
73
77
  scoring = get_scorer(cb_supplied_name)
74
78
  scoring_name = scoring._score_func.__name__
75
- assert scoring_name.endswith(
76
- ('_score', '_error', '_deviance', '_loss'))
77
- if (scoring_name.endswith('_score') or
78
- cb_supplied_name.startswith('neg_')):
79
+ assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
80
+ if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
79
81
  lower_is_better = False
80
82
  else:
81
83
  lower_is_better = True
82
- train_name = f'train_{cb_supplied_name}'
83
- valid_name = f'valid_{cb_supplied_name}'
84
+ train_name = f"train_{cb_supplied_name}"
85
+ valid_name = f"valid_{cb_supplied_name}"
84
86
  if self.cropped:
85
87
  train_scoring = CroppedTrialEpochScoring(
86
88
  cb_supplied_name, lower_is_better, on_train=True, name=train_name
@@ -98,7 +100,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
98
100
  named_by_user = True
99
101
  train_valid_callbacks = [
100
102
  (train_name, train_scoring, named_by_user),
101
- (valid_name, valid_scoring, named_by_user)
103
+ (valid_name, valid_scoring, named_by_user),
102
104
  ]
103
105
  return train_valid_callbacks
104
106
 
@@ -108,8 +110,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
108
110
  if not training:
109
111
  epoch_cbs = []
110
112
  for name, cb in self.callbacks_:
111
- if isinstance(cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)) and (
112
- hasattr(cb, 'window_inds_')) and (not cb.on_train):
113
+ if (
114
+ isinstance(
115
+ cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
116
+ )
117
+ and (hasattr(cb, "window_inds_"))
118
+ and (not cb.on_train)
119
+ ):
113
120
  epoch_cbs.append(cb)
114
121
  # for trialwise decoding stuffs it might also be we don't have
115
122
  # cropped loader, so no indices there
@@ -136,8 +143,11 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
136
143
  i_window_stops = np.concatenate(i_window_stops)
137
144
  window_ys = np.concatenate(window_ys)
138
145
  return dict(
139
- preds=preds, i_window_in_trials=i_window_in_trials,
140
- i_window_stops=i_window_stops, window_ys=window_ys)
146
+ preds=preds,
147
+ i_window_in_trials=i_window_in_trials,
148
+ i_window_stops=i_window_stops,
149
+ window_ys=window_ys,
150
+ )
141
151
 
142
152
  # Changes the default target extractor to noop
143
153
  @property
@@ -156,7 +166,9 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
156
166
  (
157
167
  "valid_loss",
158
168
  BatchScoring(
159
- valid_loss_score, name="valid_loss", target_extractor=noop,
169
+ valid_loss_score,
170
+ name="valid_loss",
171
+ target_extractor=noop,
160
172
  ),
161
173
  ),
162
174
  ("print_log", PrintLog()),
@@ -179,17 +191,27 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
179
191
  return
180
192
  # get kwargs from signal:
181
193
  signal_kwargs = dict()
182
- if isinstance(X, mne.BaseEpochs) or isinstance(X, np.ndarray):
194
+ # Using shape to work both with torch.tensor and numpy.array:
195
+ if (
196
+ isinstance(X, mne.BaseEpochs)
197
+ or (hasattr(X, "shape") and len(X.shape) >= 2)
198
+ or isinstance(X, SliceDataset)
199
+ ):
183
200
  if y is None:
184
- raise ValueError("y must be specified if X is a numpy array.")
185
- signal_kwargs['n_outputs'] = self._get_n_outputs(y, classes)
201
+ raise ValueError("y must be specified if X is array-like.")
202
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
186
203
  if isinstance(X, mne.BaseEpochs):
187
204
  self.log.info("Using mne.Epochs to find signal-related parameters.")
188
205
  signal_kwargs["n_times"] = len(X.times)
189
- signal_kwargs["sfreq"] = X.info['sfreq']
190
- signal_kwargs["chs_info"] = X.info['chs']
206
+ signal_kwargs["sfreq"] = X.info["sfreq"]
207
+ signal_kwargs["chs_info"] = X.info["chs"]
208
+ elif isinstance(X, SliceDataset):
209
+ self.log.info("Using SliceDataset to find signal-related parameters.")
210
+ Xshape = X[0].shape
211
+ signal_kwargs["n_times"] = Xshape[-1]
212
+ signal_kwargs["n_chans"] = Xshape[-2]
191
213
  else:
192
- self.log.info("Using numpy array to find signal-related parameters.")
214
+ self.log.info("Using array-like to find signal-related parameters.")
193
215
  signal_kwargs["n_times"] = X.shape[-1]
194
216
  signal_kwargs["n_chans"] = X.shape[-2]
195
217
  elif is_dataset(X):
@@ -198,21 +220,17 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
198
220
  Xshape = X0.shape
199
221
  signal_kwargs["n_times"] = Xshape[-1]
200
222
  signal_kwargs["n_chans"] = Xshape[-2]
201
- if (
202
- isinstance(X, BaseConcatDataset) and
203
- all(ds.targets_from == 'metadata' for ds in X.datasets)
223
+ if isinstance(X, BaseConcatDataset) and all(
224
+ ds.targets_from == "metadata" for ds in X.datasets
204
225
  ):
205
226
  y_target = X.get_metadata().target
206
- signal_kwargs['n_outputs'] = self._get_n_outputs(y_target, classes)
207
- elif (
208
- isinstance(X, WindowsDataset) and
209
- X.targets_from == "metadata"
210
- ):
227
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
228
+ elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
211
229
  y_target = X.windows.metadata.target
212
- signal_kwargs['n_outputs'] = self._get_n_outputs(y_target, classes)
230
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
213
231
  else:
214
232
  self.log.warning(
215
- "Can only infer signal shape of numpy arrays or and Datasets, "
233
+ "Can only infer signal shape of array-like and Datasets, "
216
234
  f"got {type(X)!r}."
217
235
  )
218
236
  return
@@ -227,15 +245,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
227
245
  if k in all_module_kwargs:
228
246
  module_kwargs[k] = v
229
247
  else:
230
- self.log.warning(
231
- f"Module {self.module!r} "
232
- f"is missing parameter {k!r}."
233
- )
248
+ self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
234
249
 
235
250
  # save kwargs to self:
236
251
  self.log.info(
237
252
  f"Passing additional parameters {module_kwargs!r} "
238
- f"to module {self.module!r}.")
253
+ f"to module {self.module!r}."
254
+ )
239
255
  module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
240
256
  self.set_params(**module_kwargs)
241
257
 
@@ -275,7 +291,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
275
291
 
276
292
  """
277
293
  if isinstance(X, mne.BaseEpochs):
278
- X = X.get_data(units='uV')
294
+ X = X.get_data(units="uV")
279
295
  return super().get_dataset(X, y)
280
296
 
281
297
  def partial_fit(self, X, y=None, classes=None, **fit_params):
@@ -291,7 +307,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
291
307
 
292
308
  * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
293
309
  ``sfreq``, ``input_window_seconds``
294
- * numpy array: ``n_times``, ``n_chans``, ``n_outputs``
310
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
295
311
  * WindowsDataset with ``targets_from='metadata'``
296
312
  (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
297
313
  * other Dataset: ``n_times``, ``n_chans``
@@ -345,7 +361,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
345
361
 
346
362
  * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
347
363
  ``sfreq``, ``input_window_seconds``
348
- * numpy array: ``n_times``, ``n_chans``, ``n_outputs``
364
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
349
365
  * WindowsDataset with ``targets_from='metadata'``
350
366
  (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
351
367
  * other Dataset: ``n_times``, ``n_chans``
@@ -0,0 +1,22 @@
1
+ from .functions import (
2
+ _get_gaussian_kernel1d,
3
+ drop_path,
4
+ hilbert_freq,
5
+ identity,
6
+ plv_time,
7
+ safe_log,
8
+ square,
9
+ )
10
+ from .initialization import glorot_weight_zero_bias, rescale_parameter
11
+
12
+ __all__ = [
13
+ "_get_gaussian_kernel1d",
14
+ "drop_path",
15
+ "hilbert_freq",
16
+ "identity",
17
+ "plv_time",
18
+ "safe_log",
19
+ "square",
20
+ "glorot_weight_zero_bias",
21
+ "rescale_parameter",
22
+ ]
@@ -0,0 +1,250 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def square(x):
10
+ return x * x
11
+
12
+
13
+ def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
14
+ """Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
15
+ return torch.log(torch.clamp(x, min=eps))
16
+
17
+
18
+ def identity(x):
19
+ return x
20
+
21
+
22
+ def drop_path(
23
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
24
+ ):
25
+ """Drop paths (Stochastic Depth) per sample.
26
+
27
+
28
+ Notes: This implementation is taken from timm library.
29
+
30
+ All credit goes to Ross Wightman.
31
+
32
+ Parameters
33
+ ----------
34
+ x: torch.Tensor
35
+ input tensor
36
+ drop_prob : float, optional
37
+ survival rate (i.e. probability of being kept), by default 0.0
38
+ training : bool, optional
39
+ whether the model is in training mode, by default False
40
+ scale_by_keep : bool, optional
41
+ whether to scale output by (1/keep_prob) during training, by default True
42
+
43
+ Returns
44
+ -------
45
+ torch.Tensor
46
+ output tensor
47
+
48
+ Notes from Ross Wightman:
49
+ (when applied in main path of residual blocks)
50
+ This is the same as the DropConnect impl I created for EfficientNet,
51
+ etc. networks, however,
52
+ the original name is misleading as 'Drop Connect' is a different form
53
+ of dropout in a separate paper...
54
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
55
+ ... I've opted for changing the layer and argument names to 'drop path'
56
+ rather than mix DropConnect as a layer name and use
57
+ 'survival rate' as the argument.
58
+
59
+ """
60
+ if drop_prob == 0.0 or not training:
61
+ return x
62
+ keep_prob = 1 - drop_prob
63
+ shape = (x.shape[0],) + (1,) * (
64
+ x.ndim - 1
65
+ ) # work with diff dim tensors, not just 2D ConvNets
66
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
67
+ if keep_prob > 0.0 and scale_by_keep:
68
+ random_tensor.div_(keep_prob)
69
+ return x * random_tensor
70
+
71
+
72
+ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
73
+ """
74
+ Generates a 1-dimensional Gaussian kernel based on the specified kernel
75
+ size and standard deviation (sigma).
76
+ This kernel is useful for Gaussian smoothing or filtering operations in
77
+ image processing. The function calculates a range limit to ensure the kernel
78
+ effectively covers the Gaussian distribution. It generates a tensor of
79
+ specified size and type, filled with values distributed according to a
80
+ Gaussian curve, normalized using a softmax function
81
+ to ensure all weights sum to 1.
82
+
83
+
84
+ Parameters
85
+ ----------
86
+ kernel_size: int
87
+ sigma: float
88
+
89
+ Returns
90
+ -------
91
+ kernel1d: torch.Tensor
92
+
93
+ Notes
94
+ -----
95
+ Code copied and modified from TorchVision:
96
+ https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
97
+ All rights reserved.
98
+
99
+ LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
100
+
101
+ """
102
+ ksize_half = (kernel_size - 1) * 0.5
103
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
104
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
105
+ kernel1d = pdf / pdf.sum()
106
+ return kernel1d
107
+
108
+
109
+ def hilbert_freq(x, forward_fourier=True):
110
+ r"""
111
+ Compute the Hilbert transform using PyTorch, separating the real and
112
+ imaginary parts.
113
+
114
+ The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
115
+ is defined as:
116
+
117
+ .. math::
118
+
119
+ x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
120
+
121
+ where:
122
+ - :math:`\mathcal{F}` is the Fourier transform,
123
+ - :math:`U(f)` is the unit step function,
124
+ - :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
125
+
126
+
127
+ Parameters
128
+ ----------
129
+ input : torch.Tensor
130
+ Input tensor. The expected shape depends on the `forward_fourier` parameter:
131
+
132
+ - If `forward_fourier` is True:
133
+ (..., seq_len)
134
+ - If `forward_fourier` is False:
135
+ (..., seq_len / 2 + 1, 2)
136
+
137
+ forward_fourier : bool, optional
138
+ Determines the format of the input tensor.
139
+ - If True, the input is in the forward Fourier domain.
140
+ - If False, the input contains separate real and imaginary parts.
141
+ Default is True.
142
+
143
+ Returns
144
+ -------
145
+ torch.Tensor
146
+ Output tensor with shape (..., seq_len, 2), where the last dimension represents
147
+ the real and imaginary parts of the Hilbert transform.
148
+
149
+ Examples
150
+ --------
151
+ >>> import torch
152
+ >>> input = torch.randn(10, 100) # Example input tensor
153
+ >>> output = hilbert_transform(input)
154
+ >>> print(output.shape)
155
+ torch.Size([10, 100, 2])
156
+
157
+ Notes
158
+ -----
159
+ The implementation is matching scipy implementation, but using torch.
160
+ https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
161
+
162
+ """
163
+ if forward_fourier:
164
+ x = torch.fft.rfft(x, norm=None, dim=-1)
165
+ x = torch.view_as_real(x)
166
+ x = x * 2.0
167
+ x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
168
+ x = F.pad(
169
+ x, [0, 0, 0, x.shape[-2] - 2]
170
+ ) # Fill Fourier coefficients to retain shape
171
+ x = torch.view_as_complex(x)
172
+ x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
173
+ x = torch.view_as_real(x)
174
+
175
+ return x
176
+
177
+
178
+ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
179
+ """Compute the Phase Locking Value (PLV) metric in the time domain.
180
+
181
+ The Phase Locking Value (PLV) is a measure of the synchronization between
182
+ different channels by evaluating the consistency of phase differences
183
+ over time. It ranges from 0 (no synchronization) to 1 (perfect
184
+ synchronization) [1]_.
185
+
186
+ Parameters
187
+ ----------
188
+ x : torch.Tensor
189
+ Input tensor containing the signal data.
190
+ - If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
191
+ - If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
192
+ where the last dimension represents the real and imaginary parts.
193
+ forward_fourier : bool, optional
194
+ Specifies the format of the input tensor `x`.
195
+ - If `True`, `x` is assumed to be in the time domain.
196
+ - If `False`, `x` is assumed to be in the Fourier domain with separate real and
197
+ imaginary components.
198
+ Default is `True`.
199
+ epsilon : float, default 1e-6
200
+ Small numerical value to ensure positivity constraint on the complex part
201
+
202
+ Returns
203
+ -------
204
+ plv : torch.Tensor
205
+ The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
206
+ element `[i, j]` represents the PLV between channel `i` and channel `j`.
207
+
208
+ References
209
+ ----------
210
+ [1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
211
+ Measuring phase synchrony in brain signals. Human brain mapping,
212
+ 8(4), 194-208.
213
+ """
214
+ # Compute the analytic signal using the Hilbert transform.
215
+ # x_a has separate real and imaginary parts.
216
+ analytic_signal = hilbert_freq(x, forward_fourier)
217
+ # Calculate the amplitude (magnitude) of the analytic signal.
218
+ # Adding a small epsilon (1e-6) to avoid division by zero.
219
+ amplitude = torch.sqrt(
220
+ analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
221
+ )
222
+ # Normalize the analytic signal to obtain unit vectors (phasors).
223
+ unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
224
+
225
+ # Compute the real part of the outer product between phasors of
226
+ # different channels.
227
+ real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
228
+
229
+ # Compute the imaginary part of the outer product between phasors of
230
+ # different channels.
231
+ imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
232
+
233
+ # Compute the cross-terms for the real and imaginary parts.
234
+ real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
235
+ imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
236
+
237
+ # Combine the real and imaginary parts to form the complex correlation.
238
+ correlation_real = real_real + imag_imag
239
+ correlation_imag = real_imag - imag_real
240
+
241
+ # Determine the number of time points (or frequency bins if in Fourier domain).
242
+ time = amplitude.shape[-1]
243
+
244
+ # Calculate the PLV by averaging the magnitude of the complex correlation over time.
245
+ # epsilon is small numerical value to ensure positivity constraint on the complex part
246
+ plv_matrix = (
247
+ 1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
248
+ )
249
+
250
+ return plv_matrix
@@ -0,0 +1,47 @@
1
+ import math
2
+
3
+ from torch import nn
4
+
5
+
6
+ def glorot_weight_zero_bias(model):
7
+ """Initialize parameters of all modules by initializing weights with
8
+ glorot
9
+ uniform/xavier initialization, and setting biases to zero. Weights from
10
+ batch norm layers are set to 1.
11
+
12
+ Parameters
13
+ ----------
14
+ model: Module
15
+ """
16
+ for module in model.modules():
17
+ if hasattr(module, "weight"):
18
+ if "BatchNorm" in module.__class__.__name__:
19
+ nn.init.constant_(module.weight, 1)
20
+ if hasattr(module, "bias"):
21
+ if module.bias is not None:
22
+ nn.init.constant_(module.bias, 0)
23
+
24
+
25
+ def rescale_parameter(param, layer_id):
26
+ r"""Recaling the l-th transformer layer.
27
+
28
+ Rescales the parameter tensor by the inverse square root of the layer id.
29
+ Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
30
+
31
+ In the labram, this is used to rescale the output matrices
32
+ (i.e., the last linear projection within each sub-layer) of the
33
+ self-attention module.
34
+
35
+ Parameters
36
+ ----------
37
+ param: :class:`torch.Tensor`
38
+ tensor to be rescaled
39
+ layer_id: int
40
+ layer id in the neural network
41
+
42
+ References
43
+ ----------
44
+ [Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
45
+ Pre-Training of Image Transformers.
46
+ """
47
+ param.div_(math.sqrt(2.0 * layer_id))
@@ -1,30 +1,100 @@
1
1
  """
2
2
  Some predefined network architectures for EEG decoding.
3
3
  """
4
+
5
+ from .atcnet import ATCNet
6
+ from .attentionbasenet import AttentionBaseNet
4
7
  from .base import EEGModuleMixin
5
- from .eegconformer import EEGConformer
6
- from .eegitnet import EEGITNet
8
+ from .biot import BIOT
9
+ from .contrawr import ContraWR
10
+ from .ctnet import CTNet
7
11
  from .deep4 import Deep4Net
8
12
  from .deepsleepnet import DeepSleepNet
9
- from .eegnet import EEGNetv4, EEGNetv1
10
- from .hybrid import HybridNet
11
- from .shallow_fbcsp import ShallowFBCSPNet
12
- from .eegresnet import EEGResNet
13
- from .eeginception import EEGInception
13
+ from .eegconformer import EEGConformer
14
14
  from .eeginception_erp import EEGInceptionERP
15
15
  from .eeginception_mi import EEGInceptionMI
16
- from .atcnet import ATCNet
17
- from .tcn import TCN
18
- from .sleep_stager_chambon_2018 import SleepStagerChambon2018
16
+ from .eegitnet import EEGITNet
17
+ from .eegminer import EEGMiner
18
+ from .eegnet import EEGNetv1, EEGNetv4
19
+ from .eegnex import EEGNeX
20
+ from .eegresnet import EEGResNet
21
+ from .eegsimpleconv import EEGSimpleConv
22
+ from .eegtcnet import EEGTCNet
23
+ from .fbcnet import FBCNet
24
+ from .fblightconvnet import FBLightConvNet
25
+ from .fbmsnet import FBMSNet
26
+ from .hybrid import HybridNet
27
+ from .ifnet import IFNet
28
+ from .labram import Labram
29
+ from .msvtnet import MSVTNet
30
+ from .sccnet import SCCNet
31
+ from .shallow_fbcsp import ShallowFBCSPNet
32
+ from .signal_jepa import (
33
+ SignalJEPA,
34
+ SignalJEPA_Contextual,
35
+ SignalJEPA_PostLocal,
36
+ SignalJEPA_PreLocal,
37
+ )
38
+ from .sinc_shallow import SincShallowNet
19
39
  from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
40
+ from .sleep_stager_chambon_2018 import SleepStagerChambon2018
20
41
  from .sleep_stager_eldele_2021 import SleepStagerEldele2021
42
+ from .sparcnet import SPARCNet
43
+ from .syncnet import SyncNet
44
+ from .tcn import BDTCN, TCN
21
45
  from .tidnet import TIDNet
46
+ from .tsinception import TSceptionV1
22
47
  from .usleep import USleep
23
- from .util import get_output_shape, to_dense_prediction_model
24
- from .modules import TimeDistributed
25
-
26
- from .util import _init_models_dict
48
+ from .util import _init_models_dict, models_mandatory_parameters
27
49
 
28
50
  # Call this last in order to make sure the dataset list is populated with
29
51
  # the models imported in this file.
30
52
  _init_models_dict()
53
+
54
+ __all__ = [
55
+ "ATCNet",
56
+ "AttentionBaseNet",
57
+ "EEGModuleMixin",
58
+ "BIOT",
59
+ "ContraWR",
60
+ "CTNet",
61
+ "Deep4Net",
62
+ "DeepSleepNet",
63
+ "EEGConformer",
64
+ "EEGInceptionERP",
65
+ "EEGInceptionMI",
66
+ "EEGITNet",
67
+ "EEGMiner",
68
+ "EEGNetv1",
69
+ "EEGNetv4",
70
+ "EEGNeX",
71
+ "EEGResNet",
72
+ "EEGSimpleConv",
73
+ "EEGTCNet",
74
+ "FBCNet",
75
+ "FBLightConvNet",
76
+ "FBMSNet",
77
+ "HybridNet",
78
+ "IFNet",
79
+ "Labram",
80
+ "MSVTNet",
81
+ "SCCNet",
82
+ "ShallowFBCSPNet",
83
+ "SignalJEPA",
84
+ "SignalJEPA_Contextual",
85
+ "SignalJEPA_PostLocal",
86
+ "SignalJEPA_PreLocal",
87
+ "SincShallowNet",
88
+ "SleepStagerBlanco2020",
89
+ "SleepStagerChambon2018",
90
+ "SleepStagerEldele2021",
91
+ "SPARCNet",
92
+ "SyncNet",
93
+ "BDTCN",
94
+ "TCN",
95
+ "TIDNet",
96
+ "TSceptionV1",
97
+ "USleep",
98
+ "_init_models_dict",
99
+ "models_mandatory_parameters",
100
+ ]