pyprep 0.5.0__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,6 @@
1
+ .claude
2
+ CLAUDE.md
3
+
1
4
  .vscode
2
5
  .idea/*
3
6
  /idea
@@ -55,6 +55,10 @@ authors:
55
55
  family-names: Veillette
56
56
  affiliation: 'Department of Psychology, University of Chicago, Chicago, IL, USA'
57
57
  orcid: 'https://orcid.org/0000-0002-0332-4372'
58
+ - given-names: Roy Eric
59
+ family-names: Wieske
60
+ affiliation: 'Biopsychology and Neuroergonomics, Technische Universität Berlin, Berlin, Germany'
61
+ orcid: 'https://orcid.org/0009-0006-2018-1074'
58
62
  type: software
59
63
  repository-code: 'https://github.com/sappelhoff/pyprep'
60
64
  license: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyprep
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: PyPREP: A Python implementation of the preprocessing pipeline (PREP) for EEG data.
5
5
  Project-URL: Bug Tracker, https://github.com/sappelhoff/pyprep/issues/
6
6
  Project-URL: Documentation, https://pyprep.readthedocs.io/en/latest
@@ -38,12 +38,13 @@ Classifier: Operating System :: MacOS
38
38
  Classifier: Operating System :: Microsoft :: Windows
39
39
  Classifier: Operating System :: POSIX :: Linux
40
40
  Classifier: Programming Language :: Python
41
- Classifier: Programming Language :: Python :: 3.9
42
41
  Classifier: Programming Language :: Python :: 3.10
43
42
  Classifier: Programming Language :: Python :: 3.11
44
43
  Classifier: Programming Language :: Python :: 3.12
44
+ Classifier: Programming Language :: Python :: 3.13
45
+ Classifier: Programming Language :: Python :: 3.14
45
46
  Classifier: Topic :: Scientific/Engineering
46
- Requires-Python: >=3.9
47
+ Requires-Python: >=3.10
47
48
  Requires-Dist: mne>=1.3.0
48
49
  Requires-Dist: numpy>=1.20.2
49
50
  Requires-Dist: psutil>=5.4.3
@@ -134,7 +135,7 @@ for EEG data, working with `MNE-Python <https://mne.tools>`_.
134
135
  Installation
135
136
  ============
136
137
 
137
- ``pyprep`` runs on Python version 3.9 or higher.
138
+ ``pyprep`` runs on Python version 3.10 or higher.
138
139
 
139
140
  We recommend to run ``pyprep`` in a dedicated virtual environment
140
141
  (for example using `conda <https://docs.conda.io/en/latest/miniconda.html>`_).
@@ -48,7 +48,7 @@ for EEG data, working with `MNE-Python <https://mne.tools>`_.
48
48
  Installation
49
49
  ============
50
50
 
51
- ``pyprep`` runs on Python version 3.9 or higher.
51
+ ``pyprep`` runs on Python version 3.10 or higher.
52
52
 
53
53
  We recommend to run ``pyprep`` in a dedicated virtual environment
54
54
  (for example using `conda <https://docs.conda.io/en/latest/miniconda.html>`_).
@@ -29,18 +29,19 @@ class NoisyChannels:
29
29
  Parameters
30
30
  ----------
31
31
  raw : mne.io.Raw
32
- An MNE Raw object to check for bad EEG channels.
33
- do_detrend : bool, optional
32
+ An MNE Raw object to check for bad EEG channels. Channels set to bad
33
+ in ``raw.info["bads"]`` will not be used to find additional bad channels.
34
+ do_detrend : bool
34
35
  Whether or not low-frequency (<1.0 Hz) trends should be removed from the
35
36
  EEG signal prior to bad channel detection. This should always be set to
36
37
  ``True`` unless the signal has already had low-frequency trends removed.
37
38
  Defaults to ``True``.
38
- random_state : {int, None, np.random.RandomState}, optional
39
+ random_state : {int, None, np.random.RandomState} | None
39
40
  The seed to use for random number generation within RANSAC. This can be
40
41
  ``None``, an integer, or a :class:`~numpy.random.RandomState` object.
41
42
  If ``None``, a random seed will be obtained from the operating system.
42
43
  Defaults to ``None``.
43
- matlab_strict : bool, optional
44
+ matlab_strict : bool
44
45
  Whether or not PyPREP should strictly follow MATLAB PREP's internal
45
46
  math, ignoring any improvements made in PyPREP over the original code
46
47
  (see :ref:`matlab-diffs` for more details). Defaults to ``False``.
@@ -49,6 +50,21 @@ class NoisyChannels:
49
50
  to other methods. RANSAC can detect bad channels that other
50
51
  methods are unable to catch, but also slows down noisy channel
51
52
  detection considerably. Defaults to ``True``.
53
+ correlation : bool
54
+ Whether correlation should be used for bad channel detection, in addition
55
+ to other methods. Defaults to ``True``.
56
+ bad_by_manual : list of str | None
57
+ List of channels that are bad. These channels will be excluded when
58
+ trying to find additional bad channels. Note that the union of these channels
59
+ and those declared in ``raw.info["bads"]`` will be used. Defaults to ``None``.
60
+ reject_by_annotation : {None, 'omit'} | None
61
+ How to handle BAD-annotated time segments (annotations starting with
62
+ "BAD" or "bad") during channel quality assessment. If ``'omit'``,
63
+ annotated segments are excluded from analysis (clean segments are
64
+ concatenated). If ``None`` (default), annotations are ignored and the
65
+ full recording is used. This is useful when recordings contain breaks
66
+ or movement artifacts that shouldn't influence channel rejection
67
+ decisions.
52
68
 
53
69
  References
54
70
  ----------
@@ -66,13 +82,17 @@ class NoisyChannels:
66
82
  matlab_strict=False,
67
83
  *,
68
84
  ransac=True,
85
+ correlation=True,
86
+ bad_by_manual=None,
87
+ reject_by_annotation=None,
69
88
  ):
70
89
  # Make sure that we got an MNE object
71
90
  assert isinstance(raw, mne.io.BaseRaw)
72
91
 
73
92
  raw.load_data()
74
93
  self.raw_mne = raw.copy()
75
- self.bad_by_manual = raw.info["bads"]
94
+ bad_by_manual = bad_by_manual if bad_by_manual else []
95
+ self.bad_by_manual = list(set(bad_by_manual + raw.info["bads"]))
76
96
  self.raw_mne.pick("eeg") # excludes bads
77
97
  self.sample_rate = raw.info["sfreq"]
78
98
  if do_detrend:
@@ -81,15 +101,59 @@ class NoisyChannels:
81
101
  )
82
102
  self.matlab_strict = matlab_strict
83
103
 
84
- assert isinstance(ransac, bool), f"ransac must be boolean, got: {ransac}"
104
+ msg = f"ransac must be boolean, got: {ransac}"
105
+ assert isinstance(ransac, bool), msg
85
106
  self.ransac = ransac
86
107
 
108
+ msg = f"correlation must be boolean, got: {correlation}"
109
+ assert isinstance(correlation, bool), msg
110
+ self.correlation = correlation
111
+
112
+ # Validate reject_by_annotation parameter
113
+ if reject_by_annotation is not None and reject_by_annotation != "omit":
114
+ raise ValueError(
115
+ f"reject_by_annotation must be None or 'omit', "
116
+ f"got: {reject_by_annotation}"
117
+ )
118
+ # reject_by_annotation is not available in MATLAB PREP
119
+ if matlab_strict and reject_by_annotation is not None:
120
+ logger.warning(
121
+ "reject_by_annotation is not available in MATLAB PREP. "
122
+ f"Setting reject_by_annotation to None (was '{reject_by_annotation}')."
123
+ )
124
+ reject_by_annotation = None
125
+ self.reject_by_annotation = reject_by_annotation
126
+
127
+ # Warn if many small BAD segments are present (potential edge effects)
128
+ if reject_by_annotation is not None:
129
+ bad_annots = [
130
+ a
131
+ for a in raw.annotations
132
+ if a["description"].startswith(("BAD", "bad"))
133
+ ]
134
+ n_bad_segments = len(bad_annots)
135
+ if n_bad_segments > 0:
136
+ total_bad_time = sum(a["duration"] for a in bad_annots)
137
+ recording_length = raw.times[-1]
138
+ bad_percentage = (total_bad_time / recording_length) * 100
139
+ mean_duration = total_bad_time / n_bad_segments
140
+ if bad_percentage > 15 and mean_duration < 5.0:
141
+ logger.warning(
142
+ f"Found {n_bad_segments} BAD segments covering "
143
+ f"{bad_percentage:.1f}% of the recording with mean duration "
144
+ f"{mean_duration:.1f}s. Using reject_by_annotation with many "
145
+ "short segments may introduce edge effects from concatenation. "
146
+ "This feature is intended for excluding a small number of "
147
+ "longer segments (e.g., recording breaks)."
148
+ )
149
+
87
150
  # Extra data for debugging
88
151
  self._extra_info = {
89
152
  "bad_by_deviation": {},
90
153
  "bad_by_hf_noise": {},
91
154
  "bad_by_correlation": {},
92
155
  "bad_by_dropout": {},
156
+ "bad_by_psd": {},
93
157
  "bad_by_ransac": {},
94
158
  }
95
159
 
@@ -104,21 +168,29 @@ class NoisyChannels:
104
168
  self.bad_by_correlation = []
105
169
  self.bad_by_SNR = []
106
170
  self.bad_by_dropout = []
171
+ self.bad_by_psd = []
107
172
  self.bad_by_ransac = []
108
173
 
109
174
  # Get original EEG channel names, channel count & samples
110
175
  ch_names = np.asarray(self.raw_mne.info["ch_names"])
111
176
  self.ch_names_original = ch_names
112
177
  self.n_chans_original = len(ch_names)
113
- self.n_samples = raw.get_data().shape[1]
178
+ self.n_samples_original = raw.n_times
114
179
 
115
180
  # Before anything else, flag bad-by-NaNs and bad-by-flats
116
181
  self.find_bad_by_nan_flat()
117
182
  bads_by_nan_flat = self.bad_by_nan + self.bad_by_flat
118
183
 
184
+ # unusable channels are also those manually marked as bad
185
+ bads_unusable = self.bad_by_manual + bads_by_nan_flat
186
+
119
187
  # Make a subset of the data containing only usable EEG channels
120
- self.usable_idx = np.isin(ch_names, bads_by_nan_flat, invert=True)
121
- self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx])
188
+ self.usable_idx = np.isin(ch_names, bads_unusable, invert=True)
189
+ self.EEGData = self.raw_mne.get_data(
190
+ picks=ch_names[self.usable_idx],
191
+ reject_by_annotation=self.reject_by_annotation,
192
+ )
193
+ self.n_samples = self.EEGData.shape[1]
122
194
  self.EEGFiltered = None
123
195
 
124
196
  # Get usable EEG channel names & channel counts
@@ -154,10 +226,10 @@ class NoisyChannels:
154
226
 
155
227
  Parameters
156
228
  ----------
157
- verbose : bool, optional
229
+ verbose : bool | None
158
230
  If ``True``, a summary of the channels currently flagged as by bad per
159
231
  category is printed. Defaults to ``False``.
160
- as_dict: bool, optional
232
+ as_dict: bool | None
161
233
  If ``True``, this method will return a dict of the channels currently
162
234
  flagged as bad by each individual bad channel type. If ``False``, this
163
235
  method will return a list of all unique bad channels detected so far.
@@ -178,6 +250,7 @@ class NoisyChannels:
178
250
  "bad_by_correlation": self.bad_by_correlation,
179
251
  "bad_by_SNR": self.bad_by_SNR,
180
252
  "bad_by_dropout": self.bad_by_dropout,
253
+ "bad_by_psd": self.bad_by_psd,
181
254
  "bad_by_ransac": self.bad_by_ransac,
182
255
  "bad_by_manual": self.bad_by_manual,
183
256
  }
@@ -186,7 +259,12 @@ class NoisyChannels:
186
259
  for bad_chs in bads.values():
187
260
  all_bads.update(bad_chs)
188
261
 
189
- name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
262
+ name_map = {
263
+ "nan": "NaN",
264
+ "hf_noise": "HF noise",
265
+ "psd": "PSD",
266
+ "ransac": "RANSAC",
267
+ }
190
268
  if verbose:
191
269
  out = f"Found {len(all_bads)} uniquely bad channels:\n"
192
270
  for bad_type, bad_chs in bads.items():
@@ -203,7 +281,15 @@ class NoisyChannels:
203
281
 
204
282
  return bads
205
283
 
206
- def find_all_bads(self, ransac=None, channel_wise=False, max_chunk_size=None):
284
+ def find_all_bads(
285
+ self,
286
+ *,
287
+ ransac=None,
288
+ channel_wise=False,
289
+ max_chunk_size=None,
290
+ correlation=None,
291
+ reject_by_annotation=None,
292
+ ):
207
293
  """Call all the functions to detect bad channels.
208
294
 
209
295
  This function calls all the bad-channel detecting functions.
@@ -217,7 +303,7 @@ class NoisyChannels:
217
303
  detection considerably. If ``None`` (default), then the value at
218
304
  instantiation of the ``NoisyChannels`` class is taken (defaults
219
305
  to ``True``), else the instantiation value is overwritten.
220
- channel_wise : bool, optional
306
+ channel_wise : bool | None
221
307
  Whether RANSAC should predict signals for chunks of channels over the
222
308
  entire signal length ("channel-wise RANSAC", see `max_chunk_size`
223
309
  parameter). If ``False``, RANSAC will instead predict signals for all
@@ -227,28 +313,57 @@ class NoisyChannels:
227
313
  (especially if `max_chunk_size` is ``None``), but can be faster on
228
314
  systems with lots of RAM to spare. Has no effect if not using RANSAC.
229
315
  Defaults to ``False``.
230
- max_chunk_size : {int, None}, optional
316
+ max_chunk_size : {int, None} | None
231
317
  The maximum number of channels to predict at once during
232
318
  channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
233
319
  size that will fit into the available RAM, which may slow down
234
320
  other programs on the host system. If using window-wise RANSAC
235
321
  (the default) or not using RANSAC at all, this parameter has no
236
322
  effect. Defaults to ``None``.
323
+ correlation : bool | None
324
+ Whether correlation should be used for bad channel detection, in addition
325
+ to the other methods. If ``None`` (default), then the value at
326
+ instantiation of the ``NoisyChannels`` class is taken (defaults
327
+ to ``True``), else the instantiation value is overwritten.
328
+ reject_by_annotation : {None, 'omit'} | None
329
+ This parameter is accepted for compatibility but is ignored here.
330
+ Annotation rejection is applied during ``NoisyChannels`` initialization,
331
+ not during ``find_all_bads``. To use annotation rejection, pass
332
+ ``reject_by_annotation`` to the ``NoisyChannels`` constructor.
237
333
 
238
334
  """
335
+ # Note: reject_by_annotation is accepted but ignored here - it's applied
336
+ # during __init__ when data is extracted. This parameter exists only for
337
+ # compatibility with ransac_settings dict unpacking.
338
+ del reject_by_annotation # unused, applied in __init__
239
339
  if ransac is not None and ransac != self.ransac:
240
- assert isinstance(ransac, bool), f"ransac must be boolean, got: {ransac}"
340
+ msg = f"ransac must be boolean, got: {ransac}"
341
+ assert isinstance(ransac, bool), msg
241
342
  logger.warning(
242
- f"Overwriting `ransac` value. Was `{self.ransac}` at instantiation "
343
+ "Overwriting `ransac` value. "
344
+ f"Was `{self.ransac}` at instantiation "
243
345
  f"of NoisyChannels. Now setting to `{ransac}`."
244
346
  )
245
347
  self.ransac = ransac
246
348
 
349
+ if correlation is not None and correlation != self.correlation:
350
+ msg = f"correlation must be boolean, got: {correlation}"
351
+ assert isinstance(correlation, bool), msg
352
+ logger.warning(
353
+ "Overwriting `correlation` value. "
354
+ f"Was `{self.correlation}` at instantiation "
355
+ f"of NoisyChannels. Now setting to `{correlation}`."
356
+ )
357
+ self.correlation = correlation
358
+
247
359
  # NOTE: Bad-by-NaN/flat is already run during init, no need to re-run here
248
360
  self.find_bad_by_deviation()
249
361
  self.find_bad_by_hfnoise()
250
- self.find_bad_by_correlation()
362
+ if self.correlation:
363
+ self.find_bad_by_correlation()
251
364
  self.find_bad_by_SNR()
365
+ if not self.matlab_strict:
366
+ self.find_bad_by_PSD()
252
367
  if self.ransac:
253
368
  self.find_bad_by_ransac(
254
369
  channel_wise=channel_wise, max_chunk_size=max_chunk_size
@@ -263,17 +378,19 @@ class NoisyChannels:
263
378
 
264
379
  This method is run automatically when a ``NoisyChannels`` object is
265
380
  initialized, preventing flat or NaN-containing channels from interfering
266
- with the detection of other types of bad channels.
381
+ with the detection of other types of bad channels. The
382
+ ``reject_by_annotation`` setting of the :class:`NoisyChannels` instance
383
+ is respected when retrieving the data.
267
384
 
268
385
  Parameters
269
386
  ----------
270
- flat_threshold : float, optional
387
+ flat_threshold : float | None
271
388
  The lowest standard deviation or MAD value for a channel to be
272
389
  considered bad-by-flat. Defaults to ``1e-15`` volts (corresponds to
273
390
  10e-10 µV in MATLAB PREP).
274
391
  """
275
392
  # Get all EEG channels from original copy of data
276
- EEGData = self.raw_mne.get_data()
393
+ EEGData = self.raw_mne.get_data(reject_by_annotation=self.reject_by_annotation)
277
394
 
278
395
  # Detect channels containing any NaN values
279
396
  nan_channel_mask = np.isnan(np.sum(EEGData, axis=1))
@@ -304,7 +421,7 @@ class NoisyChannels:
304
421
 
305
422
  Parameters
306
423
  ----------
307
- deviation_threshold : float, optional
424
+ deviation_threshold : float | None
308
425
  The minimum absolute z-score of a channel for it to be considered
309
426
  bad-by-deviation. Defaults to ``5.0``.
310
427
 
@@ -350,7 +467,7 @@ class NoisyChannels:
350
467
 
351
468
  Parameters
352
469
  ----------
353
- HF_zscore_threshold : float, optional
470
+ HF_zscore_threshold : float | None
354
471
  The minimum noisiness z-score of a channel for it to be considered
355
472
  bad-by-high-frequency-noise. Defaults to ``5.0``.
356
473
 
@@ -415,12 +532,12 @@ class NoisyChannels:
415
532
 
416
533
  Parameters
417
534
  ----------
418
- correlation_secs : float, optional
535
+ correlation_secs : float | None
419
536
  The length (in seconds) of each correlation window. Defaults to ``1.0``.
420
- correlation_threshold : float, optional
537
+ correlation_threshold : float | None
421
538
  The lowest maximum inter-channel correlation for a channel to be
422
539
  considered "bad" within a given window. Defaults to ``0.4``.
423
- frac_bad : float, optional
540
+ frac_bad : float | None
424
541
  The minimum proportion of bad windows for a channel to be considered
425
542
  "bad-by-correlation" or "bad-by-dropout". Defaults to ``0.01`` (1% of
426
543
  all windows).
@@ -509,7 +626,7 @@ class NoisyChannels:
509
626
  # Get names of bad-by-HF-noise and bad-by-correlation channels
510
627
  if not len(self._extra_info["bad_by_hf_noise"]) > 1:
511
628
  self.find_bad_by_hfnoise()
512
- if not len(self._extra_info["bad_by_correlation"]):
629
+ if not len(self._extra_info["bad_by_correlation"]) and self.correlation:
513
630
  self.find_bad_by_correlation()
514
631
  bad_by_hf = set(self.bad_by_hf_noise)
515
632
  bad_by_corr = set(self.bad_by_correlation)
@@ -517,6 +634,151 @@ class NoisyChannels:
517
634
  # Flag channels bad by both HF noise and low correlation as bad by low SNR
518
635
  self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf))
519
636
 
637
+ def find_bad_by_PSD(self, zscore_threshold=3.0, fmin=1.0, fmax=45.0):
638
+ """Detect channels with abnormally high or low power spectral density.
639
+
640
+ This is a PyPREP-only method not present in the original MATLAB PREP.
641
+
642
+ A channel is considered "bad-by-psd" if:
643
+
644
+ 1. Its power in any frequency band (low: 1-15 Hz, mid: 15-30 Hz,
645
+ high: 30-45 Hz) is abnormally HIGH compared to other channels, OR
646
+ 2. Its high-frequency band has more power than its low-frequency band
647
+ (violating the typical 1/f spectral profile of EEG).
648
+
649
+ Note: Only excess power (positive z-scores) is flagged, as abnormally
650
+ low power could reflect normal topographic variation.
651
+
652
+ PSD is computed using Welch's method over the specified frequency range.
653
+ The default range (1-45 Hz) excludes line noise frequencies (50/60 Hz).
654
+
655
+ Parameters
656
+ ----------
657
+ zscore_threshold : float, optional
658
+ The minimum absolute z-score of a channel for it to be considered
659
+ bad-by-psd. Defaults to ``3.0``.
660
+ fmin : float, optional
661
+ The lower frequency bound (in Hz) for PSD computation.
662
+ Defaults to ``1.0``.
663
+ fmax : float, optional
664
+ The upper frequency bound (in Hz) for PSD computation. The default
665
+ of ``45.0`` excludes 50/60 Hz line noise from the analysis.
666
+
667
+ """
668
+ MAD_TO_SD = 1.4826 # Scales units of MAD to units of SD, assuming normality
669
+ # Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/mad.html
670
+
671
+ # Define frequency bands (in Hz)
672
+ BAND_LOW = (fmin, 15.0) # ~ delta, theta, alpha
673
+ BAND_MID = (15.0, 30.0) # ~ beta
674
+ BAND_HIGH = (30.0, fmax) # ~ gamma
675
+
676
+ if self.EEGFiltered is None:
677
+ self.EEGFiltered = self._get_filtered_data()
678
+
679
+ # Create a temporary Raw object from filtered data for PSD computation
680
+ info = mne.create_info(
681
+ ch_names=self.ch_names_new.tolist(),
682
+ sfreq=self.sample_rate,
683
+ ch_types="eeg",
684
+ )
685
+ raw_filtered = mne.io.RawArray(self.EEGFiltered, info, verbose=False)
686
+
687
+ # Compute PSD using Welch method and convert to log scale (dB)
688
+ psd = raw_filtered.compute_psd(
689
+ method="welch", fmin=fmin, fmax=fmax, verbose=False
690
+ )
691
+ psd_data = psd.get_data()
692
+ freqs = psd.freqs
693
+ log_psd = 10 * np.log10(psd_data)
694
+
695
+ # Get frequency indices for each band
696
+ idx_low = (freqs >= BAND_LOW[0]) & (freqs < BAND_LOW[1])
697
+ idx_mid = (freqs >= BAND_MID[0]) & (freqs < BAND_MID[1])
698
+ idx_high = (freqs >= BAND_HIGH[0]) & (freqs <= BAND_HIGH[1])
699
+
700
+ # Compute band power (sum of log PSD within each band) for each channel
701
+ band_power_low = np.sum(log_psd[:, idx_low], axis=1)
702
+ band_power_mid = np.sum(log_psd[:, idx_mid], axis=1)
703
+ band_power_high = np.sum(log_psd[:, idx_high], axis=1)
704
+
705
+ def robust_zscore(values):
706
+ """Compute robust z-scores using MAD."""
707
+ median = np.median(values)
708
+ mad = np.median(np.abs(values - median))
709
+ sd = mad * MAD_TO_SD
710
+ if sd > 0:
711
+ return (values - median) / sd
712
+ return np.zeros_like(values)
713
+
714
+ # Criterion 1: Outlier with abnormally HIGH power in any band
715
+ # Note: Only positive z-scores (excess power) are flagged, as low power
716
+ # could reflect normal topographic variation rather than a bad channel
717
+ zscore_low = robust_zscore(band_power_low)
718
+ zscore_mid = robust_zscore(band_power_mid)
719
+ zscore_high = robust_zscore(band_power_high)
720
+
721
+ bad_by_band = (
722
+ (zscore_low > zscore_threshold)
723
+ | (zscore_mid > zscore_threshold)
724
+ | (zscore_high > zscore_threshold)
725
+ )
726
+
727
+ # Criterion 2: 1/f violation (high freq band has more power than low freq band)
728
+ # This is unusual for normal EEG and suggests muscle artifact or bad contact
729
+ bad_by_1f_violation = band_power_high > band_power_low
730
+
731
+ # Criterion 3: Abnormal band ratios compared to other channels
732
+ # Use small epsilon to avoid division by zero
733
+ eps = np.finfo(float).eps
734
+ ratio_low_mid = band_power_low / (band_power_mid + eps)
735
+ ratio_low_high = band_power_low / (band_power_high + eps)
736
+ ratio_mid_high = band_power_mid / (band_power_high + eps)
737
+
738
+ zscore_ratio_low_mid = robust_zscore(ratio_low_mid)
739
+ zscore_ratio_low_high = robust_zscore(ratio_low_high)
740
+ zscore_ratio_mid_high = robust_zscore(ratio_mid_high)
741
+
742
+ bad_by_ratio = (
743
+ (np.abs(zscore_ratio_low_mid) > zscore_threshold)
744
+ | (np.abs(zscore_ratio_low_high) > zscore_threshold)
745
+ | (np.abs(zscore_ratio_mid_high) > zscore_threshold)
746
+ )
747
+
748
+ # Combine criteria (bad if ANY criterion is met)
749
+ # Note: bad_by_ratio is computed for diagnostics but not used in final
750
+ # decision as it tends to be overly sensitive and theoretically debatable
751
+ bad_by_psd_usable = bad_by_band | bad_by_1f_violation
752
+
753
+ # Map back to original channel indices
754
+ psd_channel_mask = np.zeros(self.n_chans_original, dtype=bool)
755
+ psd_channel_mask[self.usable_idx] = bad_by_psd_usable
756
+ abnormal_psd_channels = self.ch_names_original[psd_channel_mask]
757
+
758
+ # Compute combined z-score for reporting (max absolute z-score across bands)
759
+ psd_zscore = np.zeros(self.n_chans_original)
760
+ max_band_zscore = np.maximum(
761
+ np.abs(zscore_low), np.maximum(np.abs(zscore_mid), np.abs(zscore_high))
762
+ )
763
+ psd_zscore[self.usable_idx] = max_band_zscore
764
+
765
+ # Update names of bad channels by abnormal PSD & save additional info
766
+ self.bad_by_psd = abnormal_psd_channels.tolist()
767
+ self._extra_info["bad_by_psd"].update(
768
+ {
769
+ "psd_zscore": psd_zscore,
770
+ "band_power_low": band_power_low,
771
+ "band_power_mid": band_power_mid,
772
+ "band_power_high": band_power_high,
773
+ "zscore_low": zscore_low,
774
+ "zscore_mid": zscore_mid,
775
+ "zscore_high": zscore_high,
776
+ "bad_by_band": bad_by_band,
777
+ "bad_by_1f_violation": bad_by_1f_violation,
778
+ "bad_by_ratio": bad_by_ratio,
779
+ }
780
+ )
781
+
520
782
  def find_bad_by_ransac(
521
783
  self,
522
784
  n_samples=50,
@@ -559,26 +821,26 @@ class NoisyChannels:
559
821
 
560
822
  Parameters
561
823
  ----------
562
- n_samples : int, optional
824
+ n_samples : int | None
563
825
  Number of random channel samples to use for RANSAC. Defaults
564
826
  to ``50``.
565
- sample_prop : float, optional
827
+ sample_prop : float | None
566
828
  Proportion of total channels to use for signal prediction per RANSAC
567
829
  sample. This needs to be in the range [0, 1], where 0 would mean no
568
830
  channels would be used and 1 would mean all channels would be used
569
831
  (neither of which would be useful values). Defaults to ``0.25``
570
832
  (e.g., 16 channels per sample for a 64-channel dataset).
571
- corr_thresh : float, optional
833
+ corr_thresh : float | None
572
834
  The minimum predicted vs. actual signal correlation for a channel to
573
835
  be considered good within a given RANSAC window. Defaults
574
836
  to ``0.75``.
575
- frac_bad : float, optional
837
+ frac_bad : float | None
576
838
  The minimum fraction of bad (i.e., below-threshold) RANSAC windows
577
839
  for a channel to be considered bad-by-RANSAC. Defaults to ``0.4``.
578
- corr_window_secs : float, optional
840
+ corr_window_secs : float | None
579
841
  The duration (in seconds) of each RANSAC correlation window. Defaults
580
842
  to 5 seconds.
581
- channel_wise : bool, optional
843
+ channel_wise : bool | None
582
844
  Whether RANSAC should predict signals for chunks of channels over the
583
845
  entire signal length ("channel-wise RANSAC", see `max_chunk_size`
584
846
  parameter). If ``False``, RANSAC will instead predict signals for all
@@ -587,7 +849,7 @@ class NoisyChannels:
587
849
  RANSAC generally has higher RAM demands than window-wise RANSAC
588
850
  (especially if `max_chunk_size` is ``None``), but can be faster on
589
851
  systems with lots of RAM to spare. Defaults to ``False``.
590
- max_chunk_size : {int, None}, optional
852
+ max_chunk_size : {int, None} | None
591
853
  The maximum number of channels to predict at once during
592
854
  channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
593
855
  size that will fit into the available RAM, which may slow down
@@ -622,7 +884,9 @@ class NoisyChannels:
622
884
  self.EEGFiltered,
623
885
  self.sample_rate,
624
886
  self.ch_names_new,
625
- self.raw_mne._get_channel_positions()[self.usable_idx, :],
887
+ self.raw_mne._get_channel_positions(self.raw_mne.ch_names)[
888
+ self.usable_idx, :
889
+ ],
626
890
  exclude_from_ransac,
627
891
  n_samples,
628
892
  sample_prop,
@@ -3,12 +3,13 @@
3
3
  # Authors: The PyPREP developers
4
4
  # SPDX-License-Identifier: MIT
5
5
 
6
+ import warnings
7
+
6
8
  import mne
7
9
  from mne.utils import check_random_state
8
10
 
9
11
  from pyprep.reference import Reference
10
12
  from pyprep.removeTrend import removeTrend
11
- from pyprep.utils import _set_diff, _union # noqa: F401
12
13
 
13
14
 
14
15
  class PrepPipeline:
@@ -39,15 +40,15 @@ class PrepPipeline:
39
40
  For example, for 60Hz you may specify
40
41
  ``np.arange(60, sfreq / 2, 60)``. Specify an empty list to
41
42
  skip the line noise removal step.
42
- - max_iterations : int, optional
43
+ - max_iterations : int | None
43
44
  - The maximum number of iterations of noisy channel removal to
44
45
  perform during robust referencing. Defaults to ``4``.
45
46
  montage : mne.channels.DigMontage
46
47
  Digital montage of EEG data.
47
- ransac : bool, optional
48
+ ransac : bool | None
48
49
  Whether or not to use RANSAC for noisy channel detection in addition to
49
50
  the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True.
50
- channel_wise : bool, optional
51
+ channel_wise : bool | None
51
52
  Whether RANSAC should predict signals for chunks of channels over the
52
53
  entire signal length ("channel-wise RANSAC", see `max_chunk_size`
53
54
  parameter). If ``False``, RANSAC will instead predict signals for all
@@ -57,24 +58,32 @@ class PrepPipeline:
57
58
  (especially if `max_chunk_size` is ``None``), but can be faster on
58
59
  systems with lots of RAM to spare. Has no effect if not using RANSAC.
59
60
  Defaults to ``False``.
60
- max_chunk_size : {int, None}, optional
61
+ max_chunk_size : {int, None} | None
61
62
  The maximum number of channels to predict at once during channel-wise
62
63
  RANSAC. If ``None``, RANSAC will use the largest chunk size that will
63
64
  fit into the available RAM, which may slow down other programs on the
64
65
  host system. If using window-wise RANSAC (the default) or not using
65
66
  RANSAC at all, this parameter has no effect. Defaults to ``None``.
66
- random_state : {int, None, np.random.RandomState}, optional
67
+ random_state : {int, None, np.random.RandomState} | None
67
68
  The random seed at which to initialize the class. If random_state is
68
69
  an int, it will be used as a seed for RandomState.
69
70
  If None, the seed will be obtained from the operating system
70
71
  (see RandomState for details). Default is None.
71
- filter_kwargs : {dict, None}, optional
72
+ filter_kwargs : {dict, None} | None
72
73
  Optional keywords arguments to be passed on to mne.filter.notch_filter.
73
74
  Do not set the "x", Fs", and "freqs" arguments via the filter_kwargs
74
75
  parameter, but use the "raw" and "prep_params" parameters instead.
75
76
  If None is passed, the pyprep default settings for filtering are used
76
77
  instead.
77
- matlab_strict : bool, optional
78
+ reject_by_annotation : {None, 'omit'} | None
79
+ How to handle BAD-annotated time segments (annotations starting with
80
+ "BAD" or "bad") during channel quality assessment. If ``'omit'``,
81
+ annotated segments are excluded from analysis (clean segments are
82
+ concatenated). If ``None`` (default), annotations are ignored and the
83
+ full recording is used. This is useful when recordings contain breaks
84
+ or movement artifacts that shouldn't influence channel rejection
85
+ decisions.
86
+ matlab_strict : bool | None
78
87
  Whether or not PyPREP should strictly follow MATLAB PREP's internal
79
88
  math, ignoring any improvements made in PyPREP over the original code
80
89
  (see :ref:`matlab-diffs` for more details). Defaults to False.
@@ -128,6 +137,7 @@ class PrepPipeline:
128
137
  max_chunk_size=None,
129
138
  random_state=None,
130
139
  filter_kwargs=None,
140
+ reject_by_annotation=None,
131
141
  matlab_strict=False,
132
142
  ):
133
143
  """Initialize PREP class."""
@@ -167,11 +177,24 @@ class PrepPipeline:
167
177
  "ransac": ransac,
168
178
  "channel_wise": channel_wise,
169
179
  "max_chunk_size": max_chunk_size,
180
+ "reject_by_annotation": reject_by_annotation,
170
181
  }
171
182
  self.random_state = check_random_state(random_state)
172
183
  self.filter_kwargs = filter_kwargs
173
184
  self.matlab_strict = matlab_strict
174
185
 
186
+ # Initialize attributes to be filled in later
187
+ self._line_noise_removed = False
188
+ self.noisy_channels_original = None
189
+ self.noisy_channels_before_interpolation = None
190
+ self.noisy_channels_after_interpolation = None
191
+ self.bad_before_interpolation = None
192
+ self.EEG_before_interpolation = None
193
+ self.reference_before_interpolation = None
194
+ self.reference_after_interpolation = None
195
+ self.interpolated_channels = None
196
+ self.still_noisy_channels = None
197
+
175
198
  @property
176
199
  def raw(self):
177
200
  """Return a version of self.raw_eeg that includes the non-eeg channels."""
@@ -181,39 +204,96 @@ class PrepPipeline:
181
204
  else:
182
205
  return full_raw.add_channels([self.raw_non_eeg], force_update_info=True)
183
206
 
184
- def fit(self):
185
- """Run the whole PREP pipeline."""
186
- # Step 1: 1Hz high pass filtering
187
- if len(self.prep_params["line_freqs"]) != 0:
188
- self.EEG_new = removeTrend(
189
- self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict
207
+ def remove_line_noise(self, line_freqs=None):
208
+ """Remove line noise from all EEG channels.
209
+
210
+ Line noise is removed by detrending the signal, applying a notch filter,
211
+ and adding the slow drifts back. By default the notch filter uses MNE's
212
+ ``spectrum_fit`` method, which attempts to isolate and remove line noise
213
+ while preserving unrelated background signal in the same frequency ranges
214
+ (to minimize distortions in the power-spectral density). The filter can be
215
+ configured via the ``filter_kwargs`` argument of :class:`PrepPipeline`.
216
+
217
+ Parameters
218
+ ----------
219
+ line_freqs : {np.ndarray, list, None}, optional
220
+ A list of the frequencies (in Hz) at which line noise should be removed
221
+ (e.g., ``np.arange(60, sfreq / 2, 60)`` for a recording with a powerline
222
+ noise of 60 Hz). If ``None`` (default), the ``"line_freqs"`` entry of the
223
+ ``prep_params`` passed to :class:`PrepPipeline` is used.
224
+
225
+ """
226
+ if line_freqs is None:
227
+ line_freqs = self.prep_params["line_freqs"]
228
+
229
+ # Remove slow drifts from the recording prior to filtering
230
+ self.EEG_new = removeTrend(
231
+ self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict
232
+ )
233
+
234
+ # Remove line noise. When no filter kwargs are given, fall back to PREP's
235
+ # default ``spectrum_fit`` settings; otherwise use the provided kwargs as-is.
236
+ if self.filter_kwargs is None:
237
+ self.EEG_clean = mne.filter.notch_filter(
238
+ self.EEG_new,
239
+ Fs=self.sfreq,
240
+ freqs=line_freqs,
241
+ method="spectrum_fit",
242
+ mt_bandwidth=2,
243
+ p_value=0.01,
244
+ filter_length="10s",
245
+ )
246
+ else:
247
+ self.EEG_clean = mne.filter.notch_filter(
248
+ self.EEG_new,
249
+ Fs=self.sfreq,
250
+ freqs=line_freqs,
251
+ **self.filter_kwargs,
252
+ )
253
+
254
+ # Add the slow drifts back
255
+ self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
256
+ self.raw_eeg._data = self.EEG
257
+ self._line_noise_removed = True
258
+
259
+ def robust_reference(self, max_iterations=None, interpolate_bads=True):
260
+ """Perform robust referencing on the EEG signal and detect bad channels.
261
+
262
+ This method uses an iterative approach to estimate a robust average
263
+ reference signal free of contamination from bad channels, as detected
264
+ automatically using the methods of :class:`~pyprep.NoisyChannels`. Once
265
+ estimated, the robust average reference is applied to the data and bad
266
+ channel detection is re-run to flag any noisy or unusable channels
267
+ post-reference.
268
+
269
+ By default, this method will also interpolate the signals of any channels
270
+ detected as bad following robust referencing, re-reference the data
271
+ accordingly, and re-detect any remaining bad channels.
272
+
273
+ Parameters
274
+ ----------
275
+ max_iterations : {int, None}, optional
276
+ The maximum number of iterations of noisy channel removal to perform
277
+ during robust referencing. If ``None`` (default), the ``"max_iterations"``
278
+ entry of the ``prep_params`` passed to :class:`PrepPipeline` is used.
279
+ interpolate_bads : bool, optional
280
+ Whether or not any remaining bad channels following robust referencing
281
+ should be interpolated. Defaults to ``True``.
282
+
283
+ """
284
+ if max_iterations is None:
285
+ max_iterations = self.prep_params["max_iterations"]
286
+
287
+ if not self._line_noise_removed:
288
+ warnings.warn(
289
+ "Robust referencing is being performed without prior line-noise "
290
+ "removal. If this is intentional, you can safely ignore this "
291
+ "warning; otherwise, call `remove_line_noise` first or use `fit`.",
292
+ UserWarning,
293
+ stacklevel=2,
190
294
  )
191
295
 
192
- # Step 2: Removing line noise
193
- linenoise = self.prep_params["line_freqs"]
194
- if self.filter_kwargs is None:
195
- self.EEG_clean = mne.filter.notch_filter(
196
- self.EEG_new,
197
- Fs=self.sfreq,
198
- freqs=linenoise,
199
- method="spectrum_fit",
200
- mt_bandwidth=2,
201
- p_value=0.01,
202
- filter_length="10s",
203
- )
204
- else:
205
- self.EEG_clean = mne.filter.notch_filter(
206
- self.EEG_new,
207
- Fs=self.sfreq,
208
- freqs=linenoise,
209
- **self.filter_kwargs,
210
- )
211
-
212
- # Add Trend back
213
- self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
214
- self.raw_eeg._data = self.EEG
215
-
216
- # Step 3: Referencing
296
+ # Perform robust referencing on the signal
217
297
  reference = Reference(
218
298
  self.raw_eeg,
219
299
  self.prep_params,
@@ -221,7 +301,8 @@ class PrepPipeline:
221
301
  matlab_strict=self.matlab_strict,
222
302
  **self.ransac_settings,
223
303
  )
224
- reference.perform_reference(self.prep_params["max_iterations"])
304
+ reference.perform_reference(max_iterations, interpolate_bads)
305
+
225
306
  self.raw_eeg = reference.raw
226
307
  self.noisy_channels_original = reference.noisy_channels_original
227
308
  self.noisy_channels_before_interpolation = (
@@ -237,4 +318,17 @@ class PrepPipeline:
237
318
  self.interpolated_channels = reference.interpolated_channels
238
319
  self.still_noisy_channels = reference.still_noisy_channels
239
320
 
321
+ def fit(self):
322
+ """Run the whole PREP pipeline."""
323
+ # Step 1: Adaptive line noise removal
324
+ if len(self.prep_params["line_freqs"]) != 0:
325
+ self.remove_line_noise(self.prep_params["line_freqs"])
326
+ else:
327
+ # No line noise to remove: mark the stage as deliberately skipped so
328
+ # that `robust_reference` does not emit a spurious warning.
329
+ self._line_noise_removed = True
330
+
331
+ # Step 2: Robust Referencing
332
+ self.robust_reference(self.prep_params["max_iterations"])
333
+
240
334
  return self
@@ -60,24 +60,24 @@ def find_bad_by_ransac(
60
60
  exclude : list
61
61
  Labels of channels to exclude as signal predictors during RANSAC
62
62
  (i.e., channels already flagged as bad by metrics other than HF noise).
63
- n_samples : int, optional
63
+ n_samples : int | None
64
64
  Number of random channel samples to use for RANSAC. Defaults to ``50``.
65
- sample_prop : float, optional
65
+ sample_prop : float | None
66
66
  Proportion of total channels to use for signal prediction per RANSAC
67
67
  sample. This needs to be in the range [0, 1], where 0 would mean no
68
68
  channels would be used and 1 would mean all channels would be used
69
69
  (neither of which would be useful values). Defaults to ``0.25`` (e.g.,
70
70
  16 channels per sample for a 64-channel dataset).
71
- corr_thresh : float, optional
71
+ corr_thresh : float | None
72
72
  The minimum predicted vs. actual signal correlation for a channel to
73
73
  be considered good within a given RANSAC window. Defaults to ``0.75``.
74
- frac_bad : float, optional
74
+ frac_bad : float | None
75
75
  The minimum fraction of bad (i.e., below-threshold) RANSAC windows for a
76
76
  channel to be considered bad-by-RANSAC. Defaults to ``0.4``.
77
- corr_window_secs : float, optional
77
+ corr_window_secs : float | None
78
78
  The duration (in seconds) of each RANSAC correlation window. Defaults to
79
79
  5 seconds.
80
- channel_wise : bool, optional
80
+ channel_wise : bool | None
81
81
  Whether RANSAC should predict signals for chunks of channels over the
82
82
  entire signal length ("channel-wise RANSAC", see `max_chunk_size`
83
83
  parameter). If ``False``, RANSAC will instead predict signals for all
@@ -86,18 +86,18 @@ def find_bad_by_ransac(
86
86
  RANSAC generally has higher RAM demands than window-wise RANSAC
87
87
  (especially if `max_chunk_size` is ``None``), but can be faster on
88
88
  systems with lots of RAM to spare. Defaults to ``False``.
89
- max_chunk_size : {int, None}, optional
89
+ max_chunk_size : {int, None} | None
90
90
  The maximum number of channels to predict at once during channel-wise
91
91
  RANSAC. If ``None``, RANSAC will use the largest chunk size that will
92
92
  fit into the available RAM, which may slow down other programs on the
93
93
  host system. If using window-wise RANSAC (the default), this parameter
94
94
  has no effect. Defaults to ``None``.
95
- random_state : {int, None, np.random.RandomState}, optional
95
+ random_state : {int, None, np.random.RandomState} | None
96
96
  The random seed with which to generate random samples of channels during
97
97
  RANSAC. If random_state is an int, it will be used as a seed for RandomState.
98
98
  If ``None``, the seed will be obtained from the operating system
99
99
  (see RandomState for details). Defaults to ``None``.
100
- matlab_strict : bool, optional
100
+ matlab_strict : bool | None
101
101
  Whether or not RANSAC should strictly follow MATLAB PREP's internal
102
102
  math, ignoring any improvements made in PyPREP over the original code
103
103
  (see :ref:`matlab-diffs` for more details). Defaults to ``False``.
@@ -12,9 +12,6 @@ from pyprep.find_noisy_channels import NoisyChannels
12
12
  from pyprep.removeTrend import removeTrend
13
13
  from pyprep.utils import _eeglab_interpolate_bads, _set_diff, _union
14
14
 
15
- logging.basicConfig(
16
- level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17
- )
18
15
  logger = logging.getLogger(__name__)
19
16
 
20
17
 
@@ -32,10 +29,10 @@ class Reference:
32
29
  Parameters of PREP which include at least the following keys:
33
30
  - ``ref_chs``
34
31
  - ``reref_chs``
35
- ransac : bool, optional
32
+ ransac : bool | None
36
33
  Whether or not to use RANSAC for noisy channel detection in addition to
37
34
  the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True.
38
- channel_wise : bool, optional
35
+ channel_wise : bool | None
39
36
  Whether RANSAC should predict signals for chunks of channels over the
40
37
  entire signal length ("channel-wise RANSAC", see `max_chunk_size`
41
38
  parameter). If ``False``, RANSAC will instead predict signals for all
@@ -45,18 +42,22 @@ class Reference:
45
42
  (especially if `max_chunk_size` is ``None``), but can be faster on
46
43
  systems with lots of RAM to spare. Has no effect if not using RANSAC.
47
44
  Defaults to ``False``.
48
- max_chunk_size : {int, None}, optional
45
+ max_chunk_size : {int, None} | None
49
46
  The maximum number of channels to predict at once during channel-wise
50
47
  RANSAC. If ``None``, RANSAC will use the largest chunk size that will
51
48
  fit into the available RAM, which may slow down other programs on the
52
49
  host system. If using window-wise RANSAC (the default) or not using
53
50
  RANSAC at all, this parameter has no effect. Defaults to ``None``.
54
- random_state : {int, None, np.random.RandomState}, optional
51
+ random_state : {int, None, np.random.RandomState} | None
55
52
  The random seed at which to initialize the class. If random_state is
56
53
  an int, it will be used as a seed for RandomState.
57
54
  If None, the seed will be obtained from the operating system
58
55
  (see RandomState for details). Default is None.
59
- matlab_strict : bool, optional
56
+ reject_by_annotation : {None, 'omit'} | None
57
+ How to handle BAD-annotated time segments (annotations starting with
58
+ "BAD" or "bad") during channel quality assessment. If ``'omit'``,
59
+ annotated segments are excluded. Defaults to ``None`` (ignore).
60
+ matlab_strict : bool | None
60
61
  Whether or not PyPREP should strictly follow MATLAB PREP's internal
61
62
  math, ignoring any improvements made in PyPREP over the original code.
62
63
  Defaults to False.
@@ -77,6 +78,7 @@ class Reference:
77
78
  channel_wise=False,
78
79
  max_chunk_size=None,
79
80
  random_state=None,
81
+ reject_by_annotation=None,
80
82
  matlab_strict=False,
81
83
  ):
82
84
  """Initialize the class."""
@@ -94,44 +96,62 @@ class Reference:
94
96
  "ransac": ransac,
95
97
  "channel_wise": channel_wise,
96
98
  "max_chunk_size": max_chunk_size,
99
+ "reject_by_annotation": reject_by_annotation,
97
100
  }
98
101
  self.random_state = check_random_state(random_state)
99
- self._extra_info = {}
100
102
  self.matlab_strict = matlab_strict
101
103
 
102
- def perform_reference(self, max_iterations=4):
104
+ # Initialize attributes that get filled in during referencing
105
+ self.bad_before_interpolation = None
106
+ self.EEG_before_interpolation = None
107
+ self.noisy_channels_before_interpolation = None
108
+ self.reference_signal_new = None
109
+ self.interpolated_channels = None
110
+ self.still_noisy_channels = None
111
+ self.noisy_channels_after_interpolation = None
112
+ self._extra_info = {
113
+ "initial_bad": None,
114
+ "interpolated": None,
115
+ "remaining_bad": None,
116
+ }
117
+
118
+ def perform_reference(self, max_iterations=4, interpolate_bads=True):
103
119
  """Estimate the true signal mean and interpolate bad channels.
104
120
 
121
+ This function implements the functionality of the `performReference` function
122
+ as part of the PREP pipeline on mne raw object.
123
+
105
124
  Parameters
106
125
  ----------
107
- max_iterations : int, optional
126
+ max_iterations : int | None
108
127
  The maximum number of iterations of noisy channel removal to perform
109
128
  during robust referencing. Defaults to ``4``.
110
-
111
- This function implements the functionality of the `performReference` function
112
- as part of the PREP pipeline on mne raw object.
129
+ interpolate_bads : bool, optional
130
+ Whether or not any remaining bad channels following robust referencing
131
+ should be interpolated or left as-is. Defaults to ``True``.
113
132
 
114
133
  Notes
115
134
  -----
116
135
  This function calls ``robust_reference`` first.
117
- Currently this function only implements the functionality of default
118
- settings, i.e., ``doRobustPost``.
119
136
 
120
137
  """
121
- # Phase 1: Estimate the true signal mean with robust referencing
138
+ # Estimate the true signal mean with robust referencing
122
139
  self.robust_reference(max_iterations)
123
140
  # If we interpolate the raw here we would be interpolating
124
141
  # more than what we later actually account for (in interpolated channels).
125
142
  dummy = self.raw.copy()
126
143
  dummy.info["bads"] = self.noisy_channels["bad_all"]
127
- if self.matlab_strict:
128
- _eeglab_interpolate_bads(dummy)
129
- else:
130
- dummy.interpolate_bads()
144
+ if len(dummy.info["bads"]) > 0:
145
+ if self.matlab_strict:
146
+ _eeglab_interpolate_bads(dummy)
147
+ else:
148
+ dummy.interpolate_bads()
131
149
  self.reference_signal = np.nanmean(
132
150
  dummy.get_data(picks=self.reference_channels), axis=0
133
151
  )
134
152
  del dummy
153
+
154
+ # Re-reference the data using the calculated robust average reference
135
155
  rereferenced_index = [
136
156
  self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
137
157
  ]
@@ -139,42 +159,80 @@ class Reference:
139
159
  self.EEG, self.reference_signal, rereferenced_index
140
160
  )
141
161
 
142
- # Phase 2: Find the bad channels and interpolate
162
+ # Detect which channels are still bad following robust referencing
143
163
  self.raw._data = self.EEG
144
164
  noisy_detector = NoisyChannels(
145
- self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict
165
+ self.raw,
166
+ random_state=self.random_state,
167
+ matlab_strict=self.matlab_strict,
168
+ reject_by_annotation=self.ransac_settings.get("reject_by_annotation"),
146
169
  )
147
170
  noisy_detector.find_all_bads(**self.ransac_settings)
148
-
149
- # Record Noisy channels and EEG before interpolation
150
171
  self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
151
172
  self.EEG_before_interpolation = self.EEG.copy()
152
173
  self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
153
174
  self.noisy_channels_before_interpolation["bad_by_manual"] = self.bads_manual
154
175
  self._extra_info["interpolated"] = noisy_detector._extra_info
155
176
 
177
+ # Update bad channels in MNE raw object
156
178
  bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
157
179
  self.raw.info["bads"] = bad_channels
158
- if self.matlab_strict:
159
- _eeglab_interpolate_bads(self.raw)
160
- else:
161
- self.raw.interpolate_bads()
180
+
181
+ # If enabled, interpolate all bad channels and detect any remaining bads
182
+ if interpolate_bads:
183
+ self.interpolate_bads()
184
+
185
+ return self
186
+
187
+ def interpolate_bads(self):
188
+ """Interpolate any remaining bad channels following robust referencing.
189
+
190
+ This method can only be called if :meth:`~.perform_reference` has already
191
+ been run with the ``interpolate_bads`` parameter set to ``False``. It cannot
192
+ be run more than once per instance of :class:`~pyprep.Reference`.
193
+
194
+ """
195
+ if self.bad_before_interpolation is None:
196
+ raise RuntimeError(
197
+ "Robust referencing must be performed before remaining bad channels "
198
+ "can be interpolated."
199
+ )
200
+ elif self.interpolated_channels is not None:
201
+ raise RuntimeError(
202
+ "Bad channel interpolation cannot be performed more than once - "
203
+ "interpolating signals using other interpolated signals is likely "
204
+ "to have poor results."
205
+ )
206
+
207
+ # Interpolate any channels flagged as bad following robust referencing
208
+ bad_channels = self.raw.info["bads"]
209
+ if len(bad_channels) > 0:
210
+ if self.matlab_strict:
211
+ _eeglab_interpolate_bads(self.raw)
212
+ else:
213
+ self.raw.interpolate_bads()
214
+
215
+ # Calculate and remove the new average reference following interpolation
162
216
  reference_correct = np.nanmean(
163
217
  self.raw.get_data(picks=self.reference_channels), axis=0
164
218
  )
219
+ rereferenced_index = [
220
+ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels
221
+ ]
165
222
  self.EEG = self.raw.get_data()
166
223
  self.EEG = self.remove_reference(
167
224
  self.EEG, reference_correct, rereferenced_index
168
225
  )
169
- # reference signal after interpolation
170
226
  self.reference_signal_new = self.reference_signal + reference_correct
171
- # MNE Raw object after interpolation
172
- self.raw._data = self.EEG
227
+ self.raw._data = self.EEG # Update the MNE Raw object
173
228
 
174
- # Still noisy channels after interpolation
229
+ # Detect any remaining noisy channels following interpolation
175
230
  self.interpolated_channels = bad_channels
176
231
  noisy_detector = NoisyChannels(
177
- self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict
232
+ self.raw,
233
+ random_state=self.random_state,
234
+ matlab_strict=self.matlab_strict,
235
+ reject_by_annotation=self.ransac_settings.get("reject_by_annotation"),
178
236
  )
179
237
  noisy_detector.find_all_bads(**self.ransac_settings)
180
238
  self.still_noisy_channels = noisy_detector.get_bads()
@@ -192,7 +250,7 @@ class Reference:
192
250
 
193
251
  Parameters
194
252
  ----------
195
- max_iterations : int, optional
253
+ max_iterations : int | None
196
254
  The maximum number of iterations of noisy channel removal to perform
197
255
  during robust referencing. Defaults to ``4``.
198
256
 
@@ -216,6 +274,7 @@ class Reference:
216
274
  do_detrend=False,
217
275
  random_state=self.random_state,
218
276
  matlab_strict=self.matlab_strict,
277
+ reject_by_annotation=self.ransac_settings.get("reject_by_annotation"),
219
278
  )
220
279
  noisy_detector.find_all_bads(**self.ransac_settings)
221
280
  self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
@@ -238,6 +297,7 @@ class Reference:
238
297
  "bad_by_correlation": [],
239
298
  "bad_by_SNR": [],
240
299
  "bad_by_dropout": [],
300
+ "bad_by_psd": [],
241
301
  "bad_by_ransac": [],
242
302
  "bad_by_manual": self.bads_manual,
243
303
  "bad_all": [],
@@ -265,6 +325,7 @@ class Reference:
265
325
  do_detrend=False,
266
326
  random_state=self.random_state,
267
327
  matlab_strict=self.matlab_strict,
328
+ reject_by_annotation=self.ransac_settings.get("reject_by_annotation"),
268
329
  )
269
330
  # Detrend applied at the beginning of the function.
270
331
 
@@ -338,7 +399,7 @@ class Reference:
338
399
  The original EEG signal.
339
400
  reference : np.ndarray, shape(times,)
340
401
  The reference signal.
341
- index : {list, None}, optional
402
+ index : {list, None} | None
342
403
  A list of channel indices from which the reference signal should be
343
404
  subtracted. Defaults to all channels in `signal`.
344
405
 
@@ -27,16 +27,16 @@ def removeTrend(
27
27
  A 2-D array of EEG data to detrend.
28
28
  sample_rate : float
29
29
  The sample rate (in Hz) of the input EEG data.
30
- detrendType : str, optional
30
+ detrendType : str | None
31
31
  Type of detrending to be performed: must be one of 'high pass',
32
32
  'high pass sinc, or 'local detrend'. Defaults to 'high pass'.
33
- detrendCutoff : float, optional
33
+ detrendCutoff : float | None
34
34
  The high-pass cutoff frequency (in Hz) to use for detrending. Defaults
35
35
  to 1.0 Hz.
36
- detrendChannels : {list, None}, optional
36
+ detrendChannels : {list, None} | None
37
37
  List of the indices of all channels that require detrending/filtering.
38
38
  If ``None``, all channels are used (default).
39
- matlab_strict : bool, optional
39
+ matlab_strict : bool | None
40
40
  Whether or not detrending should strictly follow MATLAB PREP's internal
41
41
  math, ignoring any improvements made in PyPREP over the original code
42
42
  (see :ref:`matlab-diffs` for more details). Defaults to ``False``.
@@ -56,7 +56,7 @@ def _mat_quantile(arr, q, axis=None):
56
56
  q : float
57
57
  The quantile to calculate for the input data. Must be between 0 and 1,
58
58
  inclusive.
59
- axis : {int, tuple of int, None}, optional
59
+ axis : {int, tuple of int, None} | None
60
60
  Axis along which quantile values should be calculated. Defaults to
61
61
  calculating the value at the given quantile for the entire array.
62
62
 
@@ -130,7 +130,7 @@ def _mat_iqr(arr, axis=None):
130
130
  ----------
131
131
  arr : np.ndarray
132
132
  Input array containing samples from the distribution to summarize.
133
- axis : {int, tuple of int, None}, optional
133
+ axis : {int, tuple of int, None} | None
134
134
  Axis along which IQRs should be calculated. Defaults to calculating the
135
135
  IQR for the entire array.
136
136
 
@@ -435,7 +435,7 @@ def _correlate_arrays(a, b, matlab_strict=False):
435
435
  A 2-D array to correlate with `a`.
436
436
  b : np.ndarray
437
437
  A 2-D array to correlate with `b`.
438
- matlab_strict : bool, optional
438
+ matlab_strict : bool | None
439
439
  Whether or not correlations should be calculated identically to MATLAB
440
440
  PREP (i.e., without mean subtraction) instead of by traditional Pearson
441
441
  product-moment correlation (see Notes for details). Defaults to
@@ -14,7 +14,8 @@ classifiers = [
14
14
  "Programming Language :: Python :: 3.10",
15
15
  "Programming Language :: Python :: 3.11",
16
16
  "Programming Language :: Python :: 3.12",
17
- "Programming Language :: Python :: 3.9",
17
+ "Programming Language :: Python :: 3.13",
18
+ "Programming Language :: Python :: 3.14",
18
19
  "Programming Language :: Python",
19
20
  "Topic :: Scientific/Engineering",
20
21
  ]
@@ -43,7 +44,7 @@ maintainers = [
43
44
  ]
44
45
  name = "pyprep"
45
46
  readme = {content-type = "text/x-rst", file = "README.rst"}
46
- requires-python = ">=3.9"
47
+ requires-python = ">=3.10"
47
48
 
48
49
  [project.optional-dependencies]
49
50
  dev = ["ipykernel", "ipython", "pyprep[test,docs]"]
@@ -86,8 +87,7 @@ exclude = [
86
87
  "/.github/**",
87
88
  "/docs",
88
89
  "/examples",
89
- "matprep_artifacts",
90
- "matprep_artifacts/**",
90
+ "/tools",
91
91
  "tests/**",
92
92
  ]
93
93
 
@@ -103,11 +103,6 @@ addopts = """. --cov=pyprep/ --cov-report=xml --cov-config=pyproject.toml --verb
103
103
  filterwarnings = [
104
104
  ]
105
105
 
106
- [tool.ruff]
107
- extend-exclude = [
108
- "matprep_artifacts/**",
109
- ]
110
-
111
106
  [tool.ruff.lint]
112
107
  ignore = ["A002"]
113
108
  select = ["A", "D", "E", "F", "I", "UP", "W"]
File without changes
File without changes