braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,823 @@
1
+ """Dataset classes."""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Robin Schirrmeister <robintibor@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import os
15
+ import shutil
16
+ import warnings
17
+ from abc import abstractmethod
18
+ from collections.abc import Callable
19
+ from glob import glob
20
+ from typing import Any, Generic, Iterable, no_type_check
21
+
22
+ import mne.io
23
+ import numpy as np
24
+ import pandas as pd
25
+ from mne.utils.docs import deprecated
26
+ from torch.utils.data import ConcatDataset, Dataset
27
+ from typing_extensions import TypeVar
28
+
29
+ from .bids.hub import HubDatasetMixin
30
+ from .registry import register_dataset
31
+
32
+
33
+ def _create_description(description) -> pd.Series:
34
+ if description is not None:
35
+ if not isinstance(description, pd.Series) and not isinstance(description, dict):
36
+ raise ValueError(
37
+ f"'{description}' has to be either a pandas.Series or a dict."
38
+ )
39
+ if isinstance(description, dict):
40
+ description = pd.Series(description)
41
+ return description
42
+
43
+
44
+ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]]):
45
+ def __init__(
46
+ self,
47
+ description: dict | pd.Series | None = None,
48
+ transform: Callable | None = None,
49
+ ):
50
+ self._description = _create_description(description)
51
+ self.transform = transform
52
+
53
+ @abstractmethod
54
+ def __len__(self) -> int:
55
+ pass
56
+
57
+ @property
58
+ def description(self) -> pd.Series:
59
+ return self._description
60
+
61
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
62
+ """Update (add or overwrite) the dataset description.
63
+
64
+ Parameters
65
+ ----------
66
+ description : dict | pd.Series
67
+ Description in the form key: value.
68
+ overwrite : bool
69
+ Has to be True if a key in description already exists in the
70
+ dataset description.
71
+ """
72
+ description = _create_description(description)
73
+ if self.description is None:
74
+ self._description = description
75
+ else:
76
+ for key, value in description.items():
77
+ # if the key is already in the existing description, drop it
78
+ if key in self._description:
79
+ assert overwrite, (
80
+ f"'{key}' already in description. Please "
81
+ f"rename or set overwrite to True."
82
+ )
83
+ self._description.pop(key)
84
+ self._description = pd.concat([self.description, description])
85
+
86
+ @property
87
+ def transform(self) -> Callable | None:
88
+ return self._transform
89
+
90
+ @transform.setter
91
+ def transform(self, value: Callable | None):
92
+ if value is not None and not callable(value):
93
+ raise ValueError("Transform needs to be a callable.")
94
+ self._transform = value
95
+
96
+
97
+ # Type of the datasets contained in BaseConcatDataset
98
+ T = TypeVar("T", bound=RecordDataset)
99
+
100
+
101
+ @register_dataset
102
+ class RawDataset(RecordDataset):
103
+ """Returns samples from an mne.io.Raw object along with a target.
104
+
105
+ Dataset which serves samples from an mne.io.Raw object along with a target.
106
+ The target is unique for the dataset, and is obtained through the
107
+ `description` attribute.
108
+
109
+ Parameters
110
+ ----------
111
+ raw : mne.io.Raw
112
+ Continuous data.
113
+ description : dict | pandas.Series | None
114
+ Holds additional description about the continuous signal / subject.
115
+ target_name : str | tuple | None
116
+ Name(s) of the index in `description` that should be used to provide the
117
+ target (e.g., to be used in a prediction task later on).
118
+ transform : callable | None
119
+ On-the-fly transform applied to the example before it is returned.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ raw: mne.io.BaseRaw,
125
+ description: dict | pd.Series | None = None,
126
+ target_name: str | tuple[str, ...] | None = None,
127
+ transform: Callable | None = None,
128
+ ):
129
+ super().__init__(description, transform)
130
+ self.raw = raw
131
+
132
+ # save target name for load/save later
133
+ self.target_name = self._target_name(target_name)
134
+ self.raw_preproc_kwargs: list[dict[str, Any]] = []
135
+
136
+ def __getitem__(self, index):
137
+ X = self.raw[:, index][0]
138
+ y = None
139
+ if self.target_name is not None:
140
+ y = self.description[self.target_name]
141
+ if isinstance(y, pd.Series):
142
+ y = y.to_list()
143
+ if self.transform is not None:
144
+ X = self.transform(X)
145
+ return X, y
146
+
147
+ def __len__(self):
148
+ return len(self.raw)
149
+
150
+ def _target_name(self, target_name):
151
+ if target_name is not None and not isinstance(target_name, (str, tuple, list)):
152
+ raise ValueError("target_name has to be None, str, tuple or list")
153
+ if target_name is None:
154
+ return target_name
155
+ else:
156
+ # convert tuple of names or single name to list
157
+ if isinstance(target_name, tuple):
158
+ target_name = [name for name in target_name]
159
+ elif not isinstance(target_name, list):
160
+ assert isinstance(target_name, str)
161
+ target_name = [target_name]
162
+ assert isinstance(target_name, list)
163
+ # check if target name(s) can be read from description
164
+ for name in target_name:
165
+ if self.description is None or name not in self.description:
166
+ warnings.warn(
167
+ f"'{name}' not in description. '__getitem__'"
168
+ f"will fail unless an appropriate target is"
169
+ f" added to description.",
170
+ UserWarning,
171
+ )
172
+ # return a list of str if there are multiple targets and a str otherwise
173
+ return target_name if len(target_name) > 1 else target_name[0]
174
+
175
+
176
+ @deprecated(
177
+ "The BaseDataset class is deprecated. "
178
+ "If you want to instantiate a dataset containing raws, use RawDataset instead. "
179
+ "If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
180
+ "use the RecordDataset class instead."
181
+ )
182
+ @register_dataset
183
+ class BaseDataset(RawDataset):
184
+ pass
185
+
186
+
187
+ @register_dataset
188
+ class EEGWindowsDataset(RecordDataset):
189
+ """Returns windows from an mne.Raw object, its window indices, along with a target.
190
+
191
+ Dataset which serves windows from an mne.Epochs object along with their
192
+ target and additional information. The `metadata` attribute of the Epochs
193
+ object must contain a column called `target`, which will be used to return
194
+ the target that corresponds to a window. Additional columns
195
+ `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
196
+ required to serve information about the windowing (e.g., useful for cropped
197
+ training).
198
+ See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
199
+ from a `RawDataset` object.
200
+
201
+ Parameters
202
+ ----------
203
+ windows : mne.Raw or mne.Epochs (Epochs is outdated)
204
+ Windows obtained through the application of a windower to a ``RawDataset``
205
+ (see `braindecode.datautil.windowers`).
206
+ description : dict | pandas.Series | None
207
+ Holds additional info about the windows.
208
+ transform : callable | None
209
+ On-the-fly transform applied to a window before it is returned.
210
+ targets_from : str
211
+ Defines whether targets will be extracted from metadata or from `misc`
212
+ channels (time series targets). It can be `metadata` (default) or `channels`.
213
+ last_target_only : bool
214
+ If targets are obtained from misc channels whether all targets if the entire
215
+ (compute) window will be returned or only the last target in the window.
216
+ metadata : pandas.DataFrame
217
+ Dataframe with crop indices, so `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial`
218
+ as well as `targets`.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ raw: mne.io.BaseRaw,
224
+ metadata: pd.DataFrame,
225
+ description: dict | pd.Series | None = None,
226
+ transform: Callable | None = None,
227
+ targets_from: str = "metadata",
228
+ last_target_only: bool = True,
229
+ ):
230
+ super().__init__(description, transform)
231
+ self.raw = raw
232
+ self.metadata = metadata
233
+
234
+ self.last_target_only = last_target_only
235
+ if targets_from not in ("metadata", "channels"):
236
+ raise ValueError("Wrong value for parameter `targets_from`.")
237
+ self.targets_from = targets_from
238
+ self.crop_inds = metadata.loc[
239
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
240
+ ].to_numpy()
241
+ if self.targets_from == "metadata":
242
+ self.y = metadata.loc[:, "target"].to_list()
243
+ self.raw_preproc_kwargs: list[dict[str, Any]] = []
244
+
245
+ def __getitem__(self, index: int):
246
+ """Get a window and its target.
247
+
248
+ Parameters
249
+ ----------
250
+ index : int
251
+ Index to the window (and target) to return.
252
+
253
+ Returns
254
+ -------
255
+ np.ndarray
256
+ Window of shape (n_channels, n_times).
257
+ int
258
+ Target for the windows.
259
+ np.ndarray
260
+ Crop indices.
261
+ """
262
+
263
+ # necessary to cast as list to get list of three tensors from batch,
264
+ # otherwise get single 2d-tensor...
265
+ crop_inds = self.crop_inds[index].tolist()
266
+
267
+ i_window_in_trial, i_start, i_stop = crop_inds
268
+ X = self.raw._getitem((slice(None), slice(i_start, i_stop)), return_times=False)
269
+ X = X.astype("float32")
270
+ # ensure we don't give the user the option
271
+ # to accidentally modify the underlying array
272
+ X = X.copy()
273
+ if self.transform is not None:
274
+ X = self.transform(X)
275
+ if self.targets_from == "metadata":
276
+ y = self.y[index]
277
+ else:
278
+ misc_mask = np.array(self.raw.get_channel_types()) == "misc"
279
+ if self.last_target_only:
280
+ y = X[misc_mask, -1]
281
+ else:
282
+ y = X[misc_mask, :]
283
+ # ensure we don't give the user the option
284
+ # to accidentally modify the underlying array
285
+ y = y.copy()
286
+ # remove the target channels from raw
287
+ X = X[~misc_mask, :]
288
+ return X, y, crop_inds
289
+
290
+ def __len__(self):
291
+ return len(self.crop_inds)
292
+
293
+
294
+ @register_dataset
295
+ class WindowsDataset(RecordDataset):
296
+ """Returns windows from an mne.Epochs object along with a target.
297
+
298
+ Dataset which serves windows from an mne.Epochs object along with their
299
+ target and additional information. The `metadata` attribute of the Epochs
300
+ object must contain a column called `target`, which will be used to return
301
+ the target that corresponds to a window. Additional columns
302
+ `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
303
+ required to serve information about the windowing (e.g., useful for cropped
304
+ training).
305
+ See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
306
+ from a ``RawDataset`` object.
307
+
308
+ Parameters
309
+ ----------
310
+ windows : mne.Epochs
311
+ Windows obtained through the application of a windower to a RawDataset
312
+ (see `braindecode.datautil.windowers`).
313
+ description : dict | pandas.Series | None
314
+ Holds additional info about the windows.
315
+ transform : callable | None
316
+ On-the-fly transform applied to a window before it is returned.
317
+ targets_from : str
318
+ Defines whether targets will be extracted from mne.Epochs metadata or mne.Epochs `misc`
319
+ channels (time series targets). It can be `metadata` (default) or `channels`.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ windows: mne.BaseEpochs,
325
+ description: dict | pd.Series | None = None,
326
+ transform: Callable | None = None,
327
+ targets_from: str = "metadata",
328
+ last_target_only: bool = True,
329
+ ):
330
+ super().__init__(description, transform)
331
+ self.windows = windows
332
+ self.last_target_only = last_target_only
333
+ if targets_from not in ("metadata", "channels"):
334
+ raise ValueError("Wrong value for parameter `targets_from`.")
335
+ self.targets_from = targets_from
336
+
337
+ metadata = self.windows.metadata
338
+ assert metadata is not None, "WindowsDataset requires windows with metadata."
339
+ self.crop_inds = metadata.loc[
340
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
341
+ ].to_numpy()
342
+ if self.targets_from == "metadata":
343
+ self.y = metadata.loc[:, "target"].to_list()
344
+ self.raw_preproc_kwargs: list[dict[str, Any]] = []
345
+ self.window_preproc_kwargs: list[dict[str, Any]] = []
346
+
347
+ def __getitem__(self, index: int):
348
+ """Get a window and its target.
349
+
350
+ Parameters
351
+ ----------
352
+ index : int
353
+ Index to the window (and target) to return.
354
+
355
+ Returns
356
+ -------
357
+ np.ndarray
358
+ Window of shape (n_channels, n_times).
359
+ int
360
+ Target for the windows.
361
+ np.ndarray
362
+ Crop indices.
363
+ """
364
+ X = self.windows.get_data(item=index)[0].astype("float32")
365
+ if self.transform is not None:
366
+ X = self.transform(X)
367
+ if self.targets_from == "metadata":
368
+ y = self.y[index]
369
+ else:
370
+ misc_mask = np.array(self.windows.get_channel_types()) == "misc"
371
+ if self.last_target_only:
372
+ y = X[misc_mask, -1]
373
+ else:
374
+ y = X[misc_mask, :]
375
+ # remove the target channels from raw
376
+ X = X[~misc_mask, :]
377
+ # necessary to cast as list to get list of three tensors from batch,
378
+ # otherwise get single 2d-tensor...
379
+ crop_inds = self.crop_inds[index].tolist()
380
+ return X, y, crop_inds
381
+
382
+ def __len__(self) -> int:
383
+ return len(self.windows.events)
384
+
385
+
386
+ @register_dataset
387
+ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
388
+ """A base class for concatenated datasets.
389
+
390
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
391
+ a pandas DataFrame with additional description.
392
+
393
+ Includes Hugging Face Hub integration via HubDatasetMixin for
394
+ uploading and downloading datasets.
395
+
396
+ Parameters
397
+ ----------
398
+ list_of_ds : list
399
+ list of RecordDataset
400
+ target_transform : callable | None
401
+ Optional function to call on targets before returning them.
402
+ """
403
+
404
+ datasets: list[T]
405
+
406
+ def __init__(
407
+ self,
408
+ list_of_ds: list[T | BaseConcatDataset[T]],
409
+ target_transform: Callable | None = None,
410
+ ):
411
+ # if we get a list of BaseConcatDataset, get all the individual datasets
412
+ flattened_list_of_ds: list[T] = []
413
+ for ds in list_of_ds:
414
+ if isinstance(ds, BaseConcatDataset):
415
+ flattened_list_of_ds.extend(ds.datasets)
416
+ else:
417
+ flattened_list_of_ds.append(ds)
418
+ super().__init__(flattened_list_of_ds)
419
+
420
+ self.target_transform = target_transform
421
+
422
+ def _get_sequence(self, indices):
423
+ X, y = list(), list()
424
+ for ind in indices:
425
+ out_i = super().__getitem__(ind)
426
+ X.append(out_i[0])
427
+ y.append(out_i[1])
428
+
429
+ X = np.stack(X, axis=0)
430
+ y = np.array(y)
431
+
432
+ return X, y
433
+
434
+ def __getitem__(self, idx: int | list):
435
+ """
436
+ ---
437
+
438
+ idx : int | list
439
+ Index of window and target to return. If provided as a list of
440
+ ints, multiple windows and targets will be extracted and
441
+ concatenated. The target output can be modified on the
442
+ fly by the ``traget_transform`` parameter.
443
+ """
444
+ if isinstance(idx, Iterable): # Sample multiple windows
445
+ item = self._get_sequence(idx)
446
+ else:
447
+ item = super().__getitem__(idx)
448
+ if self.target_transform is not None:
449
+ item = item[:1] + (self.target_transform(item[1]),) + item[2:]
450
+ return item
451
+
452
+ @no_type_check # TODO, it's a mess
453
+ def split(
454
+ self,
455
+ by: str | list[int] | list[list[int]] | dict[str, list[int]] | None = None,
456
+ property: str | None = None,
457
+ split_ids: list[int] | list[list[int]] | dict[str, list[int]] | None = None,
458
+ ) -> dict[str, BaseConcatDataset]:
459
+ """Split the dataset based on information listed in its description.
460
+
461
+ The format could be based on a DataFrame or based on indices.
462
+
463
+ Parameters
464
+ ----------
465
+ by : str | list | dict
466
+ If ``by`` is a string, splitting is performed based on the
467
+ description DataFrame column with this name.
468
+ If ``by`` is a (list of) list of integers, the position in the first
469
+ list corresponds to the split id and the integers to the
470
+ datapoints of that split.
471
+ If a dict then each key will be used in the returned
472
+ splits dict and each value should be a list of int.
473
+ property : str
474
+ Some property which is listed in the info DataFrame.
475
+ split_ids : list | dict
476
+ List of indices to be combined in a subset.
477
+ It can be a list of int or a list of list of int.
478
+
479
+ Returns
480
+ -------
481
+ splits : dict
482
+ A dictionary with the name of the split (a string) as key and the
483
+ dataset as value.
484
+ """
485
+
486
+ args_not_none = [by is not None, property is not None, split_ids is not None]
487
+ if sum(args_not_none) != 1:
488
+ raise ValueError("Splitting requires exactly one argument.")
489
+
490
+ if property is not None or split_ids is not None:
491
+ warnings.warn(
492
+ "Keyword arguments `property` and `split_ids` "
493
+ "are deprecated and will be removed in the future. "
494
+ "Use `by` instead.",
495
+ DeprecationWarning,
496
+ )
497
+ by = property if property is not None else split_ids
498
+ if isinstance(by, str):
499
+ split_ids = {
500
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
501
+ }
502
+ elif isinstance(by, dict):
503
+ split_ids = by
504
+ else:
505
+ # assume list(int)
506
+ if not isinstance(by[0], list):
507
+ by = [by]
508
+ # assume list(list(int))
509
+ split_ids = {split_i: split for split_i, split in enumerate(by)}
510
+
511
+ return {
512
+ str(split_name): BaseConcatDataset(
513
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
514
+ target_transform=self.target_transform,
515
+ )
516
+ for split_name, ds_inds in split_ids.items()
517
+ }
518
+
519
+ def get_metadata(self) -> pd.DataFrame:
520
+ """Concatenate the metadata and description of the wrapped Epochs.
521
+
522
+ Returns
523
+ -------
524
+ metadata : pd.DataFrame
525
+ DataFrame containing as many rows as there are windows in the
526
+ BaseConcatDataset, with the metadata and description information
527
+ for each window.
528
+ """
529
+ if not all(
530
+ [
531
+ isinstance(ds, (WindowsDataset, EEGWindowsDataset))
532
+ for ds in self.datasets
533
+ ]
534
+ ):
535
+ raise TypeError(
536
+ "Metadata dataframe can only be computed when all "
537
+ "datasets are WindowsDataset."
538
+ )
539
+
540
+ all_dfs = list()
541
+ for ds in self.datasets:
542
+ if hasattr(ds, "windows"):
543
+ df = ds.windows.metadata
544
+ else:
545
+ df = ds.metadata
546
+ for k, v in ds.description.items():
547
+ df[k] = v
548
+ all_dfs.append(df)
549
+
550
+ return pd.concat(all_dfs)
551
+
552
+ @property
553
+ def transform(self):
554
+ return [ds.transform for ds in self.datasets]
555
+
556
+ @transform.setter
557
+ def transform(self, fn):
558
+ for i in range(len(self.datasets)):
559
+ self.datasets[i].transform = fn
560
+
561
+ @property
562
+ def target_transform(self):
563
+ return self._target_transform
564
+
565
+ @target_transform.setter
566
+ def target_transform(self, fn):
567
+ if not (callable(fn) or fn is None):
568
+ raise TypeError("target_transform must be a callable.")
569
+ self._target_transform = fn
570
+
571
+ def _outdated_save(self, path, overwrite=False):
572
+ """This is a copy of the old saving function, that had inconsistent.
573
+
574
+ functionality for BaseDataset and WindowsDataset. It only exists to
575
+ assure backwards compatibility by still being able to run the old tests.
576
+
577
+ Save dataset to files.
578
+
579
+ Parameters
580
+ ----------
581
+ path : str
582
+ Directory to which .fif / -epo.fif and .json files are stored.
583
+ overwrite : bool
584
+ Whether to delete old files (.json, .fif, -epo.fif) in specified
585
+ directory prior to saving.
586
+ """
587
+ warnings.warn(
588
+ "This function only exists for backwards compatibility "
589
+ "purposes. DO NOT USE!",
590
+ UserWarning,
591
+ )
592
+ if isinstance(self.datasets[0], EEGWindowsDataset):
593
+ raise NotImplementedError(
594
+ "Outdated save not implemented for new window datasets."
595
+ )
596
+ if len(self.datasets) == 0:
597
+ raise ValueError("Expect at least one dataset")
598
+ if not (
599
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
600
+ ):
601
+ raise ValueError("dataset should have either raw or windows attribute")
602
+ file_name_templates = ["{}-raw.fif", "{}-epo.fif"]
603
+ description_file_name = os.path.join(path, "description.json")
604
+ target_file_name = os.path.join(path, "target_name.json")
605
+ if not overwrite:
606
+ from braindecode.datautil.serialization import ( # Import here to avoid circular import
607
+ _check_save_dir_empty,
608
+ )
609
+
610
+ _check_save_dir_empty(path)
611
+ else:
612
+ for file_name_template in file_name_templates:
613
+ file_names = glob(
614
+ os.path.join(path, f"*{file_name_template.lstrip('{}')}")
615
+ )
616
+ _ = [os.remove(f) for f in file_names]
617
+ if os.path.isfile(target_file_name):
618
+ os.remove(target_file_name)
619
+ if os.path.isfile(description_file_name):
620
+ os.remove(description_file_name)
621
+ for kwarg_name in [
622
+ "raw_preproc_kwargs",
623
+ "window_kwargs",
624
+ "window_preproc_kwargs",
625
+ ]:
626
+ kwarg_path = os.path.join(path, ".".join([kwarg_name, "json"]))
627
+ if os.path.exists(kwarg_path):
628
+ os.remove(kwarg_path)
629
+
630
+ is_raw = hasattr(self.datasets[0], "raw")
631
+
632
+ if is_raw:
633
+ file_name_template = file_name_templates[0]
634
+ else:
635
+ file_name_template = file_name_templates[1]
636
+
637
+ for i_ds, ds in enumerate(self.datasets):
638
+ full_file_path = os.path.join(path, file_name_template.format(i_ds))
639
+ if is_raw:
640
+ ds.raw.save(full_file_path, overwrite=overwrite)
641
+ else:
642
+ ds.windows.save(full_file_path, overwrite=overwrite)
643
+
644
+ self.description.to_json(description_file_name)
645
+ for kwarg_name in [
646
+ "raw_preproc_kwargs",
647
+ "window_kwargs",
648
+ "window_preproc_kwargs",
649
+ ]:
650
+ if hasattr(self, kwarg_name):
651
+ kwargs_path = os.path.join(path, ".".join([kwarg_name, "json"]))
652
+ kwargs = getattr(self, kwarg_name)
653
+ if kwargs is not None:
654
+ json.dump(kwargs, open(kwargs_path, "w"))
655
+
656
+ @property
657
+ def description(self) -> pd.DataFrame:
658
+ df = pd.DataFrame([ds.description for ds in self.datasets])
659
+ df.reset_index(inplace=True, drop=True)
660
+ return df
661
+
662
+ def set_description(
663
+ self, description: dict | pd.DataFrame, overwrite: bool = False
664
+ ):
665
+ """Update (add or overwrite) the dataset description.
666
+
667
+ Parameters
668
+ ----------
669
+ description : dict | pd.DataFrame
670
+ Description in the form key: value where the length of the value
671
+ has to match the number of datasets.
672
+ overwrite : bool
673
+ Has to be True if a key in description already exists in the
674
+ dataset description.
675
+ """
676
+ description = pd.DataFrame(description)
677
+ for key, value in description.items():
678
+ for ds, value_ in zip(self.datasets, value):
679
+ ds.set_description({key: value_}, overwrite=overwrite)
680
+
681
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
682
+ """Save datasets to files by creating one subdirectory for each dataset::
683
+
684
+ path/
685
+ 0/
686
+ 0-raw.fif | 0-epo.fif
687
+ description.json
688
+ raw_preproc_kwargs.json (if raws were preprocessed)
689
+ window_kwargs.json (if this is a windowed dataset)
690
+ window_preproc_kwargs.json (if windows were preprocessed)
691
+ target_name.json (if target_name is not None and dataset is raw)
692
+ 1/
693
+ 1-raw.fif | 1-epo.fif
694
+ description.json
695
+ raw_preproc_kwargs.json (if raws were preprocessed)
696
+ window_kwargs.json (if this is a windowed dataset)
697
+ window_preproc_kwargs.json (if windows were preprocessed)
698
+ target_name.json (if target_name is not None and dataset is raw)
699
+
700
+ Parameters
701
+ ----------
702
+ path : str
703
+ Directory in which subdirectories are created to store
704
+ -raw.fif | -epo.fif and .json files to.
705
+ overwrite : bool
706
+ Whether to delete old subdirectories that will be saved to in this
707
+ call.
708
+ offset : int
709
+ If provided, the integer is added to the id of the dataset in the
710
+ concat. This is useful in the setting of very large datasets, where
711
+ one dataset has to be processed and saved at a time to account for
712
+ its original position.
713
+ """
714
+ if len(self.datasets) == 0:
715
+ raise ValueError("Expect at least one dataset")
716
+ if not (
717
+ hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
718
+ ):
719
+ raise ValueError("dataset should have either raw or windows attribute")
720
+
721
+ # Create path if it doesn't exist
722
+ os.makedirs(path, exist_ok=True)
723
+
724
+ path_contents = os.listdir(path)
725
+ n_sub_dirs = len(
726
+ [e for e in path_contents if os.path.isdir(os.path.join(path, e))]
727
+ )
728
+ for i_ds, ds in enumerate(self.datasets):
729
+ # remove subdirectory from list of untouched files / subdirectories
730
+ if str(i_ds + offset) in path_contents:
731
+ path_contents.remove(str(i_ds + offset))
732
+ # save_dir/i_ds/
733
+ sub_dir = os.path.join(path, str(i_ds + offset))
734
+ if os.path.exists(sub_dir):
735
+ if overwrite:
736
+ shutil.rmtree(sub_dir)
737
+ else:
738
+ raise FileExistsError(
739
+ f"Subdirectory {sub_dir} already exists. Please select"
740
+ f" a different directory, set overwrite=True, or "
741
+ f"resolve manually."
742
+ )
743
+ # save_dir/{i_ds+offset}/
744
+ os.makedirs(sub_dir)
745
+ # save_dir/{i_ds+offset}/{i_ds+offset}-{raw_or_epo}.fif
746
+ self._save_signals(sub_dir, ds, i_ds, offset)
747
+ # save_dir/{i_ds+offset}/metadata_df.pkl
748
+ self._save_metadata(sub_dir, ds)
749
+ # save_dir/{i_ds+offset}/description.json
750
+ self._save_description(sub_dir, ds.description)
751
+ # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
752
+ # save_dir/{i_ds+offset}/window_kwargs.json
753
+ # save_dir/{i_ds+offset}/window_preproc_kwargs.json
754
+ self._save_kwargs(sub_dir, ds)
755
+ # save_dir/{i_ds+offset}/target_name.json
756
+ self._save_target_name(sub_dir, ds)
757
+ if overwrite:
758
+ # the following will be True for all datasets preprocessed and
759
+ # stored in parallel with braindecode.preprocessing.preprocess
760
+ if i_ds + 1 + offset < n_sub_dirs:
761
+ warnings.warn(
762
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
763
+ f"does not match the number of existing "
764
+ f"subdirectories ({n_sub_dirs}). You may now "
765
+ f"encounter a mix of differently preprocessed "
766
+ f"datasets!",
767
+ UserWarning,
768
+ )
769
+ # if path contains files or directories that were not touched, raise
770
+ # warning
771
+ if path_contents:
772
+ warnings.warn(
773
+ f"Chosen directory {path} contains other "
774
+ f"subdirectories or files {path_contents}."
775
+ )
776
+
777
+ @staticmethod
778
+ def _save_signals(sub_dir, ds, i_ds, offset):
779
+ raw_or_epo = "raw" if hasattr(ds, "raw") else "epo"
780
+ fif_file_name = f"{i_ds + offset}-{raw_or_epo}.fif"
781
+ fif_file_path = os.path.join(sub_dir, fif_file_name)
782
+ raw_or_windows = "raw" if raw_or_epo == "raw" else "windows"
783
+
784
+ # The following appears to be necessary to avoid a CI failure when
785
+ # preprocessing WindowsDatasets with serialization enabled. The failure
786
+ # comes from `mne.epochs._check_consistency` which ensures the Epochs's
787
+ # object `times` attribute is not writeable.
788
+ getattr(ds, raw_or_windows).times.flags["WRITEABLE"] = False
789
+
790
+ getattr(ds, raw_or_windows).save(fif_file_path)
791
+
792
+ @staticmethod
793
+ def _save_metadata(sub_dir, ds):
794
+ if hasattr(ds, "metadata"):
795
+ metadata_file_path = os.path.join(sub_dir, "metadata_df.pkl")
796
+ ds.metadata.to_pickle(metadata_file_path)
797
+
798
+ @staticmethod
799
+ def _save_description(sub_dir, description):
800
+ description_file_path = os.path.join(sub_dir, "description.json")
801
+ description.to_json(description_file_path, default_handler=str)
802
+
803
+ @staticmethod
804
+ def _save_kwargs(sub_dir, ds):
805
+ for kwargs_name in [
806
+ "raw_preproc_kwargs",
807
+ "window_kwargs",
808
+ "window_preproc_kwargs",
809
+ ]:
810
+ if hasattr(ds, kwargs_name):
811
+ kwargs_file_name = ".".join([kwargs_name, "json"])
812
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
813
+ kwargs = getattr(ds, kwargs_name)
814
+ if kwargs is not None:
815
+ with open(kwargs_file_path, "w") as f:
816
+ json.dump(kwargs, f, indent=2)
817
+
818
+ @staticmethod
819
+ def _save_target_name(sub_dir, ds):
820
+ if hasattr(ds, "target_name"):
821
+ target_file_path = os.path.join(sub_dir, "target_name.json")
822
+ with open(target_file_path, "w") as f:
823
+ json.dump({"target_name": ds.target_name}, f)