braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/base.py
CHANGED
|
@@ -3,10 +3,11 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD-3
|
|
5
5
|
|
|
6
|
-
import
|
|
7
|
-
from typing import Dict, Iterable, List, Optional, Tuple
|
|
6
|
+
from __future__ import annotations
|
|
8
7
|
|
|
8
|
+
import warnings
|
|
9
9
|
from collections import OrderedDict
|
|
10
|
+
from typing import Dict, Iterable, Optional
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
import torch
|
|
@@ -21,11 +22,11 @@ def deprecated_args(obj, *old_new_args):
|
|
|
21
22
|
out_args.append(new_val)
|
|
22
23
|
else:
|
|
23
24
|
warnings.warn(
|
|
24
|
-
f
|
|
25
|
+
f"{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead."
|
|
25
26
|
)
|
|
26
27
|
if new_val is not None:
|
|
27
28
|
raise ValueError(
|
|
28
|
-
f
|
|
29
|
+
f"{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified."
|
|
29
30
|
)
|
|
30
31
|
out_args.append(old_val)
|
|
31
32
|
return out_args
|
|
@@ -51,21 +52,12 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
51
52
|
Length of the input window in seconds.
|
|
52
53
|
sfreq : float
|
|
53
54
|
Sampling frequency of the EEG recordings.
|
|
54
|
-
add_log_softmax: bool
|
|
55
|
-
Whether to use log-softmax non-linearity as the output function.
|
|
56
|
-
LogSoftmax final layer will be removed in the future.
|
|
57
|
-
Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
|
|
58
|
-
Check the documentation of the torch.nn loss functions:
|
|
59
|
-
https://pytorch.org/docs/stable/nn.html#loss-functions.
|
|
60
55
|
|
|
61
56
|
Raises
|
|
62
57
|
------
|
|
63
58
|
ValueError: If some input signal-related parameters are not specified
|
|
64
59
|
and can not be inferred.
|
|
65
60
|
|
|
66
|
-
FutureWarning: If add_log_softmax is True, since LogSoftmax final layer
|
|
67
|
-
will be removed in the future.
|
|
68
|
-
|
|
69
61
|
Notes
|
|
70
62
|
-----
|
|
71
63
|
If some input signal-related parameters are not specified,
|
|
@@ -73,139 +65,132 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
73
65
|
"""
|
|
74
66
|
|
|
75
67
|
def __init__(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
add_log_softmax: Optional[bool] = False,
|
|
68
|
+
self,
|
|
69
|
+
n_outputs: Optional[int] = None, # type: ignore[assignment]
|
|
70
|
+
n_chans: Optional[int] = None, # type: ignore[assignment]
|
|
71
|
+
chs_info=None, # type: ignore[assignment]
|
|
72
|
+
n_times: Optional[int] = None, # type: ignore[assignment]
|
|
73
|
+
input_window_seconds: Optional[float] = None, # type: ignore[assignment]
|
|
74
|
+
sfreq: Optional[float] = None, # type: ignore[assignment]
|
|
84
75
|
):
|
|
76
|
+
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
|
|
77
|
+
raise ValueError(f"{n_chans=} different from {chs_info=} length")
|
|
85
78
|
if (
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
raise ValueError(f'{n_chans=} different from {chs_info=} length')
|
|
91
|
-
if (
|
|
92
|
-
n_times is not None and
|
|
93
|
-
input_window_seconds is not None and
|
|
94
|
-
sfreq is not None and
|
|
95
|
-
n_times != int(input_window_seconds * sfreq)
|
|
79
|
+
n_times is not None
|
|
80
|
+
and input_window_seconds is not None
|
|
81
|
+
and sfreq is not None
|
|
82
|
+
and n_times != int(input_window_seconds * sfreq)
|
|
96
83
|
):
|
|
97
84
|
raise ValueError(
|
|
98
|
-
f
|
|
99
|
-
f'{input_window_seconds=} * {sfreq=}'
|
|
85
|
+
f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
|
|
100
86
|
)
|
|
101
|
-
|
|
102
|
-
self.
|
|
103
|
-
self._chs_info = chs_info
|
|
104
|
-
self.
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
107
|
-
self.
|
|
87
|
+
|
|
88
|
+
self._input_window_seconds = input_window_seconds # type: ignore[assignment]
|
|
89
|
+
self._chs_info = chs_info # type: ignore[assignment]
|
|
90
|
+
self._n_outputs = n_outputs # type: ignore[assignment]
|
|
91
|
+
self._n_chans = n_chans # type: ignore[assignment]
|
|
92
|
+
self._n_times = n_times # type: ignore[assignment]
|
|
93
|
+
self._sfreq = sfreq # type: ignore[assignment]
|
|
94
|
+
|
|
108
95
|
super().__init__()
|
|
109
96
|
|
|
110
97
|
@property
|
|
111
|
-
def n_outputs(self):
|
|
98
|
+
def n_outputs(self) -> int:
|
|
112
99
|
if self._n_outputs is None:
|
|
113
|
-
raise ValueError(
|
|
100
|
+
raise ValueError("n_outputs not specified.")
|
|
114
101
|
return self._n_outputs
|
|
115
102
|
|
|
116
103
|
@property
|
|
117
|
-
def n_chans(self):
|
|
104
|
+
def n_chans(self) -> int:
|
|
118
105
|
if self._n_chans is None and self._chs_info is not None:
|
|
119
106
|
return len(self._chs_info)
|
|
120
107
|
elif self._n_chans is None:
|
|
121
108
|
raise ValueError(
|
|
122
|
-
|
|
109
|
+
"n_chans could not be inferred. Either specify n_chans or chs_info."
|
|
123
110
|
)
|
|
124
111
|
return self._n_chans
|
|
125
112
|
|
|
126
113
|
@property
|
|
127
|
-
def chs_info(self):
|
|
114
|
+
def chs_info(self) -> list[str]:
|
|
128
115
|
if self._chs_info is None:
|
|
129
|
-
raise ValueError(
|
|
116
|
+
raise ValueError("chs_info not specified.")
|
|
130
117
|
return self._chs_info
|
|
131
118
|
|
|
132
119
|
@property
|
|
133
|
-
def n_times(self):
|
|
120
|
+
def n_times(self) -> int:
|
|
134
121
|
if (
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
122
|
+
self._n_times is None
|
|
123
|
+
and self._input_window_seconds is not None
|
|
124
|
+
and self._sfreq is not None
|
|
138
125
|
):
|
|
139
126
|
return int(self._input_window_seconds * self._sfreq)
|
|
140
127
|
elif self._n_times is None:
|
|
141
128
|
raise ValueError(
|
|
142
|
-
|
|
143
|
-
|
|
129
|
+
"n_times could not be inferred. "
|
|
130
|
+
"Either specify n_times or input_window_seconds and sfreq."
|
|
144
131
|
)
|
|
145
132
|
return self._n_times
|
|
146
133
|
|
|
147
134
|
@property
|
|
148
|
-
def input_window_seconds(self):
|
|
135
|
+
def input_window_seconds(self) -> float:
|
|
149
136
|
if (
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
137
|
+
self._input_window_seconds is None
|
|
138
|
+
and self._n_times is not None
|
|
139
|
+
and self._sfreq is not None
|
|
153
140
|
):
|
|
154
|
-
return self._n_times / self._sfreq
|
|
141
|
+
return float(self._n_times / self._sfreq)
|
|
155
142
|
elif self._input_window_seconds is None:
|
|
156
143
|
raise ValueError(
|
|
157
|
-
|
|
158
|
-
|
|
144
|
+
"input_window_seconds could not be inferred. "
|
|
145
|
+
"Either specify input_window_seconds or n_times and sfreq."
|
|
159
146
|
)
|
|
160
147
|
return self._input_window_seconds
|
|
161
148
|
|
|
162
149
|
@property
|
|
163
|
-
def sfreq(self):
|
|
150
|
+
def sfreq(self) -> float:
|
|
164
151
|
if (
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
152
|
+
self._sfreq is None
|
|
153
|
+
and self._input_window_seconds is not None
|
|
154
|
+
and self._n_times is not None
|
|
168
155
|
):
|
|
169
|
-
return self._n_times / self._input_window_seconds
|
|
156
|
+
return float(self._n_times / self._input_window_seconds)
|
|
170
157
|
elif self._sfreq is None:
|
|
171
158
|
raise ValueError(
|
|
172
|
-
|
|
173
|
-
|
|
159
|
+
"sfreq could not be inferred. "
|
|
160
|
+
"Either specify sfreq or input_window_seconds and n_times."
|
|
174
161
|
)
|
|
175
162
|
return self._sfreq
|
|
176
163
|
|
|
177
164
|
@property
|
|
178
|
-
def
|
|
179
|
-
if self._add_log_softmax:
|
|
180
|
-
warnings.warn("LogSoftmax final layer will be removed! " +
|
|
181
|
-
"Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!")
|
|
182
|
-
return self._add_log_softmax
|
|
183
|
-
|
|
184
|
-
@property
|
|
185
|
-
def input_shape(self) -> Tuple[int]:
|
|
165
|
+
def input_shape(self) -> tuple[int, int, int]:
|
|
186
166
|
"""Input data shape."""
|
|
187
167
|
return (1, self.n_chans, self.n_times)
|
|
188
168
|
|
|
189
|
-
def get_output_shape(self) ->
|
|
169
|
+
def get_output_shape(self) -> tuple[int, ...]:
|
|
190
170
|
"""Returns shape of neural network output for batch size equal 1.
|
|
191
171
|
|
|
192
172
|
Returns
|
|
193
173
|
-------
|
|
194
|
-
output_shape:
|
|
174
|
+
output_shape: tuple[int, ...]
|
|
195
175
|
shape of the network output for `batch_size==1` (1, ...)
|
|
196
|
-
|
|
176
|
+
"""
|
|
197
177
|
with torch.inference_mode():
|
|
198
178
|
try:
|
|
199
|
-
return tuple(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
179
|
+
return tuple(
|
|
180
|
+
self.forward( # type: ignore
|
|
181
|
+
torch.zeros(
|
|
182
|
+
self.input_shape,
|
|
183
|
+
dtype=next(self.parameters()).dtype, # type: ignore
|
|
184
|
+
device=next(self.parameters()).device, # type: ignore
|
|
185
|
+
)
|
|
186
|
+
).shape
|
|
187
|
+
)
|
|
205
188
|
except RuntimeError as exc:
|
|
206
189
|
if str(exc).endswith(
|
|
207
|
-
|
|
208
|
-
|
|
190
|
+
(
|
|
191
|
+
"Output size is too small",
|
|
192
|
+
"Kernel size can't be greater than actual input size",
|
|
193
|
+
)
|
|
209
194
|
):
|
|
210
195
|
msg = (
|
|
211
196
|
"During model prediction RuntimeError was thrown showing that at some "
|
|
@@ -217,10 +202,9 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
217
202
|
raise ValueError(msg) from exc
|
|
218
203
|
raise exc
|
|
219
204
|
|
|
220
|
-
mapping = None
|
|
205
|
+
mapping: Optional[Dict[str, str]] = None
|
|
221
206
|
|
|
222
207
|
def load_state_dict(self, state_dict, *args, **kwargs):
|
|
223
|
-
|
|
224
208
|
mapping = self.mapping if self.mapping else {}
|
|
225
209
|
new_state_dict = OrderedDict()
|
|
226
210
|
for k, v in state_dict.items():
|
|
@@ -231,7 +215,7 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
231
215
|
|
|
232
216
|
return super().load_state_dict(new_state_dict, *args, **kwargs)
|
|
233
217
|
|
|
234
|
-
def to_dense_prediction_model(self, axis:
|
|
218
|
+
def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
|
|
235
219
|
"""
|
|
236
220
|
Transform a sequential model with strides to a model that outputs
|
|
237
221
|
dense predictions by removing the strides and instead inserting dilations.
|
|
@@ -250,19 +234,19 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
250
234
|
backwards one layer.
|
|
251
235
|
|
|
252
236
|
"""
|
|
253
|
-
if not hasattr(axis, "
|
|
254
|
-
axis =
|
|
255
|
-
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
|
|
237
|
+
if not hasattr(axis, "__iter__"):
|
|
238
|
+
axis = (axis,)
|
|
239
|
+
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
|
|
256
240
|
axis = np.array(axis) - 2
|
|
257
241
|
stride_so_far = np.array([1, 1])
|
|
258
|
-
for module in self.modules():
|
|
242
|
+
for module in self.modules(): # type: ignore
|
|
259
243
|
if hasattr(module, "dilation"):
|
|
260
244
|
assert module.dilation == 1 or (module.dilation == (1, 1)), (
|
|
261
245
|
"Dilation should equal 1 before conversion, maybe the model is "
|
|
262
246
|
"already converted?"
|
|
263
247
|
)
|
|
264
248
|
new_dilation = [1, 1]
|
|
265
|
-
for ax in axis:
|
|
249
|
+
for ax in axis: # type: ignore[union-attr]
|
|
266
250
|
new_dilation[ax] = int(stride_so_far[ax])
|
|
267
251
|
module.dilation = tuple(new_dilation)
|
|
268
252
|
if hasattr(module, "stride"):
|
|
@@ -270,19 +254,19 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
270
254
|
module.stride = (module.stride, module.stride)
|
|
271
255
|
stride_so_far *= np.array(module.stride)
|
|
272
256
|
new_stride = list(module.stride)
|
|
273
|
-
for ax in axis:
|
|
257
|
+
for ax in axis: # type: ignore[union-attr]
|
|
274
258
|
new_stride[ax] = 1
|
|
275
259
|
module.stride = tuple(new_stride)
|
|
276
260
|
|
|
277
261
|
def get_torchinfo_statistics(
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
262
|
+
self,
|
|
263
|
+
col_names: Optional[Iterable[str]] = (
|
|
264
|
+
"input_size",
|
|
265
|
+
"output_size",
|
|
266
|
+
"num_params",
|
|
267
|
+
"kernel_size",
|
|
268
|
+
),
|
|
269
|
+
row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
|
|
286
270
|
) -> ModelStatistics:
|
|
287
271
|
"""Generate table describing the model using torchinfo.summary.
|
|
288
272
|
|