braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,840 @@
1
+ """
2
+ Dataset classes.
3
+ """
4
+
5
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Lukas Gemein <l.gemein@gmail.com>
7
+ # Simon Brandt <simonbrandt@protonmail.com>
8
+ # David Sabbagh <dav.sabbagh@gmail.com>
9
+ # Robin Schirrmeister <robintibor@gmail.com>
10
+ #
11
+ # License: BSD (3-clause)
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import shutil
18
+ import warnings
19
+ from collections.abc import Callable
20
+ from glob import glob
21
+ from typing import Iterable, no_type_check
22
+
23
+ import mne.io
24
+ import numpy as np
25
+ import pandas as pd
26
+ from torch.utils.data import ConcatDataset, Dataset
27
+
28
+
29
+ def _create_description(description) -> pd.Series:
30
+ if description is not None:
31
+ if not isinstance(description, pd.Series) and not isinstance(description, dict):
32
+ raise ValueError(
33
+ f"'{description}' has to be either a pandas.Series or a dict."
34
+ )
35
+ if isinstance(description, dict):
36
+ description = pd.Series(description)
37
+ return description
38
+
39
+
40
+ class BaseDataset(Dataset):
41
+ """Returns samples from an mne.io.Raw object along with a target.
42
+
43
+ Dataset which serves samples from an mne.io.Raw object along with a target.
44
+ The target is unique for the dataset, and is obtained through the
45
+ `description` attribute.
46
+
47
+ Parameters
48
+ ----------
49
+ raw : mne.io.Raw
50
+ Continuous data.
51
+ description : dict | pandas.Series | None
52
+ Holds additional description about the continuous signal / subject.
53
+ target_name : str | tuple | None
54
+ Name(s) of the index in `description` that should be used to provide the
55
+ target (e.g., to be used in a prediction task later on).
56
+ transform : callable | None
57
+ On-the-fly transform applied to the example before it is returned.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ raw: mne.io.BaseRaw,
63
+ description: dict | pd.Series | None = None,
64
+ target_name: str | tuple[str, ...] | None = None,
65
+ transform: Callable | None = None,
66
+ ):
67
+ self.raw = raw
68
+ self._description = _create_description(description)
69
+ self.transform = transform
70
+
71
+ # save target name for load/save later
72
+ self.target_name = self._target_name(target_name)
73
+
74
+ def __getitem__(self, index):
75
+ X = self.raw[:, index][0]
76
+ y = None
77
+ if self.target_name is not None:
78
+ y = self.description[self.target_name]
79
+ if isinstance(y, pd.Series):
80
+ y = y.to_list()
81
+ if self.transform is not None:
82
+ X = self.transform(X)
83
+ return X, y
84
+
85
+ def __len__(self):
86
+ return len(self.raw)
87
+
88
+ @property
89
+ def transform(self):
90
+ return self._transform
91
+
92
+ @transform.setter
93
+ def transform(self, value):
94
+ if value is not None and not callable(value):
95
+ raise ValueError("Transform needs to be a callable.")
96
+ self._transform = value
97
+
98
+ @property
99
+ def description(self) -> pd.Series:
100
+ return self._description
101
+
102
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
103
+ """Update (add or overwrite) the dataset description.
104
+
105
+ Parameters
106
+ ----------
107
+ description: dict | pd.Series
108
+ Description in the form key: value.
109
+ overwrite: bool
110
+ Has to be True if a key in description already exists in the
111
+ dataset description.
112
+ """
113
+ description = _create_description(description)
114
+ for key, value in description.items():
115
+ # if the key is already in the existing description, drop it
116
+ if self._description is not None and key in self._description:
117
+ assert overwrite, (
118
+ f"'{key}' already in description. Please "
119
+ f"rename or set overwrite to True."
120
+ )
121
+ self._description.pop(key)
122
+ if self._description is None:
123
+ self._description = description
124
+ else:
125
+ self._description = pd.concat([self.description, description])
126
+
127
+ def _target_name(self, target_name):
128
+ if target_name is not None and not isinstance(target_name, (str, tuple, list)):
129
+ raise ValueError("target_name has to be None, str, tuple or list")
130
+ if target_name is None:
131
+ return target_name
132
+ else:
133
+ # convert tuple of names or single name to list
134
+ if isinstance(target_name, tuple):
135
+ target_name = [name for name in target_name]
136
+ elif not isinstance(target_name, list):
137
+ assert isinstance(target_name, str)
138
+ target_name = [target_name]
139
+ assert isinstance(target_name, list)
140
+ # check if target name(s) can be read from description
141
+ for name in target_name:
142
+ if self.description is None or name not in self.description:
143
+ warnings.warn(
144
+ f"'{name}' not in description. '__getitem__'"
145
+ f"will fail unless an appropriate target is"
146
+ f" added to description.",
147
+ UserWarning,
148
+ )
149
+ # return a list of str if there are multiple targets and a str otherwise
150
+ return target_name if len(target_name) > 1 else target_name[0]
151
+
152
+
153
+ class EEGWindowsDataset(BaseDataset):
154
+ """Returns windows from an mne.Raw object, its window indices, along with a target.
155
+
156
+ Dataset which serves windows from an mne.Epochs object along with their
157
+ target and additional information. The `metadata` attribute of the Epochs
158
+ object must contain a column called `target`, which will be used to return
159
+ the target that corresponds to a window. Additional columns
160
+ `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
161
+ required to serve information about the windowing (e.g., useful for cropped
162
+ training).
163
+ See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
164
+ from a `BaseDataset` object.
165
+
166
+ Parameters
167
+ ----------
168
+ windows : mne.Raw or mne.Epochs (Epochs is outdated)
169
+ Windows obtained through the application of a windower to a BaseDataset
170
+ (see `braindecode.datautil.windowers`).
171
+ description : dict | pandas.Series | None
172
+ Holds additional info about the windows.
173
+ transform : callable | None
174
+ On-the-fly transform applied to a window before it is returned.
175
+ targets_from : str
176
+ Defines whether targets will be extracted from metadata or from `misc`
177
+ channels (time series targets). It can be `metadata` (default) or `channels`.
178
+ last_target_only : bool
179
+ If targets are obtained from misc channels whether all targets if the entire
180
+ (compute) window will be returned or only the last target in the window.
181
+ metadata : pandas.DataFrame
182
+ Dataframe with crop indices, so `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial`
183
+ as well as `targets`.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ raw: mne.io.BaseRaw | mne.BaseEpochs,
189
+ metadata: pd.DataFrame,
190
+ description: dict | pd.Series | None = None,
191
+ transform: Callable | None = None,
192
+ targets_from: str = "metadata",
193
+ last_target_only: bool = True,
194
+ ):
195
+ self.raw = raw
196
+ self.metadata = metadata
197
+ self._description = _create_description(description)
198
+
199
+ self.transform = transform
200
+ self.last_target_only = last_target_only
201
+ if targets_from not in ("metadata", "channels"):
202
+ raise ValueError("Wrong value for parameter `targets_from`.")
203
+ self.targets_from = targets_from
204
+ self.crop_inds = metadata.loc[
205
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
206
+ ].to_numpy()
207
+ if self.targets_from == "metadata":
208
+ self.y = metadata.loc[:, "target"].to_list()
209
+
210
+ def __getitem__(self, index: int):
211
+ """Get a window and its target.
212
+
213
+ Parameters
214
+ ----------
215
+ index : int
216
+ Index to the window (and target) to return.
217
+
218
+ Returns
219
+ -------
220
+ np.ndarray
221
+ Window of shape (n_channels, n_times).
222
+ int
223
+ Target for the windows.
224
+ np.ndarray
225
+ Crop indices.
226
+ """
227
+
228
+ # necessary to cast as list to get list of three tensors from batch,
229
+ # otherwise get single 2d-tensor...
230
+ crop_inds = self.crop_inds[index].tolist()
231
+
232
+ i_window_in_trial, i_start, i_stop = crop_inds
233
+ X = self.raw._getitem((slice(None), slice(i_start, i_stop)), return_times=False)
234
+ X = X.astype("float32")
235
+ # ensure we don't give the user the option
236
+ # to accidentally modify the underlying array
237
+ X = X.copy()
238
+ if self.transform is not None:
239
+ X = self.transform(X)
240
+ if self.targets_from == "metadata":
241
+ y = self.y[index]
242
+ else:
243
+ misc_mask = np.array(self.raw.get_channel_types()) == "misc"
244
+ if self.last_target_only:
245
+ y = X[misc_mask, -1]
246
+ else:
247
+ y = X[misc_mask, :]
248
+ # ensure we don't give the user the option
249
+ # to accidentally modify the underlying array
250
+ y = y.copy()
251
+ # remove the target channels from raw
252
+ X = X[~misc_mask, :]
253
+ return X, y, crop_inds
254
+
255
+ def __len__(self):
256
+ return len(self.crop_inds)
257
+
258
+ @property
259
+ def transform(self):
260
+ return self._transform
261
+
262
+ @transform.setter
263
+ def transform(self, value):
264
+ if value is not None and not callable(value):
265
+ raise ValueError("Transform needs to be a callable.")
266
+ self._transform = value
267
+
268
+ @property
269
+ def description(self) -> pd.Series:
270
+ return self._description
271
+
272
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
273
+ """Update (add or overwrite) the dataset description.
274
+
275
+ Parameters
276
+ ----------
277
+ description: dict | pd.Series
278
+ Description in the form key: value.
279
+ overwrite: bool
280
+ Has to be True if a key in description already exists in the
281
+ dataset description.
282
+ """
283
+ description = _create_description(description)
284
+ for key, value in description.items():
285
+ # if they key is already in the existing description, drop it
286
+ if key in self._description:
287
+ assert overwrite, (
288
+ f"'{key}' already in description. Please "
289
+ f"rename or set overwrite to True."
290
+ )
291
+ self._description.pop(key)
292
+ self._description = pd.concat([self.description, description])
293
+
294
+
295
+ class WindowsDataset(BaseDataset):
296
+ """Returns windows from an mne.Epochs object along with a target.
297
+
298
+ Dataset which serves windows from an mne.Epochs object along with their
299
+ target and additional information. The `metadata` attribute of the Epochs
300
+ object must contain a column called `target`, which will be used to return
301
+ the target that corresponds to a window. Additional columns
302
+ `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
303
+ required to serve information about the windowing (e.g., useful for cropped
304
+ training).
305
+ See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
306
+ from a `BaseDataset` object.
307
+
308
+ Parameters
309
+ ----------
310
+ windows : mne.Epochs
311
+ Windows obtained through the application of a windower to a BaseDataset
312
+ (see `braindecode.datautil.windowers`).
313
+ description : dict | pandas.Series | None
314
+ Holds additional info about the windows.
315
+ transform : callable | None
316
+ On-the-fly transform applied to a window before it is returned.
317
+ targets_from : str
318
+ Defines whether targets will be extracted from mne.Epochs metadata or mne.Epochs `misc`
319
+ channels (time series targets). It can be `metadata` (default) or `channels`.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ windows: mne.BaseEpochs,
325
+ description: dict | pd.Series | None = None,
326
+ transform: Callable | None = None,
327
+ targets_from: str = "metadata",
328
+ last_target_only: bool = True,
329
+ ):
330
+ self.windows = windows
331
+ self._description = _create_description(description)
332
+ self.transform = transform
333
+ self.last_target_only = last_target_only
334
+ if targets_from not in ("metadata", "channels"):
335
+ raise ValueError("Wrong value for parameter `targets_from`.")
336
+ self.targets_from = targets_from
337
+
338
+ self.crop_inds = self.windows.metadata.loc[
339
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
340
+ ].to_numpy()
341
+ if self.targets_from == "metadata":
342
+ self.y = self.windows.metadata.loc[:, "target"].to_list()
343
+
344
+ def __getitem__(self, index: int):
345
+ """Get a window and its target.
346
+
347
+ Parameters
348
+ ----------
349
+ index : int
350
+ Index to the window (and target) to return.
351
+
352
+ Returns
353
+ -------
354
+ np.ndarray
355
+ Window of shape (n_channels, n_times).
356
+ int
357
+ Target for the windows.
358
+ np.ndarray
359
+ Crop indices.
360
+ """
361
+ X = self.windows.get_data(item=index)[0].astype("float32")
362
+ if self.transform is not None:
363
+ X = self.transform(X)
364
+ if self.targets_from == "metadata":
365
+ y = self.y[index]
366
+ else:
367
+ misc_mask = np.array(self.windows.get_channel_types()) == "misc"
368
+ if self.last_target_only:
369
+ y = X[misc_mask, -1]
370
+ else:
371
+ y = X[misc_mask, :]
372
+ # remove the target channels from raw
373
+ X = X[~misc_mask, :]
374
+ # necessary to cast as list to get list of three tensors from batch,
375
+ # otherwise get single 2d-tensor...
376
+ crop_inds = self.crop_inds[index].tolist()
377
+ return X, y, crop_inds
378
+
379
+ def __len__(self) -> int:
380
+ return len(self.windows.events)
381
+
382
+ @property
383
+ def transform(self):
384
+ return self._transform
385
+
386
+ @transform.setter
387
+ def transform(self, value):
388
+ if value is not None and not callable(value):
389
+ raise ValueError("Transform needs to be a callable.")
390
+ self._transform = value
391
+
392
+ @property
393
+ def description(self) -> pd.Series:
394
+ return self._description
395
+
396
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
397
+ """Update (add or overwrite) the dataset description.
398
+
399
+ Parameters
400
+ ----------
401
+ description: dict | pd.Series
402
+ Description in the form key: value.
403
+ overwrite: bool
404
+ Has to be True if a key in description already exists in the
405
+ dataset description.
406
+ """
407
+ description = _create_description(description)
408
+ for key, value in description.items():
409
+ # if they key is already in the existing description, drop it
410
+ if key in self._description:
411
+ assert overwrite, (
412
+ f"'{key}' already in description. Please "
413
+ f"rename or set overwrite to True."
414
+ )
415
+ self._description.pop(key)
416
+ self._description = pd.concat([self.description, description])
417
+
418
+
419
+ class BaseConcatDataset(ConcatDataset):
420
+ """A base class for concatenated datasets.
421
+
422
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
423
+ a pandas DataFrame with additional description.
424
+
425
+ Parameters
426
+ ----------
427
+ list_of_ds : list
428
+ list of BaseDataset, BaseConcatDataset or WindowsDataset
429
+ target_transform : callable | None
430
+ Optional function to call on targets before returning them.
431
+
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ list_of_ds: list[BaseDataset | BaseConcatDataset | WindowsDataset]
437
+ | None = None,
438
+ target_transform: Callable | None = None,
439
+ ):
440
+ # if we get a list of BaseConcatDataset, get all the individual datasets
441
+ if list_of_ds and isinstance(list_of_ds[0], BaseConcatDataset):
442
+ list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
443
+ super().__init__(list_of_ds)
444
+
445
+ self.target_transform = target_transform
446
+
447
+ def _get_sequence(self, indices):
448
+ X, y = list(), list()
449
+ for ind in indices:
450
+ out_i = super().__getitem__(ind)
451
+ X.append(out_i[0])
452
+ y.append(out_i[1])
453
+
454
+ X = np.stack(X, axis=0)
455
+ y = np.array(y)
456
+
457
+ return X, y
458
+
459
+ def __getitem__(self, idx: int | list):
460
+ """
461
+ Parameters
462
+ ----------
463
+ idx : int | list
464
+ Index of window and target to return. If provided as a list of
465
+ ints, multiple windows and targets will be extracted and
466
+ concatenated. The target output can be modified on the
467
+ fly by the ``traget_transform`` parameter.
468
+ """
469
+ if isinstance(idx, Iterable): # Sample multiple windows
470
+ item = self._get_sequence(idx)
471
+ else:
472
+ item = super().__getitem__(idx)
473
+ if self.target_transform is not None:
474
+ item = item[:1] + (self.target_transform(item[1]),) + item[2:]
475
+ return item
476
+
477
+ @no_type_check # TODO, it's a mess
478
+ def split(
479
+ self,
480
+ by: str | list[int] | list[list[int]] | dict[str, list[int]] | None = None,
481
+ property: str | None = None,
482
+ split_ids: list[int] | list[list[int]] | dict[str, list[int]] | None = None,
483
+ ) -> dict[str, BaseConcatDataset]:
484
+ """Split the dataset based on information listed in its description.
485
+
486
+ The format could be based on a DataFrame or based on indices.
487
+
488
+ Parameters
489
+ ----------
490
+ by : str | list | dict
491
+ If ``by`` is a string, splitting is performed based on the
492
+ description DataFrame column with this name.
493
+ If ``by`` is a (list of) list of integers, the position in the first
494
+ list corresponds to the split id and the integers to the
495
+ datapoints of that split.
496
+ If a dict then each key will be used in the returned
497
+ splits dict and each value should be a list of int.
498
+ property : str
499
+ Some property which is listed in the info DataFrame.
500
+ split_ids : list | dict
501
+ List of indices to be combined in a subset.
502
+ It can be a list of int or a list of list of int.
503
+
504
+ Returns
505
+ -------
506
+ splits : dict
507
+ A dictionary with the name of the split (a string) as key and the
508
+ dataset as value.
509
+ """
510
+
511
+ args_not_none = [by is not None, property is not None, split_ids is not None]
512
+ if sum(args_not_none) != 1:
513
+ raise ValueError("Splitting requires exactly one argument.")
514
+
515
+ if property is not None or split_ids is not None:
516
+ warnings.warn(
517
+ "Keyword arguments `property` and `split_ids` "
518
+ "are deprecated and will be removed in the future. "
519
+ "Use `by` instead.",
520
+ DeprecationWarning,
521
+ )
522
+ by = property if property is not None else split_ids
523
+ if isinstance(by, str):
524
+ split_ids = {
525
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
526
+ }
527
+ elif isinstance(by, dict):
528
+ split_ids = by
529
+ else:
530
+ # assume list(int)
531
+ if not isinstance(by[0], list):
532
+ by = [by]
533
+ # assume list(list(int))
534
+ split_ids = {split_i: split for split_i, split in enumerate(by)}
535
+
536
+ return {
537
+ str(split_name): BaseConcatDataset(
538
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
539
+ target_transform=self.target_transform,
540
+ )
541
+ for split_name, ds_inds in split_ids.items()
542
+ }
543
+
544
+ def get_metadata(self) -> pd.DataFrame:
545
+ """Concatenate the metadata and description of the wrapped Epochs.
546
+
547
+ Returns
548
+ -------
549
+ metadata : pd.DataFrame
550
+ DataFrame containing as many rows as there are windows in the
551
+ BaseConcatDataset, with the metadata and description information
552
+ for each window.
553
+ """
554
+ if not all(
555
+ [
556
+ isinstance(ds, (WindowsDataset, EEGWindowsDataset))
557
+ for ds in self.datasets
558
+ ]
559
+ ):
560
+ raise TypeError(
561
+ "Metadata dataframe can only be computed when all "
562
+ "datasets are WindowsDataset."
563
+ )
564
+
565
+ all_dfs = list()
566
+ for ds in self.datasets:
567
+ if hasattr(ds, "windows"):
568
+ df = ds.windows.metadata
569
+ else:
570
+ df = ds.metadata
571
+ for k, v in ds.description.items():
572
+ df[k] = v
573
+ all_dfs.append(df)
574
+
575
+ return pd.concat(all_dfs)
576
+
577
+ @property
578
+ def transform(self):
579
+ return [ds.transform for ds in self.datasets]
580
+
581
+ @transform.setter
582
+ def transform(self, fn):
583
+ for i in range(len(self.datasets)):
584
+ self.datasets[i].transform = fn
585
+
586
+ @property
587
+ def target_transform(self):
588
+ return self._target_transform
589
+
590
+ @target_transform.setter
591
+ def target_transform(self, fn):
592
+ if not (callable(fn) or fn is None):
593
+ raise TypeError("target_transform must be a callable.")
594
+ self._target_transform = fn
595
+
596
+ def _outdated_save(self, path, overwrite=False):
597
+ """This is a copy of the old saving function, that had inconsistent
598
+ functionality for BaseDataset and WindowsDataset. It only exists to
599
+ assure backwards compatibility by still being able to run the old tests.
600
+
601
+ Save dataset to files.
602
+
603
+ Parameters
604
+ ----------
605
+ path : str
606
+ Directory to which .fif / -epo.fif and .json files are stored.
607
+ overwrite : bool
608
+ Whether to delete old files (.json, .fif, -epo.fif) in specified
609
+ directory prior to saving.
610
+ """
611
+ warnings.warn(
612
+ "This function only exists for backwards compatibility "
613
+ "purposes. DO NOT USE!",
614
+ UserWarning,
615
+ )
616
+ if isinstance(self.datasets[0], EEGWindowsDataset):
617
+ raise NotImplementedError(
618
+ "Outdated save not implemented for new window datasets."
619
+ )
620
+ if len(self.datasets) == 0:
621
+ raise ValueError("Expect at least one dataset")
622
+ if not (
623
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
624
+ ):
625
+ raise ValueError("dataset should have either raw or windows attribute")
626
+ file_name_templates = ["{}-raw.fif", "{}-epo.fif"]
627
+ description_file_name = os.path.join(path, "description.json")
628
+ target_file_name = os.path.join(path, "target_name.json")
629
+ if not overwrite:
630
+ from braindecode.datautil.serialization import ( # Import here to avoid circular import
631
+ _check_save_dir_empty,
632
+ )
633
+
634
+ _check_save_dir_empty(path)
635
+ else:
636
+ for file_name_template in file_name_templates:
637
+ file_names = glob(
638
+ os.path.join(path, f"*{file_name_template.lstrip('{}')}")
639
+ )
640
+ _ = [os.remove(f) for f in file_names]
641
+ if os.path.isfile(target_file_name):
642
+ os.remove(target_file_name)
643
+ if os.path.isfile(description_file_name):
644
+ os.remove(description_file_name)
645
+ for kwarg_name in [
646
+ "raw_preproc_kwargs",
647
+ "window_kwargs",
648
+ "window_preproc_kwargs",
649
+ ]:
650
+ kwarg_path = os.path.join(path, ".".join([kwarg_name, "json"]))
651
+ if os.path.exists(kwarg_path):
652
+ os.remove(kwarg_path)
653
+
654
+ is_raw = hasattr(self.datasets[0], "raw")
655
+
656
+ if is_raw:
657
+ file_name_template = file_name_templates[0]
658
+ else:
659
+ file_name_template = file_name_templates[1]
660
+
661
+ for i_ds, ds in enumerate(self.datasets):
662
+ full_file_path = os.path.join(path, file_name_template.format(i_ds))
663
+ if is_raw:
664
+ ds.raw.save(full_file_path, overwrite=overwrite)
665
+ else:
666
+ ds.windows.save(full_file_path, overwrite=overwrite)
667
+
668
+ self.description.to_json(description_file_name)
669
+ for kwarg_name in [
670
+ "raw_preproc_kwargs",
671
+ "window_kwargs",
672
+ "window_preproc_kwargs",
673
+ ]:
674
+ if hasattr(self, kwarg_name):
675
+ kwargs_path = os.path.join(path, ".".join([kwarg_name, "json"]))
676
+ kwargs = getattr(self, kwarg_name)
677
+ if kwargs is not None:
678
+ json.dump(kwargs, open(kwargs_path, "w"))
679
+
680
+ @property
681
+ def description(self) -> pd.DataFrame:
682
+ df = pd.DataFrame([ds.description for ds in self.datasets])
683
+ df.reset_index(inplace=True, drop=True)
684
+ return df
685
+
686
+ def set_description(
687
+ self, description: dict | pd.DataFrame, overwrite: bool = False
688
+ ):
689
+ """Update (add or overwrite) the dataset description.
690
+
691
+ Parameters
692
+ ----------
693
+ description: dict | pd.DataFrame
694
+ Description in the form key: value where the length of the value
695
+ has to match the number of datasets.
696
+ overwrite: bool
697
+ Has to be True if a key in description already exists in the
698
+ dataset description.
699
+ """
700
+ description = pd.DataFrame(description)
701
+ for key, value in description.items():
702
+ for ds, value_ in zip(self.datasets, value):
703
+ ds.set_description({key: value_}, overwrite=overwrite)
704
+
705
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
706
+ """Save datasets to files by creating one subdirectory for each dataset:
707
+ path/
708
+ 0/
709
+ 0-raw.fif | 0-epo.fif
710
+ description.json
711
+ raw_preproc_kwargs.json (if raws were preprocessed)
712
+ window_kwargs.json (if this is a windowed dataset)
713
+ window_preproc_kwargs.json (if windows were preprocessed)
714
+ target_name.json (if target_name is not None and dataset is raw)
715
+ 1/
716
+ 1-raw.fif | 1-epo.fif
717
+ description.json
718
+ raw_preproc_kwargs.json (if raws were preprocessed)
719
+ window_kwargs.json (if this is a windowed dataset)
720
+ window_preproc_kwargs.json (if windows were preprocessed)
721
+ target_name.json (if target_name is not None and dataset is raw)
722
+
723
+ Parameters
724
+ ----------
725
+ path : str
726
+ Directory in which subdirectories are created to store
727
+ -raw.fif | -epo.fif and .json files to.
728
+ overwrite : bool
729
+ Whether to delete old subdirectories that will be saved to in this
730
+ call.
731
+ offset : int
732
+ If provided, the integer is added to the id of the dataset in the
733
+ concat. This is useful in the setting of very large datasets, where
734
+ one dataset has to be processed and saved at a time to account for
735
+ its original position.
736
+ """
737
+ if len(self.datasets) == 0:
738
+ raise ValueError("Expect at least one dataset")
739
+ if not (
740
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
741
+ ):
742
+ raise ValueError("dataset should have either raw or windows attribute")
743
+ path_contents = os.listdir(path)
744
+ n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
745
+ for i_ds, ds in enumerate(self.datasets):
746
+ # remove subdirectory from list of untouched files / subdirectories
747
+ if str(i_ds + offset) in path_contents:
748
+ path_contents.remove(str(i_ds + offset))
749
+ # save_dir/i_ds/
750
+ sub_dir = os.path.join(path, str(i_ds + offset))
751
+ if os.path.exists(sub_dir):
752
+ if overwrite:
753
+ shutil.rmtree(sub_dir)
754
+ else:
755
+ raise FileExistsError(
756
+ f"Subdirectory {sub_dir} already exists. Please select"
757
+ f" a different directory, set overwrite=True, or "
758
+ f"resolve manually."
759
+ )
760
+ # save_dir/{i_ds+offset}/
761
+ os.makedirs(sub_dir)
762
+ # save_dir/{i_ds+offset}/{i_ds+offset}-{raw_or_epo}.fif
763
+ self._save_signals(sub_dir, ds, i_ds, offset)
764
+ # save_dir/{i_ds+offset}/metadata_df.pkl
765
+ self._save_metadata(sub_dir, ds)
766
+ # save_dir/{i_ds+offset}/description.json
767
+ self._save_description(sub_dir, ds.description)
768
+ # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
769
+ # save_dir/{i_ds+offset}/window_kwargs.json
770
+ # save_dir/{i_ds+offset}/window_preproc_kwargs.json
771
+ self._save_kwargs(sub_dir, ds)
772
+ # save_dir/{i_ds+offset}/target_name.json
773
+ self._save_target_name(sub_dir, ds)
774
+ if overwrite:
775
+ # the following will be True for all datasets preprocessed and
776
+ # stored in parallel with braindecode.preprocessing.preprocess
777
+ if i_ds + 1 + offset < n_sub_dirs:
778
+ warnings.warn(
779
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
780
+ f"does not match the number of existing "
781
+ f"subdirectories ({n_sub_dirs}). You may now "
782
+ f"encounter a mix of differently preprocessed "
783
+ f"datasets!",
784
+ UserWarning,
785
+ )
786
+ # if path contains files or directories that were not touched, raise
787
+ # warning
788
+ if path_contents:
789
+ warnings.warn(
790
+ f"Chosen directory {path} contains other "
791
+ f"subdirectories or files {path_contents}."
792
+ )
793
+
794
+ @staticmethod
795
+ def _save_signals(sub_dir, ds, i_ds, offset):
796
+ raw_or_epo = "raw" if hasattr(ds, "raw") else "epo"
797
+ fif_file_name = f"{i_ds + offset}-{raw_or_epo}.fif"
798
+ fif_file_path = os.path.join(sub_dir, fif_file_name)
799
+ raw_or_windows = "raw" if raw_or_epo == "raw" else "windows"
800
+
801
+ # The following appears to be necessary to avoid a CI failure when
802
+ # preprocessing WindowsDatasets with serialization enabled. The failure
803
+ # comes from `mne.epochs._check_consistency` which ensures the Epochs's
804
+ # object `times` attribute is not writeable.
805
+ getattr(ds, raw_or_windows).times.flags["WRITEABLE"] = False
806
+
807
+ getattr(ds, raw_or_windows).save(fif_file_path)
808
+
809
+ @staticmethod
810
+ def _save_metadata(sub_dir, ds):
811
+ if hasattr(ds, "metadata"):
812
+ metadata_file_path = os.path.join(sub_dir, "metadata_df.pkl")
813
+ ds.metadata.to_pickle(metadata_file_path)
814
+
815
+ @staticmethod
816
+ def _save_description(sub_dir, description):
817
+ description_file_path = os.path.join(sub_dir, "description.json")
818
+ description.to_json(description_file_path)
819
+
820
+ @staticmethod
821
+ def _save_kwargs(sub_dir, ds):
822
+ for kwargs_name in [
823
+ "raw_preproc_kwargs",
824
+ "window_kwargs",
825
+ "window_preproc_kwargs",
826
+ ]:
827
+ if hasattr(ds, kwargs_name):
828
+ kwargs_file_name = ".".join([kwargs_name, "json"])
829
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
830
+ kwargs = getattr(ds, kwargs_name)
831
+ if kwargs is not None:
832
+ with open(kwargs_file_path, "w") as f:
833
+ json.dump(kwargs, f)
834
+
835
+ @staticmethod
836
+ def _save_target_name(sub_dir, ds):
837
+ if hasattr(ds, "target_name"):
838
+ target_file_path = os.path.join(sub_dir, "target_name.json")
839
+ with open(target_file_path, "w") as f:
840
+ json.dump({"target_name": ds.target_name}, f)