py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
- py_neuromodulation/__init__.py +12 -4
- py_neuromodulation/analysis/RMAP.py +3 -3
- py_neuromodulation/analysis/decode.py +55 -2
- py_neuromodulation/analysis/feature_reader.py +1 -0
- py_neuromodulation/analysis/stats.py +3 -3
- py_neuromodulation/default_settings.yaml +25 -20
- py_neuromodulation/features/bandpower.py +65 -23
- py_neuromodulation/features/bursts.py +9 -8
- py_neuromodulation/features/coherence.py +7 -4
- py_neuromodulation/features/feature_processor.py +4 -4
- py_neuromodulation/features/fooof.py +7 -6
- py_neuromodulation/features/mne_connectivity.py +60 -87
- py_neuromodulation/features/oscillatory.py +5 -4
- py_neuromodulation/features/sharpwaves.py +21 -0
- py_neuromodulation/filter/kalman_filter.py +17 -6
- py_neuromodulation/gui/__init__.py +3 -0
- py_neuromodulation/gui/backend/app_backend.py +419 -0
- py_neuromodulation/gui/backend/app_manager.py +345 -0
- py_neuromodulation/gui/backend/app_pynm.py +253 -0
- py_neuromodulation/gui/backend/app_socket.py +97 -0
- py_neuromodulation/gui/backend/app_utils.py +306 -0
- py_neuromodulation/gui/backend/app_window.py +202 -0
- py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
- py_neuromodulation/gui/frontend/assets/index-_6V8ZfAS.js +300137 -0
- py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
- py_neuromodulation/gui/frontend/charite.svg +16 -0
- py_neuromodulation/gui/frontend/index.html +14 -0
- py_neuromodulation/gui/window_api.py +115 -0
- py_neuromodulation/lsl_api.cfg +3 -0
- py_neuromodulation/processing/data_preprocessor.py +9 -2
- py_neuromodulation/processing/filter_preprocessing.py +43 -27
- py_neuromodulation/processing/normalization.py +32 -17
- py_neuromodulation/processing/projection.py +2 -2
- py_neuromodulation/processing/resample.py +6 -2
- py_neuromodulation/run_gui.py +36 -0
- py_neuromodulation/stream/__init__.py +7 -1
- py_neuromodulation/stream/backend_interface.py +47 -0
- py_neuromodulation/stream/data_processor.py +24 -3
- py_neuromodulation/stream/mnelsl_player.py +121 -21
- py_neuromodulation/stream/mnelsl_stream.py +9 -17
- py_neuromodulation/stream/settings.py +80 -34
- py_neuromodulation/stream/stream.py +83 -62
- py_neuromodulation/utils/channels.py +1 -1
- py_neuromodulation/utils/file_writer.py +110 -0
- py_neuromodulation/utils/io.py +46 -5
- py_neuromodulation/utils/perf.py +156 -0
- py_neuromodulation/utils/pydantic_extensions.py +322 -0
- py_neuromodulation/utils/types.py +33 -107
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/METADATA +23 -4
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/RECORD +55 -35
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/WHEEL +1 -1
- py_neuromodulation-0.1.1.dist-info/entry_points.txt +2 -0
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,18 +1,44 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
import numpy as np
|
|
3
|
-
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Annotated, Literal
|
|
5
|
+
from pydantic import Field
|
|
4
6
|
|
|
5
7
|
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
|
|
8
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
9
|
|
|
7
10
|
if TYPE_CHECKING:
|
|
8
11
|
from py_neuromodulation import NMSettings
|
|
9
|
-
|
|
10
|
-
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
MNE_CONNECTIVITY_METHOD = Literal[
|
|
18
|
+
"coh",
|
|
19
|
+
"cohy",
|
|
20
|
+
"imcoh",
|
|
21
|
+
"cacoh",
|
|
22
|
+
"mic",
|
|
23
|
+
"mim",
|
|
24
|
+
"plv",
|
|
25
|
+
"ciplv",
|
|
26
|
+
"ppc",
|
|
27
|
+
"pli",
|
|
28
|
+
"dpli",
|
|
29
|
+
"wpli",
|
|
30
|
+
"wpli2_debiased",
|
|
31
|
+
"gc",
|
|
32
|
+
"gc_tr",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
|
|
11
36
|
|
|
12
37
|
|
|
13
38
|
class MNEConnectivitySettings(NMBaseModel):
|
|
14
|
-
method:
|
|
15
|
-
mode:
|
|
39
|
+
method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
|
|
40
|
+
mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
|
|
41
|
+
channels: list[ListOfTwoStr] = []
|
|
16
42
|
|
|
17
43
|
|
|
18
44
|
class MNEConnectivity(NMFeature):
|
|
@@ -22,102 +48,42 @@ class MNEConnectivity(NMFeature):
|
|
|
22
48
|
ch_names: Iterable[str],
|
|
23
49
|
sfreq: float,
|
|
24
50
|
) -> None:
|
|
25
|
-
from mne import create_info
|
|
26
|
-
|
|
27
51
|
self.settings = settings
|
|
28
52
|
|
|
29
53
|
self.ch_names = ch_names
|
|
30
54
|
self.sfreq = sfreq
|
|
31
55
|
|
|
56
|
+
self.channels = settings.mne_connectivity_settings.channels
|
|
57
|
+
|
|
32
58
|
# Params used by spectral_connectivity_epochs
|
|
33
59
|
self.mode = settings.mne_connectivity_settings.mode
|
|
34
60
|
self.method = settings.mne_connectivity_settings.method
|
|
61
|
+
self.indices = ([], []) # convert channel names to channel indices in data
|
|
62
|
+
for con_idx in range(len(self.channels)):
|
|
63
|
+
seed_name = self.channels[con_idx][0]
|
|
64
|
+
target_name = self.channels[con_idx][1]
|
|
65
|
+
seed_name_reref = [ch for ch in self.ch_names if ch.startswith(seed_name)][0]
|
|
66
|
+
target_name_reref = [ch for ch in self.ch_names if ch.startswith(target_name)][0]
|
|
67
|
+
self.indices[0].append(self.ch_names.index(seed_name_reref))
|
|
68
|
+
self.indices[1].append(self.ch_names.index(target_name_reref))
|
|
35
69
|
|
|
36
70
|
self.fbands = settings.frequency_ranges_hz
|
|
37
71
|
self.fband_ranges: list = []
|
|
38
72
|
self.result_keys = []
|
|
39
73
|
|
|
40
|
-
self.raw_info = create_info(ch_names=self.ch_names, sfreq=self.sfreq)
|
|
41
|
-
self.raw_array: "RawArray"
|
|
42
|
-
self.epochs: "Epochs"
|
|
43
74
|
self.prev_batch_shape: tuple = (-1, -1) # sentinel value
|
|
44
75
|
|
|
45
76
|
def calc_feature(self, data: np.ndarray) -> dict:
|
|
46
|
-
from mne.io import RawArray
|
|
47
|
-
from mne import Epochs
|
|
48
77
|
from mne_connectivity import spectral_connectivity_epochs
|
|
49
|
-
import pandas as pd
|
|
50
|
-
|
|
51
|
-
time_samples_s = data.shape[1] / self.sfreq
|
|
52
|
-
epoch_length: float = 1 # TODO: Make this a parameter?
|
|
53
|
-
|
|
54
|
-
if epoch_length > time_samples_s:
|
|
55
|
-
raise ValueError(
|
|
56
|
-
f"the intended epoch length for mne connectivity: {epoch_length}s"
|
|
57
|
-
f" are longer than the passed data array {np.round(time_samples_s, 2)}s"
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
# Only reinitialize the raw_array and epochs object if the data shape has changed
|
|
61
|
-
# That could mean that the channels have been re-selected, or we're in the last batch
|
|
62
|
-
# TODO: If sfreq or channels change, do we re-initialize the whole Stream object?
|
|
63
|
-
if data.shape != self.prev_batch_shape:
|
|
64
|
-
self.raw_array = RawArray(
|
|
65
|
-
data=data,
|
|
66
|
-
info=self.raw_info,
|
|
67
|
-
copy=None, # type: ignore
|
|
68
|
-
verbose=False,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
# self.events = make_fixed_length_events(self.raw_array, duration=epoch_length)
|
|
72
|
-
# Equivalent code for those parameters:
|
|
73
|
-
event_times = np.arange(
|
|
74
|
-
0, data.shape[-1], self.sfreq * epoch_length, dtype=int
|
|
75
|
-
)
|
|
76
|
-
events = np.column_stack(
|
|
77
|
-
(
|
|
78
|
-
event_times,
|
|
79
|
-
np.zeros_like(event_times, dtype=int),
|
|
80
|
-
np.ones_like(event_times, dtype=int),
|
|
81
|
-
)
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
# there need to be minimum 2 of two epochs, otherwise mne_connectivity
|
|
85
|
-
# is not correctly initialized
|
|
86
|
-
if events.shape[0] < 2:
|
|
87
|
-
raise RuntimeError(
|
|
88
|
-
f"A minimum of 2 epochs is required for mne_connectivity,"
|
|
89
|
-
f" got only {events.shape[0]}. Increase settings['segment_length_features_ms']"
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self.epochs = Epochs(
|
|
93
|
-
self.raw_array,
|
|
94
|
-
events=events,
|
|
95
|
-
event_id={"rest": 1},
|
|
96
|
-
tmin=0,
|
|
97
|
-
tmax=epoch_length,
|
|
98
|
-
baseline=None,
|
|
99
|
-
reject_by_annotation=True,
|
|
100
|
-
verbose=False,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
# Trick the function "spectral_connectivity_epochs" into not calling "add_annotations_to_metadata"
|
|
104
|
-
# TODO: This is a hack, and maybe needs a fix in the mne_connectivity library
|
|
105
|
-
self.epochs._metadata = pd.DataFrame(index=np.arange(events.shape[0]))
|
|
106
|
-
|
|
107
|
-
else:
|
|
108
|
-
# As long as the initialization parameters, channels, sfreq and batch size are the same
|
|
109
|
-
# We can re-use the existing epochs object by updating the raw data
|
|
110
|
-
self.raw_array._data = data
|
|
111
|
-
self.epochs._raw = self.raw_array
|
|
112
78
|
|
|
113
79
|
# n_jobs is here kept to 1, since setup of the multiprocessing Pool
|
|
114
80
|
# takes longer than most batch computing sizes
|
|
115
81
|
spec_out = spectral_connectivity_epochs(
|
|
116
|
-
data=
|
|
82
|
+
data=np.expand_dims(data, axis=0), # add singleton epoch dimension
|
|
117
83
|
sfreq=self.sfreq,
|
|
118
84
|
method=self.method,
|
|
119
85
|
mode=self.mode,
|
|
120
|
-
indices=
|
|
86
|
+
indices=self.indices,
|
|
121
87
|
verbose=False,
|
|
122
88
|
)
|
|
123
89
|
dat_conn: np.ndarray = spec_out.get_data()
|
|
@@ -127,20 +93,27 @@ class MNEConnectivity(NMFeature):
|
|
|
127
93
|
for fband_range in self.fbands.values():
|
|
128
94
|
self.fband_ranges.append(
|
|
129
95
|
np.where(
|
|
130
|
-
(np.array(spec_out.freqs)
|
|
131
|
-
& (np.array(spec_out.freqs)
|
|
96
|
+
(np.array(spec_out.freqs) >= fband_range[0])
|
|
97
|
+
& (np.array(spec_out.freqs) <= fband_range[1])
|
|
132
98
|
)[0]
|
|
133
99
|
)
|
|
134
100
|
|
|
135
|
-
# TODO: If I compute the mean for the entire fband, results are almost the same before
|
|
136
|
-
# normalization (0.9999999... vs 1.0), but some change wildly after normalization (-3 vs 0)
|
|
137
|
-
# Investigate why, is this a bug in normalization?
|
|
138
101
|
feature_results = {}
|
|
139
|
-
for
|
|
140
|
-
for fband_idx,
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
102
|
+
for con_idx in np.arange(dat_conn.shape[0]):
|
|
103
|
+
for fband_idx, fband_name in enumerate(self.fbands):
|
|
104
|
+
# TODO: Add support for max_fband and max_allfbands
|
|
105
|
+
feature_results[
|
|
106
|
+
"_".join(
|
|
107
|
+
[
|
|
108
|
+
self.method,
|
|
109
|
+
self.channels[con_idx][0], # seed channel name
|
|
110
|
+
"to",
|
|
111
|
+
self.channels[con_idx][1], # target channel name
|
|
112
|
+
"mean_fband",
|
|
113
|
+
fband_name,
|
|
114
|
+
]
|
|
115
|
+
)
|
|
116
|
+
] = np.mean(dat_conn[con_idx, self.fband_ranges[fband_idx]])
|
|
144
117
|
|
|
145
118
|
# Store current experiment parameters to check if re-initialization is needed
|
|
146
119
|
self.prev_batch_shape = data.shape
|
|
@@ -3,6 +3,7 @@ import numpy as np
|
|
|
3
3
|
from itertools import product
|
|
4
4
|
|
|
5
5
|
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
7
|
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
@@ -17,12 +18,12 @@ class OscillatoryFeatures(BoolSelector):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class OscillatorySettings(NMBaseModel):
|
|
20
|
-
windowlength_ms: int = 1000
|
|
21
|
+
windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
|
|
21
22
|
log_transform: bool = True
|
|
22
23
|
features: OscillatoryFeatures = OscillatoryFeatures(
|
|
23
24
|
mean=True, median=False, std=False, max=False
|
|
24
25
|
)
|
|
25
|
-
return_spectrum: bool =
|
|
26
|
+
return_spectrum: bool = True
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
ESTIMATOR_DICT = {
|
|
@@ -176,7 +177,7 @@ class Welch(OscillatoryFeature):
|
|
|
176
177
|
if self.settings.return_spectrum:
|
|
177
178
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
178
179
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
179
|
-
feature_results[f"{ch_name}_welch_psd_{
|
|
180
|
+
feature_results[f"{ch_name}_welch_psd_{int(f)}"] = Z[ch_idx][idx]
|
|
180
181
|
|
|
181
182
|
return feature_results
|
|
182
183
|
|
|
@@ -242,7 +243,7 @@ class STFT(OscillatoryFeature):
|
|
|
242
243
|
if self.settings.return_spectrum:
|
|
243
244
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
244
245
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
245
|
-
feature_results[f"{ch_name}_stft_psd_{
|
|
246
|
+
feature_results[f"{ch_name}_stft_psd_{int(f)}"] = Z[ch_idx].mean(
|
|
246
247
|
axis=1
|
|
247
248
|
)[idx]
|
|
248
249
|
|
|
@@ -267,6 +267,14 @@ class SharpwaveAnalyzer(NMFeature):
|
|
|
267
267
|
|
|
268
268
|
# for each feature take the respective fun.
|
|
269
269
|
for feature_name, estimator_name, estimator in estimator_combinations:
|
|
270
|
+
if feature_name == "num_peaks":
|
|
271
|
+
key_name = f"{ch_name}_Sharpwave_{feature_name}_{filter_name}"
|
|
272
|
+
if len(waveform_results[feature_name]) == 1:
|
|
273
|
+
dict_ch_features[key_name][key_name_pt] = waveform_results[feature_name][0]
|
|
274
|
+
continue
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError("num_peaks should be a list with length 1")
|
|
277
|
+
# there can be only one num_peak in each batch
|
|
270
278
|
feature_data = waveform_results[feature_name]
|
|
271
279
|
key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
|
|
272
280
|
|
|
@@ -280,12 +288,25 @@ class SharpwaveAnalyzer(NMFeature):
|
|
|
280
288
|
|
|
281
289
|
# the key_name stays, since the estimator function stays between peaks and troughs
|
|
282
290
|
for key_name, estimator in self.estimator_key_map.items():
|
|
291
|
+
if len(dict_ch_features[key_name]) == 0:
|
|
292
|
+
# might happen if num_peaks was written in estimator
|
|
293
|
+
# e.g. estimator["mean"] = ["num_peaks"]
|
|
294
|
+
# for conveniance this doesn't raise an exception
|
|
295
|
+
continue
|
|
296
|
+
|
|
283
297
|
feature_results[key_name] = estimator(
|
|
284
298
|
[
|
|
285
299
|
list(dict_ch_features[key_name].values())[0],
|
|
286
300
|
list(dict_ch_features[key_name].values())[1],
|
|
287
301
|
]
|
|
288
302
|
)
|
|
303
|
+
# add here also the num_peaks features
|
|
304
|
+
if self.sw_settings.sharpwave_features.num_peaks:
|
|
305
|
+
for ch_name in self.ch_names:
|
|
306
|
+
for filter_name in self.filter_names:
|
|
307
|
+
key_name = f"{ch_name}_Sharpwave_num_peaks_{filter_name}"
|
|
308
|
+
feature_results[key_name] = np_mean([dict_ch_features[key_name]["Peak"],
|
|
309
|
+
dict_ch_features[key_name]["Trough"]])
|
|
289
310
|
else:
|
|
290
311
|
# otherwise, save all write all "flattened" key value pairs in feature_results
|
|
291
312
|
for key, subdict in dict_ch_features.items():
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
|
+
|
|
4
5
|
from py_neuromodulation.utils.types import NMBaseModel
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMErrorList
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
if TYPE_CHECKING:
|
|
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
|
|
|
22
24
|
"HFA",
|
|
23
25
|
]
|
|
24
26
|
|
|
25
|
-
def validate_fbands(self, settings: "NMSettings") ->
|
|
26
|
-
|
|
27
|
+
def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
|
|
28
|
+
errors: NMErrorList = NMErrorList()
|
|
29
|
+
|
|
30
|
+
if not all(
|
|
27
31
|
(item in settings.frequency_ranges_hz for item in self.frequency_bands)
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
+
):
|
|
33
|
+
errors.add_error(
|
|
34
|
+
"Frequency bands for Kalman filter must also be specified in "
|
|
35
|
+
"frequency_ranges_hz.",
|
|
36
|
+
location=[
|
|
37
|
+
"kalman_filter_settings",
|
|
38
|
+
"frequency_bands",
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return errors
|
|
32
43
|
|
|
33
44
|
|
|
34
45
|
def define_KF(Tp, sigma_w, sigma_v):
|