braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,823 @@
|
|
|
1
|
+
"""Dataset classes."""
|
|
2
|
+
|
|
3
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Simon Brandt <simonbrandt@protonmail.com>
|
|
6
|
+
# David Sabbagh <dav.sabbagh@gmail.com>
|
|
7
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
import shutil
|
|
16
|
+
import warnings
|
|
17
|
+
from abc import abstractmethod
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from glob import glob
|
|
20
|
+
from typing import Any, Generic, Iterable, no_type_check
|
|
21
|
+
|
|
22
|
+
import mne.io
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
from mne.utils.docs import deprecated
|
|
26
|
+
from torch.utils.data import ConcatDataset, Dataset
|
|
27
|
+
from typing_extensions import TypeVar
|
|
28
|
+
|
|
29
|
+
from .bids.hub import HubDatasetMixin
|
|
30
|
+
from .registry import register_dataset
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _create_description(description) -> pd.Series:
|
|
34
|
+
if description is not None:
|
|
35
|
+
if not isinstance(description, pd.Series) and not isinstance(description, dict):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"'{description}' has to be either a pandas.Series or a dict."
|
|
38
|
+
)
|
|
39
|
+
if isinstance(description, dict):
|
|
40
|
+
description = pd.Series(description)
|
|
41
|
+
return description
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]]):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
description: dict | pd.Series | None = None,
|
|
48
|
+
transform: Callable | None = None,
|
|
49
|
+
):
|
|
50
|
+
self._description = _create_description(description)
|
|
51
|
+
self.transform = transform
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def __len__(self) -> int:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def description(self) -> pd.Series:
|
|
59
|
+
return self._description
|
|
60
|
+
|
|
61
|
+
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
62
|
+
"""Update (add or overwrite) the dataset description.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
description : dict | pd.Series
|
|
67
|
+
Description in the form key: value.
|
|
68
|
+
overwrite : bool
|
|
69
|
+
Has to be True if a key in description already exists in the
|
|
70
|
+
dataset description.
|
|
71
|
+
"""
|
|
72
|
+
description = _create_description(description)
|
|
73
|
+
if self.description is None:
|
|
74
|
+
self._description = description
|
|
75
|
+
else:
|
|
76
|
+
for key, value in description.items():
|
|
77
|
+
# if the key is already in the existing description, drop it
|
|
78
|
+
if key in self._description:
|
|
79
|
+
assert overwrite, (
|
|
80
|
+
f"'{key}' already in description. Please "
|
|
81
|
+
f"rename or set overwrite to True."
|
|
82
|
+
)
|
|
83
|
+
self._description.pop(key)
|
|
84
|
+
self._description = pd.concat([self.description, description])
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def transform(self) -> Callable | None:
|
|
88
|
+
return self._transform
|
|
89
|
+
|
|
90
|
+
@transform.setter
|
|
91
|
+
def transform(self, value: Callable | None):
|
|
92
|
+
if value is not None and not callable(value):
|
|
93
|
+
raise ValueError("Transform needs to be a callable.")
|
|
94
|
+
self._transform = value
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# Type of the datasets contained in BaseConcatDataset
|
|
98
|
+
T = TypeVar("T", bound=RecordDataset)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@register_dataset
|
|
102
|
+
class RawDataset(RecordDataset):
|
|
103
|
+
"""Returns samples from an mne.io.Raw object along with a target.
|
|
104
|
+
|
|
105
|
+
Dataset which serves samples from an mne.io.Raw object along with a target.
|
|
106
|
+
The target is unique for the dataset, and is obtained through the
|
|
107
|
+
`description` attribute.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
raw : mne.io.Raw
|
|
112
|
+
Continuous data.
|
|
113
|
+
description : dict | pandas.Series | None
|
|
114
|
+
Holds additional description about the continuous signal / subject.
|
|
115
|
+
target_name : str | tuple | None
|
|
116
|
+
Name(s) of the index in `description` that should be used to provide the
|
|
117
|
+
target (e.g., to be used in a prediction task later on).
|
|
118
|
+
transform : callable | None
|
|
119
|
+
On-the-fly transform applied to the example before it is returned.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
raw: mne.io.BaseRaw,
|
|
125
|
+
description: dict | pd.Series | None = None,
|
|
126
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
127
|
+
transform: Callable | None = None,
|
|
128
|
+
):
|
|
129
|
+
super().__init__(description, transform)
|
|
130
|
+
self.raw = raw
|
|
131
|
+
|
|
132
|
+
# save target name for load/save later
|
|
133
|
+
self.target_name = self._target_name(target_name)
|
|
134
|
+
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
135
|
+
|
|
136
|
+
def __getitem__(self, index):
|
|
137
|
+
X = self.raw[:, index][0]
|
|
138
|
+
y = None
|
|
139
|
+
if self.target_name is not None:
|
|
140
|
+
y = self.description[self.target_name]
|
|
141
|
+
if isinstance(y, pd.Series):
|
|
142
|
+
y = y.to_list()
|
|
143
|
+
if self.transform is not None:
|
|
144
|
+
X = self.transform(X)
|
|
145
|
+
return X, y
|
|
146
|
+
|
|
147
|
+
def __len__(self):
|
|
148
|
+
return len(self.raw)
|
|
149
|
+
|
|
150
|
+
def _target_name(self, target_name):
|
|
151
|
+
if target_name is not None and not isinstance(target_name, (str, tuple, list)):
|
|
152
|
+
raise ValueError("target_name has to be None, str, tuple or list")
|
|
153
|
+
if target_name is None:
|
|
154
|
+
return target_name
|
|
155
|
+
else:
|
|
156
|
+
# convert tuple of names or single name to list
|
|
157
|
+
if isinstance(target_name, tuple):
|
|
158
|
+
target_name = [name for name in target_name]
|
|
159
|
+
elif not isinstance(target_name, list):
|
|
160
|
+
assert isinstance(target_name, str)
|
|
161
|
+
target_name = [target_name]
|
|
162
|
+
assert isinstance(target_name, list)
|
|
163
|
+
# check if target name(s) can be read from description
|
|
164
|
+
for name in target_name:
|
|
165
|
+
if self.description is None or name not in self.description:
|
|
166
|
+
warnings.warn(
|
|
167
|
+
f"'{name}' not in description. '__getitem__'"
|
|
168
|
+
f"will fail unless an appropriate target is"
|
|
169
|
+
f" added to description.",
|
|
170
|
+
UserWarning,
|
|
171
|
+
)
|
|
172
|
+
# return a list of str if there are multiple targets and a str otherwise
|
|
173
|
+
return target_name if len(target_name) > 1 else target_name[0]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@deprecated(
|
|
177
|
+
"The BaseDataset class is deprecated. "
|
|
178
|
+
"If you want to instantiate a dataset containing raws, use RawDataset instead. "
|
|
179
|
+
"If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
|
|
180
|
+
"use the RecordDataset class instead."
|
|
181
|
+
)
|
|
182
|
+
@register_dataset
|
|
183
|
+
class BaseDataset(RawDataset):
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@register_dataset
|
|
188
|
+
class EEGWindowsDataset(RecordDataset):
|
|
189
|
+
"""Returns windows from an mne.Raw object, its window indices, along with a target.
|
|
190
|
+
|
|
191
|
+
Dataset which serves windows from an mne.Epochs object along with their
|
|
192
|
+
target and additional information. The `metadata` attribute of the Epochs
|
|
193
|
+
object must contain a column called `target`, which will be used to return
|
|
194
|
+
the target that corresponds to a window. Additional columns
|
|
195
|
+
`i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
|
|
196
|
+
required to serve information about the windowing (e.g., useful for cropped
|
|
197
|
+
training).
|
|
198
|
+
See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
|
|
199
|
+
from a `RawDataset` object.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
windows : mne.Raw or mne.Epochs (Epochs is outdated)
|
|
204
|
+
Windows obtained through the application of a windower to a ``RawDataset``
|
|
205
|
+
(see `braindecode.datautil.windowers`).
|
|
206
|
+
description : dict | pandas.Series | None
|
|
207
|
+
Holds additional info about the windows.
|
|
208
|
+
transform : callable | None
|
|
209
|
+
On-the-fly transform applied to a window before it is returned.
|
|
210
|
+
targets_from : str
|
|
211
|
+
Defines whether targets will be extracted from metadata or from `misc`
|
|
212
|
+
channels (time series targets). It can be `metadata` (default) or `channels`.
|
|
213
|
+
last_target_only : bool
|
|
214
|
+
If targets are obtained from misc channels whether all targets if the entire
|
|
215
|
+
(compute) window will be returned or only the last target in the window.
|
|
216
|
+
metadata : pandas.DataFrame
|
|
217
|
+
Dataframe with crop indices, so `i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial`
|
|
218
|
+
as well as `targets`.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
raw: mne.io.BaseRaw,
|
|
224
|
+
metadata: pd.DataFrame,
|
|
225
|
+
description: dict | pd.Series | None = None,
|
|
226
|
+
transform: Callable | None = None,
|
|
227
|
+
targets_from: str = "metadata",
|
|
228
|
+
last_target_only: bool = True,
|
|
229
|
+
):
|
|
230
|
+
super().__init__(description, transform)
|
|
231
|
+
self.raw = raw
|
|
232
|
+
self.metadata = metadata
|
|
233
|
+
|
|
234
|
+
self.last_target_only = last_target_only
|
|
235
|
+
if targets_from not in ("metadata", "channels"):
|
|
236
|
+
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
237
|
+
self.targets_from = targets_from
|
|
238
|
+
self.crop_inds = metadata.loc[
|
|
239
|
+
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
240
|
+
].to_numpy()
|
|
241
|
+
if self.targets_from == "metadata":
|
|
242
|
+
self.y = metadata.loc[:, "target"].to_list()
|
|
243
|
+
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
244
|
+
|
|
245
|
+
def __getitem__(self, index: int):
|
|
246
|
+
"""Get a window and its target.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
index : int
|
|
251
|
+
Index to the window (and target) to return.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
np.ndarray
|
|
256
|
+
Window of shape (n_channels, n_times).
|
|
257
|
+
int
|
|
258
|
+
Target for the windows.
|
|
259
|
+
np.ndarray
|
|
260
|
+
Crop indices.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
# necessary to cast as list to get list of three tensors from batch,
|
|
264
|
+
# otherwise get single 2d-tensor...
|
|
265
|
+
crop_inds = self.crop_inds[index].tolist()
|
|
266
|
+
|
|
267
|
+
i_window_in_trial, i_start, i_stop = crop_inds
|
|
268
|
+
X = self.raw._getitem((slice(None), slice(i_start, i_stop)), return_times=False)
|
|
269
|
+
X = X.astype("float32")
|
|
270
|
+
# ensure we don't give the user the option
|
|
271
|
+
# to accidentally modify the underlying array
|
|
272
|
+
X = X.copy()
|
|
273
|
+
if self.transform is not None:
|
|
274
|
+
X = self.transform(X)
|
|
275
|
+
if self.targets_from == "metadata":
|
|
276
|
+
y = self.y[index]
|
|
277
|
+
else:
|
|
278
|
+
misc_mask = np.array(self.raw.get_channel_types()) == "misc"
|
|
279
|
+
if self.last_target_only:
|
|
280
|
+
y = X[misc_mask, -1]
|
|
281
|
+
else:
|
|
282
|
+
y = X[misc_mask, :]
|
|
283
|
+
# ensure we don't give the user the option
|
|
284
|
+
# to accidentally modify the underlying array
|
|
285
|
+
y = y.copy()
|
|
286
|
+
# remove the target channels from raw
|
|
287
|
+
X = X[~misc_mask, :]
|
|
288
|
+
return X, y, crop_inds
|
|
289
|
+
|
|
290
|
+
def __len__(self):
|
|
291
|
+
return len(self.crop_inds)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@register_dataset
|
|
295
|
+
class WindowsDataset(RecordDataset):
|
|
296
|
+
"""Returns windows from an mne.Epochs object along with a target.
|
|
297
|
+
|
|
298
|
+
Dataset which serves windows from an mne.Epochs object along with their
|
|
299
|
+
target and additional information. The `metadata` attribute of the Epochs
|
|
300
|
+
object must contain a column called `target`, which will be used to return
|
|
301
|
+
the target that corresponds to a window. Additional columns
|
|
302
|
+
`i_window_in_trial`, `i_start_in_trial`, `i_stop_in_trial` are also
|
|
303
|
+
required to serve information about the windowing (e.g., useful for cropped
|
|
304
|
+
training).
|
|
305
|
+
See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
|
|
306
|
+
from a ``RawDataset`` object.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
windows : mne.Epochs
|
|
311
|
+
Windows obtained through the application of a windower to a RawDataset
|
|
312
|
+
(see `braindecode.datautil.windowers`).
|
|
313
|
+
description : dict | pandas.Series | None
|
|
314
|
+
Holds additional info about the windows.
|
|
315
|
+
transform : callable | None
|
|
316
|
+
On-the-fly transform applied to a window before it is returned.
|
|
317
|
+
targets_from : str
|
|
318
|
+
Defines whether targets will be extracted from mne.Epochs metadata or mne.Epochs `misc`
|
|
319
|
+
channels (time series targets). It can be `metadata` (default) or `channels`.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
windows: mne.BaseEpochs,
|
|
325
|
+
description: dict | pd.Series | None = None,
|
|
326
|
+
transform: Callable | None = None,
|
|
327
|
+
targets_from: str = "metadata",
|
|
328
|
+
last_target_only: bool = True,
|
|
329
|
+
):
|
|
330
|
+
super().__init__(description, transform)
|
|
331
|
+
self.windows = windows
|
|
332
|
+
self.last_target_only = last_target_only
|
|
333
|
+
if targets_from not in ("metadata", "channels"):
|
|
334
|
+
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
335
|
+
self.targets_from = targets_from
|
|
336
|
+
|
|
337
|
+
metadata = self.windows.metadata
|
|
338
|
+
assert metadata is not None, "WindowsDataset requires windows with metadata."
|
|
339
|
+
self.crop_inds = metadata.loc[
|
|
340
|
+
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
341
|
+
].to_numpy()
|
|
342
|
+
if self.targets_from == "metadata":
|
|
343
|
+
self.y = metadata.loc[:, "target"].to_list()
|
|
344
|
+
self.raw_preproc_kwargs: list[dict[str, Any]] = []
|
|
345
|
+
self.window_preproc_kwargs: list[dict[str, Any]] = []
|
|
346
|
+
|
|
347
|
+
def __getitem__(self, index: int):
|
|
348
|
+
"""Get a window and its target.
|
|
349
|
+
|
|
350
|
+
Parameters
|
|
351
|
+
----------
|
|
352
|
+
index : int
|
|
353
|
+
Index to the window (and target) to return.
|
|
354
|
+
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
np.ndarray
|
|
358
|
+
Window of shape (n_channels, n_times).
|
|
359
|
+
int
|
|
360
|
+
Target for the windows.
|
|
361
|
+
np.ndarray
|
|
362
|
+
Crop indices.
|
|
363
|
+
"""
|
|
364
|
+
X = self.windows.get_data(item=index)[0].astype("float32")
|
|
365
|
+
if self.transform is not None:
|
|
366
|
+
X = self.transform(X)
|
|
367
|
+
if self.targets_from == "metadata":
|
|
368
|
+
y = self.y[index]
|
|
369
|
+
else:
|
|
370
|
+
misc_mask = np.array(self.windows.get_channel_types()) == "misc"
|
|
371
|
+
if self.last_target_only:
|
|
372
|
+
y = X[misc_mask, -1]
|
|
373
|
+
else:
|
|
374
|
+
y = X[misc_mask, :]
|
|
375
|
+
# remove the target channels from raw
|
|
376
|
+
X = X[~misc_mask, :]
|
|
377
|
+
# necessary to cast as list to get list of three tensors from batch,
|
|
378
|
+
# otherwise get single 2d-tensor...
|
|
379
|
+
crop_inds = self.crop_inds[index].tolist()
|
|
380
|
+
return X, y, crop_inds
|
|
381
|
+
|
|
382
|
+
def __len__(self) -> int:
|
|
383
|
+
return len(self.windows.events)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@register_dataset
|
|
387
|
+
class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
388
|
+
"""A base class for concatenated datasets.
|
|
389
|
+
|
|
390
|
+
Holds either mne.Raw or mne.Epoch in self.datasets and has
|
|
391
|
+
a pandas DataFrame with additional description.
|
|
392
|
+
|
|
393
|
+
Includes Hugging Face Hub integration via HubDatasetMixin for
|
|
394
|
+
uploading and downloading datasets.
|
|
395
|
+
|
|
396
|
+
Parameters
|
|
397
|
+
----------
|
|
398
|
+
list_of_ds : list
|
|
399
|
+
list of RecordDataset
|
|
400
|
+
target_transform : callable | None
|
|
401
|
+
Optional function to call on targets before returning them.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
datasets: list[T]
|
|
405
|
+
|
|
406
|
+
def __init__(
|
|
407
|
+
self,
|
|
408
|
+
list_of_ds: list[T | BaseConcatDataset[T]],
|
|
409
|
+
target_transform: Callable | None = None,
|
|
410
|
+
):
|
|
411
|
+
# if we get a list of BaseConcatDataset, get all the individual datasets
|
|
412
|
+
flattened_list_of_ds: list[T] = []
|
|
413
|
+
for ds in list_of_ds:
|
|
414
|
+
if isinstance(ds, BaseConcatDataset):
|
|
415
|
+
flattened_list_of_ds.extend(ds.datasets)
|
|
416
|
+
else:
|
|
417
|
+
flattened_list_of_ds.append(ds)
|
|
418
|
+
super().__init__(flattened_list_of_ds)
|
|
419
|
+
|
|
420
|
+
self.target_transform = target_transform
|
|
421
|
+
|
|
422
|
+
def _get_sequence(self, indices):
|
|
423
|
+
X, y = list(), list()
|
|
424
|
+
for ind in indices:
|
|
425
|
+
out_i = super().__getitem__(ind)
|
|
426
|
+
X.append(out_i[0])
|
|
427
|
+
y.append(out_i[1])
|
|
428
|
+
|
|
429
|
+
X = np.stack(X, axis=0)
|
|
430
|
+
y = np.array(y)
|
|
431
|
+
|
|
432
|
+
return X, y
|
|
433
|
+
|
|
434
|
+
def __getitem__(self, idx: int | list):
|
|
435
|
+
"""
|
|
436
|
+
---
|
|
437
|
+
|
|
438
|
+
idx : int | list
|
|
439
|
+
Index of window and target to return. If provided as a list of
|
|
440
|
+
ints, multiple windows and targets will be extracted and
|
|
441
|
+
concatenated. The target output can be modified on the
|
|
442
|
+
fly by the ``traget_transform`` parameter.
|
|
443
|
+
"""
|
|
444
|
+
if isinstance(idx, Iterable): # Sample multiple windows
|
|
445
|
+
item = self._get_sequence(idx)
|
|
446
|
+
else:
|
|
447
|
+
item = super().__getitem__(idx)
|
|
448
|
+
if self.target_transform is not None:
|
|
449
|
+
item = item[:1] + (self.target_transform(item[1]),) + item[2:]
|
|
450
|
+
return item
|
|
451
|
+
|
|
452
|
+
@no_type_check # TODO, it's a mess
|
|
453
|
+
def split(
|
|
454
|
+
self,
|
|
455
|
+
by: str | list[int] | list[list[int]] | dict[str, list[int]] | None = None,
|
|
456
|
+
property: str | None = None,
|
|
457
|
+
split_ids: list[int] | list[list[int]] | dict[str, list[int]] | None = None,
|
|
458
|
+
) -> dict[str, BaseConcatDataset]:
|
|
459
|
+
"""Split the dataset based on information listed in its description.
|
|
460
|
+
|
|
461
|
+
The format could be based on a DataFrame or based on indices.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
by : str | list | dict
|
|
466
|
+
If ``by`` is a string, splitting is performed based on the
|
|
467
|
+
description DataFrame column with this name.
|
|
468
|
+
If ``by`` is a (list of) list of integers, the position in the first
|
|
469
|
+
list corresponds to the split id and the integers to the
|
|
470
|
+
datapoints of that split.
|
|
471
|
+
If a dict then each key will be used in the returned
|
|
472
|
+
splits dict and each value should be a list of int.
|
|
473
|
+
property : str
|
|
474
|
+
Some property which is listed in the info DataFrame.
|
|
475
|
+
split_ids : list | dict
|
|
476
|
+
List of indices to be combined in a subset.
|
|
477
|
+
It can be a list of int or a list of list of int.
|
|
478
|
+
|
|
479
|
+
Returns
|
|
480
|
+
-------
|
|
481
|
+
splits : dict
|
|
482
|
+
A dictionary with the name of the split (a string) as key and the
|
|
483
|
+
dataset as value.
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
args_not_none = [by is not None, property is not None, split_ids is not None]
|
|
487
|
+
if sum(args_not_none) != 1:
|
|
488
|
+
raise ValueError("Splitting requires exactly one argument.")
|
|
489
|
+
|
|
490
|
+
if property is not None or split_ids is not None:
|
|
491
|
+
warnings.warn(
|
|
492
|
+
"Keyword arguments `property` and `split_ids` "
|
|
493
|
+
"are deprecated and will be removed in the future. "
|
|
494
|
+
"Use `by` instead.",
|
|
495
|
+
DeprecationWarning,
|
|
496
|
+
)
|
|
497
|
+
by = property if property is not None else split_ids
|
|
498
|
+
if isinstance(by, str):
|
|
499
|
+
split_ids = {
|
|
500
|
+
k: list(v) for k, v in self.description.groupby(by).groups.items()
|
|
501
|
+
}
|
|
502
|
+
elif isinstance(by, dict):
|
|
503
|
+
split_ids = by
|
|
504
|
+
else:
|
|
505
|
+
# assume list(int)
|
|
506
|
+
if not isinstance(by[0], list):
|
|
507
|
+
by = [by]
|
|
508
|
+
# assume list(list(int))
|
|
509
|
+
split_ids = {split_i: split for split_i, split in enumerate(by)}
|
|
510
|
+
|
|
511
|
+
return {
|
|
512
|
+
str(split_name): BaseConcatDataset(
|
|
513
|
+
[self.datasets[ds_ind] for ds_ind in ds_inds],
|
|
514
|
+
target_transform=self.target_transform,
|
|
515
|
+
)
|
|
516
|
+
for split_name, ds_inds in split_ids.items()
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
def get_metadata(self) -> pd.DataFrame:
|
|
520
|
+
"""Concatenate the metadata and description of the wrapped Epochs.
|
|
521
|
+
|
|
522
|
+
Returns
|
|
523
|
+
-------
|
|
524
|
+
metadata : pd.DataFrame
|
|
525
|
+
DataFrame containing as many rows as there are windows in the
|
|
526
|
+
BaseConcatDataset, with the metadata and description information
|
|
527
|
+
for each window.
|
|
528
|
+
"""
|
|
529
|
+
if not all(
|
|
530
|
+
[
|
|
531
|
+
isinstance(ds, (WindowsDataset, EEGWindowsDataset))
|
|
532
|
+
for ds in self.datasets
|
|
533
|
+
]
|
|
534
|
+
):
|
|
535
|
+
raise TypeError(
|
|
536
|
+
"Metadata dataframe can only be computed when all "
|
|
537
|
+
"datasets are WindowsDataset."
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
all_dfs = list()
|
|
541
|
+
for ds in self.datasets:
|
|
542
|
+
if hasattr(ds, "windows"):
|
|
543
|
+
df = ds.windows.metadata
|
|
544
|
+
else:
|
|
545
|
+
df = ds.metadata
|
|
546
|
+
for k, v in ds.description.items():
|
|
547
|
+
df[k] = v
|
|
548
|
+
all_dfs.append(df)
|
|
549
|
+
|
|
550
|
+
return pd.concat(all_dfs)
|
|
551
|
+
|
|
552
|
+
@property
|
|
553
|
+
def transform(self):
|
|
554
|
+
return [ds.transform for ds in self.datasets]
|
|
555
|
+
|
|
556
|
+
@transform.setter
|
|
557
|
+
def transform(self, fn):
|
|
558
|
+
for i in range(len(self.datasets)):
|
|
559
|
+
self.datasets[i].transform = fn
|
|
560
|
+
|
|
561
|
+
@property
|
|
562
|
+
def target_transform(self):
|
|
563
|
+
return self._target_transform
|
|
564
|
+
|
|
565
|
+
@target_transform.setter
|
|
566
|
+
def target_transform(self, fn):
|
|
567
|
+
if not (callable(fn) or fn is None):
|
|
568
|
+
raise TypeError("target_transform must be a callable.")
|
|
569
|
+
self._target_transform = fn
|
|
570
|
+
|
|
571
|
+
def _outdated_save(self, path, overwrite=False):
|
|
572
|
+
"""This is a copy of the old saving function, that had inconsistent.
|
|
573
|
+
|
|
574
|
+
functionality for BaseDataset and WindowsDataset. It only exists to
|
|
575
|
+
assure backwards compatibility by still being able to run the old tests.
|
|
576
|
+
|
|
577
|
+
Save dataset to files.
|
|
578
|
+
|
|
579
|
+
Parameters
|
|
580
|
+
----------
|
|
581
|
+
path : str
|
|
582
|
+
Directory to which .fif / -epo.fif and .json files are stored.
|
|
583
|
+
overwrite : bool
|
|
584
|
+
Whether to delete old files (.json, .fif, -epo.fif) in specified
|
|
585
|
+
directory prior to saving.
|
|
586
|
+
"""
|
|
587
|
+
warnings.warn(
|
|
588
|
+
"This function only exists for backwards compatibility "
|
|
589
|
+
"purposes. DO NOT USE!",
|
|
590
|
+
UserWarning,
|
|
591
|
+
)
|
|
592
|
+
if isinstance(self.datasets[0], EEGWindowsDataset):
|
|
593
|
+
raise NotImplementedError(
|
|
594
|
+
"Outdated save not implemented for new window datasets."
|
|
595
|
+
)
|
|
596
|
+
if len(self.datasets) == 0:
|
|
597
|
+
raise ValueError("Expect at least one dataset")
|
|
598
|
+
if not (
|
|
599
|
+
hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
|
|
600
|
+
):
|
|
601
|
+
raise ValueError("dataset should have either raw or windows attribute")
|
|
602
|
+
file_name_templates = ["{}-raw.fif", "{}-epo.fif"]
|
|
603
|
+
description_file_name = os.path.join(path, "description.json")
|
|
604
|
+
target_file_name = os.path.join(path, "target_name.json")
|
|
605
|
+
if not overwrite:
|
|
606
|
+
from braindecode.datautil.serialization import ( # Import here to avoid circular import
|
|
607
|
+
_check_save_dir_empty,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
_check_save_dir_empty(path)
|
|
611
|
+
else:
|
|
612
|
+
for file_name_template in file_name_templates:
|
|
613
|
+
file_names = glob(
|
|
614
|
+
os.path.join(path, f"*{file_name_template.lstrip('{}')}")
|
|
615
|
+
)
|
|
616
|
+
_ = [os.remove(f) for f in file_names]
|
|
617
|
+
if os.path.isfile(target_file_name):
|
|
618
|
+
os.remove(target_file_name)
|
|
619
|
+
if os.path.isfile(description_file_name):
|
|
620
|
+
os.remove(description_file_name)
|
|
621
|
+
for kwarg_name in [
|
|
622
|
+
"raw_preproc_kwargs",
|
|
623
|
+
"window_kwargs",
|
|
624
|
+
"window_preproc_kwargs",
|
|
625
|
+
]:
|
|
626
|
+
kwarg_path = os.path.join(path, ".".join([kwarg_name, "json"]))
|
|
627
|
+
if os.path.exists(kwarg_path):
|
|
628
|
+
os.remove(kwarg_path)
|
|
629
|
+
|
|
630
|
+
is_raw = hasattr(self.datasets[0], "raw")
|
|
631
|
+
|
|
632
|
+
if is_raw:
|
|
633
|
+
file_name_template = file_name_templates[0]
|
|
634
|
+
else:
|
|
635
|
+
file_name_template = file_name_templates[1]
|
|
636
|
+
|
|
637
|
+
for i_ds, ds in enumerate(self.datasets):
|
|
638
|
+
full_file_path = os.path.join(path, file_name_template.format(i_ds))
|
|
639
|
+
if is_raw:
|
|
640
|
+
ds.raw.save(full_file_path, overwrite=overwrite)
|
|
641
|
+
else:
|
|
642
|
+
ds.windows.save(full_file_path, overwrite=overwrite)
|
|
643
|
+
|
|
644
|
+
self.description.to_json(description_file_name)
|
|
645
|
+
for kwarg_name in [
|
|
646
|
+
"raw_preproc_kwargs",
|
|
647
|
+
"window_kwargs",
|
|
648
|
+
"window_preproc_kwargs",
|
|
649
|
+
]:
|
|
650
|
+
if hasattr(self, kwarg_name):
|
|
651
|
+
kwargs_path = os.path.join(path, ".".join([kwarg_name, "json"]))
|
|
652
|
+
kwargs = getattr(self, kwarg_name)
|
|
653
|
+
if kwargs is not None:
|
|
654
|
+
json.dump(kwargs, open(kwargs_path, "w"))
|
|
655
|
+
|
|
656
|
+
@property
|
|
657
|
+
def description(self) -> pd.DataFrame:
|
|
658
|
+
df = pd.DataFrame([ds.description for ds in self.datasets])
|
|
659
|
+
df.reset_index(inplace=True, drop=True)
|
|
660
|
+
return df
|
|
661
|
+
|
|
662
|
+
def set_description(
|
|
663
|
+
self, description: dict | pd.DataFrame, overwrite: bool = False
|
|
664
|
+
):
|
|
665
|
+
"""Update (add or overwrite) the dataset description.
|
|
666
|
+
|
|
667
|
+
Parameters
|
|
668
|
+
----------
|
|
669
|
+
description : dict | pd.DataFrame
|
|
670
|
+
Description in the form key: value where the length of the value
|
|
671
|
+
has to match the number of datasets.
|
|
672
|
+
overwrite : bool
|
|
673
|
+
Has to be True if a key in description already exists in the
|
|
674
|
+
dataset description.
|
|
675
|
+
"""
|
|
676
|
+
description = pd.DataFrame(description)
|
|
677
|
+
for key, value in description.items():
|
|
678
|
+
for ds, value_ in zip(self.datasets, value):
|
|
679
|
+
ds.set_description({key: value_}, overwrite=overwrite)
|
|
680
|
+
|
|
681
|
+
def save(self, path: str, overwrite: bool = False, offset: int = 0):
|
|
682
|
+
"""Save datasets to files by creating one subdirectory for each dataset::
|
|
683
|
+
|
|
684
|
+
path/
|
|
685
|
+
0/
|
|
686
|
+
0-raw.fif | 0-epo.fif
|
|
687
|
+
description.json
|
|
688
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
689
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
690
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
691
|
+
target_name.json (if target_name is not None and dataset is raw)
|
|
692
|
+
1/
|
|
693
|
+
1-raw.fif | 1-epo.fif
|
|
694
|
+
description.json
|
|
695
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
696
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
697
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
698
|
+
target_name.json (if target_name is not None and dataset is raw)
|
|
699
|
+
|
|
700
|
+
Parameters
|
|
701
|
+
----------
|
|
702
|
+
path : str
|
|
703
|
+
Directory in which subdirectories are created to store
|
|
704
|
+
-raw.fif | -epo.fif and .json files to.
|
|
705
|
+
overwrite : bool
|
|
706
|
+
Whether to delete old subdirectories that will be saved to in this
|
|
707
|
+
call.
|
|
708
|
+
offset : int
|
|
709
|
+
If provided, the integer is added to the id of the dataset in the
|
|
710
|
+
concat. This is useful in the setting of very large datasets, where
|
|
711
|
+
one dataset has to be processed and saved at a time to account for
|
|
712
|
+
its original position.
|
|
713
|
+
"""
|
|
714
|
+
if len(self.datasets) == 0:
|
|
715
|
+
raise ValueError("Expect at least one dataset")
|
|
716
|
+
if not (
|
|
717
|
+
hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
|
|
718
|
+
):
|
|
719
|
+
raise ValueError("dataset should have either raw or windows attribute")
|
|
720
|
+
|
|
721
|
+
# Create path if it doesn't exist
|
|
722
|
+
os.makedirs(path, exist_ok=True)
|
|
723
|
+
|
|
724
|
+
path_contents = os.listdir(path)
|
|
725
|
+
n_sub_dirs = len(
|
|
726
|
+
[e for e in path_contents if os.path.isdir(os.path.join(path, e))]
|
|
727
|
+
)
|
|
728
|
+
for i_ds, ds in enumerate(self.datasets):
|
|
729
|
+
# remove subdirectory from list of untouched files / subdirectories
|
|
730
|
+
if str(i_ds + offset) in path_contents:
|
|
731
|
+
path_contents.remove(str(i_ds + offset))
|
|
732
|
+
# save_dir/i_ds/
|
|
733
|
+
sub_dir = os.path.join(path, str(i_ds + offset))
|
|
734
|
+
if os.path.exists(sub_dir):
|
|
735
|
+
if overwrite:
|
|
736
|
+
shutil.rmtree(sub_dir)
|
|
737
|
+
else:
|
|
738
|
+
raise FileExistsError(
|
|
739
|
+
f"Subdirectory {sub_dir} already exists. Please select"
|
|
740
|
+
f" a different directory, set overwrite=True, or "
|
|
741
|
+
f"resolve manually."
|
|
742
|
+
)
|
|
743
|
+
# save_dir/{i_ds+offset}/
|
|
744
|
+
os.makedirs(sub_dir)
|
|
745
|
+
# save_dir/{i_ds+offset}/{i_ds+offset}-{raw_or_epo}.fif
|
|
746
|
+
self._save_signals(sub_dir, ds, i_ds, offset)
|
|
747
|
+
# save_dir/{i_ds+offset}/metadata_df.pkl
|
|
748
|
+
self._save_metadata(sub_dir, ds)
|
|
749
|
+
# save_dir/{i_ds+offset}/description.json
|
|
750
|
+
self._save_description(sub_dir, ds.description)
|
|
751
|
+
# save_dir/{i_ds+offset}/raw_preproc_kwargs.json
|
|
752
|
+
# save_dir/{i_ds+offset}/window_kwargs.json
|
|
753
|
+
# save_dir/{i_ds+offset}/window_preproc_kwargs.json
|
|
754
|
+
self._save_kwargs(sub_dir, ds)
|
|
755
|
+
# save_dir/{i_ds+offset}/target_name.json
|
|
756
|
+
self._save_target_name(sub_dir, ds)
|
|
757
|
+
if overwrite:
|
|
758
|
+
# the following will be True for all datasets preprocessed and
|
|
759
|
+
# stored in parallel with braindecode.preprocessing.preprocess
|
|
760
|
+
if i_ds + 1 + offset < n_sub_dirs:
|
|
761
|
+
warnings.warn(
|
|
762
|
+
f"The number of saved datasets ({i_ds + 1 + offset}) "
|
|
763
|
+
f"does not match the number of existing "
|
|
764
|
+
f"subdirectories ({n_sub_dirs}). You may now "
|
|
765
|
+
f"encounter a mix of differently preprocessed "
|
|
766
|
+
f"datasets!",
|
|
767
|
+
UserWarning,
|
|
768
|
+
)
|
|
769
|
+
# if path contains files or directories that were not touched, raise
|
|
770
|
+
# warning
|
|
771
|
+
if path_contents:
|
|
772
|
+
warnings.warn(
|
|
773
|
+
f"Chosen directory {path} contains other "
|
|
774
|
+
f"subdirectories or files {path_contents}."
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
@staticmethod
|
|
778
|
+
def _save_signals(sub_dir, ds, i_ds, offset):
|
|
779
|
+
raw_or_epo = "raw" if hasattr(ds, "raw") else "epo"
|
|
780
|
+
fif_file_name = f"{i_ds + offset}-{raw_or_epo}.fif"
|
|
781
|
+
fif_file_path = os.path.join(sub_dir, fif_file_name)
|
|
782
|
+
raw_or_windows = "raw" if raw_or_epo == "raw" else "windows"
|
|
783
|
+
|
|
784
|
+
# The following appears to be necessary to avoid a CI failure when
|
|
785
|
+
# preprocessing WindowsDatasets with serialization enabled. The failure
|
|
786
|
+
# comes from `mne.epochs._check_consistency` which ensures the Epochs's
|
|
787
|
+
# object `times` attribute is not writeable.
|
|
788
|
+
getattr(ds, raw_or_windows).times.flags["WRITEABLE"] = False
|
|
789
|
+
|
|
790
|
+
getattr(ds, raw_or_windows).save(fif_file_path)
|
|
791
|
+
|
|
792
|
+
@staticmethod
|
|
793
|
+
def _save_metadata(sub_dir, ds):
|
|
794
|
+
if hasattr(ds, "metadata"):
|
|
795
|
+
metadata_file_path = os.path.join(sub_dir, "metadata_df.pkl")
|
|
796
|
+
ds.metadata.to_pickle(metadata_file_path)
|
|
797
|
+
|
|
798
|
+
@staticmethod
|
|
799
|
+
def _save_description(sub_dir, description):
|
|
800
|
+
description_file_path = os.path.join(sub_dir, "description.json")
|
|
801
|
+
description.to_json(description_file_path, default_handler=str)
|
|
802
|
+
|
|
803
|
+
@staticmethod
|
|
804
|
+
def _save_kwargs(sub_dir, ds):
|
|
805
|
+
for kwargs_name in [
|
|
806
|
+
"raw_preproc_kwargs",
|
|
807
|
+
"window_kwargs",
|
|
808
|
+
"window_preproc_kwargs",
|
|
809
|
+
]:
|
|
810
|
+
if hasattr(ds, kwargs_name):
|
|
811
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
812
|
+
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
813
|
+
kwargs = getattr(ds, kwargs_name)
|
|
814
|
+
if kwargs is not None:
|
|
815
|
+
with open(kwargs_file_path, "w") as f:
|
|
816
|
+
json.dump(kwargs, f, indent=2)
|
|
817
|
+
|
|
818
|
+
@staticmethod
|
|
819
|
+
def _save_target_name(sub_dir, ds):
|
|
820
|
+
if hasattr(ds, "target_name"):
|
|
821
|
+
target_file_path = os.path.join(sub_dir, "target_name.json")
|
|
822
|
+
with open(target_file_path, "w") as f:
|
|
823
|
+
json.dump({"target_name": ds.target_name}, f)
|