braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,322 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from torch import nn
7
+ from torch.nn import init
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import (
11
+ AvgPool2dWithConv,
12
+ CombinedConv,
13
+ Ensure4d,
14
+ SqueezeFinalOutput,
15
+ )
16
+
17
+
18
+ class Deep4Net(EEGModuleMixin, nn.Sequential):
19
+ """Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
20
+
21
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
22
+ :align: center
23
+ :alt: CTNet Architecture
24
+
25
+ Model described in [Schirrmeister2017]_.
26
+
27
+ Parameters
28
+ ----------
29
+ final_conv_length: int | str
30
+ Length of the final convolution layer.
31
+ If set to "auto", n_times must not be None. Default: "auto".
32
+ n_filters_time: int
33
+ Number of temporal filters.
34
+ n_filters_spat: int
35
+ Number of spatial filters.
36
+ filter_time_length: int
37
+ Length of the temporal filter in layer 1.
38
+ pool_time_length: int
39
+ Length of temporal pooling filter.
40
+ pool_time_stride: int
41
+ Length of stride between temporal pooling filters.
42
+ n_filters_2: int
43
+ Number of temporal filters in layer 2.
44
+ filter_length_2: int
45
+ Length of the temporal filter in layer 2.
46
+ n_filters_3: int
47
+ Number of temporal filters in layer 3.
48
+ filter_length_3: int
49
+ Length of the temporal filter in layer 3.
50
+ n_filters_4: int
51
+ Number of temporal filters in layer 4.
52
+ filter_length_4: int
53
+ Length of the temporal filter in layer 4.
54
+ activation_first_conv_nonlin: nn.Module, default is nn.ELU
55
+ Non-linear activation function to be used after convolution in layer 1.
56
+ first_pool_mode: str
57
+ Pooling mode in layer 1. "max" or "mean".
58
+ first_pool_nonlin: callable
59
+ Non-linear activation function to be used after pooling in layer 1.
60
+ activation_later_conv_nonlin: nn.Module, default is nn.ELU
61
+ Non-linear activation function to be used after convolution in later layers.
62
+ later_pool_mode: str
63
+ Pooling mode in later layers. "max" or "mean".
64
+ later_pool_nonlin: callable
65
+ Non-linear activation function to be used after pooling in later layers.
66
+ drop_prob: float
67
+ Dropout probability.
68
+ split_first_layer: bool
69
+ Split first layer into temporal and spatial layers (True) or just use temporal (False).
70
+ There would be no non-linearity between the split layers.
71
+ batch_norm: bool
72
+ Whether to use batch normalisation.
73
+ batch_norm_alpha: float
74
+ Momentum for BatchNorm2d.
75
+ stride_before_pool: bool
76
+ Stride before pooling.
77
+
78
+
79
+ References
80
+ ----------
81
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
82
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
83
+ & Ball, T. (2017).
84
+ Deep learning with convolutional neural networks for EEG decoding and
85
+ visualization.
86
+ Human Brain Mapping , Aug. 2017.
87
+ Online: http://dx.doi.org/10.1002/hbm.23730
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ n_chans=None,
93
+ n_outputs=None,
94
+ n_times=None,
95
+ final_conv_length="auto",
96
+ n_filters_time=25,
97
+ n_filters_spat=25,
98
+ filter_time_length=10,
99
+ pool_time_length=3,
100
+ pool_time_stride=3,
101
+ n_filters_2=50,
102
+ filter_length_2=10,
103
+ n_filters_3=100,
104
+ filter_length_3=10,
105
+ n_filters_4=200,
106
+ filter_length_4=10,
107
+ activation_first_conv_nonlin: nn.Module = nn.ELU,
108
+ first_pool_mode="max",
109
+ first_pool_nonlin: nn.Module = nn.Identity,
110
+ activation_later_conv_nonlin: nn.Module = nn.ELU,
111
+ later_pool_mode="max",
112
+ later_pool_nonlin: nn.Module = nn.Identity,
113
+ drop_prob=0.5,
114
+ split_first_layer=True,
115
+ batch_norm=True,
116
+ batch_norm_alpha=0.1,
117
+ stride_before_pool=False,
118
+ chs_info=None,
119
+ input_window_seconds=None,
120
+ sfreq=None,
121
+ ):
122
+ super().__init__(
123
+ n_outputs=n_outputs,
124
+ n_chans=n_chans,
125
+ chs_info=chs_info,
126
+ n_times=n_times,
127
+ input_window_seconds=input_window_seconds,
128
+ sfreq=sfreq,
129
+ )
130
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
131
+
132
+ if final_conv_length == "auto":
133
+ assert self.n_times is not None
134
+ self.final_conv_length = final_conv_length
135
+ self.n_filters_time = n_filters_time
136
+ self.n_filters_spat = n_filters_spat
137
+ self.filter_time_length = filter_time_length
138
+ self.pool_time_length = pool_time_length
139
+ self.pool_time_stride = pool_time_stride
140
+ self.n_filters_2 = n_filters_2
141
+ self.filter_length_2 = filter_length_2
142
+ self.n_filters_3 = n_filters_3
143
+ self.filter_length_3 = filter_length_3
144
+ self.n_filters_4 = n_filters_4
145
+ self.filter_length_4 = filter_length_4
146
+ self.first_nonlin = activation_first_conv_nonlin
147
+ self.first_pool_mode = first_pool_mode
148
+ self.first_pool_nonlin = first_pool_nonlin
149
+ self.later_conv_nonlin = activation_later_conv_nonlin
150
+ self.later_pool_mode = later_pool_mode
151
+ self.later_pool_nonlin = later_pool_nonlin
152
+ self.drop_prob = drop_prob
153
+ self.split_first_layer = split_first_layer
154
+ self.batch_norm = batch_norm
155
+ self.batch_norm_alpha = batch_norm_alpha
156
+ self.stride_before_pool = stride_before_pool
157
+
158
+ # For the load_state_dict
159
+ # When padronize all layers,
160
+ # add the old's parameters here
161
+ self.mapping = {
162
+ "conv_time.weight": "conv_time_spat.conv_time.weight",
163
+ "conv_spat.weight": "conv_time_spat.conv_spat.weight",
164
+ "conv_time.bias": "conv_time_spat.conv_time.bias",
165
+ "conv_spat.bias": "conv_time_spat.conv_spat.bias",
166
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
167
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
168
+ }
169
+
170
+ if self.stride_before_pool:
171
+ conv_stride = self.pool_time_stride
172
+ pool_stride = 1
173
+ else:
174
+ conv_stride = 1
175
+ pool_stride = self.pool_time_stride
176
+ self.add_module("ensuredims", Ensure4d())
177
+ pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
178
+ first_pool_class = pool_class_dict[self.first_pool_mode]
179
+ later_pool_class = pool_class_dict[self.later_pool_mode]
180
+ if self.split_first_layer:
181
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
182
+ self.add_module(
183
+ "conv_time_spat",
184
+ CombinedConv(
185
+ in_chans=self.n_chans,
186
+ n_filters_time=self.n_filters_time,
187
+ n_filters_spat=self.n_filters_spat,
188
+ filter_time_length=filter_time_length,
189
+ bias_time=True,
190
+ bias_spat=not self.batch_norm,
191
+ ),
192
+ )
193
+ n_filters_conv = self.n_filters_spat
194
+ else:
195
+ self.add_module(
196
+ "conv_time",
197
+ nn.Conv2d(
198
+ self.n_chans,
199
+ self.n_filters_time,
200
+ (self.filter_time_length, 1),
201
+ stride=(conv_stride, 1),
202
+ bias=not self.batch_norm,
203
+ ),
204
+ )
205
+ n_filters_conv = self.n_filters_time
206
+ if self.batch_norm:
207
+ self.add_module(
208
+ "bnorm",
209
+ nn.BatchNorm2d(
210
+ n_filters_conv,
211
+ momentum=self.batch_norm_alpha,
212
+ affine=True,
213
+ eps=1e-5,
214
+ ),
215
+ )
216
+ self.add_module("conv_nonlin", self.first_nonlin())
217
+ self.add_module(
218
+ "pool",
219
+ first_pool_class(
220
+ kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
221
+ ),
222
+ )
223
+ self.add_module("pool_nonlin", self.first_pool_nonlin())
224
+
225
+ def add_conv_pool_block(
226
+ model, n_filters_before, n_filters, filter_length, block_nr
227
+ ):
228
+ suffix = "_{:d}".format(block_nr)
229
+ self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
230
+ self.add_module(
231
+ "conv" + suffix,
232
+ nn.Conv2d(
233
+ n_filters_before,
234
+ n_filters,
235
+ (filter_length, 1),
236
+ stride=(conv_stride, 1),
237
+ bias=not self.batch_norm,
238
+ ),
239
+ )
240
+ if self.batch_norm:
241
+ self.add_module(
242
+ "bnorm" + suffix,
243
+ nn.BatchNorm2d(
244
+ n_filters,
245
+ momentum=self.batch_norm_alpha,
246
+ affine=True,
247
+ eps=1e-5,
248
+ ),
249
+ )
250
+ self.add_module("nonlin" + suffix, self.later_conv_nonlin())
251
+
252
+ self.add_module(
253
+ "pool" + suffix,
254
+ later_pool_class(
255
+ kernel_size=(self.pool_time_length, 1),
256
+ stride=(pool_stride, 1),
257
+ ),
258
+ )
259
+ self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
260
+
261
+ add_conv_pool_block(
262
+ self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
263
+ )
264
+ add_conv_pool_block(
265
+ self, self.n_filters_2, self.n_filters_3, self.filter_length_3, 3
266
+ )
267
+ add_conv_pool_block(
268
+ self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
269
+ )
270
+
271
+ # self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
272
+ self.eval()
273
+ if self.final_conv_length == "auto":
274
+ self.final_conv_length = self.get_output_shape()[2]
275
+
276
+ # Incorporating classification module and subsequent ones in one final layer
277
+ module = nn.Sequential()
278
+
279
+ module.add_module(
280
+ "conv_classifier",
281
+ nn.Conv2d(
282
+ self.n_filters_4,
283
+ self.n_outputs,
284
+ (self.final_conv_length, 1),
285
+ bias=True,
286
+ ),
287
+ )
288
+
289
+ module.add_module("squeeze", SqueezeFinalOutput())
290
+
291
+ self.add_module("final_layer", module)
292
+
293
+ # Initialization, xavier is same as in our paper...
294
+ # was default from lasagne
295
+ init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
296
+ # maybe no bias in case of no split layer and batch norm
297
+ if self.split_first_layer or (not self.batch_norm):
298
+ init.constant_(self.conv_time_spat.conv_time.bias, 0)
299
+ if self.split_first_layer:
300
+ init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
301
+ if not self.batch_norm:
302
+ init.constant_(self.conv_spat.bias, 0)
303
+ if self.batch_norm:
304
+ init.constant_(self.bnorm.weight, 1)
305
+ init.constant_(self.bnorm.bias, 0)
306
+ param_dict = dict(list(self.named_parameters()))
307
+ for block_nr in range(2, 5):
308
+ conv_weight = param_dict["conv_{:d}.weight".format(block_nr)]
309
+ init.xavier_uniform_(conv_weight, gain=1)
310
+ if not self.batch_norm:
311
+ conv_bias = param_dict["conv_{:d}.bias".format(block_nr)]
312
+ init.constant_(conv_bias, 0)
313
+ else:
314
+ bnorm_weight = param_dict["bnorm_{:d}.weight".format(block_nr)]
315
+ bnorm_bias = param_dict["bnorm_{:d}.bias".format(block_nr)]
316
+ init.constant_(bnorm_weight, 1)
317
+ init.constant_(bnorm_bias, 0)
318
+
319
+ init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
320
+ init.constant_(self.final_layer.conv_classifier.bias, 0)
321
+
322
+ self.train()
@@ -0,0 +1,295 @@
1
+ # Authors: Théo Gnassounou <theo.gnassounou@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class DeepSleepNet(EEGModuleMixin, nn.Module):
11
+ """Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
12
+
13
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
14
+ :align: center
15
+ :alt: DeepSleepNet Architecture
16
+
17
+ Convolutional neural network and bidirectional-Long Short-Term
18
+ for single channels sleep staging described in [Supratak2017]_.
19
+
20
+ Parameters
21
+ ----------
22
+ activation_large: nn.Module, default=nn.ELU
23
+ Activation function class to apply. Should be a PyTorch activation
24
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
25
+ activation_small: nn.Module, default=nn.ReLU
26
+ Activation function class to apply. Should be a PyTorch activation
27
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
28
+ return_feats : bool
29
+ If True, return the features, i.e. the output of the feature extractor
30
+ (before the final linear layer). If False, pass the features through
31
+ the final linear layer.
32
+ drop_prob : float, default=0.5
33
+ The dropout rate for regularization. Values should be between 0 and 1.
34
+
35
+
36
+ References
37
+ ----------
38
+ .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
39
+ DeepSleepNet: A model for automatic sleep stage scoring based
40
+ on raw single-channel EEG. IEEE Transactions on Neural Systems
41
+ and Rehabilitation Engineering, 25(11), 1998-2008.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ n_outputs=5,
47
+ return_feats=False,
48
+ n_chans=None,
49
+ chs_info=None,
50
+ n_times=None,
51
+ input_window_seconds=None,
52
+ sfreq=None,
53
+ activation_large: nn.Module = nn.ELU,
54
+ activation_small: nn.Module = nn.ReLU,
55
+ drop_prob: float = 0.5,
56
+ ):
57
+ super().__init__(
58
+ n_outputs=n_outputs,
59
+ n_chans=n_chans,
60
+ chs_info=chs_info,
61
+ n_times=n_times,
62
+ input_window_seconds=input_window_seconds,
63
+ sfreq=sfreq,
64
+ )
65
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
66
+ self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
67
+ self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
68
+ self.dropout = nn.Dropout(0.5)
69
+ self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
70
+ self.fc = nn.Sequential(
71
+ nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
72
+ )
73
+
74
+ self.features_extractor = nn.Identity()
75
+ self.len_last_layer = 1024
76
+ self.return_feats = return_feats
77
+
78
+ # TODO: Add new way to handle return_features == True
79
+ if not return_feats:
80
+ self.final_layer = nn.Linear(1024, self.n_outputs)
81
+ else:
82
+ self.final_layer = nn.Identity()
83
+
84
+ def forward(self, x):
85
+ """Forward pass.
86
+
87
+ Parameters
88
+ ----------
89
+ x: torch.Tensor
90
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
91
+ """
92
+
93
+ if x.ndim == 3:
94
+ x = x.unsqueeze(1)
95
+
96
+ x1 = self.cnn1(x)
97
+ x1 = x1.flatten(start_dim=1)
98
+
99
+ x2 = self.cnn2(x)
100
+ x2 = x2.flatten(start_dim=1)
101
+
102
+ x = torch.cat((x1, x2), dim=1)
103
+ x = self.dropout(x)
104
+ temp = x.clone()
105
+ temp = self.fc(temp)
106
+ x = x.unsqueeze(1)
107
+ x = self.bilstm(x)
108
+ x = x.squeeze()
109
+ x = torch.add(x, temp)
110
+ x = self.dropout(x)
111
+
112
+ feats = self.features_extractor(x)
113
+
114
+ if self.return_feats:
115
+ return feats
116
+ else:
117
+ return self.final_layer(feats)
118
+
119
+
120
+ class _SmallCNN(nn.Module):
121
+ """
122
+ Smaller filter sizes to learn temporal information.
123
+
124
+ Parameters
125
+ ----------
126
+ activation: nn.Module, default=nn.ReLU
127
+ Activation function class to apply. Should be a PyTorch activation
128
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
129
+ drop_prob : float, default=0.5
130
+ The dropout rate for regularization. Values should be between 0 and 1.
131
+ """
132
+
133
+ def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
134
+ super().__init__()
135
+ self.conv1 = nn.Sequential(
136
+ nn.Conv2d(
137
+ in_channels=1,
138
+ out_channels=64,
139
+ kernel_size=(1, 50),
140
+ stride=(1, 6),
141
+ padding=(0, 22),
142
+ bias=False,
143
+ ),
144
+ nn.BatchNorm2d(num_features=64),
145
+ activation(),
146
+ )
147
+ self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
148
+ self.dropout = nn.Dropout(p=drop_prob)
149
+ self.conv2 = nn.Sequential(
150
+ nn.Conv2d(
151
+ in_channels=64,
152
+ out_channels=128,
153
+ kernel_size=(1, 8),
154
+ stride=1,
155
+ padding="same",
156
+ bias=False,
157
+ ),
158
+ nn.BatchNorm2d(num_features=128),
159
+ activation(),
160
+ )
161
+ self.conv3 = nn.Sequential(
162
+ nn.Conv2d(
163
+ in_channels=128,
164
+ out_channels=128,
165
+ kernel_size=(1, 8),
166
+ stride=1,
167
+ padding="same",
168
+ bias=False,
169
+ ),
170
+ nn.BatchNorm2d(num_features=128),
171
+ activation(),
172
+ )
173
+ self.conv4 = nn.Sequential(
174
+ nn.Conv2d(
175
+ in_channels=128,
176
+ out_channels=128,
177
+ kernel_size=(1, 8),
178
+ stride=1,
179
+ padding="same",
180
+ bias=False,
181
+ ),
182
+ nn.BatchNorm2d(num_features=128),
183
+ activation(),
184
+ )
185
+ self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
186
+
187
+ def forward(self, x):
188
+ x = self.conv1(x)
189
+ x = self.dropout(self.pool1(x))
190
+ x = self.conv2(x)
191
+ x = self.conv3(x)
192
+ x = self.conv4(x)
193
+ x = self.pool2(x)
194
+ return x
195
+
196
+
197
+ class _LargeCNN(nn.Module):
198
+ """
199
+ Larger filter sizes to learn frequency information.
200
+
201
+ Parameters
202
+ ----------
203
+ activation: nn.Module, default=nn.ELU
204
+ Activation function class to apply. Should be a PyTorch activation
205
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
206
+
207
+ """
208
+
209
+ def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
210
+ super().__init__()
211
+
212
+ self.conv1 = nn.Sequential(
213
+ nn.Conv2d(
214
+ in_channels=1,
215
+ out_channels=64,
216
+ kernel_size=(1, 400),
217
+ stride=(1, 50),
218
+ padding=(0, 175),
219
+ bias=False,
220
+ ),
221
+ nn.BatchNorm2d(num_features=64),
222
+ activation(),
223
+ )
224
+ self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
225
+ self.dropout = nn.Dropout(p=drop_prob)
226
+ self.conv2 = nn.Sequential(
227
+ nn.Conv2d(
228
+ in_channels=64,
229
+ out_channels=128,
230
+ kernel_size=(1, 6),
231
+ stride=1,
232
+ padding="same",
233
+ bias=False,
234
+ ),
235
+ nn.BatchNorm2d(num_features=128),
236
+ activation(),
237
+ )
238
+ self.conv3 = nn.Sequential(
239
+ nn.Conv2d(
240
+ in_channels=128,
241
+ out_channels=128,
242
+ kernel_size=(1, 6),
243
+ stride=1,
244
+ padding="same",
245
+ bias=False,
246
+ ),
247
+ nn.BatchNorm2d(num_features=128),
248
+ activation(),
249
+ )
250
+ self.conv4 = nn.Sequential(
251
+ nn.Conv2d(
252
+ in_channels=128,
253
+ out_channels=128,
254
+ kernel_size=(1, 6),
255
+ stride=1,
256
+ padding="same",
257
+ bias=False,
258
+ ),
259
+ nn.BatchNorm2d(num_features=128),
260
+ activation(),
261
+ )
262
+ self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
263
+
264
+ def forward(self, x):
265
+ x = self.conv1(x)
266
+ x = self.dropout(self.pool1(x))
267
+ x = self.conv2(x)
268
+ x = self.conv3(x)
269
+ x = self.conv4(x)
270
+ x = self.pool2(x)
271
+ return x
272
+
273
+
274
+ class _BiLSTM(nn.Module):
275
+ def __init__(self, input_size, hidden_size, num_layers):
276
+ super(_BiLSTM, self).__init__()
277
+ self.hidden_size = hidden_size
278
+ self.num_layers = num_layers
279
+ self.lstm = nn.LSTM(
280
+ input_size,
281
+ hidden_size,
282
+ num_layers,
283
+ batch_first=True,
284
+ dropout=0.5,
285
+ bidirectional=True,
286
+ )
287
+
288
+ def forward(self, x):
289
+ # set initial hidden and cell states
290
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
291
+ c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
292
+
293
+ # forward propagate LSTM
294
+ out, _ = self.lstm(x, (h0, c0))
295
+ return out