braindecode 1.3.0.dev175955015__py3-none-any.whl → 1.3.0.dev176481332__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (65) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +10 -2
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +2 -2
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/serialization.py +7 -7
  15. braindecode/functional/functions.py +6 -2
  16. braindecode/functional/initialization.py +2 -3
  17. braindecode/models/__init__.py +2 -0
  18. braindecode/models/atcnet.py +26 -27
  19. braindecode/models/attentionbasenet.py +37 -32
  20. braindecode/models/attn_sleep.py +2 -0
  21. braindecode/models/base.py +2 -2
  22. braindecode/models/bendr.py +469 -0
  23. braindecode/models/biot.py +2 -0
  24. braindecode/models/contrawr.py +2 -0
  25. braindecode/models/ctnet.py +8 -3
  26. braindecode/models/deepsleepnet.py +28 -19
  27. braindecode/models/eegconformer.py +2 -2
  28. braindecode/models/eeginception_erp.py +31 -25
  29. braindecode/models/eegitnet.py +2 -0
  30. braindecode/models/eegminer.py +2 -0
  31. braindecode/models/eegnet.py +1 -1
  32. braindecode/models/eegtcnet.py +2 -0
  33. braindecode/models/fbcnet.py +2 -0
  34. braindecode/models/fblightconvnet.py +2 -0
  35. braindecode/models/fbmsnet.py +2 -0
  36. braindecode/models/ifnet.py +2 -0
  37. braindecode/models/labram.py +33 -26
  38. braindecode/models/msvtnet.py +2 -0
  39. braindecode/models/patchedtransformer.py +1 -1
  40. braindecode/models/signal_jepa.py +8 -0
  41. braindecode/models/sinc_shallow.py +12 -9
  42. braindecode/models/sstdpn.py +11 -11
  43. braindecode/models/summary.csv +1 -0
  44. braindecode/models/syncnet.py +2 -0
  45. braindecode/models/tcn.py +2 -0
  46. braindecode/models/usleep.py +26 -21
  47. braindecode/models/util.py +1 -0
  48. braindecode/modules/attention.py +10 -10
  49. braindecode/modules/blocks.py +3 -3
  50. braindecode/modules/filter.py +2 -3
  51. braindecode/modules/layers.py +18 -17
  52. braindecode/preprocessing/__init__.py +24 -0
  53. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  54. braindecode/preprocessing/preprocess.py +12 -12
  55. braindecode/preprocessing/util.py +166 -0
  56. braindecode/preprocessing/windowers.py +26 -20
  57. braindecode/samplers/base.py +8 -8
  58. braindecode/version.py +1 -1
  59. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/METADATA +4 -2
  60. braindecode-1.3.0.dev176481332.dist-info/RECORD +106 -0
  61. braindecode-1.3.0.dev175955015.dist-info/RECORD +0 -103
  62. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/WHEEL +0 -0
  63. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/licenses/LICENSE.txt +0 -0
  64. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/licenses/NOTICE.txt +0 -0
  65. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/top_level.txt +0 -0
@@ -189,7 +189,7 @@ class AugmentedDataLoader(DataLoader):
189
189
 
190
190
  Parameters
191
191
  ----------
192
- dataset : BaseDataset
192
+ dataset : RecordDataset
193
193
  The dataset containing the signals.
194
194
  transforms : list | Transform, optional
195
195
  Transform or sequence of Transform to be applied to each batch.
@@ -2,7 +2,13 @@
2
2
  Loader code for some datasets.
3
3
  """
4
4
 
5
- from .base import BaseConcatDataset, BaseDataset, WindowsDataset
5
+ from .base import (
6
+ BaseConcatDataset,
7
+ EEGWindowsDataset,
8
+ RawDataset,
9
+ RecordDataset,
10
+ WindowsDataset,
11
+ )
6
12
  from .bcicomp import BCICompetitionIVDataset4
7
13
  from .bids import BIDSDataset, BIDSEpochsDataset
8
14
  from .mne import create_from_mne_epochs, create_from_mne_raw
@@ -15,7 +21,9 @@ from .xy import create_from_X_y
15
21
 
16
22
  __all__ = [
17
23
  "WindowsDataset",
18
- "BaseDataset",
24
+ "EEGWindowsDataset",
25
+ "RecordDataset",
26
+ "RawDataset",
19
27
  "BaseConcatDataset",
20
28
  "BIDSDataset",
21
29
  "BIDSEpochsDataset",
@@ -16,14 +16,17 @@ import json
16
16
  import os
17
17
  import shutil
18
18
  import warnings
19
+ from abc import abstractmethod
19
20
  from collections.abc import Callable
20
21
  from glob import glob
21
- from typing import Iterable, no_type_check
22
+ from typing import Generic, Iterable, no_type_check
22
23
 
23
24
  import mne.io
24
25
  import numpy as np
25
26
  import pandas as pd
27
+ from mne.utils.docs import deprecated
26
28
  from torch.utils.data import ConcatDataset, Dataset
29
+ from typing_extensions import TypeVar
27
30
 
28
31
 
29
32
  def _create_description(description) -> pd.Series:
@@ -37,7 +40,64 @@ def _create_description(description) -> pd.Series:
37
40
  return description
38
41
 
39
42
 
40
- class BaseDataset(Dataset):
43
+ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]]):
44
+ def __init__(
45
+ self,
46
+ description: dict | pd.Series | None = None,
47
+ transform: Callable | None = None,
48
+ ):
49
+ self._description = _create_description(description)
50
+ self.transform = transform
51
+
52
+ @abstractmethod
53
+ def __len__(self) -> int:
54
+ pass
55
+
56
+ @property
57
+ def description(self) -> pd.Series:
58
+ return self._description
59
+
60
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
61
+ """Update (add or overwrite) the dataset description.
62
+
63
+ Parameters
64
+ ----------
65
+ description: dict | pd.Series
66
+ Description in the form key: value.
67
+ overwrite: bool
68
+ Has to be True if a key in description already exists in the
69
+ dataset description.
70
+ """
71
+ description = _create_description(description)
72
+ if self.description is None:
73
+ self._description = description
74
+ else:
75
+ for key, value in description.items():
76
+ # if the key is already in the existing description, drop it
77
+ if key in self._description:
78
+ assert overwrite, (
79
+ f"'{key}' already in description. Please "
80
+ f"rename or set overwrite to True."
81
+ )
82
+ self._description.pop(key)
83
+ self._description = pd.concat([self.description, description])
84
+
85
+ @property
86
+ def transform(self) -> Callable | None:
87
+ return self._transform
88
+
89
+ @transform.setter
90
+ def transform(self, value: Callable | None):
91
+ if value is not None and not callable(value):
92
+ raise ValueError("Transform needs to be a callable.")
93
+ self._transform = value
94
+
95
+
96
+ # Type of the datasets contained in BaseConcatDataset
97
+ T = TypeVar("T", bound=RecordDataset)
98
+
99
+
100
+ class RawDataset(RecordDataset):
41
101
  """Returns samples from an mne.io.Raw object along with a target.
42
102
 
43
103
  Dataset which serves samples from an mne.io.Raw object along with a target.
@@ -64,9 +124,8 @@ class BaseDataset(Dataset):
64
124
  target_name: str | tuple[str, ...] | None = None,
65
125
  transform: Callable | None = None,
66
126
  ):
127
+ super().__init__(description, transform)
67
128
  self.raw = raw
68
- self._description = _create_description(description)
69
- self.transform = transform
70
129
 
71
130
  # save target name for load/save later
72
131
  self.target_name = self._target_name(target_name)
@@ -85,45 +144,6 @@ class BaseDataset(Dataset):
85
144
  def __len__(self):
86
145
  return len(self.raw)
87
146
 
88
- @property
89
- def transform(self):
90
- return self._transform
91
-
92
- @transform.setter
93
- def transform(self, value):
94
- if value is not None and not callable(value):
95
- raise ValueError("Transform needs to be a callable.")
96
- self._transform = value
97
-
98
- @property
99
- def description(self) -> pd.Series:
100
- return self._description
101
-
102
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
103
- """Update (add or overwrite) the dataset description.
104
-
105
- Parameters
106
- ----------
107
- description: dict | pd.Series
108
- Description in the form key: value.
109
- overwrite: bool
110
- Has to be True if a key in description already exists in the
111
- dataset description.
112
- """
113
- description = _create_description(description)
114
- for key, value in description.items():
115
- # if the key is already in the existing description, drop it
116
- if self._description is not None and key in self._description:
117
- assert overwrite, (
118
- f"'{key}' already in description. Please "
119
- f"rename or set overwrite to True."
120
- )
121
- self._description.pop(key)
122
- if self._description is None:
123
- self._description = description
124
- else:
125
- self._description = pd.concat([self.description, description])
126
-
127
147
  def _target_name(self, target_name):
128
148
  if target_name is not None and not isinstance(target_name, (str, tuple, list)):
129
149
  raise ValueError("target_name has to be None, str, tuple or list")
@@ -150,7 +170,17 @@ class BaseDataset(Dataset):
150
170
  return target_name if len(target_name) > 1 else target_name[0]
151
171
 
152
172
 
153
- class EEGWindowsDataset(BaseDataset):
173
+ @deprecated(
174
+ "The BaseDataset class is deprecated. "
175
+ "If you want to instantiate a dataset containing raws, use RawDataset instead. "
176
+ "If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
177
+ "use the RecordDataset class instead."
178
+ )
179
+ class BaseDataset(RawDataset):
180
+ pass
181
+
182
+
183
+ class EEGWindowsDataset(RecordDataset):
154
184
  """Returns windows from an mne.Raw object, its window indices, along with a target.
155
185
 
156
186
  Dataset which serves windows from an mne.Epochs object along with their
@@ -161,12 +191,12 @@ class EEGWindowsDataset(BaseDataset):
161
191
  required to serve information about the windowing (e.g., useful for cropped
162
192
  training).
163
193
  See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
164
- from a `BaseDataset` object.
194
+ from a `RawDataset` object.
165
195
 
166
196
  Parameters
167
197
  ----------
168
198
  windows : mne.Raw or mne.Epochs (Epochs is outdated)
169
- Windows obtained through the application of a windower to a BaseDataset
199
+ Windows obtained through the application of a windower to a ``RawDataset``
170
200
  (see `braindecode.datautil.windowers`).
171
201
  description : dict | pandas.Series | None
172
202
  Holds additional info about the windows.
@@ -185,18 +215,17 @@ class EEGWindowsDataset(BaseDataset):
185
215
 
186
216
  def __init__(
187
217
  self,
188
- raw: mne.io.BaseRaw | mne.BaseEpochs,
218
+ raw: mne.io.BaseRaw,
189
219
  metadata: pd.DataFrame,
190
220
  description: dict | pd.Series | None = None,
191
221
  transform: Callable | None = None,
192
222
  targets_from: str = "metadata",
193
223
  last_target_only: bool = True,
194
224
  ):
225
+ super().__init__(description, transform)
195
226
  self.raw = raw
196
227
  self.metadata = metadata
197
- self._description = _create_description(description)
198
228
 
199
- self.transform = transform
200
229
  self.last_target_only = last_target_only
201
230
  if targets_from not in ("metadata", "channels"):
202
231
  raise ValueError("Wrong value for parameter `targets_from`.")
@@ -255,44 +284,8 @@ class EEGWindowsDataset(BaseDataset):
255
284
  def __len__(self):
256
285
  return len(self.crop_inds)
257
286
 
258
- @property
259
- def transform(self):
260
- return self._transform
261
-
262
- @transform.setter
263
- def transform(self, value):
264
- if value is not None and not callable(value):
265
- raise ValueError("Transform needs to be a callable.")
266
- self._transform = value
267
-
268
- @property
269
- def description(self) -> pd.Series:
270
- return self._description
271
-
272
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
273
- """Update (add or overwrite) the dataset description.
274
-
275
- Parameters
276
- ----------
277
- description: dict | pd.Series
278
- Description in the form key: value.
279
- overwrite: bool
280
- Has to be True if a key in description already exists in the
281
- dataset description.
282
- """
283
- description = _create_description(description)
284
- for key, value in description.items():
285
- # if they key is already in the existing description, drop it
286
- if key in self._description:
287
- assert overwrite, (
288
- f"'{key}' already in description. Please "
289
- f"rename or set overwrite to True."
290
- )
291
- self._description.pop(key)
292
- self._description = pd.concat([self.description, description])
293
287
 
294
-
295
- class WindowsDataset(BaseDataset):
288
+ class WindowsDataset(RecordDataset):
296
289
  """Returns windows from an mne.Epochs object along with a target.
297
290
 
298
291
  Dataset which serves windows from an mne.Epochs object along with their
@@ -303,12 +296,12 @@ class WindowsDataset(BaseDataset):
303
296
  required to serve information about the windowing (e.g., useful for cropped
304
297
  training).
305
298
  See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
306
- from a `BaseDataset` object.
299
+ from a ``RawDataset`` object.
307
300
 
308
301
  Parameters
309
302
  ----------
310
303
  windows : mne.Epochs
311
- Windows obtained through the application of a windower to a BaseDataset
304
+ Windows obtained through the application of a windower to a RawDataset
312
305
  (see `braindecode.datautil.windowers`).
313
306
  description : dict | pandas.Series | None
314
307
  Holds additional info about the windows.
@@ -327,19 +320,20 @@ class WindowsDataset(BaseDataset):
327
320
  targets_from: str = "metadata",
328
321
  last_target_only: bool = True,
329
322
  ):
323
+ super().__init__(description, transform)
330
324
  self.windows = windows
331
- self._description = _create_description(description)
332
- self.transform = transform
333
325
  self.last_target_only = last_target_only
334
326
  if targets_from not in ("metadata", "channels"):
335
327
  raise ValueError("Wrong value for parameter `targets_from`.")
336
328
  self.targets_from = targets_from
337
329
 
338
- self.crop_inds = self.windows.metadata.loc[
330
+ metadata = self.windows.metadata
331
+ assert metadata is not None, "WindowsDataset requires windows with metadata."
332
+ self.crop_inds = metadata.loc[
339
333
  :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
340
334
  ].to_numpy()
341
335
  if self.targets_from == "metadata":
342
- self.y = self.windows.metadata.loc[:, "target"].to_list()
336
+ self.y = metadata.loc[:, "target"].to_list()
343
337
 
344
338
  def __getitem__(self, index: int):
345
339
  """Get a window and its target.
@@ -379,44 +373,8 @@ class WindowsDataset(BaseDataset):
379
373
  def __len__(self) -> int:
380
374
  return len(self.windows.events)
381
375
 
382
- @property
383
- def transform(self):
384
- return self._transform
385
-
386
- @transform.setter
387
- def transform(self, value):
388
- if value is not None and not callable(value):
389
- raise ValueError("Transform needs to be a callable.")
390
- self._transform = value
391
376
 
392
- @property
393
- def description(self) -> pd.Series:
394
- return self._description
395
-
396
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
397
- """Update (add or overwrite) the dataset description.
398
-
399
- Parameters
400
- ----------
401
- description: dict | pd.Series
402
- Description in the form key: value.
403
- overwrite: bool
404
- Has to be True if a key in description already exists in the
405
- dataset description.
406
- """
407
- description = _create_description(description)
408
- for key, value in description.items():
409
- # if they key is already in the existing description, drop it
410
- if key in self._description:
411
- assert overwrite, (
412
- f"'{key}' already in description. Please "
413
- f"rename or set overwrite to True."
414
- )
415
- self._description.pop(key)
416
- self._description = pd.concat([self.description, description])
417
-
418
-
419
- class BaseConcatDataset(ConcatDataset):
377
+ class BaseConcatDataset(ConcatDataset, Generic[T]):
420
378
  """A base class for concatenated datasets.
421
379
 
422
380
  Holds either mne.Raw or mne.Epoch in self.datasets and has
@@ -425,22 +383,27 @@ class BaseConcatDataset(ConcatDataset):
425
383
  Parameters
426
384
  ----------
427
385
  list_of_ds : list
428
- list of BaseDataset, BaseConcatDataset or WindowsDataset
386
+ list of RecordDataset
429
387
  target_transform : callable | None
430
388
  Optional function to call on targets before returning them.
431
389
 
432
390
  """
433
391
 
392
+ datasets: list[T]
393
+
434
394
  def __init__(
435
395
  self,
436
- list_of_ds: list[BaseDataset | BaseConcatDataset | WindowsDataset]
437
- | None = None,
396
+ list_of_ds: list[T | BaseConcatDataset[T]],
438
397
  target_transform: Callable | None = None,
439
398
  ):
440
399
  # if we get a list of BaseConcatDataset, get all the individual datasets
441
- if list_of_ds and isinstance(list_of_ds[0], BaseConcatDataset):
442
- list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
443
- super().__init__(list_of_ds)
400
+ flattened_list_of_ds: list[T] = []
401
+ for ds in list_of_ds:
402
+ if isinstance(ds, BaseConcatDataset):
403
+ flattened_list_of_ds.extend(ds.datasets)
404
+ else:
405
+ flattened_list_of_ds.append(ds)
406
+ super().__init__(flattened_list_of_ds)
444
407
 
445
408
  self.target_transform = target_transform
446
409
 
@@ -703,22 +666,23 @@ class BaseConcatDataset(ConcatDataset):
703
666
  ds.set_description({key: value_}, overwrite=overwrite)
704
667
 
705
668
  def save(self, path: str, overwrite: bool = False, offset: int = 0):
706
- """Save datasets to files by creating one subdirectory for each dataset:
707
- path/
708
- 0/
709
- 0-raw.fif | 0-epo.fif
710
- description.json
711
- raw_preproc_kwargs.json (if raws were preprocessed)
712
- window_kwargs.json (if this is a windowed dataset)
713
- window_preproc_kwargs.json (if windows were preprocessed)
714
- target_name.json (if target_name is not None and dataset is raw)
715
- 1/
716
- 1-raw.fif | 1-epo.fif
717
- description.json
718
- raw_preproc_kwargs.json (if raws were preprocessed)
719
- window_kwargs.json (if this is a windowed dataset)
720
- window_preproc_kwargs.json (if windows were preprocessed)
721
- target_name.json (if target_name is not None and dataset is raw)
669
+ """Save datasets to files by creating one subdirectory for each dataset::
670
+
671
+ path/
672
+ 0/
673
+ 0-raw.fif | 0-epo.fif
674
+ description.json
675
+ raw_preproc_kwargs.json (if raws were preprocessed)
676
+ window_kwargs.json (if this is a windowed dataset)
677
+ window_preproc_kwargs.json (if windows were preprocessed)
678
+ target_name.json (if target_name is not None and dataset is raw)
679
+ 1/
680
+ 1-raw.fif | 1-epo.fif
681
+ description.json
682
+ raw_preproc_kwargs.json (if raws were preprocessed)
683
+ window_kwargs.json (if this is a windowed dataset)
684
+ window_preproc_kwargs.json (if windows were preprocessed)
685
+ target_name.json (if target_name is not None and dataset is raw)
722
686
 
723
687
  Parameters
724
688
  ----------
@@ -16,7 +16,7 @@ import numpy as np
16
16
  from mne.utils import verbose
17
17
  from scipy.io import loadmat
18
18
 
19
- from braindecode.datasets import BaseConcatDataset, BaseDataset
19
+ from braindecode.datasets import BaseConcatDataset, RawDataset
20
20
 
21
21
  DATASET_URL = (
22
22
  "https://stacks.stanford.edu/file/druid:zk881ps0522/"
@@ -73,8 +73,8 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
73
73
  file_name=file_path.split("/")[-1],
74
74
  session="test",
75
75
  )
76
- datasets.append(BaseDataset(raw_train, description=desc_train))
77
- datasets.append(BaseDataset(raw_test, description=desc_test))
76
+ datasets.append(RawDataset(raw_train, description=desc_train))
77
+ datasets.append(RawDataset(raw_test, description=desc_test))
78
78
  super().__init__(datasets)
79
79
 
80
80
  @staticmethod
@@ -85,7 +85,7 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
85
85
  ----------
86
86
  path (None | str) – Location of where to look for the data storing location.
87
87
  If None, the environment variable or config parameter
88
- MNE_DATASETS_(dataset)_PATH is used. If it doesnt exist, the “~/mne_data”
88
+ MNE_DATASETS_(dataset)_PATH is used. If it doesn't exist, the “~/mne_data”
89
89
  directory is used. If the dataset is not found under the given path, the data
90
90
  will be automatically downloaded to the specified folder.
91
91
  force_update (bool) – Force update of the dataset even if a local copy exists.
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import pandas as pd
20
20
  from joblib import Parallel, delayed
21
21
 
22
- from .base import BaseConcatDataset, BaseDataset, WindowsDataset
22
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
23
23
 
24
24
 
25
25
  def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
@@ -186,12 +186,12 @@ class BIDSDataset(BaseConcatDataset):
186
186
  )
187
187
  super().__init__(all_base_ds)
188
188
 
189
- def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
189
+ def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> RawDataset:
190
190
  description = _description_from_bids_path(bids_path)
191
191
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)
192
192
  if self.preload:
193
193
  raw.load_data()
194
- return BaseDataset(raw, description)
194
+ return RawDataset(raw, description)
195
195
 
196
196
 
197
197
  class BIDSEpochsDataset(BIDSDataset):
@@ -25,14 +25,14 @@ class BIDSIterableDataset(IterableDataset):
25
25
 
26
26
  Examples
27
27
  --------
28
- >>> from braindecode.datasets import BaseDataset, BaseConcatDataset
28
+ >>> from braindecode.datasets import RecordDataset, BaseConcatDataset
29
29
  >>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
30
30
  >>> from braindecode.preprocessing import create_fixed_length_windows
31
31
  >>>
32
32
  >>> def my_reader_fn(path):
33
33
  ... raw = mne_bids.read_raw_bids(path)
34
34
  ... desc = _description_from_bids_path(path)
35
- ... ds = BaseDataset(raw, description=desc)
35
+ ... ds = RawDataset(raw, description=desc)
36
36
  ... windows_ds = create_fixed_length_windows(
37
37
  ... BaseConcatDataset([ds]),
38
38
  ... window_size_samples=400,
@@ -9,7 +9,7 @@ import mne
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
 
12
- from .base import BaseConcatDataset, BaseDataset, WindowsDataset
12
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
13
13
 
14
14
 
15
15
  def create_from_mne_raw(
@@ -75,11 +75,9 @@ def create_from_mne_raw(
75
75
  f"length of 'raws' ({len(raws)}) and 'description' "
76
76
  f"({len(descriptions)}) has to match"
77
77
  )
78
- base_datasets = [
79
- BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
80
- ]
78
+ base_datasets = [RawDataset(raw, desc) for raw, desc in zip(raws, descriptions)]
81
79
  else:
82
- base_datasets = [BaseDataset(raw) for raw in raws]
80
+ base_datasets = [RawDataset(raw) for raw in raws]
83
81
 
84
82
  base_datasets = BaseConcatDataset(base_datasets)
85
83
  windows_datasets = create_windows_from_events(
@@ -18,7 +18,7 @@ import pandas as pd
18
18
 
19
19
  from braindecode.util import _update_moabb_docstring
20
20
 
21
- from .base import BaseConcatDataset, BaseDataset
21
+ from .base import BaseConcatDataset, RawDataset
22
22
 
23
23
 
24
24
  def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
@@ -164,7 +164,7 @@ class MOABBDataset(BaseConcatDataset):
164
164
  dataset_load_kwargs=dataset_load_kwargs,
165
165
  )
166
166
  all_base_ds = [
167
- BaseDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
167
+ RawDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
168
168
  ]
169
169
  super().__init__(all_base_ds)
170
170
 
@@ -31,7 +31,7 @@ import pandas as pd
31
31
  from joblib import Parallel, delayed
32
32
  from mne.datasets import fetch_dataset
33
33
 
34
- from braindecode.datasets.base import BaseConcatDataset, BaseDataset
34
+ from braindecode.datasets.base import BaseConcatDataset, RawDataset
35
35
 
36
36
  NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
37
37
  NMT_archive_name = "NMT.zip"
@@ -172,7 +172,7 @@ class NMT(BaseConcatDataset):
172
172
  d["n_samples"] = raw.n_times
173
173
  d["sfreq"] = raw.info["sfreq"]
174
174
  d["train"] = "train" in d.path.split(os.sep)
175
- base_dataset = BaseDataset(raw, d, target_name)
175
+ base_dataset = RawDataset(raw, d, target_name)
176
176
  return base_dataset
177
177
 
178
178
 
@@ -21,7 +21,7 @@ from mne.datasets.sleep_physionet._utils import _fetch_one
21
21
  from mne.datasets.utils import _get_path
22
22
  from mne.utils import warn
23
23
 
24
- from braindecode.datasets import BaseConcatDataset, BaseDataset
24
+ from braindecode.datasets import BaseConcatDataset, RawDataset
25
25
 
26
26
  PC18_DIR = op.join(op.dirname(__file__), "data", "pc18")
27
27
  PC18_RECORDS = op.join(PC18_DIR, "sleep_records.csv")
@@ -403,7 +403,7 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
403
403
  },
404
404
  name="",
405
405
  )
406
- base_dataset = BaseDataset(raw_file, desc)
406
+ base_dataset = RawDataset(raw_file, desc)
407
407
 
408
408
  if preproc is not None:
409
409
  from braindecode.preprocessing.preprocess import _preprocess
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import pandas as pd
13
13
  from mne.datasets.sleep_physionet.age import fetch_data
14
14
 
15
- from .base import BaseConcatDataset, BaseDataset
15
+ from .base import BaseConcatDataset, RawDataset
16
16
 
17
17
 
18
18
  class SleepPhysionet(BaseConcatDataset):
@@ -71,7 +71,7 @@ class SleepPhysionet(BaseConcatDataset):
71
71
  crop_wake_mins=crop_wake_mins,
72
72
  crop=crop,
73
73
  )
74
- base_ds = BaseDataset(raw, desc)
74
+ base_ds = RawDataset(raw, desc)
75
75
  all_base_ds.append(base_ds)
76
76
  super().__init__(all_base_ds)
77
77
 
@@ -22,7 +22,7 @@ import numpy as np
22
22
  import pandas as pd
23
23
  from joblib import Parallel, delayed
24
24
 
25
- from .base import BaseConcatDataset, BaseDataset
25
+ from .base import BaseConcatDataset, RawDataset
26
26
 
27
27
 
28
28
  class TUH(BaseConcatDataset):
@@ -214,7 +214,7 @@ class TUH(BaseConcatDataset):
214
214
  d["report"] = physician_report
215
215
  additional_description = pd.Series(d)
216
216
  description = pd.concat([description, additional_description])
217
- base_dataset = BaseDataset(raw, description, target_name=target_name)
217
+ base_dataset = RawDataset(raw, description, target_name=target_name)
218
218
  return base_dataset
219
219
 
220
220
 
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import pandas as pd
13
13
  from numpy.typing import ArrayLike, NDArray
14
14
 
15
- from .base import BaseConcatDataset, BaseDataset
15
+ from .base import BaseConcatDataset, RawDataset
16
16
 
17
17
  log = logging.getLogger(__name__)
18
18
 
@@ -69,7 +69,7 @@ def create_from_X_y(
69
69
  n_samples_per_x.append(x.shape[1])
70
70
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
71
71
  raw = mne.io.RawArray(x, info)
72
- base_dataset = BaseDataset(
72
+ base_dataset = RawDataset(
73
73
  raw, pd.Series({"target": target}), target_name="target"
74
74
  )
75
75
  base_datasets.append(base_dataset)