braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -77,14 +77,6 @@ class EEGPrepBasePreprocessor(Preprocessor):
77
77
  self.record_orig_chanlocs = record_orig_chanlocs
78
78
  self.force_dtype = np.dtype(force_dtype) if force_dtype is not None else None
79
79
 
80
- @property
81
- def _all_attrs(self):
82
- return super()._all_attrs + [
83
- "can_change_duration",
84
- "record_orig_chanlocs",
85
- "force_dtype",
86
- ]
87
-
88
80
  @abstractmethod
89
81
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
90
82
  """Apply the preprocessor to an EEGLAB EEG structure. Overridden by subclass."""
@@ -388,46 +380,31 @@ class EEGPrep(EEGPrepBasePreprocessor):
388
380
  can_change_duration=can_change_duration or False,
389
381
  )
390
382
  self.resample_to = resample_to
391
- self.bad_channel_reinterpolate = bad_channel_reinterpolate
383
+ self.reinterpolate = bad_channel_reinterpolate
392
384
  self.common_avg_ref = common_avg_ref
393
385
  self.burst_removal_cutoff = burst_removal_cutoff
394
386
  self.bad_window_max_bad_channels = bad_window_max_bad_channels
395
- self.bad_channel_corr_threshold = bad_channel_corr_threshold
396
- self.highpass_frequencies = highpass_frequencies
397
- self.flatline_maxdur = flatline_maxdur
398
- self.bad_channel_hf_threshold = bad_channel_hf_threshold
399
- self.bad_channel_max_broken_time = bad_channel_max_broken_time
400
- self.bad_window_tolerances = bad_window_tolerances
401
- self.refdata_max_bad_channels = refdata_max_bad_channels
402
- self.refdata_max_tolerances = refdata_max_tolerances
403
- self.num_samples = num_samples
404
- self.subset_size = subset_size
405
- self.bad_channel_nolocs_threshold = bad_channel_nolocs_threshold
406
- self.bad_channel_nolocs_exclude_frac = bad_channel_nolocs_exclude_frac
407
- self.max_mem_mb = max_mem_mb
408
387
 
409
- @property
410
- def clean_artifacts_params(self):
411
- if self.bad_channel_corr_threshold is None:
388
+ if bad_channel_corr_threshold is None:
412
389
  line_noise_crit = None
413
390
  else:
414
- line_noise_crit = self.bad_channel_hf_threshold
415
- return dict(
416
- ChannelCriterion=self.bad_channel_corr_threshold,
391
+ line_noise_crit = bad_channel_hf_threshold
392
+ self.clean_artifacts_params = dict(
393
+ ChannelCriterion=bad_channel_corr_threshold,
417
394
  LineNoiseCriterion=line_noise_crit,
418
- BurstCriterion=self.burst_removal_cutoff,
419
- WindowCriterion=self.bad_window_max_bad_channels,
420
- Highpass=self.highpass_frequencies,
421
- ChannelCriterionMaxBadTime=self.bad_channel_max_broken_time,
422
- BurstCriterionRefMaxBadChns=self.refdata_max_bad_channels,
423
- BurstCriterionRefTolerances=self.refdata_max_tolerances,
424
- WindowCriterionTolerances=self.bad_window_tolerances,
425
- FlatlineCriterion=self.flatline_maxdur,
426
- NumSamples=self.num_samples,
427
- SubsetSize=self.subset_size,
428
- NoLocsChannelCriterion=self.bad_channel_nolocs_threshold,
429
- NoLocsChannelCriterionExcluded=self.bad_channel_nolocs_exclude_frac,
430
- MaxMem=self.max_mem_mb,
395
+ BurstCriterion=burst_removal_cutoff,
396
+ WindowCriterion=bad_window_max_bad_channels,
397
+ Highpass=highpass_frequencies,
398
+ ChannelCriterionMaxBadTime=bad_channel_max_broken_time,
399
+ BurstCriterionRefMaxBadChns=refdata_max_bad_channels,
400
+ BurstCriterionRefTolerances=refdata_max_tolerances,
401
+ WindowCriterionTolerances=bad_window_tolerances,
402
+ FlatlineCriterion=flatline_maxdur,
403
+ NumSamples=num_samples,
404
+ SubsetSize=subset_size,
405
+ NoLocsChannelCriterion=bad_channel_nolocs_threshold,
406
+ NoLocsChannelCriterionExcluded=bad_channel_nolocs_exclude_frac,
407
+ MaxMem=max_mem_mb,
431
408
  # For reference, the function additionally accepts these (legacy etc.)
432
409
  # arguments, which we're not exposing here (current defaults as below):
433
410
  # BurstRejection='off',
@@ -437,29 +414,6 @@ class EEGPrep(EEGPrepBasePreprocessor):
437
414
  # availableRAM_GB=None,
438
415
  )
439
416
 
440
- @property
441
- def _all_attrs(self):
442
- return super()._all_attrs + [
443
- "resample_to",
444
- "bad_channel_reinterpolate",
445
- "common_avg_ref",
446
- "burst_removal_cutoff",
447
- "bad_window_max_bad_channels",
448
- "bad_channel_corr_threshold",
449
- "highpass_frequencies",
450
- "flatline_maxdur",
451
- "bad_channel_hf_threshold",
452
- "bad_channel_max_broken_time",
453
- "bad_window_tolerances",
454
- "refdata_max_bad_channels",
455
- "refdata_max_tolerances",
456
- "num_samples",
457
- "subset_size",
458
- "bad_channel_nolocs_threshold",
459
- "bad_channel_nolocs_exclude_frac",
460
- "max_mem_mb",
461
- ]
462
-
463
417
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
464
418
  """Apply the preprocessor to an EEGLAB EEG structure."""
465
419
  # remove per-channel DC offset (can be huge)
@@ -495,9 +449,7 @@ class EEGPrep(EEGPrepBasePreprocessor):
495
449
  eeg["data"] = eeg["data"].astype(np.float32)
496
450
 
497
451
  # optionally reinterpolate dropped channels
498
- if self.bad_channel_reinterpolate and (
499
- len(orig_chanlocs) > len(eeg["chanlocs"])
500
- ):
452
+ if self.reinterpolate and (len(orig_chanlocs) > len(eeg["chanlocs"])):
501
453
  eeg = eegprep.eeg_interp(eeg, orig_chanlocs)
502
454
 
503
455
  # optionally apply common average reference
@@ -561,13 +513,6 @@ class RemoveFlatChannels(EEGPrepBasePreprocessor):
561
513
  self.max_flatline_duration = max_flatline_duration
562
514
  self.max_allowed_jitter = max_allowed_jitter
563
515
 
564
- @property
565
- def _all_attrs(self):
566
- return super()._all_attrs + [
567
- "max_flatline_duration",
568
- "max_allowed_jitter",
569
- ]
570
-
571
516
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
572
517
  """Apply the preprocessor to an EEGLAB EEG structure."""
573
518
  eeg = eegprep.clean_flatlines(
@@ -664,14 +609,6 @@ class RemoveDrifts(EEGPrepBasePreprocessor):
664
609
  self.attenuation = attenuation
665
610
  self.method = method
666
611
 
667
- @property
668
- def _all_attrs(self):
669
- return super()._all_attrs + [
670
- "transition",
671
- "attenuation",
672
- "method",
673
- ]
674
-
675
612
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
676
613
  """Apply the preprocessor to an EEGLAB EEG structure."""
677
614
  eeg = eegprep.clean_drifts(
@@ -738,10 +675,6 @@ class Resampling(EEGPrepBasePreprocessor):
738
675
  super().__init__(can_change_duration=True)
739
676
  self.sfreq = sfreq
740
677
 
741
- @property
742
- def _all_attrs(self):
743
- return super()._all_attrs + ["sfreq"]
744
-
745
678
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
746
679
  """Apply the preprocessor to an EEGLAB EEG structure."""
747
680
  if self.sfreq is not None:
@@ -846,17 +779,6 @@ class RemoveBadChannels(EEGPrepBasePreprocessor):
846
779
  self.num_samples = num_samples
847
780
  self.subset_size = subset_size
848
781
 
849
- @property
850
- def _all_attrs(self):
851
- return super()._all_attrs + [
852
- "corr_threshold",
853
- "noise_threshold",
854
- "window_len",
855
- "max_broken_time",
856
- "num_samples",
857
- "subset_size",
858
- ]
859
-
860
782
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
861
783
  """Apply the preprocessor to an EEGLAB EEG structure."""
862
784
  eeg = eegprep.clean_channels(
@@ -947,16 +869,6 @@ class RemoveBadChannelsNoLocs(EEGPrepBasePreprocessor):
947
869
  self.max_broken_time = max_broken_time
948
870
  self.linenoise_aware = linenoise_aware
949
871
 
950
- @property
951
- def _all_attrs(self):
952
- return super()._all_attrs + [
953
- "min_corr",
954
- "ignored_quantile",
955
- "window_len",
956
- "max_broken_time",
957
- "linenoise_aware",
958
- ]
959
-
960
872
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
961
873
  """Apply the preprocessor to an EEGLAB EEG structure."""
962
874
  eeg, _ = eegprep.clean_channels_nolocs(
@@ -1070,19 +982,6 @@ class RemoveBursts(EEGPrepBasePreprocessor):
1070
982
  self.ref_wndlen = ref_wndlen
1071
983
  self.maxmem = maxmem
1072
984
 
1073
- @property
1074
- def _all_attrs(self):
1075
- return super()._all_attrs + [
1076
- "cutoff",
1077
- "window_len",
1078
- "step_size",
1079
- "max_dims",
1080
- "ref_maxbadchannels",
1081
- "ref_tolerances",
1082
- "ref_wndlen",
1083
- "maxmem",
1084
- ]
1085
-
1086
985
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
1087
986
  """Apply the preprocessor to an EEGLAB EEG structure."""
1088
987
  eeg = eegprep.clean_asr(
@@ -1199,20 +1098,6 @@ class RemoveBadWindows(EEGPrepBasePreprocessor):
1199
1098
  self.step_sizes = step_sizes
1200
1099
  self.shape_range = shape_range
1201
1100
 
1202
- @property
1203
- def _all_attrs(self):
1204
- return super()._all_attrs + [
1205
- "max_bad_channels",
1206
- "zthresholds",
1207
- "window_len",
1208
- "window_overlap",
1209
- "max_dropout_fraction",
1210
- "min_clean_fraction",
1211
- "truncate_quant",
1212
- "step_sizes",
1213
- "shape_range",
1214
- ]
1215
-
1216
1101
  def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
1217
1102
  """Apply the preprocessor to an EEGLAB EEG structure."""
1218
1103
  eeg, _ = eegprep.clean_windows(
@@ -38,60 +38,24 @@ def _generate_init_method(func, force_copy_false=False):
38
38
  force_copy_false : bool
39
39
  If True, forces copy=False by default for functions that have a copy parameter.
40
40
  """
41
- func_name = func.__name__
42
41
  parameters = list(inspect.signature(func).parameters.values())
43
- param_names = [
44
- param.name
45
- for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
46
- ]
47
- all_mandatory = [
48
- param.name
49
- for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
50
- if param.default == inspect.Parameter.empty
51
- ]
42
+ param_names = [param.name for param in parameters]
52
43
 
53
44
  def init_method(self, *args, **kwargs):
54
- used = []
55
- mandatory = list(all_mandatory)
56
- init_kwargs = {}
57
-
58
45
  # For standalone functions with copy parameter, set copy=False by default
59
46
  if force_copy_false and "copy" in param_names and "copy" not in kwargs:
60
47
  kwargs["copy"] = False
61
48
 
62
49
  for name, value in zip(param_names, args):
63
- init_kwargs[name] = value
64
- used.append(name)
65
- if name in mandatory:
66
- mandatory.remove(name)
50
+ setattr(self, name, value)
67
51
  for name, value in kwargs.items():
68
- if name in used:
69
- raise TypeError(f"Multiple values for argument '{name}'")
70
- if name not in param_names:
71
- raise TypeError(
72
- f"'{name}' is an invalid keyword argument for {func_name}()"
73
- )
74
- init_kwargs[name] = value
75
- if name in mandatory:
76
- mandatory.remove(name)
77
- if len(mandatory) > 0:
78
- raise TypeError(
79
- f"{func_name}() missing required arguments: {', '.join(mandatory)}"
80
- )
81
- Preprocessor.__init__(self, fn=func_name, apply_on_array=False, **init_kwargs)
52
+ setattr(self, name, value)
53
+ self.kwargs = kwargs
82
54
 
83
55
  init_method.__signature__ = inspect.signature(func)
84
56
  return init_method
85
57
 
86
58
 
87
- def _generate_repr_method(class_name):
88
- def repr_method(self):
89
- args_str = ", ".join(f"{k}={v.__repr__()}" for k, v in self.kwargs.items())
90
- return f"{class_name}({args_str})"
91
-
92
- return repr_method
93
-
94
-
95
59
  def _generate_mne_pre_processor(function):
96
60
  """
97
61
  Generate a class based on an MNE function for preprocessing.
@@ -105,14 +69,10 @@ def _generate_mne_pre_processor(function):
105
69
  class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
106
70
  "Eeg", "EEG"
107
71
  )
108
-
109
- # Automatically determine if function is standalone
110
- is_standalone = _is_standalone_function(function)
111
-
112
72
  # Create a wrapper note that references the original MNE function
113
73
  # For Raw methods, use mne.io.Raw.method_name format with :meth:
114
74
  # For standalone functions, use the function name only with :func:
115
- if not is_standalone:
75
+ if hasattr(mne.io.Raw, function.__name__):
116
76
  ref_path = f"mne.io.Raw.{function.__name__}"
117
77
  ref_role = "meth"
118
78
  else:
@@ -136,10 +96,6 @@ def _generate_mne_pre_processor(function):
136
96
 
137
97
  base_classes = (Preprocessor,)
138
98
 
139
- # Check if function has a 'copy' parameter
140
- sig = inspect.signature(function)
141
- has_copy_param = "copy" in sig.parameters
142
- force_copy_false = is_standalone and has_copy_param
143
99
  # Automatically determine if function is standalone
144
100
  is_standalone = _is_standalone_function(function)
145
101
 
@@ -147,13 +103,26 @@ def _generate_mne_pre_processor(function):
147
103
  sig = inspect.signature(function)
148
104
  has_copy_param = "copy" in sig.parameters
149
105
  force_copy_false = is_standalone and has_copy_param
150
- class_attrs = {
151
- "__init__": _generate_init_method(function, force_copy_false),
152
- "__doc__": wrapper_note + (function.__doc__ or ""),
153
- "__repr__": _generate_repr_method(class_name),
154
- "fn": function if is_standalone else function.__name__,
155
- "_is_standalone": is_standalone,
156
- }
106
+
107
+ if is_standalone:
108
+ # For standalone functions, store the actual function object
109
+ class_attrs = {
110
+ "__init__": _generate_init_method(
111
+ function, force_copy_false=force_copy_false
112
+ ),
113
+ "__doc__": wrapper_note + (function.__doc__ or ""),
114
+ "fn": function, # Store the function itself, not the name
115
+ "_is_standalone": True,
116
+ }
117
+ else:
118
+ # For methods, store the function name as before
119
+ class_attrs = {
120
+ "__init__": _generate_init_method(function),
121
+ "__doc__": wrapper_note + (function.__doc__ or ""),
122
+ "fn": function.__name__,
123
+ "_is_standalone": False,
124
+ }
125
+
157
126
  generated_class = type(class_name, base_classes, class_attrs)
158
127
 
159
128
  return generated_class
@@ -13,9 +13,7 @@ from __future__ import annotations
13
13
  import platform
14
14
  import sys
15
15
  from collections.abc import Iterable
16
- from functools import cached_property, partial
17
- from importlib import import_module
18
- from inspect import signature
16
+ from functools import partial
19
17
  from warnings import warn
20
18
 
21
19
  if sys.version_info < (3, 9):
@@ -34,7 +32,6 @@ from braindecode.datasets.base import (
34
32
  BaseConcatDataset,
35
33
  EEGWindowsDataset,
36
34
  RawDataset,
37
- RecordDataset,
38
35
  WindowsDataset,
39
36
  )
40
37
  from braindecode.datautil.serialization import (
@@ -72,30 +69,7 @@ class Preprocessor(object):
72
69
  def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
73
70
  if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
74
71
  warn("Preprocessing choices with lambda functions cannot be saved.")
75
- if apply_on_array and not callable(fn):
76
- warn(
77
- "apply_on_array can only be True if fn is a callable function. "
78
- "Automatically correcting to apply_on_array=False."
79
- )
80
- apply_on_array = False
81
- # We store the exact input parameters. Simpler for serialization.
82
- self.fn = fn
83
- self.apply_on_array = apply_on_array
84
- self.kwargs = kwargs
85
-
86
- @property
87
- def _all_attrs(self):
88
- return ["fn", "apply_on_array", "kwargs"]
89
-
90
- @property
91
- def _init_attrs(self):
92
- return [k for k in self._all_attrs if k in signature(self.__init__).parameters]
93
-
94
- @cached_property
95
- def _function(self):
96
- kwargs = dict(self.kwargs)
97
- fn = self.fn
98
- if self.apply_on_array:
72
+ if callable(fn) and apply_on_array:
99
73
  channel_wise = kwargs.pop("channel_wise", False)
100
74
  picks = kwargs.pop("picks", None)
101
75
  n_jobs = kwargs.pop("n_jobs", 1)
@@ -106,21 +80,12 @@ class Preprocessor(object):
106
80
  n_jobs=n_jobs,
107
81
  )
108
82
  fn = "apply_function"
109
-
110
- if callable(fn):
111
- return partial(fn, **kwargs)
112
- return partial(self._apply_str, fn=fn, **kwargs)
113
-
114
- @staticmethod
115
- def _apply_str(raw_or_epochs: BaseRaw | BaseEpochs, fn: str, **kwargs):
116
- if not hasattr(raw_or_epochs, fn):
117
- raise AttributeError(f"MNE object does not have a {fn} method.")
118
- return getattr(raw_or_epochs, fn)(**kwargs)
83
+ self.fn = fn
84
+ self.kwargs = kwargs
119
85
 
120
86
  def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
121
- function = self._function
122
87
  try:
123
- result = function(raw_or_epochs)
88
+ return self._try_apply(raw_or_epochs)
124
89
  except RuntimeError:
125
90
  # Maybe the function needs the data to be loaded and the data was
126
91
  # not loaded yet. Not all MNE functions need data to be loaded,
@@ -128,83 +93,20 @@ class Preprocessor(object):
128
93
  # without preloading data which can make the overall preprocessing
129
94
  # pipeline substantially faster.
130
95
  raw_or_epochs.load_data()
131
- result = function(raw_or_epochs)
132
- if result is not None:
133
- return result
134
- return raw_or_epochs
135
-
136
- def serialize(self):
137
- """Return a serializable representation of the Preprocessor.
138
-
139
- Returns
140
- -------
141
- dict
142
- Dictionary with keys 'fn' and 'kwargs' representing the
143
- Preprocessor.
144
- """
145
- out = {k: getattr(self, k) for k in self._init_attrs}
146
- if "fn" in out and callable(self.fn):
147
- out["fn"] = self.fn.__module__ + "." + self.fn.__name__
148
- out["__class_path__"] = (
149
- self.__class__.__module__ + "." + self.__class__.__name__
150
- )
151
- if "kwargs" not in out and self.kwargs:
152
- out["kwargs"] = self.kwargs
153
- return out
154
-
155
- @classmethod
156
- def deserialize(cls_parent, data: dict):
157
- """Create a Preprocessor from its serializable representation.
158
-
159
- Parameters
160
- ----------
161
- data : dict
162
- Dictionary with keys 'fn' and 'kwargs' representing the
163
- Preprocessor.
164
- Returns
165
- -------
166
- Preprocessor
167
- The deserialized Preprocessor object.
168
- """
169
- class_path = data.pop("__class_path__")
170
- cls_name = class_path.split(".")[-1]
171
- cls_module_name = ".".join(class_path.split(".")[:-1])
172
- cls_module = import_module(cls_module_name)
173
- cls = getattr(cls_module, cls_name)
174
-
175
- kwargs = data.pop("kwargs") if "kwargs" in data else {}
176
-
177
- fn = data.get("fn", None)
178
- if fn is not None and "." in fn: # callable function
179
- fn_name = fn.split(".")[-1]
180
- module_name = ".".join(fn.split(".")[:-1])
181
- module = import_module(module_name)
182
- data["fn"] = getattr(module, fn_name)
183
-
184
- return cls(**data, **kwargs)
185
-
186
- def __repr__(self):
187
- cls_name = self.__class__.__name__
188
- args_str = ", ".join(
189
- f"{k}={getattr(self, k).__repr__()}" for k in self._init_attrs
190
- )
191
- return f"{cls_name}({args_str})"
192
-
193
- def _same_attr(self, other, attr):
194
- a = getattr(self, attr)
195
- b = getattr(other, attr)
196
- if attr == "fn" and callable(a):
197
- return a.__module__ == b.__module__ and a.__name__ == b.__name__
198
- if isinstance(a, np.ndarray):
199
- return np.array_equal(a, b)
200
- return a == b
201
-
202
- def __eq__(self, other):
203
- if not isinstance(other, Preprocessor):
204
- return False
205
- return all(self._same_attr(other, attr) for attr in self._all_attrs) and (
206
- self.__class__ == other.__class__
207
- )
96
+ return self._try_apply(raw_or_epochs)
97
+
98
+ def _try_apply(self, raw_or_epochs):
99
+ if callable(self.fn):
100
+ result = self.fn(raw_or_epochs, **self.kwargs)
101
+ # For standalone functions that return a new object, propagate it back
102
+ if result is not None and result is not raw_or_epochs:
103
+ return result
104
+ return raw_or_epochs
105
+ else:
106
+ if not hasattr(raw_or_epochs, self.fn):
107
+ raise AttributeError(f"MNE object does not have a {self.fn} method.")
108
+ getattr(raw_or_epochs, self.fn)(**self.kwargs)
109
+ return raw_or_epochs
208
110
 
209
111
 
210
112
  def preprocess(
@@ -326,12 +228,7 @@ def _replace_inplace(concat_ds, new_concat_ds):
326
228
 
327
229
 
328
230
  def _preprocess(
329
- ds: RecordDataset,
330
- ds_index,
331
- preprocessors,
332
- save_dir=None,
333
- overwrite=False,
334
- copy_data=False,
231
+ ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
335
232
  ):
336
233
  """Apply preprocessor(s) to Raw or Epochs object.
337
234
 
@@ -390,6 +287,25 @@ def _preprocess(
390
287
  return ds
391
288
 
392
289
 
290
+ def _get_preproc_kwargs(preprocessors):
291
+ preproc_kwargs = []
292
+ for p in preprocessors:
293
+ # in case of a mne function, fn is a str, kwargs is a dict
294
+ func_name = p.fn
295
+ func_kwargs = p.kwargs
296
+ # in case of another function
297
+ # if apply_on_array=False
298
+ if callable(p.fn):
299
+ func_name = p.fn.__name__
300
+ # if apply_on_array=True
301
+ else:
302
+ if "fun" in p.fn:
303
+ func_name = p.kwargs["fun"].func.__name__
304
+ func_kwargs = p.kwargs["fun"].keywords
305
+ preproc_kwargs.append((func_name, func_kwargs))
306
+ return preproc_kwargs
307
+
308
+
393
309
  def _set_preproc_kwargs(ds, preprocessors):
394
310
  """Record preprocessing keyword arguments in RecordDataset.
395
311
 
@@ -400,7 +316,7 @@ def _set_preproc_kwargs(ds, preprocessors):
400
316
  preprocessors : list
401
317
  List of preprocessors.
402
318
  """
403
- preproc_kwargs = [p.serialize() for p in preprocessors]
319
+ preproc_kwargs = _get_preproc_kwargs(preprocessors)
404
320
  if isinstance(ds, WindowsDataset):
405
321
  kind = "window"
406
322
  elif isinstance(ds, EEGWindowsDataset):
@@ -409,8 +325,7 @@ def _set_preproc_kwargs(ds, preprocessors):
409
325
  kind = "raw"
410
326
  else:
411
327
  raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
412
- old_preproc_kwargs = getattr(ds, kind + "_preproc_kwargs")
413
- old_preproc_kwargs.extend(preproc_kwargs)
328
+ setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
414
329
 
415
330
 
416
331
  def exponential_moving_standardize(
@@ -5,7 +5,6 @@
5
5
  # License: BSD-3
6
6
 
7
7
  import base64
8
- import inspect
9
8
  import json
10
9
  import re
11
10
  from typing import Any
@@ -13,8 +12,6 @@ from typing import Any
13
12
  import numpy as np
14
13
  from mne.io.base import BaseRaw
15
14
 
16
- from braindecode import preprocessing
17
-
18
15
  __all__ = ["mne_store_metadata", "mne_load_metadata"]
19
16
 
20
17
 
@@ -26,14 +23,6 @@ _MARKER_END = "-->"
26
23
  # Marker key for numpy arrays
27
24
  _NP_ARRAY_TAG = "__numpy_array__"
28
25
 
29
- preprocessor_dict = {}
30
-
31
-
32
- def _init_preprocessor_dict():
33
- for m in inspect.getmembers(preprocessing, inspect.isclass):
34
- if issubclass(m[1], preprocessing.Preprocessor):
35
- preprocessor_dict[m[0]] = m[1]
36
-
37
26
 
38
27
  def _numpy_decoder(dct):
39
28
  """Internal JSON decoder hook to handle numpy arrays."""
braindecode/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.3.0.dev177628147"
1
+ __version__ = "1.3.0.dev182330353"
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.3.0.dev177628147
3
+ Version: 1.3.0.dev182330353
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
- Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
5
+ Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
7
7
  License: BSD-3-Clause
8
8
  Project-URL: homepage, https://braindecode.org
@@ -38,14 +38,12 @@ Requires-Dist: wfdb
38
38
  Requires-Dist: h5py
39
39
  Requires-Dist: linear_attention_transformer
40
40
  Requires-Dist: docstring_inheritance
41
- Requires-Dist: rotary_embedding_torch
42
41
  Provides-Extra: moabb
43
42
  Requires-Dist: moabb>=1.2.0; extra == "moabb"
44
43
  Provides-Extra: eegprep
45
44
  Requires-Dist: eegprep[eeglabio]>=0.1.1; extra == "eegprep"
46
- Provides-Extra: hub
47
- Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hub"
48
- Requires-Dist: zarr<3.0,>=2.18; extra == "hub"
45
+ Provides-Extra: hug
46
+ Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hug"
49
47
  Provides-Extra: tests
50
48
  Requires-Dist: pytest; extra == "tests"
51
49
  Requires-Dist: pytest-cov; extra == "tests"
@@ -71,11 +69,7 @@ Requires-Dist: pre-commit; extra == "docs"
71
69
  Requires-Dist: openneuro-py; extra == "docs"
72
70
  Requires-Dist: plotly; extra == "docs"
73
71
  Provides-Extra: all
74
- Requires-Dist: braindecode[moabb]; extra == "all"
75
- Requires-Dist: braindecode[hub]; extra == "all"
76
- Requires-Dist: braindecode[tests]; extra == "all"
77
- Requires-Dist: braindecode[docs]; extra == "all"
78
- Requires-Dist: braindecode[eegprep]; extra == "all"
72
+ Requires-Dist: braindecode[docs,eegprep,hug,moabb,tests]; extra == "all"
79
73
  Dynamic: license-file
80
74
 
81
75
  .. image:: https://badges.gitter.im/braindecodechat/community.svg