braindecode 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/transforms.py +0 -1
- braindecode/datautil/__init__.py +3 -0
- braindecode/datautil/serialization.py +13 -2
- braindecode/functional/__init__.py +12 -0
- braindecode/functional/functions.py +0 -1
- braindecode/models/__init__.py +48 -0
- braindecode/models/atcnet.py +46 -11
- braindecode/models/attentionbasenet.py +49 -0
- braindecode/models/biot.py +29 -8
- braindecode/models/contrawr.py +29 -8
- braindecode/models/ctnet.py +99 -13
- braindecode/models/deep4.py +52 -2
- braindecode/models/eegconformer.py +2 -3
- braindecode/models/eeginception_mi.py +9 -3
- braindecode/models/eegitnet.py +0 -1
- braindecode/models/eegminer.py +0 -1
- braindecode/models/eegnet.py +0 -1
- braindecode/models/fbcnet.py +1 -1
- braindecode/models/fbmsnet.py +0 -1
- braindecode/models/labram.py +23 -3
- braindecode/models/msvtnet.py +1 -1
- braindecode/models/sccnet.py +29 -4
- braindecode/models/signal_jepa.py +0 -1
- braindecode/models/sleep_stager_eldele_2021.py +0 -1
- braindecode/models/sparcnet.py +62 -16
- braindecode/models/tcn.py +1 -1
- braindecode/models/tsinception.py +38 -13
- braindecode/models/util.py +2 -6
- braindecode/modules/__init__.py +46 -0
- braindecode/modules/filter.py +0 -4
- braindecode/modules/layers.py +3 -5
- braindecode/modules/linear.py +1 -2
- braindecode/modules/util.py +0 -1
- braindecode/modules/wrapper.py +0 -2
- braindecode/samplers/base.py +0 -2
- braindecode/version.py +1 -1
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/METADATA +5 -5
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/RECORD +42 -42
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/WHEEL +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.0.0.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/datautil/__init__.py
CHANGED
|
@@ -107,7 +107,14 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=N
|
|
|
107
107
|
def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
|
|
108
108
|
all_signals = []
|
|
109
109
|
file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
|
|
110
|
-
description_df = pd.read_json(
|
|
110
|
+
description_df = pd.read_json(
|
|
111
|
+
path / "description.json", typ="series", convert_dates=False
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if "timestamp" in description_df.index:
|
|
115
|
+
timestamp_numeric = pd.to_numeric(description_df["timestamp"])
|
|
116
|
+
description_df["timestamp"] = pd.to_datetime(timestamp_numeric)
|
|
117
|
+
|
|
111
118
|
if ids_to_load is None:
|
|
112
119
|
file_names = path.glob(f"*{file_name.lstrip('{}')}")
|
|
113
120
|
# Extract ids, e.g.,
|
|
@@ -242,7 +249,11 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
|
242
249
|
signals = _load_signals(fif_file_path, preload, is_raw)
|
|
243
250
|
|
|
244
251
|
description_file_path = sub_dir / "description.json"
|
|
245
|
-
description = pd.read_json(description_file_path, typ="series")
|
|
252
|
+
description = pd.read_json(description_file_path, typ="series", convert_dates=False)
|
|
253
|
+
|
|
254
|
+
# if 'timestamp' in description.index:
|
|
255
|
+
# timestamp_numeric = pd.to_numeric(description['timestamp'])
|
|
256
|
+
# description['timestamp'] = pd.to_datetime(timestamp_numeric, unit='s')
|
|
246
257
|
|
|
247
258
|
target_file_path = sub_dir / "target_name.json"
|
|
248
259
|
target_name = None
|
|
@@ -8,3 +8,15 @@ from .functions import (
|
|
|
8
8
|
square,
|
|
9
9
|
)
|
|
10
10
|
from .initialization import glorot_weight_zero_bias, rescale_parameter
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"_get_gaussian_kernel1d",
|
|
14
|
+
"drop_path",
|
|
15
|
+
"hilbert_freq",
|
|
16
|
+
"identity",
|
|
17
|
+
"plv_time",
|
|
18
|
+
"safe_log",
|
|
19
|
+
"square",
|
|
20
|
+
"glorot_weight_zero_bias",
|
|
21
|
+
"rescale_parameter",
|
|
22
|
+
]
|
braindecode/models/__init__.py
CHANGED
|
@@ -50,3 +50,51 @@ from .util import _init_models_dict, models_mandatory_parameters
|
|
|
50
50
|
# Call this last in order to make sure the dataset list is populated with
|
|
51
51
|
# the models imported in this file.
|
|
52
52
|
_init_models_dict()
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
"ATCNet",
|
|
56
|
+
"AttentionBaseNet",
|
|
57
|
+
"EEGModuleMixin",
|
|
58
|
+
"BIOT",
|
|
59
|
+
"ContraWR",
|
|
60
|
+
"CTNet",
|
|
61
|
+
"Deep4Net",
|
|
62
|
+
"DeepSleepNet",
|
|
63
|
+
"EEGConformer",
|
|
64
|
+
"EEGInceptionERP",
|
|
65
|
+
"EEGInceptionMI",
|
|
66
|
+
"EEGITNet",
|
|
67
|
+
"EEGMiner",
|
|
68
|
+
"EEGNetv1",
|
|
69
|
+
"EEGNetv4",
|
|
70
|
+
"EEGNeX",
|
|
71
|
+
"EEGResNet",
|
|
72
|
+
"EEGSimpleConv",
|
|
73
|
+
"EEGTCNet",
|
|
74
|
+
"FBCNet",
|
|
75
|
+
"FBLightConvNet",
|
|
76
|
+
"FBMSNet",
|
|
77
|
+
"HybridNet",
|
|
78
|
+
"IFNet",
|
|
79
|
+
"Labram",
|
|
80
|
+
"MSVTNet",
|
|
81
|
+
"SCCNet",
|
|
82
|
+
"ShallowFBCSPNet",
|
|
83
|
+
"SignalJEPA",
|
|
84
|
+
"SignalJEPA_Contextual",
|
|
85
|
+
"SignalJEPA_PostLocal",
|
|
86
|
+
"SignalJEPA_PreLocal",
|
|
87
|
+
"SincShallowNet",
|
|
88
|
+
"SleepStagerBlanco2020",
|
|
89
|
+
"SleepStagerChambon2018",
|
|
90
|
+
"SleepStagerEldele2021",
|
|
91
|
+
"SPARCNet",
|
|
92
|
+
"SyncNet",
|
|
93
|
+
"BDTCN",
|
|
94
|
+
"TCN",
|
|
95
|
+
"TIDNet",
|
|
96
|
+
"TSceptionV1",
|
|
97
|
+
"USleep",
|
|
98
|
+
"_init_models_dict",
|
|
99
|
+
"models_mandatory_parameters",
|
|
100
|
+
]
|
braindecode/models/atcnet.py
CHANGED
|
@@ -5,6 +5,7 @@ import math
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from einops.layers.torch import Rearrange
|
|
8
|
+
from mne.utils import warn
|
|
8
9
|
from torch import nn
|
|
9
10
|
|
|
10
11
|
from braindecode.models.base import EEGModuleMixin
|
|
@@ -69,9 +70,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
69
70
|
tcn_kernel_size : int
|
|
70
71
|
Temporal kernel size used in TCN block, denoted Kt in table 1 of the
|
|
71
72
|
paper [1]_. Defaults to 4 as in [1]_.
|
|
72
|
-
tcn_n_filters : int
|
|
73
|
-
Number of filters used in TCN convolutional layers (Ft). Defaults to
|
|
74
|
-
32 as in [1]_.
|
|
75
73
|
tcn_dropout : float
|
|
76
74
|
Dropout probability used in the TCN block, denoted pt in table 1
|
|
77
75
|
of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
@@ -117,7 +115,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
117
115
|
att_drop_prob=0.5,
|
|
118
116
|
tcn_depth=2,
|
|
119
117
|
tcn_kernel_size=4,
|
|
120
|
-
tcn_n_filters=32,
|
|
121
118
|
tcn_drop_prob=0.3,
|
|
122
119
|
tcn_activation: nn.Module = nn.ELU,
|
|
123
120
|
concat=False,
|
|
@@ -134,6 +131,45 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
134
131
|
sfreq=sfreq,
|
|
135
132
|
)
|
|
136
133
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
134
|
+
|
|
135
|
+
# Validate and adjust parameters based on input size
|
|
136
|
+
|
|
137
|
+
min_len_tcn = (tcn_kernel_size - 1) * (2 ** (tcn_depth - 1)) + 1
|
|
138
|
+
# Minimum length required to get at least one sliding window
|
|
139
|
+
min_len_sliding = n_windows + min_len_tcn - 1
|
|
140
|
+
# Minimum input size that produces the required feature map length
|
|
141
|
+
min_n_times = min_len_sliding * conv_block_pool_size_1 * conv_block_pool_size_2
|
|
142
|
+
|
|
143
|
+
# 2. If the input is shorter, calculate a scaling factor
|
|
144
|
+
if self.n_times < min_n_times:
|
|
145
|
+
scaling_factor = self.n_times / min_n_times
|
|
146
|
+
warn(
|
|
147
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
148
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
149
|
+
"Adjusting parameters to ensure compatibility."
|
|
150
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
151
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
152
|
+
UserWarning,
|
|
153
|
+
)
|
|
154
|
+
conv_block_kernel_length_1 = max(
|
|
155
|
+
1, int(conv_block_kernel_length_1 * scaling_factor)
|
|
156
|
+
)
|
|
157
|
+
conv_block_kernel_length_2 = max(
|
|
158
|
+
1, int(conv_block_kernel_length_2 * scaling_factor)
|
|
159
|
+
)
|
|
160
|
+
conv_block_pool_size_1 = max(
|
|
161
|
+
1, int(conv_block_pool_size_1 * scaling_factor)
|
|
162
|
+
)
|
|
163
|
+
conv_block_pool_size_2 = max(
|
|
164
|
+
1, int(conv_block_pool_size_2 * scaling_factor)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# n_windows should be at least 1
|
|
168
|
+
n_windows = max(1, int(n_windows * scaling_factor))
|
|
169
|
+
|
|
170
|
+
# tcn_kernel_size must be at least 2 for dilation to work
|
|
171
|
+
tcn_kernel_size = max(2, int(tcn_kernel_size * scaling_factor))
|
|
172
|
+
|
|
137
173
|
self.conv_block_n_filters = conv_block_n_filters
|
|
138
174
|
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
|
|
139
175
|
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
|
|
@@ -147,12 +183,11 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
147
183
|
self.att_dropout = att_drop_prob
|
|
148
184
|
self.tcn_depth = tcn_depth
|
|
149
185
|
self.tcn_kernel_size = tcn_kernel_size
|
|
150
|
-
self.tcn_n_filters = tcn_n_filters
|
|
151
186
|
self.tcn_dropout = tcn_drop_prob
|
|
152
187
|
self.tcn_activation = tcn_activation
|
|
153
188
|
self.concat = concat
|
|
154
189
|
self.max_norm_const = max_norm_const
|
|
155
|
-
|
|
190
|
+
self.tcn_n_filters = int(self.conv_block_depth_mult * self.conv_block_n_filters)
|
|
156
191
|
map = dict()
|
|
157
192
|
for w in range(self.n_windows):
|
|
158
193
|
map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
|
|
@@ -197,13 +232,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
197
232
|
*[
|
|
198
233
|
_TCNResidualBlock(
|
|
199
234
|
in_channels=self.F2,
|
|
200
|
-
kernel_size=tcn_kernel_size,
|
|
201
|
-
n_filters=tcn_n_filters,
|
|
202
|
-
dropout=
|
|
203
|
-
activation=tcn_activation,
|
|
235
|
+
kernel_size=self.tcn_kernel_size,
|
|
236
|
+
n_filters=self.tcn_n_filters,
|
|
237
|
+
dropout=self.tcn_dropout,
|
|
238
|
+
activation=self.tcn_activation,
|
|
204
239
|
dilation=2**i,
|
|
205
240
|
)
|
|
206
|
-
for i in range(tcn_depth)
|
|
241
|
+
for i in range(self.tcn_depth)
|
|
207
242
|
]
|
|
208
243
|
)
|
|
209
244
|
for _ in range(self.n_windows)
|
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
|
|
5
5
|
from einops.layers.torch import Rearrange
|
|
6
|
+
from mne.utils import warn
|
|
6
7
|
from torch import nn
|
|
7
8
|
|
|
8
9
|
from braindecode.models.base import EEGModuleMixin
|
|
@@ -162,6 +163,33 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
162
163
|
)
|
|
163
164
|
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
164
165
|
|
|
166
|
+
min_n_times_required = self._get_min_n_times(
|
|
167
|
+
pool_length_inp,
|
|
168
|
+
pool_stride_inp,
|
|
169
|
+
pool_length,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if self.n_times < min_n_times_required:
|
|
173
|
+
scaling_factor = self.n_times / min_n_times_required
|
|
174
|
+
warn(
|
|
175
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
176
|
+
f"({min_n_times_required}) for the current model parameters configuration. "
|
|
177
|
+
"Adjusting parameters to ensure compatibility."
|
|
178
|
+
"Reducing the kernel, pooling, and stride sizes accordingly.\n"
|
|
179
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
180
|
+
UserWarning,
|
|
181
|
+
)
|
|
182
|
+
# 3. Scale down all temporal parameters proportionally
|
|
183
|
+
# Use max(1, ...) to ensure parameters remain valid
|
|
184
|
+
temp_filter_length_inp = max(
|
|
185
|
+
1, int(temp_filter_length_inp * scaling_factor)
|
|
186
|
+
)
|
|
187
|
+
pool_length_inp = max(1, int(pool_length_inp * scaling_factor))
|
|
188
|
+
pool_stride_inp = max(1, int(pool_stride_inp * scaling_factor))
|
|
189
|
+
temp_filter_length = max(1, int(temp_filter_length * scaling_factor))
|
|
190
|
+
pool_length = max(1, int(pool_length * scaling_factor))
|
|
191
|
+
pool_stride = max(1, int(pool_stride * scaling_factor))
|
|
192
|
+
|
|
165
193
|
self.input_block = _FeatureExtractor(
|
|
166
194
|
n_chans=self.n_chans,
|
|
167
195
|
n_temporal_filters=n_temporal_filters,
|
|
@@ -231,6 +259,27 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
231
259
|
seq_lengths.append(int(out))
|
|
232
260
|
return seq_lengths
|
|
233
261
|
|
|
262
|
+
@staticmethod
|
|
263
|
+
def _get_min_n_times(
|
|
264
|
+
pool_length_inp: int,
|
|
265
|
+
pool_stride_inp: int,
|
|
266
|
+
pool_length: int,
|
|
267
|
+
) -> int:
|
|
268
|
+
"""
|
|
269
|
+
Calculates the minimum n_times required for the model to work
|
|
270
|
+
with the given parameters.
|
|
271
|
+
|
|
272
|
+
The calculation is based on reversing the pooling operations to
|
|
273
|
+
ensure the input to each is valid.
|
|
274
|
+
"""
|
|
275
|
+
# The input to the second pooling layer must be at least its kernel size.
|
|
276
|
+
min_len_for_second_pool = pool_length
|
|
277
|
+
|
|
278
|
+
# Reverse the first pooling operation to find the required input size.
|
|
279
|
+
# Formula: min_L_in = Stride * (min_L_out - 1) + Kernel
|
|
280
|
+
min_len = pool_stride_inp * (min_len_for_second_pool - 1) + pool_length_inp
|
|
281
|
+
return min_len
|
|
282
|
+
|
|
234
283
|
|
|
235
284
|
class _FeatureExtractor(nn.Module):
|
|
236
285
|
"""
|
braindecode/models/biot.py
CHANGED
|
@@ -87,6 +87,10 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
87
87
|
input_window_seconds=None,
|
|
88
88
|
activation: nn.Module = nn.ELU,
|
|
89
89
|
drop_prob: float = 0.5,
|
|
90
|
+
# Parameters for the encoder
|
|
91
|
+
max_seq_len: int = 1024,
|
|
92
|
+
attn_dropout=0.2,
|
|
93
|
+
attn_layer_dropout=0.2,
|
|
90
94
|
):
|
|
91
95
|
super().__init__(
|
|
92
96
|
n_outputs=n_outputs,
|
|
@@ -123,14 +127,29 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
123
127
|
UserWarning,
|
|
124
128
|
)
|
|
125
129
|
hop_length = self.sfreq // 2
|
|
130
|
+
|
|
131
|
+
if self.input_window_seconds < 1.0:
|
|
132
|
+
warning_msg = (
|
|
133
|
+
"The input window is less than 1 second, which may not be "
|
|
134
|
+
"sufficient for the model to learn meaningful representations."
|
|
135
|
+
"Changing the `n_fft` to `n_times`."
|
|
136
|
+
)
|
|
137
|
+
warn(warning_msg, UserWarning)
|
|
138
|
+
self.n_fft = self.n_times
|
|
139
|
+
else:
|
|
140
|
+
self.n_fft = int(self.sfreq)
|
|
141
|
+
|
|
126
142
|
self.encoder = _BIOTEncoder(
|
|
127
143
|
emb_size=emb_size,
|
|
128
144
|
att_num_heads=att_num_heads,
|
|
129
145
|
n_layers=n_layers,
|
|
130
146
|
n_chans=self.n_chans,
|
|
131
|
-
n_fft=self.
|
|
147
|
+
n_fft=self.n_fft,
|
|
132
148
|
hop_length=hop_length,
|
|
133
149
|
drop_prob=drop_prob,
|
|
150
|
+
max_seq_len=max_seq_len,
|
|
151
|
+
attn_dropout=attn_dropout,
|
|
152
|
+
attn_layer_dropout=attn_layer_dropout,
|
|
134
153
|
)
|
|
135
154
|
|
|
136
155
|
self.final_layer = _ClassificationHead(
|
|
@@ -231,12 +250,11 @@ class _ClassificationHead(nn.Sequential):
|
|
|
231
250
|
|
|
232
251
|
def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
|
|
233
252
|
super().__init__()
|
|
234
|
-
self.
|
|
235
|
-
|
|
236
|
-
nn.Linear(emb_size, n_outputs),
|
|
237
|
-
)
|
|
253
|
+
self.activation_layer = activation()
|
|
254
|
+
self.classification_head = nn.Linear(emb_size, n_outputs)
|
|
238
255
|
|
|
239
256
|
def forward(self, x):
|
|
257
|
+
x = self.activation_layer(x)
|
|
240
258
|
out = self.classification_head(x)
|
|
241
259
|
return out
|
|
242
260
|
|
|
@@ -344,6 +362,9 @@ class _BIOTEncoder(nn.Module):
|
|
|
344
362
|
n_fft=200, # Related with the frequency resolution
|
|
345
363
|
hop_length=100,
|
|
346
364
|
drop_prob: float = 0.1,
|
|
365
|
+
max_seq_len: int = 1024, # The maximum sequence length
|
|
366
|
+
attn_dropout=0.2, # dropout post-attention
|
|
367
|
+
attn_layer_dropout=0.2, # dropout right after self-attention layer
|
|
347
368
|
):
|
|
348
369
|
super().__init__()
|
|
349
370
|
|
|
@@ -357,9 +378,9 @@ class _BIOTEncoder(nn.Module):
|
|
|
357
378
|
dim=emb_size,
|
|
358
379
|
heads=att_num_heads,
|
|
359
380
|
depth=n_layers,
|
|
360
|
-
max_seq_len=
|
|
361
|
-
attn_layer_dropout=
|
|
362
|
-
attn_dropout=
|
|
381
|
+
max_seq_len=max_seq_len,
|
|
382
|
+
attn_layer_dropout=attn_layer_dropout,
|
|
383
|
+
attn_dropout=attn_dropout,
|
|
363
384
|
)
|
|
364
385
|
self.positional_encoding = _PositionalEncoding(emb_size, drop_prob=drop_prob)
|
|
365
386
|
|
braindecode/models/contrawr.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
import torch.nn as nn
|
|
5
|
+
from mne.utils import warn
|
|
5
6
|
|
|
6
7
|
from braindecode.models.base import EEGModuleMixin
|
|
7
8
|
|
|
@@ -57,6 +58,9 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
57
58
|
steps=20,
|
|
58
59
|
activation: nn.Module = nn.ELU,
|
|
59
60
|
drop_prob: float = 0.5,
|
|
61
|
+
stride_res: int = 2,
|
|
62
|
+
kernel_size_res: int = 3,
|
|
63
|
+
padding_res: int = 1,
|
|
60
64
|
# Another way to pass the EEG parameters
|
|
61
65
|
chs_info=None,
|
|
62
66
|
n_times=None,
|
|
@@ -74,7 +78,17 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
74
78
|
if not isinstance(res_channels, list):
|
|
75
79
|
raise ValueError("res_channels must be a list of integers.")
|
|
76
80
|
|
|
77
|
-
self.
|
|
81
|
+
if self.input_window_seconds < 1.0:
|
|
82
|
+
warning_msg = (
|
|
83
|
+
"The input window is less than 1 second, which may not be "
|
|
84
|
+
"sufficient for the model to learn meaningful representations."
|
|
85
|
+
"changing the `n_fft` to `n_times`."
|
|
86
|
+
)
|
|
87
|
+
warn(warning_msg, UserWarning)
|
|
88
|
+
self.n_fft = self.n_times
|
|
89
|
+
else:
|
|
90
|
+
self.n_fft = int(self.sfreq)
|
|
91
|
+
|
|
78
92
|
self.steps = steps
|
|
79
93
|
|
|
80
94
|
res_channels = [self.n_chans] + res_channels + [emb_size]
|
|
@@ -89,19 +103,22 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
89
103
|
_ResBlock(
|
|
90
104
|
in_channels=res_channels[i],
|
|
91
105
|
out_channels=res_channels[i + 1],
|
|
92
|
-
stride=
|
|
106
|
+
stride=stride_res,
|
|
93
107
|
use_downsampling=True,
|
|
94
108
|
pooling=True,
|
|
95
109
|
drop_prob=drop_prob,
|
|
110
|
+
kernel_size=kernel_size_res,
|
|
111
|
+
padding=padding_res,
|
|
112
|
+
activation=activation,
|
|
96
113
|
)
|
|
97
114
|
for i in range(len(res_channels) - 1)
|
|
98
115
|
]
|
|
99
116
|
)
|
|
117
|
+
self.adaptative_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
118
|
+
self.flatten_layer = nn.Flatten()
|
|
100
119
|
|
|
101
|
-
self.
|
|
102
|
-
|
|
103
|
-
nn.Linear(emb_size, self.n_outputs),
|
|
104
|
-
)
|
|
120
|
+
self.activation_layer = activation()
|
|
121
|
+
self.final_layer = nn.Linear(emb_size, self.n_outputs)
|
|
105
122
|
|
|
106
123
|
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
107
124
|
"""
|
|
@@ -118,9 +135,13 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
118
135
|
"""
|
|
119
136
|
X = self.torch_stft(X)
|
|
120
137
|
|
|
121
|
-
for conv in self.convs
|
|
138
|
+
for conv in self.convs:
|
|
122
139
|
X = conv.forward(X)
|
|
123
|
-
|
|
140
|
+
|
|
141
|
+
emb = self.adaptative_pool(X)
|
|
142
|
+
emb = self.flatten_layer(emb)
|
|
143
|
+
emb = self.activation_layer(emb)
|
|
144
|
+
|
|
124
145
|
return self.final_layer(emb)
|
|
125
146
|
|
|
126
147
|
|
braindecode/models/ctnet.py
CHANGED
|
@@ -10,6 +10,7 @@ classification from Wei Zhao et al. (2024).
|
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
12
|
import math
|
|
13
|
+
from typing import Optional
|
|
13
14
|
|
|
14
15
|
import torch
|
|
15
16
|
from einops.layers.torch import Rearrange
|
|
@@ -57,7 +58,7 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
57
58
|
Activation function to use in the network.
|
|
58
59
|
heads : int, default=4
|
|
59
60
|
Number of attention heads in the Transformer encoder.
|
|
60
|
-
emb_size : int, default=
|
|
61
|
+
emb_size : int or None, default=None
|
|
61
62
|
Embedding size (dimensionality) for the Transformer encoder.
|
|
62
63
|
depth : int, default=6
|
|
63
64
|
Number of encoder layers in the Transformer.
|
|
@@ -110,11 +111,11 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
110
111
|
drop_prob_final: float = 0.5,
|
|
111
112
|
# other parameters
|
|
112
113
|
heads: int = 4,
|
|
113
|
-
emb_size: int = 40,
|
|
114
|
+
emb_size: Optional[int] = 40,
|
|
114
115
|
depth: int = 6,
|
|
115
|
-
n_filters_time: int =
|
|
116
|
+
n_filters_time: Optional[int] = None,
|
|
116
117
|
kernel_size: int = 64,
|
|
117
|
-
depth_multiplier: int = 2,
|
|
118
|
+
depth_multiplier: Optional[int] = 2,
|
|
118
119
|
pool_size_1: int = 8,
|
|
119
120
|
pool_size_2: int = 8,
|
|
120
121
|
):
|
|
@@ -128,21 +129,18 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
128
129
|
)
|
|
129
130
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
130
131
|
|
|
131
|
-
self.emb_size = emb_size
|
|
132
132
|
self.activation_patch = activation_patch
|
|
133
133
|
self.activation_transformer = activation_transformer
|
|
134
|
-
|
|
135
|
-
self.n_filters_time = n_filters_time
|
|
136
134
|
self.drop_prob_cnn = drop_prob_cnn
|
|
137
135
|
self.pool_size_1 = pool_size_1
|
|
138
136
|
self.pool_size_2 = pool_size_2
|
|
139
|
-
self.depth_multiplier = depth_multiplier
|
|
140
137
|
self.kernel_size = kernel_size
|
|
141
138
|
self.drop_prob_posi = drop_prob_posi
|
|
142
139
|
self.drop_prob_final = drop_prob_final
|
|
143
|
-
|
|
140
|
+
self.heads = heads
|
|
141
|
+
self.depth = depth
|
|
144
142
|
# n_times - pool_size_1 / p
|
|
145
|
-
sequence_length = math.floor(
|
|
143
|
+
self.sequence_length = math.floor(
|
|
146
144
|
(
|
|
147
145
|
math.floor((self.n_times - self.pool_size_1) / self.pool_size_1 + 1)
|
|
148
146
|
- self.pool_size_2
|
|
@@ -151,6 +149,10 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
151
149
|
+ 1
|
|
152
150
|
)
|
|
153
151
|
|
|
152
|
+
self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
|
|
153
|
+
depth_multiplier, n_filters_time, emb_size
|
|
154
|
+
)
|
|
155
|
+
|
|
154
156
|
# Layers
|
|
155
157
|
self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
|
|
156
158
|
self.flatten = nn.Flatten()
|
|
@@ -167,14 +169,17 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
167
169
|
)
|
|
168
170
|
|
|
169
171
|
self.position = _PositionalEncoding(
|
|
170
|
-
emb_size=emb_size,
|
|
172
|
+
emb_size=self.emb_size,
|
|
171
173
|
drop_prob=self.drop_prob_posi,
|
|
172
174
|
n_times=self.n_times,
|
|
173
175
|
pool_size=self.pool_size_1,
|
|
174
176
|
)
|
|
175
177
|
|
|
176
178
|
self.trans = _TransformerEncoder(
|
|
177
|
-
heads,
|
|
179
|
+
self.heads,
|
|
180
|
+
self.depth,
|
|
181
|
+
self.emb_size,
|
|
182
|
+
activation=self.activation_transformer,
|
|
178
183
|
)
|
|
179
184
|
|
|
180
185
|
self.flatten_drop_layer = nn.Sequential(
|
|
@@ -183,7 +188,8 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
183
188
|
)
|
|
184
189
|
|
|
185
190
|
self.final_layer = nn.Linear(
|
|
186
|
-
in_features=emb_size * sequence_length,
|
|
191
|
+
in_features=int(self.emb_size * self.sequence_length),
|
|
192
|
+
out_features=self.n_outputs,
|
|
187
193
|
)
|
|
188
194
|
|
|
189
195
|
def forward(self, x: Tensor) -> Tensor:
|
|
@@ -210,6 +216,86 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
210
216
|
out = self.final_layer(flatten_feature)
|
|
211
217
|
return out
|
|
212
218
|
|
|
219
|
+
@staticmethod
|
|
220
|
+
def _resolve_dims(
|
|
221
|
+
depth_multiplier: Optional[int],
|
|
222
|
+
n_filters_time: Optional[int],
|
|
223
|
+
emb_size: Optional[int],
|
|
224
|
+
) -> tuple[int, int, int]:
|
|
225
|
+
# Basic type/positivity checks for provided values
|
|
226
|
+
for name, val in (
|
|
227
|
+
("depth_multiplier", depth_multiplier),
|
|
228
|
+
("n_filters_time", n_filters_time),
|
|
229
|
+
("emb_size", emb_size),
|
|
230
|
+
):
|
|
231
|
+
if val is not None:
|
|
232
|
+
if not isinstance(val, int):
|
|
233
|
+
raise TypeError(f"{name} must be int, got {type(val).__name__}")
|
|
234
|
+
if val <= 0:
|
|
235
|
+
raise ValueError(f"{name} must be > 0, got {val}")
|
|
236
|
+
|
|
237
|
+
missing = [
|
|
238
|
+
k
|
|
239
|
+
for k, v in {
|
|
240
|
+
"depth_multiplier": depth_multiplier,
|
|
241
|
+
"n_filters_time": n_filters_time,
|
|
242
|
+
"emb_size": emb_size,
|
|
243
|
+
}.items()
|
|
244
|
+
if v is None
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
if len(missing) >= 2:
|
|
248
|
+
# Too many unknowns → ambiguous
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Specify exactly two of {depth_multiplier, n_filters_time, emb_size}; the third will be inferred."
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if len(missing) == 1:
|
|
254
|
+
# Infer the missing one
|
|
255
|
+
if missing[0] == "emb_size":
|
|
256
|
+
assert depth_multiplier is not None and n_filters_time is not None
|
|
257
|
+
emb_size = depth_multiplier * n_filters_time
|
|
258
|
+
elif missing[0] == "n_filters_time":
|
|
259
|
+
assert emb_size is not None and depth_multiplier is not None
|
|
260
|
+
if emb_size % depth_multiplier != 0:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"emb_size={emb_size} must be divisible by depth_multiplier={depth_multiplier}"
|
|
263
|
+
)
|
|
264
|
+
n_filters_time = emb_size // depth_multiplier
|
|
265
|
+
else: # missing depth_multiplier
|
|
266
|
+
assert emb_size is not None and n_filters_time is not None
|
|
267
|
+
if emb_size % n_filters_time != 0:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"emb_size={emb_size} must be divisible by n_filters_time={n_filters_time}"
|
|
270
|
+
)
|
|
271
|
+
depth_multiplier = emb_size // n_filters_time
|
|
272
|
+
|
|
273
|
+
else:
|
|
274
|
+
# All provided: enforce consistency
|
|
275
|
+
assert (
|
|
276
|
+
depth_multiplier is not None
|
|
277
|
+
and n_filters_time is not None
|
|
278
|
+
and emb_size is not None
|
|
279
|
+
)
|
|
280
|
+
prod = depth_multiplier * n_filters_time
|
|
281
|
+
if prod != emb_size:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"`depth_multiplier * n_filters_time` must equal `emb_size`, "
|
|
284
|
+
f"but got {depth_multiplier} * {n_filters_time} = {prod} != {emb_size}. "
|
|
285
|
+
"Fix by setting one of: "
|
|
286
|
+
f"emb_size={prod}, "
|
|
287
|
+
f"n_filters_time={emb_size // depth_multiplier if emb_size % depth_multiplier == 0 else 'not integer'}, "
|
|
288
|
+
f"depth_multiplier={emb_size // n_filters_time if emb_size % n_filters_time == 0 else 'not integer'}."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Ensure plain ints for the return type
|
|
292
|
+
assert (
|
|
293
|
+
depth_multiplier is not None
|
|
294
|
+
and n_filters_time is not None
|
|
295
|
+
and emb_size is not None
|
|
296
|
+
)
|
|
297
|
+
return depth_multiplier, n_filters_time, emb_size
|
|
298
|
+
|
|
213
299
|
|
|
214
300
|
class _PatchEmbeddingEEGNet(nn.Module):
|
|
215
301
|
def __init__(
|