braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,27 @@
|
|
|
1
1
|
# Authors: Yonghao Song <eeyhsong@gmail.com>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
4
6
|
import torch
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
from einops import rearrange
|
|
7
7
|
from einops.layers.torch import Rearrange
|
|
8
|
-
from torch import
|
|
9
|
-
import warnings
|
|
8
|
+
from torch import Tensor, nn
|
|
10
9
|
|
|
11
|
-
from .base import EEGModuleMixin
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import FeedForwardBlock, MultiHeadAttention
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class EEGConformer(EEGModuleMixin, nn.Module):
|
|
15
|
-
"""EEG Conformer.
|
|
15
|
+
"""EEG Conformer from Song et al. (2022) from [song2022]_.
|
|
16
|
+
|
|
17
|
+
.. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
|
|
18
|
+
:align: center
|
|
19
|
+
:alt: EEGConformer Architecture
|
|
16
20
|
|
|
17
21
|
Convolutional Transformer for EEG decoding.
|
|
18
22
|
|
|
19
23
|
The paper and original code with more details about the methodological
|
|
20
|
-
choices are available at the [
|
|
24
|
+
choices are available at the [song2022]_ and [ConformerCode]_.
|
|
21
25
|
|
|
22
26
|
This neural network architecture receives a traditional braindecode input.
|
|
23
27
|
The input shape should be three-dimensional matrix representing the EEG
|
|
@@ -68,15 +72,15 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
68
72
|
return_features: bool
|
|
69
73
|
If True, the forward method returns the features before the
|
|
70
74
|
last classification layer. Defaults to False.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
activation: nn.Module
|
|
76
|
+
Activation function as parameter. Default is nn.ELU
|
|
77
|
+
activation_transfor: nn.Module
|
|
78
|
+
Activation function as parameter, applied at the FeedForwardBlock module
|
|
79
|
+
inside the transformer. Default is nn.GeLU
|
|
80
|
+
|
|
77
81
|
References
|
|
78
82
|
----------
|
|
79
|
-
.. [
|
|
83
|
+
.. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
80
84
|
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
81
85
|
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
|
|
82
86
|
31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
|
|
@@ -86,34 +90,26 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
86
90
|
"""
|
|
87
91
|
|
|
88
92
|
def __init__(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
input_window_samples=None,
|
|
109
|
-
add_log_softmax=True,
|
|
93
|
+
self,
|
|
94
|
+
n_outputs=None,
|
|
95
|
+
n_chans=None,
|
|
96
|
+
n_filters_time=40,
|
|
97
|
+
filter_time_length=25,
|
|
98
|
+
pool_time_length=75,
|
|
99
|
+
pool_time_stride=15,
|
|
100
|
+
drop_prob=0.5,
|
|
101
|
+
att_depth=6,
|
|
102
|
+
att_heads=10,
|
|
103
|
+
att_drop_prob=0.5,
|
|
104
|
+
final_fc_length="auto",
|
|
105
|
+
return_features=False,
|
|
106
|
+
activation: nn.Module = nn.ELU,
|
|
107
|
+
activation_transfor: nn.Module = nn.GELU,
|
|
108
|
+
n_times=None,
|
|
109
|
+
chs_info=None,
|
|
110
|
+
input_window_seconds=None,
|
|
111
|
+
sfreq=None,
|
|
110
112
|
):
|
|
111
|
-
n_outputs, n_chans, n_times = deprecated_args(
|
|
112
|
-
self,
|
|
113
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
114
|
-
('n_channels', 'n_chans', n_channels, n_chans),
|
|
115
|
-
('input_window_samples', 'n_times', input_window_samples, n_times)
|
|
116
|
-
)
|
|
117
113
|
super().__init__(
|
|
118
114
|
n_outputs=n_outputs,
|
|
119
115
|
n_chans=n_chans,
|
|
@@ -121,19 +117,22 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
121
117
|
n_times=n_times,
|
|
122
118
|
input_window_seconds=input_window_seconds,
|
|
123
119
|
sfreq=sfreq,
|
|
124
|
-
add_log_softmax=add_log_softmax,
|
|
125
120
|
)
|
|
126
121
|
self.mapping = {
|
|
127
|
-
|
|
128
|
-
|
|
122
|
+
"classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
|
|
123
|
+
"classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
|
|
129
124
|
}
|
|
130
125
|
|
|
131
126
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
132
|
-
del n_classes, n_channels, input_window_samples
|
|
133
127
|
if not (self.n_chans <= 64):
|
|
134
|
-
warnings.warn(
|
|
135
|
-
|
|
136
|
-
|
|
128
|
+
warnings.warn(
|
|
129
|
+
"This model has only been tested on no more "
|
|
130
|
+
+ "than 64 channels. no guarantee to work with "
|
|
131
|
+
+ "more channels.",
|
|
132
|
+
UserWarning,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.return_features = return_features
|
|
137
136
|
|
|
138
137
|
self.patch_embedding = _PatchEmbedding(
|
|
139
138
|
n_filters_time=n_filters_time,
|
|
@@ -141,38 +140,44 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
141
140
|
n_channels=self.n_chans,
|
|
142
141
|
pool_time_length=pool_time_length,
|
|
143
142
|
stride_avg_pool=pool_time_stride,
|
|
144
|
-
drop_prob=drop_prob
|
|
143
|
+
drop_prob=drop_prob,
|
|
144
|
+
activation=activation,
|
|
145
|
+
)
|
|
145
146
|
|
|
146
147
|
if final_fc_length == "auto":
|
|
147
148
|
assert self.n_times is not None
|
|
148
|
-
final_fc_length = self.get_fc_size()
|
|
149
|
+
self.final_fc_length = self.get_fc_size()
|
|
150
|
+
else:
|
|
151
|
+
self.final_fc_length = final_fc_length
|
|
149
152
|
|
|
150
153
|
self.transformer = _TransformerEncoder(
|
|
151
154
|
att_depth=att_depth,
|
|
152
155
|
emb_size=n_filters_time,
|
|
153
156
|
att_heads=att_heads,
|
|
154
|
-
att_drop=att_drop_prob
|
|
157
|
+
att_drop=att_drop_prob,
|
|
158
|
+
activation=activation_transfor,
|
|
159
|
+
)
|
|
155
160
|
|
|
156
161
|
self.fc = _FullyConnected(
|
|
157
|
-
final_fc_length=final_fc_length
|
|
162
|
+
final_fc_length=self.final_fc_length, activation=activation
|
|
163
|
+
)
|
|
158
164
|
|
|
159
|
-
self.final_layer =
|
|
160
|
-
return_features=return_features,
|
|
161
|
-
add_log_softmax=self.add_log_softmax)
|
|
165
|
+
self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
|
|
162
166
|
|
|
163
167
|
def forward(self, x: Tensor) -> Tensor:
|
|
164
168
|
x = torch.unsqueeze(x, dim=1) # add one extra dimension
|
|
165
169
|
x = self.patch_embedding(x)
|
|
166
|
-
|
|
167
|
-
|
|
170
|
+
feature = self.transformer(x)
|
|
171
|
+
|
|
172
|
+
if self.return_features:
|
|
173
|
+
return feature
|
|
174
|
+
|
|
175
|
+
x = self.fc(feature)
|
|
168
176
|
x = self.final_layer(x)
|
|
169
177
|
return x
|
|
170
178
|
|
|
171
179
|
def get_fc_size(self):
|
|
172
|
-
|
|
173
|
-
out = self.patch_embedding(torch.ones((1, 1,
|
|
174
|
-
self.n_chans,
|
|
175
|
-
self.n_times)))
|
|
180
|
+
out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
|
|
176
181
|
size_embedding_1 = out.cpu().data.numpy().shape[1]
|
|
177
182
|
size_embedding_2 = out.cpu().data.numpy().shape[2]
|
|
178
183
|
|
|
@@ -207,26 +212,24 @@ class _PatchEmbedding(nn.Module):
|
|
|
207
212
|
"""
|
|
208
213
|
|
|
209
214
|
def __init__(
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
215
|
+
self,
|
|
216
|
+
n_filters_time,
|
|
217
|
+
filter_time_length,
|
|
218
|
+
n_channels,
|
|
219
|
+
pool_time_length,
|
|
220
|
+
stride_avg_pool,
|
|
221
|
+
drop_prob,
|
|
222
|
+
activation: nn.Module = nn.ELU,
|
|
217
223
|
):
|
|
218
224
|
super().__init__()
|
|
219
225
|
|
|
220
226
|
self.shallownet = nn.Sequential(
|
|
221
|
-
nn.Conv2d(1, n_filters_time,
|
|
222
|
-
|
|
223
|
-
nn.Conv2d(n_filters_time, n_filters_time,
|
|
224
|
-
(n_channels, 1), (1, 1)),
|
|
227
|
+
nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
|
|
228
|
+
nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
|
|
225
229
|
nn.BatchNorm2d(num_features=n_filters_time),
|
|
226
|
-
|
|
230
|
+
activation(),
|
|
227
231
|
nn.AvgPool2d(
|
|
228
|
-
kernel_size=(1, pool_time_length),
|
|
229
|
-
stride=(1, stride_avg_pool)
|
|
232
|
+
kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
|
|
230
233
|
),
|
|
231
234
|
# pooling acts as slicing to obtain 'patch' along the
|
|
232
235
|
# time dimension as in ViT
|
|
@@ -246,79 +249,43 @@ class _PatchEmbedding(nn.Module):
|
|
|
246
249
|
return x
|
|
247
250
|
|
|
248
251
|
|
|
249
|
-
class _MultiHeadAttention(nn.Module):
|
|
250
|
-
def __init__(self, emb_size, num_heads, dropout):
|
|
251
|
-
super().__init__()
|
|
252
|
-
self.emb_size = emb_size
|
|
253
|
-
self.num_heads = num_heads
|
|
254
|
-
self.keys = nn.Linear(emb_size, emb_size)
|
|
255
|
-
self.queries = nn.Linear(emb_size, emb_size)
|
|
256
|
-
self.values = nn.Linear(emb_size, emb_size)
|
|
257
|
-
self.att_drop = nn.Dropout(dropout)
|
|
258
|
-
self.projection = nn.Linear(emb_size, emb_size)
|
|
259
|
-
|
|
260
|
-
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
|
261
|
-
queries = rearrange(
|
|
262
|
-
self.queries(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
263
|
-
)
|
|
264
|
-
keys = rearrange(
|
|
265
|
-
self.keys(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
266
|
-
)
|
|
267
|
-
values = rearrange(
|
|
268
|
-
self.values(x), "b n (h d) -> b h n d", h=self.num_heads
|
|
269
|
-
)
|
|
270
|
-
energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
|
|
271
|
-
if mask is not None:
|
|
272
|
-
fill_value = torch.finfo(torch.float32).min
|
|
273
|
-
energy.mask_fill(~mask, fill_value)
|
|
274
|
-
|
|
275
|
-
scaling = self.emb_size ** (1 / 2)
|
|
276
|
-
att = F.softmax(energy / scaling, dim=-1)
|
|
277
|
-
att = self.att_drop(att)
|
|
278
|
-
out = torch.einsum("bhal, bhlv -> bhav ", att, values)
|
|
279
|
-
out = rearrange(out, "b h n d -> b n (h d)")
|
|
280
|
-
out = self.projection(out)
|
|
281
|
-
return out
|
|
282
|
-
|
|
283
|
-
|
|
284
252
|
class _ResidualAdd(nn.Module):
|
|
285
253
|
def __init__(self, fn):
|
|
286
254
|
super().__init__()
|
|
287
255
|
self.fn = fn
|
|
288
256
|
|
|
289
|
-
def forward(self, x
|
|
257
|
+
def forward(self, x):
|
|
290
258
|
res = x
|
|
291
|
-
x = self.fn(x
|
|
259
|
+
x = self.fn(x)
|
|
292
260
|
x += res
|
|
293
261
|
return x
|
|
294
262
|
|
|
295
263
|
|
|
296
|
-
class _FeedForwardBlock(nn.Sequential):
|
|
297
|
-
def __init__(self, emb_size, expansion, drop_p):
|
|
298
|
-
super().__init__(
|
|
299
|
-
nn.Linear(emb_size, expansion * emb_size),
|
|
300
|
-
nn.GELU(),
|
|
301
|
-
nn.Dropout(drop_p),
|
|
302
|
-
nn.Linear(expansion * emb_size, emb_size),
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
|
|
306
264
|
class _TransformerEncoderBlock(nn.Sequential):
|
|
307
|
-
def __init__(
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
emb_size,
|
|
268
|
+
att_heads,
|
|
269
|
+
att_drop,
|
|
270
|
+
forward_expansion=4,
|
|
271
|
+
activation: nn.Module = nn.GELU,
|
|
272
|
+
):
|
|
308
273
|
super().__init__(
|
|
309
274
|
_ResidualAdd(
|
|
310
275
|
nn.Sequential(
|
|
311
276
|
nn.LayerNorm(emb_size),
|
|
312
|
-
|
|
277
|
+
MultiHeadAttention(emb_size, att_heads, att_drop),
|
|
313
278
|
nn.Dropout(att_drop),
|
|
314
279
|
)
|
|
315
280
|
),
|
|
316
281
|
_ResidualAdd(
|
|
317
282
|
nn.Sequential(
|
|
318
283
|
nn.LayerNorm(emb_size),
|
|
319
|
-
|
|
320
|
-
emb_size,
|
|
321
|
-
|
|
284
|
+
FeedForwardBlock(
|
|
285
|
+
emb_size,
|
|
286
|
+
expansion=forward_expansion,
|
|
287
|
+
drop_p=att_drop,
|
|
288
|
+
activation=activation,
|
|
322
289
|
),
|
|
323
290
|
nn.Dropout(att_drop),
|
|
324
291
|
)
|
|
@@ -344,19 +311,29 @@ class _TransformerEncoder(nn.Sequential):
|
|
|
344
311
|
|
|
345
312
|
"""
|
|
346
313
|
|
|
347
|
-
def __init__(
|
|
314
|
+
def __init__(
|
|
315
|
+
self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
|
|
316
|
+
):
|
|
348
317
|
super().__init__(
|
|
349
318
|
*[
|
|
350
|
-
_TransformerEncoderBlock(
|
|
319
|
+
_TransformerEncoderBlock(
|
|
320
|
+
emb_size, att_heads, att_drop, activation=activation
|
|
321
|
+
)
|
|
351
322
|
for _ in range(att_depth)
|
|
352
323
|
]
|
|
353
324
|
)
|
|
354
325
|
|
|
355
326
|
|
|
356
327
|
class _FullyConnected(nn.Module):
|
|
357
|
-
def __init__(
|
|
358
|
-
|
|
359
|
-
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
final_fc_length,
|
|
331
|
+
drop_prob_1=0.5,
|
|
332
|
+
drop_prob_2=0.3,
|
|
333
|
+
out_channels=256,
|
|
334
|
+
hidden_channels=32,
|
|
335
|
+
activation: nn.Module = nn.ELU,
|
|
336
|
+
):
|
|
360
337
|
"""Fully-connected layer for the transformer encoder.
|
|
361
338
|
|
|
362
339
|
Parameters
|
|
@@ -375,17 +352,16 @@ class _FullyConnected(nn.Module):
|
|
|
375
352
|
Number of output channels for the second linear layer.
|
|
376
353
|
return_features : bool
|
|
377
354
|
Whether to return input features.
|
|
378
|
-
add_log_softmax: bool
|
|
379
|
-
Whether to add LogSoftmax non-linearity as the final layer.
|
|
380
355
|
"""
|
|
381
356
|
|
|
382
357
|
super().__init__()
|
|
358
|
+
self.hidden_channels = hidden_channels
|
|
383
359
|
self.fc = nn.Sequential(
|
|
384
360
|
nn.Linear(final_fc_length, out_channels),
|
|
385
|
-
|
|
361
|
+
activation(),
|
|
386
362
|
nn.Dropout(drop_prob_1),
|
|
387
363
|
nn.Linear(out_channels, hidden_channels),
|
|
388
|
-
|
|
364
|
+
activation(),
|
|
389
365
|
nn.Dropout(drop_prob_2),
|
|
390
366
|
)
|
|
391
367
|
|
|
@@ -393,40 +369,3 @@ class _FullyConnected(nn.Module):
|
|
|
393
369
|
x = x.contiguous().view(x.size(0), -1)
|
|
394
370
|
out = self.fc(x)
|
|
395
371
|
return out
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
class _FinalLayer(nn.Module):
|
|
399
|
-
def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
|
|
400
|
-
"""Classification head for the transformer encoder.
|
|
401
|
-
|
|
402
|
-
Parameters
|
|
403
|
-
----------
|
|
404
|
-
n_classes : int
|
|
405
|
-
Number of classes for classification.
|
|
406
|
-
hidden_channels : int
|
|
407
|
-
Number of output channels for the second linear layer.
|
|
408
|
-
return_features : bool
|
|
409
|
-
Whether to return input features.
|
|
410
|
-
add_log_softmax : bool
|
|
411
|
-
Adding LogSoftmax or not.
|
|
412
|
-
"""
|
|
413
|
-
|
|
414
|
-
super().__init__()
|
|
415
|
-
self.final_layer = nn.Sequential(
|
|
416
|
-
nn.Linear(hidden_channels, n_classes),
|
|
417
|
-
)
|
|
418
|
-
self.return_features = return_features
|
|
419
|
-
if add_log_softmax:
|
|
420
|
-
classification = nn.LogSoftmax(dim=1)
|
|
421
|
-
else:
|
|
422
|
-
classification = nn.Identity()
|
|
423
|
-
if not self.return_features:
|
|
424
|
-
self.final_layer.add_module("classification", classification)
|
|
425
|
-
|
|
426
|
-
def forward(self, x):
|
|
427
|
-
if self.return_features:
|
|
428
|
-
out = self.final_layer(x)
|
|
429
|
-
return out, x
|
|
430
|
-
else:
|
|
431
|
-
out = self.final_layer(x)
|
|
432
|
-
return out
|