braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/modules/attention.py
CHANGED
|
@@ -35,16 +35,6 @@ class SqueezeAndExcitation(nn.Module):
|
|
|
35
35
|
bias: bool, default=False
|
|
36
36
|
if True, adds a learnable bias will be used in the convolution.
|
|
37
37
|
|
|
38
|
-
Examples
|
|
39
|
-
--------
|
|
40
|
-
>>> import torch
|
|
41
|
-
>>> from braindecode.modules import SqueezeAndExcitation
|
|
42
|
-
>>> module = SqueezeAndExcitation(in_channels=16, reduction_rate=4)
|
|
43
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
44
|
-
>>> outputs = module(inputs)
|
|
45
|
-
>>> outputs.shape
|
|
46
|
-
torch.Size([2, 16, 1, 64])
|
|
47
|
-
|
|
48
38
|
References
|
|
49
39
|
----------
|
|
50
40
|
.. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
|
|
@@ -100,16 +90,6 @@ class GSoP(nn.Module):
|
|
|
100
90
|
bias: bool, default=False
|
|
101
91
|
if True, adds a learnable bias will be used in the convolution.
|
|
102
92
|
|
|
103
|
-
Examples
|
|
104
|
-
--------
|
|
105
|
-
>>> import torch
|
|
106
|
-
>>> from braindecode.modules import GSoP
|
|
107
|
-
>>> module = GSoP(in_channels=16, reduction_rate=4)
|
|
108
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
109
|
-
>>> outputs = module(inputs)
|
|
110
|
-
>>> outputs.shape
|
|
111
|
-
torch.Size([2, 16, 1, 64])
|
|
112
|
-
|
|
113
93
|
References
|
|
114
94
|
----------
|
|
115
95
|
.. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
|
|
@@ -166,16 +146,6 @@ class FCA(nn.Module):
|
|
|
166
146
|
reduction_rate : int, default=4
|
|
167
147
|
Reduction ratio of the fully-connected layers.
|
|
168
148
|
|
|
169
|
-
Examples
|
|
170
|
-
--------
|
|
171
|
-
>>> import torch
|
|
172
|
-
>>> from braindecode.modules import FCA
|
|
173
|
-
>>> module = FCA(in_channels=16, seq_len=64, reduction_rate=4, freq_idx=0)
|
|
174
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
175
|
-
>>> outputs = module(inputs)
|
|
176
|
-
>>> outputs.shape
|
|
177
|
-
torch.Size([2, 16, 1, 64])
|
|
178
|
-
|
|
179
149
|
References
|
|
180
150
|
----------
|
|
181
151
|
.. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
|
|
@@ -260,16 +230,6 @@ class EncNet(nn.Module):
|
|
|
260
230
|
n_codewords : int
|
|
261
231
|
number of codewords
|
|
262
232
|
|
|
263
|
-
Examples
|
|
264
|
-
--------
|
|
265
|
-
>>> import torch
|
|
266
|
-
>>> from braindecode.modules import EncNet
|
|
267
|
-
>>> module = EncNet(in_channels=16, n_codewords=8)
|
|
268
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
269
|
-
>>> outputs = module(inputs)
|
|
270
|
-
>>> outputs.shape
|
|
271
|
-
torch.Size([2, 16, 1, 64])
|
|
272
|
-
|
|
273
233
|
References
|
|
274
234
|
----------
|
|
275
235
|
.. [Zhang2018] Zhang, H. et al. 2018.
|
|
@@ -327,16 +287,6 @@ class ECA(nn.Module):
|
|
|
327
287
|
kernel size of convolutional layer, determines degree of channel
|
|
328
288
|
interaction, must be odd.
|
|
329
289
|
|
|
330
|
-
Examples
|
|
331
|
-
--------
|
|
332
|
-
>>> import torch
|
|
333
|
-
>>> from braindecode.modules import ECA
|
|
334
|
-
>>> module = ECA(in_channels=16, kernel_size=3)
|
|
335
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
336
|
-
>>> outputs = module(inputs)
|
|
337
|
-
>>> outputs.shape
|
|
338
|
-
torch.Size([2, 16, 1, 64])
|
|
339
|
-
|
|
340
290
|
References
|
|
341
291
|
----------
|
|
342
292
|
.. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
|
|
@@ -388,16 +338,6 @@ class GatherExcite(nn.Module):
|
|
|
388
338
|
reduction_rate : int, default=4
|
|
389
339
|
reduction ratio of the excite block (if used)
|
|
390
340
|
|
|
391
|
-
Examples
|
|
392
|
-
--------
|
|
393
|
-
>>> import torch
|
|
394
|
-
>>> from braindecode.modules import GatherExcite
|
|
395
|
-
>>> module = GatherExcite(in_channels=16, seq_len=64, extra_params=False, use_mlp=True)
|
|
396
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
397
|
-
>>> outputs = module(inputs)
|
|
398
|
-
>>> outputs.shape
|
|
399
|
-
torch.Size([2, 16, 1, 64])
|
|
400
|
-
|
|
401
341
|
References
|
|
402
342
|
----------
|
|
403
343
|
.. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
|
|
@@ -467,16 +407,6 @@ class GCT(nn.Module):
|
|
|
467
407
|
in_channels : int
|
|
468
408
|
number of input feature channels
|
|
469
409
|
|
|
470
|
-
Examples
|
|
471
|
-
--------
|
|
472
|
-
>>> import torch
|
|
473
|
-
>>> from braindecode.modules import GCT
|
|
474
|
-
>>> module = GCT(in_channels=16)
|
|
475
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
476
|
-
>>> outputs = module(inputs)
|
|
477
|
-
>>> outputs.shape
|
|
478
|
-
torch.Size([2, 16, 1, 64])
|
|
479
|
-
|
|
480
410
|
References
|
|
481
411
|
----------
|
|
482
412
|
.. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
|
|
@@ -522,16 +452,6 @@ class SRM(nn.Module):
|
|
|
522
452
|
reduction_rate : int, default=4
|
|
523
453
|
reduction ratio of the fully-connected layers (if used),
|
|
524
454
|
|
|
525
|
-
Examples
|
|
526
|
-
--------
|
|
527
|
-
>>> import torch
|
|
528
|
-
>>> from braindecode.modules import SRM
|
|
529
|
-
>>> module = SRM(in_channels=16, use_mlp=False)
|
|
530
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
531
|
-
>>> outputs = module(inputs)
|
|
532
|
-
>>> outputs.shape
|
|
533
|
-
torch.Size([2, 16, 1, 64])
|
|
534
|
-
|
|
535
455
|
References
|
|
536
456
|
----------
|
|
537
457
|
.. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
|
|
@@ -597,16 +517,6 @@ class CBAM(nn.Module):
|
|
|
597
517
|
kernel_size : int
|
|
598
518
|
kernel size of the convolutional layer
|
|
599
519
|
|
|
600
|
-
Examples
|
|
601
|
-
--------
|
|
602
|
-
>>> import torch
|
|
603
|
-
>>> from braindecode.modules import CBAM
|
|
604
|
-
>>> module = CBAM(in_channels=16, reduction_rate=4, kernel_size=3)
|
|
605
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
606
|
-
>>> outputs = module(inputs)
|
|
607
|
-
>>> outputs.shape
|
|
608
|
-
torch.Size([2, 16, 1, 64])
|
|
609
|
-
|
|
610
520
|
References
|
|
611
521
|
----------
|
|
612
522
|
.. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
|
|
@@ -665,16 +575,6 @@ class CAT(nn.Module):
|
|
|
665
575
|
bias : bool, default=False
|
|
666
576
|
if True, adds a learnable bias will be used in the convolution,
|
|
667
577
|
|
|
668
|
-
Examples
|
|
669
|
-
--------
|
|
670
|
-
>>> import torch
|
|
671
|
-
>>> from braindecode.modules import CAT
|
|
672
|
-
>>> module = CAT(in_channels=16, reduction_rate=4, kernel_size=3)
|
|
673
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
674
|
-
>>> outputs = module(inputs)
|
|
675
|
-
>>> outputs.shape
|
|
676
|
-
torch.Size([2, 16, 1, 64])
|
|
677
|
-
|
|
678
578
|
References
|
|
679
579
|
----------
|
|
680
580
|
.. [Wu2023] Wu, Z. et al., 2023
|
|
@@ -771,16 +671,6 @@ class CATLite(nn.Module):
|
|
|
771
671
|
bias : bool, default=True
|
|
772
672
|
if True, adds a learnable bias will be used in the convolution,
|
|
773
673
|
|
|
774
|
-
Examples
|
|
775
|
-
--------
|
|
776
|
-
>>> import torch
|
|
777
|
-
>>> from braindecode.modules import CATLite
|
|
778
|
-
>>> module = CATLite(in_channels=16, reduction_rate=4)
|
|
779
|
-
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
780
|
-
>>> outputs = module(inputs)
|
|
781
|
-
>>> outputs.shape
|
|
782
|
-
torch.Size([2, 16, 1, 64])
|
|
783
|
-
|
|
784
674
|
References
|
|
785
675
|
----------
|
|
786
676
|
.. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
|
|
@@ -834,19 +724,6 @@ class CATLite(nn.Module):
|
|
|
834
724
|
|
|
835
725
|
|
|
836
726
|
class MultiHeadAttention(nn.Module):
|
|
837
|
-
"""Multi-head self-attention block.
|
|
838
|
-
|
|
839
|
-
Examples
|
|
840
|
-
--------
|
|
841
|
-
>>> import torch
|
|
842
|
-
>>> from braindecode.modules import MultiHeadAttention
|
|
843
|
-
>>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1)
|
|
844
|
-
>>> inputs = torch.randn(2, 10, 32)
|
|
845
|
-
>>> outputs = module(inputs)
|
|
846
|
-
>>> outputs.shape
|
|
847
|
-
torch.Size([2, 10, 32])
|
|
848
|
-
"""
|
|
849
|
-
|
|
850
727
|
def __init__(self, emb_size, num_heads, dropout):
|
|
851
728
|
super().__init__()
|
|
852
729
|
self.emb_size = emb_size
|
braindecode/modules/blocks.py
CHANGED
|
@@ -14,22 +14,6 @@ class InceptionBlock(nn.Module):
|
|
|
14
14
|
----------
|
|
15
15
|
branches : list of nn.Module
|
|
16
16
|
List of convolutional branches to apply to the input.
|
|
17
|
-
|
|
18
|
-
Examples
|
|
19
|
-
--------
|
|
20
|
-
>>> import torch
|
|
21
|
-
>>> from torch import nn
|
|
22
|
-
>>> from braindecode.modules import InceptionBlock
|
|
23
|
-
>>> block = InceptionBlock(
|
|
24
|
-
... [
|
|
25
|
-
... nn.Conv1d(3, 4, kernel_size=1),
|
|
26
|
-
... nn.Conv1d(3, 4, kernel_size=3, padding=1),
|
|
27
|
-
... ]
|
|
28
|
-
... )
|
|
29
|
-
>>> inputs = torch.randn(2, 3, 100)
|
|
30
|
-
>>> outputs = block(inputs)
|
|
31
|
-
>>> outputs.shape
|
|
32
|
-
torch.Size([2, 8, 100])
|
|
33
17
|
"""
|
|
34
18
|
|
|
35
19
|
def __init__(self, branches):
|
|
@@ -71,16 +55,6 @@ class MLP(nn.Sequential):
|
|
|
71
55
|
Dropout rate.
|
|
72
56
|
normalize: bool (default=False)
|
|
73
57
|
Whether to apply layer normalization.
|
|
74
|
-
|
|
75
|
-
Examples
|
|
76
|
-
--------
|
|
77
|
-
>>> import torch
|
|
78
|
-
>>> from braindecode.modules import MLP
|
|
79
|
-
>>> module = MLP(in_features=32, hidden_features=(64,), out_features=16)
|
|
80
|
-
>>> inputs = torch.randn(2, 10, 32)
|
|
81
|
-
>>> outputs = module(inputs)
|
|
82
|
-
>>> outputs.shape
|
|
83
|
-
torch.Size([2, 10, 16])
|
|
84
58
|
"""
|
|
85
59
|
|
|
86
60
|
def __init__(
|
|
@@ -125,33 +99,7 @@ class MLP(nn.Sequential):
|
|
|
125
99
|
|
|
126
100
|
|
|
127
101
|
class FeedForwardBlock(nn.Sequential):
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
Parameters
|
|
131
|
-
----------
|
|
132
|
-
emb_size : int
|
|
133
|
-
Embedding dimension.
|
|
134
|
-
expansion : int
|
|
135
|
-
Expansion factor for the hidden layer size.
|
|
136
|
-
drop_p : float
|
|
137
|
-
Dropout probability.
|
|
138
|
-
activation : type[nn.Module], default=nn.GELU
|
|
139
|
-
Activation function constructor.
|
|
140
|
-
|
|
141
|
-
Examples
|
|
142
|
-
--------
|
|
143
|
-
>>> import torch
|
|
144
|
-
>>> from braindecode.modules import FeedForwardBlock
|
|
145
|
-
>>> module = FeedForwardBlock(emb_size=32, expansion=2, drop_p=0.1)
|
|
146
|
-
>>> inputs = torch.randn(2, 10, 32)
|
|
147
|
-
>>> outputs = module(inputs)
|
|
148
|
-
>>> outputs.shape
|
|
149
|
-
torch.Size([2, 10, 32])
|
|
150
|
-
"""
|
|
151
|
-
|
|
152
|
-
def __init__(
|
|
153
|
-
self, emb_size, expansion, drop_p, activation: type[nn.Module] = nn.GELU
|
|
154
|
-
):
|
|
102
|
+
def __init__(self, emb_size, expansion, drop_p, activation: nn.Module = nn.GELU):
|
|
155
103
|
super().__init__(
|
|
156
104
|
nn.Linear(emb_size, expansion * emb_size),
|
|
157
105
|
activation(),
|
|
@@ -25,16 +25,6 @@ class AvgPool2dWithConv(nn.Module):
|
|
|
25
25
|
Dilation applied to the pooling filter.
|
|
26
26
|
padding: int or (int,int)
|
|
27
27
|
Padding applied before the pooling operation.
|
|
28
|
-
|
|
29
|
-
Examples
|
|
30
|
-
--------
|
|
31
|
-
>>> import torch
|
|
32
|
-
>>> from braindecode.modules import AvgPool2dWithConv
|
|
33
|
-
>>> module = AvgPool2dWithConv(kernel_size=(1, 4), stride=(1, 4))
|
|
34
|
-
>>> inputs = torch.randn(2, 4, 1, 16)
|
|
35
|
-
>>> outputs = module(inputs)
|
|
36
|
-
>>> outputs.shape
|
|
37
|
-
torch.Size([2, 4, 1, 4])
|
|
38
28
|
"""
|
|
39
29
|
|
|
40
30
|
def __init__(self, kernel_size, stride, dilation=1, padding=0):
|
|
@@ -83,19 +73,6 @@ class AvgPool2dWithConv(nn.Module):
|
|
|
83
73
|
|
|
84
74
|
|
|
85
75
|
class Conv2dWithConstraint(nn.Conv2d):
|
|
86
|
-
"""2D convolution with max-norm constraint on the weights.
|
|
87
|
-
|
|
88
|
-
Examples
|
|
89
|
-
--------
|
|
90
|
-
>>> import torch
|
|
91
|
-
>>> from braindecode.modules import Conv2dWithConstraint
|
|
92
|
-
>>> module = Conv2dWithConstraint(4, 8, kernel_size=(1, 3), padding=(0, 1), bias=False)
|
|
93
|
-
>>> inputs = torch.randn(2, 4, 1, 64)
|
|
94
|
-
>>> outputs = module(inputs)
|
|
95
|
-
>>> outputs.shape
|
|
96
|
-
torch.Size([2, 8, 1, 64])
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
76
|
def __init__(self, *args, max_norm=1, **kwargs):
|
|
100
77
|
super().__init__(*args, **kwargs)
|
|
101
78
|
self.max_norm = max_norm
|
|
@@ -124,16 +101,6 @@ class CombinedConv(nn.Module):
|
|
|
124
101
|
bias_spat: bool
|
|
125
102
|
Whether to use bias in the spatial conv
|
|
126
103
|
|
|
127
|
-
Examples
|
|
128
|
-
--------
|
|
129
|
-
>>> import torch
|
|
130
|
-
>>> from braindecode.modules import CombinedConv
|
|
131
|
-
>>> module = CombinedConv(in_chans=8, n_filters_time=4, n_filters_spat=4, filter_time_length=5)
|
|
132
|
-
>>> inputs = torch.randn(2, 1, 100, 8)
|
|
133
|
-
>>> outputs = module(inputs)
|
|
134
|
-
>>> outputs.shape
|
|
135
|
-
torch.Size([2, 4, 96, 1])
|
|
136
|
-
|
|
137
104
|
"""
|
|
138
105
|
|
|
139
106
|
def __init__(
|
|
@@ -215,16 +182,6 @@ class CausalConv1d(nn.Conv1d):
|
|
|
215
182
|
----------
|
|
216
183
|
.. [1] https://discuss.pytorch.org/t/causal-convolution/3456/4
|
|
217
184
|
.. [2] https://gist.github.com/paultsw/7a9d6e3ce7b70e9e2c61bc9287addefc
|
|
218
|
-
|
|
219
|
-
Examples
|
|
220
|
-
--------
|
|
221
|
-
>>> import torch
|
|
222
|
-
>>> from braindecode.modules import CausalConv1d
|
|
223
|
-
>>> module = CausalConv1d(in_channels=4, out_channels=8, kernel_size=5, dilation=2)
|
|
224
|
-
>>> inputs = torch.randn(2, 4, 128)
|
|
225
|
-
>>> outputs = module(inputs)
|
|
226
|
-
>>> outputs.shape
|
|
227
|
-
torch.Size([2, 8, 128])
|
|
228
185
|
"""
|
|
229
186
|
|
|
230
187
|
def __init__(
|
|
@@ -293,16 +250,6 @@ class DepthwiseConv2d(torch.nn.Conv2d):
|
|
|
293
250
|
Padding mode to use. Options are 'zeros', 'reflect', 'replicate', or
|
|
294
251
|
'circular'.
|
|
295
252
|
Default is 'zeros'.
|
|
296
|
-
|
|
297
|
-
Examples
|
|
298
|
-
--------
|
|
299
|
-
>>> import torch
|
|
300
|
-
>>> from braindecode.modules import DepthwiseConv2d
|
|
301
|
-
>>> module = DepthwiseConv2d(in_channels=4, depth_multiplier=2, kernel_size=3, padding=1)
|
|
302
|
-
>>> inputs = torch.randn(2, 4, 1, 64)
|
|
303
|
-
>>> outputs = module(inputs)
|
|
304
|
-
>>> outputs.shape
|
|
305
|
-
torch.Size([2, 8, 1, 64])
|
|
306
253
|
"""
|
|
307
254
|
|
|
308
255
|
def __init__(
|
braindecode/modules/filter.py
CHANGED
|
@@ -113,22 +113,6 @@ class FilterBankLayer(nn.Module):
|
|
|
113
113
|
Control verbosity of the logging output. If ``None``, use the default
|
|
114
114
|
verbosity level. See the func:`mne.verbose` for details.
|
|
115
115
|
Should only be passed as a keyword argument.
|
|
116
|
-
|
|
117
|
-
Examples
|
|
118
|
-
--------
|
|
119
|
-
>>> import torch
|
|
120
|
-
>>> from braindecode.modules import FilterBankLayer
|
|
121
|
-
>>> module = FilterBankLayer(
|
|
122
|
-
... n_chans=2,
|
|
123
|
-
... sfreq=128,
|
|
124
|
-
... band_filters=[(4.0, 8.0), (8.0, 12.0)],
|
|
125
|
-
... method="fir",
|
|
126
|
-
... verbose=False,
|
|
127
|
-
... )
|
|
128
|
-
>>> inputs = torch.randn(2, 2, 256)
|
|
129
|
-
>>> outputs = module(inputs)
|
|
130
|
-
>>> outputs.shape
|
|
131
|
-
torch.Size([2, 2, 2, 256])
|
|
132
116
|
"""
|
|
133
117
|
|
|
134
118
|
def __init__(
|
|
@@ -417,21 +401,6 @@ class GeneralizedGaussianFilter(nn.Module):
|
|
|
417
401
|
Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
|
|
418
402
|
Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
|
|
419
403
|
|
|
420
|
-
Examples
|
|
421
|
-
--------
|
|
422
|
-
>>> import torch
|
|
423
|
-
>>> from braindecode.modules import GeneralizedGaussianFilter
|
|
424
|
-
>>> module = GeneralizedGaussianFilter(
|
|
425
|
-
... in_channels=2,
|
|
426
|
-
... out_channels=2,
|
|
427
|
-
... sequence_length=256,
|
|
428
|
-
... sample_rate=128,
|
|
429
|
-
... inverse_fourier=True,
|
|
430
|
-
... )
|
|
431
|
-
>>> inputs = torch.randn(3, 2, 256)
|
|
432
|
-
>>> outputs = module(inputs)
|
|
433
|
-
>>> outputs.shape
|
|
434
|
-
torch.Size([3, 2, 256])
|
|
435
404
|
|
|
436
405
|
Notes
|
|
437
406
|
-----
|
braindecode/modules/layers.py
CHANGED
|
@@ -11,21 +11,6 @@ from braindecode.functional import drop_path
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class Ensure4d(nn.Module):
|
|
14
|
-
"""Ensure the input tensor has 4 dimensions.
|
|
15
|
-
|
|
16
|
-
This is a small utility layer that repeatedly adds a singleton dimension at
|
|
17
|
-
the end until the input has shape ``(batch, channels, time, 1)``.
|
|
18
|
-
|
|
19
|
-
Examples
|
|
20
|
-
--------
|
|
21
|
-
>>> import torch
|
|
22
|
-
>>> from braindecode.modules import Ensure4d
|
|
23
|
-
>>> module = Ensure4d()
|
|
24
|
-
>>> outputs = module(torch.randn(2, 3, 10))
|
|
25
|
-
>>> outputs.shape
|
|
26
|
-
torch.Size([2, 3, 10, 1])
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
14
|
def forward(self, x):
|
|
30
15
|
while len(x.shape) < 4:
|
|
31
16
|
x = x.unsqueeze(-1)
|
|
@@ -33,19 +18,6 @@ class Ensure4d(nn.Module):
|
|
|
33
18
|
|
|
34
19
|
|
|
35
20
|
class Chomp1d(nn.Module):
|
|
36
|
-
"""Remove samples from the end of a sequence.
|
|
37
|
-
|
|
38
|
-
Examples
|
|
39
|
-
--------
|
|
40
|
-
>>> import torch
|
|
41
|
-
>>> from braindecode.modules import Chomp1d
|
|
42
|
-
>>> module = Chomp1d(chomp_size=5)
|
|
43
|
-
>>> inputs = torch.randn(2, 3, 20)
|
|
44
|
-
>>> outputs = module(inputs)
|
|
45
|
-
>>> outputs.shape
|
|
46
|
-
torch.Size([2, 3, 15])
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
21
|
def __init__(self, chomp_size):
|
|
50
22
|
super().__init__()
|
|
51
23
|
self.chomp_size = chomp_size
|
|
@@ -71,17 +43,6 @@ class TimeDistributed(nn.Module):
|
|
|
71
43
|
module : nn.Module
|
|
72
44
|
Module to be applied to the input windows. Must accept an input of
|
|
73
45
|
shape (batch_size, n_channels, n_times).
|
|
74
|
-
|
|
75
|
-
Examples
|
|
76
|
-
--------
|
|
77
|
-
>>> import torch
|
|
78
|
-
>>> from torch import nn
|
|
79
|
-
>>> from braindecode.modules import TimeDistributed
|
|
80
|
-
>>> module = TimeDistributed(nn.Conv1d(3, 4, kernel_size=3, padding=1))
|
|
81
|
-
>>> inputs = torch.randn(2, 5, 3, 20)
|
|
82
|
-
>>> outputs = module(inputs)
|
|
83
|
-
>>> outputs.shape
|
|
84
|
-
torch.Size([2, 5, 4])
|
|
85
46
|
"""
|
|
86
47
|
|
|
87
48
|
def __init__(self, module):
|
|
@@ -130,17 +91,6 @@ class DropPath(nn.Module):
|
|
|
130
91
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
131
92
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
132
93
|
SOFTWARE.
|
|
133
|
-
|
|
134
|
-
Examples
|
|
135
|
-
--------
|
|
136
|
-
>>> import torch
|
|
137
|
-
>>> from braindecode.modules import DropPath
|
|
138
|
-
>>> module = DropPath(drop_prob=0.2)
|
|
139
|
-
>>> module.train()
|
|
140
|
-
>>> inputs = torch.randn(2, 3, 10)
|
|
141
|
-
>>> outputs = module(inputs)
|
|
142
|
-
>>> outputs.shape
|
|
143
|
-
torch.Size([2, 3, 10])
|
|
144
94
|
"""
|
|
145
95
|
|
|
146
96
|
def __init__(self, drop_prob=None):
|
|
@@ -180,37 +130,3 @@ class SqueezeFinalOutput(nn.Module):
|
|
|
180
130
|
if x.shape[-1] == 1:
|
|
181
131
|
x = x.squeeze(-1)
|
|
182
132
|
return x
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
class SubjectLayers(nn.Module):
|
|
186
|
-
"""Per-subject linear transformation layer.
|
|
187
|
-
|
|
188
|
-
Applies subject-specific linear transformations to the input. Each subject
|
|
189
|
-
owns an independent weight matrix, enabling personalized feature
|
|
190
|
-
processing.
|
|
191
|
-
"""
|
|
192
|
-
|
|
193
|
-
def __init__(
|
|
194
|
-
self,
|
|
195
|
-
in_channels: int,
|
|
196
|
-
out_channels: int,
|
|
197
|
-
n_subjects: int,
|
|
198
|
-
init_id: bool = False,
|
|
199
|
-
):
|
|
200
|
-
super().__init__()
|
|
201
|
-
self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
|
|
202
|
-
if init_id:
|
|
203
|
-
if in_channels != out_channels:
|
|
204
|
-
raise AssertionError("init_id requires in_channels == out_channels")
|
|
205
|
-
self.weights.data[:] = torch.eye(in_channels)[None]
|
|
206
|
-
self.weights.data *= 1 / (in_channels**0.5)
|
|
207
|
-
|
|
208
|
-
def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor:
|
|
209
|
-
"""Apply the subject-specific linear transforms."""
|
|
210
|
-
_, C, D = self.weights.shape
|
|
211
|
-
weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
|
|
212
|
-
return torch.einsum("bct,bcd->bdt", x, weights)
|
|
213
|
-
|
|
214
|
-
def __repr__(self) -> str:
|
|
215
|
-
S, C, D = self.weights.shape
|
|
216
|
-
return f"SubjectLayers({C}, {D}, {S})"
|
braindecode/modules/linear.py
CHANGED
|
@@ -20,16 +20,6 @@ class MaxNormLinear(nn.Linear):
|
|
|
20
20
|
If set to ``False``, the layer will not learn an additive bias.
|
|
21
21
|
Default: ``True``.
|
|
22
22
|
|
|
23
|
-
Examples
|
|
24
|
-
--------
|
|
25
|
-
>>> import torch
|
|
26
|
-
>>> from braindecode.modules import MaxNormLinear
|
|
27
|
-
>>> module = MaxNormLinear(10, 5, max_norm_val=2)
|
|
28
|
-
>>> inputs = torch.randn(2, 10)
|
|
29
|
-
>>> outputs = module(inputs)
|
|
30
|
-
>>> outputs.shape
|
|
31
|
-
torch.Size([2, 5])
|
|
32
|
-
|
|
33
23
|
References
|
|
34
24
|
----------
|
|
35
25
|
.. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
|
|
@@ -51,18 +41,7 @@ class MaxNormLinear(nn.Linear):
|
|
|
51
41
|
|
|
52
42
|
|
|
53
43
|
class LinearWithConstraint(nn.Linear):
|
|
54
|
-
"""Linear layer with max-norm constraint on the weights.
|
|
55
|
-
|
|
56
|
-
Examples
|
|
57
|
-
--------
|
|
58
|
-
>>> import torch
|
|
59
|
-
>>> from braindecode.modules import LinearWithConstraint
|
|
60
|
-
>>> module = LinearWithConstraint(10, 5, max_norm=1.0)
|
|
61
|
-
>>> inputs = torch.randn(2, 10)
|
|
62
|
-
>>> outputs = module(inputs)
|
|
63
|
-
>>> outputs.shape
|
|
64
|
-
torch.Size([2, 5])
|
|
65
|
-
"""
|
|
44
|
+
"""Linear layer with max-norm constraint on the weights."""
|
|
66
45
|
|
|
67
46
|
def __init__(self, *args, max_norm=1.0, **kwargs):
|
|
68
47
|
super(LinearWithConstraint, self).__init__(*args, **kwargs)
|
braindecode/modules/stats.py
CHANGED
|
@@ -22,16 +22,6 @@ class StatLayer(nn.Module):
|
|
|
22
22
|
Used only for functions requiring clamping (e.g., log variance).
|
|
23
23
|
apply_log : bool, default=False
|
|
24
24
|
Whether to apply log after computation (used for LogVarLayer).
|
|
25
|
-
|
|
26
|
-
Examples
|
|
27
|
-
--------
|
|
28
|
-
>>> import torch
|
|
29
|
-
>>> from braindecode.modules import StatLayer
|
|
30
|
-
>>> module = StatLayer(stat_fn=torch.mean, dim=-1, keepdim=True)
|
|
31
|
-
>>> inputs = torch.randn(2, 3, 10)
|
|
32
|
-
>>> outputs = module(inputs)
|
|
33
|
-
>>> outputs.shape
|
|
34
|
-
torch.Size([2, 3, 1])
|
|
35
25
|
"""
|
|
36
26
|
|
|
37
27
|
def __init__(
|
braindecode/modules/util.py
CHANGED
|
@@ -71,15 +71,6 @@ def aggregate_probas(logits, n_windows_stride=1):
|
|
|
71
71
|
De Vos, M. (2018). Joint classification and prediction CNN framework
|
|
72
72
|
for automatic sleep stage classification. IEEE Transactions on
|
|
73
73
|
Biomedical Engineering, 66(5), 1285-1296.
|
|
74
|
-
|
|
75
|
-
Examples
|
|
76
|
-
--------
|
|
77
|
-
>>> import numpy as np
|
|
78
|
-
>>> from braindecode.modules import aggregate_probas
|
|
79
|
-
>>> logits = np.random.randn(3, 4, 5) # (n_sequences, n_classes, n_windows)
|
|
80
|
-
>>> probas = aggregate_probas(logits, n_windows_stride=1)
|
|
81
|
-
>>> probas.shape
|
|
82
|
-
(7, 4)
|
|
83
74
|
"""
|
|
84
75
|
log_probas = log_softmax(logits, axis=1)
|
|
85
76
|
return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
|
braindecode/modules/wrapper.py
CHANGED
|
@@ -10,16 +10,6 @@ class Expression(nn.Module):
|
|
|
10
10
|
expression_fn : callable
|
|
11
11
|
Should accept variable number of objects of type
|
|
12
12
|
`torch.autograd.Variable` to compute its output.
|
|
13
|
-
|
|
14
|
-
Examples
|
|
15
|
-
--------
|
|
16
|
-
>>> import torch
|
|
17
|
-
>>> from braindecode.modules import Expression
|
|
18
|
-
>>> module = Expression(lambda x: x**2)
|
|
19
|
-
>>> inputs = torch.randn(2, 3)
|
|
20
|
-
>>> outputs = module(inputs)
|
|
21
|
-
>>> outputs.shape
|
|
22
|
-
torch.Size([2, 3])
|
|
23
13
|
"""
|
|
24
14
|
|
|
25
15
|
def __init__(self, expression_fn):
|
|
@@ -59,13 +49,6 @@ class IntermediateOutputWrapper(nn.Module):
|
|
|
59
49
|
>>> model = Deep4Net()
|
|
60
50
|
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
|
|
61
51
|
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
|
|
62
|
-
|
|
63
|
-
>>> import torch
|
|
64
|
-
>>> base = torch.nn.Sequential(torch.nn.Linear(10, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2))
|
|
65
|
-
>>> wrapped = IntermediateOutputWrapper(to_select=["0", "2"], model=base)
|
|
66
|
-
>>> outputs = wrapped(torch.randn(4, 10))
|
|
67
|
-
>>> len(outputs)
|
|
68
|
-
2
|
|
69
52
|
"""
|
|
70
53
|
|
|
71
54
|
def __init__(self, to_select, model):
|
|
@@ -324,9 +324,6 @@ def _replace_inplace(concat_ds, new_concat_ds):
|
|
|
324
324
|
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
|
|
325
325
|
)
|
|
326
326
|
|
|
327
|
-
# Recompute cumulative_sizes after replacing datasets
|
|
328
|
-
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
|
|
329
|
-
|
|
330
327
|
|
|
331
328
|
def _preprocess(
|
|
332
329
|
ds: RecordDataset,
|