braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,359 @@
1
+ """Convenience functions for storing and loading of windows datasets."""
2
+
3
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ import json
8
+ import os
9
+ import pickle
10
+ import warnings
11
+ from glob import glob
12
+ from pathlib import Path
13
+
14
+ import mne
15
+ import pandas as pd
16
+ from joblib import Parallel, delayed
17
+
18
+ from ..datasets.base import (
19
+ BaseConcatDataset,
20
+ EEGWindowsDataset,
21
+ RawDataset,
22
+ WindowsDataset,
23
+ )
24
+
25
+
26
+ def save_concat_dataset(path, concat_dataset, overwrite=False):
27
+ warnings.warn(
28
+ '"save_concat_dataset()" is deprecated and will be removed in'
29
+ " the future. Use dataset.save() instead.",
30
+ UserWarning,
31
+ )
32
+ concat_dataset.save(path=path, overwrite=overwrite)
33
+
34
+
35
+ def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
36
+ """Load a stored BaseConcatDataset from.
37
+
38
+ files.
39
+
40
+ Parameters
41
+ ----------
42
+ path : pathlib.Path
43
+ Path to the directory of the .fif / -epo.fif and .json files.
44
+ preload : bool
45
+ Whether to preload the data.
46
+ ids_to_load : None | list(int)
47
+ Ids of specific files to load.
48
+ target_name : None or str
49
+ Load specific description column as target. If not given, take saved
50
+ target name.
51
+
52
+ Returns
53
+ -------
54
+ concat_dataset : BaseConcatDataset
55
+ """
56
+ # assume we have a single concat dataset to load
57
+ is_raw = (path / "0-raw.fif").is_file()
58
+ assert not (not is_raw and target_name is not None), (
59
+ "Setting a new target is only supported for raws."
60
+ )
61
+ is_epochs = (path / "0-epo.fif").is_file()
62
+ paths = [path]
63
+ # assume we have multiple concat datasets to load
64
+ if not (is_raw or is_epochs):
65
+ is_raw = (path / "0" / "0-raw.fif").is_file()
66
+ is_epochs = (path / "0" / "0-epo.fif").is_file()
67
+ paths = path.glob("*/")
68
+ paths = sorted(paths, key=lambda p: int(p.name))
69
+ if ids_to_load is not None:
70
+ paths = [paths[i] for i in ids_to_load]
71
+ ids_to_load = None
72
+ # if we have neither a single nor multiple datasets, something went wrong
73
+ assert is_raw or is_epochs, (
74
+ f"Expect either raw or epo to exist in {path} or in {path / '0'}"
75
+ )
76
+
77
+ datasets = []
78
+ for path in paths:
79
+ if is_raw and target_name is None:
80
+ target_file_name = path / "target_name.json"
81
+ target_name = json.load(open(target_file_name, "r"))["target_name"]
82
+
83
+ all_signals, description = _load_signals_and_description(
84
+ path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
85
+ )
86
+ for i_signal, signal in enumerate(all_signals):
87
+ if is_raw:
88
+ datasets.append(
89
+ RawDataset(
90
+ signal, description.iloc[i_signal], target_name=target_name
91
+ )
92
+ )
93
+ else:
94
+ datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
95
+ concat_ds = BaseConcatDataset(datasets)
96
+ for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
97
+ kwarg_path = path / ".".join([kwarg_name, "json"])
98
+ if kwarg_path.exists():
99
+ with open(kwarg_path, "r") as f:
100
+ kwargs = json.load(f)
101
+ kwargs = [tuple(kwarg) for kwarg in kwargs]
102
+ setattr(concat_ds, kwarg_name, kwargs)
103
+ return concat_ds
104
+
105
+
106
+ def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
107
+ all_signals = []
108
+ file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
109
+ description_df = pd.read_json(
110
+ path / "description.json", typ="series", convert_dates=False
111
+ )
112
+
113
+ if "timestamp" in description_df.index:
114
+ timestamp_numeric = pd.to_numeric(description_df["timestamp"])
115
+ description_df["timestamp"] = pd.to_datetime(timestamp_numeric)
116
+
117
+ if ids_to_load is None:
118
+ file_names = path.glob(f"*{file_name.lstrip('{}')}")
119
+ # Extract ids, e.g.,
120
+ # '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
121
+ # '11-raw.fif' -> 11
122
+ ids_to_load = sorted(
123
+ [int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
124
+ )
125
+ for i in ids_to_load:
126
+ fif_file = path / file_name.format(i)
127
+ all_signals.append(_load_signals(fif_file, preload, is_raw))
128
+ description_df = description_df.iloc[ids_to_load]
129
+ return all_signals, description_df
130
+
131
+
132
+ def _load_signals(fif_file, preload, is_raw):
133
+ # Reading the raw file from pickle if it has been save before.
134
+ # The pickle file only contain the raw object without the data.
135
+ pkl_file = fif_file.with_suffix(".pkl")
136
+ if pkl_file.exists():
137
+ with open(pkl_file, "rb") as f:
138
+ signals = pickle.load(f)
139
+
140
+ if all(Path(f).exists() for f in signals.filenames):
141
+ if preload:
142
+ signals.load_data()
143
+ return signals
144
+ else: # This may happen if the file has been moved together with the pickle file.
145
+ warnings.warn(
146
+ f"Pickle file {pkl_file} exists, but the referenced fif "
147
+ "file(s) do not exist. Will read the fif file(s) directly "
148
+ "and re-create the pickle file.",
149
+ UserWarning,
150
+ )
151
+
152
+ # If pickle didn't exist read via mne (likely slower) and save pkl after
153
+ if is_raw:
154
+ signals = mne.io.read_raw_fif(fif_file, preload=preload)
155
+ elif fif_file.name.endswith("-epo.fif"):
156
+ signals = mne.read_epochs(fif_file, preload=preload)
157
+ else:
158
+ raise ValueError("fif_file must end with raw.fif or epo.fif.")
159
+
160
+ # Only do this for raw objects. Epoch objects are not picklable as they
161
+ # hold references to open files in `signals._raw[0].fid`.
162
+ if is_raw:
163
+ # Saving the raw file without data into a pickle file, so it can be
164
+ # retrieved faster on the next use of this dataset.
165
+ with open(pkl_file, "wb") as f:
166
+ if preload:
167
+ data = signals._data
168
+ signals._data, signals.preload = None, False
169
+ pickle.dump(signals, f)
170
+ if preload:
171
+ signals._data, signals.preload = data, True
172
+
173
+ return signals
174
+
175
+
176
+ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
177
+ """Load a stored BaseConcatDataset from.
178
+
179
+ files.
180
+
181
+ Parameters
182
+ ----------
183
+ path : str | pathlib.Path
184
+ Path to the directory of the .fif / -epo.fif and .json files.
185
+ preload : bool
186
+ Whether to preload the data.
187
+ ids_to_load : list of int | None
188
+ Ids of specific files to load.
189
+ target_name : str | list | None
190
+ Load specific description column as target. If not given, take saved
191
+ target name.
192
+ n_jobs : int
193
+ Number of jobs to be used to read files in parallel.
194
+
195
+ Returns
196
+ -------
197
+ concat_dataset : BaseConcatDataset
198
+ """
199
+ # Make sure we always work with a pathlib.Path
200
+ path = Path(path)
201
+
202
+ # if we encounter a dataset that was saved in 'the old way', call the
203
+ # corresponding 'old' loading function
204
+ if _is_outdated_saved(path):
205
+ warnings.warn(
206
+ "The way your dataset was saved is deprecated by now. "
207
+ "Please save it again using dataset.save().",
208
+ UserWarning,
209
+ )
210
+ return _outdated_load_concat_dataset(
211
+ path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
212
+ )
213
+
214
+ # else we have a dataset saved in the new way with subdirectories in path
215
+ # for every dataset with description.json and -epo.fif or -raw.fif,
216
+ # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
217
+ # window_preproc_kwargs.json
218
+ if ids_to_load is None:
219
+ ids_to_load = [p.name for p in path.iterdir()]
220
+ ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
221
+ ids_to_load = [str(i) for i in ids_to_load]
222
+ first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
223
+ is_raw = first_raw_fif_path.exists()
224
+ metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
225
+ has_stored_windows = metadata_path.exists()
226
+
227
+ # Parallelization of mne.read_epochs with preload=False fails with
228
+ # 'TypeError: cannot pickle '_io.BufferedReader' object'.
229
+ # So ignore n_jobs in that case and load with a single job.
230
+ if not is_raw and n_jobs != 1:
231
+ warnings.warn(
232
+ "Parallelized reading with `preload=False` is not supported for "
233
+ "windowed data. Will use `n_jobs=1`.",
234
+ UserWarning,
235
+ )
236
+ n_jobs = 1
237
+ datasets = Parallel(n_jobs)(
238
+ delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
239
+ for i in ids_to_load
240
+ )
241
+ return BaseConcatDataset(datasets)
242
+
243
+
244
+ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
245
+ sub_dir = path / i
246
+ file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
247
+ if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
248
+ raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
249
+
250
+ fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
251
+ fif_file_name = fif_name_pattern.format(i)
252
+ fif_file_path = sub_dir / fif_file_name
253
+
254
+ signals = _load_signals(fif_file_path, preload, is_raw)
255
+
256
+ description_file_path = sub_dir / "description.json"
257
+ description = pd.read_json(description_file_path, typ="series", convert_dates=False)
258
+
259
+ # if 'timestamp' in description.index:
260
+ # timestamp_numeric = pd.to_numeric(description['timestamp'])
261
+ # description['timestamp'] = pd.to_datetime(timestamp_numeric, unit='s')
262
+
263
+ target_file_path = sub_dir / "target_name.json"
264
+ target_name = None
265
+ if target_file_path.exists():
266
+ target_name = json.load(open(target_file_path, "r"))["target_name"]
267
+
268
+ if is_raw and (not has_stored_windows):
269
+ dataset = RawDataset(signals, description, target_name)
270
+ else:
271
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
272
+ windows_ds_kwargs = [
273
+ kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
274
+ ]
275
+ windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
276
+ if is_raw:
277
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
278
+ dataset = EEGWindowsDataset(
279
+ signals,
280
+ metadata=metadata,
281
+ description=description,
282
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
283
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
284
+ )
285
+ else:
286
+ # MNE epochs dataset
287
+ dataset = WindowsDataset(
288
+ signals,
289
+ description,
290
+ targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
291
+ last_target_only=windows_ds_kwargs.get("last_target_only", True),
292
+ )
293
+ setattr(dataset, "window_kwargs", window_kwargs)
294
+ for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
295
+ kwargs = _load_kwargs_json(kwargs_name, sub_dir)
296
+ setattr(dataset, kwargs_name, kwargs)
297
+ return dataset
298
+
299
+
300
+ def _load_kwargs_json(kwargs_name, sub_dir):
301
+ kwargs_file_name = ".".join([kwargs_name, "json"])
302
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
303
+ if os.path.exists(kwargs_file_path):
304
+ kwargs = json.load(open(kwargs_file_path, "r"))
305
+ return kwargs
306
+
307
+
308
+ def _is_outdated_saved(path):
309
+ """Data was saved in the old way if there are 'description.json', '-raw.fif'.
310
+
311
+ or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
312
+ than 'description.json' files.
313
+ """
314
+ description_files = glob(os.path.join(path, "**/description.json"))
315
+ fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
316
+ os.path.join(path, "**/*-epo.fif")
317
+ )
318
+ multiple = len(description_files) != len(fif_files)
319
+ kwargs_in_path = any(
320
+ [
321
+ os.path.exists(os.path.join(path, kwarg_name))
322
+ for kwarg_name in [
323
+ "raw_preproc_kwargs",
324
+ "window_kwargs",
325
+ "window_preproc_kwargs",
326
+ ]
327
+ ]
328
+ )
329
+ return (
330
+ os.path.exists(os.path.join(path, "description.json"))
331
+ or os.path.exists(os.path.join(path, "0-raw.fif"))
332
+ or os.path.exists(os.path.join(path, "0-epo.fif"))
333
+ or multiple
334
+ or kwargs_in_path
335
+ )
336
+
337
+
338
+ def _check_save_dir_empty(save_dir):
339
+ """Make sure a BaseConcatDataset can be saved under a given directory.
340
+
341
+ Parameters
342
+ ----------
343
+ save_dir : str
344
+ Directory under which a `BaseConcatDataset` will be saved.
345
+
346
+ Raises
347
+ ------
348
+ FileExistsError
349
+ If ``save_dir`` is not a valid directory for saving.
350
+ """
351
+ sub_dirs = [
352
+ os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
353
+ ]
354
+ if any(sub_dirs):
355
+ raise FileExistsError(
356
+ f"Directory {save_dir} already contains subdirectories. Please "
357
+ "select a different directory, set overwrite=True, or resolve "
358
+ "manually."
359
+ )
@@ -0,0 +1,154 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import logging
6
+ from typing import Any, Literal
7
+
8
+ import mne
9
+ import numpy as np
10
+ from skorch.helper import SliceDataset
11
+ from skorch.utils import is_dataset
12
+
13
+ from braindecode.datasets.base import BaseConcatDataset, WindowsDataset
14
+ from braindecode.models.util import SigArgName
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def ms_to_samples(ms, fs):
20
+ """
21
+ Compute milliseconds to number of samples.
22
+
23
+ Parameters
24
+ ----------
25
+ ms : number
26
+ Milliseconds
27
+ fs : number
28
+ Sampling rate
29
+
30
+ Returns
31
+ -------
32
+ n_samples : int
33
+ Number of samples
34
+ """
35
+ return ms * fs / 1000.0
36
+
37
+
38
+ def samples_to_ms(n_samples, fs):
39
+ """
40
+ Compute milliseconds to number of samples.
41
+
42
+ Parameters
43
+ ----------
44
+ n_samples : number
45
+ Number of samples
46
+ fs : number
47
+ Sampling rate
48
+
49
+ Returns
50
+ -------
51
+ milliseconds : int
52
+ """
53
+ return n_samples * 1000.0 / fs
54
+
55
+
56
+ def _get_n_outputs(y, classes, mode):
57
+ if mode == "classification":
58
+ classes_y = np.unique(y)
59
+ if classes is not None:
60
+ assert set(classes_y) <= set(classes)
61
+ else:
62
+ classes = classes_y
63
+ return len(classes)
64
+ elif mode == "regression":
65
+ if y is None:
66
+ return None
67
+ if y.ndim == 1:
68
+ return 1
69
+ else:
70
+ return y.shape[-1]
71
+ else:
72
+ raise ValueError(f"Unknown mode {mode}")
73
+
74
+
75
+ def infer_signal_properties(
76
+ X,
77
+ y=None,
78
+ mode: Literal["classification", "regression"] = "classification",
79
+ classes: list | None = None,
80
+ ) -> dict[SigArgName, Any]:
81
+ """Infers signal properties from the data.
82
+
83
+ The extracted signal properties are:
84
+
85
+ + n_chans: number of channels
86
+ + n_times: number of time points
87
+ + n_outputs: number of outputs
88
+ + chs_info: channel information
89
+ + sfreq: sampling frequency
90
+
91
+ The returned dictionary can serve as kwargs for model initialization.
92
+
93
+ Depending on the type of input passed, not all properties can be inferred.
94
+
95
+ Parameters
96
+ ----------
97
+ X : array-like or mne.BaseEpochs or Dataset
98
+ Input data
99
+ y : array-like or None
100
+ Targets
101
+ mode : "classification" or "regression"
102
+ Mode of the task
103
+ classes : list or None
104
+ List of classes for classification
105
+
106
+ Returns
107
+ -------
108
+ signal_kwargs : dict
109
+ Dictionary with signal-properties. Can serve as kwargs for model
110
+ initialization.
111
+ """
112
+ signal_kwargs: dict[SigArgName, Any] = {}
113
+ # Using shape to work both with torch.tensor and numpy.array:
114
+ if (
115
+ isinstance(X, mne.BaseEpochs)
116
+ or (hasattr(X, "shape") and len(X.shape) >= 2)
117
+ or isinstance(X, SliceDataset)
118
+ ):
119
+ if y is None:
120
+ raise ValueError("y must be specified if X is array-like.")
121
+ signal_kwargs["n_outputs"] = _get_n_outputs(y, classes, mode)
122
+ if isinstance(X, mne.BaseEpochs):
123
+ log.info("Using mne.Epochs to find signal-related parameters.")
124
+ signal_kwargs["n_times"] = len(X.times)
125
+ signal_kwargs["sfreq"] = X.info["sfreq"]
126
+ signal_kwargs["chs_info"] = X.info["chs"]
127
+ elif isinstance(X, SliceDataset):
128
+ log.info("Using SliceDataset to find signal-related parameters.")
129
+ Xshape = X[0].shape
130
+ signal_kwargs["n_times"] = Xshape[-1]
131
+ signal_kwargs["n_chans"] = Xshape[-2]
132
+ else:
133
+ log.info("Using array-like to find signal-related parameters.")
134
+ signal_kwargs["n_times"] = X.shape[-1]
135
+ signal_kwargs["n_chans"] = X.shape[-2]
136
+ elif is_dataset(X):
137
+ log.info(f"Using Dataset {X!r} to find signal-related parameters.")
138
+ X0 = X[0][0]
139
+ Xshape = X0.shape
140
+ signal_kwargs["n_times"] = Xshape[-1]
141
+ signal_kwargs["n_chans"] = Xshape[-2]
142
+ if isinstance(X, BaseConcatDataset) and all(
143
+ ds.targets_from == "metadata" for ds in X.datasets
144
+ ):
145
+ y_target = X.get_metadata().target
146
+ signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
147
+ elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
148
+ y_target = X.windows.metadata.target
149
+ signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
150
+ else:
151
+ log.warning(
152
+ f"Can only infer signal shape of array-like and Datasets, got {type(X)!r}."
153
+ )
154
+ return signal_kwargs