braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
1
  Model,Application,Type,Sampling Frequency (Hz),Hyperparameters,#Parameters,get_#Parameters,Categorization
2
- ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
3
- AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
2
+ ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
3
+ AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
4
4
  BDTCN,Normal Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)","Convolution,Recurrent"
5
- BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Foundation Model"
5
+ BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Large Brain Model"
6
6
  ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)",Convolution
7
- CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, num_heads=2, embed_dim=16)","Convolution,Attention/Transformer"
7
+ CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)","Convolution,Small Attention"
8
8
  Deep4Net,General,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
9
9
  DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)","Convolution,Recurrent"
10
- EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
10
+ EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
11
11
  EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)","Convolution"
12
12
  EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)","Convolution"
13
13
  EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)","Convolution,Recurrent"
@@ -17,18 +17,18 @@ EEGSym,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",299
17
17
  EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)","Convolution,Interpretability"
18
18
  EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)","Convolution"
19
19
  EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)","Convolution,Recurrent"
20
- Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Foundation Model"
21
- MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
20
+ Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Large Brain Model"
21
+ MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
22
22
  SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)","Convolution"
23
- SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
24
- SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
25
- SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
26
- SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
23
+ SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
24
+ SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
25
+ SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
26
+ SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
27
27
  SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Interpretability"
28
28
  ShallowFBCSPNet,General,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution"
29
29
  SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
30
30
  SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)","Convolution"
31
- AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Attention/Transformer"
31
+ AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Small Attention"
32
32
  SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)","Convolution"
33
33
  SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Interpretability"
34
34
  TSception,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSception(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Convolution"
@@ -38,10 +38,8 @@ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",118
38
38
  FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
39
39
  FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
40
40
  IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
41
- BrainModule,Speech Decoding,Classification,250,"n_chans, n_outputs, n_times, sfreq",6186909,"BrainModule(n_chans=64, n_outputs=29, n_times=160, sfreq=1000)","Convolution"
42
- PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Foundation Model"
43
- SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
44
- BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
45
- LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Foundation Model"
46
- MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
47
- REVE,General,Classification,200,"n_outputs, n_times, n_chans",69481476,"REVE(n_times=1000, n_outputs=4, n_chans=19)","Foundation Model,Attention/Transformer"
41
+ PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
42
+ SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
43
+ BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
44
+ LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Large Brain Model"
45
+ MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
10
  class SyncNet(EEGModuleMixin, nn.Module):
11
- r"""Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
11
+ """Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
12
 
13
13
  :bdg-warning:`Interpretability`
14
14
 
@@ -89,7 +89,7 @@ class SyncNet(EEGModuleMixin, nn.Module):
89
89
  num_filters=1,
90
90
  filter_width=40,
91
91
  pool_size=40,
92
- activation: type[nn.Module] = nn.ReLU,
92
+ activation: nn.Module = nn.ReLU,
93
93
  ampli_init_values: tuple[float, float] = (-0.05, 0.05),
94
94
  omega_init_values: tuple[float, float] = (0.0, 1.0),
95
95
  beta_init_values: tuple[float, float] = (0.0, 0.05),
braindecode/models/tcn.py CHANGED
@@ -12,7 +12,7 @@ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
12
12
 
13
13
 
14
14
  class BDTCN(EEGModuleMixin, nn.Module):
15
- r"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
15
+ """Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
16
16
 
17
17
  :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
18
18
 
@@ -57,7 +57,7 @@ class BDTCN(EEGModuleMixin, nn.Module):
57
57
  n_filters=30,
58
58
  kernel_size=5,
59
59
  drop_prob=0.5,
60
- activation: type[nn.Module] = nn.ReLU,
60
+ activation: nn.Module = nn.ReLU,
61
61
  ):
62
62
  super().__init__(
63
63
  n_outputs=n_outputs,
@@ -90,7 +90,7 @@ class BDTCN(EEGModuleMixin, nn.Module):
90
90
 
91
91
 
92
92
  class TCN(nn.Module):
93
- r"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
93
+ """Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
94
94
 
95
95
  See [Bai2018]_ for details.
96
96
 
@@ -126,7 +126,7 @@ class TCN(nn.Module):
126
126
  n_filters=30,
127
127
  kernel_size=5,
128
128
  drop_prob=0.5,
129
- activation: type[nn.Module] = nn.ReLU,
129
+ activation: nn.Module = nn.ReLU,
130
130
  ):
131
131
  super().__init__()
132
132
  self.mapping = {
@@ -221,7 +221,7 @@ class _TemporalBlock(nn.Module):
221
221
  dilation,
222
222
  padding,
223
223
  drop_prob,
224
- activation: type[nn.Module] = nn.ReLU,
224
+ activation: nn.Module = nn.ReLU,
225
225
  ):
226
226
  super().__init__()
227
227
  self.conv1 = weight_norm(
@@ -11,7 +11,7 @@ from braindecode.modules import Ensure4d
11
11
 
12
12
 
13
13
  class TIDNet(EEGModuleMixin, nn.Module):
14
- r"""Thinker Invariance DenseNet model from Kostas et al (2020) [TIDNet]_.
14
+ """Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
15
15
 
16
16
  :bdg-success:`Convolution`
17
17
 
@@ -85,7 +85,7 @@ class TIDNet(EEGModuleMixin, nn.Module):
85
85
  temp_span: float = 0.05,
86
86
  bottleneck: int = 3,
87
87
  summary: int = -1,
88
- activation: type[nn.Module] = nn.LeakyReLU,
88
+ activation: nn.Module = nn.LeakyReLU,
89
89
  ):
90
90
  super().__init__(
91
91
  n_outputs=n_outputs,
@@ -157,7 +157,7 @@ class _BatchNormZG(nn.BatchNorm2d):
157
157
 
158
158
 
159
159
  class _ConvBlock2D(nn.Module):
160
- r"""Implements Convolution block with order:
160
+ """Implements Convolution block with order:
161
161
  Convolution, dropout, activation, batch-norm
162
162
  """
163
163
 
@@ -13,7 +13,7 @@ from braindecode.models.base import EEGModuleMixin
13
13
 
14
14
 
15
15
  class TSception(EEGModuleMixin, nn.Module):
16
- r"""TSception model from Ding et al. (2020) from [ding2020]_.
16
+ """TSception model from Ding et al. (2020) from [ding2020]_.
17
17
 
18
18
  :bdg-success:`Convolution`
19
19
 
@@ -78,7 +78,7 @@ class TSception(EEGModuleMixin, nn.Module):
78
78
  number_filter_spat: int = 6,
79
79
  hidden_size: int = 128,
80
80
  drop_prob: float = 0.5,
81
- activation: type[nn.Module] = nn.LeakyReLU,
81
+ activation: nn.Module = nn.LeakyReLU,
82
82
  pool_size: int = 8,
83
83
  inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
84
84
  ):
@@ -290,6 +290,6 @@ class TSception(EEGModuleMixin, nn.Module):
290
290
  "this alias will be removed in v1.14."
291
291
  )
292
292
  class TSceptionV1(TSception):
293
- r"""Deprecated alias for TSception."""
293
+ """Deprecated alias for TSception."""
294
294
 
295
295
  pass
@@ -12,8 +12,8 @@ from braindecode.models.base import EEGModuleMixin
12
12
 
13
13
 
14
14
  class USleep(EEGModuleMixin, nn.Module):
15
- r"""
16
- Sleep staging architecture from Perslev et al (2021) [1]_.
15
+ """
16
+ Sleep staging architecture from Perslev et al. (2021) [1]_.
17
17
 
18
18
  :bdg-success:`Convolution`
19
19
 
@@ -182,7 +182,7 @@ class USleep(EEGModuleMixin, nn.Module):
182
182
  input_window_seconds=None,
183
183
  time_conv_size_s=9 / 128,
184
184
  ensure_odd_conv_size=False,
185
- activation: type[nn.Module] = nn.ELU,
185
+ activation: nn.Module = nn.ELU,
186
186
  chs_info=None,
187
187
  n_times=None,
188
188
  ):
@@ -331,7 +331,7 @@ class USleep(EEGModuleMixin, nn.Module):
331
331
 
332
332
 
333
333
  class _EncoderBlock(nn.Module):
334
- r"""Encoding block for a timeseries x of shape (B, C, T)."""
334
+ """Encoding block for a timeseries x of shape (B, C, T)."""
335
335
 
336
336
  def __init__(
337
337
  self,
@@ -339,7 +339,7 @@ class _EncoderBlock(nn.Module):
339
339
  out_channels=2,
340
340
  kernel_size=9,
341
341
  downsample=2,
342
- activation: type[nn.Module] = nn.ELU,
342
+ activation: nn.Module = nn.ELU,
343
343
  ):
344
344
  super().__init__()
345
345
  self.in_channels = in_channels
@@ -371,7 +371,7 @@ class _EncoderBlock(nn.Module):
371
371
 
372
372
 
373
373
  class _DecoderBlock(nn.Module):
374
- r"""Decoding block for a timeseries x of shape (B, C, T)."""
374
+ """Decoding block for a timeseries x of shape (B, C, T)."""
375
375
 
376
376
  def __init__(
377
377
  self,
@@ -380,7 +380,7 @@ class _DecoderBlock(nn.Module):
380
380
  kernel_size=9,
381
381
  upsample=2,
382
382
  with_skip_connection=True,
383
- activation: type[nn.Module] = nn.ELU,
383
+ activation: nn.Module = nn.ELU,
384
384
  ):
385
385
  super().__init__()
386
386
  self.in_channels = in_channels
@@ -3,9 +3,8 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
  import inspect
6
- from copy import deepcopy
7
6
  from pathlib import Path
8
- from typing import Any, Dict, Literal, Optional, Sequence
7
+ from typing import Any, Dict, Optional, Sequence
9
8
 
10
9
  import numpy as np
11
10
  import pandas as pd
@@ -30,16 +29,6 @@ def _init_models_dict():
30
29
  models_dict[m[0]] = m[1]
31
30
 
32
31
 
33
- SigArgName = Literal[
34
- "n_outputs",
35
- "n_chans",
36
- "chs_info",
37
- "n_times",
38
- "input_window_seconds",
39
- "sfreq",
40
- ]
41
-
42
-
43
32
  ################################################################
44
33
  # Test cases for models
45
34
  #
@@ -61,9 +50,7 @@ SigArgName = Literal[
61
50
  # The keys of this dictionary can only be among those of
62
51
  # default_signal_params.
63
52
  ################################################################
64
- models_mandatory_parameters: list[
65
- tuple[str, list[SigArgName], dict[SigArgName, Any] | None]
66
- ] = [
53
+ models_mandatory_parameters = [
67
54
  ("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
68
55
  ("BDTCN", ["n_chans", "n_outputs"], None),
69
56
  ("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
@@ -77,60 +64,45 @@ models_mandatory_parameters: list[
77
64
  (
78
65
  "SleepStagerBlanco2020",
79
66
  ["n_chans", "n_outputs", "n_times"],
80
- {"n_chans": 4}, # n_chans dividable by n_groups=2
67
+ dict(n_chans=4), # n_chans dividable by n_groups=2
81
68
  ),
82
69
  ("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
83
70
  (
84
71
  "AttnSleep",
85
72
  ["n_outputs", "n_times", "sfreq"],
86
- {
87
- "sfreq": 100.0,
88
- "n_times": 3000,
89
- "chs_info": [{"ch_name": "C1", "kind": "eeg"}],
90
- },
73
+ dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
91
74
  ), # 1 channel
92
75
  ("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
93
- ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 128.0}),
76
+ ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
94
77
  ("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
95
78
  ("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
96
79
  ("Labram", ["n_chans", "n_outputs", "n_times"], None),
97
80
  ("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
98
81
  ("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
99
- ("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], {"sfreq": 200.0}),
82
+ ("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], dict(sfreq=200.0)),
100
83
  ("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
101
84
  ("EEGSym", ["chs_info", "n_chans", "n_outputs", "n_times", "sfreq"], None),
102
- ("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
85
+ ("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
103
86
  ("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
104
87
  ("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
105
88
  ("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
106
- ("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
89
+ ("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
107
90
  ("CTNet", ["n_chans", "n_outputs", "n_times"], None),
108
- ("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 250.0}),
109
- ("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
91
+ ("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
92
+ ("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
110
93
  ("SignalJEPA", ["chs_info"], None),
111
94
  ("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
112
95
  ("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
113
96
  ("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
114
- ("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
115
- ("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
116
- ("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
117
- ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
97
+ ("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
98
+ ("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
99
+ ("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
100
+ ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
118
101
  ("PBT", ["n_chans", "n_outputs", "n_times"], None),
119
102
  ("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
120
- ("BrainModule", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
121
103
  ("BENDR", ["n_chans", "n_outputs", "n_times"], None),
122
104
  ("LUNA", ["n_chans", "n_times", "n_outputs"], None),
123
105
  ("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
124
- (
125
- "REVE",
126
- ["n_times", "n_outputs", "n_chans", "chs_info"],
127
- {
128
- "sfreq": 200.0,
129
- "n_chans": 19,
130
- "n_times": 1_000,
131
- "chs_info": [{"ch_name": f"E{i + 1}", "kind": "eeg"} for i in range(19)],
132
- },
133
- ),
134
106
  ]
135
107
 
136
108
  ################################################################
@@ -143,129 +115,6 @@ non_classification_models = [
143
115
  "SignalJEPA",
144
116
  ]
145
117
 
146
- ################################################################
147
-
148
- rng = np.random.default_rng(12)
149
- # Generating the channel info
150
- chs_info = [
151
- {
152
- "ch_name": f"C{i}",
153
- "kind": "eeg",
154
- "loc": rng.random(12),
155
- }
156
- for i in range(1, 4)
157
- ]
158
- default_signal_params: dict[SigArgName, Any] = {
159
- "n_times": 1000,
160
- "sfreq": 250.0,
161
- "n_outputs": 2,
162
- "chs_info": chs_info,
163
- "n_chans": len(chs_info),
164
- "input_window_seconds": 4.0,
165
- }
166
-
167
-
168
- def _get_signal_params(
169
- signal_params: dict[SigArgName, Any] | None,
170
- required_params: list[SigArgName] | None = None,
171
- ) -> dict[SigArgName, Any]:
172
- """Get signal parameters for model initialization in tests."""
173
- sp = deepcopy(default_signal_params)
174
- if signal_params is not None:
175
- sp.update(signal_params)
176
- if "chs_info" in signal_params and "n_chans" not in signal_params:
177
- sp["n_chans"] = len(signal_params["chs_info"])
178
- if "n_chans" in signal_params and "chs_info" not in signal_params:
179
- sp["chs_info"] = [
180
- {"ch_name": f"C{i}", "kind": "eeg", "loc": rng.random(12)}
181
- for i in range(signal_params["n_chans"])
182
- ]
183
- assert isinstance(sp["n_times"], int)
184
- assert isinstance(sp["sfreq"], float)
185
- assert isinstance(sp["input_window_seconds"], float)
186
- if "input_window_seconds" not in signal_params:
187
- sp["input_window_seconds"] = sp["n_times"] / sp["sfreq"]
188
- if "sfreq" not in signal_params:
189
- sp["sfreq"] = sp["n_times"] / sp["input_window_seconds"]
190
- if "n_times" not in signal_params:
191
- sp["n_times"] = int(sp["input_window_seconds"] * sp["sfreq"])
192
- if required_params is not None:
193
- sp = {
194
- k: sp[k] for k in set((signal_params or {}).keys()).union(required_params)
195
- }
196
- return sp
197
-
198
-
199
- def _get_possible_signal_params(
200
- signal_params: dict[SigArgName, Any], required_params: list[SigArgName]
201
- ):
202
- sp = signal_params
203
-
204
- # List possible model kwargs:
205
- output_kwargs = []
206
- output_kwargs.append(dict(n_outputs=sp["n_outputs"]))
207
-
208
- if "n_outputs" not in required_params:
209
- output_kwargs.append(dict(n_outputs=None))
210
-
211
- channel_kwargs = []
212
- channel_kwargs.append(dict(chs_info=sp["chs_info"], n_chans=None))
213
- if "chs_info" not in required_params:
214
- channel_kwargs.append(dict(n_chans=sp["n_chans"], chs_info=None))
215
- if "n_chans" not in required_params and "chs_info" not in required_params:
216
- channel_kwargs.append(dict(n_chans=None, chs_info=None))
217
-
218
- time_kwargs = []
219
- time_kwargs.append(
220
- dict(n_times=sp["n_times"], sfreq=sp["sfreq"], input_window_seconds=None)
221
- )
222
- time_kwargs.append(
223
- dict(
224
- n_times=None,
225
- sfreq=sp["sfreq"],
226
- input_window_seconds=sp["input_window_seconds"],
227
- )
228
- )
229
- time_kwargs.append(
230
- dict(
231
- n_times=sp["n_times"],
232
- sfreq=None,
233
- input_window_seconds=sp["input_window_seconds"],
234
- )
235
- )
236
- if "n_times" not in required_params and "sfreq" not in required_params:
237
- time_kwargs.append(
238
- dict(
239
- n_times=None,
240
- sfreq=None,
241
- input_window_seconds=sp["input_window_seconds"],
242
- )
243
- )
244
- if (
245
- "n_times" not in required_params
246
- and "input_window_seconds" not in required_params
247
- ):
248
- time_kwargs.append(
249
- dict(n_times=None, sfreq=sp["sfreq"], input_window_seconds=None)
250
- )
251
- if "sfreq" not in required_params and "input_window_seconds" not in required_params:
252
- time_kwargs.append(
253
- dict(n_times=sp["n_times"], sfreq=None, input_window_seconds=None)
254
- )
255
- if (
256
- "n_times" not in required_params
257
- and "sfreq" not in required_params
258
- and "input_window_seconds" not in required_params
259
- ):
260
- time_kwargs.append(dict(n_times=None, sfreq=None, input_window_seconds=None))
261
-
262
- return [
263
- dict(**o, **c, **t)
264
- for o in output_kwargs
265
- for c in channel_kwargs
266
- for t in time_kwargs
267
- ]
268
-
269
118
 
270
119
  ################################################################
271
120
  def get_summary_table(dir_name=None):
@@ -22,14 +22,7 @@ from .convolution import (
22
22
  DepthwiseConv2d,
23
23
  )
24
24
  from .filter import FilterBankLayer, GeneralizedGaussianFilter
25
- from .layers import (
26
- Chomp1d,
27
- DropPath,
28
- Ensure4d,
29
- SqueezeFinalOutput,
30
- SubjectLayers,
31
- TimeDistributed,
32
- )
25
+ from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
33
26
  from .linear import LinearWithConstraint, MaxNormLinear
34
27
  from .parametrization import MaxNorm, MaxNormParametrize
35
28
  from .stats import (
@@ -72,7 +65,6 @@ __all__ = [
72
65
  "Chomp1d",
73
66
  "DropPath",
74
67
  "Ensure4d",
75
- "SubjectLayers",
76
68
  "SqueezeFinalOutput",
77
69
  "TimeDistributed",
78
70
  "LinearWithConstraint",
@@ -8,24 +8,14 @@ class SafeLog(nn.Module):
8
8
  r"""
9
9
  Safe logarithm activation function module.
10
10
 
11
- :math:`\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)`
11
+ :math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
12
12
 
13
13
  Parameters
14
14
  ----------
15
- epsilon : float, optional
15
+ eps : float, optional
16
16
  A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
17
17
  Default is 1e-6.
18
18
 
19
- Examples
20
- --------
21
- >>> import torch
22
- >>> from braindecode.modules import SafeLog
23
- >>> module = SafeLog(epsilon=1e-6)
24
- >>> inputs = torch.rand(2, 3)
25
- >>> outputs = module(inputs)
26
- >>> outputs.shape
27
- torch.Size([2, 3])
28
-
29
19
  """
30
20
 
31
21
  def __init__(self, epsilon: float = 1e-6):
@@ -54,23 +44,7 @@ class SafeLog(nn.Module):
54
44
 
55
45
 
56
46
  class LogActivation(nn.Module):
57
- """Logarithm activation function.
58
-
59
- Parameters
60
- ----------
61
- epsilon : float, default=1e-6
62
- Small float to adjust the activation.
63
-
64
- Examples
65
- --------
66
- >>> import torch
67
- >>> from braindecode.modules import LogActivation
68
- >>> module = LogActivation(epsilon=1e-6)
69
- >>> inputs = torch.rand(2, 3)
70
- >>> outputs = module(inputs)
71
- >>> outputs.shape
72
- torch.Size([2, 3])
73
- """
47
+ """Logarithm activation function."""
74
48
 
75
49
  def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
76
50
  """