braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- # mypy: ignore-errors
2
1
  """
3
2
  Shared validation utilities for Hub format operations.
4
3
 
@@ -11,7 +10,7 @@ This module provides validation functions used by hub.py to avoid code duplicati
11
10
 
12
11
  from typing import Any, List, Tuple
13
12
 
14
- from ..registry import get_dataset_type
13
+ from .registry import get_dataset_type
15
14
 
16
15
 
17
16
  def validate_dataset_uniformity(
@@ -25,35 +25,35 @@ def create_from_mne_raw(
25
25
  drop_bad_windows: bool = True,
26
26
  accepted_bads_ratio: float = 0.0,
27
27
  ) -> BaseConcatDataset:
28
- """Create WindowsDatasets from mne.RawArrays.
28
+ """Create WindowsDatasets from mne.RawArrays
29
29
 
30
30
  Parameters
31
31
  ----------
32
- raws : array-like
32
+ raws: array-like
33
33
  list of mne.RawArrays
34
- trial_start_offset_samples : int
34
+ trial_start_offset_samples: int
35
35
  start offset from original trial onsets in samples
36
- trial_stop_offset_samples : int
36
+ trial_stop_offset_samples: int
37
37
  stop offset from original trial stop in samples
38
- window_size_samples : int
38
+ window_size_samples: int
39
39
  window size
40
- window_stride_samples : int
40
+ window_stride_samples: int
41
41
  stride between windows
42
- drop_last_window : bool
42
+ drop_last_window: bool
43
43
  whether or not have a last overlapping window, when
44
44
  windows do not equally divide the continuous signal
45
- descriptions : array-like
45
+ descriptions: array-like
46
46
  list of dicts or pandas.Series with additional information about the raws
47
- mapping : dict(str: int)
47
+ mapping: dict(str: int)
48
48
  mapping from event description to target value
49
- preload : bool
49
+ preload: bool
50
50
  if True, preload the data of the Epochs objects.
51
- drop_bad_windows : bool
51
+ drop_bad_windows: bool
52
52
  If True, call `.drop_bad()` on the resulting mne.Epochs object. This
53
53
  step allows identifying e.g., windows that fall outside of the
54
54
  continuous recording. It is suggested to run this step here as otherwise
55
55
  the BaseConcatDataset has to be updated as well.
56
- accepted_bads_ratio : float, optional
56
+ accepted_bads_ratio: float, optional
57
57
  Acceptable proportion of trials withinconsistent length in a raw. If
58
58
  the number of trials whose length is exceeded by the window size is
59
59
  smaller than this, then only the corresponding trials are dropped, but
@@ -62,7 +62,7 @@ def create_from_mne_raw(
62
62
 
63
63
  Returns
64
64
  -------
65
- windows_datasets : BaseConcatDataset
65
+ windows_datasets: BaseConcatDataset
66
66
  X and y transformed to a dataset format that is compatible with skorch
67
67
  and braindecode
68
68
  """
@@ -101,23 +101,23 @@ def create_from_mne_epochs(
101
101
  window_stride_samples: int,
102
102
  drop_last_window: bool,
103
103
  ) -> BaseConcatDataset:
104
- """Create WindowsDatasets from mne.Epochs.
104
+ """Create WindowsDatasets from mne.Epochs
105
105
 
106
106
  Parameters
107
107
  ----------
108
- list_of_epochs : array-like
108
+ list_of_epochs: array-like
109
109
  list of mne.Epochs
110
- window_size_samples : int
110
+ window_size_samples: int
111
111
  window size
112
- window_stride_samples : int
112
+ window_stride_samples: int
113
113
  stride between windows
114
- drop_last_window : bool
114
+ drop_last_window: bool
115
115
  whether or not have a last overlapping window, when
116
116
  windows do not equally divide the continuous signal
117
117
 
118
118
  Returns
119
119
  -------
120
- windows_datasets : BaseConcatDataset
120
+ windows_datasets: BaseConcatDataset
121
121
  X and y transformed to a dataset format that is compatible with skorch
122
122
  and braindecode
123
123
  """
@@ -90,14 +90,14 @@ def fetch_data_with_moabb(
90
90
 
91
91
  Parameters
92
92
  ----------
93
- dataset_name : str | moabb.datasets.base.BaseDataset
93
+ dataset_name: str | moabb.datasets.base.BaseDataset
94
94
  the name of a dataset included in moabb
95
- subject_ids : list(int) | int
95
+ subject_ids: list(int) | int
96
96
  (list of) int of subject(s) to be fetched
97
- dataset_kwargs : dict, optional
97
+ dataset_kwargs: dict, optional
98
98
  optional dictionary containing keyword arguments
99
99
  to pass to the moabb dataset when instantiating it.
100
- data_load_kwargs : dict, optional
100
+ data_load_kwargs: dict, optional
101
101
  optional dictionary containing keyword arguments
102
102
  to pass to the moabb dataset's load_data method.
103
103
  Allows using the moabb cache_config=None and
@@ -105,8 +105,8 @@ def fetch_data_with_moabb(
105
105
 
106
106
  Returns
107
107
  -------
108
- raws : mne.Raw
109
- info : pandas.DataFrame
108
+ raws: mne.Raw
109
+ info: pandas.DataFrame
110
110
  """
111
111
  if isinstance(dataset_name, str):
112
112
  dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
@@ -127,15 +127,15 @@ class MOABBDataset(BaseConcatDataset):
127
127
 
128
128
  Parameters
129
129
  ----------
130
- dataset_name : str
130
+ dataset_name: str
131
131
  name of dataset included in moabb to be fetched
132
- subject_ids : list(int) | int | None
132
+ subject_ids: list(int) | int | None
133
133
  (list of) int of subject(s) to be fetched. If None, data of all
134
134
  subjects is fetched.
135
- dataset_kwargs : dict, optional
135
+ dataset_kwargs: dict, optional
136
136
  optional dictionary containing keyword arguments
137
137
  to pass to the moabb dataset when instantiating it.
138
- dataset_load_kwargs : dict, optional
138
+ dataset_load_kwargs: dict, optional
139
139
  optional dictionary containing keyword arguments
140
140
  to pass to the moabb dataset's load_data method.
141
141
  Allows using the moabb cache_config=None and
@@ -9,6 +9,7 @@ Note:
9
9
  - The signal unit may not be uV and further examination is required.
10
10
  - The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
11
11
  which needs to be further determined.
12
+
12
13
  """
13
14
 
14
15
  # Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
@@ -31,7 +32,6 @@ from joblib import Parallel, delayed
31
32
  from mne.datasets import fetch_dataset
32
33
 
33
34
  from braindecode.datasets.base import BaseConcatDataset, RawDataset
34
- from braindecode.datasets.utils import _correct_dataset_path
35
35
 
36
36
  NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
37
37
  NMT_archive_name = "NMT.zip"
@@ -66,17 +66,17 @@ class NMT(BaseConcatDataset):
66
66
 
67
67
  Parameters
68
68
  ----------
69
- path : str
69
+ path: str
70
70
  Parent directory of the dataset.
71
- recording_ids : list(int) | int
71
+ recording_ids: list(int) | int
72
72
  A (list of) int of recording id(s) to be read (order matters and will
73
73
  overwrite default chronological order, e.g. if recording_ids=[1,0],
74
74
  then the first recording returned by this class will be chronologically
75
75
  later than the second recording. Provide recording_ids in ascending
76
76
  order to preserve chronological order.).
77
- target_name : str
77
+ target_name: str
78
78
  Can be "pathological", "gender", or "age".
79
- preload : bool
79
+ preload: bool
80
80
  If True, preload the data of the Raw objects.
81
81
 
82
82
  References
@@ -96,34 +96,22 @@ class NMT(BaseConcatDataset):
96
96
  preload=False,
97
97
  n_jobs=1,
98
98
  ):
99
- # Convert empty string to None for consistency
100
- if path == "":
101
- path = None
99
+ # correct the path if needed
100
+ if path is not None:
101
+ list_csv = glob.glob(f"{path}/**/Labels.csv", recursive=True)
102
+ if isinstance(list_csv, list) and len(list_csv) > 0:
103
+ path = Path(list_csv[0]).parent
102
104
 
103
- # Download dataset if not present
104
- if path is None:
105
+ if path is None or len(list_csv) == 0:
105
106
  path = fetch_dataset(
106
107
  dataset_params=NMT_dataset_params,
107
- path=None,
108
+ path=Path(path) if path is not None else None,
108
109
  processor="unzip",
109
110
  force_update=False,
110
111
  )
111
112
  # First time we fetch the dataset, we need to move the files to the
112
113
  # correct directory.
113
- path = _correct_dataset_path(
114
- path, NMT_archive_name, "nmt_scalp_eeg_dataset"
115
- )
116
- else:
117
- # Validate that the provided path is a valid NMT dataset
118
- if not Path(f"{path}/Labels.csv").exists():
119
- raise ValueError(
120
- f"The provided path {path} does not contain a valid "
121
- "NMT dataset (missing Labels.csv). Please ensure the "
122
- "path points directly to the NMT dataset directory."
123
- )
124
- path = _correct_dataset_path(
125
- path, NMT_archive_name, "nmt_scalp_eeg_dataset"
126
- )
114
+ path = _correct_path(path)
127
115
 
128
116
  # Get all file paths
129
117
  file_paths = glob.glob(
@@ -149,10 +137,7 @@ class NMT(BaseConcatDataset):
149
137
  os.path.join(path, "Labels.csv"), index_col="recordname"
150
138
  )
151
139
  if recording_ids is not None:
152
- # Match metadata by record name instead of position to fix alignment bug
153
- # when CSV order differs from sorted file order
154
- selected_recordnames = [os.path.basename(fp) for fp in file_paths]
155
- description = description.loc[selected_recordnames]
140
+ description = description.iloc[recording_ids]
156
141
  description.replace(
157
142
  {
158
143
  "not specified": "X",
@@ -191,6 +176,39 @@ class NMT(BaseConcatDataset):
191
176
  return base_dataset
192
177
 
193
178
 
179
+ def _correct_path(path: str):
180
+ """
181
+ Check if the path is correct and rename the file if needed.
182
+
183
+ Parameters
184
+ ----------
185
+ path: basestring
186
+ Path to the file.
187
+
188
+ Returns
189
+ -------
190
+ path: basestring
191
+ Corrected path.
192
+ """
193
+ if not Path(path).exists():
194
+ unzip_file_name = f"{NMT_archive_name}.unzip"
195
+ if (Path(path).parent / unzip_file_name).exists():
196
+ try:
197
+ os.rename(
198
+ src=Path(path).parent / unzip_file_name,
199
+ dst=Path(path),
200
+ )
201
+
202
+ except PermissionError:
203
+ raise PermissionError(
204
+ f"Please rename {Path(path).parent / unzip_file_name}"
205
+ + f"manually to {path} and try again."
206
+ )
207
+ path = os.path.join(path, "nmt_scalp_eeg_dataset")
208
+
209
+ return path
210
+
211
+
194
212
  def _get_header(*args):
195
213
  all_paths = {**_NMT_PATHS}
196
214
  return all_paths[args[0]]
@@ -198,24 +216,19 @@ def _get_header(*args):
198
216
 
199
217
  def _fake_pd_read_csv(*args, **kwargs):
200
218
  # Create a list of lists to hold the data
201
- # Updated to match the file IDs from _NMT_PATHS (0000036-0000042)
202
- # to align with the mocked glob.glob return value
203
219
  data = [
204
- ["0000036.edf", "normal", 35, "male", "train"],
205
- ["0000037.edf", "abnormal", 28, "female", "test"],
206
- ["0000038.edf", "normal", 62, "male", "train"],
207
- ["0000039.edf", "abnormal", 41, "female", "test"],
208
- ["0000040.edf", "normal", 19, "male", "train"],
209
- ["0000041.edf", "abnormal", 55, "female", "test"],
210
- ["0000042.edf", "normal", 71, "male", "train"],
220
+ ["0000001.edf", "normal", 35, "male", "train"],
221
+ ["0000002.edf", "abnormal", 28, "female", "test"],
222
+ ["0000003.edf", "normal", 62, "male", "train"],
223
+ ["0000004.edf", "abnormal", 41, "female", "test"],
224
+ ["0000005.edf", "normal", 19, "male", "train"],
225
+ ["0000006.edf", "abnormal", 55, "female", "test"],
226
+ ["0000007.edf", "normal", 71, "male", "train"],
211
227
  ]
212
228
 
213
229
  # Create the DataFrame, specifying column names
214
230
  df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
215
231
 
216
- # Set recordname as index to match the real pd.read_csv behavior with index_col="recordname"
217
- df.set_index("recordname", inplace=True)
218
-
219
232
  return df
220
233
 
221
234
 
@@ -275,33 +288,18 @@ _NMT_PATHS = {
275
288
  class _NMTMock(NMT):
276
289
  """Mocked class for testing and examples."""
277
290
 
278
- @mock.patch("pathlib.Path.exists", return_value=True)
279
- @mock.patch("braindecode.datasets.nmt._correct_dataset_path")
280
- @mock.patch("mne.datasets.fetch_dataset")
281
- @mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
282
- @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
283
291
  @mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
292
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
293
+ @mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
284
294
  def __init__(
285
295
  self,
286
296
  mock_glob,
287
- mock_fetch,
288
- mock_correct_path,
289
- mock_path_exists,
290
297
  path,
291
298
  recording_ids=None,
292
299
  target_name="pathological",
293
300
  preload=False,
294
301
  n_jobs=1,
295
302
  ):
296
- # Prevent download by providing a dummy path if empty/None
297
- if not path:
298
- path = "mocked_nmt_path"
299
-
300
- # Mock fetch_dataset to return a valid path without downloading
301
- mock_fetch.return_value = path
302
- # Mock _correct_dataset_path to return the path as-is
303
- mock_correct_path.side_effect = lambda p, *args, **kwargs: p
304
-
305
303
  with warnings.catch_warnings():
306
304
  warnings.filterwarnings("ignore", message="Cannot save date file")
307
305
  super().__init__(
@@ -1,4 +1,6 @@
1
- """PhysioNet Challenge 2018 dataset."""
1
+ """
2
+ PhysioNet Challenge 2018 dataset.
3
+ """
2
4
 
3
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
6
  # Bruno Aristimunha <b.aristimunha@gmail.com>
@@ -251,7 +253,7 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
251
253
 
252
254
  Parameters
253
255
  ----------
254
- subject_ids : list(int) | str | None
256
+ subject_ids: list(int) | str | None
255
257
  (list of) int of subject(s) to be loaded.
256
258
  - If `None`, loads all subjects (both training and test sets [no label associated]).
257
259
  - If `"training"`, loads only the training set subjects.
@@ -263,7 +265,7 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
263
265
  is used. If it doesn't exist, the "~/mne_data" directory is used. If
264
266
  the dataset is not found under the given path, the data will be
265
267
  automatically downloaded to the specified folder.
266
- load_eeg_only : bool
268
+ load_eeg_only: bool
267
269
  If True, only load the EEG channels and discard the others (EOG, EMG,
268
270
  temperature, respiration) to avoid resampling the other signals.
269
271
  preproc : list(Preprocessor) | None
@@ -25,18 +25,18 @@ class SleepPhysionet(BaseConcatDataset):
25
25
 
26
26
  Parameters
27
27
  ----------
28
- subject_ids : list(int) | int | None
28
+ subject_ids: list(int) | int | None
29
29
  (list of) int of subject(s) to be loaded. If None, load all available
30
30
  subjects.
31
- recording_ids : list(int) | None
31
+ recording_ids: list(int) | None
32
32
  Recordings to load per subject (each subject except 13 has two
33
33
  recordings). Can be [1], [2] or [1, 2] (same as None).
34
- preload : bool
34
+ preload: bool
35
35
  If True, preload the data of the Raw objects.
36
- load_eeg_only : bool
36
+ load_eeg_only: bool
37
37
  If True, only load the EEG channels and discard the others (EOG, EMG,
38
38
  temperature, respiration) to avoid resampling the other signals.
39
- crop_wake_mins : float
39
+ crop_wake_mins: float
40
40
  Number of minutes of wake time to keep before the first sleep event
41
41
  and after the last sleep event. Used to reduce the imbalance in this
42
42
  dataset. Default of 30 mins.
@@ -1,6 +1,5 @@
1
1
  """
2
- Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the.
3
-
2
+ Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the
4
3
  TUH Abnormal EEG Corpus.
5
4
  """
6
5
 
@@ -27,32 +26,31 @@ from .base import BaseConcatDataset, RawDataset
27
26
 
28
27
 
29
28
  class TUH(BaseConcatDataset):
30
- """Temple University Hospital (TUH) EEG Corpus.
31
-
29
+ """Temple University Hospital (TUH) EEG Corpus
32
30
  (www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
33
31
 
34
32
  Parameters
35
33
  ----------
36
- path : str
34
+ path: str
37
35
  Parent directory of the dataset.
38
- recording_ids : list(int) | int
36
+ recording_ids: list(int) | int
39
37
  A (list of) int of recording id(s) to be read (order matters and will
40
38
  overwrite default chronological order, e.g. if recording_ids=[1,0],
41
39
  then the first recording returned by this class will be chronologically
42
40
  later then the second recording. Provide recording_ids in ascending
43
41
  order to preserve chronological order.).
44
- target_name : str
42
+ target_name: str
45
43
  Can be 'gender', or 'age'.
46
- preload : bool
44
+ preload: bool
47
45
  If True, preload the data of the Raw objects.
48
- add_physician_reports : bool
46
+ add_physician_reports: bool
49
47
  If True, the physician reports will be read from disk and added to the
50
48
  description.
51
- rename_channels : bool
49
+ rename_channels: bool
52
50
  If True, rename the EEG channels to the standard 10-05 system.
53
- set_montage : bool
51
+ set_montage: bool
54
52
  If True, set the montage to the standard 10-05 system.
55
- n_jobs : int
53
+ n_jobs: int
56
54
  Number of jobs to be used to read files in parallel.
57
55
  """
58
56
 
@@ -381,31 +379,30 @@ def _parse_age_and_gender_from_edf_header(file_path):
381
379
 
382
380
  class TUHAbnormal(TUH):
383
381
  """Temple University Hospital (TUH) Abnormal EEG Corpus.
384
-
385
382
  see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
386
383
 
387
384
  Parameters
388
385
  ----------
389
- path : str
386
+ path: str
390
387
  Parent directory of the dataset.
391
- recording_ids : list(int) | int
388
+ recording_ids: list(int) | int
392
389
  A (list of) int of recording id(s) to be read (order matters and will
393
390
  overwrite default chronological order, e.g. if recording_ids=[1,0],
394
391
  then the first recording returned by this class will be chronologically
395
392
  later then the second recording. Provide recording_ids in ascending
396
393
  order to preserve chronological order.).
397
- target_name : str
394
+ target_name: str
398
395
  Can be 'pathological', 'gender', or 'age'.
399
- preload : bool
396
+ preload: bool
400
397
  If True, preload the data of the Raw objects.
401
- add_physician_reports : bool
398
+ add_physician_reports: bool
402
399
  If True, the physician reports will be read from disk and added to the
403
400
  description.
404
- rename_channels : bool
401
+ rename_channels: bool
405
402
  If True, rename the EEG channels to the standard 10-05 system.
406
- set_montage : bool
403
+ set_montage: bool
407
404
  If True, set the montage to the standard 10-05 system.
408
- n_jobs : int
405
+ n_jobs: int
409
406
  Number of jobs to be used to read files in parallel.
410
407
  """
411
408
 
@@ -26,32 +26,31 @@ def create_from_X_y(
26
26
  window_size_samples: int | None = None,
27
27
  window_stride_samples: int | None = None,
28
28
  ) -> BaseConcatDataset:
29
- """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for.
30
-
29
+ """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
31
30
  decoding with skorch and braindecode, where X is a list of pre-cut trials
32
31
  and y are corresponding targets.
33
32
 
34
33
  Parameters
35
34
  ----------
36
- X : array-like
35
+ X: array-like
37
36
  list of pre-cut trials as n_trials x n_channels x n_times
38
- y : array-like
37
+ y: array-like
39
38
  targets corresponding to the trials
40
- drop_last_window : bool
39
+ drop_last_window: bool
41
40
  whether or not have a last overlapping window, when
42
41
  windows/windows do not equally divide the continuous signal
43
- sfreq : float
42
+ sfreq: float
44
43
  Sampling frequency of signals.
45
- ch_names : array-like
44
+ ch_names: array-like
46
45
  Names of the channels.
47
- window_size_samples : int
46
+ window_size_samples: int
48
47
  window size
49
- window_stride_samples : int
48
+ window_stride_samples: int
50
49
  stride between windows
51
50
 
52
51
  Returns
53
52
  -------
54
- windows_datasets : BaseConcatDataset
53
+ windows_datasets: BaseConcatDataset
55
54
  X and y transformed to a dataset format that is compatible with skorch
56
55
  and braindecode
57
56
  """
@@ -1,4 +1,6 @@
1
- """Utilities for data manipulation."""
1
+ """
2
+ Utilities for data manipulation.
3
+ """
2
4
 
3
5
  from .channel_utils import (
4
6
  division_channels_idx,
@@ -9,7 +11,6 @@ from .serialization import (
9
11
  load_concat_dataset,
10
12
  save_concat_dataset,
11
13
  )
12
- from .util import infer_signal_properties
13
14
 
14
15
 
15
16
  def __getattr__(name):
@@ -58,5 +59,4 @@ __all__ = [
58
59
  "_check_save_dir_empty",
59
60
  "match_hemisphere_chans",
60
61
  "division_channels_idx",
61
- "infer_signal_properties",
62
62
  ]
@@ -1,4 +1,6 @@
1
- """Convenience functions for storing and loading of windows datasets."""
1
+ """
2
+ Convenience functions for storing and loading of windows datasets.
3
+ """
2
4
 
3
5
  # Authors: Lukas Gemein <l.gemein@gmail.com>
4
6
  #
@@ -33,25 +35,24 @@ def save_concat_dataset(path, concat_dataset, overwrite=False):
33
35
 
34
36
 
35
37
  def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
36
- """Load a stored BaseConcatDataset from.
37
-
38
+ """Load a stored BaseConcatDataset from
38
39
  files.
39
40
 
40
41
  Parameters
41
42
  ----------
42
- path : pathlib.Path
43
+ path: pathlib.Path
43
44
  Path to the directory of the .fif / -epo.fif and .json files.
44
- preload : bool
45
+ preload: bool
45
46
  Whether to preload the data.
46
- ids_to_load : None | list(int)
47
+ ids_to_load: None | list(int)
47
48
  Ids of specific files to load.
48
- target_name : None or str
49
+ target_name: None or str
49
50
  Load specific description column as target. If not given, take saved
50
51
  target name.
51
52
 
52
53
  Returns
53
54
  -------
54
- concat_dataset : BaseConcatDataset
55
+ concat_dataset: BaseConcatDataset
55
56
  """
56
57
  # assume we have a single concat dataset to load
57
58
  is_raw = (path / "0-raw.fif").is_file()
@@ -137,7 +138,7 @@ def _load_signals(fif_file, preload, is_raw):
137
138
  with open(pkl_file, "rb") as f:
138
139
  signals = pickle.load(f)
139
140
 
140
- if all(Path(f).exists() for f in signals.filenames):
141
+ if all(f.exists() for f in signals.filenames):
141
142
  if preload:
142
143
  signals.load_data()
143
144
  return signals
@@ -174,27 +175,26 @@ def _load_signals(fif_file, preload, is_raw):
174
175
 
175
176
 
176
177
  def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
177
- """Load a stored BaseConcatDataset from.
178
-
178
+ """Load a stored BaseConcatDataset from
179
179
  files.
180
180
 
181
181
  Parameters
182
182
  ----------
183
- path : str | pathlib.Path
183
+ path: str | pathlib.Path
184
184
  Path to the directory of the .fif / -epo.fif and .json files.
185
- preload : bool
185
+ preload: bool
186
186
  Whether to preload the data.
187
- ids_to_load : list of int | None
187
+ ids_to_load: list of int | None
188
188
  Ids of specific files to load.
189
- target_name : str | list | None
189
+ target_name: str | list | None
190
190
  Load specific description column as target. If not given, take saved
191
191
  target name.
192
- n_jobs : int
192
+ n_jobs: int
193
193
  Number of jobs to be used to read files in parallel.
194
194
 
195
195
  Returns
196
196
  -------
197
- concat_dataset : BaseConcatDataset
197
+ concat_dataset: BaseConcatDataset
198
198
  """
199
199
  # Make sure we always work with a pathlib.Path
200
200
  path = Path(path)
@@ -306,11 +306,9 @@ def _load_kwargs_json(kwargs_name, sub_dir):
306
306
 
307
307
 
308
308
  def _is_outdated_saved(path):
309
- """Data was saved in the old way if there are 'description.json', '-raw.fif'.
310
-
309
+ """Data was saved in the old way if there are 'description.json', '-raw.fif'
311
310
  or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
312
- than 'description.json' files.
313
- """
311
+ than 'description.json' files."""
314
312
  description_files = glob(os.path.join(path, "**/description.json"))
315
313
  fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
316
314
  os.path.join(path, "**/*-epo.fif")
@@ -344,7 +342,7 @@ def _check_save_dir_empty(save_dir):
344
342
  Directory under which a `BaseConcatDataset` will be saved.
345
343
 
346
344
  Raises
347
- ------
345
+ -------
348
346
  FileExistsError
349
347
  If ``save_dir`` is not a valid directory for saving.
350
348
  """