braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,29 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
 
6
- import numpy as np
7
- import pandas as pd
6
+ from __future__ import annotations
7
+
8
8
  import logging
9
+
9
10
  import mne
11
+ import numpy as np
12
+ import pandas as pd
13
+ from numpy.typing import ArrayLike, NDArray
10
14
 
11
- from .base import BaseDataset, BaseConcatDataset
15
+ from .base import BaseConcatDataset, BaseDataset
12
16
 
13
17
  log = logging.getLogger(__name__)
14
18
 
15
19
 
16
20
  def create_from_X_y(
17
- X, y, drop_last_window, sfreq, ch_names=None, window_size_samples=None,
18
- window_stride_samples=None):
21
+ X: NDArray,
22
+ y: ArrayLike,
23
+ drop_last_window: bool,
24
+ sfreq: float,
25
+ ch_names: ArrayLike = None,
26
+ window_size_samples: int | None = None,
27
+ window_stride_samples: int | None = None,
28
+ ) -> BaseConcatDataset:
19
29
  """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
20
30
  decoding with skorch and braindecode, where X is a list of pre-cut trials
21
31
  and y are corresponding targets.
@@ -46,7 +56,9 @@ def create_from_X_y(
46
56
  """
47
57
  # Prevent circular import
48
58
  from ..preprocessing.windowers import (
49
- create_fixed_length_windows, )
59
+ create_fixed_length_windows,
60
+ )
61
+
50
62
  n_samples_per_x = []
51
63
  base_datasets = []
52
64
  if ch_names is None:
@@ -57,16 +69,19 @@ def create_from_X_y(
57
69
  n_samples_per_x.append(x.shape[1])
58
70
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
59
71
  raw = mne.io.RawArray(x, info)
60
- base_dataset = BaseDataset(raw, pd.Series({"target": target}),
61
- target_name="target")
72
+ base_dataset = BaseDataset(
73
+ raw, pd.Series({"target": target}), target_name="target"
74
+ )
62
75
  base_datasets.append(base_dataset)
63
76
  base_datasets = BaseConcatDataset(base_datasets)
64
77
 
65
78
  if window_size_samples is None and window_stride_samples is None:
66
79
  if not len(np.unique(n_samples_per_x)) == 1:
67
- raise ValueError("if 'window_size_samples' and "
68
- "'window_stride_samples' are None, "
69
- "all trials have to have the same length")
80
+ raise ValueError(
81
+ "if 'window_size_samples' and "
82
+ "'window_stride_samples' are None, "
83
+ "all trials have to have the same length"
84
+ )
70
85
  window_size_samples = n_samples_per_x[0]
71
86
  window_stride_samples = n_samples_per_x[0]
72
87
  windows_datasets = create_fixed_length_windows(
@@ -75,6 +90,6 @@ def create_from_X_y(
75
90
  stop_offset_samples=None,
76
91
  window_size_samples=window_size_samples,
77
92
  window_stride_samples=window_stride_samples,
78
- drop_last_window=drop_last_window
93
+ drop_last_window=drop_last_window,
79
94
  )
80
95
  return windows_datasets
@@ -2,32 +2,48 @@
2
2
  Utilities for data manipulation.
3
3
  """
4
4
 
5
-
6
5
  from .serialization import (
7
- save_concat_dataset, load_concat_dataset, _check_save_dir_empty)
6
+ _check_save_dir_empty,
7
+ load_concat_dataset,
8
+ save_concat_dataset,
9
+ )
8
10
 
9
11
 
10
12
  def __getattr__(name):
11
13
  # ideas from https://stackoverflow.com/a/57110249/1469195
12
- from warnings import warn
13
14
  import importlib
14
- if name == 'create_from_X_y':
15
- warn('create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y')
16
- xy = importlib.import_module('..datasets.xy', __package__)
15
+ from warnings import warn
16
+
17
+ if name == "create_from_X_y":
18
+ warn(
19
+ "create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
20
+ )
21
+ xy = importlib.import_module("..datasets.xy", __package__)
17
22
  return xy.create_from_X_y
18
- if name in ['create_from_mne_raw', 'create_from_mne_epochs']:
19
- warn(f'{name} has been moved to datasets, please use from braindecode.datasets import {name}')
20
- mne = importlib.import_module('..datasets.mne', __package__)
23
+ if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
24
+ warn(
25
+ f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
26
+ )
27
+ mne = importlib.import_module("..datasets.mne", __package__)
21
28
  return mne.__dict__[name]
22
- if name in ['zscore', 'scale', 'exponential_moving_demean',
23
- 'exponential_moving_standardize', 'filterbank',
24
- 'preprocess', 'Preprocessor']:
25
- warn(f'{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}')
26
- preprocess = importlib.import_module('..preprocessing.preprocess', __package__)
29
+ if name in [
30
+ "scale",
31
+ "exponential_moving_demean",
32
+ "exponential_moving_standardize",
33
+ "filterbank",
34
+ "preprocess",
35
+ "Preprocessor",
36
+ ]:
37
+ warn(
38
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
39
+ )
40
+ preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
27
41
  return preprocess.__dict__[name]
28
- if name in ['create_windows_from_events', 'create_fixed_length_windows']:
29
- warn(f'{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}')
30
- windowers = importlib.import_module('..preprocessing.windowers', __package__)
42
+ if name in ["create_windows_from_events", "create_fixed_length_windows"]:
43
+ warn(
44
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
45
+ )
46
+ windowers = importlib.import_module("..preprocessing.windowers", __package__)
31
47
  return windowers.__dict__[name]
32
48
 
33
- raise AttributeError('No possible import named ' + name)
49
+ raise AttributeError("No possible import named " + name)
@@ -6,8 +6,8 @@ Convenience functions for storing and loading of windows datasets.
6
6
  #
7
7
  # License: BSD (3-clause)
8
8
 
9
- import os
10
9
  import json
10
+ import os
11
11
  import pickle
12
12
  import warnings
13
13
  from glob import glob
@@ -17,17 +17,24 @@ import mne
17
17
  import pandas as pd
18
18
  from joblib import Parallel, delayed
19
19
 
20
- from ..datasets.base import BaseDataset, BaseConcatDataset, WindowsDataset, EEGWindowsDataset
20
+ from ..datasets.base import (
21
+ BaseConcatDataset,
22
+ BaseDataset,
23
+ EEGWindowsDataset,
24
+ WindowsDataset,
25
+ )
21
26
 
22
27
 
23
28
  def save_concat_dataset(path, concat_dataset, overwrite=False):
24
- warnings.warn('"save_concat_dataset()" is deprecated and will be removed in'
25
- ' the future. Use dataset.save() instead.')
29
+ warnings.warn(
30
+ '"save_concat_dataset()" is deprecated and will be removed in'
31
+ " the future. Use dataset.save() instead.",
32
+ UserWarning,
33
+ )
26
34
  concat_dataset.save(path=path, overwrite=overwrite)
27
35
 
28
36
 
29
- def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
30
- target_name=None):
37
+ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
31
38
  """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
32
39
  files.
33
40
 
@@ -48,15 +55,16 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
48
55
  concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
49
56
  """
50
57
  # assume we have a single concat dataset to load
51
- is_raw = (path / '0-raw.fif').is_file()
58
+ is_raw = (path / "0-raw.fif").is_file()
52
59
  assert not (not is_raw and target_name is not None), (
53
- 'Setting a new target is only supported for raws.')
54
- is_epochs = (path / '0-epo.fif').is_file()
60
+ "Setting a new target is only supported for raws."
61
+ )
62
+ is_epochs = (path / "0-epo.fif").is_file()
55
63
  paths = [path]
56
64
  # assume we have multiple concat datasets to load
57
65
  if not (is_raw or is_epochs):
58
- is_raw = (path / '0' / '0-raw.fif').is_file()
59
- is_epochs = (path / '0' / '0-epo.fif').is_file()
66
+ is_raw = (path / "0" / "0-raw.fif").is_file()
67
+ is_epochs = (path / "0" / "0-epo.fif").is_file()
60
68
  paths = path.glob("*/")
61
69
  paths = sorted(paths, key=lambda p: int(p.name))
62
70
  if ids_to_load is not None:
@@ -64,33 +72,32 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
64
72
  ids_to_load = None
65
73
  # if we have neither a single nor multiple datasets, something went wrong
66
74
  assert is_raw or is_epochs, (
67
- f'Expect either raw or epo to exist in {path} or in '
68
- f'{path / "0"}')
75
+ f"Expect either raw or epo to exist in {path} or in {path / '0'}"
76
+ )
69
77
 
70
78
  datasets = []
71
79
  for path in paths:
72
80
  if is_raw and target_name is None:
73
- target_file_name = path / 'target_name.json'
74
- target_name = json.load(open(target_file_name, "r"))['target_name']
81
+ target_file_name = path / "target_name.json"
82
+ target_name = json.load(open(target_file_name, "r"))["target_name"]
75
83
 
76
84
  all_signals, description = _load_signals_and_description(
77
- path=path, preload=preload, is_raw=is_raw,
78
- ids_to_load=ids_to_load
85
+ path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
79
86
  )
80
87
  for i_signal, signal in enumerate(all_signals):
81
88
  if is_raw:
82
89
  datasets.append(
83
- BaseDataset(signal, description.iloc[i_signal],
84
- target_name=target_name))
85
- else:
86
- datasets.append(
87
- WindowsDataset(signal, description.iloc[i_signal])
90
+ BaseDataset(
91
+ signal, description.iloc[i_signal], target_name=target_name
92
+ )
88
93
  )
94
+ else:
95
+ datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
89
96
  concat_ds = BaseConcatDataset(datasets)
90
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs', 'window_preproc_kwargs']:
91
- kwarg_path = path / '.'.join([kwarg_name, 'json'])
97
+ for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
98
+ kwarg_path = path / ".".join([kwarg_name, "json"])
92
99
  if kwarg_path.exists():
93
- with open(kwarg_path, 'r') as f:
100
+ with open(kwarg_path, "r") as f:
94
101
  kwargs = json.load(f)
95
102
  kwargs = [tuple(kwarg) for kwarg in kwargs]
96
103
  setattr(concat_ds, kwarg_name, kwargs)
@@ -107,7 +114,8 @@ def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
107
114
  # '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
108
115
  # '11-raw.fif' -> 11
109
116
  ids_to_load = sorted(
110
- [int(os.path.split(f)[-1].split('-')[0]) for f in file_names])
117
+ [int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
118
+ )
111
119
  for i in ids_to_load:
112
120
  fif_file = path / file_name.format(i)
113
121
  all_signals.append(_load_signals(fif_file, preload, is_raw))
@@ -133,10 +141,10 @@ def _load_signals(fif_file, preload, is_raw):
133
141
  # If pickle didn't exist read via mne (likely slower) and save pkl after
134
142
  if is_raw:
135
143
  signals = mne.io.read_raw_fif(fif_file, preload=preload)
136
- elif fif_file.name.endswith('-epo.fif'):
144
+ elif fif_file.name.endswith("-epo.fif"):
137
145
  signals = mne.read_epochs(fif_file, preload=preload)
138
146
  else:
139
- raise ValueError('fif_file must end with raw.fif or epo.fif.')
147
+ raise ValueError("fif_file must end with raw.fif or epo.fif.")
140
148
 
141
149
  # Only do this for raw objects. Epoch objects are not picklable as they
142
150
  # hold references to open files in `signals._raw[0].fid`.
@@ -154,8 +162,7 @@ def _load_signals(fif_file, preload, is_raw):
154
162
  return signals
155
163
 
156
164
 
157
- def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
158
- n_jobs=1):
165
+ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
159
166
  """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
160
167
  files.
161
168
 
@@ -183,11 +190,14 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
183
190
  # if we encounter a dataset that was saved in 'the old way', call the
184
191
  # corresponding 'old' loading function
185
192
  if _is_outdated_saved(path):
186
- warnings.warn("The way your dataset was saved is deprecated by now. "
187
- "Please save it again using dataset.save().", UserWarning)
193
+ warnings.warn(
194
+ "The way your dataset was saved is deprecated by now. "
195
+ "Please save it again using dataset.save().",
196
+ UserWarning,
197
+ )
188
198
  return _outdated_load_concat_dataset(
189
- path=path, preload=preload, ids_to_load=ids_to_load,
190
- target_name=target_name)
199
+ path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
200
+ )
191
201
 
192
202
  # else we have a dataset saved in the new way with subdirectories in path
193
203
  # for every dataset with description.json and -epo.fif or -raw.fif,
@@ -197,9 +207,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
197
207
  ids_to_load = [p.name for p in path.iterdir()]
198
208
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
199
209
  ids_to_load = [str(i) for i in ids_to_load]
200
- first_raw_fif_path = path / ids_to_load[0] / f'{ids_to_load[0]}-raw.fif'
210
+ first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
201
211
  is_raw = first_raw_fif_path.exists()
202
- metadata_path = path / ids_to_load[0] / 'metadata_df.pkl'
212
+ metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
203
213
  has_stored_windows = metadata_path.exists()
204
214
 
205
215
  # Parallelization of mne.read_epochs with preload=False fails with
@@ -207,8 +217,10 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
207
217
  # So ignore n_jobs in that case and load with a single job.
208
218
  if not is_raw and n_jobs != 1:
209
219
  warnings.warn(
210
- 'Parallelized reading with `preload=False` is not supported for '
211
- 'windowed data. Will use `n_jobs=1`.', UserWarning)
220
+ "Parallelized reading with `preload=False` is not supported for "
221
+ "windowed data. Will use `n_jobs=1`.",
222
+ UserWarning,
223
+ )
212
224
  n_jobs = 1
213
225
  datasets = Parallel(n_jobs)(
214
226
  delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
@@ -219,9 +231,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
219
231
 
220
232
  def _load_parallel(path, i, preload, is_raw, has_stored_windows):
221
233
  sub_dir = path / i
222
- file_name_patterns = ['{}-raw.fif', '{}-epo.fif']
234
+ file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
223
235
  if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
224
- raise FileExistsError('Found -raw.fif and -epo.fif in directory.')
236
+ raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
225
237
 
226
238
  fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
227
239
  fif_file_name = fif_name_pattern.format(i)
@@ -229,48 +241,51 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
229
241
 
230
242
  signals = _load_signals(fif_file_path, preload, is_raw)
231
243
 
232
- description_file_path = sub_dir / 'description.json'
233
- description = pd.read_json(description_file_path, typ='series')
244
+ description_file_path = sub_dir / "description.json"
245
+ description = pd.read_json(description_file_path, typ="series")
234
246
 
235
- target_file_path = sub_dir / 'target_name.json'
247
+ target_file_path = sub_dir / "target_name.json"
236
248
  target_name = None
237
249
  if target_file_path.exists():
238
- target_name = json.load(open(target_file_path, "r"))['target_name']
250
+ target_name = json.load(open(target_file_path, "r"))["target_name"]
239
251
 
240
252
  if is_raw and (not has_stored_windows):
241
253
  dataset = BaseDataset(signals, description, target_name)
242
254
  else:
243
- window_kwargs = _load_kwargs_json('window_kwargs', sub_dir)
244
- windows_ds_kwargs = [kwargs[1] for kwargs in window_kwargs if kwargs[0] == 'WindowsDataset']
255
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
256
+ windows_ds_kwargs = [
257
+ kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
258
+ ]
245
259
  windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
246
260
  if is_raw:
247
- metadata = pd.read_pickle(path / i / 'metadata_df.pkl')
261
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
248
262
  dataset = EEGWindowsDataset(
249
263
  signals,
250
264
  metadata=metadata,
251
265
  description=description,
252
- targets_from=windows_ds_kwargs.get('targets_from', 'metadata'),
253
- last_target_only=windows_ds_kwargs.get('last_target_only', True),
266
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
267
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
254
268
  )
255
269
  else:
256
270
  # MNE epochs dataset
257
271
  dataset = WindowsDataset(
258
- signals, description,
259
- targets_from=windows_ds_kwargs.get('targets_from', 'metadata'),
260
- last_target_only=windows_ds_kwargs.get('last_target_only', True)
272
+ signals,
273
+ description,
274
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
275
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
261
276
  )
262
- setattr(dataset, 'window_kwargs', window_kwargs)
263
- for kwargs_name in ['raw_preproc_kwargs', 'window_preproc_kwargs']:
277
+ setattr(dataset, "window_kwargs", window_kwargs)
278
+ for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
264
279
  kwargs = _load_kwargs_json(kwargs_name, sub_dir)
265
280
  setattr(dataset, kwargs_name, kwargs)
266
281
  return dataset
267
282
 
268
283
 
269
284
  def _load_kwargs_json(kwargs_name, sub_dir):
270
- kwargs_file_name = '.'.join([kwargs_name, 'json'])
285
+ kwargs_file_name = ".".join([kwargs_name, "json"])
271
286
  kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
272
287
  if os.path.exists(kwargs_file_path):
273
- kwargs = json.load(open(kwargs_file_path, 'r'))
288
+ kwargs = json.load(open(kwargs_file_path, "r"))
274
289
  kwargs = [tuple(kwarg) for kwarg in kwargs]
275
290
  return kwargs
276
291
 
@@ -279,18 +294,28 @@ def _is_outdated_saved(path):
279
294
  """Data was saved in the old way if there are 'description.json', '-raw.fif'
280
295
  or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
281
296
  than 'description.json' files."""
282
- description_files = glob(os.path.join(path, '**/description.json'))
283
- fif_files = glob(os.path.join(path, '**/*-raw.fif')) + glob(os.path.join(path, '**/*-epo.fif'))
297
+ description_files = glob(os.path.join(path, "**/description.json"))
298
+ fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
299
+ os.path.join(path, "**/*-epo.fif")
300
+ )
284
301
  multiple = len(description_files) != len(fif_files)
285
302
  kwargs_in_path = any(
286
- [os.path.exists(os.path.join(path, kwarg_name))
287
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs',
288
- 'window_preproc_kwargs']])
289
- return (os.path.exists(os.path.join(path, 'description.json')) or
290
- os.path.exists(os.path.join(path, '0-raw.fif')) or
291
- os.path.exists(os.path.join(path, '0-epo.fif')) or
292
- multiple or
293
- kwargs_in_path)
303
+ [
304
+ os.path.exists(os.path.join(path, kwarg_name))
305
+ for kwarg_name in [
306
+ "raw_preproc_kwargs",
307
+ "window_kwargs",
308
+ "window_preproc_kwargs",
309
+ ]
310
+ ]
311
+ )
312
+ return (
313
+ os.path.exists(os.path.join(path, "description.json"))
314
+ or os.path.exists(os.path.join(path, "0-raw.fif"))
315
+ or os.path.exists(os.path.join(path, "0-epo.fif"))
316
+ or multiple
317
+ or kwargs_in_path
318
+ )
294
319
 
295
320
 
296
321
  def _check_save_dir_empty(save_dir):
@@ -306,10 +331,12 @@ def _check_save_dir_empty(save_dir):
306
331
  FileExistsError
307
332
  If ``save_dir`` is not a valid directory for saving.
308
333
  """
309
- sub_dirs = [os.path.basename(s).isdigit()
310
- for s in glob(os.path.join(save_dir, '*'))]
334
+ sub_dirs = [
335
+ os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
336
+ ]
311
337
  if any(sub_dirs):
312
338
  raise FileExistsError(
313
- f'Directory {save_dir} already contains subdirectories. Please '
314
- 'select a different directory, set overwrite=True, or resolve '
315
- 'manually.')
339
+ f"Directory {save_dir} already contains subdirectories. Please "
340
+ "select a different directory, set overwrite=True, or resolve "
341
+ "manually."
342
+ )