py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/__init__.py +80 -13
  5. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
  6. py_neuromodulation/analysis/__init__.py +4 -0
  7. py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
  8. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
  9. py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
  10. py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
  11. py_neuromodulation/data/README +6 -6
  12. py_neuromodulation/data/dataset_description.json +8 -8
  13. py_neuromodulation/data/participants.json +32 -32
  14. py_neuromodulation/data/participants.tsv +2 -2
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  22. py_neuromodulation/default_settings.yaml +241 -0
  23. py_neuromodulation/features/__init__.py +31 -0
  24. py_neuromodulation/features/bandpower.py +165 -0
  25. py_neuromodulation/features/bispectra.py +157 -0
  26. py_neuromodulation/features/bursts.py +297 -0
  27. py_neuromodulation/features/coherence.py +255 -0
  28. py_neuromodulation/features/feature_processor.py +121 -0
  29. py_neuromodulation/features/fooof.py +142 -0
  30. py_neuromodulation/features/hjorth_raw.py +57 -0
  31. py_neuromodulation/features/linelength.py +21 -0
  32. py_neuromodulation/features/mne_connectivity.py +148 -0
  33. py_neuromodulation/features/nolds.py +94 -0
  34. py_neuromodulation/features/oscillatory.py +249 -0
  35. py_neuromodulation/features/sharpwaves.py +432 -0
  36. py_neuromodulation/filter/__init__.py +3 -0
  37. py_neuromodulation/filter/kalman_filter.py +67 -0
  38. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  39. py_neuromodulation/filter/mne_filter.py +128 -0
  40. py_neuromodulation/filter/notch_filter.py +93 -0
  41. py_neuromodulation/grid_cortex.tsv +40 -40
  42. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  43. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  44. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  45. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  46. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  47. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  48. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  49. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  50. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  51. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  52. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  53. py_neuromodulation/processing/__init__.py +10 -0
  54. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
  55. py_neuromodulation/processing/data_preprocessor.py +77 -0
  56. py_neuromodulation/processing/filter_preprocessing.py +78 -0
  57. py_neuromodulation/processing/normalization.py +175 -0
  58. py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
  59. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
  60. py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
  61. py_neuromodulation/stream/__init__.py +3 -0
  62. py_neuromodulation/stream/data_processor.py +325 -0
  63. py_neuromodulation/stream/generator.py +53 -0
  64. py_neuromodulation/stream/mnelsl_player.py +94 -0
  65. py_neuromodulation/stream/mnelsl_stream.py +120 -0
  66. py_neuromodulation/stream/settings.py +292 -0
  67. py_neuromodulation/stream/stream.py +427 -0
  68. py_neuromodulation/utils/__init__.py +2 -0
  69. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
  70. py_neuromodulation/utils/database.py +149 -0
  71. py_neuromodulation/utils/io.py +378 -0
  72. py_neuromodulation/utils/keyboard.py +52 -0
  73. py_neuromodulation/utils/logging.py +66 -0
  74. py_neuromodulation/utils/types.py +251 -0
  75. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
  76. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  77. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
  78. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
  79. py_neuromodulation/FieldTrip.py +0 -589
  80. py_neuromodulation/_write_example_dataset_helper.py +0 -65
  81. py_neuromodulation/nm_EpochStream.py +0 -92
  82. py_neuromodulation/nm_IO.py +0 -417
  83. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  84. py_neuromodulation/nm_bispectra.py +0 -168
  85. py_neuromodulation/nm_bursts.py +0 -198
  86. py_neuromodulation/nm_coherence.py +0 -205
  87. py_neuromodulation/nm_cohortwrapper.py +0 -435
  88. py_neuromodulation/nm_eval_timing.py +0 -239
  89. py_neuromodulation/nm_features.py +0 -116
  90. py_neuromodulation/nm_features_abc.py +0 -39
  91. py_neuromodulation/nm_filter.py +0 -219
  92. py_neuromodulation/nm_filter_preprocessing.py +0 -91
  93. py_neuromodulation/nm_fooof.py +0 -159
  94. py_neuromodulation/nm_generator.py +0 -37
  95. py_neuromodulation/nm_hjorth_raw.py +0 -73
  96. py_neuromodulation/nm_kalmanfilter.py +0 -58
  97. py_neuromodulation/nm_linelength.py +0 -33
  98. py_neuromodulation/nm_mne_connectivity.py +0 -112
  99. py_neuromodulation/nm_nolds.py +0 -93
  100. py_neuromodulation/nm_normalization.py +0 -214
  101. py_neuromodulation/nm_oscillatory.py +0 -448
  102. py_neuromodulation/nm_run_analysis.py +0 -435
  103. py_neuromodulation/nm_settings.json +0 -338
  104. py_neuromodulation/nm_settings.py +0 -68
  105. py_neuromodulation/nm_sharpwaves.py +0 -401
  106. py_neuromodulation/nm_stream_abc.py +0 -218
  107. py_neuromodulation/nm_stream_offline.py +0 -359
  108. py_neuromodulation/utils/_logging.py +0 -24
  109. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -0,0 +1,128 @@
1
+ import numpy as np
2
+ from collections.abc import Sequence
3
+
4
+
5
+ class MNEFilter:
6
+ """mne.filter wrapper
7
+
8
+ This class stores for given frequency band ranges the filter
9
+ coefficients with length "filter_len".
10
+ The filters can then be used sequentially for band power estimation with
11
+ apply_filter().
12
+ Note that this filter can be a bandpass, bandstop, lowpass, or highpass filter
13
+ depending on the frequency ranges given (see further details in mne.filter.create_filter).
14
+
15
+ Parameters
16
+ ----------
17
+ f_ranges : list[tuple[float | None, float | None]]
18
+ sfreq : float
19
+ Sampling frequency.
20
+ filter_length : str, optional
21
+ Filter length. Human readable (e.g. "1000ms", "1s"), by default "999ms"
22
+ l_trans_bandwidth : float | str, optional
23
+ Length of the lower transition band or "auto", by default 4
24
+ h_trans_bandwidth : float | str, optional
25
+ Length of the higher transition band or "auto", by default 4
26
+ verbose : bool | None, optional
27
+ Verbosity level, by default None
28
+
29
+ Attributes
30
+ ----------
31
+ filter_bank: np.ndarray shape (n,)
32
+ Factor to upsample by.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ f_ranges: Sequence[tuple[float | None, float | None]],
38
+ sfreq: float,
39
+ filter_length: str | float = "999ms",
40
+ l_trans_bandwidth: float | str = 4,
41
+ h_trans_bandwidth: float | str = 4,
42
+ verbose: bool | int | str | None = None,
43
+ ) -> None:
44
+ from mne.filter import create_filter
45
+
46
+ filter_bank = []
47
+ # mne create_filter function only accepts str and int for filter_length
48
+ if isinstance(filter_length, float):
49
+ filter_length = int(filter_length)
50
+
51
+ for f_range in f_ranges:
52
+ try:
53
+ filt = create_filter(
54
+ None,
55
+ sfreq,
56
+ l_freq=f_range[0],
57
+ h_freq=f_range[1],
58
+ fir_design="firwin",
59
+ l_trans_bandwidth=l_trans_bandwidth, # type: ignore
60
+ h_trans_bandwidth=h_trans_bandwidth, # type: ignore
61
+ filter_length=filter_length, # type: ignore
62
+ verbose=verbose,
63
+ )
64
+ except ValueError:
65
+ filt = create_filter(
66
+ None,
67
+ sfreq,
68
+ l_freq=f_range[0],
69
+ h_freq=f_range[1],
70
+ fir_design="firwin",
71
+ verbose=verbose,
72
+ # filter_length=filter_length,
73
+ )
74
+ filter_bank.append(filt)
75
+
76
+ self.num_filters = len(filter_bank)
77
+ self.filter_bank = np.vstack(filter_bank)
78
+
79
+ self.filters: np.ndarray
80
+ self.num_channels = -1
81
+
82
+ def filter_data(self, data: np.ndarray) -> np.ndarray:
83
+ """Apply previously calculated (bandpass) filters to data.
84
+
85
+ Parameters
86
+ ----------
87
+ data : np.ndarray (n_samples, ) or (n_channels, n_samples)
88
+ Data to be filtered
89
+ filter_bank : np.ndarray, shape (n_fbands, filter_len)
90
+ Output of calc_bandpass_filters.
91
+
92
+ Returns
93
+ -------
94
+ np.ndarray, shape (n_channels, n_fbands, n_samples)
95
+ Filtered data.
96
+
97
+ """
98
+ from scipy.signal import fftconvolve
99
+
100
+ if data.ndim > 2:
101
+ raise ValueError(
102
+ f"Data must have one or two dimensions. Got:"
103
+ f" {data.ndim} dimensions."
104
+ )
105
+ if data.ndim == 1:
106
+ data = np.expand_dims(data, axis=0)
107
+
108
+ if self.num_channels == -1:
109
+ self.num_channels = data.shape[0]
110
+ self.filters = np.tile(
111
+ self.filter_bank[None, :, :], (self.num_channels, 1, 1)
112
+ )
113
+
114
+ data_tiled = np.tile(data[:, None, :], (1, self.num_filters, 1))
115
+
116
+ filtered = fftconvolve(data_tiled, self.filters, axes=2, mode="same")
117
+
118
+ # ensure here that the output dimension matches the input dimension
119
+ if data.shape[1] != filtered.shape[-1]:
120
+ # select the middle part of the filtered data
121
+ middle_index = filtered.shape[-1] // 2
122
+ filtered = filtered[
123
+ :,
124
+ :,
125
+ middle_index - data.shape[1] // 2 : middle_index + data.shape[1] // 2,
126
+ ]
127
+
128
+ return filtered
@@ -0,0 +1,93 @@
1
+ import numpy as np
2
+ from typing import cast
3
+
4
+ from py_neuromodulation.utils.types import NMPreprocessor
5
+ from py_neuromodulation import logger
6
+
7
+
8
+ class NotchFilter(NMPreprocessor):
9
+ def __init__(
10
+ self,
11
+ sfreq: float,
12
+ line_noise: float | None = None,
13
+ freqs: np.ndarray | None = None,
14
+ notch_widths: int | np.ndarray | None = 3,
15
+ trans_bandwidth: float = 6.8,
16
+ ) -> None:
17
+ from mne.filter import create_filter
18
+
19
+ if line_noise is None and freqs is None:
20
+ raise ValueError(
21
+ "Either line_noise or freqs must be defined if notch_filter is"
22
+ "activated."
23
+ )
24
+
25
+ if freqs is None:
26
+ freqs = np.arange(line_noise, sfreq / 2, line_noise, dtype=int)
27
+
28
+ if freqs.size > 0 and freqs[-1] >= sfreq / 2:
29
+ freqs = freqs[:-1]
30
+
31
+ # Code is copied from filter.py notch_filter
32
+ if freqs.size == 0:
33
+ self.filter_bank = None
34
+ logger.warning(
35
+ "WARNING: notch_filter is activated but data is not being"
36
+ " filtered. This may be due to a low sampling frequency or"
37
+ " incorrect specifications. Make sure your settings are"
38
+ f" correct. Got: {sfreq = }, {line_noise = }, {freqs = }."
39
+ )
40
+ return
41
+
42
+ filter_length = int(sfreq - 1)
43
+ if notch_widths is None:
44
+ notch_widths = freqs / 200.0
45
+ elif np.any(notch_widths < 0):
46
+ raise ValueError("notch_widths must be >= 0")
47
+ else:
48
+ notch_widths = np.atleast_1d(notch_widths)
49
+ if len(notch_widths) == 1:
50
+ notch_widths = notch_widths[0] * np.ones_like(freqs)
51
+ elif len(notch_widths) != len(freqs):
52
+ raise ValueError(
53
+ "notch_widths must be None, scalar, or the " "same length as freqs"
54
+ )
55
+ notch_widths = cast(np.ndarray, notch_widths) # For MyPy only, no runtime cost
56
+
57
+ # Speed this up by computing the fourier coefficients once
58
+ tb_half = trans_bandwidth / 2.0
59
+ lows = [freq - nw / 2.0 - tb_half for freq, nw in zip(freqs, notch_widths)]
60
+ highs = [freq + nw / 2.0 + tb_half for freq, nw in zip(freqs, notch_widths)]
61
+
62
+ self.filter_bank = create_filter(
63
+ data=None,
64
+ sfreq=sfreq,
65
+ l_freq=highs,
66
+ h_freq=lows,
67
+ filter_length=filter_length, # type: ignore
68
+ l_trans_bandwidth=tb_half, # type: ignore
69
+ h_trans_bandwidth=tb_half, # type: ignore
70
+ method="fir",
71
+ iir_params=None,
72
+ phase="zero",
73
+ fir_window="hamming",
74
+ fir_design="firwin",
75
+ verbose=False,
76
+ )
77
+
78
+ def process(self, data: np.ndarray) -> np.ndarray:
79
+ if self.filter_bank is None:
80
+ return data
81
+
82
+ from mne.filter import _overlap_add_filter
83
+
84
+ return _overlap_add_filter(
85
+ x=data,
86
+ h=self.filter_bank,
87
+ n_fft=None,
88
+ phase="zero",
89
+ picks=None,
90
+ n_jobs=1,
91
+ copy=True,
92
+ pad="reflect_limited",
93
+ )
@@ -1,40 +1,40 @@
1
- x y z
2
- -13.1 -67.7 69.1
3
- -35.5 -60.0 66.0
4
- -48.3 -55.1 58.2
5
- -60.0 -51.8 48.0
6
- -16.9 -51.6 78.0
7
- -34.8 -49.3 71.7
8
- -67.5 -47.1 31.0
9
- -46.1 -43.7 61.1
10
- -59.8 -39.6 53.3
11
- -14.2 -39.1 81.1
12
- -28.3 -31.2 76.0
13
- -42.3 -30.7 70.2
14
- -67.6 -30.1 41.2
15
- -50.5 -24.4 64.4
16
- -14.6 -22.7 80.2
17
- -60.9 -18.7 50.9
18
- -31.6 -16.9 75.2
19
- -5.1 -12.6 77.3
20
- -65.6 -10.8 37.8
21
- -41.8 -10.2 67.0
22
- -55.1 -4.01 53.2
23
- -22.7 1.2 72.0
24
- -5.8 2.8 74.8
25
- -49.2 3.7 54.7
26
- -34.5 3.9 66.5
27
- -61.55 6.2 35.9
28
- -63.6 8.3 25.7
29
- -40.4 11.8 60.7
30
- -48.7 14.5 50.5
31
- -21.8 16.0 68.9
32
- -58.2 18.2 27.3
33
- -7.0 18.4 70.3
34
- -36.3 19.9 59.6
35
- -48.1 24.6 44.0
36
- -56.8 28.52 20.8
37
- -7.3 33.8 61.7
38
- -22.2 35.0 57.2
39
- -36.8 35.4 47.0
40
- -46.8 35.6 36.0
1
+ x y z
2
+ -13.1 -67.7 69.1
3
+ -35.5 -60.0 66.0
4
+ -48.3 -55.1 58.2
5
+ -60.0 -51.8 48.0
6
+ -16.9 -51.6 78.0
7
+ -34.8 -49.3 71.7
8
+ -67.5 -47.1 31.0
9
+ -46.1 -43.7 61.1
10
+ -59.8 -39.6 53.3
11
+ -14.2 -39.1 81.1
12
+ -28.3 -31.2 76.0
13
+ -42.3 -30.7 70.2
14
+ -67.6 -30.1 41.2
15
+ -50.5 -24.4 64.4
16
+ -14.6 -22.7 80.2
17
+ -60.9 -18.7 50.9
18
+ -31.6 -16.9 75.2
19
+ -5.1 -12.6 77.3
20
+ -65.6 -10.8 37.8
21
+ -41.8 -10.2 67.0
22
+ -55.1 -4.01 53.2
23
+ -22.7 1.2 72.0
24
+ -5.8 2.8 74.8
25
+ -49.2 3.7 54.7
26
+ -34.5 3.9 66.5
27
+ -61.55 6.2 35.9
28
+ -63.6 8.3 25.7
29
+ -40.4 11.8 60.7
30
+ -48.7 14.5 50.5
31
+ -21.8 16.0 68.9
32
+ -58.2 18.2 27.3
33
+ -7.0 18.4 70.3
34
+ -36.3 19.9 59.6
35
+ -48.1 24.6 44.0
36
+ -56.8 28.52 20.8
37
+ -7.3 33.8 61.7
38
+ -22.2 35.0 57.2
39
+ -36.8 35.4 47.0
40
+ -46.8 35.6 36.0
@@ -0,0 +1,10 @@
1
+ from .artifacts import PARRMArtifactRejection
2
+ from .data_preprocessor import DataPreprocessor
3
+ from .projection import Projection, ProjectionSettings
4
+ from .normalization import FeatureNormalizer, RawNormalizer, NormalizationSettings
5
+ from .resample import Resampler, ResamplerSettings
6
+ from .rereference import ReReferencer
7
+ from .filter_preprocessing import PreprocessingFilter, FilterSettings
8
+
9
+ # Expose Notch filter also in the processing module, as it is used as a data preprocessing step
10
+ from py_neuromodulation.filter import NotchFilter
@@ -1,25 +1,29 @@
1
- from pyparrm import PARRM
2
-
3
-
4
- class PARRMArtifactRejection:
5
- def __init__(self, data, sampling_freq, artefact_freq, verbose=False):
6
- self.data = data
7
- self.sampling_freq = sampling_freq
8
- self.artefact_freq = artefact_freq
9
- self.verbose = verbose
10
-
11
- self.parrm = PARRM(
12
- data=data,
13
- sampling_freq=sampling_freq,
14
- artefact_freq=artefact_freq,
15
- verbose=False,
16
- )
17
-
18
- def filter_data(self):
19
- self.parrm.find_period()
20
- self.parrm.create_filter(
21
- filter_direction="both",
22
- )
23
- filtered_data = self.parrm.filter_data()
24
-
25
- return filtered_data
1
+ class PARRMArtifactRejection:
2
+ """
3
+ This module enables training of a PARRM filter before computation,
4
+ that can in real-time then be applied.
5
+ https://pyparrm.readthedocs.io/en/stable/
6
+ """
7
+ def __init__(self, data, sampling_freq, artefact_freq, verbose=False):
8
+ from pyparrm import PARRM
9
+
10
+ self.data = data
11
+ self.sampling_freq = sampling_freq
12
+ self.artefact_freq = artefact_freq
13
+ self.verbose = verbose
14
+
15
+ self.parrm = PARRM(
16
+ data=data,
17
+ sampling_freq=sampling_freq,
18
+ artefact_freq=artefact_freq,
19
+ verbose=False,
20
+ )
21
+
22
+ def filter_data(self):
23
+ self.parrm.find_period()
24
+ self.parrm.create_filter(
25
+ filter_direction="both",
26
+ )
27
+ filtered_data = self.parrm.filter_data()
28
+
29
+ return filtered_data
@@ -0,0 +1,77 @@
1
+ from typing import TYPE_CHECKING, Type
2
+ from py_neuromodulation.utils.types import PreprocessorName, NMPreprocessor
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+ import pandas as pd
7
+ from py_neuromodulation.stream.settings import NMSettings
8
+
9
+ PREPROCESSOR_DICT: dict[PreprocessorName, str] = {
10
+ "preprocessing_filter": "PreprocessingFilter",
11
+ "notch_filter": "NotchFilter",
12
+ "raw_resampling": "Resampler",
13
+ "re_referencing": "ReReferencer",
14
+ "raw_normalization": "RawNormalizer",
15
+ }
16
+
17
+
18
+ class DataPreprocessor:
19
+ "Class for initializing and holding data preprocessing classes"
20
+
21
+ def __init__(
22
+ self,
23
+ settings: "NMSettings",
24
+ channels: "pd.DataFrame",
25
+ sfreq: float,
26
+ line_noise: float | None = None,
27
+ ) -> None:
28
+ from importlib import import_module
29
+ from inspect import getfullargspec
30
+
31
+ possible_arguments = {
32
+ "sfreq": sfreq,
33
+ "settings": settings,
34
+ "channels": channels,
35
+ "line_noise": line_noise,
36
+ }
37
+
38
+ for preprocessing_method in settings.preprocessing:
39
+ if preprocessing_method not in PREPROCESSOR_DICT.keys():
40
+ raise ValueError(
41
+ f"Invalid preprocessing method '{preprocessing_method}'. Must be one of {PREPROCESSOR_DICT.keys()}"
42
+ )
43
+
44
+ # Get needed preprocessor classes from settings
45
+ preprocessor_classes: dict[str, Type[NMPreprocessor]] = {
46
+ preprocessor_name: getattr(
47
+ import_module("py_neuromodulation.processing"), class_name
48
+ )
49
+ for preprocessor_name, class_name in PREPROCESSOR_DICT.items()
50
+ if preprocessor_name in settings.preprocessing
51
+ }
52
+
53
+ # Function to instantiate preprocessor with settings
54
+ def instantiate_preprocessor(
55
+ preprocessor_class: Type[NMPreprocessor], preprocessor_name: str
56
+ ) -> NMPreprocessor:
57
+ settings_str = f"{preprocessor_name}_settings"
58
+ # Filter out arguments that are not in the preprocessor's __init__ method
59
+ args = {
60
+ arg: possible_arguments[arg]
61
+ for arg in getfullargspec(preprocessor_class).args
62
+ if arg in possible_arguments
63
+ }
64
+ # Retrieve more possible arguments from settings
65
+ args |= getattr(settings, settings_str, {})
66
+ # Pass arguments to preprocessor class and return instance
67
+ return preprocessor_class(**args)
68
+
69
+ self.preprocessors: list[NMPreprocessor] = [
70
+ instantiate_preprocessor(preprocessor_class, preprocessor_name)
71
+ for preprocessor_name, preprocessor_class in preprocessor_classes.items()
72
+ ]
73
+
74
+ def process_data(self, data: "np.ndarray") -> "np.ndarray":
75
+ for preprocessor in self.preprocessors:
76
+ data = preprocessor.process(data)
77
+ return data
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+
3
+ from pydantic import Field
4
+ from typing import TYPE_CHECKING
5
+
6
+ from py_neuromodulation.utils.types import BoolSelector, FrequencyRange, NMPreprocessor
7
+
8
+ if TYPE_CHECKING:
9
+ from py_neuromodulation import NMSettings
10
+
11
+
12
+ FILTER_SETTINGS_MAP = {
13
+ "bandstop_filter": "bandstop_filter_settings",
14
+ "bandpass_filter": "bandpass_filter_settings",
15
+ "lowpass_filter": "lowpass_filter_cutoff_hz",
16
+ "highpass_filter": "highpass_filter_cutoff_hz",
17
+ }
18
+
19
+
20
+ class FilterSettings(BoolSelector):
21
+ bandstop_filter: bool = True
22
+ bandpass_filter: bool = True
23
+ lowpass_filter: bool = True
24
+ highpass_filter: bool = True
25
+
26
+ bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160)
27
+ bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200)
28
+ lowpass_filter_cutoff_hz: float = Field(default=200)
29
+ highpass_filter_cutoff_hz: float = Field(default=3)
30
+
31
+ def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]:
32
+ filter_value = self[FILTER_SETTINGS_MAP[filter_name]]
33
+
34
+ match filter_name:
35
+ case "bandstop_filter":
36
+ return (filter_value.frequency_high_hz, filter_value.frequency_low_hz)
37
+ case "bandpass_filter":
38
+ return (filter_value.frequency_low_hz, filter_value.frequency_high_hz)
39
+ case "lowpass_filter":
40
+ return (None, filter_value)
41
+ case "highpass_filter":
42
+ return (filter_value, None)
43
+ case _:
44
+ raise ValueError(
45
+ "Filter name must be one of 'bandstop_filter', 'lowpass_filter', "
46
+ "'highpass_filter', 'bandpass_filter'"
47
+ )
48
+
49
+
50
+ class PreprocessingFilter(NMPreprocessor):
51
+ def __init__(self, settings: "NMSettings", sfreq: float) -> None:
52
+ from py_neuromodulation.filter import MNEFilter
53
+
54
+ self.filters: list[MNEFilter] = [
55
+ MNEFilter(
56
+ f_ranges=[settings.preprocessing_filter.get_filter_tuple(filter_name)], # type: ignore
57
+ sfreq=sfreq,
58
+ filter_length=sfreq - 1,
59
+ verbose=False,
60
+ )
61
+ for filter_name in settings.preprocessing_filter.get_enabled()
62
+ ]
63
+
64
+ def process(self, data: np.ndarray) -> np.ndarray:
65
+ """Preprocess data according to the initialized list of PreprocessingFilter objects
66
+
67
+ Args:
68
+ data (numpy ndarray) :
69
+ shape(n_channels, n_samples) - data to be preprocessed.
70
+
71
+ Returns:
72
+ preprocessed_data (numpy ndarray):
73
+ shape(n_channels, n_samples) - preprocessed data
74
+ """
75
+
76
+ for filter in self.filters:
77
+ data = filter.filter_data(data if len(data.shape) == 2 else data[:, 0, :])
78
+ return data if len(data.shape) == 2 else data[:, 0, :]