braindecode 1.3.0.dev173691341__py3-none-any.whl → 1.3.0.dev173767962__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +26 -27
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +9 -6
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/preprocess.py +23 -14
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/METADATA +4 -2
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/RECORD +52 -49
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/top_level.txt +0 -0
braindecode/models/summary.csv
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
|
|
3
3
|
AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
4
4
|
BDTCN,Normal Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)","Convolution,Recurrent"
|
|
5
|
-
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Large
|
|
5
|
+
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Large Brain Model"
|
|
6
6
|
ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)",Convolution
|
|
7
7
|
CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)","Convolution,Small Attention"
|
|
8
8
|
Deep4Net,General,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
|
|
@@ -16,13 +16,13 @@ EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEG
|
|
|
16
16
|
EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)","Convolution,Interpretability"
|
|
17
17
|
EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)","Convolution"
|
|
18
18
|
EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)","Convolution,Recurrent"
|
|
19
|
-
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Large
|
|
19
|
+
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Large Brain Model"
|
|
20
20
|
MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
|
|
21
21
|
SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)","Convolution"
|
|
22
|
-
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large
|
|
23
|
-
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large
|
|
24
|
-
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large
|
|
25
|
-
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large
|
|
22
|
+
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
|
|
23
|
+
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
|
|
24
|
+
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
25
|
+
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
26
26
|
SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Interpretability"
|
|
27
27
|
ShallowFBCSPNet,General,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution"
|
|
28
28
|
SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
|
|
@@ -37,3 +37,6 @@ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",118
|
|
|
37
37
|
FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
38
38
|
FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
39
39
|
IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
40
|
+
PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
|
|
41
|
+
SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
42
|
+
BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
|
braindecode/models/usleep.py
CHANGED
|
@@ -62,43 +62,48 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
62
62
|
- Decoder :class:`_DecoderBlock` **(progressive upsampling + skip fusion to high-frequency map, 12 blocks; upsampling x2 per block)**
|
|
63
63
|
|
|
64
64
|
- *Operations.*
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
65
|
+
|
|
66
|
+
- **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
|
|
67
|
+
- **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
|
|
68
|
+
- ELU, :class:`torch.nn.ELU`
|
|
69
|
+
- Batch Norm, :class:`torch.nn.BatchNorm2d`
|
|
70
|
+
- **Concatenate** with the encoder skip at the same temporal scale, ``torch.cat``
|
|
71
|
+
- **Convolution**, :class:`torch.nn.Conv2d`
|
|
72
|
+
- ELU, :class:`torch.nn.ELU`
|
|
73
|
+
- Batch Norm, :class:`torch.nn.BatchNorm2d`.
|
|
73
74
|
|
|
74
75
|
**Output**: A multi-class, **high-frequency** per-sample representation aligned to the input rate (128 Hz).
|
|
75
76
|
|
|
76
77
|
- **Segment Classifier incorporate into :class:`braindecode.models.USleep` (aggregation to fixed epochs)**
|
|
77
78
|
|
|
78
79
|
- *Operations.*
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
80
|
+
|
|
81
|
+
- **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
|
|
82
|
+
- **1x1 conv**, :class:`torch.nn.Conv2d`
|
|
83
|
+
- ELU, :class:`torch.nn.ELU`
|
|
84
|
+
- **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
|
|
83
85
|
|
|
84
86
|
**Role**: Learns a **non-linear** weighted combination over each 30-s window (unlike U-Time's linear combiner).
|
|
85
87
|
|
|
86
88
|
.. rubric:: Convolutional Details
|
|
87
89
|
|
|
88
90
|
- **Temporal (where time-domain patterns are learned).**
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
91
|
+
|
|
92
|
+
All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
|
|
93
|
+
(reported sensitivity to ±6.75 min around each epoch; theoretical field ≈ 9.6 min at the deepest layer).
|
|
94
|
+
The decoder restores sample-level resolution before epoch aggregation.
|
|
92
95
|
|
|
93
96
|
- **Spatial (how channels are processed).**
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
+
|
|
98
|
+
Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
|
|
99
|
+
is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
|
|
100
|
+
supporting robustness to channel placement and hardware differences.
|
|
97
101
|
|
|
98
102
|
- **Spectral (how frequency content is captured).**
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
103
|
+
|
|
104
|
+
No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
|
|
105
|
+
filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
|
|
106
|
+
retains fine temporal detail for the segment classifier.
|
|
102
107
|
|
|
103
108
|
|
|
104
109
|
.. rubric:: Attention / Sequential Modules
|
braindecode/models/util.py
CHANGED
|
@@ -95,6 +95,9 @@ models_mandatory_parameters = [
|
|
|
95
95
|
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
96
96
|
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
97
97
|
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
98
|
+
("PBT", ["n_chans", "n_outputs", "n_times"], None),
|
|
99
|
+
("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
100
|
+
("BENDR", ["n_chans", "n_outputs", "n_times"], None),
|
|
98
101
|
]
|
|
99
102
|
|
|
100
103
|
################################################################
|
braindecode/modules/attention.py
CHANGED
|
@@ -38,7 +38,7 @@ class SqueezeAndExcitation(nn.Module):
|
|
|
38
38
|
References
|
|
39
39
|
----------
|
|
40
40
|
.. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
|
|
41
|
-
|
|
41
|
+
Squeeze-and-Excitation Networks. CVPR 2018.
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
44
|
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
|
|
@@ -93,7 +93,7 @@ class GSoP(nn.Module):
|
|
|
93
93
|
References
|
|
94
94
|
----------
|
|
95
95
|
.. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
|
|
96
|
-
|
|
96
|
+
Global Second-order Pooling Convolutional Networks. CVPR 2018.
|
|
97
97
|
"""
|
|
98
98
|
|
|
99
99
|
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
|
|
@@ -149,7 +149,7 @@ class FCA(nn.Module):
|
|
|
149
149
|
References
|
|
150
150
|
----------
|
|
151
151
|
.. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
|
|
152
|
-
|
|
152
|
+
FcaNet: Frequency Channel Attention Networks. ICCV 2021.
|
|
153
153
|
"""
|
|
154
154
|
|
|
155
155
|
def __init__(
|
|
@@ -233,7 +233,7 @@ class EncNet(nn.Module):
|
|
|
233
233
|
References
|
|
234
234
|
----------
|
|
235
235
|
.. [Zhang2018] Zhang, H. et al. 2018.
|
|
236
|
-
|
|
236
|
+
Context Encoding for Semantic Segmentation. CVPR 2018.
|
|
237
237
|
"""
|
|
238
238
|
|
|
239
239
|
def __init__(self, in_channels: int, n_codewords: int):
|
|
@@ -290,7 +290,7 @@ class ECA(nn.Module):
|
|
|
290
290
|
References
|
|
291
291
|
----------
|
|
292
292
|
.. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
|
|
293
|
-
|
|
293
|
+
for Deep Convolutional Neural Networks. CVPR 2021.
|
|
294
294
|
"""
|
|
295
295
|
|
|
296
296
|
def __init__(self, in_channels: int, kernel_size: int):
|
|
@@ -341,8 +341,8 @@ class GatherExcite(nn.Module):
|
|
|
341
341
|
References
|
|
342
342
|
----------
|
|
343
343
|
.. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
|
|
344
|
-
|
|
345
|
-
|
|
344
|
+
Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
|
|
345
|
+
NeurIPS 2018.
|
|
346
346
|
"""
|
|
347
347
|
|
|
348
348
|
def __init__(
|
|
@@ -410,7 +410,7 @@ class GCT(nn.Module):
|
|
|
410
410
|
References
|
|
411
411
|
----------
|
|
412
412
|
.. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
|
|
413
|
-
|
|
413
|
+
Gated Channel Transformation for Visual Recognition. CVPR 2020.
|
|
414
414
|
"""
|
|
415
415
|
|
|
416
416
|
def __init__(self, in_channels: int):
|
|
@@ -455,7 +455,7 @@ class SRM(nn.Module):
|
|
|
455
455
|
References
|
|
456
456
|
----------
|
|
457
457
|
.. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
|
|
458
|
-
|
|
458
|
+
Recalibration Module for Convolutional Neural Networks. ICCV 2019.
|
|
459
459
|
"""
|
|
460
460
|
|
|
461
461
|
def __init__(
|
|
@@ -520,7 +520,7 @@ class CBAM(nn.Module):
|
|
|
520
520
|
References
|
|
521
521
|
----------
|
|
522
522
|
.. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
|
|
523
|
-
|
|
523
|
+
CBAM: Convolutional Block Attention Module. ECCV 2018.
|
|
524
524
|
"""
|
|
525
525
|
|
|
526
526
|
def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
|
braindecode/modules/blocks.py
CHANGED
|
@@ -37,8 +37,8 @@ class MLP(nn.Sequential):
|
|
|
37
37
|
:math:`a_i` are called activation functions. The trainable parameters of an
|
|
38
38
|
MLP are its weights and biases :math:`\\phi = \{W_i, b_i | i = 1, \dots, L\}`.
|
|
39
39
|
|
|
40
|
-
Parameters
|
|
41
|
-
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
42
|
in_features: int
|
|
43
43
|
Number of input features.
|
|
44
44
|
hidden_features: Sequential[int] (default=None)
|
|
@@ -49,7 +49,7 @@ class MLP(nn.Sequential):
|
|
|
49
49
|
out_features: int (default=None)
|
|
50
50
|
Number of output features, if None, set to in_features.
|
|
51
51
|
act_layer: nn.GELU (default)
|
|
52
|
-
The activation function constructor. If
|
|
52
|
+
The activation function constructor. If ``None``, use
|
|
53
53
|
:class:`torch.nn.GELU` instead.
|
|
54
54
|
drop: float (default=0.0)
|
|
55
55
|
Dropout rate.
|
braindecode/modules/filter.py
CHANGED
|
@@ -17,9 +17,8 @@ class FilterBankLayer(nn.Module):
|
|
|
17
17
|
It uses MNE's `create_filter` function to create the band-specific filters and
|
|
18
18
|
applies them to multi-channel time-series data. Each filter in the bank corresponds to a
|
|
19
19
|
specific frequency band and is applied to all channels of the input data. The filtering is
|
|
20
|
-
performed using FFT-based convolution via the
|
|
21
|
-
|
|
22
|
-
:func:`torchaudio.functional if the method is IIR.
|
|
20
|
+
performed using FFT-based convolution via the ``torchaudio.functional`` if the method is FIR,
|
|
21
|
+
and ``torchaudio.functional`` if the method is IIR.
|
|
23
22
|
|
|
24
23
|
The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
|
|
25
24
|
spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
|
braindecode/modules/layers.py
CHANGED
|
@@ -70,26 +70,27 @@ class TimeDistributed(nn.Module):
|
|
|
70
70
|
class DropPath(nn.Module):
|
|
71
71
|
"""Drop paths, also known as Stochastic Depth, per sample.
|
|
72
72
|
|
|
73
|
-
|
|
73
|
+
When applied in main path of residual blocks.
|
|
74
74
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
drop_prob: float (default=None)
|
|
78
|
+
Drop path probability (should be in range 0-1).
|
|
79
79
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
80
|
+
Notes
|
|
81
|
+
-----
|
|
82
|
+
Code copied and modified from VISSL facebookresearch:
|
|
83
83
|
https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
|
|
85
|
+
All rights reserved.
|
|
86
|
+
|
|
87
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
88
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
89
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
90
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
91
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
92
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
93
|
+
SOFTWARE.
|
|
93
94
|
"""
|
|
94
95
|
|
|
95
96
|
def __init__(self, drop_prob=None):
|
|
@@ -30,8 +30,8 @@ from numpy.typing import NDArray
|
|
|
30
30
|
|
|
31
31
|
from braindecode.datasets.base import (
|
|
32
32
|
BaseConcatDataset,
|
|
33
|
-
BaseDataset,
|
|
34
33
|
EEGWindowsDataset,
|
|
34
|
+
RawDataset,
|
|
35
35
|
WindowsDataset,
|
|
36
36
|
)
|
|
37
37
|
from braindecode.datautil.serialization import (
|
|
@@ -112,13 +112,14 @@ def preprocess(
|
|
|
112
112
|
n_jobs: int | None = None,
|
|
113
113
|
offset: int = 0,
|
|
114
114
|
copy_data: bool | None = None,
|
|
115
|
+
parallel_kwargs: dict | None = None,
|
|
115
116
|
):
|
|
116
117
|
"""Apply preprocessors to a concat dataset.
|
|
117
118
|
|
|
118
119
|
Parameters
|
|
119
120
|
----------
|
|
120
121
|
concat_ds : BaseConcatDataset
|
|
121
|
-
A concat of ``
|
|
122
|
+
A concat of ``RecordDataset`` to be preprocessed.
|
|
122
123
|
preprocessors : list of Preprocessor
|
|
123
124
|
Preprocessor objects to apply to each dataset.
|
|
124
125
|
save_dir : str | None
|
|
@@ -135,6 +136,10 @@ def preprocess(
|
|
|
135
136
|
and saving very large datasets in chunks to preserve original positions.
|
|
136
137
|
copy_data : bool | None
|
|
137
138
|
Whether the data passed to parallel jobs should be copied or passed by reference.
|
|
139
|
+
parallel_kwargs : dict | None
|
|
140
|
+
Additional keyword arguments forwarded to ``joblib.Parallel``.
|
|
141
|
+
Defaults to None (equivalent to ``{}``).
|
|
142
|
+
See https://joblib.readthedocs.io/en/stable/generated/joblib.Parallel.html for details.
|
|
138
143
|
|
|
139
144
|
Returns
|
|
140
145
|
-------
|
|
@@ -153,8 +158,12 @@ def preprocess(
|
|
|
153
158
|
|
|
154
159
|
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
|
|
155
160
|
|
|
156
|
-
|
|
157
|
-
|
|
161
|
+
parallel_params = {} if parallel_kwargs is None else dict(parallel_kwargs)
|
|
162
|
+
parallel_params.setdefault(
|
|
163
|
+
"prefer", "threads" if platform.system() == "Windows" else None
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
list_of_ds = Parallel(n_jobs=n_jobs, **parallel_params)(
|
|
158
167
|
delayed(_preprocess)(
|
|
159
168
|
ds,
|
|
160
169
|
i + offset,
|
|
@@ -220,15 +229,15 @@ def _preprocess(
|
|
|
220
229
|
|
|
221
230
|
Parameters
|
|
222
231
|
----------
|
|
223
|
-
ds:
|
|
232
|
+
ds: RecordDataset
|
|
224
233
|
Dataset object to preprocess.
|
|
225
234
|
ds_index : int
|
|
226
|
-
Index of the
|
|
235
|
+
Index of the ``RecordDataset`` in its ``BaseConcatDataset``. Ignored if save_dir
|
|
227
236
|
is None.
|
|
228
237
|
preprocessors: list(Preprocessor)
|
|
229
238
|
List of preprocessors to apply to the dataset.
|
|
230
239
|
save_dir : str | None
|
|
231
|
-
If provided, save the preprocessed
|
|
240
|
+
If provided, save the preprocessed RecordDataset in the
|
|
232
241
|
specified directory.
|
|
233
242
|
overwrite : bool
|
|
234
243
|
If True, overwrite existing file with the same name.
|
|
@@ -254,8 +263,8 @@ def _preprocess(
|
|
|
254
263
|
_preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
255
264
|
else:
|
|
256
265
|
raise ValueError(
|
|
257
|
-
"Can only preprocess concatenation of
|
|
258
|
-
"
|
|
266
|
+
"Can only preprocess concatenation of RecordDataset, "
|
|
267
|
+
"with either a `raw` or `windows` attribute."
|
|
259
268
|
)
|
|
260
269
|
|
|
261
270
|
# Store preprocessing keyword arguments in the dataset
|
|
@@ -288,11 +297,11 @@ def _get_preproc_kwargs(preprocessors):
|
|
|
288
297
|
|
|
289
298
|
|
|
290
299
|
def _set_preproc_kwargs(ds, preprocessors):
|
|
291
|
-
"""Record preprocessing keyword arguments in
|
|
300
|
+
"""Record preprocessing keyword arguments in RecordDataset.
|
|
292
301
|
|
|
293
302
|
Parameters
|
|
294
303
|
----------
|
|
295
|
-
ds :
|
|
304
|
+
ds : RecordDataset
|
|
296
305
|
Dataset in which to record preprocessing keyword arguments.
|
|
297
306
|
preprocessors : list
|
|
298
307
|
List of preprocessors.
|
|
@@ -300,12 +309,12 @@ def _set_preproc_kwargs(ds, preprocessors):
|
|
|
300
309
|
preproc_kwargs = _get_preproc_kwargs(preprocessors)
|
|
301
310
|
if isinstance(ds, WindowsDataset):
|
|
302
311
|
kind = "window"
|
|
303
|
-
|
|
312
|
+
elif isinstance(ds, EEGWindowsDataset):
|
|
304
313
|
kind = "raw"
|
|
305
|
-
elif isinstance(ds,
|
|
314
|
+
elif isinstance(ds, RawDataset):
|
|
306
315
|
kind = "raw"
|
|
307
316
|
else:
|
|
308
|
-
raise TypeError(f"ds must be a
|
|
317
|
+
raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
|
|
309
318
|
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
|
|
310
319
|
|
|
311
320
|
|
|
@@ -25,7 +25,12 @@ import pandas as pd
|
|
|
25
25
|
from joblib import Parallel, delayed
|
|
26
26
|
from numpy.typing import ArrayLike
|
|
27
27
|
|
|
28
|
-
from ..datasets.base import
|
|
28
|
+
from ..datasets.base import (
|
|
29
|
+
BaseConcatDataset,
|
|
30
|
+
EEGWindowsDataset,
|
|
31
|
+
RawDataset,
|
|
32
|
+
WindowsDataset,
|
|
33
|
+
)
|
|
29
34
|
|
|
30
35
|
|
|
31
36
|
class _LazyDataFrame:
|
|
@@ -189,7 +194,7 @@ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
|
|
|
189
194
|
|
|
190
195
|
# XXX it's called concat_ds...
|
|
191
196
|
def create_windows_from_events(
|
|
192
|
-
concat_ds: BaseConcatDataset,
|
|
197
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
193
198
|
trial_start_offset_samples: int = 0,
|
|
194
199
|
trial_stop_offset_samples: int = 0,
|
|
195
200
|
window_size_samples: int | None = None,
|
|
@@ -206,7 +211,7 @@ def create_windows_from_events(
|
|
|
206
211
|
use_mne_epochs: bool | None = None,
|
|
207
212
|
n_jobs: int = 1,
|
|
208
213
|
verbose: bool | str | int | None = "error",
|
|
209
|
-
):
|
|
214
|
+
) -> BaseConcatDataset[WindowsDataset | EEGWindowsDataset]:
|
|
210
215
|
"""Create windows based on events in mne.Raw.
|
|
211
216
|
|
|
212
217
|
This function extracts windows of size window_size_samples in the interval
|
|
@@ -228,7 +233,7 @@ def create_windows_from_events(
|
|
|
228
233
|
|
|
229
234
|
Parameters
|
|
230
235
|
----------
|
|
231
|
-
concat_ds: BaseConcatDataset
|
|
236
|
+
concat_ds: BaseConcatDataset[RawDataset]
|
|
232
237
|
A concat of base datasets each holding raw and description.
|
|
233
238
|
trial_start_offset_samples: int
|
|
234
239
|
Start offset from original trial onsets, in samples. Defaults to zero.
|
|
@@ -268,7 +273,7 @@ def create_windows_from_events(
|
|
|
268
273
|
rejection based on flatness is done. See mne.Epochs.
|
|
269
274
|
on_missing: str
|
|
270
275
|
What to do if one or several event ids are not found in the recording.
|
|
271
|
-
Valid keys are ‘error
|
|
276
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
272
277
|
accepted_bads_ratio: float, optional
|
|
273
278
|
Acceptable proportion of trials with inconsistent length in a raw. If
|
|
274
279
|
the number of trials whose length is exceeded by the window size is
|
|
@@ -286,7 +291,7 @@ def create_windows_from_events(
|
|
|
286
291
|
|
|
287
292
|
Returns
|
|
288
293
|
-------
|
|
289
|
-
windows_datasets: BaseConcatDataset
|
|
294
|
+
windows_datasets: BaseConcatDataset[WindowsDataset | EEGWindowsDataset]
|
|
290
295
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
291
296
|
"""
|
|
292
297
|
_check_windowing_arguments(
|
|
@@ -341,7 +346,7 @@ def create_windows_from_events(
|
|
|
341
346
|
|
|
342
347
|
|
|
343
348
|
def create_fixed_length_windows(
|
|
344
|
-
concat_ds: BaseConcatDataset,
|
|
349
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
345
350
|
start_offset_samples: int = 0,
|
|
346
351
|
stop_offset_samples: int | None = None,
|
|
347
352
|
window_size_samples: int | None = None,
|
|
@@ -358,12 +363,12 @@ def create_fixed_length_windows(
|
|
|
358
363
|
on_missing: str = "error",
|
|
359
364
|
n_jobs: int = 1,
|
|
360
365
|
verbose: bool | str | int | None = "error",
|
|
361
|
-
):
|
|
366
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
362
367
|
"""Windower that creates sliding windows.
|
|
363
368
|
|
|
364
369
|
Parameters
|
|
365
370
|
----------
|
|
366
|
-
concat_ds: ConcatDataset
|
|
371
|
+
concat_ds: ConcatDataset[RawDataset]
|
|
367
372
|
A concat of base datasets each holding raw and description.
|
|
368
373
|
start_offset_samples: int
|
|
369
374
|
Start offset from beginning of recording in samples.
|
|
@@ -398,7 +403,7 @@ def create_fixed_length_windows(
|
|
|
398
403
|
by using the _LazyDataFrame (experimental).
|
|
399
404
|
on_missing: str
|
|
400
405
|
What to do if one or several event ids are not found in the recording.
|
|
401
|
-
Valid keys are ‘error
|
|
406
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
402
407
|
n_jobs: int
|
|
403
408
|
Number of jobs to use to parallelize the windowing.
|
|
404
409
|
verbose: bool | str | int | None
|
|
@@ -406,7 +411,7 @@ def create_fixed_length_windows(
|
|
|
406
411
|
|
|
407
412
|
Returns
|
|
408
413
|
-------
|
|
409
|
-
windows_datasets: BaseConcatDataset
|
|
414
|
+
windows_datasets: BaseConcatDataset[EEGWindowsDataset]
|
|
410
415
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
411
416
|
"""
|
|
412
417
|
stop_offset_samples, drop_last_window = (
|
|
@@ -473,11 +478,11 @@ def _create_windows_from_events(
|
|
|
473
478
|
verbose="error",
|
|
474
479
|
use_mne_epochs=False,
|
|
475
480
|
):
|
|
476
|
-
"""Create WindowsDataset from
|
|
481
|
+
"""Create WindowsDataset from RawDataset based on events.
|
|
477
482
|
|
|
478
483
|
Parameters
|
|
479
484
|
----------
|
|
480
|
-
ds :
|
|
485
|
+
ds : RawDataset
|
|
481
486
|
Dataset containing continuous data and description.
|
|
482
487
|
infer_mapping : bool
|
|
483
488
|
If True, extract all events from all datasets and map them to
|
|
@@ -648,11 +653,11 @@ def _create_fixed_length_windows(
|
|
|
648
653
|
on_missing="error",
|
|
649
654
|
verbose="error",
|
|
650
655
|
):
|
|
651
|
-
"""Create WindowsDataset from
|
|
656
|
+
"""Create WindowsDataset from RawDataset with sliding windows.
|
|
652
657
|
|
|
653
658
|
Parameters
|
|
654
659
|
----------
|
|
655
|
-
ds :
|
|
660
|
+
ds : RawDataset
|
|
656
661
|
Dataset containing continuous data and description.
|
|
657
662
|
|
|
658
663
|
See `create_fixed_length_windows` for description of other parameters.
|
|
@@ -750,7 +755,7 @@ def _create_fixed_length_windows(
|
|
|
750
755
|
|
|
751
756
|
|
|
752
757
|
def create_windows_from_target_channels(
|
|
753
|
-
concat_ds,
|
|
758
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
754
759
|
window_size_samples=None,
|
|
755
760
|
preload=False,
|
|
756
761
|
picks=None,
|
|
@@ -759,7 +764,7 @@ def create_windows_from_target_channels(
|
|
|
759
764
|
n_jobs=1,
|
|
760
765
|
last_target_only=True,
|
|
761
766
|
verbose="error",
|
|
762
|
-
):
|
|
767
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
763
768
|
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
764
769
|
delayed(_create_windows_from_target_channels)(
|
|
765
770
|
ds,
|
|
@@ -788,11 +793,11 @@ def _create_windows_from_target_channels(
|
|
|
788
793
|
on_missing="error",
|
|
789
794
|
verbose="error",
|
|
790
795
|
):
|
|
791
|
-
"""Create WindowsDataset from
|
|
796
|
+
"""Create WindowsDataset from RawDataset using targets `misc` channels from mne.Raw.
|
|
792
797
|
|
|
793
798
|
Parameters
|
|
794
799
|
----------
|
|
795
|
-
ds :
|
|
800
|
+
ds : RawDataset
|
|
796
801
|
Dataset containing continuous data and description.
|
|
797
802
|
|
|
798
803
|
See `create_fixed_length_windows` for description of other parameters.
|
braindecode/samplers/base.py
CHANGED
|
@@ -122,14 +122,14 @@ class DistributedRecordingSampler(DistributedSampler):
|
|
|
122
122
|
DataFrame with at least one of {subject, session, run} columns for each
|
|
123
123
|
window in the BaseConcatDataset to sample examples from. Normally
|
|
124
124
|
obtained with `BaseConcatDataset.get_metadata()`. For instance,
|
|
125
|
-
`metadata.head()` might look like this
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
125
|
+
`metadata.head()` might look like this::
|
|
126
|
+
|
|
127
|
+
i_window_in_trial i_start_in_trial i_stop_in_trial target subject session run
|
|
128
|
+
0 0 0 500 -1 4 session_T run_0
|
|
129
|
+
1 1 500 1000 -1 4 session_T run_0
|
|
130
|
+
2 2 1000 1500 -1 4 session_T run_0
|
|
131
|
+
3 3 1500 2000 -1 4 session_T run_0
|
|
132
|
+
4 4 2000 2500 -1 4 session_T run_0
|
|
133
133
|
|
|
134
134
|
random_state : np.RandomState | int | None
|
|
135
135
|
Random state.
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.3.0.
|
|
1
|
+
__version__ = "1.3.0.dev173767962"
|
{braindecode-1.3.0.dev173691341.dist-info → braindecode-1.3.0.dev173767962.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.0.
|
|
3
|
+
Version: 1.3.0.dev173767962
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -40,6 +40,8 @@ Requires-Dist: linear_attention_transformer
|
|
|
40
40
|
Requires-Dist: docstring_inheritance
|
|
41
41
|
Provides-Extra: moabb
|
|
42
42
|
Requires-Dist: moabb>=1.2.0; extra == "moabb"
|
|
43
|
+
Provides-Extra: hug
|
|
44
|
+
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hug"
|
|
43
45
|
Provides-Extra: tests
|
|
44
46
|
Requires-Dist: pytest; extra == "tests"
|
|
45
47
|
Requires-Dist: pytest-cov; extra == "tests"
|
|
@@ -65,7 +67,7 @@ Requires-Dist: pre-commit; extra == "docs"
|
|
|
65
67
|
Requires-Dist: openneuro-py; extra == "docs"
|
|
66
68
|
Requires-Dist: plotly; extra == "docs"
|
|
67
69
|
Provides-Extra: all
|
|
68
|
-
Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
|
|
70
|
+
Requires-Dist: braindecode[docs,hug,moabb,tests]; extra == "all"
|
|
69
71
|
Dynamic: license-file
|
|
70
72
|
|
|
71
73
|
.. image:: https://badges.gitter.im/braindecodechat/community.svg
|