braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/tcn.py CHANGED
@@ -2,22 +2,23 @@
2
2
  # Lukas Gemein <l.gemein@gmail.com>
3
3
  #
4
4
  # License: BSD-3
5
-
5
+ import torch
6
6
  from torch import nn
7
7
  from torch.nn import init
8
- from torch.nn.utils import weight_norm
8
+ from torch.nn.utils.parametrizations import weight_norm
9
9
 
10
- from .modules import Ensure4d, Expression
11
- from .functions import squeeze_final_output
12
- from .base import EEGModuleMixin, deprecated_args
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import Chomp1d, Ensure4d, Expression, SqueezeFinalOutput
13
12
 
14
13
 
15
- class TCN(EEGModuleMixin, nn.Module):
16
- """Temporal Convolutional Network (TCN) from Bai et al 2018.
14
+ class BDTCN(EEGModuleMixin, nn.Module):
15
+ """Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
17
16
 
18
- See [Bai2018]_ for details.
17
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
18
+ :align: center
19
+ :alt: Braindecode TCN Architecture
19
20
 
20
- Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
21
+ See [gemein2020]_ for details.
21
22
 
22
23
  Parameters
23
24
  ----------
@@ -29,36 +30,33 @@ class TCN(EEGModuleMixin, nn.Module):
29
30
  kernel size of the convolutions
30
31
  drop_prob: float
31
32
  dropout probability
32
- n_in_chans: int
33
- Alias for `n_chans`.
33
+ activation: nn.Module, default=nn.ReLU
34
+ Activation function class to apply. Should be a PyTorch activation
35
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
34
36
 
35
37
  References
36
38
  ----------
37
- .. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
38
- An empirical evaluation of generic convolutional and recurrent networks
39
- for sequence modeling.
40
- arXiv preprint arXiv:1803.01271.
39
+ .. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
40
+ Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
41
+ diagnostics of EEG pathology. NeuroImage, 220, 117021.
41
42
  """
42
43
 
43
44
  def __init__(
44
- self,
45
- n_chans=None,
46
- n_outputs=None,
47
- n_blocks=None,
48
- n_filters=None,
49
- kernel_size=None,
50
- drop_prob=None,
51
- chs_info=None,
52
- n_times=None,
53
- input_window_seconds=None,
54
- sfreq=None,
55
- n_in_chans=None,
56
- add_log_softmax=False,
45
+ self,
46
+ # Braindecode parameters
47
+ n_chans=None,
48
+ n_outputs=None,
49
+ chs_info=None,
50
+ n_times=None,
51
+ sfreq=None,
52
+ input_window_seconds=None,
53
+ # Model's parameters
54
+ n_blocks=3,
55
+ n_filters=30,
56
+ kernel_size=5,
57
+ drop_prob=0.5,
58
+ activation: nn.Module = nn.ReLU,
57
59
  ):
58
- n_chans, = deprecated_args(
59
- self,
60
- ("n_in_chans", "n_chans", n_in_chans, n_chans),
61
- )
62
60
  super().__init__(
63
61
  n_outputs=n_outputs,
64
62
  n_chans=n_chans,
@@ -66,43 +64,106 @@ class TCN(EEGModuleMixin, nn.Module):
66
64
  n_times=n_times,
67
65
  input_window_seconds=input_window_seconds,
68
66
  sfreq=sfreq,
69
- add_log_softmax=add_log_softmax,
70
67
  )
71
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
72
- del n_in_chans
68
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
69
+
70
+ self.base_tcn = TCN(
71
+ n_chans=self.n_chans,
72
+ n_outputs=self.n_outputs,
73
+ n_blocks=n_blocks,
74
+ n_filters=n_filters,
75
+ kernel_size=kernel_size,
76
+ drop_prob=drop_prob,
77
+ activation=activation,
78
+ )
79
+
80
+ self.final_layer = torch.nn.Sequential(
81
+ torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
82
+ )
83
+
84
+ def forward(self, x):
85
+ x = self.base_tcn(x)
86
+ x = self.final_layer(x)
87
+ return x
88
+
89
+
90
+ class TCN(nn.Module):
91
+ """Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
92
+
93
+ See [Bai2018]_ for details.
94
+
95
+ Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
96
+
97
+ Parameters
98
+ ----------
99
+ n_filters: int
100
+ number of output filters of each convolution
101
+ n_blocks: int
102
+ number of temporal blocks in the network
103
+ kernel_size: int
104
+ kernel size of the convolutions
105
+ drop_prob: float
106
+ dropout probability
107
+ activation: nn.Module, default=nn.ReLU
108
+ Activation function class to apply. Should be a PyTorch activation
109
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
110
+
111
+ References
112
+ ----------
113
+ .. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
114
+ An empirical evaluation of generic convolutional and recurrent networks
115
+ for sequence modeling.
116
+ arXiv preprint arXiv:1803.01271.
117
+ """
73
118
 
119
+ def __init__(
120
+ self,
121
+ n_chans=None,
122
+ n_outputs=None,
123
+ n_blocks=3,
124
+ n_filters=30,
125
+ kernel_size=5,
126
+ drop_prob=0.5,
127
+ activation: nn.Module = nn.ReLU,
128
+ ):
129
+ super().__init__()
74
130
  self.mapping = {
75
131
  "fc.weight": "final_layer.fc.weight",
76
- "fc.bias": "final_layer.fc.bias"
132
+ "fc.bias": "final_layer.fc.bias",
77
133
  }
78
134
  self.ensuredims = Ensure4d()
79
135
  t_blocks = nn.Sequential()
80
136
  for i in range(n_blocks):
81
- n_inputs = self.n_chans if i == 0 else n_filters
82
- dilation_size = 2 ** i
83
- t_blocks.add_module("temporal_block_{:d}".format(i), TemporalBlock(
84
- n_inputs=n_inputs,
85
- n_outputs=n_filters,
86
- kernel_size=kernel_size,
87
- stride=1,
88
- dilation=dilation_size,
89
- padding=(kernel_size - 1) * dilation_size,
90
- drop_prob=drop_prob
91
- ))
137
+ n_inputs = n_chans if i == 0 else n_filters
138
+ dilation_size = 2**i
139
+ t_blocks.add_module(
140
+ "temporal_block_{:d}".format(i),
141
+ _TemporalBlock(
142
+ n_inputs=n_inputs,
143
+ n_outputs=n_filters,
144
+ kernel_size=kernel_size,
145
+ stride=1,
146
+ dilation=dilation_size,
147
+ padding=(kernel_size - 1) * dilation_size,
148
+ drop_prob=drop_prob,
149
+ activation=activation,
150
+ ),
151
+ )
92
152
  self.temporal_blocks = t_blocks
93
153
 
94
- # Here, change to final_layer
95
- self.final_layer = _FinalLayer(in_features=n_filters, out_features=self.n_outputs,
96
- add_log_softmax=add_log_softmax)
154
+ self.final_layer = _FinalLayer(
155
+ in_features=n_filters,
156
+ out_features=n_outputs,
157
+ )
97
158
  self.min_len = 1
98
159
  for i in range(n_blocks):
99
- dilation = 2 ** i
160
+ dilation = 2**i
100
161
  self.min_len += 2 * (kernel_size - 1) * dilation
101
162
 
102
163
  # start in eval mode
103
- self.eval()
164
+ self.train()
104
165
 
105
- def forward(self, x):
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
167
  """Forward pass.
107
168
 
108
169
  Parameters
@@ -126,21 +187,18 @@ class TCN(EEGModuleMixin, nn.Module):
126
187
 
127
188
 
128
189
  class _FinalLayer(nn.Module):
129
- def __init__(self, in_features, out_features, add_log_softmax=True):
130
-
190
+ def __init__(self, in_features, out_features):
131
191
  super().__init__()
132
192
 
133
193
  self.fc = nn.Linear(in_features=in_features, out_features=out_features)
134
194
 
135
- if add_log_softmax:
136
- self.out_fun = nn.LogSoftmax(dim=1)
137
- else:
138
- self.out_fun = nn.Identity()
195
+ self.out_fun = nn.Identity()
139
196
 
140
- self.squeeze = Expression(squeeze_final_output)
141
-
142
- def forward(self, x, batch_size, time_size, min_len):
197
+ self.squeeze = SqueezeFinalOutput()
143
198
 
199
+ def forward(
200
+ self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
201
+ ) -> torch.Tensor:
144
202
  fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
145
203
  fc_out = self.out_fun(fc_out)
146
204
  fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
@@ -151,27 +209,51 @@ class _FinalLayer(nn.Module):
151
209
  return self.squeeze(out[:, :, :, None])
152
210
 
153
211
 
154
- class TemporalBlock(nn.Module):
155
- def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation,
156
- padding, drop_prob):
212
+ class _TemporalBlock(nn.Module):
213
+ def __init__(
214
+ self,
215
+ n_inputs,
216
+ n_outputs,
217
+ kernel_size,
218
+ stride,
219
+ dilation,
220
+ padding,
221
+ drop_prob,
222
+ activation: nn.Module = nn.ReLU,
223
+ ):
157
224
  super().__init__()
158
- self.conv1 = weight_norm(nn.Conv1d(
159
- n_inputs, n_outputs, kernel_size,
160
- stride=stride, padding=padding, dilation=dilation))
225
+ self.conv1 = weight_norm(
226
+ nn.Conv1d(
227
+ n_inputs,
228
+ n_outputs,
229
+ kernel_size,
230
+ stride=stride,
231
+ padding=padding,
232
+ dilation=dilation,
233
+ )
234
+ )
161
235
  self.chomp1 = Chomp1d(padding)
162
- self.relu1 = nn.ReLU()
236
+ self.relu1 = activation()
163
237
  self.dropout1 = nn.Dropout2d(drop_prob)
164
238
 
165
- self.conv2 = weight_norm(nn.Conv1d(
166
- n_outputs, n_outputs, kernel_size,
167
- stride=stride, padding=padding, dilation=dilation))
239
+ self.conv2 = weight_norm(
240
+ nn.Conv1d(
241
+ n_outputs,
242
+ n_outputs,
243
+ kernel_size,
244
+ stride=stride,
245
+ padding=padding,
246
+ dilation=dilation,
247
+ )
248
+ )
168
249
  self.chomp2 = Chomp1d(padding)
169
- self.relu2 = nn.ReLU()
250
+ self.relu2 = activation()
170
251
  self.dropout2 = nn.Dropout2d(drop_prob)
171
252
 
172
- self.downsample = (nn.Conv1d(n_inputs, n_outputs, 1)
173
- if n_inputs != n_outputs else None)
174
- self.relu = nn.ReLU()
253
+ self.downsample = (
254
+ nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
255
+ )
256
+ self.relu = activation()
175
257
 
176
258
  init.normal_(self.conv1.weight, 0, 0.01)
177
259
  init.normal_(self.conv2.weight, 0, 0.01)
@@ -189,15 +271,3 @@ class TemporalBlock(nn.Module):
189
271
  out = self.dropout2(out)
190
272
  res = x if self.downsample is None else self.downsample(x)
191
273
  return self.relu(out + res)
192
-
193
-
194
- class Chomp1d(nn.Module):
195
- def __init__(self, chomp_size):
196
- super().__init__()
197
- self.chomp_size = chomp_size
198
-
199
- def extra_repr(self):
200
- return 'chomp_size={}'.format(self.chomp_size)
201
-
202
- def forward(self, x):
203
- return x[:, :, :-self.chomp_size].contiguous()