braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/util.py
CHANGED
|
@@ -3,163 +3,127 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
import inspect
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
|
|
7
|
-
import
|
|
8
|
-
import torch
|
|
9
|
-
from scipy.special import log_softmax
|
|
10
|
-
from sklearn.utils import deprecated
|
|
8
|
+
import pandas as pd
|
|
11
9
|
|
|
12
10
|
import braindecode.models as models
|
|
13
11
|
|
|
12
|
+
models_dict = {}
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
)
|
|
19
|
-
def to_dense_prediction_model(model, axis=(2, 3)):
|
|
20
|
-
"""
|
|
21
|
-
Transform a sequential model with strides to a model that outputs
|
|
22
|
-
dense predictions by removing the strides and instead inserting dilations.
|
|
23
|
-
Modifies model in-place.
|
|
24
|
-
|
|
25
|
-
Parameters
|
|
26
|
-
----------
|
|
27
|
-
model: torch.nn.Module
|
|
28
|
-
Model which modules will be modified
|
|
29
|
-
axis: int or (int,int)
|
|
30
|
-
Axis to transform (in terms of intermediate output axes)
|
|
31
|
-
can either be 2, 3, or (2,3).
|
|
32
|
-
|
|
33
|
-
Notes
|
|
34
|
-
-----
|
|
35
|
-
Does not yet work correctly for average pooling.
|
|
36
|
-
Prior to version 0.1.7, there had been a bug that could move strides
|
|
37
|
-
backwards one layer.
|
|
38
|
-
|
|
39
|
-
"""
|
|
40
|
-
if not hasattr(axis, "__len__"):
|
|
41
|
-
axis = [axis]
|
|
42
|
-
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
|
|
43
|
-
axis = np.array(axis) - 2
|
|
44
|
-
stride_so_far = np.array([1, 1])
|
|
45
|
-
for module in model.modules():
|
|
46
|
-
if hasattr(module, "dilation"):
|
|
47
|
-
assert module.dilation == 1 or (module.dilation == (1, 1)), (
|
|
48
|
-
"Dilation should equal 1 before conversion, maybe the model is "
|
|
49
|
-
"already converted?"
|
|
50
|
-
)
|
|
51
|
-
new_dilation = [1, 1]
|
|
52
|
-
for ax in axis:
|
|
53
|
-
new_dilation[ax] = int(stride_so_far[ax])
|
|
54
|
-
module.dilation = tuple(new_dilation)
|
|
55
|
-
if hasattr(module, "stride"):
|
|
56
|
-
if not hasattr(module.stride, "__len__"):
|
|
57
|
-
module.stride = (module.stride, module.stride)
|
|
58
|
-
stride_so_far *= np.array(module.stride)
|
|
59
|
-
new_stride = list(module.stride)
|
|
60
|
-
for ax in axis:
|
|
61
|
-
new_stride[ax] = 1
|
|
62
|
-
module.stride = tuple(new_stride)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@deprecated(
|
|
66
|
-
"will be removed in version 1.0. Use EEGModuleMixin.get_output_shape method directly on the "
|
|
67
|
-
"model object."
|
|
68
|
-
)
|
|
69
|
-
def get_output_shape(model, in_chans, input_window_samples):
|
|
70
|
-
"""Returns shape of neural network output for batch size equal 1.
|
|
71
|
-
|
|
72
|
-
Returns
|
|
73
|
-
-------
|
|
74
|
-
output_shape: tuple
|
|
75
|
-
shape of the network output for `batch_size==1` (1, ...)
|
|
76
|
-
"""
|
|
77
|
-
with torch.no_grad():
|
|
78
|
-
dummy_input = torch.ones(
|
|
79
|
-
1, in_chans, input_window_samples,
|
|
80
|
-
dtype=next(model.parameters()).dtype,
|
|
81
|
-
device=next(model.parameters()).device,
|
|
82
|
-
)
|
|
83
|
-
output_shape = model(dummy_input).shape
|
|
84
|
-
return output_shape
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _pad_shift_array(x, stride=1):
|
|
88
|
-
"""Zero-pad and shift rows of a 3D array.
|
|
89
|
-
|
|
90
|
-
E.g., used to align predictions of corresponding windows in
|
|
91
|
-
sequence-to-sequence models.
|
|
92
|
-
|
|
93
|
-
Parameters
|
|
94
|
-
----------
|
|
95
|
-
x : np.ndarray
|
|
96
|
-
Array of shape (n_rows, n_classes, n_windows).
|
|
97
|
-
stride : int
|
|
98
|
-
Number of non-overlapping elements between two consecutive sequences.
|
|
99
|
-
|
|
100
|
-
Returns
|
|
101
|
-
-------
|
|
102
|
-
np.ndarray :
|
|
103
|
-
Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
|
|
104
|
-
where each row is obtained by zero-padding the corresponding row in
|
|
105
|
-
``x`` before and after in the last dimension.
|
|
106
|
-
"""
|
|
107
|
-
if x.ndim != 3:
|
|
108
|
-
raise NotImplementedError(
|
|
109
|
-
'x must be of shape (n_rows, n_classes, n_windows), got '
|
|
110
|
-
f'{x.shape}')
|
|
111
|
-
x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
|
|
112
|
-
orig_strides = x_padded.strides
|
|
113
|
-
new_strides = (orig_strides[0] - stride * orig_strides[2],
|
|
114
|
-
orig_strides[1],
|
|
115
|
-
orig_strides[2])
|
|
116
|
-
return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
|
|
117
|
-
|
|
14
|
+
# For the models inside the init model, go through all the models
|
|
15
|
+
# check those have the EEGMixin class inherited. If they are, add them to the
|
|
16
|
+
# list.
|
|
118
17
|
|
|
119
|
-
def aggregate_probas(logits, n_windows_stride=1):
|
|
120
|
-
"""Aggregate predicted probabilities with self-ensembling.
|
|
121
18
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
19
|
+
def _init_models_dict():
|
|
20
|
+
for m in inspect.getmembers(models, inspect.isclass):
|
|
21
|
+
if (
|
|
22
|
+
issubclass(m[1], models.base.EEGModuleMixin)
|
|
23
|
+
and m[1] != models.base.EEGModuleMixin
|
|
24
|
+
):
|
|
25
|
+
models_dict[m[0]] = m[1]
|
|
125
26
|
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
logits : np.ndarray
|
|
129
|
-
Array of shape (n_sequences, n_classes, n_windows) containing the
|
|
130
|
-
logits (i.e. the raw unnormalized scores for each class) for each
|
|
131
|
-
window of each sequence.
|
|
132
|
-
n_windows_stride : int
|
|
133
|
-
Number of windows between two consecutive sequences. Default is 1
|
|
134
|
-
(maximally overlapping sequences).
|
|
135
27
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
28
|
+
################################################################
|
|
29
|
+
# Test cases for models
|
|
30
|
+
#
|
|
31
|
+
# This list should be updated whenever a new model is added to
|
|
32
|
+
# braindecode (otherwise `test_completeness__models_test_cases`
|
|
33
|
+
# will fail).
|
|
34
|
+
# Each element in the list should be a tuple with structure
|
|
35
|
+
# (model_class, required_params, signal_params), such that:
|
|
36
|
+
#
|
|
37
|
+
# model_name: str
|
|
38
|
+
# The name of the class of the model to be tested.
|
|
39
|
+
# required_params: list[str]
|
|
40
|
+
# The signal-related parameters that are needed to initialize
|
|
41
|
+
# the model.
|
|
42
|
+
# signal_params: dict | None
|
|
43
|
+
# The characteristics of the signal that should be passed to
|
|
44
|
+
# the model tested in case the default_signal_params are not
|
|
45
|
+
# compatible with this model.
|
|
46
|
+
# The keys of this dictionary can only be among those of
|
|
47
|
+
# default_signal_params.
|
|
48
|
+
################################################################
|
|
49
|
+
models_mandatory_parameters = [
|
|
50
|
+
("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
51
|
+
("BDTCN", ["n_chans", "n_outputs"], None),
|
|
52
|
+
("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
|
|
53
|
+
("DeepSleepNet", ["n_outputs"], None),
|
|
54
|
+
("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
|
|
55
|
+
("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
56
|
+
("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
57
|
+
("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
58
|
+
("EEGNetv1", ["n_chans", "n_outputs", "n_times"], None),
|
|
59
|
+
("EEGNetv4", ["n_chans", "n_outputs", "n_times"], None),
|
|
60
|
+
("EEGResNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
61
|
+
("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
62
|
+
(
|
|
63
|
+
"SleepStagerBlanco2020",
|
|
64
|
+
["n_chans", "n_outputs", "n_times"],
|
|
65
|
+
dict(n_chans=4), # n_chans dividable by n_groups=2
|
|
66
|
+
),
|
|
67
|
+
("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
68
|
+
(
|
|
69
|
+
"SleepStagerEldele2021",
|
|
70
|
+
["n_outputs", "n_times", "sfreq"],
|
|
71
|
+
dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
|
|
72
|
+
), # 1 channel
|
|
73
|
+
("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
74
|
+
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
|
|
75
|
+
("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
|
|
76
|
+
("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
77
|
+
("Labram", ["n_chans", "n_outputs", "n_times"], None),
|
|
78
|
+
("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
|
|
79
|
+
("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
80
|
+
("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], dict(sfreq=200.0)),
|
|
81
|
+
("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
|
|
82
|
+
("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
83
|
+
("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
84
|
+
("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
85
|
+
("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
86
|
+
("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
87
|
+
("CTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
88
|
+
("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
|
|
89
|
+
("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
90
|
+
("SignalJEPA", ["chs_info"], None),
|
|
91
|
+
("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
|
|
92
|
+
("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
93
|
+
("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
94
|
+
("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
95
|
+
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
96
|
+
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
97
|
+
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
################################################################
|
|
101
|
+
# List of models that are not meant for classification
|
|
102
|
+
#
|
|
103
|
+
# Their output shape may difer from the expected output shape
|
|
104
|
+
# for classification models.
|
|
105
|
+
################################################################
|
|
106
|
+
non_classification_models = [
|
|
107
|
+
"SignalJEPA",
|
|
108
|
+
]
|
|
142
109
|
|
|
143
|
-
References
|
|
144
|
-
----------
|
|
145
|
-
.. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
|
|
146
|
-
De Vos, M. (2018). Joint classification and prediction CNN framework
|
|
147
|
-
for automatic sleep stage classification. IEEE Transactions on
|
|
148
|
-
Biomedical Engineering, 66(5), 1285-1296.
|
|
149
|
-
"""
|
|
150
|
-
log_probas = log_softmax(logits, axis=1)
|
|
151
|
-
return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
|
|
152
110
|
|
|
111
|
+
################################################################
|
|
112
|
+
def get_summary_table(dir_name=None):
|
|
113
|
+
if dir_name is None:
|
|
114
|
+
dir_path = Path(__file__).parent
|
|
115
|
+
else:
|
|
116
|
+
dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
|
|
153
117
|
|
|
154
|
-
|
|
118
|
+
path = dir_path / "summary.csv"
|
|
155
119
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
120
|
+
df = pd.read_csv(
|
|
121
|
+
path,
|
|
122
|
+
header=0,
|
|
123
|
+
index_col="Model",
|
|
124
|
+
skipinitialspace=True,
|
|
125
|
+
)
|
|
126
|
+
return df
|
|
159
127
|
|
|
160
128
|
|
|
161
|
-
|
|
162
|
-
for m in inspect.getmembers(models, inspect.isclass):
|
|
163
|
-
if (issubclass(m[1], models.base.EEGModuleMixin)
|
|
164
|
-
and m[1] != models.base.EEGModuleMixin):
|
|
165
|
-
models_dict[m[0]] = m[1]
|
|
129
|
+
_summary_table = get_summary_table()
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from .activation import LogActivation, SafeLog
|
|
2
|
+
from .attention import (
|
|
3
|
+
CAT,
|
|
4
|
+
CBAM,
|
|
5
|
+
ECA,
|
|
6
|
+
FCA,
|
|
7
|
+
GCT,
|
|
8
|
+
SRM,
|
|
9
|
+
CATLite,
|
|
10
|
+
EncNet,
|
|
11
|
+
GatherExcite,
|
|
12
|
+
GSoP,
|
|
13
|
+
MultiHeadAttention,
|
|
14
|
+
SqueezeAndExcitation,
|
|
15
|
+
)
|
|
16
|
+
from .blocks import MLP, FeedForwardBlock, InceptionBlock
|
|
17
|
+
from .convolution import (
|
|
18
|
+
AvgPool2dWithConv,
|
|
19
|
+
CausalConv1d,
|
|
20
|
+
CombinedConv,
|
|
21
|
+
Conv2dWithConstraint,
|
|
22
|
+
DepthwiseConv2d,
|
|
23
|
+
)
|
|
24
|
+
from .filter import FilterBankLayer, GeneralizedGaussianFilter
|
|
25
|
+
from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
|
|
26
|
+
from .linear import LinearWithConstraint, MaxNormLinear
|
|
27
|
+
from .parametrization import MaxNorm, MaxNormParametrize
|
|
28
|
+
from .stats import (
|
|
29
|
+
LogPowerLayer,
|
|
30
|
+
LogVarLayer,
|
|
31
|
+
MaxLayer,
|
|
32
|
+
MeanLayer,
|
|
33
|
+
StatLayer,
|
|
34
|
+
StdLayer,
|
|
35
|
+
VarLayer,
|
|
36
|
+
)
|
|
37
|
+
from .util import aggregate_probas
|
|
38
|
+
from .wrapper import Expression, IntermediateOutputWrapper
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"LogActivation",
|
|
42
|
+
"SafeLog",
|
|
43
|
+
"CAT",
|
|
44
|
+
"CBAM",
|
|
45
|
+
"ECA",
|
|
46
|
+
"FCA",
|
|
47
|
+
"GCT",
|
|
48
|
+
"SRM",
|
|
49
|
+
"CATLite",
|
|
50
|
+
"EncNet",
|
|
51
|
+
"GatherExcite",
|
|
52
|
+
"GSoP",
|
|
53
|
+
"MultiHeadAttention",
|
|
54
|
+
"SqueezeAndExcitation",
|
|
55
|
+
"MLP",
|
|
56
|
+
"FeedForwardBlock",
|
|
57
|
+
"InceptionBlock",
|
|
58
|
+
"AvgPool2dWithConv",
|
|
59
|
+
"CausalConv1d",
|
|
60
|
+
"CombinedConv",
|
|
61
|
+
"Conv2dWithConstraint",
|
|
62
|
+
"DepthwiseConv2d",
|
|
63
|
+
"FilterBankLayer",
|
|
64
|
+
"GeneralizedGaussianFilter",
|
|
65
|
+
"Chomp1d",
|
|
66
|
+
"DropPath",
|
|
67
|
+
"Ensure4d",
|
|
68
|
+
"SqueezeFinalOutput",
|
|
69
|
+
"TimeDistributed",
|
|
70
|
+
"LinearWithConstraint",
|
|
71
|
+
"MaxNormLinear",
|
|
72
|
+
"MaxNorm",
|
|
73
|
+
"MaxNormParametrize",
|
|
74
|
+
"LogPowerLayer",
|
|
75
|
+
"LogVarLayer",
|
|
76
|
+
"MaxLayer",
|
|
77
|
+
"MeanLayer",
|
|
78
|
+
"StatLayer",
|
|
79
|
+
"StdLayer",
|
|
80
|
+
"VarLayer",
|
|
81
|
+
"aggregate_probas",
|
|
82
|
+
"Expression",
|
|
83
|
+
"IntermediateOutputWrapper",
|
|
84
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor, nn
|
|
3
|
+
|
|
4
|
+
import braindecode.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SafeLog(nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
Safe logarithm activation function module.
|
|
10
|
+
|
|
11
|
+
:math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
eps : float, optional
|
|
16
|
+
A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
|
|
17
|
+
Default is 1e-6.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, epsilon: float = 1e-6):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.epsilon = epsilon
|
|
24
|
+
|
|
25
|
+
def forward(self, x) -> Tensor:
|
|
26
|
+
"""
|
|
27
|
+
Forward pass of the SafeLog module.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
x : torch.Tensor
|
|
32
|
+
Input tensor.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
torch.Tensor
|
|
37
|
+
Output tensor after applying safe logarithm.
|
|
38
|
+
"""
|
|
39
|
+
return F.safe_log(x=x, eps=self.epsilon)
|
|
40
|
+
|
|
41
|
+
def extra_repr(self) -> str:
|
|
42
|
+
eps_str = f"eps={self.epsilon}"
|
|
43
|
+
return eps_str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LogActivation(nn.Module):
|
|
47
|
+
"""Logarithm activation function."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
|
|
50
|
+
"""
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
epsilon : float
|
|
54
|
+
Small float to adjust the activation.
|
|
55
|
+
"""
|
|
56
|
+
super().__init__(*args, **kwargs)
|
|
57
|
+
self.epsilon = epsilon
|
|
58
|
+
|
|
59
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)
|