braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,23 @@ from einops.layers.torch import Rearrange
6
6
  from torch import nn
7
7
  from torch.nn import init
8
8
 
9
- from .base import EEGModuleMixin, deprecated_args
10
- from .functions import safe_log, square, squeeze_final_output
11
- from .modules import CombinedConv, Ensure4d, Expression
9
+ from braindecode.functional import square
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ CombinedConv,
13
+ Ensure4d,
14
+ Expression,
15
+ SafeLog,
16
+ SqueezeFinalOutput,
17
+ )
12
18
 
13
19
 
14
20
  class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
15
- """Shallow ConvNet model from Schirrmeister et al 2017.
21
+ """Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
22
+
23
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
24
+ :align: center
25
+ :alt: ShallowNet Architecture
16
26
 
17
27
  Model described in [Schirrmeister2017]_.
18
28
 
@@ -35,7 +45,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
35
45
  Non-linear function to be used after convolution layers.
36
46
  pool_mode: str
37
47
  Method to use on pooling layers. "max" or "mean".
38
- pool_nonlin: callable
48
+ activation_pool_nonlin: callable
39
49
  Non-linear function to be used after pooling layers.
40
50
  split_first_layer: bool
41
51
  Split first layer into temporal and spatial layers (True) or just use temporal (False).
@@ -46,12 +56,6 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
46
56
  Momentum for BatchNorm2d.
47
57
  drop_prob: float
48
58
  Dropout probability.
49
- in_chans : int
50
- Alias for `n_chans`.
51
- n_classes: int
52
- Alias for `n_outputs`.
53
- input_window_samples: int | None
54
- Alias for `n_times`.
55
59
 
56
60
  References
57
61
  ----------
@@ -65,37 +69,27 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
65
69
  """
66
70
 
67
71
  def __init__(
68
- self,
69
- n_chans=None,
70
- n_outputs=None,
71
- n_times=None,
72
- n_filters_time=40,
73
- filter_time_length=25,
74
- n_filters_spat=40,
75
- pool_time_length=75,
76
- pool_time_stride=15,
77
- final_conv_length=30,
78
- conv_nonlin=square,
79
- pool_mode="mean",
80
- pool_nonlin=safe_log,
81
- split_first_layer=True,
82
- batch_norm=True,
83
- batch_norm_alpha=0.1,
84
- drop_prob=0.5,
85
- chs_info=None,
86
- input_window_seconds=None,
87
- sfreq=None,
88
- in_chans=None,
89
- n_classes=None,
90
- input_window_samples=None,
91
- add_log_softmax=True,
72
+ self,
73
+ n_chans=None,
74
+ n_outputs=None,
75
+ n_times=None,
76
+ n_filters_time=40,
77
+ filter_time_length=25,
78
+ n_filters_spat=40,
79
+ pool_time_length=75,
80
+ pool_time_stride=15,
81
+ final_conv_length="auto",
82
+ conv_nonlin=square,
83
+ pool_mode="mean",
84
+ activation_pool_nonlin: nn.Module = SafeLog,
85
+ split_first_layer=True,
86
+ batch_norm=True,
87
+ batch_norm_alpha=0.1,
88
+ drop_prob=0.5,
89
+ chs_info=None,
90
+ input_window_seconds=None,
91
+ sfreq=None,
92
92
  ):
93
- n_chans, n_outputs, n_times = deprecated_args(
94
- self,
95
- ("in_chans", "n_chans", in_chans, n_chans),
96
- ("n_classes", "n_outputs", n_classes, n_outputs),
97
- ("input_window_samples", "n_times", input_window_samples, n_times),
98
- )
99
93
  super().__init__(
100
94
  n_outputs=n_outputs,
101
95
  n_chans=n_chans,
@@ -103,10 +97,8 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
103
97
  n_times=n_times,
104
98
  input_window_seconds=input_window_seconds,
105
99
  sfreq=sfreq,
106
- add_log_softmax=add_log_softmax,
107
100
  )
108
101
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
109
- del in_chans, n_classes, input_window_samples
110
102
  if final_conv_length == "auto":
111
103
  assert self.n_times is not None
112
104
  self.n_filters_time = n_filters_time
@@ -117,7 +109,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
117
109
  self.final_conv_length = final_conv_length
118
110
  self.conv_nonlin = conv_nonlin
119
111
  self.pool_mode = pool_mode
120
- self.pool_nonlin = pool_nonlin
112
+ self.pool_nonlin = activation_pool_nonlin
121
113
  self.split_first_layer = split_first_layer
122
114
  self.batch_norm = batch_norm
123
115
  self.batch_norm_alpha = batch_norm_alpha
@@ -129,7 +121,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
129
121
  "conv_time.bias": "conv_time_spat.conv_time.bias",
130
122
  "conv_spat.bias": "conv_time_spat.conv_spat.bias",
131
123
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
132
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
124
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
133
125
  }
134
126
 
135
127
  self.add_module("ensuredims", Ensure4d())
@@ -175,7 +167,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
175
167
  stride=(self.pool_time_stride, 1),
176
168
  ),
177
169
  )
178
- self.add_module("pool_nonlin_exp", Expression(self.pool_nonlin))
170
+ self.add_module("pool_nonlin_exp", self.pool_nonlin())
179
171
  self.add_module("drop", nn.Dropout(p=self.drop_prob))
180
172
  self.eval()
181
173
  if self.final_conv_length == "auto":
@@ -184,17 +176,17 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
184
176
  # Incorporating classification module and subsequent ones in one final layer
185
177
  module = nn.Sequential()
186
178
 
187
- module.add_module("conv_classifier",
188
- nn.Conv2d(
189
- n_filters_conv,
190
- self.n_outputs,
191
- (self.final_conv_length, 1),
192
- bias=True, ))
193
-
194
- if self.add_log_softmax:
195
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
179
+ module.add_module(
180
+ "conv_classifier",
181
+ nn.Conv2d(
182
+ n_filters_conv,
183
+ self.n_outputs,
184
+ (self.final_conv_length, 1),
185
+ bias=True,
186
+ ),
187
+ )
196
188
 
197
- module.add_module("squeeze", Expression(squeeze_final_output))
189
+ module.add_module("squeeze", SqueezeFinalOutput())
198
190
 
199
191
  self.add_module("final_layer", module)
200
192
 
@@ -212,3 +204,5 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
212
204
  init.constant_(self.bnorm.bias, 0)
213
205
  init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
214
206
  init.constant_(self.final_layer.conv_classifier.bias, 0)
207
+
208
+ self.train()