py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/__init__.py +80 -13
- py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
- py_neuromodulation/analysis/__init__.py +4 -0
- py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
- py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
- py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
- py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/default_settings.yaml +241 -0
- py_neuromodulation/features/__init__.py +31 -0
- py_neuromodulation/features/bandpower.py +165 -0
- py_neuromodulation/features/bispectra.py +157 -0
- py_neuromodulation/features/bursts.py +297 -0
- py_neuromodulation/features/coherence.py +255 -0
- py_neuromodulation/features/feature_processor.py +121 -0
- py_neuromodulation/features/fooof.py +142 -0
- py_neuromodulation/features/hjorth_raw.py +57 -0
- py_neuromodulation/features/linelength.py +21 -0
- py_neuromodulation/features/mne_connectivity.py +148 -0
- py_neuromodulation/features/nolds.py +94 -0
- py_neuromodulation/features/oscillatory.py +249 -0
- py_neuromodulation/features/sharpwaves.py +432 -0
- py_neuromodulation/filter/__init__.py +3 -0
- py_neuromodulation/filter/kalman_filter.py +67 -0
- py_neuromodulation/filter/kalman_filter_external.py +1890 -0
- py_neuromodulation/filter/mne_filter.py +128 -0
- py_neuromodulation/filter/notch_filter.py +93 -0
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/processing/__init__.py +10 -0
- py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
- py_neuromodulation/processing/data_preprocessor.py +77 -0
- py_neuromodulation/processing/filter_preprocessing.py +78 -0
- py_neuromodulation/processing/normalization.py +175 -0
- py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
- py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
- py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
- py_neuromodulation/stream/__init__.py +3 -0
- py_neuromodulation/stream/data_processor.py +325 -0
- py_neuromodulation/stream/generator.py +53 -0
- py_neuromodulation/stream/mnelsl_player.py +94 -0
- py_neuromodulation/stream/mnelsl_stream.py +120 -0
- py_neuromodulation/stream/settings.py +292 -0
- py_neuromodulation/stream/stream.py +427 -0
- py_neuromodulation/utils/__init__.py +2 -0
- py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
- py_neuromodulation/utils/database.py +149 -0
- py_neuromodulation/utils/io.py +378 -0
- py_neuromodulation/utils/keyboard.py +52 -0
- py_neuromodulation/utils/logging.py +66 -0
- py_neuromodulation/utils/types.py +251 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
- py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/FieldTrip.py +0 -589
- py_neuromodulation/_write_example_dataset_helper.py +0 -65
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_IO.py +0 -417
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_bispectra.py +0 -168
- py_neuromodulation/nm_bursts.py +0 -198
- py_neuromodulation/nm_coherence.py +0 -205
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features.py +0 -116
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_filter.py +0 -219
- py_neuromodulation/nm_filter_preprocessing.py +0 -91
- py_neuromodulation/nm_fooof.py +0 -159
- py_neuromodulation/nm_generator.py +0 -37
- py_neuromodulation/nm_hjorth_raw.py +0 -73
- py_neuromodulation/nm_kalmanfilter.py +0 -58
- py_neuromodulation/nm_linelength.py +0 -33
- py_neuromodulation/nm_mne_connectivity.py +0 -112
- py_neuromodulation/nm_nolds.py +0 -93
- py_neuromodulation/nm_normalization.py +0 -214
- py_neuromodulation/nm_oscillatory.py +0 -448
- py_neuromodulation/nm_run_analysis.py +0 -435
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_settings.py +0 -68
- py_neuromodulation/nm_sharpwaves.py +0 -401
- py_neuromodulation/nm_stream_abc.py +0 -218
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from py_neuromodulation import NMSettings
|
|
9
|
+
from mne.io import RawArray
|
|
10
|
+
from mne import Epochs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MNEConnectivitySettings(NMBaseModel):
|
|
14
|
+
method: str = "plv"
|
|
15
|
+
mode: str = "multitaper"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MNEConnectivity(NMFeature):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
settings: "NMSettings",
|
|
22
|
+
ch_names: Iterable[str],
|
|
23
|
+
sfreq: float,
|
|
24
|
+
) -> None:
|
|
25
|
+
from mne import create_info
|
|
26
|
+
|
|
27
|
+
self.settings = settings
|
|
28
|
+
|
|
29
|
+
self.ch_names = ch_names
|
|
30
|
+
self.sfreq = sfreq
|
|
31
|
+
|
|
32
|
+
# Params used by spectral_connectivity_epochs
|
|
33
|
+
self.mode = settings.mne_connectivity_settings.mode
|
|
34
|
+
self.method = settings.mne_connectivity_settings.method
|
|
35
|
+
|
|
36
|
+
self.fbands = settings.frequency_ranges_hz
|
|
37
|
+
self.fband_ranges: list = []
|
|
38
|
+
self.result_keys = []
|
|
39
|
+
|
|
40
|
+
self.raw_info = create_info(ch_names=self.ch_names, sfreq=self.sfreq)
|
|
41
|
+
self.raw_array: "RawArray"
|
|
42
|
+
self.epochs: "Epochs"
|
|
43
|
+
self.prev_batch_shape: tuple = (-1, -1) # sentinel value
|
|
44
|
+
|
|
45
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
46
|
+
from mne.io import RawArray
|
|
47
|
+
from mne import Epochs
|
|
48
|
+
from mne_connectivity import spectral_connectivity_epochs
|
|
49
|
+
import pandas as pd
|
|
50
|
+
|
|
51
|
+
time_samples_s = data.shape[1] / self.sfreq
|
|
52
|
+
epoch_length: float = 1 # TODO: Make this a parameter?
|
|
53
|
+
|
|
54
|
+
if epoch_length > time_samples_s:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"the intended epoch length for mne connectivity: {epoch_length}s"
|
|
57
|
+
f" are longer than the passed data array {np.round(time_samples_s, 2)}s"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Only reinitialize the raw_array and epochs object if the data shape has changed
|
|
61
|
+
# That could mean that the channels have been re-selected, or we're in the last batch
|
|
62
|
+
# TODO: If sfreq or channels change, do we re-initialize the whole Stream object?
|
|
63
|
+
if data.shape != self.prev_batch_shape:
|
|
64
|
+
self.raw_array = RawArray(
|
|
65
|
+
data=data,
|
|
66
|
+
info=self.raw_info,
|
|
67
|
+
copy=None, # type: ignore
|
|
68
|
+
verbose=False,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# self.events = make_fixed_length_events(self.raw_array, duration=epoch_length)
|
|
72
|
+
# Equivalent code for those parameters:
|
|
73
|
+
event_times = np.arange(
|
|
74
|
+
0, data.shape[-1], self.sfreq * epoch_length, dtype=int
|
|
75
|
+
)
|
|
76
|
+
events = np.column_stack(
|
|
77
|
+
(
|
|
78
|
+
event_times,
|
|
79
|
+
np.zeros_like(event_times, dtype=int),
|
|
80
|
+
np.ones_like(event_times, dtype=int),
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# there need to be minimum 2 of two epochs, otherwise mne_connectivity
|
|
85
|
+
# is not correctly initialized
|
|
86
|
+
if events.shape[0] < 2:
|
|
87
|
+
raise RuntimeError(
|
|
88
|
+
f"A minimum of 2 epochs is required for mne_connectivity,"
|
|
89
|
+
f" got only {events.shape[0]}. Increase settings['segment_length_features_ms']"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.epochs = Epochs(
|
|
93
|
+
self.raw_array,
|
|
94
|
+
events=events,
|
|
95
|
+
event_id={"rest": 1},
|
|
96
|
+
tmin=0,
|
|
97
|
+
tmax=epoch_length,
|
|
98
|
+
baseline=None,
|
|
99
|
+
reject_by_annotation=True,
|
|
100
|
+
verbose=False,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Trick the function "spectral_connectivity_epochs" into not calling "add_annotations_to_metadata"
|
|
104
|
+
# TODO: This is a hack, and maybe needs a fix in the mne_connectivity library
|
|
105
|
+
self.epochs._metadata = pd.DataFrame(index=np.arange(events.shape[0]))
|
|
106
|
+
|
|
107
|
+
else:
|
|
108
|
+
# As long as the initialization parameters, channels, sfreq and batch size are the same
|
|
109
|
+
# We can re-use the existing epochs object by updating the raw data
|
|
110
|
+
self.raw_array._data = data
|
|
111
|
+
self.epochs._raw = self.raw_array
|
|
112
|
+
|
|
113
|
+
# n_jobs is here kept to 1, since setup of the multiprocessing Pool
|
|
114
|
+
# takes longer than most batch computing sizes
|
|
115
|
+
spec_out = spectral_connectivity_epochs(
|
|
116
|
+
data=self.epochs,
|
|
117
|
+
sfreq=self.sfreq,
|
|
118
|
+
method=self.method,
|
|
119
|
+
mode=self.mode,
|
|
120
|
+
indices=(np.array([0, 0, 1, 1]), np.array([2, 3, 2, 3])),
|
|
121
|
+
verbose=False,
|
|
122
|
+
)
|
|
123
|
+
dat_conn: np.ndarray = spec_out.get_data()
|
|
124
|
+
|
|
125
|
+
# Get frequency band ranges only for the first batch, it's already the same
|
|
126
|
+
if len(self.fband_ranges) == 0:
|
|
127
|
+
for fband_range in self.fbands.values():
|
|
128
|
+
self.fband_ranges.append(
|
|
129
|
+
np.where(
|
|
130
|
+
(np.array(spec_out.freqs) > fband_range[0])
|
|
131
|
+
& (np.array(spec_out.freqs) < fband_range[1])
|
|
132
|
+
)[0]
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# TODO: If I compute the mean for the entire fband, results are almost the same before
|
|
136
|
+
# normalization (0.9999999... vs 1.0), but some change wildly after normalization (-3 vs 0)
|
|
137
|
+
# Investigate why, is this a bug in normalization?
|
|
138
|
+
feature_results = {}
|
|
139
|
+
for conn in np.arange(dat_conn.shape[0]):
|
|
140
|
+
for fband_idx, fband in enumerate(self.fbands):
|
|
141
|
+
feature_results["_".join(["ch1", self.method, str(conn), fband])] = (
|
|
142
|
+
np.mean(dat_conn[conn, self.fband_ranges[fband_idx]])
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Store current experiment parameters to check if re-initialization is needed
|
|
146
|
+
self.prev_batch_shape = data.shape
|
|
147
|
+
|
|
148
|
+
return feature_results
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from py_neuromodulation.utils.types import NMFeature, BoolSelector, NMBaseModel
|
|
7
|
+
|
|
8
|
+
from pydantic import field_validator
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from py_neuromodulation import NMSettings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NoldsFeatures(BoolSelector):
|
|
15
|
+
sample_entropy: bool = False
|
|
16
|
+
correlation_dimension: bool = False
|
|
17
|
+
lyapunov_exponent: bool = True
|
|
18
|
+
hurst_exponent: bool = False
|
|
19
|
+
detrended_fluctuation_analysis: bool = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NoldsSettings(NMBaseModel):
|
|
23
|
+
raw: bool = True
|
|
24
|
+
frequency_bands: list[str] = ["low_beta"]
|
|
25
|
+
features: NoldsFeatures = NoldsFeatures()
|
|
26
|
+
|
|
27
|
+
@field_validator("frequency_bands")
|
|
28
|
+
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
29
|
+
return [f.replace(" ", "_") for f in frequency_bands]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Nolds(NMFeature):
|
|
33
|
+
def __init__(
|
|
34
|
+
self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
|
|
35
|
+
) -> None:
|
|
36
|
+
self.settings = settings.nolds_settings
|
|
37
|
+
self.ch_names = ch_names
|
|
38
|
+
|
|
39
|
+
if len(self.settings.frequency_bands) > 0:
|
|
40
|
+
from py_neuromodulation.features.bandpower import BandPower
|
|
41
|
+
|
|
42
|
+
self.bp_filter = BandPower(settings, ch_names, sfreq, use_kf=False)
|
|
43
|
+
|
|
44
|
+
# Check if the selected frequency bands are defined in the global settings
|
|
45
|
+
for fb in settings.nolds_settings.frequency_bands:
|
|
46
|
+
assert (
|
|
47
|
+
fb in settings.frequency_ranges_hz
|
|
48
|
+
), f"{fb} selected in nolds_features, but not defined in s['frequency_ranges_hz']"
|
|
49
|
+
|
|
50
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
51
|
+
feature_results = {}
|
|
52
|
+
data = np.nan_to_num(data)
|
|
53
|
+
if self.settings.raw:
|
|
54
|
+
feature_results = self.calc_nolds(data, feature_results)
|
|
55
|
+
if len(self.settings.frequency_bands) > 0:
|
|
56
|
+
data_filt = self.bp_filter.bandpass_filter.filter_data(data)
|
|
57
|
+
|
|
58
|
+
for f_band_idx, f_band in enumerate(self.settings.frequency_bands):
|
|
59
|
+
# filter data now for a specific fband and pass to calc_nolds
|
|
60
|
+
feature_results = self.calc_nolds(
|
|
61
|
+
data_filt[:, f_band_idx, :], feature_results, f_band
|
|
62
|
+
) # ch, bands, samples
|
|
63
|
+
return feature_results
|
|
64
|
+
|
|
65
|
+
def calc_nolds(
|
|
66
|
+
self, data: np.ndarray, feature_results: dict, data_str: str = "raw"
|
|
67
|
+
) -> dict:
|
|
68
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
69
|
+
for f_name in self.settings.features.get_enabled():
|
|
70
|
+
feature_results[f"{ch_name}_nolds_{f_name}_{data_str}"] = (
|
|
71
|
+
self.calc_nolds_feature(f_name, data[ch_idx, :])
|
|
72
|
+
if data[ch_idx, :].sum()
|
|
73
|
+
else 0
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return feature_results
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def calc_nolds_feature(f_name: str, dat: np.ndarray):
|
|
80
|
+
import nolds
|
|
81
|
+
|
|
82
|
+
match f_name:
|
|
83
|
+
case "sample_entropy":
|
|
84
|
+
return nolds.sampen(dat)
|
|
85
|
+
case "correlation_dimension":
|
|
86
|
+
return nolds.corr_dim(dat, emb_dim=2)
|
|
87
|
+
case "lyapunov_exponent":
|
|
88
|
+
return nolds.lyap_r(dat)
|
|
89
|
+
case "hurst_exponent":
|
|
90
|
+
return nolds.hurst_rs(dat)
|
|
91
|
+
case "detrended_fluctuation_analysis":
|
|
92
|
+
return nolds.dfa(dat)
|
|
93
|
+
case _:
|
|
94
|
+
raise ValueError(f"Invalid nolds feature name: {f_name}")
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
import numpy as np
|
|
3
|
+
from itertools import product
|
|
4
|
+
|
|
5
|
+
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from py_neuromodulation.stream.settings import NMSettings
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OscillatoryFeatures(BoolSelector):
|
|
13
|
+
mean: bool = True
|
|
14
|
+
median: bool = False
|
|
15
|
+
std: bool = False
|
|
16
|
+
max: bool = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OscillatorySettings(NMBaseModel):
|
|
20
|
+
windowlength_ms: int = 1000
|
|
21
|
+
log_transform: bool = True
|
|
22
|
+
features: OscillatoryFeatures = OscillatoryFeatures(
|
|
23
|
+
mean=True, median=False, std=False, max=False
|
|
24
|
+
)
|
|
25
|
+
return_spectrum: bool = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
ESTIMATOR_DICT = {
|
|
29
|
+
"mean": np.nanmean,
|
|
30
|
+
"median": np.nanmedian,
|
|
31
|
+
"std": np.nanstd,
|
|
32
|
+
"max": np.nanmax,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OscillatoryFeature(NMFeature):
|
|
37
|
+
def __init__(
|
|
38
|
+
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int
|
|
39
|
+
) -> None:
|
|
40
|
+
settings.validate()
|
|
41
|
+
self.settings: OscillatorySettings # Assignment in subclass __init__
|
|
42
|
+
self.osc_feature_name: str # Required for output
|
|
43
|
+
|
|
44
|
+
self.sfreq = int(sfreq)
|
|
45
|
+
self.ch_names = ch_names
|
|
46
|
+
|
|
47
|
+
self.frequency_ranges = settings.frequency_ranges_hz
|
|
48
|
+
|
|
49
|
+
# Test settings
|
|
50
|
+
assert self.settings.windowlength_ms <= settings.segment_length_features_ms, (
|
|
51
|
+
f"oscillatory feature windowlength_ms = ({self.settings.windowlength_ms})"
|
|
52
|
+
f"needs to be smaller than"
|
|
53
|
+
f"settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class FFT(OscillatoryFeature):
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
settings: "NMSettings",
|
|
61
|
+
ch_names: Sequence[str],
|
|
62
|
+
sfreq: int,
|
|
63
|
+
) -> None:
|
|
64
|
+
from scipy.fft import rfftfreq
|
|
65
|
+
|
|
66
|
+
self.osc_feature_name = "fft"
|
|
67
|
+
self.settings = settings.fft_settings
|
|
68
|
+
# super.__init__ needs osc_feature_name and settings
|
|
69
|
+
super().__init__(settings, ch_names, sfreq)
|
|
70
|
+
|
|
71
|
+
window_ms = self.settings.windowlength_ms
|
|
72
|
+
|
|
73
|
+
self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
|
|
74
|
+
self.freqs = rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq))
|
|
75
|
+
|
|
76
|
+
# Pre-calculate frequency ranges
|
|
77
|
+
self.idx_range = [
|
|
78
|
+
(
|
|
79
|
+
f_band,
|
|
80
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
|
|
81
|
+
)
|
|
82
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
self.estimators = [
|
|
86
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
90
|
+
data = data[:, self.window_samples :]
|
|
91
|
+
|
|
92
|
+
from scipy.fft import rfft
|
|
93
|
+
|
|
94
|
+
Z = np.abs(rfft(data)) # type: ignore
|
|
95
|
+
|
|
96
|
+
if self.settings.log_transform:
|
|
97
|
+
Z = np.log10(Z)
|
|
98
|
+
|
|
99
|
+
feature_results = {}
|
|
100
|
+
|
|
101
|
+
for f_band_name, idx_range in self.idx_range:
|
|
102
|
+
# TODO Can we get rid of this for-loop? Hard to vectorize windows of different lengths...
|
|
103
|
+
Z_band = Z[:, idx_range] # Data for all channels
|
|
104
|
+
|
|
105
|
+
for est_name, est_fun in self.estimators:
|
|
106
|
+
result = est_fun(Z_band, axis=1)
|
|
107
|
+
|
|
108
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
109
|
+
feature_results[
|
|
110
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
111
|
+
] = result[ch_idx]
|
|
112
|
+
|
|
113
|
+
if self.settings.return_spectrum:
|
|
114
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
115
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
116
|
+
feature_results[f"{ch_name}_fft_psd_{int(f)}"] = Z[ch_idx][idx]
|
|
117
|
+
|
|
118
|
+
return feature_results
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Welch(OscillatoryFeature):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
settings: "NMSettings",
|
|
125
|
+
ch_names: Sequence[str],
|
|
126
|
+
sfreq: int,
|
|
127
|
+
) -> None:
|
|
128
|
+
from scipy.fft import rfftfreq
|
|
129
|
+
|
|
130
|
+
self.osc_feature_name = "welch"
|
|
131
|
+
self.settings = settings.welch_settings
|
|
132
|
+
# super.__init__ needs osc_feature_name and settings
|
|
133
|
+
super().__init__(settings, ch_names, sfreq)
|
|
134
|
+
|
|
135
|
+
self.freqs = rfftfreq(self.sfreq, 1 / self.sfreq)
|
|
136
|
+
|
|
137
|
+
self.idx_range = [
|
|
138
|
+
(
|
|
139
|
+
f_band,
|
|
140
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
|
|
141
|
+
)
|
|
142
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
self.estimators = [
|
|
146
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
150
|
+
from scipy.signal import welch
|
|
151
|
+
|
|
152
|
+
_, Z = welch(
|
|
153
|
+
data,
|
|
154
|
+
fs=self.sfreq,
|
|
155
|
+
window="hann",
|
|
156
|
+
nperseg=self.sfreq,
|
|
157
|
+
noverlap=None,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self.settings.log_transform:
|
|
161
|
+
Z = np.log10(Z)
|
|
162
|
+
|
|
163
|
+
feature_results = {}
|
|
164
|
+
|
|
165
|
+
for f_band_name, idx_range in self.idx_range:
|
|
166
|
+
Z_band = Z[:, idx_range]
|
|
167
|
+
|
|
168
|
+
for est_name, est_fun in self.estimators:
|
|
169
|
+
result = est_fun(Z_band, axis=1)
|
|
170
|
+
|
|
171
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
172
|
+
feature_results[
|
|
173
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
174
|
+
] = result[ch_idx]
|
|
175
|
+
|
|
176
|
+
if self.settings.return_spectrum:
|
|
177
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
178
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
179
|
+
feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
|
|
180
|
+
|
|
181
|
+
return feature_results
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class STFT(OscillatoryFeature):
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
settings: "NMSettings",
|
|
188
|
+
ch_names: Sequence[str],
|
|
189
|
+
sfreq: int,
|
|
190
|
+
) -> None:
|
|
191
|
+
from scipy.fft import rfftfreq
|
|
192
|
+
|
|
193
|
+
self.osc_feature_name = "stft"
|
|
194
|
+
self.settings = settings.stft_settings
|
|
195
|
+
# super.__init__ needs osc_feature_name and settings
|
|
196
|
+
super().__init__(settings, ch_names, sfreq)
|
|
197
|
+
|
|
198
|
+
self.nperseg = self.settings.windowlength_ms
|
|
199
|
+
|
|
200
|
+
self.freqs = rfftfreq(self.nperseg, 1 / self.sfreq)
|
|
201
|
+
|
|
202
|
+
self.idx_range = [
|
|
203
|
+
(
|
|
204
|
+
f_band,
|
|
205
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs <= f_range[1]))[0],
|
|
206
|
+
)
|
|
207
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
self.estimators = [
|
|
211
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
215
|
+
from scipy.signal import stft
|
|
216
|
+
|
|
217
|
+
_, _, Zxx = stft(
|
|
218
|
+
data,
|
|
219
|
+
fs=self.sfreq,
|
|
220
|
+
window="hamming",
|
|
221
|
+
nperseg=self.nperseg,
|
|
222
|
+
boundary="even",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
Z = np.abs(Zxx)
|
|
226
|
+
if self.settings.log_transform:
|
|
227
|
+
Z = np.log10(Z)
|
|
228
|
+
|
|
229
|
+
feature_results = {}
|
|
230
|
+
|
|
231
|
+
for f_band_name, idx_range in self.idx_range:
|
|
232
|
+
Z_band = Z[:, idx_range, :]
|
|
233
|
+
|
|
234
|
+
for est_name, est_fun in self.estimators:
|
|
235
|
+
result = est_fun(Z_band, axis=(1, 2))
|
|
236
|
+
|
|
237
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
238
|
+
feature_results[
|
|
239
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
240
|
+
] = result[ch_idx]
|
|
241
|
+
|
|
242
|
+
if self.settings.return_spectrum:
|
|
243
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
244
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
245
|
+
feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
|
|
246
|
+
axis=1
|
|
247
|
+
)[idx]
|
|
248
|
+
|
|
249
|
+
return feature_results
|