py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/FieldTrip.py +589 -589
- py_neuromodulation/__init__.py +74 -13
- py_neuromodulation/_write_example_dataset_helper.py +83 -65
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/nm_IO.py +413 -417
- py_neuromodulation/nm_RMAP.py +496 -531
- py_neuromodulation/nm_analysis.py +993 -1074
- py_neuromodulation/nm_artifacts.py +30 -25
- py_neuromodulation/nm_bispectra.py +154 -168
- py_neuromodulation/nm_bursts.py +292 -198
- py_neuromodulation/nm_coherence.py +251 -205
- py_neuromodulation/nm_database.py +149 -0
- py_neuromodulation/nm_decode.py +918 -992
- py_neuromodulation/nm_define_nmchannels.py +300 -302
- py_neuromodulation/nm_features.py +144 -116
- py_neuromodulation/nm_filter.py +219 -219
- py_neuromodulation/nm_filter_preprocessing.py +79 -91
- py_neuromodulation/nm_fooof.py +139 -159
- py_neuromodulation/nm_generator.py +45 -37
- py_neuromodulation/nm_hjorth_raw.py +52 -73
- py_neuromodulation/nm_kalmanfilter.py +71 -58
- py_neuromodulation/nm_linelength.py +21 -33
- py_neuromodulation/nm_logger.py +66 -0
- py_neuromodulation/nm_mne_connectivity.py +149 -112
- py_neuromodulation/nm_mnelsl_generator.py +90 -0
- py_neuromodulation/nm_mnelsl_stream.py +116 -0
- py_neuromodulation/nm_nolds.py +96 -93
- py_neuromodulation/nm_normalization.py +173 -214
- py_neuromodulation/nm_oscillatory.py +423 -448
- py_neuromodulation/nm_plots.py +585 -612
- py_neuromodulation/nm_preprocessing.py +83 -0
- py_neuromodulation/nm_projection.py +370 -394
- py_neuromodulation/nm_rereference.py +97 -95
- py_neuromodulation/nm_resample.py +59 -50
- py_neuromodulation/nm_run_analysis.py +325 -435
- py_neuromodulation/nm_settings.py +289 -68
- py_neuromodulation/nm_settings.yaml +244 -0
- py_neuromodulation/nm_sharpwaves.py +423 -401
- py_neuromodulation/nm_stats.py +464 -480
- py_neuromodulation/nm_stream.py +398 -0
- py_neuromodulation/nm_stream_abc.py +166 -218
- py_neuromodulation/nm_types.py +193 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
- py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -1,448 +1,423 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
f"{ch_name}
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
class BandPower(
|
|
314
|
-
def __init__(
|
|
315
|
-
self,
|
|
316
|
-
settings:
|
|
317
|
-
ch_names: Iterable[str],
|
|
318
|
-
sfreq: float,
|
|
319
|
-
use_kf: bool = None,
|
|
320
|
-
) -> None:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
self.
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
self.
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
self.
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
].
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
for
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
feature_calc = np.var(data[ch_idx, f_band_idx, -seglen:])
|
|
425
|
-
elif bp_feature == "mobility":
|
|
426
|
-
deriv_variance = np.var(
|
|
427
|
-
np.diff(data[ch_idx, f_band_idx, -seglen:])
|
|
428
|
-
)
|
|
429
|
-
feature_calc = np.sqrt(
|
|
430
|
-
deriv_variance / np.var(data[ch_idx, f_band_idx, -seglen:])
|
|
431
|
-
)
|
|
432
|
-
elif bp_feature == "complexity":
|
|
433
|
-
dat_deriv = np.diff(data[ch_idx, f_band_idx, -seglen:])
|
|
434
|
-
deriv_variance = np.var(dat_deriv)
|
|
435
|
-
mobility = np.sqrt(
|
|
436
|
-
deriv_variance / np.var(data[ch_idx, f_band_idx, -seglen:])
|
|
437
|
-
)
|
|
438
|
-
dat_deriv_2 = np.diff(dat_deriv)
|
|
439
|
-
dat_deriv_2_var = np.var(dat_deriv_2)
|
|
440
|
-
deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
|
|
441
|
-
feature_calc = deriv_mobility / mobility
|
|
442
|
-
|
|
443
|
-
if self.KF_dict and (bp_feature == "activity"):
|
|
444
|
-
feature_calc = self.update_KF(feature_calc, feature_name)
|
|
445
|
-
|
|
446
|
-
features_compute[feature_name] = np.nan_to_num(feature_calc)
|
|
447
|
-
|
|
448
|
-
return features_compute
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
import numpy as np
|
|
3
|
+
from itertools import product
|
|
4
|
+
|
|
5
|
+
from py_neuromodulation.nm_types import NMBaseModel
|
|
6
|
+
from pydantic import field_validator
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from py_neuromodulation.nm_features import NMFeature
|
|
10
|
+
from py_neuromodulation.nm_types import BoolSelector
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from py_neuromodulation.nm_settings import NMSettings
|
|
14
|
+
from py_neuromodulation.nm_kalmanfilter import KalmanSettings
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OscillatoryFeatures(BoolSelector):
|
|
18
|
+
mean: bool = True
|
|
19
|
+
median: bool = False
|
|
20
|
+
std: bool = False
|
|
21
|
+
max: bool = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OscillatorySettings(NMBaseModel):
|
|
25
|
+
windowlength_ms: int = 1000
|
|
26
|
+
log_transform: bool = True
|
|
27
|
+
features: OscillatoryFeatures = OscillatoryFeatures(
|
|
28
|
+
mean=True, median=False, std=False, max=False
|
|
29
|
+
)
|
|
30
|
+
return_spectrum: bool = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
ESTIMATOR_DICT = {
|
|
34
|
+
"mean": np.nanmean,
|
|
35
|
+
"median": np.nanmedian,
|
|
36
|
+
"std": np.nanstd,
|
|
37
|
+
"max": np.nanmax,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class OscillatoryFeature(NMFeature):
|
|
42
|
+
def __init__(
|
|
43
|
+
self, settings: "NMSettings", ch_names: Iterable[str], sfreq: int
|
|
44
|
+
) -> None:
|
|
45
|
+
settings.validate()
|
|
46
|
+
self.settings: OscillatorySettings # Assignment in subclass __init__
|
|
47
|
+
self.osc_feature_name: str # Required for output
|
|
48
|
+
|
|
49
|
+
self.sfreq = int(sfreq)
|
|
50
|
+
self.ch_names = ch_names
|
|
51
|
+
|
|
52
|
+
self.frequency_ranges = settings.frequency_ranges_hz
|
|
53
|
+
|
|
54
|
+
# Test settings
|
|
55
|
+
assert self.settings.windowlength_ms <= settings.segment_length_features_ms, (
|
|
56
|
+
f"oscillatory feature windowlength_ms = ({self.settings.windowlength_ms})"
|
|
57
|
+
f"needs to be smaller than"
|
|
58
|
+
f"settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class FFT(OscillatoryFeature):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
settings: "NMSettings",
|
|
66
|
+
ch_names: Iterable[str],
|
|
67
|
+
sfreq: int,
|
|
68
|
+
) -> None:
|
|
69
|
+
from scipy.fft import rfftfreq
|
|
70
|
+
|
|
71
|
+
self.osc_feature_name = "fft"
|
|
72
|
+
self.settings = settings.fft_settings
|
|
73
|
+
# super.__init__ needs osc_feature_name and settings
|
|
74
|
+
super().__init__(settings, ch_names, sfreq)
|
|
75
|
+
|
|
76
|
+
window_ms = self.settings.windowlength_ms
|
|
77
|
+
|
|
78
|
+
self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
|
|
79
|
+
self.freqs = rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq))
|
|
80
|
+
|
|
81
|
+
# Pre-calculate frequency ranges
|
|
82
|
+
self.idx_range = [
|
|
83
|
+
(
|
|
84
|
+
f_band,
|
|
85
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
|
|
86
|
+
)
|
|
87
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
self.estimators = [
|
|
91
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
95
|
+
data = data[:, self.window_samples :]
|
|
96
|
+
|
|
97
|
+
from scipy.fft import rfft
|
|
98
|
+
|
|
99
|
+
Z = np.abs(rfft(data)) # type: ignore
|
|
100
|
+
|
|
101
|
+
if self.settings.log_transform:
|
|
102
|
+
Z = np.log10(Z)
|
|
103
|
+
|
|
104
|
+
feature_results = {}
|
|
105
|
+
|
|
106
|
+
for f_band_name, idx_range in self.idx_range:
|
|
107
|
+
# TODO Can we get rid of this for-loop? Hard to vectorize windows of different lengths...
|
|
108
|
+
Z_band = Z[:, idx_range] # Data for all channels
|
|
109
|
+
|
|
110
|
+
for est_name, est_fun in self.estimators:
|
|
111
|
+
result = est_fun(Z_band, axis=1)
|
|
112
|
+
|
|
113
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
114
|
+
feature_results[
|
|
115
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
116
|
+
] = result[ch_idx]
|
|
117
|
+
|
|
118
|
+
if self.settings.return_spectrum:
|
|
119
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
120
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
121
|
+
feature_results[f"{ch_name}_fft_psd_{int(f)}"] = Z[ch_idx][idx]
|
|
122
|
+
|
|
123
|
+
return feature_results
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Welch(OscillatoryFeature):
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
settings: "NMSettings",
|
|
130
|
+
ch_names: Iterable[str],
|
|
131
|
+
sfreq: int,
|
|
132
|
+
) -> None:
|
|
133
|
+
from scipy.fft import rfftfreq
|
|
134
|
+
|
|
135
|
+
self.osc_feature_name = "welch"
|
|
136
|
+
self.settings = settings.welch_settings
|
|
137
|
+
# super.__init__ needs osc_feature_name and settings
|
|
138
|
+
super().__init__(settings, ch_names, sfreq)
|
|
139
|
+
|
|
140
|
+
self.freqs = rfftfreq(self.sfreq, 1 / self.sfreq)
|
|
141
|
+
|
|
142
|
+
self.idx_range = [
|
|
143
|
+
(
|
|
144
|
+
f_band,
|
|
145
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
|
|
146
|
+
)
|
|
147
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
self.estimators = [
|
|
151
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
155
|
+
from scipy.signal import welch
|
|
156
|
+
|
|
157
|
+
_, Z = welch(
|
|
158
|
+
data,
|
|
159
|
+
fs=self.sfreq,
|
|
160
|
+
window="hann",
|
|
161
|
+
nperseg=self.sfreq,
|
|
162
|
+
noverlap=None,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if self.settings.log_transform:
|
|
166
|
+
Z = np.log10(Z)
|
|
167
|
+
|
|
168
|
+
feature_results = {}
|
|
169
|
+
|
|
170
|
+
for f_band_name, idx_range in self.idx_range:
|
|
171
|
+
Z_band = Z[:, idx_range]
|
|
172
|
+
|
|
173
|
+
for est_name, est_fun in self.estimators:
|
|
174
|
+
result = est_fun(Z_band, axis=1)
|
|
175
|
+
|
|
176
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
177
|
+
feature_results[
|
|
178
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
179
|
+
] = result[ch_idx]
|
|
180
|
+
|
|
181
|
+
if self.settings.return_spectrum:
|
|
182
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
183
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
184
|
+
feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
|
|
185
|
+
|
|
186
|
+
return feature_results
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class STFT(OscillatoryFeature):
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
settings: "NMSettings",
|
|
193
|
+
ch_names: Iterable[str],
|
|
194
|
+
sfreq: int,
|
|
195
|
+
) -> None:
|
|
196
|
+
from scipy.fft import rfftfreq
|
|
197
|
+
|
|
198
|
+
self.osc_feature_name = "stft"
|
|
199
|
+
self.settings = settings.stft_settings
|
|
200
|
+
# super.__init__ needs osc_feature_name and settings
|
|
201
|
+
super().__init__(settings, ch_names, sfreq)
|
|
202
|
+
|
|
203
|
+
self.nperseg = self.settings.windowlength_ms
|
|
204
|
+
|
|
205
|
+
self.freqs = rfftfreq(self.nperseg, 1 / self.sfreq)
|
|
206
|
+
|
|
207
|
+
self.idx_range = [
|
|
208
|
+
(
|
|
209
|
+
f_band,
|
|
210
|
+
np.where((self.freqs >= f_range[0]) & (self.freqs <= f_range[1]))[0],
|
|
211
|
+
)
|
|
212
|
+
for f_band, f_range in self.frequency_ranges.items()
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
self.estimators = [
|
|
216
|
+
(est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
220
|
+
from scipy.signal import stft
|
|
221
|
+
|
|
222
|
+
_, _, Zxx = stft(
|
|
223
|
+
data,
|
|
224
|
+
fs=self.sfreq,
|
|
225
|
+
window="hamming",
|
|
226
|
+
nperseg=self.nperseg,
|
|
227
|
+
boundary="even",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
Z = np.abs(Zxx)
|
|
231
|
+
if self.settings.log_transform:
|
|
232
|
+
Z = np.log10(Z)
|
|
233
|
+
|
|
234
|
+
feature_results = {}
|
|
235
|
+
|
|
236
|
+
for f_band_name, idx_range in self.idx_range:
|
|
237
|
+
Z_band = Z[:, idx_range, :]
|
|
238
|
+
|
|
239
|
+
for est_name, est_fun in self.estimators:
|
|
240
|
+
result = est_fun(Z_band, axis=(1, 2))
|
|
241
|
+
|
|
242
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
243
|
+
feature_results[
|
|
244
|
+
f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
|
|
245
|
+
] = result[ch_idx]
|
|
246
|
+
|
|
247
|
+
if self.settings.return_spectrum:
|
|
248
|
+
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
249
|
+
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
250
|
+
feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
|
|
251
|
+
axis=1
|
|
252
|
+
)[idx]
|
|
253
|
+
|
|
254
|
+
return feature_results
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class BandpowerFeatures(BoolSelector):
|
|
258
|
+
activity: bool = True
|
|
259
|
+
mobility: bool = False
|
|
260
|
+
complexity: bool = False
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
###################################
|
|
264
|
+
######## BANDPOWER FEATURE ########
|
|
265
|
+
###################################
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class BandpassSettings(NMBaseModel):
|
|
269
|
+
segment_lengths_ms: dict[str, int] = {
|
|
270
|
+
"theta": 1000,
|
|
271
|
+
"alpha": 500,
|
|
272
|
+
"low_beta": 333,
|
|
273
|
+
"high_beta": 333,
|
|
274
|
+
"low_gamma": 100,
|
|
275
|
+
"high_gamma": 100,
|
|
276
|
+
"HFA": 100,
|
|
277
|
+
}
|
|
278
|
+
bandpower_features: BandpowerFeatures = BandpowerFeatures()
|
|
279
|
+
log_transform: bool = True
|
|
280
|
+
kalman_filter: bool = False
|
|
281
|
+
|
|
282
|
+
@field_validator("segment_lengths_ms")
|
|
283
|
+
@classmethod
|
|
284
|
+
# Replace spaces with underscores in frequency band names
|
|
285
|
+
def fbands_spaces_to_underscores(cls, segment_lengths_ms: dict[str, int]):
|
|
286
|
+
return {k.replace(" ", "_"): v for k, v in segment_lengths_ms.items()}
|
|
287
|
+
|
|
288
|
+
@field_validator("bandpower_features")
|
|
289
|
+
@classmethod
|
|
290
|
+
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
|
|
291
|
+
assert (
|
|
292
|
+
len(bandpower_features.get_enabled()) > 0
|
|
293
|
+
), "Set at least one bandpower_feature to True."
|
|
294
|
+
|
|
295
|
+
return bandpower_features
|
|
296
|
+
|
|
297
|
+
def validate_fbands(self, settings: "NMSettings") -> None:
|
|
298
|
+
# Ensure that each freq-band is defined in the global settings
|
|
299
|
+
for fband_name in settings.frequency_ranges_hz.keys():
|
|
300
|
+
assert fband_name in self.segment_lengths_ms, (
|
|
301
|
+
f"frequency range {fband_name} "
|
|
302
|
+
"needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Ensure that segment length for each freq-band is smaller than the feature segment length setting
|
|
306
|
+
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
|
|
307
|
+
assert seg_length_fband <= settings.segment_length_features_ms, (
|
|
308
|
+
f"segment length {seg_length_fband} needs to be smaller than "
|
|
309
|
+
f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class BandPower(NMFeature):
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
settings: "NMSettings",
|
|
317
|
+
ch_names: Iterable[str],
|
|
318
|
+
sfreq: float,
|
|
319
|
+
use_kf: bool | None = None,
|
|
320
|
+
) -> None:
|
|
321
|
+
settings.validate()
|
|
322
|
+
|
|
323
|
+
self.bp_settings: BandpassSettings = settings.bandpass_filter_settings
|
|
324
|
+
self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings
|
|
325
|
+
self.sfreq = sfreq
|
|
326
|
+
self.ch_names = ch_names
|
|
327
|
+
self.KF_dict: dict = {}
|
|
328
|
+
|
|
329
|
+
from py_neuromodulation.nm_filter import MNEFilter
|
|
330
|
+
|
|
331
|
+
self.bandpass_filter = MNEFilter(
|
|
332
|
+
f_ranges=[
|
|
333
|
+
tuple(frange) for frange in settings.frequency_ranges_hz.values()
|
|
334
|
+
],
|
|
335
|
+
sfreq=self.sfreq,
|
|
336
|
+
filter_length=self.sfreq - 1,
|
|
337
|
+
verbose=False,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if use_kf or (use_kf is None and self.bp_settings.kalman_filter):
|
|
341
|
+
self.init_KF("bandpass_activity")
|
|
342
|
+
|
|
343
|
+
seglengths = self.bp_settings.segment_lengths_ms
|
|
344
|
+
|
|
345
|
+
self.feature_params = []
|
|
346
|
+
for ch_idx, ch_name in enumerate(self.ch_names):
|
|
347
|
+
for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()):
|
|
348
|
+
seglength_ms = seglengths[f_band]
|
|
349
|
+
seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
|
|
350
|
+
for bp_feature in self.bp_settings.bandpower_features.get_enabled():
|
|
351
|
+
feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band])
|
|
352
|
+
self.feature_params.append(
|
|
353
|
+
(
|
|
354
|
+
ch_idx,
|
|
355
|
+
f_band_idx,
|
|
356
|
+
seglen,
|
|
357
|
+
bp_feature,
|
|
358
|
+
feature_name,
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
def init_KF(self, feature: str) -> None:
|
|
363
|
+
from py_neuromodulation.nm_kalmanfilter import define_KF
|
|
364
|
+
|
|
365
|
+
for f_band in self.kalman_filter_settings.frequency_bands:
|
|
366
|
+
for channel in self.ch_names:
|
|
367
|
+
self.KF_dict["_".join([channel, feature, f_band])] = define_KF(
|
|
368
|
+
self.kalman_filter_settings.Tp,
|
|
369
|
+
self.kalman_filter_settings.sigma_w,
|
|
370
|
+
self.kalman_filter_settings.sigma_v,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating:
|
|
374
|
+
if KF_name in self.KF_dict:
|
|
375
|
+
self.KF_dict[KF_name].predict()
|
|
376
|
+
self.KF_dict[KF_name].update(feature_calc)
|
|
377
|
+
feature_calc = self.KF_dict[KF_name].x[0]
|
|
378
|
+
return feature_calc
|
|
379
|
+
|
|
380
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
381
|
+
data = self.bandpass_filter.filter_data(data)
|
|
382
|
+
|
|
383
|
+
feature_results = {}
|
|
384
|
+
|
|
385
|
+
for (
|
|
386
|
+
ch_idx,
|
|
387
|
+
f_band_idx,
|
|
388
|
+
seglen,
|
|
389
|
+
bp_feature,
|
|
390
|
+
feature_name,
|
|
391
|
+
) in self.feature_params:
|
|
392
|
+
feature_results[feature_name] = self.calc_bp_feature(
|
|
393
|
+
bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:]
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return feature_results
|
|
397
|
+
|
|
398
|
+
def calc_bp_feature(self, bp_feature, feature_name, data):
|
|
399
|
+
match bp_feature:
|
|
400
|
+
case "activity":
|
|
401
|
+
feature_calc = np.var(data)
|
|
402
|
+
if self.bp_settings.log_transform:
|
|
403
|
+
feature_calc = np.log10(feature_calc)
|
|
404
|
+
if self.KF_dict:
|
|
405
|
+
feature_calc = self.update_KF(feature_calc, feature_name)
|
|
406
|
+
case "mobility":
|
|
407
|
+
feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data))
|
|
408
|
+
case "complexity":
|
|
409
|
+
feature_calc = self.calc_complexity(data)
|
|
410
|
+
case _:
|
|
411
|
+
raise ValueError(f"Unknown bandpower feature: {bp_feature}")
|
|
412
|
+
|
|
413
|
+
return np.nan_to_num(feature_calc)
|
|
414
|
+
|
|
415
|
+
@staticmethod
|
|
416
|
+
def calc_complexity(data: np.ndarray) -> float:
|
|
417
|
+
dat_deriv = np.diff(data)
|
|
418
|
+
deriv_variance = np.var(dat_deriv)
|
|
419
|
+
mobility = np.sqrt(deriv_variance / np.var(data))
|
|
420
|
+
dat_deriv_2_var = np.var(np.diff(dat_deriv))
|
|
421
|
+
deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
|
|
422
|
+
|
|
423
|
+
return deriv_mobility / mobility
|