braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from einops.layers.torch import Rearrange
|
|
6
|
+
from mne.utils import warn
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import init
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
AvgPool2dWithConv,
|
|
13
|
+
CombinedConv,
|
|
14
|
+
Ensure4d,
|
|
15
|
+
SqueezeFinalOutput,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
20
|
+
r"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
21
|
+
|
|
22
|
+
:bdg-success:`Convolution`
|
|
23
|
+
|
|
24
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
|
|
25
|
+
:align: center
|
|
26
|
+
:alt: Deep4Net Architecture
|
|
27
|
+
:width: 600px
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
Model described in [Schirrmeister2017]_.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
final_conv_length: int | str
|
|
35
|
+
Length of the final convolution layer.
|
|
36
|
+
If set to "auto", n_times must not be None. Default: "auto".
|
|
37
|
+
n_filters_time: int
|
|
38
|
+
Number of temporal filters.
|
|
39
|
+
n_filters_spat: int
|
|
40
|
+
Number of spatial filters.
|
|
41
|
+
filter_time_length: int
|
|
42
|
+
Length of the temporal filter in layer 1.
|
|
43
|
+
pool_time_length: int
|
|
44
|
+
Length of temporal pooling filter.
|
|
45
|
+
pool_time_stride: int
|
|
46
|
+
Length of stride between temporal pooling filters.
|
|
47
|
+
n_filters_2: int
|
|
48
|
+
Number of temporal filters in layer 2.
|
|
49
|
+
filter_length_2: int
|
|
50
|
+
Length of the temporal filter in layer 2.
|
|
51
|
+
n_filters_3: int
|
|
52
|
+
Number of temporal filters in layer 3.
|
|
53
|
+
filter_length_3: int
|
|
54
|
+
Length of the temporal filter in layer 3.
|
|
55
|
+
n_filters_4: int
|
|
56
|
+
Number of temporal filters in layer 4.
|
|
57
|
+
filter_length_4: int
|
|
58
|
+
Length of the temporal filter in layer 4.
|
|
59
|
+
activation_first_conv_nonlin: nn.Module, default is nn.ELU
|
|
60
|
+
Non-linear activation function to be used after convolution in layer 1.
|
|
61
|
+
first_pool_mode: str
|
|
62
|
+
Pooling mode in layer 1. "max" or "mean".
|
|
63
|
+
first_pool_nonlin: callable
|
|
64
|
+
Non-linear activation function to be used after pooling in layer 1.
|
|
65
|
+
activation_later_conv_nonlin: nn.Module, default is nn.ELU
|
|
66
|
+
Non-linear activation function to be used after convolution in later layers.
|
|
67
|
+
later_pool_mode: str
|
|
68
|
+
Pooling mode in later layers. "max" or "mean".
|
|
69
|
+
later_pool_nonlin: callable
|
|
70
|
+
Non-linear activation function to be used after pooling in later layers.
|
|
71
|
+
drop_prob: float
|
|
72
|
+
Dropout probability.
|
|
73
|
+
split_first_layer: bool
|
|
74
|
+
Split first layer into temporal and spatial layers (True) or just use temporal (False).
|
|
75
|
+
There would be no non-linearity between the split layers.
|
|
76
|
+
batch_norm: bool
|
|
77
|
+
Whether to use batch normalisation.
|
|
78
|
+
batch_norm_alpha: float
|
|
79
|
+
Momentum for BatchNorm2d.
|
|
80
|
+
stride_before_pool: bool
|
|
81
|
+
Stride before pooling.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
References
|
|
85
|
+
----------
|
|
86
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
87
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
88
|
+
& Ball, T. (2017).
|
|
89
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
90
|
+
visualization.
|
|
91
|
+
Human Brain Mapping , Aug. 2017.
|
|
92
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
n_chans=None,
|
|
98
|
+
n_outputs=None,
|
|
99
|
+
n_times=None,
|
|
100
|
+
final_conv_length="auto",
|
|
101
|
+
n_filters_time=25,
|
|
102
|
+
n_filters_spat=25,
|
|
103
|
+
filter_time_length=10,
|
|
104
|
+
pool_time_length=3,
|
|
105
|
+
pool_time_stride=3,
|
|
106
|
+
n_filters_2=50,
|
|
107
|
+
filter_length_2=10,
|
|
108
|
+
n_filters_3=100,
|
|
109
|
+
filter_length_3=10,
|
|
110
|
+
n_filters_4=200,
|
|
111
|
+
filter_length_4=10,
|
|
112
|
+
activation_first_conv_nonlin: type[nn.Module] = nn.ELU,
|
|
113
|
+
first_pool_mode="max",
|
|
114
|
+
first_pool_nonlin: type[nn.Module] = nn.Identity,
|
|
115
|
+
activation_later_conv_nonlin: type[nn.Module] = nn.ELU,
|
|
116
|
+
later_pool_mode="max",
|
|
117
|
+
later_pool_nonlin: type[nn.Module] = nn.Identity,
|
|
118
|
+
drop_prob=0.5,
|
|
119
|
+
split_first_layer=True,
|
|
120
|
+
batch_norm=True,
|
|
121
|
+
batch_norm_alpha=0.1,
|
|
122
|
+
stride_before_pool=False,
|
|
123
|
+
# Braindecode EEGModuleMixin parameters
|
|
124
|
+
chs_info=None,
|
|
125
|
+
input_window_seconds=None,
|
|
126
|
+
sfreq=None,
|
|
127
|
+
):
|
|
128
|
+
super().__init__(
|
|
129
|
+
n_outputs=n_outputs,
|
|
130
|
+
n_chans=n_chans,
|
|
131
|
+
chs_info=chs_info,
|
|
132
|
+
n_times=n_times,
|
|
133
|
+
input_window_seconds=input_window_seconds,
|
|
134
|
+
sfreq=sfreq,
|
|
135
|
+
)
|
|
136
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
137
|
+
|
|
138
|
+
if final_conv_length == "auto":
|
|
139
|
+
assert self.n_times is not None
|
|
140
|
+
self.final_conv_length = final_conv_length
|
|
141
|
+
self.n_filters_time = n_filters_time
|
|
142
|
+
self.n_filters_spat = n_filters_spat
|
|
143
|
+
self.filter_time_length = filter_time_length
|
|
144
|
+
self.pool_time_length = pool_time_length
|
|
145
|
+
self.pool_time_stride = pool_time_stride
|
|
146
|
+
self.n_filters_2 = n_filters_2
|
|
147
|
+
self.filter_length_2 = filter_length_2
|
|
148
|
+
self.n_filters_3 = n_filters_3
|
|
149
|
+
self.filter_length_3 = filter_length_3
|
|
150
|
+
self.n_filters_4 = n_filters_4
|
|
151
|
+
self.filter_length_4 = filter_length_4
|
|
152
|
+
self.first_nonlin = activation_first_conv_nonlin
|
|
153
|
+
self.first_pool_mode = first_pool_mode
|
|
154
|
+
self.first_pool_nonlin = first_pool_nonlin
|
|
155
|
+
self.later_conv_nonlin = activation_later_conv_nonlin
|
|
156
|
+
self.later_pool_mode = later_pool_mode
|
|
157
|
+
self.later_pool_nonlin = later_pool_nonlin
|
|
158
|
+
self.drop_prob = drop_prob
|
|
159
|
+
self.split_first_layer = split_first_layer
|
|
160
|
+
self.batch_norm = batch_norm
|
|
161
|
+
self.batch_norm_alpha = batch_norm_alpha
|
|
162
|
+
self.stride_before_pool = stride_before_pool
|
|
163
|
+
|
|
164
|
+
min_n_times = self._get_min_n_times()
|
|
165
|
+
if self.n_times < min_n_times:
|
|
166
|
+
scaling_factor = self.n_times / min_n_times
|
|
167
|
+
warn(
|
|
168
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
169
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
170
|
+
"Adjusting parameters to ensure compatibility."
|
|
171
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
172
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
173
|
+
UserWarning,
|
|
174
|
+
)
|
|
175
|
+
# Calculate a scaling factor to adjust temporal parameters
|
|
176
|
+
# Apply the scaling factor to all temporal kernel and pooling sizes
|
|
177
|
+
self.filter_time_length = max(
|
|
178
|
+
1, int(self.filter_time_length * scaling_factor)
|
|
179
|
+
)
|
|
180
|
+
self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
|
|
181
|
+
self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
|
|
182
|
+
self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
|
|
183
|
+
self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
|
|
184
|
+
self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
|
|
185
|
+
# For the load_state_dict
|
|
186
|
+
# When padronize all layers,
|
|
187
|
+
# add the old's parameters here
|
|
188
|
+
self.mapping = {
|
|
189
|
+
"conv_time.weight": "conv_time_spat.conv_time.weight",
|
|
190
|
+
"conv_spat.weight": "conv_time_spat.conv_spat.weight",
|
|
191
|
+
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
192
|
+
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
193
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
194
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if self.stride_before_pool:
|
|
198
|
+
conv_stride = self.pool_time_stride
|
|
199
|
+
pool_stride = 1
|
|
200
|
+
else:
|
|
201
|
+
conv_stride = 1
|
|
202
|
+
pool_stride = self.pool_time_stride
|
|
203
|
+
self.add_module("ensuredims", Ensure4d())
|
|
204
|
+
pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
|
|
205
|
+
first_pool_class = pool_class_dict[self.first_pool_mode]
|
|
206
|
+
later_pool_class = pool_class_dict[self.later_pool_mode]
|
|
207
|
+
if self.split_first_layer:
|
|
208
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
209
|
+
self.add_module(
|
|
210
|
+
"conv_time_spat",
|
|
211
|
+
CombinedConv(
|
|
212
|
+
in_chans=self.n_chans,
|
|
213
|
+
n_filters_time=self.n_filters_time,
|
|
214
|
+
n_filters_spat=self.n_filters_spat,
|
|
215
|
+
filter_time_length=filter_time_length,
|
|
216
|
+
bias_time=True,
|
|
217
|
+
bias_spat=not self.batch_norm,
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
n_filters_conv = self.n_filters_spat
|
|
221
|
+
else:
|
|
222
|
+
self.add_module(
|
|
223
|
+
"conv_time",
|
|
224
|
+
nn.Conv2d(
|
|
225
|
+
self.n_chans,
|
|
226
|
+
self.n_filters_time,
|
|
227
|
+
(self.filter_time_length, 1),
|
|
228
|
+
stride=(conv_stride, 1),
|
|
229
|
+
bias=not self.batch_norm,
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
n_filters_conv = self.n_filters_time
|
|
233
|
+
if self.batch_norm:
|
|
234
|
+
self.add_module(
|
|
235
|
+
"bnorm",
|
|
236
|
+
nn.BatchNorm2d(
|
|
237
|
+
n_filters_conv,
|
|
238
|
+
momentum=self.batch_norm_alpha,
|
|
239
|
+
affine=True,
|
|
240
|
+
eps=1e-5,
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
self.add_module("conv_nonlin", self.first_nonlin())
|
|
244
|
+
self.add_module(
|
|
245
|
+
"pool",
|
|
246
|
+
first_pool_class(
|
|
247
|
+
kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
self.add_module("pool_nonlin", self.first_pool_nonlin())
|
|
251
|
+
|
|
252
|
+
def add_conv_pool_block(
|
|
253
|
+
model, n_filters_before, n_filters, filter_length, block_nr
|
|
254
|
+
):
|
|
255
|
+
suffix = "_{:d}".format(block_nr)
|
|
256
|
+
self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
|
|
257
|
+
self.add_module(
|
|
258
|
+
"conv" + suffix,
|
|
259
|
+
nn.Conv2d(
|
|
260
|
+
n_filters_before,
|
|
261
|
+
n_filters,
|
|
262
|
+
(filter_length, 1),
|
|
263
|
+
stride=(conv_stride, 1),
|
|
264
|
+
bias=not self.batch_norm,
|
|
265
|
+
),
|
|
266
|
+
)
|
|
267
|
+
if self.batch_norm:
|
|
268
|
+
self.add_module(
|
|
269
|
+
"bnorm" + suffix,
|
|
270
|
+
nn.BatchNorm2d(
|
|
271
|
+
n_filters,
|
|
272
|
+
momentum=self.batch_norm_alpha,
|
|
273
|
+
affine=True,
|
|
274
|
+
eps=1e-5,
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
self.add_module("nonlin" + suffix, self.later_conv_nonlin())
|
|
278
|
+
|
|
279
|
+
self.add_module(
|
|
280
|
+
"pool" + suffix,
|
|
281
|
+
later_pool_class(
|
|
282
|
+
kernel_size=(self.pool_time_length, 1),
|
|
283
|
+
stride=(pool_stride, 1),
|
|
284
|
+
),
|
|
285
|
+
)
|
|
286
|
+
self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
|
|
287
|
+
|
|
288
|
+
add_conv_pool_block(
|
|
289
|
+
self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
|
|
290
|
+
)
|
|
291
|
+
add_conv_pool_block(
|
|
292
|
+
self, self.n_filters_2, self.n_filters_3, self.filter_length_3, 3
|
|
293
|
+
)
|
|
294
|
+
add_conv_pool_block(
|
|
295
|
+
self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self.eval()
|
|
299
|
+
if self.final_conv_length == "auto":
|
|
300
|
+
self.final_conv_length = self.get_output_shape()[2]
|
|
301
|
+
|
|
302
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
303
|
+
module = nn.Sequential()
|
|
304
|
+
|
|
305
|
+
module.add_module(
|
|
306
|
+
"conv_classifier",
|
|
307
|
+
nn.Conv2d(
|
|
308
|
+
self.n_filters_4,
|
|
309
|
+
self.n_outputs,
|
|
310
|
+
(self.final_conv_length, 1),
|
|
311
|
+
bias=True,
|
|
312
|
+
),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
316
|
+
|
|
317
|
+
self.add_module("final_layer", module)
|
|
318
|
+
|
|
319
|
+
# Initialization, xavier is same as in our paper...
|
|
320
|
+
# was default from lasagne
|
|
321
|
+
init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
|
|
322
|
+
# maybe no bias in case of no split layer and batch norm
|
|
323
|
+
if self.split_first_layer or (not self.batch_norm):
|
|
324
|
+
init.constant_(self.conv_time_spat.conv_time.bias, 0)
|
|
325
|
+
if self.split_first_layer:
|
|
326
|
+
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
327
|
+
if not self.batch_norm:
|
|
328
|
+
init.constant_(self.conv_time_spat.conv_spat.bias, 0)
|
|
329
|
+
if self.batch_norm:
|
|
330
|
+
init.constant_(self.bnorm.weight, 1)
|
|
331
|
+
init.constant_(self.bnorm.bias, 0)
|
|
332
|
+
param_dict = dict(list(self.named_parameters()))
|
|
333
|
+
for block_nr in range(2, 5):
|
|
334
|
+
conv_weight = param_dict["conv_{:d}.weight".format(block_nr)]
|
|
335
|
+
init.xavier_uniform_(conv_weight, gain=1)
|
|
336
|
+
if not self.batch_norm:
|
|
337
|
+
conv_bias = param_dict["conv_{:d}.bias".format(block_nr)]
|
|
338
|
+
init.constant_(conv_bias, 0)
|
|
339
|
+
else:
|
|
340
|
+
bnorm_weight = param_dict["bnorm_{:d}.weight".format(block_nr)]
|
|
341
|
+
bnorm_bias = param_dict["bnorm_{:d}.bias".format(block_nr)]
|
|
342
|
+
init.constant_(bnorm_weight, 1)
|
|
343
|
+
init.constant_(bnorm_bias, 0)
|
|
344
|
+
|
|
345
|
+
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
346
|
+
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
347
|
+
|
|
348
|
+
self.train()
|
|
349
|
+
|
|
350
|
+
def _get_min_n_times(self) -> int:
|
|
351
|
+
"""
|
|
352
|
+
Calculate the minimum number of time samples required for the model
|
|
353
|
+
to work with the given temporal parameters.
|
|
354
|
+
"""
|
|
355
|
+
# Start with the minimum valid output length of the network (1)
|
|
356
|
+
min_len = 1
|
|
357
|
+
|
|
358
|
+
# List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
|
|
359
|
+
# Each tuple: (filter_length, pool_length, pool_stride)
|
|
360
|
+
block_params = [
|
|
361
|
+
(self.filter_length_4, self.pool_time_length, self.pool_time_stride),
|
|
362
|
+
(self.filter_length_3, self.pool_time_length, self.pool_time_stride),
|
|
363
|
+
(self.filter_length_2, self.pool_time_length, self.pool_time_stride),
|
|
364
|
+
(self.filter_time_length, self.pool_time_length, self.pool_time_stride),
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
# Work backward from the last layer to the input
|
|
368
|
+
for filter_len, pool_len, pool_stride in block_params:
|
|
369
|
+
# Reverse the pooling operation
|
|
370
|
+
# L_in = stride * (L_out - 1) + kernel_size
|
|
371
|
+
min_len = pool_stride * (min_len - 1) + pool_len
|
|
372
|
+
# Reverse the convolution operation (assuming stride=1)
|
|
373
|
+
# L_in = L_out + kernel_size - 1
|
|
374
|
+
min_len = min_len + filter_len - 1
|
|
375
|
+
|
|
376
|
+
return min_len
|