braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,30 @@
|
|
|
1
1
|
# Authors: Yonghao Song <eeyhsong@gmail.com>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
4
7
|
import torch
|
|
5
8
|
import torch.nn.functional as F
|
|
6
9
|
from einops import rearrange
|
|
7
10
|
from einops.layers.torch import Rearrange
|
|
8
|
-
from torch import
|
|
9
|
-
import warnings
|
|
11
|
+
from torch import Tensor, nn
|
|
10
12
|
|
|
11
|
-
from .base import EEGModuleMixin
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import FeedForwardBlock, MultiHeadAttention
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class EEGConformer(EEGModuleMixin, nn.Module):
|
|
15
|
-
"""EEG Conformer.
|
|
18
|
+
"""EEG Conformer from Song et al. (2022) from [song2022]_.
|
|
19
|
+
|
|
20
|
+
.. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: EEGConformer Architecture
|
|
16
23
|
|
|
17
24
|
Convolutional Transformer for EEG decoding.
|
|
18
25
|
|
|
19
26
|
The paper and original code with more details about the methodological
|
|
20
|
-
choices are available at the [
|
|
27
|
+
choices are available at the [song2022]_ and [ConformerCode]_.
|
|
21
28
|
|
|
22
29
|
This neural network architecture receives a traditional braindecode input.
|
|
23
30
|
The input shape should be three-dimensional matrix representing the EEG
|
|
@@ -68,15 +75,15 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
68
75
|
return_features: bool
|
|
69
76
|
If True, the forward method returns the features before the
|
|
70
77
|
last classification layer. Defaults to False.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
78
|
+
activation: nn.Module
|
|
79
|
+
Activation function as parameter. Default is nn.ELU
|
|
80
|
+
activation_transfor: nn.Module
|
|
81
|
+
Activation function as parameter, applied at the FeedForwardBlock module
|
|
82
|
+
inside the transformer. Default is nn.GeLU
|
|
83
|
+
|
|
77
84
|
References
|
|
78
85
|
----------
|
|
79
|
-
.. [
|
|
86
|
+
.. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
80
87
|
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
81
88
|
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
|
|
82
89
|
31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
|
|
@@ -86,34 +93,26 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
86
93
|
"""
|
|
87
94
|
|
|
88
95
|
def __init__(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
input_window_samples=None,
|
|
109
|
-
add_log_softmax=True,
|
|
96
|
+
self,
|
|
97
|
+
n_outputs=None,
|
|
98
|
+
n_chans=None,
|
|
99
|
+
n_filters_time=40,
|
|
100
|
+
filter_time_length=25,
|
|
101
|
+
pool_time_length=75,
|
|
102
|
+
pool_time_stride=15,
|
|
103
|
+
drop_prob=0.5,
|
|
104
|
+
att_depth=6,
|
|
105
|
+
att_heads=10,
|
|
106
|
+
att_drop_prob=0.5,
|
|
107
|
+
final_fc_length="auto",
|
|
108
|
+
return_features=False,
|
|
109
|
+
activation: nn.Module = nn.ELU,
|
|
110
|
+
activation_transfor: nn.Module = nn.GELU,
|
|
111
|
+
n_times=None,
|
|
112
|
+
chs_info=None,
|
|
113
|
+
input_window_seconds=None,
|
|
114
|
+
sfreq=None,
|
|
110
115
|
):
|
|
111
|
-
n_outputs, n_chans, n_times = deprecated_args(
|
|
112
|
-
self,
|
|
113
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
114
|
-
('n_channels', 'n_chans', n_channels, n_chans),
|
|
115
|
-
('input_window_samples', 'n_times', input_window_samples, n_times)
|
|
116
|
-
)
|
|
117
116
|
super().__init__(
|
|
118
117
|
n_outputs=n_outputs,
|
|
119
118
|
n_chans=n_chans,
|
|
@@ -121,19 +120,22 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
121
120
|
n_times=n_times,
|
|
122
121
|
input_window_seconds=input_window_seconds,
|
|
123
122
|
sfreq=sfreq,
|
|
124
|
-
add_log_softmax=add_log_softmax,
|
|
125
123
|
)
|
|
126
124
|
self.mapping = {
|
|
127
|
-
|
|
128
|
-
|
|
125
|
+
"classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
|
|
126
|
+
"classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
|
|
129
127
|
}
|
|
130
128
|
|
|
131
129
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
132
|
-
del n_classes, n_channels, input_window_samples
|
|
133
130
|
if not (self.n_chans <= 64):
|
|
134
|
-
warnings.warn(
|
|
135
|
-
|
|
136
|
-
|
|
131
|
+
warnings.warn(
|
|
132
|
+
"This model has only been tested on no more "
|
|
133
|
+
+ "than 64 channels. no guarantee to work with "
|
|
134
|
+
+ "more channels.",
|
|
135
|
+
UserWarning,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.return_features = return_features
|
|
137
139
|
|
|
138
140
|
self.patch_embedding = _PatchEmbedding(
|
|
139
141
|
n_filters_time=n_filters_time,
|
|
@@ -141,38 +143,42 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
141
143
|
n_channels=self.n_chans,
|
|
142
144
|
pool_time_length=pool_time_length,
|
|
143
145
|
stride_avg_pool=pool_time_stride,
|
|
144
|
-
drop_prob=drop_prob
|
|
146
|
+
drop_prob=drop_prob,
|
|
147
|
+
activation=activation,
|
|
148
|
+
)
|
|
145
149
|
|
|
146
150
|
if final_fc_length == "auto":
|
|
147
151
|
assert self.n_times is not None
|
|
148
|
-
final_fc_length = self.get_fc_size()
|
|
152
|
+
self.final_fc_length = self.get_fc_size()
|
|
149
153
|
|
|
150
154
|
self.transformer = _TransformerEncoder(
|
|
151
155
|
att_depth=att_depth,
|
|
152
156
|
emb_size=n_filters_time,
|
|
153
157
|
att_heads=att_heads,
|
|
154
|
-
att_drop=att_drop_prob
|
|
158
|
+
att_drop=att_drop_prob,
|
|
159
|
+
activation=activation_transfor,
|
|
160
|
+
)
|
|
155
161
|
|
|
156
162
|
self.fc = _FullyConnected(
|
|
157
|
-
final_fc_length=final_fc_length
|
|
163
|
+
final_fc_length=self.final_fc_length, activation=activation
|
|
164
|
+
)
|
|
158
165
|
|
|
159
|
-
self.final_layer =
|
|
160
|
-
return_features=return_features,
|
|
161
|
-
add_log_softmax=self.add_log_softmax)
|
|
166
|
+
self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
|
|
162
167
|
|
|
163
168
|
def forward(self, x: Tensor) -> Tensor:
|
|
164
169
|
x = torch.unsqueeze(x, dim=1) # add one extra dimension
|
|
165
170
|
x = self.patch_embedding(x)
|
|
166
|
-
|
|
167
|
-
|
|
171
|
+
feature = self.transformer(x)
|
|
172
|
+
|
|
173
|
+
if self.return_features:
|
|
174
|
+
return feature
|
|
175
|
+
|
|
176
|
+
x = self.fc(feature)
|
|
168
177
|
x = self.final_layer(x)
|
|
169
178
|
return x
|
|
170
179
|
|
|
171
180
|
def get_fc_size(self):
|
|
172
|
-
|
|
173
|
-
out = self.patch_embedding(torch.ones((1, 1,
|
|
174
|
-
self.n_chans,
|
|
175
|
-
self.n_times)))
|
|
181
|
+
out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
|
|
176
182
|
size_embedding_1 = out.cpu().data.numpy().shape[1]
|
|
177
183
|
size_embedding_2 = out.cpu().data.numpy().shape[2]
|
|
178
184
|
|
|
@@ -207,26 +213,24 @@ class _PatchEmbedding(nn.Module):
|
|
|
207
213
|
"""
|
|
208
214
|
|
|
209
215
|
def __init__(
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
216
|
+
self,
|
|
217
|
+
n_filters_time,
|
|
218
|
+
filter_time_length,
|
|
219
|
+
n_channels,
|
|
220
|
+
pool_time_length,
|
|
221
|
+
stride_avg_pool,
|
|
222
|
+
drop_prob,
|
|
223
|
+
activation: nn.Module = nn.ELU,
|
|
217
224
|
):
|
|
218
225
|
super().__init__()
|
|
219
226
|
|
|
220
227
|
self.shallownet = nn.Sequential(
|
|
221
|
-
nn.Conv2d(1, n_filters_time,
|
|
222
|
-
|
|
223
|
-
nn.Conv2d(n_filters_time, n_filters_time,
|
|
224
|
-
(n_channels, 1), (1, 1)),
|
|
228
|
+
nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
|
|
229
|
+
nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
|
|
225
230
|
nn.BatchNorm2d(num_features=n_filters_time),
|
|
226
|
-
|
|
231
|
+
activation(),
|
|
227
232
|
nn.AvgPool2d(
|
|
228
|
-
kernel_size=(1, pool_time_length),
|
|
229
|
-
stride=(1, stride_avg_pool)
|
|
233
|
+
kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
|
|
230
234
|
),
|
|
231
235
|
# pooling acts as slicing to obtain 'patch' along the
|
|
232
236
|
# time dimension as in ViT
|
|
@@ -246,79 +250,43 @@ class _PatchEmbedding(nn.Module):
|
|
|
246
250
|
return x
|
|
247
251
|
|
|
248
252
|
|
|
249
|
-
class _MultiHeadAttention(nn.Module):
|
|
250
|
-
def __init__(self, emb_size, num_heads, dropout):
|
|
251
|
-
super().__init__()
|
|
252
|
-
self.emb_size = emb_size
|
|
253
|
-
self.num_heads = num_heads
|
|
254
|
-
self.keys = nn.Linear(emb_size, emb_size)
|
|
255
|
-
self.queries = nn.Linear(emb_size, emb_size)
|
|
256
|
-
self.values = nn.Linear(emb_size, emb_size)
|
|
257
|
-
self.att_drop = nn.Dropout(dropout)
|
|
258
|
-
self.projection = nn.Linear(emb_size, emb_size)
|
|
259
|
-
|
|
260
|
-
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
|
261
|
-
queries = rearrange(
|
|
262
|
-
self.queries(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
263
|
-
)
|
|
264
|
-
keys = rearrange(
|
|
265
|
-
self.keys(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
266
|
-
)
|
|
267
|
-
values = rearrange(
|
|
268
|
-
self.values(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
269
|
-
)
|
|
270
|
-
energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
|
|
271
|
-
if mask is not None:
|
|
272
|
-
fill_value = torch.finfo(torch.float32).min
|
|
273
|
-
energy.mask_fill(~mask, fill_value)
|
|
274
|
-
|
|
275
|
-
scaling = self.emb_size ** (1 / 2)
|
|
276
|
-
att = F.softmax(energy / scaling, dim=-1)
|
|
277
|
-
att = self.att_drop(att)
|
|
278
|
-
out = torch.einsum("bhal, bhlv -> bhav ", att, values)
|
|
279
|
-
out = rearrange(out, "b h n d -> b n (h d)")
|
|
280
|
-
out = self.projection(out)
|
|
281
|
-
return out
|
|
282
|
-
|
|
283
|
-
|
|
284
253
|
class _ResidualAdd(nn.Module):
|
|
285
254
|
def __init__(self, fn):
|
|
286
255
|
super().__init__()
|
|
287
256
|
self.fn = fn
|
|
288
257
|
|
|
289
|
-
def forward(self, x
|
|
258
|
+
def forward(self, x):
|
|
290
259
|
res = x
|
|
291
|
-
x = self.fn(x
|
|
260
|
+
x = self.fn(x)
|
|
292
261
|
x += res
|
|
293
262
|
return x
|
|
294
263
|
|
|
295
264
|
|
|
296
|
-
class _FeedForwardBlock(nn.Sequential):
|
|
297
|
-
def __init__(self, emb_size, expansion, drop_p):
|
|
298
|
-
super().__init__(
|
|
299
|
-
nn.Linear(emb_size, expansion * emb_size),
|
|
300
|
-
nn.GELU(),
|
|
301
|
-
nn.Dropout(drop_p),
|
|
302
|
-
nn.Linear(expansion * emb_size, emb_size),
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
|
|
306
265
|
class _TransformerEncoderBlock(nn.Sequential):
|
|
307
|
-
def __init__(
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
emb_size,
|
|
269
|
+
att_heads,
|
|
270
|
+
att_drop,
|
|
271
|
+
forward_expansion=4,
|
|
272
|
+
activation: nn.Module = nn.GELU,
|
|
273
|
+
):
|
|
308
274
|
super().__init__(
|
|
309
275
|
_ResidualAdd(
|
|
310
276
|
nn.Sequential(
|
|
311
277
|
nn.LayerNorm(emb_size),
|
|
312
|
-
|
|
278
|
+
MultiHeadAttention(emb_size, att_heads, att_drop),
|
|
313
279
|
nn.Dropout(att_drop),
|
|
314
280
|
)
|
|
315
281
|
),
|
|
316
282
|
_ResidualAdd(
|
|
317
283
|
nn.Sequential(
|
|
318
284
|
nn.LayerNorm(emb_size),
|
|
319
|
-
|
|
320
|
-
emb_size,
|
|
321
|
-
|
|
285
|
+
FeedForwardBlock(
|
|
286
|
+
emb_size,
|
|
287
|
+
expansion=forward_expansion,
|
|
288
|
+
drop_p=att_drop,
|
|
289
|
+
activation=activation,
|
|
322
290
|
),
|
|
323
291
|
nn.Dropout(att_drop),
|
|
324
292
|
)
|
|
@@ -344,19 +312,29 @@ class _TransformerEncoder(nn.Sequential):
|
|
|
344
312
|
|
|
345
313
|
"""
|
|
346
314
|
|
|
347
|
-
def __init__(
|
|
315
|
+
def __init__(
|
|
316
|
+
self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
|
|
317
|
+
):
|
|
348
318
|
super().__init__(
|
|
349
319
|
*[
|
|
350
|
-
_TransformerEncoderBlock(
|
|
320
|
+
_TransformerEncoderBlock(
|
|
321
|
+
emb_size, att_heads, att_drop, activation=activation
|
|
322
|
+
)
|
|
351
323
|
for _ in range(att_depth)
|
|
352
324
|
]
|
|
353
325
|
)
|
|
354
326
|
|
|
355
327
|
|
|
356
328
|
class _FullyConnected(nn.Module):
|
|
357
|
-
def __init__(
|
|
358
|
-
|
|
359
|
-
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
final_fc_length,
|
|
332
|
+
drop_prob_1=0.5,
|
|
333
|
+
drop_prob_2=0.3,
|
|
334
|
+
out_channels=256,
|
|
335
|
+
hidden_channels=32,
|
|
336
|
+
activation: nn.Module = nn.ELU,
|
|
337
|
+
):
|
|
360
338
|
"""Fully-connected layer for the transformer encoder.
|
|
361
339
|
|
|
362
340
|
Parameters
|
|
@@ -375,17 +353,16 @@ class _FullyConnected(nn.Module):
|
|
|
375
353
|
Number of output channels for the second linear layer.
|
|
376
354
|
return_features : bool
|
|
377
355
|
Whether to return input features.
|
|
378
|
-
add_log_softmax: bool
|
|
379
|
-
Whether to add LogSoftmax non-linearity as the final layer.
|
|
380
356
|
"""
|
|
381
357
|
|
|
382
358
|
super().__init__()
|
|
359
|
+
self.hidden_channels = hidden_channels
|
|
383
360
|
self.fc = nn.Sequential(
|
|
384
361
|
nn.Linear(final_fc_length, out_channels),
|
|
385
|
-
|
|
362
|
+
activation(),
|
|
386
363
|
nn.Dropout(drop_prob_1),
|
|
387
364
|
nn.Linear(out_channels, hidden_channels),
|
|
388
|
-
|
|
365
|
+
activation(),
|
|
389
366
|
nn.Dropout(drop_prob_2),
|
|
390
367
|
)
|
|
391
368
|
|
|
@@ -393,40 +370,3 @@ class _FullyConnected(nn.Module):
|
|
|
393
370
|
x = x.contiguous().view(x.size(0), -1)
|
|
394
371
|
out = self.fc(x)
|
|
395
372
|
return out
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
class _FinalLayer(nn.Module):
|
|
399
|
-
def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
|
|
400
|
-
"""Classification head for the transformer encoder.
|
|
401
|
-
|
|
402
|
-
Parameters
|
|
403
|
-
----------
|
|
404
|
-
n_classes : int
|
|
405
|
-
Number of classes for classification.
|
|
406
|
-
hidden_channels : int
|
|
407
|
-
Number of output channels for the second linear layer.
|
|
408
|
-
return_features : bool
|
|
409
|
-
Whether to return input features.
|
|
410
|
-
add_log_softmax : bool
|
|
411
|
-
Adding LogSoftmax or not.
|
|
412
|
-
"""
|
|
413
|
-
|
|
414
|
-
super().__init__()
|
|
415
|
-
self.final_layer = nn.Sequential(
|
|
416
|
-
nn.Linear(hidden_channels, n_classes),
|
|
417
|
-
)
|
|
418
|
-
self.return_features = return_features
|
|
419
|
-
if add_log_softmax:
|
|
420
|
-
classification = nn.LogSoftmax(dim=1)
|
|
421
|
-
else:
|
|
422
|
-
classification = nn.Identity()
|
|
423
|
-
if not self.return_features:
|
|
424
|
-
self.final_layer.add_module("classification", classification)
|
|
425
|
-
|
|
426
|
-
def forward(self, x):
|
|
427
|
-
if self.return_features:
|
|
428
|
-
out = self.final_layer(x)
|
|
429
|
-
return out, x
|
|
430
|
-
else:
|
|
431
|
-
out = self.final_layer(x)
|
|
432
|
-
return out
|