braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -10,24 +10,28 @@ Dataset classes.
10
10
  #
11
11
  # License: BSD (3-clause)
12
12
 
13
- import os
13
+ from __future__ import annotations
14
+
14
15
  import json
16
+ import os
15
17
  import shutil
16
- from typing import Iterable
17
18
  import warnings
19
+ from collections.abc import Callable
18
20
  from glob import glob
21
+ from typing import Iterable, no_type_check
19
22
 
23
+ import mne.io
20
24
  import numpy as np
21
25
  import pandas as pd
22
- from torch.utils.data import Dataset, ConcatDataset
26
+ from torch.utils.data import ConcatDataset, Dataset
23
27
 
24
28
 
25
- def _create_description(description):
29
+ def _create_description(description) -> pd.Series:
26
30
  if description is not None:
27
- if (not isinstance(description, pd.Series) and
28
- not isinstance(description, dict)):
29
- raise ValueError(f"'{description}' has to be either a "
30
- f"pandas.Series or a dict.")
31
+ if not isinstance(description, pd.Series) and not isinstance(description, dict):
32
+ raise ValueError(
33
+ f"'{description}' has to be either a pandas.Series or a dict."
34
+ )
31
35
  if isinstance(description, dict):
32
36
  description = pd.Series(description)
33
37
  return description
@@ -52,8 +56,14 @@ class BaseDataset(Dataset):
52
56
  transform : callable | None
53
57
  On-the-fly transform applied to the example before it is returned.
54
58
  """
55
- def __init__(self, raw, description=None, target_name=None,
56
- transform=None):
59
+
60
+ def __init__(
61
+ self,
62
+ raw: mne.io.BaseRaw,
63
+ description: dict | pd.Series | None = None,
64
+ target_name: str | tuple[str, ...] | None = None,
65
+ transform: Callable | None = None,
66
+ ):
57
67
  self.raw = raw
58
68
  self._description = _create_description(description)
59
69
  self.transform = transform
@@ -82,14 +92,14 @@ class BaseDataset(Dataset):
82
92
  @transform.setter
83
93
  def transform(self, value):
84
94
  if value is not None and not callable(value):
85
- raise ValueError('Transform needs to be a callable.')
95
+ raise ValueError("Transform needs to be a callable.")
86
96
  self._transform = value
87
97
 
88
98
  @property
89
- def description(self):
99
+ def description(self) -> pd.Series:
90
100
  return self._description
91
101
 
92
- def set_description(self, description, overwrite=False):
102
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
93
103
  """Update (add or overwrite) the dataset description.
94
104
 
95
105
  Parameters
@@ -104,8 +114,10 @@ class BaseDataset(Dataset):
104
114
  for key, value in description.items():
105
115
  # if the key is already in the existing description, drop it
106
116
  if self._description is not None and key in self._description:
107
- assert overwrite, (f"'{key}' already in description. Please "
108
- f"rename or set overwrite to True.")
117
+ assert overwrite, (
118
+ f"'{key}' already in description. Please "
119
+ f"rename or set overwrite to True."
120
+ )
109
121
  self._description.pop(key)
110
122
  if self._description is None:
111
123
  self._description = description
@@ -114,7 +126,7 @@ class BaseDataset(Dataset):
114
126
 
115
127
  def _target_name(self, target_name):
116
128
  if target_name is not None and not isinstance(target_name, (str, tuple, list)):
117
- raise ValueError('target_name has to be None, str, tuple or list')
129
+ raise ValueError("target_name has to be None, str, tuple or list")
118
130
  if target_name is None:
119
131
  return target_name
120
132
  else:
@@ -128,9 +140,12 @@ class BaseDataset(Dataset):
128
140
  # check if target name(s) can be read from description
129
141
  for name in target_name:
130
142
  if self.description is None or name not in self.description:
131
- warnings.warn(f"'{name}' not in description. '__getitem__'"
132
- f"will fail unless an appropriate target is"
133
- f" added to description.", UserWarning)
143
+ warnings.warn(
144
+ f"'{name}' not in description. '__getitem__'"
145
+ f"will fail unless an appropriate target is"
146
+ f" added to description.",
147
+ UserWarning,
148
+ )
134
149
  # return a list of str if there are multiple targets and a str otherwise
135
150
  return target_name if len(target_name) > 1 else target_name[0]
136
151
 
@@ -168,24 +183,31 @@ class EEGWindowsDataset(BaseDataset):
168
183
  as well as `targets`.
169
184
  """
170
185
 
171
- def __init__(self, raw, metadata, description=None, transform=None, targets_from='metadata',
172
- last_target_only=True, ):
186
+ def __init__(
187
+ self,
188
+ raw: mne.io.BaseRaw | mne.BaseEpochs,
189
+ metadata: pd.DataFrame,
190
+ description: dict | pd.Series | None = None,
191
+ transform: Callable | None = None,
192
+ targets_from: str = "metadata",
193
+ last_target_only: bool = True,
194
+ ):
173
195
  self.raw = raw
174
196
  self.metadata = metadata
175
197
  self._description = _create_description(description)
176
198
 
177
199
  self.transform = transform
178
200
  self.last_target_only = last_target_only
179
- if targets_from not in ('metadata', 'channels'):
180
- raise ValueError('Wrong value for parameter `targets_from`.')
201
+ if targets_from not in ("metadata", "channels"):
202
+ raise ValueError("Wrong value for parameter `targets_from`.")
181
203
  self.targets_from = targets_from
182
204
  self.crop_inds = metadata.loc[
183
- :, ['i_window_in_trial', 'i_start_in_trial',
184
- 'i_stop_in_trial']].to_numpy()
185
- if self.targets_from == 'metadata':
186
- self.y = metadata.loc[:, 'target'].to_list()
205
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
206
+ ].to_numpy()
207
+ if self.targets_from == "metadata":
208
+ self.y = metadata.loc[:, "target"].to_list()
187
209
 
188
- def __getitem__(self, index):
210
+ def __getitem__(self, index: int):
189
211
  """Get a window and its target.
190
212
 
191
213
  Parameters
@@ -209,16 +231,16 @@ class EEGWindowsDataset(BaseDataset):
209
231
 
210
232
  i_window_in_trial, i_start, i_stop = crop_inds
211
233
  X = self.raw._getitem((slice(None), slice(i_start, i_stop)), return_times=False)
212
- X = X.astype('float32')
234
+ X = X.astype("float32")
213
235
  # ensure we don't give the user the option
214
236
  # to accidentally modify the underlying array
215
237
  X = X.copy()
216
238
  if self.transform is not None:
217
239
  X = self.transform(X)
218
- if self.targets_from == 'metadata':
240
+ if self.targets_from == "metadata":
219
241
  y = self.y[index]
220
242
  else:
221
- misc_mask = np.array(self.raw.get_channel_types()) == 'misc'
243
+ misc_mask = np.array(self.raw.get_channel_types()) == "misc"
222
244
  if self.last_target_only:
223
245
  y = X[misc_mask, -1]
224
246
  else:
@@ -240,14 +262,14 @@ class EEGWindowsDataset(BaseDataset):
240
262
  @transform.setter
241
263
  def transform(self, value):
242
264
  if value is not None and not callable(value):
243
- raise ValueError('Transform needs to be a callable.')
265
+ raise ValueError("Transform needs to be a callable.")
244
266
  self._transform = value
245
267
 
246
268
  @property
247
- def description(self):
269
+ def description(self) -> pd.Series:
248
270
  return self._description
249
271
 
250
- def set_description(self, description, overwrite=False):
272
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
251
273
  """Update (add or overwrite) the dataset description.
252
274
 
253
275
  Parameters
@@ -262,8 +284,10 @@ class EEGWindowsDataset(BaseDataset):
262
284
  for key, value in description.items():
263
285
  # if they key is already in the existing description, drop it
264
286
  if key in self._description:
265
- assert overwrite, (f"'{key}' already in description. Please "
266
- f"rename or set overwrite to True.")
287
+ assert overwrite, (
288
+ f"'{key}' already in description. Please "
289
+ f"rename or set overwrite to True."
290
+ )
267
291
  self._description.pop(key)
268
292
  self._description = pd.concat([self.description, description])
269
293
 
@@ -294,23 +318,30 @@ class WindowsDataset(BaseDataset):
294
318
  Defines whether targets will be extracted from mne.Epochs metadata or mne.Epochs `misc`
295
319
  channels (time series targets). It can be `metadata` (default) or `channels`.
296
320
  """
297
- def __init__(self, windows, description=None, transform=None, targets_from='metadata',
298
- last_target_only=True):
321
+
322
+ def __init__(
323
+ self,
324
+ windows: mne.BaseEpochs,
325
+ description: dict | pd.Series | None = None,
326
+ transform: Callable | None = None,
327
+ targets_from: str = "metadata",
328
+ last_target_only: bool = True,
329
+ ):
299
330
  self.windows = windows
300
331
  self._description = _create_description(description)
301
332
  self.transform = transform
302
333
  self.last_target_only = last_target_only
303
- if targets_from not in ('metadata', 'channels'):
304
- raise ValueError('Wrong value for parameter `targets_from`.')
334
+ if targets_from not in ("metadata", "channels"):
335
+ raise ValueError("Wrong value for parameter `targets_from`.")
305
336
  self.targets_from = targets_from
306
337
 
307
338
  self.crop_inds = self.windows.metadata.loc[
308
- :, ['i_window_in_trial', 'i_start_in_trial',
309
- 'i_stop_in_trial']].to_numpy()
310
- if self.targets_from == 'metadata':
311
- self.y = self.windows.metadata.loc[:, 'target'].to_list()
339
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
340
+ ].to_numpy()
341
+ if self.targets_from == "metadata":
342
+ self.y = self.windows.metadata.loc[:, "target"].to_list()
312
343
 
313
- def __getitem__(self, index):
344
+ def __getitem__(self, index: int):
314
345
  """Get a window and its target.
315
346
 
316
347
  Parameters
@@ -327,13 +358,13 @@ class WindowsDataset(BaseDataset):
327
358
  np.ndarray
328
359
  Crop indices.
329
360
  """
330
- X = self.windows.get_data(item=index)[0].astype('float32')
361
+ X = self.windows.get_data(item=index)[0].astype("float32")
331
362
  if self.transform is not None:
332
363
  X = self.transform(X)
333
- if self.targets_from == 'metadata':
364
+ if self.targets_from == "metadata":
334
365
  y = self.y[index]
335
366
  else:
336
- misc_mask = np.array(self.windows.get_channel_types()) == 'misc'
367
+ misc_mask = np.array(self.windows.get_channel_types()) == "misc"
337
368
  if self.last_target_only:
338
369
  y = X[misc_mask, -1]
339
370
  else:
@@ -345,7 +376,7 @@ class WindowsDataset(BaseDataset):
345
376
  crop_inds = self.crop_inds[index].tolist()
346
377
  return X, y, crop_inds
347
378
 
348
- def __len__(self):
379
+ def __len__(self) -> int:
349
380
  return len(self.windows.events)
350
381
 
351
382
  @property
@@ -355,14 +386,14 @@ class WindowsDataset(BaseDataset):
355
386
  @transform.setter
356
387
  def transform(self, value):
357
388
  if value is not None and not callable(value):
358
- raise ValueError('Transform needs to be a callable.')
389
+ raise ValueError("Transform needs to be a callable.")
359
390
  self._transform = value
360
391
 
361
392
  @property
362
- def description(self):
393
+ def description(self) -> pd.Series:
363
394
  return self._description
364
395
 
365
- def set_description(self, description, overwrite=False):
396
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
366
397
  """Update (add or overwrite) the dataset description.
367
398
 
368
399
  Parameters
@@ -377,16 +408,19 @@ class WindowsDataset(BaseDataset):
377
408
  for key, value in description.items():
378
409
  # if they key is already in the existing description, drop it
379
410
  if key in self._description:
380
- assert overwrite, (f"'{key}' already in description. Please "
381
- f"rename or set overwrite to True.")
411
+ assert overwrite, (
412
+ f"'{key}' already in description. Please "
413
+ f"rename or set overwrite to True."
414
+ )
382
415
  self._description.pop(key)
383
416
  self._description = pd.concat([self.description, description])
384
417
 
385
418
 
386
419
  class BaseConcatDataset(ConcatDataset):
387
- """A base class for concatenated datasets. Holds either mne.Raw or
388
- mne.Epoch in self.datasets and has a pandas DataFrame with additional
389
- description.
420
+ """A base class for concatenated datasets.
421
+
422
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
423
+ a pandas DataFrame with additional description.
390
424
 
391
425
  Parameters
392
426
  ----------
@@ -394,8 +428,15 @@ class BaseConcatDataset(ConcatDataset):
394
428
  list of BaseDataset, BaseConcatDataset or WindowsDataset
395
429
  target_transform : callable | None
396
430
  Optional function to call on targets before returning them.
431
+
397
432
  """
398
- def __init__(self, list_of_ds, target_transform=None):
433
+
434
+ def __init__(
435
+ self,
436
+ list_of_ds: list[BaseDataset | BaseConcatDataset | WindowsDataset]
437
+ | None = None,
438
+ target_transform: Callable | None = None,
439
+ ):
399
440
  # if we get a list of BaseConcatDataset, get all the individual datasets
400
441
  if list_of_ds and isinstance(list_of_ds[0], BaseConcatDataset):
401
442
  list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
@@ -415,7 +456,7 @@ class BaseConcatDataset(ConcatDataset):
415
456
 
416
457
  return X, y
417
458
 
418
- def __getitem__(self, idx):
459
+ def __getitem__(self, idx: int | list):
419
460
  """
420
461
  Parameters
421
462
  ----------
@@ -433,9 +474,16 @@ class BaseConcatDataset(ConcatDataset):
433
474
  item = item[:1] + (self.target_transform(item[1]),) + item[2:]
434
475
  return item
435
476
 
436
- def split(self, by=None, property=None, split_ids=None):
437
- """Split the dataset based on information listed in its description
438
- DataFrame or based on indices.
477
+ @no_type_check # TODO, it's a mess
478
+ def split(
479
+ self,
480
+ by: str | list[int] | list[list[int]] | dict[str, list[int]] | None = None,
481
+ property: str | None = None,
482
+ split_ids: list[int] | list[list[int]] | dict[str, list[int]] | None = None,
483
+ ) -> dict[str, BaseConcatDataset]:
484
+ """Split the dataset based on information listed in its description.
485
+
486
+ The format could be based on a DataFrame or based on indices.
439
487
 
440
488
  Parameters
441
489
  ----------
@@ -448,8 +496,8 @@ class BaseConcatDataset(ConcatDataset):
448
496
  If a dict then each key will be used in the returned
449
497
  splits dict and each value should be a list of int.
450
498
  property : str
451
- Some property which is listed in info DataFrame.
452
- split_ids : list | dict
499
+ Some property which is listed in the info DataFrame.
500
+ split_ids : list | dict
453
501
  List of indices to be combined in a subset.
454
502
  It can be a list of int or a list of list of int.
455
503
 
@@ -459,20 +507,22 @@ class BaseConcatDataset(ConcatDataset):
459
507
  A dictionary with the name of the split (a string) as key and the
460
508
  dataset as value.
461
509
  """
462
- args_not_none = [
463
- by is not None, property is not None, split_ids is not None]
510
+
511
+ args_not_none = [by is not None, property is not None, split_ids is not None]
464
512
  if sum(args_not_none) != 1:
465
513
  raise ValueError("Splitting requires exactly one argument.")
466
514
 
467
515
  if property is not None or split_ids is not None:
468
- warnings.warn("Keyword arguments `property` and `split_ids` "
469
- "are deprecated and will be removed in the future. "
470
- "Use `by` instead.", DeprecationWarning)
516
+ warnings.warn(
517
+ "Keyword arguments `property` and `split_ids` "
518
+ "are deprecated and will be removed in the future. "
519
+ "Use `by` instead.",
520
+ DeprecationWarning,
521
+ )
471
522
  by = property if property is not None else split_ids
472
523
  if isinstance(by, str):
473
524
  split_ids = {
474
- k: list(v)
475
- for k, v in self.description.groupby(by).groups.items()
525
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
476
526
  }
477
527
  elif isinstance(by, dict):
478
528
  split_ids = by
@@ -483,11 +533,15 @@ class BaseConcatDataset(ConcatDataset):
483
533
  # assume list(list(int))
484
534
  split_ids = {split_i: split for split_i, split in enumerate(by)}
485
535
 
486
- return {str(split_name): BaseConcatDataset(
487
- [self.datasets[ds_ind] for ds_ind in ds_inds], target_transform=self.target_transform)
488
- for split_name, ds_inds in split_ids.items()}
536
+ return {
537
+ str(split_name): BaseConcatDataset(
538
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
539
+ target_transform=self.target_transform,
540
+ )
541
+ for split_name, ds_inds in split_ids.items()
542
+ }
489
543
 
490
- def get_metadata(self):
544
+ def get_metadata(self) -> pd.DataFrame:
491
545
  """Concatenate the metadata and description of the wrapped Epochs.
492
546
 
493
547
  Returns
@@ -497,13 +551,20 @@ class BaseConcatDataset(ConcatDataset):
497
551
  BaseConcatDataset, with the metadata and description information
498
552
  for each window.
499
553
  """
500
- if not all([isinstance(ds, (WindowsDataset, EEGWindowsDataset)) for ds in self.datasets]):
501
- raise TypeError('Metadata dataframe can only be computed when all '
502
- 'datasets are WindowsDataset.')
554
+ if not all(
555
+ [
556
+ isinstance(ds, (WindowsDataset, EEGWindowsDataset))
557
+ for ds in self.datasets
558
+ ]
559
+ ):
560
+ raise TypeError(
561
+ "Metadata dataframe can only be computed when all "
562
+ "datasets are WindowsDataset."
563
+ )
503
564
 
504
565
  all_dfs = list()
505
566
  for ds in self.datasets:
506
- if hasattr(ds, 'windows'):
567
+ if hasattr(ds, "windows"):
507
568
  df = ds.windows.metadata
508
569
  else:
509
570
  df = ds.metadata
@@ -529,7 +590,7 @@ class BaseConcatDataset(ConcatDataset):
529
590
  @target_transform.setter
530
591
  def target_transform(self, fn):
531
592
  if not (callable(fn) or fn is None):
532
- raise TypeError('target_transform must be a callable.')
593
+ raise TypeError("target_transform must be a callable.")
533
594
  self._target_transform = fn
534
595
 
535
596
  def _outdated_save(self, path, overwrite=False):
@@ -547,39 +608,50 @@ class BaseConcatDataset(ConcatDataset):
547
608
  Whether to delete old files (.json, .fif, -epo.fif) in specified
548
609
  directory prior to saving.
549
610
  """
550
- warnings.warn('This function only exists for backwards compatibility '
551
- 'purposes. DO NOT USE!', UserWarning)
611
+ warnings.warn(
612
+ "This function only exists for backwards compatibility "
613
+ "purposes. DO NOT USE!",
614
+ UserWarning,
615
+ )
552
616
  if isinstance(self.datasets[0], EEGWindowsDataset):
553
- raise NotImplementedError("Outdated save not implemented for new window datasets.")
617
+ raise NotImplementedError(
618
+ "Outdated save not implemented for new window datasets."
619
+ )
554
620
  if len(self.datasets) == 0:
555
621
  raise ValueError("Expect at least one dataset")
556
- if not (hasattr(self.datasets[0], 'raw') or hasattr(
557
- self.datasets[0], 'windows')):
558
- raise ValueError("dataset should have either raw or windows "
559
- "attribute")
622
+ if not (
623
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
624
+ ):
625
+ raise ValueError("dataset should have either raw or windows attribute")
560
626
  file_name_templates = ["{}-raw.fif", "{}-epo.fif"]
561
- description_file_name = os.path.join(path, 'description.json')
562
- target_file_name = os.path.join(path, 'target_name.json')
627
+ description_file_name = os.path.join(path, "description.json")
628
+ target_file_name = os.path.join(path, "target_name.json")
563
629
  if not overwrite:
564
- from braindecode.datautil.serialization import \
565
- _check_save_dir_empty # Import here to avoid circular import
630
+ from braindecode.datautil.serialization import ( # Import here to avoid circular import
631
+ _check_save_dir_empty,
632
+ )
633
+
566
634
  _check_save_dir_empty(path)
567
635
  else:
568
636
  for file_name_template in file_name_templates:
569
- file_names = glob(os.path.join(
570
- path, f"*{file_name_template.lstrip('{}')}"))
637
+ file_names = glob(
638
+ os.path.join(path, f"*{file_name_template.lstrip('{}')}")
639
+ )
571
640
  _ = [os.remove(f) for f in file_names]
572
641
  if os.path.isfile(target_file_name):
573
642
  os.remove(target_file_name)
574
643
  if os.path.isfile(description_file_name):
575
644
  os.remove(description_file_name)
576
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs',
577
- 'window_preproc_kwargs']:
578
- kwarg_path = os.path.join(path, '.'.join([kwarg_name, 'json']))
645
+ for kwarg_name in [
646
+ "raw_preproc_kwargs",
647
+ "window_kwargs",
648
+ "window_preproc_kwargs",
649
+ ]:
650
+ kwarg_path = os.path.join(path, ".".join([kwarg_name, "json"]))
579
651
  if os.path.exists(kwarg_path):
580
652
  os.remove(kwarg_path)
581
653
 
582
- is_raw = hasattr(self.datasets[0], 'raw')
654
+ is_raw = hasattr(self.datasets[0], "raw")
583
655
 
584
656
  if is_raw:
585
657
  file_name_template = file_name_templates[0]
@@ -594,21 +666,26 @@ class BaseConcatDataset(ConcatDataset):
594
666
  ds.windows.save(full_file_path, overwrite=overwrite)
595
667
 
596
668
  self.description.to_json(description_file_name)
597
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs',
598
- 'window_preproc_kwargs']:
669
+ for kwarg_name in [
670
+ "raw_preproc_kwargs",
671
+ "window_kwargs",
672
+ "window_preproc_kwargs",
673
+ ]:
599
674
  if hasattr(self, kwarg_name):
600
- kwargs_path = os.path.join(path, '.'.join([kwarg_name, 'json']))
675
+ kwargs_path = os.path.join(path, ".".join([kwarg_name, "json"]))
601
676
  kwargs = getattr(self, kwarg_name)
602
677
  if kwargs is not None:
603
- json.dump(kwargs, open(kwargs_path, 'w'))
678
+ json.dump(kwargs, open(kwargs_path, "w"))
604
679
 
605
680
  @property
606
- def description(self):
681
+ def description(self) -> pd.DataFrame:
607
682
  df = pd.DataFrame([ds.description for ds in self.datasets])
608
683
  df.reset_index(inplace=True, drop=True)
609
684
  return df
610
685
 
611
- def set_description(self, description, overwrite=False):
686
+ def set_description(
687
+ self, description: dict | pd.DataFrame, overwrite: bool = False
688
+ ):
612
689
  """Update (add or overwrite) the dataset description.
613
690
 
614
691
  Parameters
@@ -625,7 +702,7 @@ class BaseConcatDataset(ConcatDataset):
625
702
  for ds, value_ in zip(self.datasets, value):
626
703
  ds.set_description({key: value_}, overwrite=overwrite)
627
704
 
628
- def save(self, path, overwrite=False, offset=0):
705
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
629
706
  """Save datasets to files by creating one subdirectory for each dataset:
630
707
  path/
631
708
  0/
@@ -659,10 +736,10 @@ class BaseConcatDataset(ConcatDataset):
659
736
  """
660
737
  if len(self.datasets) == 0:
661
738
  raise ValueError("Expect at least one dataset")
662
- if not (hasattr(self.datasets[0], 'raw') or hasattr(
663
- self.datasets[0], 'windows')):
664
- raise ValueError("dataset should have either raw or windows "
665
- "attribute")
739
+ if not (
740
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
741
+ ):
742
+ raise ValueError("dataset should have either raw or windows attribute")
666
743
  path_contents = os.listdir(path)
667
744
  n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
668
745
  for i_ds, ds in enumerate(self.datasets):
@@ -676,9 +753,10 @@ class BaseConcatDataset(ConcatDataset):
676
753
  shutil.rmtree(sub_dir)
677
754
  else:
678
755
  raise FileExistsError(
679
- f'Subdirectory {sub_dir} already exists. Please select'
680
- f' a different directory, set overwrite=True, or '
681
- f'resolve manually.')
756
+ f"Subdirectory {sub_dir} already exists. Please select"
757
+ f" a different directory, set overwrite=True, or "
758
+ f"resolve manually."
759
+ )
682
760
  # save_dir/{i_ds+offset}/
683
761
  os.makedirs(sub_dir)
684
762
  # save_dir/{i_ds+offset}/{i_ds+offset}-{raw_or_epo}.fif
@@ -696,59 +774,67 @@ class BaseConcatDataset(ConcatDataset):
696
774
  if overwrite:
697
775
  # the following will be True for all datasets preprocessed and
698
776
  # stored in parallel with braindecode.preprocessing.preprocess
699
- if i_ds+1+offset < n_sub_dirs:
700
- warnings.warn(f"The number of saved datasets ({i_ds+1+offset}) "
701
- f"does not match the number of existing "
702
- f"subdirectories ({n_sub_dirs}). You may now "
703
- f"encounter a mix of differently preprocessed "
704
- f"datasets!", UserWarning)
777
+ if i_ds + 1 + offset < n_sub_dirs:
778
+ warnings.warn(
779
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
780
+ f"does not match the number of existing "
781
+ f"subdirectories ({n_sub_dirs}). You may now "
782
+ f"encounter a mix of differently preprocessed "
783
+ f"datasets!",
784
+ UserWarning,
785
+ )
705
786
  # if path contains files or directories that were not touched, raise
706
787
  # warning
707
788
  if path_contents:
708
- warnings.warn(f'Chosen directory {path} contains other '
709
- f'subdirectories or files {path_contents}.')
789
+ warnings.warn(
790
+ f"Chosen directory {path} contains other "
791
+ f"subdirectories or files {path_contents}."
792
+ )
710
793
 
711
794
  @staticmethod
712
795
  def _save_signals(sub_dir, ds, i_ds, offset):
713
- raw_or_epo = 'raw' if hasattr(ds, 'raw') else 'epo'
714
- fif_file_name = f'{i_ds + offset}-{raw_or_epo}.fif'
796
+ raw_or_epo = "raw" if hasattr(ds, "raw") else "epo"
797
+ fif_file_name = f"{i_ds + offset}-{raw_or_epo}.fif"
715
798
  fif_file_path = os.path.join(sub_dir, fif_file_name)
716
- raw_or_windows = 'raw' if raw_or_epo == 'raw' else 'windows'
799
+ raw_or_windows = "raw" if raw_or_epo == "raw" else "windows"
717
800
 
718
801
  # The following appears to be necessary to avoid a CI failure when
719
802
  # preprocessing WindowsDatasets with serialization enabled. The failure
720
803
  # comes from `mne.epochs._check_consistency` which ensures the Epochs's
721
804
  # object `times` attribute is not writeable.
722
- getattr(ds, raw_or_windows).times.flags['WRITEABLE'] = False
805
+ getattr(ds, raw_or_windows).times.flags["WRITEABLE"] = False
723
806
 
724
807
  getattr(ds, raw_or_windows).save(fif_file_path)
725
808
 
726
809
  @staticmethod
727
810
  def _save_metadata(sub_dir, ds):
728
- if hasattr(ds, 'metadata'):
729
- metadata_file_path = os.path.join(sub_dir, 'metadata_df.pkl')
811
+ if hasattr(ds, "metadata"):
812
+ metadata_file_path = os.path.join(sub_dir, "metadata_df.pkl")
730
813
  ds.metadata.to_pickle(metadata_file_path)
731
814
 
732
815
  @staticmethod
733
816
  def _save_description(sub_dir, description):
734
- description_file_path = os.path.join(sub_dir, 'description.json')
817
+ description_file_path = os.path.join(sub_dir, "description.json")
735
818
  description.to_json(description_file_path)
736
819
 
737
820
  @staticmethod
738
821
  def _save_kwargs(sub_dir, ds):
739
- for kwargs_name in ['raw_preproc_kwargs', 'window_kwargs',
740
- 'window_preproc_kwargs']:
822
+ for kwargs_name in [
823
+ "raw_preproc_kwargs",
824
+ "window_kwargs",
825
+ "window_preproc_kwargs",
826
+ ]:
741
827
  if hasattr(ds, kwargs_name):
742
- kwargs_file_name = '.'.join([kwargs_name, 'json'])
828
+ kwargs_file_name = ".".join([kwargs_name, "json"])
743
829
  kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
744
830
  kwargs = getattr(ds, kwargs_name)
745
831
  if kwargs is not None:
746
- with open(kwargs_file_path, 'w') as f:
832
+ with open(kwargs_file_path, "w") as f:
747
833
  json.dump(kwargs, f)
748
834
 
749
835
  @staticmethod
750
836
  def _save_target_name(sub_dir, ds):
751
- if hasattr(ds, 'target_name'):
752
- target_file_path = os.path.join(sub_dir, 'target_name.json')
753
- with open(target_file_path, 'w') as f:
754
- json.dump({'target_name': ds.target_name}, f)
837
+ if hasattr(ds, "target_name"):
838
+ target_file_path = os.path.join(sub_dir, "target_name.json")
839
+ with open(target_file_path, "w") as f:
840
+ json.dump({"target_name": ds.target_name}, f)