braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from torch import nn
7
+ from torch.nn import init
8
+
9
+ from braindecode.functional import square
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ CombinedConv,
13
+ Ensure4d,
14
+ Expression,
15
+ SafeLog,
16
+ SqueezeFinalOutput,
17
+ )
18
+
19
+
20
+ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
21
+ """Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
22
+
23
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
24
+ :align: center
25
+ :alt: ShallowNet Architecture
26
+
27
+ Model described in [Schirrmeister2017]_.
28
+
29
+ Parameters
30
+ ----------
31
+ n_filters_time: int
32
+ Number of temporal filters.
33
+ filter_time_length: int
34
+ Length of the temporal filter.
35
+ n_filters_spat: int
36
+ Number of spatial filters.
37
+ pool_time_length: int
38
+ Length of temporal pooling filter.
39
+ pool_time_stride: int
40
+ Length of stride between temporal pooling filters.
41
+ final_conv_length: int | str
42
+ Length of the final convolution layer.
43
+ If set to "auto", length of the input signal must be specified.
44
+ conv_nonlin: callable
45
+ Non-linear function to be used after convolution layers.
46
+ pool_mode: str
47
+ Method to use on pooling layers. "max" or "mean".
48
+ activation_pool_nonlin: callable
49
+ Non-linear function to be used after pooling layers.
50
+ split_first_layer: bool
51
+ Split first layer into temporal and spatial layers (True) or just use temporal (False).
52
+ There would be no non-linearity between the split layers.
53
+ batch_norm: bool
54
+ Whether to use batch normalisation.
55
+ batch_norm_alpha: float
56
+ Momentum for BatchNorm2d.
57
+ drop_prob: float
58
+ Dropout probability.
59
+
60
+ References
61
+ ----------
62
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
63
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
64
+ & Ball, T. (2017).
65
+ Deep learning with convolutional neural networks for EEG decoding and
66
+ visualization.
67
+ Human Brain Mapping , Aug. 2017.
68
+ Online: http://dx.doi.org/10.1002/hbm.23730
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ n_chans=None,
74
+ n_outputs=None,
75
+ n_times=None,
76
+ n_filters_time=40,
77
+ filter_time_length=25,
78
+ n_filters_spat=40,
79
+ pool_time_length=75,
80
+ pool_time_stride=15,
81
+ final_conv_length="auto",
82
+ conv_nonlin=square,
83
+ pool_mode="mean",
84
+ activation_pool_nonlin: nn.Module = SafeLog,
85
+ split_first_layer=True,
86
+ batch_norm=True,
87
+ batch_norm_alpha=0.1,
88
+ drop_prob=0.5,
89
+ chs_info=None,
90
+ input_window_seconds=None,
91
+ sfreq=None,
92
+ ):
93
+ super().__init__(
94
+ n_outputs=n_outputs,
95
+ n_chans=n_chans,
96
+ chs_info=chs_info,
97
+ n_times=n_times,
98
+ input_window_seconds=input_window_seconds,
99
+ sfreq=sfreq,
100
+ )
101
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
102
+ if final_conv_length == "auto":
103
+ assert self.n_times is not None
104
+ self.n_filters_time = n_filters_time
105
+ self.filter_time_length = filter_time_length
106
+ self.n_filters_spat = n_filters_spat
107
+ self.pool_time_length = pool_time_length
108
+ self.pool_time_stride = pool_time_stride
109
+ self.final_conv_length = final_conv_length
110
+ self.conv_nonlin = conv_nonlin
111
+ self.pool_mode = pool_mode
112
+ self.pool_nonlin = activation_pool_nonlin
113
+ self.split_first_layer = split_first_layer
114
+ self.batch_norm = batch_norm
115
+ self.batch_norm_alpha = batch_norm_alpha
116
+ self.drop_prob = drop_prob
117
+
118
+ self.mapping = {
119
+ "conv_time.weight": "conv_time_spat.conv_time.weight",
120
+ "conv_spat.weight": "conv_time_spat.conv_spat.weight",
121
+ "conv_time.bias": "conv_time_spat.conv_time.bias",
122
+ "conv_spat.bias": "conv_time_spat.conv_spat.bias",
123
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
124
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
125
+ }
126
+
127
+ self.add_module("ensuredims", Ensure4d())
128
+ pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
129
+ if self.split_first_layer:
130
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
131
+ self.add_module(
132
+ "conv_time_spat",
133
+ CombinedConv(
134
+ in_chans=self.n_chans,
135
+ n_filters_time=self.n_filters_time,
136
+ n_filters_spat=self.n_filters_spat,
137
+ filter_time_length=filter_time_length,
138
+ bias_time=True,
139
+ bias_spat=not self.batch_norm,
140
+ ),
141
+ )
142
+ n_filters_conv = self.n_filters_spat
143
+ else:
144
+ self.add_module(
145
+ "conv_time",
146
+ nn.Conv2d(
147
+ self.n_chans,
148
+ self.n_filters_time,
149
+ (self.filter_time_length, 1),
150
+ stride=1,
151
+ bias=not self.batch_norm,
152
+ ),
153
+ )
154
+ n_filters_conv = self.n_filters_time
155
+ if self.batch_norm:
156
+ self.add_module(
157
+ "bnorm",
158
+ nn.BatchNorm2d(
159
+ n_filters_conv, momentum=self.batch_norm_alpha, affine=True
160
+ ),
161
+ )
162
+ self.add_module("conv_nonlin_exp", Expression(self.conv_nonlin))
163
+ self.add_module(
164
+ "pool",
165
+ pool_class(
166
+ kernel_size=(self.pool_time_length, 1),
167
+ stride=(self.pool_time_stride, 1),
168
+ ),
169
+ )
170
+ self.add_module("pool_nonlin_exp", self.pool_nonlin())
171
+ self.add_module("drop", nn.Dropout(p=self.drop_prob))
172
+ self.eval()
173
+ if self.final_conv_length == "auto":
174
+ self.final_conv_length = self.get_output_shape()[2]
175
+
176
+ # Incorporating classification module and subsequent ones in one final layer
177
+ module = nn.Sequential()
178
+
179
+ module.add_module(
180
+ "conv_classifier",
181
+ nn.Conv2d(
182
+ n_filters_conv,
183
+ self.n_outputs,
184
+ (self.final_conv_length, 1),
185
+ bias=True,
186
+ ),
187
+ )
188
+
189
+ module.add_module("squeeze", SqueezeFinalOutput())
190
+
191
+ self.add_module("final_layer", module)
192
+
193
+ # Initialization, xavier is same as in paper...
194
+ init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
195
+ # maybe no bias in case of no split layer and batch norm
196
+ if self.split_first_layer or (not self.batch_norm):
197
+ init.constant_(self.conv_time_spat.conv_time.bias, 0)
198
+ if self.split_first_layer:
199
+ init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
200
+ if not self.batch_norm:
201
+ init.constant_(self.conv_time_spat.conv_spat.bias, 0)
202
+ if self.batch_norm:
203
+ init.constant_(self.bnorm.weight, 1)
204
+ init.constant_(self.bnorm.bias, 0)
205
+ init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
206
+ init.constant_(self.final_layer.conv_classifier.bias, 0)
207
+
208
+ self.train()