py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,159 +1,139 @@
1
- import logging
2
- from typing import Iterable
3
-
4
- import numpy as np
5
- from fooof import FOOOF
6
- from scipy import fft
7
-
8
- from py_neuromodulation import nm_features_abc
9
-
10
-
11
- class FooofAnalyzer(nm_features_abc.Feature):
12
- def __init__(
13
- self, settings: dict, ch_names: Iterable[str], sfreq: float
14
- ) -> None:
15
- self.settings_fooof = settings["fooof"]
16
- self.sfreq = sfreq
17
- self.ch_names = ch_names
18
-
19
- self.freq_range = self.settings_fooof["freq_range_hz"]
20
- self.ap_mode = "knee" if self.settings_fooof["knee"] else "fixed"
21
- self.max_n_peaks = self.settings_fooof["max_n_peaks"]
22
-
23
- self.num_samples = int(
24
- self.settings_fooof["windowlength_ms"] * sfreq / 1000
25
- )
26
-
27
- self.f_vec = np.arange(0, int(self.num_samples / 2) + 1, 1)
28
-
29
- def test_settings(
30
- s: dict,
31
- ch_names: Iterable[str],
32
- sfreq: int | float,
33
- ):
34
- assert isinstance(s["fooof"]["aperiodic"]["exponent"], bool)
35
- assert isinstance(s["fooof"]["aperiodic"]["offset"], bool)
36
- assert isinstance(s["fooof"]["aperiodic"]["knee"], bool)
37
- assert isinstance(s["fooof"]["periodic"]["center_frequency"], bool)
38
- assert isinstance(s["fooof"]["periodic"]["band_width"], bool)
39
- assert isinstance(s["fooof"]["periodic"]["height_over_ap"], bool)
40
- assert isinstance(s["fooof"]["knee"], bool)
41
- assert isinstance(s["fooof"]["windowlength_ms"], (int, float))
42
- assert (
43
- s["fooof"]["windowlength_ms"] <= s["segment_length_features_ms"]
44
- ), (
45
- "fooof windowlength_ms needs to be smaller equal than segment_length_features_ms "
46
- f"got windowlength_ms: {s['fooof']['windowlength_ms']} and {s['segment_length_features_ms']}"
47
- )
48
-
49
- assert (
50
- s["fooof"]["freq_range_hz"][0] < sfreq
51
- and s["fooof"]["freq_range_hz"][1] < sfreq
52
- ), f"fooof frequency range needs to be below sfreq, got {s['fooof']['freq_range_hz']}"
53
-
54
- def _get_spectrum(self, data: np.array):
55
- """return absolute value fft spectrum"""
56
-
57
- data = data[-self.num_samples :]
58
- Z = np.abs(fft.rfft(data))
59
-
60
- return Z
61
-
62
- def calc_feature(
63
- self,
64
- data: np.array,
65
- features_compute: dict,
66
- ) -> dict:
67
- for ch_idx, ch_name in enumerate(self.ch_names):
68
- spectrum = self._get_spectrum(data[ch_idx, :])
69
-
70
- try:
71
- fm = FOOOF(
72
- aperiodic_mode=self.ap_mode,
73
- peak_width_limits=self.settings_fooof["peak_width_limits"],
74
- max_n_peaks=self.settings_fooof["max_n_peaks"],
75
- min_peak_height=self.settings_fooof["min_peak_height"],
76
- peak_threshold=self.settings_fooof["peak_threshold"],
77
- verbose=False,
78
- )
79
- fm.fit(self.f_vec, spectrum, self.freq_range)
80
- except Exception as e:
81
- logging.critical(e, exc_info=True)
82
-
83
- if fm.fooofed_spectrum_ is None:
84
- FIT_PASSED = False
85
- else:
86
- FIT_PASSED = True
87
-
88
- if self.settings_fooof["aperiodic"]["exponent"]:
89
- features_compute[f"{ch_name}_fooof_a_exp"] = (
90
- np.nan_to_num(fm.get_params("aperiodic_params", "exponent"))
91
- if FIT_PASSED is True
92
- else None
93
- )
94
-
95
- if self.settings_fooof["aperiodic"]["offset"]:
96
- features_compute[f"{ch_name}_fooof_a_offset"] = (
97
- np.nan_to_num(fm.get_params("aperiodic_params", "offset"))
98
- if FIT_PASSED is True
99
- else None
100
- )
101
-
102
- if self.settings_fooof["aperiodic"]["knee"]:
103
- if FIT_PASSED is False:
104
- knee_freq = None
105
- else:
106
- if fm.get_params("aperiodic_params", "exponent") != 0:
107
- knee_fooof = fm.get_params("aperiodic_params", "knee")
108
- knee_freq = np.nan_to_num(
109
- knee_fooof
110
- ** (
111
- 1
112
- / fm.get_params("aperiodic_params", "exponent")
113
- )
114
- )
115
- else:
116
- knee_freq = None
117
-
118
- features_compute[f"{ch_name}_fooof_a_knee_frequency"] = (
119
- knee_freq
120
- )
121
-
122
- peaks_bw = (
123
- fm.get_params("peak_params", "BW")
124
- if FIT_PASSED is True
125
- else None
126
- )
127
- peaks_cf = (
128
- fm.get_params("peak_params", "CF")
129
- if FIT_PASSED is True
130
- else None
131
- )
132
- peaks_pw = (
133
- fm.get_params("peak_params", "PW")
134
- if FIT_PASSED is True
135
- else None
136
- )
137
-
138
- if type(peaks_bw) is np.float64 or peaks_bw is None:
139
- peaks_bw = [peaks_bw]
140
- peaks_cf = [peaks_cf]
141
- peaks_pw = [peaks_pw]
142
-
143
- for peak_idx in range(self.max_n_peaks):
144
- if self.settings_fooof["periodic"]["band_width"]:
145
- features_compute[f"{ch_name}_fooof_p_{peak_idx}_bw"] = (
146
- peaks_bw[peak_idx] if peak_idx < len(peaks_bw) else None
147
- )
148
-
149
- if self.settings_fooof["periodic"]["center_frequency"]:
150
- features_compute[f"{ch_name}_fooof_p_{peak_idx}_cf"] = (
151
- peaks_cf[peak_idx] if peak_idx < len(peaks_bw) else None
152
- )
153
-
154
- if self.settings_fooof["periodic"]["height_over_ap"]:
155
- features_compute[f"{ch_name}_fooof_p_{peak_idx}_pw"] = (
156
- peaks_pw[peak_idx] if peak_idx < len(peaks_bw) else None
157
- )
158
-
159
- return features_compute
1
+ from collections.abc import Iterable
2
+ import numpy as np
3
+
4
+ from typing import TYPE_CHECKING
5
+ from py_neuromodulation.nm_types import NMBaseModel
6
+
7
+ from py_neuromodulation.nm_features import NMFeature
8
+ from py_neuromodulation.nm_types import BoolSelector, FrequencyRange
9
+
10
+ if TYPE_CHECKING:
11
+ from py_neuromodulation.nm_settings import NMSettings
12
+
13
+
14
+ class FooofAperiodicSettings(BoolSelector):
15
+ exponent: bool = True
16
+ offset: bool = True
17
+ knee: bool = True
18
+
19
+
20
+ class FooofPeriodicSettings(BoolSelector):
21
+ center_frequency: bool = False
22
+ band_width: bool = False
23
+ height_over_ap: bool = False
24
+
25
+
26
+ class FooofSettings(NMBaseModel):
27
+ aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
28
+ periodic: FooofPeriodicSettings = FooofPeriodicSettings()
29
+ windowlength_ms: float = 800
30
+ peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
31
+ max_n_peaks: int = 3
32
+ min_peak_height: float = 0
33
+ peak_threshold: float = 2
34
+ freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
35
+ knee: bool = True
36
+
37
+
38
+ class FooofAnalyzer(NMFeature):
39
+ feat_name_map = {
40
+ "exponent": "exp",
41
+ "offset": "offset",
42
+ "knee": "knee_frequency",
43
+ "center_frequency": "cf",
44
+ "band_width": "bw",
45
+ "height_over_ap": "pw",
46
+ }
47
+
48
+ def __init__(
49
+ self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
50
+ ) -> None:
51
+ self.settings = settings.fooof
52
+ self.sfreq = sfreq
53
+ self.ch_names = ch_names
54
+
55
+ self.ap_mode = "knee" if self.settings.knee else "fixed"
56
+
57
+ self.num_samples = int(self.settings.windowlength_ms * sfreq / 1000)
58
+
59
+ self.f_vec = np.arange(0, int(self.num_samples / 2) + 1, 1)
60
+
61
+ assert (
62
+ settings.fooof.windowlength_ms <= settings.segment_length_features_ms
63
+ ), f"fooof windowlength_ms ({settings.fooof.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
64
+
65
+ assert (
66
+ settings.fooof.freq_range_hz[0] < sfreq
67
+ and settings.fooof.freq_range_hz[1] < sfreq
68
+ ), f"fooof frequency range needs to be below sfreq, got {settings.fooof.freq_range_hz}"
69
+
70
+ from fooof import FOOOFGroup
71
+
72
+ self.fm = FOOOFGroup(
73
+ aperiodic_mode=self.ap_mode,
74
+ peak_width_limits=tuple(self.settings.peak_width_limits),
75
+ max_n_peaks=self.settings.max_n_peaks,
76
+ min_peak_height=self.settings.min_peak_height,
77
+ peak_threshold=self.settings.peak_threshold,
78
+ verbose=False,
79
+ )
80
+
81
+ def calc_feature(self, data: np.ndarray) -> dict:
82
+ from scipy.fft import rfft
83
+
84
+ spectra = np.abs(rfft(data[:, -self.num_samples :])) # type: ignore
85
+
86
+ self.fm.fit(self.f_vec, spectra, self.settings.freq_range_hz)
87
+
88
+ if not self.fm.has_model or self.fm.null_inds_ is None:
89
+ raise RuntimeError("FOOOF failed to fit model to data.")
90
+
91
+ failed_fits: list[int] = self.fm.null_inds_
92
+
93
+ feature_results = {}
94
+ for ch_idx, ch_name in enumerate(self.ch_names):
95
+ FIT_PASSED = ch_idx not in failed_fits
96
+ exp = self.fm.get_params("aperiodic_params", "exponent")[ch_idx]
97
+
98
+ for feat in self.settings.aperiodic.get_enabled():
99
+ f_name = f"{ch_name}_fooof_a_{self.feat_name_map[feat]}"
100
+
101
+ if not FIT_PASSED:
102
+ feature_results[f_name] = None
103
+
104
+ elif feat == "knee" and exp == 0:
105
+ feature_results[f_name] = None
106
+
107
+ else:
108
+ params = self.fm.get_params("aperiodic_params", feat)[ch_idx]
109
+ if feat == "knee":
110
+ # If knee parameter is negative, set knee frequency to 0
111
+ if params < 0:
112
+ params = 0
113
+ else:
114
+ params = params ** (1 / exp)
115
+
116
+ feature_results[f_name] = np.nan_to_num(params)
117
+
118
+ peaks_dict: dict[str, np.ndarray | None] = {
119
+ "bw": self.fm.get_params("peak_params", "BW") if FIT_PASSED else None,
120
+ "cf": self.fm.get_params("peak_params", "CF") if FIT_PASSED else None,
121
+ "pw": self.fm.get_params("peak_params", "PW") if FIT_PASSED else None,
122
+ }
123
+
124
+ if type(peaks_dict["bw"]) is np.float64 or peaks_dict["bw"] is None:
125
+ peaks_dict["bw"] = [peaks_dict["bw"]]
126
+ peaks_dict["cf"] = [peaks_dict["cf"]]
127
+ peaks_dict["pw"] = [peaks_dict["pw"]]
128
+
129
+ for peak_idx in range(self.settings.max_n_peaks):
130
+ for feat in self.settings.periodic.get_enabled():
131
+ f_name = f"{ch_name}_fooof_p_{peak_idx}_{self.feat_name_map[feat]}"
132
+
133
+ feature_results[f_name] = (
134
+ peaks_dict[self.feat_name_map[feat]][peak_idx]
135
+ if peak_idx < len(peaks_dict[self.feat_name_map[feat]])
136
+ else None
137
+ )
138
+
139
+ return feature_results
@@ -1,37 +1,45 @@
1
- from typing import Iterator
2
-
3
- import numpy as np
4
-
5
-
6
- def raw_data_generator(
7
- data: np.ndarray,
8
- settings: dict,
9
- sfreq: int,
10
- ) -> Iterator[np.ndarray]:
11
- """
12
- This generator function mimics online data acquisition.
13
- The data are iteratively sampled with sfreq_new.
14
- Arguments
15
- ---------
16
- ieeg_raw (np array): shape (channels, time)
17
- sfreq: int
18
- sfreq_new: int
19
- offset_time: int | float
20
- Returns
21
- -------
22
- np.array: new batch for run function of full segment length shape
23
- """
24
- sfreq_new = settings["sampling_rate_features_hz"]
25
- offset_time = settings["segment_length_features_ms"]
26
- offset_start = offset_time / 1000 * sfreq
27
-
28
- ratio_samples_features = sfreq / sfreq_new
29
-
30
- ratio_counter = 0
31
- for cnt in range(data.shape[1]+1): # shape + 1 guarantees that the last sample is also included
32
-
33
- if (cnt - offset_start) >= ratio_samples_features * ratio_counter:
34
-
35
- ratio_counter += 1
36
-
37
- yield data[:, np.floor(cnt-offset_start).astype(int) : cnt]
1
+ from collections.abc import Iterator
2
+ from typing import TYPE_CHECKING
3
+ import numpy as np
4
+
5
+ if TYPE_CHECKING:
6
+ from py_neuromodulation.nm_settings import NMSettings
7
+
8
+
9
+ def raw_data_generator(
10
+ data: np.ndarray,
11
+ settings: "NMSettings",
12
+ sfreq: float,
13
+ ) -> Iterator[tuple[np.ndarray, np.ndarray]]:
14
+ """
15
+ This generator function mimics online data acquisition.
16
+ The data are iteratively sampled with settings.sampling_rate_features_hz
17
+
18
+ Arguments
19
+ ---------
20
+ data (np array): shape (channels, time)
21
+ settings (nm_settings.NMSettings): settings object
22
+ sfreq (float): sampling frequency of the data
23
+
24
+ Returns
25
+ -------
26
+ np.array: 1D array of time stamps
27
+ np.array: new batch for run function of full segment length shape
28
+ """
29
+ sfreq_new = settings.sampling_rate_features_hz
30
+ offset_time = settings.segment_length_features_ms
31
+ offset_start = offset_time / 1000 * sfreq
32
+
33
+ ratio_samples_features = sfreq / sfreq_new
34
+
35
+ ratio_counter = 0
36
+ for cnt in range(
37
+ data.shape[1] + 1
38
+ ): # shape + 1 guarantees that the last sample is also included
39
+ if (cnt - offset_start) >= ratio_samples_features * ratio_counter:
40
+ ratio_counter += 1
41
+
42
+ yield (
43
+ np.arange(cnt - offset_start, cnt) / sfreq,
44
+ data[:, np.floor(cnt - offset_start).astype(int) : cnt],
45
+ )
@@ -1,73 +1,52 @@
1
- import enum
2
- import numpy as np
3
- from typing import Iterable
4
-
5
- from py_neuromodulation import nm_features_abc
6
-
7
-
8
- class Hjorth(nm_features_abc.Feature):
9
- def __init__(
10
- self, settings: dict, ch_names: Iterable[str], sfreq: float
11
- ) -> None:
12
- self.s = settings
13
- self.ch_names = ch_names
14
-
15
- @staticmethod
16
- def test_settings(
17
- settings: dict,
18
- ch_names: Iterable[str],
19
- sfreq: int | float,
20
- ):
21
- # no settings to test
22
- pass
23
-
24
- def calc_feature(self, data: np.array, features_compute: dict) -> dict:
25
- for ch_idx, ch_name in enumerate(self.ch_names):
26
- features_compute[
27
- "_".join([ch_name, "RawHjorth_Activity"])
28
- ] = np.nan_to_num(np.var(data[ch_idx, :]))
29
- deriv_variance = np.nan_to_num(np.var(np.diff(data[ch_idx, :])))
30
- mobility = np.nan_to_num(
31
- np.sqrt(deriv_variance / np.var(data[ch_idx, :]))
32
- )
33
- features_compute[
34
- "_".join([ch_name, "RawHjorth_Mobility"])
35
- ] = mobility
36
-
37
- dat_deriv_2_var = np.nan_to_num(
38
- np.var(np.diff(np.diff(data[ch_idx, :])))
39
- )
40
- deriv_mobility = np.nan_to_num(
41
- np.sqrt(dat_deriv_2_var / deriv_variance)
42
- )
43
- features_compute[
44
- "_".join([ch_name, "RawHjorth_Complexity"])
45
- ] = np.nan_to_num(deriv_mobility / mobility)
46
-
47
- return features_compute
48
-
49
-
50
- class Raw(nm_features_abc.Feature):
51
- def __init__(
52
- self, settings: dict, ch_names: Iterable[str], sfreq: float
53
- ) -> None:
54
- self.ch_names = ch_names
55
-
56
- def calc_feature(
57
- self,
58
- data: np.array,
59
- features_compute: dict,
60
- ) -> dict:
61
- for ch_idx, ch_name in enumerate(self.ch_names):
62
- features_compute["_".join([ch_name, "raw"])] = data[ch_idx, -1]
63
-
64
- return features_compute
65
-
66
- @staticmethod
67
- def test_settings(
68
- settings: dict,
69
- ch_names: Iterable[str],
70
- sfreq: int | float,
71
- ):
72
- # no settings to test
73
- pass
1
+ """
2
+ Reference: B Hjorth
3
+ EEG analysis based on time domain properties
4
+ Electroencephalogr Clin Neurophysiol. 1970 Sep;29(3):306-10.
5
+ DOI: 10.1016/0013-4694(70)90143-4
6
+ """
7
+
8
+ import numpy as np
9
+ from collections.abc import Iterable
10
+
11
+ from py_neuromodulation.nm_features import NMFeature
12
+ from py_neuromodulation.nm_settings import NMSettings
13
+
14
+
15
+ class Hjorth(NMFeature):
16
+ def __init__(
17
+ self, settings: NMSettings, ch_names: Iterable[str], sfreq: float
18
+ ) -> None:
19
+ self.ch_names = ch_names
20
+
21
+ def calc_feature(self, data: np.ndarray) -> dict:
22
+ var = np.var(data, axis=-1)
23
+ deriv1 = np.diff(data, axis=-1)
24
+ deriv2 = np.diff(deriv1, axis=-1)
25
+ deriv1_var = np.var(deriv1, axis=-1)
26
+ deriv2_var = np.var(deriv2, axis=-1)
27
+ deriv1_mobility = np.sqrt(deriv2_var / deriv1_var)
28
+
29
+ activity = np.nan_to_num(var)
30
+ mobility = np.nan_to_num(np.sqrt(deriv1_var / var))
31
+ complexity = np.nan_to_num(deriv1_mobility / mobility)
32
+
33
+ feature_results = {}
34
+ for ch_idx, ch_name in enumerate(self.ch_names):
35
+ feature_results[f"{ch_name}_RawHjorth_Activity"] = activity[ch_idx]
36
+ feature_results[f"{ch_name}_RawHjorth_Mobility"] = mobility[ch_idx]
37
+ feature_results[f"{ch_name}_RawHjorth_Complexity"] = complexity[ch_idx]
38
+
39
+ return feature_results
40
+
41
+
42
+ class Raw(NMFeature):
43
+ def __init__(self, settings: dict, ch_names: Iterable[str], sfreq: float) -> None:
44
+ self.ch_names = ch_names
45
+
46
+ def calc_feature(self, data: np.ndarray) -> dict:
47
+ feature_results = {}
48
+
49
+ for ch_idx, ch_name in enumerate(self.ch_names):
50
+ feature_results["_".join([ch_name, "raw"])] = data[ch_idx, -1]
51
+
52
+ return feature_results
@@ -1,58 +1,71 @@
1
- from numpy import array, cov
2
- from typing import Iterable
3
-
4
- from filterpy.kalman import KalmanFilter
5
-
6
-
7
- def define_KF(Tp, sigma_w, sigma_v):
8
- """Define Kalman filter according to white noise acceleration model.
9
- See DOI: 10.1109/TBME.2009.2038990 for explanation
10
- See https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html#r64ca38088676-2 for implementation details
11
-
12
- Parameters
13
- ----------
14
- Tp : float
15
- prediction interval
16
- sigma_w : float
17
- process noise
18
- sigma_v : float
19
- measurement noise
20
-
21
- Returns
22
- -------
23
- filterpy.KalmanFilter
24
- initialized KalmanFilter object
25
- """
26
-
27
- f = KalmanFilter(dim_x=2, dim_z=1)
28
- f.x = array([0, 1]) # x here sensor signal and it's first derivative
29
- f.F = array([[1, Tp], [0, 1]])
30
- f.H = array([[1, 0]])
31
- f.R = sigma_v
32
- f.Q = array([[(sigma_w**2)*(Tp**3)/3, (sigma_w**2)*(Tp**2)/2],
33
- [(sigma_w**2)*(Tp**2)/2, (sigma_w**2)*Tp]])
34
- f.P = cov([[1, 0], [0, 1]])
35
- return f
36
-
37
- def test_kf_settings(
38
- s: dict,
39
- ch_names: Iterable[str],
40
- sfreq: int | float,
41
- ):
42
- assert isinstance(s["kalman_filter_settings"]["Tp"], (float, int))
43
- assert isinstance(s["kalman_filter_settings"]["sigma_w"], (float, int))
44
- assert isinstance(s["kalman_filter_settings"]["sigma_v"], (float, int))
45
- assert s["kalman_filter_settings"][
46
- "frequency_bands"
47
- ], "No frequency bands specified for Kalman filter."
48
- assert isinstance(
49
- s["kalman_filter_settings"]["frequency_bands"], list
50
- ), "Frequency bands for Kalman filter must be specified as a list."
51
- assert (
52
- item
53
- in s["frequency_ranges_hz"].values()
54
- for item in s["kalman_filter_settings"]["frequency_bands"]
55
- ), (
56
- "Frequency bands for Kalman filter must also be specified in "
57
- "bandpass_filter_settings."
58
- )
1
+ from numpy import array, cov
2
+ from py_neuromodulation.nm_types import NMBaseModel
3
+ from typing import TYPE_CHECKING
4
+
5
+ from pydantic import field_validator
6
+
7
+ if TYPE_CHECKING:
8
+ from py_neuromodulation.nm_settings import NMSettings
9
+
10
+
11
+ class KalmanSettings(NMBaseModel):
12
+ Tp: float = 0.1
13
+ sigma_w: float = 0.7
14
+ sigma_v: float = 1.0
15
+ frequency_bands: list[str] = [
16
+ "theta",
17
+ "alpha",
18
+ "low_beta",
19
+ "high_beta",
20
+ "low_gamma",
21
+ "high_gamma",
22
+ "HFA",
23
+ ]
24
+
25
+ @field_validator("frequency_bands")
26
+ def fbands_spaces_to_underscores(cls, frequency_bands):
27
+ return [f.replace(" ", "_") for f in frequency_bands]
28
+
29
+ def validate_fbands(self, settings: "NMSettings") -> None:
30
+ assert all(
31
+ (item in settings.frequency_ranges_hz for item in self.frequency_bands)
32
+ ), (
33
+ "Frequency bands for Kalman filter must also be specified in "
34
+ "bandpass_filter_settings."
35
+ )
36
+
37
+
38
+ def define_KF(Tp, sigma_w, sigma_v):
39
+ """Define Kalman filter according to white noise acceleration model.
40
+ See DOI: 10.1109/TBME.2009.2038990 for explanation
41
+ See https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html#r64ca38088676-2 for implementation details
42
+
43
+ Parameters
44
+ ----------
45
+ Tp : float
46
+ prediction interval
47
+ sigma_w : float
48
+ process noise
49
+ sigma_v : float
50
+ measurement noise
51
+
52
+ Returns
53
+ -------
54
+ filterpy.KalmanFilter
55
+ initialized KalmanFilter object
56
+ """
57
+ from filterpy.kalman import KalmanFilter
58
+
59
+ f = KalmanFilter(dim_x=2, dim_z=1)
60
+ f.x = array([0, 1]) # x here sensor signal and it's first derivative
61
+ f.F = array([[1, Tp], [0, 1]])
62
+ f.H = array([[1, 0]])
63
+ f.R = sigma_v
64
+ f.Q = array(
65
+ [
66
+ [(sigma_w**2) * (Tp**3) / 3, (sigma_w**2) * (Tp**2) / 2],
67
+ [(sigma_w**2) * (Tp**2) / 2, (sigma_w**2) * Tp],
68
+ ]
69
+ )
70
+ f.P = cov([[1, 0], [0, 1]])
71
+ return f