braindecode 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/transforms.py +0 -1
- braindecode/datautil/__init__.py +3 -0
- braindecode/datautil/serialization.py +13 -2
- braindecode/functional/__init__.py +12 -0
- braindecode/functional/functions.py +0 -1
- braindecode/models/__init__.py +48 -0
- braindecode/models/atcnet.py +46 -11
- braindecode/models/attentionbasenet.py +49 -0
- braindecode/models/biot.py +29 -8
- braindecode/models/contrawr.py +29 -8
- braindecode/models/ctnet.py +99 -13
- braindecode/models/deep4.py +52 -2
- braindecode/models/eegconformer.py +2 -3
- braindecode/models/eeginception_mi.py +9 -3
- braindecode/models/eegitnet.py +0 -1
- braindecode/models/eegminer.py +0 -1
- braindecode/models/eegnet.py +0 -1
- braindecode/models/fbcnet.py +1 -1
- braindecode/models/fbmsnet.py +0 -1
- braindecode/models/labram.py +23 -3
- braindecode/models/msvtnet.py +1 -1
- braindecode/models/sccnet.py +29 -4
- braindecode/models/signal_jepa.py +0 -1
- braindecode/models/sleep_stager_eldele_2021.py +0 -1
- braindecode/models/sparcnet.py +62 -16
- braindecode/models/tcn.py +1 -1
- braindecode/models/tsinception.py +38 -13
- braindecode/models/util.py +2 -6
- braindecode/modules/__init__.py +46 -0
- braindecode/modules/filter.py +0 -4
- braindecode/modules/layers.py +3 -5
- braindecode/modules/linear.py +1 -2
- braindecode/modules/util.py +0 -1
- braindecode/modules/wrapper.py +0 -2
- braindecode/samplers/base.py +0 -2
- braindecode/version.py +1 -1
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/METADATA +5 -5
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/RECORD +42 -42
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/WHEEL +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/deep4.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
from einops.layers.torch import Rearrange
|
|
6
|
+
from mne.utils import warn
|
|
6
7
|
from torch import nn
|
|
7
8
|
from torch.nn import init
|
|
8
9
|
|
|
@@ -115,6 +116,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
115
116
|
batch_norm=True,
|
|
116
117
|
batch_norm_alpha=0.1,
|
|
117
118
|
stride_before_pool=False,
|
|
119
|
+
# Braindecode EEGModuleMixin parameters
|
|
118
120
|
chs_info=None,
|
|
119
121
|
input_window_seconds=None,
|
|
120
122
|
sfreq=None,
|
|
@@ -155,6 +157,27 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
155
157
|
self.batch_norm_alpha = batch_norm_alpha
|
|
156
158
|
self.stride_before_pool = stride_before_pool
|
|
157
159
|
|
|
160
|
+
min_n_times = self._get_min_n_times()
|
|
161
|
+
if self.n_times < min_n_times:
|
|
162
|
+
scaling_factor = self.n_times / min_n_times
|
|
163
|
+
warn(
|
|
164
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
165
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
166
|
+
"Adjusting parameters to ensure compatibility."
|
|
167
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
168
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
169
|
+
UserWarning,
|
|
170
|
+
)
|
|
171
|
+
# Calculate a scaling factor to adjust temporal parameters
|
|
172
|
+
# Apply the scaling factor to all temporal kernel and pooling sizes
|
|
173
|
+
self.filter_time_length = max(
|
|
174
|
+
1, int(self.filter_time_length * scaling_factor)
|
|
175
|
+
)
|
|
176
|
+
self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
|
|
177
|
+
self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
|
|
178
|
+
self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
|
|
179
|
+
self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
|
|
180
|
+
self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
|
|
158
181
|
# For the load_state_dict
|
|
159
182
|
# When padronize all layers,
|
|
160
183
|
# add the old's parameters here
|
|
@@ -268,7 +291,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
268
291
|
self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
|
|
269
292
|
)
|
|
270
293
|
|
|
271
|
-
# self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
|
|
272
294
|
self.eval()
|
|
273
295
|
if self.final_conv_length == "auto":
|
|
274
296
|
self.final_conv_length = self.get_output_shape()[2]
|
|
@@ -299,7 +321,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
299
321
|
if self.split_first_layer:
|
|
300
322
|
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
301
323
|
if not self.batch_norm:
|
|
302
|
-
init.constant_(self.conv_spat.bias, 0)
|
|
324
|
+
init.constant_(self.conv_time_spat.conv_spat.bias, 0)
|
|
303
325
|
if self.batch_norm:
|
|
304
326
|
init.constant_(self.bnorm.weight, 1)
|
|
305
327
|
init.constant_(self.bnorm.bias, 0)
|
|
@@ -320,3 +342,31 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
320
342
|
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
321
343
|
|
|
322
344
|
self.train()
|
|
345
|
+
|
|
346
|
+
def _get_min_n_times(self) -> int:
|
|
347
|
+
"""
|
|
348
|
+
Calculate the minimum number of time samples required for the model
|
|
349
|
+
to work with the given temporal parameters.
|
|
350
|
+
"""
|
|
351
|
+
# Start with the minimum valid output length of the network (1)
|
|
352
|
+
min_len = 1
|
|
353
|
+
|
|
354
|
+
# List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
|
|
355
|
+
# Each tuple: (filter_length, pool_length, pool_stride)
|
|
356
|
+
block_params = [
|
|
357
|
+
(self.filter_length_4, self.pool_time_length, self.pool_time_stride),
|
|
358
|
+
(self.filter_length_3, self.pool_time_length, self.pool_time_stride),
|
|
359
|
+
(self.filter_length_2, self.pool_time_length, self.pool_time_stride),
|
|
360
|
+
(self.filter_time_length, self.pool_time_length, self.pool_time_stride),
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
# Work backward from the last layer to the input
|
|
364
|
+
for filter_len, pool_len, pool_stride in block_params:
|
|
365
|
+
# Reverse the pooling operation
|
|
366
|
+
# L_in = stride * (L_out - 1) + kernel_size
|
|
367
|
+
min_len = pool_stride * (min_len - 1) + pool_len
|
|
368
|
+
# Reverse the convolution operation (assuming stride=1)
|
|
369
|
+
# L_in = L_out + kernel_size - 1
|
|
370
|
+
min_len = min_len + filter_len - 1
|
|
371
|
+
|
|
372
|
+
return min_len
|
|
@@ -2,11 +2,8 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
import warnings
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
|
-
from einops import rearrange
|
|
10
7
|
from einops.layers.torch import Rearrange
|
|
11
8
|
from torch import Tensor, nn
|
|
12
9
|
|
|
@@ -150,6 +147,8 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
150
147
|
if final_fc_length == "auto":
|
|
151
148
|
assert self.n_times is not None
|
|
152
149
|
self.final_fc_length = self.get_fc_size()
|
|
150
|
+
else:
|
|
151
|
+
self.final_fc_length = final_fc_length
|
|
153
152
|
|
|
154
153
|
self.transformer = _TransformerEncoder(
|
|
155
154
|
att_depth=att_depth,
|
|
@@ -66,9 +66,9 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
66
66
|
n_outputs=None,
|
|
67
67
|
input_window_seconds=None,
|
|
68
68
|
sfreq=250,
|
|
69
|
-
n_convs=5,
|
|
70
|
-
n_filters=48,
|
|
71
|
-
kernel_unit_s=0.1,
|
|
69
|
+
n_convs: int = 5,
|
|
70
|
+
n_filters: int = 48,
|
|
71
|
+
kernel_unit_s: float = 0.1,
|
|
72
72
|
activation: nn.Module = nn.ReLU,
|
|
73
73
|
chs_info=None,
|
|
74
74
|
n_times=None,
|
|
@@ -307,6 +307,12 @@ class _InceptionModuleMI(nn.Module):
|
|
|
307
307
|
|
|
308
308
|
X2 = self.pooling(X)
|
|
309
309
|
X2 = self.pooling_conv(X2)
|
|
310
|
+
# Get the target length from one of the conv branches
|
|
311
|
+
target_len = X1[0].shape[-1]
|
|
312
|
+
|
|
313
|
+
# Crop the pooling output if its length does not match
|
|
314
|
+
if X2.shape[-1] != target_len:
|
|
315
|
+
X2 = X2[..., :target_len]
|
|
310
316
|
|
|
311
317
|
out = torch.cat(X1 + [X2], 1)
|
|
312
318
|
|
braindecode/models/eegitnet.py
CHANGED
braindecode/models/eegminer.py
CHANGED
braindecode/models/eegnet.py
CHANGED
braindecode/models/fbcnet.py
CHANGED
braindecode/models/fbmsnet.py
CHANGED
braindecode/models/labram.py
CHANGED
|
@@ -166,11 +166,23 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
166
166
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
167
167
|
|
|
168
168
|
self.patch_size = patch_size
|
|
169
|
-
self.n_path = self.n_times // patch_size
|
|
170
169
|
self.num_features = self.emb_size = emb_size
|
|
171
170
|
self.neural_tokenizer = neural_tokenizer
|
|
172
171
|
self.init_scale = init_scale
|
|
173
172
|
|
|
173
|
+
if patch_size > self.n_times:
|
|
174
|
+
warn(
|
|
175
|
+
f"patch_size ({patch_size}) > n_times ({self.n_times}); "
|
|
176
|
+
f"setting patch_size = {self.n_times}.",
|
|
177
|
+
UserWarning,
|
|
178
|
+
)
|
|
179
|
+
self.patch_size = self.n_times
|
|
180
|
+
self.num_features = None
|
|
181
|
+
self.emb_size = None
|
|
182
|
+
else:
|
|
183
|
+
self.patch_size = patch_size
|
|
184
|
+
self.n_path = self.n_times // self.patch_size
|
|
185
|
+
|
|
174
186
|
if neural_tokenizer and in_channels != 1:
|
|
175
187
|
warn(
|
|
176
188
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
@@ -220,8 +232,17 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
220
232
|
emb_dim=self.emb_size,
|
|
221
233
|
),
|
|
222
234
|
)
|
|
235
|
+
|
|
236
|
+
with torch.no_grad():
|
|
237
|
+
dummy = torch.zeros(1, self.n_chans, self.n_times)
|
|
238
|
+
out = self.patch_embed(dummy)
|
|
239
|
+
# out.shape for tokenizer: (1, n_chans, emb_dim)
|
|
240
|
+
# for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
|
|
241
|
+
self.emb_size = out.shape[-1]
|
|
242
|
+
self.num_features = self.emb_size
|
|
243
|
+
|
|
223
244
|
# Defining the parameters
|
|
224
|
-
# Creating a parameter list with cls token
|
|
245
|
+
# Creating a parameter list with cls token]
|
|
225
246
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
|
|
226
247
|
# Positional embedding and time embedding are complementary
|
|
227
248
|
# one is for the spatial information and the other is for the temporal
|
|
@@ -366,7 +387,6 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
366
387
|
x : torch.Tensor
|
|
367
388
|
The output of the model.
|
|
368
389
|
"""
|
|
369
|
-
|
|
370
390
|
if self.neural_tokenizer:
|
|
371
391
|
batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
|
|
372
392
|
else:
|
braindecode/models/msvtnet.py
CHANGED
braindecode/models/sccnet.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
+
from warnings import warn
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
from einops.layers.torch import Rearrange
|
|
@@ -98,9 +99,33 @@ class SCCNet(EEGModuleMixin, nn.Module):
|
|
|
98
99
|
self.n_spatial_filters_smooth = n_spatial_filters_smooth
|
|
99
100
|
self.drop_prob = drop_prob
|
|
100
101
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
102
|
+
# Original logical for SCCNet
|
|
103
|
+
conv_kernel_time = 0.1 # 100ms
|
|
104
|
+
pool_kernel_time = 0.5 # 500ms
|
|
105
|
+
|
|
106
|
+
# Calculate sample-based sizes from time durations
|
|
107
|
+
conv_kernel_samples = int(math.floor(self.sfreq * conv_kernel_time))
|
|
108
|
+
pool_kernel_samples = int(math.floor(self.sfreq * pool_kernel_time))
|
|
109
|
+
|
|
110
|
+
# If the input window is too short for the default kernel sizes,
|
|
111
|
+
# scale them down proportionally.
|
|
112
|
+
total_kernel_samples = conv_kernel_samples + pool_kernel_samples
|
|
113
|
+
|
|
114
|
+
if self.n_times < total_kernel_samples:
|
|
115
|
+
warning_msg = (
|
|
116
|
+
f"Input window seconds ({self.input_window_seconds:.2f}s) is smaller than the "
|
|
117
|
+
f"model's combined kernel sizes ({(total_kernel_samples / self.sfreq):.2f}s). "
|
|
118
|
+
"Scaling temporal parameters down proportionally."
|
|
119
|
+
)
|
|
120
|
+
warn(warning_msg, UserWarning, stacklevel=2)
|
|
121
|
+
|
|
122
|
+
scaling_factor = self.n_times / total_kernel_samples
|
|
123
|
+
conv_kernel_samples = int(math.floor(conv_kernel_samples * scaling_factor))
|
|
124
|
+
pool_kernel_samples = int(math.floor(pool_kernel_samples * scaling_factor))
|
|
125
|
+
|
|
126
|
+
# Ensure kernels are at least 1 sample wide
|
|
127
|
+
self.samples_100ms = max(1, conv_kernel_samples)
|
|
128
|
+
self.kernel_size_pool = max(1, pool_kernel_samples)
|
|
104
129
|
|
|
105
130
|
num_features = self._calc_num_features()
|
|
106
131
|
|
|
@@ -135,7 +160,7 @@ class SCCNet(EEGModuleMixin, nn.Module):
|
|
|
135
160
|
|
|
136
161
|
self.dropout = nn.Dropout(self.drop_prob)
|
|
137
162
|
self.temporal_smoothing = nn.AvgPool2d(
|
|
138
|
-
kernel_size=(1,
|
|
163
|
+
kernel_size=(1, self.kernel_size_pool),
|
|
139
164
|
stride=(1, self.samples_100ms),
|
|
140
165
|
)
|
|
141
166
|
|
braindecode/models/sparcnet.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
from math import floor, log2
|
|
5
|
-
from typing import Any
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
import torch.nn as nn
|
|
@@ -69,6 +68,19 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
69
68
|
conv_bias: bool = True,
|
|
70
69
|
batch_norm: bool = True,
|
|
71
70
|
activation: nn.Module = nn.ELU,
|
|
71
|
+
kernel_size_conv0: int = 7,
|
|
72
|
+
kernel_size_conv1: int = 1,
|
|
73
|
+
kernel_size_conv2: int = 3,
|
|
74
|
+
kernel_size_pool: int = 3,
|
|
75
|
+
stride_pool: int = 2,
|
|
76
|
+
stride_conv0: int = 2,
|
|
77
|
+
stride_conv1: int = 1,
|
|
78
|
+
stride_conv2: int = 1,
|
|
79
|
+
padding_pool: int = 1,
|
|
80
|
+
padding_conv0: int = 3,
|
|
81
|
+
padding_conv2: int = 1,
|
|
82
|
+
kernel_size_trans: int = 2,
|
|
83
|
+
stride_trans: int = 2,
|
|
72
84
|
# EEGModuleMixin parameters
|
|
73
85
|
# (another way to present the same parameters)
|
|
74
86
|
chs_info=None,
|
|
@@ -96,9 +108,9 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
96
108
|
nn.Conv1d(
|
|
97
109
|
in_channels=self.n_chans,
|
|
98
110
|
out_channels=out_channels,
|
|
99
|
-
kernel_size=
|
|
100
|
-
stride=
|
|
101
|
-
padding=
|
|
111
|
+
kernel_size=kernel_size_conv0,
|
|
112
|
+
stride=stride_conv0,
|
|
113
|
+
padding=padding_conv0,
|
|
102
114
|
bias=conv_bias,
|
|
103
115
|
),
|
|
104
116
|
)
|
|
@@ -106,7 +118,11 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
106
118
|
)
|
|
107
119
|
first_conv["norm0"] = nn.BatchNorm1d(out_channels)
|
|
108
120
|
first_conv["act_layer"] = activation()
|
|
109
|
-
first_conv["pool0"] = nn.MaxPool1d(
|
|
121
|
+
first_conv["pool0"] = nn.MaxPool1d(
|
|
122
|
+
kernel_size=kernel_size_pool,
|
|
123
|
+
stride=stride_pool,
|
|
124
|
+
padding=padding_pool,
|
|
125
|
+
)
|
|
110
126
|
|
|
111
127
|
self.encoder = nn.Sequential(first_conv)
|
|
112
128
|
|
|
@@ -123,6 +139,11 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
123
139
|
conv_bias=conv_bias,
|
|
124
140
|
batch_norm=batch_norm,
|
|
125
141
|
activation=activation,
|
|
142
|
+
kernel_size_conv1=kernel_size_conv1,
|
|
143
|
+
kernel_size_conv2=kernel_size_conv2,
|
|
144
|
+
stride_conv1=stride_conv1,
|
|
145
|
+
stride_conv2=stride_conv2,
|
|
146
|
+
padding_conv2=padding_conv2,
|
|
126
147
|
)
|
|
127
148
|
self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
|
|
128
149
|
# update the number of channels after each dense block
|
|
@@ -134,16 +155,19 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
134
155
|
conv_bias=conv_bias,
|
|
135
156
|
batch_norm=batch_norm,
|
|
136
157
|
activation=activation,
|
|
158
|
+
kernel_size_trans=kernel_size_trans,
|
|
159
|
+
stride_trans=stride_trans,
|
|
137
160
|
)
|
|
138
161
|
self.encoder.add_module("transition%d" % (n_layer + 1), trans)
|
|
139
162
|
# update the number of channels after each transition layer
|
|
140
163
|
n_channels = n_channels // 2
|
|
141
164
|
|
|
165
|
+
self.adaptative_pool = nn.AdaptiveAvgPool1d(1)
|
|
166
|
+
self.activation_layer = activation()
|
|
167
|
+
self.flatten_layer = nn.Flatten()
|
|
168
|
+
|
|
142
169
|
# add final convolutional layer
|
|
143
|
-
self.final_layer = nn.
|
|
144
|
-
activation(),
|
|
145
|
-
nn.Linear(n_channels, self.n_outputs),
|
|
146
|
-
)
|
|
170
|
+
self.final_layer = nn.Linear(n_channels, self.n_outputs)
|
|
147
171
|
|
|
148
172
|
self._init_weights()
|
|
149
173
|
|
|
@@ -178,7 +202,10 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
178
202
|
torch.Tensor
|
|
179
203
|
The output tensor of the model with shape (batch_size, n_outputs)
|
|
180
204
|
"""
|
|
181
|
-
emb = self.encoder(X)
|
|
205
|
+
emb = self.encoder(X)
|
|
206
|
+
emb = self.adaptative_pool(emb)
|
|
207
|
+
emb = self.activation_layer(emb)
|
|
208
|
+
emb = self.flatten_layer(emb)
|
|
182
209
|
out = self.final_layer(emb)
|
|
183
210
|
return out
|
|
184
211
|
|
|
@@ -224,6 +251,11 @@ class _DenseLayer(nn.Sequential):
|
|
|
224
251
|
conv_bias: bool = True,
|
|
225
252
|
batch_norm: bool = True,
|
|
226
253
|
activation: nn.Module = nn.ELU,
|
|
254
|
+
kernel_size_conv1: int = 1,
|
|
255
|
+
kernel_size_conv2: int = 3,
|
|
256
|
+
stride_conv1: int = 1,
|
|
257
|
+
stride_conv2: int = 1,
|
|
258
|
+
padding_conv2: int = 1,
|
|
227
259
|
):
|
|
228
260
|
super().__init__()
|
|
229
261
|
if batch_norm:
|
|
@@ -235,8 +267,8 @@ class _DenseLayer(nn.Sequential):
|
|
|
235
267
|
nn.Conv1d(
|
|
236
268
|
in_channels=in_channels,
|
|
237
269
|
out_channels=bottleneck_size * growth_rate,
|
|
238
|
-
kernel_size=
|
|
239
|
-
stride=
|
|
270
|
+
kernel_size=kernel_size_conv1,
|
|
271
|
+
stride=stride_conv1,
|
|
240
272
|
bias=conv_bias,
|
|
241
273
|
),
|
|
242
274
|
)
|
|
@@ -248,9 +280,9 @@ class _DenseLayer(nn.Sequential):
|
|
|
248
280
|
nn.Conv1d(
|
|
249
281
|
in_channels=bottleneck_size * growth_rate,
|
|
250
282
|
out_channels=growth_rate,
|
|
251
|
-
kernel_size=
|
|
252
|
-
stride=
|
|
253
|
-
padding=
|
|
283
|
+
kernel_size=kernel_size_conv2,
|
|
284
|
+
stride=stride_conv2,
|
|
285
|
+
padding=padding_conv2,
|
|
254
286
|
bias=conv_bias,
|
|
255
287
|
),
|
|
256
288
|
)
|
|
@@ -311,6 +343,11 @@ class _DenseBlock(nn.Sequential):
|
|
|
311
343
|
conv_bias=True,
|
|
312
344
|
batch_norm=True,
|
|
313
345
|
activation: nn.Module = nn.ELU,
|
|
346
|
+
kernel_size_conv1: int = 1,
|
|
347
|
+
kernel_size_conv2: int = 3,
|
|
348
|
+
stride_conv1: int = 1,
|
|
349
|
+
stride_conv2: int = 1,
|
|
350
|
+
padding_conv2: int = 1,
|
|
314
351
|
):
|
|
315
352
|
super(_DenseBlock, self).__init__()
|
|
316
353
|
for idx_layer in range(num_layers):
|
|
@@ -322,6 +359,11 @@ class _DenseBlock(nn.Sequential):
|
|
|
322
359
|
conv_bias=conv_bias,
|
|
323
360
|
batch_norm=batch_norm,
|
|
324
361
|
activation=activation,
|
|
362
|
+
kernel_size_conv1=kernel_size_conv1,
|
|
363
|
+
kernel_size_conv2=kernel_size_conv2,
|
|
364
|
+
stride_conv1=stride_conv1,
|
|
365
|
+
stride_conv2=stride_conv2,
|
|
366
|
+
padding_conv2=padding_conv2,
|
|
325
367
|
)
|
|
326
368
|
self.add_module(f"denselayer{idx_layer + 1}", layer)
|
|
327
369
|
|
|
@@ -360,6 +402,8 @@ class _TransitionLayer(nn.Sequential):
|
|
|
360
402
|
conv_bias=True,
|
|
361
403
|
batch_norm=True,
|
|
362
404
|
activation: nn.Module = nn.ELU,
|
|
405
|
+
kernel_size_trans: int = 2,
|
|
406
|
+
stride_trans: int = 2,
|
|
363
407
|
):
|
|
364
408
|
super(_TransitionLayer, self).__init__()
|
|
365
409
|
if batch_norm:
|
|
@@ -375,4 +419,6 @@ class _TransitionLayer(nn.Sequential):
|
|
|
375
419
|
bias=conv_bias,
|
|
376
420
|
),
|
|
377
421
|
)
|
|
378
|
-
self.add_module(
|
|
422
|
+
self.add_module(
|
|
423
|
+
"pool", nn.AvgPool1d(kernel_size=kernel_size_trans, stride=stride_trans)
|
|
424
|
+
)
|
braindecode/models/tcn.py
CHANGED
|
@@ -8,7 +8,7 @@ from torch.nn import init
|
|
|
8
8
|
from torch.nn.utils.parametrizations import weight_norm
|
|
9
9
|
|
|
10
10
|
from braindecode.models.base import EEGModuleMixin
|
|
11
|
-
from braindecode.modules import Chomp1d, Ensure4d,
|
|
11
|
+
from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class BDTCN(EEGModuleMixin, nn.Module):
|
|
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn as nn
|
|
9
9
|
from einops.layers.torch import Rearrange
|
|
10
|
+
from mne.utils import warn
|
|
10
11
|
|
|
11
12
|
from braindecode.models.base import EEGModuleMixin
|
|
12
13
|
|
|
@@ -98,20 +99,44 @@ class TSceptionV1(EEGModuleMixin, nn.Module):
|
|
|
98
99
|
|
|
99
100
|
### Layers
|
|
100
101
|
self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
|
|
102
|
+
if self.input_window_seconds < max(self.inception_windows):
|
|
103
|
+
inception_windows = (
|
|
104
|
+
self.input_window_seconds,
|
|
105
|
+
self.input_window_seconds / 2,
|
|
106
|
+
self.input_window_seconds / 4,
|
|
107
|
+
)
|
|
108
|
+
warning_msg = (
|
|
109
|
+
"Input window size is smaller than the maximum inception window size. "
|
|
110
|
+
"We are adjusting the input window size to match the maximum inception window size.\n"
|
|
111
|
+
f"Original input window size: {self.inception_windows}, \n"
|
|
112
|
+
f"Adjusted inception windows: {inception_windows}"
|
|
113
|
+
)
|
|
114
|
+
warn(warning_msg, UserWarning)
|
|
115
|
+
self.inception_windows = inception_windows
|
|
101
116
|
# Define temporal convolutional layers (Tception)
|
|
102
|
-
self.temporal_blocks = nn.ModuleList(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
117
|
+
self.temporal_blocks = nn.ModuleList()
|
|
118
|
+
for window in self.inception_windows:
|
|
119
|
+
# 1. Calculate the temporal kernel size for this block
|
|
120
|
+
kernel_size_t = int(window * self.sfreq)
|
|
121
|
+
|
|
122
|
+
# 2. Calculate the output length of the convolution
|
|
123
|
+
conv_out_len = self.n_times - kernel_size_t + 1
|
|
124
|
+
|
|
125
|
+
# 3. Ensure the pooling size is not larger than the conv output
|
|
126
|
+
# and is at least 1.
|
|
127
|
+
dynamic_pool_size = max(1, min(self.pool_size, conv_out_len))
|
|
128
|
+
|
|
129
|
+
# 4. Create the block with the dynamic pooling size
|
|
130
|
+
block = self._conv_block(
|
|
131
|
+
in_channels=1,
|
|
132
|
+
out_channels=self.number_filter_temp,
|
|
133
|
+
kernel_size=(1, kernel_size_t),
|
|
134
|
+
stride=1,
|
|
135
|
+
pool_size=dynamic_pool_size, # Use the dynamic size
|
|
136
|
+
activation=self.activation,
|
|
137
|
+
)
|
|
138
|
+
self.temporal_blocks.append(block)
|
|
139
|
+
|
|
115
140
|
self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
|
|
116
141
|
|
|
117
142
|
# Define spatial convolutional layers (Sception)
|
braindecode/models/util.py
CHANGED
|
@@ -5,11 +5,7 @@
|
|
|
5
5
|
import inspect
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
|
|
8
|
-
import numpy as np
|
|
9
8
|
import pandas as pd
|
|
10
|
-
import torch
|
|
11
|
-
from scipy.special import log_softmax
|
|
12
|
-
from sklearn.utils import deprecated
|
|
13
9
|
|
|
14
10
|
import braindecode.models as models
|
|
15
11
|
|
|
@@ -76,12 +72,12 @@ models_mandatory_parameters = [
|
|
|
76
72
|
), # 1 channel
|
|
77
73
|
("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
78
74
|
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
|
|
79
|
-
("BIOT", ["n_chans", "n_outputs", "sfreq"], None),
|
|
75
|
+
("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
|
|
80
76
|
("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
81
77
|
("Labram", ["n_chans", "n_outputs", "n_times"], None),
|
|
82
78
|
("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
|
|
83
79
|
("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
84
|
-
("ContraWR", ["n_chans", "n_outputs", "sfreq"], dict(sfreq=200.0)),
|
|
80
|
+
("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], dict(sfreq=200.0)),
|
|
85
81
|
("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
|
|
86
82
|
("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
87
83
|
("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
|
braindecode/modules/__init__.py
CHANGED
|
@@ -36,3 +36,49 @@ from .stats import (
|
|
|
36
36
|
)
|
|
37
37
|
from .util import aggregate_probas
|
|
38
38
|
from .wrapper import Expression, IntermediateOutputWrapper
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"LogActivation",
|
|
42
|
+
"SafeLog",
|
|
43
|
+
"CAT",
|
|
44
|
+
"CBAM",
|
|
45
|
+
"ECA",
|
|
46
|
+
"FCA",
|
|
47
|
+
"GCT",
|
|
48
|
+
"SRM",
|
|
49
|
+
"CATLite",
|
|
50
|
+
"EncNet",
|
|
51
|
+
"GatherExcite",
|
|
52
|
+
"GSoP",
|
|
53
|
+
"MultiHeadAttention",
|
|
54
|
+
"SqueezeAndExcitation",
|
|
55
|
+
"MLP",
|
|
56
|
+
"FeedForwardBlock",
|
|
57
|
+
"InceptionBlock",
|
|
58
|
+
"AvgPool2dWithConv",
|
|
59
|
+
"CausalConv1d",
|
|
60
|
+
"CombinedConv",
|
|
61
|
+
"Conv2dWithConstraint",
|
|
62
|
+
"DepthwiseConv2d",
|
|
63
|
+
"FilterBankLayer",
|
|
64
|
+
"GeneralizedGaussianFilter",
|
|
65
|
+
"Chomp1d",
|
|
66
|
+
"DropPath",
|
|
67
|
+
"Ensure4d",
|
|
68
|
+
"SqueezeFinalOutput",
|
|
69
|
+
"TimeDistributed",
|
|
70
|
+
"LinearWithConstraint",
|
|
71
|
+
"MaxNormLinear",
|
|
72
|
+
"MaxNorm",
|
|
73
|
+
"MaxNormParametrize",
|
|
74
|
+
"LogPowerLayer",
|
|
75
|
+
"LogVarLayer",
|
|
76
|
+
"MaxLayer",
|
|
77
|
+
"MeanLayer",
|
|
78
|
+
"StatLayer",
|
|
79
|
+
"StdLayer",
|
|
80
|
+
"VarLayer",
|
|
81
|
+
"aggregate_probas",
|
|
82
|
+
"Expression",
|
|
83
|
+
"IntermediateOutputWrapper",
|
|
84
|
+
]
|
braindecode/modules/filter.py
CHANGED
|
@@ -1,18 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from functools import partial
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
|
-
from einops.layers.torch import Rearrange
|
|
8
6
|
from mne.filter import _check_coefficients, create_filter
|
|
9
7
|
from mne.utils import warn
|
|
10
8
|
from torch import Tensor, from_numpy, nn
|
|
11
9
|
from torch.fft import fftfreq
|
|
12
10
|
from torchaudio.functional import fftconvolve, filtfilt
|
|
13
11
|
|
|
14
|
-
import braindecode.functional as F
|
|
15
|
-
|
|
16
12
|
|
|
17
13
|
class FilterBankLayer(nn.Module):
|
|
18
14
|
"""Apply multiple band-pass filters to generate multiview signal representation.
|