py-neuromodulation 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
  3. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -0
  4. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -0
  5. py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
  6. py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
  7. py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
  8. py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
  9. py_neuromodulation/FieldTrip.py +589 -589
  10. py_neuromodulation/__init__.py +74 -13
  11. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  12. py_neuromodulation/data/README +6 -0
  13. py_neuromodulation/data/dataset_description.json +8 -0
  14. py_neuromodulation/data/participants.json +32 -0
  15. py_neuromodulation/data/participants.tsv +2 -0
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -0
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -0
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -0
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.eeg +0 -0
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -0
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -0
  22. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -0
  23. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -0
  24. py_neuromodulation/grid_cortex.tsv +40 -0
  25. py_neuromodulation/grid_subcortex.tsv +1429 -0
  26. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  27. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  28. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  29. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  30. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  31. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  32. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  33. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  34. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  35. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  36. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  37. py_neuromodulation/nm_IO.py +413 -417
  38. py_neuromodulation/nm_RMAP.py +496 -531
  39. py_neuromodulation/nm_analysis.py +993 -1074
  40. py_neuromodulation/nm_artifacts.py +30 -25
  41. py_neuromodulation/nm_bispectra.py +154 -168
  42. py_neuromodulation/nm_bursts.py +292 -198
  43. py_neuromodulation/nm_coherence.py +251 -205
  44. py_neuromodulation/nm_database.py +149 -0
  45. py_neuromodulation/nm_decode.py +918 -992
  46. py_neuromodulation/nm_define_nmchannels.py +300 -302
  47. py_neuromodulation/nm_features.py +144 -116
  48. py_neuromodulation/nm_filter.py +219 -219
  49. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  50. py_neuromodulation/nm_fooof.py +139 -159
  51. py_neuromodulation/nm_generator.py +45 -37
  52. py_neuromodulation/nm_hjorth_raw.py +52 -73
  53. py_neuromodulation/nm_kalmanfilter.py +71 -58
  54. py_neuromodulation/nm_linelength.py +21 -33
  55. py_neuromodulation/nm_logger.py +66 -0
  56. py_neuromodulation/nm_mne_connectivity.py +149 -112
  57. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  58. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  59. py_neuromodulation/nm_nolds.py +96 -93
  60. py_neuromodulation/nm_normalization.py +173 -214
  61. py_neuromodulation/nm_oscillatory.py +423 -448
  62. py_neuromodulation/nm_plots.py +585 -612
  63. py_neuromodulation/nm_preprocessing.py +83 -0
  64. py_neuromodulation/nm_projection.py +370 -394
  65. py_neuromodulation/nm_rereference.py +97 -95
  66. py_neuromodulation/nm_resample.py +59 -50
  67. py_neuromodulation/nm_run_analysis.py +325 -435
  68. py_neuromodulation/nm_settings.py +289 -68
  69. py_neuromodulation/nm_settings.yaml +244 -0
  70. py_neuromodulation/nm_sharpwaves.py +423 -401
  71. py_neuromodulation/nm_stats.py +464 -480
  72. py_neuromodulation/nm_stream.py +398 -0
  73. py_neuromodulation/nm_stream_abc.py +166 -218
  74. py_neuromodulation/nm_types.py +193 -0
  75. py_neuromodulation/plots/STN_surf.mat +0 -0
  76. py_neuromodulation/plots/Vertices.mat +0 -0
  77. py_neuromodulation/plots/faces.mat +0 -0
  78. py_neuromodulation/plots/grid.mat +0 -0
  79. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +185 -182
  80. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  81. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -2
  82. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info/licenses}/LICENSE +21 -21
  83. docs/build/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  84. docs/build/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -233
  85. docs/build/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  86. docs/build/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  87. docs/build/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  88. docs/build/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  89. docs/build/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  90. docs/build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  91. docs/build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -239
  92. docs/build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  93. docs/build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  94. docs/build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  95. docs/build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  96. docs/build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  97. docs/source/_build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -76
  98. docs/source/_build/html/_downloads/0d0d0a76e8f648d5d3cbc47da6351932/plot_real_time_demo.py +0 -97
  99. docs/source/_build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -240
  100. docs/source/_build/html/_downloads/5d73cadc59a8805c47e3b84063afc157/plot_example_BIDS.py +0 -233
  101. docs/source/_build/html/_downloads/7660317fa5a6bfbd12fcca9961457fc4/plot_example_rmap_computing.py +0 -63
  102. docs/source/_build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  103. docs/source/_build/html/_downloads/839e5b319379f7fd9e867deb00fd797f/plot_example_gridPointProjection.py +0 -210
  104. docs/source/_build/html/_downloads/ae8be19afe5e559f011fc9b138968ba0/plot_first_demo.py +0 -192
  105. docs/source/_build/html/_downloads/b8b06cacc17969d3725a0b6f1d7741c5/plot_example_sharpwave_analysis.py +0 -219
  106. docs/source/_build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -121
  107. docs/source/_build/html/_downloads/c31a86c0b68cb4167d968091ace8080d/plot_example_add_feature.py +0 -68
  108. docs/source/_build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  109. docs/source/_build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -189
  110. docs/source/_build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  111. docs/source/auto_examples/plot_0_first_demo.py +0 -189
  112. docs/source/auto_examples/plot_1_example_BIDS.py +0 -240
  113. docs/source/auto_examples/plot_2_example_add_feature.py +0 -76
  114. docs/source/auto_examples/plot_3_example_sharpwave_analysis.py +0 -219
  115. docs/source/auto_examples/plot_4_example_gridPointProjection.py +0 -210
  116. docs/source/auto_examples/plot_5_example_rmap_computing.py +0 -64
  117. docs/source/auto_examples/plot_6_real_time_demo.py +0 -121
  118. docs/source/conf.py +0 -105
  119. examples/plot_0_first_demo.py +0 -189
  120. examples/plot_1_example_BIDS.py +0 -240
  121. examples/plot_2_example_add_feature.py +0 -76
  122. examples/plot_3_example_sharpwave_analysis.py +0 -219
  123. examples/plot_4_example_gridPointProjection.py +0 -210
  124. examples/plot_5_example_rmap_computing.py +0 -64
  125. examples/plot_6_real_time_demo.py +0 -121
  126. packages/realtime_decoding/build/lib/realtime_decoding/__init__.py +0 -4
  127. packages/realtime_decoding/build/lib/realtime_decoding/decoder.py +0 -104
  128. packages/realtime_decoding/build/lib/realtime_decoding/features.py +0 -163
  129. packages/realtime_decoding/build/lib/realtime_decoding/helpers.py +0 -15
  130. packages/realtime_decoding/build/lib/realtime_decoding/run_decoding.py +0 -345
  131. packages/realtime_decoding/build/lib/realtime_decoding/trainer.py +0 -54
  132. packages/tmsi/build/lib/TMSiFileFormats/__init__.py +0 -37
  133. packages/tmsi/build/lib/TMSiFileFormats/file_formats/__init__.py +0 -36
  134. packages/tmsi/build/lib/TMSiFileFormats/file_formats/lsl_stream_writer.py +0 -200
  135. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_file_writer.py +0 -496
  136. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_to_edf_converter.py +0 -236
  137. packages/tmsi/build/lib/TMSiFileFormats/file_formats/xdf_file_writer.py +0 -977
  138. packages/tmsi/build/lib/TMSiFileFormats/file_readers/__init__.py +0 -35
  139. packages/tmsi/build/lib/TMSiFileFormats/file_readers/edf_reader.py +0 -116
  140. packages/tmsi/build/lib/TMSiFileFormats/file_readers/poly5reader.py +0 -294
  141. packages/tmsi/build/lib/TMSiFileFormats/file_readers/xdf_reader.py +0 -229
  142. packages/tmsi/build/lib/TMSiFileFormats/file_writer.py +0 -102
  143. packages/tmsi/build/lib/TMSiPlotters/__init__.py +0 -2
  144. packages/tmsi/build/lib/TMSiPlotters/gui/__init__.py +0 -39
  145. packages/tmsi/build/lib/TMSiPlotters/gui/_plotter_gui.py +0 -234
  146. packages/tmsi/build/lib/TMSiPlotters/gui/plotting_gui.py +0 -440
  147. packages/tmsi/build/lib/TMSiPlotters/plotters/__init__.py +0 -44
  148. packages/tmsi/build/lib/TMSiPlotters/plotters/hd_emg_plotter.py +0 -446
  149. packages/tmsi/build/lib/TMSiPlotters/plotters/impedance_plotter.py +0 -589
  150. packages/tmsi/build/lib/TMSiPlotters/plotters/signal_plotter.py +0 -1326
  151. packages/tmsi/build/lib/TMSiSDK/__init__.py +0 -54
  152. packages/tmsi/build/lib/TMSiSDK/device.py +0 -588
  153. packages/tmsi/build/lib/TMSiSDK/devices/__init__.py +0 -34
  154. packages/tmsi/build/lib/TMSiSDK/devices/saga/TMSi_Device_API.py +0 -1764
  155. packages/tmsi/build/lib/TMSiSDK/devices/saga/__init__.py +0 -34
  156. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_device.py +0 -1366
  157. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_types.py +0 -520
  158. packages/tmsi/build/lib/TMSiSDK/devices/saga/xml_saga_config.py +0 -165
  159. packages/tmsi/build/lib/TMSiSDK/error.py +0 -95
  160. packages/tmsi/build/lib/TMSiSDK/sample_data.py +0 -63
  161. packages/tmsi/build/lib/TMSiSDK/sample_data_server.py +0 -99
  162. packages/tmsi/build/lib/TMSiSDK/settings.py +0 -45
  163. packages/tmsi/build/lib/TMSiSDK/tmsi_device.py +0 -111
  164. packages/tmsi/build/lib/__init__.py +0 -4
  165. packages/tmsi/build/lib/apex_sdk/__init__.py +0 -34
  166. packages/tmsi/build/lib/apex_sdk/device/__init__.py +0 -41
  167. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API.py +0 -1009
  168. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_enums.py +0 -239
  169. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_structures.py +0 -668
  170. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_device.py +0 -1611
  171. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_dongle.py +0 -38
  172. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_event_reader.py +0 -57
  173. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_channel.py +0 -44
  174. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_config.py +0 -150
  175. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_const.py +0 -36
  176. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_impedance_channel.py +0 -48
  177. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_info.py +0 -108
  178. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/dongle_info.py +0 -39
  179. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/download_measurement.py +0 -77
  180. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/eeg_measurement.py +0 -150
  181. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/impedance_measurement.py +0 -129
  182. packages/tmsi/build/lib/apex_sdk/device/threads/conversion_thread.py +0 -59
  183. packages/tmsi/build/lib/apex_sdk/device/threads/sampling_thread.py +0 -57
  184. packages/tmsi/build/lib/apex_sdk/device/tmsi_channel.py +0 -83
  185. packages/tmsi/build/lib/apex_sdk/device/tmsi_device.py +0 -201
  186. packages/tmsi/build/lib/apex_sdk/device/tmsi_device_enums.py +0 -103
  187. packages/tmsi/build/lib/apex_sdk/device/tmsi_dongle.py +0 -43
  188. packages/tmsi/build/lib/apex_sdk/device/tmsi_event_reader.py +0 -50
  189. packages/tmsi/build/lib/apex_sdk/device/tmsi_measurement.py +0 -118
  190. packages/tmsi/build/lib/apex_sdk/sample_data_server/__init__.py +0 -33
  191. packages/tmsi/build/lib/apex_sdk/sample_data_server/event_data.py +0 -44
  192. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data.py +0 -50
  193. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data_server.py +0 -136
  194. packages/tmsi/build/lib/apex_sdk/tmsi_errors/error.py +0 -126
  195. packages/tmsi/build/lib/apex_sdk/tmsi_sdk.py +0 -113
  196. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/apex/apex_structure_generator.py +0 -134
  197. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/decorators.py +0 -60
  198. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/logger_filter.py +0 -42
  199. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/singleton.py +0 -42
  200. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/support_functions.py +0 -72
  201. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/tmsi_logger.py +0 -98
  202. py_neuromodulation/nm_EpochStream.py +0 -92
  203. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  204. py_neuromodulation/nm_cohortwrapper.py +0 -435
  205. py_neuromodulation/nm_eval_timing.py +0 -239
  206. py_neuromodulation/nm_features_abc.py +0 -39
  207. py_neuromodulation/nm_stream_offline.py +0 -358
  208. py_neuromodulation/utils/_logging.py +0 -24
  209. py_neuromodulation-0.0.3.dist-info/RECORD +0 -188
  210. py_neuromodulation-0.0.3.dist-info/top_level.txt +0 -5
  211. tests/__init__.py +0 -0
  212. tests/conftest.py +0 -117
  213. tests/test_all_examples.py +0 -10
  214. tests/test_all_features.py +0 -63
  215. tests/test_bispectra.py +0 -70
  216. tests/test_bursts.py +0 -105
  217. tests/test_feature_sampling_rates.py +0 -143
  218. tests/test_fooof.py +0 -16
  219. tests/test_initalization_offline_stream.py +0 -41
  220. tests/test_multiprocessing.py +0 -58
  221. tests/test_nan_values.py +0 -29
  222. tests/test_nm_filter.py +0 -95
  223. tests/test_nm_resample.py +0 -63
  224. tests/test_normalization_settings.py +0 -146
  225. tests/test_notch_filter.py +0 -31
  226. tests/test_osc_features.py +0 -424
  227. tests/test_preprocessing_filter.py +0 -151
  228. tests/test_rereference.py +0 -171
  229. tests/test_sampling.py +0 -57
  230. tests/test_settings_change_after_init.py +0 -76
  231. tests/test_sharpwave.py +0 -165
  232. tests/test_target_channel_add.py +0 -100
  233. tests/test_timing.py +0 -80
@@ -1,448 +1,423 @@
1
- from typing import Iterable
2
-
3
- import numpy as np
4
- from scipy import fft, signal
5
-
6
- from py_neuromodulation import nm_filter, nm_features_abc, nm_kalmanfilter
7
-
8
-
9
- class OscillatoryFeature(nm_features_abc.Feature):
10
- def __init__(
11
- self, settings: dict, ch_names: Iterable[str], sfreq: float
12
- ) -> None:
13
- self.s = settings
14
- self.sfreq = sfreq
15
- self.ch_names = ch_names
16
- self.KF_dict = {}
17
-
18
- self.f_ranges_dict = settings["frequency_ranges_hz"]
19
- self.fband_names = list(settings["frequency_ranges_hz"].keys())
20
- self.f_ranges = list(settings["frequency_ranges_hz"].values())
21
-
22
- @staticmethod
23
- def test_settings_osc(
24
- s: dict,
25
- ch_names: Iterable[str],
26
- sfreq: int | float,
27
- osc_feature_name: str,
28
- ):
29
- assert (
30
- fb[0] < sfreq / 2 and fb[1] < sfreq / 2
31
- for fb in s["frequency_ranges_hz"].values()
32
- ), (
33
- "the frequency band ranges need to be smaller than the nyquist frequency"
34
- f"got sfreq = {sfreq} and fband ranges {s['frequency_ranges_hz']}"
35
- )
36
-
37
- if osc_feature_name != "bandpass_filter_settings":
38
- assert isinstance(
39
- s[osc_feature_name]["windowlength_ms"], int
40
- ), f"windowlength_ms needs to be type int, got {s[osc_feature_name]['windowlength_ms']}"
41
-
42
- assert (
43
- s[osc_feature_name]["windowlength_ms"]
44
- <= s["segment_length_features_ms"]
45
- ), (
46
- f"oscillatory feature windowlength_ms = ({s[osc_feature_name]['windowlength_ms']})"
47
- f"needs to be smaller than"
48
- f"s['segment_length_features_ms'] = {s['segment_length_features_ms']}",
49
- )
50
- else:
51
- for seg_length in s[osc_feature_name][
52
- "segment_lengths_ms"
53
- ].values():
54
- assert isinstance(
55
- seg_length, int
56
- ), f"segment length has to be type int, got {seg_length}"
57
- assert isinstance(
58
- s[osc_feature_name]["log_transform"], bool
59
- ), f"log_transform needs to be type bool, got {s[osc_feature_name]['log_transform']}"
60
-
61
- assert isinstance(s["frequency_ranges_hz"], dict)
62
-
63
- assert (
64
- isinstance(value, list)
65
- for value in s["frequency_ranges_hz"].values()
66
- )
67
- assert (len(value) == 2 for value in s["frequency_ranges_hz"].values())
68
-
69
- assert (
70
- isinstance(value[0], list)
71
- for value in s["frequency_ranges_hz"].values()
72
- )
73
-
74
- assert (
75
- len(value[0]) == 2 for value in s["frequency_ranges_hz"].values()
76
- )
77
-
78
- assert (
79
- isinstance(value[1], (float, int))
80
- for value in s["frequency_ranges_hz"].values()
81
- )
82
-
83
- def init_KF(self, feature: str) -> None:
84
- for f_band in self.s["kalman_filter_settings"]["frequency_bands"]:
85
- for channel in self.ch_names:
86
- self.KF_dict[
87
- "_".join([channel, feature, f_band])
88
- ] = nm_kalmanfilter.define_KF(
89
- self.s["kalman_filter_settings"]["Tp"],
90
- self.s["kalman_filter_settings"]["sigma_w"],
91
- self.s["kalman_filter_settings"]["sigma_v"],
92
- )
93
-
94
- def update_KF(self, feature_calc: float, KF_name: str) -> float:
95
- if KF_name in self.KF_dict:
96
- self.KF_dict[KF_name].predict()
97
- self.KF_dict[KF_name].update(feature_calc)
98
- feature_calc = self.KF_dict[KF_name].x[0]
99
- return feature_calc
100
-
101
- def estimate_osc_features(
102
- self,
103
- features_compute: dict,
104
- data: np.ndarray,
105
- feature_name: np.ndarray,
106
- est_name: str,
107
- ):
108
- for feature_est_name in list(self.s[est_name]["features"].keys()):
109
- if self.s[est_name]["features"][feature_est_name] is True:
110
- # switch case for feature_est_name
111
- match feature_est_name:
112
- case "mean":
113
- features_compute[
114
- f"{feature_name}_{feature_est_name}"
115
- ] = np.nanmean(data)
116
- case "median":
117
- features_compute[
118
- f"{feature_name}_{feature_est_name}"
119
- ] = np.nanmedian(data)
120
- case "std":
121
- features_compute[
122
- f"{feature_name}_{feature_est_name}"
123
- ] = np.nanstd(data)
124
- case "max":
125
- features_compute[
126
- f"{feature_name}_{feature_est_name}"
127
- ] = np.nanmax(data)
128
-
129
- return features_compute
130
-
131
-
132
- class FFT(OscillatoryFeature):
133
- def __init__(
134
- self,
135
- settings: dict,
136
- ch_names: Iterable[str],
137
- sfreq: float,
138
- ) -> None:
139
- super().__init__(settings, ch_names, sfreq)
140
-
141
- if self.s["fft_settings"]["log_transform"]:
142
- self.log_transform = True
143
- else:
144
- self.log_transform = False
145
-
146
- window_ms = self.s["fft_settings"]["windowlength_ms"]
147
- self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
148
- self.freqs = fft.rfftfreq(
149
- -self.window_samples, 1 / np.floor(self.sfreq)
150
- )
151
-
152
- self.feature_params = []
153
- for ch_idx, ch_name in enumerate(self.ch_names):
154
- for fband, f_range in self.f_ranges_dict.items():
155
- idx_range = np.where(
156
- (self.freqs >= f_range[0]) & (self.freqs < f_range[1])
157
- )[0]
158
- feature_name = "_".join([ch_name, "fft", fband])
159
- self.feature_params.append((ch_idx, feature_name, idx_range))
160
-
161
- @staticmethod
162
- def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
163
- OscillatoryFeature.test_settings_osc(s, ch_names, sfreq, "fft_settings")
164
-
165
- def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
166
- data = data[:, self.window_samples :]
167
- Z = np.abs(fft.rfft(data))
168
-
169
- if self.log_transform:
170
- Z = np.log10(Z)
171
-
172
- for ch_idx, feature_name, idx_range in self.feature_params:
173
- Z_ch = Z[ch_idx, idx_range]
174
-
175
- features_compute = self.estimate_osc_features(
176
- features_compute, Z_ch, feature_name, "fft_settings"
177
- )
178
-
179
- for ch_idx, ch_name in enumerate(self.ch_names):
180
- if self.s["fft_settings"]["return_spectrum"]:
181
- features_compute.update(
182
- {
183
- f"{ch_name}_fft_psd_{str(f)}": Z[ch_idx][idx]
184
- for idx, f in enumerate(self.freqs.astype(int))
185
- }
186
- )
187
-
188
- return features_compute
189
-
190
-
191
- class Welch(OscillatoryFeature):
192
- def __init__(
193
- self,
194
- settings: dict,
195
- ch_names: Iterable[str],
196
- sfreq: float,
197
- ) -> None:
198
- super().__init__(settings, ch_names, sfreq)
199
-
200
- self.log_transform = self.s["welch_settings"]["log_transform"]
201
-
202
- self.feature_params = []
203
- for ch_idx, ch_name in enumerate(self.ch_names):
204
- for fband, f_range in self.f_ranges_dict.items():
205
- feature_name = "_".join([ch_name, "welch", fband])
206
- self.feature_params.append((ch_idx, feature_name, f_range))
207
-
208
- @staticmethod
209
- def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
210
- OscillatoryFeature.test_settings_osc(
211
- s, ch_names, sfreq, "welch_settings"
212
- )
213
-
214
- def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
215
- freqs, Z = signal.welch(
216
- data,
217
- fs=self.sfreq,
218
- window="hann",
219
- nperseg=self.sfreq,
220
- noverlap=None,
221
- )
222
-
223
- if self.log_transform:
224
- Z = np.log10(Z)
225
-
226
- for ch_idx, feature_name, f_range in self.feature_params:
227
- Z_ch = Z[ch_idx]
228
-
229
- idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
230
- 0
231
- ]
232
-
233
- features_compute = self.estimate_osc_features(
234
- features_compute,
235
- Z_ch[idx_range],
236
- feature_name,
237
- "welch_settings",
238
- )
239
-
240
- for ch_idx, ch_name in enumerate(self.ch_names):
241
- if self.s["welch_settings"]["return_spectrum"]:
242
- features_compute.update(
243
- {
244
- f"{ch_name}_welch_psd_{str(f)}": Z[ch_idx][idx]
245
- for idx, f in enumerate(freqs.astype(int))
246
- }
247
- )
248
-
249
- return features_compute
250
-
251
-
252
- class STFT(OscillatoryFeature):
253
- def __init__(
254
- self,
255
- settings: dict,
256
- ch_names: Iterable[str],
257
- sfreq: float,
258
- ) -> None:
259
- super().__init__(settings, ch_names, sfreq)
260
-
261
- self.nperseg = int(self.s["stft_settings"]["windowlength_ms"])
262
- self.log_transform = self.s["stft_settings"]["log_transform"]
263
-
264
- self.feature_params = []
265
- for ch_idx, ch_name in enumerate(self.ch_names):
266
- for fband, f_range in self.f_ranges_dict.items():
267
- feature_name = "_".join([ch_name, "stft", fband])
268
- self.feature_params.append((ch_idx, feature_name, f_range))
269
-
270
- @staticmethod
271
- def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
272
- OscillatoryFeature.test_settings_osc(
273
- s, ch_names, sfreq, "stft_settings"
274
- )
275
-
276
- def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
277
- freqs, _, Zxx = signal.stft(
278
- data,
279
- fs=self.sfreq,
280
- window="hamming",
281
- nperseg=self.nperseg,
282
- boundary="even",
283
- )
284
- Z = np.abs(Zxx)
285
- if self.log_transform:
286
- Z = np.log10(Z)
287
- for ch_idx, feature_name, f_range in self.feature_params:
288
- Z_ch = Z[ch_idx]
289
- idx_range = np.where((freqs >= f_range[0]) & (freqs <= f_range[1]))[
290
- 0
291
- ]
292
-
293
- features_compute = self.estimate_osc_features(
294
- features_compute,
295
- Z_ch[idx_range, :],
296
- feature_name,
297
- "stft_settings",
298
- )
299
-
300
- for ch_idx, ch_name in enumerate(self.ch_names):
301
- if self.s["stft_settings"]["return_spectrum"]:
302
- Z_ch_mean = Z[ch_idx].mean(axis=1)
303
- features_compute.update(
304
- {
305
- f"{ch_name}_stft_psd_{str(f)}": Z_ch_mean[idx]
306
- for idx, f in enumerate(freqs.astype(int))
307
- }
308
- )
309
-
310
- return features_compute
311
-
312
-
313
- class BandPower(OscillatoryFeature):
314
- def __init__(
315
- self,
316
- settings: dict,
317
- ch_names: Iterable[str],
318
- sfreq: float,
319
- use_kf: bool = None,
320
- ) -> None:
321
- super().__init__(settings, ch_names, sfreq)
322
- bp_settings = self.s["bandpass_filter_settings"]
323
-
324
- self.bandpass_filter = nm_filter.MNEFilter(
325
- f_ranges=list(self.f_ranges_dict.values()),
326
- sfreq=self.sfreq,
327
- filter_length=self.sfreq - 1,
328
- verbose=False,
329
- )
330
-
331
- self.log_transform = bp_settings["log_transform"]
332
-
333
- if use_kf is True or (
334
- use_kf is None and bp_settings["kalman_filter"] is True
335
- ):
336
- self.init_KF("bandpass_activity")
337
-
338
- bp_features = ["activity", "mobility", "complexity"]
339
- seglengths = bp_settings["segment_lengths_ms"]
340
-
341
- self.feature_params = []
342
- for ch_idx, ch_name in enumerate(self.ch_names):
343
- for f_band_idx, f_band in enumerate(self.f_ranges_dict):
344
- seglength_ms = seglengths[f_band]
345
- seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
346
- for bp_feature, v in bp_settings["bandpower_features"].items():
347
- if v is True:
348
- if bp_feature not in bp_features:
349
- raise ValueError()
350
- feature_name = "_".join(
351
- [ch_name, "bandpass", bp_feature, f_band]
352
- )
353
- self.feature_params.append(
354
- (
355
- ch_idx,
356
- ch_name,
357
- f_band,
358
- f_band_idx,
359
- seglen,
360
- bp_feature,
361
- feature_name,
362
- )
363
- )
364
-
365
- @staticmethod
366
- def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
367
- OscillatoryFeature.test_settings_osc(
368
- s, ch_names, sfreq, "bandpass_filter_settings"
369
- )
370
-
371
- assert (
372
- isinstance(value, bool)
373
- for value in s["bandpass_filter_settings"][
374
- "bandpower_features"
375
- ].values()
376
- )
377
-
378
- assert any(
379
- value is True
380
- for value in s["bandpass_filter_settings"][
381
- "bandpower_features"
382
- ].values()
383
- ), "Set at least one bandpower_feature to True."
384
-
385
- for fband_name, seg_length_fband in s["bandpass_filter_settings"][
386
- "segment_lengths_ms"
387
- ].items():
388
- assert isinstance(seg_length_fband, int), (
389
- f"bandpass segment_lengths_ms for {fband_name} "
390
- f"needs to be of type int, got {seg_length_fband}"
391
- )
392
-
393
- assert seg_length_fband <= s["segment_length_features_ms"], (
394
- f"segment length {seg_length_fband} needs to be smaller than "
395
- f" s['segment_length_features_ms'] = {s['segment_length_features_ms']}"
396
- )
397
-
398
- for fband_name in list(s["frequency_ranges_hz"].keys()):
399
- assert fband_name in list(
400
- s["bandpass_filter_settings"]["segment_lengths_ms"].keys()
401
- ), (
402
- f"frequency range {fband_name} "
403
- "needs to be defined in s['bandpass_filter_settings']['segment_lengths_ms']"
404
- )
405
-
406
- def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
407
- data = self.bandpass_filter.filter_data(data)
408
-
409
- for (
410
- ch_idx,
411
- ch_name,
412
- f_band,
413
- f_band_idx,
414
- seglen,
415
- bp_feature,
416
- feature_name,
417
- ) in self.feature_params:
418
- if bp_feature == "activity":
419
- if self.log_transform:
420
- feature_calc = np.log10(
421
- np.var(data[ch_idx, f_band_idx, -seglen:])
422
- )
423
- else:
424
- feature_calc = np.var(data[ch_idx, f_band_idx, -seglen:])
425
- elif bp_feature == "mobility":
426
- deriv_variance = np.var(
427
- np.diff(data[ch_idx, f_band_idx, -seglen:])
428
- )
429
- feature_calc = np.sqrt(
430
- deriv_variance / np.var(data[ch_idx, f_band_idx, -seglen:])
431
- )
432
- elif bp_feature == "complexity":
433
- dat_deriv = np.diff(data[ch_idx, f_band_idx, -seglen:])
434
- deriv_variance = np.var(dat_deriv)
435
- mobility = np.sqrt(
436
- deriv_variance / np.var(data[ch_idx, f_band_idx, -seglen:])
437
- )
438
- dat_deriv_2 = np.diff(dat_deriv)
439
- dat_deriv_2_var = np.var(dat_deriv_2)
440
- deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
441
- feature_calc = deriv_mobility / mobility
442
-
443
- if self.KF_dict and (bp_feature == "activity"):
444
- feature_calc = self.update_KF(feature_calc, feature_name)
445
-
446
- features_compute[feature_name] = np.nan_to_num(feature_calc)
447
-
448
- return features_compute
1
+ from collections.abc import Iterable
2
+ import numpy as np
3
+ from itertools import product
4
+
5
+ from py_neuromodulation.nm_types import NMBaseModel
6
+ from pydantic import field_validator
7
+ from typing import TYPE_CHECKING
8
+
9
+ from py_neuromodulation.nm_features import NMFeature
10
+ from py_neuromodulation.nm_types import BoolSelector
11
+
12
+ if TYPE_CHECKING:
13
+ from py_neuromodulation.nm_settings import NMSettings
14
+ from py_neuromodulation.nm_kalmanfilter import KalmanSettings
15
+
16
+
17
+ class OscillatoryFeatures(BoolSelector):
18
+ mean: bool = True
19
+ median: bool = False
20
+ std: bool = False
21
+ max: bool = False
22
+
23
+
24
+ class OscillatorySettings(NMBaseModel):
25
+ windowlength_ms: int = 1000
26
+ log_transform: bool = True
27
+ features: OscillatoryFeatures = OscillatoryFeatures(
28
+ mean=True, median=False, std=False, max=False
29
+ )
30
+ return_spectrum: bool = False
31
+
32
+
33
+ ESTIMATOR_DICT = {
34
+ "mean": np.nanmean,
35
+ "median": np.nanmedian,
36
+ "std": np.nanstd,
37
+ "max": np.nanmax,
38
+ }
39
+
40
+
41
+ class OscillatoryFeature(NMFeature):
42
+ def __init__(
43
+ self, settings: "NMSettings", ch_names: Iterable[str], sfreq: int
44
+ ) -> None:
45
+ settings.validate()
46
+ self.settings: OscillatorySettings # Assignment in subclass __init__
47
+ self.osc_feature_name: str # Required for output
48
+
49
+ self.sfreq = int(sfreq)
50
+ self.ch_names = ch_names
51
+
52
+ self.frequency_ranges = settings.frequency_ranges_hz
53
+
54
+ # Test settings
55
+ assert self.settings.windowlength_ms <= settings.segment_length_features_ms, (
56
+ f"oscillatory feature windowlength_ms = ({self.settings.windowlength_ms})"
57
+ f"needs to be smaller than"
58
+ f"settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
59
+ )
60
+
61
+
62
+ class FFT(OscillatoryFeature):
63
+ def __init__(
64
+ self,
65
+ settings: "NMSettings",
66
+ ch_names: Iterable[str],
67
+ sfreq: int,
68
+ ) -> None:
69
+ from scipy.fft import rfftfreq
70
+
71
+ self.osc_feature_name = "fft"
72
+ self.settings = settings.fft_settings
73
+ # super.__init__ needs osc_feature_name and settings
74
+ super().__init__(settings, ch_names, sfreq)
75
+
76
+ window_ms = self.settings.windowlength_ms
77
+
78
+ self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
79
+ self.freqs = rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq))
80
+
81
+ # Pre-calculate frequency ranges
82
+ self.idx_range = [
83
+ (
84
+ f_band,
85
+ np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
86
+ )
87
+ for f_band, f_range in self.frequency_ranges.items()
88
+ ]
89
+
90
+ self.estimators = [
91
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
92
+ ]
93
+
94
+ def calc_feature(self, data: np.ndarray) -> dict:
95
+ data = data[:, self.window_samples :]
96
+
97
+ from scipy.fft import rfft
98
+
99
+ Z = np.abs(rfft(data)) # type: ignore
100
+
101
+ if self.settings.log_transform:
102
+ Z = np.log10(Z)
103
+
104
+ feature_results = {}
105
+
106
+ for f_band_name, idx_range in self.idx_range:
107
+ # TODO Can we get rid of this for-loop? Hard to vectorize windows of different lengths...
108
+ Z_band = Z[:, idx_range] # Data for all channels
109
+
110
+ for est_name, est_fun in self.estimators:
111
+ result = est_fun(Z_band, axis=1)
112
+
113
+ for ch_idx, ch_name in enumerate(self.ch_names):
114
+ feature_results[
115
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
116
+ ] = result[ch_idx]
117
+
118
+ if self.settings.return_spectrum:
119
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
120
+ for (ch_idx, ch_name), (idx, f) in combinations:
121
+ feature_results[f"{ch_name}_fft_psd_{int(f)}"] = Z[ch_idx][idx]
122
+
123
+ return feature_results
124
+
125
+
126
+ class Welch(OscillatoryFeature):
127
+ def __init__(
128
+ self,
129
+ settings: "NMSettings",
130
+ ch_names: Iterable[str],
131
+ sfreq: int,
132
+ ) -> None:
133
+ from scipy.fft import rfftfreq
134
+
135
+ self.osc_feature_name = "welch"
136
+ self.settings = settings.welch_settings
137
+ # super.__init__ needs osc_feature_name and settings
138
+ super().__init__(settings, ch_names, sfreq)
139
+
140
+ self.freqs = rfftfreq(self.sfreq, 1 / self.sfreq)
141
+
142
+ self.idx_range = [
143
+ (
144
+ f_band,
145
+ np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
146
+ )
147
+ for f_band, f_range in self.frequency_ranges.items()
148
+ ]
149
+
150
+ self.estimators = [
151
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
152
+ ]
153
+
154
+ def calc_feature(self, data: np.ndarray) -> dict:
155
+ from scipy.signal import welch
156
+
157
+ _, Z = welch(
158
+ data,
159
+ fs=self.sfreq,
160
+ window="hann",
161
+ nperseg=self.sfreq,
162
+ noverlap=None,
163
+ )
164
+
165
+ if self.settings.log_transform:
166
+ Z = np.log10(Z)
167
+
168
+ feature_results = {}
169
+
170
+ for f_band_name, idx_range in self.idx_range:
171
+ Z_band = Z[:, idx_range]
172
+
173
+ for est_name, est_fun in self.estimators:
174
+ result = est_fun(Z_band, axis=1)
175
+
176
+ for ch_idx, ch_name in enumerate(self.ch_names):
177
+ feature_results[
178
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
179
+ ] = result[ch_idx]
180
+
181
+ if self.settings.return_spectrum:
182
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
183
+ for (ch_idx, ch_name), (idx, f) in combinations:
184
+ feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
185
+
186
+ return feature_results
187
+
188
+
189
+ class STFT(OscillatoryFeature):
190
+ def __init__(
191
+ self,
192
+ settings: "NMSettings",
193
+ ch_names: Iterable[str],
194
+ sfreq: int,
195
+ ) -> None:
196
+ from scipy.fft import rfftfreq
197
+
198
+ self.osc_feature_name = "stft"
199
+ self.settings = settings.stft_settings
200
+ # super.__init__ needs osc_feature_name and settings
201
+ super().__init__(settings, ch_names, sfreq)
202
+
203
+ self.nperseg = self.settings.windowlength_ms
204
+
205
+ self.freqs = rfftfreq(self.nperseg, 1 / self.sfreq)
206
+
207
+ self.idx_range = [
208
+ (
209
+ f_band,
210
+ np.where((self.freqs >= f_range[0]) & (self.freqs <= f_range[1]))[0],
211
+ )
212
+ for f_band, f_range in self.frequency_ranges.items()
213
+ ]
214
+
215
+ self.estimators = [
216
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
217
+ ]
218
+
219
+ def calc_feature(self, data: np.ndarray) -> dict:
220
+ from scipy.signal import stft
221
+
222
+ _, _, Zxx = stft(
223
+ data,
224
+ fs=self.sfreq,
225
+ window="hamming",
226
+ nperseg=self.nperseg,
227
+ boundary="even",
228
+ )
229
+
230
+ Z = np.abs(Zxx)
231
+ if self.settings.log_transform:
232
+ Z = np.log10(Z)
233
+
234
+ feature_results = {}
235
+
236
+ for f_band_name, idx_range in self.idx_range:
237
+ Z_band = Z[:, idx_range, :]
238
+
239
+ for est_name, est_fun in self.estimators:
240
+ result = est_fun(Z_band, axis=(1, 2))
241
+
242
+ for ch_idx, ch_name in enumerate(self.ch_names):
243
+ feature_results[
244
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
245
+ ] = result[ch_idx]
246
+
247
+ if self.settings.return_spectrum:
248
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
249
+ for (ch_idx, ch_name), (idx, f) in combinations:
250
+ feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
251
+ axis=1
252
+ )[idx]
253
+
254
+ return feature_results
255
+
256
+
257
+ class BandpowerFeatures(BoolSelector):
258
+ activity: bool = True
259
+ mobility: bool = False
260
+ complexity: bool = False
261
+
262
+
263
+ ###################################
264
+ ######## BANDPOWER FEATURE ########
265
+ ###################################
266
+
267
+
268
+ class BandpassSettings(NMBaseModel):
269
+ segment_lengths_ms: dict[str, int] = {
270
+ "theta": 1000,
271
+ "alpha": 500,
272
+ "low_beta": 333,
273
+ "high_beta": 333,
274
+ "low_gamma": 100,
275
+ "high_gamma": 100,
276
+ "HFA": 100,
277
+ }
278
+ bandpower_features: BandpowerFeatures = BandpowerFeatures()
279
+ log_transform: bool = True
280
+ kalman_filter: bool = False
281
+
282
+ @field_validator("segment_lengths_ms")
283
+ @classmethod
284
+ # Replace spaces with underscores in frequency band names
285
+ def fbands_spaces_to_underscores(cls, segment_lengths_ms: dict[str, int]):
286
+ return {k.replace(" ", "_"): v for k, v in segment_lengths_ms.items()}
287
+
288
+ @field_validator("bandpower_features")
289
+ @classmethod
290
+ def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
291
+ assert (
292
+ len(bandpower_features.get_enabled()) > 0
293
+ ), "Set at least one bandpower_feature to True."
294
+
295
+ return bandpower_features
296
+
297
+ def validate_fbands(self, settings: "NMSettings") -> None:
298
+ # Ensure that each freq-band is defined in the global settings
299
+ for fband_name in settings.frequency_ranges_hz.keys():
300
+ assert fband_name in self.segment_lengths_ms, (
301
+ f"frequency range {fband_name} "
302
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
303
+ )
304
+
305
+ # Ensure that segment length for each freq-band is smaller than the feature segment length setting
306
+ for fband_name, seg_length_fband in self.segment_lengths_ms.items():
307
+ assert seg_length_fband <= settings.segment_length_features_ms, (
308
+ f"segment length {seg_length_fband} needs to be smaller than "
309
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
310
+ )
311
+
312
+
313
+ class BandPower(NMFeature):
314
+ def __init__(
315
+ self,
316
+ settings: "NMSettings",
317
+ ch_names: Iterable[str],
318
+ sfreq: float,
319
+ use_kf: bool | None = None,
320
+ ) -> None:
321
+ settings.validate()
322
+
323
+ self.bp_settings: BandpassSettings = settings.bandpass_filter_settings
324
+ self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings
325
+ self.sfreq = sfreq
326
+ self.ch_names = ch_names
327
+ self.KF_dict: dict = {}
328
+
329
+ from py_neuromodulation.nm_filter import MNEFilter
330
+
331
+ self.bandpass_filter = MNEFilter(
332
+ f_ranges=[
333
+ tuple(frange) for frange in settings.frequency_ranges_hz.values()
334
+ ],
335
+ sfreq=self.sfreq,
336
+ filter_length=self.sfreq - 1,
337
+ verbose=False,
338
+ )
339
+
340
+ if use_kf or (use_kf is None and self.bp_settings.kalman_filter):
341
+ self.init_KF("bandpass_activity")
342
+
343
+ seglengths = self.bp_settings.segment_lengths_ms
344
+
345
+ self.feature_params = []
346
+ for ch_idx, ch_name in enumerate(self.ch_names):
347
+ for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()):
348
+ seglength_ms = seglengths[f_band]
349
+ seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
350
+ for bp_feature in self.bp_settings.bandpower_features.get_enabled():
351
+ feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band])
352
+ self.feature_params.append(
353
+ (
354
+ ch_idx,
355
+ f_band_idx,
356
+ seglen,
357
+ bp_feature,
358
+ feature_name,
359
+ )
360
+ )
361
+
362
+ def init_KF(self, feature: str) -> None:
363
+ from py_neuromodulation.nm_kalmanfilter import define_KF
364
+
365
+ for f_band in self.kalman_filter_settings.frequency_bands:
366
+ for channel in self.ch_names:
367
+ self.KF_dict["_".join([channel, feature, f_band])] = define_KF(
368
+ self.kalman_filter_settings.Tp,
369
+ self.kalman_filter_settings.sigma_w,
370
+ self.kalman_filter_settings.sigma_v,
371
+ )
372
+
373
+ def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating:
374
+ if KF_name in self.KF_dict:
375
+ self.KF_dict[KF_name].predict()
376
+ self.KF_dict[KF_name].update(feature_calc)
377
+ feature_calc = self.KF_dict[KF_name].x[0]
378
+ return feature_calc
379
+
380
+ def calc_feature(self, data: np.ndarray) -> dict:
381
+ data = self.bandpass_filter.filter_data(data)
382
+
383
+ feature_results = {}
384
+
385
+ for (
386
+ ch_idx,
387
+ f_band_idx,
388
+ seglen,
389
+ bp_feature,
390
+ feature_name,
391
+ ) in self.feature_params:
392
+ feature_results[feature_name] = self.calc_bp_feature(
393
+ bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:]
394
+ )
395
+
396
+ return feature_results
397
+
398
+ def calc_bp_feature(self, bp_feature, feature_name, data):
399
+ match bp_feature:
400
+ case "activity":
401
+ feature_calc = np.var(data)
402
+ if self.bp_settings.log_transform:
403
+ feature_calc = np.log10(feature_calc)
404
+ if self.KF_dict:
405
+ feature_calc = self.update_KF(feature_calc, feature_name)
406
+ case "mobility":
407
+ feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data))
408
+ case "complexity":
409
+ feature_calc = self.calc_complexity(data)
410
+ case _:
411
+ raise ValueError(f"Unknown bandpower feature: {bp_feature}")
412
+
413
+ return np.nan_to_num(feature_calc)
414
+
415
+ @staticmethod
416
+ def calc_complexity(data: np.ndarray) -> float:
417
+ dat_deriv = np.diff(data)
418
+ deriv_variance = np.var(dat_deriv)
419
+ mobility = np.sqrt(deriv_variance / np.var(data))
420
+ dat_deriv_2_var = np.var(np.diff(dat_deriv))
421
+ deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
422
+
423
+ return deriv_mobility / mobility