py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
- py_neuromodulation/__init__.py +12 -4
- py_neuromodulation/analysis/RMAP.py +3 -3
- py_neuromodulation/analysis/decode.py +55 -2
- py_neuromodulation/analysis/feature_reader.py +1 -0
- py_neuromodulation/analysis/stats.py +3 -3
- py_neuromodulation/default_settings.yaml +24 -17
- py_neuromodulation/features/bandpower.py +65 -23
- py_neuromodulation/features/bursts.py +9 -8
- py_neuromodulation/features/coherence.py +7 -4
- py_neuromodulation/features/feature_processor.py +4 -4
- py_neuromodulation/features/fooof.py +7 -6
- py_neuromodulation/features/mne_connectivity.py +25 -3
- py_neuromodulation/features/oscillatory.py +5 -4
- py_neuromodulation/features/sharpwaves.py +21 -0
- py_neuromodulation/filter/kalman_filter.py +17 -6
- py_neuromodulation/gui/__init__.py +3 -0
- py_neuromodulation/gui/backend/app_backend.py +419 -0
- py_neuromodulation/gui/backend/app_manager.py +345 -0
- py_neuromodulation/gui/backend/app_pynm.py +244 -0
- py_neuromodulation/gui/backend/app_socket.py +95 -0
- py_neuromodulation/gui/backend/app_utils.py +306 -0
- py_neuromodulation/gui/backend/app_window.py +202 -0
- py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
- py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
- py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
- py_neuromodulation/gui/frontend/charite.svg +16 -0
- py_neuromodulation/gui/frontend/index.html +14 -0
- py_neuromodulation/gui/window_api.py +115 -0
- py_neuromodulation/lsl_api.cfg +3 -0
- py_neuromodulation/processing/data_preprocessor.py +9 -2
- py_neuromodulation/processing/filter_preprocessing.py +43 -27
- py_neuromodulation/processing/normalization.py +32 -17
- py_neuromodulation/processing/projection.py +2 -2
- py_neuromodulation/processing/resample.py +6 -2
- py_neuromodulation/run_gui.py +36 -0
- py_neuromodulation/stream/__init__.py +7 -1
- py_neuromodulation/stream/backend_interface.py +47 -0
- py_neuromodulation/stream/data_processor.py +24 -3
- py_neuromodulation/stream/mnelsl_player.py +121 -21
- py_neuromodulation/stream/mnelsl_stream.py +9 -17
- py_neuromodulation/stream/settings.py +80 -34
- py_neuromodulation/stream/stream.py +82 -62
- py_neuromodulation/utils/channels.py +1 -1
- py_neuromodulation/utils/file_writer.py +110 -0
- py_neuromodulation/utils/io.py +46 -5
- py_neuromodulation/utils/perf.py +156 -0
- py_neuromodulation/utils/pydantic_extensions.py +322 -0
- py_neuromodulation/utils/types.py +33 -107
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +18 -4
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +55 -35
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
- py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
py_neuromodulation/__init__.py
CHANGED
|
@@ -4,11 +4,12 @@ from pathlib import PurePath
|
|
|
4
4
|
from importlib.metadata import version
|
|
5
5
|
from py_neuromodulation.utils.logging import NMLogger
|
|
6
6
|
|
|
7
|
+
|
|
7
8
|
#####################################
|
|
8
9
|
# Globals and environment variables #
|
|
9
10
|
#####################################
|
|
10
11
|
|
|
11
|
-
__version__ = version("py_neuromodulation")
|
|
12
|
+
__version__ = version("py_neuromodulation")
|
|
12
13
|
|
|
13
14
|
# Check if the module is running headless (no display) for tests and doc builds
|
|
14
15
|
PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
|
|
@@ -16,6 +17,9 @@ PYNM_DIR = PurePath(__file__).parent # Define constant for py_nm directory
|
|
|
16
17
|
|
|
17
18
|
os.environ["MPLBACKEND"] = "agg" if PYNM_HEADLESS else "qtagg" # Set matplotlib backend
|
|
18
19
|
|
|
20
|
+
os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
|
|
21
|
+
|
|
22
|
+
|
|
19
23
|
# Set environment variable MNE_LSL_LIB (required to import Stream below)
|
|
20
24
|
LSL_DICT = {
|
|
21
25
|
"windows_32bit": "windows/x86/liblsl.1.16.2.dll",
|
|
@@ -34,15 +38,17 @@ LSL_DICT = {
|
|
|
34
38
|
|
|
35
39
|
PLATFORM = platform.system().lower().strip()
|
|
36
40
|
ARCH = platform.architecture()[0]
|
|
41
|
+
|
|
37
42
|
match PLATFORM:
|
|
38
43
|
case "windows":
|
|
39
44
|
KEY = PLATFORM + "_" + ARCH
|
|
40
45
|
case "darwin":
|
|
41
46
|
KEY = PLATFORM + "_" + platform.processor()
|
|
42
47
|
case "linux":
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
48
|
+
if "VERSION_CODENAME" in platform.freedesktop_os_release().keys():
|
|
49
|
+
DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
|
|
50
|
+
KEY = PLATFORM + "_" + DIST + "_" + ARCH
|
|
51
|
+
else:
|
|
46
52
|
KEY = PLATFORM + "_" + ARCH
|
|
47
53
|
case _:
|
|
48
54
|
KEY = ""
|
|
@@ -78,3 +84,5 @@ from . import stream
|
|
|
78
84
|
from . import analysis
|
|
79
85
|
|
|
80
86
|
from .stream.settings import get_default_settings, get_fast_compute, reset_settings
|
|
87
|
+
|
|
88
|
+
from .gui.backend.app_manager import AppManager as App
|
|
@@ -8,8 +8,8 @@ import pandas as pd
|
|
|
8
8
|
import nibabel as nib
|
|
9
9
|
from matplotlib import pyplot as plt
|
|
10
10
|
|
|
11
|
-
from py_neuromodulation.
|
|
12
|
-
from py_neuromodulation.types import _PathLike
|
|
11
|
+
from py_neuromodulation.analysis import reg_plot
|
|
12
|
+
from py_neuromodulation.utils.types import _PathLike
|
|
13
13
|
from py_neuromodulation import PYNM_DIR
|
|
14
14
|
|
|
15
15
|
LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
|
|
@@ -335,7 +335,7 @@ class RMAPCross_Val_ChannelSelector:
|
|
|
335
335
|
reshape: bool = True,
|
|
336
336
|
):
|
|
337
337
|
if reshape:
|
|
338
|
-
fp = np.reshape(fp, (91, 109, 91), order="
|
|
338
|
+
fp = np.reshape(fp, (91, 109, 91), order="C")
|
|
339
339
|
|
|
340
340
|
img = nib.nifti1.Nifti1Image(fp, affine=affine)
|
|
341
341
|
|
|
@@ -6,14 +6,67 @@ from sklearn.metrics import r2_score
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import numpy as np
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from pathlib import PurePath
|
|
9
|
+
from pathlib import Path, PurePath
|
|
10
10
|
import pickle
|
|
11
11
|
|
|
12
12
|
from py_neuromodulation import logger
|
|
13
|
+
from py_neuromodulation.utils.types import _PathLike
|
|
13
14
|
|
|
14
15
|
from typing import Callable
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
class RealTimeDecoder:
|
|
19
|
+
def __init__(self, model_path: _PathLike):
|
|
20
|
+
self.model_path = Path(model_path)
|
|
21
|
+
if not self.model_path.exists():
|
|
22
|
+
raise FileNotFoundError(f"Model file {self.model_path} not found")
|
|
23
|
+
if not self.model_path.is_file():
|
|
24
|
+
raise IsADirectoryError(f"Model file {self.model_path} is a directory")
|
|
25
|
+
|
|
26
|
+
if self.model_path.suffix == ".skops":
|
|
27
|
+
from skops import io as skops_io
|
|
28
|
+
|
|
29
|
+
self.model = skops_io.load(self.model_path)
|
|
30
|
+
else:
|
|
31
|
+
return NotImplementedError("Only skops models are supported")
|
|
32
|
+
|
|
33
|
+
def predict(
|
|
34
|
+
self,
|
|
35
|
+
feature_dict: dict,
|
|
36
|
+
channel: str | None = None,
|
|
37
|
+
fft_bands_only: bool = True,
|
|
38
|
+
) -> dict:
|
|
39
|
+
try:
|
|
40
|
+
if channel is not None:
|
|
41
|
+
features_ch = {
|
|
42
|
+
f: feature_dict[f]
|
|
43
|
+
for f in feature_dict.keys()
|
|
44
|
+
if f.startswith(channel)
|
|
45
|
+
}
|
|
46
|
+
if fft_bands_only is True:
|
|
47
|
+
features_ch_fft = {
|
|
48
|
+
f: features_ch[f]
|
|
49
|
+
for f in features_ch.keys()
|
|
50
|
+
if "fft" in f and "psd" not in f
|
|
51
|
+
}
|
|
52
|
+
out = self.model.predict_proba(
|
|
53
|
+
np.array(list(features_ch_fft.values())).reshape(1, -1)
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
out = self.model.predict_proba(features_ch)
|
|
57
|
+
else:
|
|
58
|
+
out = self.model.predict(feature_dict)
|
|
59
|
+
for decode_output_idx in range(out.shape[1]):
|
|
60
|
+
feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
|
|
61
|
+
decode_output_idx
|
|
62
|
+
]
|
|
63
|
+
logger.debug(f"Decoded values: {out}")
|
|
64
|
+
return feature_dict
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.error(f"Error in decoding: {e}")
|
|
67
|
+
return feature_dict
|
|
68
|
+
|
|
69
|
+
|
|
17
70
|
class CV_res:
|
|
18
71
|
def __init__(
|
|
19
72
|
self,
|
|
@@ -168,7 +221,7 @@ class Decoder:
|
|
|
168
221
|
self.feature_names = [
|
|
169
222
|
col
|
|
170
223
|
for col in self.features.columns
|
|
171
|
-
if
|
|
224
|
+
if any(col.startswith(used_ch) for used_ch in self.used_chs)
|
|
172
225
|
]
|
|
173
226
|
self.data = np.nan_to_num(np.array(self.features[self.feature_names]))
|
|
174
227
|
|
|
@@ -51,6 +51,7 @@ class FeatureReader:
|
|
|
51
51
|
self.feature_file = feature_file if feature_file else self.feature_list[0]
|
|
52
52
|
|
|
53
53
|
FILE_BASENAME = PurePath(self.feature_file).stem
|
|
54
|
+
# the features are saved in a directory feature_file, and each file hold the name starting with feature_file
|
|
54
55
|
PATH_READ_FILE = str(PurePath(self.feature_dir, FILE_BASENAME, FILE_BASENAME))
|
|
55
56
|
|
|
56
57
|
self.settings = NMSettings.from_file(PATH_READ_FILE)
|
|
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
|
|
6
6
|
# from numba import njit
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pandas as pd
|
|
9
|
-
import scipy.stats as
|
|
9
|
+
import scipy.stats as scipy_stats
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def fitlm_kfold(x, y, kfold_splits=5):
|
|
@@ -51,7 +51,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
|
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
53
|
# compute ground truth difference
|
|
54
|
-
gT =
|
|
54
|
+
gT = scipy_stats.spearmanr(x, y)[0]
|
|
55
55
|
#
|
|
56
56
|
pV = np.array((x, y))
|
|
57
57
|
# Initialize permutation:
|
|
@@ -65,7 +65,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
|
|
|
65
65
|
random.shuffle(args_order_2)
|
|
66
66
|
# Compute permuted absolute difference of your two sampled
|
|
67
67
|
# distributions and store it in pD:
|
|
68
|
-
pD.append(
|
|
68
|
+
pD.append(scipy_stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
|
|
69
69
|
|
|
70
70
|
# calculate p value
|
|
71
71
|
if gT < 0:
|
|
@@ -1,4 +1,15 @@
|
|
|
1
|
-
|
|
1
|
+
# We
|
|
2
|
+
# should
|
|
3
|
+
# have
|
|
4
|
+
# a
|
|
5
|
+
# brief
|
|
6
|
+
# explanation
|
|
7
|
+
# of
|
|
8
|
+
# the
|
|
9
|
+
# settings
|
|
10
|
+
# format
|
|
11
|
+
# here
|
|
12
|
+
|
|
2
13
|
########################
|
|
3
14
|
### General settings ###
|
|
4
15
|
########################
|
|
@@ -51,12 +62,8 @@ preprocessing_filter:
|
|
|
51
62
|
lowpass_filter: true
|
|
52
63
|
highpass_filter: true
|
|
53
64
|
bandpass_filter: true
|
|
54
|
-
bandstop_filter_settings:
|
|
55
|
-
|
|
56
|
-
frequency_high_hz: 160
|
|
57
|
-
bandpass_filter_settings:
|
|
58
|
-
frequency_low_hz: 3
|
|
59
|
-
frequency_high_hz: 200
|
|
65
|
+
bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
|
|
66
|
+
bandpass_filter_settings: [3, 200] # [hz, _hz]
|
|
60
67
|
lowpass_filter_cutoff_hz: 200
|
|
61
68
|
highpass_filter_cutoff_hz: 3
|
|
62
69
|
|
|
@@ -71,6 +78,7 @@ postprocessing:
|
|
|
71
78
|
feature_normalization_settings:
|
|
72
79
|
normalization_time_s: 30
|
|
73
80
|
normalization_method: zscore # supported methods: mean, median, zscore, zscore-median, quantile, power, robust, minmax
|
|
81
|
+
normalize_psd: false
|
|
74
82
|
clip: 3
|
|
75
83
|
|
|
76
84
|
project_cortex_settings:
|
|
@@ -91,7 +99,7 @@ fft_settings:
|
|
|
91
99
|
median: false
|
|
92
100
|
std: false
|
|
93
101
|
max: false
|
|
94
|
-
return_spectrum:
|
|
102
|
+
return_spectrum: true
|
|
95
103
|
|
|
96
104
|
welch_settings:
|
|
97
105
|
windowlength_ms: 1000
|
|
@@ -101,7 +109,7 @@ welch_settings:
|
|
|
101
109
|
median: false
|
|
102
110
|
std: false
|
|
103
111
|
max: false
|
|
104
|
-
return_spectrum:
|
|
112
|
+
return_spectrum: true
|
|
105
113
|
|
|
106
114
|
stft_settings:
|
|
107
115
|
windowlength_ms: 500
|
|
@@ -136,7 +144,7 @@ kalman_filter_settings:
|
|
|
136
144
|
frequency_bands:
|
|
137
145
|
[theta, alpha, low_beta, high_beta, low_gamma, high_gamma, HFA]
|
|
138
146
|
|
|
139
|
-
|
|
147
|
+
bursts_settings:
|
|
140
148
|
threshold: 75
|
|
141
149
|
time_duration_s: 30
|
|
142
150
|
frequency_bands: [low_beta, high_beta, low_gamma]
|
|
@@ -161,11 +169,9 @@ sharpwave_analysis_settings:
|
|
|
161
169
|
rise_steepness: false
|
|
162
170
|
decay_steepness: false
|
|
163
171
|
slope_ratio: false
|
|
164
|
-
filter_ranges_hz:
|
|
165
|
-
-
|
|
166
|
-
|
|
167
|
-
- frequency_low_hz: 5
|
|
168
|
-
frequency_high_hz: 30
|
|
172
|
+
filter_ranges_hz: # list of [low_hz, high_hz]
|
|
173
|
+
- [5, 80]
|
|
174
|
+
- [5, 30]
|
|
169
175
|
detect_troughs:
|
|
170
176
|
estimate: true
|
|
171
177
|
distance_troughs_ms: 10
|
|
@@ -174,6 +180,7 @@ sharpwave_analysis_settings:
|
|
|
174
180
|
estimate: true
|
|
175
181
|
distance_troughs_ms: 5
|
|
176
182
|
distance_peaks_ms: 10
|
|
183
|
+
# TONI: Reverse this setting? e.g. interval: [mean, var]
|
|
177
184
|
estimator:
|
|
178
185
|
mean: [interval]
|
|
179
186
|
median: []
|
|
@@ -223,8 +230,8 @@ nolds_settings:
|
|
|
223
230
|
frequency_bands: [low_beta]
|
|
224
231
|
|
|
225
232
|
mne_connectiviy_settings:
|
|
226
|
-
method: plv
|
|
227
|
-
mode: multitaper
|
|
233
|
+
method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
|
|
234
|
+
mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
|
|
228
235
|
|
|
229
236
|
bispectrum_settings:
|
|
230
237
|
f1s: [5, 35]
|
|
@@ -2,8 +2,13 @@ import numpy as np
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
from pydantic import field_validator
|
|
5
|
-
|
|
6
5
|
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import (
|
|
7
|
+
NMField,
|
|
8
|
+
NMErrorList,
|
|
9
|
+
create_validation_error,
|
|
10
|
+
)
|
|
11
|
+
from py_neuromodulation import logger
|
|
7
12
|
|
|
8
13
|
if TYPE_CHECKING:
|
|
9
14
|
from py_neuromodulation.stream.settings import NMSettings
|
|
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
|
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class BandPowerSettings(NMBaseModel):
|
|
20
|
-
segment_lengths_ms: dict[str, int] =
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
segment_lengths_ms: dict[str, int] = NMField(
|
|
26
|
+
default={
|
|
27
|
+
"theta": 1000,
|
|
28
|
+
"alpha": 500,
|
|
29
|
+
"low beta": 333,
|
|
30
|
+
"high beta": 333,
|
|
31
|
+
"low gamma": 100,
|
|
32
|
+
"high gamma": 100,
|
|
33
|
+
"HFA": 100,
|
|
34
|
+
},
|
|
35
|
+
custom_metadata={"field_type": "FrequencySegmentLength"},
|
|
36
|
+
)
|
|
29
37
|
bandpower_features: BandpowerFeatures = BandpowerFeatures()
|
|
30
38
|
log_transform: bool = True
|
|
31
39
|
kalman_filter: bool = False
|
|
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
|
|
|
33
41
|
@field_validator("bandpower_features")
|
|
34
42
|
@classmethod
|
|
35
43
|
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
44
|
+
if not len(bandpower_features.get_enabled()) > 0:
|
|
45
|
+
raise create_validation_error(
|
|
46
|
+
error_message="Set at least one bandpower_feature to True.",
|
|
47
|
+
location=["bandpass_filter_settings", "bandpower_features"],
|
|
48
|
+
)
|
|
40
49
|
return bandpower_features
|
|
41
50
|
|
|
42
|
-
def validate_fbands(self, settings: "NMSettings") ->
|
|
51
|
+
def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
|
|
52
|
+
"""_summary_
|
|
53
|
+
|
|
54
|
+
:param settings: _description_
|
|
55
|
+
:type settings: NMSettings
|
|
56
|
+
:raises create_validation_error: _description_
|
|
57
|
+
:raises create_validation_error: _description_
|
|
58
|
+
:raises ValueError: _description_
|
|
59
|
+
"""
|
|
60
|
+
errors: NMErrorList = NMErrorList()
|
|
61
|
+
|
|
43
62
|
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
63
|
+
# Check that all frequency bands are defined in settings.frequency_ranges_hz
|
|
64
|
+
if fband_name not in settings.frequency_ranges_hz:
|
|
65
|
+
logger.warning(
|
|
66
|
+
f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
|
|
67
|
+
" is not defined in settings.frequency_ranges_hz"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Check that all segment lengths are smaller than settings.segment_length_features_ms
|
|
71
|
+
if not seg_length_fband <= settings.segment_length_features_ms:
|
|
72
|
+
errors.add_error(
|
|
73
|
+
f"segment length {seg_length_fband} needs to be smaller than "
|
|
74
|
+
f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
|
|
75
|
+
location=[
|
|
76
|
+
"bandpass_filter_settings",
|
|
77
|
+
"segment_lengths_ms",
|
|
78
|
+
fband_name,
|
|
79
|
+
],
|
|
80
|
+
)
|
|
48
81
|
|
|
82
|
+
# Check that all frequency bands defined in settings.frequency_ranges_hz
|
|
49
83
|
for fband_name in settings.frequency_ranges_hz.keys():
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
84
|
+
if fband_name not in self.segment_lengths_ms:
|
|
85
|
+
errors.add_error(
|
|
86
|
+
f"frequency range {fband_name} "
|
|
87
|
+
"needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
|
|
88
|
+
location=[
|
|
89
|
+
"bandpass_filter_settings",
|
|
90
|
+
"segment_lengths_ms",
|
|
91
|
+
fband_name,
|
|
92
|
+
],
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return errors
|
|
54
96
|
|
|
55
97
|
|
|
56
98
|
class BandPower(NMFeature):
|
|
@@ -7,11 +7,12 @@ else:
|
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from itertools import product
|
|
9
9
|
|
|
10
|
-
from pydantic import
|
|
10
|
+
from pydantic import field_validator
|
|
11
11
|
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
|
|
12
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
12
13
|
|
|
13
14
|
from typing import TYPE_CHECKING, Callable
|
|
14
|
-
from py_neuromodulation.utils.
|
|
15
|
+
from py_neuromodulation.utils.pydantic_extensions import create_validation_error
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from py_neuromodulation import NMSettings
|
|
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
class BurstsSettings(NMBaseModel):
|
|
49
|
-
threshold: float =
|
|
50
|
-
time_duration_s: float =
|
|
50
|
+
threshold: float = NMField(default=75, ge=0)
|
|
51
|
+
time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
|
|
51
52
|
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
|
|
52
53
|
burst_features: BurstFeatures = BurstFeatures()
|
|
53
54
|
|
|
@@ -64,16 +65,16 @@ class Bursts(NMFeature):
|
|
|
64
65
|
settings.validate()
|
|
65
66
|
|
|
66
67
|
# Validate that all frequency bands are defined in the settings
|
|
67
|
-
for fband_burst in settings.
|
|
68
|
+
for fband_burst in settings.bursts_settings.frequency_bands:
|
|
68
69
|
if fband_burst not in list(settings.frequency_ranges_hz.keys()):
|
|
69
70
|
raise create_validation_error(
|
|
70
71
|
f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
|
|
71
|
-
|
|
72
|
+
location=["burst_settings", "frequency_bands"],
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
from py_neuromodulation.filter import MNEFilter
|
|
75
76
|
|
|
76
|
-
self.settings = settings.
|
|
77
|
+
self.settings = settings.bursts_settings
|
|
77
78
|
self.sfreq = sfreq
|
|
78
79
|
self.ch_names = ch_names
|
|
79
80
|
self.segment_length_features_s = settings.segment_length_features_ms / 1000
|
|
@@ -83,7 +84,7 @@ class Bursts(NMFeature):
|
|
|
83
84
|
/ settings.sampling_rate_features_hz
|
|
84
85
|
)
|
|
85
86
|
|
|
86
|
-
self.fband_names = settings.
|
|
87
|
+
self.fband_names = settings.bursts_settings.frequency_bands
|
|
87
88
|
|
|
88
89
|
f_ranges: list[tuple[float, float]] = [
|
|
89
90
|
(
|
|
@@ -3,7 +3,7 @@ from collections.abc import Iterable
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
from typing import TYPE_CHECKING, Annotated
|
|
6
|
-
from pydantic import
|
|
6
|
+
from pydantic import field_validator
|
|
7
7
|
|
|
8
8
|
from py_neuromodulation.utils.types import (
|
|
9
9
|
NMFeature,
|
|
@@ -11,6 +11,7 @@ from py_neuromodulation.utils.types import (
|
|
|
11
11
|
FrequencyRange,
|
|
12
12
|
NMBaseModel,
|
|
13
13
|
)
|
|
14
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
14
15
|
from py_neuromodulation import logger
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -28,15 +29,17 @@ class CoherenceFeatures(BoolSelector):
|
|
|
28
29
|
max_allfbands: bool = True
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
# TODO: make this into a pydantic model that only accepts names from
|
|
33
|
+
# the channels objects and does not accept the same string twice
|
|
34
|
+
ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
class CoherenceSettings(NMBaseModel):
|
|
35
38
|
features: CoherenceFeatures = CoherenceFeatures()
|
|
36
39
|
method: CoherenceMethods = CoherenceMethods()
|
|
37
40
|
channels: list[ListOfTwoStr] = []
|
|
38
|
-
nperseg: int =
|
|
39
|
-
frequency_bands: list[str] =
|
|
41
|
+
nperseg: int = NMField(default=256, ge=1)
|
|
42
|
+
frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
|
|
40
43
|
|
|
41
44
|
@field_validator("frequency_bands")
|
|
42
45
|
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from typing import Type, TYPE_CHECKING
|
|
2
2
|
|
|
3
|
-
from py_neuromodulation.utils.types import NMFeature,
|
|
3
|
+
from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
6
|
import numpy as np
|
|
7
7
|
from py_neuromodulation import NMSettings
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
FEATURE_DICT: dict[
|
|
10
|
+
FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
|
|
11
11
|
"raw_hjorth": "Hjorth",
|
|
12
12
|
"return_raw": "Raw",
|
|
13
13
|
"bandpass_filter": "BandPower",
|
|
@@ -42,7 +42,7 @@ class FeatureProcessors:
|
|
|
42
42
|
from importlib import import_module
|
|
43
43
|
|
|
44
44
|
# Accept 'str' for custom features
|
|
45
|
-
self.features: dict[
|
|
45
|
+
self.features: dict[FEATURE_NAME | str, NMFeature] = {
|
|
46
46
|
feature_name: getattr(
|
|
47
47
|
import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
|
|
48
48
|
)(settings, ch_names, sfreq)
|
|
@@ -83,7 +83,7 @@ class FeatureProcessors:
|
|
|
83
83
|
|
|
84
84
|
return feature_results
|
|
85
85
|
|
|
86
|
-
def get_feature(self, fname:
|
|
86
|
+
def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
|
|
87
87
|
return self.features[fname]
|
|
88
88
|
|
|
89
89
|
|
|
@@ -9,6 +9,7 @@ from py_neuromodulation.utils.types import (
|
|
|
9
9
|
BoolSelector,
|
|
10
10
|
FrequencyRange,
|
|
11
11
|
)
|
|
12
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
12
13
|
|
|
13
14
|
if TYPE_CHECKING:
|
|
14
15
|
from py_neuromodulation import NMSettings
|
|
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
|
|
|
29
30
|
class FooofSettings(NMBaseModel):
|
|
30
31
|
aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
|
|
31
32
|
periodic: FooofPeriodicSettings = FooofPeriodicSettings()
|
|
32
|
-
windowlength_ms: float = 800
|
|
33
|
+
windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
|
|
33
34
|
peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
|
|
34
|
-
max_n_peaks: int = 3
|
|
35
|
-
min_peak_height: float = 0
|
|
36
|
-
peak_threshold: float = 2
|
|
35
|
+
max_n_peaks: int = NMField(3, ge=0)
|
|
36
|
+
min_peak_height: float = NMField(0, ge=0)
|
|
37
|
+
peak_threshold: float = NMField(2, ge=0)
|
|
37
38
|
freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
|
|
38
39
|
knee: bool = True
|
|
39
40
|
|
|
@@ -63,12 +64,12 @@ class FooofAnalyzer(NMFeature):
|
|
|
63
64
|
|
|
64
65
|
assert (
|
|
65
66
|
self.settings.windowlength_ms <= settings.segment_length_features_ms
|
|
66
|
-
), f"fooof windowlength_ms ({settings.
|
|
67
|
+
), f"fooof windowlength_ms ({settings.fooof_settings.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
|
|
67
68
|
|
|
68
69
|
assert (
|
|
69
70
|
self.settings.freq_range_hz[0] < sfreq
|
|
70
71
|
and self.settings.freq_range_hz[1] < sfreq
|
|
71
|
-
), f"fooof frequency range needs to be below sfreq, got {settings.
|
|
72
|
+
), f"fooof frequency range needs to be below sfreq, got {settings.fooof_settings.freq_range_hz}"
|
|
72
73
|
|
|
73
74
|
from fooof import FOOOFGroup
|
|
74
75
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
import numpy as np
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
4
4
|
|
|
5
5
|
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
7
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
from py_neuromodulation import NMSettings
|
|
@@ -10,9 +11,30 @@ if TYPE_CHECKING:
|
|
|
10
11
|
from mne import Epochs
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
MNE_CONNECTIVITY_METHOD = Literal[
|
|
15
|
+
"coh",
|
|
16
|
+
"cohy",
|
|
17
|
+
"imcoh",
|
|
18
|
+
"cacoh",
|
|
19
|
+
"mic",
|
|
20
|
+
"mim",
|
|
21
|
+
"plv",
|
|
22
|
+
"ciplv",
|
|
23
|
+
"ppc",
|
|
24
|
+
"pli",
|
|
25
|
+
"dpli",
|
|
26
|
+
"wpli",
|
|
27
|
+
"wpli2_debiased",
|
|
28
|
+
"gc",
|
|
29
|
+
"gc_tr",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
|
|
33
|
+
|
|
34
|
+
|
|
13
35
|
class MNEConnectivitySettings(NMBaseModel):
|
|
14
|
-
method:
|
|
15
|
-
mode:
|
|
36
|
+
method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
|
|
37
|
+
mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
|
|
16
38
|
|
|
17
39
|
|
|
18
40
|
class MNEConnectivity(NMFeature):
|
|
@@ -3,6 +3,7 @@ import numpy as np
|
|
|
3
3
|
from itertools import product
|
|
4
4
|
|
|
5
5
|
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
7
|
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
@@ -17,12 +18,12 @@ class OscillatoryFeatures(BoolSelector):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class OscillatorySettings(NMBaseModel):
|
|
20
|
-
windowlength_ms: int = 1000
|
|
21
|
+
windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
|
|
21
22
|
log_transform: bool = True
|
|
22
23
|
features: OscillatoryFeatures = OscillatoryFeatures(
|
|
23
24
|
mean=True, median=False, std=False, max=False
|
|
24
25
|
)
|
|
25
|
-
return_spectrum: bool =
|
|
26
|
+
return_spectrum: bool = True
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
ESTIMATOR_DICT = {
|
|
@@ -176,7 +177,7 @@ class Welch(OscillatoryFeature):
|
|
|
176
177
|
if self.settings.return_spectrum:
|
|
177
178
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
178
179
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
179
|
-
feature_results[f"{ch_name}_welch_psd_{
|
|
180
|
+
feature_results[f"{ch_name}_welch_psd_{int(f)}"] = Z[ch_idx][idx]
|
|
180
181
|
|
|
181
182
|
return feature_results
|
|
182
183
|
|
|
@@ -242,7 +243,7 @@ class STFT(OscillatoryFeature):
|
|
|
242
243
|
if self.settings.return_spectrum:
|
|
243
244
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
244
245
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
245
|
-
feature_results[f"{ch_name}_stft_psd_{
|
|
246
|
+
feature_results[f"{ch_name}_stft_psd_{int(f)}"] = Z[ch_idx].mean(
|
|
246
247
|
axis=1
|
|
247
248
|
)[idx]
|
|
248
249
|
|