braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/functional.py +0 -101
- braindecode/augmentation/transforms.py +0 -74
- braindecode/datasets/base.py +3 -18
- braindecode/datautil/serialization.py +1 -0
- braindecode/models/__init__.py +1 -8
- braindecode/models/summary.csv +0 -1
- braindecode/models/util.py +0 -84
- braindecode/preprocessing/__init__.py +0 -5
- braindecode/preprocessing/eegprep_preprocess.py +19 -134
- braindecode/preprocessing/mne_preprocess.py +25 -56
- braindecode/preprocessing/preprocess.py +41 -126
- braindecode/preprocessing/util.py +0 -11
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +5 -11
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/RECORD +19 -24
- braindecode/datasets/hub.py +0 -962
- braindecode/datasets/hub_validation.py +0 -113
- braindecode/datasets/registry.py +0 -120
- braindecode/datautil/hub_formats.py +0 -180
- braindecode/models/luna.py +0 -836
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -77,14 +77,6 @@ class EEGPrepBasePreprocessor(Preprocessor):
|
|
|
77
77
|
self.record_orig_chanlocs = record_orig_chanlocs
|
|
78
78
|
self.force_dtype = np.dtype(force_dtype) if force_dtype is not None else None
|
|
79
79
|
|
|
80
|
-
@property
|
|
81
|
-
def _all_attrs(self):
|
|
82
|
-
return super()._all_attrs + [
|
|
83
|
-
"can_change_duration",
|
|
84
|
-
"record_orig_chanlocs",
|
|
85
|
-
"force_dtype",
|
|
86
|
-
]
|
|
87
|
-
|
|
88
80
|
@abstractmethod
|
|
89
81
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
90
82
|
"""Apply the preprocessor to an EEGLAB EEG structure. Overridden by subclass."""
|
|
@@ -388,46 +380,31 @@ class EEGPrep(EEGPrepBasePreprocessor):
|
|
|
388
380
|
can_change_duration=can_change_duration or False,
|
|
389
381
|
)
|
|
390
382
|
self.resample_to = resample_to
|
|
391
|
-
self.
|
|
383
|
+
self.reinterpolate = bad_channel_reinterpolate
|
|
392
384
|
self.common_avg_ref = common_avg_ref
|
|
393
385
|
self.burst_removal_cutoff = burst_removal_cutoff
|
|
394
386
|
self.bad_window_max_bad_channels = bad_window_max_bad_channels
|
|
395
|
-
self.bad_channel_corr_threshold = bad_channel_corr_threshold
|
|
396
|
-
self.highpass_frequencies = highpass_frequencies
|
|
397
|
-
self.flatline_maxdur = flatline_maxdur
|
|
398
|
-
self.bad_channel_hf_threshold = bad_channel_hf_threshold
|
|
399
|
-
self.bad_channel_max_broken_time = bad_channel_max_broken_time
|
|
400
|
-
self.bad_window_tolerances = bad_window_tolerances
|
|
401
|
-
self.refdata_max_bad_channels = refdata_max_bad_channels
|
|
402
|
-
self.refdata_max_tolerances = refdata_max_tolerances
|
|
403
|
-
self.num_samples = num_samples
|
|
404
|
-
self.subset_size = subset_size
|
|
405
|
-
self.bad_channel_nolocs_threshold = bad_channel_nolocs_threshold
|
|
406
|
-
self.bad_channel_nolocs_exclude_frac = bad_channel_nolocs_exclude_frac
|
|
407
|
-
self.max_mem_mb = max_mem_mb
|
|
408
387
|
|
|
409
|
-
|
|
410
|
-
def clean_artifacts_params(self):
|
|
411
|
-
if self.bad_channel_corr_threshold is None:
|
|
388
|
+
if bad_channel_corr_threshold is None:
|
|
412
389
|
line_noise_crit = None
|
|
413
390
|
else:
|
|
414
|
-
line_noise_crit =
|
|
415
|
-
|
|
416
|
-
ChannelCriterion=
|
|
391
|
+
line_noise_crit = bad_channel_hf_threshold
|
|
392
|
+
self.clean_artifacts_params = dict(
|
|
393
|
+
ChannelCriterion=bad_channel_corr_threshold,
|
|
417
394
|
LineNoiseCriterion=line_noise_crit,
|
|
418
|
-
BurstCriterion=
|
|
419
|
-
WindowCriterion=
|
|
420
|
-
Highpass=
|
|
421
|
-
ChannelCriterionMaxBadTime=
|
|
422
|
-
BurstCriterionRefMaxBadChns=
|
|
423
|
-
BurstCriterionRefTolerances=
|
|
424
|
-
WindowCriterionTolerances=
|
|
425
|
-
FlatlineCriterion=
|
|
426
|
-
NumSamples=
|
|
427
|
-
SubsetSize=
|
|
428
|
-
NoLocsChannelCriterion=
|
|
429
|
-
NoLocsChannelCriterionExcluded=
|
|
430
|
-
MaxMem=
|
|
395
|
+
BurstCriterion=burst_removal_cutoff,
|
|
396
|
+
WindowCriterion=bad_window_max_bad_channels,
|
|
397
|
+
Highpass=highpass_frequencies,
|
|
398
|
+
ChannelCriterionMaxBadTime=bad_channel_max_broken_time,
|
|
399
|
+
BurstCriterionRefMaxBadChns=refdata_max_bad_channels,
|
|
400
|
+
BurstCriterionRefTolerances=refdata_max_tolerances,
|
|
401
|
+
WindowCriterionTolerances=bad_window_tolerances,
|
|
402
|
+
FlatlineCriterion=flatline_maxdur,
|
|
403
|
+
NumSamples=num_samples,
|
|
404
|
+
SubsetSize=subset_size,
|
|
405
|
+
NoLocsChannelCriterion=bad_channel_nolocs_threshold,
|
|
406
|
+
NoLocsChannelCriterionExcluded=bad_channel_nolocs_exclude_frac,
|
|
407
|
+
MaxMem=max_mem_mb,
|
|
431
408
|
# For reference, the function additionally accepts these (legacy etc.)
|
|
432
409
|
# arguments, which we're not exposing here (current defaults as below):
|
|
433
410
|
# BurstRejection='off',
|
|
@@ -437,29 +414,6 @@ class EEGPrep(EEGPrepBasePreprocessor):
|
|
|
437
414
|
# availableRAM_GB=None,
|
|
438
415
|
)
|
|
439
416
|
|
|
440
|
-
@property
|
|
441
|
-
def _all_attrs(self):
|
|
442
|
-
return super()._all_attrs + [
|
|
443
|
-
"resample_to",
|
|
444
|
-
"bad_channel_reinterpolate",
|
|
445
|
-
"common_avg_ref",
|
|
446
|
-
"burst_removal_cutoff",
|
|
447
|
-
"bad_window_max_bad_channels",
|
|
448
|
-
"bad_channel_corr_threshold",
|
|
449
|
-
"highpass_frequencies",
|
|
450
|
-
"flatline_maxdur",
|
|
451
|
-
"bad_channel_hf_threshold",
|
|
452
|
-
"bad_channel_max_broken_time",
|
|
453
|
-
"bad_window_tolerances",
|
|
454
|
-
"refdata_max_bad_channels",
|
|
455
|
-
"refdata_max_tolerances",
|
|
456
|
-
"num_samples",
|
|
457
|
-
"subset_size",
|
|
458
|
-
"bad_channel_nolocs_threshold",
|
|
459
|
-
"bad_channel_nolocs_exclude_frac",
|
|
460
|
-
"max_mem_mb",
|
|
461
|
-
]
|
|
462
|
-
|
|
463
417
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
464
418
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
465
419
|
# remove per-channel DC offset (can be huge)
|
|
@@ -495,9 +449,7 @@ class EEGPrep(EEGPrepBasePreprocessor):
|
|
|
495
449
|
eeg["data"] = eeg["data"].astype(np.float32)
|
|
496
450
|
|
|
497
451
|
# optionally reinterpolate dropped channels
|
|
498
|
-
if self.
|
|
499
|
-
len(orig_chanlocs) > len(eeg["chanlocs"])
|
|
500
|
-
):
|
|
452
|
+
if self.reinterpolate and (len(orig_chanlocs) > len(eeg["chanlocs"])):
|
|
501
453
|
eeg = eegprep.eeg_interp(eeg, orig_chanlocs)
|
|
502
454
|
|
|
503
455
|
# optionally apply common average reference
|
|
@@ -561,13 +513,6 @@ class RemoveFlatChannels(EEGPrepBasePreprocessor):
|
|
|
561
513
|
self.max_flatline_duration = max_flatline_duration
|
|
562
514
|
self.max_allowed_jitter = max_allowed_jitter
|
|
563
515
|
|
|
564
|
-
@property
|
|
565
|
-
def _all_attrs(self):
|
|
566
|
-
return super()._all_attrs + [
|
|
567
|
-
"max_flatline_duration",
|
|
568
|
-
"max_allowed_jitter",
|
|
569
|
-
]
|
|
570
|
-
|
|
571
516
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
572
517
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
573
518
|
eeg = eegprep.clean_flatlines(
|
|
@@ -664,14 +609,6 @@ class RemoveDrifts(EEGPrepBasePreprocessor):
|
|
|
664
609
|
self.attenuation = attenuation
|
|
665
610
|
self.method = method
|
|
666
611
|
|
|
667
|
-
@property
|
|
668
|
-
def _all_attrs(self):
|
|
669
|
-
return super()._all_attrs + [
|
|
670
|
-
"transition",
|
|
671
|
-
"attenuation",
|
|
672
|
-
"method",
|
|
673
|
-
]
|
|
674
|
-
|
|
675
612
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
676
613
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
677
614
|
eeg = eegprep.clean_drifts(
|
|
@@ -738,10 +675,6 @@ class Resampling(EEGPrepBasePreprocessor):
|
|
|
738
675
|
super().__init__(can_change_duration=True)
|
|
739
676
|
self.sfreq = sfreq
|
|
740
677
|
|
|
741
|
-
@property
|
|
742
|
-
def _all_attrs(self):
|
|
743
|
-
return super()._all_attrs + ["sfreq"]
|
|
744
|
-
|
|
745
678
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
746
679
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
747
680
|
if self.sfreq is not None:
|
|
@@ -846,17 +779,6 @@ class RemoveBadChannels(EEGPrepBasePreprocessor):
|
|
|
846
779
|
self.num_samples = num_samples
|
|
847
780
|
self.subset_size = subset_size
|
|
848
781
|
|
|
849
|
-
@property
|
|
850
|
-
def _all_attrs(self):
|
|
851
|
-
return super()._all_attrs + [
|
|
852
|
-
"corr_threshold",
|
|
853
|
-
"noise_threshold",
|
|
854
|
-
"window_len",
|
|
855
|
-
"max_broken_time",
|
|
856
|
-
"num_samples",
|
|
857
|
-
"subset_size",
|
|
858
|
-
]
|
|
859
|
-
|
|
860
782
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
861
783
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
862
784
|
eeg = eegprep.clean_channels(
|
|
@@ -947,16 +869,6 @@ class RemoveBadChannelsNoLocs(EEGPrepBasePreprocessor):
|
|
|
947
869
|
self.max_broken_time = max_broken_time
|
|
948
870
|
self.linenoise_aware = linenoise_aware
|
|
949
871
|
|
|
950
|
-
@property
|
|
951
|
-
def _all_attrs(self):
|
|
952
|
-
return super()._all_attrs + [
|
|
953
|
-
"min_corr",
|
|
954
|
-
"ignored_quantile",
|
|
955
|
-
"window_len",
|
|
956
|
-
"max_broken_time",
|
|
957
|
-
"linenoise_aware",
|
|
958
|
-
]
|
|
959
|
-
|
|
960
872
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
961
873
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
962
874
|
eeg, _ = eegprep.clean_channels_nolocs(
|
|
@@ -1070,19 +982,6 @@ class RemoveBursts(EEGPrepBasePreprocessor):
|
|
|
1070
982
|
self.ref_wndlen = ref_wndlen
|
|
1071
983
|
self.maxmem = maxmem
|
|
1072
984
|
|
|
1073
|
-
@property
|
|
1074
|
-
def _all_attrs(self):
|
|
1075
|
-
return super()._all_attrs + [
|
|
1076
|
-
"cutoff",
|
|
1077
|
-
"window_len",
|
|
1078
|
-
"step_size",
|
|
1079
|
-
"max_dims",
|
|
1080
|
-
"ref_maxbadchannels",
|
|
1081
|
-
"ref_tolerances",
|
|
1082
|
-
"ref_wndlen",
|
|
1083
|
-
"maxmem",
|
|
1084
|
-
]
|
|
1085
|
-
|
|
1086
985
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
1087
986
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
1088
987
|
eeg = eegprep.clean_asr(
|
|
@@ -1199,20 +1098,6 @@ class RemoveBadWindows(EEGPrepBasePreprocessor):
|
|
|
1199
1098
|
self.step_sizes = step_sizes
|
|
1200
1099
|
self.shape_range = shape_range
|
|
1201
1100
|
|
|
1202
|
-
@property
|
|
1203
|
-
def _all_attrs(self):
|
|
1204
|
-
return super()._all_attrs + [
|
|
1205
|
-
"max_bad_channels",
|
|
1206
|
-
"zthresholds",
|
|
1207
|
-
"window_len",
|
|
1208
|
-
"window_overlap",
|
|
1209
|
-
"max_dropout_fraction",
|
|
1210
|
-
"min_clean_fraction",
|
|
1211
|
-
"truncate_quant",
|
|
1212
|
-
"step_sizes",
|
|
1213
|
-
"shape_range",
|
|
1214
|
-
]
|
|
1215
|
-
|
|
1216
1101
|
def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
|
|
1217
1102
|
"""Apply the preprocessor to an EEGLAB EEG structure."""
|
|
1218
1103
|
eeg, _ = eegprep.clean_windows(
|
|
@@ -38,60 +38,24 @@ def _generate_init_method(func, force_copy_false=False):
|
|
|
38
38
|
force_copy_false : bool
|
|
39
39
|
If True, forces copy=False by default for functions that have a copy parameter.
|
|
40
40
|
"""
|
|
41
|
-
func_name = func.__name__
|
|
42
41
|
parameters = list(inspect.signature(func).parameters.values())
|
|
43
|
-
param_names = [
|
|
44
|
-
param.name
|
|
45
|
-
for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
|
|
46
|
-
]
|
|
47
|
-
all_mandatory = [
|
|
48
|
-
param.name
|
|
49
|
-
for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
|
|
50
|
-
if param.default == inspect.Parameter.empty
|
|
51
|
-
]
|
|
42
|
+
param_names = [param.name for param in parameters]
|
|
52
43
|
|
|
53
44
|
def init_method(self, *args, **kwargs):
|
|
54
|
-
used = []
|
|
55
|
-
mandatory = list(all_mandatory)
|
|
56
|
-
init_kwargs = {}
|
|
57
|
-
|
|
58
45
|
# For standalone functions with copy parameter, set copy=False by default
|
|
59
46
|
if force_copy_false and "copy" in param_names and "copy" not in kwargs:
|
|
60
47
|
kwargs["copy"] = False
|
|
61
48
|
|
|
62
49
|
for name, value in zip(param_names, args):
|
|
63
|
-
|
|
64
|
-
used.append(name)
|
|
65
|
-
if name in mandatory:
|
|
66
|
-
mandatory.remove(name)
|
|
50
|
+
setattr(self, name, value)
|
|
67
51
|
for name, value in kwargs.items():
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if name not in param_names:
|
|
71
|
-
raise TypeError(
|
|
72
|
-
f"'{name}' is an invalid keyword argument for {func_name}()"
|
|
73
|
-
)
|
|
74
|
-
init_kwargs[name] = value
|
|
75
|
-
if name in mandatory:
|
|
76
|
-
mandatory.remove(name)
|
|
77
|
-
if len(mandatory) > 0:
|
|
78
|
-
raise TypeError(
|
|
79
|
-
f"{func_name}() missing required arguments: {', '.join(mandatory)}"
|
|
80
|
-
)
|
|
81
|
-
Preprocessor.__init__(self, fn=func_name, apply_on_array=False, **init_kwargs)
|
|
52
|
+
setattr(self, name, value)
|
|
53
|
+
self.kwargs = kwargs
|
|
82
54
|
|
|
83
55
|
init_method.__signature__ = inspect.signature(func)
|
|
84
56
|
return init_method
|
|
85
57
|
|
|
86
58
|
|
|
87
|
-
def _generate_repr_method(class_name):
|
|
88
|
-
def repr_method(self):
|
|
89
|
-
args_str = ", ".join(f"{k}={v.__repr__()}" for k, v in self.kwargs.items())
|
|
90
|
-
return f"{class_name}({args_str})"
|
|
91
|
-
|
|
92
|
-
return repr_method
|
|
93
|
-
|
|
94
|
-
|
|
95
59
|
def _generate_mne_pre_processor(function):
|
|
96
60
|
"""
|
|
97
61
|
Generate a class based on an MNE function for preprocessing.
|
|
@@ -105,14 +69,10 @@ def _generate_mne_pre_processor(function):
|
|
|
105
69
|
class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
|
|
106
70
|
"Eeg", "EEG"
|
|
107
71
|
)
|
|
108
|
-
|
|
109
|
-
# Automatically determine if function is standalone
|
|
110
|
-
is_standalone = _is_standalone_function(function)
|
|
111
|
-
|
|
112
72
|
# Create a wrapper note that references the original MNE function
|
|
113
73
|
# For Raw methods, use mne.io.Raw.method_name format with :meth:
|
|
114
74
|
# For standalone functions, use the function name only with :func:
|
|
115
|
-
if
|
|
75
|
+
if hasattr(mne.io.Raw, function.__name__):
|
|
116
76
|
ref_path = f"mne.io.Raw.{function.__name__}"
|
|
117
77
|
ref_role = "meth"
|
|
118
78
|
else:
|
|
@@ -136,10 +96,6 @@ def _generate_mne_pre_processor(function):
|
|
|
136
96
|
|
|
137
97
|
base_classes = (Preprocessor,)
|
|
138
98
|
|
|
139
|
-
# Check if function has a 'copy' parameter
|
|
140
|
-
sig = inspect.signature(function)
|
|
141
|
-
has_copy_param = "copy" in sig.parameters
|
|
142
|
-
force_copy_false = is_standalone and has_copy_param
|
|
143
99
|
# Automatically determine if function is standalone
|
|
144
100
|
is_standalone = _is_standalone_function(function)
|
|
145
101
|
|
|
@@ -147,13 +103,26 @@ def _generate_mne_pre_processor(function):
|
|
|
147
103
|
sig = inspect.signature(function)
|
|
148
104
|
has_copy_param = "copy" in sig.parameters
|
|
149
105
|
force_copy_false = is_standalone and has_copy_param
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
106
|
+
|
|
107
|
+
if is_standalone:
|
|
108
|
+
# For standalone functions, store the actual function object
|
|
109
|
+
class_attrs = {
|
|
110
|
+
"__init__": _generate_init_method(
|
|
111
|
+
function, force_copy_false=force_copy_false
|
|
112
|
+
),
|
|
113
|
+
"__doc__": wrapper_note + (function.__doc__ or ""),
|
|
114
|
+
"fn": function, # Store the function itself, not the name
|
|
115
|
+
"_is_standalone": True,
|
|
116
|
+
}
|
|
117
|
+
else:
|
|
118
|
+
# For methods, store the function name as before
|
|
119
|
+
class_attrs = {
|
|
120
|
+
"__init__": _generate_init_method(function),
|
|
121
|
+
"__doc__": wrapper_note + (function.__doc__ or ""),
|
|
122
|
+
"fn": function.__name__,
|
|
123
|
+
"_is_standalone": False,
|
|
124
|
+
}
|
|
125
|
+
|
|
157
126
|
generated_class = type(class_name, base_classes, class_attrs)
|
|
158
127
|
|
|
159
128
|
return generated_class
|
|
@@ -13,9 +13,7 @@ from __future__ import annotations
|
|
|
13
13
|
import platform
|
|
14
14
|
import sys
|
|
15
15
|
from collections.abc import Iterable
|
|
16
|
-
from functools import
|
|
17
|
-
from importlib import import_module
|
|
18
|
-
from inspect import signature
|
|
16
|
+
from functools import partial
|
|
19
17
|
from warnings import warn
|
|
20
18
|
|
|
21
19
|
if sys.version_info < (3, 9):
|
|
@@ -34,7 +32,6 @@ from braindecode.datasets.base import (
|
|
|
34
32
|
BaseConcatDataset,
|
|
35
33
|
EEGWindowsDataset,
|
|
36
34
|
RawDataset,
|
|
37
|
-
RecordDataset,
|
|
38
35
|
WindowsDataset,
|
|
39
36
|
)
|
|
40
37
|
from braindecode.datautil.serialization import (
|
|
@@ -72,30 +69,7 @@ class Preprocessor(object):
|
|
|
72
69
|
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
|
|
73
70
|
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
|
|
74
71
|
warn("Preprocessing choices with lambda functions cannot be saved.")
|
|
75
|
-
if
|
|
76
|
-
warn(
|
|
77
|
-
"apply_on_array can only be True if fn is a callable function. "
|
|
78
|
-
"Automatically correcting to apply_on_array=False."
|
|
79
|
-
)
|
|
80
|
-
apply_on_array = False
|
|
81
|
-
# We store the exact input parameters. Simpler for serialization.
|
|
82
|
-
self.fn = fn
|
|
83
|
-
self.apply_on_array = apply_on_array
|
|
84
|
-
self.kwargs = kwargs
|
|
85
|
-
|
|
86
|
-
@property
|
|
87
|
-
def _all_attrs(self):
|
|
88
|
-
return ["fn", "apply_on_array", "kwargs"]
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
def _init_attrs(self):
|
|
92
|
-
return [k for k in self._all_attrs if k in signature(self.__init__).parameters]
|
|
93
|
-
|
|
94
|
-
@cached_property
|
|
95
|
-
def _function(self):
|
|
96
|
-
kwargs = dict(self.kwargs)
|
|
97
|
-
fn = self.fn
|
|
98
|
-
if self.apply_on_array:
|
|
72
|
+
if callable(fn) and apply_on_array:
|
|
99
73
|
channel_wise = kwargs.pop("channel_wise", False)
|
|
100
74
|
picks = kwargs.pop("picks", None)
|
|
101
75
|
n_jobs = kwargs.pop("n_jobs", 1)
|
|
@@ -106,21 +80,12 @@ class Preprocessor(object):
|
|
|
106
80
|
n_jobs=n_jobs,
|
|
107
81
|
)
|
|
108
82
|
fn = "apply_function"
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
return partial(fn, **kwargs)
|
|
112
|
-
return partial(self._apply_str, fn=fn, **kwargs)
|
|
113
|
-
|
|
114
|
-
@staticmethod
|
|
115
|
-
def _apply_str(raw_or_epochs: BaseRaw | BaseEpochs, fn: str, **kwargs):
|
|
116
|
-
if not hasattr(raw_or_epochs, fn):
|
|
117
|
-
raise AttributeError(f"MNE object does not have a {fn} method.")
|
|
118
|
-
return getattr(raw_or_epochs, fn)(**kwargs)
|
|
83
|
+
self.fn = fn
|
|
84
|
+
self.kwargs = kwargs
|
|
119
85
|
|
|
120
86
|
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
|
|
121
|
-
function = self._function
|
|
122
87
|
try:
|
|
123
|
-
|
|
88
|
+
return self._try_apply(raw_or_epochs)
|
|
124
89
|
except RuntimeError:
|
|
125
90
|
# Maybe the function needs the data to be loaded and the data was
|
|
126
91
|
# not loaded yet. Not all MNE functions need data to be loaded,
|
|
@@ -128,83 +93,20 @@ class Preprocessor(object):
|
|
|
128
93
|
# without preloading data which can make the overall preprocessing
|
|
129
94
|
# pipeline substantially faster.
|
|
130
95
|
raw_or_epochs.load_data()
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
out = {k: getattr(self, k) for k in self._init_attrs}
|
|
146
|
-
if "fn" in out and callable(self.fn):
|
|
147
|
-
out["fn"] = self.fn.__module__ + "." + self.fn.__name__
|
|
148
|
-
out["__class_path__"] = (
|
|
149
|
-
self.__class__.__module__ + "." + self.__class__.__name__
|
|
150
|
-
)
|
|
151
|
-
if "kwargs" not in out and self.kwargs:
|
|
152
|
-
out["kwargs"] = self.kwargs
|
|
153
|
-
return out
|
|
154
|
-
|
|
155
|
-
@classmethod
|
|
156
|
-
def deserialize(cls_parent, data: dict):
|
|
157
|
-
"""Create a Preprocessor from its serializable representation.
|
|
158
|
-
|
|
159
|
-
Parameters
|
|
160
|
-
----------
|
|
161
|
-
data : dict
|
|
162
|
-
Dictionary with keys 'fn' and 'kwargs' representing the
|
|
163
|
-
Preprocessor.
|
|
164
|
-
Returns
|
|
165
|
-
-------
|
|
166
|
-
Preprocessor
|
|
167
|
-
The deserialized Preprocessor object.
|
|
168
|
-
"""
|
|
169
|
-
class_path = data.pop("__class_path__")
|
|
170
|
-
cls_name = class_path.split(".")[-1]
|
|
171
|
-
cls_module_name = ".".join(class_path.split(".")[:-1])
|
|
172
|
-
cls_module = import_module(cls_module_name)
|
|
173
|
-
cls = getattr(cls_module, cls_name)
|
|
174
|
-
|
|
175
|
-
kwargs = data.pop("kwargs") if "kwargs" in data else {}
|
|
176
|
-
|
|
177
|
-
fn = data.get("fn", None)
|
|
178
|
-
if fn is not None and "." in fn: # callable function
|
|
179
|
-
fn_name = fn.split(".")[-1]
|
|
180
|
-
module_name = ".".join(fn.split(".")[:-1])
|
|
181
|
-
module = import_module(module_name)
|
|
182
|
-
data["fn"] = getattr(module, fn_name)
|
|
183
|
-
|
|
184
|
-
return cls(**data, **kwargs)
|
|
185
|
-
|
|
186
|
-
def __repr__(self):
|
|
187
|
-
cls_name = self.__class__.__name__
|
|
188
|
-
args_str = ", ".join(
|
|
189
|
-
f"{k}={getattr(self, k).__repr__()}" for k in self._init_attrs
|
|
190
|
-
)
|
|
191
|
-
return f"{cls_name}({args_str})"
|
|
192
|
-
|
|
193
|
-
def _same_attr(self, other, attr):
|
|
194
|
-
a = getattr(self, attr)
|
|
195
|
-
b = getattr(other, attr)
|
|
196
|
-
if attr == "fn" and callable(a):
|
|
197
|
-
return a.__module__ == b.__module__ and a.__name__ == b.__name__
|
|
198
|
-
if isinstance(a, np.ndarray):
|
|
199
|
-
return np.array_equal(a, b)
|
|
200
|
-
return a == b
|
|
201
|
-
|
|
202
|
-
def __eq__(self, other):
|
|
203
|
-
if not isinstance(other, Preprocessor):
|
|
204
|
-
return False
|
|
205
|
-
return all(self._same_attr(other, attr) for attr in self._all_attrs) and (
|
|
206
|
-
self.__class__ == other.__class__
|
|
207
|
-
)
|
|
96
|
+
return self._try_apply(raw_or_epochs)
|
|
97
|
+
|
|
98
|
+
def _try_apply(self, raw_or_epochs):
|
|
99
|
+
if callable(self.fn):
|
|
100
|
+
result = self.fn(raw_or_epochs, **self.kwargs)
|
|
101
|
+
# For standalone functions that return a new object, propagate it back
|
|
102
|
+
if result is not None and result is not raw_or_epochs:
|
|
103
|
+
return result
|
|
104
|
+
return raw_or_epochs
|
|
105
|
+
else:
|
|
106
|
+
if not hasattr(raw_or_epochs, self.fn):
|
|
107
|
+
raise AttributeError(f"MNE object does not have a {self.fn} method.")
|
|
108
|
+
getattr(raw_or_epochs, self.fn)(**self.kwargs)
|
|
109
|
+
return raw_or_epochs
|
|
208
110
|
|
|
209
111
|
|
|
210
112
|
def preprocess(
|
|
@@ -326,12 +228,7 @@ def _replace_inplace(concat_ds, new_concat_ds):
|
|
|
326
228
|
|
|
327
229
|
|
|
328
230
|
def _preprocess(
|
|
329
|
-
ds
|
|
330
|
-
ds_index,
|
|
331
|
-
preprocessors,
|
|
332
|
-
save_dir=None,
|
|
333
|
-
overwrite=False,
|
|
334
|
-
copy_data=False,
|
|
231
|
+
ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
|
|
335
232
|
):
|
|
336
233
|
"""Apply preprocessor(s) to Raw or Epochs object.
|
|
337
234
|
|
|
@@ -390,6 +287,25 @@ def _preprocess(
|
|
|
390
287
|
return ds
|
|
391
288
|
|
|
392
289
|
|
|
290
|
+
def _get_preproc_kwargs(preprocessors):
|
|
291
|
+
preproc_kwargs = []
|
|
292
|
+
for p in preprocessors:
|
|
293
|
+
# in case of a mne function, fn is a str, kwargs is a dict
|
|
294
|
+
func_name = p.fn
|
|
295
|
+
func_kwargs = p.kwargs
|
|
296
|
+
# in case of another function
|
|
297
|
+
# if apply_on_array=False
|
|
298
|
+
if callable(p.fn):
|
|
299
|
+
func_name = p.fn.__name__
|
|
300
|
+
# if apply_on_array=True
|
|
301
|
+
else:
|
|
302
|
+
if "fun" in p.fn:
|
|
303
|
+
func_name = p.kwargs["fun"].func.__name__
|
|
304
|
+
func_kwargs = p.kwargs["fun"].keywords
|
|
305
|
+
preproc_kwargs.append((func_name, func_kwargs))
|
|
306
|
+
return preproc_kwargs
|
|
307
|
+
|
|
308
|
+
|
|
393
309
|
def _set_preproc_kwargs(ds, preprocessors):
|
|
394
310
|
"""Record preprocessing keyword arguments in RecordDataset.
|
|
395
311
|
|
|
@@ -400,7 +316,7 @@ def _set_preproc_kwargs(ds, preprocessors):
|
|
|
400
316
|
preprocessors : list
|
|
401
317
|
List of preprocessors.
|
|
402
318
|
"""
|
|
403
|
-
preproc_kwargs =
|
|
319
|
+
preproc_kwargs = _get_preproc_kwargs(preprocessors)
|
|
404
320
|
if isinstance(ds, WindowsDataset):
|
|
405
321
|
kind = "window"
|
|
406
322
|
elif isinstance(ds, EEGWindowsDataset):
|
|
@@ -409,8 +325,7 @@ def _set_preproc_kwargs(ds, preprocessors):
|
|
|
409
325
|
kind = "raw"
|
|
410
326
|
else:
|
|
411
327
|
raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
|
|
412
|
-
|
|
413
|
-
old_preproc_kwargs.extend(preproc_kwargs)
|
|
328
|
+
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
|
|
414
329
|
|
|
415
330
|
|
|
416
331
|
def exponential_moving_standardize(
|
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
# License: BSD-3
|
|
6
6
|
|
|
7
7
|
import base64
|
|
8
|
-
import inspect
|
|
9
8
|
import json
|
|
10
9
|
import re
|
|
11
10
|
from typing import Any
|
|
@@ -13,8 +12,6 @@ from typing import Any
|
|
|
13
12
|
import numpy as np
|
|
14
13
|
from mne.io.base import BaseRaw
|
|
15
14
|
|
|
16
|
-
from braindecode import preprocessing
|
|
17
|
-
|
|
18
15
|
__all__ = ["mne_store_metadata", "mne_load_metadata"]
|
|
19
16
|
|
|
20
17
|
|
|
@@ -26,14 +23,6 @@ _MARKER_END = "-->"
|
|
|
26
23
|
# Marker key for numpy arrays
|
|
27
24
|
_NP_ARRAY_TAG = "__numpy_array__"
|
|
28
25
|
|
|
29
|
-
preprocessor_dict = {}
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def _init_preprocessor_dict():
|
|
33
|
-
for m in inspect.getmembers(preprocessing, inspect.isclass):
|
|
34
|
-
if issubclass(m[1], preprocessing.Preprocessor):
|
|
35
|
-
preprocessor_dict[m[0]] = m[1]
|
|
36
|
-
|
|
37
26
|
|
|
38
27
|
def _numpy_decoder(dct):
|
|
39
28
|
"""Internal JSON decoder hook to handle numpy arrays."""
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.3.0.
|
|
1
|
+
__version__ = "1.3.0.dev182330353"
|
{braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA
RENAMED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.0.
|
|
3
|
+
Version: 1.3.0.dev182330353
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
|
-
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com
|
|
5
|
+
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
7
7
|
License: BSD-3-Clause
|
|
8
8
|
Project-URL: homepage, https://braindecode.org
|
|
@@ -38,14 +38,12 @@ Requires-Dist: wfdb
|
|
|
38
38
|
Requires-Dist: h5py
|
|
39
39
|
Requires-Dist: linear_attention_transformer
|
|
40
40
|
Requires-Dist: docstring_inheritance
|
|
41
|
-
Requires-Dist: rotary_embedding_torch
|
|
42
41
|
Provides-Extra: moabb
|
|
43
42
|
Requires-Dist: moabb>=1.2.0; extra == "moabb"
|
|
44
43
|
Provides-Extra: eegprep
|
|
45
44
|
Requires-Dist: eegprep[eeglabio]>=0.1.1; extra == "eegprep"
|
|
46
|
-
Provides-Extra:
|
|
47
|
-
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "
|
|
48
|
-
Requires-Dist: zarr<3.0,>=2.18; extra == "hub"
|
|
45
|
+
Provides-Extra: hug
|
|
46
|
+
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hug"
|
|
49
47
|
Provides-Extra: tests
|
|
50
48
|
Requires-Dist: pytest; extra == "tests"
|
|
51
49
|
Requires-Dist: pytest-cov; extra == "tests"
|
|
@@ -71,11 +69,7 @@ Requires-Dist: pre-commit; extra == "docs"
|
|
|
71
69
|
Requires-Dist: openneuro-py; extra == "docs"
|
|
72
70
|
Requires-Dist: plotly; extra == "docs"
|
|
73
71
|
Provides-Extra: all
|
|
74
|
-
Requires-Dist: braindecode[moabb]; extra == "all"
|
|
75
|
-
Requires-Dist: braindecode[hub]; extra == "all"
|
|
76
|
-
Requires-Dist: braindecode[tests]; extra == "all"
|
|
77
|
-
Requires-Dist: braindecode[docs]; extra == "all"
|
|
78
|
-
Requires-Dist: braindecode[eegprep]; extra == "all"
|
|
72
|
+
Requires-Dist: braindecode[docs,eegprep,hug,moabb,tests]; extra == "all"
|
|
79
73
|
Dynamic: license-file
|
|
80
74
|
|
|
81
75
|
.. image:: https://badges.gitter.im/braindecodechat/community.svg
|