py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
- py_neuromodulation/__init__.py +12 -4
- py_neuromodulation/analysis/RMAP.py +3 -3
- py_neuromodulation/analysis/decode.py +55 -2
- py_neuromodulation/analysis/feature_reader.py +1 -0
- py_neuromodulation/analysis/stats.py +3 -3
- py_neuromodulation/default_settings.yaml +25 -20
- py_neuromodulation/features/bandpower.py +65 -23
- py_neuromodulation/features/bursts.py +9 -8
- py_neuromodulation/features/coherence.py +7 -4
- py_neuromodulation/features/feature_processor.py +4 -4
- py_neuromodulation/features/fooof.py +7 -6
- py_neuromodulation/features/mne_connectivity.py +60 -87
- py_neuromodulation/features/oscillatory.py +5 -4
- py_neuromodulation/features/sharpwaves.py +21 -0
- py_neuromodulation/filter/kalman_filter.py +17 -6
- py_neuromodulation/gui/__init__.py +3 -0
- py_neuromodulation/gui/backend/app_backend.py +419 -0
- py_neuromodulation/gui/backend/app_manager.py +345 -0
- py_neuromodulation/gui/backend/app_pynm.py +253 -0
- py_neuromodulation/gui/backend/app_socket.py +97 -0
- py_neuromodulation/gui/backend/app_utils.py +306 -0
- py_neuromodulation/gui/backend/app_window.py +202 -0
- py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
- py_neuromodulation/gui/frontend/assets/index-_6V8ZfAS.js +300137 -0
- py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
- py_neuromodulation/gui/frontend/charite.svg +16 -0
- py_neuromodulation/gui/frontend/index.html +14 -0
- py_neuromodulation/gui/window_api.py +115 -0
- py_neuromodulation/lsl_api.cfg +3 -0
- py_neuromodulation/processing/data_preprocessor.py +9 -2
- py_neuromodulation/processing/filter_preprocessing.py +43 -27
- py_neuromodulation/processing/normalization.py +32 -17
- py_neuromodulation/processing/projection.py +2 -2
- py_neuromodulation/processing/resample.py +6 -2
- py_neuromodulation/run_gui.py +36 -0
- py_neuromodulation/stream/__init__.py +7 -1
- py_neuromodulation/stream/backend_interface.py +47 -0
- py_neuromodulation/stream/data_processor.py +24 -3
- py_neuromodulation/stream/mnelsl_player.py +121 -21
- py_neuromodulation/stream/mnelsl_stream.py +9 -17
- py_neuromodulation/stream/settings.py +80 -34
- py_neuromodulation/stream/stream.py +83 -62
- py_neuromodulation/utils/channels.py +1 -1
- py_neuromodulation/utils/file_writer.py +110 -0
- py_neuromodulation/utils/io.py +46 -5
- py_neuromodulation/utils/perf.py +156 -0
- py_neuromodulation/utils/pydantic_extensions.py +322 -0
- py_neuromodulation/utils/types.py +33 -107
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/METADATA +23 -4
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/RECORD +55 -35
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/WHEEL +1 -1
- py_neuromodulation-0.1.1.dist-info/entry_points.txt +2 -0
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/licenses/LICENSE +0 -0
py_neuromodulation/__init__.py
CHANGED
|
@@ -4,11 +4,12 @@ from pathlib import PurePath
|
|
|
4
4
|
from importlib.metadata import version
|
|
5
5
|
from py_neuromodulation.utils.logging import NMLogger
|
|
6
6
|
|
|
7
|
+
|
|
7
8
|
#####################################
|
|
8
9
|
# Globals and environment variables #
|
|
9
10
|
#####################################
|
|
10
11
|
|
|
11
|
-
__version__ = version("py_neuromodulation")
|
|
12
|
+
__version__ = version("py_neuromodulation")
|
|
12
13
|
|
|
13
14
|
# Check if the module is running headless (no display) for tests and doc builds
|
|
14
15
|
PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
|
|
@@ -16,6 +17,9 @@ PYNM_DIR = PurePath(__file__).parent # Define constant for py_nm directory
|
|
|
16
17
|
|
|
17
18
|
os.environ["MPLBACKEND"] = "agg" if PYNM_HEADLESS else "qtagg" # Set matplotlib backend
|
|
18
19
|
|
|
20
|
+
os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
|
|
21
|
+
|
|
22
|
+
|
|
19
23
|
# Set environment variable MNE_LSL_LIB (required to import Stream below)
|
|
20
24
|
LSL_DICT = {
|
|
21
25
|
"windows_32bit": "windows/x86/liblsl.1.16.2.dll",
|
|
@@ -34,15 +38,17 @@ LSL_DICT = {
|
|
|
34
38
|
|
|
35
39
|
PLATFORM = platform.system().lower().strip()
|
|
36
40
|
ARCH = platform.architecture()[0]
|
|
41
|
+
|
|
37
42
|
match PLATFORM:
|
|
38
43
|
case "windows":
|
|
39
44
|
KEY = PLATFORM + "_" + ARCH
|
|
40
45
|
case "darwin":
|
|
41
46
|
KEY = PLATFORM + "_" + platform.processor()
|
|
42
47
|
case "linux":
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
48
|
+
if "VERSION_CODENAME" in platform.freedesktop_os_release().keys():
|
|
49
|
+
DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
|
|
50
|
+
KEY = PLATFORM + "_" + DIST + "_" + ARCH
|
|
51
|
+
else:
|
|
46
52
|
KEY = PLATFORM + "_" + ARCH
|
|
47
53
|
case _:
|
|
48
54
|
KEY = ""
|
|
@@ -78,3 +84,5 @@ from . import stream
|
|
|
78
84
|
from . import analysis
|
|
79
85
|
|
|
80
86
|
from .stream.settings import get_default_settings, get_fast_compute, reset_settings
|
|
87
|
+
|
|
88
|
+
from .gui.backend.app_manager import AppManager as App
|
|
@@ -8,8 +8,8 @@ import pandas as pd
|
|
|
8
8
|
import nibabel as nib
|
|
9
9
|
from matplotlib import pyplot as plt
|
|
10
10
|
|
|
11
|
-
from py_neuromodulation.
|
|
12
|
-
from py_neuromodulation.types import _PathLike
|
|
11
|
+
from py_neuromodulation.analysis import reg_plot
|
|
12
|
+
from py_neuromodulation.utils.types import _PathLike
|
|
13
13
|
from py_neuromodulation import PYNM_DIR
|
|
14
14
|
|
|
15
15
|
LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
|
|
@@ -335,7 +335,7 @@ class RMAPCross_Val_ChannelSelector:
|
|
|
335
335
|
reshape: bool = True,
|
|
336
336
|
):
|
|
337
337
|
if reshape:
|
|
338
|
-
fp = np.reshape(fp, (91, 109, 91), order="
|
|
338
|
+
fp = np.reshape(fp, (91, 109, 91), order="C")
|
|
339
339
|
|
|
340
340
|
img = nib.nifti1.Nifti1Image(fp, affine=affine)
|
|
341
341
|
|
|
@@ -6,14 +6,67 @@ from sklearn.metrics import r2_score
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import numpy as np
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from pathlib import PurePath
|
|
9
|
+
from pathlib import Path, PurePath
|
|
10
10
|
import pickle
|
|
11
11
|
|
|
12
12
|
from py_neuromodulation import logger
|
|
13
|
+
from py_neuromodulation.utils.types import _PathLike
|
|
13
14
|
|
|
14
15
|
from typing import Callable
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
class RealTimeDecoder:
|
|
19
|
+
def __init__(self, model_path: _PathLike):
|
|
20
|
+
self.model_path = Path(model_path)
|
|
21
|
+
if not self.model_path.exists():
|
|
22
|
+
raise FileNotFoundError(f"Model file {self.model_path} not found")
|
|
23
|
+
if not self.model_path.is_file():
|
|
24
|
+
raise IsADirectoryError(f"Model file {self.model_path} is a directory")
|
|
25
|
+
|
|
26
|
+
if self.model_path.suffix == ".skops":
|
|
27
|
+
from skops import io as skops_io
|
|
28
|
+
|
|
29
|
+
self.model = skops_io.load(self.model_path)
|
|
30
|
+
else:
|
|
31
|
+
return NotImplementedError("Only skops models are supported")
|
|
32
|
+
|
|
33
|
+
def predict(
|
|
34
|
+
self,
|
|
35
|
+
feature_dict: dict,
|
|
36
|
+
channel: str | None = None,
|
|
37
|
+
fft_bands_only: bool = True,
|
|
38
|
+
) -> dict:
|
|
39
|
+
try:
|
|
40
|
+
if channel is not None:
|
|
41
|
+
features_ch = {
|
|
42
|
+
f: feature_dict[f]
|
|
43
|
+
for f in feature_dict.keys()
|
|
44
|
+
if f.startswith(channel)
|
|
45
|
+
}
|
|
46
|
+
if fft_bands_only is True:
|
|
47
|
+
features_ch_fft = {
|
|
48
|
+
f: features_ch[f]
|
|
49
|
+
for f in features_ch.keys()
|
|
50
|
+
if "fft" in f and "psd" not in f
|
|
51
|
+
}
|
|
52
|
+
out = self.model.predict_proba(
|
|
53
|
+
np.array(list(features_ch_fft.values())).reshape(1, -1)
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
out = self.model.predict_proba(features_ch)
|
|
57
|
+
else:
|
|
58
|
+
out = self.model.predict(feature_dict)
|
|
59
|
+
for decode_output_idx in range(out.shape[1]):
|
|
60
|
+
feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
|
|
61
|
+
decode_output_idx
|
|
62
|
+
]
|
|
63
|
+
logger.debug(f"Decoded values: {out}")
|
|
64
|
+
return feature_dict
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.error(f"Error in decoding: {e}")
|
|
67
|
+
return feature_dict
|
|
68
|
+
|
|
69
|
+
|
|
17
70
|
class CV_res:
|
|
18
71
|
def __init__(
|
|
19
72
|
self,
|
|
@@ -168,7 +221,7 @@ class Decoder:
|
|
|
168
221
|
self.feature_names = [
|
|
169
222
|
col
|
|
170
223
|
for col in self.features.columns
|
|
171
|
-
if
|
|
224
|
+
if any(col.startswith(used_ch) for used_ch in self.used_chs)
|
|
172
225
|
]
|
|
173
226
|
self.data = np.nan_to_num(np.array(self.features[self.feature_names]))
|
|
174
227
|
|
|
@@ -51,6 +51,7 @@ class FeatureReader:
|
|
|
51
51
|
self.feature_file = feature_file if feature_file else self.feature_list[0]
|
|
52
52
|
|
|
53
53
|
FILE_BASENAME = PurePath(self.feature_file).stem
|
|
54
|
+
# the features are saved in a directory feature_file, and each file hold the name starting with feature_file
|
|
54
55
|
PATH_READ_FILE = str(PurePath(self.feature_dir, FILE_BASENAME, FILE_BASENAME))
|
|
55
56
|
|
|
56
57
|
self.settings = NMSettings.from_file(PATH_READ_FILE)
|
|
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
|
|
6
6
|
# from numba import njit
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pandas as pd
|
|
9
|
-
import scipy.stats as
|
|
9
|
+
import scipy.stats as scipy_stats
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def fitlm_kfold(x, y, kfold_splits=5):
|
|
@@ -51,7 +51,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
|
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
53
|
# compute ground truth difference
|
|
54
|
-
gT =
|
|
54
|
+
gT = scipy_stats.spearmanr(x, y)[0]
|
|
55
55
|
#
|
|
56
56
|
pV = np.array((x, y))
|
|
57
57
|
# Initialize permutation:
|
|
@@ -65,7 +65,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
|
|
|
65
65
|
random.shuffle(args_order_2)
|
|
66
66
|
# Compute permuted absolute difference of your two sampled
|
|
67
67
|
# distributions and store it in pD:
|
|
68
|
-
pD.append(
|
|
68
|
+
pD.append(scipy_stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
|
|
69
69
|
|
|
70
70
|
# calculate p value
|
|
71
71
|
if gT < 0:
|
|
@@ -1,4 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
# Settings should be modified either directly through
|
|
2
|
+
# the default_settings.yaml file or by creating a new
|
|
3
|
+
# settings.yaml file than can be loaded with
|
|
4
|
+
# settings.NMSettings.load(FILE_PATH)
|
|
5
|
+
#
|
|
6
|
+
# Alternatively, the settings can also be modified through the
|
|
7
|
+
# settings object directly, e.g. settings.features.raw_hjorth = False
|
|
8
|
+
|
|
2
9
|
########################
|
|
3
10
|
### General settings ###
|
|
4
11
|
########################
|
|
@@ -10,9 +17,9 @@ frequency_ranges_hz: # frequency band ranges can be added, removed and altered
|
|
|
10
17
|
alpha: [8, 12]
|
|
11
18
|
low_beta: [13, 20]
|
|
12
19
|
high_beta: [20, 35]
|
|
13
|
-
low_gamma: [60, 80]
|
|
14
|
-
high_gamma: [90, 200]
|
|
15
|
-
HFA: [200, 400]
|
|
20
|
+
#low_gamma: [60, 80]
|
|
21
|
+
#high_gamma: [90, 200]
|
|
22
|
+
#HFA: [200, 400]
|
|
16
23
|
|
|
17
24
|
# Enabled features
|
|
18
25
|
features:
|
|
@@ -51,12 +58,8 @@ preprocessing_filter:
|
|
|
51
58
|
lowpass_filter: true
|
|
52
59
|
highpass_filter: true
|
|
53
60
|
bandpass_filter: true
|
|
54
|
-
bandstop_filter_settings:
|
|
55
|
-
|
|
56
|
-
frequency_high_hz: 160
|
|
57
|
-
bandpass_filter_settings:
|
|
58
|
-
frequency_low_hz: 3
|
|
59
|
-
frequency_high_hz: 200
|
|
61
|
+
bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
|
|
62
|
+
bandpass_filter_settings: [3, 200] # [hz, _hz]
|
|
60
63
|
lowpass_filter_cutoff_hz: 200
|
|
61
64
|
highpass_filter_cutoff_hz: 3
|
|
62
65
|
|
|
@@ -71,6 +74,7 @@ postprocessing:
|
|
|
71
74
|
feature_normalization_settings:
|
|
72
75
|
normalization_time_s: 30
|
|
73
76
|
normalization_method: zscore # supported methods: mean, median, zscore, zscore-median, quantile, power, robust, minmax
|
|
77
|
+
normalize_psd: false
|
|
74
78
|
clip: 3
|
|
75
79
|
|
|
76
80
|
project_cortex_settings:
|
|
@@ -136,10 +140,10 @@ kalman_filter_settings:
|
|
|
136
140
|
frequency_bands:
|
|
137
141
|
[theta, alpha, low_beta, high_beta, low_gamma, high_gamma, HFA]
|
|
138
142
|
|
|
139
|
-
|
|
143
|
+
bursts_settings:
|
|
140
144
|
threshold: 75
|
|
141
145
|
time_duration_s: 30
|
|
142
|
-
frequency_bands: [low_beta, high_beta
|
|
146
|
+
frequency_bands: [low_beta, high_beta] # low_gamma
|
|
143
147
|
burst_features:
|
|
144
148
|
duration: true
|
|
145
149
|
amplitude: true
|
|
@@ -161,11 +165,9 @@ sharpwave_analysis_settings:
|
|
|
161
165
|
rise_steepness: false
|
|
162
166
|
decay_steepness: false
|
|
163
167
|
slope_ratio: false
|
|
164
|
-
filter_ranges_hz:
|
|
165
|
-
-
|
|
166
|
-
|
|
167
|
-
- frequency_low_hz: 5
|
|
168
|
-
frequency_high_hz: 30
|
|
168
|
+
filter_ranges_hz: # list of [low_hz, high_hz]
|
|
169
|
+
- [5, 80]
|
|
170
|
+
- [5, 30]
|
|
169
171
|
detect_troughs:
|
|
170
172
|
estimate: true
|
|
171
173
|
distance_troughs_ms: 10
|
|
@@ -174,6 +176,7 @@ sharpwave_analysis_settings:
|
|
|
174
176
|
estimate: true
|
|
175
177
|
distance_troughs_ms: 5
|
|
176
178
|
distance_peaks_ms: 10
|
|
179
|
+
# TONI: Reverse this setting? e.g. interval: [mean, var]
|
|
177
180
|
estimator:
|
|
178
181
|
mean: [interval]
|
|
179
182
|
median: []
|
|
@@ -183,7 +186,7 @@ sharpwave_analysis_settings:
|
|
|
183
186
|
apply_estimator_between_peaks_and_troughs: true
|
|
184
187
|
|
|
185
188
|
coherence_settings:
|
|
186
|
-
channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels.
|
|
189
|
+
channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels, where the first channel is the seed and the second channel is the target.
|
|
187
190
|
# Example channels: [[STN_RIGHT_0, ECOG_RIGHT_0], [STN_RIGHT_1, ECOG_RIGHT_1]]
|
|
188
191
|
frequency_bands: [high_beta]
|
|
189
192
|
features:
|
|
@@ -223,8 +226,10 @@ nolds_settings:
|
|
|
223
226
|
frequency_bands: [low_beta]
|
|
224
227
|
|
|
225
228
|
mne_connectiviy_settings:
|
|
226
|
-
|
|
227
|
-
|
|
229
|
+
channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels, where the first channel is the seed and the second channel is the target.
|
|
230
|
+
# Example channels: [[STN_RIGHT_0, ECOG_RIGHT_0], [STN_RIGHT_1, ECOG_RIGHT_1]]
|
|
231
|
+
method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
|
|
232
|
+
mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
|
|
228
233
|
|
|
229
234
|
bispectrum_settings:
|
|
230
235
|
f1s: [5, 35]
|
|
@@ -2,8 +2,13 @@ import numpy as np
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
from pydantic import field_validator
|
|
5
|
-
|
|
6
5
|
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import (
|
|
7
|
+
NMField,
|
|
8
|
+
NMErrorList,
|
|
9
|
+
create_validation_error,
|
|
10
|
+
)
|
|
11
|
+
from py_neuromodulation import logger
|
|
7
12
|
|
|
8
13
|
if TYPE_CHECKING:
|
|
9
14
|
from py_neuromodulation.stream.settings import NMSettings
|
|
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
|
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class BandPowerSettings(NMBaseModel):
|
|
20
|
-
segment_lengths_ms: dict[str, int] =
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
segment_lengths_ms: dict[str, int] = NMField(
|
|
26
|
+
default={
|
|
27
|
+
"theta": 1000,
|
|
28
|
+
"alpha": 500,
|
|
29
|
+
"low beta": 333,
|
|
30
|
+
"high beta": 333,
|
|
31
|
+
"low gamma": 100,
|
|
32
|
+
"high gamma": 100,
|
|
33
|
+
"HFA": 100,
|
|
34
|
+
},
|
|
35
|
+
custom_metadata={"field_type": "FrequencySegmentLength"},
|
|
36
|
+
)
|
|
29
37
|
bandpower_features: BandpowerFeatures = BandpowerFeatures()
|
|
30
38
|
log_transform: bool = True
|
|
31
39
|
kalman_filter: bool = False
|
|
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
|
|
|
33
41
|
@field_validator("bandpower_features")
|
|
34
42
|
@classmethod
|
|
35
43
|
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
44
|
+
if not len(bandpower_features.get_enabled()) > 0:
|
|
45
|
+
raise create_validation_error(
|
|
46
|
+
error_message="Set at least one bandpower_feature to True.",
|
|
47
|
+
location=["bandpass_filter_settings", "bandpower_features"],
|
|
48
|
+
)
|
|
40
49
|
return bandpower_features
|
|
41
50
|
|
|
42
|
-
def validate_fbands(self, settings: "NMSettings") ->
|
|
51
|
+
def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
|
|
52
|
+
"""_summary_
|
|
53
|
+
|
|
54
|
+
:param settings: _description_
|
|
55
|
+
:type settings: NMSettings
|
|
56
|
+
:raises create_validation_error: _description_
|
|
57
|
+
:raises create_validation_error: _description_
|
|
58
|
+
:raises ValueError: _description_
|
|
59
|
+
"""
|
|
60
|
+
errors: NMErrorList = NMErrorList()
|
|
61
|
+
|
|
43
62
|
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
63
|
+
# Check that all frequency bands are defined in settings.frequency_ranges_hz
|
|
64
|
+
if fband_name not in settings.frequency_ranges_hz:
|
|
65
|
+
logger.warning(
|
|
66
|
+
f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
|
|
67
|
+
" is not defined in settings.frequency_ranges_hz"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Check that all segment lengths are smaller than settings.segment_length_features_ms
|
|
71
|
+
if not seg_length_fband <= settings.segment_length_features_ms:
|
|
72
|
+
errors.add_error(
|
|
73
|
+
f"segment length {seg_length_fband} needs to be smaller than "
|
|
74
|
+
f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
|
|
75
|
+
location=[
|
|
76
|
+
"bandpass_filter_settings",
|
|
77
|
+
"segment_lengths_ms",
|
|
78
|
+
fband_name,
|
|
79
|
+
],
|
|
80
|
+
)
|
|
48
81
|
|
|
82
|
+
# Check that all frequency bands defined in settings.frequency_ranges_hz
|
|
49
83
|
for fband_name in settings.frequency_ranges_hz.keys():
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
84
|
+
if fband_name not in self.segment_lengths_ms:
|
|
85
|
+
errors.add_error(
|
|
86
|
+
f"frequency range {fband_name} "
|
|
87
|
+
"needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
|
|
88
|
+
location=[
|
|
89
|
+
"bandpass_filter_settings",
|
|
90
|
+
"segment_lengths_ms",
|
|
91
|
+
fband_name,
|
|
92
|
+
],
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return errors
|
|
54
96
|
|
|
55
97
|
|
|
56
98
|
class BandPower(NMFeature):
|
|
@@ -7,11 +7,12 @@ else:
|
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from itertools import product
|
|
9
9
|
|
|
10
|
-
from pydantic import
|
|
10
|
+
from pydantic import field_validator
|
|
11
11
|
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
|
|
12
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
12
13
|
|
|
13
14
|
from typing import TYPE_CHECKING, Callable
|
|
14
|
-
from py_neuromodulation.utils.
|
|
15
|
+
from py_neuromodulation.utils.pydantic_extensions import create_validation_error
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from py_neuromodulation import NMSettings
|
|
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
class BurstsSettings(NMBaseModel):
|
|
49
|
-
threshold: float =
|
|
50
|
-
time_duration_s: float =
|
|
50
|
+
threshold: float = NMField(default=75, ge=0)
|
|
51
|
+
time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
|
|
51
52
|
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
|
|
52
53
|
burst_features: BurstFeatures = BurstFeatures()
|
|
53
54
|
|
|
@@ -64,16 +65,16 @@ class Bursts(NMFeature):
|
|
|
64
65
|
settings.validate()
|
|
65
66
|
|
|
66
67
|
# Validate that all frequency bands are defined in the settings
|
|
67
|
-
for fband_burst in settings.
|
|
68
|
+
for fband_burst in settings.bursts_settings.frequency_bands:
|
|
68
69
|
if fband_burst not in list(settings.frequency_ranges_hz.keys()):
|
|
69
70
|
raise create_validation_error(
|
|
70
71
|
f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
|
|
71
|
-
|
|
72
|
+
location=["burst_settings", "frequency_bands"],
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
from py_neuromodulation.filter import MNEFilter
|
|
75
76
|
|
|
76
|
-
self.settings = settings.
|
|
77
|
+
self.settings = settings.bursts_settings
|
|
77
78
|
self.sfreq = sfreq
|
|
78
79
|
self.ch_names = ch_names
|
|
79
80
|
self.segment_length_features_s = settings.segment_length_features_ms / 1000
|
|
@@ -83,7 +84,7 @@ class Bursts(NMFeature):
|
|
|
83
84
|
/ settings.sampling_rate_features_hz
|
|
84
85
|
)
|
|
85
86
|
|
|
86
|
-
self.fband_names = settings.
|
|
87
|
+
self.fband_names = settings.bursts_settings.frequency_bands
|
|
87
88
|
|
|
88
89
|
f_ranges: list[tuple[float, float]] = [
|
|
89
90
|
(
|
|
@@ -3,7 +3,7 @@ from collections.abc import Iterable
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
from typing import TYPE_CHECKING, Annotated
|
|
6
|
-
from pydantic import
|
|
6
|
+
from pydantic import field_validator
|
|
7
7
|
|
|
8
8
|
from py_neuromodulation.utils.types import (
|
|
9
9
|
NMFeature,
|
|
@@ -11,6 +11,7 @@ from py_neuromodulation.utils.types import (
|
|
|
11
11
|
FrequencyRange,
|
|
12
12
|
NMBaseModel,
|
|
13
13
|
)
|
|
14
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
14
15
|
from py_neuromodulation import logger
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -28,15 +29,17 @@ class CoherenceFeatures(BoolSelector):
|
|
|
28
29
|
max_allfbands: bool = True
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
# TODO: make this into a pydantic model that only accepts names from
|
|
33
|
+
# the channels objects and does not accept the same string twice
|
|
34
|
+
ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
class CoherenceSettings(NMBaseModel):
|
|
35
38
|
features: CoherenceFeatures = CoherenceFeatures()
|
|
36
39
|
method: CoherenceMethods = CoherenceMethods()
|
|
37
40
|
channels: list[ListOfTwoStr] = []
|
|
38
|
-
nperseg: int =
|
|
39
|
-
frequency_bands: list[str] =
|
|
41
|
+
nperseg: int = NMField(default=256, ge=1)
|
|
42
|
+
frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
|
|
40
43
|
|
|
41
44
|
@field_validator("frequency_bands")
|
|
42
45
|
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from typing import Type, TYPE_CHECKING
|
|
2
2
|
|
|
3
|
-
from py_neuromodulation.utils.types import NMFeature,
|
|
3
|
+
from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
6
|
import numpy as np
|
|
7
7
|
from py_neuromodulation import NMSettings
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
FEATURE_DICT: dict[
|
|
10
|
+
FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
|
|
11
11
|
"raw_hjorth": "Hjorth",
|
|
12
12
|
"return_raw": "Raw",
|
|
13
13
|
"bandpass_filter": "BandPower",
|
|
@@ -42,7 +42,7 @@ class FeatureProcessors:
|
|
|
42
42
|
from importlib import import_module
|
|
43
43
|
|
|
44
44
|
# Accept 'str' for custom features
|
|
45
|
-
self.features: dict[
|
|
45
|
+
self.features: dict[FEATURE_NAME | str, NMFeature] = {
|
|
46
46
|
feature_name: getattr(
|
|
47
47
|
import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
|
|
48
48
|
)(settings, ch_names, sfreq)
|
|
@@ -83,7 +83,7 @@ class FeatureProcessors:
|
|
|
83
83
|
|
|
84
84
|
return feature_results
|
|
85
85
|
|
|
86
|
-
def get_feature(self, fname:
|
|
86
|
+
def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
|
|
87
87
|
return self.features[fname]
|
|
88
88
|
|
|
89
89
|
|
|
@@ -9,6 +9,7 @@ from py_neuromodulation.utils.types import (
|
|
|
9
9
|
BoolSelector,
|
|
10
10
|
FrequencyRange,
|
|
11
11
|
)
|
|
12
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
12
13
|
|
|
13
14
|
if TYPE_CHECKING:
|
|
14
15
|
from py_neuromodulation import NMSettings
|
|
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
|
|
|
29
30
|
class FooofSettings(NMBaseModel):
|
|
30
31
|
aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
|
|
31
32
|
periodic: FooofPeriodicSettings = FooofPeriodicSettings()
|
|
32
|
-
windowlength_ms: float = 800
|
|
33
|
+
windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
|
|
33
34
|
peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
|
|
34
|
-
max_n_peaks: int = 3
|
|
35
|
-
min_peak_height: float = 0
|
|
36
|
-
peak_threshold: float = 2
|
|
35
|
+
max_n_peaks: int = NMField(3, ge=0)
|
|
36
|
+
min_peak_height: float = NMField(0, ge=0)
|
|
37
|
+
peak_threshold: float = NMField(2, ge=0)
|
|
37
38
|
freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
|
|
38
39
|
knee: bool = True
|
|
39
40
|
|
|
@@ -63,12 +64,12 @@ class FooofAnalyzer(NMFeature):
|
|
|
63
64
|
|
|
64
65
|
assert (
|
|
65
66
|
self.settings.windowlength_ms <= settings.segment_length_features_ms
|
|
66
|
-
), f"fooof windowlength_ms ({settings.
|
|
67
|
+
), f"fooof windowlength_ms ({settings.fooof_settings.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
|
|
67
68
|
|
|
68
69
|
assert (
|
|
69
70
|
self.settings.freq_range_hz[0] < sfreq
|
|
70
71
|
and self.settings.freq_range_hz[1] < sfreq
|
|
71
|
-
), f"fooof frequency range needs to be below sfreq, got {settings.
|
|
72
|
+
), f"fooof frequency range needs to be below sfreq, got {settings.fooof_settings.freq_range_hz}"
|
|
72
73
|
|
|
73
74
|
from fooof import FOOOFGroup
|
|
74
75
|
|