braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
3
  # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
- # Bruna Lopes <brunajaflopes@gmail.com>
5
4
  #
6
5
  # License: BSD (3-clause)
7
6
 
@@ -1195,103 +1194,3 @@ def mask_encoding(
1195
1194
  X[mask] = 0
1196
1195
 
1197
1196
  return X, y # Return the masked tensor and labels
1198
-
1199
-
1200
- def channels_rereference(
1201
- X: torch.Tensor,
1202
- y: torch.Tensor,
1203
- random_state: int | np.random.RandomState | None = None,
1204
- ) -> tuple[torch.Tensor, torch.Tensor]:
1205
- """Randomly re-reference channels in EEG data matrix.
1206
-
1207
- Part of the augmentations proposed in [1]_
1208
-
1209
- Parameters
1210
- ----------
1211
- X : torch.Tensor
1212
- EEG input example or batch.
1213
- y : torch.Tensor
1214
- EEG labels for the example or batch.
1215
- random_state: int | numpy.random.Generator, optional
1216
- Seed to be used to instantiate numpy random number generator instance.
1217
- Defaults to None.
1218
-
1219
- Returns
1220
- -------
1221
- torch.Tensor
1222
- Transformed inputs.
1223
- torch.Tensor
1224
- Transformed labels.
1225
-
1226
- References
1227
- ----------
1228
- .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1229
- Representation Learning for Electroencephalogram Classification. Proceedings
1230
- of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1231
- Learning Research 136:238-253
1232
-
1233
- """
1234
-
1235
- rng = check_random_state(random_state)
1236
- batch_size, n_channels, _ = X.shape
1237
-
1238
- ch = rng.randint(0, n_channels, size=batch_size)
1239
-
1240
- X_ch = X[torch.arange(batch_size), ch, :]
1241
- X = X - X_ch.unsqueeze(1)
1242
- X[torch.arange(batch_size), ch, :] = -X_ch
1243
-
1244
- return X, y
1245
-
1246
-
1247
- def amplitude_scale(
1248
- X: torch.Tensor,
1249
- y: torch.Tensor,
1250
- scale: tuple,
1251
- random_state: int | np.random.RandomState | None = None,
1252
- ) -> tuple[torch.Tensor, torch.Tensor]:
1253
- """Rescale amplitude of each channel based on a random sampled scaling value.
1254
-
1255
- Part of the augmentations proposed in [1]_
1256
-
1257
- Parameters
1258
- ----------
1259
- X : torch.Tensor
1260
- EEG input example or batch.
1261
- y : torch.Tensor
1262
- EEG labels for the example or batch.
1263
- scale : tuple of floats
1264
- Interval from which ypu sample the scaling value
1265
- random_state: int | numpy.random.Generator, optional
1266
- Seed to be used to instantiate numpy random number generator instance.
1267
- Defaults to None.
1268
-
1269
- Returns
1270
- -------
1271
- torch.Tensor
1272
- Transformed inputs.
1273
- torch.Tensor
1274
- Transformed labels.
1275
-
1276
- References
1277
- ----------
1278
- .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1279
- Representation Learning for Electroencephalogram Classification. Proceedings
1280
- of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1281
- Learning Research 136:238-253
1282
-
1283
- """
1284
-
1285
- rng = torch.Generator()
1286
- rng.manual_seed(random_state)
1287
- batch_size, n_channels, _ = X.shape
1288
-
1289
- # Parameter for scaling amplitude / channel / trial
1290
- l, h = scale
1291
- s = l + (h - l) * torch.rand(
1292
- batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
1293
- )
1294
-
1295
- X = s * X
1296
-
1297
- return X, y
@@ -1,7 +1,6 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
3
  # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
- # Bruna Lopes <brunajaflopes@gmail.com>
5
4
  #
6
5
  # License: BSD (3-clause)
7
6
 
@@ -14,11 +13,9 @@ from mne.channels import make_standard_montage
14
13
 
15
14
  from .base import Transform
16
15
  from .functional import (
17
- amplitude_scale,
18
16
  bandstop_filter,
19
17
  channels_dropout,
20
18
  channels_permute,
21
- channels_rereference,
22
19
  channels_shuffle,
23
20
  frequency_shift,
24
21
  ft_surrogate,
@@ -1274,74 +1271,3 @@ class MaskEncoding(Transform):
1274
1271
  "segment_length": segment_length,
1275
1272
  "n_segments": self.n_segments,
1276
1273
  }
1277
-
1278
-
1279
- class ChannelsReref(Transform):
1280
- """Randomly re-reference channels in EEG data matrix.
1281
-
1282
- Part of the augmentations proposed in [1]_
1283
-
1284
- Parameters
1285
- ----------
1286
- probability: float
1287
- Float setting the probability of applying the operation.
1288
- random_state: int | numpy.random.Generator, optional
1289
- Seed to be used to instantiate numpy random number generator instance.
1290
- Used to decide whether or not to transform given the probability
1291
- argument, to sample which channels to shuffle and to carry the shuffle.
1292
- Defaults to None.
1293
-
1294
- References
1295
- ----------
1296
- .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1297
- Representation Learning for Electroencephalogram Classification. Proceedings
1298
- of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1299
- Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1300
-
1301
- """
1302
-
1303
- operation = staticmethod(channels_rereference) # type: ignore[assignment]
1304
-
1305
- def __init__(self, probability, random_state=None):
1306
- super().__init__(probability=probability, random_state=random_state)
1307
-
1308
- def get_augmentation_params(self, *batch):
1309
- """Return transform parameters"""
1310
- return {
1311
- "random_state": self.rng,
1312
- }
1313
-
1314
-
1315
- class AmplitudeScale(Transform):
1316
- """Rescale amplitude based on a random sampled scaling value.
1317
-
1318
- Part of the augmentations proposed in [1]_
1319
-
1320
- Parameters
1321
- ----------
1322
- probability: float
1323
- Float setting the probability of applying the operation.
1324
- random_state: int | numpy.random.Generator, optional
1325
- Seed to be used to instantiate numpy random number generator instance.
1326
- Used to decide whether or not to transform given the probability
1327
- argument, to sample which channels to shuffle and to carry the shuffle.
1328
- Defaults to None.
1329
-
1330
- References
1331
- ----------
1332
- .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1333
- Representation Learning for Electroencephalogram Classification. Proceedings
1334
- of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1335
- Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1336
-
1337
- """
1338
-
1339
- operation = staticmethod(amplitude_scale) # type: ignore[assignment]
1340
-
1341
- def __init__(self, probability, interval=(0.5, 2), random_state=None):
1342
- super().__init__(probability=probability, random_state=random_state)
1343
- self.scale = interval
1344
-
1345
- def get_augmentation_params(self, *batch):
1346
- """Return transform parameters"""
1347
- return {"random_state": self.rng, "scale": self.scale}
@@ -19,7 +19,7 @@ import warnings
19
19
  from abc import abstractmethod
20
20
  from collections.abc import Callable
21
21
  from glob import glob
22
- from typing import Any, Generic, Iterable, no_type_check
22
+ from typing import Generic, Iterable, no_type_check
23
23
 
24
24
  import mne.io
25
25
  import numpy as np
@@ -28,9 +28,6 @@ from mne.utils.docs import deprecated
28
28
  from torch.utils.data import ConcatDataset, Dataset
29
29
  from typing_extensions import TypeVar
30
30
 
31
- from .hub import HubDatasetMixin
32
- from .registry import register_dataset
33
-
34
31
 
35
32
  def _create_description(description) -> pd.Series:
36
33
  if description is not None:
@@ -100,7 +97,6 @@ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]])
100
97
  T = TypeVar("T", bound=RecordDataset)
101
98
 
102
99
 
103
- @register_dataset
104
100
  class RawDataset(RecordDataset):
105
101
  """Returns samples from an mne.io.Raw object along with a target.
106
102
 
@@ -133,7 +129,6 @@ class RawDataset(RecordDataset):
133
129
 
134
130
  # save target name for load/save later
135
131
  self.target_name = self._target_name(target_name)
136
- self.raw_preproc_kwargs: list[dict[str, Any]] = []
137
132
 
138
133
  def __getitem__(self, index):
139
134
  X = self.raw[:, index][0]
@@ -181,12 +176,10 @@ class RawDataset(RecordDataset):
181
176
  "If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
182
177
  "use the RecordDataset class instead."
183
178
  )
184
- @register_dataset
185
179
  class BaseDataset(RawDataset):
186
180
  pass
187
181
 
188
182
 
189
- @register_dataset
190
183
  class EEGWindowsDataset(RecordDataset):
191
184
  """Returns windows from an mne.Raw object, its window indices, along with a target.
192
185
 
@@ -242,7 +235,6 @@ class EEGWindowsDataset(RecordDataset):
242
235
  ].to_numpy()
243
236
  if self.targets_from == "metadata":
244
237
  self.y = metadata.loc[:, "target"].to_list()
245
- self.raw_preproc_kwargs: list[dict[str, Any]] = []
246
238
 
247
239
  def __getitem__(self, index: int):
248
240
  """Get a window and its target.
@@ -293,7 +285,6 @@ class EEGWindowsDataset(RecordDataset):
293
285
  return len(self.crop_inds)
294
286
 
295
287
 
296
- @register_dataset
297
288
  class WindowsDataset(RecordDataset):
298
289
  """Returns windows from an mne.Epochs object along with a target.
299
290
 
@@ -343,8 +334,6 @@ class WindowsDataset(RecordDataset):
343
334
  ].to_numpy()
344
335
  if self.targets_from == "metadata":
345
336
  self.y = metadata.loc[:, "target"].to_list()
346
- self.raw_preproc_kwargs: list[dict[str, Any]] = []
347
- self.window_preproc_kwargs: list[dict[str, Any]] = []
348
337
 
349
338
  def __getitem__(self, index: int):
350
339
  """Get a window and its target.
@@ -385,16 +374,12 @@ class WindowsDataset(RecordDataset):
385
374
  return len(self.windows.events)
386
375
 
387
376
 
388
- @register_dataset
389
- class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
377
+ class BaseConcatDataset(ConcatDataset, Generic[T]):
390
378
  """A base class for concatenated datasets.
391
379
 
392
380
  Holds either mne.Raw or mne.Epoch in self.datasets and has
393
381
  a pandas DataFrame with additional description.
394
382
 
395
- Includes Hugging Face Hub integration via HubDatasetMixin for
396
- uploading and downloading datasets.
397
-
398
383
  Parameters
399
384
  ----------
400
385
  list_of_ds : list
@@ -809,7 +794,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
809
794
  kwargs = getattr(ds, kwargs_name)
810
795
  if kwargs is not None:
811
796
  with open(kwargs_file_path, "w") as f:
812
- json.dump(kwargs, f, indent=2)
797
+ json.dump(kwargs, f)
813
798
 
814
799
  @staticmethod
815
800
  def _save_target_name(sub_dir, ds):
@@ -302,6 +302,7 @@ def _load_kwargs_json(kwargs_name, sub_dir):
302
302
  kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
303
303
  if os.path.exists(kwargs_file_path):
304
304
  kwargs = json.load(open(kwargs_file_path, "r"))
305
+ kwargs = [tuple(kwarg) for kwarg in kwargs]
305
306
  return kwargs
306
307
 
307
308
 
@@ -28,7 +28,6 @@ from .fbmsnet import FBMSNet
28
28
  from .hybrid import HybridNet
29
29
  from .ifnet import IFNet
30
30
  from .labram import Labram
31
- from .luna import LUNA
32
31
  from .medformer import MEDFormer
33
32
  from .msvtnet import MSVTNet
34
33
  from .patchedtransformer import PBT
@@ -50,11 +49,7 @@ from .tcn import BDTCN, TCN
50
49
  from .tidnet import TIDNet
51
50
  from .tsinception import TSception
52
51
  from .usleep import USleep
53
- from .util import (
54
- _init_models_dict,
55
- extract_channel_locations_from_chs_info,
56
- models_mandatory_parameters,
57
- )
52
+ from .util import _init_models_dict, models_mandatory_parameters
58
53
 
59
54
  # Call this last in order to make sure the dataset list is populated with
60
55
  # the models imported in this file.
@@ -88,8 +83,6 @@ __all__ = [
88
83
  "HybridNet",
89
84
  "IFNet",
90
85
  "Labram",
91
- "LUNA",
92
- "extract_channel_locations_from_chs_info",
93
86
  "MEDFormer",
94
87
  "MSVTNet",
95
88
  "PBT",
@@ -41,5 +41,4 @@ IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860
41
41
  PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
42
42
  SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
43
43
  BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
44
- LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Large Brain Model"
45
44
  MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
@@ -4,9 +4,7 @@
4
4
  # License: BSD (3-clause)
5
5
  import inspect
6
6
  from pathlib import Path
7
- from typing import Any, Dict, Optional, Sequence
8
7
 
9
- import numpy as np
10
8
  import pandas as pd
11
9
 
12
10
  import braindecode.models as models
@@ -101,7 +99,6 @@ models_mandatory_parameters = [
101
99
  ("PBT", ["n_chans", "n_outputs", "n_times"], None),
102
100
  ("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
103
101
  ("BENDR", ["n_chans", "n_outputs", "n_times"], None),
104
- ("LUNA", ["n_chans", "n_times", "n_outputs"], None),
105
102
  ("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
106
103
  ]
107
104
 
@@ -134,85 +131,4 @@ def get_summary_table(dir_name=None):
134
131
  return df
135
132
 
136
133
 
137
- def extract_channel_locations_from_chs_info(
138
- chs_info: Optional[Sequence[Dict[str, Any]]],
139
- num_channels: Optional[int] = None,
140
- ) -> Optional[np.ndarray]:
141
- """Extract 3D channel locations from MNE-style channel information.
142
-
143
- This function provides a unified approach to extract 3D channel locations
144
- from MNE channel information. It's compatible with models like SignalJEPA
145
- and LUNA that need to work with channel spatial information.
146
-
147
- Parameters
148
- ----------
149
- chs_info : list of dict or None
150
- Channel information, typically from ``mne.Info.chs``. Each dict should
151
- contain a 'loc' key with a 12-element array (MNE format) where indices 3:6
152
- represent the 3D cartesian coordinates.
153
- num_channels : int or None
154
- If specified, only extract the first ``num_channels`` channel locations.
155
- If None, extract all available channels.
156
-
157
- Returns
158
- -------
159
- channel_locations : np.ndarray of shape (n_channels, 3) or None
160
- Array of 3D channel locations in cartesian coordinates. Returns None if
161
- no valid locations are found.
162
-
163
- Notes
164
- -----
165
- - This function handles both 12-element MNE location format (using indices 3:6)
166
- and 3-element location format (using directly).
167
- - Invalid or missing locations cause extraction to stop at that point.
168
- - Returns None if no valid locations can be extracted.
169
- - This is a unified utility compatible with models like SignalJEPA and LUNA.
170
-
171
- Examples
172
- --------
173
- >>> import mne
174
- >>> from braindecode.models.util import extract_channel_locations_from_chs_info
175
- >>> raw = mne.io.read_raw_edf("sample.edf")
176
- >>> locs = extract_channel_locations_from_chs_info(raw.info['chs'], num_channels=22)
177
- >>> print(locs.shape)
178
- (22, 3)
179
- """
180
- if chs_info is None:
181
- return None
182
-
183
- locations = []
184
- n_to_extract = num_channels if num_channels is not None else len(chs_info)
185
-
186
- for i, ch_info in enumerate(chs_info[:n_to_extract]):
187
- if not isinstance(ch_info, dict):
188
- break
189
-
190
- loc = ch_info.get("loc")
191
- if loc is None:
192
- break
193
-
194
- try:
195
- loc_array = np.asarray(loc, dtype=np.float32)
196
-
197
- # MNE format: 12-element array with coordinates at indices 3:6
198
- if loc_array.ndim == 1 and loc_array.size >= 6:
199
- if loc_array.size == 12:
200
- # Standard MNE format
201
- coordinates = loc_array[3:6]
202
- else:
203
- # Assume first 3 elements are coordinates
204
- coordinates = loc_array[:3]
205
- else:
206
- break
207
-
208
- locations.append(coordinates)
209
- except (ValueError, TypeError):
210
- break
211
-
212
- if len(locations) == 0:
213
- return None
214
-
215
- return np.stack(locations, axis=0)
216
-
217
-
218
134
  _summary_table = get_summary_table()
@@ -188,17 +188,12 @@ from .preprocess import (
188
188
  filterbank,
189
189
  preprocess,
190
190
  )
191
- from .util import _init_preprocessor_dict
192
191
  from .windowers import (
193
192
  create_fixed_length_windows,
194
193
  create_windows_from_events,
195
194
  create_windows_from_target_channels,
196
195
  )
197
196
 
198
- # Call this last in order to make sure the list is populated with
199
- # the preprocessors imported in this file.
200
- _init_preprocessor_dict()
201
-
202
197
  __all__ = [
203
198
  "exponential_moving_demean",
204
199
  "exponential_moving_standardize",