braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,426 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+ from math import floor, log2
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class SPARCNet(EEGModuleMixin, nn.Module):
14
+ r"""Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al (2023) [jing2023]_.
15
+
16
+ :bdg-success:`Convolution`
17
+
18
+ This is a temporal CNN model for biosignal classification based on the DenseNet
19
+ architecture.
20
+
21
+ The model is based on the unofficial implementation [Code2023]_.
22
+
23
+ .. versionadded:: 0.9
24
+
25
+ Notes
26
+ -----
27
+ This implementation is not guaranteed to be correct, has not been checked
28
+ by original authors.
29
+
30
+ Parameters
31
+ ----------
32
+ block_layers : int, optional
33
+ Number of layers per dense block. Default is 4.
34
+ growth_rate : int, optional
35
+ Growth rate of the DenseNet. Default is 16.
36
+ bn_size : int, optional
37
+ Bottleneck size. Default is 16.
38
+ drop_prob : float, optional
39
+ Dropout rate. Default is 0.5.
40
+ conv_bias : bool, optional
41
+ Whether to use bias in convolutional layers. Default is True.
42
+ batch_norm : bool, optional
43
+ Whether to use batch normalization. Default is True.
44
+ activation: nn.Module, default=nn.ELU
45
+ Activation function class to apply. Should be a PyTorch activation
46
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
47
+
48
+ References
49
+ ----------
50
+ .. [jing2023] Jing, J., Ge, W., Hong, S., Fernandes, M. B., Lin, Z.,
51
+ Yang, C., ... & Westover, M. B. (2023). Development of expert-level
52
+ classification of seizures and rhythmic and periodic
53
+ patterns during eeg interpretation. Neurology, 100(17), e1750-e1762.
54
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
55
+ Biosignal Transformer for Cross-data Learning in the Wild.
56
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
57
+
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ n_chans=None,
63
+ n_times=None,
64
+ n_outputs=None,
65
+ # Neural network parameters
66
+ block_layers: int = 4,
67
+ growth_rate: int = 16,
68
+ bottleneck_size: int = 16,
69
+ drop_prob: float = 0.5,
70
+ conv_bias: bool = True,
71
+ batch_norm: bool = True,
72
+ activation: type[nn.Module] = nn.ELU,
73
+ kernel_size_conv0: int = 7,
74
+ kernel_size_conv1: int = 1,
75
+ kernel_size_conv2: int = 3,
76
+ kernel_size_pool: int = 3,
77
+ stride_pool: int = 2,
78
+ stride_conv0: int = 2,
79
+ stride_conv1: int = 1,
80
+ stride_conv2: int = 1,
81
+ padding_pool: int = 1,
82
+ padding_conv0: int = 3,
83
+ padding_conv2: int = 1,
84
+ kernel_size_trans: int = 2,
85
+ stride_trans: int = 2,
86
+ # EEGModuleMixin parameters
87
+ # (another way to present the same parameters)
88
+ chs_info=None,
89
+ input_window_seconds=None,
90
+ sfreq=None,
91
+ ):
92
+ super().__init__(
93
+ n_outputs=n_outputs,
94
+ n_chans=n_chans,
95
+ chs_info=chs_info,
96
+ n_times=n_times,
97
+ input_window_seconds=input_window_seconds,
98
+ sfreq=sfreq,
99
+ )
100
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
101
+
102
+ # add initial convolutional layer
103
+ # the number of output channels is the smallest power of 2
104
+ # that is greater than the number of input channels
105
+ out_channels = 2 ** (floor(log2(self.n_chans)) + 1)
106
+ first_conv = OrderedDict(
107
+ [
108
+ (
109
+ "conv0",
110
+ nn.Conv1d(
111
+ in_channels=self.n_chans,
112
+ out_channels=out_channels,
113
+ kernel_size=kernel_size_conv0,
114
+ stride=stride_conv0,
115
+ padding=padding_conv0,
116
+ bias=conv_bias,
117
+ ),
118
+ )
119
+ ]
120
+ )
121
+ first_conv["norm0"] = nn.BatchNorm1d(out_channels)
122
+ first_conv["act_layer"] = activation()
123
+ first_conv["pool0"] = nn.MaxPool1d(
124
+ kernel_size=kernel_size_pool,
125
+ stride=stride_pool,
126
+ padding=padding_pool,
127
+ )
128
+
129
+ self.encoder = nn.Sequential(first_conv)
130
+
131
+ n_channels = out_channels
132
+
133
+ # Adding dense blocks
134
+ for n_layer in range(floor(log2(self.n_times // 4))):
135
+ block = _DenseBlock(
136
+ num_layers=block_layers,
137
+ in_channels=n_channels,
138
+ growth_rate=growth_rate,
139
+ bottleneck_size=bottleneck_size,
140
+ drop_prob=drop_prob,
141
+ conv_bias=conv_bias,
142
+ batch_norm=batch_norm,
143
+ activation=activation,
144
+ kernel_size_conv1=kernel_size_conv1,
145
+ kernel_size_conv2=kernel_size_conv2,
146
+ stride_conv1=stride_conv1,
147
+ stride_conv2=stride_conv2,
148
+ padding_conv2=padding_conv2,
149
+ )
150
+ self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
151
+ # update the number of channels after each dense block
152
+ n_channels = n_channels + block_layers * growth_rate
153
+
154
+ trans = _TransitionLayer(
155
+ in_channels=n_channels,
156
+ out_channels=n_channels // 2,
157
+ conv_bias=conv_bias,
158
+ batch_norm=batch_norm,
159
+ activation=activation,
160
+ kernel_size_trans=kernel_size_trans,
161
+ stride_trans=stride_trans,
162
+ )
163
+ self.encoder.add_module("transition%d" % (n_layer + 1), trans)
164
+ # update the number of channels after each transition layer
165
+ n_channels = n_channels // 2
166
+
167
+ self.adaptative_pool = nn.AdaptiveAvgPool1d(1)
168
+ self.activation_layer = activation()
169
+ self.flatten_layer = nn.Flatten()
170
+
171
+ # add final convolutional layer
172
+ self.final_layer = nn.Linear(n_channels, self.n_outputs)
173
+
174
+ self._init_weights()
175
+
176
+ def _init_weights(self):
177
+ """
178
+ Initialize the weights of the model.
179
+
180
+ Official init from torch repo, using kaiming_normal for conv layers
181
+ and normal for linear layers.
182
+
183
+ """
184
+ for m in self.modules():
185
+ if isinstance(m, nn.Conv1d):
186
+ nn.init.kaiming_normal_(m.weight.data)
187
+ elif isinstance(m, nn.BatchNorm1d):
188
+ m.weight.data.fill_(1)
189
+ m.bias.data.zero_()
190
+ elif isinstance(m, nn.Linear):
191
+ m.bias.data.zero_()
192
+
193
+ def forward(self, X: torch.Tensor):
194
+ """
195
+ Forward pass of the model.
196
+
197
+ Parameters
198
+ ----------
199
+ X: torch.Tensor
200
+ The input tensor of the model with shape (batch_size, n_channels, n_times)
201
+
202
+ Returns
203
+ -------
204
+ torch.Tensor
205
+ The output tensor of the model with shape (batch_size, n_outputs)
206
+ """
207
+ emb = self.encoder(X)
208
+ emb = self.adaptative_pool(emb)
209
+ emb = self.activation_layer(emb)
210
+ emb = self.flatten_layer(emb)
211
+ out = self.final_layer(emb)
212
+ return out
213
+
214
+
215
+ class _DenseLayer(nn.Sequential):
216
+ r"""
217
+ A densely connected layer with batch normalization and dropout.
218
+
219
+ Parameters
220
+ ----------
221
+ in_channels : int
222
+ Number of input channels.
223
+ growth_rate : int
224
+ Rate of growth of channels in this layer.
225
+ bottleneck_size : int
226
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
227
+ drop_prob : float, optional
228
+ Dropout rate. Default is 0.5.
229
+ conv_bias : bool, optional
230
+ Whether to use bias in convolutional layers. Default is True.
231
+ batch_norm : bool, optional
232
+ Whether to use batch normalization. Default is True.
233
+ activation: nn.Module, default=nn.ELU
234
+ Activation function class to apply. Should be a PyTorch activation
235
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
236
+
237
+ Examples
238
+ --------
239
+ >>> x = torch.randn(128, 5, 1000)
240
+ >>> batch, channels, length = x.shape
241
+ >>> model = _DenseLayer(channels, 5, 2)
242
+ >>> y = model(x)
243
+ >>> y.shape
244
+ torch.Size([128, 10, 1000])
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ in_channels: int,
250
+ growth_rate: int,
251
+ bottleneck_size: int,
252
+ drop_prob: float = 0.5,
253
+ conv_bias: bool = True,
254
+ batch_norm: bool = True,
255
+ activation: type[nn.Module] = nn.ELU,
256
+ kernel_size_conv1: int = 1,
257
+ kernel_size_conv2: int = 3,
258
+ stride_conv1: int = 1,
259
+ stride_conv2: int = 1,
260
+ padding_conv2: int = 1,
261
+ ):
262
+ super().__init__()
263
+ if batch_norm:
264
+ self.add_module("norm1", nn.BatchNorm1d(in_channels))
265
+
266
+ self.add_module("elu1", activation())
267
+ self.add_module(
268
+ "conv1",
269
+ nn.Conv1d(
270
+ in_channels=in_channels,
271
+ out_channels=bottleneck_size * growth_rate,
272
+ kernel_size=kernel_size_conv1,
273
+ stride=stride_conv1,
274
+ bias=conv_bias,
275
+ ),
276
+ )
277
+ if batch_norm:
278
+ self.add_module("norm2", nn.BatchNorm1d(bottleneck_size * growth_rate))
279
+ self.add_module("elu2", activation())
280
+ self.add_module(
281
+ "conv2",
282
+ nn.Conv1d(
283
+ in_channels=bottleneck_size * growth_rate,
284
+ out_channels=growth_rate,
285
+ kernel_size=kernel_size_conv2,
286
+ stride=stride_conv2,
287
+ padding=padding_conv2,
288
+ bias=conv_bias,
289
+ ),
290
+ )
291
+ self.drop_prob = drop_prob
292
+
293
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
294
+ # Manually pass through each submodule
295
+ out = x
296
+ for layer in self:
297
+ out = layer(out)
298
+ # apply dropout using the functional API
299
+ out = F.dropout(out, p=self.drop_prob, training=self.training)
300
+ # concatenate input and new features
301
+ return torch.cat([x, out], dim=1)
302
+
303
+
304
+ class _DenseBlock(nn.Sequential):
305
+ r"""
306
+ A densely connected block that uses DenseLayers.
307
+
308
+ Parameters
309
+ ----------
310
+ num_layers : int
311
+ Number of layers in this block.
312
+ in_channels : int
313
+ Number of input channels.
314
+ growth_rate : int
315
+ Rate of growth of channels in this layer.
316
+ bottleneck_size : int
317
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
318
+ drop_prob : float, optional
319
+ Dropout rate. Default is 0.5.
320
+ conv_bias : bool, optional
321
+ Whether to use bias in convolutional layers. Default is True.
322
+ batch_norm : bool, optional
323
+ Whether to use batch normalization. Default is True.
324
+ activation: nn.Module, default=nn.ELU
325
+ Activation function class to apply. Should be a PyTorch activation
326
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
327
+
328
+ Examples
329
+ --------
330
+ >>> x = torch.randn(128, 5, 1000)
331
+ >>> batch, channels, length = x.shape
332
+ >>> model = _DenseBlock(3, channels, 5, 2)
333
+ >>> y = model(x)
334
+ >>> y.shape
335
+ torch.Size([128, 20, 1000])
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ num_layers,
341
+ in_channels,
342
+ growth_rate,
343
+ bottleneck_size,
344
+ drop_prob=0.5,
345
+ conv_bias=True,
346
+ batch_norm=True,
347
+ activation: type[nn.Module] = nn.ELU,
348
+ kernel_size_conv1: int = 1,
349
+ kernel_size_conv2: int = 3,
350
+ stride_conv1: int = 1,
351
+ stride_conv2: int = 1,
352
+ padding_conv2: int = 1,
353
+ ):
354
+ super(_DenseBlock, self).__init__()
355
+ for idx_layer in range(num_layers):
356
+ layer = _DenseLayer(
357
+ in_channels=in_channels + idx_layer * growth_rate,
358
+ growth_rate=growth_rate,
359
+ bottleneck_size=bottleneck_size,
360
+ drop_prob=drop_prob,
361
+ conv_bias=conv_bias,
362
+ batch_norm=batch_norm,
363
+ activation=activation,
364
+ kernel_size_conv1=kernel_size_conv1,
365
+ kernel_size_conv2=kernel_size_conv2,
366
+ stride_conv1=stride_conv1,
367
+ stride_conv2=stride_conv2,
368
+ padding_conv2=padding_conv2,
369
+ )
370
+ self.add_module(f"denselayer{idx_layer + 1}", layer)
371
+
372
+
373
+ class _TransitionLayer(nn.Sequential):
374
+ r"""
375
+ A pooling transition layer.
376
+
377
+ Parameters
378
+ ----------
379
+ in_channels : int
380
+ Number of input channels.
381
+ out_channels : int
382
+ Number of output channels.
383
+ conv_bias : bool, optional
384
+ Whether to use bias in convolutional layers. Default is True.
385
+ batch_norm : bool, optional
386
+ Whether to use batch normalization. Default is True.
387
+ activation: nn.Module, default=nn.ELU
388
+ Activation function class to apply. Should be a PyTorch activation
389
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
390
+
391
+ Examples
392
+ --------
393
+ >>> x = torch.randn(128, 5, 1000)
394
+ >>> model = _TransitionLayer(5, 18)
395
+ >>> y = model(x)
396
+ >>> y.shape
397
+ torch.Size([128, 18, 500])
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ in_channels,
403
+ out_channels,
404
+ conv_bias=True,
405
+ batch_norm=True,
406
+ activation: type[nn.Module] = nn.ELU,
407
+ kernel_size_trans: int = 2,
408
+ stride_trans: int = 2,
409
+ ):
410
+ super(_TransitionLayer, self).__init__()
411
+ if batch_norm:
412
+ self.add_module("norm", nn.BatchNorm1d(in_channels))
413
+ self.add_module("elu", activation())
414
+ self.add_module(
415
+ "conv",
416
+ nn.Conv1d(
417
+ in_channels=in_channels,
418
+ out_channels=out_channels,
419
+ kernel_size=1,
420
+ stride=1,
421
+ bias=conv_bias,
422
+ ),
423
+ )
424
+ self.add_module(
425
+ "pool", nn.AvgPool1d(kernel_size=kernel_size_trans, stride=stride_trans)
426
+ )