py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/__init__.py +80 -13
- py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
- py_neuromodulation/analysis/__init__.py +4 -0
- py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
- py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
- py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
- py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/default_settings.yaml +241 -0
- py_neuromodulation/features/__init__.py +31 -0
- py_neuromodulation/features/bandpower.py +165 -0
- py_neuromodulation/features/bispectra.py +157 -0
- py_neuromodulation/features/bursts.py +297 -0
- py_neuromodulation/features/coherence.py +255 -0
- py_neuromodulation/features/feature_processor.py +121 -0
- py_neuromodulation/features/fooof.py +142 -0
- py_neuromodulation/features/hjorth_raw.py +57 -0
- py_neuromodulation/features/linelength.py +21 -0
- py_neuromodulation/features/mne_connectivity.py +148 -0
- py_neuromodulation/features/nolds.py +94 -0
- py_neuromodulation/features/oscillatory.py +249 -0
- py_neuromodulation/features/sharpwaves.py +432 -0
- py_neuromodulation/filter/__init__.py +3 -0
- py_neuromodulation/filter/kalman_filter.py +67 -0
- py_neuromodulation/filter/kalman_filter_external.py +1890 -0
- py_neuromodulation/filter/mne_filter.py +128 -0
- py_neuromodulation/filter/notch_filter.py +93 -0
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/processing/__init__.py +10 -0
- py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
- py_neuromodulation/processing/data_preprocessor.py +77 -0
- py_neuromodulation/processing/filter_preprocessing.py +78 -0
- py_neuromodulation/processing/normalization.py +175 -0
- py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
- py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
- py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
- py_neuromodulation/stream/__init__.py +3 -0
- py_neuromodulation/stream/data_processor.py +325 -0
- py_neuromodulation/stream/generator.py +53 -0
- py_neuromodulation/stream/mnelsl_player.py +94 -0
- py_neuromodulation/stream/mnelsl_stream.py +120 -0
- py_neuromodulation/stream/settings.py +292 -0
- py_neuromodulation/stream/stream.py +427 -0
- py_neuromodulation/utils/__init__.py +2 -0
- py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
- py_neuromodulation/utils/database.py +149 -0
- py_neuromodulation/utils/io.py +378 -0
- py_neuromodulation/utils/keyboard.py +52 -0
- py_neuromodulation/utils/logging.py +66 -0
- py_neuromodulation/utils/types.py +251 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
- py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/FieldTrip.py +0 -589
- py_neuromodulation/_write_example_dataset_helper.py +0 -65
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_IO.py +0 -417
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_bispectra.py +0 -168
- py_neuromodulation/nm_bursts.py +0 -198
- py_neuromodulation/nm_coherence.py +0 -205
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features.py +0 -116
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_filter.py +0 -219
- py_neuromodulation/nm_filter_preprocessing.py +0 -91
- py_neuromodulation/nm_fooof.py +0 -159
- py_neuromodulation/nm_generator.py +0 -37
- py_neuromodulation/nm_hjorth_raw.py +0 -73
- py_neuromodulation/nm_kalmanfilter.py +0 -58
- py_neuromodulation/nm_linelength.py +0 -33
- py_neuromodulation/nm_mne_connectivity.py +0 -112
- py_neuromodulation/nm_nolds.py +0 -93
- py_neuromodulation/nm_normalization.py +0 -214
- py_neuromodulation/nm_oscillatory.py +0 -448
- py_neuromodulation/nm_run_analysis.py +0 -435
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_settings.py +0 -68
- py_neuromodulation/nm_sharpwaves.py +0 -401
- py_neuromodulation/nm_stream_abc.py +0 -218
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from pydantic import field_validator
|
|
5
|
+
|
|
6
|
+
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from py_neuromodulation.stream.settings import NMSettings
|
|
10
|
+
from py_neuromodulation.filter import KalmanSettings
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BandpowerFeatures(BoolSelector):
|
|
14
|
+
activity: bool = True
|
|
15
|
+
mobility: bool = False
|
|
16
|
+
complexity: bool = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BandPowerSettings(NMBaseModel):
|
|
20
|
+
segment_lengths_ms: dict[str, int] = {
|
|
21
|
+
"theta": 1000,
|
|
22
|
+
"alpha": 500,
|
|
23
|
+
"low beta": 333,
|
|
24
|
+
"high beta": 333,
|
|
25
|
+
"low gamma": 100,
|
|
26
|
+
"high gamma": 100,
|
|
27
|
+
"HFA": 100,
|
|
28
|
+
}
|
|
29
|
+
bandpower_features: BandpowerFeatures = BandpowerFeatures()
|
|
30
|
+
log_transform: bool = True
|
|
31
|
+
kalman_filter: bool = False
|
|
32
|
+
|
|
33
|
+
@field_validator("bandpower_features")
|
|
34
|
+
@classmethod
|
|
35
|
+
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
|
|
36
|
+
assert (
|
|
37
|
+
len(bandpower_features.get_enabled()) > 0
|
|
38
|
+
), "Set at least one bandpower_feature to True."
|
|
39
|
+
|
|
40
|
+
return bandpower_features
|
|
41
|
+
|
|
42
|
+
def validate_fbands(self, settings: "NMSettings") -> None:
|
|
43
|
+
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
|
|
44
|
+
assert seg_length_fband <= settings.segment_length_features_ms, (
|
|
45
|
+
f"segment length {seg_length_fband} needs to be smaller than "
|
|
46
|
+
f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
for fband_name in settings.frequency_ranges_hz.keys():
|
|
50
|
+
assert fband_name in self.segment_lengths_ms, (
|
|
51
|
+
f"frequency range {fband_name} "
|
|
52
|
+
"needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BandPower(NMFeature):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
settings: "NMSettings",
|
|
60
|
+
ch_names: Sequence[str],
|
|
61
|
+
sfreq: float,
|
|
62
|
+
use_kf: bool | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
settings.validate()
|
|
65
|
+
|
|
66
|
+
self.bp_settings: BandPowerSettings = settings.bandpass_filter_settings
|
|
67
|
+
self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings
|
|
68
|
+
self.sfreq = sfreq
|
|
69
|
+
self.ch_names = ch_names
|
|
70
|
+
self.KF_dict: dict = {}
|
|
71
|
+
|
|
72
|
+
from py_neuromodulation.filter import MNEFilter
|
|
73
|
+
|
|
74
|
+
self.bandpass_filter = MNEFilter(
|
|
75
|
+
f_ranges=[
|
|
76
|
+
tuple(frange) for frange in settings.frequency_ranges_hz.values()
|
|
77
|
+
],
|
|
78
|
+
sfreq=self.sfreq,
|
|
79
|
+
filter_length=self.sfreq - 1,
|
|
80
|
+
verbose=False,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if use_kf or (use_kf is None and self.bp_settings.kalman_filter):
|
|
84
|
+
self.init_KF("bandpass_activity")
|
|
85
|
+
|
|
86
|
+
seglengths = self.bp_settings.segment_lengths_ms
|
|
87
|
+
|
|
88
|
+
self.feature_params = []
|
|
89
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
90
|
+
for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()):
|
|
91
|
+
seglength_ms = seglengths[f_band]
|
|
92
|
+
seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
|
|
93
|
+
for bp_feature in self.bp_settings.bandpower_features.get_enabled():
|
|
94
|
+
feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band])
|
|
95
|
+
self.feature_params.append(
|
|
96
|
+
(
|
|
97
|
+
ch_idx,
|
|
98
|
+
f_band_idx,
|
|
99
|
+
seglen,
|
|
100
|
+
bp_feature,
|
|
101
|
+
feature_name,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def init_KF(self, feature: str) -> None:
|
|
106
|
+
from py_neuromodulation.filter import define_KF
|
|
107
|
+
|
|
108
|
+
for f_band in self.kalman_filter_settings.frequency_bands:
|
|
109
|
+
for channel in self.ch_names:
|
|
110
|
+
self.KF_dict["_".join([channel, feature, f_band])] = define_KF(
|
|
111
|
+
self.kalman_filter_settings.Tp,
|
|
112
|
+
self.kalman_filter_settings.sigma_w,
|
|
113
|
+
self.kalman_filter_settings.sigma_v,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating:
|
|
117
|
+
if KF_name in self.KF_dict:
|
|
118
|
+
self.KF_dict[KF_name].predict()
|
|
119
|
+
self.KF_dict[KF_name].update(feature_calc)
|
|
120
|
+
feature_calc = self.KF_dict[KF_name].x[0]
|
|
121
|
+
return feature_calc
|
|
122
|
+
|
|
123
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
124
|
+
data = self.bandpass_filter.filter_data(data)
|
|
125
|
+
|
|
126
|
+
feature_results = {}
|
|
127
|
+
for (
|
|
128
|
+
ch_idx,
|
|
129
|
+
f_band_idx,
|
|
130
|
+
seglen,
|
|
131
|
+
bp_feature,
|
|
132
|
+
feature_name,
|
|
133
|
+
) in self.feature_params:
|
|
134
|
+
feature_results[feature_name] = self.calc_bp_feature(
|
|
135
|
+
bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return feature_results
|
|
139
|
+
|
|
140
|
+
def calc_bp_feature(self, bp_feature, feature_name, data):
|
|
141
|
+
match bp_feature:
|
|
142
|
+
case "activity":
|
|
143
|
+
feature_calc = np.var(data)
|
|
144
|
+
if self.bp_settings.log_transform:
|
|
145
|
+
feature_calc = np.log10(feature_calc)
|
|
146
|
+
if self.KF_dict:
|
|
147
|
+
feature_calc = self.update_KF(feature_calc, feature_name)
|
|
148
|
+
case "mobility":
|
|
149
|
+
feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data))
|
|
150
|
+
case "complexity":
|
|
151
|
+
feature_calc = self.calc_complexity(data)
|
|
152
|
+
case _:
|
|
153
|
+
raise ValueError(f"Unknown bandpower feature: {bp_feature}")
|
|
154
|
+
|
|
155
|
+
return np.nan_to_num(feature_calc)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def calc_complexity(data: np.ndarray) -> float:
|
|
159
|
+
dat_deriv = np.diff(data)
|
|
160
|
+
deriv_variance = np.var(dat_deriv)
|
|
161
|
+
mobility = np.sqrt(deriv_variance / np.var(data))
|
|
162
|
+
dat_deriv_2_var = np.var(np.diff(dat_deriv))
|
|
163
|
+
deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
|
|
164
|
+
|
|
165
|
+
return deriv_mobility / mobility
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from pydantic import field_validator
|
|
3
|
+
from typing import TYPE_CHECKING, Callable
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from py_neuromodulation.utils.types import (
|
|
8
|
+
NMBaseModel,
|
|
9
|
+
NMFeature,
|
|
10
|
+
BoolSelector,
|
|
11
|
+
FrequencyRange,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from py_neuromodulation import NMSettings
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BispectraComponents(BoolSelector):
|
|
19
|
+
absolute: bool = True
|
|
20
|
+
real: bool = True
|
|
21
|
+
imag: bool = True
|
|
22
|
+
phase: bool = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BispectraFeatures(BoolSelector):
|
|
26
|
+
mean: bool = True
|
|
27
|
+
sum: bool = True
|
|
28
|
+
var: bool = True
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BispectraSettings(NMBaseModel):
|
|
32
|
+
f1s: FrequencyRange = FrequencyRange(5, 35)
|
|
33
|
+
f2s: FrequencyRange = FrequencyRange(5, 35)
|
|
34
|
+
compute_features_for_whole_fband_range: bool = True
|
|
35
|
+
frequency_bands: list[str] = ["theta", "alpha", "low_beta", "high_beta"]
|
|
36
|
+
|
|
37
|
+
components: BispectraComponents = BispectraComponents()
|
|
38
|
+
bispectrum_features: BispectraFeatures = BispectraFeatures()
|
|
39
|
+
|
|
40
|
+
@field_validator("f1s", "f2s")
|
|
41
|
+
def test_range(cls, filter_range):
|
|
42
|
+
assert (
|
|
43
|
+
filter_range[1] > filter_range[0]
|
|
44
|
+
), f"second frequency range value needs to be higher than first one, got {filter_range}"
|
|
45
|
+
return filter_range
|
|
46
|
+
|
|
47
|
+
@field_validator("frequency_bands")
|
|
48
|
+
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
49
|
+
return [f.replace(" ", "_") for f in frequency_bands]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
FEATURE_DICT: dict[str, Callable] = {
|
|
53
|
+
"mean": np.nanmean,
|
|
54
|
+
"sum": np.nansum,
|
|
55
|
+
"var": np.nanvar,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
COMPONENT_DICT: dict[str, Callable] = {
|
|
59
|
+
"real": lambda obj: getattr(obj, "real"),
|
|
60
|
+
"imag": lambda obj: getattr(obj, "imag"),
|
|
61
|
+
"absolute": np.abs,
|
|
62
|
+
"phase": np.angle,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Bispectra(NMFeature):
|
|
67
|
+
def __init__(
|
|
68
|
+
self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
|
|
69
|
+
) -> None:
|
|
70
|
+
self.sfreq = sfreq
|
|
71
|
+
self.ch_names = ch_names
|
|
72
|
+
self.frequency_ranges_hz = settings.frequency_ranges_hz
|
|
73
|
+
self.settings: BispectraSettings = settings.bispectrum_settings
|
|
74
|
+
|
|
75
|
+
assert all(
|
|
76
|
+
f_band_bispectrum in settings.frequency_ranges_hz
|
|
77
|
+
for f_band_bispectrum in self.settings.frequency_bands
|
|
78
|
+
), (
|
|
79
|
+
"bispectrum selected frequency bands don't match the ones"
|
|
80
|
+
"specified in s['frequency_ranges_hz']"
|
|
81
|
+
f"bispectrum frequency bands: {self.settings.frequency_bands}"
|
|
82
|
+
f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self.used_features = self.settings.bispectrum_features.get_enabled()
|
|
86
|
+
|
|
87
|
+
self.min_freq = min(
|
|
88
|
+
self.settings.f1s.frequency_low_hz, self.settings.f2s.frequency_low_hz
|
|
89
|
+
)
|
|
90
|
+
self.max_freq = max(
|
|
91
|
+
self.settings.f1s.frequency_high_hz, self.settings.f2s.frequency_high_hz
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# self.freqs: np.ndarray = np.array([]) # In case we pre-computed this
|
|
95
|
+
|
|
96
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
97
|
+
from pybispectra import compute_fft, WaveShape
|
|
98
|
+
|
|
99
|
+
# PyBispectra's compute_fft uses PQDM to parallelize the calculation per channel
|
|
100
|
+
# Is this necessary? Maybe the overhead of parallelization is not worth it
|
|
101
|
+
# considering that we incur in it once per batch of data
|
|
102
|
+
fft_coeffs, freqs = compute_fft(
|
|
103
|
+
data=np.expand_dims(data, axis=(0)),
|
|
104
|
+
sampling_freq=self.sfreq,
|
|
105
|
+
n_points=data.shape[1],
|
|
106
|
+
verbose=False,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# freqs is batch independent, except for the last batch perhaps (if it has different shape)
|
|
110
|
+
# but it's computed by compute_fft regardless so no advantage in pre-computing it
|
|
111
|
+
# if not self.freqs = self.freqs = np.fft.rfftfreq(n=data.shape[1], d = 1 / sfreq)
|
|
112
|
+
|
|
113
|
+
# fft_coeffs shape: [epochs, channels, frequencies]
|
|
114
|
+
|
|
115
|
+
f_spectrum_range = freqs[
|
|
116
|
+
np.logical_and(freqs >= self.min_freq, freqs <= self.max_freq)
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
waveshape = WaveShape(
|
|
120
|
+
data=fft_coeffs,
|
|
121
|
+
freqs=freqs,
|
|
122
|
+
sampling_freq=self.sfreq,
|
|
123
|
+
verbose=False,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
waveshape.compute(
|
|
127
|
+
f1s=tuple(self.settings.f1s), # type: ignore
|
|
128
|
+
f2s=tuple(self.settings.f2s), # type: ignore
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
feature_results = {}
|
|
132
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
133
|
+
bispectrum = waveshape._bicoherence[
|
|
134
|
+
ch_idx
|
|
135
|
+
] # Same as waveshape.results._data, skips a copy
|
|
136
|
+
|
|
137
|
+
for component in self.settings.components.get_enabled():
|
|
138
|
+
spectrum_ch = COMPONENT_DICT[component](bispectrum)
|
|
139
|
+
|
|
140
|
+
for fb in self.settings.frequency_bands:
|
|
141
|
+
range_ = (f_spectrum_range >= self.frequency_ranges_hz[fb][0]) & (
|
|
142
|
+
f_spectrum_range <= self.frequency_ranges_hz[fb][1]
|
|
143
|
+
)
|
|
144
|
+
# waveshape.results.plot()
|
|
145
|
+
data_bs = spectrum_ch[range_, range_]
|
|
146
|
+
|
|
147
|
+
for bispectrum_feature in self.used_features:
|
|
148
|
+
feature_results[
|
|
149
|
+
f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_{fb}"
|
|
150
|
+
] = FEATURE_DICT[bispectrum_feature](data_bs)
|
|
151
|
+
|
|
152
|
+
if self.settings.compute_features_for_whole_fband_range:
|
|
153
|
+
feature_results[
|
|
154
|
+
f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_whole_fband_range"
|
|
155
|
+
] = FEATURE_DICT[bispectrum_feature](spectrum_ch)
|
|
156
|
+
|
|
157
|
+
return feature_results
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
if np.__version__ >= "2.0.0":
|
|
4
|
+
from numpy.lib._function_base_impl import _quantile as np_quantile # type:ignore
|
|
5
|
+
else:
|
|
6
|
+
from numpy.lib.function_base import _quantile as np_quantile # type:ignore
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from itertools import product
|
|
9
|
+
|
|
10
|
+
from pydantic import Field, field_validator
|
|
11
|
+
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
|
|
12
|
+
|
|
13
|
+
from typing import TYPE_CHECKING, Callable
|
|
14
|
+
from py_neuromodulation.utils.types import create_validation_error
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from py_neuromodulation import NMSettings
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
LARGE_NUM = 2**24
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_label_pos(burst_labels, valid_labels):
|
|
24
|
+
max_label = np.max(burst_labels, axis=2).flatten()
|
|
25
|
+
min_label = np.min(
|
|
26
|
+
burst_labels, axis=2, initial=LARGE_NUM, where=burst_labels != 0
|
|
27
|
+
).flatten()
|
|
28
|
+
label_positions = np.zeros_like(valid_labels)
|
|
29
|
+
N = len(valid_labels)
|
|
30
|
+
pos = 0
|
|
31
|
+
i = 0
|
|
32
|
+
while i < N:
|
|
33
|
+
if valid_labels[i] >= min_label[pos] and valid_labels[i] <= max_label[pos]:
|
|
34
|
+
label_positions[i] = pos
|
|
35
|
+
i += 1
|
|
36
|
+
else:
|
|
37
|
+
pos += 1
|
|
38
|
+
return label_positions
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BurstFeatures(BoolSelector):
|
|
42
|
+
duration: bool = True
|
|
43
|
+
amplitude: bool = True
|
|
44
|
+
burst_rate_per_s: bool = True
|
|
45
|
+
in_burst: bool = True
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BurstsSettings(NMBaseModel):
|
|
49
|
+
threshold: float = Field(default=75, ge=0, le=100)
|
|
50
|
+
time_duration_s: float = Field(default=30, ge=0)
|
|
51
|
+
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
|
|
52
|
+
burst_features: BurstFeatures = BurstFeatures()
|
|
53
|
+
|
|
54
|
+
@field_validator("frequency_bands")
|
|
55
|
+
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
56
|
+
return [f.replace(" ", "_") for f in frequency_bands]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Bursts(NMFeature):
|
|
60
|
+
def __init__(
|
|
61
|
+
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
|
|
62
|
+
) -> None:
|
|
63
|
+
# Test settings
|
|
64
|
+
settings.validate()
|
|
65
|
+
|
|
66
|
+
# Validate that all frequency bands are defined in the settings
|
|
67
|
+
for fband_burst in settings.burst_settings.frequency_bands:
|
|
68
|
+
if fband_burst not in list(settings.frequency_ranges_hz.keys()):
|
|
69
|
+
raise create_validation_error(
|
|
70
|
+
f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
|
|
71
|
+
loc=["burst_settings", "frequency_bands"],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
from py_neuromodulation.filter import MNEFilter
|
|
75
|
+
|
|
76
|
+
self.settings = settings.burst_settings
|
|
77
|
+
self.sfreq = sfreq
|
|
78
|
+
self.ch_names = ch_names
|
|
79
|
+
self.segment_length_features_s = settings.segment_length_features_ms / 1000
|
|
80
|
+
self.samples_overlap = int(
|
|
81
|
+
self.sfreq
|
|
82
|
+
* self.segment_length_features_s
|
|
83
|
+
/ settings.sampling_rate_features_hz
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.fband_names = settings.burst_settings.frequency_bands
|
|
87
|
+
|
|
88
|
+
f_ranges: list[tuple[float, float]] = [
|
|
89
|
+
(
|
|
90
|
+
settings.frequency_ranges_hz[fband_name][0],
|
|
91
|
+
settings.frequency_ranges_hz[fband_name][1],
|
|
92
|
+
)
|
|
93
|
+
for fband_name in self.fband_names
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
self.bandpass_filter = MNEFilter(
|
|
97
|
+
f_ranges=f_ranges,
|
|
98
|
+
sfreq=self.sfreq,
|
|
99
|
+
filter_length=self.sfreq - 1,
|
|
100
|
+
verbose=False,
|
|
101
|
+
)
|
|
102
|
+
self.filter_data = self.bandpass_filter.filter_data
|
|
103
|
+
|
|
104
|
+
self.num_max_samples_ring_buffer = int(
|
|
105
|
+
self.sfreq * self.settings.time_duration_s
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.n_channels = len(self.ch_names)
|
|
109
|
+
self.n_fbands = len(self.fband_names)
|
|
110
|
+
|
|
111
|
+
# Create circular buffer array for previous time_duration_s
|
|
112
|
+
self.data_buffer = np.empty(
|
|
113
|
+
(self.n_channels, self.n_fbands, 0), dtype=np.float64
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.used_features = self.settings.burst_features.get_enabled()
|
|
117
|
+
|
|
118
|
+
self.feature_combinations = list(
|
|
119
|
+
product(
|
|
120
|
+
enumerate(self.ch_names),
|
|
121
|
+
enumerate(self.fband_names),
|
|
122
|
+
self.settings.burst_features.get_enabled(),
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Variables to store results
|
|
127
|
+
self.burst_duration_mean: np.ndarray
|
|
128
|
+
self.burst_duration_max: np.ndarray
|
|
129
|
+
self.burst_amplitude_max: np.ndarray
|
|
130
|
+
self.burst_amplitude_mean: np.ndarray
|
|
131
|
+
self.burst_rate_per_s: np.ndarray
|
|
132
|
+
self.end_in_burst: np.ndarray
|
|
133
|
+
|
|
134
|
+
self.STORE_FEAT_DICT: dict[str, Callable] = {
|
|
135
|
+
"duration": self.store_duration,
|
|
136
|
+
"amplitude": self.store_amplitude,
|
|
137
|
+
"burst_rate_per_s": self.store_burst_rate,
|
|
138
|
+
"in_burst": self.store_in_burst,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
self.batch = 0
|
|
142
|
+
|
|
143
|
+
# Structure matrix for np.ndimage.label
|
|
144
|
+
# pixels are connected only to adjacent neighbors along the last axis
|
|
145
|
+
self.label_structure_matrix = np.zeros((3, 3, 3))
|
|
146
|
+
self.label_structure_matrix[1, 1, :] = 1
|
|
147
|
+
|
|
148
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
149
|
+
from scipy.signal import hilbert
|
|
150
|
+
from scipy.ndimage import label, sum_labels as label_sum, mean as label_mean
|
|
151
|
+
|
|
152
|
+
filtered_data = np.abs(np.array(hilbert(self.filter_data(data))))
|
|
153
|
+
|
|
154
|
+
# Update buffer array
|
|
155
|
+
batch_size = (
|
|
156
|
+
filtered_data.shape[-1] if self.batch == 0 else self.samples_overlap
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
self.batch += 1
|
|
160
|
+
self.data_buffer = np.concatenate(
|
|
161
|
+
(
|
|
162
|
+
self.data_buffer,
|
|
163
|
+
filtered_data[:, :, -batch_size:],
|
|
164
|
+
),
|
|
165
|
+
axis=2,
|
|
166
|
+
)[:, :, -self.num_max_samples_ring_buffer :]
|
|
167
|
+
|
|
168
|
+
# Burst threshold is calculated with the percentile defined in the settings
|
|
169
|
+
# Call low-level numpy function directly, extra checks not needed
|
|
170
|
+
burst_thr = np_quantile(self.data_buffer, self.settings.threshold / 100)[
|
|
171
|
+
:, :, None
|
|
172
|
+
] # Add back the extra dimension
|
|
173
|
+
|
|
174
|
+
# Get burst locations as a boolean array, True where data is above threshold (i.e. a burst)
|
|
175
|
+
bursts = filtered_data >= burst_thr
|
|
176
|
+
|
|
177
|
+
# Use np.diff to find the places where bursts start and end
|
|
178
|
+
# Prepend False at the beginning ensures that data never starts on a burst
|
|
179
|
+
# Floor division to ignore last burst if series ends in a burst (true burst length unknown)
|
|
180
|
+
num_bursts = (
|
|
181
|
+
np.sum(np.diff(bursts, axis=2, prepend=False), axis=2) // 2
|
|
182
|
+
).astype(np.float64) # np.astype added to avoid casting error in np.divide
|
|
183
|
+
|
|
184
|
+
# Label each burst with a unique id, limiting connectivity to last axis (see scipy.ndimage.label docs for details)
|
|
185
|
+
burst_labels = label(bursts, self.label_structure_matrix)[0] # type: ignore # wrong return type in scipy
|
|
186
|
+
|
|
187
|
+
# Remove labels of bursts that are at the end of the dataset, and 0
|
|
188
|
+
labels_at_end = np.concatenate((np.unique(burst_labels[:, :, -1]), (0,)))
|
|
189
|
+
valid_labels = np.unique(burst_labels)
|
|
190
|
+
valid_labels = valid_labels[
|
|
191
|
+
~np.isin(valid_labels, labels_at_end, assume_unique=True)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# Find (channel, band) coordinates for each valid label and get an array that maps each valid label to its channel/band
|
|
195
|
+
# Channel band coordinate is flattened to a 1D array of length (n_channels x n_fbands)
|
|
196
|
+
label_positions = get_label_pos(burst_labels, valid_labels)
|
|
197
|
+
|
|
198
|
+
# Now we're ready to calculate features
|
|
199
|
+
|
|
200
|
+
if "duration" in self.used_features or "burst_rate_per_s" in self.used_features:
|
|
201
|
+
# Handle division by zero using np.divide. Where num_bursts is 0, the result is 0
|
|
202
|
+
self.burst_duration_mean = (
|
|
203
|
+
np.divide(
|
|
204
|
+
np.sum(bursts, axis=2),
|
|
205
|
+
num_bursts,
|
|
206
|
+
out=np.zeros_like(num_bursts),
|
|
207
|
+
where=num_bursts != 0,
|
|
208
|
+
)
|
|
209
|
+
/ self.sfreq
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if "duration" in self.used_features:
|
|
213
|
+
# First get burst length for each valid burst
|
|
214
|
+
burst_lengths = (
|
|
215
|
+
label_sum(bursts, burst_labels, index=valid_labels) / self.sfreq
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Now the max needs to be calculated per channel/band
|
|
219
|
+
# For that, loop over channels/bands, get the corresponding burst lengths, and get the max
|
|
220
|
+
# Give parameter initial=0 so that when there are no bursts, the max is 0
|
|
221
|
+
# TODO: it might be interesting to write a C function for this
|
|
222
|
+
duration_max_flat = np.zeros(self.n_channels * self.n_fbands)
|
|
223
|
+
for idx in range(self.n_channels * self.n_fbands):
|
|
224
|
+
duration_max_flat[idx] = np.max(
|
|
225
|
+
burst_lengths[label_positions == idx], initial=0
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.burst_duration_max = duration_max_flat.reshape(
|
|
229
|
+
(self.n_channels, self.n_fbands)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if "amplitude" in self.used_features:
|
|
233
|
+
# Max amplitude is just the max of the filtered data where there is a burst
|
|
234
|
+
self.burst_amplitude_max = (filtered_data * bursts).max(axis=2)
|
|
235
|
+
|
|
236
|
+
# The mean is actually a mean of means, so we need the mean for each individual burst
|
|
237
|
+
label_means = label_mean(filtered_data, burst_labels, index=valid_labels)
|
|
238
|
+
# Now, loop over channels/bands, get the corresponding burst means, and calculate the mean of means
|
|
239
|
+
# TODO: it might be interesting to write a C function for this
|
|
240
|
+
amplitude_mean_flat = np.zeros(self.n_channels * self.n_fbands)
|
|
241
|
+
for idx in range(self.n_channels * self.n_fbands):
|
|
242
|
+
mask = label_positions == idx
|
|
243
|
+
amplitude_mean_flat[idx] = (
|
|
244
|
+
np.mean(label_means[mask]) if np.any(mask) else 0
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
self.burst_amplitude_mean = amplitude_mean_flat.reshape(
|
|
248
|
+
(self.n_channels, self.n_fbands)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if "burst_rate_per_s" in self.used_features:
|
|
252
|
+
self.burst_rate_per_s = (
|
|
253
|
+
self.burst_duration_mean / self.segment_length_features_s
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if "in_burst" in self.used_features:
|
|
257
|
+
self.end_in_burst = bursts[:, :, -1] # End in burst
|
|
258
|
+
|
|
259
|
+
# Create dictionary of features which is the correct return format
|
|
260
|
+
feature_results = {}
|
|
261
|
+
for (ch_i, ch), (fb_i, fb), feat in self.feature_combinations:
|
|
262
|
+
self.STORE_FEAT_DICT[feat](feature_results, ch_i, ch, fb_i, fb)
|
|
263
|
+
|
|
264
|
+
return feature_results
|
|
265
|
+
|
|
266
|
+
def store_duration(
|
|
267
|
+
self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
|
|
268
|
+
):
|
|
269
|
+
feature_results[f"{ch}_bursts_{fb}_duration_mean"] = self.burst_duration_mean[
|
|
270
|
+
ch_i, fb_i
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
feature_results[f"{ch}_bursts_{fb}_duration_max"] = self.burst_duration_max[
|
|
274
|
+
ch_i, fb_i
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
def store_amplitude(
|
|
278
|
+
self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
|
|
279
|
+
):
|
|
280
|
+
feature_results[f"{ch}_bursts_{fb}_amplitude_mean"] = self.burst_amplitude_mean[
|
|
281
|
+
ch_i, fb_i
|
|
282
|
+
]
|
|
283
|
+
feature_results[f"{ch}_bursts_{fb}_amplitude_max"] = self.burst_amplitude_max[
|
|
284
|
+
ch_i, fb_i
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
def store_burst_rate(
|
|
288
|
+
self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
|
|
289
|
+
):
|
|
290
|
+
feature_results[f"{ch}_bursts_{fb}_burst_rate_per_s"] = self.burst_rate_per_s[
|
|
291
|
+
ch_i, fb_i
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
def store_in_burst(
|
|
295
|
+
self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
|
|
296
|
+
):
|
|
297
|
+
feature_results[f"{ch}_bursts_{fb}_in_burst"] = self.end_in_burst[ch_i, fb_i]
|