braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,29 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
 
6
- import numpy as np
7
- import pandas as pd
6
+ from __future__ import annotations
7
+
8
8
  import logging
9
+
9
10
  import mne
11
+ import numpy as np
12
+ import pandas as pd
13
+ from numpy.typing import ArrayLike, NDArray
10
14
 
11
- from .base import BaseDataset, BaseConcatDataset
15
+ from .base import BaseConcatDataset, BaseDataset
12
16
 
13
17
  log = logging.getLogger(__name__)
14
18
 
15
19
 
16
20
  def create_from_X_y(
17
- X, y, drop_last_window, sfreq, ch_names=None, window_size_samples=None,
18
- window_stride_samples=None):
21
+ X: NDArray,
22
+ y: ArrayLike,
23
+ drop_last_window: bool,
24
+ sfreq: float,
25
+ ch_names: ArrayLike = None,
26
+ window_size_samples: int | None = None,
27
+ window_stride_samples: int | None = None,
28
+ ) -> BaseConcatDataset:
19
29
  """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
20
30
  decoding with skorch and braindecode, where X is a list of pre-cut trials
21
31
  and y are corresponding targets.
@@ -46,7 +56,9 @@ def create_from_X_y(
46
56
  """
47
57
  # Prevent circular import
48
58
  from ..preprocessing.windowers import (
49
- create_fixed_length_windows, )
59
+ create_fixed_length_windows,
60
+ )
61
+
50
62
  n_samples_per_x = []
51
63
  base_datasets = []
52
64
  if ch_names is None:
@@ -57,16 +69,19 @@ def create_from_X_y(
57
69
  n_samples_per_x.append(x.shape[1])
58
70
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
59
71
  raw = mne.io.RawArray(x, info)
60
- base_dataset = BaseDataset(raw, pd.Series({"target": target}),
61
- target_name="target")
72
+ base_dataset = BaseDataset(
73
+ raw, pd.Series({"target": target}), target_name="target"
74
+ )
62
75
  base_datasets.append(base_dataset)
63
76
  base_datasets = BaseConcatDataset(base_datasets)
64
77
 
65
78
  if window_size_samples is None and window_stride_samples is None:
66
79
  if not len(np.unique(n_samples_per_x)) == 1:
67
- raise ValueError("if 'window_size_samples' and "
68
- "'window_stride_samples' are None, "
69
- "all trials have to have the same length")
80
+ raise ValueError(
81
+ "if 'window_size_samples' and "
82
+ "'window_stride_samples' are None, "
83
+ "all trials have to have the same length"
84
+ )
70
85
  window_size_samples = n_samples_per_x[0]
71
86
  window_stride_samples = n_samples_per_x[0]
72
87
  windows_datasets = create_fixed_length_windows(
@@ -75,6 +90,6 @@ def create_from_X_y(
75
90
  stop_offset_samples=None,
76
91
  window_size_samples=window_size_samples,
77
92
  window_stride_samples=window_stride_samples,
78
- drop_last_window=drop_last_window
93
+ drop_last_window=drop_last_window,
79
94
  )
80
95
  return windows_datasets
@@ -2,32 +2,51 @@
2
2
  Utilities for data manipulation.
3
3
  """
4
4
 
5
-
6
5
  from .serialization import (
7
- save_concat_dataset, load_concat_dataset, _check_save_dir_empty)
6
+ _check_save_dir_empty,
7
+ load_concat_dataset,
8
+ save_concat_dataset,
9
+ )
8
10
 
9
11
 
10
12
  def __getattr__(name):
11
13
  # ideas from https://stackoverflow.com/a/57110249/1469195
12
- from warnings import warn
13
14
  import importlib
14
- if name == 'create_from_X_y':
15
- warn('create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y')
16
- xy = importlib.import_module('..datasets.xy', __package__)
15
+ from warnings import warn
16
+
17
+ if name == "create_from_X_y":
18
+ warn(
19
+ "create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
20
+ )
21
+ xy = importlib.import_module("..datasets.xy", __package__)
17
22
  return xy.create_from_X_y
18
- if name in ['create_from_mne_raw', 'create_from_mne_epochs']:
19
- warn(f'{name} has been moved to datasets, please use from braindecode.datasets import {name}')
20
- mne = importlib.import_module('..datasets.mne', __package__)
23
+ if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
24
+ warn(
25
+ f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
26
+ )
27
+ mne = importlib.import_module("..datasets.mne", __package__)
21
28
  return mne.__dict__[name]
22
- if name in ['zscore', 'scale', 'exponential_moving_demean',
23
- 'exponential_moving_standardize', 'filterbank',
24
- 'preprocess', 'Preprocessor']:
25
- warn(f'{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}')
26
- preprocess = importlib.import_module('..preprocessing.preprocess', __package__)
29
+ if name in [
30
+ "scale",
31
+ "exponential_moving_demean",
32
+ "exponential_moving_standardize",
33
+ "filterbank",
34
+ "preprocess",
35
+ "Preprocessor",
36
+ ]:
37
+ warn(
38
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
39
+ )
40
+ preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
27
41
  return preprocess.__dict__[name]
28
- if name in ['create_windows_from_events', 'create_fixed_length_windows']:
29
- warn(f'{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}')
30
- windowers = importlib.import_module('..preprocessing.windowers', __package__)
42
+ if name in ["create_windows_from_events", "create_fixed_length_windows"]:
43
+ warn(
44
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
45
+ )
46
+ windowers = importlib.import_module("..preprocessing.windowers", __package__)
31
47
  return windowers.__dict__[name]
32
48
 
33
- raise AttributeError('No possible import named ' + name)
49
+ raise AttributeError("No possible import named " + name)
50
+
51
+
52
+ __all__ = ["load_concat_dataset", "save_concat_dataset", "_check_save_dir_empty"]
@@ -6,8 +6,8 @@ Convenience functions for storing and loading of windows datasets.
6
6
  #
7
7
  # License: BSD (3-clause)
8
8
 
9
- import os
10
9
  import json
10
+ import os
11
11
  import pickle
12
12
  import warnings
13
13
  from glob import glob
@@ -17,17 +17,24 @@ import mne
17
17
  import pandas as pd
18
18
  from joblib import Parallel, delayed
19
19
 
20
- from ..datasets.base import BaseDataset, BaseConcatDataset, WindowsDataset, EEGWindowsDataset
20
+ from ..datasets.base import (
21
+ BaseConcatDataset,
22
+ BaseDataset,
23
+ EEGWindowsDataset,
24
+ WindowsDataset,
25
+ )
21
26
 
22
27
 
23
28
  def save_concat_dataset(path, concat_dataset, overwrite=False):
24
- warnings.warn('"save_concat_dataset()" is deprecated and will be removed in'
25
- ' the future. Use dataset.save() instead.')
29
+ warnings.warn(
30
+ '"save_concat_dataset()" is deprecated and will be removed in'
31
+ " the future. Use dataset.save() instead.",
32
+ UserWarning,
33
+ )
26
34
  concat_dataset.save(path=path, overwrite=overwrite)
27
35
 
28
36
 
29
- def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
30
- target_name=None):
37
+ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
31
38
  """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
32
39
  files.
33
40
 
@@ -48,15 +55,16 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
48
55
  concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
49
56
  """
50
57
  # assume we have a single concat dataset to load
51
- is_raw = (path / '0-raw.fif').is_file()
58
+ is_raw = (path / "0-raw.fif").is_file()
52
59
  assert not (not is_raw and target_name is not None), (
53
- 'Setting a new target is only supported for raws.')
54
- is_epochs = (path / '0-epo.fif').is_file()
60
+ "Setting a new target is only supported for raws."
61
+ )
62
+ is_epochs = (path / "0-epo.fif").is_file()
55
63
  paths = [path]
56
64
  # assume we have multiple concat datasets to load
57
65
  if not (is_raw or is_epochs):
58
- is_raw = (path / '0' / '0-raw.fif').is_file()
59
- is_epochs = (path / '0' / '0-epo.fif').is_file()
66
+ is_raw = (path / "0" / "0-raw.fif").is_file()
67
+ is_epochs = (path / "0" / "0-epo.fif").is_file()
60
68
  paths = path.glob("*/")
61
69
  paths = sorted(paths, key=lambda p: int(p.name))
62
70
  if ids_to_load is not None:
@@ -64,33 +72,32 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
64
72
  ids_to_load = None
65
73
  # if we have neither a single nor multiple datasets, something went wrong
66
74
  assert is_raw or is_epochs, (
67
- f'Expect either raw or epo to exist in {path} or in '
68
- f'{path / "0"}')
75
+ f"Expect either raw or epo to exist in {path} or in {path / '0'}"
76
+ )
69
77
 
70
78
  datasets = []
71
79
  for path in paths:
72
80
  if is_raw and target_name is None:
73
- target_file_name = path / 'target_name.json'
74
- target_name = json.load(open(target_file_name, "r"))['target_name']
81
+ target_file_name = path / "target_name.json"
82
+ target_name = json.load(open(target_file_name, "r"))["target_name"]
75
83
 
76
84
  all_signals, description = _load_signals_and_description(
77
- path=path, preload=preload, is_raw=is_raw,
78
- ids_to_load=ids_to_load
85
+ path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
79
86
  )
80
87
  for i_signal, signal in enumerate(all_signals):
81
88
  if is_raw:
82
89
  datasets.append(
83
- BaseDataset(signal, description.iloc[i_signal],
84
- target_name=target_name))
85
- else:
86
- datasets.append(
87
- WindowsDataset(signal, description.iloc[i_signal])
90
+ BaseDataset(
91
+ signal, description.iloc[i_signal], target_name=target_name
92
+ )
88
93
  )
94
+ else:
95
+ datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
89
96
  concat_ds = BaseConcatDataset(datasets)
90
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs', 'window_preproc_kwargs']:
91
- kwarg_path = path / '.'.join([kwarg_name, 'json'])
97
+ for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
98
+ kwarg_path = path / ".".join([kwarg_name, "json"])
92
99
  if kwarg_path.exists():
93
- with open(kwarg_path, 'r') as f:
100
+ with open(kwarg_path, "r") as f:
94
101
  kwargs = json.load(f)
95
102
  kwargs = [tuple(kwarg) for kwarg in kwargs]
96
103
  setattr(concat_ds, kwarg_name, kwargs)
@@ -100,14 +107,22 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
100
107
  def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
101
108
  all_signals = []
102
109
  file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
103
- description_df = pd.read_json(path / "description.json")
110
+ description_df = pd.read_json(
111
+ path / "description.json", typ="series", convert_dates=False
112
+ )
113
+
114
+ if "timestamp" in description_df.index:
115
+ timestamp_numeric = pd.to_numeric(description_df["timestamp"])
116
+ description_df["timestamp"] = pd.to_datetime(timestamp_numeric)
117
+
104
118
  if ids_to_load is None:
105
119
  file_names = path.glob(f"*{file_name.lstrip('{}')}")
106
120
  # Extract ids, e.g.,
107
121
  # '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
108
122
  # '11-raw.fif' -> 11
109
123
  ids_to_load = sorted(
110
- [int(os.path.split(f)[-1].split('-')[0]) for f in file_names])
124
+ [int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
125
+ )
111
126
  for i in ids_to_load:
112
127
  fif_file = path / file_name.format(i)
113
128
  all_signals.append(_load_signals(fif_file, preload, is_raw))
@@ -133,10 +148,10 @@ def _load_signals(fif_file, preload, is_raw):
133
148
  # If pickle didn't exist read via mne (likely slower) and save pkl after
134
149
  if is_raw:
135
150
  signals = mne.io.read_raw_fif(fif_file, preload=preload)
136
- elif fif_file.name.endswith('-epo.fif'):
151
+ elif fif_file.name.endswith("-epo.fif"):
137
152
  signals = mne.read_epochs(fif_file, preload=preload)
138
153
  else:
139
- raise ValueError('fif_file must end with raw.fif or epo.fif.')
154
+ raise ValueError("fif_file must end with raw.fif or epo.fif.")
140
155
 
141
156
  # Only do this for raw objects. Epoch objects are not picklable as they
142
157
  # hold references to open files in `signals._raw[0].fid`.
@@ -154,8 +169,7 @@ def _load_signals(fif_file, preload, is_raw):
154
169
  return signals
155
170
 
156
171
 
157
- def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
158
- n_jobs=1):
172
+ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
159
173
  """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
160
174
  files.
161
175
 
@@ -183,11 +197,14 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
183
197
  # if we encounter a dataset that was saved in 'the old way', call the
184
198
  # corresponding 'old' loading function
185
199
  if _is_outdated_saved(path):
186
- warnings.warn("The way your dataset was saved is deprecated by now. "
187
- "Please save it again using dataset.save().", UserWarning)
200
+ warnings.warn(
201
+ "The way your dataset was saved is deprecated by now. "
202
+ "Please save it again using dataset.save().",
203
+ UserWarning,
204
+ )
188
205
  return _outdated_load_concat_dataset(
189
- path=path, preload=preload, ids_to_load=ids_to_load,
190
- target_name=target_name)
206
+ path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
207
+ )
191
208
 
192
209
  # else we have a dataset saved in the new way with subdirectories in path
193
210
  # for every dataset with description.json and -epo.fif or -raw.fif,
@@ -197,9 +214,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
197
214
  ids_to_load = [p.name for p in path.iterdir()]
198
215
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
199
216
  ids_to_load = [str(i) for i in ids_to_load]
200
- first_raw_fif_path = path / ids_to_load[0] / f'{ids_to_load[0]}-raw.fif'
217
+ first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
201
218
  is_raw = first_raw_fif_path.exists()
202
- metadata_path = path / ids_to_load[0] / 'metadata_df.pkl'
219
+ metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
203
220
  has_stored_windows = metadata_path.exists()
204
221
 
205
222
  # Parallelization of mne.read_epochs with preload=False fails with
@@ -207,8 +224,10 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
207
224
  # So ignore n_jobs in that case and load with a single job.
208
225
  if not is_raw and n_jobs != 1:
209
226
  warnings.warn(
210
- 'Parallelized reading with `preload=False` is not supported for '
211
- 'windowed data. Will use `n_jobs=1`.', UserWarning)
227
+ "Parallelized reading with `preload=False` is not supported for "
228
+ "windowed data. Will use `n_jobs=1`.",
229
+ UserWarning,
230
+ )
212
231
  n_jobs = 1
213
232
  datasets = Parallel(n_jobs)(
214
233
  delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
@@ -219,9 +238,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
219
238
 
220
239
  def _load_parallel(path, i, preload, is_raw, has_stored_windows):
221
240
  sub_dir = path / i
222
- file_name_patterns = ['{}-raw.fif', '{}-epo.fif']
241
+ file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
223
242
  if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
224
- raise FileExistsError('Found -raw.fif and -epo.fif in directory.')
243
+ raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
225
244
 
226
245
  fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
227
246
  fif_file_name = fif_name_pattern.format(i)
@@ -229,48 +248,55 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
229
248
 
230
249
  signals = _load_signals(fif_file_path, preload, is_raw)
231
250
 
232
- description_file_path = sub_dir / 'description.json'
233
- description = pd.read_json(description_file_path, typ='series')
251
+ description_file_path = sub_dir / "description.json"
252
+ description = pd.read_json(description_file_path, typ="series", convert_dates=False)
253
+
254
+ # if 'timestamp' in description.index:
255
+ # timestamp_numeric = pd.to_numeric(description['timestamp'])
256
+ # description['timestamp'] = pd.to_datetime(timestamp_numeric, unit='s')
234
257
 
235
- target_file_path = sub_dir / 'target_name.json'
258
+ target_file_path = sub_dir / "target_name.json"
236
259
  target_name = None
237
260
  if target_file_path.exists():
238
- target_name = json.load(open(target_file_path, "r"))['target_name']
261
+ target_name = json.load(open(target_file_path, "r"))["target_name"]
239
262
 
240
263
  if is_raw and (not has_stored_windows):
241
264
  dataset = BaseDataset(signals, description, target_name)
242
265
  else:
243
- window_kwargs = _load_kwargs_json('window_kwargs', sub_dir)
244
- windows_ds_kwargs = [kwargs[1] for kwargs in window_kwargs if kwargs[0] == 'WindowsDataset']
266
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
267
+ windows_ds_kwargs = [
268
+ kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
269
+ ]
245
270
  windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
246
271
  if is_raw:
247
- metadata = pd.read_pickle(path / i / 'metadata_df.pkl')
272
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
248
273
  dataset = EEGWindowsDataset(
249
274
  signals,
250
275
  metadata=metadata,
251
276
  description=description,
252
- targets_from=windows_ds_kwargs.get('targets_from', 'metadata'),
253
- last_target_only=windows_ds_kwargs.get('last_target_only', True),
277
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
278
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
254
279
  )
255
280
  else:
256
281
  # MNE epochs dataset
257
282
  dataset = WindowsDataset(
258
- signals, description,
259
- targets_from=windows_ds_kwargs.get('targets_from', 'metadata'),
260
- last_target_only=windows_ds_kwargs.get('last_target_only', True)
283
+ signals,
284
+ description,
285
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
286
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
261
287
  )
262
- setattr(dataset, 'window_kwargs', window_kwargs)
263
- for kwargs_name in ['raw_preproc_kwargs', 'window_preproc_kwargs']:
288
+ setattr(dataset, "window_kwargs", window_kwargs)
289
+ for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
264
290
  kwargs = _load_kwargs_json(kwargs_name, sub_dir)
265
291
  setattr(dataset, kwargs_name, kwargs)
266
292
  return dataset
267
293
 
268
294
 
269
295
  def _load_kwargs_json(kwargs_name, sub_dir):
270
- kwargs_file_name = '.'.join([kwargs_name, 'json'])
296
+ kwargs_file_name = ".".join([kwargs_name, "json"])
271
297
  kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
272
298
  if os.path.exists(kwargs_file_path):
273
- kwargs = json.load(open(kwargs_file_path, 'r'))
299
+ kwargs = json.load(open(kwargs_file_path, "r"))
274
300
  kwargs = [tuple(kwarg) for kwarg in kwargs]
275
301
  return kwargs
276
302
 
@@ -279,18 +305,28 @@ def _is_outdated_saved(path):
279
305
  """Data was saved in the old way if there are 'description.json', '-raw.fif'
280
306
  or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
281
307
  than 'description.json' files."""
282
- description_files = glob(os.path.join(path, '**/description.json'))
283
- fif_files = glob(os.path.join(path, '**/*-raw.fif')) + glob(os.path.join(path, '**/*-epo.fif'))
308
+ description_files = glob(os.path.join(path, "**/description.json"))
309
+ fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
310
+ os.path.join(path, "**/*-epo.fif")
311
+ )
284
312
  multiple = len(description_files) != len(fif_files)
285
313
  kwargs_in_path = any(
286
- [os.path.exists(os.path.join(path, kwarg_name))
287
- for kwarg_name in ['raw_preproc_kwargs', 'window_kwargs',
288
- 'window_preproc_kwargs']])
289
- return (os.path.exists(os.path.join(path, 'description.json')) or
290
- os.path.exists(os.path.join(path, '0-raw.fif')) or
291
- os.path.exists(os.path.join(path, '0-epo.fif')) or
292
- multiple or
293
- kwargs_in_path)
314
+ [
315
+ os.path.exists(os.path.join(path, kwarg_name))
316
+ for kwarg_name in [
317
+ "raw_preproc_kwargs",
318
+ "window_kwargs",
319
+ "window_preproc_kwargs",
320
+ ]
321
+ ]
322
+ )
323
+ return (
324
+ os.path.exists(os.path.join(path, "description.json"))
325
+ or os.path.exists(os.path.join(path, "0-raw.fif"))
326
+ or os.path.exists(os.path.join(path, "0-epo.fif"))
327
+ or multiple
328
+ or kwargs_in_path
329
+ )
294
330
 
295
331
 
296
332
  def _check_save_dir_empty(save_dir):
@@ -306,10 +342,12 @@ def _check_save_dir_empty(save_dir):
306
342
  FileExistsError
307
343
  If ``save_dir`` is not a valid directory for saving.
308
344
  """
309
- sub_dirs = [os.path.basename(s).isdigit()
310
- for s in glob(os.path.join(save_dir, '*'))]
345
+ sub_dirs = [
346
+ os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
347
+ ]
311
348
  if any(sub_dirs):
312
349
  raise FileExistsError(
313
- f'Directory {save_dir} already contains subdirectories. Please '
314
- 'select a different directory, set overwrite=True, or resolve '
315
- 'manually.')
350
+ f"Directory {save_dir} already contains subdirectories. Please "
351
+ "select a different directory, set overwrite=True, or resolve "
352
+ "manually."
353
+ )