braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
# Authors: Théo Gnassounou <theo.gnassounou@inria.fr>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
r"""DeepSleepNet from Supratak et al (2017) [Supratak2017]_.
|
|
12
|
+
|
|
13
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
14
|
+
|
|
15
|
+
.. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/master/img/deepsleepnet.png
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: DeepSleepNet Architecture
|
|
18
|
+
:width: 700px
|
|
19
|
+
|
|
20
|
+
.. rubric:: Architectural Overview
|
|
21
|
+
|
|
22
|
+
DeepSleepNet couples **dual-path convolution neural network representation learning** with
|
|
23
|
+
**sequence residual learning** via bidirectional LSTMs.
|
|
24
|
+
|
|
25
|
+
The network have:
|
|
26
|
+
|
|
27
|
+
- (i) learns complementary, time-frequency features from each
|
|
28
|
+
30-s epoch using **two parallel CNNs** (small vs. large first-layer filters), then
|
|
29
|
+
- (ii) models **temporal dependencies across epochs** using **two-layer BiLSTMs**
|
|
30
|
+
with a **residual shortcut** from the CNN features, and finally
|
|
31
|
+
- (iii) outputs per-epoch sleep stages. This design encodes both
|
|
32
|
+
epoch-local patterns and longer-range transition rules used by human scorers.
|
|
33
|
+
|
|
34
|
+
In term of implementation:
|
|
35
|
+
|
|
36
|
+
- (i) :class:`_RepresentationLearning` two CNNs extract epoch-wise features
|
|
37
|
+
(small-filter path for temporal precision; large-filter path for frequency precision);
|
|
38
|
+
- (ii) :class:`_SequenceResidualLearning` stacked BiLSTMs with peepholes + residual shortcut
|
|
39
|
+
inject temporal context while preserving CNN evidence;
|
|
40
|
+
- (iii) :class:`_Classifier` linear readout (softmax) for the five sleep stages.
|
|
41
|
+
|
|
42
|
+
.. rubric:: Macro Components
|
|
43
|
+
|
|
44
|
+
- :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
|
|
45
|
+
|
|
46
|
+
- *Operations.*
|
|
47
|
+
- **Small-filter CNN** 4 times:
|
|
48
|
+
|
|
49
|
+
- :class:`~torch.nn.Conv1d`
|
|
50
|
+
- :class:`~torch.nn.BatchNorm1d`
|
|
51
|
+
- :class:`~torch.nn.ReLU`
|
|
52
|
+
- :class:`~torch.nn.MaxPool1d` after.
|
|
53
|
+
|
|
54
|
+
- First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
|
|
55
|
+
|
|
56
|
+
- **Large-filter CNN**:
|
|
57
|
+
|
|
58
|
+
- Same stack but first conv uses **filter length ≈ 4·Fs** and
|
|
59
|
+
- **stride ≈ Fs/2** to emphasize *frequency* content.
|
|
60
|
+
|
|
61
|
+
- Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
|
|
62
|
+
|
|
63
|
+
- *Rationale.*
|
|
64
|
+
Two first-layer scales provide a **learned, dual-scale filter bank** that trades
|
|
65
|
+
temporal vs. frequency precision without hand-crafted features.
|
|
66
|
+
|
|
67
|
+
- :class:`_SequenceResidualLearning` (:class:`~torch.nn.BiLSTM` **context + residual fusion)**
|
|
68
|
+
|
|
69
|
+
- *Operations.*
|
|
70
|
+
- **Two-layer BiLSTM** with **peephole connections** processes the sequence of epoch embeddings
|
|
71
|
+
``{a_t}`` forward and backward; hidden states from both directions are **concatenated**.
|
|
72
|
+
- A **shortcut MLP** (fully connected + :class:`~torch.nn.BatchNorm1d` + :class:`~torch.nn.ReLU`) projects ``a_t`` to the BiLSTM output
|
|
73
|
+
dimension and is **added** (residual) to the :class:`~torch.nn.BiLSTM` output at each time step.
|
|
74
|
+
- *Role.* Encodes **stage-transition rules** and smooths predictions over time while preserving
|
|
75
|
+
salient CNN features via the residual path.
|
|
76
|
+
|
|
77
|
+
- :class:`_Classifier` **(epoch-wise prediction)**
|
|
78
|
+
|
|
79
|
+
- *Operations.*
|
|
80
|
+
- :class:`~torch.nn.Linear` to produce per-epoch class probabilities.
|
|
81
|
+
|
|
82
|
+
Original training uses two-step optimization: CNN pretraining on class-balanced data,
|
|
83
|
+
then end-to-end fine-tuning with sequential batches.
|
|
84
|
+
|
|
85
|
+
.. rubric:: Convolutional Details
|
|
86
|
+
|
|
87
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
88
|
+
|
|
89
|
+
Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
|
|
90
|
+
stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
|
|
91
|
+
(first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
|
|
92
|
+
epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
|
|
93
|
+
with **max pooling** for downsampling.
|
|
94
|
+
|
|
95
|
+
- **Spatial (how channels are processed).**
|
|
96
|
+
|
|
97
|
+
The original model operates on **single-channel** raw EEG; convolutions therefore mix only
|
|
98
|
+
along time (no spatial convolution across electrodes).
|
|
99
|
+
|
|
100
|
+
- **Spectral (how frequency information emerges).**
|
|
101
|
+
|
|
102
|
+
No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
|
|
103
|
+
*frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
|
|
104
|
+
together functioning as a **two-band learned filter bank** at the first layer.
|
|
105
|
+
|
|
106
|
+
.. rubric:: Attention / Sequential Modules
|
|
107
|
+
|
|
108
|
+
- **Type.** **Bidirectional LSTM** (two layers) with **peephole connections**; forward and
|
|
109
|
+
backward streams are independent and concatenated.
|
|
110
|
+
- **Shapes.** For a sequence of ``N`` epochs, the CNN produces ``{a_t} ∈ R^{D}``;
|
|
111
|
+
BiLSTM outputs ``h_t ∈ R^{2H}``; the shortcut MLP maps ``a_t → R^{2H}`` to enable
|
|
112
|
+
**element-wise residual addition**.
|
|
113
|
+
- **Role.** Models **long-range temporal dependencies** (e.g., persisting N2 without visible
|
|
114
|
+
K-complex/spindles), stabilizing per-epoch predictions.
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
.. rubric:: Additional Mechanisms
|
|
118
|
+
|
|
119
|
+
- **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
|
|
120
|
+
improving gradient flow and retaining discriminative content from representation learning.
|
|
121
|
+
- **Two-step training.**
|
|
122
|
+
|
|
123
|
+
- (i) **Pretrain** the CNN paths with class-balanced sampling;
|
|
124
|
+
- (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
|
|
125
|
+
sequence encoder.
|
|
126
|
+
|
|
127
|
+
- **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
|
|
128
|
+
does not leak across recordings.
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
.. rubric:: Usage and Configuration
|
|
132
|
+
|
|
133
|
+
- **Epoch pipeline.** Use **two parallel CNNs** with the first conv sized to **Fs/2** (small path)
|
|
134
|
+
and **4·Fs** (large path), with strides **Fs/16** and **Fs/2**, respectively; stack three more
|
|
135
|
+
conv blocks with small kernels, plus **max pooling** in each path. Concatenate path outputs
|
|
136
|
+
to form epoch embeddings.
|
|
137
|
+
- **Sequence encoder.** Apply **two-layer BiLSTM (peepholes)** over the sequence of embeddings;
|
|
138
|
+
add a **projection MLP** on the CNN features and **sum** with BiLSTM outputs (residual).
|
|
139
|
+
Finish with :class:`~torch.nn.Linear` per epoch.
|
|
140
|
+
- **Reference implementation.** See the official repository for a faithful implementation and
|
|
141
|
+
training scripts.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
activation_large: nn.Module, default=nn.ELU
|
|
146
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
147
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
148
|
+
activation_small: nn.Module, default=nn.ReLU
|
|
149
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
150
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
151
|
+
return_feats : bool
|
|
152
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
153
|
+
(before the final linear layer). If False, pass the features through
|
|
154
|
+
the final linear layer.
|
|
155
|
+
drop_prob : float, default=0.5
|
|
156
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
157
|
+
|
|
158
|
+
References
|
|
159
|
+
----------
|
|
160
|
+
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
161
|
+
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
162
|
+
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
163
|
+
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
n_outputs=5,
|
|
169
|
+
return_feats=False,
|
|
170
|
+
n_chans=None,
|
|
171
|
+
chs_info=None,
|
|
172
|
+
n_times=None,
|
|
173
|
+
input_window_seconds=None,
|
|
174
|
+
sfreq=None,
|
|
175
|
+
activation_large: type[nn.Module] = nn.ELU,
|
|
176
|
+
activation_small: type[nn.Module] = nn.ReLU,
|
|
177
|
+
drop_prob: float = 0.5,
|
|
178
|
+
):
|
|
179
|
+
super().__init__(
|
|
180
|
+
n_outputs=n_outputs,
|
|
181
|
+
n_chans=n_chans,
|
|
182
|
+
chs_info=chs_info,
|
|
183
|
+
n_times=n_times,
|
|
184
|
+
input_window_seconds=input_window_seconds,
|
|
185
|
+
sfreq=sfreq,
|
|
186
|
+
)
|
|
187
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
188
|
+
self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
|
|
189
|
+
self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
|
|
190
|
+
self.dropout = nn.Dropout(0.5)
|
|
191
|
+
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
192
|
+
self.fc = nn.Sequential(
|
|
193
|
+
nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.features_extractor = nn.Identity()
|
|
197
|
+
self.len_last_layer = 1024
|
|
198
|
+
self.return_feats = return_feats
|
|
199
|
+
|
|
200
|
+
# TODO: Add new way to handle return_features == True
|
|
201
|
+
if not return_feats:
|
|
202
|
+
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
203
|
+
else:
|
|
204
|
+
self.final_layer = nn.Identity()
|
|
205
|
+
|
|
206
|
+
def forward(self, x):
|
|
207
|
+
"""Forward pass.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
x: torch.Tensor
|
|
212
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
if x.ndim == 3:
|
|
216
|
+
x = x.unsqueeze(1)
|
|
217
|
+
|
|
218
|
+
x1 = self.cnn1(x)
|
|
219
|
+
x1 = x1.flatten(start_dim=1)
|
|
220
|
+
|
|
221
|
+
x2 = self.cnn2(x)
|
|
222
|
+
x2 = x2.flatten(start_dim=1)
|
|
223
|
+
|
|
224
|
+
x = torch.cat((x1, x2), dim=1)
|
|
225
|
+
x = self.dropout(x)
|
|
226
|
+
temp = x.clone()
|
|
227
|
+
temp = self.fc(temp)
|
|
228
|
+
x = x.unsqueeze(1)
|
|
229
|
+
x = self.bilstm(x)
|
|
230
|
+
x = x.squeeze()
|
|
231
|
+
x = torch.add(x, temp)
|
|
232
|
+
x = self.dropout(x)
|
|
233
|
+
|
|
234
|
+
feats = self.features_extractor(x)
|
|
235
|
+
|
|
236
|
+
if self.return_feats:
|
|
237
|
+
return feats
|
|
238
|
+
else:
|
|
239
|
+
return self.final_layer(feats)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class _SmallCNN(nn.Module):
|
|
243
|
+
r"""
|
|
244
|
+
Smaller filter sizes to learn temporal information.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
activation: nn.Module, default=nn.ReLU
|
|
249
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
250
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
251
|
+
drop_prob : float, default=0.5
|
|
252
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def __init__(self, activation: type[nn.Module] = nn.ReLU, drop_prob: float = 0.5):
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.conv1 = nn.Sequential(
|
|
258
|
+
nn.Conv2d(
|
|
259
|
+
in_channels=1,
|
|
260
|
+
out_channels=64,
|
|
261
|
+
kernel_size=(1, 50),
|
|
262
|
+
stride=(1, 6),
|
|
263
|
+
padding=(0, 22),
|
|
264
|
+
bias=False,
|
|
265
|
+
),
|
|
266
|
+
nn.BatchNorm2d(num_features=64),
|
|
267
|
+
activation(),
|
|
268
|
+
)
|
|
269
|
+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
|
|
270
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
271
|
+
self.conv2 = nn.Sequential(
|
|
272
|
+
nn.Conv2d(
|
|
273
|
+
in_channels=64,
|
|
274
|
+
out_channels=128,
|
|
275
|
+
kernel_size=(1, 8),
|
|
276
|
+
stride=1,
|
|
277
|
+
padding="same",
|
|
278
|
+
bias=False,
|
|
279
|
+
),
|
|
280
|
+
nn.BatchNorm2d(num_features=128),
|
|
281
|
+
activation(),
|
|
282
|
+
)
|
|
283
|
+
self.conv3 = nn.Sequential(
|
|
284
|
+
nn.Conv2d(
|
|
285
|
+
in_channels=128,
|
|
286
|
+
out_channels=128,
|
|
287
|
+
kernel_size=(1, 8),
|
|
288
|
+
stride=1,
|
|
289
|
+
padding="same",
|
|
290
|
+
bias=False,
|
|
291
|
+
),
|
|
292
|
+
nn.BatchNorm2d(num_features=128),
|
|
293
|
+
activation(),
|
|
294
|
+
)
|
|
295
|
+
self.conv4 = nn.Sequential(
|
|
296
|
+
nn.Conv2d(
|
|
297
|
+
in_channels=128,
|
|
298
|
+
out_channels=128,
|
|
299
|
+
kernel_size=(1, 8),
|
|
300
|
+
stride=1,
|
|
301
|
+
padding="same",
|
|
302
|
+
bias=False,
|
|
303
|
+
),
|
|
304
|
+
nn.BatchNorm2d(num_features=128),
|
|
305
|
+
activation(),
|
|
306
|
+
)
|
|
307
|
+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
|
|
308
|
+
|
|
309
|
+
def forward(self, x):
|
|
310
|
+
x = self.conv1(x)
|
|
311
|
+
x = self.dropout(self.pool1(x))
|
|
312
|
+
x = self.conv2(x)
|
|
313
|
+
x = self.conv3(x)
|
|
314
|
+
x = self.conv4(x)
|
|
315
|
+
x = self.pool2(x)
|
|
316
|
+
return x
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class _LargeCNN(nn.Module):
|
|
320
|
+
r"""
|
|
321
|
+
Larger filter sizes to learn frequency information.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
activation: nn.Module, default=nn.ELU
|
|
326
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
327
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
328
|
+
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def __init__(self, activation: type[nn.Module] = nn.ELU, drop_prob: float = 0.5):
|
|
332
|
+
super().__init__()
|
|
333
|
+
|
|
334
|
+
self.conv1 = nn.Sequential(
|
|
335
|
+
nn.Conv2d(
|
|
336
|
+
in_channels=1,
|
|
337
|
+
out_channels=64,
|
|
338
|
+
kernel_size=(1, 400),
|
|
339
|
+
stride=(1, 50),
|
|
340
|
+
padding=(0, 175),
|
|
341
|
+
bias=False,
|
|
342
|
+
),
|
|
343
|
+
nn.BatchNorm2d(num_features=64),
|
|
344
|
+
activation(),
|
|
345
|
+
)
|
|
346
|
+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
|
|
347
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
348
|
+
self.conv2 = nn.Sequential(
|
|
349
|
+
nn.Conv2d(
|
|
350
|
+
in_channels=64,
|
|
351
|
+
out_channels=128,
|
|
352
|
+
kernel_size=(1, 6),
|
|
353
|
+
stride=1,
|
|
354
|
+
padding="same",
|
|
355
|
+
bias=False,
|
|
356
|
+
),
|
|
357
|
+
nn.BatchNorm2d(num_features=128),
|
|
358
|
+
activation(),
|
|
359
|
+
)
|
|
360
|
+
self.conv3 = nn.Sequential(
|
|
361
|
+
nn.Conv2d(
|
|
362
|
+
in_channels=128,
|
|
363
|
+
out_channels=128,
|
|
364
|
+
kernel_size=(1, 6),
|
|
365
|
+
stride=1,
|
|
366
|
+
padding="same",
|
|
367
|
+
bias=False,
|
|
368
|
+
),
|
|
369
|
+
nn.BatchNorm2d(num_features=128),
|
|
370
|
+
activation(),
|
|
371
|
+
)
|
|
372
|
+
self.conv4 = nn.Sequential(
|
|
373
|
+
nn.Conv2d(
|
|
374
|
+
in_channels=128,
|
|
375
|
+
out_channels=128,
|
|
376
|
+
kernel_size=(1, 6),
|
|
377
|
+
stride=1,
|
|
378
|
+
padding="same",
|
|
379
|
+
bias=False,
|
|
380
|
+
),
|
|
381
|
+
nn.BatchNorm2d(num_features=128),
|
|
382
|
+
activation(),
|
|
383
|
+
)
|
|
384
|
+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
|
|
385
|
+
|
|
386
|
+
def forward(self, x):
|
|
387
|
+
x = self.conv1(x)
|
|
388
|
+
x = self.dropout(self.pool1(x))
|
|
389
|
+
x = self.conv2(x)
|
|
390
|
+
x = self.conv3(x)
|
|
391
|
+
x = self.conv4(x)
|
|
392
|
+
x = self.pool2(x)
|
|
393
|
+
return x
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class _BiLSTM(nn.Module):
|
|
397
|
+
def __init__(self, input_size, hidden_size, num_layers):
|
|
398
|
+
super(_BiLSTM, self).__init__()
|
|
399
|
+
self.hidden_size = hidden_size
|
|
400
|
+
self.num_layers = num_layers
|
|
401
|
+
self.lstm = nn.LSTM(
|
|
402
|
+
input_size,
|
|
403
|
+
hidden_size,
|
|
404
|
+
num_layers,
|
|
405
|
+
batch_first=True,
|
|
406
|
+
dropout=0.5,
|
|
407
|
+
bidirectional=True,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def forward(self, x):
|
|
411
|
+
# set initial hidden and cell states
|
|
412
|
+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
413
|
+
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
414
|
+
|
|
415
|
+
# forward propagate LSTM
|
|
416
|
+
out, _ = self.lstm(x, (h0, c0))
|
|
417
|
+
return out
|