py-neuromodulation 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
  3. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -0
  4. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -0
  5. py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
  6. py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
  7. py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
  8. py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
  9. py_neuromodulation/FieldTrip.py +589 -589
  10. py_neuromodulation/__init__.py +74 -13
  11. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  12. py_neuromodulation/data/README +6 -0
  13. py_neuromodulation/data/dataset_description.json +8 -0
  14. py_neuromodulation/data/participants.json +32 -0
  15. py_neuromodulation/data/participants.tsv +2 -0
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -0
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -0
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -0
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.eeg +0 -0
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -0
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -0
  22. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -0
  23. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -0
  24. py_neuromodulation/grid_cortex.tsv +40 -0
  25. py_neuromodulation/grid_subcortex.tsv +1429 -0
  26. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  27. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  28. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  29. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  30. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  31. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  32. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  33. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  34. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  35. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  36. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  37. py_neuromodulation/nm_IO.py +413 -417
  38. py_neuromodulation/nm_RMAP.py +496 -531
  39. py_neuromodulation/nm_analysis.py +993 -1074
  40. py_neuromodulation/nm_artifacts.py +30 -25
  41. py_neuromodulation/nm_bispectra.py +154 -168
  42. py_neuromodulation/nm_bursts.py +292 -198
  43. py_neuromodulation/nm_coherence.py +251 -205
  44. py_neuromodulation/nm_database.py +149 -0
  45. py_neuromodulation/nm_decode.py +918 -992
  46. py_neuromodulation/nm_define_nmchannels.py +300 -302
  47. py_neuromodulation/nm_features.py +144 -116
  48. py_neuromodulation/nm_filter.py +219 -219
  49. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  50. py_neuromodulation/nm_fooof.py +139 -159
  51. py_neuromodulation/nm_generator.py +45 -37
  52. py_neuromodulation/nm_hjorth_raw.py +52 -73
  53. py_neuromodulation/nm_kalmanfilter.py +71 -58
  54. py_neuromodulation/nm_linelength.py +21 -33
  55. py_neuromodulation/nm_logger.py +66 -0
  56. py_neuromodulation/nm_mne_connectivity.py +149 -112
  57. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  58. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  59. py_neuromodulation/nm_nolds.py +96 -93
  60. py_neuromodulation/nm_normalization.py +173 -214
  61. py_neuromodulation/nm_oscillatory.py +423 -448
  62. py_neuromodulation/nm_plots.py +585 -612
  63. py_neuromodulation/nm_preprocessing.py +83 -0
  64. py_neuromodulation/nm_projection.py +370 -394
  65. py_neuromodulation/nm_rereference.py +97 -95
  66. py_neuromodulation/nm_resample.py +59 -50
  67. py_neuromodulation/nm_run_analysis.py +325 -435
  68. py_neuromodulation/nm_settings.py +289 -68
  69. py_neuromodulation/nm_settings.yaml +244 -0
  70. py_neuromodulation/nm_sharpwaves.py +423 -401
  71. py_neuromodulation/nm_stats.py +464 -480
  72. py_neuromodulation/nm_stream.py +398 -0
  73. py_neuromodulation/nm_stream_abc.py +166 -218
  74. py_neuromodulation/nm_types.py +193 -0
  75. py_neuromodulation/plots/STN_surf.mat +0 -0
  76. py_neuromodulation/plots/Vertices.mat +0 -0
  77. py_neuromodulation/plots/faces.mat +0 -0
  78. py_neuromodulation/plots/grid.mat +0 -0
  79. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +185 -182
  80. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  81. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -2
  82. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info/licenses}/LICENSE +21 -21
  83. docs/build/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  84. docs/build/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -233
  85. docs/build/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  86. docs/build/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  87. docs/build/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  88. docs/build/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  89. docs/build/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  90. docs/build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  91. docs/build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -239
  92. docs/build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  93. docs/build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  94. docs/build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  95. docs/build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  96. docs/build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  97. docs/source/_build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -76
  98. docs/source/_build/html/_downloads/0d0d0a76e8f648d5d3cbc47da6351932/plot_real_time_demo.py +0 -97
  99. docs/source/_build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -240
  100. docs/source/_build/html/_downloads/5d73cadc59a8805c47e3b84063afc157/plot_example_BIDS.py +0 -233
  101. docs/source/_build/html/_downloads/7660317fa5a6bfbd12fcca9961457fc4/plot_example_rmap_computing.py +0 -63
  102. docs/source/_build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  103. docs/source/_build/html/_downloads/839e5b319379f7fd9e867deb00fd797f/plot_example_gridPointProjection.py +0 -210
  104. docs/source/_build/html/_downloads/ae8be19afe5e559f011fc9b138968ba0/plot_first_demo.py +0 -192
  105. docs/source/_build/html/_downloads/b8b06cacc17969d3725a0b6f1d7741c5/plot_example_sharpwave_analysis.py +0 -219
  106. docs/source/_build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -121
  107. docs/source/_build/html/_downloads/c31a86c0b68cb4167d968091ace8080d/plot_example_add_feature.py +0 -68
  108. docs/source/_build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  109. docs/source/_build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -189
  110. docs/source/_build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  111. docs/source/auto_examples/plot_0_first_demo.py +0 -189
  112. docs/source/auto_examples/plot_1_example_BIDS.py +0 -240
  113. docs/source/auto_examples/plot_2_example_add_feature.py +0 -76
  114. docs/source/auto_examples/plot_3_example_sharpwave_analysis.py +0 -219
  115. docs/source/auto_examples/plot_4_example_gridPointProjection.py +0 -210
  116. docs/source/auto_examples/plot_5_example_rmap_computing.py +0 -64
  117. docs/source/auto_examples/plot_6_real_time_demo.py +0 -121
  118. docs/source/conf.py +0 -105
  119. examples/plot_0_first_demo.py +0 -189
  120. examples/plot_1_example_BIDS.py +0 -240
  121. examples/plot_2_example_add_feature.py +0 -76
  122. examples/plot_3_example_sharpwave_analysis.py +0 -219
  123. examples/plot_4_example_gridPointProjection.py +0 -210
  124. examples/plot_5_example_rmap_computing.py +0 -64
  125. examples/plot_6_real_time_demo.py +0 -121
  126. packages/realtime_decoding/build/lib/realtime_decoding/__init__.py +0 -4
  127. packages/realtime_decoding/build/lib/realtime_decoding/decoder.py +0 -104
  128. packages/realtime_decoding/build/lib/realtime_decoding/features.py +0 -163
  129. packages/realtime_decoding/build/lib/realtime_decoding/helpers.py +0 -15
  130. packages/realtime_decoding/build/lib/realtime_decoding/run_decoding.py +0 -345
  131. packages/realtime_decoding/build/lib/realtime_decoding/trainer.py +0 -54
  132. packages/tmsi/build/lib/TMSiFileFormats/__init__.py +0 -37
  133. packages/tmsi/build/lib/TMSiFileFormats/file_formats/__init__.py +0 -36
  134. packages/tmsi/build/lib/TMSiFileFormats/file_formats/lsl_stream_writer.py +0 -200
  135. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_file_writer.py +0 -496
  136. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_to_edf_converter.py +0 -236
  137. packages/tmsi/build/lib/TMSiFileFormats/file_formats/xdf_file_writer.py +0 -977
  138. packages/tmsi/build/lib/TMSiFileFormats/file_readers/__init__.py +0 -35
  139. packages/tmsi/build/lib/TMSiFileFormats/file_readers/edf_reader.py +0 -116
  140. packages/tmsi/build/lib/TMSiFileFormats/file_readers/poly5reader.py +0 -294
  141. packages/tmsi/build/lib/TMSiFileFormats/file_readers/xdf_reader.py +0 -229
  142. packages/tmsi/build/lib/TMSiFileFormats/file_writer.py +0 -102
  143. packages/tmsi/build/lib/TMSiPlotters/__init__.py +0 -2
  144. packages/tmsi/build/lib/TMSiPlotters/gui/__init__.py +0 -39
  145. packages/tmsi/build/lib/TMSiPlotters/gui/_plotter_gui.py +0 -234
  146. packages/tmsi/build/lib/TMSiPlotters/gui/plotting_gui.py +0 -440
  147. packages/tmsi/build/lib/TMSiPlotters/plotters/__init__.py +0 -44
  148. packages/tmsi/build/lib/TMSiPlotters/plotters/hd_emg_plotter.py +0 -446
  149. packages/tmsi/build/lib/TMSiPlotters/plotters/impedance_plotter.py +0 -589
  150. packages/tmsi/build/lib/TMSiPlotters/plotters/signal_plotter.py +0 -1326
  151. packages/tmsi/build/lib/TMSiSDK/__init__.py +0 -54
  152. packages/tmsi/build/lib/TMSiSDK/device.py +0 -588
  153. packages/tmsi/build/lib/TMSiSDK/devices/__init__.py +0 -34
  154. packages/tmsi/build/lib/TMSiSDK/devices/saga/TMSi_Device_API.py +0 -1764
  155. packages/tmsi/build/lib/TMSiSDK/devices/saga/__init__.py +0 -34
  156. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_device.py +0 -1366
  157. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_types.py +0 -520
  158. packages/tmsi/build/lib/TMSiSDK/devices/saga/xml_saga_config.py +0 -165
  159. packages/tmsi/build/lib/TMSiSDK/error.py +0 -95
  160. packages/tmsi/build/lib/TMSiSDK/sample_data.py +0 -63
  161. packages/tmsi/build/lib/TMSiSDK/sample_data_server.py +0 -99
  162. packages/tmsi/build/lib/TMSiSDK/settings.py +0 -45
  163. packages/tmsi/build/lib/TMSiSDK/tmsi_device.py +0 -111
  164. packages/tmsi/build/lib/__init__.py +0 -4
  165. packages/tmsi/build/lib/apex_sdk/__init__.py +0 -34
  166. packages/tmsi/build/lib/apex_sdk/device/__init__.py +0 -41
  167. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API.py +0 -1009
  168. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_enums.py +0 -239
  169. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_structures.py +0 -668
  170. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_device.py +0 -1611
  171. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_dongle.py +0 -38
  172. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_event_reader.py +0 -57
  173. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_channel.py +0 -44
  174. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_config.py +0 -150
  175. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_const.py +0 -36
  176. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_impedance_channel.py +0 -48
  177. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_info.py +0 -108
  178. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/dongle_info.py +0 -39
  179. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/download_measurement.py +0 -77
  180. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/eeg_measurement.py +0 -150
  181. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/impedance_measurement.py +0 -129
  182. packages/tmsi/build/lib/apex_sdk/device/threads/conversion_thread.py +0 -59
  183. packages/tmsi/build/lib/apex_sdk/device/threads/sampling_thread.py +0 -57
  184. packages/tmsi/build/lib/apex_sdk/device/tmsi_channel.py +0 -83
  185. packages/tmsi/build/lib/apex_sdk/device/tmsi_device.py +0 -201
  186. packages/tmsi/build/lib/apex_sdk/device/tmsi_device_enums.py +0 -103
  187. packages/tmsi/build/lib/apex_sdk/device/tmsi_dongle.py +0 -43
  188. packages/tmsi/build/lib/apex_sdk/device/tmsi_event_reader.py +0 -50
  189. packages/tmsi/build/lib/apex_sdk/device/tmsi_measurement.py +0 -118
  190. packages/tmsi/build/lib/apex_sdk/sample_data_server/__init__.py +0 -33
  191. packages/tmsi/build/lib/apex_sdk/sample_data_server/event_data.py +0 -44
  192. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data.py +0 -50
  193. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data_server.py +0 -136
  194. packages/tmsi/build/lib/apex_sdk/tmsi_errors/error.py +0 -126
  195. packages/tmsi/build/lib/apex_sdk/tmsi_sdk.py +0 -113
  196. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/apex/apex_structure_generator.py +0 -134
  197. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/decorators.py +0 -60
  198. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/logger_filter.py +0 -42
  199. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/singleton.py +0 -42
  200. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/support_functions.py +0 -72
  201. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/tmsi_logger.py +0 -98
  202. py_neuromodulation/nm_EpochStream.py +0 -92
  203. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  204. py_neuromodulation/nm_cohortwrapper.py +0 -435
  205. py_neuromodulation/nm_eval_timing.py +0 -239
  206. py_neuromodulation/nm_features_abc.py +0 -39
  207. py_neuromodulation/nm_stream_offline.py +0 -358
  208. py_neuromodulation/utils/_logging.py +0 -24
  209. py_neuromodulation-0.0.3.dist-info/RECORD +0 -188
  210. py_neuromodulation-0.0.3.dist-info/top_level.txt +0 -5
  211. tests/__init__.py +0 -0
  212. tests/conftest.py +0 -117
  213. tests/test_all_examples.py +0 -10
  214. tests/test_all_features.py +0 -63
  215. tests/test_bispectra.py +0 -70
  216. tests/test_bursts.py +0 -105
  217. tests/test_feature_sampling_rates.py +0 -143
  218. tests/test_fooof.py +0 -16
  219. tests/test_initalization_offline_stream.py +0 -41
  220. tests/test_multiprocessing.py +0 -58
  221. tests/test_nan_values.py +0 -29
  222. tests/test_nm_filter.py +0 -95
  223. tests/test_nm_resample.py +0 -63
  224. tests/test_normalization_settings.py +0 -146
  225. tests/test_notch_filter.py +0 -31
  226. tests/test_osc_features.py +0 -424
  227. tests/test_preprocessing_filter.py +0 -151
  228. tests/test_rereference.py +0 -171
  229. tests/test_sampling.py +0 -57
  230. tests/test_settings_change_after_init.py +0 -76
  231. tests/test_sharpwave.py +0 -165
  232. tests/test_target_channel_add.py +0 -100
  233. tests/test_timing.py +0 -80
@@ -1,1074 +1,993 @@
1
- import os
2
- from pathlib import Path
3
- from re import VERBOSE
4
- import re
5
- from typing import Optional
6
-
7
- import _pickle as cPickle
8
- import numpy as np
9
- import pandas as pd
10
- from sklearn import base, linear_model, metrics, model_selection
11
- from scipy import stats
12
-
13
- from py_neuromodulation import nm_decode, nm_IO, nm_plots
14
-
15
- target_filter_str = {
16
- "CLEAN",
17
- "SQUARED_EMG",
18
- "SQUARED_INTERPOLATED_EMG",
19
- "SQUARED_ROTAWHEEL",
20
- "SQUARED_ROTATION" "rota_squared",
21
- }
22
- features_reverse_order_plotting = {"stft", "fft", "bandpass"}
23
-
24
-
25
- class Feature_Reader:
26
-
27
- feature_dir: str
28
- feature_list: list[str]
29
- settings: dict
30
- sidecar: dict
31
- sfreq: int
32
- line_noise: int
33
- nm_channels: pd.DataFrame
34
- feature_arr: pd.DataFrame
35
- ch_names: list[str]
36
- ch_names_ECOG: list[str]
37
- decoder: nm_decode.Decoder = None
38
-
39
- def __init__(
40
- self, feature_dir: str, feature_file: str, binarize_label: bool = True
41
- ) -> None:
42
- """Feature_Reader enables analysis methods on top of NM_reader and NM_Decoder
43
-
44
- Parameters
45
- ----------
46
- feature_dir : str, optional
47
- Path to py_neuromodulation estimated feature runs, where each feature is a folder,
48
- feature_file : str, optional
49
- specific feature run, if None it is set to the first feature folder in feature_dir
50
- binarize_label : bool
51
- binarize label, by default True
52
-
53
- """
54
- self.feature_dir = feature_dir
55
- self.feature_list = nm_IO.get_run_list_indir(self.feature_dir)
56
- if feature_file is None:
57
- self.feature_file = self.feature_list[0]
58
- else:
59
- self.feature_file = feature_file
60
-
61
- FILE_BASENAME = Path(self.feature_file).stem
62
- PATH_READ_FILE = str(
63
- Path(self.feature_dir, FILE_BASENAME, FILE_BASENAME)
64
- )
65
-
66
- self.settings = nm_IO.read_settings(PATH_READ_FILE)
67
- self.sidecar = nm_IO.read_sidecar(PATH_READ_FILE)
68
- if self.sidecar["sess_right"] is None:
69
- if "coords" in self.sidecar:
70
- if len(self.sidecar["coords"]["cortex_left"]["ch_names"]) > 0:
71
- self.sidecar["sess_right"] = False
72
- if len(self.sidecar["coords"]["cortex_right"]["ch_names"]) > 0:
73
- self.sidecar["sess_right"] = True
74
- self.sfreq = self.sidecar["sfreq"]
75
- self.nm_channels = nm_IO.read_nm_channels(PATH_READ_FILE)
76
- self.feature_arr = nm_IO.read_features(PATH_READ_FILE)
77
-
78
- self.ch_names = self.nm_channels.new_name
79
- self.used_chs = list(
80
- self.nm_channels[
81
- (self.nm_channels["target"] == 0)
82
- & (self.nm_channels["used"] == 1)
83
- ]["new_name"]
84
- )
85
- self.ch_names_ECOG = self.nm_channels.query(
86
- '(type=="ecog") and (used == 1) and (status=="good")'
87
- ).new_name.to_list()
88
-
89
- # init plotter
90
- self.nmplotter = nm_plots.NM_Plot()
91
- if self.nm_channels["target"].sum() > 0:
92
- self.label_name = self._get_target_ch()
93
- self.label = self.read_target_ch(
94
- self.feature_arr,
95
- self.label_name,
96
- binarize=binarize_label,
97
- binarize_th=0.3,
98
- )
99
-
100
- def _get_target_ch(self) -> str:
101
- target_names = list(
102
- self.nm_channels[self.nm_channels["target"] == 1]["name"]
103
- )
104
- target_clean = [
105
- target_name
106
- for target_name in target_names
107
- for filter_str in target_filter_str
108
- if filter_str.lower() in target_name.lower()
109
- ]
110
-
111
- if len(target_clean) == 0:
112
- if "ARTIFACT" not in target_names[0]:
113
- target = target_names[0]
114
- elif len(target_names) > 1:
115
- target = target_names[1]
116
- else:
117
- target = target_names[0]
118
- else:
119
- for target_ in target_clean:
120
- # try to select contralateral label
121
- if self.sidecar["sess_right"] is True and "LEFT" in target_:
122
- target = target_
123
- continue
124
- elif self.sidecar["sess_right"] is False and "RIGHT" in target_:
125
- target = target_
126
- continue
127
- if target_ == target_clean[-1]:
128
- target = target_clean[0] # set label to last element
129
- return target
130
-
131
- @staticmethod
132
- def read_target_ch(
133
- feature_arr: pd.DataFrame,
134
- label_name: str,
135
- binarize: bool = True,
136
- binarize_th: float = 0.3,
137
- ) -> None:
138
- """_summary_
139
-
140
- Parameters
141
- ----------
142
- feature_arr : pd.DataFrame
143
- _description_
144
- label_name : str
145
- _description_
146
- binarize : bool, optional
147
- _description_, by default True
148
- binarize_th : float, optional
149
- _description_, by default 0.3
150
-
151
- Returns
152
- -------
153
- _type_
154
- _description_
155
- """
156
-
157
- label = np.nan_to_num(np.array(feature_arr[label_name]))
158
- if binarize:
159
- label = label > binarize_th
160
- return label
161
-
162
- @staticmethod
163
- def filter_features(
164
- feature_columns: list,
165
- ch_name: str = None,
166
- list_feature_keywords: list[str] = None,
167
- ) -> list:
168
- """filters read features by ch_name and/or modality
169
-
170
- Parameters
171
- ----------
172
- feature_columns : list
173
- [description]
174
- ch_name : str, optional
175
- [description], by default None
176
- list_feature_keywords : list[str], optional
177
- list of feature strings that need to be in the columns, by default None
178
-
179
- Returns
180
- -------
181
- features : list
182
- column list that suffice the ch_name and list_feature_keywords
183
- """
184
-
185
- if ch_name is not None:
186
- feature_select = [i for i in list(feature_columns) if ch_name in i]
187
- else:
188
- feature_select = feature_columns
189
-
190
- if list_feature_keywords is not None:
191
- feature_select = [
192
- f
193
- for f in feature_select
194
- if any(x in f for x in list_feature_keywords)
195
- ]
196
-
197
- if (
198
- len(
199
- [
200
- mod
201
- for mod in features_reverse_order_plotting
202
- if mod in list_feature_keywords
203
- ]
204
- )
205
- > 0
206
- ):
207
- # flip list s.t. theta band is lowest in subsequent plot
208
- feature_select = feature_select[::-1]
209
-
210
- return feature_select
211
-
212
- def set_target_ch(self, ch_name: str) -> None:
213
- self.label = ch_name
214
-
215
- def normalize_features(
216
- self,
217
- ) -> pd.DataFrame:
218
- """Normalize feature_arr feature columns
219
-
220
- Returns:
221
- pd.DataFrame: z-scored feature_arr
222
- """
223
- cols_norm = [c for c in self.feature_arr.columns if "time" not in c]
224
- feature_arr_norm = stats.zscore(self.feature_arr[cols_norm], nan_policy="omit")
225
- feature_arr_norm["time"] = self.feature_arr["time"]
226
- return feature_arr_norm
227
-
228
- def plot_cort_projection(self) -> None:
229
- """_summary_
230
- """
231
-
232
- if self.sidecar["sess_right"]:
233
- ecog_strip = np.array(
234
- self.sidecar["coords"]["cortex_right"]["positions"]
235
- )
236
- else:
237
- ecog_strip = np.array(
238
- self.sidecar["coords"]["cortex_left"]["positions"]
239
- )
240
- self.nmplotter.plot_cortex(
241
- grid_cortex=np.array(self.sidecar["grid_cortex"])
242
- if "grid_cortex" in self.sidecar
243
- else None,
244
- ecog_strip=ecog_strip,
245
- grid_color=np.array(self.sidecar["proj_matrix_cortex"]).sum(axis=1)
246
- if "grid_cortex" in self.sidecar
247
- else None,
248
- set_clim=False,
249
- )
250
-
251
- def plot_target_avg_all_channels(
252
- self,
253
- ch_names_ECOG=None,
254
- list_feature_keywords: list[str] = ["stft"],
255
- epoch_len: int = 4,
256
- threshold: float = 0.1,
257
- ):
258
- """Wrapper that call plot_features_per_channel
259
- for every given ECoG channel
260
-
261
- Parameters
262
- ----------
263
- ch_names_ECOG : list, optional
264
- list of ECoG channel to plot features for, by default None
265
- list_feature_keywords : list[str], optional
266
- keywords to plot, by default ["stft"]
267
- epoch_len : int, optional
268
- epoch length in seconds, by default 4
269
- threshold : float, optional
270
- threshold for event detection, by default 0.1
271
- """
272
-
273
- if ch_names_ECOG is None:
274
- ch_names_ECOG = self.ch_names_ECOG
275
- for ch_name_ECOG in ch_names_ECOG:
276
- self.plot_target_averaged_channel(
277
- ch=ch_name_ECOG,
278
- list_feature_keywords=list_feature_keywords,
279
- epoch_len=epoch_len,
280
- threshold=threshold,
281
- )
282
-
283
- def plot_target_averaged_channel(
284
- self,
285
- ch: str = None,
286
- list_feature_keywords: Optional[list[str]] = None,
287
- features_to_plt: list = None,
288
- epoch_len: int = 4,
289
- threshold: float = 0.1,
290
- normalize_data: bool = True,
291
- show_plot: bool = True,
292
- title: str = "Movement aligned features",
293
- ytick_labelsize=None,
294
- figsize_x: float = 8,
295
- figsize_y: float = 8,
296
- ) -> None:
297
- """_summary_
298
-
299
- Parameters
300
- ----------
301
- ch : str, optional
302
- _description_, by default None
303
- list_feature_keywords : Optional[list[str]], optional
304
- _description_, by default None
305
- features_to_plt : list, optional
306
- _description_, by default None
307
- epoch_len : int, optional
308
- _description_, by default 4
309
- threshold : float, optional
310
- _description_, by default 0.1
311
- normalize_data : bool, optional
312
- _description_, by default True
313
- show_plot : bool, optional
314
- _description_, by default True
315
- title : str, optional
316
- _description_, by default "Movement aligned features"
317
- ytick_labelsize : _type_, optional
318
- _description_, by default None
319
- figsize_x : float, optional
320
- _description_, by default 8
321
- figsize_y : float, optional
322
- _description_, by default 8
323
- """
324
-
325
- # TODO: This does not work properly when we have bipolar rereferencing
326
-
327
- if features_to_plt is None:
328
-
329
- filtered_df = self.feature_arr[
330
- self.filter_features(
331
- self.feature_arr.columns, ch, list_feature_keywords
332
- )[::-1]
333
- ]
334
- else:
335
- filtered_df = self.feature_arr[features_to_plt]
336
-
337
- data = np.expand_dims(np.array(filtered_df), axis=1)
338
-
339
- X_epoch, y_epoch = self.get_epochs(
340
- data,
341
- self.label,
342
- epoch_len=epoch_len,
343
- sfreq=self.settings["sampling_rate_features_hz"],
344
- threshold=threshold,
345
- )
346
-
347
- nm_plots.plot_epochs_avg(
348
- X_epoch=X_epoch,
349
- y_epoch=y_epoch,
350
- epoch_len=epoch_len,
351
- sfreq=self.settings["sampling_rate_features_hz"],
352
- feature_names=list(filtered_df.columns),
353
- feature_str_add="_".join(list_feature_keywords)
354
- if list_feature_keywords is not None
355
- else "all",
356
- cut_ch_name_cols=True,
357
- ch_name=ch if ch is not None else None,
358
- label_name=self.label_name,
359
- normalize_data=normalize_data,
360
- show_plot=show_plot,
361
- save=True,
362
- OUT_PATH=self.feature_dir,
363
- feature_file=self.feature_file,
364
- str_title=title,
365
- ytick_labelsize=ytick_labelsize,
366
- figsize_x=figsize_x,
367
- figsize_y=figsize_y
368
- )
369
-
370
- def plot_all_features(
371
- self,
372
- ch_used: str = None,
373
- time_limit_low_s: float = None,
374
- time_limit_high_s: float = None,
375
- normalize: bool = True,
376
- save: bool = False,
377
- title="all_feature_plt.pdf",
378
- ytick_labelsize: int = 10,
379
- clim_low: float = None,
380
- clim_high: float = None,
381
- ):
382
- """_summary_
383
-
384
- Parameters
385
- ----------
386
- ch_used : str, optional
387
- _description_, by default None
388
- time_limit_low_s : float, optional
389
- _description_, by default None
390
- time_limit_high_s : float, optional
391
- _description_, by default None
392
- normalize : bool, optional
393
- _description_, by default True
394
- save : bool, optional
395
- _description_, by default False
396
- title : str, optional
397
- _description_, by default "all_feature_plt.pdf"
398
- ytick_labelsize : int, optional
399
- _description_, by default 10
400
- clim_low : float, optional
401
- _description_, by default None
402
- clim_high : float, optional
403
- _description_, by default None
404
- """
405
-
406
- if ch_used is not None:
407
- col_used = [
408
- c
409
- for c in self.feature_arr.columns
410
- if c.startswith(ch_used)
411
- or c == "time"
412
- or "LABEL" in c
413
- or "MOV" in c
414
- ]
415
- df = self.feature_arr[col_used[::-1]]
416
- else:
417
- df = self.feature_arr[self.feature_arr.columns[::-1]]
418
-
419
- nm_plots.plot_all_features(
420
- df=df,
421
- time_limit_low_s=time_limit_low_s,
422
- time_limit_high_s=time_limit_high_s,
423
- normalize=normalize,
424
- save=save,
425
- title=title,
426
- ytick_labelsize=ytick_labelsize,
427
- feature_file=self.feature_file,
428
- OUT_PATH=self.feature_dir,
429
- clim_low=clim_low,
430
- clim_high=clim_high,
431
- )
432
-
433
- @staticmethod
434
- def get_performace_sub_strip(performance_sub: dict, plt_grid: bool = False):
435
- """_summary_
436
-
437
- Parameters
438
- ----------
439
- performance_sub : dict
440
- _description_
441
- plt_grid : bool, optional
442
- _description_, by default False
443
-
444
- Returns
445
- -------
446
- _type_
447
- _description_
448
- """
449
-
450
- ecog_strip_performance = []
451
- ecog_coords_strip = []
452
- cortex_grid = []
453
- grid_performance = []
454
-
455
- channels_ = performance_sub.keys()
456
-
457
- for ch in channels_:
458
- if "grid" not in ch and "combined" not in ch:
459
- ecog_coords_strip.append(performance_sub[ch]["coord"])
460
- ecog_strip_performance.append(
461
- performance_sub[ch]["performance_test"]
462
- )
463
- elif plt_grid is True and "gridcortex_" in ch:
464
- cortex_grid.append(performance_sub[ch]["coord"])
465
- grid_performance.append(performance_sub[ch]["performance_test"])
466
-
467
- if len(ecog_coords_strip) > 0:
468
- ecog_coords_strip = np.vstack(ecog_coords_strip)
469
-
470
- return (
471
- ecog_strip_performance,
472
- ecog_coords_strip,
473
- cortex_grid,
474
- grid_performance,
475
- )
476
-
477
- def plot_across_subject_grd_ch_performance(
478
- self,
479
- performance_dict=None,
480
- plt_grid=False,
481
- feature_str_add="performance_allch_allgrid",
482
- ):
483
- ecog_strip_performance = []
484
- ecog_coords_strip = []
485
- grid_performance = []
486
- for sub in performance_dict.keys():
487
- (
488
- ecog_strip_performance_sub,
489
- ecog_coords_strip_sub,
490
- _,
491
- grid_performance_sub,
492
- ) = self.get_performace_sub_strip(
493
- performance_dict[sub], plt_grid=plt_grid
494
- )
495
- ecog_strip_performance.extend(ecog_strip_performance_sub)
496
- ecog_coords_strip.extend(ecog_coords_strip_sub)
497
- grid_performance.append(grid_performance_sub)
498
- grid_performance = list(np.vstack(grid_performance).mean(axis=0))
499
- coords_all = np.array(ecog_coords_strip)
500
- coords_all[:, 0] = np.abs(coords_all[:, 0])
501
-
502
- self.nmplotter.plot_cortex(
503
- grid_cortex=np.array(self.sidecar["grid_cortex"])
504
- if "grid_cortex" in self.sidecar
505
- else None,
506
- ecog_strip=coords_all if len(ecog_coords_strip) > 0 else None,
507
- grid_color=grid_performance if len(grid_performance) > 0 else None,
508
- strip_color=np.array(ecog_strip_performance)
509
- if len(ecog_strip_performance) > 0
510
- else None,
511
- sess_right=self.sidecar["sess_right"],
512
- save=True,
513
- OUT_PATH=self.feature_dir,
514
- feature_file=self.feature_file,
515
- feature_str_add=feature_str_add,
516
- show_plot=True,
517
- )
518
-
519
- def plot_subject_grid_ch_performance(
520
- self,
521
- subject_name=None,
522
- performance_dict=None,
523
- plt_grid=False,
524
- feature_str_add="performance_allch_allgrid",
525
- ):
526
- """plot subject specific performance for individual channeal and optional grid points
527
-
528
- Parameters
529
- ----------
530
- subject_name : string, optional
531
- used subject, by default None
532
- performance_dict : dict, optional
533
- [description], by default None
534
- plt_grid : bool, optional
535
- True to plot grid performances, by default False
536
- feature_str_add : string, optional
537
- figure output_name
538
- """
539
-
540
- ecog_strip_performance = []
541
- ecog_coords_strip = []
542
- cortex_grid = []
543
- grid_performance = []
544
-
545
- if subject_name is None:
546
- subject_name = self.feature_file[
547
- self.feature_file.find("sub-") : self.feature_file.find("_ses")
548
- ][4:]
549
-
550
- (
551
- ecog_strip_performance,
552
- ecog_coords_strip,
553
- cortex_grid,
554
- grid_performance,
555
- ) = self.get_performace_sub_strip(
556
- performance_dict[subject_name], plt_grid=plt_grid
557
- )
558
-
559
- self.nmplotter.plot_cortex(
560
- grid_cortex=np.array(self.sidecar["grid_cortex"])
561
- if "grid_cortex" in self.sidecar
562
- else None,
563
- ecog_strip=ecog_coords_strip
564
- if len(ecog_coords_strip) > 0
565
- else None,
566
- grid_color=grid_performance if len(grid_performance) > 0 else None,
567
- strip_color=ecog_strip_performance
568
- if len(ecog_strip_performance) > 0
569
- else None,
570
- sess_right=self.sidecar["sess_right"],
571
- save=True,
572
- OUT_PATH=self.feature_dir,
573
- feature_file=self.feature_file,
574
- feature_str_add=feature_str_add,
575
- show_plot=True,
576
- )
577
-
578
- def plot_feature_series_time(
579
- self,
580
- ):
581
- self.nmplotter.plot_feature_series_time(self.feature_arr)
582
-
583
- def plot_corr_matrix(
584
- self,
585
- ):
586
- return nm_plots.plot_corr_matrix(
587
- self.feature_arr,
588
- )
589
-
590
- @staticmethod
591
- def get_epochs(data, y_, epoch_len, sfreq, threshold=0) -> (np.ndarray, np.ndarray):
592
- """Return epoched data.
593
-
594
- Parameters
595
- ----------
596
- data : np.ndarray
597
- array of extracted features of shape (n_samples, n_channels, n_features)
598
- y_ : np.ndarray
599
- array of labels e.g. ones for movement and zeros for
600
- no movement or baseline corr. rotameter data
601
- epoch_len : int
602
- length of epoch in seconds
603
- sfreq : int/float
604
- sampling frequency of data
605
- threshold : int/float
606
- (Optional) threshold to be used for identifying events
607
- (default=0 for y_tr with only ones
608
- and zeros)
609
-
610
- Returns
611
- -------
612
- epoch_ : np.ndarray
613
- array of epoched ieeg data with shape (epochs,samples,channels,features)
614
- y_arr : np.ndarray
615
- array of epoched event label data with shape (epochs,samples)
616
- """
617
-
618
- epoch_lim = int(epoch_len * sfreq)
619
-
620
- ind_mov = np.where(np.diff(np.array(y_ > threshold) * 1) == 1)[0]
621
-
622
- low_limit = ind_mov > epoch_lim / 2
623
- up_limit = ind_mov < y_.shape[0] - epoch_lim / 2
624
-
625
- ind_mov = ind_mov[low_limit & up_limit]
626
-
627
- epoch_ = np.zeros(
628
- [ind_mov.shape[0], epoch_lim, data.shape[1], data.shape[2]]
629
- )
630
-
631
- y_arr = np.zeros([ind_mov.shape[0], int(epoch_lim)])
632
-
633
- for idx, i in enumerate(ind_mov):
634
-
635
- epoch_[idx, :, :, :] = data[
636
- i - epoch_lim // 2 : i + epoch_lim // 2, :, :
637
- ]
638
-
639
- y_arr[idx, :] = y_[i - epoch_lim // 2 : i + epoch_lim // 2]
640
-
641
- return epoch_, y_arr
642
-
643
- def set_decoder(
644
- self,
645
- decoder: nm_decode.Decoder = None,
646
- TRAIN_VAL_SPLIT=False,
647
- RUN_BAY_OPT=False,
648
- save_coef=False,
649
- model: base.BaseEstimator = linear_model.LogisticRegression,
650
- eval_method=metrics.r2_score,
651
- cv_method: model_selection.BaseCrossValidator = model_selection.KFold(
652
- n_splits=3, shuffle=False
653
- ),
654
- get_movement_detection_rate: bool = False,
655
- mov_detection_threshold=0.5,
656
- min_consequent_count=3,
657
- threshold_score=True,
658
- bay_opt_param_space: list = [],
659
- STACK_FEATURES_N_SAMPLES=False,
660
- time_stack_n_samples=5,
661
- use_nested_cv=False,
662
- VERBOSE=False,
663
- undersampling=False,
664
- oversampling=False,
665
- mrmr_select=False,
666
- pca=False,
667
- cca=False,
668
- ):
669
- if decoder is not None:
670
- self.decoder = decoder
671
- else:
672
-
673
- self.decoder = nm_decode.Decoder(
674
- features=self.feature_arr,
675
- label=self.label,
676
- label_name=self.label_name,
677
- used_chs=self.used_chs,
678
- model=model,
679
- eval_method=eval_method,
680
- cv_method=cv_method,
681
- threshold_score=threshold_score,
682
- TRAIN_VAL_SPLIT=TRAIN_VAL_SPLIT,
683
- RUN_BAY_OPT=RUN_BAY_OPT,
684
- save_coef=save_coef,
685
- get_movement_detection_rate=get_movement_detection_rate,
686
- min_consequent_count=min_consequent_count,
687
- mov_detection_threshold=mov_detection_threshold,
688
- bay_opt_param_space=bay_opt_param_space,
689
- STACK_FEATURES_N_SAMPLES=STACK_FEATURES_N_SAMPLES,
690
- time_stack_n_samples=time_stack_n_samples,
691
- VERBOSE=VERBOSE,
692
- use_nested_cv=use_nested_cv,
693
- undersampling=undersampling,
694
- oversampling=oversampling,
695
- mrmr_select=mrmr_select,
696
- sfreq=self.sfreq,
697
- pca=pca,
698
- cca=cca,
699
- )
700
-
701
- def run_ML_model(
702
- self,
703
- feature_file: str = None,
704
- estimate_gridpoints: bool = False,
705
- estimate_channels: bool = True,
706
- estimate_all_channels_combined: bool = False,
707
- output_name: str = "LM",
708
- save_results: bool = True,
709
- ):
710
- """machine learning model evaluation for ECoG strip channels and/or grid points
711
-
712
- Parameters
713
- ----------
714
- feature_file : string, optional
715
- [description], by default None
716
- estimate_gridpoints : bool, optional
717
- run ML analysis for grid points, by default True
718
- estimate_channels : bool, optional
719
- run ML analysis for ECoG strip channel, by default True
720
- estimate_all_channels_combined : bool, optional
721
- run ML analysis features of all channels concatenated, by default False
722
- model : sklearn model, optional
723
- ML model, needs to obtain fit and predict functions,
724
- by default linear_model.LogisticRegression(class_weight="balanced")
725
- eval_method : sklearn.metrics, optional
726
- evaluation performance metric, by default metrics.balanced_accuracy_score
727
- cv_method : sklearn.model_selection, optional
728
- valdation strategy, by default model_selection.KFold(n_splits=3, shuffle=False)
729
- output_name : str, optional
730
- saving name, by default "LM"
731
- save_results : boolean
732
- if true, save model._coef trained coefficients
733
- """
734
- if feature_file is None:
735
- feature_file = self.feature_file
736
-
737
- if estimate_gridpoints:
738
- self.decoder.set_data_grid_points()
739
- _ = self.decoder.run_CV_caller("grid_points")
740
- if estimate_channels:
741
- self.decoder.set_data_ind_channels()
742
- _ = self.decoder.run_CV_caller("ind_channels")
743
- if estimate_all_channels_combined:
744
- _ = self.decoder.run_CV_caller("all_channels_combined")
745
-
746
- if save_results:
747
- self.decoder.save(
748
- self.feature_dir,
749
- self.feature_file
750
- if ".vhdr" in self.feature_file
751
- else self.feature_file,
752
- output_name,
753
- )
754
-
755
- return self.read_results(
756
- read_grid_points=estimate_gridpoints,
757
- read_all_combined=estimate_all_channels_combined,
758
- read_channels=estimate_channels,
759
- ML_model_name=output_name,
760
- read_mov_detection_rates=self.decoder.get_movement_detection_rate,
761
- read_bay_opt_params=self.decoder.RUN_BAY_OPT,
762
- read_mrmr=self.decoder.mrmr_select,
763
- model_save=self.decoder.model_save,
764
- )
765
-
766
- def read_results(
767
- self,
768
- performance_dict: dict = {},
769
- subject_name: str = None,
770
- DEFAULT_PERFORMANCE: float = 0.5,
771
- read_grid_points: bool = True,
772
- read_channels: bool = True,
773
- read_all_combined: bool = False,
774
- ML_model_name: str = "LM",
775
- read_mov_detection_rates: bool = False,
776
- read_bay_opt_params: bool = False,
777
- read_mrmr: bool = False,
778
- model_save: bool = False,
779
- save_results: bool = False,
780
- PATH_OUT: str = None,
781
- folder_name: str = None,
782
- str_add: str = None,
783
- ):
784
- """Save performances of a given patient into performance_dict from saved nm_decoder
785
-
786
- Parameters
787
- ----------
788
- performance_dict : dictionary
789
- dictionary including decoding performances, by default dictionary
790
- subject_name : string, optional
791
- subject name, by default None
792
- DEFAULT_PERFORMANCE : float, optional
793
- chance performance, by default 0.5
794
- read_grid_points : bool, optional
795
- true if grid point performances are read, by default True
796
- read_channels : bool, optional
797
- true if channels performances are read, by default True
798
- read_all_combined : bool, optional
799
- true if all combined channel performances are read, by default False
800
- ML_model_name : str, optional
801
- machine learning model name, by default 'LM'
802
- read_mov_detection_rates : boolean, by defaulte False
803
- if True, read movement detection rates, as well as fpr's and tpr's
804
- read_bay_opt_params : boolean, by default False
805
- read_mrmr : boolean, by default False
806
- model_save : boolean, by default False
807
- save_results : boolean, by default False
808
- PATH_OUT : string, by default None
809
- folder_name : string, by default None
810
- str_add : string, by default None
811
-
812
- Returns
813
- -------
814
- performance_dict : dictionary
815
-
816
- """
817
-
818
- if ".vhdr" in self.feature_file:
819
- feature_file = self.feature_file[: -len(".vhdr")]
820
- else:
821
- feature_file = self.feature_file
822
-
823
- if subject_name is None:
824
- subject_name = feature_file[
825
- feature_file.find("sub-") : feature_file.find("_ses")
826
- ][4:]
827
-
828
- PATH_ML_ = os.path.join(
829
- self.feature_dir,
830
- feature_file,
831
- feature_file + "_" + ML_model_name + "_ML_RES.p",
832
- )
833
-
834
- # read ML results
835
- with open(PATH_ML_, "rb") as input:
836
- ML_res = cPickle.load(input)
837
- if self.decoder is None:
838
- self.decoder = ML_res
839
-
840
- performance_dict[subject_name] = {}
841
-
842
- def write_CV_res_in_performance_dict(
843
- obj_read,
844
- obj_write,
845
- read_mov_detection_rates=read_mov_detection_rates,
846
- read_bay_opt_params=False,
847
- ):
848
- def transform_list_of_dicts_into_dict_of_lists(l_):
849
- dict_out = {}
850
- for key_, _ in l_[0].items():
851
- key_l = []
852
- for dict_ in l_:
853
- key_l.append(dict_[key_])
854
- dict_out[key_] = key_l
855
- return dict_out
856
-
857
- def read_ML_performances(
858
- obj_read, obj_write, set_inner_CV_res: bool = False
859
- ):
860
- def set_score(
861
- key_set: str,
862
- key_get: str,
863
- take_mean: bool = True,
864
- val=None,
865
- ):
866
- if set_inner_CV_res is True:
867
- key_set = "InnerCV_" + key_set
868
- key_get = "InnerCV_" + key_get
869
- if take_mean is True:
870
- val = np.mean(obj_read[key_get])
871
- obj_write[key_set] = val
872
-
873
- set_score(
874
- key_set="performance_test",
875
- key_get="score_test",
876
- take_mean=True,
877
- )
878
- set_score(
879
- key_set="performance_train",
880
- key_get="score_train",
881
- take_mean=True,
882
- )
883
-
884
- if "coef" in obj_read:
885
- set_score(
886
- key_set="coef",
887
- key_get="coef",
888
- take_mean=False,
889
- val=np.concatenate(obj_read["coef"]),
890
- )
891
-
892
- if read_mov_detection_rates:
893
- set_score(
894
- key_set="mov_detection_rates_test",
895
- key_get="mov_detection_rates_test",
896
- take_mean=True,
897
- )
898
- set_score(
899
- key_set="mov_detection_rates_train",
900
- key_get="mov_detection_rates_train",
901
- take_mean=True,
902
- )
903
- set_score(
904
- key_set="fprate_test",
905
- key_get="fprate_test",
906
- take_mean=True,
907
- )
908
- set_score(
909
- key_set="fprate_train",
910
- key_get="fprate_train",
911
- take_mean=True,
912
- )
913
- set_score(
914
- key_set="tprate_test",
915
- key_get="tprate_test",
916
- take_mean=True,
917
- )
918
- set_score(
919
- key_set="tprate_train",
920
- key_get="tprate_train",
921
- take_mean=True,
922
- )
923
-
924
- if read_bay_opt_params is True:
925
- # transform dict into keys for json saving
926
- dict_to_save = transform_list_of_dicts_into_dict_of_lists(
927
- obj_read["best_bay_opt_params"]
928
- )
929
- set_score(
930
- key_set="bay_opt_best_params",
931
- key_get=None,
932
- take_mean=False,
933
- val=dict_to_save,
934
- )
935
-
936
- if read_mrmr is True:
937
- # transform dict into keys for json saving
938
-
939
- set_score(
940
- key_set="mrmr_select",
941
- key_get=None,
942
- take_mean=False,
943
- val=obj_read["mrmr_select"],
944
- )
945
- if model_save is True:
946
- set_score(
947
- key_set="model_save",
948
- key_get=None,
949
- take_mean=False,
950
- val=obj_read["model_save"],
951
- )
952
-
953
- read_ML_performances(obj_read, obj_write)
954
-
955
- if (
956
- len([key_ for key_ in obj_read.keys() if "InnerCV_" in key_])
957
- > 0
958
- ):
959
- read_ML_performances(obj_read, obj_write, set_inner_CV_res=True)
960
-
961
- if read_channels:
962
-
963
- ch_to_use = self.ch_names_ECOG
964
- ch_to_use = self.decoder.used_chs
965
- for ch in ch_to_use:
966
-
967
- performance_dict[subject_name][ch] = {}
968
-
969
- if "coords" in self.sidecar:
970
- if (
971
- len(self.sidecar["coords"]) > 0
972
- ): # check if coords are empty
973
-
974
- coords_exist = False
975
- for cortex_loc in self.sidecar["coords"].keys():
976
- for ch_name_coord_idx, ch_name_coord in enumerate(
977
- self.sidecar["coords"][cortex_loc]["ch_names"]
978
- ):
979
- if ch.startswith(ch_name_coord):
980
- coords = self.sidecar["coords"][cortex_loc][
981
- "positions"
982
- ][ch_name_coord_idx]
983
- coords_exist = True # optimally break out of the two loops...
984
- if coords_exist is False:
985
- coords = None
986
- performance_dict[subject_name][ch]["coord"] = coords
987
- write_CV_res_in_performance_dict(
988
- ML_res.ch_ind_results[ch],
989
- performance_dict[subject_name][ch],
990
- read_mov_detection_rates=read_mov_detection_rates,
991
- read_bay_opt_params=read_bay_opt_params,
992
- )
993
-
994
- if read_all_combined:
995
- performance_dict[subject_name]["all_ch_combined"] = {}
996
- write_CV_res_in_performance_dict(
997
- ML_res.all_ch_results,
998
- performance_dict[subject_name]["all_ch_combined"],
999
- read_mov_detection_rates=read_mov_detection_rates,
1000
- read_bay_opt_params=read_bay_opt_params,
1001
- )
1002
-
1003
- if read_grid_points:
1004
- performance_dict[subject_name][
1005
- "active_gridpoints"
1006
- ] = ML_res.active_gridpoints
1007
-
1008
- for project_settings, grid_type in zip(
1009
- ["project_cortex", "project_subcortex"],
1010
- ["gridcortex_", "gridsubcortex_"],
1011
- ):
1012
- if self.settings["postprocessing"][project_settings] is False:
1013
- continue
1014
-
1015
- # the sidecar keys are grid_cortex and subcortex_grid
1016
- for grid_point in range(
1017
- len(self.sidecar["grid_" + project_settings.split("_")[1]])
1018
- ):
1019
-
1020
- gp_str = grid_type + str(grid_point)
1021
-
1022
- performance_dict[subject_name][gp_str] = {}
1023
- performance_dict[subject_name][gp_str][
1024
- "coord"
1025
- ] = self.sidecar["grid_" + project_settings.split("_")[1]][
1026
- grid_point
1027
- ]
1028
-
1029
- if gp_str in ML_res.active_gridpoints:
1030
- write_CV_res_in_performance_dict(
1031
- ML_res.gridpoint_ind_results[gp_str],
1032
- performance_dict[subject_name][gp_str],
1033
- read_mov_detection_rates=read_mov_detection_rates,
1034
- read_bay_opt_params=read_bay_opt_params,
1035
- )
1036
- else:
1037
- # set non interpolated grid point to default performance
1038
- performance_dict[subject_name][gp_str][
1039
- "performance_test"
1040
- ] = DEFAULT_PERFORMANCE
1041
- performance_dict[subject_name][gp_str][
1042
- "performance_train"
1043
- ] = DEFAULT_PERFORMANCE
1044
-
1045
- if save_results:
1046
- nm_IO.save_general_dict(
1047
- dict_=performance_dict,
1048
- path_out=PATH_OUT,
1049
- str_add=str_add,
1050
- folder_name=folder_name,
1051
- )
1052
- return performance_dict
1053
-
1054
- @staticmethod
1055
- def get_dataframe_performances(p: dict) -> pd.DataFrame:
1056
- performances = []
1057
- for sub in p.keys():
1058
- for ch in p[sub].keys():
1059
- if "active_gridpoints" in ch:
1060
- continue
1061
- dict_add = p[sub][ch].copy()
1062
- dict_add["sub"] = sub
1063
- dict_add["ch"] = ch
1064
-
1065
- if "all_ch_" in ch:
1066
- dict_add["ch_type"] = "all ch combinded"
1067
- elif "gridcortex" in ch:
1068
- dict_add["ch_type"] = "cortex grid"
1069
- else:
1070
- dict_add["ch_type"] = "electrode ch"
1071
- performances.append(dict_add)
1072
- df = pd.DataFrame(performances)
1073
-
1074
- return df
1
+ from pathlib import PurePath
2
+
3
+ import pickle
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from sklearn.linear_model import LogisticRegression
8
+ from sklearn.metrics import r2_score
9
+ from sklearn.model_selection import KFold
10
+
11
+ from scipy.stats import zscore as scipy_zscore
12
+
13
+ from py_neuromodulation import nm_IO, nm_plots
14
+ from py_neuromodulation.nm_decode import Decoder
15
+ from py_neuromodulation.nm_types import _PathLike
16
+ from py_neuromodulation.nm_settings import NMSettings
17
+
18
+
19
+ target_filter_str = {
20
+ "CLEAN",
21
+ "SQUARED_EMG",
22
+ "SQUARED_INTERPOLATED_EMG",
23
+ "SQUARED_ROTAWHEEL",
24
+ "SQUARED_ROTATION" "rota_squared",
25
+ }
26
+ features_reverse_order_plotting = {"stft", "fft", "bandpass"}
27
+
28
+
29
+ class FeatureReader:
30
+ def __init__(
31
+ self,
32
+ feature_dir: _PathLike,
33
+ feature_file: _PathLike = "",
34
+ binarize_label: bool = True,
35
+ ) -> None:
36
+ """Feature_Reader enables analysis methods on top of NM_reader and NM_Decoder
37
+
38
+ Parameters
39
+ ----------
40
+ feature_dir : str, optional
41
+ Path to py_neuromodulation estimated feature runs, where each feature is a folder,
42
+ feature_file : str, optional
43
+ specific feature run, if None it is set to the first feature folder in feature_dir
44
+ binarize_label : bool
45
+ binarize label, by default True
46
+
47
+ """
48
+ self.feature_dir = feature_dir
49
+ self.feature_list: list[str] = nm_IO.get_run_list_indir(self.feature_dir)
50
+ self.feature_file = feature_file if feature_file else self.feature_list[0]
51
+
52
+ FILE_BASENAME = PurePath(self.feature_file).stem
53
+ PATH_READ_FILE = str(PurePath(self.feature_dir, FILE_BASENAME, FILE_BASENAME))
54
+
55
+ self.settings = NMSettings.from_file(PATH_READ_FILE)
56
+ self.sidecar = nm_IO.read_sidecar(PATH_READ_FILE)
57
+ if self.sidecar["sess_right"] is None:
58
+ if "coords" in self.sidecar:
59
+ if len(self.sidecar["coords"]["cortex_left"]["ch_names"]) > 0:
60
+ self.sidecar["sess_right"] = False
61
+ if len(self.sidecar["coords"]["cortex_right"]["ch_names"]) > 0:
62
+ self.sidecar["sess_right"] = True
63
+ self.sfreq = self.sidecar["sfreq"]
64
+ self.nm_channels = nm_IO.read_nm_channels(PATH_READ_FILE)
65
+ self.feature_arr = nm_IO.read_features(PATH_READ_FILE)
66
+
67
+ self.ch_names = self.nm_channels.new_name
68
+ self.used_chs = list(
69
+ self.nm_channels[
70
+ (self.nm_channels["target"] == 0) & (self.nm_channels["used"] == 1)
71
+ ]["new_name"]
72
+ )
73
+ self.ch_names_ECOG = self.nm_channels.query(
74
+ '(type=="ecog") and (used == 1) and (status=="good")'
75
+ ).new_name.to_list()
76
+
77
+ # init plotter
78
+ self.nmplotter = nm_plots.NM_Plot()
79
+ if self.nm_channels["target"].sum() > 0:
80
+ self.label_name = self._get_target_ch()
81
+ self.label = self.read_target_ch(
82
+ self.feature_arr,
83
+ self.label_name,
84
+ binarize=binarize_label,
85
+ binarize_th=0.3,
86
+ )
87
+
88
+ def _get_target_ch(self) -> str:
89
+ target_names = list(self.nm_channels[self.nm_channels["target"] == 1]["name"])
90
+ target_clean = [
91
+ target_name
92
+ for target_name in target_names
93
+ for filter_str in target_filter_str
94
+ if filter_str.lower() in target_name.lower()
95
+ ]
96
+
97
+ if len(target_clean) == 0:
98
+ if "ARTIFACT" not in target_names[0]:
99
+ target = target_names[0]
100
+ elif len(target_names) > 1:
101
+ target = target_names[1]
102
+ else:
103
+ target = target_names[0]
104
+ else:
105
+ for target_ in target_clean:
106
+ # try to select contralateral label
107
+ if self.sidecar["sess_right"] and "LEFT" in target_:
108
+ target = target_
109
+ continue
110
+ elif not self.sidecar["sess_right"] and "RIGHT" in target_:
111
+ target = target_
112
+ continue
113
+ if target_ == target_clean[-1]:
114
+ target = target_clean[0] # set label to last element
115
+ return target
116
+
117
+ @staticmethod
118
+ def read_target_ch(
119
+ feature_arr: "pd.DataFrame",
120
+ label_name: str,
121
+ binarize: bool = True,
122
+ binarize_th: float = 0.3,
123
+ ) -> np.ndarray:
124
+ """_summary_
125
+
126
+ Parameters
127
+ ----------
128
+ feature_arr : pd.DataFrame
129
+ _description_
130
+ label_name : str
131
+ _description_
132
+ binarize : bool, optional
133
+ _description_, by default True
134
+ binarize_th : float, optional
135
+ _description_, by default 0.3
136
+
137
+ Returns
138
+ -------
139
+ _type_
140
+ _description_
141
+ """
142
+
143
+ label = np.nan_to_num(np.array(feature_arr[label_name]))
144
+ if binarize:
145
+ label = label > binarize_th
146
+ return label
147
+
148
+ @staticmethod
149
+ def filter_features(
150
+ feature_columns: list,
151
+ ch_name: str | None = None,
152
+ list_feature_keywords: list[str] | None = None,
153
+ ) -> list:
154
+ """filters read features by ch_name and/or modality
155
+
156
+ Parameters
157
+ ----------
158
+ feature_columns : list
159
+ ch_name : str, optional
160
+ list_feature_keywords : list[str], optional
161
+ list of feature strings that need to be in the columns, by default None
162
+
163
+ Returns
164
+ -------
165
+ features : list
166
+ column list that suffice the ch_name and list_feature_keywords
167
+ """
168
+
169
+ if ch_name is not None:
170
+ feature_select = [i for i in list(feature_columns) if ch_name in i]
171
+ else:
172
+ feature_select = feature_columns
173
+
174
+ if list_feature_keywords is not None:
175
+ feature_select = [
176
+ f for f in feature_select if any(x in f for x in list_feature_keywords)
177
+ ]
178
+
179
+ if (
180
+ len(
181
+ [
182
+ mod
183
+ for mod in features_reverse_order_plotting
184
+ if mod in list_feature_keywords
185
+ ]
186
+ )
187
+ > 0
188
+ ):
189
+ # flip list s.t. theta band is lowest in subsequent plot
190
+ feature_select = feature_select[::-1]
191
+
192
+ return feature_select
193
+
194
+ def set_target_ch(self, ch_name: str) -> None:
195
+ self.label_name = ch_name
196
+
197
+ def normalize_features(
198
+ self,
199
+ ) -> "pd.DataFrame":
200
+ """Normalize feature_arr feature columns
201
+
202
+ Returns:
203
+ pd.DataFrame: z-scored feature_arr
204
+ """
205
+ cols_norm = [c for c in self.feature_arr.columns if "time" not in c]
206
+ feature_arr_norm = scipy_zscore(self.feature_arr[cols_norm], nan_policy="omit")
207
+ feature_arr_norm["time"] = self.feature_arr["time"]
208
+ return feature_arr_norm
209
+
210
+ def plot_cort_projection(self) -> None:
211
+ """_summary_"""
212
+
213
+ if self.sidecar["sess_right"]:
214
+ ecog_strip = np.array(self.sidecar["coords"]["cortex_right"]["positions"])
215
+ else:
216
+ ecog_strip = np.array(self.sidecar["coords"]["cortex_left"]["positions"])
217
+ self.nmplotter.plot_cortex(
218
+ grid_cortex=np.array(self.sidecar["grid_cortex"])
219
+ if "grid_cortex" in self.sidecar
220
+ else None,
221
+ ecog_strip=ecog_strip,
222
+ grid_color=np.array(self.sidecar["proj_matrix_cortex"]).sum(axis=1)
223
+ if "grid_cortex" in self.sidecar
224
+ else None,
225
+ set_clim=False,
226
+ )
227
+
228
+ def plot_target_avg_all_channels(
229
+ self,
230
+ ch_names_ECOG=None,
231
+ list_feature_keywords: list[str] = ["stft"],
232
+ epoch_len: int = 4,
233
+ threshold: float = 0.1,
234
+ ):
235
+ """Wrapper that call plot_features_per_channel
236
+ for every given ECoG channel
237
+
238
+ Parameters
239
+ ----------
240
+ ch_names_ECOG : list, optional
241
+ list of ECoG channel to plot features for, by default None
242
+ list_feature_keywords : list[str], optional
243
+ keywords to plot, by default ["stft"]
244
+ epoch_len : int, optional
245
+ epoch length in seconds, by default 4
246
+ threshold : float, optional
247
+ threshold for event detection, by default 0.1
248
+ """
249
+
250
+ if ch_names_ECOG is None:
251
+ ch_names_ECOG = self.ch_names_ECOG
252
+ for ch_name_ECOG in ch_names_ECOG:
253
+ self.plot_target_averaged_channel(
254
+ ch=ch_name_ECOG,
255
+ list_feature_keywords=list_feature_keywords,
256
+ epoch_len=epoch_len,
257
+ threshold=threshold,
258
+ )
259
+
260
+ def plot_target_averaged_channel(
261
+ self,
262
+ ch: str = "",
263
+ list_feature_keywords: list[str] | None = None,
264
+ features_to_plt: list | None = None,
265
+ epoch_len: int = 4,
266
+ threshold: float = 0.1,
267
+ normalize_data: bool = True,
268
+ show_plot: bool = True,
269
+ title: str = "Movement aligned features",
270
+ ytick_labelsize=None,
271
+ figsize_x: float = 8,
272
+ figsize_y: float = 8,
273
+ ) -> None:
274
+ """_summary_
275
+
276
+ Parameters
277
+ ----------
278
+ ch : str, optional
279
+ list_feature_keywords : Optional[list[str]], optional
280
+ features_to_plt : list, optional
281
+ epoch_len : int, optional
282
+ threshold : float, optional
283
+ normalize_data : bool, optional
284
+ show_plot : bool, optional
285
+ title : str, optional
286
+ by default "Movement aligned features"
287
+ ytick_labelsize : _type_, optional
288
+ figsize_x : float, optional
289
+ by default 8
290
+ figsize_y : float, optional
291
+ by default 8
292
+ """
293
+
294
+ # TODO: This does not work properly when we have bipolar rereferencing
295
+
296
+ if features_to_plt is None:
297
+ filtered_df = self.feature_arr[
298
+ self.filter_features(
299
+ self.feature_arr.columns, ch, list_feature_keywords
300
+ )[::-1]
301
+ ]
302
+ else:
303
+ filtered_df = self.feature_arr[features_to_plt]
304
+
305
+ data = np.expand_dims(np.array(filtered_df), axis=1)
306
+
307
+ X_epoch, y_epoch = self.get_epochs(
308
+ data,
309
+ self.label,
310
+ epoch_len=epoch_len,
311
+ sfreq=self.settings.sampling_rate_features_hz,
312
+ threshold=threshold,
313
+ )
314
+
315
+ nm_plots.plot_epochs_avg(
316
+ X_epoch=X_epoch,
317
+ y_epoch=y_epoch,
318
+ epoch_len=epoch_len,
319
+ sfreq=self.settings.sampling_rate_features_hz,
320
+ feature_names=list(filtered_df.columns),
321
+ feature_str_add="_".join(list_feature_keywords)
322
+ if list_feature_keywords is not None
323
+ else "all",
324
+ cut_ch_name_cols=True,
325
+ ch_name=ch,
326
+ label_name=self.label_name,
327
+ normalize_data=normalize_data,
328
+ show_plot=show_plot,
329
+ save=True,
330
+ OUT_PATH=self.feature_dir,
331
+ feature_file=self.feature_file,
332
+ str_title=title,
333
+ ytick_labelsize=ytick_labelsize,
334
+ figsize_x=figsize_x,
335
+ figsize_y=figsize_y,
336
+ )
337
+
338
+ def plot_all_features(
339
+ self,
340
+ ch_used: str | None = None,
341
+ time_limit_low_s: float | None = None,
342
+ time_limit_high_s: float | None = None,
343
+ normalize: bool = True,
344
+ save: bool = False,
345
+ title="all_feature_plt.pdf",
346
+ ytick_labelsize: int = 10,
347
+ clim_low: float | None = None,
348
+ clim_high: float | None = None,
349
+ ):
350
+ """_summary_
351
+
352
+ Parameters
353
+ ----------
354
+ ch_used : str, optional
355
+ time_limit_low_s : float, optional
356
+ time_limit_high_s : float, optional
357
+ normalize : bool, optional
358
+ save : bool, optional
359
+ title : str, optional
360
+ default "all_feature_plt.pdf"
361
+ ytick_labelsize : int, optional
362
+ by default 10
363
+ clim_low : float, optional
364
+ by default None
365
+ clim_high : float, optional
366
+ by default None
367
+ """
368
+
369
+ if ch_used is not None:
370
+ col_used = [
371
+ c
372
+ for c in self.feature_arr.columns
373
+ if c.startswith(ch_used) or c == "time" or "LABEL" in c or "MOV" in c
374
+ ]
375
+ df = self.feature_arr[col_used[::-1]]
376
+ else:
377
+ df = self.feature_arr[self.feature_arr.columns[::-1]]
378
+
379
+ nm_plots.plot_all_features(
380
+ df=df,
381
+ time_limit_low_s=time_limit_low_s,
382
+ time_limit_high_s=time_limit_high_s,
383
+ normalize=normalize,
384
+ save=save,
385
+ title=title,
386
+ ytick_labelsize=ytick_labelsize,
387
+ feature_file=self.feature_file,
388
+ OUT_PATH=self.feature_dir,
389
+ clim_low=clim_low,
390
+ clim_high=clim_high,
391
+ )
392
+
393
+ @staticmethod
394
+ def get_performace_sub_strip(performance_sub: dict, plt_grid: bool = False):
395
+ ecog_strip_performance = []
396
+ ecog_coords_strip = []
397
+ cortex_grid = []
398
+ grid_performance = []
399
+
400
+ channels_ = performance_sub.keys()
401
+
402
+ for ch in channels_:
403
+ if "grid" not in ch and "combined" not in ch:
404
+ ecog_coords_strip.append(performance_sub[ch]["coord"])
405
+ ecog_strip_performance.append(performance_sub[ch]["performance_test"])
406
+ elif plt_grid and "gridcortex_" in ch:
407
+ cortex_grid.append(performance_sub[ch]["coord"])
408
+ grid_performance.append(performance_sub[ch]["performance_test"])
409
+
410
+ if len(ecog_coords_strip) > 0:
411
+ ecog_coords_strip = np.vstack(ecog_coords_strip)
412
+
413
+ return (
414
+ ecog_strip_performance,
415
+ ecog_coords_strip,
416
+ cortex_grid,
417
+ grid_performance,
418
+ )
419
+
420
+ def plot_across_subject_grd_ch_performance(
421
+ self,
422
+ performance_dict=None,
423
+ plt_grid=False,
424
+ feature_str_add="performance_allch_allgrid",
425
+ ):
426
+ ecog_strip_performance = []
427
+ ecog_coords_strip = []
428
+ grid_performance = []
429
+ for sub in performance_dict.keys():
430
+ (
431
+ ecog_strip_performance_sub,
432
+ ecog_coords_strip_sub,
433
+ _,
434
+ grid_performance_sub,
435
+ ) = self.get_performace_sub_strip(performance_dict[sub], plt_grid=plt_grid)
436
+ ecog_strip_performance.extend(ecog_strip_performance_sub)
437
+ ecog_coords_strip.extend(ecog_coords_strip_sub)
438
+ grid_performance.append(grid_performance_sub)
439
+ grid_performance = list(np.vstack(grid_performance).mean(axis=0))
440
+ coords_all = np.array(ecog_coords_strip)
441
+ coords_all[:, 0] = np.abs(coords_all[:, 0])
442
+
443
+ self.nmplotter.plot_cortex(
444
+ grid_cortex=np.array(self.sidecar["grid_cortex"])
445
+ if "grid_cortex" in self.sidecar
446
+ else None,
447
+ ecog_strip=coords_all if len(ecog_coords_strip) > 0 else None,
448
+ grid_color=grid_performance if len(grid_performance) > 0 else None,
449
+ strip_color=np.array(ecog_strip_performance)
450
+ if len(ecog_strip_performance) > 0
451
+ else None,
452
+ sess_right=self.sidecar["sess_right"],
453
+ save=True,
454
+ OUT_PATH=self.feature_dir,
455
+ feature_file=self.feature_file,
456
+ feature_str_add=feature_str_add,
457
+ show_plot=True,
458
+ )
459
+
460
+ def plot_subject_grid_ch_performance(
461
+ self,
462
+ subject_name=None,
463
+ performance_dict=None,
464
+ plt_grid=False,
465
+ feature_str_add="performance_allch_allgrid",
466
+ ):
467
+ """plot subject specific performance for individual channeal and optional grid points
468
+
469
+ Parameters
470
+ ----------
471
+ subject_name : string, optional
472
+ used subject, by default None
473
+ performance_dict : dict, optional
474
+ by default None
475
+ plt_grid : bool, optional
476
+ True to plot grid performances, by default False
477
+ feature_str_add : string, optional
478
+ figure output_name
479
+ """
480
+
481
+ ecog_strip_performance = []
482
+ ecog_coords_strip = []
483
+ cortex_grid = []
484
+ grid_performance = []
485
+
486
+ if subject_name is None:
487
+ subject_name = self.feature_file[
488
+ self.feature_file.find("sub-") : self.feature_file.find("_ses")
489
+ ][4:]
490
+
491
+ (
492
+ ecog_strip_performance,
493
+ ecog_coords_strip,
494
+ cortex_grid,
495
+ grid_performance,
496
+ ) = self.get_performace_sub_strip(
497
+ performance_dict[subject_name], plt_grid=plt_grid
498
+ )
499
+
500
+ self.nmplotter.plot_cortex(
501
+ grid_cortex=np.array(self.sidecar["grid_cortex"])
502
+ if "grid_cortex" in self.sidecar
503
+ else None,
504
+ ecog_strip=ecog_coords_strip if len(ecog_coords_strip) > 0 else None,
505
+ grid_color=grid_performance if len(grid_performance) > 0 else None,
506
+ strip_color=ecog_strip_performance
507
+ if len(ecog_strip_performance) > 0
508
+ else None,
509
+ sess_right=self.sidecar["sess_right"],
510
+ save=True,
511
+ OUT_PATH=self.feature_dir,
512
+ feature_file=self.feature_file,
513
+ feature_str_add=feature_str_add,
514
+ show_plot=True,
515
+ )
516
+
517
+ def plot_feature_series_time(
518
+ self,
519
+ ):
520
+ self.nmplotter.plot_feature_series_time(self.feature_arr)
521
+
522
+ def plot_corr_matrix(
523
+ self,
524
+ ):
525
+ return nm_plots.plot_corr_matrix(
526
+ self.feature_arr,
527
+ )
528
+
529
+ @staticmethod
530
+ def get_epochs(
531
+ data, y_, epoch_len, sfreq, threshold=0
532
+ ) -> tuple[np.ndarray, np.ndarray]:
533
+ """Return epoched data.
534
+
535
+ Parameters
536
+ ----------
537
+ data : np.ndarray
538
+ array of extracted features of shape (n_samples, n_channels, n_features)
539
+ y_ : np.ndarray
540
+ array of labels e.g. ones for movement and zeros for
541
+ no movement or baseline corr. rotameter data
542
+ epoch_len : int
543
+ length of epoch in seconds
544
+ sfreq : int/float
545
+ sampling frequency of data
546
+ threshold : int/float
547
+ (Optional) threshold to be used for identifying events
548
+ (default=0 for y_tr with only ones
549
+ and zeros)
550
+
551
+ Returns
552
+ -------
553
+ epoch_ : np.ndarray
554
+ array of epoched ieeg data with shape (epochs,samples,channels,features)
555
+ y_arr : np.ndarray
556
+ array of epoched event label data with shape (epochs,samples)
557
+ """
558
+
559
+ epoch_lim = int(epoch_len * sfreq)
560
+
561
+ ind_mov = np.where(np.diff(np.array(y_ > threshold) * 1) == 1)[0]
562
+
563
+ low_limit = ind_mov > epoch_lim / 2
564
+ up_limit = ind_mov < y_.shape[0] - epoch_lim / 2
565
+
566
+ ind_mov = ind_mov[low_limit & up_limit]
567
+
568
+ epoch_ = np.zeros([ind_mov.shape[0], epoch_lim, data.shape[1], data.shape[2]])
569
+
570
+ y_arr = np.zeros([ind_mov.shape[0], int(epoch_lim)])
571
+
572
+ for idx, i in enumerate(ind_mov):
573
+ epoch_[idx, :, :, :] = data[i - epoch_lim // 2 : i + epoch_lim // 2, :, :]
574
+
575
+ y_arr[idx, :] = y_[i - epoch_lim // 2 : i + epoch_lim // 2]
576
+
577
+ return epoch_, y_arr
578
+
579
+ def set_decoder(
580
+ self,
581
+ decoder: Decoder | None = None,
582
+ TRAIN_VAL_SPLIT=False,
583
+ RUN_BAY_OPT=False,
584
+ save_coef=False,
585
+ model=LogisticRegression,
586
+ eval_method=r2_score,
587
+ cv_method=KFold(n_splits=3, shuffle=False),
588
+ get_movement_detection_rate: bool = False,
589
+ mov_detection_threshold=0.5,
590
+ min_consequent_count=3,
591
+ threshold_score=True,
592
+ bay_opt_param_space: list = [],
593
+ STACK_FEATURES_N_SAMPLES=False,
594
+ time_stack_n_samples=5,
595
+ use_nested_cv=False,
596
+ VERBOSE=False,
597
+ undersampling=False,
598
+ oversampling=False,
599
+ mrmr_select=False,
600
+ pca=False,
601
+ cca=False,
602
+ ):
603
+ if decoder is not None:
604
+ self.decoder = decoder
605
+ else:
606
+ self.decoder = Decoder(
607
+ features=self.feature_arr,
608
+ label=self.label,
609
+ label_name=self.label_name,
610
+ used_chs=self.used_chs,
611
+ model=model,
612
+ eval_method=eval_method,
613
+ cv_method=cv_method,
614
+ threshold_score=threshold_score,
615
+ TRAIN_VAL_SPLIT=TRAIN_VAL_SPLIT,
616
+ RUN_BAY_OPT=RUN_BAY_OPT,
617
+ save_coef=save_coef,
618
+ get_movement_detection_rate=get_movement_detection_rate,
619
+ min_consequent_count=min_consequent_count,
620
+ mov_detection_threshold=mov_detection_threshold,
621
+ bay_opt_param_space=bay_opt_param_space,
622
+ STACK_FEATURES_N_SAMPLES=STACK_FEATURES_N_SAMPLES,
623
+ time_stack_n_samples=time_stack_n_samples,
624
+ VERBOSE=VERBOSE,
625
+ use_nested_cv=use_nested_cv,
626
+ undersampling=undersampling,
627
+ oversampling=oversampling,
628
+ mrmr_select=mrmr_select,
629
+ sfreq=self.sfreq,
630
+ pca=pca,
631
+ cca=cca,
632
+ )
633
+
634
+ def run_ML_model(
635
+ self,
636
+ feature_file: str | None = None,
637
+ estimate_gridpoints: bool = False,
638
+ estimate_channels: bool = True,
639
+ estimate_all_channels_combined: bool = False,
640
+ output_name: str = "LM",
641
+ save_results: bool = True,
642
+ ):
643
+ """machine learning model evaluation for ECoG strip channels and/or grid points
644
+
645
+ Parameters
646
+ ----------
647
+ feature_file : string, optional
648
+ estimate_gridpoints : bool, optional
649
+ run ML analysis for grid points, by default True
650
+ estimate_channels : bool, optional
651
+ run ML analysis for ECoG strip channel, by default True
652
+ estimate_all_channels_combined : bool, optional
653
+ run ML analysis features of all channels concatenated, by default False
654
+ model : sklearn model, optional
655
+ ML model, needs to obtain fit and predict functions,
656
+ by default linear_model.LogisticRegression(class_weight="balanced")
657
+ eval_method : sklearn.metrics, optional
658
+ evaluation performance metric, by default metrics.balanced_accuracy_score
659
+ cv_method : sklearn.model_selection, optional
660
+ valdation strategy, by default model_selection.KFold(n_splits=3, shuffle=False)
661
+ output_name : str, optional
662
+ saving name, by default "LM"
663
+ save_results : boolean
664
+ if true, save model._coef trained coefficients
665
+ """
666
+ if feature_file is None:
667
+ feature_file = self.feature_file
668
+
669
+ if estimate_gridpoints:
670
+ self.decoder.set_data_grid_points()
671
+ _ = self.decoder.run_CV_caller("grid_points")
672
+ if estimate_channels:
673
+ self.decoder.set_data_ind_channels()
674
+ _ = self.decoder.run_CV_caller("ind_channels")
675
+ if estimate_all_channels_combined:
676
+ _ = self.decoder.run_CV_caller("all_channels_combined")
677
+
678
+ if save_results:
679
+ self.decoder.save(
680
+ self.feature_dir,
681
+ self.feature_file
682
+ if ".vhdr" in self.feature_file
683
+ else self.feature_file,
684
+ output_name,
685
+ )
686
+
687
+ return self.read_results(
688
+ read_grid_points=estimate_gridpoints,
689
+ read_all_combined=estimate_all_channels_combined,
690
+ read_channels=estimate_channels,
691
+ ML_model_name=output_name,
692
+ read_mov_detection_rates=self.decoder.get_movement_detection_rate,
693
+ read_bay_opt_params=self.decoder.RUN_BAY_OPT,
694
+ read_mrmr=self.decoder.mrmr_select,
695
+ model_save=self.decoder.model_save,
696
+ )
697
+
698
+ def read_results(
699
+ self,
700
+ performance_dict: dict = {},
701
+ subject_name: str | None = None,
702
+ DEFAULT_PERFORMANCE: float = 0.5,
703
+ read_grid_points: bool = True,
704
+ read_channels: bool = True,
705
+ read_all_combined: bool = False,
706
+ ML_model_name: str = "LM",
707
+ read_mov_detection_rates: bool = False,
708
+ read_bay_opt_params: bool = False,
709
+ read_mrmr: bool = False,
710
+ model_save: bool = False,
711
+ save_results: bool = False,
712
+ PATH_OUT: str = "", # Removed None default, save_general_dict does not handle None anyway
713
+ folder_name: str = "",
714
+ str_add: str = "",
715
+ ):
716
+ """Save performances of a given patient into performance_dict from saved nm_decoder
717
+
718
+ Parameters
719
+ ----------
720
+ performance_dict : dictionary
721
+ dictionary including decoding performances, by default dictionary
722
+ subject_name : string, optional
723
+ subject name, by default None
724
+ DEFAULT_PERFORMANCE : float, optional
725
+ chance performance, by default 0.5
726
+ read_grid_points : bool, optional
727
+ true if grid point performances are read, by default True
728
+ read_channels : bool, optional
729
+ true if channels performances are read, by default True
730
+ read_all_combined : bool, optional
731
+ true if all combined channel performances are read, by default False
732
+ ML_model_name : str, optional
733
+ machine learning model name, by default 'LM'
734
+ read_mov_detection_rates : boolean, by defaulte False
735
+ if True, read movement detection rates, as well as fpr's and tpr's
736
+ read_bay_opt_params : boolean, by default False
737
+ read_mrmr : boolean, by default False
738
+ model_save : boolean, by default False
739
+ save_results : boolean, by default False
740
+ PATH_OUT : string, by default None
741
+ folder_name : string, by default None
742
+ str_add : string, by default None
743
+
744
+ Returns
745
+ -------
746
+ performance_dict : dictionary
747
+
748
+ """
749
+
750
+ if ".vhdr" in self.feature_file:
751
+ feature_file = self.feature_file[: -len(".vhdr")]
752
+ else:
753
+ feature_file = self.feature_file
754
+
755
+ if subject_name is None:
756
+ subject_name = feature_file[
757
+ feature_file.find("sub-") : feature_file.find("_ses")
758
+ ][4:]
759
+
760
+ PATH_ML_ = PurePath(
761
+ self.feature_dir,
762
+ feature_file,
763
+ feature_file + "_" + ML_model_name + "_ML_RES.p",
764
+ )
765
+
766
+ # read ML results
767
+ with open(PATH_ML_, "rb") as input:
768
+ ML_res = pickle.load(input)
769
+ if self.decoder is None:
770
+ self.decoder = ML_res
771
+
772
+ performance_dict[subject_name] = {}
773
+
774
+ def write_CV_res_in_performance_dict(
775
+ obj_read,
776
+ obj_write,
777
+ read_mov_detection_rates=read_mov_detection_rates,
778
+ read_bay_opt_params=False,
779
+ ):
780
+ def transform_list_of_dicts_into_dict_of_lists(l_):
781
+ dict_out = {}
782
+ for key_, _ in l_[0].items():
783
+ key_l = []
784
+ for dict_ in l_:
785
+ key_l.append(dict_[key_])
786
+ dict_out[key_] = key_l
787
+ return dict_out
788
+
789
+ def read_ML_performances(
790
+ obj_read, obj_write, set_inner_CV_res: bool = False
791
+ ):
792
+ def set_score(
793
+ key_set: str = "",
794
+ key_get: str = "",
795
+ take_mean: bool = True,
796
+ val=None,
797
+ ):
798
+ if set_inner_CV_res:
799
+ key_set = "InnerCV_" + key_set
800
+ key_get = "InnerCV_" + key_get
801
+ if take_mean:
802
+ val = np.mean(obj_read[key_get])
803
+ obj_write[key_set] = val
804
+
805
+ set_score(
806
+ key_set="performance_test",
807
+ key_get="score_test",
808
+ take_mean=True,
809
+ )
810
+ set_score(
811
+ key_set="performance_train",
812
+ key_get="score_train",
813
+ take_mean=True,
814
+ )
815
+
816
+ if "coef" in obj_read:
817
+ set_score(
818
+ key_set="coef",
819
+ key_get="coef",
820
+ take_mean=False,
821
+ val=np.concatenate(obj_read["coef"]),
822
+ )
823
+
824
+ if read_mov_detection_rates:
825
+ set_score(
826
+ key_set="mov_detection_rates_test",
827
+ key_get="mov_detection_rates_test",
828
+ take_mean=True,
829
+ )
830
+ set_score(
831
+ key_set="mov_detection_rates_train",
832
+ key_get="mov_detection_rates_train",
833
+ take_mean=True,
834
+ )
835
+ set_score(
836
+ key_set="fprate_test",
837
+ key_get="fprate_test",
838
+ take_mean=True,
839
+ )
840
+ set_score(
841
+ key_set="fprate_train",
842
+ key_get="fprate_train",
843
+ take_mean=True,
844
+ )
845
+ set_score(
846
+ key_set="tprate_test",
847
+ key_get="tprate_test",
848
+ take_mean=True,
849
+ )
850
+ set_score(
851
+ key_set="tprate_train",
852
+ key_get="tprate_train",
853
+ take_mean=True,
854
+ )
855
+
856
+ if read_bay_opt_params:
857
+ # transform dict into keys for json saving
858
+ dict_to_save = transform_list_of_dicts_into_dict_of_lists(
859
+ obj_read["best_bay_opt_params"]
860
+ )
861
+ set_score(
862
+ key_set="bay_opt_best_params",
863
+ take_mean=False,
864
+ val=dict_to_save,
865
+ )
866
+
867
+ if read_mrmr:
868
+ # transform dict into keys for json saving
869
+
870
+ set_score(
871
+ key_set="mrmr_select",
872
+ take_mean=False,
873
+ val=obj_read["mrmr_select"],
874
+ )
875
+ if model_save:
876
+ set_score(
877
+ key_set="model_save",
878
+ take_mean=False,
879
+ val=obj_read["model_save"],
880
+ )
881
+
882
+ read_ML_performances(obj_read, obj_write)
883
+
884
+ if len([key_ for key_ in obj_read.keys() if "InnerCV_" in key_]) > 0:
885
+ read_ML_performances(obj_read, obj_write, set_inner_CV_res=True)
886
+
887
+ if read_channels:
888
+ ch_to_use = self.ch_names_ECOG
889
+ ch_to_use = self.decoder.used_chs
890
+ for ch in ch_to_use:
891
+ performance_dict[subject_name][ch] = {}
892
+
893
+ if "coords" in self.sidecar:
894
+ if len(self.sidecar["coords"]) > 0: # check if coords are empty
895
+ coords_exist = False
896
+ for cortex_loc in self.sidecar["coords"].keys():
897
+ for ch_name_coord_idx, ch_name_coord in enumerate(
898
+ self.sidecar["coords"][cortex_loc]["ch_names"]
899
+ ):
900
+ if ch.startswith(ch_name_coord):
901
+ coords = self.sidecar["coords"][cortex_loc][
902
+ "positions"
903
+ ][ch_name_coord_idx]
904
+ coords_exist = (
905
+ True # optimally break out of the two loops...
906
+ )
907
+ if not coords_exist:
908
+ coords = None
909
+ performance_dict[subject_name][ch]["coord"] = coords
910
+ write_CV_res_in_performance_dict(
911
+ ML_res.ch_ind_results[ch],
912
+ performance_dict[subject_name][ch],
913
+ read_mov_detection_rates=read_mov_detection_rates,
914
+ read_bay_opt_params=read_bay_opt_params,
915
+ )
916
+
917
+ if read_all_combined:
918
+ performance_dict[subject_name]["all_ch_combined"] = {}
919
+ write_CV_res_in_performance_dict(
920
+ ML_res.all_ch_results,
921
+ performance_dict[subject_name]["all_ch_combined"],
922
+ read_mov_detection_rates=read_mov_detection_rates,
923
+ read_bay_opt_params=read_bay_opt_params,
924
+ )
925
+
926
+ if read_grid_points:
927
+ performance_dict[subject_name]["active_gridpoints"] = (
928
+ ML_res.active_gridpoints
929
+ )
930
+
931
+ for project_settings, grid_type in zip(
932
+ ["project_cortex", "project_subcortex"],
933
+ ["gridcortex_", "gridsubcortex_"],
934
+ ):
935
+ if not self.settings.postprocessing[project_settings]:
936
+ continue
937
+
938
+ # the sidecar keys are grid_cortex and subcortex_grid
939
+ for grid_point in range(
940
+ len(self.sidecar["grid_" + project_settings.split("_")[1]])
941
+ ):
942
+ gp_str = grid_type + str(grid_point)
943
+
944
+ performance_dict[subject_name][gp_str] = {}
945
+ performance_dict[subject_name][gp_str]["coord"] = self.sidecar[
946
+ "grid_" + project_settings.split("_")[1]
947
+ ][grid_point]
948
+
949
+ if gp_str in ML_res.active_gridpoints:
950
+ write_CV_res_in_performance_dict(
951
+ ML_res.gridpoint_ind_results[gp_str],
952
+ performance_dict[subject_name][gp_str],
953
+ read_mov_detection_rates=read_mov_detection_rates,
954
+ read_bay_opt_params=read_bay_opt_params,
955
+ )
956
+ else:
957
+ # set non interpolated grid point to default performance
958
+ performance_dict[subject_name][gp_str]["performance_test"] = (
959
+ DEFAULT_PERFORMANCE
960
+ )
961
+ performance_dict[subject_name][gp_str]["performance_train"] = (
962
+ DEFAULT_PERFORMANCE
963
+ )
964
+
965
+ if save_results:
966
+ nm_IO.save_general_dict(
967
+ dict_=performance_dict,
968
+ path_out=PATH_OUT,
969
+ prefix=folder_name,
970
+ str_add=str_add,
971
+ )
972
+ return performance_dict
973
+
974
+ @staticmethod
975
+ def get_dataframe_performances(p: dict) -> "pd.DataFrame":
976
+ performances = []
977
+ for sub in p.keys():
978
+ for ch in p[sub].keys():
979
+ if "active_gridpoints" in ch:
980
+ continue
981
+ dict_add = p[sub][ch].copy()
982
+ dict_add["sub"] = sub
983
+ dict_add["ch"] = ch
984
+
985
+ if "all_ch_" in ch:
986
+ dict_add["ch_type"] = "all ch combinded"
987
+ elif "gridcortex" in ch:
988
+ dict_add["ch_type"] = "cortex grid"
989
+ else:
990
+ dict_add["ch_type"] = "electrode ch"
991
+ performances.append(dict_add)
992
+
993
+ return pd.DataFrame(performances)