braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -4,11 +4,133 @@
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
|
|
7
|
-
from .base import EEGModuleMixin
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
11
|
-
|
|
10
|
+
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
|
|
12
|
+
|
|
13
|
+
.. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
|
|
14
|
+
:align: center
|
|
15
|
+
:alt: DeepSleepNet Architecture
|
|
16
|
+
|
|
17
|
+
Convolutional neural network and bidirectional-Long Short-Term
|
|
18
|
+
for single channels sleep staging described in [Supratak2017]_.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
activation_large: nn.Module, default=nn.ELU
|
|
23
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
24
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
25
|
+
activation_small: nn.Module, default=nn.ReLU
|
|
26
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
27
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
28
|
+
return_feats : bool
|
|
29
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
30
|
+
(before the final linear layer). If False, pass the features through
|
|
31
|
+
the final linear layer.
|
|
32
|
+
drop_prob : float, default=0.5
|
|
33
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
References
|
|
37
|
+
----------
|
|
38
|
+
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
39
|
+
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
40
|
+
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
41
|
+
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
n_outputs=5,
|
|
47
|
+
return_feats=False,
|
|
48
|
+
n_chans=None,
|
|
49
|
+
chs_info=None,
|
|
50
|
+
n_times=None,
|
|
51
|
+
input_window_seconds=None,
|
|
52
|
+
sfreq=None,
|
|
53
|
+
activation_large: nn.Module = nn.ELU,
|
|
54
|
+
activation_small: nn.Module = nn.ReLU,
|
|
55
|
+
drop_prob: float = 0.5,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(
|
|
58
|
+
n_outputs=n_outputs,
|
|
59
|
+
n_chans=n_chans,
|
|
60
|
+
chs_info=chs_info,
|
|
61
|
+
n_times=n_times,
|
|
62
|
+
input_window_seconds=input_window_seconds,
|
|
63
|
+
sfreq=sfreq,
|
|
64
|
+
)
|
|
65
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
66
|
+
self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
|
|
67
|
+
self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
|
|
68
|
+
self.dropout = nn.Dropout(0.5)
|
|
69
|
+
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
70
|
+
self.fc = nn.Sequential(
|
|
71
|
+
nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.features_extractor = nn.Identity()
|
|
75
|
+
self.len_last_layer = 1024
|
|
76
|
+
self.return_feats = return_feats
|
|
77
|
+
|
|
78
|
+
# TODO: Add new way to handle return_features == True
|
|
79
|
+
if not return_feats:
|
|
80
|
+
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
81
|
+
else:
|
|
82
|
+
self.final_layer = nn.Identity()
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
"""Forward pass.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
x: torch.Tensor
|
|
90
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if x.ndim == 3:
|
|
94
|
+
x = x.unsqueeze(1)
|
|
95
|
+
|
|
96
|
+
x1 = self.cnn1(x)
|
|
97
|
+
x1 = x1.flatten(start_dim=1)
|
|
98
|
+
|
|
99
|
+
x2 = self.cnn2(x)
|
|
100
|
+
x2 = x2.flatten(start_dim=1)
|
|
101
|
+
|
|
102
|
+
x = torch.cat((x1, x2), dim=1)
|
|
103
|
+
x = self.dropout(x)
|
|
104
|
+
temp = x.clone()
|
|
105
|
+
temp = self.fc(temp)
|
|
106
|
+
x = x.unsqueeze(1)
|
|
107
|
+
x = self.bilstm(x)
|
|
108
|
+
x = x.squeeze()
|
|
109
|
+
x = torch.add(x, temp)
|
|
110
|
+
x = self.dropout(x)
|
|
111
|
+
|
|
112
|
+
feats = self.features_extractor(x)
|
|
113
|
+
|
|
114
|
+
if self.return_feats:
|
|
115
|
+
return feats
|
|
116
|
+
else:
|
|
117
|
+
return self.final_layer(feats)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class _SmallCNN(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
Smaller filter sizes to learn temporal information.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
activation: nn.Module, default=nn.ReLU
|
|
127
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
128
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
129
|
+
drop_prob : float, default=0.5
|
|
130
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
|
|
12
134
|
super().__init__()
|
|
13
135
|
self.conv1 = nn.Sequential(
|
|
14
136
|
nn.Conv2d(
|
|
@@ -20,10 +142,10 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
20
142
|
bias=False,
|
|
21
143
|
),
|
|
22
144
|
nn.BatchNorm2d(num_features=64),
|
|
23
|
-
|
|
145
|
+
activation(),
|
|
24
146
|
)
|
|
25
147
|
self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
|
|
26
|
-
self.dropout = nn.Dropout(p=
|
|
148
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
27
149
|
self.conv2 = nn.Sequential(
|
|
28
150
|
nn.Conv2d(
|
|
29
151
|
in_channels=64,
|
|
@@ -34,7 +156,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
34
156
|
bias=False,
|
|
35
157
|
),
|
|
36
158
|
nn.BatchNorm2d(num_features=128),
|
|
37
|
-
|
|
159
|
+
activation(),
|
|
38
160
|
)
|
|
39
161
|
self.conv3 = nn.Sequential(
|
|
40
162
|
nn.Conv2d(
|
|
@@ -46,7 +168,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
46
168
|
bias=False,
|
|
47
169
|
),
|
|
48
170
|
nn.BatchNorm2d(num_features=128),
|
|
49
|
-
|
|
171
|
+
activation(),
|
|
50
172
|
)
|
|
51
173
|
self.conv4 = nn.Sequential(
|
|
52
174
|
nn.Conv2d(
|
|
@@ -58,7 +180,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
58
180
|
bias=False,
|
|
59
181
|
),
|
|
60
182
|
nn.BatchNorm2d(num_features=128),
|
|
61
|
-
|
|
183
|
+
activation(),
|
|
62
184
|
)
|
|
63
185
|
self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
|
|
64
186
|
|
|
@@ -72,8 +194,19 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
72
194
|
return x
|
|
73
195
|
|
|
74
196
|
|
|
75
|
-
class _LargeCNN(nn.Module):
|
|
76
|
-
|
|
197
|
+
class _LargeCNN(nn.Module):
|
|
198
|
+
"""
|
|
199
|
+
Larger filter sizes to learn frequency information.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
activation: nn.Module, default=nn.ELU
|
|
204
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
205
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
|
|
77
210
|
super().__init__()
|
|
78
211
|
|
|
79
212
|
self.conv1 = nn.Sequential(
|
|
@@ -86,10 +219,10 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
86
219
|
bias=False,
|
|
87
220
|
),
|
|
88
221
|
nn.BatchNorm2d(num_features=64),
|
|
89
|
-
|
|
222
|
+
activation(),
|
|
90
223
|
)
|
|
91
224
|
self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
|
|
92
|
-
self.dropout = nn.Dropout(p=
|
|
225
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
93
226
|
self.conv2 = nn.Sequential(
|
|
94
227
|
nn.Conv2d(
|
|
95
228
|
in_channels=64,
|
|
@@ -100,7 +233,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
100
233
|
bias=False,
|
|
101
234
|
),
|
|
102
235
|
nn.BatchNorm2d(num_features=128),
|
|
103
|
-
|
|
236
|
+
activation(),
|
|
104
237
|
)
|
|
105
238
|
self.conv3 = nn.Sequential(
|
|
106
239
|
nn.Conv2d(
|
|
@@ -112,7 +245,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
112
245
|
bias=False,
|
|
113
246
|
),
|
|
114
247
|
nn.BatchNorm2d(num_features=128),
|
|
115
|
-
|
|
248
|
+
activation(),
|
|
116
249
|
)
|
|
117
250
|
self.conv4 = nn.Sequential(
|
|
118
251
|
nn.Conv2d(
|
|
@@ -124,7 +257,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
124
257
|
bias=False,
|
|
125
258
|
),
|
|
126
259
|
nn.BatchNorm2d(num_features=128),
|
|
127
|
-
|
|
260
|
+
activation(),
|
|
128
261
|
)
|
|
129
262
|
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
|
|
130
263
|
|
|
@@ -154,112 +287,9 @@ class _BiLSTM(nn.Module):
|
|
|
154
287
|
|
|
155
288
|
def forward(self, x):
|
|
156
289
|
# set initial hidden and cell states
|
|
157
|
-
h0 = torch.zeros(
|
|
158
|
-
self.num_layers * 2, x.size(0), self.hidden_size
|
|
159
|
-
).to(x.device)
|
|
290
|
+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
160
291
|
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
161
292
|
|
|
162
293
|
# forward propagate LSTM
|
|
163
294
|
out, _ = self.lstm(x, (h0, c0))
|
|
164
295
|
return out
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
168
|
-
"""Sleep staging architecture from Supratak et al 2017.
|
|
169
|
-
|
|
170
|
-
Convolutional neural network and bidirectional-Long Short-Term
|
|
171
|
-
for single channels sleep staging described in [Supratak2017]_.
|
|
172
|
-
|
|
173
|
-
Parameters
|
|
174
|
-
----------
|
|
175
|
-
return_feats : bool
|
|
176
|
-
If True, return the features, i.e. the output of the feature extractor
|
|
177
|
-
(before the final linear layer). If False, pass the features through
|
|
178
|
-
the final linear layer.
|
|
179
|
-
n_classes :
|
|
180
|
-
Alias for n_outputs.
|
|
181
|
-
|
|
182
|
-
References
|
|
183
|
-
----------
|
|
184
|
-
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
185
|
-
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
186
|
-
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
187
|
-
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
def __init__(
|
|
191
|
-
self,
|
|
192
|
-
n_outputs=5,
|
|
193
|
-
return_feats=False,
|
|
194
|
-
n_chans=None,
|
|
195
|
-
chs_info=None,
|
|
196
|
-
n_times=None,
|
|
197
|
-
input_window_seconds=None,
|
|
198
|
-
sfreq=None,
|
|
199
|
-
n_classes=None,
|
|
200
|
-
):
|
|
201
|
-
n_outputs, = deprecated_args(
|
|
202
|
-
self,
|
|
203
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
204
|
-
)
|
|
205
|
-
super().__init__(
|
|
206
|
-
n_outputs=n_outputs,
|
|
207
|
-
n_chans=n_chans,
|
|
208
|
-
chs_info=chs_info,
|
|
209
|
-
n_times=n_times,
|
|
210
|
-
input_window_seconds=input_window_seconds,
|
|
211
|
-
sfreq=sfreq,
|
|
212
|
-
)
|
|
213
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
214
|
-
del n_classes
|
|
215
|
-
self.cnn1 = _SmallCNN()
|
|
216
|
-
self.cnn2 = _LargeCNN()
|
|
217
|
-
self.dropout = nn.Dropout(0.5)
|
|
218
|
-
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
219
|
-
self.fc = nn.Sequential(nn.Linear(3072, 1024, bias=False),
|
|
220
|
-
nn.BatchNorm1d(num_features=1024))
|
|
221
|
-
|
|
222
|
-
self.features_extractor = nn.Identity()
|
|
223
|
-
self.len_last_layer = 1024
|
|
224
|
-
self.return_feats = return_feats
|
|
225
|
-
|
|
226
|
-
# TODO: Add new way to handle return_features == True
|
|
227
|
-
if not return_feats:
|
|
228
|
-
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
229
|
-
else:
|
|
230
|
-
self.final_layer = nn.Identity()
|
|
231
|
-
|
|
232
|
-
def forward(self, x):
|
|
233
|
-
"""Forward pass.
|
|
234
|
-
|
|
235
|
-
Parameters
|
|
236
|
-
----------
|
|
237
|
-
x: torch.Tensor
|
|
238
|
-
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
if x.ndim == 3:
|
|
242
|
-
x = x.unsqueeze(1)
|
|
243
|
-
|
|
244
|
-
x1 = self.cnn1(x)
|
|
245
|
-
x1 = x1.flatten(start_dim=1)
|
|
246
|
-
|
|
247
|
-
x2 = self.cnn2(x)
|
|
248
|
-
x2 = x2.flatten(start_dim=1)
|
|
249
|
-
|
|
250
|
-
x = torch.cat((x1, x2), dim=1)
|
|
251
|
-
x = self.dropout(x)
|
|
252
|
-
temp = x.clone()
|
|
253
|
-
temp = self.fc(temp)
|
|
254
|
-
x = x.unsqueeze(1)
|
|
255
|
-
x = self.bilstm(x)
|
|
256
|
-
x = x.squeeze()
|
|
257
|
-
x = torch.add(x, temp)
|
|
258
|
-
x = self.dropout(x)
|
|
259
|
-
|
|
260
|
-
feats = self.features_extractor(x)
|
|
261
|
-
|
|
262
|
-
if self.return_feats:
|
|
263
|
-
return feats
|
|
264
|
-
else:
|
|
265
|
-
return self.final_layer(feats)
|