braindecode 1.3.0.dev173909672__tar.gz → 1.3.0.dev175435903__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {braindecode-1.3.0.dev173909672/braindecode.egg-info → braindecode-1.3.0.dev175435903}/PKG-INFO +17 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/augmentation/base.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/augmentation/functional.py +255 -54
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/augmentation/transforms.py +76 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/__init__.py +12 -4
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/base.py +132 -153
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/bcicomp.py +4 -4
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/bids.py +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/experimental.py +2 -2
- braindecode-1.3.0.dev175435903/braindecode/datasets/hub.py +962 -0
- braindecode-1.3.0.dev175435903/braindecode/datasets/hub_validation.py +113 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/mne.py +3 -5
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/moabb.py +17 -7
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/nmt.py +2 -2
- braindecode-1.3.0.dev175435903/braindecode/datasets/registry.py +120 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/sleep_physio_challe_18.py +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/sleep_physionet.py +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/tuh.py +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/xy.py +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datautil/__init__.py +11 -1
- braindecode-1.3.0.dev175435903/braindecode/datautil/channel_utils.py +114 -0
- braindecode-1.3.0.dev175435903/braindecode/datautil/hub_formats.py +180 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datautil/serialization.py +8 -9
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/eegneuralnet.py +2 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/functional/functions.py +6 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/functional/initialization.py +2 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/__init__.py +18 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/atcnet.py +37 -38
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/attentionbasenet.py +43 -36
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/attn_sleep.py +11 -7
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/base.py +280 -2
- braindecode-1.3.0.dev175435903/braindecode/models/bendr.py +469 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/biot.py +27 -23
- braindecode-1.3.0.dev175435903/braindecode/models/config.py +230 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/contrawr.py +4 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/ctnet.py +41 -36
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/deep4.py +4 -4
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/deepsleepnet.py +32 -23
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegconformer.py +29 -24
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eeginception_erp.py +32 -26
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eeginception_mi.py +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegitnet.py +5 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegminer.py +2 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegnet.py +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegnex.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegsimpleconv.py +1 -1
- braindecode-1.3.0.dev175435903/braindecode/models/eegsym.py +917 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/eegtcnet.py +5 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/fbcnet.py +6 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/fblightconvnet.py +3 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/fbmsnet.py +21 -7
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/hybrid.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/ifnet.py +4 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/labram.py +221 -115
- braindecode-1.3.0.dev175435903/braindecode/models/luna.py +836 -0
- braindecode-1.3.0.dev175435903/braindecode/models/medformer.py +758 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/msvtnet.py +11 -9
- braindecode-1.3.0.dev175435903/braindecode/models/patchedtransformer.py +640 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/sccnet.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/shallow_fbcsp.py +4 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/signal_jepa.py +111 -27
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/sinc_shallow.py +13 -11
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/sleep_stager_blanco_2020.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/sleep_stager_chambon_2018.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/sparcnet.py +4 -4
- braindecode-1.3.0.dev175435903/braindecode/models/sstdpn.py +869 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/summary.csv +7 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/syncnet.py +3 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/tcn.py +5 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/tidnet.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/tsinception.py +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/models/usleep.py +29 -24
- braindecode-1.3.0.dev175435903/braindecode/models/util.py +358 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/attention.py +10 -10
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/blocks.py +6 -4
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/filter.py +2 -9
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/layers.py +18 -17
- braindecode-1.3.0.dev175435903/braindecode/preprocessing/__init__.py +271 -0
- braindecode-1.3.0.dev175435903/braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode-1.3.0.dev175435903/braindecode/preprocessing/mne_preprocess.py +240 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/preprocessing/preprocess.py +146 -51
- braindecode-1.3.0.dev175435903/braindecode/preprocessing/util.py +177 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/preprocessing/windowers.py +26 -20
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/samplers/base.py +8 -8
- braindecode-1.3.0.dev175435903/braindecode/version.py +1 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903/braindecode.egg-info}/PKG-INFO +17 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode.egg-info/SOURCES.txt +14 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode.egg-info/requires.txt +18 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/api.rst +140 -13
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/conf.py +53 -14
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/index.rst +5 -5
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/install/install_pip.rst +7 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/install/install_source.rst +1 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/attention.rst +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/channel.rst +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/convolution.rst +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/filterbank.rst +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/gnn.rst +3 -6
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/interpretable.rst +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/lbm.rst +2 -2
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/recurrent.rst +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/categorization/spd.rst +3 -3
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/whats_new.rst +38 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/pyproject.toml +21 -2
- braindecode-1.3.0.dev173909672/braindecode/models/util.py +0 -129
- braindecode-1.3.0.dev173909672/braindecode/preprocessing/__init__.py +0 -37
- braindecode-1.3.0.dev173909672/braindecode/preprocessing/mne_preprocess.py +0 -77
- braindecode-1.3.0.dev173909672/braindecode/version.py +0 -1
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/MANIFEST.in +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/README.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/augmentation/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/classifier.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datasets/bbci.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/datautil/util.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/functional/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/activation.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/convolution.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/linear.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/parametrization.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/stats.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/util.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/modules/wrapper.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/regressor.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/samplers/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/samplers/ssl.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/training/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/training/callbacks.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/training/losses.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/training/scoring.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/util.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/visualization/__init__.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/visualization/confusion_matrices.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/visualization/gradients.py +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode.egg-info/dependency_links.txt +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode.egg-info/top_level.txt +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/Makefile +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/_templates/autosummary/class.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/_templates/autosummary/function.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/cite.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/help.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/install/install.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/models.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/models_categorization.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/models_table.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/models/models_visualization.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/docs/sg_execution_times.rst +0 -0
- {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/setup.cfg +0 -0
{braindecode-1.3.0.dev173909672/braindecode.egg-info → braindecode-1.3.0.dev175435903}/PKG-INFO
RENAMED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.0.
|
|
3
|
+
Version: 1.3.0.dev175435903
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
|
-
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
5
|
+
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
7
7
|
License: BSD-3-Clause
|
|
8
8
|
Project-URL: homepage, https://braindecode.org
|
|
@@ -38,14 +38,23 @@ Requires-Dist: wfdb
|
|
|
38
38
|
Requires-Dist: h5py
|
|
39
39
|
Requires-Dist: linear_attention_transformer
|
|
40
40
|
Requires-Dist: docstring_inheritance
|
|
41
|
+
Requires-Dist: rotary_embedding_torch
|
|
41
42
|
Provides-Extra: moabb
|
|
42
43
|
Requires-Dist: moabb>=1.2.0; extra == "moabb"
|
|
44
|
+
Provides-Extra: eegprep
|
|
45
|
+
Requires-Dist: eegprep[eeglabio]>=0.1.1; extra == "eegprep"
|
|
46
|
+
Provides-Extra: hub
|
|
47
|
+
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hub"
|
|
48
|
+
Requires-Dist: zarr<3.0,>=2.18; extra == "hub"
|
|
43
49
|
Provides-Extra: tests
|
|
44
50
|
Requires-Dist: pytest; extra == "tests"
|
|
45
51
|
Requires-Dist: pytest-cov; extra == "tests"
|
|
46
52
|
Requires-Dist: codecov; extra == "tests"
|
|
47
53
|
Requires-Dist: pytest_cases; extra == "tests"
|
|
48
54
|
Requires-Dist: mypy; extra == "tests"
|
|
55
|
+
Provides-Extra: typing
|
|
56
|
+
Requires-Dist: pydantic<3.0,>=2.0; extra == "typing"
|
|
57
|
+
Requires-Dist: numpydantic>=1.7; extra == "typing"
|
|
49
58
|
Provides-Extra: docs
|
|
50
59
|
Requires-Dist: sphinx_gallery; extra == "docs"
|
|
51
60
|
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
@@ -65,7 +74,12 @@ Requires-Dist: pre-commit; extra == "docs"
|
|
|
65
74
|
Requires-Dist: openneuro-py; extra == "docs"
|
|
66
75
|
Requires-Dist: plotly; extra == "docs"
|
|
67
76
|
Provides-Extra: all
|
|
68
|
-
Requires-Dist: braindecode[
|
|
77
|
+
Requires-Dist: braindecode[moabb]; extra == "all"
|
|
78
|
+
Requires-Dist: braindecode[hub]; extra == "all"
|
|
79
|
+
Requires-Dist: braindecode[tests]; extra == "all"
|
|
80
|
+
Requires-Dist: braindecode[docs]; extra == "all"
|
|
81
|
+
Requires-Dist: braindecode[eegprep]; extra == "all"
|
|
82
|
+
Requires-Dist: braindecode[typing]; extra == "all"
|
|
69
83
|
Dynamic: license-file
|
|
70
84
|
|
|
71
85
|
.. image:: https://badges.gitter.im/braindecodechat/community.svg
|
{braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev175435903}/braindecode/augmentation/base.py
RENAMED
|
@@ -189,7 +189,7 @@ class AugmentedDataLoader(DataLoader):
|
|
|
189
189
|
|
|
190
190
|
Parameters
|
|
191
191
|
----------
|
|
192
|
-
dataset :
|
|
192
|
+
dataset : RecordDataset
|
|
193
193
|
The dataset containing the signals.
|
|
194
194
|
transforms : list | Transform, optional
|
|
195
195
|
Transform or sequence of Transform to be applied to each batch.
|
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
3
|
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
|
+
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
4
5
|
#
|
|
5
6
|
# License: BSD (3-clause)
|
|
6
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
7
10
|
from numbers import Real
|
|
11
|
+
from typing import Literal
|
|
8
12
|
|
|
9
13
|
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
10
15
|
import torch
|
|
11
16
|
from mne.filter import notch_filter
|
|
12
17
|
from scipy.interpolate import Rbf
|
|
@@ -15,7 +20,7 @@ from torch.fft import fft, ifft
|
|
|
15
20
|
from torch.nn.functional import one_hot, pad
|
|
16
21
|
|
|
17
22
|
|
|
18
|
-
def identity(X, y):
|
|
23
|
+
def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
19
24
|
"""Identity operation.
|
|
20
25
|
|
|
21
26
|
Parameters
|
|
@@ -35,7 +40,7 @@ def identity(X, y):
|
|
|
35
40
|
return X, y
|
|
36
41
|
|
|
37
42
|
|
|
38
|
-
def time_reverse(X, y):
|
|
43
|
+
def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
39
44
|
"""Flip the time axis of each input.
|
|
40
45
|
|
|
41
46
|
Parameters
|
|
@@ -55,7 +60,7 @@ def time_reverse(X, y):
|
|
|
55
60
|
return torch.flip(X, [-1]), y
|
|
56
61
|
|
|
57
62
|
|
|
58
|
-
def sign_flip(X, y):
|
|
63
|
+
def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
59
64
|
"""Flip the sign axis of each input.
|
|
60
65
|
|
|
61
66
|
Parameters
|
|
@@ -75,7 +80,13 @@ def sign_flip(X, y):
|
|
|
75
80
|
return -X, y
|
|
76
81
|
|
|
77
82
|
|
|
78
|
-
def _new_random_fft_phase_odd(
|
|
83
|
+
def _new_random_fft_phase_odd(
|
|
84
|
+
batch_size: int,
|
|
85
|
+
c: int,
|
|
86
|
+
n: int,
|
|
87
|
+
device: torch.device,
|
|
88
|
+
random_state: int | np.random.RandomState | None,
|
|
89
|
+
) -> torch.Tensor:
|
|
79
90
|
rng = check_random_state(random_state)
|
|
80
91
|
random_phase = torch.from_numpy(
|
|
81
92
|
2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
|
|
@@ -90,7 +101,13 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
|
|
|
90
101
|
)
|
|
91
102
|
|
|
92
103
|
|
|
93
|
-
def _new_random_fft_phase_even(
|
|
104
|
+
def _new_random_fft_phase_even(
|
|
105
|
+
batch_size: int,
|
|
106
|
+
c: int,
|
|
107
|
+
n: int,
|
|
108
|
+
device: torch.device,
|
|
109
|
+
random_state: int | np.random.RandomState | None,
|
|
110
|
+
) -> torch.Tensor:
|
|
94
111
|
rng = check_random_state(random_state)
|
|
95
112
|
random_phase = torch.from_numpy(
|
|
96
113
|
2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
|
|
@@ -109,7 +126,13 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
|
|
|
109
126
|
_new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
|
|
110
127
|
|
|
111
128
|
|
|
112
|
-
def ft_surrogate(
|
|
129
|
+
def ft_surrogate(
|
|
130
|
+
X: torch.Tensor,
|
|
131
|
+
y: torch.Tensor,
|
|
132
|
+
phase_noise_magnitude: float,
|
|
133
|
+
channel_indep: bool,
|
|
134
|
+
random_state: int | np.random.RandomState | None = None,
|
|
135
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
113
136
|
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
|
|
114
137
|
|
|
115
138
|
Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
|
|
@@ -175,7 +198,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
|
|
|
175
198
|
return transformed_X, y
|
|
176
199
|
|
|
177
200
|
|
|
178
|
-
def _pick_channels_randomly(
|
|
201
|
+
def _pick_channels_randomly(
|
|
202
|
+
X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
|
|
203
|
+
) -> torch.Tensor:
|
|
179
204
|
rng = check_random_state(random_state)
|
|
180
205
|
batch_size, n_channels, _ = X.shape
|
|
181
206
|
# allows to use the same RNG
|
|
@@ -188,7 +213,12 @@ def _pick_channels_randomly(X, p_pick, random_state):
|
|
|
188
213
|
return torch.sigmoid(1000 * (unif_samples - p_pick))
|
|
189
214
|
|
|
190
215
|
|
|
191
|
-
def channels_dropout(
|
|
216
|
+
def channels_dropout(
|
|
217
|
+
X: torch.Tensor,
|
|
218
|
+
y: torch.Tensor,
|
|
219
|
+
p_drop: float,
|
|
220
|
+
random_state: int | np.random.RandomState | None = None,
|
|
221
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
192
222
|
"""Randomly set channels to flat signal.
|
|
193
223
|
|
|
194
224
|
Part of the CMSAugment policy proposed in [1]_
|
|
@@ -222,7 +252,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
|
|
|
222
252
|
return X * mask.unsqueeze(-1), y
|
|
223
253
|
|
|
224
254
|
|
|
225
|
-
def _make_permutation_matrix(
|
|
255
|
+
def _make_permutation_matrix(
|
|
256
|
+
X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
|
|
257
|
+
) -> torch.Tensor:
|
|
226
258
|
rng = check_random_state(random_state)
|
|
227
259
|
batch_size, n_channels, _ = X.shape
|
|
228
260
|
hard_mask = mask.round()
|
|
@@ -241,7 +273,12 @@ def _make_permutation_matrix(X, mask, random_state):
|
|
|
241
273
|
return batch_permutations
|
|
242
274
|
|
|
243
275
|
|
|
244
|
-
def channels_shuffle(
|
|
276
|
+
def channels_shuffle(
|
|
277
|
+
X: torch.Tensor,
|
|
278
|
+
y: torch.Tensor,
|
|
279
|
+
p_shuffle: float,
|
|
280
|
+
random_state: int | np.random.RandomState | None = None,
|
|
281
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
245
282
|
"""Randomly shuffle channels in EEG data matrix.
|
|
246
283
|
|
|
247
284
|
Part of the CMSAugment policy proposed in [1]_
|
|
@@ -280,7 +317,12 @@ def channels_shuffle(X, y, p_shuffle, random_state=None):
|
|
|
280
317
|
return torch.matmul(batch_permutations, X), y
|
|
281
318
|
|
|
282
319
|
|
|
283
|
-
def gaussian_noise(
|
|
320
|
+
def gaussian_noise(
|
|
321
|
+
X: torch.Tensor,
|
|
322
|
+
y: torch.Tensor,
|
|
323
|
+
std: float,
|
|
324
|
+
random_state: int | np.random.RandomState | None = None,
|
|
325
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
284
326
|
"""Randomly add white Gaussian noise to all channels.
|
|
285
327
|
|
|
286
328
|
Suggested e.g. in [1]_, [2]_ and [3]_
|
|
@@ -332,7 +374,9 @@ def gaussian_noise(X, y, std, random_state=None):
|
|
|
332
374
|
return transformed_X, y
|
|
333
375
|
|
|
334
376
|
|
|
335
|
-
def channels_permute(
|
|
377
|
+
def channels_permute(
|
|
378
|
+
X: torch.Tensor, y: torch.Tensor, permutation: list[int]
|
|
379
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
336
380
|
"""Permute EEG channels according to fixed permutation matrix.
|
|
337
381
|
|
|
338
382
|
Suggested e.g. in [1]_
|
|
@@ -362,7 +406,12 @@ def channels_permute(X, y, permutation):
|
|
|
362
406
|
return X[..., permutation, :], y
|
|
363
407
|
|
|
364
408
|
|
|
365
|
-
def smooth_time_mask(
|
|
409
|
+
def smooth_time_mask(
|
|
410
|
+
X: torch.Tensor,
|
|
411
|
+
y: torch.Tensor,
|
|
412
|
+
mask_start_per_sample: torch.Tensor,
|
|
413
|
+
mask_len_samples: int,
|
|
414
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
366
415
|
"""Smoothly replace a contiguous part of all channels by zeros.
|
|
367
416
|
|
|
368
417
|
Originally proposed in [1]_ and [2]_
|
|
@@ -412,7 +461,13 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
|
|
|
412
461
|
return X * mask, y
|
|
413
462
|
|
|
414
463
|
|
|
415
|
-
def bandstop_filter(
|
|
464
|
+
def bandstop_filter(
|
|
465
|
+
X: torch.Tensor,
|
|
466
|
+
y: torch.Tensor,
|
|
467
|
+
sfreq: float,
|
|
468
|
+
bandwidth: float,
|
|
469
|
+
freqs_to_notch: npt.ArrayLike | None,
|
|
470
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
416
471
|
"""Apply a band-stop filter with desired bandwidth at the desired frequency
|
|
417
472
|
position.
|
|
418
473
|
|
|
@@ -451,7 +506,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
|
|
|
451
506
|
Representation Learning for Electroencephalogram Classification. In
|
|
452
507
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
453
508
|
"""
|
|
454
|
-
if bandwidth == 0:
|
|
509
|
+
if bandwidth == 0 or freqs_to_notch is None:
|
|
455
510
|
return X, y
|
|
456
511
|
transformed_X = X.clone()
|
|
457
512
|
for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
|
|
@@ -469,7 +524,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
|
|
|
469
524
|
return transformed_X, y
|
|
470
525
|
|
|
471
526
|
|
|
472
|
-
def _analytic_transform(x):
|
|
527
|
+
def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
|
|
473
528
|
if torch.is_complex(x):
|
|
474
529
|
raise ValueError("x must be real.")
|
|
475
530
|
|
|
@@ -486,12 +541,12 @@ def _analytic_transform(x):
|
|
|
486
541
|
return ifft(f * h, dim=-1)
|
|
487
542
|
|
|
488
543
|
|
|
489
|
-
def _nextpow2(n):
|
|
544
|
+
def _nextpow2(n: int) -> int:
|
|
490
545
|
"""Return the first integer N such that 2**N >= abs(n)."""
|
|
491
546
|
return int(np.ceil(np.log2(np.abs(n))))
|
|
492
547
|
|
|
493
548
|
|
|
494
|
-
def _frequency_shift(X, fs, f_shift):
|
|
549
|
+
def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
|
|
495
550
|
"""
|
|
496
551
|
Shift the specified signal by the specified frequency.
|
|
497
552
|
|
|
@@ -504,9 +559,13 @@ def _frequency_shift(X, fs, f_shift):
|
|
|
504
559
|
t = torch.arange(N_padded, device=X.device) / fs
|
|
505
560
|
padded = pad(X, (0, N_padded - N_orig))
|
|
506
561
|
analytical = _analytic_transform(padded)
|
|
507
|
-
if isinstance(f_shift,
|
|
508
|
-
|
|
509
|
-
|
|
562
|
+
if isinstance(f_shift, torch.Tensor):
|
|
563
|
+
_f_shift = f_shift
|
|
564
|
+
elif isinstance(f_shift, (float, int, np.ndarray, list)):
|
|
565
|
+
_f_shift = torch.as_tensor(f_shift).float()
|
|
566
|
+
else:
|
|
567
|
+
raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
|
|
568
|
+
f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
|
|
510
569
|
reshaped_f_shift = f_shift_stack.permute(
|
|
511
570
|
*torch.arange(f_shift_stack.ndim - 1, -1, -1)
|
|
512
571
|
)
|
|
@@ -514,7 +573,9 @@ def _frequency_shift(X, fs, f_shift):
|
|
|
514
573
|
return shifted[..., :N_orig].real.float()
|
|
515
574
|
|
|
516
575
|
|
|
517
|
-
def frequency_shift(
|
|
576
|
+
def frequency_shift(
|
|
577
|
+
X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
|
|
578
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
518
579
|
"""Adds a shift in the frequency domain to all channels.
|
|
519
580
|
|
|
520
581
|
Note that here, the shift is the same for all channels of a single example.
|
|
@@ -545,7 +606,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
|
|
|
545
606
|
return transformed_X, y
|
|
546
607
|
|
|
547
608
|
|
|
548
|
-
def _torch_normalize_vectors(rr):
|
|
609
|
+
def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
|
|
549
610
|
"""Normalize surface vertices."""
|
|
550
611
|
norm = torch.linalg.norm(rr, axis=1, keepdim=True)
|
|
551
612
|
mask = norm > 0
|
|
@@ -554,7 +615,9 @@ def _torch_normalize_vectors(rr):
|
|
|
554
615
|
return new_rr
|
|
555
616
|
|
|
556
617
|
|
|
557
|
-
def _torch_legval(
|
|
618
|
+
def _torch_legval(
|
|
619
|
+
x: torch.Tensor, c: torch.Tensor, tensor: bool = True
|
|
620
|
+
) -> torch.Tensor:
|
|
558
621
|
"""
|
|
559
622
|
Evaluate a Legendre series at points x.
|
|
560
623
|
If `c` is of length `n + 1`, this function returns the value:
|
|
@@ -662,7 +725,9 @@ def _torch_legval(x, c, tensor=True):
|
|
|
662
725
|
return c0 + c1 * x
|
|
663
726
|
|
|
664
727
|
|
|
665
|
-
def _torch_calc_g(
|
|
728
|
+
def _torch_calc_g(
|
|
729
|
+
cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
|
|
730
|
+
) -> torch.Tensor:
|
|
666
731
|
"""Calculate spherical spline g function between points on a sphere.
|
|
667
732
|
|
|
668
733
|
Parameters
|
|
@@ -718,23 +783,25 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
|
|
718
783
|
return _torch_legval(cosang, [0] + factors)
|
|
719
784
|
|
|
720
785
|
|
|
721
|
-
def _torch_make_interpolation_matrix(
|
|
786
|
+
def _torch_make_interpolation_matrix(
|
|
787
|
+
pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
|
|
788
|
+
) -> torch.Tensor:
|
|
722
789
|
"""Compute interpolation matrix based on spherical splines.
|
|
723
790
|
|
|
724
791
|
Implementation based on [1]_
|
|
725
792
|
|
|
726
793
|
Parameters
|
|
727
794
|
----------
|
|
728
|
-
pos_from :
|
|
795
|
+
pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
|
|
729
796
|
The positions to interpolate from.
|
|
730
|
-
pos_to :
|
|
797
|
+
pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
|
|
731
798
|
The positions to interpolate.
|
|
732
799
|
alpha : float
|
|
733
800
|
Regularization parameter. Defaults to 1e-5.
|
|
734
801
|
|
|
735
802
|
Returns
|
|
736
803
|
-------
|
|
737
|
-
interpolation :
|
|
804
|
+
interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
|
|
738
805
|
The interpolation matrix that maps good signals to the location
|
|
739
806
|
of bad signals.
|
|
740
807
|
|
|
@@ -822,7 +889,12 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
|
|
822
889
|
return interpolation
|
|
823
890
|
|
|
824
891
|
|
|
825
|
-
def _rotate_signals(
|
|
892
|
+
def _rotate_signals(
|
|
893
|
+
X: torch.Tensor,
|
|
894
|
+
rotations: list[torch.Tensor],
|
|
895
|
+
sensors_positions_matrix: torch.Tensor,
|
|
896
|
+
spherical: bool = True,
|
|
897
|
+
) -> torch.Tensor:
|
|
826
898
|
sensors_positions_matrix = sensors_positions_matrix.to(X.device)
|
|
827
899
|
rot_sensors_matrices = [
|
|
828
900
|
rotation.matmul(sensors_positions_matrix) for rotation in rotations
|
|
@@ -853,22 +925,29 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
|
|
|
853
925
|
return transformed_X
|
|
854
926
|
|
|
855
927
|
|
|
856
|
-
def _make_rotation_matrix(
|
|
928
|
+
def _make_rotation_matrix(
|
|
929
|
+
axis: Literal["x", "y", "z"],
|
|
930
|
+
angle: float | int | np.ndarray | list | torch.Tensor,
|
|
931
|
+
degrees: bool = True,
|
|
932
|
+
) -> torch.Tensor:
|
|
857
933
|
assert axis in ["x", "y", "z"], "axis should be either x, y or z."
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
934
|
+
if isinstance(angle, torch.Tensor):
|
|
935
|
+
_angle = angle
|
|
936
|
+
elif isinstance(angle, (float, int, np.ndarray, list)):
|
|
937
|
+
_angle = torch.as_tensor(angle)
|
|
938
|
+
else:
|
|
939
|
+
raise ValueError(f"Invalid angle type: {type(angle)}")
|
|
861
940
|
|
|
862
941
|
if degrees:
|
|
863
|
-
|
|
942
|
+
_angle = _angle * np.pi / 180
|
|
864
943
|
|
|
865
|
-
device =
|
|
944
|
+
device = _angle.device
|
|
866
945
|
zero = torch.zeros(1, device=device)
|
|
867
946
|
rot = torch.stack(
|
|
868
947
|
[
|
|
869
948
|
torch.as_tensor([1, 0, 0], device=device),
|
|
870
|
-
torch.hstack([zero, torch.cos(
|
|
871
|
-
torch.hstack([zero, torch.sin(
|
|
949
|
+
torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
|
|
950
|
+
torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
|
|
872
951
|
]
|
|
873
952
|
)
|
|
874
953
|
if axis == "x":
|
|
@@ -881,7 +960,14 @@ def _make_rotation_matrix(axis, angle, degrees=True):
|
|
|
881
960
|
return rot[:, [1, 2, 0]]
|
|
882
961
|
|
|
883
962
|
|
|
884
|
-
def sensors_rotation(
|
|
963
|
+
def sensors_rotation(
|
|
964
|
+
X: torch.Tensor,
|
|
965
|
+
y: torch.Tensor,
|
|
966
|
+
sensors_positions_matrix: torch.Tensor,
|
|
967
|
+
axis: Literal["x", "y", "z"],
|
|
968
|
+
angles: npt.ArrayLike,
|
|
969
|
+
spherical_splines: bool,
|
|
970
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
885
971
|
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
886
972
|
with the desired angle.
|
|
887
973
|
|
|
@@ -893,7 +979,7 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
|
|
|
893
979
|
EEG input example or batch.
|
|
894
980
|
y : torch.Tensor
|
|
895
981
|
EEG labels for the example or batch.
|
|
896
|
-
sensors_positions_matrix :
|
|
982
|
+
sensors_positions_matrix : torch.Tensor
|
|
897
983
|
Matrix giving the positions of each sensor in a 3D cartesian coordinate
|
|
898
984
|
system. Should have shape (3, n_channels), where n_channels is the
|
|
899
985
|
number of channels. Standard 10-20 positions can be obtained from
|
|
@@ -924,7 +1010,9 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
|
|
|
924
1010
|
return rotated_X, y
|
|
925
1011
|
|
|
926
1012
|
|
|
927
|
-
def mixup(
|
|
1013
|
+
def mixup(
|
|
1014
|
+
X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
|
|
1015
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
928
1016
|
"""Mixes two channels of EEG data.
|
|
929
1017
|
|
|
930
1018
|
See [1]_ for details.
|
|
@@ -973,8 +1061,13 @@ def mixup(X, y, lam, idx_perm):
|
|
|
973
1061
|
|
|
974
1062
|
|
|
975
1063
|
def segmentation_reconstruction(
|
|
976
|
-
X
|
|
977
|
-
|
|
1064
|
+
X: torch.Tensor,
|
|
1065
|
+
y: torch.Tensor,
|
|
1066
|
+
n_segments: int,
|
|
1067
|
+
data_classes: list[tuple[int, torch.Tensor]],
|
|
1068
|
+
rand_indices: npt.ArrayLike,
|
|
1069
|
+
idx_shuffle: npt.ArrayLike,
|
|
1070
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
978
1071
|
"""Segment and reconstruct EEG data from [1]_.
|
|
979
1072
|
|
|
980
1073
|
See [1]_ for details.
|
|
@@ -987,6 +1080,8 @@ def segmentation_reconstruction(
|
|
|
987
1080
|
EEG labels for the example or batch.
|
|
988
1081
|
n_segments : int
|
|
989
1082
|
Number of segments to use in the batch.
|
|
1083
|
+
data_classes: list[tuple[int, torch.Tensor]]
|
|
1084
|
+
List of tuples. Each tuple contains the class index and the corresponding EEG data.
|
|
990
1085
|
rand_indices: array-like
|
|
991
1086
|
Array of indices that indicates which trial to use in each segment.
|
|
992
1087
|
idx_shuffle: array-like
|
|
@@ -1005,8 +1100,8 @@ def segmentation_reconstruction(
|
|
|
1005
1100
|
"""
|
|
1006
1101
|
|
|
1007
1102
|
# Initialize lists to store augmented data and corresponding labels
|
|
1008
|
-
aug_data = []
|
|
1009
|
-
aug_label = []
|
|
1103
|
+
aug_data: list[torch.Tensor] = []
|
|
1104
|
+
aug_label: list[torch.Tensor] = []
|
|
1010
1105
|
|
|
1011
1106
|
# Iterate through each class to separate and augment data
|
|
1012
1107
|
for class_index, X_class in data_classes:
|
|
@@ -1030,20 +1125,26 @@ def segmentation_reconstruction(
|
|
|
1030
1125
|
aug_data.append(X_aug)
|
|
1031
1126
|
aug_label.append(torch.full((n_trials,), class_index))
|
|
1032
1127
|
# Concatenate the augmented data and labels
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1128
|
+
concat_aug_data = torch.cat(aug_data, dim=0)
|
|
1129
|
+
concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
|
|
1130
|
+
concat_aug_data = concat_aug_data[idx_shuffle]
|
|
1036
1131
|
|
|
1037
1132
|
if y is not None:
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
return
|
|
1133
|
+
concat_label = torch.cat(aug_label, dim=0)
|
|
1134
|
+
concat_label = concat_label.to(dtype=y.dtype, device=y.device)
|
|
1135
|
+
concat_label = concat_label[idx_shuffle]
|
|
1136
|
+
return concat_aug_data, concat_label
|
|
1042
1137
|
|
|
1043
|
-
return
|
|
1138
|
+
return concat_aug_data, None
|
|
1044
1139
|
|
|
1045
1140
|
|
|
1046
|
-
def mask_encoding(
|
|
1141
|
+
def mask_encoding(
|
|
1142
|
+
X: torch.Tensor,
|
|
1143
|
+
y: torch.Tensor,
|
|
1144
|
+
time_start: torch.Tensor,
|
|
1145
|
+
segment_length: int,
|
|
1146
|
+
n_segments: int,
|
|
1147
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1047
1148
|
"""Mark encoding from Ding et al. (2024) from [ding2024]_.
|
|
1048
1149
|
|
|
1049
1150
|
Replaces a contiguous part (or parts) of all channels by zeros
|
|
@@ -1094,3 +1195,103 @@ def mask_encoding(X, y, time_start, segment_length, n_segments):
|
|
|
1094
1195
|
X[mask] = 0
|
|
1095
1196
|
|
|
1096
1197
|
return X, y # Return the masked tensor and labels
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
def channels_rereference(
|
|
1201
|
+
X: torch.Tensor,
|
|
1202
|
+
y: torch.Tensor,
|
|
1203
|
+
random_state: int | np.random.RandomState | None = None,
|
|
1204
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1205
|
+
"""Randomly re-reference channels in EEG data matrix.
|
|
1206
|
+
|
|
1207
|
+
Part of the augmentations proposed in [1]_
|
|
1208
|
+
|
|
1209
|
+
Parameters
|
|
1210
|
+
----------
|
|
1211
|
+
X : torch.Tensor
|
|
1212
|
+
EEG input example or batch.
|
|
1213
|
+
y : torch.Tensor
|
|
1214
|
+
EEG labels for the example or batch.
|
|
1215
|
+
random_state: int | numpy.random.Generator, optional
|
|
1216
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1217
|
+
Defaults to None.
|
|
1218
|
+
|
|
1219
|
+
Returns
|
|
1220
|
+
-------
|
|
1221
|
+
torch.Tensor
|
|
1222
|
+
Transformed inputs.
|
|
1223
|
+
torch.Tensor
|
|
1224
|
+
Transformed labels.
|
|
1225
|
+
|
|
1226
|
+
References
|
|
1227
|
+
----------
|
|
1228
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1229
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1230
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1231
|
+
Learning Research 136:238-253
|
|
1232
|
+
|
|
1233
|
+
"""
|
|
1234
|
+
|
|
1235
|
+
rng = check_random_state(random_state)
|
|
1236
|
+
batch_size, n_channels, _ = X.shape
|
|
1237
|
+
|
|
1238
|
+
ch = rng.randint(0, n_channels, size=batch_size)
|
|
1239
|
+
|
|
1240
|
+
X_ch = X[torch.arange(batch_size), ch, :]
|
|
1241
|
+
X = X - X_ch.unsqueeze(1)
|
|
1242
|
+
X[torch.arange(batch_size), ch, :] = -X_ch
|
|
1243
|
+
|
|
1244
|
+
return X, y
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
def amplitude_scale(
|
|
1248
|
+
X: torch.Tensor,
|
|
1249
|
+
y: torch.Tensor,
|
|
1250
|
+
scale: tuple,
|
|
1251
|
+
random_state: int | np.random.RandomState | None = None,
|
|
1252
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1253
|
+
"""Rescale amplitude of each channel based on a random sampled scaling value.
|
|
1254
|
+
|
|
1255
|
+
Part of the augmentations proposed in [1]_
|
|
1256
|
+
|
|
1257
|
+
Parameters
|
|
1258
|
+
----------
|
|
1259
|
+
X : torch.Tensor
|
|
1260
|
+
EEG input example or batch.
|
|
1261
|
+
y : torch.Tensor
|
|
1262
|
+
EEG labels for the example or batch.
|
|
1263
|
+
scale : tuple of floats
|
|
1264
|
+
Interval from which ypu sample the scaling value
|
|
1265
|
+
random_state: int | numpy.random.Generator, optional
|
|
1266
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1267
|
+
Defaults to None.
|
|
1268
|
+
|
|
1269
|
+
Returns
|
|
1270
|
+
-------
|
|
1271
|
+
torch.Tensor
|
|
1272
|
+
Transformed inputs.
|
|
1273
|
+
torch.Tensor
|
|
1274
|
+
Transformed labels.
|
|
1275
|
+
|
|
1276
|
+
References
|
|
1277
|
+
----------
|
|
1278
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1279
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1280
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1281
|
+
Learning Research 136:238-253
|
|
1282
|
+
|
|
1283
|
+
"""
|
|
1284
|
+
|
|
1285
|
+
rng = torch.Generator()
|
|
1286
|
+
rng.manual_seed(random_state)
|
|
1287
|
+
batch_size, n_channels, _ = X.shape
|
|
1288
|
+
|
|
1289
|
+
# Parameter for scaling amplitude / channel / trial
|
|
1290
|
+
l, h = scale
|
|
1291
|
+
s = l + (h - l) * torch.rand(
|
|
1292
|
+
batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
X = s * X
|
|
1296
|
+
|
|
1297
|
+
return X, y
|