braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,96 @@
1
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+
10
+ import mne
11
+ import numpy as np
12
+ import pandas as pd
13
+ from numpy.typing import ArrayLike, NDArray
14
+
15
+ from .base import BaseConcatDataset, RawDataset
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ def create_from_X_y(
21
+ X: NDArray,
22
+ y: ArrayLike,
23
+ drop_last_window: bool,
24
+ sfreq: float,
25
+ ch_names: ArrayLike = None,
26
+ window_size_samples: int | None = None,
27
+ window_stride_samples: int | None = None,
28
+ ) -> BaseConcatDataset:
29
+ """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for.
30
+
31
+ decoding with skorch and braindecode, where X is a list of pre-cut trials
32
+ and y are corresponding targets.
33
+
34
+ Parameters
35
+ ----------
36
+ X : array-like
37
+ list of pre-cut trials as n_trials x n_channels x n_times
38
+ y : array-like
39
+ targets corresponding to the trials
40
+ drop_last_window : bool
41
+ whether or not have a last overlapping window, when
42
+ windows/windows do not equally divide the continuous signal
43
+ sfreq : float
44
+ Sampling frequency of signals.
45
+ ch_names : array-like
46
+ Names of the channels.
47
+ window_size_samples : int
48
+ window size
49
+ window_stride_samples : int
50
+ stride between windows
51
+
52
+ Returns
53
+ -------
54
+ windows_datasets : BaseConcatDataset
55
+ X and y transformed to a dataset format that is compatible with skorch
56
+ and braindecode
57
+ """
58
+ # Prevent circular import
59
+ from ..preprocessing.windowers import (
60
+ create_fixed_length_windows,
61
+ )
62
+
63
+ n_samples_per_x = []
64
+ base_datasets = []
65
+ if ch_names is None:
66
+ ch_names = [str(i) for i in range(X.shape[1])]
67
+ log.info(f"No channel names given, set to 0-{X.shape[1]}).")
68
+
69
+ for x, target in zip(X, y):
70
+ n_samples_per_x.append(x.shape[1])
71
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
72
+ raw = mne.io.RawArray(x, info)
73
+ base_dataset = RawDataset(
74
+ raw, pd.Series({"target": target}), target_name="target"
75
+ )
76
+ base_datasets.append(base_dataset)
77
+ base_datasets = BaseConcatDataset(base_datasets)
78
+
79
+ if window_size_samples is None and window_stride_samples is None:
80
+ if not len(np.unique(n_samples_per_x)) == 1:
81
+ raise ValueError(
82
+ "if 'window_size_samples' and "
83
+ "'window_stride_samples' are None, "
84
+ "all trials have to have the same length"
85
+ )
86
+ window_size_samples = n_samples_per_x[0]
87
+ window_stride_samples = n_samples_per_x[0]
88
+ windows_datasets = create_fixed_length_windows(
89
+ base_datasets,
90
+ start_offset_samples=0,
91
+ stop_offset_samples=None,
92
+ window_size_samples=window_size_samples,
93
+ window_stride_samples=window_stride_samples,
94
+ drop_last_window=drop_last_window,
95
+ )
96
+ return windows_datasets
@@ -0,0 +1,62 @@
1
+ """Utilities for data manipulation."""
2
+
3
+ from .channel_utils import (
4
+ division_channels_idx,
5
+ match_hemisphere_chans,
6
+ )
7
+ from .serialization import (
8
+ _check_save_dir_empty,
9
+ load_concat_dataset,
10
+ save_concat_dataset,
11
+ )
12
+ from .util import infer_signal_properties
13
+
14
+
15
+ def __getattr__(name):
16
+ # ideas from https://stackoverflow.com/a/57110249/1469195
17
+ import importlib
18
+ from warnings import warn
19
+
20
+ if name == "create_from_X_y":
21
+ warn(
22
+ "create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
23
+ )
24
+ xy = importlib.import_module("..datasets.xy", __package__)
25
+ return xy.create_from_X_y
26
+ if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
27
+ warn(
28
+ f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
29
+ )
30
+ mne = importlib.import_module("..datasets.mne", __package__)
31
+ return mne.__dict__[name]
32
+ if name in [
33
+ "scale",
34
+ "exponential_moving_demean",
35
+ "exponential_moving_standardize",
36
+ "filterbank",
37
+ "preprocess",
38
+ "Preprocessor",
39
+ ]:
40
+ warn(
41
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
42
+ )
43
+ preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
44
+ return preprocess.__dict__[name]
45
+ if name in ["create_windows_from_events", "create_fixed_length_windows"]:
46
+ warn(
47
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
48
+ )
49
+ windowers = importlib.import_module("..preprocessing.windowers", __package__)
50
+ return windowers.__dict__[name]
51
+
52
+ raise AttributeError("No possible import named " + name)
53
+
54
+
55
+ __all__ = [
56
+ "load_concat_dataset",
57
+ "save_concat_dataset",
58
+ "_check_save_dir_empty",
59
+ "match_hemisphere_chans",
60
+ "division_channels_idx",
61
+ "infer_signal_properties",
62
+ ]
@@ -0,0 +1,114 @@
1
+ """
2
+ Utilities for EEG channel manipulation and selection.
3
+
4
+ This module provides functions for dividing and matching EEG channels,
5
+ particularly for hemisphere-aware processing.
6
+ """
7
+
8
+ import re
9
+ from re import search
10
+
11
+
12
+ def match_hemisphere_chans(left_chs, right_chs):
13
+ """
14
+ Match channels of the left and right hemispheres based on their names.
15
+
16
+ This function pairs channels from the left and right hemispheres by matching
17
+ their numeric identifiers. For a left channel with number N, it finds the
18
+ corresponding right channel with number N+1.
19
+
20
+ Parameters
21
+ ----------
22
+ left_chs : list of str
23
+ A list of channel names from the left hemisphere.
24
+ right_chs : list of str
25
+ A list of channel names from the right hemisphere.
26
+
27
+ Returns
28
+ -------
29
+ list of tuples
30
+ List of tuples with matched channel names from the left and right hemispheres.
31
+ Each tuple contains (left_channel, right_channel).
32
+
33
+ Raises
34
+ ------
35
+ ValueError
36
+ If the left and right channels do not match in length.
37
+ ValueError
38
+ If a channel name does not contain a number.
39
+ ValueError
40
+ If no matching right hemisphere channel is found for a left channel.
41
+
42
+ Examples
43
+ --------
44
+ >>> left = ['C3', 'F3']
45
+ >>> right = ['C4', 'F4']
46
+ >>> match_hemisphere_chans(left, right)
47
+ [('C3', 'C4'), ('F3', 'F4')]
48
+ """
49
+ if len(left_chs) != len(right_chs):
50
+ raise ValueError("Left and right channels do not match.")
51
+ right_chs = list(right_chs)
52
+ regexp = r"\d+"
53
+ out = []
54
+ for left in left_chs:
55
+ match = re.search(regexp, left)
56
+ if match is None:
57
+ raise ValueError(f"Channel '{left}' does not contain a number.")
58
+ chan_idx = 1 + int(match.group())
59
+ target_r = re.sub(regexp, str(chan_idx), left)
60
+ for right in right_chs:
61
+ if right == target_r:
62
+ out.append((left, right))
63
+ right_chs.remove(right)
64
+ break
65
+ else:
66
+ raise ValueError(
67
+ f"Found no right hemisphere matching channel for '{left}'."
68
+ )
69
+ return out
70
+
71
+
72
+ def division_channels_idx(ch_names):
73
+ """
74
+ Divide EEG channel names into left, right, and middle based on numbering.
75
+
76
+ This function categorizes channels by their numeric suffix:
77
+ - Odd-numbered channels → left hemisphere
78
+ - Even-numbered channels → right hemisphere
79
+ - Channels without numbers → middle/midline
80
+
81
+ Parameters
82
+ ----------
83
+ ch_names : list of str
84
+ A list of EEG channel names to be divided based on their numbering.
85
+
86
+ Returns
87
+ -------
88
+ tuple of lists
89
+ Three lists containing the channel names:
90
+ - left: Odd-numbered channels (e.g., C3, F3, P3)
91
+ - right: Even-numbered channels (e.g., C4, F4, P4)
92
+ - middle: Channels without numbers (e.g., Cz, Fz, Pz)
93
+
94
+ Notes
95
+ -----
96
+ The function identifies channel numbers by searching for numeric characters
97
+ in the channel names. Standard 10-20 system EEG channel naming conventions
98
+ use odd numbers for left hemisphere and even numbers for right hemisphere.
99
+
100
+ Examples
101
+ --------
102
+ >>> channels = ['FP1', 'FP2', 'O1', 'O2', 'FZ']
103
+ >>> division_channels_idx(channels)
104
+ (['FP1', 'O1'], ['FP2', 'O2'], ['FZ'])
105
+ """
106
+ left, right, middle = [], [], []
107
+ for ch in ch_names:
108
+ number = search(r"\d+", ch)
109
+ if number is not None:
110
+ (left if int(number[0]) % 2 else right).append(ch)
111
+ else:
112
+ middle.append(ch)
113
+
114
+ return left, right, middle
@@ -0,0 +1,180 @@
1
+ """
2
+ Format converters for Hugging Face Hub integration.
3
+
4
+ This module provides Zarr format converters to transform EEG datasets for
5
+ efficient storage and fast random access during training on the Hugging Face Hub.
6
+
7
+ This module provides a standalone functional API that delegates to the
8
+ HubDatasetMixin methods for all actual implementations.
9
+ """
10
+
11
+ # Authors: Kuntal Kokate
12
+ #
13
+ # License: BSD (3-clause)
14
+
15
+ from __future__ import annotations
16
+
17
+ import shutil
18
+ from pathlib import Path
19
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
20
+
21
+ # Import registry for dynamic class lookup
22
+ from ..datasets.registry import get_dataset_class
23
+
24
+ # Import dataset classes for type checking only
25
+ if TYPE_CHECKING:
26
+ from ..datasets.base import BaseConcatDataset
27
+
28
+
29
+ # =============================================================================
30
+ # Zarr Format Converters
31
+ # =============================================================================
32
+
33
+
34
+ def convert_to_zarr(
35
+ dataset: BaseConcatDataset,
36
+ output_path: Union[str, Path],
37
+ compression: str = "blosc",
38
+ compression_level: int = 5,
39
+ overwrite: bool = False,
40
+ ) -> Path:
41
+ """Convert BaseConcatDataset to Zarr format.
42
+
43
+ Zarr provides cloud-native chunked storage, optimized for random access
44
+ during training. This is the format used for Hugging Face Hub uploads,
45
+ based on comprehensive benchmarking showing:
46
+ - Fastest random access: 0.010 ms (critical for PyTorch DataLoader)
47
+ - Fast save/load: 0.46s / 0.12s
48
+ - Good compression: ~23% size reduction with blosc
49
+
50
+ Parameters
51
+ ----------
52
+ dataset : BaseConcatDataset
53
+ The dataset to convert.
54
+ output_path : str | Path
55
+ Path where the Zarr directory will be created.
56
+ compression : str, default="blosc"
57
+ Compression algorithm. Options: "blosc" (recommended), "zstd", "gzip", None.
58
+ blosc uses zstd codec by default, providing best balance of speed and compression.
59
+ compression_level : int, default=5
60
+ Compression level (0-9). Level 5 provides optimal balance based on benchmarks.
61
+ overwrite : bool, default=False
62
+ Whether to overwrite existing directory.
63
+
64
+ Returns
65
+ -------
66
+ Path
67
+ Path to the created Zarr directory.
68
+
69
+ Notes
70
+ -----
71
+ The chunking strategy is optimized for random access:
72
+ - Windowed data: Each window is a separate chunk (1, n_channels, n_times)
73
+ - Raw data: Chunks of (n_channels, 10000) samples
74
+
75
+ Examples
76
+ --------
77
+ >>> dataset = NMT(path=path, preload=True)
78
+ >>> # Use default settings (optimal from benchmarks)
79
+ >>> zarr_path = convert_to_zarr(dataset, "dataset.zarr")
80
+ >>>
81
+ >>> # Or customize compression
82
+ >>> zarr_path = convert_to_zarr(
83
+ ... dataset, "dataset.zarr",
84
+ ... compression="blosc",
85
+ ... compression_level=5
86
+ ... )
87
+ """
88
+ output_path = Path(output_path)
89
+
90
+ if output_path.exists():
91
+ if not overwrite:
92
+ raise FileExistsError(
93
+ f"{output_path} already exists. Set overwrite=True to replace it."
94
+ )
95
+ # Remove existing directory if overwrite is True
96
+ shutil.rmtree(output_path)
97
+
98
+ # Delegate to HubDatasetMixin method
99
+ dataset._convert_to_zarr_inline(output_path, compression, compression_level)
100
+
101
+ return output_path
102
+
103
+
104
+ def load_from_zarr(
105
+ input_path: Union[str, Path],
106
+ preload: bool = True,
107
+ ids_to_load: Optional[List[int]] = None,
108
+ ):
109
+ """Load BaseConcatDataset from Zarr format.
110
+
111
+ Zarr is the format used for braindecode Hub datasets, providing
112
+ the fastest random access performance for training with PyTorch.
113
+
114
+ Parameters
115
+ ----------
116
+ input_path : str | Path
117
+ Path to the Zarr directory.
118
+ preload : bool, default=True
119
+ Whether to load data into memory. If False, uses lazy loading
120
+ (data is loaded on-demand during training).
121
+ ids_to_load : list of int | None
122
+ Specific recording IDs to load. If None, loads all.
123
+
124
+ Returns
125
+ -------
126
+ BaseConcatDataset
127
+ The loaded dataset.
128
+
129
+ Examples
130
+ --------
131
+ >>> # Load from local zarr directory
132
+ >>> dataset = load_from_zarr("dataset.zarr", preload=True)
133
+ >>>
134
+ >>> # Load from Hugging Face Hub (handled automatically)
135
+ >>> from braindecode.datasets import BaseConcatDataset
136
+ >>> dataset = BaseConcatDataset.from_pretrained("username/dataset-name")
137
+ """
138
+ # Delegate to HubDatasetMixin static method
139
+ BaseConcatDataset = get_dataset_class("BaseConcatDataset")
140
+
141
+ # Load full dataset using mixin method
142
+ dataset = BaseConcatDataset._load_from_zarr_inline(Path(input_path), preload)
143
+
144
+ # Filter to specific IDs if requested
145
+ if ids_to_load is not None:
146
+ # Get only the requested datasets
147
+ filtered_datasets = [dataset.datasets[i] for i in ids_to_load]
148
+ dataset = BaseConcatDataset(filtered_datasets)
149
+
150
+ return dataset
151
+
152
+
153
+ # =============================================================================
154
+ # Utility Functions
155
+ # =============================================================================
156
+
157
+
158
+ def get_format_info(dataset: "BaseConcatDataset") -> Dict:
159
+ """Get dataset information for Hub metadata.
160
+
161
+ Validates that all datasets in the concat have uniform properties
162
+ (channels, sampling frequency) and raises an error if not.
163
+
164
+ Parameters
165
+ ----------
166
+ dataset : BaseConcatDataset
167
+ The dataset to analyze.
168
+
169
+ Returns
170
+ -------
171
+ dict
172
+ Dictionary with dataset statistics and format info.
173
+
174
+ Raises
175
+ ------
176
+ ValueError
177
+ If datasets have inconsistent channels or sampling frequencies.
178
+ """
179
+ # Delegate to HubDatasetMixin method
180
+ return dataset._get_format_info_inline()