braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,251 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def square(x):
11
+ return x * x
12
+
13
+
14
+ def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
15
+ """Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
16
+ return torch.log(torch.clamp(x, min=eps))
17
+
18
+
19
+ def identity(x):
20
+ return x
21
+
22
+
23
+ def drop_path(
24
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
25
+ ):
26
+ """Drop paths (Stochastic Depth) per sample.
27
+
28
+
29
+ Notes: This implementation is taken from timm library.
30
+
31
+ All credit goes to Ross Wightman.
32
+
33
+ Parameters
34
+ ----------
35
+ x: torch.Tensor
36
+ input tensor
37
+ drop_prob : float, optional
38
+ survival rate (i.e. probability of being kept), by default 0.0
39
+ training : bool, optional
40
+ whether the model is in training mode, by default False
41
+ scale_by_keep : bool, optional
42
+ whether to scale output by (1/keep_prob) during training, by default True
43
+
44
+ Returns
45
+ -------
46
+ torch.Tensor
47
+ output tensor
48
+
49
+ Notes from Ross Wightman:
50
+ (when applied in main path of residual blocks)
51
+ This is the same as the DropConnect impl I created for EfficientNet,
52
+ etc. networks, however,
53
+ the original name is misleading as 'Drop Connect' is a different form
54
+ of dropout in a separate paper...
55
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
56
+ ... I've opted for changing the layer and argument names to 'drop path'
57
+ rather than mix DropConnect as a layer name and use
58
+ 'survival rate' as the argument.
59
+
60
+ """
61
+ if drop_prob == 0.0 or not training:
62
+ return x
63
+ keep_prob = 1 - drop_prob
64
+ shape = (x.shape[0],) + (1,) * (
65
+ x.ndim - 1
66
+ ) # work with diff dim tensors, not just 2D ConvNets
67
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
68
+ if keep_prob > 0.0 and scale_by_keep:
69
+ random_tensor.div_(keep_prob)
70
+ return x * random_tensor
71
+
72
+
73
+ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
74
+ """
75
+ Generates a 1-dimensional Gaussian kernel based on the specified kernel
76
+ size and standard deviation (sigma).
77
+ This kernel is useful for Gaussian smoothing or filtering operations in
78
+ image processing. The function calculates a range limit to ensure the kernel
79
+ effectively covers the Gaussian distribution. It generates a tensor of
80
+ specified size and type, filled with values distributed according to a
81
+ Gaussian curve, normalized using a softmax function
82
+ to ensure all weights sum to 1.
83
+
84
+
85
+ Parameters
86
+ ----------
87
+ kernel_size: int
88
+ sigma: float
89
+
90
+ Returns
91
+ -------
92
+ kernel1d: torch.Tensor
93
+
94
+ Notes
95
+ -----
96
+ Code copied and modified from TorchVision:
97
+ https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
98
+ All rights reserved.
99
+
100
+ LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
101
+
102
+ """
103
+ ksize_half = (kernel_size - 1) * 0.5
104
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
105
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
106
+ kernel1d = pdf / pdf.sum()
107
+ return kernel1d
108
+
109
+
110
+ def hilbert_freq(x, forward_fourier=True):
111
+ r"""
112
+ Compute the Hilbert transform using PyTorch, separating the real and
113
+ imaginary parts.
114
+
115
+ The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
116
+ is defined as:
117
+
118
+ .. math::
119
+
120
+ x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
121
+
122
+ where:
123
+ - :math:`\mathcal{F}` is the Fourier transform,
124
+ - :math:`U(f)` is the unit step function,
125
+ - :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
126
+
127
+
128
+ Parameters
129
+ ----------
130
+ input : torch.Tensor
131
+ Input tensor. The expected shape depends on the `forward_fourier` parameter:
132
+
133
+ - If `forward_fourier` is True:
134
+ (..., seq_len)
135
+ - If `forward_fourier` is False:
136
+ (..., seq_len / 2 + 1, 2)
137
+
138
+ forward_fourier : bool, optional
139
+ Determines the format of the input tensor.
140
+ - If True, the input is in the forward Fourier domain.
141
+ - If False, the input contains separate real and imaginary parts.
142
+ Default is True.
143
+
144
+ Returns
145
+ -------
146
+ torch.Tensor
147
+ Output tensor with shape (..., seq_len, 2), where the last dimension represents
148
+ the real and imaginary parts of the Hilbert transform.
149
+
150
+ Examples
151
+ --------
152
+ >>> import torch
153
+ >>> input = torch.randn(10, 100) # Example input tensor
154
+ >>> output = hilbert_transform(input)
155
+ >>> print(output.shape)
156
+ torch.Size([10, 100, 2])
157
+
158
+ Notes
159
+ -----
160
+ The implementation is matching scipy implementation, but using torch.
161
+ https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
162
+
163
+ """
164
+ if forward_fourier:
165
+ x = torch.fft.rfft(x, norm=None, dim=-1)
166
+ x = torch.view_as_real(x)
167
+ x = x * 2.0
168
+ x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
169
+ x = F.pad(
170
+ x, [0, 0, 0, x.shape[-2] - 2]
171
+ ) # Fill Fourier coefficients to retain shape
172
+ x = torch.view_as_complex(x)
173
+ x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
174
+ x = torch.view_as_real(x)
175
+
176
+ return x
177
+
178
+
179
+ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
180
+ """Compute the Phase Locking Value (PLV) metric in the time domain.
181
+
182
+ The Phase Locking Value (PLV) is a measure of the synchronization between
183
+ different channels by evaluating the consistency of phase differences
184
+ over time. It ranges from 0 (no synchronization) to 1 (perfect
185
+ synchronization) [1]_.
186
+
187
+ Parameters
188
+ ----------
189
+ x : torch.Tensor
190
+ Input tensor containing the signal data.
191
+ - If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
192
+ - If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
193
+ where the last dimension represents the real and imaginary parts.
194
+ forward_fourier : bool, optional
195
+ Specifies the format of the input tensor `x`.
196
+ - If `True`, `x` is assumed to be in the time domain.
197
+ - If `False`, `x` is assumed to be in the Fourier domain with separate real and
198
+ imaginary components.
199
+ Default is `True`.
200
+ epsilon : float, default 1e-6
201
+ Small numerical value to ensure positivity constraint on the complex part
202
+
203
+ Returns
204
+ -------
205
+ plv : torch.Tensor
206
+ The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
207
+ element `[i, j]` represents the PLV between channel `i` and channel `j`.
208
+
209
+ References
210
+ ----------
211
+ [1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
212
+ Measuring phase synchrony in brain signals. Human brain mapping,
213
+ 8(4), 194-208.
214
+ """
215
+ # Compute the analytic signal using the Hilbert transform.
216
+ # x_a has separate real and imaginary parts.
217
+ analytic_signal = hilbert_freq(x, forward_fourier)
218
+ # Calculate the amplitude (magnitude) of the analytic signal.
219
+ # Adding a small epsilon (1e-6) to avoid division by zero.
220
+ amplitude = torch.sqrt(
221
+ analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
222
+ )
223
+ # Normalize the analytic signal to obtain unit vectors (phasors).
224
+ unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
225
+
226
+ # Compute the real part of the outer product between phasors of
227
+ # different channels.
228
+ real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
229
+
230
+ # Compute the imaginary part of the outer product between phasors of
231
+ # different channels.
232
+ imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
233
+
234
+ # Compute the cross-terms for the real and imaginary parts.
235
+ real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
236
+ imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
237
+
238
+ # Combine the real and imaginary parts to form the complex correlation.
239
+ correlation_real = real_real + imag_imag
240
+ correlation_imag = real_imag - imag_real
241
+
242
+ # Determine the number of time points (or frequency bins if in Fourier domain).
243
+ time = amplitude.shape[-1]
244
+
245
+ # Calculate the PLV by averaging the magnitude of the complex correlation over time.
246
+ # epsilon is small numerical value to ensure positivity constraint on the complex part
247
+ plv_matrix = (
248
+ 1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
249
+ )
250
+
251
+ return plv_matrix
@@ -0,0 +1,47 @@
1
+ import math
2
+
3
+ from torch import nn
4
+
5
+
6
+ def glorot_weight_zero_bias(model):
7
+ """Initialize parameters of all modules by initializing weights with
8
+ glorot
9
+ uniform/xavier initialization, and setting biases to zero. Weights from
10
+ batch norm layers are set to 1.
11
+
12
+ Parameters
13
+ ----------
14
+ model: Module
15
+ """
16
+ for module in model.modules():
17
+ if hasattr(module, "weight"):
18
+ if "BatchNorm" in module.__class__.__name__:
19
+ nn.init.constant_(module.weight, 1)
20
+ if hasattr(module, "bias"):
21
+ if module.bias is not None:
22
+ nn.init.constant_(module.bias, 0)
23
+
24
+
25
+ def rescale_parameter(param, layer_id):
26
+ r"""Recaling the l-th transformer layer.
27
+
28
+ Rescales the parameter tensor by the inverse square root of the layer id.
29
+ Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
30
+
31
+ In the labram, this is used to rescale the output matrices
32
+ (i.e., the last linear projection within each sub-layer) of the
33
+ self-attention module.
34
+
35
+ Parameters
36
+ ----------
37
+ param: :class:`torch.Tensor`
38
+ tensor to be rescaled
39
+ layer_id: int
40
+ layer id in the neural network
41
+
42
+ References
43
+ ----------
44
+ [Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
45
+ Pre-Training of Image Transformers.
46
+ """
47
+ param.div_(math.sqrt(2.0 * layer_id))
@@ -0,0 +1,52 @@
1
+ """
2
+ Some predefined network architectures for EEG decoding.
3
+ """
4
+
5
+ from .atcnet import ATCNet
6
+ from .attentionbasenet import AttentionBaseNet
7
+ from .base import EEGModuleMixin
8
+ from .biot import BIOT
9
+ from .contrawr import ContraWR
10
+ from .ctnet import CTNet
11
+ from .deep4 import Deep4Net
12
+ from .deepsleepnet import DeepSleepNet
13
+ from .eegconformer import EEGConformer
14
+ from .eeginception_erp import EEGInceptionERP
15
+ from .eeginception_mi import EEGInceptionMI
16
+ from .eegitnet import EEGITNet
17
+ from .eegminer import EEGMiner
18
+ from .eegnet import EEGNetv1, EEGNetv4
19
+ from .eegnex import EEGNeX
20
+ from .eegresnet import EEGResNet
21
+ from .eegsimpleconv import EEGSimpleConv
22
+ from .eegtcnet import EEGTCNet
23
+ from .fbcnet import FBCNet
24
+ from .fblightconvnet import FBLightConvNet
25
+ from .fbmsnet import FBMSNet
26
+ from .hybrid import HybridNet
27
+ from .ifnet import IFNet
28
+ from .labram import Labram
29
+ from .msvtnet import MSVTNet
30
+ from .sccnet import SCCNet
31
+ from .shallow_fbcsp import ShallowFBCSPNet
32
+ from .signal_jepa import (
33
+ SignalJEPA,
34
+ SignalJEPA_Contextual,
35
+ SignalJEPA_PostLocal,
36
+ SignalJEPA_PreLocal,
37
+ )
38
+ from .sinc_shallow import SincShallowNet
39
+ from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
40
+ from .sleep_stager_chambon_2018 import SleepStagerChambon2018
41
+ from .sleep_stager_eldele_2021 import SleepStagerEldele2021
42
+ from .sparcnet import SPARCNet
43
+ from .syncnet import SyncNet
44
+ from .tcn import BDTCN, TCN
45
+ from .tidnet import TIDNet
46
+ from .tsinception import TSceptionV1
47
+ from .usleep import USleep
48
+ from .util import _init_models_dict, models_mandatory_parameters
49
+
50
+ # Call this last in order to make sure the dataset list is populated with
51
+ # the models imported in this file.
52
+ _init_models_dict()