braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ """
2
+ Utilities for data manipulation.
3
+ """
4
+
5
+ from .serialization import (
6
+ _check_save_dir_empty,
7
+ load_concat_dataset,
8
+ save_concat_dataset,
9
+ )
10
+
11
+
12
+ def __getattr__(name):
13
+ # ideas from https://stackoverflow.com/a/57110249/1469195
14
+ import importlib
15
+ from warnings import warn
16
+
17
+ if name == "create_from_X_y":
18
+ warn(
19
+ "create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
20
+ )
21
+ xy = importlib.import_module("..datasets.xy", __package__)
22
+ return xy.create_from_X_y
23
+ if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
24
+ warn(
25
+ f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
26
+ )
27
+ mne = importlib.import_module("..datasets.mne", __package__)
28
+ return mne.__dict__[name]
29
+ if name in [
30
+ "scale",
31
+ "exponential_moving_demean",
32
+ "exponential_moving_standardize",
33
+ "filterbank",
34
+ "preprocess",
35
+ "Preprocessor",
36
+ ]:
37
+ warn(
38
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
39
+ )
40
+ preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
41
+ return preprocess.__dict__[name]
42
+ if name in ["create_windows_from_events", "create_fixed_length_windows"]:
43
+ warn(
44
+ f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
45
+ )
46
+ windowers = importlib.import_module("..preprocessing.windowers", __package__)
47
+ return windowers.__dict__[name]
48
+
49
+ raise AttributeError("No possible import named " + name)
@@ -0,0 +1,342 @@
1
+ """
2
+ Convenience functions for storing and loading of windows datasets.
3
+ """
4
+
5
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
6
+ #
7
+ # License: BSD (3-clause)
8
+
9
+ import json
10
+ import os
11
+ import pickle
12
+ import warnings
13
+ from glob import glob
14
+ from pathlib import Path
15
+
16
+ import mne
17
+ import pandas as pd
18
+ from joblib import Parallel, delayed
19
+
20
+ from ..datasets.base import (
21
+ BaseConcatDataset,
22
+ BaseDataset,
23
+ EEGWindowsDataset,
24
+ WindowsDataset,
25
+ )
26
+
27
+
28
+ def save_concat_dataset(path, concat_dataset, overwrite=False):
29
+ warnings.warn(
30
+ '"save_concat_dataset()" is deprecated and will be removed in'
31
+ " the future. Use dataset.save() instead.",
32
+ UserWarning,
33
+ )
34
+ concat_dataset.save(path=path, overwrite=overwrite)
35
+
36
+
37
+ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
38
+ """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
39
+ files.
40
+
41
+ Parameters
42
+ ----------
43
+ path: pathlib.Path
44
+ Path to the directory of the .fif / -epo.fif and .json files.
45
+ preload: bool
46
+ Whether to preload the data.
47
+ ids_to_load: None | list(int)
48
+ Ids of specific files to load.
49
+ target_name: None or str
50
+ Load specific description column as target. If not given, take saved
51
+ target name.
52
+
53
+ Returns
54
+ -------
55
+ concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
56
+ """
57
+ # assume we have a single concat dataset to load
58
+ is_raw = (path / "0-raw.fif").is_file()
59
+ assert not (not is_raw and target_name is not None), (
60
+ "Setting a new target is only supported for raws."
61
+ )
62
+ is_epochs = (path / "0-epo.fif").is_file()
63
+ paths = [path]
64
+ # assume we have multiple concat datasets to load
65
+ if not (is_raw or is_epochs):
66
+ is_raw = (path / "0" / "0-raw.fif").is_file()
67
+ is_epochs = (path / "0" / "0-epo.fif").is_file()
68
+ paths = path.glob("*/")
69
+ paths = sorted(paths, key=lambda p: int(p.name))
70
+ if ids_to_load is not None:
71
+ paths = [paths[i] for i in ids_to_load]
72
+ ids_to_load = None
73
+ # if we have neither a single nor multiple datasets, something went wrong
74
+ assert is_raw or is_epochs, (
75
+ f"Expect either raw or epo to exist in {path} or in {path / '0'}"
76
+ )
77
+
78
+ datasets = []
79
+ for path in paths:
80
+ if is_raw and target_name is None:
81
+ target_file_name = path / "target_name.json"
82
+ target_name = json.load(open(target_file_name, "r"))["target_name"]
83
+
84
+ all_signals, description = _load_signals_and_description(
85
+ path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
86
+ )
87
+ for i_signal, signal in enumerate(all_signals):
88
+ if is_raw:
89
+ datasets.append(
90
+ BaseDataset(
91
+ signal, description.iloc[i_signal], target_name=target_name
92
+ )
93
+ )
94
+ else:
95
+ datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
96
+ concat_ds = BaseConcatDataset(datasets)
97
+ for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
98
+ kwarg_path = path / ".".join([kwarg_name, "json"])
99
+ if kwarg_path.exists():
100
+ with open(kwarg_path, "r") as f:
101
+ kwargs = json.load(f)
102
+ kwargs = [tuple(kwarg) for kwarg in kwargs]
103
+ setattr(concat_ds, kwarg_name, kwargs)
104
+ return concat_ds
105
+
106
+
107
+ def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
108
+ all_signals = []
109
+ file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
110
+ description_df = pd.read_json(path / "description.json")
111
+ if ids_to_load is None:
112
+ file_names = path.glob(f"*{file_name.lstrip('{}')}")
113
+ # Extract ids, e.g.,
114
+ # '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
115
+ # '11-raw.fif' -> 11
116
+ ids_to_load = sorted(
117
+ [int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
118
+ )
119
+ for i in ids_to_load:
120
+ fif_file = path / file_name.format(i)
121
+ all_signals.append(_load_signals(fif_file, preload, is_raw))
122
+ description_df = description_df.iloc[ids_to_load]
123
+ return all_signals, description_df
124
+
125
+
126
+ def _load_signals(fif_file, preload, is_raw):
127
+ # Reading the raw file from pickle if it has been save before.
128
+ # The pickle file only contain the raw object without the data.
129
+ pkl_file = fif_file.with_suffix(".pkl")
130
+ if pkl_file.exists():
131
+ with open(pkl_file, "rb") as f:
132
+ signals = pickle.load(f)
133
+
134
+ # If the file has been moved together with the pickle file, make sure
135
+ # the path links to correct fif file.
136
+ signals._fname = str(fif_file)
137
+ if preload:
138
+ signals.load_data()
139
+ return signals
140
+
141
+ # If pickle didn't exist read via mne (likely slower) and save pkl after
142
+ if is_raw:
143
+ signals = mne.io.read_raw_fif(fif_file, preload=preload)
144
+ elif fif_file.name.endswith("-epo.fif"):
145
+ signals = mne.read_epochs(fif_file, preload=preload)
146
+ else:
147
+ raise ValueError("fif_file must end with raw.fif or epo.fif.")
148
+
149
+ # Only do this for raw objects. Epoch objects are not picklable as they
150
+ # hold references to open files in `signals._raw[0].fid`.
151
+ if is_raw:
152
+ # Saving the raw file without data into a pickle file, so it can be
153
+ # retrieved faster on the next use of this dataset.
154
+ with open(pkl_file, "wb") as f:
155
+ if preload:
156
+ data = signals._data
157
+ signals._data, signals.preload = None, False
158
+ pickle.dump(signals, f)
159
+ if preload:
160
+ signals._data, signals.preload = data, True
161
+
162
+ return signals
163
+
164
+
165
+ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
166
+ """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
167
+ files.
168
+
169
+ Parameters
170
+ ----------
171
+ path: str | pathlib.Path
172
+ Path to the directory of the .fif / -epo.fif and .json files.
173
+ preload: bool
174
+ Whether to preload the data.
175
+ ids_to_load: list of int | None
176
+ Ids of specific files to load.
177
+ target_name: str | list | None
178
+ Load specific description column as target. If not given, take saved
179
+ target name.
180
+ n_jobs: int
181
+ Number of jobs to be used to read files in parallel.
182
+
183
+ Returns
184
+ -------
185
+ concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
186
+ """
187
+ # Make sure we always work with a pathlib.Path
188
+ path = Path(path)
189
+
190
+ # if we encounter a dataset that was saved in 'the old way', call the
191
+ # corresponding 'old' loading function
192
+ if _is_outdated_saved(path):
193
+ warnings.warn(
194
+ "The way your dataset was saved is deprecated by now. "
195
+ "Please save it again using dataset.save().",
196
+ UserWarning,
197
+ )
198
+ return _outdated_load_concat_dataset(
199
+ path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
200
+ )
201
+
202
+ # else we have a dataset saved in the new way with subdirectories in path
203
+ # for every dataset with description.json and -epo.fif or -raw.fif,
204
+ # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
205
+ # window_preproc_kwargs.json
206
+ if ids_to_load is None:
207
+ ids_to_load = [p.name for p in path.iterdir()]
208
+ ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
209
+ ids_to_load = [str(i) for i in ids_to_load]
210
+ first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
211
+ is_raw = first_raw_fif_path.exists()
212
+ metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
213
+ has_stored_windows = metadata_path.exists()
214
+
215
+ # Parallelization of mne.read_epochs with preload=False fails with
216
+ # 'TypeError: cannot pickle '_io.BufferedReader' object'.
217
+ # So ignore n_jobs in that case and load with a single job.
218
+ if not is_raw and n_jobs != 1:
219
+ warnings.warn(
220
+ "Parallelized reading with `preload=False` is not supported for "
221
+ "windowed data. Will use `n_jobs=1`.",
222
+ UserWarning,
223
+ )
224
+ n_jobs = 1
225
+ datasets = Parallel(n_jobs)(
226
+ delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
227
+ for i in ids_to_load
228
+ )
229
+ return BaseConcatDataset(datasets)
230
+
231
+
232
+ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
233
+ sub_dir = path / i
234
+ file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
235
+ if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
236
+ raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
237
+
238
+ fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
239
+ fif_file_name = fif_name_pattern.format(i)
240
+ fif_file_path = sub_dir / fif_file_name
241
+
242
+ signals = _load_signals(fif_file_path, preload, is_raw)
243
+
244
+ description_file_path = sub_dir / "description.json"
245
+ description = pd.read_json(description_file_path, typ="series")
246
+
247
+ target_file_path = sub_dir / "target_name.json"
248
+ target_name = None
249
+ if target_file_path.exists():
250
+ target_name = json.load(open(target_file_path, "r"))["target_name"]
251
+
252
+ if is_raw and (not has_stored_windows):
253
+ dataset = BaseDataset(signals, description, target_name)
254
+ else:
255
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
256
+ windows_ds_kwargs = [
257
+ kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
258
+ ]
259
+ windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
260
+ if is_raw:
261
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
262
+ dataset = EEGWindowsDataset(
263
+ signals,
264
+ metadata=metadata,
265
+ description=description,
266
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
267
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
268
+ )
269
+ else:
270
+ # MNE epochs dataset
271
+ dataset = WindowsDataset(
272
+ signals,
273
+ description,
274
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
275
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
276
+ )
277
+ setattr(dataset, "window_kwargs", window_kwargs)
278
+ for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
279
+ kwargs = _load_kwargs_json(kwargs_name, sub_dir)
280
+ setattr(dataset, kwargs_name, kwargs)
281
+ return dataset
282
+
283
+
284
+ def _load_kwargs_json(kwargs_name, sub_dir):
285
+ kwargs_file_name = ".".join([kwargs_name, "json"])
286
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
287
+ if os.path.exists(kwargs_file_path):
288
+ kwargs = json.load(open(kwargs_file_path, "r"))
289
+ kwargs = [tuple(kwarg) for kwarg in kwargs]
290
+ return kwargs
291
+
292
+
293
+ def _is_outdated_saved(path):
294
+ """Data was saved in the old way if there are 'description.json', '-raw.fif'
295
+ or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
296
+ than 'description.json' files."""
297
+ description_files = glob(os.path.join(path, "**/description.json"))
298
+ fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
299
+ os.path.join(path, "**/*-epo.fif")
300
+ )
301
+ multiple = len(description_files) != len(fif_files)
302
+ kwargs_in_path = any(
303
+ [
304
+ os.path.exists(os.path.join(path, kwarg_name))
305
+ for kwarg_name in [
306
+ "raw_preproc_kwargs",
307
+ "window_kwargs",
308
+ "window_preproc_kwargs",
309
+ ]
310
+ ]
311
+ )
312
+ return (
313
+ os.path.exists(os.path.join(path, "description.json"))
314
+ or os.path.exists(os.path.join(path, "0-raw.fif"))
315
+ or os.path.exists(os.path.join(path, "0-epo.fif"))
316
+ or multiple
317
+ or kwargs_in_path
318
+ )
319
+
320
+
321
+ def _check_save_dir_empty(save_dir):
322
+ """Make sure a BaseConcatDataset can be saved under a given directory.
323
+
324
+ Parameters
325
+ ----------
326
+ save_dir : str
327
+ Directory under which a `BaseConcatDataset` will be saved.
328
+
329
+ Raises
330
+ -------
331
+ FileExistsError
332
+ If ``save_dir`` is not a valid directory for saving.
333
+ """
334
+ sub_dirs = [
335
+ os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
336
+ ]
337
+ if any(sub_dirs):
338
+ raise FileExistsError(
339
+ f"Directory {save_dir} already contains subdirectories. Please "
340
+ "select a different directory, set overwrite=True, or resolve "
341
+ "manually."
342
+ )
@@ -0,0 +1,41 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+
6
+ def ms_to_samples(ms, fs):
7
+ """
8
+ Compute milliseconds to number of samples.
9
+
10
+ Parameters
11
+ ----------
12
+ ms: number
13
+ Milliseconds
14
+ fs: number
15
+ Sampling rate
16
+
17
+ Returns
18
+ -------
19
+ n_samples: int
20
+ Number of samples
21
+
22
+ """
23
+ return ms * fs / 1000.0
24
+
25
+
26
+ def samples_to_ms(n_samples, fs):
27
+ """
28
+ Compute milliseconds to number of samples.
29
+
30
+ Parameters
31
+ ----------
32
+ n_samples: number
33
+ Number of samples
34
+ fs: number
35
+ Sampling rate
36
+
37
+ Returns
38
+ -------
39
+ milliseconds: int
40
+ """
41
+ return n_samples * 1000.0 / fs
@@ -5,32 +5,36 @@
5
5
 
6
6
 
7
7
  import abc
8
- import logging
9
8
  import inspect
9
+ import logging
10
10
 
11
11
  import mne
12
12
  import numpy as np
13
13
  import torch
14
- from skorch import NeuralNet
15
14
  from sklearn.metrics import get_scorer
15
+ from skorch import NeuralNet
16
16
  from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
17
- from skorch.utils import noop, to_numpy, train_loss_score, valid_loss_score, is_dataset
17
+ from skorch.helper import SliceDataset
18
+ from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
18
19
 
19
- from .training.scoring import (CroppedTimeSeriesEpochScoring,
20
- CroppedTrialEpochScoring, PostEpochTrainScoring)
21
- from .models.util import models_dict
22
20
  from .datasets.base import BaseConcatDataset, WindowsDataset
21
+ from .models.util import models_dict
22
+ from .training.scoring import (
23
+ CroppedTimeSeriesEpochScoring,
24
+ CroppedTrialEpochScoring,
25
+ PostEpochTrainScoring,
26
+ )
23
27
 
24
28
  log = logging.getLogger(__name__)
25
29
 
26
30
 
27
- def _get_model(model):
28
- ''' Returns the corresponding class in case the model passed is a string. '''
31
+ def _get_model(model: str):
32
+ """Returns the corresponding class in case the model passed is a string."""
29
33
  if isinstance(model, str):
30
34
  if model in models_dict:
31
35
  model = models_dict[model]
32
36
  else:
33
- raise ValueError(f'Unknown model name {model!r}.')
37
+ raise ValueError(f"Unknown model name {model!r}.")
34
38
  return model
35
39
 
36
40
 
@@ -50,7 +54,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
50
54
  will be left as is.
51
55
 
52
56
  """
53
- kwargs = self.get_params_for('module')
57
+ kwargs = self.get_params_for("module")
54
58
  module = _get_model(self.module)
55
59
  module = self.initialized_instance(module, kwargs)
56
60
  # pylint: disable=attribute-defined-outside-init
@@ -61,7 +65,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
61
65
  # Here we parse the callbacks supplied as strings,
62
66
  # e.g. 'accuracy', to the callbacks skorch expects
63
67
  for name, cb, named_by_user in super()._yield_callbacks():
64
- if name == 'str':
68
+ if name == "str":
65
69
  train_cb, valid_cb = self._parse_str_callback(cb)
66
70
  yield train_cb
67
71
  if self.train_split is not None:
@@ -72,15 +76,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
72
76
  def _parse_str_callback(self, cb_supplied_name):
73
77
  scoring = get_scorer(cb_supplied_name)
74
78
  scoring_name = scoring._score_func.__name__
75
- assert scoring_name.endswith(
76
- ('_score', '_error', '_deviance', '_loss'))
77
- if (scoring_name.endswith('_score') or
78
- cb_supplied_name.startswith('neg_')):
79
+ assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
80
+ if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
79
81
  lower_is_better = False
80
82
  else:
81
83
  lower_is_better = True
82
- train_name = f'train_{cb_supplied_name}'
83
- valid_name = f'valid_{cb_supplied_name}'
84
+ train_name = f"train_{cb_supplied_name}"
85
+ valid_name = f"valid_{cb_supplied_name}"
84
86
  if self.cropped:
85
87
  train_scoring = CroppedTrialEpochScoring(
86
88
  cb_supplied_name, lower_is_better, on_train=True, name=train_name
@@ -98,7 +100,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
98
100
  named_by_user = True
99
101
  train_valid_callbacks = [
100
102
  (train_name, train_scoring, named_by_user),
101
- (valid_name, valid_scoring, named_by_user)
103
+ (valid_name, valid_scoring, named_by_user),
102
104
  ]
103
105
  return train_valid_callbacks
104
106
 
@@ -108,8 +110,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
108
110
  if not training:
109
111
  epoch_cbs = []
110
112
  for name, cb in self.callbacks_:
111
- if isinstance(cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)) and (
112
- hasattr(cb, 'window_inds_')) and (not cb.on_train):
113
+ if (
114
+ isinstance(
115
+ cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
116
+ )
117
+ and (hasattr(cb, "window_inds_"))
118
+ and (not cb.on_train)
119
+ ):
113
120
  epoch_cbs.append(cb)
114
121
  # for trialwise decoding stuffs it might also be we don't have
115
122
  # cropped loader, so no indices there
@@ -136,8 +143,11 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
136
143
  i_window_stops = np.concatenate(i_window_stops)
137
144
  window_ys = np.concatenate(window_ys)
138
145
  return dict(
139
- preds=preds, i_window_in_trials=i_window_in_trials,
140
- i_window_stops=i_window_stops, window_ys=window_ys)
146
+ preds=preds,
147
+ i_window_in_trials=i_window_in_trials,
148
+ i_window_stops=i_window_stops,
149
+ window_ys=window_ys,
150
+ )
141
151
 
142
152
  # Changes the default target extractor to noop
143
153
  @property
@@ -156,7 +166,9 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
156
166
  (
157
167
  "valid_loss",
158
168
  BatchScoring(
159
- valid_loss_score, name="valid_loss", target_extractor=noop,
169
+ valid_loss_score,
170
+ name="valid_loss",
171
+ target_extractor=noop,
160
172
  ),
161
173
  ),
162
174
  ("print_log", PrintLog()),
@@ -179,17 +191,27 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
179
191
  return
180
192
  # get kwargs from signal:
181
193
  signal_kwargs = dict()
182
- if isinstance(X, mne.BaseEpochs) or isinstance(X, np.ndarray):
194
+ # Using shape to work both with torch.tensor and numpy.array:
195
+ if (
196
+ isinstance(X, mne.BaseEpochs)
197
+ or (hasattr(X, "shape") and len(X.shape) >= 2)
198
+ or isinstance(X, SliceDataset)
199
+ ):
183
200
  if y is None:
184
- raise ValueError("y must be specified if X is a numpy array.")
185
- signal_kwargs['n_outputs'] = self._get_n_outputs(y, classes)
201
+ raise ValueError("y must be specified if X is array-like.")
202
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
186
203
  if isinstance(X, mne.BaseEpochs):
187
204
  self.log.info("Using mne.Epochs to find signal-related parameters.")
188
205
  signal_kwargs["n_times"] = len(X.times)
189
- signal_kwargs["sfreq"] = X.info['sfreq']
190
- signal_kwargs["chs_info"] = X.info['chs']
206
+ signal_kwargs["sfreq"] = X.info["sfreq"]
207
+ signal_kwargs["chs_info"] = X.info["chs"]
208
+ elif isinstance(X, SliceDataset):
209
+ self.log.info("Using SliceDataset to find signal-related parameters.")
210
+ Xshape = X[0].shape
211
+ signal_kwargs["n_times"] = Xshape[-1]
212
+ signal_kwargs["n_chans"] = Xshape[-2]
191
213
  else:
192
- self.log.info("Using numpy array to find signal-related parameters.")
214
+ self.log.info("Using array-like to find signal-related parameters.")
193
215
  signal_kwargs["n_times"] = X.shape[-1]
194
216
  signal_kwargs["n_chans"] = X.shape[-2]
195
217
  elif is_dataset(X):
@@ -198,21 +220,17 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
198
220
  Xshape = X0.shape
199
221
  signal_kwargs["n_times"] = Xshape[-1]
200
222
  signal_kwargs["n_chans"] = Xshape[-2]
201
- if (
202
- isinstance(X, BaseConcatDataset) and
203
- all(ds.targets_from == 'metadata' for ds in X.datasets)
223
+ if isinstance(X, BaseConcatDataset) and all(
224
+ ds.targets_from == "metadata" for ds in X.datasets
204
225
  ):
205
226
  y_target = X.get_metadata().target
206
- signal_kwargs['n_outputs'] = self._get_n_outputs(y_target, classes)
207
- elif (
208
- isinstance(X, WindowsDataset) and
209
- X.targets_from == "metadata"
210
- ):
227
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
228
+ elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
211
229
  y_target = X.windows.metadata.target
212
- signal_kwargs['n_outputs'] = self._get_n_outputs(y_target, classes)
230
+ signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
213
231
  else:
214
232
  self.log.warning(
215
- "Can only infer signal shape of numpy arrays or and Datasets, "
233
+ "Can only infer signal shape of array-like and Datasets, "
216
234
  f"got {type(X)!r}."
217
235
  )
218
236
  return
@@ -227,15 +245,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
227
245
  if k in all_module_kwargs:
228
246
  module_kwargs[k] = v
229
247
  else:
230
- self.log.warning(
231
- f"Module {self.module!r} "
232
- f"is missing parameter {k!r}."
233
- )
248
+ self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
234
249
 
235
250
  # save kwargs to self:
236
251
  self.log.info(
237
252
  f"Passing additional parameters {module_kwargs!r} "
238
- f"to module {self.module!r}.")
253
+ f"to module {self.module!r}."
254
+ )
239
255
  module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
240
256
  self.set_params(**module_kwargs)
241
257
 
@@ -275,7 +291,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
275
291
 
276
292
  """
277
293
  if isinstance(X, mne.BaseEpochs):
278
- X = X.get_data(units='uV')
294
+ X = X.get_data(units="uV")
279
295
  return super().get_dataset(X, y)
280
296
 
281
297
  def partial_fit(self, X, y=None, classes=None, **fit_params):
@@ -291,7 +307,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
291
307
 
292
308
  * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
293
309
  ``sfreq``, ``input_window_seconds``
294
- * numpy array: ``n_times``, ``n_chans``, ``n_outputs``
310
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
295
311
  * WindowsDataset with ``targets_from='metadata'``
296
312
  (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
297
313
  * other Dataset: ``n_times``, ``n_chans``
@@ -345,7 +361,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
345
361
 
346
362
  * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
347
363
  ``sfreq``, ``input_window_seconds``
348
- * numpy array: ``n_times``, ``n_chans``, ``n_outputs``
364
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
349
365
  * WindowsDataset with ``targets_from='metadata'``
350
366
  (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
351
367
  * other Dataset: ``n_times``, ``n_chans``
@@ -0,0 +1,10 @@
1
+ from .functions import (
2
+ _get_gaussian_kernel1d,
3
+ drop_path,
4
+ hilbert_freq,
5
+ identity,
6
+ plv_time,
7
+ safe_log,
8
+ square,
9
+ )
10
+ from .initialization import glorot_weight_zero_bias, rescale_parameter