py-neuromodulation 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +106 -0
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +119 -0
- py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
- py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
- py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
- py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
- py_neuromodulation/{helper.py → _write_example_dataset_helper.py} +1 -1
- py_neuromodulation/nm_EpochStream.py +2 -3
- py_neuromodulation/nm_IO.py +43 -70
- py_neuromodulation/nm_RMAP.py +308 -11
- py_neuromodulation/nm_analysis.py +1 -1
- py_neuromodulation/nm_artifacts.py +25 -0
- py_neuromodulation/nm_bispectra.py +64 -29
- py_neuromodulation/nm_bursts.py +44 -30
- py_neuromodulation/nm_coherence.py +2 -1
- py_neuromodulation/nm_features.py +4 -2
- py_neuromodulation/nm_filter.py +63 -32
- py_neuromodulation/nm_filter_preprocessing.py +91 -0
- py_neuromodulation/nm_fooof.py +47 -29
- py_neuromodulation/nm_mne_connectivity.py +1 -1
- py_neuromodulation/nm_normalization.py +50 -74
- py_neuromodulation/nm_oscillatory.py +151 -31
- py_neuromodulation/nm_plots.py +13 -10
- py_neuromodulation/nm_rereference.py +10 -8
- py_neuromodulation/nm_run_analysis.py +28 -13
- py_neuromodulation/nm_settings.json +51 -3
- py_neuromodulation/nm_sharpwaves.py +103 -136
- py_neuromodulation/nm_stats.py +44 -30
- py_neuromodulation/nm_stream_abc.py +18 -10
- py_neuromodulation/nm_stream_offline.py +188 -46
- py_neuromodulation/utils/_logging.py +24 -0
- {py_neuromodulation-0.0.2.dist-info → py_neuromodulation-0.0.4.dist-info}/METADATA +72 -32
- py_neuromodulation-0.0.4.dist-info/RECORD +72 -0
- {py_neuromodulation-0.0.2.dist-info → py_neuromodulation-0.0.4.dist-info}/WHEEL +1 -1
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/MOV_aligned_features_ch_ECOG_RIGHT_0_all.png +0 -0
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/all_feature_plt.pdf +0 -0
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_FEATURES.csv +0 -182
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_LM_ML_RES.p +0 -0
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_SETTINGS.json +0 -273
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_SIDECAR.json +0 -6
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_decoding_performance.png +0 -0
- py_neuromodulation/data/derivatives/sub-testsub_ses-EphysMedOff_task-gripforce_run-0/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_nm_channels.csv +0 -11
- py_neuromodulation/py_neuromodulation.egg-info/PKG-INFO +0 -104
- py_neuromodulation/py_neuromodulation.egg-info/dependency_links.txt +0 -1
- py_neuromodulation/py_neuromodulation.egg-info/requires.txt +0 -26
- py_neuromodulation/py_neuromodulation.egg-info/top_level.txt +0 -1
- py_neuromodulation-0.0.2.dist-info/RECORD +0 -73
- /py_neuromodulation/{py_neuromodulation.egg-info/SOURCES.txt → utils/__init__.py} +0 -0
- {py_neuromodulation-0.0.2.dist-info → py_neuromodulation-0.0.4.dist-info/licenses}/LICENSE +0 -0
|
@@ -3,8 +3,6 @@ from enum import Enum
|
|
|
3
3
|
|
|
4
4
|
from sklearn import preprocessing
|
|
5
5
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
|
|
8
6
|
class NORM_METHODS(Enum):
|
|
9
7
|
MEAN = "mean"
|
|
10
8
|
MEDIAN = "median"
|
|
@@ -138,6 +136,17 @@ class FeatureNormalizer:
|
|
|
138
136
|
|
|
139
137
|
return data
|
|
140
138
|
|
|
139
|
+
"""
|
|
140
|
+
Functions to check for NaN's before deciding which Numpy function to call
|
|
141
|
+
"""
|
|
142
|
+
def nan_mean(data, axis):
|
|
143
|
+
return np.nanmean(data, axis=axis) if np.any(np.isnan(sum(data))) else np.mean(data, axis=axis)
|
|
144
|
+
|
|
145
|
+
def nan_std(data, axis):
|
|
146
|
+
return np.nanstd(data, axis=axis) if np.any(np.isnan(sum(data))) else np.std(data, axis=axis)
|
|
147
|
+
|
|
148
|
+
def nan_median(data, axis):
|
|
149
|
+
return np.nanmedian(data, axis=axis) if np.any(np.isnan(sum(data))) else np.median(data, axis=axis)
|
|
141
150
|
|
|
142
151
|
def _normalize_and_clip(
|
|
143
152
|
current: np.ndarray,
|
|
@@ -147,82 +156,49 @@ def _normalize_and_clip(
|
|
|
147
156
|
description: str,
|
|
148
157
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
149
158
|
"""Normalize data."""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
.
|
|
173
|
-
.
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
.fit(np.nan_to_num(previous))
|
|
179
|
-
.transform(current)
|
|
180
|
-
)
|
|
181
|
-
elif method == NORM_METHODS.ROBUST.value:
|
|
182
|
-
if len(current.shape) == 1:
|
|
183
|
-
current = (
|
|
184
|
-
preprocessing.RobustScaler()
|
|
185
|
-
.fit(np.nan_to_num(previous))
|
|
186
|
-
.transform(np.expand_dims(current, axis=0))[0, :]
|
|
187
|
-
)
|
|
188
|
-
else:
|
|
159
|
+
match method:
|
|
160
|
+
case NORM_METHODS.MEAN.value:
|
|
161
|
+
mean = nan_mean(previous, axis=0)
|
|
162
|
+
current = (current - mean) / mean
|
|
163
|
+
case NORM_METHODS.MEDIAN.value:
|
|
164
|
+
median = nan_median(previous, axis=0)
|
|
165
|
+
current = (current - median) / median
|
|
166
|
+
case NORM_METHODS.ZSCORE.value:
|
|
167
|
+
current = (current - nan_mean(previous, axis=0)) / nan_std(previous, axis=0)
|
|
168
|
+
case NORM_METHODS.ZSCORE_MEDIAN.value:
|
|
169
|
+
current = (current - nan_median(previous, axis=0)) / nan_std(previous, axis=0)
|
|
170
|
+
# For the following methods we check for the shape of current
|
|
171
|
+
# when current is a 1D array, then it is the post-processing normalization,
|
|
172
|
+
# and we need to expand, and remove the extra dimension afterwards
|
|
173
|
+
# When current is a 2D array, then it is pre-processing normalization, and
|
|
174
|
+
# there's no need for expanding.
|
|
175
|
+
case (NORM_METHODS.QUANTILE.value |
|
|
176
|
+
NORM_METHODS.ROBUST.value |
|
|
177
|
+
NORM_METHODS.MINMAX.value |
|
|
178
|
+
NORM_METHODS.POWER.value):
|
|
179
|
+
|
|
180
|
+
norm_methods = {
|
|
181
|
+
NORM_METHODS.QUANTILE.value : lambda: preprocessing.QuantileTransformer(n_quantiles=300),
|
|
182
|
+
NORM_METHODS.ROBUST.value : preprocessing.RobustScaler,
|
|
183
|
+
NORM_METHODS.MINMAX.value : preprocessing.MinMaxScaler,
|
|
184
|
+
NORM_METHODS.POWER.value : preprocessing.PowerTransformer
|
|
185
|
+
}
|
|
186
|
+
|
|
189
187
|
current = (
|
|
190
|
-
|
|
188
|
+
norm_methods[method]()
|
|
191
189
|
.fit(np.nan_to_num(previous))
|
|
192
|
-
.transform(
|
|
190
|
+
.transform(
|
|
191
|
+
# if post-processing: pad dimensions to 2
|
|
192
|
+
np.reshape(current, (2-len(current.shape))*(1,) + current.shape)
|
|
193
|
+
)
|
|
194
|
+
.squeeze() # if post-processing: remove extra dimension
|
|
193
195
|
)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
.fit(np.nan_to_num(previous))
|
|
200
|
-
.transform(np.expand_dims(current, axis=0))[0, :]
|
|
201
|
-
)
|
|
202
|
-
else:
|
|
203
|
-
current = (
|
|
204
|
-
preprocessing.MinMaxScaler()
|
|
205
|
-
.fit(np.nan_to_num(previous))
|
|
206
|
-
.transform(current)
|
|
207
|
-
)
|
|
208
|
-
elif method == NORM_METHODS.POWER.value:
|
|
209
|
-
if len(current.shape) == 1:
|
|
210
|
-
current = (
|
|
211
|
-
preprocessing.PowerTransformer()
|
|
212
|
-
.fit(np.nan_to_num(previous))
|
|
213
|
-
.transform(np.expand_dims(current, axis=0))[0, :]
|
|
196
|
+
|
|
197
|
+
case _:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"Only {[e.value for e in NORM_METHODS]} are supported as "
|
|
200
|
+
f"{description} normalization methods. Got {method}."
|
|
214
201
|
)
|
|
215
|
-
else:
|
|
216
|
-
current = (
|
|
217
|
-
preprocessing.PowerTransformer()
|
|
218
|
-
.fit(np.nan_to_num(previous))
|
|
219
|
-
.transform(current)
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
raise ValueError(
|
|
223
|
-
f"Only {[e.value for e in NORM_METHODS]} are supported as "
|
|
224
|
-
f"{description} normalization methods. Got {method}."
|
|
225
|
-
)
|
|
226
202
|
|
|
227
203
|
if clip:
|
|
228
204
|
current = _clip(data=current, clip=clip)
|
|
@@ -38,6 +38,15 @@ class OscillatoryFeature(nm_features_abc.Feature):
|
|
|
38
38
|
assert isinstance(
|
|
39
39
|
s[osc_feature_name]["windowlength_ms"], int
|
|
40
40
|
), f"windowlength_ms needs to be type int, got {s[osc_feature_name]['windowlength_ms']}"
|
|
41
|
+
|
|
42
|
+
assert (
|
|
43
|
+
s[osc_feature_name]["windowlength_ms"]
|
|
44
|
+
<= s["segment_length_features_ms"]
|
|
45
|
+
), (
|
|
46
|
+
f"oscillatory feature windowlength_ms = ({s[osc_feature_name]['windowlength_ms']})"
|
|
47
|
+
f"needs to be smaller than"
|
|
48
|
+
f"s['segment_length_features_ms'] = {s['segment_length_features_ms']}",
|
|
49
|
+
)
|
|
41
50
|
else:
|
|
42
51
|
for seg_length in s[osc_feature_name][
|
|
43
52
|
"segment_lengths_ms"
|
|
@@ -48,12 +57,6 @@ class OscillatoryFeature(nm_features_abc.Feature):
|
|
|
48
57
|
assert isinstance(
|
|
49
58
|
s[osc_feature_name]["log_transform"], bool
|
|
50
59
|
), f"log_transform needs to be type bool, got {s[osc_feature_name]['log_transform']}"
|
|
51
|
-
assert isinstance(
|
|
52
|
-
s[osc_feature_name]["kalman_filter"], bool
|
|
53
|
-
), f"kalman_filter needs to be type bool, got {s[osc_feature_name]['kalman_filter']}"
|
|
54
|
-
|
|
55
|
-
if s[osc_feature_name]["kalman_filter"] is True:
|
|
56
|
-
nm_kalmanfilter.test_kf_settings(s, ch_names, sfreq)
|
|
57
60
|
|
|
58
61
|
assert isinstance(s["frequency_ranges_hz"], dict)
|
|
59
62
|
|
|
@@ -95,6 +98,36 @@ class OscillatoryFeature(nm_features_abc.Feature):
|
|
|
95
98
|
feature_calc = self.KF_dict[KF_name].x[0]
|
|
96
99
|
return feature_calc
|
|
97
100
|
|
|
101
|
+
def estimate_osc_features(
|
|
102
|
+
self,
|
|
103
|
+
features_compute: dict,
|
|
104
|
+
data: np.ndarray,
|
|
105
|
+
feature_name: np.ndarray,
|
|
106
|
+
est_name: str,
|
|
107
|
+
):
|
|
108
|
+
for feature_est_name in list(self.s[est_name]["features"].keys()):
|
|
109
|
+
if self.s[est_name]["features"][feature_est_name] is True:
|
|
110
|
+
# switch case for feature_est_name
|
|
111
|
+
match feature_est_name:
|
|
112
|
+
case "mean":
|
|
113
|
+
features_compute[
|
|
114
|
+
f"{feature_name}_{feature_est_name}"
|
|
115
|
+
] = np.nanmean(data)
|
|
116
|
+
case "median":
|
|
117
|
+
features_compute[
|
|
118
|
+
f"{feature_name}_{feature_est_name}"
|
|
119
|
+
] = np.nanmedian(data)
|
|
120
|
+
case "std":
|
|
121
|
+
features_compute[
|
|
122
|
+
f"{feature_name}_{feature_est_name}"
|
|
123
|
+
] = np.nanstd(data)
|
|
124
|
+
case "max":
|
|
125
|
+
features_compute[
|
|
126
|
+
f"{feature_name}_{feature_est_name}"
|
|
127
|
+
] = np.nanmax(data)
|
|
128
|
+
|
|
129
|
+
return features_compute
|
|
130
|
+
|
|
98
131
|
|
|
99
132
|
class FFT(OscillatoryFeature):
|
|
100
133
|
def __init__(
|
|
@@ -104,8 +137,6 @@ class FFT(OscillatoryFeature):
|
|
|
104
137
|
sfreq: float,
|
|
105
138
|
) -> None:
|
|
106
139
|
super().__init__(settings, ch_names, sfreq)
|
|
107
|
-
if self.s["fft_settings"]["kalman_filter"]:
|
|
108
|
-
self.init_KF("fft")
|
|
109
140
|
|
|
110
141
|
if self.s["fft_settings"]["log_transform"]:
|
|
111
142
|
self.log_transform = True
|
|
@@ -114,13 +145,15 @@ class FFT(OscillatoryFeature):
|
|
|
114
145
|
|
|
115
146
|
window_ms = self.s["fft_settings"]["windowlength_ms"]
|
|
116
147
|
self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
|
|
117
|
-
freqs = fft.rfftfreq(
|
|
148
|
+
self.freqs = fft.rfftfreq(
|
|
149
|
+
-self.window_samples, 1 / np.floor(self.sfreq)
|
|
150
|
+
)
|
|
118
151
|
|
|
119
152
|
self.feature_params = []
|
|
120
153
|
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
121
154
|
for fband, f_range in self.f_ranges_dict.items():
|
|
122
155
|
idx_range = np.where(
|
|
123
|
-
(freqs >= f_range[0]) & (freqs < f_range[1])
|
|
156
|
+
(self.freqs >= f_range[0]) & (self.freqs < f_range[1])
|
|
124
157
|
)[0]
|
|
125
158
|
feature_name = "_".join([ch_name, "fft", fband])
|
|
126
159
|
self.feature_params.append((ch_idx, feature_name, idx_range))
|
|
@@ -132,17 +165,87 @@ class FFT(OscillatoryFeature):
|
|
|
132
165
|
def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
|
|
133
166
|
data = data[:, self.window_samples :]
|
|
134
167
|
Z = np.abs(fft.rfft(data))
|
|
168
|
+
|
|
169
|
+
if self.log_transform:
|
|
170
|
+
Z = np.log10(Z)
|
|
171
|
+
|
|
135
172
|
for ch_idx, feature_name, idx_range in self.feature_params:
|
|
136
173
|
Z_ch = Z[ch_idx, idx_range]
|
|
137
|
-
feature_calc = np.mean(Z_ch)
|
|
138
174
|
|
|
139
|
-
|
|
140
|
-
|
|
175
|
+
features_compute = self.estimate_osc_features(
|
|
176
|
+
features_compute, Z_ch, feature_name, "fft_settings"
|
|
177
|
+
)
|
|
141
178
|
|
|
142
|
-
|
|
143
|
-
|
|
179
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
180
|
+
if self.s["fft_settings"]["return_spectrum"]:
|
|
181
|
+
features_compute.update(
|
|
182
|
+
{
|
|
183
|
+
f"{ch_name}_fft_psd_{str(f)}": Z[ch_idx][idx]
|
|
184
|
+
for idx, f in enumerate(self.freqs.astype(int))
|
|
185
|
+
}
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return features_compute
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class Welch(OscillatoryFeature):
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
settings: dict,
|
|
195
|
+
ch_names: Iterable[str],
|
|
196
|
+
sfreq: float,
|
|
197
|
+
) -> None:
|
|
198
|
+
super().__init__(settings, ch_names, sfreq)
|
|
199
|
+
|
|
200
|
+
self.log_transform = self.s["welch_settings"]["log_transform"]
|
|
201
|
+
|
|
202
|
+
self.feature_params = []
|
|
203
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
204
|
+
for fband, f_range in self.f_ranges_dict.items():
|
|
205
|
+
feature_name = "_".join([ch_name, "welch", fband])
|
|
206
|
+
self.feature_params.append((ch_idx, feature_name, f_range))
|
|
207
|
+
|
|
208
|
+
@staticmethod
|
|
209
|
+
def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
|
|
210
|
+
OscillatoryFeature.test_settings_osc(
|
|
211
|
+
s, ch_names, sfreq, "welch_settings"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
|
|
215
|
+
freqs, Z = signal.welch(
|
|
216
|
+
data,
|
|
217
|
+
fs=self.sfreq,
|
|
218
|
+
window="hann",
|
|
219
|
+
nperseg=self.sfreq,
|
|
220
|
+
noverlap=None,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if self.log_transform:
|
|
224
|
+
Z = np.log10(Z)
|
|
225
|
+
|
|
226
|
+
for ch_idx, feature_name, f_range in self.feature_params:
|
|
227
|
+
Z_ch = Z[ch_idx]
|
|
228
|
+
|
|
229
|
+
idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
|
|
230
|
+
0
|
|
231
|
+
]
|
|
232
|
+
|
|
233
|
+
features_compute = self.estimate_osc_features(
|
|
234
|
+
features_compute,
|
|
235
|
+
Z_ch[idx_range],
|
|
236
|
+
feature_name,
|
|
237
|
+
"welch_settings",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
241
|
+
if self.s["welch_settings"]["return_spectrum"]:
|
|
242
|
+
features_compute.update(
|
|
243
|
+
{
|
|
244
|
+
f"{ch_name}_welch_psd_{str(f)}": Z[ch_idx][idx]
|
|
245
|
+
for idx, f in enumerate(freqs.astype(int))
|
|
246
|
+
}
|
|
247
|
+
)
|
|
144
248
|
|
|
145
|
-
features_compute[feature_name] = feature_calc
|
|
146
249
|
return features_compute
|
|
147
250
|
|
|
148
251
|
|
|
@@ -154,10 +257,9 @@ class STFT(OscillatoryFeature):
|
|
|
154
257
|
sfreq: float,
|
|
155
258
|
) -> None:
|
|
156
259
|
super().__init__(settings, ch_names, sfreq)
|
|
157
|
-
if self.s["stft_settings"]["kalman_filter"]:
|
|
158
|
-
self.init_KF("stft")
|
|
159
260
|
|
|
160
261
|
self.nperseg = int(self.s["stft_settings"]["windowlength_ms"])
|
|
262
|
+
self.log_transform = self.s["stft_settings"]["log_transform"]
|
|
161
263
|
|
|
162
264
|
self.feature_params = []
|
|
163
265
|
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
@@ -172,7 +274,7 @@ class STFT(OscillatoryFeature):
|
|
|
172
274
|
)
|
|
173
275
|
|
|
174
276
|
def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
|
|
175
|
-
|
|
277
|
+
freqs, _, Zxx = signal.stft(
|
|
176
278
|
data,
|
|
177
279
|
fs=self.sfreq,
|
|
178
280
|
window="hamming",
|
|
@@ -180,15 +282,30 @@ class STFT(OscillatoryFeature):
|
|
|
180
282
|
boundary="even",
|
|
181
283
|
)
|
|
182
284
|
Z = np.abs(Zxx)
|
|
285
|
+
if self.log_transform:
|
|
286
|
+
Z = np.log10(Z)
|
|
183
287
|
for ch_idx, feature_name, f_range in self.feature_params:
|
|
184
288
|
Z_ch = Z[ch_idx]
|
|
185
|
-
idx_range = np.where((
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
289
|
+
idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
|
|
290
|
+
0
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
features_compute = self.estimate_osc_features(
|
|
294
|
+
features_compute,
|
|
295
|
+
Z_ch[idx_range, :],
|
|
296
|
+
feature_name,
|
|
297
|
+
"stft_settings",
|
|
298
|
+
)
|
|
190
299
|
|
|
191
|
-
|
|
300
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
301
|
+
if self.s["stft_settings"]["return_spectrum"]:
|
|
302
|
+
Z_ch_mean = Z[ch_idx].mean(axis=1)
|
|
303
|
+
features_compute.update(
|
|
304
|
+
{
|
|
305
|
+
f"{ch_name}_stft_psd_{str(f)}": Z_ch_mean[idx]
|
|
306
|
+
for idx, f in enumerate(freqs.astype(int))
|
|
307
|
+
}
|
|
308
|
+
)
|
|
192
309
|
|
|
193
310
|
return features_compute
|
|
194
311
|
|
|
@@ -204,7 +321,7 @@ class BandPower(OscillatoryFeature):
|
|
|
204
321
|
super().__init__(settings, ch_names, sfreq)
|
|
205
322
|
bp_settings = self.s["bandpass_filter_settings"]
|
|
206
323
|
|
|
207
|
-
self.bandpass_filter = nm_filter.
|
|
324
|
+
self.bandpass_filter = nm_filter.MNEFilter(
|
|
208
325
|
f_ranges=list(self.f_ranges_dict.values()),
|
|
209
326
|
sfreq=self.sfreq,
|
|
210
327
|
filter_length=self.sfreq - 1,
|
|
@@ -265,7 +382,9 @@ class BandPower(OscillatoryFeature):
|
|
|
265
382
|
].values()
|
|
266
383
|
), "Set at least one bandpower_feature to True."
|
|
267
384
|
|
|
268
|
-
for fband_name, seg_length_fband in s["bandpass_filter_settings"][
|
|
385
|
+
for fband_name, seg_length_fband in s["bandpass_filter_settings"][
|
|
386
|
+
"segment_lengths_ms"
|
|
387
|
+
].items():
|
|
269
388
|
assert isinstance(seg_length_fband, int), (
|
|
270
389
|
f"bandpass segment_lengths_ms for {fband_name} "
|
|
271
390
|
f"needs to be of type int, got {seg_length_fband}"
|
|
@@ -275,15 +394,16 @@ class BandPower(OscillatoryFeature):
|
|
|
275
394
|
f"segment length {seg_length_fband} needs to be smaller than "
|
|
276
395
|
f" s['segment_length_features_ms'] = {s['segment_length_features_ms']}"
|
|
277
396
|
)
|
|
278
|
-
|
|
397
|
+
|
|
279
398
|
for fband_name in list(s["frequency_ranges_hz"].keys()):
|
|
280
|
-
assert fband_name in list(
|
|
399
|
+
assert fband_name in list(
|
|
400
|
+
s["bandpass_filter_settings"]["segment_lengths_ms"].keys()
|
|
401
|
+
), (
|
|
281
402
|
f"frequency range {fband_name} "
|
|
282
403
|
"needs to be defined in s['bandpass_filter_settings']['segment_lengths_ms']"
|
|
283
404
|
)
|
|
284
405
|
|
|
285
406
|
def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
|
|
286
|
-
|
|
287
407
|
data = self.bandpass_filter.filter_data(data)
|
|
288
408
|
|
|
289
409
|
for (
|
|
@@ -297,7 +417,7 @@ class BandPower(OscillatoryFeature):
|
|
|
297
417
|
) in self.feature_params:
|
|
298
418
|
if bp_feature == "activity":
|
|
299
419
|
if self.log_transform:
|
|
300
|
-
feature_calc = np.
|
|
420
|
+
feature_calc = np.log10(
|
|
301
421
|
np.var(data[ch_idx, f_band_idx, -seglen:])
|
|
302
422
|
)
|
|
303
423
|
else:
|
py_neuromodulation/nm_plots.py
CHANGED
|
@@ -6,6 +6,9 @@ from matplotlib import gridspec
|
|
|
6
6
|
from typing import Optional
|
|
7
7
|
import seaborn as sb
|
|
8
8
|
import pandas as pd
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("PynmLogger")
|
|
9
12
|
|
|
10
13
|
from py_neuromodulation import nm_IO, nm_stats
|
|
11
14
|
|
|
@@ -87,7 +90,9 @@ def plot_epoch(
|
|
|
87
90
|
):
|
|
88
91
|
if z_score is None:
|
|
89
92
|
X_epoch = stats.zscore(
|
|
90
|
-
np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
|
|
93
|
+
np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
|
|
94
|
+
axis=0,
|
|
95
|
+
nan_policy="omit",
|
|
91
96
|
).T
|
|
92
97
|
y_epoch = np.stack(np.array(y_epoch))
|
|
93
98
|
plt.figure(figsize=(6, 6))
|
|
@@ -237,7 +242,7 @@ def plot_corr_matrix(
|
|
|
237
242
|
|
|
238
243
|
if save_plot:
|
|
239
244
|
plt.savefig(plt_path, bbox_inches="tight")
|
|
240
|
-
|
|
245
|
+
logger.info(f"Correlation matrix figure saved to {plt_path}")
|
|
241
246
|
|
|
242
247
|
if show_plot is False:
|
|
243
248
|
plt.close()
|
|
@@ -329,7 +334,7 @@ def plot_epochs_avg(
|
|
|
329
334
|
|
|
330
335
|
if normalize_data:
|
|
331
336
|
X_epoch_mean = stats.zscore(
|
|
332
|
-
np.nanmean(np.squeeze(X_epoch), axis=0), axis=0
|
|
337
|
+
np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
|
|
333
338
|
).T
|
|
334
339
|
else:
|
|
335
340
|
X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
|
|
@@ -385,7 +390,7 @@ def plot_epochs_avg(
|
|
|
385
390
|
feature_name=feature_str_add,
|
|
386
391
|
)
|
|
387
392
|
plt.savefig(plt_path, bbox_inches="tight")
|
|
388
|
-
|
|
393
|
+
logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
|
|
389
394
|
if show_plot is False:
|
|
390
395
|
plt.close()
|
|
391
396
|
|
|
@@ -441,7 +446,6 @@ def plot_all_features(
|
|
|
441
446
|
OUT_PATH: str = None,
|
|
442
447
|
feature_file: str = None,
|
|
443
448
|
):
|
|
444
|
-
|
|
445
449
|
if time_limit_high_s is not None:
|
|
446
450
|
df = df[df["time"] < time_limit_high_s * 1000]
|
|
447
451
|
if time_limit_low_s is not None:
|
|
@@ -449,7 +453,7 @@ def plot_all_features(
|
|
|
449
453
|
|
|
450
454
|
cols_plt = [c for c in df.columns if c != "time"]
|
|
451
455
|
if normalize is True:
|
|
452
|
-
data_plt = stats.zscore(df[cols_plt])
|
|
456
|
+
data_plt = stats.zscore(df[cols_plt], nan_policy="omit")
|
|
453
457
|
else:
|
|
454
458
|
data_plt = df[cols_plt]
|
|
455
459
|
|
|
@@ -487,7 +491,6 @@ class NM_Plot:
|
|
|
487
491
|
sess_right: Optional[bool] = False,
|
|
488
492
|
proj_matrix_cortex: np.ndarray | None = None,
|
|
489
493
|
) -> None:
|
|
490
|
-
|
|
491
494
|
self.grid_cortex = grid_cortex
|
|
492
495
|
self.grid_subcortex = grid_subcortex
|
|
493
496
|
self.ecog_strip = ecog_strip
|
|
@@ -510,7 +513,6 @@ class NM_Plot:
|
|
|
510
513
|
) = nm_IO.read_plot_modules()
|
|
511
514
|
|
|
512
515
|
def plot_grid_elec_3d(self) -> None:
|
|
513
|
-
|
|
514
516
|
plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
|
|
515
517
|
|
|
516
518
|
def plot_cortex(
|
|
@@ -552,7 +554,6 @@ class NM_Plot:
|
|
|
552
554
|
axes.axes.set_aspect("equal", anchor="C")
|
|
553
555
|
|
|
554
556
|
if grid_cortex is not None:
|
|
555
|
-
|
|
556
557
|
grid_color = (
|
|
557
558
|
np.ones(grid_cortex.shape[0])
|
|
558
559
|
if grid_color is None
|
|
@@ -604,6 +605,8 @@ class NM_Plot:
|
|
|
604
605
|
feature_name=feature_str_add,
|
|
605
606
|
)
|
|
606
607
|
plt.savefig(plt_path, bbox_inches="tight")
|
|
607
|
-
|
|
608
|
+
logger.info(
|
|
609
|
+
f"Feature epoch average figure saved to: {str(plt_path)}"
|
|
610
|
+
)
|
|
608
611
|
if show_plot is False:
|
|
609
612
|
plt.close()
|
|
@@ -4,10 +4,8 @@ import pandas as pd
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class ReReferencer:
|
|
7
|
-
|
|
8
7
|
ref_matrix: np.ndarray
|
|
9
8
|
|
|
10
|
-
|
|
11
9
|
def __init__(
|
|
12
10
|
self,
|
|
13
11
|
sfreq: int | float,
|
|
@@ -28,11 +26,15 @@ class ReReferencer:
|
|
|
28
26
|
ValueError: rereferencing using undefined channel
|
|
29
27
|
ValueError: rereferencing to same channel
|
|
30
28
|
"""
|
|
31
|
-
|
|
29
|
+
nm_channels = nm_channels[nm_channels["used"] == 1].reset_index(
|
|
30
|
+
drop=True
|
|
31
|
+
)
|
|
32
|
+
# (channels_used,) = np.where((nm_channels.used == 1))
|
|
32
33
|
|
|
33
34
|
ch_names = nm_channels["name"].tolist()
|
|
34
35
|
|
|
35
|
-
|
|
36
|
+
# no re-referencing is being performed when there is a single channel present only
|
|
37
|
+
if nm_channels.shape[0] in (0, 1):
|
|
36
38
|
self.ref_matrix = None
|
|
37
39
|
return
|
|
38
40
|
|
|
@@ -48,8 +50,8 @@ class ReReferencer:
|
|
|
48
50
|
ref_matrix = np.zeros((len(nm_channels), len(nm_channels)))
|
|
49
51
|
for ind in range(len(nm_channels)):
|
|
50
52
|
ref_matrix[ind, ind] = 1
|
|
51
|
-
if ind not in channels_used:
|
|
52
|
-
|
|
53
|
+
# if ind not in channels_used:
|
|
54
|
+
# continue
|
|
53
55
|
ref = refs[ind]
|
|
54
56
|
if ref.lower() == "none" or pd.isnull(ref):
|
|
55
57
|
ref_idx = None
|
|
@@ -84,10 +86,10 @@ class ReReferencer:
|
|
|
84
86
|
shape(n_channels, n_samples) - data to be rereferenced.
|
|
85
87
|
|
|
86
88
|
Returns:
|
|
87
|
-
reref_data (numpy ndarray):
|
|
89
|
+
reref_data (numpy ndarray):
|
|
88
90
|
shape(n_channels, n_samples) - rereferenced data
|
|
89
91
|
"""
|
|
90
92
|
if self.ref_matrix is not None:
|
|
91
93
|
return self.ref_matrix @ data
|
|
92
94
|
else:
|
|
93
|
-
return data
|
|
95
|
+
return data
|