py-neuromodulation 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
  3. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -0
  4. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -0
  5. py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
  6. py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
  7. py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
  8. py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
  9. py_neuromodulation/FieldTrip.py +589 -589
  10. py_neuromodulation/__init__.py +74 -13
  11. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  12. py_neuromodulation/data/README +6 -0
  13. py_neuromodulation/data/dataset_description.json +8 -0
  14. py_neuromodulation/data/participants.json +32 -0
  15. py_neuromodulation/data/participants.tsv +2 -0
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -0
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -0
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -0
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.eeg +0 -0
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -0
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -0
  22. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -0
  23. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -0
  24. py_neuromodulation/grid_cortex.tsv +40 -0
  25. py_neuromodulation/grid_subcortex.tsv +1429 -0
  26. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  27. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  28. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  29. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  30. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  31. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  32. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  33. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  34. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  35. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  36. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  37. py_neuromodulation/nm_IO.py +413 -417
  38. py_neuromodulation/nm_RMAP.py +496 -531
  39. py_neuromodulation/nm_analysis.py +993 -1074
  40. py_neuromodulation/nm_artifacts.py +30 -25
  41. py_neuromodulation/nm_bispectra.py +154 -168
  42. py_neuromodulation/nm_bursts.py +292 -198
  43. py_neuromodulation/nm_coherence.py +251 -205
  44. py_neuromodulation/nm_database.py +149 -0
  45. py_neuromodulation/nm_decode.py +918 -992
  46. py_neuromodulation/nm_define_nmchannels.py +300 -302
  47. py_neuromodulation/nm_features.py +144 -116
  48. py_neuromodulation/nm_filter.py +219 -219
  49. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  50. py_neuromodulation/nm_fooof.py +139 -159
  51. py_neuromodulation/nm_generator.py +45 -37
  52. py_neuromodulation/nm_hjorth_raw.py +52 -73
  53. py_neuromodulation/nm_kalmanfilter.py +71 -58
  54. py_neuromodulation/nm_linelength.py +21 -33
  55. py_neuromodulation/nm_logger.py +66 -0
  56. py_neuromodulation/nm_mne_connectivity.py +149 -112
  57. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  58. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  59. py_neuromodulation/nm_nolds.py +96 -93
  60. py_neuromodulation/nm_normalization.py +173 -214
  61. py_neuromodulation/nm_oscillatory.py +423 -448
  62. py_neuromodulation/nm_plots.py +585 -612
  63. py_neuromodulation/nm_preprocessing.py +83 -0
  64. py_neuromodulation/nm_projection.py +370 -394
  65. py_neuromodulation/nm_rereference.py +97 -95
  66. py_neuromodulation/nm_resample.py +59 -50
  67. py_neuromodulation/nm_run_analysis.py +325 -435
  68. py_neuromodulation/nm_settings.py +289 -68
  69. py_neuromodulation/nm_settings.yaml +244 -0
  70. py_neuromodulation/nm_sharpwaves.py +423 -401
  71. py_neuromodulation/nm_stats.py +464 -480
  72. py_neuromodulation/nm_stream.py +398 -0
  73. py_neuromodulation/nm_stream_abc.py +166 -218
  74. py_neuromodulation/nm_types.py +193 -0
  75. py_neuromodulation/plots/STN_surf.mat +0 -0
  76. py_neuromodulation/plots/Vertices.mat +0 -0
  77. py_neuromodulation/plots/faces.mat +0 -0
  78. py_neuromodulation/plots/grid.mat +0 -0
  79. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +185 -182
  80. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  81. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -2
  82. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info/licenses}/LICENSE +21 -21
  83. docs/build/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  84. docs/build/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -233
  85. docs/build/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  86. docs/build/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  87. docs/build/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  88. docs/build/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  89. docs/build/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  90. docs/build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  91. docs/build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -239
  92. docs/build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  93. docs/build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  94. docs/build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  95. docs/build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  96. docs/build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  97. docs/source/_build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -76
  98. docs/source/_build/html/_downloads/0d0d0a76e8f648d5d3cbc47da6351932/plot_real_time_demo.py +0 -97
  99. docs/source/_build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -240
  100. docs/source/_build/html/_downloads/5d73cadc59a8805c47e3b84063afc157/plot_example_BIDS.py +0 -233
  101. docs/source/_build/html/_downloads/7660317fa5a6bfbd12fcca9961457fc4/plot_example_rmap_computing.py +0 -63
  102. docs/source/_build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  103. docs/source/_build/html/_downloads/839e5b319379f7fd9e867deb00fd797f/plot_example_gridPointProjection.py +0 -210
  104. docs/source/_build/html/_downloads/ae8be19afe5e559f011fc9b138968ba0/plot_first_demo.py +0 -192
  105. docs/source/_build/html/_downloads/b8b06cacc17969d3725a0b6f1d7741c5/plot_example_sharpwave_analysis.py +0 -219
  106. docs/source/_build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -121
  107. docs/source/_build/html/_downloads/c31a86c0b68cb4167d968091ace8080d/plot_example_add_feature.py +0 -68
  108. docs/source/_build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  109. docs/source/_build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -189
  110. docs/source/_build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  111. docs/source/auto_examples/plot_0_first_demo.py +0 -189
  112. docs/source/auto_examples/plot_1_example_BIDS.py +0 -240
  113. docs/source/auto_examples/plot_2_example_add_feature.py +0 -76
  114. docs/source/auto_examples/plot_3_example_sharpwave_analysis.py +0 -219
  115. docs/source/auto_examples/plot_4_example_gridPointProjection.py +0 -210
  116. docs/source/auto_examples/plot_5_example_rmap_computing.py +0 -64
  117. docs/source/auto_examples/plot_6_real_time_demo.py +0 -121
  118. docs/source/conf.py +0 -105
  119. examples/plot_0_first_demo.py +0 -189
  120. examples/plot_1_example_BIDS.py +0 -240
  121. examples/plot_2_example_add_feature.py +0 -76
  122. examples/plot_3_example_sharpwave_analysis.py +0 -219
  123. examples/plot_4_example_gridPointProjection.py +0 -210
  124. examples/plot_5_example_rmap_computing.py +0 -64
  125. examples/plot_6_real_time_demo.py +0 -121
  126. packages/realtime_decoding/build/lib/realtime_decoding/__init__.py +0 -4
  127. packages/realtime_decoding/build/lib/realtime_decoding/decoder.py +0 -104
  128. packages/realtime_decoding/build/lib/realtime_decoding/features.py +0 -163
  129. packages/realtime_decoding/build/lib/realtime_decoding/helpers.py +0 -15
  130. packages/realtime_decoding/build/lib/realtime_decoding/run_decoding.py +0 -345
  131. packages/realtime_decoding/build/lib/realtime_decoding/trainer.py +0 -54
  132. packages/tmsi/build/lib/TMSiFileFormats/__init__.py +0 -37
  133. packages/tmsi/build/lib/TMSiFileFormats/file_formats/__init__.py +0 -36
  134. packages/tmsi/build/lib/TMSiFileFormats/file_formats/lsl_stream_writer.py +0 -200
  135. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_file_writer.py +0 -496
  136. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_to_edf_converter.py +0 -236
  137. packages/tmsi/build/lib/TMSiFileFormats/file_formats/xdf_file_writer.py +0 -977
  138. packages/tmsi/build/lib/TMSiFileFormats/file_readers/__init__.py +0 -35
  139. packages/tmsi/build/lib/TMSiFileFormats/file_readers/edf_reader.py +0 -116
  140. packages/tmsi/build/lib/TMSiFileFormats/file_readers/poly5reader.py +0 -294
  141. packages/tmsi/build/lib/TMSiFileFormats/file_readers/xdf_reader.py +0 -229
  142. packages/tmsi/build/lib/TMSiFileFormats/file_writer.py +0 -102
  143. packages/tmsi/build/lib/TMSiPlotters/__init__.py +0 -2
  144. packages/tmsi/build/lib/TMSiPlotters/gui/__init__.py +0 -39
  145. packages/tmsi/build/lib/TMSiPlotters/gui/_plotter_gui.py +0 -234
  146. packages/tmsi/build/lib/TMSiPlotters/gui/plotting_gui.py +0 -440
  147. packages/tmsi/build/lib/TMSiPlotters/plotters/__init__.py +0 -44
  148. packages/tmsi/build/lib/TMSiPlotters/plotters/hd_emg_plotter.py +0 -446
  149. packages/tmsi/build/lib/TMSiPlotters/plotters/impedance_plotter.py +0 -589
  150. packages/tmsi/build/lib/TMSiPlotters/plotters/signal_plotter.py +0 -1326
  151. packages/tmsi/build/lib/TMSiSDK/__init__.py +0 -54
  152. packages/tmsi/build/lib/TMSiSDK/device.py +0 -588
  153. packages/tmsi/build/lib/TMSiSDK/devices/__init__.py +0 -34
  154. packages/tmsi/build/lib/TMSiSDK/devices/saga/TMSi_Device_API.py +0 -1764
  155. packages/tmsi/build/lib/TMSiSDK/devices/saga/__init__.py +0 -34
  156. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_device.py +0 -1366
  157. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_types.py +0 -520
  158. packages/tmsi/build/lib/TMSiSDK/devices/saga/xml_saga_config.py +0 -165
  159. packages/tmsi/build/lib/TMSiSDK/error.py +0 -95
  160. packages/tmsi/build/lib/TMSiSDK/sample_data.py +0 -63
  161. packages/tmsi/build/lib/TMSiSDK/sample_data_server.py +0 -99
  162. packages/tmsi/build/lib/TMSiSDK/settings.py +0 -45
  163. packages/tmsi/build/lib/TMSiSDK/tmsi_device.py +0 -111
  164. packages/tmsi/build/lib/__init__.py +0 -4
  165. packages/tmsi/build/lib/apex_sdk/__init__.py +0 -34
  166. packages/tmsi/build/lib/apex_sdk/device/__init__.py +0 -41
  167. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API.py +0 -1009
  168. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_enums.py +0 -239
  169. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_structures.py +0 -668
  170. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_device.py +0 -1611
  171. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_dongle.py +0 -38
  172. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_event_reader.py +0 -57
  173. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_channel.py +0 -44
  174. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_config.py +0 -150
  175. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_const.py +0 -36
  176. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_impedance_channel.py +0 -48
  177. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_info.py +0 -108
  178. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/dongle_info.py +0 -39
  179. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/download_measurement.py +0 -77
  180. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/eeg_measurement.py +0 -150
  181. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/impedance_measurement.py +0 -129
  182. packages/tmsi/build/lib/apex_sdk/device/threads/conversion_thread.py +0 -59
  183. packages/tmsi/build/lib/apex_sdk/device/threads/sampling_thread.py +0 -57
  184. packages/tmsi/build/lib/apex_sdk/device/tmsi_channel.py +0 -83
  185. packages/tmsi/build/lib/apex_sdk/device/tmsi_device.py +0 -201
  186. packages/tmsi/build/lib/apex_sdk/device/tmsi_device_enums.py +0 -103
  187. packages/tmsi/build/lib/apex_sdk/device/tmsi_dongle.py +0 -43
  188. packages/tmsi/build/lib/apex_sdk/device/tmsi_event_reader.py +0 -50
  189. packages/tmsi/build/lib/apex_sdk/device/tmsi_measurement.py +0 -118
  190. packages/tmsi/build/lib/apex_sdk/sample_data_server/__init__.py +0 -33
  191. packages/tmsi/build/lib/apex_sdk/sample_data_server/event_data.py +0 -44
  192. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data.py +0 -50
  193. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data_server.py +0 -136
  194. packages/tmsi/build/lib/apex_sdk/tmsi_errors/error.py +0 -126
  195. packages/tmsi/build/lib/apex_sdk/tmsi_sdk.py +0 -113
  196. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/apex/apex_structure_generator.py +0 -134
  197. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/decorators.py +0 -60
  198. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/logger_filter.py +0 -42
  199. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/singleton.py +0 -42
  200. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/support_functions.py +0 -72
  201. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/tmsi_logger.py +0 -98
  202. py_neuromodulation/nm_EpochStream.py +0 -92
  203. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  204. py_neuromodulation/nm_cohortwrapper.py +0 -435
  205. py_neuromodulation/nm_eval_timing.py +0 -239
  206. py_neuromodulation/nm_features_abc.py +0 -39
  207. py_neuromodulation/nm_stream_offline.py +0 -358
  208. py_neuromodulation/utils/_logging.py +0 -24
  209. py_neuromodulation-0.0.3.dist-info/RECORD +0 -188
  210. py_neuromodulation-0.0.3.dist-info/top_level.txt +0 -5
  211. tests/__init__.py +0 -0
  212. tests/conftest.py +0 -117
  213. tests/test_all_examples.py +0 -10
  214. tests/test_all_features.py +0 -63
  215. tests/test_bispectra.py +0 -70
  216. tests/test_bursts.py +0 -105
  217. tests/test_feature_sampling_rates.py +0 -143
  218. tests/test_fooof.py +0 -16
  219. tests/test_initalization_offline_stream.py +0 -41
  220. tests/test_multiprocessing.py +0 -58
  221. tests/test_nan_values.py +0 -29
  222. tests/test_nm_filter.py +0 -95
  223. tests/test_nm_resample.py +0 -63
  224. tests/test_normalization_settings.py +0 -146
  225. tests/test_notch_filter.py +0 -31
  226. tests/test_osc_features.py +0 -424
  227. tests/test_preprocessing_filter.py +0 -151
  228. tests/test_rereference.py +0 -171
  229. tests/test_sampling.py +0 -57
  230. tests/test_settings_change_after_init.py +0 -76
  231. tests/test_sharpwave.py +0 -165
  232. tests/test_target_channel_add.py +0 -100
  233. tests/test_timing.py +0 -80
@@ -1,612 +1,585 @@
1
- from scipy import stats
2
- import os
3
- import numpy as np
4
- from matplotlib import pyplot as plt
5
- from matplotlib import gridspec
6
- from typing import Optional
7
- import seaborn as sb
8
- import pandas as pd
9
- import logging
10
-
11
- logger = logging.getLogger("PynmLogger")
12
-
13
- from py_neuromodulation import nm_IO, nm_stats
14
-
15
-
16
- def plot_df_subjects(
17
- df,
18
- x_col="sub",
19
- y_col="performance_test",
20
- hue=None,
21
- title="channel specific performances",
22
- PATH_SAVE: str = None,
23
- figsize_tuple: tuple = (5, 3),
24
- ):
25
- alpha_box = 0.4
26
- plt.figure(figsize=figsize_tuple, dpi=300)
27
- sb.boxplot(
28
- x=x_col,
29
- y=y_col,
30
- hue=hue,
31
- data=df,
32
- palette="viridis",
33
- showmeans=False,
34
- boxprops=dict(alpha=alpha_box),
35
- showcaps=True,
36
- showbox=True,
37
- showfliers=False,
38
- notch=False,
39
- whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
40
- capprops={"alpha": alpha_box},
41
- medianprops=dict(
42
- linestyle="-", linewidth=5, color="gray", alpha=alpha_box
43
- ),
44
- )
45
-
46
- ax = sb.stripplot(
47
- x=x_col,
48
- y=y_col,
49
- hue=hue,
50
- data=df,
51
- palette="viridis",
52
- dodge=True,
53
- s=5,
54
- )
55
-
56
- if hue is not None:
57
- n_hues = df[hue].nunique()
58
-
59
- handles, labels = ax.get_legend_handles_labels()
60
- l = plt.legend(
61
- handles[0:n_hues],
62
- labels[0:n_hues],
63
- bbox_to_anchor=(1.05, 1),
64
- loc=2,
65
- title=hue,
66
- borderaxespad=0.0,
67
- )
68
- plt.title(title)
69
- plt.ylabel(y_col)
70
- plt.xticks(rotation=90)
71
- if PATH_SAVE is not None:
72
- plt.savefig(
73
- PATH_SAVE,
74
- bbox_inches="tight",
75
- )
76
- # plt.show()
77
- return plt.gca()
78
-
79
-
80
- def plot_epoch(
81
- X_epoch: np.array,
82
- y_epoch: np.array,
83
- feature_names: list,
84
- z_score: bool = None,
85
- epoch_len: int = 4,
86
- sfreq: int = 10,
87
- str_title: str = None,
88
- str_label: str = None,
89
- ytick_labelsize: float = None,
90
- ):
91
- if z_score is None:
92
- X_epoch = stats.zscore(
93
- np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
94
- axis=0,
95
- nan_policy="omit",
96
- ).T
97
- y_epoch = np.stack(np.array(y_epoch))
98
- plt.figure(figsize=(6, 6))
99
- plt.subplot(211)
100
- plt.imshow(X_epoch, aspect="auto")
101
- plt.yticks(
102
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
103
- )
104
- plt.xticks(
105
- np.arange(0, X_epoch.shape[1], 1),
106
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
107
- rotation=90,
108
- )
109
- plt.gca().invert_yaxis()
110
- plt.xlabel("Time [s]")
111
- plt.title(str_title)
112
-
113
- plt.subplot(212)
114
- for i in range(y_epoch.shape[0]):
115
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
116
- plt.plot(
117
- y_epoch.mean(axis=0),
118
- color="black",
119
- alpha=1,
120
- linewidth=3.0,
121
- label="mean target",
122
- )
123
- plt.legend()
124
- plt.ylabel("Target")
125
- plt.title(str_label)
126
- plt.xticks(
127
- np.arange(0, X_epoch.shape[1], 1),
128
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
129
- rotation=90,
130
- )
131
- plt.xlabel("Time [s]")
132
- plt.tight_layout()
133
-
134
-
135
- def reg_plot(
136
- x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str = None
137
- ):
138
- plt.figure(figsize=(4, 4), dpi=300)
139
- rho, p = nm_stats.permutationTestSpearmansRho(
140
- data[x_col],
141
- data[y_col],
142
- False,
143
- "R^2",
144
- 5000,
145
- )
146
- sb.regplot(x=x_col, y=y_col, data=data)
147
- plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
148
-
149
- if out_path_save is not None:
150
- plt.savefig(
151
- out_path_save,
152
- bbox_inches="tight",
153
- )
154
-
155
-
156
- def plot_bar_performance_per_channel(
157
- ch_names,
158
- performances: dict,
159
- PATH_OUT: str,
160
- sub: str = None,
161
- save_str: str = "ch_comp_bar_plt.png",
162
- performance_metric: str = "Balanced Accuracy",
163
- ):
164
- """
165
- performances dict is output of ml_decode
166
- """
167
- plt.figure(figsize=(4, 3), dpi=300)
168
- if sub is None:
169
- sub = list(performances.keys())[0]
170
- plt.bar(
171
- np.arange(len(ch_names)),
172
- [performances[sub][p]["performance_test"] for p in performances[sub]],
173
- )
174
- plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
175
- plt.xlabel("channels")
176
- plt.ylabel(performance_metric)
177
- plt.savefig(
178
- os.path.join(PATH_OUT, save_str),
179
- bbox_inches="tight",
180
- )
181
- plt.close()
182
-
183
-
184
- def plot_corr_matrix(
185
- feature: pd.DataFrame,
186
- feature_file: str = None,
187
- ch_name: str = None,
188
- feature_names: list[str] = None,
189
- show_plot=True,
190
- OUT_PATH: str = None,
191
- feature_name_plt="Features_corr_matr",
192
- save_plot: bool = False,
193
- save_plot_name: str = None,
194
- figsize: tuple[int] = (7, 7),
195
- title: str = None,
196
- cbar_vmin: float = -1,
197
- cbar_vmax: float = 1.0,
198
- ):
199
- # cut out channel name for each column
200
- if ch_name is not None:
201
- feature_col_name = [
202
- i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
203
- ]
204
- else:
205
- feature_col_name = feature.columns
206
-
207
- plt.figure(figsize=figsize)
208
- if feature_names is not None:
209
- corr = feature[feature_names].corr()
210
- else:
211
- corr = feature.corr()
212
- sb.heatmap(
213
- corr,
214
- xticklabels=feature_col_name,
215
- yticklabels=feature_col_name,
216
- vmin=cbar_vmin,
217
- vmax=cbar_vmax,
218
- cmap="viridis",
219
- )
220
- if title is None:
221
- if ch_name is not None:
222
- plt.title("Correlation matrix features channel: " + str(ch_name))
223
- else:
224
- plt.title("Correlation matrix")
225
- else:
226
- plt.title(title)
227
-
228
- # if len(feature_col_name) > 50:
229
- # plt.xticks([])
230
- # plt.yticks([])
231
-
232
- if save_plot and save_plot_name is None:
233
- plt_path = get_plt_path(
234
- OUT_PATH=OUT_PATH,
235
- feature_file=feature_file,
236
- ch_name=ch_name,
237
- str_plt_type=feature_name_plt,
238
- # feature_name=feature_names.__str__, # This here raises an error in os.path.join in line 251
239
- )
240
- if save_plot and save_plot_name is not None:
241
- plt_path = os.path.join(OUT_PATH, save_plot_name)
242
-
243
- if save_plot:
244
- plt.savefig(plt_path, bbox_inches="tight")
245
- logger.info(f"Correlation matrix figure saved to {plt_path}")
246
-
247
- if show_plot is False:
248
- plt.close()
249
-
250
- plt.tight_layout()
251
-
252
- return plt.gca()
253
-
254
-
255
- def plot_feature_series_time(features) -> None:
256
- plt.imshow(features.T, aspect="auto")
257
-
258
-
259
- def get_plt_path(
260
- OUT_PATH: str | None = None,
261
- feature_file: str | None = None,
262
- ch_name: str | None = None,
263
- str_plt_type: str | None = None,
264
- feature_name: str | None = None,
265
- ) -> None:
266
- """[summary]
267
-
268
- Parameters
269
- ----------
270
- OUT_PATH : str, optional
271
- folder of preprocessed runs, by default None
272
- feature_file : str, optional
273
- run_name, by default None
274
- ch_name : str, optional
275
- ch_name, by default None
276
- str_plt_type : str, optional
277
- type of plot, e.g. mov_avg_feature or corr_matr, by default None
278
- feature_name : str, optional
279
- e.g. bandpower, stft, sharpwave_prominence, by default None
280
- """
281
- if None not in (ch_name, OUT_PATH, feature_file):
282
- if feature_name is None:
283
- plt_path = os.path.join(
284
- OUT_PATH,
285
- feature_file,
286
- str_plt_type + "_ch_" + ch_name + ".png",
287
- )
288
- else:
289
- plt_path = os.path.join(
290
- OUT_PATH,
291
- feature_file,
292
- str_plt_type + "_ch_" + ch_name + "_" + feature_name + ".png",
293
- )
294
- elif None not in (OUT_PATH, feature_file) and ch_name is None:
295
- plt_path = os.path.join(
296
- OUT_PATH,
297
- feature_file,
298
- str_plt_type + "_ch_" + feature_name + ".png",
299
- )
300
-
301
- else:
302
- plt_path = os.getcwd() + ".png"
303
- return plt_path
304
-
305
-
306
- def plot_epochs_avg(
307
- X_epoch: np.ndarray,
308
- y_epoch: np.ndarray,
309
- epoch_len: int,
310
- sfreq: int,
311
- feature_names: list[str] = None,
312
- feature_str_add: str = None,
313
- cut_ch_name_cols: bool = True,
314
- ch_name: str = None,
315
- label_name: str = None,
316
- normalize_data: bool = True,
317
- show_plot: bool = True,
318
- save: bool = False,
319
- OUT_PATH: str = None,
320
- feature_file: str = None,
321
- str_title: str = "Movement aligned features",
322
- ytick_labelsize=None,
323
- figsize_x: float = 8,
324
- figsize_y: float = 8,
325
- ) -> None:
326
- # cut channel name of for axis + "_" for more dense plot
327
- if feature_names is None:
328
- if cut_ch_name_cols and None not in (ch_name, feature_names):
329
- feature_names = [
330
- i[len(ch_name) + 1 :]
331
- for i in list(feature_names)
332
- if ch_name in i
333
- ]
334
-
335
- if normalize_data:
336
- X_epoch_mean = stats.zscore(
337
- np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
338
- ).T
339
- else:
340
- X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
341
-
342
- if len(X_epoch_mean.shape) == 1:
343
- X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
344
-
345
- plt.figure(figsize=(figsize_x, figsize_y))
346
- gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
347
- plt.subplot(gs[0])
348
- plt.imshow(X_epoch_mean, aspect="auto")
349
- plt.yticks(
350
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
351
- )
352
- plt.xticks(
353
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
354
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
355
- rotation=90,
356
- )
357
- plt.xlabel("Time [s]")
358
- str_title = str_title
359
- if ch_name:
360
- str_title += f" channel: {ch_name}"
361
- plt.title(str_title)
362
-
363
- plt.subplot(gs[1])
364
- for i in range(y_epoch.shape[0]):
365
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
366
- plt.plot(
367
- y_epoch.mean(axis=0),
368
- color="black",
369
- alpha=1,
370
- linewidth=3.0,
371
- label="mean target",
372
- )
373
- plt.legend()
374
- plt.ylabel("Target")
375
- plt.title(label_name)
376
- plt.xticks(
377
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
378
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
379
- rotation=90,
380
- )
381
- plt.xlabel("Time [s]")
382
- plt.tight_layout()
383
-
384
- if save:
385
- plt_path = get_plt_path(
386
- OUT_PATH,
387
- feature_file,
388
- ch_name,
389
- str_plt_type="MOV_aligned_features",
390
- feature_name=feature_str_add,
391
- )
392
- plt.savefig(plt_path, bbox_inches="tight")
393
- logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
394
- if show_plot is False:
395
- plt.close()
396
-
397
-
398
- def plot_grid_elec_3d(
399
- cortex_grid: np.ndarray | None = None,
400
- ecog_strip: np.ndarray | None = None,
401
- grid_color: np.ndarray | None = None,
402
- strip_color: np.ndarray | None = None,
403
- ):
404
- ax = plt.axes(projection="3d")
405
-
406
- if cortex_grid is not None:
407
- grid_color = (
408
- np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
409
- )
410
- _ = ax.scatter3D(
411
- cortex_grid[:, 0],
412
- cortex_grid[:, 1],
413
- cortex_grid[:, 2],
414
- c=grid_color,
415
- s=300,
416
- alpha=0.8,
417
- cmap="viridis",
418
- )
419
-
420
- if ecog_strip is not None:
421
- strip_color = (
422
- np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
423
- )
424
- _ = ax.scatter(
425
- ecog_strip[:, 0],
426
- ecog_strip[:, 1],
427
- ecog_strip[:, 2],
428
- c=strip_color,
429
- s=500,
430
- alpha=0.8,
431
- cmap="gray",
432
- marker="o",
433
- )
434
-
435
-
436
- def plot_all_features(
437
- df: pd.DataFrame,
438
- time_limit_low_s: float = None,
439
- time_limit_high_s: float = None,
440
- normalize: bool = True,
441
- ytick_labelsize: int = 4,
442
- clim_low: float = None,
443
- clim_high: float = None,
444
- save: bool = False,
445
- title="all_feature_plt.pdf",
446
- OUT_PATH: str = None,
447
- feature_file: str = None,
448
- ):
449
- if time_limit_high_s is not None:
450
- df = df[df["time"] < time_limit_high_s * 1000]
451
- if time_limit_low_s is not None:
452
- df = df[df["time"] > time_limit_low_s * 1000]
453
-
454
- cols_plt = [c for c in df.columns if c != "time"]
455
- if normalize is True:
456
- data_plt = stats.zscore(df[cols_plt], nan_policy="omit")
457
- else:
458
- data_plt = df[cols_plt]
459
-
460
- plt.figure() # figsize=(7, 5), dpi=300
461
- plt.imshow(data_plt.T, aspect="auto")
462
- plt.xlabel("Time [s]")
463
- plt.ylabel("Feature Names")
464
- plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
465
-
466
- tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
467
- tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
468
- plt.xticks(tick_num, tick_labels)
469
-
470
- plt.title(f"Feature Plot {feature_file}")
471
-
472
- if clim_low is not None:
473
- plt.clim(vmin=clim_low)
474
- if clim_high is not None:
475
- plt.clim(vmax=clim_high)
476
-
477
- plt.colorbar()
478
- plt.tight_layout()
479
-
480
- if save is True:
481
- plt_path = os.path.join(OUT_PATH, feature_file, title)
482
- plt.savefig(plt_path, bbox_inches="tight")
483
-
484
-
485
- class NM_Plot:
486
- def __init__(
487
- self,
488
- ecog_strip: np.ndarray | None = None,
489
- grid_cortex: np.ndarray | None = None,
490
- grid_subcortex: np.ndarray | None = None,
491
- sess_right: Optional[bool] = False,
492
- proj_matrix_cortex: np.ndarray | None = None,
493
- ) -> None:
494
- self.grid_cortex = grid_cortex
495
- self.grid_subcortex = grid_subcortex
496
- self.ecog_strip = ecog_strip
497
- self.sess_right = sess_right
498
- self.proj_matrix_cortex = proj_matrix_cortex
499
-
500
- (
501
- self.faces,
502
- self.vertices,
503
- self.grid,
504
- self.stn_surf,
505
- self.x_ver,
506
- self.y_ver,
507
- self.x_ecog,
508
- self.y_ecog,
509
- self.z_ecog,
510
- self.x_stn,
511
- self.y_stn,
512
- self.z_stn,
513
- ) = nm_IO.read_plot_modules()
514
-
515
- def plot_grid_elec_3d(self) -> None:
516
- plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
517
-
518
- def plot_cortex(
519
- self,
520
- grid_cortex: Optional[np.ndarray] = None,
521
- grid_color: Optional[np.ndarray] = None,
522
- ecog_strip: Optional[np.ndarray] = None,
523
- strip_color: Optional[np.ndarray] = None,
524
- sess_right: Optional[bool] = None,
525
- save: bool = False,
526
- OUT_PATH: str = None,
527
- feature_file: str = None,
528
- feature_str_add: str = None,
529
- show_plot: bool = True,
530
- title: str = "Cortical grid",
531
- set_clim: bool = True,
532
- lower_clim: float = 0.5,
533
- upper_clim: float = 0.7,
534
- cbar_label: str = "Balanced Accuracy",
535
- ):
536
- """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
537
- Colorcoded by grid_color
538
- """
539
-
540
- if grid_cortex is None:
541
- if type(self.grid_cortex) is pd.DataFrame:
542
- grid_cortex = np.array(self.grid_cortex)
543
- else:
544
- grid_cortex = self.grid_cortex
545
-
546
- if ecog_strip is None:
547
- ecog_strip = self.ecog_strip
548
-
549
- if sess_right is True:
550
- grid_cortex[0, :] = grid_cortex[0, :] * -1
551
-
552
- fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
553
- axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
554
- axes.axes.set_aspect("equal", anchor="C")
555
-
556
- if grid_cortex is not None:
557
- grid_color = (
558
- np.ones(grid_cortex.shape[0])
559
- if grid_color is None
560
- else grid_color
561
- )
562
-
563
- pos_ecog = axes.scatter(
564
- grid_cortex[:, 0],
565
- grid_cortex[:, 1],
566
- c=grid_color,
567
- s=150,
568
- alpha=0.8,
569
- cmap="viridis",
570
- label="grid points",
571
- )
572
- if set_clim:
573
- pos_ecog.set_clim(lower_clim, upper_clim)
574
- if ecog_strip is not None:
575
- strip_color = (
576
- np.ones(ecog_strip.shape[0])
577
- if strip_color is None
578
- else strip_color
579
- )
580
-
581
- pos_ecog = axes.scatter(
582
- ecog_strip[:, 0],
583
- ecog_strip[:, 1],
584
- c=strip_color,
585
- s=400,
586
- alpha=0.8,
587
- cmap="viridis",
588
- marker="x",
589
- label="ecog electrode",
590
- )
591
- plt.axis("off")
592
- plt.legend()
593
- plt.title(title)
594
- if set_clim:
595
- pos_ecog.set_clim(lower_clim, upper_clim)
596
- cbar = fig.colorbar(pos_ecog)
597
- cbar.set_label(cbar_label)
598
-
599
- if save:
600
- plt_path = get_plt_path(
601
- OUT_PATH,
602
- feature_file,
603
- ch_name=None,
604
- str_plt_type="PLOT_CORTEX",
605
- feature_name=feature_str_add,
606
- )
607
- plt.savefig(plt_path, bbox_inches="tight")
608
- logger.info(
609
- f"Feature epoch average figure saved to: {str(plt_path)}"
610
- )
611
- if show_plot is False:
612
- plt.close()
1
+ import numpy as np
2
+ import pandas as pd
3
+ from scipy.stats import zscore as scipy_zscore
4
+ from matplotlib import pyplot as plt
5
+ from matplotlib import gridspec
6
+ import seaborn as sb
7
+ from pathlib import PurePath
8
+
9
+ from py_neuromodulation.nm_types import _PathLike
10
+ from py_neuromodulation import logger
11
+
12
+
13
+ def plot_df_subjects(
14
+ df,
15
+ x_col="sub",
16
+ y_col="performance_test",
17
+ hue=None,
18
+ title="channel specific performances",
19
+ PATH_SAVE: _PathLike = "",
20
+ figsize_tuple: tuple[float, float] = (5, 3),
21
+ ):
22
+ alpha_box = 0.4
23
+ plt.figure(figsize=figsize_tuple, dpi=300)
24
+ sb.boxplot(
25
+ x=x_col,
26
+ y=y_col,
27
+ hue=hue,
28
+ data=df,
29
+ palette="viridis",
30
+ showmeans=False,
31
+ boxprops=dict(alpha=alpha_box),
32
+ showcaps=True,
33
+ showbox=True,
34
+ showfliers=False,
35
+ notch=False,
36
+ whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
37
+ capprops={"alpha": alpha_box},
38
+ medianprops=dict(linestyle="-", linewidth=5, color="gray", alpha=alpha_box),
39
+ )
40
+
41
+ ax = sb.stripplot(
42
+ x=x_col,
43
+ y=y_col,
44
+ hue=hue,
45
+ data=df,
46
+ palette="viridis",
47
+ dodge=True,
48
+ s=5,
49
+ )
50
+
51
+ if hue is not None:
52
+ n_hues = df[hue].nunique()
53
+
54
+ handles, labels = ax.get_legend_handles_labels()
55
+ plt.legend(
56
+ handles[0:n_hues],
57
+ labels[0:n_hues],
58
+ bbox_to_anchor=(1.05, 1),
59
+ loc=2,
60
+ title=hue,
61
+ borderaxespad=0.0,
62
+ )
63
+ plt.title(title)
64
+ plt.ylabel(y_col)
65
+ plt.xticks(rotation=90)
66
+ if PATH_SAVE:
67
+ plt.savefig(
68
+ PATH_SAVE,
69
+ bbox_inches="tight",
70
+ )
71
+ # plt.show()
72
+ return plt.gca()
73
+
74
+
75
+ def plot_epoch(
76
+ X_epoch: np.ndarray,
77
+ y_epoch: np.ndarray,
78
+ feature_names: list,
79
+ z_score: bool | None = None,
80
+ epoch_len: int = 4,
81
+ sfreq: int = 10,
82
+ str_title: str = "",
83
+ str_label: str = "",
84
+ ytick_labelsize: float | None = None,
85
+ ):
86
+ if z_score is None:
87
+ X_epoch = scipy_zscore(
88
+ np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
89
+ axis=0,
90
+ nan_policy="omit",
91
+ ).T
92
+ y_epoch = np.stack([np.array(y_epoch)])
93
+ plt.figure(figsize=(6, 6))
94
+ plt.subplot(211)
95
+ plt.imshow(X_epoch, aspect="auto")
96
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
97
+ plt.xticks(
98
+ np.arange(0, X_epoch.shape[1], 1),
99
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
100
+ rotation=90,
101
+ )
102
+ plt.gca().invert_yaxis()
103
+ plt.xlabel("Time [s]")
104
+ plt.title(str_title)
105
+
106
+ plt.subplot(212)
107
+ for i in range(y_epoch.shape[0]):
108
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
109
+ plt.plot(
110
+ y_epoch.mean(axis=0),
111
+ color="black",
112
+ alpha=1,
113
+ linewidth=3.0,
114
+ label="mean target",
115
+ )
116
+ plt.legend()
117
+ plt.ylabel("Target")
118
+ plt.title(str_label)
119
+ plt.xticks(
120
+ np.arange(0, X_epoch.shape[1], 1),
121
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
122
+ rotation=90,
123
+ )
124
+ plt.xlabel("Time [s]")
125
+ plt.tight_layout()
126
+
127
+
128
+ def reg_plot(
129
+ x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str | None = None
130
+ ):
131
+
132
+ from py_neuromodulation.nm_stats import permutationTestSpearmansRho
133
+
134
+ plt.figure(figsize=(4, 4), dpi=300)
135
+ rho, p = permutationTestSpearmansRho(
136
+ data[x_col],
137
+ data[y_col],
138
+ False,
139
+ "R^2",
140
+ 5000,
141
+ )
142
+ sb.regplot(x=x_col, y=y_col, data=data)
143
+ plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
144
+
145
+ if out_path_save is not None:
146
+ plt.savefig(
147
+ out_path_save,
148
+ bbox_inches="tight",
149
+ )
150
+
151
+
152
+ def plot_bar_performance_per_channel(
153
+ ch_names,
154
+ performances: dict,
155
+ PATH_OUT: _PathLike,
156
+ sub: str | None = None,
157
+ save_str: str = "ch_comp_bar_plt.png",
158
+ performance_metric: str = "Balanced Accuracy",
159
+ ):
160
+ """
161
+ performances dict is output of ml_decode
162
+ """
163
+ plt.figure(figsize=(4, 3), dpi=300)
164
+ if sub is None:
165
+ sub = list(performances.keys())[0]
166
+ plt.bar(
167
+ np.arange(len(ch_names)),
168
+ [performances[sub][p]["performance_test"] for p in performances[sub]],
169
+ )
170
+ plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
171
+ plt.xlabel("channels")
172
+ plt.ylabel(performance_metric)
173
+ plt.savefig(
174
+ PurePath(PATH_OUT, save_str),
175
+ bbox_inches="tight",
176
+ )
177
+ plt.close()
178
+
179
+
180
+ def plot_corr_matrix(
181
+ feature: pd.DataFrame,
182
+ feature_file: _PathLike = "",
183
+ ch_name: str = "",
184
+ feature_names: list[str] = [],
185
+ show_plot=True,
186
+ OUT_PATH: _PathLike = "",
187
+ feature_name_plt="Features_corr_matr",
188
+ save_plot: bool = False,
189
+ save_plot_name: str = "",
190
+ figsize: tuple[float, float] = (7, 7),
191
+ title: str = "",
192
+ cbar_vmin: float = -1,
193
+ cbar_vmax: float = 1.0,
194
+ ):
195
+ # cut out channel name for each column
196
+ if not ch_name:
197
+ feature_col_name = [
198
+ i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
199
+ ]
200
+ else:
201
+ feature_col_name = feature.columns
202
+
203
+ plt.figure(figsize=figsize)
204
+ if (
205
+ len(feature_names) > 0
206
+ ): # Checking length to accomodate for tests passing a pandas Index
207
+ corr = feature[feature_names].corr()
208
+ else:
209
+ corr = feature.corr()
210
+ sb.heatmap(
211
+ corr,
212
+ xticklabels=feature_col_name,
213
+ yticklabels=feature_col_name,
214
+ vmin=cbar_vmin,
215
+ vmax=cbar_vmax,
216
+ cmap="viridis",
217
+ )
218
+ if not title:
219
+ if ch_name:
220
+ plt.title("Correlation matrix features channel: " + str(ch_name))
221
+ else:
222
+ plt.title("Correlation matrix")
223
+ else:
224
+ plt.title(title)
225
+
226
+ # if len(feature_col_name) > 50:
227
+ # plt.xticks([])
228
+ # plt.yticks([])
229
+
230
+ if save_plot:
231
+ plt_path = (
232
+ PurePath(OUT_PATH, save_plot_name)
233
+ if save_plot_name
234
+ else get_plt_path(
235
+ OUT_PATH=OUT_PATH,
236
+ feature_file=feature_file,
237
+ ch_name=ch_name,
238
+ str_plt_type=feature_name_plt,
239
+ feature_name="_".join(feature_names),
240
+ )
241
+ )
242
+
243
+ plt.savefig(plt_path, bbox_inches="tight")
244
+ logger.info(f"Correlation matrix figure saved to {plt_path}")
245
+
246
+ if not show_plot:
247
+ plt.close()
248
+
249
+ plt.tight_layout()
250
+
251
+ return plt.gca()
252
+
253
+
254
+ def plot_feature_series_time(features) -> None:
255
+ plt.imshow(features.T, aspect="auto")
256
+
257
+
258
+ def get_plt_path(
259
+ OUT_PATH: _PathLike = "",
260
+ feature_file: str = "",
261
+ ch_name: str = "",
262
+ str_plt_type: str = "",
263
+ feature_name: str = "",
264
+ ) -> _PathLike:
265
+ """[summary]
266
+
267
+ Parameters
268
+ ----------
269
+ OUT_PATH : str, optional
270
+ folder of preprocessed runs, by default None
271
+ feature_file : str, optional
272
+ run_name, by default None
273
+ ch_name : str, optional
274
+ ch_name, by default None
275
+ str_plt_type : str, optional
276
+ type of plot, e.g. mov_avg_feature or corr_matr, by default None
277
+ feature_name : str, optional
278
+ e.g. bandpower, stft, sharpwave_prominence, by default None
279
+ """
280
+ filename = (
281
+ str_plt_type
282
+ + (("_ch_" + ch_name) if ch_name else "")
283
+ + (("_" + feature_name) if feature_name else "")
284
+ + ".png"
285
+ )
286
+
287
+ return PurePath(OUT_PATH, feature_file, filename)
288
+
289
+
290
+ def plot_epochs_avg(
291
+ X_epoch: np.ndarray,
292
+ y_epoch: np.ndarray,
293
+ epoch_len: int,
294
+ sfreq: int,
295
+ feature_names: list[str] = [],
296
+ feature_str_add: str = "",
297
+ cut_ch_name_cols: bool = True,
298
+ ch_name: str = "",
299
+ label_name: str = "",
300
+ normalize_data: bool = True,
301
+ show_plot: bool = True,
302
+ save: bool = False,
303
+ OUT_PATH: _PathLike = "",
304
+ feature_file: str = "",
305
+ str_title: str = "Movement aligned features",
306
+ ytick_labelsize=None,
307
+ figsize_x: float = 8,
308
+ figsize_y: float = 8,
309
+ ) -> None:
310
+ # cut channel name of for axis + "_" for more dense plot
311
+ if not feature_names:
312
+ if cut_ch_name_cols and None not in (ch_name, feature_names):
313
+ feature_names = [
314
+ i[len(ch_name) + 1 :] for i in list(feature_names) if ch_name in i
315
+ ]
316
+
317
+ if normalize_data:
318
+ X_epoch_mean = scipy_zscore(
319
+ np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
320
+ ).T
321
+ else:
322
+ X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
323
+
324
+ if len(X_epoch_mean.shape) == 1:
325
+ X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
326
+
327
+ plt.figure(figsize=(figsize_x, figsize_y))
328
+ gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
329
+ plt.subplot(gs[0])
330
+ plt.imshow(X_epoch_mean, aspect="auto")
331
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
332
+ plt.xticks(
333
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
334
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
335
+ rotation=90,
336
+ )
337
+ plt.xlabel("Time [s]")
338
+ str_title = str_title
339
+ if ch_name:
340
+ str_title += f" channel: {ch_name}"
341
+ plt.title(str_title)
342
+
343
+ plt.subplot(gs[1])
344
+ for i in range(y_epoch.shape[0]):
345
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
346
+ plt.plot(
347
+ y_epoch.mean(axis=0),
348
+ color="black",
349
+ alpha=1,
350
+ linewidth=3.0,
351
+ label="mean target",
352
+ )
353
+ plt.legend()
354
+ plt.ylabel("Target")
355
+ plt.title(label_name)
356
+ plt.xticks(
357
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
358
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
359
+ rotation=90,
360
+ )
361
+ plt.xlabel("Time [s]")
362
+ plt.tight_layout()
363
+
364
+ if save:
365
+ plt_path = get_plt_path(
366
+ OUT_PATH,
367
+ feature_file,
368
+ ch_name,
369
+ str_plt_type="MOV_aligned_features",
370
+ feature_name=feature_str_add,
371
+ )
372
+ plt.savefig(plt_path, bbox_inches="tight")
373
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
374
+ if not show_plot:
375
+ plt.close()
376
+
377
+
378
+ def plot_grid_elec_3d(
379
+ cortex_grid: np.ndarray | None = None,
380
+ ecog_strip: np.ndarray | None = None,
381
+ grid_color: np.ndarray | None = None,
382
+ strip_color: np.ndarray | None = None,
383
+ ):
384
+ ax = plt.axes(projection="3d")
385
+
386
+ if cortex_grid is not None:
387
+ grid_color = np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
388
+ _ = ax.scatter3D(
389
+ cortex_grid[:, 0],
390
+ cortex_grid[:, 1],
391
+ cortex_grid[:, 2],
392
+ c=grid_color,
393
+ s=300,
394
+ alpha=0.8,
395
+ cmap="viridis",
396
+ )
397
+
398
+ if ecog_strip is not None:
399
+ strip_color = (
400
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
401
+ )
402
+ _ = ax.scatter(
403
+ ecog_strip[:, 0],
404
+ ecog_strip[:, 1],
405
+ ecog_strip[:, 2],
406
+ c=strip_color,
407
+ s=500, # Bug? Third argument is s, what is this value?
408
+ alpha=0.8,
409
+ cmap="gray",
410
+ marker="o",
411
+ )
412
+
413
+
414
+ def plot_all_features(
415
+ df: pd.DataFrame,
416
+ time_limit_low_s: float | None = None,
417
+ time_limit_high_s: float | None = None,
418
+ normalize: bool = True,
419
+ ytick_labelsize: int = 4,
420
+ clim_low: float | None = None,
421
+ clim_high: float | None = None,
422
+ save: bool = False,
423
+ title="all_feature_plt.pdf",
424
+ OUT_PATH: _PathLike = "",
425
+ feature_file: str = "",
426
+ ):
427
+ if time_limit_high_s is not None:
428
+ df = df[df["time"] < time_limit_high_s * 1000]
429
+ if time_limit_low_s is not None:
430
+ df = df[df["time"] > time_limit_low_s * 1000]
431
+
432
+ cols_plt = [c for c in df.columns if c != "time"]
433
+ if normalize:
434
+ data_plt = scipy_zscore(df[cols_plt], nan_policy="omit")
435
+ else:
436
+ data_plt = df[cols_plt]
437
+
438
+ plt.figure() # figsize=(7, 5), dpi=300
439
+ plt.imshow(data_plt.T, aspect="auto")
440
+ plt.xlabel("Time [s]")
441
+ plt.ylabel("Feature Names")
442
+ plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
443
+
444
+ tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
445
+ tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
446
+ plt.xticks(tick_num, tick_labels)
447
+
448
+ plt.title(f"Feature Plot {feature_file}")
449
+
450
+ if clim_low is not None:
451
+ plt.clim(vmin=clim_low)
452
+ if clim_high is not None:
453
+ plt.clim(vmax=clim_high)
454
+
455
+ plt.colorbar()
456
+ plt.tight_layout()
457
+
458
+ if save:
459
+ plt_path = PurePath(OUT_PATH, feature_file, title)
460
+ plt.savefig(plt_path, bbox_inches="tight")
461
+
462
+
463
+ class NM_Plot:
464
+ def __init__(
465
+ self,
466
+ ecog_strip: np.ndarray | None = None,
467
+ grid_cortex: np.ndarray | None = None,
468
+ grid_subcortex: np.ndarray | None = None,
469
+ sess_right: bool | None = False,
470
+ proj_matrix_cortex: np.ndarray | None = None,
471
+ ) -> None:
472
+ self.grid_cortex = grid_cortex
473
+ self.grid_subcortex = grid_subcortex
474
+ self.ecog_strip = ecog_strip
475
+ self.sess_right = sess_right
476
+ self.proj_matrix_cortex = proj_matrix_cortex
477
+
478
+ from py_neuromodulation.nm_IO import read_plot_modules
479
+
480
+ (
481
+ self.faces,
482
+ self.vertices,
483
+ self.grid,
484
+ self.stn_surf,
485
+ self.x_ver,
486
+ self.y_ver,
487
+ self.x_ecog,
488
+ self.y_ecog,
489
+ self.z_ecog,
490
+ self.x_stn,
491
+ self.y_stn,
492
+ self.z_stn,
493
+ ) = read_plot_modules()
494
+
495
+ def plot_grid_elec_3d(self) -> None:
496
+ plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
497
+
498
+ def plot_cortex(
499
+ self,
500
+ grid_cortex: np.ndarray | pd.DataFrame | None = None,
501
+ grid_color: np.ndarray | None = None,
502
+ ecog_strip: np.ndarray | None = None,
503
+ strip_color: np.ndarray | None = None,
504
+ sess_right: bool | None = None,
505
+ save: bool = False,
506
+ OUT_PATH: _PathLike = "",
507
+ feature_file: str = "",
508
+ feature_str_add: str = "",
509
+ show_plot: bool = True,
510
+ title: str = "Cortical grid",
511
+ set_clim: bool = True,
512
+ lower_clim: float = 0.5,
513
+ upper_clim: float = 0.7,
514
+ cbar_label: str = "Balanced Accuracy",
515
+ ):
516
+ """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
517
+ Colorcoded by grid_color
518
+ """
519
+
520
+ if grid_cortex is None:
521
+ if type(self.grid_cortex) is pd.DataFrame:
522
+ grid_cortex = np.array(self.grid_cortex)
523
+ else:
524
+ grid_cortex = self.grid_cortex
525
+
526
+ if ecog_strip is None:
527
+ ecog_strip = self.ecog_strip
528
+
529
+ if sess_right:
530
+ grid_cortex[0, :] = grid_cortex[0, :] * -1 # type: ignore # Handled above
531
+
532
+ fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
533
+ axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
534
+ axes.axes.set_aspect("equal", anchor="C")
535
+
536
+ if grid_cortex is not None:
537
+ grid_color = (
538
+ np.ones(grid_cortex.shape[0]) if grid_color is None else grid_color
539
+ )
540
+
541
+ pos_ecog = axes.scatter(
542
+ grid_cortex[:, 0],
543
+ grid_cortex[:, 1],
544
+ c=grid_color,
545
+ s=150,
546
+ alpha=0.8,
547
+ cmap="viridis",
548
+ label="grid points",
549
+ )
550
+ if set_clim:
551
+ pos_ecog.set_clim(lower_clim, upper_clim)
552
+ if ecog_strip is not None:
553
+ strip_color = (
554
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
555
+ )
556
+
557
+ pos_ecog = axes.scatter(
558
+ ecog_strip[:, 0],
559
+ ecog_strip[:, 1],
560
+ c=strip_color,
561
+ s=400,
562
+ alpha=0.8,
563
+ cmap="viridis",
564
+ marker="x",
565
+ label="ecog electrode",
566
+ )
567
+ plt.axis("off")
568
+ plt.legend()
569
+ plt.title(title)
570
+ if set_clim:
571
+ pos_ecog.set_clim(lower_clim, upper_clim)
572
+ cbar = fig.colorbar(pos_ecog)
573
+ cbar.set_label(cbar_label)
574
+
575
+ if save:
576
+ plt_path = get_plt_path(
577
+ OUT_PATH,
578
+ feature_file,
579
+ str_plt_type="PLOT_CORTEX",
580
+ feature_name=feature_str_add,
581
+ )
582
+ plt.savefig(plt_path, bbox_inches="tight")
583
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
584
+ if not show_plot:
585
+ plt.close()