braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib import cm
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def plot_confusion_matrix(
|
|
11
|
+
confusion_mat,
|
|
12
|
+
class_names=None,
|
|
13
|
+
figsize=None,
|
|
14
|
+
colormap=cm.bwr,
|
|
15
|
+
textcolor="black",
|
|
16
|
+
vmin=None,
|
|
17
|
+
vmax=None,
|
|
18
|
+
fontweight="normal",
|
|
19
|
+
rotate_row_labels=90,
|
|
20
|
+
rotate_col_labels=0,
|
|
21
|
+
with_f1_score=False,
|
|
22
|
+
norm_axes=(0, 1),
|
|
23
|
+
rotate_precision=False,
|
|
24
|
+
class_names_fontsize=12,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
Generates a confusion matrix with additional precision and sensitivity metrics as in [1]_.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
confusion_mat: 2d numpy array
|
|
33
|
+
A confusion matrix, e.g. sklearn confusion matrix:
|
|
34
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
|
|
35
|
+
class_names: array, optional
|
|
36
|
+
List of classes/targets.
|
|
37
|
+
figsize: tuple, optional
|
|
38
|
+
Size of the generated confusion matrix figure.
|
|
39
|
+
colormap: matplotlib cm colormap, optional
|
|
40
|
+
textcolor: str, optional
|
|
41
|
+
Color of the text in the figure.
|
|
42
|
+
vmin, vmax: float, optional
|
|
43
|
+
The data range that the colormap covers.
|
|
44
|
+
fontweight: str, optional
|
|
45
|
+
Weight of the font in the figure:
|
|
46
|
+
[ 'normal' | 'bold' | 'heavy' | 'light' | 'ultrabold' | 'ultralight']
|
|
47
|
+
rotate_row_labels: int, optional
|
|
48
|
+
The rotation angle of the row labels
|
|
49
|
+
rotate_col_labels: int, optional
|
|
50
|
+
The rotation angle of the column labels
|
|
51
|
+
with_f1_score: bool, optional
|
|
52
|
+
norm_axes: tuple, optional
|
|
53
|
+
rotate_precision: bool, optional
|
|
54
|
+
class_names_fontsize: int, optional
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
fig: matplotlib figure
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
|
|
63
|
+
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
|
|
64
|
+
Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
|
|
65
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
66
|
+
visualization.
|
|
67
|
+
Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
|
|
68
|
+
"""
|
|
69
|
+
# transpose to get confusion matrix same way as matlab
|
|
70
|
+
confusion_mat = confusion_mat.T
|
|
71
|
+
n_classes = confusion_mat.shape[0]
|
|
72
|
+
if class_names is None:
|
|
73
|
+
class_names = [str(i_class + 1) for i_class in range(n_classes)]
|
|
74
|
+
|
|
75
|
+
# norm by all targets
|
|
76
|
+
normed_conf_mat = confusion_mat / np.float32(
|
|
77
|
+
np.sum(confusion_mat, axis=norm_axes, keepdims=True)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
fig = plt.figure(figsize=figsize)
|
|
81
|
+
plt.clf()
|
|
82
|
+
ax = fig.add_subplot(111)
|
|
83
|
+
ax.set_aspect(1)
|
|
84
|
+
if vmin is None:
|
|
85
|
+
vmin = np.min(normed_conf_mat)
|
|
86
|
+
if vmax is None:
|
|
87
|
+
vmax = np.max(normed_conf_mat)
|
|
88
|
+
|
|
89
|
+
# see http://stackoverflow.com/a/31397438/1469195
|
|
90
|
+
# brighten so that black text remains readable
|
|
91
|
+
# used alpha=0.6 before
|
|
92
|
+
def _brighten(
|
|
93
|
+
x,
|
|
94
|
+
):
|
|
95
|
+
brightened_x = 1 - ((1 - np.array(x)) * 0.4)
|
|
96
|
+
return brightened_x
|
|
97
|
+
|
|
98
|
+
brightened_cmap = _cmap_map(_brighten, colormap) # colormap #
|
|
99
|
+
ax.imshow(
|
|
100
|
+
np.array(normed_conf_mat),
|
|
101
|
+
cmap=brightened_cmap,
|
|
102
|
+
interpolation="nearest",
|
|
103
|
+
vmin=vmin,
|
|
104
|
+
vmax=vmax,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# make space for precision and sensitivity
|
|
108
|
+
plt.xlim(-0.5, normed_conf_mat.shape[0] + 0.5)
|
|
109
|
+
plt.ylim(normed_conf_mat.shape[1] + 0.5, -0.5)
|
|
110
|
+
width = len(confusion_mat)
|
|
111
|
+
height = len(confusion_mat[0])
|
|
112
|
+
for x in range(width):
|
|
113
|
+
for y in range(height):
|
|
114
|
+
if x == y:
|
|
115
|
+
this_font_weight = "bold"
|
|
116
|
+
else:
|
|
117
|
+
this_font_weight = fontweight
|
|
118
|
+
annotate_str = "{:d}".format(confusion_mat[x][y])
|
|
119
|
+
annotate_str += "\n"
|
|
120
|
+
ax.annotate(
|
|
121
|
+
annotate_str.format(confusion_mat[x][y]),
|
|
122
|
+
xy=(y, x),
|
|
123
|
+
horizontalalignment="center",
|
|
124
|
+
verticalalignment="center",
|
|
125
|
+
fontsize=12,
|
|
126
|
+
color=textcolor,
|
|
127
|
+
fontweight=this_font_weight,
|
|
128
|
+
)
|
|
129
|
+
if x != y or (not with_f1_score):
|
|
130
|
+
ax.annotate(
|
|
131
|
+
"\n\n{:4.1f}%".format(normed_conf_mat[x][y] * 100),
|
|
132
|
+
xy=(y, x),
|
|
133
|
+
horizontalalignment="center",
|
|
134
|
+
verticalalignment="center",
|
|
135
|
+
fontsize=10,
|
|
136
|
+
color=textcolor,
|
|
137
|
+
fontweight=this_font_weight,
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
assert x == y
|
|
141
|
+
precision = confusion_mat[x][x] / float(np.sum(confusion_mat[x, :]))
|
|
142
|
+
sensitivity = confusion_mat[x][x] / float(np.sum(confusion_mat[:, y]))
|
|
143
|
+
f1_score = 2 * precision * sensitivity / (precision + sensitivity)
|
|
144
|
+
|
|
145
|
+
ax.annotate(
|
|
146
|
+
"\n{:4.1f}%\n{:4.1f}% (F)".format(
|
|
147
|
+
(confusion_mat[x][y] / float(np.sum(confusion_mat))) * 100,
|
|
148
|
+
f1_score * 100,
|
|
149
|
+
),
|
|
150
|
+
xy=(y, x + 0.1),
|
|
151
|
+
horizontalalignment="center",
|
|
152
|
+
verticalalignment="center",
|
|
153
|
+
fontsize=10,
|
|
154
|
+
color=textcolor,
|
|
155
|
+
fontweight=this_font_weight,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Add values for target correctness etc.
|
|
159
|
+
for x in range(width):
|
|
160
|
+
y = len(confusion_mat)
|
|
161
|
+
if float(np.sum(confusion_mat[x, :])) == 0:
|
|
162
|
+
annotate_str = "-"
|
|
163
|
+
else:
|
|
164
|
+
correctness = confusion_mat[x][x] / float(np.sum(confusion_mat[x, :]))
|
|
165
|
+
annotate_str = ""
|
|
166
|
+
annotate_str += "\n{:5.2f}%".format(correctness * 100)
|
|
167
|
+
ax.annotate(
|
|
168
|
+
annotate_str,
|
|
169
|
+
xy=(y, x),
|
|
170
|
+
horizontalalignment="center",
|
|
171
|
+
verticalalignment="center",
|
|
172
|
+
fontsize=12,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
for y in range(height):
|
|
176
|
+
x = len(confusion_mat)
|
|
177
|
+
if float(np.sum(confusion_mat[:, y])) == 0:
|
|
178
|
+
annotate_str = "-"
|
|
179
|
+
else:
|
|
180
|
+
correctness = confusion_mat[y][y] / float(np.sum(confusion_mat[:, y]))
|
|
181
|
+
annotate_str = ""
|
|
182
|
+
annotate_str += "\n{:5.2f}%".format(correctness * 100)
|
|
183
|
+
ax.annotate(
|
|
184
|
+
annotate_str,
|
|
185
|
+
xy=(y, x),
|
|
186
|
+
horizontalalignment="center",
|
|
187
|
+
verticalalignment="center",
|
|
188
|
+
fontsize=12,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
overall_correctness = np.sum(np.diag(confusion_mat)) / np.sum(confusion_mat).astype(
|
|
192
|
+
float
|
|
193
|
+
)
|
|
194
|
+
ax.annotate(
|
|
195
|
+
"{:5.2f}%".format(overall_correctness * 100),
|
|
196
|
+
xy=(len(confusion_mat), len(confusion_mat)),
|
|
197
|
+
horizontalalignment="center",
|
|
198
|
+
verticalalignment="center",
|
|
199
|
+
fontsize=12,
|
|
200
|
+
fontweight="bold",
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
plt.xticks(
|
|
204
|
+
range(width),
|
|
205
|
+
class_names,
|
|
206
|
+
fontsize=class_names_fontsize,
|
|
207
|
+
rotation=rotate_col_labels,
|
|
208
|
+
)
|
|
209
|
+
plt.yticks(
|
|
210
|
+
np.arange(0, height),
|
|
211
|
+
class_names,
|
|
212
|
+
va="center",
|
|
213
|
+
fontsize=class_names_fontsize,
|
|
214
|
+
rotation=rotate_row_labels,
|
|
215
|
+
)
|
|
216
|
+
plt.grid(False)
|
|
217
|
+
plt.ylabel("Predictions", fontsize=15)
|
|
218
|
+
plt.xlabel("Targets", fontsize=15)
|
|
219
|
+
|
|
220
|
+
# n classes is also shape of matrix/size
|
|
221
|
+
ax.text(-1.2, n_classes + 0.2, "Recall", ha="center", va="center", fontsize=13)
|
|
222
|
+
if rotate_precision:
|
|
223
|
+
rotation = 90
|
|
224
|
+
x_pos = -1.1
|
|
225
|
+
va = "center"
|
|
226
|
+
else:
|
|
227
|
+
rotation = 0
|
|
228
|
+
x_pos = -0.8
|
|
229
|
+
va = "top"
|
|
230
|
+
ax.text(
|
|
231
|
+
n_classes,
|
|
232
|
+
x_pos,
|
|
233
|
+
"Precision",
|
|
234
|
+
ha="center",
|
|
235
|
+
va=va,
|
|
236
|
+
rotation=rotation, # 270,
|
|
237
|
+
fontsize=13,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return fig
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# see http://stackoverflow.com/a/31397438/1469195
|
|
244
|
+
def _cmap_map(function, cmap, name="colormap_mod", N=None, gamma=None):
|
|
245
|
+
"""
|
|
246
|
+
Modify a colormap using `function` which must operate on 3-element
|
|
247
|
+
arrays of [r, g, b] values.
|
|
248
|
+
|
|
249
|
+
You may specify the number of colors, `N`, and the opacity, `gamma`,
|
|
250
|
+
value of the returned colormap. These values default to the ones in
|
|
251
|
+
the input `cmap`.
|
|
252
|
+
|
|
253
|
+
You may also specify a `name` for the colormap, so that it can be
|
|
254
|
+
loaded using plt.get_cmap(name).
|
|
255
|
+
"""
|
|
256
|
+
from matplotlib.colors import LinearSegmentedColormap as lsc
|
|
257
|
+
|
|
258
|
+
if N is None:
|
|
259
|
+
N = cmap.N
|
|
260
|
+
if gamma is None:
|
|
261
|
+
gamma = cmap._gamma
|
|
262
|
+
cdict = cmap._segmentdata
|
|
263
|
+
# Cast the steps into lists:
|
|
264
|
+
step_dict = {key: list(map(lambda x: x[0], cdict[key])) for key in cdict}
|
|
265
|
+
# Now get the unique steps (first column of the arrays):
|
|
266
|
+
step_dicts = np.array(list(step_dict.values()))
|
|
267
|
+
step_list = np.unique(step_dicts)
|
|
268
|
+
# 'y0', 'y1' are as defined in LinearSegmentedColormap docstring:
|
|
269
|
+
y0 = cmap(step_list)[:, :3]
|
|
270
|
+
y1 = y0.copy()[:, :3]
|
|
271
|
+
# Go back to catch the discontinuities, and place them into y0, y1
|
|
272
|
+
for iclr, key in enumerate(["red", "green", "blue"]):
|
|
273
|
+
for istp, step in enumerate(step_list):
|
|
274
|
+
try:
|
|
275
|
+
ind = step_dict[key].index(step)
|
|
276
|
+
except ValueError:
|
|
277
|
+
# This step is not in this color
|
|
278
|
+
continue
|
|
279
|
+
y0[istp, iclr] = cdict[key][ind][1]
|
|
280
|
+
y1[istp, iclr] = cdict[key][ind][2]
|
|
281
|
+
# Map the colors to their new values:
|
|
282
|
+
y0 = np.array(list(map(function, y0)))
|
|
283
|
+
y1 = np.array(list(map(function, y1)))
|
|
284
|
+
# Build the new colormap (overwriting step_dict):
|
|
285
|
+
for iclr, clr in enumerate(["red", "green", "blue"]):
|
|
286
|
+
step_dict[clr] = np.vstack((step_list, y0[:, iclr], y1[:, iclr])).T
|
|
287
|
+
# Remove alpha, otherwise crashes...
|
|
288
|
+
step_dict.pop("alpha", None)
|
|
289
|
+
return lsc(name, step_dict, N=N, gamma=gamma)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from skorch.utils import to_numpy, to_tensor
|
|
8
|
+
|
|
9
|
+
from braindecode.util import set_random_seeds
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compute_amplitude_gradients(model, dataset, batch_size, seed=20240205):
|
|
13
|
+
"""Compute amplitude gradients after seeding for reproducibility."""
|
|
14
|
+
cuda = next(model.parameters()).is_cuda
|
|
15
|
+
set_random_seeds(seed=seed, cuda=cuda)
|
|
16
|
+
loader = torch.utils.data.DataLoader(
|
|
17
|
+
dataset, batch_size=batch_size, drop_last=False, shuffle=False
|
|
18
|
+
)
|
|
19
|
+
all_amp_grads = []
|
|
20
|
+
for batch_X, _, _ in loader:
|
|
21
|
+
this_amp_grads = compute_amplitude_gradients_for_X(
|
|
22
|
+
model,
|
|
23
|
+
batch_X,
|
|
24
|
+
)
|
|
25
|
+
all_amp_grads.append(this_amp_grads)
|
|
26
|
+
all_amp_grads = np.concatenate(all_amp_grads, axis=1)
|
|
27
|
+
return all_amp_grads
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_amplitude_gradients_for_X(model, X):
|
|
31
|
+
device = next(model.parameters()).device
|
|
32
|
+
ffted = np.fft.rfft(X, axis=2)
|
|
33
|
+
amps = np.abs(ffted)
|
|
34
|
+
phases = np.angle(ffted)
|
|
35
|
+
amps_th = to_tensor(amps.astype(np.float32), device=device).requires_grad_(True)
|
|
36
|
+
phases_th = to_tensor(phases.astype(np.float32), device=device).requires_grad_(True)
|
|
37
|
+
|
|
38
|
+
fft_coefs = amps_th.unsqueeze(-1) * torch.stack(
|
|
39
|
+
(torch.cos(phases_th), torch.sin(phases_th)), dim=-1
|
|
40
|
+
)
|
|
41
|
+
fft_coefs = fft_coefs.squeeze(3)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
complex_fft_coefs = torch.view_as_complex(fft_coefs)
|
|
45
|
+
iffted = torch.fft.irfft(complex_fft_coefs, n=X.shape[2], dim=2)
|
|
46
|
+
except AttributeError:
|
|
47
|
+
iffted = torch.irfft( # Deprecated since 1.7
|
|
48
|
+
fft_coefs, signal_ndim=1, signal_sizes=(X.shape[2],)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
outs = model(iffted)
|
|
52
|
+
|
|
53
|
+
n_filters = outs.shape[1]
|
|
54
|
+
amp_grads_per_filter = np.full((n_filters,) + ffted.shape, np.nan, dtype=np.float32)
|
|
55
|
+
for i_filter in range(n_filters):
|
|
56
|
+
mean_out = torch.mean(outs[:, i_filter])
|
|
57
|
+
mean_out.backward(retain_graph=True)
|
|
58
|
+
amp_grads = to_numpy(amps_th.grad.clone())
|
|
59
|
+
amp_grads_per_filter[i_filter] = amp_grads
|
|
60
|
+
amps_th.grad.zero_()
|
|
61
|
+
assert not np.any(np.isnan(amp_grads_per_filter))
|
|
62
|
+
return amp_grads_per_filter
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: braindecode
|
|
3
|
+
Version: 1.3.0.dev177069446
|
|
4
|
+
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
|
+
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
|
+
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
7
|
+
License: BSD-3-Clause
|
|
8
|
+
Project-URL: homepage, https://braindecode.org
|
|
9
|
+
Project-URL: repository, https://github.com/braindecode/braindecode
|
|
10
|
+
Project-URL: documentation, https://braindecode.org/stable/index.html
|
|
11
|
+
Keywords: python,deep-learning,neuroscience,pytorch,meg,eeg,neuroimaging,electroencephalography,magnetoencephalography,electrocorticography,ecog,electroencephalogram
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Topic :: Software Development :: Build Tools
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Description-Content-Type: text/x-rst
|
|
23
|
+
License-File: LICENSE.txt
|
|
24
|
+
License-File: NOTICE.txt
|
|
25
|
+
Requires-Dist: torch>=2.2
|
|
26
|
+
Requires-Dist: torchaudio>=2.0
|
|
27
|
+
Requires-Dist: mne>=1.11.0
|
|
28
|
+
Requires-Dist: mne_bids>=0.16
|
|
29
|
+
Requires-Dist: h5py
|
|
30
|
+
Requires-Dist: skorch>=1.3.0
|
|
31
|
+
Requires-Dist: joblib
|
|
32
|
+
Requires-Dist: torchinfo
|
|
33
|
+
Requires-Dist: wfdb
|
|
34
|
+
Requires-Dist: linear_attention_transformer
|
|
35
|
+
Requires-Dist: docstring_inheritance
|
|
36
|
+
Requires-Dist: rotary_embedding_torch
|
|
37
|
+
Provides-Extra: moabb
|
|
38
|
+
Requires-Dist: moabb>=1.4.3; extra == "moabb"
|
|
39
|
+
Provides-Extra: eegprep
|
|
40
|
+
Requires-Dist: eegprep[eeglabio]>=0.2.23; extra == "eegprep"
|
|
41
|
+
Provides-Extra: hub
|
|
42
|
+
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hub"
|
|
43
|
+
Requires-Dist: zarr>=3.0; extra == "hub"
|
|
44
|
+
Provides-Extra: tests
|
|
45
|
+
Requires-Dist: pytest; extra == "tests"
|
|
46
|
+
Requires-Dist: pytest-cov; extra == "tests"
|
|
47
|
+
Requires-Dist: codecov; extra == "tests"
|
|
48
|
+
Requires-Dist: pytest_cases; extra == "tests"
|
|
49
|
+
Requires-Dist: mypy; extra == "tests"
|
|
50
|
+
Requires-Dist: transformers>=4.57.0; extra == "tests"
|
|
51
|
+
Requires-Dist: bids_validator; extra == "tests"
|
|
52
|
+
Provides-Extra: typing
|
|
53
|
+
Requires-Dist: exca==0.4; extra == "typing"
|
|
54
|
+
Requires-Dist: numpydantic>=1.7; extra == "typing"
|
|
55
|
+
Provides-Extra: docs
|
|
56
|
+
Requires-Dist: sphinx_gallery; extra == "docs"
|
|
57
|
+
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
58
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
|
|
59
|
+
Requires-Dist: sphinx-autobuild; extra == "docs"
|
|
60
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
|
61
|
+
Requires-Dist: sphinx_sitemap; extra == "docs"
|
|
62
|
+
Requires-Dist: pydata_sphinx_theme; extra == "docs"
|
|
63
|
+
Requires-Dist: numpydoc; extra == "docs"
|
|
64
|
+
Requires-Dist: memory_profiler; extra == "docs"
|
|
65
|
+
Requires-Dist: pillow; extra == "docs"
|
|
66
|
+
Requires-Dist: ipython; extra == "docs"
|
|
67
|
+
Requires-Dist: sphinx_design; extra == "docs"
|
|
68
|
+
Requires-Dist: lightning; extra == "docs"
|
|
69
|
+
Requires-Dist: seaborn; extra == "docs"
|
|
70
|
+
Requires-Dist: pre-commit; extra == "docs"
|
|
71
|
+
Requires-Dist: openneuro-py; extra == "docs"
|
|
72
|
+
Requires-Dist: plotly; extra == "docs"
|
|
73
|
+
Requires-Dist: shap; extra == "docs"
|
|
74
|
+
Requires-Dist: nbformat; extra == "docs"
|
|
75
|
+
Requires-Dist: transformers; extra == "docs"
|
|
76
|
+
Provides-Extra: all
|
|
77
|
+
Requires-Dist: braindecode[moabb]; extra == "all"
|
|
78
|
+
Requires-Dist: braindecode[tests]; extra == "all"
|
|
79
|
+
Requires-Dist: braindecode[docs]; extra == "all"
|
|
80
|
+
Requires-Dist: braindecode[hub]; extra == "all"
|
|
81
|
+
Requires-Dist: braindecode[eegprep]; extra == "all"
|
|
82
|
+
Requires-Dist: braindecode[typing]; extra == "all"
|
|
83
|
+
Dynamic: license-file
|
|
84
|
+
|
|
85
|
+
.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.8214376.svg
|
|
86
|
+
:target: https://doi.org/10.5281/zenodo.8214376
|
|
87
|
+
:alt: DOI
|
|
88
|
+
|
|
89
|
+
.. image:: https://github.com/braindecode/braindecode/workflows/docs/badge.svg
|
|
90
|
+
:target: https://github.com/braindecode/braindecode/actions
|
|
91
|
+
:alt: Docs Build Status
|
|
92
|
+
|
|
93
|
+
.. image:: https://github.com/braindecode/braindecode/workflows/tests/badge.svg
|
|
94
|
+
:target: https://github.com/braindecode/braindecode/actions?query=branch%3Amaster
|
|
95
|
+
:alt: Test Build Status
|
|
96
|
+
|
|
97
|
+
.. image:: https://codecov.io/gh/braindecode/braindecode/branch/master/graph/badge.svg
|
|
98
|
+
:target: https://codecov.io/gh/braindecode/braindecode
|
|
99
|
+
:alt: Code Coverage
|
|
100
|
+
|
|
101
|
+
.. image:: https://img.shields.io/pypi/v/braindecode?color=blue&style=flat-square
|
|
102
|
+
:target: https://pypi.org/project/braindecode/
|
|
103
|
+
:alt: PyPI
|
|
104
|
+
|
|
105
|
+
.. image:: https://img.shields.io/pypi/v/braindecode?label=version&color=orange&style=flat-square
|
|
106
|
+
:target: https://pypi.org/project/braindecode/
|
|
107
|
+
:alt: Version
|
|
108
|
+
|
|
109
|
+
.. image:: https://img.shields.io/pypi/pyversions/braindecode?style=flat-square
|
|
110
|
+
:target: https://pypi.org/project/braindecode/
|
|
111
|
+
:alt: Python versions
|
|
112
|
+
|
|
113
|
+
.. image:: https://pepy.tech/badge/braindecode
|
|
114
|
+
:target: https://pepy.tech/project/braindecode
|
|
115
|
+
:alt: Downloads
|
|
116
|
+
|
|
117
|
+
.. |Braindecode| image:: https://user-images.githubusercontent.com/42702466/177958779-b00628aa-9155-4c51-96d1-d8c345aff575.svg
|
|
118
|
+
|
|
119
|
+
.. _braindecode: braindecode.org/
|
|
120
|
+
|
|
121
|
+
#############
|
|
122
|
+
Braindecode
|
|
123
|
+
#############
|
|
124
|
+
|
|
125
|
+
Braindecode is an open-source Python toolbox for decoding raw electrophysiological brain
|
|
126
|
+
data with deep learning models. It includes dataset fetchers, data preprocessing and
|
|
127
|
+
visualization tools, as well as implementations of several deep learning architectures
|
|
128
|
+
and data augmentations for analysis of EEG, ECoG and MEG.
|
|
129
|
+
|
|
130
|
+
For neuroscientists who want to work with deep learning and deep learning researchers
|
|
131
|
+
who want to work with neurophysiological data.
|
|
132
|
+
|
|
133
|
+
##########################
|
|
134
|
+
Installation Braindecode
|
|
135
|
+
##########################
|
|
136
|
+
|
|
137
|
+
1. Install pytorch from http://pytorch.org/ (you don't need to install torchvision).
|
|
138
|
+
2. If you want to download EEG datasets from `MOABB
|
|
139
|
+
<https://github.com/NeuroTechX/moabb>`_, install it:
|
|
140
|
+
|
|
141
|
+
.. code-block:: bash
|
|
142
|
+
|
|
143
|
+
pip install moabb
|
|
144
|
+
|
|
145
|
+
3. Install latest release of braindecode via pip:
|
|
146
|
+
|
|
147
|
+
.. code-block:: bash
|
|
148
|
+
|
|
149
|
+
pip install braindecode
|
|
150
|
+
|
|
151
|
+
If you want to install the latest development version of braindecode, please refer to
|
|
152
|
+
`contributing page
|
|
153
|
+
<https://github.com/braindecode/braindecode/blob/master/CONTRIBUTING.md>`__
|
|
154
|
+
|
|
155
|
+
###############
|
|
156
|
+
Documentation
|
|
157
|
+
###############
|
|
158
|
+
|
|
159
|
+
Documentation is online under https://braindecode.org, both in the stable and dev
|
|
160
|
+
versions.
|
|
161
|
+
|
|
162
|
+
#############################
|
|
163
|
+
Contributing to Braindecode
|
|
164
|
+
#############################
|
|
165
|
+
|
|
166
|
+
Guidelines for contributing to the library can be found on the braindecode github:
|
|
167
|
+
|
|
168
|
+
https://github.com/braindecode/braindecode/blob/master/CONTRIBUTING.md
|
|
169
|
+
|
|
170
|
+
########
|
|
171
|
+
Citing
|
|
172
|
+
########
|
|
173
|
+
|
|
174
|
+
If you use Braindecode in scientific work, please cite the software using the Zenodo DOI
|
|
175
|
+
shown in the badge below:
|
|
176
|
+
|
|
177
|
+
.. image:: https://zenodo.org/badge/232335424.svg
|
|
178
|
+
:target: https://doi.org/10.5281/zenodo.8214376
|
|
179
|
+
:alt: DOI
|
|
180
|
+
|
|
181
|
+
Additionally, we highly encourage you to cite the article that originally introduced the
|
|
182
|
+
Braindecode library and has served as a foundational reference for many works on deep
|
|
183
|
+
learning with EEG recordings. Please use the following reference:
|
|
184
|
+
|
|
185
|
+
.. code-block:: bibtex
|
|
186
|
+
|
|
187
|
+
@article {HBM:HBM23730,
|
|
188
|
+
author = {Schirrmeister, Robin Tibor and Springenberg, Jost Tobias and Fiederer,
|
|
189
|
+
Lukas Dominique Josef and Glasstetter, Martin and Eggensperger, Katharina and Tangermann, Michael and
|
|
190
|
+
Hutter, Frank and Burgard, Wolfram and Ball, Tonio},
|
|
191
|
+
title = {Deep learning with convolutional neural networks for EEG decoding and visualization},
|
|
192
|
+
journal = {Human Brain Mapping},
|
|
193
|
+
issn = {1097-0193},
|
|
194
|
+
url = {http://dx.doi.org/10.1002/hbm.23730},
|
|
195
|
+
doi = {10.1002/hbm.23730},
|
|
196
|
+
month = {aug},
|
|
197
|
+
year = {2017},
|
|
198
|
+
keywords = {electroencephalography, EEG analysis, machine learning, end-to-end learning, brain–machine interface,
|
|
199
|
+
brain–computer interface, model interpretability, brain mapping},
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
as well as the `MNE-Python <https://mne.tools>`_ software that is used by braindecode:
|
|
203
|
+
|
|
204
|
+
.. code-block:: bibtex
|
|
205
|
+
|
|
206
|
+
@article{10.3389/fnins.2013.00267,
|
|
207
|
+
author={Gramfort, Alexandre and Luessi, Martin and Larson, Eric and Engemann, Denis and Strohmeier, Daniel and Brodbeck, Christian and Goj, Roman and Jas, Mainak and Brooks, Teon and Parkkonen, Lauri and Hämäläinen, Matti},
|
|
208
|
+
title={{MEG and EEG data analysis with MNE-Python}},
|
|
209
|
+
journal={Frontiers in Neuroscience},
|
|
210
|
+
volume={7},
|
|
211
|
+
pages={267},
|
|
212
|
+
year={2013},
|
|
213
|
+
url={https://www.frontiersin.org/article/10.3389/fnins.2013.00267},
|
|
214
|
+
doi={10.3389/fnins.2013.00267},
|
|
215
|
+
issn={1662-453X},
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
***********
|
|
219
|
+
Licensing
|
|
220
|
+
***********
|
|
221
|
+
|
|
222
|
+
This project is primarily licensed under the BSD-3-Clause License.
|
|
223
|
+
|
|
224
|
+
Additional Components
|
|
225
|
+
=====================
|
|
226
|
+
|
|
227
|
+
Some components within this repository are licensed under the Creative Commons
|
|
228
|
+
Attribution-NonCommercial 4.0 International License.
|
|
229
|
+
|
|
230
|
+
Please refer to the ``LICENSE`` and ``NOTICE`` files for more detailed information.
|