py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/__init__.py +80 -13
- py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
- py_neuromodulation/analysis/__init__.py +4 -0
- py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
- py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
- py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
- py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/default_settings.yaml +241 -0
- py_neuromodulation/features/__init__.py +31 -0
- py_neuromodulation/features/bandpower.py +165 -0
- py_neuromodulation/features/bispectra.py +157 -0
- py_neuromodulation/features/bursts.py +297 -0
- py_neuromodulation/features/coherence.py +255 -0
- py_neuromodulation/features/feature_processor.py +121 -0
- py_neuromodulation/features/fooof.py +142 -0
- py_neuromodulation/features/hjorth_raw.py +57 -0
- py_neuromodulation/features/linelength.py +21 -0
- py_neuromodulation/features/mne_connectivity.py +148 -0
- py_neuromodulation/features/nolds.py +94 -0
- py_neuromodulation/features/oscillatory.py +249 -0
- py_neuromodulation/features/sharpwaves.py +432 -0
- py_neuromodulation/filter/__init__.py +3 -0
- py_neuromodulation/filter/kalman_filter.py +67 -0
- py_neuromodulation/filter/kalman_filter_external.py +1890 -0
- py_neuromodulation/filter/mne_filter.py +128 -0
- py_neuromodulation/filter/notch_filter.py +93 -0
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/processing/__init__.py +10 -0
- py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
- py_neuromodulation/processing/data_preprocessor.py +77 -0
- py_neuromodulation/processing/filter_preprocessing.py +78 -0
- py_neuromodulation/processing/normalization.py +175 -0
- py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
- py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
- py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
- py_neuromodulation/stream/__init__.py +3 -0
- py_neuromodulation/stream/data_processor.py +325 -0
- py_neuromodulation/stream/generator.py +53 -0
- py_neuromodulation/stream/mnelsl_player.py +94 -0
- py_neuromodulation/stream/mnelsl_stream.py +120 -0
- py_neuromodulation/stream/settings.py +292 -0
- py_neuromodulation/stream/stream.py +427 -0
- py_neuromodulation/utils/__init__.py +2 -0
- py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
- py_neuromodulation/utils/database.py +149 -0
- py_neuromodulation/utils/io.py +378 -0
- py_neuromodulation/utils/keyboard.py +52 -0
- py_neuromodulation/utils/logging.py +66 -0
- py_neuromodulation/utils/types.py +251 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
- py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/FieldTrip.py +0 -589
- py_neuromodulation/_write_example_dataset_helper.py +0 -65
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_IO.py +0 -417
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_bispectra.py +0 -168
- py_neuromodulation/nm_bursts.py +0 -198
- py_neuromodulation/nm_coherence.py +0 -205
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features.py +0 -116
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_filter.py +0 -219
- py_neuromodulation/nm_filter_preprocessing.py +0 -91
- py_neuromodulation/nm_fooof.py +0 -159
- py_neuromodulation/nm_generator.py +0 -37
- py_neuromodulation/nm_hjorth_raw.py +0 -73
- py_neuromodulation/nm_kalmanfilter.py +0 -58
- py_neuromodulation/nm_linelength.py +0 -33
- py_neuromodulation/nm_mne_connectivity.py +0 -112
- py_neuromodulation/nm_nolds.py +0 -93
- py_neuromodulation/nm_normalization.py +0 -214
- py_neuromodulation/nm_oscillatory.py +0 -448
- py_neuromodulation/nm_run_analysis.py +0 -435
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_settings.py +0 -68
- py_neuromodulation/nm_sharpwaves.py +0 -401
- py_neuromodulation/nm_stream_abc.py +0 -218
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from itertools import product
|
|
4
|
+
|
|
5
|
+
from pydantic import model_validator
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
if np.__version__ >= "2.0.0":
|
|
11
|
+
from numpy._core._methods import _mean as np_mean # type: ignore
|
|
12
|
+
else:
|
|
13
|
+
from numpy.core._methods import _mean as np_mean
|
|
14
|
+
|
|
15
|
+
from py_neuromodulation.utils.types import (
|
|
16
|
+
NMFeature,
|
|
17
|
+
NMBaseModel,
|
|
18
|
+
BoolSelector,
|
|
19
|
+
FrequencyRange,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from py_neuromodulation import NMSettings
|
|
24
|
+
|
|
25
|
+
# Using low-level numpy mean function for performance, could do the same for the other estimators
|
|
26
|
+
ESTIMATOR_DICT = {
|
|
27
|
+
"mean": np_mean,
|
|
28
|
+
"median": np.median,
|
|
29
|
+
"max": np.max,
|
|
30
|
+
"min": np.min,
|
|
31
|
+
"var": np.var,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PeakDetectionSettings(NMBaseModel):
|
|
36
|
+
estimate: bool = True
|
|
37
|
+
distance_troughs_ms: float = 10
|
|
38
|
+
distance_peaks_ms: float = 5
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SharpwaveFeatures(BoolSelector):
|
|
42
|
+
peak_left: bool = False
|
|
43
|
+
peak_right: bool = False
|
|
44
|
+
num_peaks: bool = False
|
|
45
|
+
trough: bool = False
|
|
46
|
+
width: bool = False
|
|
47
|
+
prominence: bool = True
|
|
48
|
+
interval: bool = True
|
|
49
|
+
decay_time: bool = False
|
|
50
|
+
rise_time: bool = False
|
|
51
|
+
sharpness: bool = True
|
|
52
|
+
rise_steepness: bool = False
|
|
53
|
+
decay_steepness: bool = False
|
|
54
|
+
slope_ratio: bool = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SharpwaveEstimators(NMBaseModel):
|
|
58
|
+
mean: list[str] = ["interval"]
|
|
59
|
+
median: list[str] = []
|
|
60
|
+
max: list[str] = ["prominence", "sharpness"]
|
|
61
|
+
min: list[str] = []
|
|
62
|
+
var: list[str] = []
|
|
63
|
+
|
|
64
|
+
def keys(self):
|
|
65
|
+
return ["mean", "median", "max", "min", "var"]
|
|
66
|
+
|
|
67
|
+
def values(self):
|
|
68
|
+
return [self.mean, self.median, self.max, self.min, self.var]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class SharpwaveSettings(NMBaseModel):
|
|
72
|
+
sharpwave_features: SharpwaveFeatures = SharpwaveFeatures()
|
|
73
|
+
filter_ranges_hz: list[FrequencyRange] = [
|
|
74
|
+
FrequencyRange(5, 80),
|
|
75
|
+
FrequencyRange(5, 30),
|
|
76
|
+
]
|
|
77
|
+
detect_troughs: PeakDetectionSettings = PeakDetectionSettings()
|
|
78
|
+
detect_peaks: PeakDetectionSettings = PeakDetectionSettings()
|
|
79
|
+
estimator: SharpwaveEstimators = SharpwaveEstimators()
|
|
80
|
+
apply_estimator_between_peaks_and_troughs: bool = True
|
|
81
|
+
|
|
82
|
+
def disable_all_features(self):
|
|
83
|
+
self.sharpwave_features.disable_all()
|
|
84
|
+
for est in self.estimator.keys():
|
|
85
|
+
self.estimator[est] = []
|
|
86
|
+
|
|
87
|
+
@model_validator(mode="after")
|
|
88
|
+
def test_settings(cls, settings):
|
|
89
|
+
# check if all features are also enabled via an estimator
|
|
90
|
+
estimator_list = [est for list_ in settings.estimator.values() for est in list_]
|
|
91
|
+
|
|
92
|
+
for used_feature in settings.sharpwave_features.get_enabled():
|
|
93
|
+
assert (
|
|
94
|
+
used_feature in estimator_list
|
|
95
|
+
), f"Add estimator key for {used_feature}"
|
|
96
|
+
|
|
97
|
+
return settings
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class SharpwaveAnalyzer(NMFeature):
|
|
101
|
+
def __init__(
|
|
102
|
+
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
|
|
103
|
+
) -> None:
|
|
104
|
+
self.sw_settings = settings.sharpwave_analysis_settings
|
|
105
|
+
self.sfreq = sfreq
|
|
106
|
+
self.ch_names = ch_names
|
|
107
|
+
self.list_filter: list[tuple[str, Any]] = []
|
|
108
|
+
self.trough: list = []
|
|
109
|
+
self.troughs_idx: list = []
|
|
110
|
+
|
|
111
|
+
settings.validate()
|
|
112
|
+
|
|
113
|
+
# FrequencyRange's are already ensured to have high > low
|
|
114
|
+
# Test that the higher frequency is smaller than the sampling frequency
|
|
115
|
+
for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
|
|
116
|
+
assert filter_range[1] < sfreq, (
|
|
117
|
+
"Filter range has to be smaller than sfreq, "
|
|
118
|
+
f"got sfreq {sfreq} and filter range {filter_range}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
|
|
122
|
+
# Test settings
|
|
123
|
+
# TODO: handle None values
|
|
124
|
+
if filter_range[0] is None:
|
|
125
|
+
self.list_filter.append(("no_filter", None))
|
|
126
|
+
else:
|
|
127
|
+
from mne.filter import create_filter
|
|
128
|
+
|
|
129
|
+
self.list_filter.append(
|
|
130
|
+
(
|
|
131
|
+
f"range_{filter_range[0]:.0f}_{filter_range[1]:.0f}",
|
|
132
|
+
create_filter(
|
|
133
|
+
None,
|
|
134
|
+
sfreq,
|
|
135
|
+
l_freq=filter_range[0],
|
|
136
|
+
h_freq=filter_range[1],
|
|
137
|
+
fir_design="firwin",
|
|
138
|
+
# l_trans_bandwidth=None,
|
|
139
|
+
# h_trans_bandwidth=None,
|
|
140
|
+
# filter_length=str(sfreq) + "ms",
|
|
141
|
+
verbose=False,
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.filter_names = [name for name, _ in self.list_filter]
|
|
147
|
+
self.filters = np.vstack([filter for _, filter in self.list_filter])
|
|
148
|
+
self.filters = np.tile(self.filters[None, :, :], (len(self.ch_names), 1, 1))
|
|
149
|
+
|
|
150
|
+
self.used_features = self.sw_settings.sharpwave_features.get_enabled()
|
|
151
|
+
|
|
152
|
+
# initializing estimator functions, respecitive for all sharpwave features
|
|
153
|
+
self.estimator_dict: dict[str, dict[str, Callable]] = {
|
|
154
|
+
feat: {
|
|
155
|
+
est: ESTIMATOR_DICT[est]
|
|
156
|
+
for est in self.sw_settings.estimator.keys()
|
|
157
|
+
if feat in self.sw_settings.estimator[est]
|
|
158
|
+
}
|
|
159
|
+
for feat_list in self.sw_settings.estimator.values()
|
|
160
|
+
for feat in feat_list
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
estimator_combinations = [
|
|
164
|
+
(feature_name, estimator_name, estimator)
|
|
165
|
+
for feature_name in self.used_features
|
|
166
|
+
for estimator_name, estimator in self.estimator_dict[feature_name].items()
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
filter_combinations = list(
|
|
170
|
+
product(
|
|
171
|
+
enumerate(self.ch_names), enumerate(self.filter_names), [False, True]
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.estimator_key_map: dict[str, Callable] = {}
|
|
176
|
+
self.combinations = []
|
|
177
|
+
for (ch_idx, ch_name), (
|
|
178
|
+
filter_idx,
|
|
179
|
+
filter_name,
|
|
180
|
+
), detect_troughs in filter_combinations:
|
|
181
|
+
for feature_name, estimator_name, estimator in estimator_combinations:
|
|
182
|
+
key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
|
|
183
|
+
self.estimator_key_map[key_name] = estimator
|
|
184
|
+
self.combinations.append(
|
|
185
|
+
(
|
|
186
|
+
(ch_idx, ch_name),
|
|
187
|
+
(filter_idx, filter_name),
|
|
188
|
+
detect_troughs,
|
|
189
|
+
estimator_combinations,
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Check required feature computations according to settings
|
|
194
|
+
self.need_peak_left = (
|
|
195
|
+
self.sw_settings.sharpwave_features.peak_left
|
|
196
|
+
or self.sw_settings.sharpwave_features.prominence
|
|
197
|
+
)
|
|
198
|
+
self.need_peak_right = (
|
|
199
|
+
self.sw_settings.sharpwave_features.peak_right
|
|
200
|
+
or self.sw_settings.sharpwave_features.prominence
|
|
201
|
+
)
|
|
202
|
+
self.need_trough = (
|
|
203
|
+
self.sw_settings.sharpwave_features.trough
|
|
204
|
+
or self.sw_settings.sharpwave_features.prominence
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
self.need_decay_steepness = (
|
|
208
|
+
self.sw_settings.sharpwave_features.decay_steepness
|
|
209
|
+
or self.sw_settings.sharpwave_features.slope_ratio
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
self.need_rise_steepness = (
|
|
213
|
+
self.sw_settings.sharpwave_features.rise_steepness
|
|
214
|
+
or self.sw_settings.sharpwave_features.slope_ratio
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
self.need_steepness = self.need_rise_steepness or self.need_decay_steepness
|
|
218
|
+
|
|
219
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
220
|
+
"""Given a new data batch, the peaks, troughs and sharpwave features
|
|
221
|
+
are estimated. Importantly only new data is being analyzed here. In
|
|
222
|
+
steps of 1/settings["sampling_rate_features] analyzed and returned.
|
|
223
|
+
Pre-initialized filters are applied to each channel.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
data (np.ndarray): 2d data array with shape [num_channels, samples]
|
|
228
|
+
feature_results (dict): Features.py estimated features
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
feature_results (dict): set features for Features.py object
|
|
233
|
+
"""
|
|
234
|
+
dict_ch_features: dict[str, dict[str, float]] = defaultdict(lambda: {})
|
|
235
|
+
|
|
236
|
+
from scipy.signal import fftconvolve
|
|
237
|
+
|
|
238
|
+
data = np.tile(data[:, None, :], (1, len(self.list_filter), 1))
|
|
239
|
+
data = fftconvolve(data, self.filters, axes=2, mode="same")
|
|
240
|
+
|
|
241
|
+
self.filtered_data = (
|
|
242
|
+
data # TONI: Expose filtered data for example 3, need a better way
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
feature_results = {}
|
|
246
|
+
|
|
247
|
+
for (
|
|
248
|
+
(ch_idx, ch_name),
|
|
249
|
+
(filter_idx, filter_name),
|
|
250
|
+
detect_troughs,
|
|
251
|
+
estimator_combinations,
|
|
252
|
+
) in self.combinations:
|
|
253
|
+
sub_data = data[ch_idx, filter_idx, :]
|
|
254
|
+
|
|
255
|
+
key_name_pt = "Trough" if detect_troughs else "Peak"
|
|
256
|
+
|
|
257
|
+
if (not detect_troughs and not self.sw_settings.detect_peaks.estimate) or (
|
|
258
|
+
detect_troughs and not self.sw_settings.detect_troughs.estimate
|
|
259
|
+
):
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
# the detect_troughs loop start with peaks, s.t. data does not need to be flipped
|
|
263
|
+
sub_data = -sub_data if detect_troughs else sub_data
|
|
264
|
+
# sub_data *= 1 - 2 * detect_troughs # branchless version
|
|
265
|
+
|
|
266
|
+
waveform_results = self.analyze_waveform(sub_data)
|
|
267
|
+
|
|
268
|
+
# for each feature take the respective fun.
|
|
269
|
+
for feature_name, estimator_name, estimator in estimator_combinations:
|
|
270
|
+
feature_data = waveform_results[feature_name]
|
|
271
|
+
key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
|
|
272
|
+
|
|
273
|
+
# zero check because no peaks can be also detected
|
|
274
|
+
feature_data = estimator(feature_data) if len(feature_data) != 0 else 0
|
|
275
|
+
dict_ch_features[key_name][key_name_pt] = feature_data
|
|
276
|
+
|
|
277
|
+
if self.sw_settings.apply_estimator_between_peaks_and_troughs:
|
|
278
|
+
# apply between 'Trough' and 'Peak' the respective function again
|
|
279
|
+
# save only the 'est_fun' (e.g. max) between them
|
|
280
|
+
|
|
281
|
+
# the key_name stays, since the estimator function stays between peaks and troughs
|
|
282
|
+
for key_name, estimator in self.estimator_key_map.items():
|
|
283
|
+
feature_results[key_name] = estimator(
|
|
284
|
+
[
|
|
285
|
+
list(dict_ch_features[key_name].values())[0],
|
|
286
|
+
list(dict_ch_features[key_name].values())[1],
|
|
287
|
+
]
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
# otherwise, save all write all "flattened" key value pairs in feature_results
|
|
291
|
+
for key, subdict in dict_ch_features.items():
|
|
292
|
+
for key_sub, value_sub in subdict.items():
|
|
293
|
+
feature_results[key + "_analyze_" + key_sub] = value_sub
|
|
294
|
+
|
|
295
|
+
return feature_results
|
|
296
|
+
|
|
297
|
+
def analyze_waveform(self, data) -> dict:
|
|
298
|
+
"""Given the scipy.signal.find_peaks trough/peak distance
|
|
299
|
+
settings specified sharpwave features are estimated.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
from scipy.signal import find_peaks
|
|
303
|
+
|
|
304
|
+
# TODO: find peaks is actually not that big a performance hit, but the rest
|
|
305
|
+
# of this function is. Perhaps find_peaks can be put in a loop and the rest optimized somehow?
|
|
306
|
+
peak_idx: np.ndarray = find_peaks(
|
|
307
|
+
data, distance=self.sw_settings.detect_troughs.distance_peaks_ms
|
|
308
|
+
)[0]
|
|
309
|
+
trough_idx: np.ndarray = find_peaks(
|
|
310
|
+
-data, distance=self.sw_settings.detect_troughs.distance_troughs_ms
|
|
311
|
+
)[0]
|
|
312
|
+
|
|
313
|
+
""" Find left and right peak indexes for each trough """
|
|
314
|
+
peak_pointer = first_valid = last_valid = 0
|
|
315
|
+
peak_idx_left_list: list[int] = []
|
|
316
|
+
peak_idx_right_list: list[int] = []
|
|
317
|
+
|
|
318
|
+
for i in range(len(trough_idx)):
|
|
319
|
+
# Locate peak right of current trough
|
|
320
|
+
while (
|
|
321
|
+
peak_pointer < peak_idx.size and peak_idx[peak_pointer] < trough_idx[i]
|
|
322
|
+
):
|
|
323
|
+
peak_pointer += 1
|
|
324
|
+
|
|
325
|
+
if peak_pointer - 1 < 0:
|
|
326
|
+
# If trough has no peak to it's left, it's not valid
|
|
327
|
+
first_valid = i + 1 # Try with next one
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
if peak_pointer == peak_idx.size:
|
|
331
|
+
# If we went past the end of the peaks list, trough had no peak to its right
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
last_valid = i
|
|
335
|
+
peak_idx_left_list.append(peak_idx[peak_pointer - 1])
|
|
336
|
+
peak_idx_right_list.append(peak_idx[peak_pointer])
|
|
337
|
+
|
|
338
|
+
# Remove non valid troughs and make array of left and right peaks for each trough
|
|
339
|
+
trough_idx = trough_idx[first_valid : last_valid + 1]
|
|
340
|
+
peak_idx_left = np.array(peak_idx_left_list, dtype=int)
|
|
341
|
+
peak_idx_right = np.array(peak_idx_right_list, dtype=int)
|
|
342
|
+
|
|
343
|
+
""" Calculate features (vectorized) """
|
|
344
|
+
results: dict = {}
|
|
345
|
+
|
|
346
|
+
if self.need_peak_left:
|
|
347
|
+
results["peak_left"] = data[peak_idx_left]
|
|
348
|
+
|
|
349
|
+
if self.need_peak_right:
|
|
350
|
+
results["peak_right"] = data[peak_idx_right]
|
|
351
|
+
|
|
352
|
+
if self.need_trough:
|
|
353
|
+
results["trough"] = data[trough_idx]
|
|
354
|
+
|
|
355
|
+
if self.sw_settings.sharpwave_features.interval:
|
|
356
|
+
results["interval"] = np.concatenate((np.zeros(1), np.diff(trough_idx))) * (
|
|
357
|
+
1000 / self.sfreq
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
if self.sw_settings.sharpwave_features.sharpness:
|
|
361
|
+
# sharpess is calculated on a +- 5 ms window
|
|
362
|
+
# valid troughs need 5 ms of margin on both sides
|
|
363
|
+
troughs_valid = trough_idx[
|
|
364
|
+
np.logical_and(
|
|
365
|
+
trough_idx - int(5 * (1000 / self.sfreq)) > 0,
|
|
366
|
+
trough_idx + int(5 * (1000 / self.sfreq)) < data.shape[0],
|
|
367
|
+
)
|
|
368
|
+
]
|
|
369
|
+
trough_height = data[troughs_valid]
|
|
370
|
+
left_height = data[troughs_valid - int(5 * (1000 / self.sfreq))]
|
|
371
|
+
right_height = data[troughs_valid + int(5 * (1000 / self.sfreq))]
|
|
372
|
+
# results["sharpness"] = ((trough_height - left_height) + (trough_height - right_height)) / 2
|
|
373
|
+
results["sharpness"] = trough_height - 0.5 * (left_height + right_height)
|
|
374
|
+
|
|
375
|
+
if self.sw_settings.sharpwave_features.num_peaks:
|
|
376
|
+
results["num_peaks"] = [
|
|
377
|
+
trough_idx.shape[0]
|
|
378
|
+
] # keep list to the estimator can be applied
|
|
379
|
+
|
|
380
|
+
if self.need_steepness:
|
|
381
|
+
# steepness is calculated as the first derivative
|
|
382
|
+
steepness: np.ndarray = np.concatenate((np.zeros(1), np.diff(data)))
|
|
383
|
+
|
|
384
|
+
# Create an array with the rise and decay steepness for each trough
|
|
385
|
+
# 0th dimension for rise/decay, 1st for trough index, 2nd for timepoint
|
|
386
|
+
steepness_troughs = np.zeros((2, trough_idx.shape[0], steepness.shape[0]))
|
|
387
|
+
if self.need_rise_steepness or self.need_decay_steepness:
|
|
388
|
+
for i in range(len(trough_idx)):
|
|
389
|
+
steepness_troughs[
|
|
390
|
+
0, i, 0 : trough_idx[i] - peak_idx_left[i] + 1
|
|
391
|
+
] = steepness[peak_idx_left[i] : trough_idx[i] + 1]
|
|
392
|
+
steepness_troughs[
|
|
393
|
+
1, i, 0 : peak_idx_right[i] - trough_idx[i] + 1
|
|
394
|
+
] = steepness[trough_idx[i] : peak_idx_right[i] + 1]
|
|
395
|
+
|
|
396
|
+
if self.need_rise_steepness:
|
|
397
|
+
# left peak -> trough
|
|
398
|
+
# + 1 due to python syntax, s.t. the last element is included
|
|
399
|
+
results["rise_steepness"] = np.max(
|
|
400
|
+
np.abs(steepness_troughs[0, :, :]), axis=1
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
if self.need_decay_steepness:
|
|
404
|
+
# trough -> right peak
|
|
405
|
+
results["decay_steepness"] = np.max(
|
|
406
|
+
np.abs(steepness_troughs[1, :, :]), axis=1
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if self.sw_settings.sharpwave_features.slope_ratio:
|
|
410
|
+
results["slope_ratio"] = (
|
|
411
|
+
results["rise_steepness"] - results["decay_steepness"]
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
if self.sw_settings.sharpwave_features.prominence:
|
|
415
|
+
results["prominence"] = np.abs(
|
|
416
|
+
(results["peak_right"] + results["peak_left"]) / 2 - results["trough"]
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if self.sw_settings.sharpwave_features.decay_time:
|
|
420
|
+
results["decay_time"] = (peak_idx_left - trough_idx) * (
|
|
421
|
+
1000 / self.sfreq
|
|
422
|
+
) # ms
|
|
423
|
+
|
|
424
|
+
if self.sw_settings.sharpwave_features.rise_time:
|
|
425
|
+
results["rise_time"] = (peak_idx_right - trough_idx) * (
|
|
426
|
+
1000 / self.sfreq
|
|
427
|
+
) # ms
|
|
428
|
+
|
|
429
|
+
if self.sw_settings.sharpwave_features.width:
|
|
430
|
+
results["width"] = peak_idx_right - peak_idx_left # ms
|
|
431
|
+
|
|
432
|
+
return results
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from py_neuromodulation.utils.types import NMBaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from py_neuromodulation.stream.settings import NMSettings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class KalmanSettings(NMBaseModel):
|
|
12
|
+
Tp: float = 0.1
|
|
13
|
+
sigma_w: float = 0.7
|
|
14
|
+
sigma_v: float = 1.0
|
|
15
|
+
frequency_bands: list[str] = [
|
|
16
|
+
"theta",
|
|
17
|
+
"alpha",
|
|
18
|
+
"low_beta",
|
|
19
|
+
"high_beta",
|
|
20
|
+
"low_gamma",
|
|
21
|
+
"high_gamma",
|
|
22
|
+
"HFA",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
def validate_fbands(self, settings: "NMSettings") -> None:
|
|
26
|
+
assert all(
|
|
27
|
+
(item in settings.frequency_ranges_hz for item in self.frequency_bands)
|
|
28
|
+
), (
|
|
29
|
+
"Frequency bands for Kalman filter must also be specified in "
|
|
30
|
+
"bandpass_filter_settings."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def define_KF(Tp, sigma_w, sigma_v):
|
|
35
|
+
"""Define Kalman filter according to white noise acceleration model.
|
|
36
|
+
See DOI: 10.1109/TBME.2009.2038990 for explanation
|
|
37
|
+
See https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html#r64ca38088676-2 for implementation details
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
Tp : float
|
|
42
|
+
prediction interval
|
|
43
|
+
sigma_w : float
|
|
44
|
+
process noise
|
|
45
|
+
sigma_v : float
|
|
46
|
+
measurement noise
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
filterpy.KalmanFilter
|
|
51
|
+
initialized KalmanFilter object
|
|
52
|
+
"""
|
|
53
|
+
from .kalman_filter_external import KalmanFilter
|
|
54
|
+
|
|
55
|
+
f = KalmanFilter(dim_x=2, dim_z=1)
|
|
56
|
+
f.x = np.array([0, 1]) # x here sensor signal and it's first derivative
|
|
57
|
+
f.F = np.array([[1, Tp], [0, 1]])
|
|
58
|
+
f.H = np.array([[1, 0]])
|
|
59
|
+
f.R = sigma_v
|
|
60
|
+
f.Q = np.array(
|
|
61
|
+
[
|
|
62
|
+
[(sigma_w**2) * (Tp**3) / 3, (sigma_w**2) * (Tp**2) / 2],
|
|
63
|
+
[(sigma_w**2) * (Tp**2) / 2, (sigma_w**2) * Tp],
|
|
64
|
+
]
|
|
65
|
+
)
|
|
66
|
+
f.P = np.cov([[1, 0], [0, 1]])
|
|
67
|
+
return f
|