braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/functional.py +0 -101
- braindecode/augmentation/transforms.py +0 -74
- braindecode/datasets/base.py +3 -18
- braindecode/datautil/serialization.py +1 -0
- braindecode/models/__init__.py +1 -8
- braindecode/models/summary.csv +0 -1
- braindecode/models/util.py +0 -84
- braindecode/preprocessing/__init__.py +0 -5
- braindecode/preprocessing/eegprep_preprocess.py +19 -134
- braindecode/preprocessing/mne_preprocess.py +25 -56
- braindecode/preprocessing/preprocess.py +41 -126
- braindecode/preprocessing/util.py +0 -11
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +5 -11
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/RECORD +19 -24
- braindecode/datasets/hub.py +0 -962
- braindecode/datasets/hub_validation.py +0 -113
- braindecode/datasets/registry.py +0 -120
- braindecode/datautil/hub_formats.py +0 -180
- braindecode/models/luna.py +0 -836
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
3
|
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
|
-
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
5
4
|
#
|
|
6
5
|
# License: BSD (3-clause)
|
|
7
6
|
|
|
@@ -1195,103 +1194,3 @@ def mask_encoding(
|
|
|
1195
1194
|
X[mask] = 0
|
|
1196
1195
|
|
|
1197
1196
|
return X, y # Return the masked tensor and labels
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
def channels_rereference(
|
|
1201
|
-
X: torch.Tensor,
|
|
1202
|
-
y: torch.Tensor,
|
|
1203
|
-
random_state: int | np.random.RandomState | None = None,
|
|
1204
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1205
|
-
"""Randomly re-reference channels in EEG data matrix.
|
|
1206
|
-
|
|
1207
|
-
Part of the augmentations proposed in [1]_
|
|
1208
|
-
|
|
1209
|
-
Parameters
|
|
1210
|
-
----------
|
|
1211
|
-
X : torch.Tensor
|
|
1212
|
-
EEG input example or batch.
|
|
1213
|
-
y : torch.Tensor
|
|
1214
|
-
EEG labels for the example or batch.
|
|
1215
|
-
random_state: int | numpy.random.Generator, optional
|
|
1216
|
-
Seed to be used to instantiate numpy random number generator instance.
|
|
1217
|
-
Defaults to None.
|
|
1218
|
-
|
|
1219
|
-
Returns
|
|
1220
|
-
-------
|
|
1221
|
-
torch.Tensor
|
|
1222
|
-
Transformed inputs.
|
|
1223
|
-
torch.Tensor
|
|
1224
|
-
Transformed labels.
|
|
1225
|
-
|
|
1226
|
-
References
|
|
1227
|
-
----------
|
|
1228
|
-
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1229
|
-
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1230
|
-
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1231
|
-
Learning Research 136:238-253
|
|
1232
|
-
|
|
1233
|
-
"""
|
|
1234
|
-
|
|
1235
|
-
rng = check_random_state(random_state)
|
|
1236
|
-
batch_size, n_channels, _ = X.shape
|
|
1237
|
-
|
|
1238
|
-
ch = rng.randint(0, n_channels, size=batch_size)
|
|
1239
|
-
|
|
1240
|
-
X_ch = X[torch.arange(batch_size), ch, :]
|
|
1241
|
-
X = X - X_ch.unsqueeze(1)
|
|
1242
|
-
X[torch.arange(batch_size), ch, :] = -X_ch
|
|
1243
|
-
|
|
1244
|
-
return X, y
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
def amplitude_scale(
|
|
1248
|
-
X: torch.Tensor,
|
|
1249
|
-
y: torch.Tensor,
|
|
1250
|
-
scale: tuple,
|
|
1251
|
-
random_state: int | np.random.RandomState | None = None,
|
|
1252
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1253
|
-
"""Rescale amplitude of each channel based on a random sampled scaling value.
|
|
1254
|
-
|
|
1255
|
-
Part of the augmentations proposed in [1]_
|
|
1256
|
-
|
|
1257
|
-
Parameters
|
|
1258
|
-
----------
|
|
1259
|
-
X : torch.Tensor
|
|
1260
|
-
EEG input example or batch.
|
|
1261
|
-
y : torch.Tensor
|
|
1262
|
-
EEG labels for the example or batch.
|
|
1263
|
-
scale : tuple of floats
|
|
1264
|
-
Interval from which ypu sample the scaling value
|
|
1265
|
-
random_state: int | numpy.random.Generator, optional
|
|
1266
|
-
Seed to be used to instantiate numpy random number generator instance.
|
|
1267
|
-
Defaults to None.
|
|
1268
|
-
|
|
1269
|
-
Returns
|
|
1270
|
-
-------
|
|
1271
|
-
torch.Tensor
|
|
1272
|
-
Transformed inputs.
|
|
1273
|
-
torch.Tensor
|
|
1274
|
-
Transformed labels.
|
|
1275
|
-
|
|
1276
|
-
References
|
|
1277
|
-
----------
|
|
1278
|
-
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1279
|
-
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1280
|
-
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1281
|
-
Learning Research 136:238-253
|
|
1282
|
-
|
|
1283
|
-
"""
|
|
1284
|
-
|
|
1285
|
-
rng = torch.Generator()
|
|
1286
|
-
rng.manual_seed(random_state)
|
|
1287
|
-
batch_size, n_channels, _ = X.shape
|
|
1288
|
-
|
|
1289
|
-
# Parameter for scaling amplitude / channel / trial
|
|
1290
|
-
l, h = scale
|
|
1291
|
-
s = l + (h - l) * torch.rand(
|
|
1292
|
-
batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
|
|
1293
|
-
)
|
|
1294
|
-
|
|
1295
|
-
X = s * X
|
|
1296
|
-
|
|
1297
|
-
return X, y
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
3
|
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
|
-
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
5
4
|
#
|
|
6
5
|
# License: BSD (3-clause)
|
|
7
6
|
|
|
@@ -14,11 +13,9 @@ from mne.channels import make_standard_montage
|
|
|
14
13
|
|
|
15
14
|
from .base import Transform
|
|
16
15
|
from .functional import (
|
|
17
|
-
amplitude_scale,
|
|
18
16
|
bandstop_filter,
|
|
19
17
|
channels_dropout,
|
|
20
18
|
channels_permute,
|
|
21
|
-
channels_rereference,
|
|
22
19
|
channels_shuffle,
|
|
23
20
|
frequency_shift,
|
|
24
21
|
ft_surrogate,
|
|
@@ -1274,74 +1271,3 @@ class MaskEncoding(Transform):
|
|
|
1274
1271
|
"segment_length": segment_length,
|
|
1275
1272
|
"n_segments": self.n_segments,
|
|
1276
1273
|
}
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
class ChannelsReref(Transform):
|
|
1280
|
-
"""Randomly re-reference channels in EEG data matrix.
|
|
1281
|
-
|
|
1282
|
-
Part of the augmentations proposed in [1]_
|
|
1283
|
-
|
|
1284
|
-
Parameters
|
|
1285
|
-
----------
|
|
1286
|
-
probability: float
|
|
1287
|
-
Float setting the probability of applying the operation.
|
|
1288
|
-
random_state: int | numpy.random.Generator, optional
|
|
1289
|
-
Seed to be used to instantiate numpy random number generator instance.
|
|
1290
|
-
Used to decide whether or not to transform given the probability
|
|
1291
|
-
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
1292
|
-
Defaults to None.
|
|
1293
|
-
|
|
1294
|
-
References
|
|
1295
|
-
----------
|
|
1296
|
-
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1297
|
-
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1298
|
-
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1299
|
-
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1300
|
-
|
|
1301
|
-
"""
|
|
1302
|
-
|
|
1303
|
-
operation = staticmethod(channels_rereference) # type: ignore[assignment]
|
|
1304
|
-
|
|
1305
|
-
def __init__(self, probability, random_state=None):
|
|
1306
|
-
super().__init__(probability=probability, random_state=random_state)
|
|
1307
|
-
|
|
1308
|
-
def get_augmentation_params(self, *batch):
|
|
1309
|
-
"""Return transform parameters"""
|
|
1310
|
-
return {
|
|
1311
|
-
"random_state": self.rng,
|
|
1312
|
-
}
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
class AmplitudeScale(Transform):
|
|
1316
|
-
"""Rescale amplitude based on a random sampled scaling value.
|
|
1317
|
-
|
|
1318
|
-
Part of the augmentations proposed in [1]_
|
|
1319
|
-
|
|
1320
|
-
Parameters
|
|
1321
|
-
----------
|
|
1322
|
-
probability: float
|
|
1323
|
-
Float setting the probability of applying the operation.
|
|
1324
|
-
random_state: int | numpy.random.Generator, optional
|
|
1325
|
-
Seed to be used to instantiate numpy random number generator instance.
|
|
1326
|
-
Used to decide whether or not to transform given the probability
|
|
1327
|
-
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
1328
|
-
Defaults to None.
|
|
1329
|
-
|
|
1330
|
-
References
|
|
1331
|
-
----------
|
|
1332
|
-
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1333
|
-
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1334
|
-
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1335
|
-
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1336
|
-
|
|
1337
|
-
"""
|
|
1338
|
-
|
|
1339
|
-
operation = staticmethod(amplitude_scale) # type: ignore[assignment]
|
|
1340
|
-
|
|
1341
|
-
def __init__(self, probability, interval=(0.5, 2), random_state=None):
|
|
1342
|
-
super().__init__(probability=probability, random_state=random_state)
|
|
1343
|
-
self.scale = interval
|
|
1344
|
-
|
|
1345
|
-
def get_augmentation_params(self, *batch):
|
|
1346
|
-
"""Return transform parameters"""
|
|
1347
|
-
return {"random_state": self.rng, "scale": self.scale}
|
braindecode/datasets/base.py
CHANGED
|
@@ -19,7 +19,7 @@ import warnings
|
|
|
19
19
|
from abc import abstractmethod
|
|
20
20
|
from collections.abc import Callable
|
|
21
21
|
from glob import glob
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Generic, Iterable, no_type_check
|
|
23
23
|
|
|
24
24
|
import mne.io
|
|
25
25
|
import numpy as np
|
|
@@ -28,9 +28,6 @@ from mne.utils.docs import deprecated
|
|
|
28
28
|
from torch.utils.data import ConcatDataset, Dataset
|
|
29
29
|
from typing_extensions import TypeVar
|
|
30
30
|
|
|
31
|
-
from .hub import HubDatasetMixin
|
|
32
|
-
from .registry import register_dataset
|
|
33
|
-
|
|
34
31
|
|
|
35
32
|
def _create_description(description) -> pd.Series:
|
|
36
33
|
if description is not None:
|
|
@@ -100,7 +97,6 @@ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]])
|
|
|
100
97
|
T = TypeVar("T", bound=RecordDataset)
|
|
101
98
|
|
|
102
99
|
|
|
103
|
-
@register_dataset
|
|
104
100
|
class RawDataset(RecordDataset):
|
|
105
101
|
"""Returns samples from an mne.io.Raw object along with a target.
|
|
106
102
|
|
|
@@ -133,7 +129,6 @@ class RawDataset(RecordDataset):
|
|
|
133
129
|
|
|
134
130
|
# save target name for load/save later
|
|
135
131
|
self.target_name = self._target_name(target_name)
|
|
136
|
-
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
137
132
|
|
|
138
133
|
def __getitem__(self, index):
|
|
139
134
|
X = self.raw[:, index][0]
|
|
@@ -181,12 +176,10 @@ class RawDataset(RecordDataset):
|
|
|
181
176
|
"If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
|
|
182
177
|
"use the RecordDataset class instead."
|
|
183
178
|
)
|
|
184
|
-
@register_dataset
|
|
185
179
|
class BaseDataset(RawDataset):
|
|
186
180
|
pass
|
|
187
181
|
|
|
188
182
|
|
|
189
|
-
@register_dataset
|
|
190
183
|
class EEGWindowsDataset(RecordDataset):
|
|
191
184
|
"""Returns windows from an mne.Raw object, its window indices, along with a target.
|
|
192
185
|
|
|
@@ -242,7 +235,6 @@ class EEGWindowsDataset(RecordDataset):
|
|
|
242
235
|
].to_numpy()
|
|
243
236
|
if self.targets_from == "metadata":
|
|
244
237
|
self.y = metadata.loc[:, "target"].to_list()
|
|
245
|
-
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
246
238
|
|
|
247
239
|
def __getitem__(self, index: int):
|
|
248
240
|
"""Get a window and its target.
|
|
@@ -293,7 +285,6 @@ class EEGWindowsDataset(RecordDataset):
|
|
|
293
285
|
return len(self.crop_inds)
|
|
294
286
|
|
|
295
287
|
|
|
296
|
-
@register_dataset
|
|
297
288
|
class WindowsDataset(RecordDataset):
|
|
298
289
|
"""Returns windows from an mne.Epochs object along with a target.
|
|
299
290
|
|
|
@@ -343,8 +334,6 @@ class WindowsDataset(RecordDataset):
|
|
|
343
334
|
].to_numpy()
|
|
344
335
|
if self.targets_from == "metadata":
|
|
345
336
|
self.y = metadata.loc[:, "target"].to_list()
|
|
346
|
-
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
347
|
-
self.window_preproc_kwargs: list[dict[str, Any]] = []
|
|
348
337
|
|
|
349
338
|
def __getitem__(self, index: int):
|
|
350
339
|
"""Get a window and its target.
|
|
@@ -385,16 +374,12 @@ class WindowsDataset(RecordDataset):
|
|
|
385
374
|
return len(self.windows.events)
|
|
386
375
|
|
|
387
376
|
|
|
388
|
-
|
|
389
|
-
class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
377
|
+
class BaseConcatDataset(ConcatDataset, Generic[T]):
|
|
390
378
|
"""A base class for concatenated datasets.
|
|
391
379
|
|
|
392
380
|
Holds either mne.Raw or mne.Epoch in self.datasets and has
|
|
393
381
|
a pandas DataFrame with additional description.
|
|
394
382
|
|
|
395
|
-
Includes Hugging Face Hub integration via HubDatasetMixin for
|
|
396
|
-
uploading and downloading datasets.
|
|
397
|
-
|
|
398
383
|
Parameters
|
|
399
384
|
----------
|
|
400
385
|
list_of_ds : list
|
|
@@ -809,7 +794,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
809
794
|
kwargs = getattr(ds, kwargs_name)
|
|
810
795
|
if kwargs is not None:
|
|
811
796
|
with open(kwargs_file_path, "w") as f:
|
|
812
|
-
json.dump(kwargs, f
|
|
797
|
+
json.dump(kwargs, f)
|
|
813
798
|
|
|
814
799
|
@staticmethod
|
|
815
800
|
def _save_target_name(sub_dir, ds):
|
|
@@ -302,6 +302,7 @@ def _load_kwargs_json(kwargs_name, sub_dir):
|
|
|
302
302
|
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
303
303
|
if os.path.exists(kwargs_file_path):
|
|
304
304
|
kwargs = json.load(open(kwargs_file_path, "r"))
|
|
305
|
+
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
305
306
|
return kwargs
|
|
306
307
|
|
|
307
308
|
|
braindecode/models/__init__.py
CHANGED
|
@@ -28,7 +28,6 @@ from .fbmsnet import FBMSNet
|
|
|
28
28
|
from .hybrid import HybridNet
|
|
29
29
|
from .ifnet import IFNet
|
|
30
30
|
from .labram import Labram
|
|
31
|
-
from .luna import LUNA
|
|
32
31
|
from .medformer import MEDFormer
|
|
33
32
|
from .msvtnet import MSVTNet
|
|
34
33
|
from .patchedtransformer import PBT
|
|
@@ -50,11 +49,7 @@ from .tcn import BDTCN, TCN
|
|
|
50
49
|
from .tidnet import TIDNet
|
|
51
50
|
from .tsinception import TSception
|
|
52
51
|
from .usleep import USleep
|
|
53
|
-
from .util import
|
|
54
|
-
_init_models_dict,
|
|
55
|
-
extract_channel_locations_from_chs_info,
|
|
56
|
-
models_mandatory_parameters,
|
|
57
|
-
)
|
|
52
|
+
from .util import _init_models_dict, models_mandatory_parameters
|
|
58
53
|
|
|
59
54
|
# Call this last in order to make sure the dataset list is populated with
|
|
60
55
|
# the models imported in this file.
|
|
@@ -88,8 +83,6 @@ __all__ = [
|
|
|
88
83
|
"HybridNet",
|
|
89
84
|
"IFNet",
|
|
90
85
|
"Labram",
|
|
91
|
-
"LUNA",
|
|
92
|
-
"extract_channel_locations_from_chs_info",
|
|
93
86
|
"MEDFormer",
|
|
94
87
|
"MSVTNet",
|
|
95
88
|
"PBT",
|
braindecode/models/summary.csv
CHANGED
|
@@ -41,5 +41,4 @@ IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860
|
|
|
41
41
|
PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
|
|
42
42
|
SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
43
43
|
BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
|
|
44
|
-
LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
45
44
|
MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
|
braindecode/models/util.py
CHANGED
|
@@ -4,9 +4,7 @@
|
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
import inspect
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, Optional, Sequence
|
|
8
7
|
|
|
9
|
-
import numpy as np
|
|
10
8
|
import pandas as pd
|
|
11
9
|
|
|
12
10
|
import braindecode.models as models
|
|
@@ -101,7 +99,6 @@ models_mandatory_parameters = [
|
|
|
101
99
|
("PBT", ["n_chans", "n_outputs", "n_times"], None),
|
|
102
100
|
("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
103
101
|
("BENDR", ["n_chans", "n_outputs", "n_times"], None),
|
|
104
|
-
("LUNA", ["n_chans", "n_times", "n_outputs"], None),
|
|
105
102
|
("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
|
|
106
103
|
]
|
|
107
104
|
|
|
@@ -134,85 +131,4 @@ def get_summary_table(dir_name=None):
|
|
|
134
131
|
return df
|
|
135
132
|
|
|
136
133
|
|
|
137
|
-
def extract_channel_locations_from_chs_info(
|
|
138
|
-
chs_info: Optional[Sequence[Dict[str, Any]]],
|
|
139
|
-
num_channels: Optional[int] = None,
|
|
140
|
-
) -> Optional[np.ndarray]:
|
|
141
|
-
"""Extract 3D channel locations from MNE-style channel information.
|
|
142
|
-
|
|
143
|
-
This function provides a unified approach to extract 3D channel locations
|
|
144
|
-
from MNE channel information. It's compatible with models like SignalJEPA
|
|
145
|
-
and LUNA that need to work with channel spatial information.
|
|
146
|
-
|
|
147
|
-
Parameters
|
|
148
|
-
----------
|
|
149
|
-
chs_info : list of dict or None
|
|
150
|
-
Channel information, typically from ``mne.Info.chs``. Each dict should
|
|
151
|
-
contain a 'loc' key with a 12-element array (MNE format) where indices 3:6
|
|
152
|
-
represent the 3D cartesian coordinates.
|
|
153
|
-
num_channels : int or None
|
|
154
|
-
If specified, only extract the first ``num_channels`` channel locations.
|
|
155
|
-
If None, extract all available channels.
|
|
156
|
-
|
|
157
|
-
Returns
|
|
158
|
-
-------
|
|
159
|
-
channel_locations : np.ndarray of shape (n_channels, 3) or None
|
|
160
|
-
Array of 3D channel locations in cartesian coordinates. Returns None if
|
|
161
|
-
no valid locations are found.
|
|
162
|
-
|
|
163
|
-
Notes
|
|
164
|
-
-----
|
|
165
|
-
- This function handles both 12-element MNE location format (using indices 3:6)
|
|
166
|
-
and 3-element location format (using directly).
|
|
167
|
-
- Invalid or missing locations cause extraction to stop at that point.
|
|
168
|
-
- Returns None if no valid locations can be extracted.
|
|
169
|
-
- This is a unified utility compatible with models like SignalJEPA and LUNA.
|
|
170
|
-
|
|
171
|
-
Examples
|
|
172
|
-
--------
|
|
173
|
-
>>> import mne
|
|
174
|
-
>>> from braindecode.models.util import extract_channel_locations_from_chs_info
|
|
175
|
-
>>> raw = mne.io.read_raw_edf("sample.edf")
|
|
176
|
-
>>> locs = extract_channel_locations_from_chs_info(raw.info['chs'], num_channels=22)
|
|
177
|
-
>>> print(locs.shape)
|
|
178
|
-
(22, 3)
|
|
179
|
-
"""
|
|
180
|
-
if chs_info is None:
|
|
181
|
-
return None
|
|
182
|
-
|
|
183
|
-
locations = []
|
|
184
|
-
n_to_extract = num_channels if num_channels is not None else len(chs_info)
|
|
185
|
-
|
|
186
|
-
for i, ch_info in enumerate(chs_info[:n_to_extract]):
|
|
187
|
-
if not isinstance(ch_info, dict):
|
|
188
|
-
break
|
|
189
|
-
|
|
190
|
-
loc = ch_info.get("loc")
|
|
191
|
-
if loc is None:
|
|
192
|
-
break
|
|
193
|
-
|
|
194
|
-
try:
|
|
195
|
-
loc_array = np.asarray(loc, dtype=np.float32)
|
|
196
|
-
|
|
197
|
-
# MNE format: 12-element array with coordinates at indices 3:6
|
|
198
|
-
if loc_array.ndim == 1 and loc_array.size >= 6:
|
|
199
|
-
if loc_array.size == 12:
|
|
200
|
-
# Standard MNE format
|
|
201
|
-
coordinates = loc_array[3:6]
|
|
202
|
-
else:
|
|
203
|
-
# Assume first 3 elements are coordinates
|
|
204
|
-
coordinates = loc_array[:3]
|
|
205
|
-
else:
|
|
206
|
-
break
|
|
207
|
-
|
|
208
|
-
locations.append(coordinates)
|
|
209
|
-
except (ValueError, TypeError):
|
|
210
|
-
break
|
|
211
|
-
|
|
212
|
-
if len(locations) == 0:
|
|
213
|
-
return None
|
|
214
|
-
|
|
215
|
-
return np.stack(locations, axis=0)
|
|
216
|
-
|
|
217
|
-
|
|
218
134
|
_summary_table = get_summary_table()
|
|
@@ -188,17 +188,12 @@ from .preprocess import (
|
|
|
188
188
|
filterbank,
|
|
189
189
|
preprocess,
|
|
190
190
|
)
|
|
191
|
-
from .util import _init_preprocessor_dict
|
|
192
191
|
from .windowers import (
|
|
193
192
|
create_fixed_length_windows,
|
|
194
193
|
create_windows_from_events,
|
|
195
194
|
create_windows_from_target_channels,
|
|
196
195
|
)
|
|
197
196
|
|
|
198
|
-
# Call this last in order to make sure the list is populated with
|
|
199
|
-
# the preprocessors imported in this file.
|
|
200
|
-
_init_preprocessor_dict()
|
|
201
|
-
|
|
202
197
|
__all__ = [
|
|
203
198
|
"exponential_moving_demean",
|
|
204
199
|
"exponential_moving_standardize",
|