py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,219 +1,219 @@
1
- """Module for filter functionality."""
2
-
3
- import logging
4
-
5
- logger = logging.getLogger("PynmLogger")
6
-
7
- import mne
8
- from mne.filter import _overlap_add_filter
9
- import numpy as np
10
-
11
-
12
- class MNEFilter:
13
- """mne.filter wrapper
14
-
15
- This class stores for given frequency band ranges the filter
16
- coefficients with length "filter_len".
17
- The filters can then be used sequentially for band power estimation with
18
- apply_filter().
19
- Note that this filter can be a bandpass, bandstop, lowpass, or highpass filter
20
- depending on the frequency ranges given (see further details in mne.filter.create_filter).
21
-
22
- Parameters
23
- ----------
24
- f_ranges : list of lists
25
- Frequency ranges. Inner lists must be of length 2.
26
- sfreq : int | float
27
- Sampling frequency.
28
- filter_length : str, optional
29
- Filter length. Human readable (e.g. "1000ms", "1s"), by default "999ms"
30
- l_trans_bandwidth : int | float | str, optional
31
- Length of the lower transition band or "auto", by default 4
32
- h_trans_bandwidth : int | float | str, optional
33
- Length of the higher transition band or "auto", by default 4
34
- verbose : bool | None, optional
35
- Verbosity level, by default None
36
-
37
- Attributes
38
- ----------
39
- filter_bank: np.ndarray shape (n,)
40
- Factor to upsample by.
41
- """
42
-
43
- def __init__(
44
- self,
45
- f_ranges: list[list[int | float | None]] | list[int | float | None],
46
- sfreq: int | float,
47
- filter_length: str | float = "999ms",
48
- l_trans_bandwidth: int | float | str = 4,
49
- h_trans_bandwidth: int | float | str = 4,
50
- verbose: bool | int | str | None = None,
51
- ) -> None:
52
- filter_bank = []
53
- # mne create_filter function only accepts str and int for filter_length
54
- if isinstance(filter_length, float):
55
- filter_length = int(filter_length)
56
-
57
- if not isinstance(f_ranges[0], list):
58
- f_ranges = [f_ranges]
59
-
60
- for f_range in f_ranges:
61
- try:
62
- filt = mne.filter.create_filter(
63
- None,
64
- sfreq,
65
- l_freq=f_range[0],
66
- h_freq=f_range[1],
67
- fir_design="firwin",
68
- l_trans_bandwidth=l_trans_bandwidth, # type: ignore
69
- h_trans_bandwidth=h_trans_bandwidth, # type: ignore
70
- filter_length=filter_length, # type: ignore
71
- verbose=verbose,
72
- )
73
- except:
74
- filt = mne.filter.create_filter(
75
- None,
76
- sfreq,
77
- l_freq=f_range[0],
78
- h_freq=f_range[1],
79
- fir_design="firwin",
80
- verbose=verbose,
81
- # filter_length=filter_length,
82
- )
83
- filter_bank.append(filt)
84
- self.filter_bank = np.vstack(filter_bank)
85
-
86
- def filter_data(self, data: np.ndarray) -> np.ndarray:
87
- """Apply previously calculated (bandpass) filters to data.
88
-
89
- Parameters
90
- ----------
91
- data : np.ndarray (n_samples, ) or (n_channels, n_samples)
92
- Data to be filtered
93
- filter_bank : np.ndarray, shape (n_fbands, filter_len)
94
- Output of calc_bandpass_filters.
95
-
96
- Returns
97
- -------
98
- np.ndarray, shape (n_channels, n_fbands, n_samples)
99
- Filtered data.
100
-
101
- """
102
- if data.ndim > 2:
103
- raise ValueError(
104
- f"Data must have one or two dimensions. Got:"
105
- f" {data.ndim} dimensions."
106
- )
107
- if data.ndim == 1:
108
- data = np.expand_dims(data, axis=0)
109
-
110
- filtered = np.array(
111
- [
112
- [
113
- np.convolve(filt, chan, mode="same")
114
- for filt in self.filter_bank
115
- ]
116
- for chan in data
117
- ]
118
- )
119
-
120
- # ensure here that the output dimension matches the input dimension
121
- if data.shape[1] != filtered.shape[-1]:
122
- # select the middle part of the filtered data
123
- middle_index = filtered.shape[-1] // 2
124
- filtered = filtered[
125
- :,
126
- :,
127
- middle_index
128
- - data.shape[1] // 2 : middle_index
129
- + data.shape[1] // 2,
130
- ]
131
-
132
- return filtered
133
-
134
-
135
- class NotchFilter:
136
- def __init__(
137
- self,
138
- sfreq: int | float,
139
- line_noise: int | float | None = None,
140
- freqs: np.ndarray | None = None,
141
- notch_widths: int | np.ndarray | None = 3,
142
- trans_bandwidth: int = 6.8,
143
- ) -> None:
144
- if line_noise is None and freqs is None:
145
- raise ValueError(
146
- "Either line_noise or freqs must be defined if notch_filter is"
147
- "activated."
148
- )
149
- if freqs is None:
150
- freqs = np.arange(line_noise, sfreq / 2, line_noise, dtype=int)
151
-
152
- if freqs.size > 0:
153
- if freqs[-1] >= sfreq / 2:
154
- freqs = freqs[:-1]
155
-
156
- # Code is copied from filter.py notch_filter
157
- if freqs.size == 0:
158
- self.filter_bank = None
159
- logger.warning(
160
- "WARNING: notch_filter is activated but data is not being"
161
- f" filtered. This may be due to a low sampling frequency or"
162
- f" incorrect specifications. Make sure your settings are"
163
- f" correct. Got: {sfreq = }, {line_noise = }, {freqs = }."
164
- )
165
- return
166
-
167
- filter_length = int(sfreq - 1)
168
- if notch_widths is None:
169
- notch_widths = freqs / 200.0
170
- elif np.any(notch_widths < 0):
171
- raise ValueError("notch_widths must be >= 0")
172
- else:
173
- notch_widths = np.atleast_1d(notch_widths)
174
- if len(notch_widths) == 1:
175
- notch_widths = notch_widths[0] * np.ones_like(freqs)
176
- elif len(notch_widths) != len(freqs):
177
- raise ValueError(
178
- "notch_widths must be None, scalar, or the "
179
- "same length as freqs"
180
- )
181
-
182
- # Speed this up by computing the fourier coefficients once
183
- tb_half = trans_bandwidth / 2.0
184
- lows = [
185
- freq - nw / 2.0 - tb_half for freq, nw in zip(freqs, notch_widths)
186
- ]
187
- highs = [
188
- freq + nw / 2.0 + tb_half for freq, nw in zip(freqs, notch_widths)
189
- ]
190
-
191
- self.filter_bank = mne.filter.create_filter(
192
- data=None,
193
- sfreq=sfreq,
194
- l_freq=highs,
195
- h_freq=lows,
196
- filter_length=filter_length, # type: ignore
197
- l_trans_bandwidth=tb_half, # type: ignore
198
- h_trans_bandwidth=tb_half, # type: ignore
199
- method="fir",
200
- iir_params=None,
201
- phase="zero",
202
- fir_window="hamming",
203
- fir_design="firwin",
204
- verbose=False,
205
- )
206
-
207
- def process(self, data: np.ndarray) -> np.ndarray:
208
- if self.filter_bank is None:
209
- return data
210
- return _overlap_add_filter(
211
- x=data,
212
- h=self.filter_bank,
213
- n_fft=None,
214
- phase="zero",
215
- picks=None,
216
- n_jobs=1,
217
- copy=True,
218
- pad="reflect_limited",
219
- )
1
+ """Module for filter functionality."""
2
+
3
+ import numpy as np
4
+ from typing import cast
5
+ from collections.abc import Sequence
6
+
7
+ from py_neuromodulation.nm_preprocessing import NMPreprocessor
8
+ from py_neuromodulation import logger
9
+
10
+ from mne.filter import create_filter
11
+
12
+
13
+ class MNEFilter:
14
+ """mne.filter wrapper
15
+
16
+ This class stores for given frequency band ranges the filter
17
+ coefficients with length "filter_len".
18
+ The filters can then be used sequentially for band power estimation with
19
+ apply_filter().
20
+ Note that this filter can be a bandpass, bandstop, lowpass, or highpass filter
21
+ depending on the frequency ranges given (see further details in mne.filter.create_filter).
22
+
23
+ Parameters
24
+ ----------
25
+ f_ranges : list[tuple[float | None, float | None]]
26
+ sfreq : float
27
+ Sampling frequency.
28
+ filter_length : str, optional
29
+ Filter length. Human readable (e.g. "1000ms", "1s"), by default "999ms"
30
+ l_trans_bandwidth : float | str, optional
31
+ Length of the lower transition band or "auto", by default 4
32
+ h_trans_bandwidth : float | str, optional
33
+ Length of the higher transition band or "auto", by default 4
34
+ verbose : bool | None, optional
35
+ Verbosity level, by default None
36
+
37
+ Attributes
38
+ ----------
39
+ filter_bank: np.ndarray shape (n,)
40
+ Factor to upsample by.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ f_ranges: Sequence[tuple[float | None, float | None]],
46
+ sfreq: float,
47
+ filter_length: str | float = "999ms",
48
+ l_trans_bandwidth: float | str = 4,
49
+ h_trans_bandwidth: float | str = 4,
50
+ verbose: bool | int | str | None = None,
51
+ ) -> None:
52
+ filter_bank = []
53
+ # mne create_filter function only accepts str and int for filter_length
54
+ if isinstance(filter_length, float):
55
+ filter_length = int(filter_length)
56
+
57
+ for f_range in f_ranges:
58
+ try:
59
+ filt = create_filter(
60
+ None,
61
+ sfreq,
62
+ l_freq=f_range[0],
63
+ h_freq=f_range[1],
64
+ fir_design="firwin",
65
+ l_trans_bandwidth=l_trans_bandwidth, # type: ignore
66
+ h_trans_bandwidth=h_trans_bandwidth, # type: ignore
67
+ filter_length=filter_length, # type: ignore
68
+ verbose=verbose,
69
+ )
70
+ except ValueError:
71
+ filt = create_filter(
72
+ None,
73
+ sfreq,
74
+ l_freq=f_range[0],
75
+ h_freq=f_range[1],
76
+ fir_design="firwin",
77
+ verbose=verbose,
78
+ # filter_length=filter_length,
79
+ )
80
+ filter_bank.append(filt)
81
+
82
+ self.num_filters = len(filter_bank)
83
+ self.filter_bank = np.vstack(filter_bank)
84
+
85
+ self.filters: np.ndarray
86
+ self.num_channels = -1
87
+
88
+
89
+ def filter_data(self, data: np.ndarray) -> np.ndarray:
90
+ """Apply previously calculated (bandpass) filters to data.
91
+
92
+ Parameters
93
+ ----------
94
+ data : np.ndarray (n_samples, ) or (n_channels, n_samples)
95
+ Data to be filtered
96
+ filter_bank : np.ndarray, shape (n_fbands, filter_len)
97
+ Output of calc_bandpass_filters.
98
+
99
+ Returns
100
+ -------
101
+ np.ndarray, shape (n_channels, n_fbands, n_samples)
102
+ Filtered data.
103
+
104
+ """
105
+ from scipy.signal import fftconvolve
106
+
107
+ if data.ndim > 2:
108
+ raise ValueError(
109
+ f"Data must have one or two dimensions. Got:"
110
+ f" {data.ndim} dimensions."
111
+ )
112
+ if data.ndim == 1:
113
+ data = np.expand_dims(data, axis=0)
114
+
115
+ if self.num_channels == -1:
116
+ self.num_channels = data.shape[0]
117
+ self.filters = np.tile(self.filter_bank[None, :, :], (self.num_channels, 1, 1))
118
+
119
+ data_tiled = np.tile(data[:, None, :], (1, self.num_filters, 1))
120
+
121
+ filtered = fftconvolve(data_tiled, self.filters, axes=2, mode="same")
122
+
123
+ # ensure here that the output dimension matches the input dimension
124
+ if data.shape[1] != filtered.shape[-1]:
125
+ # select the middle part of the filtered data
126
+ middle_index = filtered.shape[-1] // 2
127
+ filtered = filtered[
128
+ :,
129
+ :,
130
+ middle_index - data.shape[1] // 2 : middle_index + data.shape[1] // 2,
131
+ ]
132
+
133
+ return filtered
134
+
135
+
136
+ class NotchFilter(NMPreprocessor):
137
+ def __init__(
138
+ self,
139
+ sfreq: float,
140
+ line_noise: float | None = None,
141
+ freqs: np.ndarray | None = None,
142
+ notch_widths: int | np.ndarray | None = 3,
143
+ trans_bandwidth: float = 6.8,
144
+ ) -> None:
145
+ if line_noise is None and freqs is None:
146
+ raise ValueError(
147
+ "Either line_noise or freqs must be defined if notch_filter is"
148
+ "activated."
149
+ )
150
+
151
+ if freqs is None:
152
+ freqs = np.arange(line_noise, sfreq / 2, line_noise, dtype=int)
153
+
154
+ if freqs.size > 0 and freqs[-1] >= sfreq / 2:
155
+ freqs = freqs[:-1]
156
+
157
+ # Code is copied from filter.py notch_filter
158
+ if freqs.size == 0:
159
+ self.filter_bank = None
160
+ logger.warning(
161
+ "WARNING: notch_filter is activated but data is not being"
162
+ " filtered. This may be due to a low sampling frequency or"
163
+ " incorrect specifications. Make sure your settings are"
164
+ f" correct. Got: {sfreq = }, {line_noise = }, {freqs = }."
165
+ )
166
+ return
167
+
168
+ filter_length = int(sfreq - 1)
169
+ if notch_widths is None:
170
+ notch_widths = freqs / 200.0
171
+ elif np.any(notch_widths < 0):
172
+ raise ValueError("notch_widths must be >= 0")
173
+ else:
174
+ notch_widths = np.atleast_1d(notch_widths)
175
+ if len(notch_widths) == 1:
176
+ notch_widths = notch_widths[0] * np.ones_like(freqs)
177
+ elif len(notch_widths) != len(freqs):
178
+ raise ValueError(
179
+ "notch_widths must be None, scalar, or the " "same length as freqs"
180
+ )
181
+ notch_widths = cast(np.ndarray, notch_widths) # For MyPy only, no runtime cost
182
+
183
+ # Speed this up by computing the fourier coefficients once
184
+ tb_half = trans_bandwidth / 2.0
185
+ lows = [freq - nw / 2.0 - tb_half for freq, nw in zip(freqs, notch_widths)]
186
+ highs = [freq + nw / 2.0 + tb_half for freq, nw in zip(freqs, notch_widths)]
187
+
188
+ self.filter_bank = create_filter(
189
+ data=None,
190
+ sfreq=sfreq,
191
+ l_freq=highs,
192
+ h_freq=lows,
193
+ filter_length=filter_length, # type: ignore
194
+ l_trans_bandwidth=tb_half, # type: ignore
195
+ h_trans_bandwidth=tb_half, # type: ignore
196
+ method="fir",
197
+ iir_params=None,
198
+ phase="zero",
199
+ fir_window="hamming",
200
+ fir_design="firwin",
201
+ verbose=False,
202
+ )
203
+
204
+ def process(self, data: np.ndarray) -> np.ndarray:
205
+ if self.filter_bank is None:
206
+ return data
207
+
208
+ from mne.filter import _overlap_add_filter
209
+
210
+ return _overlap_add_filter(
211
+ x=data,
212
+ h=self.filter_bank,
213
+ n_fft=None,
214
+ phase="zero",
215
+ picks=None,
216
+ n_jobs=1,
217
+ copy=True,
218
+ pad="reflect_limited",
219
+ )
@@ -1,91 +1,79 @@
1
- import numpy as np
2
-
3
- from py_neuromodulation import nm_filter
4
-
5
-
6
- class PreprocessingFilter:
7
-
8
- def __init__(self, settings: dict, sfreq: int | float) -> None:
9
- self.s = settings
10
- self.sfreq = sfreq
11
- self.filters = []
12
-
13
- if self.s["preprocessing_filter"]["bandstop_filter"] is True:
14
- self.filters.append(
15
- nm_filter.MNEFilter(
16
- f_ranges=[
17
- self.s["preprocessing_filter"][
18
- "bandstop_filter_settings"
19
- ]["frequency_high_hz"],
20
- self.s["preprocessing_filter"][
21
- "bandstop_filter_settings"
22
- ]["frequency_low_hz"],
23
- ],
24
- sfreq=self.sfreq,
25
- filter_length=self.sfreq - 1,
26
- verbose=False,
27
- )
28
- )
29
-
30
- if self.s["preprocessing_filter"]["bandpass_filter"] is True:
31
- self.filters.append(
32
- nm_filter.MNEFilter(
33
- f_ranges=[
34
- self.s["preprocessing_filter"][
35
- "bandpass_filter_settings"
36
- ]["frequency_low_hz"],
37
- self.s["preprocessing_filter"][
38
- "bandpass_filter_settings"
39
- ]["frequency_high_hz"],
40
- ],
41
- sfreq=self.sfreq,
42
- filter_length=self.sfreq - 1,
43
- verbose=False,
44
- )
45
- )
46
- if self.s["preprocessing_filter"]["lowpass_filter"] is True:
47
- self.filters.append(
48
- nm_filter.MNEFilter(
49
- f_ranges=[
50
- None,
51
- self.s["preprocessing_filter"][
52
- "lowpass_filter_settings"
53
- ]["frequency_cutoff_hz"],
54
- ],
55
- sfreq=self.sfreq,
56
- filter_length=self.sfreq - 1,
57
- verbose=False,
58
- )
59
- )
60
- if self.s["preprocessing_filter"]["highpass_filter"] is True:
61
- self.filters.append(
62
- nm_filter.MNEFilter(
63
- f_ranges=[
64
- self.s["preprocessing_filter"][
65
- "highpass_filter_settings"
66
- ]["frequency_cutoff_hz"],
67
- None,
68
- ],
69
- sfreq=self.sfreq,
70
- filter_length=self.sfreq - 1,
71
- verbose=False,
72
- )
73
- )
74
-
75
- def process(self, data: np.ndarray) -> np.ndarray:
76
- """Preprocess data according to the initialized list of PreprocessingFilter objects
77
-
78
- Args:
79
- data (numpy ndarray) :
80
- shape(n_channels, n_samples) - data to be preprocessed.
81
-
82
- Returns:
83
- preprocessed_data (numpy ndarray):
84
- shape(n_channels, n_samples) - preprocessed data
85
- """
86
-
87
- for filter in self.filters:
88
- data = filter.filter_data(
89
- data if len(data.shape) == 2 else data[:, 0, :]
90
- )
91
- return data if len(data.shape) == 2 else data[:, 0, :]
1
+ import numpy as np
2
+
3
+ from pydantic import Field
4
+ from typing import TYPE_CHECKING
5
+
6
+ from py_neuromodulation.nm_types import BoolSelector, FrequencyRange
7
+ from py_neuromodulation.nm_preprocessing import NMPreprocessor
8
+
9
+ if TYPE_CHECKING:
10
+ from py_neuromodulation.nm_settings import NMSettings
11
+
12
+
13
+ FILTER_SETTINGS_MAP = {
14
+ "bandstop_filter": "bandstop_filter_settings",
15
+ "bandpass_filter": "bandpass_filter_settings",
16
+ "lowpass_filter": "lowpass_filter_cutoff_hz",
17
+ "highpass_filter": "highpass_filter_cutoff_hz",
18
+ }
19
+
20
+
21
+ class FilterSettings(BoolSelector):
22
+ bandstop_filter: bool = True
23
+ bandpass_filter: bool = True
24
+ lowpass_filter: bool = True
25
+ highpass_filter: bool = True
26
+
27
+ bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160)
28
+ bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200)
29
+ lowpass_filter_cutoff_hz: float = Field(default=200)
30
+ highpass_filter_cutoff_hz: float = Field(default=3)
31
+
32
+ def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]:
33
+ filter_value = self[FILTER_SETTINGS_MAP[filter_name]]
34
+
35
+ match filter_name:
36
+ case "bandstop_filter":
37
+ return (filter_value.frequency_high_hz, filter_value.frequency_low_hz)
38
+ case "bandpass_filter":
39
+ return (filter_value.frequency_low_hz, filter_value.frequency_high_hz)
40
+ case "lowpass_filter":
41
+ return (None, filter_value)
42
+ case "highpass_filter":
43
+ return (filter_value, None)
44
+ case _:
45
+ raise ValueError(
46
+ "Filter name must be one of 'bandstop_filter', 'lowpass_filter', "
47
+ "'highpass_filter', 'bandpass_filter'"
48
+ )
49
+
50
+
51
+ class PreprocessingFilter(NMPreprocessor):
52
+ def __init__(self, settings: "NMSettings", sfreq: float) -> None:
53
+ from py_neuromodulation.nm_filter import MNEFilter
54
+
55
+ self.filters: list[MNEFilter] = [
56
+ MNEFilter(
57
+ f_ranges=[settings.preprocessing_filter.get_filter_tuple(filter_name)], # type: ignore
58
+ sfreq=sfreq,
59
+ filter_length=sfreq - 1,
60
+ verbose=False,
61
+ )
62
+ for filter_name in settings.preprocessing_filter.get_enabled()
63
+ ]
64
+
65
+ def process(self, data: np.ndarray) -> np.ndarray:
66
+ """Preprocess data according to the initialized list of PreprocessingFilter objects
67
+
68
+ Args:
69
+ data (numpy ndarray) :
70
+ shape(n_channels, n_samples) - data to be preprocessed.
71
+
72
+ Returns:
73
+ preprocessed_data (numpy ndarray):
74
+ shape(n_channels, n_samples) - preprocessed data
75
+ """
76
+
77
+ for filter in self.filters:
78
+ data = filter.filter_data(data if len(data.shape) == 2 else data[:, 0, :])
79
+ return data if len(data.shape) == 2 else data[:, 0, :]