braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
braindecode/util.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
from warnings import warn
|
|
9
|
+
|
|
10
|
+
import h5py
|
|
11
|
+
import mne
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from sklearn.utils import check_random_state
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
19
|
+
"""Set seeds for python random module numpy.random and torch.
|
|
20
|
+
|
|
21
|
+
For more details about reproducibility in pytorch see
|
|
22
|
+
https://pytorch.org/docs/stable/notes/randomness.html
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
seed : int
|
|
27
|
+
Random seed.
|
|
28
|
+
cuda : bool
|
|
29
|
+
Whether to set cuda seed with torch.
|
|
30
|
+
cudnn_benchmark : bool (default=None)
|
|
31
|
+
Whether pytorch will use cudnn benchmark. When set to `None` it will not modify
|
|
32
|
+
torch.backends.cudnn.benchmark (displays warning in the case of possible lack of
|
|
33
|
+
reproducibility). When set to True, results may not be reproducible (no warning displayed).
|
|
34
|
+
When set to False it may slow down computations.
|
|
35
|
+
|
|
36
|
+
Notes
|
|
37
|
+
-----
|
|
38
|
+
In some cases setting environment variable `PYTHONHASHSEED` may be needed before running a
|
|
39
|
+
script to ensure full reproducibility. See
|
|
40
|
+
https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14
|
|
41
|
+
|
|
42
|
+
Using this function may not ensure full reproducibility of the results as we do not set
|
|
43
|
+
`torch.use_deterministic_algorithms(True)`.
|
|
44
|
+
"""
|
|
45
|
+
random.seed(seed)
|
|
46
|
+
torch.manual_seed(seed)
|
|
47
|
+
if cuda:
|
|
48
|
+
if isinstance(cudnn_benchmark, bool):
|
|
49
|
+
torch.backends.cudnn.benchmark = cudnn_benchmark
|
|
50
|
+
elif cudnn_benchmark is None:
|
|
51
|
+
if torch.backends.cudnn.benchmark:
|
|
52
|
+
warn(
|
|
53
|
+
"torch.backends.cudnn.benchmark was set to True which may results in lack of "
|
|
54
|
+
"reproducibility. In some cases to ensure reproducibility you may need to "
|
|
55
|
+
"set torch.backends.cudnn.benchmark to False.",
|
|
56
|
+
UserWarning,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
|
|
61
|
+
)
|
|
62
|
+
torch.cuda.manual_seed_all(seed)
|
|
63
|
+
np.random.seed(seed)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
|
|
67
|
+
"""
|
|
68
|
+
Convenience function to transform numpy array to `torch.Tensor`.
|
|
69
|
+
|
|
70
|
+
Converts `X` to ndarray using asarray if necessary.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
X : ndarray or list or number
|
|
75
|
+
Input arrays
|
|
76
|
+
requires_grad : bool
|
|
77
|
+
passed on to Variable constructor
|
|
78
|
+
dtype : numpy dtype, optional
|
|
79
|
+
var_kwargs:
|
|
80
|
+
passed on to Variable constructor
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
var : `torch.Tensor`
|
|
85
|
+
"""
|
|
86
|
+
if not hasattr(X, "__len__"):
|
|
87
|
+
X = [X]
|
|
88
|
+
X = np.asarray(X)
|
|
89
|
+
if dtype is not None:
|
|
90
|
+
X = X.astype(dtype)
|
|
91
|
+
X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs)
|
|
92
|
+
if pin_memory:
|
|
93
|
+
X_tensor = X_tensor.pin_memory()
|
|
94
|
+
return X_tensor
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def th_to_np(var: Tensor):
|
|
98
|
+
"""Convenience function to transform `torch.Tensor` to numpy.
|
|
99
|
+
|
|
100
|
+
array.
|
|
101
|
+
Should work both for CPU and GPU.
|
|
102
|
+
"""
|
|
103
|
+
return var.cpu().data.numpy()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def corr(a, b):
|
|
107
|
+
"""
|
|
108
|
+
Computes correlation only between terms of a and terms of b, not within.
|
|
109
|
+
|
|
110
|
+
a and b.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
a, b : 2darray, features x samples
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
Correlation between features in x and features in y
|
|
119
|
+
"""
|
|
120
|
+
# Difference to numpy:
|
|
121
|
+
# Correlation only between terms of x and y
|
|
122
|
+
# not between x and x or y and y
|
|
123
|
+
this_cov = cov(a, b)
|
|
124
|
+
return _cov_to_corr(this_cov, a, b)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def cov(a, b):
|
|
128
|
+
"""
|
|
129
|
+
Computes covariance only between terms of a and terms of b, not within.
|
|
130
|
+
|
|
131
|
+
a and b.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
a, b : 2darray, features x samples
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
Covariance between features in x and features in y
|
|
140
|
+
"""
|
|
141
|
+
demeaned_a = a - np.mean(a, axis=1, keepdims=True)
|
|
142
|
+
demeaned_b = b - np.mean(b, axis=1, keepdims=True)
|
|
143
|
+
this_cov = np.dot(demeaned_a, demeaned_b.T) / (b.shape[1] - 1)
|
|
144
|
+
return this_cov
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _cov_to_corr(this_cov, a, b):
|
|
148
|
+
# computing "unbiased" corr
|
|
149
|
+
# ddof=1 for unbiased..
|
|
150
|
+
var_a = np.var(a, axis=1, ddof=1)
|
|
151
|
+
var_b = np.var(b, axis=1, ddof=1)
|
|
152
|
+
return _cov_and_var_to_corr(this_cov, var_a, var_b)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _cov_and_var_to_corr(this_cov, var_a, var_b):
|
|
156
|
+
divisor = np.outer(np.sqrt(var_a), np.sqrt(var_b))
|
|
157
|
+
return this_cov / divisor
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
|
|
161
|
+
"""
|
|
162
|
+
Reshape two nd-arrays into 2d-arrays, apply function and reshape.
|
|
163
|
+
|
|
164
|
+
result back.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
stat_fn : function
|
|
169
|
+
Function to apply to 2d-arrays
|
|
170
|
+
a : nd-array: nd-array
|
|
171
|
+
b : nd-array
|
|
172
|
+
axis_a : int or list of int
|
|
173
|
+
sample axis
|
|
174
|
+
axis_b : int or list of int
|
|
175
|
+
sample axis
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
result : nd-array
|
|
180
|
+
The result reshaped to remaining_dims_a + remaining_dims_b
|
|
181
|
+
"""
|
|
182
|
+
if not hasattr(axis_a, "__len__"):
|
|
183
|
+
axis_a = [axis_a]
|
|
184
|
+
if not hasattr(axis_b, "__len__"):
|
|
185
|
+
axis_b = [axis_b]
|
|
186
|
+
other_axis_a = [i for i in range(a.ndim) if i not in axis_a]
|
|
187
|
+
other_axis_b = [i for i in range(b.ndim) if i not in axis_b]
|
|
188
|
+
transposed_topo_a = a.transpose(tuple(other_axis_a) + tuple(axis_a))
|
|
189
|
+
n_stat_axis_a = [a.shape[i] for i in axis_a]
|
|
190
|
+
n_other_axis_a = [a.shape[i] for i in other_axis_a]
|
|
191
|
+
flat_topo_a = transposed_topo_a.reshape(
|
|
192
|
+
np.prod(n_other_axis_a), np.prod(n_stat_axis_a)
|
|
193
|
+
)
|
|
194
|
+
transposed_topo_b = b.transpose(tuple(other_axis_b) + tuple(axis_b))
|
|
195
|
+
n_stat_axis_b = [b.shape[i] for i in axis_b]
|
|
196
|
+
n_other_axis_b = [b.shape[i] for i in other_axis_b]
|
|
197
|
+
flat_topo_b = transposed_topo_b.reshape(
|
|
198
|
+
np.prod(n_other_axis_b), np.prod(n_stat_axis_b)
|
|
199
|
+
)
|
|
200
|
+
assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
|
|
201
|
+
stat_result = stat_fn(flat_topo_a, flat_topo_b)
|
|
202
|
+
topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
|
|
203
|
+
return topo_result
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
|
|
207
|
+
"""Create indices for batches balanced in size.
|
|
208
|
+
|
|
209
|
+
(batches will have maximum size difference of 1).
|
|
210
|
+
Supply either batch size or number of batches. Resulting batches
|
|
211
|
+
will not have the given batch size but rather the next largest batch size
|
|
212
|
+
that allows to split the set into balanced batches (maximum size difference 1).
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
n_trials : int
|
|
217
|
+
Size of set.
|
|
218
|
+
rng : RandomState
|
|
219
|
+
shuffle : bool
|
|
220
|
+
Whether to shuffle indices before splitting set.
|
|
221
|
+
n_batches : int, optional
|
|
222
|
+
batch_size : int, optional
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
batches : list of list of int
|
|
227
|
+
Indices for each batch.
|
|
228
|
+
"""
|
|
229
|
+
assert batch_size is not None or n_batches is not None
|
|
230
|
+
if n_batches is None:
|
|
231
|
+
n_batches = int(np.round(n_trials / float(batch_size)))
|
|
232
|
+
|
|
233
|
+
if n_batches > 0:
|
|
234
|
+
min_batch_size = n_trials // n_batches
|
|
235
|
+
n_batches_with_extra_trial = n_trials % n_batches
|
|
236
|
+
else:
|
|
237
|
+
n_batches = 1
|
|
238
|
+
min_batch_size = n_trials
|
|
239
|
+
n_batches_with_extra_trial = 0
|
|
240
|
+
assert n_batches_with_extra_trial < n_batches
|
|
241
|
+
all_inds = np.array(range(n_trials))
|
|
242
|
+
if shuffle:
|
|
243
|
+
rng.shuffle(all_inds)
|
|
244
|
+
i_start_trial = 0
|
|
245
|
+
i_stop_trial = 0
|
|
246
|
+
batches = []
|
|
247
|
+
for i_batch in range(n_batches):
|
|
248
|
+
i_stop_trial += min_batch_size
|
|
249
|
+
if i_batch < n_batches_with_extra_trial:
|
|
250
|
+
i_stop_trial += 1
|
|
251
|
+
batch_inds = all_inds[range(i_start_trial, i_stop_trial)]
|
|
252
|
+
batches.append(batch_inds)
|
|
253
|
+
i_start_trial = i_stop_trial
|
|
254
|
+
assert i_start_trial == n_trials
|
|
255
|
+
return batches
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def create_mne_dummy_raw(
|
|
259
|
+
n_channels,
|
|
260
|
+
n_times,
|
|
261
|
+
sfreq,
|
|
262
|
+
include_anns=True,
|
|
263
|
+
description=None,
|
|
264
|
+
savedir=None,
|
|
265
|
+
save_format="fif",
|
|
266
|
+
overwrite=True,
|
|
267
|
+
random_state=None,
|
|
268
|
+
):
|
|
269
|
+
"""Create an mne.io.RawArray with fake data, and optionally save it.
|
|
270
|
+
|
|
271
|
+
This will overwrite already existing files.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
n_channels : int
|
|
276
|
+
Number of channels.
|
|
277
|
+
n_times : int
|
|
278
|
+
Number of samples.
|
|
279
|
+
sfreq : float
|
|
280
|
+
Sampling frequency.
|
|
281
|
+
include_anns : bool
|
|
282
|
+
If True, also create annotations.
|
|
283
|
+
description : list | None
|
|
284
|
+
List of descriptions used for creating annotations. It should contain
|
|
285
|
+
10 elements.
|
|
286
|
+
savedir : str | None
|
|
287
|
+
If provided as a string, the file will be saved under that directory.
|
|
288
|
+
save_format : str | list
|
|
289
|
+
If `savedir` is provided, this specifies the file format the data should
|
|
290
|
+
be saved to. Can be 'raw' or 'hdf5', or a list containing both.
|
|
291
|
+
random_state : int | RandomState
|
|
292
|
+
Random state for the generation of random data.
|
|
293
|
+
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
raw : mne.io.Raw
|
|
297
|
+
The created Raw object.
|
|
298
|
+
save_fname : dict | None
|
|
299
|
+
Dictionary containing the name the raw data was saved to.
|
|
300
|
+
"""
|
|
301
|
+
random_state = check_random_state(random_state)
|
|
302
|
+
data = random_state.rand(n_channels, n_times)
|
|
303
|
+
ch_names = [f"ch{i}" for i in range(n_channels)]
|
|
304
|
+
ch_types = ["eeg"] * n_channels
|
|
305
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
306
|
+
|
|
307
|
+
raw = mne.io.RawArray(data, info)
|
|
308
|
+
|
|
309
|
+
if include_anns:
|
|
310
|
+
n_anns = 10
|
|
311
|
+
inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
|
|
312
|
+
int
|
|
313
|
+
)
|
|
314
|
+
onset = raw.times[inds]
|
|
315
|
+
duration = [1] * n_anns
|
|
316
|
+
if description is None:
|
|
317
|
+
description = ["test"] * n_anns
|
|
318
|
+
anns = mne.Annotations(onset, duration, description)
|
|
319
|
+
raw = raw.set_annotations(anns)
|
|
320
|
+
|
|
321
|
+
save_fname = dict()
|
|
322
|
+
if savedir is not None:
|
|
323
|
+
if not isinstance(save_format, list):
|
|
324
|
+
save_format = [save_format]
|
|
325
|
+
fname = os.path.join(savedir, "fake_eeg_raw")
|
|
326
|
+
|
|
327
|
+
if "fif" in save_format:
|
|
328
|
+
fif_fname = fname + ".fif"
|
|
329
|
+
raw.save(fif_fname, overwrite=overwrite)
|
|
330
|
+
save_fname["fif"] = fif_fname
|
|
331
|
+
if "hdf5" in save_format:
|
|
332
|
+
h5_fname = fname + ".h5"
|
|
333
|
+
with h5py.File(h5_fname, "w") as f:
|
|
334
|
+
f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
|
|
335
|
+
save_fname["hdf5"] = h5_fname
|
|
336
|
+
|
|
337
|
+
return raw, save_fname
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class ThrowAwayIndexLoader(object):
|
|
341
|
+
def __init__(self, net, loader, is_regression):
|
|
342
|
+
self.net = net
|
|
343
|
+
self.loader = loader
|
|
344
|
+
self.last_i = None
|
|
345
|
+
self.is_regression = is_regression
|
|
346
|
+
|
|
347
|
+
def __iter__(
|
|
348
|
+
self,
|
|
349
|
+
):
|
|
350
|
+
normal_iter = self.loader.__iter__()
|
|
351
|
+
for batch in normal_iter:
|
|
352
|
+
if len(batch) == 3:
|
|
353
|
+
x, y, i = batch
|
|
354
|
+
# Store for scoring callbacks
|
|
355
|
+
self.net._last_window_inds_ = i
|
|
356
|
+
else:
|
|
357
|
+
x, y = batch
|
|
358
|
+
|
|
359
|
+
# TODO: should be on dataset side
|
|
360
|
+
if hasattr(x, "type"):
|
|
361
|
+
x = x.type(torch.float32)
|
|
362
|
+
if self.is_regression:
|
|
363
|
+
y = y.type(torch.float32)
|
|
364
|
+
else:
|
|
365
|
+
y = y.type(torch.int64)
|
|
366
|
+
yield x, y
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def update_estimator_docstring(base_class, docstring):
|
|
370
|
+
base_doc = base_class.__doc__.replace(" : ", ": ")
|
|
371
|
+
idx = base_doc.find("callbacks:")
|
|
372
|
+
idx_end = idx + base_doc[idx:].find("\n\n")
|
|
373
|
+
# remove callback descripiton already included in braindecode docstring
|
|
374
|
+
filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
|
|
375
|
+
splitted = docstring.split("Parameters\n ----------\n ")
|
|
376
|
+
out_docstring = (
|
|
377
|
+
splitted[0]
|
|
378
|
+
+ filtered_doc[
|
|
379
|
+
filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
|
|
380
|
+
]
|
|
381
|
+
+ splitted[1]
|
|
382
|
+
+ filtered_doc[filtered_doc.find("Attributes") :]
|
|
383
|
+
)
|
|
384
|
+
return out_docstring
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _update_moabb_docstring(base_class, docstring):
|
|
388
|
+
base_doc = base_class.__doc__
|
|
389
|
+
# Clean up malformed rubrics from moabb docstrings
|
|
390
|
+
# Remove lines that have ".. rubric::" followed by content on same line or improper formatting
|
|
391
|
+
|
|
392
|
+
base_doc = re.sub(r"\.\. rubric:: (.+?)\n\s+\.\. note::", r".. note::", base_doc)
|
|
393
|
+
out_docstring = base_doc + f"\n\n{docstring}"
|
|
394
|
+
return out_docstring
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def read_all_file_names(directory, extension):
|
|
398
|
+
"""Read all files with specified extension from given path and sorts them.
|
|
399
|
+
|
|
400
|
+
based on a given sorting key.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
directory : str
|
|
405
|
+
Parent directory to be searched for files of the specified type.
|
|
406
|
+
extension : str
|
|
407
|
+
File extension, i.e. ".edf" or ".txt".
|
|
408
|
+
|
|
409
|
+
Returns
|
|
410
|
+
-------
|
|
411
|
+
file_paths : list(str)
|
|
412
|
+
List of all files found in (sub)directories of path.
|
|
413
|
+
"""
|
|
414
|
+
assert extension.startswith(".")
|
|
415
|
+
file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
|
|
416
|
+
assert len(file_paths) > 0, (
|
|
417
|
+
f"something went wrong. Found no {extension} files in {directory}"
|
|
418
|
+
)
|
|
419
|
+
return file_paths
|
braindecode/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.3.0.dev177069446"
|