braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/augmentation/functional.py +154 -54
  3. braindecode/augmentation/transforms.py +2 -2
  4. braindecode/datasets/__init__.py +10 -2
  5. braindecode/datasets/base.py +116 -152
  6. braindecode/datasets/bcicomp.py +4 -4
  7. braindecode/datasets/bids.py +3 -3
  8. braindecode/datasets/experimental.py +218 -0
  9. braindecode/datasets/mne.py +3 -5
  10. braindecode/datasets/moabb.py +2 -2
  11. braindecode/datasets/nmt.py +2 -2
  12. braindecode/datasets/sleep_physio_challe_18.py +4 -3
  13. braindecode/datasets/sleep_physionet.py +2 -2
  14. braindecode/datasets/tuh.py +2 -2
  15. braindecode/datasets/xy.py +2 -2
  16. braindecode/datautil/serialization.py +18 -13
  17. braindecode/eegneuralnet.py +2 -0
  18. braindecode/functional/functions.py +6 -2
  19. braindecode/functional/initialization.py +2 -3
  20. braindecode/models/__init__.py +12 -8
  21. braindecode/models/atcnet.py +156 -17
  22. braindecode/models/attentionbasenet.py +148 -16
  23. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  24. braindecode/models/base.py +280 -2
  25. braindecode/models/bendr.py +469 -0
  26. braindecode/models/biot.py +3 -1
  27. braindecode/models/ctnet.py +7 -4
  28. braindecode/models/deep4.py +6 -2
  29. braindecode/models/deepsleepnet.py +127 -5
  30. braindecode/models/eegconformer.py +114 -15
  31. braindecode/models/eeginception_erp.py +82 -7
  32. braindecode/models/eeginception_mi.py +2 -0
  33. braindecode/models/eegnet.py +64 -177
  34. braindecode/models/eegnex.py +113 -6
  35. braindecode/models/eegsimpleconv.py +2 -0
  36. braindecode/models/eegtcnet.py +1 -1
  37. braindecode/models/labram.py +188 -84
  38. braindecode/models/patchedtransformer.py +640 -0
  39. braindecode/models/sccnet.py +81 -8
  40. braindecode/models/shallow_fbcsp.py +2 -0
  41. braindecode/models/signal_jepa.py +109 -27
  42. braindecode/models/sinc_shallow.py +10 -9
  43. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  44. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  45. braindecode/models/sparcnet.py +2 -0
  46. braindecode/models/sstdpn.py +869 -0
  47. braindecode/models/summary.csv +42 -41
  48. braindecode/models/tidnet.py +2 -0
  49. braindecode/models/tsinception.py +15 -3
  50. braindecode/models/usleep.py +108 -9
  51. braindecode/models/util.py +8 -5
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -3
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +24 -0
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/preprocess.py +42 -39
  59. braindecode/preprocessing/util.py +166 -0
  60. braindecode/preprocessing/windowers.py +24 -19
  61. braindecode/samplers/base.py +8 -8
  62. braindecode/version.py +1 -1
  63. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
  64. braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
  65. braindecode/models/eegresnet.py +0 -362
  66. braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
  67. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
  68. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
@@ -16,14 +16,17 @@ import json
16
16
  import os
17
17
  import shutil
18
18
  import warnings
19
+ from abc import abstractmethod
19
20
  from collections.abc import Callable
20
21
  from glob import glob
21
- from typing import Iterable, no_type_check
22
+ from typing import Generic, Iterable, no_type_check
22
23
 
23
24
  import mne.io
24
25
  import numpy as np
25
26
  import pandas as pd
27
+ from mne.utils.docs import deprecated
26
28
  from torch.utils.data import ConcatDataset, Dataset
29
+ from typing_extensions import TypeVar
27
30
 
28
31
 
29
32
  def _create_description(description) -> pd.Series:
@@ -37,7 +40,64 @@ def _create_description(description) -> pd.Series:
37
40
  return description
38
41
 
39
42
 
40
- class BaseDataset(Dataset):
43
+ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]]):
44
+ def __init__(
45
+ self,
46
+ description: dict | pd.Series | None = None,
47
+ transform: Callable | None = None,
48
+ ):
49
+ self._description = _create_description(description)
50
+ self.transform = transform
51
+
52
+ @abstractmethod
53
+ def __len__(self) -> int:
54
+ pass
55
+
56
+ @property
57
+ def description(self) -> pd.Series:
58
+ return self._description
59
+
60
+ def set_description(self, description: dict | pd.Series, overwrite: bool = False):
61
+ """Update (add or overwrite) the dataset description.
62
+
63
+ Parameters
64
+ ----------
65
+ description: dict | pd.Series
66
+ Description in the form key: value.
67
+ overwrite: bool
68
+ Has to be True if a key in description already exists in the
69
+ dataset description.
70
+ """
71
+ description = _create_description(description)
72
+ if self.description is None:
73
+ self._description = description
74
+ else:
75
+ for key, value in description.items():
76
+ # if the key is already in the existing description, drop it
77
+ if key in self._description:
78
+ assert overwrite, (
79
+ f"'{key}' already in description. Please "
80
+ f"rename or set overwrite to True."
81
+ )
82
+ self._description.pop(key)
83
+ self._description = pd.concat([self.description, description])
84
+
85
+ @property
86
+ def transform(self) -> Callable | None:
87
+ return self._transform
88
+
89
+ @transform.setter
90
+ def transform(self, value: Callable | None):
91
+ if value is not None and not callable(value):
92
+ raise ValueError("Transform needs to be a callable.")
93
+ self._transform = value
94
+
95
+
96
+ # Type of the datasets contained in BaseConcatDataset
97
+ T = TypeVar("T", bound=RecordDataset)
98
+
99
+
100
+ class RawDataset(RecordDataset):
41
101
  """Returns samples from an mne.io.Raw object along with a target.
42
102
 
43
103
  Dataset which serves samples from an mne.io.Raw object along with a target.
@@ -64,9 +124,8 @@ class BaseDataset(Dataset):
64
124
  target_name: str | tuple[str, ...] | None = None,
65
125
  transform: Callable | None = None,
66
126
  ):
127
+ super().__init__(description, transform)
67
128
  self.raw = raw
68
- self._description = _create_description(description)
69
- self.transform = transform
70
129
 
71
130
  # save target name for load/save later
72
131
  self.target_name = self._target_name(target_name)
@@ -85,45 +144,6 @@ class BaseDataset(Dataset):
85
144
  def __len__(self):
86
145
  return len(self.raw)
87
146
 
88
- @property
89
- def transform(self):
90
- return self._transform
91
-
92
- @transform.setter
93
- def transform(self, value):
94
- if value is not None and not callable(value):
95
- raise ValueError("Transform needs to be a callable.")
96
- self._transform = value
97
-
98
- @property
99
- def description(self) -> pd.Series:
100
- return self._description
101
-
102
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
103
- """Update (add or overwrite) the dataset description.
104
-
105
- Parameters
106
- ----------
107
- description: dict | pd.Series
108
- Description in the form key: value.
109
- overwrite: bool
110
- Has to be True if a key in description already exists in the
111
- dataset description.
112
- """
113
- description = _create_description(description)
114
- for key, value in description.items():
115
- # if the key is already in the existing description, drop it
116
- if self._description is not None and key in self._description:
117
- assert overwrite, (
118
- f"'{key}' already in description. Please "
119
- f"rename or set overwrite to True."
120
- )
121
- self._description.pop(key)
122
- if self._description is None:
123
- self._description = description
124
- else:
125
- self._description = pd.concat([self.description, description])
126
-
127
147
  def _target_name(self, target_name):
128
148
  if target_name is not None and not isinstance(target_name, (str, tuple, list)):
129
149
  raise ValueError("target_name has to be None, str, tuple or list")
@@ -150,7 +170,17 @@ class BaseDataset(Dataset):
150
170
  return target_name if len(target_name) > 1 else target_name[0]
151
171
 
152
172
 
153
- class EEGWindowsDataset(BaseDataset):
173
+ @deprecated(
174
+ "The BaseDataset class is deprecated. "
175
+ "If you want to instantiate a dataset containing raws, use RawDataset instead. "
176
+ "If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
177
+ "use the RecordDataset class instead."
178
+ )
179
+ class BaseDataset(RawDataset):
180
+ pass
181
+
182
+
183
+ class EEGWindowsDataset(RecordDataset):
154
184
  """Returns windows from an mne.Raw object, its window indices, along with a target.
155
185
 
156
186
  Dataset which serves windows from an mne.Epochs object along with their
@@ -161,12 +191,12 @@ class EEGWindowsDataset(BaseDataset):
161
191
  required to serve information about the windowing (e.g., useful for cropped
162
192
  training).
163
193
  See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
164
- from a `BaseDataset` object.
194
+ from a `RawDataset` object.
165
195
 
166
196
  Parameters
167
197
  ----------
168
198
  windows : mne.Raw or mne.Epochs (Epochs is outdated)
169
- Windows obtained through the application of a windower to a BaseDataset
199
+ Windows obtained through the application of a windower to a ``RawDataset``
170
200
  (see `braindecode.datautil.windowers`).
171
201
  description : dict | pandas.Series | None
172
202
  Holds additional info about the windows.
@@ -185,18 +215,17 @@ class EEGWindowsDataset(BaseDataset):
185
215
 
186
216
  def __init__(
187
217
  self,
188
- raw: mne.io.BaseRaw | mne.BaseEpochs,
218
+ raw: mne.io.BaseRaw,
189
219
  metadata: pd.DataFrame,
190
220
  description: dict | pd.Series | None = None,
191
221
  transform: Callable | None = None,
192
222
  targets_from: str = "metadata",
193
223
  last_target_only: bool = True,
194
224
  ):
225
+ super().__init__(description, transform)
195
226
  self.raw = raw
196
227
  self.metadata = metadata
197
- self._description = _create_description(description)
198
228
 
199
- self.transform = transform
200
229
  self.last_target_only = last_target_only
201
230
  if targets_from not in ("metadata", "channels"):
202
231
  raise ValueError("Wrong value for parameter `targets_from`.")
@@ -255,44 +284,8 @@ class EEGWindowsDataset(BaseDataset):
255
284
  def __len__(self):
256
285
  return len(self.crop_inds)
257
286
 
258
- @property
259
- def transform(self):
260
- return self._transform
261
-
262
- @transform.setter
263
- def transform(self, value):
264
- if value is not None and not callable(value):
265
- raise ValueError("Transform needs to be a callable.")
266
- self._transform = value
267
-
268
- @property
269
- def description(self) -> pd.Series:
270
- return self._description
271
-
272
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
273
- """Update (add or overwrite) the dataset description.
274
-
275
- Parameters
276
- ----------
277
- description: dict | pd.Series
278
- Description in the form key: value.
279
- overwrite: bool
280
- Has to be True if a key in description already exists in the
281
- dataset description.
282
- """
283
- description = _create_description(description)
284
- for key, value in description.items():
285
- # if they key is already in the existing description, drop it
286
- if key in self._description:
287
- assert overwrite, (
288
- f"'{key}' already in description. Please "
289
- f"rename or set overwrite to True."
290
- )
291
- self._description.pop(key)
292
- self._description = pd.concat([self.description, description])
293
287
 
294
-
295
- class WindowsDataset(BaseDataset):
288
+ class WindowsDataset(RecordDataset):
296
289
  """Returns windows from an mne.Epochs object along with a target.
297
290
 
298
291
  Dataset which serves windows from an mne.Epochs object along with their
@@ -303,12 +296,12 @@ class WindowsDataset(BaseDataset):
303
296
  required to serve information about the windowing (e.g., useful for cropped
304
297
  training).
305
298
  See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
306
- from a `BaseDataset` object.
299
+ from a ``RawDataset`` object.
307
300
 
308
301
  Parameters
309
302
  ----------
310
303
  windows : mne.Epochs
311
- Windows obtained through the application of a windower to a BaseDataset
304
+ Windows obtained through the application of a windower to a RawDataset
312
305
  (see `braindecode.datautil.windowers`).
313
306
  description : dict | pandas.Series | None
314
307
  Holds additional info about the windows.
@@ -327,19 +320,20 @@ class WindowsDataset(BaseDataset):
327
320
  targets_from: str = "metadata",
328
321
  last_target_only: bool = True,
329
322
  ):
323
+ super().__init__(description, transform)
330
324
  self.windows = windows
331
- self._description = _create_description(description)
332
- self.transform = transform
333
325
  self.last_target_only = last_target_only
334
326
  if targets_from not in ("metadata", "channels"):
335
327
  raise ValueError("Wrong value for parameter `targets_from`.")
336
328
  self.targets_from = targets_from
337
329
 
338
- self.crop_inds = self.windows.metadata.loc[
330
+ metadata = self.windows.metadata
331
+ assert metadata is not None, "WindowsDataset requires windows with metadata."
332
+ self.crop_inds = metadata.loc[
339
333
  :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
340
334
  ].to_numpy()
341
335
  if self.targets_from == "metadata":
342
- self.y = self.windows.metadata.loc[:, "target"].to_list()
336
+ self.y = metadata.loc[:, "target"].to_list()
343
337
 
344
338
  def __getitem__(self, index: int):
345
339
  """Get a window and its target.
@@ -379,44 +373,8 @@ class WindowsDataset(BaseDataset):
379
373
  def __len__(self) -> int:
380
374
  return len(self.windows.events)
381
375
 
382
- @property
383
- def transform(self):
384
- return self._transform
385
-
386
- @transform.setter
387
- def transform(self, value):
388
- if value is not None and not callable(value):
389
- raise ValueError("Transform needs to be a callable.")
390
- self._transform = value
391
376
 
392
- @property
393
- def description(self) -> pd.Series:
394
- return self._description
395
-
396
- def set_description(self, description: dict | pd.Series, overwrite: bool = False):
397
- """Update (add or overwrite) the dataset description.
398
-
399
- Parameters
400
- ----------
401
- description: dict | pd.Series
402
- Description in the form key: value.
403
- overwrite: bool
404
- Has to be True if a key in description already exists in the
405
- dataset description.
406
- """
407
- description = _create_description(description)
408
- for key, value in description.items():
409
- # if they key is already in the existing description, drop it
410
- if key in self._description:
411
- assert overwrite, (
412
- f"'{key}' already in description. Please "
413
- f"rename or set overwrite to True."
414
- )
415
- self._description.pop(key)
416
- self._description = pd.concat([self.description, description])
417
-
418
-
419
- class BaseConcatDataset(ConcatDataset):
377
+ class BaseConcatDataset(ConcatDataset, Generic[T]):
420
378
  """A base class for concatenated datasets.
421
379
 
422
380
  Holds either mne.Raw or mne.Epoch in self.datasets and has
@@ -425,22 +383,27 @@ class BaseConcatDataset(ConcatDataset):
425
383
  Parameters
426
384
  ----------
427
385
  list_of_ds : list
428
- list of BaseDataset, BaseConcatDataset or WindowsDataset
386
+ list of RecordDataset
429
387
  target_transform : callable | None
430
388
  Optional function to call on targets before returning them.
431
389
 
432
390
  """
433
391
 
392
+ datasets: list[T]
393
+
434
394
  def __init__(
435
395
  self,
436
- list_of_ds: list[BaseDataset | BaseConcatDataset | WindowsDataset]
437
- | None = None,
396
+ list_of_ds: list[T | BaseConcatDataset[T]],
438
397
  target_transform: Callable | None = None,
439
398
  ):
440
399
  # if we get a list of BaseConcatDataset, get all the individual datasets
441
- if list_of_ds and isinstance(list_of_ds[0], BaseConcatDataset):
442
- list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
443
- super().__init__(list_of_ds)
400
+ flattened_list_of_ds: list[T] = []
401
+ for ds in list_of_ds:
402
+ if isinstance(ds, BaseConcatDataset):
403
+ flattened_list_of_ds.extend(ds.datasets)
404
+ else:
405
+ flattened_list_of_ds.append(ds)
406
+ super().__init__(flattened_list_of_ds)
444
407
 
445
408
  self.target_transform = target_transform
446
409
 
@@ -703,22 +666,23 @@ class BaseConcatDataset(ConcatDataset):
703
666
  ds.set_description({key: value_}, overwrite=overwrite)
704
667
 
705
668
  def save(self, path: str, overwrite: bool = False, offset: int = 0):
706
- """Save datasets to files by creating one subdirectory for each dataset:
707
- path/
708
- 0/
709
- 0-raw.fif | 0-epo.fif
710
- description.json
711
- raw_preproc_kwargs.json (if raws were preprocessed)
712
- window_kwargs.json (if this is a windowed dataset)
713
- window_preproc_kwargs.json (if windows were preprocessed)
714
- target_name.json (if target_name is not None and dataset is raw)
715
- 1/
716
- 1-raw.fif | 1-epo.fif
717
- description.json
718
- raw_preproc_kwargs.json (if raws were preprocessed)
719
- window_kwargs.json (if this is a windowed dataset)
720
- window_preproc_kwargs.json (if windows were preprocessed)
721
- target_name.json (if target_name is not None and dataset is raw)
669
+ """Save datasets to files by creating one subdirectory for each dataset::
670
+
671
+ path/
672
+ 0/
673
+ 0-raw.fif | 0-epo.fif
674
+ description.json
675
+ raw_preproc_kwargs.json (if raws were preprocessed)
676
+ window_kwargs.json (if this is a windowed dataset)
677
+ window_preproc_kwargs.json (if windows were preprocessed)
678
+ target_name.json (if target_name is not None and dataset is raw)
679
+ 1/
680
+ 1-raw.fif | 1-epo.fif
681
+ description.json
682
+ raw_preproc_kwargs.json (if raws were preprocessed)
683
+ window_kwargs.json (if this is a windowed dataset)
684
+ window_preproc_kwargs.json (if windows were preprocessed)
685
+ target_name.json (if target_name is not None and dataset is raw)
722
686
 
723
687
  Parameters
724
688
  ----------
@@ -815,7 +779,7 @@ class BaseConcatDataset(ConcatDataset):
815
779
  @staticmethod
816
780
  def _save_description(sub_dir, description):
817
781
  description_file_path = os.path.join(sub_dir, "description.json")
818
- description.to_json(description_file_path)
782
+ description.to_json(description_file_path, default_handler=str)
819
783
 
820
784
  @staticmethod
821
785
  def _save_kwargs(sub_dir, ds):
@@ -16,7 +16,7 @@ import numpy as np
16
16
  from mne.utils import verbose
17
17
  from scipy.io import loadmat
18
18
 
19
- from braindecode.datasets import BaseConcatDataset, BaseDataset
19
+ from braindecode.datasets import BaseConcatDataset, RawDataset
20
20
 
21
21
  DATASET_URL = (
22
22
  "https://stacks.stanford.edu/file/druid:zk881ps0522/"
@@ -73,8 +73,8 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
73
73
  file_name=file_path.split("/")[-1],
74
74
  session="test",
75
75
  )
76
- datasets.append(BaseDataset(raw_train, description=desc_train))
77
- datasets.append(BaseDataset(raw_test, description=desc_test))
76
+ datasets.append(RawDataset(raw_train, description=desc_train))
77
+ datasets.append(RawDataset(raw_test, description=desc_test))
78
78
  super().__init__(datasets)
79
79
 
80
80
  @staticmethod
@@ -85,7 +85,7 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
85
85
  ----------
86
86
  path (None | str) – Location of where to look for the data storing location.
87
87
  If None, the environment variable or config parameter
88
- MNE_DATASETS_(dataset)_PATH is used. If it doesnt exist, the “~/mne_data”
88
+ MNE_DATASETS_(dataset)_PATH is used. If it doesn't exist, the “~/mne_data”
89
89
  directory is used. If the dataset is not found under the given path, the data
90
90
  will be automatically downloaded to the specified folder.
91
91
  force_update (bool) – Force update of the dataset even if a local copy exists.
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import pandas as pd
20
20
  from joblib import Parallel, delayed
21
21
 
22
- from .base import BaseConcatDataset, BaseDataset, WindowsDataset
22
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
23
23
 
24
24
 
25
25
  def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
@@ -186,12 +186,12 @@ class BIDSDataset(BaseConcatDataset):
186
186
  )
187
187
  super().__init__(all_base_ds)
188
188
 
189
- def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
189
+ def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> RawDataset:
190
190
  description = _description_from_bids_path(bids_path)
191
191
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)
192
192
  if self.preload:
193
193
  raw.load_data()
194
- return BaseDataset(raw, description)
194
+ return RawDataset(raw, description)
195
195
 
196
196
 
197
197
  class BIDSEpochsDataset(BIDSDataset):
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Callable, Sequence
6
+
7
+ import mne_bids
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+
11
+ class BIDSIterableDataset(IterableDataset):
12
+ """Dataset for loading BIDS.
13
+
14
+ .. warning::
15
+ This class is experimental and may change in the future.
16
+
17
+ .. warning::
18
+ This dataset is not consistent with the Braindecode API.
19
+
20
+ This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
21
+ as it will be used to find the files to load. The default ``extensions`` parameter was changed.
22
+
23
+ More information on BIDS (Brain Imaging Data Structure)
24
+ can be found at https://bids.neuroimaging.io
25
+
26
+ Examples
27
+ --------
28
+ >>> from braindecode.datasets import RecordDataset, BaseConcatDataset
29
+ >>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
30
+ >>> from braindecode.preprocessing import create_fixed_length_windows
31
+ >>>
32
+ >>> def my_reader_fn(path):
33
+ ... raw = mne_bids.read_raw_bids(path)
34
+ ... desc = _description_from_bids_path(path)
35
+ ... ds = RawDataset(raw, description=desc)
36
+ ... windows_ds = create_fixed_length_windows(
37
+ ... BaseConcatDataset([ds]),
38
+ ... window_size_samples=400,
39
+ ... window_stride_samples=200,
40
+ ... )
41
+ ... return windows_ds
42
+ >>>
43
+ >>> dataset = BIDSIterableDataset(
44
+ ... reader_fn=my_reader_fn,
45
+ ... root="root/of/my/bids/dataset/",
46
+ ... )
47
+
48
+ Parameters
49
+ ----------
50
+ reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
51
+ A function that takes a BIDSPath and returns a dataset.
52
+ pool_size : int
53
+ The number of recordings to read and sample from.
54
+ bids_paths : list[mne_bids.BIDSPath] | None
55
+ A list of BIDSPaths to load. If None, will use the paths found by
56
+ :func:`mne_bids.find_matching_paths` and the arguments below.
57
+ root : pathlib.Path | str
58
+ The root of the BIDS path.
59
+ subjects : str | array-like of str | None
60
+ The subject ID. Corresponds to "sub".
61
+ sessions : str | array-like of str | None
62
+ The acquisition session. Corresponds to "ses".
63
+ tasks : str | array-like of str | None
64
+ The experimental task. Corresponds to "task".
65
+ acquisitions: str | array-like of str | None
66
+ The acquisition parameters. Corresponds to "acq".
67
+ runs : str | array-like of str | None
68
+ The run number. Corresponds to "run".
69
+ processings : str | array-like of str | None
70
+ The processing label. Corresponds to "proc".
71
+ recordings : str | array-like of str | None
72
+ The recording name. Corresponds to "rec".
73
+ spaces : str | array-like of str | None
74
+ The coordinate space for anatomical and sensor location
75
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
76
+ Corresponds to "space".
77
+ Note that valid values for ``space`` must come from a list
78
+ of BIDS keywords as described in the BIDS specification.
79
+ splits : str | array-like of str | None
80
+ The split of the continuous recording file for ``.fif`` data.
81
+ Corresponds to "split".
82
+ descriptions : str | array-like of str | None
83
+ This corresponds to the BIDS entity ``desc``. It is used to provide
84
+ additional information for derivative data, e.g., preprocessed data
85
+ may be assigned ``description='cleaned'``.
86
+ suffixes : str | array-like of str | None
87
+ The filename suffix. This is the entity after the
88
+ last ``_`` before the extension. E.g., ``'channels'``.
89
+ The following filename suffix's are accepted:
90
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
91
+ 'participants', 'scans', 'electrodes', 'coordsystem',
92
+ 'channels', 'events', 'headshape', 'digitizer',
93
+ 'beh', 'physio', 'stim'
94
+ extensions : str | array-like of str | None
95
+ The extension of the filename. E.g., ``'.json'``.
96
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
97
+ datatypes : str | array-like of str | None
98
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
99
+ ``'ieeg'``.
100
+ check : bool
101
+ If ``True``, only returns paths that conform to BIDS. If ``False``
102
+ (default), the ``.check`` attribute of the returned
103
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
104
+ do conform to BIDS, and to ``False`` for those that don't.
105
+ preload : bool
106
+ If True, preload the data. Defaults to False.
107
+ n_jobs : int
108
+ Number of jobs to run in parallel. Defaults to 1.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ reader_fn: Callable[[mne_bids.BIDSPath], Sequence],
114
+ pool_size: int = 4,
115
+ bids_paths: list[mne_bids.BIDSPath] | None = None,
116
+ root: Path | str | None = None,
117
+ subjects: str | list[str] | None = None,
118
+ sessions: str | list[str] | None = None,
119
+ tasks: str | list[str] | None = None,
120
+ acquisitions: str | list[str] | None = None,
121
+ runs: str | list[str] | None = None,
122
+ processings: str | list[str] | None = None,
123
+ recordings: str | list[str] | None = None,
124
+ spaces: str | list[str] | None = None,
125
+ splits: str | list[str] | None = None,
126
+ descriptions: str | list[str] | None = None,
127
+ suffixes: str | list[str] | None = None,
128
+ extensions: str | list[str] | None = [
129
+ ".con",
130
+ ".sqd",
131
+ ".pdf",
132
+ ".fif",
133
+ ".ds",
134
+ ".vhdr",
135
+ ".set",
136
+ ".edf",
137
+ ".bdf",
138
+ ".EDF",
139
+ ".snirf",
140
+ ".cdt",
141
+ ".mef",
142
+ ".nwb",
143
+ ],
144
+ datatypes: str | list[str] | None = None,
145
+ check: bool = False,
146
+ ):
147
+ if bids_paths is None:
148
+ bids_paths = mne_bids.find_matching_paths(
149
+ root=root,
150
+ subjects=subjects,
151
+ sessions=sessions,
152
+ tasks=tasks,
153
+ acquisitions=acquisitions,
154
+ runs=runs,
155
+ processings=processings,
156
+ recordings=recordings,
157
+ spaces=spaces,
158
+ splits=splits,
159
+ descriptions=descriptions,
160
+ suffixes=suffixes,
161
+ extensions=extensions,
162
+ datatypes=datatypes,
163
+ check=check,
164
+ ignore_json=True,
165
+ )
166
+ # Filter out _epo.fif files:
167
+ bids_paths = [
168
+ bids_path
169
+ for bids_path in bids_paths
170
+ if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
171
+ ]
172
+ self.bids_paths = bids_paths
173
+ self.reader_fn = reader_fn
174
+ self.pool_size = pool_size
175
+
176
+ def __add__(self, other):
177
+ assert isinstance(other, BIDSIterableDataset)
178
+ return BIDSIterableDataset(
179
+ reader_fn=self.reader_fn,
180
+ bids_paths=self.bids_paths + other.bids_paths,
181
+ pool_size=self.pool_size,
182
+ )
183
+
184
+ def __iadd__(self, other):
185
+ assert isinstance(other, BIDSIterableDataset)
186
+ self.bids_paths += other.bids_paths
187
+ return self
188
+
189
+ def __iter__(self):
190
+ worker_info = get_worker_info()
191
+ if worker_info is None: # single-process data loading, return the full iterator
192
+ bids_paths = self.bids_paths
193
+ else: # in a worker process
194
+ # split workload
195
+ bids_paths = self.bids_paths[worker_info.id :: worker_info.num_workers]
196
+
197
+ pool = []
198
+ end = False
199
+ paths_it = iter(random.sample(bids_paths, k=len(bids_paths)))
200
+ while not (end and len(pool) == 0):
201
+ while not end and len(pool) < self.pool_size:
202
+ try:
203
+ bids_path = next(paths_it)
204
+ ds = self.reader_fn(bids_path)
205
+ if ds is None:
206
+ print(f"Skipping {bids_path} as it is too short.")
207
+ continue
208
+ idx = iter(random.sample(range(len(ds)), k=len(ds)))
209
+ pool.append((ds, idx))
210
+ except StopIteration:
211
+ end = True
212
+ i_pool = random.randint(0, len(pool) - 1)
213
+ ds, idx = pool[i_pool]
214
+ try:
215
+ i_ds = next(idx)
216
+ yield ds[i_ds]
217
+ except StopIteration:
218
+ pool.pop(i_pool)