py-neuromodulation 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
  3. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -0
  4. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -0
  5. py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
  6. py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
  7. py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
  8. py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
  9. py_neuromodulation/FieldTrip.py +589 -589
  10. py_neuromodulation/__init__.py +74 -13
  11. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  12. py_neuromodulation/data/README +6 -0
  13. py_neuromodulation/data/dataset_description.json +8 -0
  14. py_neuromodulation/data/participants.json +32 -0
  15. py_neuromodulation/data/participants.tsv +2 -0
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -0
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -0
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -0
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.eeg +0 -0
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -0
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -0
  22. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -0
  23. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -0
  24. py_neuromodulation/grid_cortex.tsv +40 -0
  25. py_neuromodulation/grid_subcortex.tsv +1429 -0
  26. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  27. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  28. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  29. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  30. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  31. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  32. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  33. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  34. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  35. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  36. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  37. py_neuromodulation/nm_IO.py +413 -417
  38. py_neuromodulation/nm_RMAP.py +496 -531
  39. py_neuromodulation/nm_analysis.py +993 -1074
  40. py_neuromodulation/nm_artifacts.py +30 -25
  41. py_neuromodulation/nm_bispectra.py +154 -168
  42. py_neuromodulation/nm_bursts.py +292 -198
  43. py_neuromodulation/nm_coherence.py +251 -205
  44. py_neuromodulation/nm_database.py +149 -0
  45. py_neuromodulation/nm_decode.py +918 -992
  46. py_neuromodulation/nm_define_nmchannels.py +300 -302
  47. py_neuromodulation/nm_features.py +144 -116
  48. py_neuromodulation/nm_filter.py +219 -219
  49. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  50. py_neuromodulation/nm_fooof.py +139 -159
  51. py_neuromodulation/nm_generator.py +45 -37
  52. py_neuromodulation/nm_hjorth_raw.py +52 -73
  53. py_neuromodulation/nm_kalmanfilter.py +71 -58
  54. py_neuromodulation/nm_linelength.py +21 -33
  55. py_neuromodulation/nm_logger.py +66 -0
  56. py_neuromodulation/nm_mne_connectivity.py +149 -112
  57. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  58. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  59. py_neuromodulation/nm_nolds.py +96 -93
  60. py_neuromodulation/nm_normalization.py +173 -214
  61. py_neuromodulation/nm_oscillatory.py +423 -448
  62. py_neuromodulation/nm_plots.py +585 -612
  63. py_neuromodulation/nm_preprocessing.py +83 -0
  64. py_neuromodulation/nm_projection.py +370 -394
  65. py_neuromodulation/nm_rereference.py +97 -95
  66. py_neuromodulation/nm_resample.py +59 -50
  67. py_neuromodulation/nm_run_analysis.py +325 -435
  68. py_neuromodulation/nm_settings.py +289 -68
  69. py_neuromodulation/nm_settings.yaml +244 -0
  70. py_neuromodulation/nm_sharpwaves.py +423 -401
  71. py_neuromodulation/nm_stats.py +464 -480
  72. py_neuromodulation/nm_stream.py +398 -0
  73. py_neuromodulation/nm_stream_abc.py +166 -218
  74. py_neuromodulation/nm_types.py +193 -0
  75. py_neuromodulation/plots/STN_surf.mat +0 -0
  76. py_neuromodulation/plots/Vertices.mat +0 -0
  77. py_neuromodulation/plots/faces.mat +0 -0
  78. py_neuromodulation/plots/grid.mat +0 -0
  79. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +185 -182
  80. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  81. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -2
  82. {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.5.dist-info/licenses}/LICENSE +21 -21
  83. docs/build/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  84. docs/build/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -233
  85. docs/build/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  86. docs/build/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  87. docs/build/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  88. docs/build/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  89. docs/build/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  90. docs/build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
  91. docs/build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -239
  92. docs/build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  93. docs/build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
  94. docs/build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  95. docs/build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
  96. docs/build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  97. docs/source/_build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -76
  98. docs/source/_build/html/_downloads/0d0d0a76e8f648d5d3cbc47da6351932/plot_real_time_demo.py +0 -97
  99. docs/source/_build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -240
  100. docs/source/_build/html/_downloads/5d73cadc59a8805c47e3b84063afc157/plot_example_BIDS.py +0 -233
  101. docs/source/_build/html/_downloads/7660317fa5a6bfbd12fcca9961457fc4/plot_example_rmap_computing.py +0 -63
  102. docs/source/_build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
  103. docs/source/_build/html/_downloads/839e5b319379f7fd9e867deb00fd797f/plot_example_gridPointProjection.py +0 -210
  104. docs/source/_build/html/_downloads/ae8be19afe5e559f011fc9b138968ba0/plot_first_demo.py +0 -192
  105. docs/source/_build/html/_downloads/b8b06cacc17969d3725a0b6f1d7741c5/plot_example_sharpwave_analysis.py +0 -219
  106. docs/source/_build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -121
  107. docs/source/_build/html/_downloads/c31a86c0b68cb4167d968091ace8080d/plot_example_add_feature.py +0 -68
  108. docs/source/_build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
  109. docs/source/_build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -189
  110. docs/source/_build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
  111. docs/source/auto_examples/plot_0_first_demo.py +0 -189
  112. docs/source/auto_examples/plot_1_example_BIDS.py +0 -240
  113. docs/source/auto_examples/plot_2_example_add_feature.py +0 -76
  114. docs/source/auto_examples/plot_3_example_sharpwave_analysis.py +0 -219
  115. docs/source/auto_examples/plot_4_example_gridPointProjection.py +0 -210
  116. docs/source/auto_examples/plot_5_example_rmap_computing.py +0 -64
  117. docs/source/auto_examples/plot_6_real_time_demo.py +0 -121
  118. docs/source/conf.py +0 -105
  119. examples/plot_0_first_demo.py +0 -189
  120. examples/plot_1_example_BIDS.py +0 -240
  121. examples/plot_2_example_add_feature.py +0 -76
  122. examples/plot_3_example_sharpwave_analysis.py +0 -219
  123. examples/plot_4_example_gridPointProjection.py +0 -210
  124. examples/plot_5_example_rmap_computing.py +0 -64
  125. examples/plot_6_real_time_demo.py +0 -121
  126. packages/realtime_decoding/build/lib/realtime_decoding/__init__.py +0 -4
  127. packages/realtime_decoding/build/lib/realtime_decoding/decoder.py +0 -104
  128. packages/realtime_decoding/build/lib/realtime_decoding/features.py +0 -163
  129. packages/realtime_decoding/build/lib/realtime_decoding/helpers.py +0 -15
  130. packages/realtime_decoding/build/lib/realtime_decoding/run_decoding.py +0 -345
  131. packages/realtime_decoding/build/lib/realtime_decoding/trainer.py +0 -54
  132. packages/tmsi/build/lib/TMSiFileFormats/__init__.py +0 -37
  133. packages/tmsi/build/lib/TMSiFileFormats/file_formats/__init__.py +0 -36
  134. packages/tmsi/build/lib/TMSiFileFormats/file_formats/lsl_stream_writer.py +0 -200
  135. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_file_writer.py +0 -496
  136. packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_to_edf_converter.py +0 -236
  137. packages/tmsi/build/lib/TMSiFileFormats/file_formats/xdf_file_writer.py +0 -977
  138. packages/tmsi/build/lib/TMSiFileFormats/file_readers/__init__.py +0 -35
  139. packages/tmsi/build/lib/TMSiFileFormats/file_readers/edf_reader.py +0 -116
  140. packages/tmsi/build/lib/TMSiFileFormats/file_readers/poly5reader.py +0 -294
  141. packages/tmsi/build/lib/TMSiFileFormats/file_readers/xdf_reader.py +0 -229
  142. packages/tmsi/build/lib/TMSiFileFormats/file_writer.py +0 -102
  143. packages/tmsi/build/lib/TMSiPlotters/__init__.py +0 -2
  144. packages/tmsi/build/lib/TMSiPlotters/gui/__init__.py +0 -39
  145. packages/tmsi/build/lib/TMSiPlotters/gui/_plotter_gui.py +0 -234
  146. packages/tmsi/build/lib/TMSiPlotters/gui/plotting_gui.py +0 -440
  147. packages/tmsi/build/lib/TMSiPlotters/plotters/__init__.py +0 -44
  148. packages/tmsi/build/lib/TMSiPlotters/plotters/hd_emg_plotter.py +0 -446
  149. packages/tmsi/build/lib/TMSiPlotters/plotters/impedance_plotter.py +0 -589
  150. packages/tmsi/build/lib/TMSiPlotters/plotters/signal_plotter.py +0 -1326
  151. packages/tmsi/build/lib/TMSiSDK/__init__.py +0 -54
  152. packages/tmsi/build/lib/TMSiSDK/device.py +0 -588
  153. packages/tmsi/build/lib/TMSiSDK/devices/__init__.py +0 -34
  154. packages/tmsi/build/lib/TMSiSDK/devices/saga/TMSi_Device_API.py +0 -1764
  155. packages/tmsi/build/lib/TMSiSDK/devices/saga/__init__.py +0 -34
  156. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_device.py +0 -1366
  157. packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_types.py +0 -520
  158. packages/tmsi/build/lib/TMSiSDK/devices/saga/xml_saga_config.py +0 -165
  159. packages/tmsi/build/lib/TMSiSDK/error.py +0 -95
  160. packages/tmsi/build/lib/TMSiSDK/sample_data.py +0 -63
  161. packages/tmsi/build/lib/TMSiSDK/sample_data_server.py +0 -99
  162. packages/tmsi/build/lib/TMSiSDK/settings.py +0 -45
  163. packages/tmsi/build/lib/TMSiSDK/tmsi_device.py +0 -111
  164. packages/tmsi/build/lib/__init__.py +0 -4
  165. packages/tmsi/build/lib/apex_sdk/__init__.py +0 -34
  166. packages/tmsi/build/lib/apex_sdk/device/__init__.py +0 -41
  167. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API.py +0 -1009
  168. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_enums.py +0 -239
  169. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_structures.py +0 -668
  170. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_device.py +0 -1611
  171. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_dongle.py +0 -38
  172. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_event_reader.py +0 -57
  173. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_channel.py +0 -44
  174. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_config.py +0 -150
  175. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_const.py +0 -36
  176. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_impedance_channel.py +0 -48
  177. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_info.py +0 -108
  178. packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/dongle_info.py +0 -39
  179. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/download_measurement.py +0 -77
  180. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/eeg_measurement.py +0 -150
  181. packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/impedance_measurement.py +0 -129
  182. packages/tmsi/build/lib/apex_sdk/device/threads/conversion_thread.py +0 -59
  183. packages/tmsi/build/lib/apex_sdk/device/threads/sampling_thread.py +0 -57
  184. packages/tmsi/build/lib/apex_sdk/device/tmsi_channel.py +0 -83
  185. packages/tmsi/build/lib/apex_sdk/device/tmsi_device.py +0 -201
  186. packages/tmsi/build/lib/apex_sdk/device/tmsi_device_enums.py +0 -103
  187. packages/tmsi/build/lib/apex_sdk/device/tmsi_dongle.py +0 -43
  188. packages/tmsi/build/lib/apex_sdk/device/tmsi_event_reader.py +0 -50
  189. packages/tmsi/build/lib/apex_sdk/device/tmsi_measurement.py +0 -118
  190. packages/tmsi/build/lib/apex_sdk/sample_data_server/__init__.py +0 -33
  191. packages/tmsi/build/lib/apex_sdk/sample_data_server/event_data.py +0 -44
  192. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data.py +0 -50
  193. packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data_server.py +0 -136
  194. packages/tmsi/build/lib/apex_sdk/tmsi_errors/error.py +0 -126
  195. packages/tmsi/build/lib/apex_sdk/tmsi_sdk.py +0 -113
  196. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/apex/apex_structure_generator.py +0 -134
  197. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/decorators.py +0 -60
  198. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/logger_filter.py +0 -42
  199. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/singleton.py +0 -42
  200. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/support_functions.py +0 -72
  201. packages/tmsi/build/lib/apex_sdk/tmsi_utilities/tmsi_logger.py +0 -98
  202. py_neuromodulation/nm_EpochStream.py +0 -92
  203. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  204. py_neuromodulation/nm_cohortwrapper.py +0 -435
  205. py_neuromodulation/nm_eval_timing.py +0 -239
  206. py_neuromodulation/nm_features_abc.py +0 -39
  207. py_neuromodulation/nm_stream_offline.py +0 -358
  208. py_neuromodulation/utils/_logging.py +0 -24
  209. py_neuromodulation-0.0.3.dist-info/RECORD +0 -188
  210. py_neuromodulation-0.0.3.dist-info/top_level.txt +0 -5
  211. tests/__init__.py +0 -0
  212. tests/conftest.py +0 -117
  213. tests/test_all_examples.py +0 -10
  214. tests/test_all_features.py +0 -63
  215. tests/test_bispectra.py +0 -70
  216. tests/test_bursts.py +0 -105
  217. tests/test_feature_sampling_rates.py +0 -143
  218. tests/test_fooof.py +0 -16
  219. tests/test_initalization_offline_stream.py +0 -41
  220. tests/test_multiprocessing.py +0 -58
  221. tests/test_nan_values.py +0 -29
  222. tests/test_nm_filter.py +0 -95
  223. tests/test_nm_resample.py +0 -63
  224. tests/test_normalization_settings.py +0 -146
  225. tests/test_notch_filter.py +0 -31
  226. tests/test_osc_features.py +0 -424
  227. tests/test_preprocessing_filter.py +0 -151
  228. tests/test_rereference.py +0 -171
  229. tests/test_sampling.py +0 -57
  230. tests/test_settings_change_after_init.py +0 -76
  231. tests/test_sharpwave.py +0 -165
  232. tests/test_target_channel_add.py +0 -100
  233. tests/test_timing.py +0 -80
@@ -1,531 +1,496 @@
1
- import numpy as np
2
- import os
3
- import wget
4
-
5
-
6
- # from numba import jit
7
- from scipy import stats
8
- import scipy.io as sio
9
- import pandas as pd
10
- from typing import Union, Tuple, List
11
- import nibabel as nib
12
- from matplotlib import pyplot as plt
13
-
14
- import py_neuromodulation
15
-
16
- from py_neuromodulation import nm_plots
17
-
18
- LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
19
- LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN = [
20
- 1,
21
- 8,
22
- 16,
23
- 33,
24
- 34,
25
- 35,
26
- 36,
27
- 37,
28
- 51,
29
- 75,
30
- 77,
31
- 78,
32
- 99,
33
- 109,
34
- 115,
35
- 136,
36
- 155,
37
- 170,
38
- 210,
39
- 215,
40
- 243,
41
- 352,
42
- 359,
43
- 361,
44
- 415,
45
- 416,
46
- 422,
47
- 529,
48
- 567,
49
- 569,
50
- 622,
51
- 623,
52
- 625,
53
- 627,
54
- 632,
55
- 633,
56
- 634,
57
- 635,
58
- 639,
59
- 640,
60
- 641,
61
- 643,
62
- 644,
63
- 650,
64
- 661,
65
- 663,
66
- 667,
67
- 683,
68
- 684,
69
- 685,
70
- 704,
71
- 708,
72
- 722,
73
- 839,
74
- 840,
75
- 905,
76
- 993,
77
- 1011,
78
- ]
79
-
80
-
81
- class ConnectivityChannelSelector:
82
-
83
- def __init__(
84
- self,
85
- whole_brain_connectome: bool = True,
86
- func_connectivity: bool = True,
87
- ) -> None:
88
- """ConnectivityChannelSelector
89
-
90
- Parameters
91
- ----------
92
- whole_brain_connectome : bool, optional
93
- if True a 1236 whole-brain point grid is chosen,
94
- if False, a 1025 point grid of the cortical hull is loaded,
95
- by default True
96
- func_connectivity : bool, optional
97
- if true, functional connectivity fMRI is loaded,
98
- if false structural dMRIby, default True
99
- """
100
-
101
- self.connectome_name = self._get_connectome_name(
102
- whole_brain_connectome, func_connectivity
103
- )
104
-
105
- self.whole_brain_connectome = whole_brain_connectome
106
- self.func_connectivity = func_connectivity
107
-
108
- self.PATH_CONN_DECODING = os.path.join(
109
- py_neuromodulation.__path__[0],
110
- "ConnectivityDecoding",
111
- )
112
-
113
- if whole_brain_connectome:
114
- self.PATH_GRID = os.path.join(
115
- self.PATH_CONN_DECODING,
116
- "mni_coords_whole_brain.mat",
117
- )
118
- self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"]
119
- if func_connectivity is False:
120
- # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN
121
- self.grid = np.delete(
122
- self.grid,
123
- LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN,
124
- axis=0,
125
- )
126
- else:
127
- self.PATH_GRID = os.path.join(
128
- self.PATH_CONN_DECODING,
129
- "mni_coords_cortical_surface.mat",
130
- )
131
- self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"]
132
- if func_connectivity is False:
133
- # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL
134
- self.grid = np.delete(
135
- self.grid, LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL, axis=0
136
- )
137
-
138
- if func_connectivity:
139
- self.RMAP_arr = nib.load(
140
- os.path.join(self.PATH_CONN_DECODING, "RMAP_func_all.nii")
141
- ).get_fdata()
142
- else:
143
- self.RMAP_arr = nib.load(
144
- os.path.join(self.PATH_CONN_DECODING, "RMAP_struc.nii")
145
- ).get_fdata()
146
-
147
- def _get_connectome_name(
148
- self, whole_brain_connectome: str, func_connectivity: str
149
- ):
150
-
151
- connectome_name = "connectome_"
152
- if whole_brain_connectome:
153
- connectome_name += "whole_brain_"
154
- else:
155
- connectome_name += "hull_"
156
- if func_connectivity:
157
- connectome_name += "func"
158
- else:
159
- connectome_name += "struc"
160
- return connectome_name
161
-
162
- def get_available_connectomes(self) -> list:
163
- """Return list of saved connectomes in the
164
- package folder/ConnectivityDecoding/connectome_folder/ folder.
165
-
166
- Returns
167
- -------
168
- list_connectomes: list
169
- """
170
- return os.listdir(
171
- os.path.join(
172
- self.PATH_CONN_DECODING,
173
- "connectome_folder",
174
- )
175
- )
176
-
177
- def plot_grid(self) -> None:
178
- """Plot the loaded template grid that passed coordinates are matched to."""
179
-
180
- fig = plt.figure()
181
- ax = fig.add_subplot(111, projection="3d")
182
- ax.scatter(
183
- self.grid[:, 0], self.grid[:, 1], self.grid[:, 2], s=50, alpha=0.2
184
- )
185
- plt.show()
186
-
187
- def get_closest_node(
188
- self, coord: Union[List, np.array]
189
- ) -> Tuple[List, List]:
190
- """Given a list or np.array of coordinates, return the closest nodes in the
191
- grid and their indices.
192
-
193
- Parameters
194
- ----------
195
- coord : np.array
196
- MNI coordinates with shape (num_channels, 3)
197
-
198
- Returns
199
- -------
200
- Tuple[List, List]
201
- Grid coordinates, grid indices
202
- """
203
-
204
- idx_ = []
205
- for c in coord:
206
- dist = np.linalg.norm(self.grid - c, axis=1)
207
- idx_.append(np.argmin(dist))
208
-
209
- return [self.grid[idx] for idx in idx_], idx_
210
-
211
- def get_rmap_correlations(
212
- self, fps: Union[list, np.array], RMAP_use: np.array = None
213
- ) -> List:
214
- """Calculate correlations of passed fingerprints with the RMAP
215
-
216
- Parameters
217
- ----------
218
- fps : Union[list, np.array]
219
- List of fingerprints
220
- RMAP_use : np.array, optional
221
- Passed RMAP, by default None
222
-
223
- Returns
224
- -------
225
- List
226
- correlation values
227
- """
228
-
229
- RMAP_ = self.RMAP_arr if RMAP_use is None else RMAP_use
230
- RMAP_ = RMAP_.flatten()
231
- corrs = []
232
- for fp in fps:
233
- corrs.append(np.corrcoef(RMAP_, fp.flatten())[0][1])
234
- return corrs
235
-
236
- def load_connectome(
237
- self,
238
- whole_brain_connectome: bool = None,
239
- func_connectivity: bool = None,
240
- ) -> None:
241
- """Load connectome, if not available download connectome from
242
- Zenodo.
243
-
244
- Parameters
245
- ----------
246
- whole_brain_connectome : bool, optional
247
- if true whole brain connectome
248
- if false cortical hull grid connectome, by default None
249
- func_connectivity : bool, optional
250
- if true fMRI if false dMRI, by default None
251
- """
252
-
253
- if whole_brain_connectome is not None:
254
- self.whole_brain_connectome = whole_brain_connectome
255
- if func_connectivity is not None:
256
- self.func_connectivity = func_connectivity
257
-
258
- self.connectome_name = self._get_connectome_name(
259
- self.whole_brain_connectome, self.func_connectivity
260
- )
261
-
262
- PATH_CONNECTOME = os.path.join(
263
- self.PATH_CONN_DECODING,
264
- "connectome_folder",
265
- self.connectome_name + ".mat",
266
- )
267
-
268
- if os.path.exists(PATH_CONNECTOME) is False:
269
- user_input = input(
270
- "Do you want to download the connectome? (yes/no): "
271
- ).lower()
272
- if user_input == "yes":
273
- self._download_connectome()
274
- elif user_input == "no":
275
- print("Connectome missing, has to be downloaded")
276
-
277
- self.connectome = sio.loadmat(PATH_CONNECTOME)
278
-
279
- def get_grid_fingerprints(self, grid_idx: Union[list, np.array]) -> list:
280
- return [self.connectome[str(grid_idx)] for grid_idx in grid_idx]
281
-
282
- def download_connectome(
283
- self,
284
- ):
285
- # download the connectome from the Zenodo API
286
- print("Downloading the connectome...")
287
-
288
- record_id = "10804702"
289
- file_name = self.connectome_name
290
-
291
- wget.download(
292
- f"https://zenodo.org/api/records/{record_id}/files/{file_name}/content",
293
- out=os.path.join(
294
- self.PATH_CONN_DECODING,
295
- "connectome_folder",
296
- f"{self.connectome_name}.mat",
297
- ),
298
- )
299
-
300
-
301
- class RMAPCross_Val_ChannelSelector:
302
-
303
- def __init__(self) -> None:
304
- pass
305
-
306
- def load_fingerprint(self, path_nii) -> None:
307
- """Return Nifti fingerprint"""
308
- epi_img = nib.load(path_nii)
309
- self.affine = epi_img.affine
310
- fp = epi_img.get_fdata()
311
- return fp
312
-
313
- def load_all_fingerprints(
314
- self, path_dir: str, cond_str: str = "_AvgR_Fz.nii"
315
- ):
316
-
317
- if cond_str is not None:
318
- l_fps = list(filter(lambda k: cond_str in k, os.listdir(path_dir)))
319
- else:
320
- l_fps = os.listdir(path_dir)
321
-
322
- return l_fps, [
323
- self.load_fingerprint(os.path.join(path_dir, f)) for f in l_fps
324
- ]
325
-
326
- def get_fingerprints_from_path_with_cond(
327
- self,
328
- path_dir: str,
329
- str_to_omit: str = None,
330
- str_to_keep: str = None,
331
- keep: bool = True,
332
- ):
333
-
334
- if keep:
335
- l_fps = list(
336
- filter(
337
- lambda k: "_AvgR_Fz.nii" in k and str_to_keep in k,
338
- os.listdir(path_dir),
339
- )
340
- )
341
- else:
342
- l_fps = list(
343
- filter(
344
- lambda k: "_AvgR_Fz.nii" in k and str_to_omit not in k,
345
- os.listdir(path_dir),
346
- )
347
- )
348
- return l_fps, [
349
- self.load_fingerprint(os.path.join(path_dir, f)) for f in l_fps
350
- ]
351
-
352
- @staticmethod
353
- def save_Nii(
354
- fp: np.array,
355
- affine: np.array,
356
- name: str = "img.nii",
357
- reshape: bool = True,
358
- ):
359
-
360
- if reshape:
361
- fp = np.reshape(fp, (91, 109, 91), order="F")
362
-
363
- img = nib.nifti1.Nifti1Image(fp, affine=affine)
364
-
365
- nib.save(img, name)
366
-
367
- def get_RMAP(self, X: np.array, y: np.array):
368
- # faster than calculate_RMap_numba
369
- # https://stackoverflow.com/questions/71252740/correlating-an-array-row-wise-with-a-vector/71253141#71253141
370
-
371
- r = (
372
- len(y) * np.sum(X * y[None, :], axis=-1)
373
- - (np.sum(X, axis=-1) * np.sum(y))
374
- ) / (
375
- np.sqrt(
376
- (len(y) * np.sum(X**2, axis=-1) - np.sum(X, axis=-1) ** 2)
377
- * (len(y) * np.sum(y**2) - np.sum(y) ** 2)
378
- )
379
- )
380
- return r
381
-
382
- @staticmethod
383
- # @jit(nopython=True)
384
- def calculate_RMap_numba(fp, performances):
385
- # The RMap also needs performances; for every fingerprint / channel
386
- # Save the corresponding performance
387
- # for every voxel; correlate it with performances
388
-
389
- arr = fp[0].flatten()
390
- NUM_VOXELS = arr.shape[0]
391
- LEN_FPS = len(fp)
392
- fp_arr = np.empty((NUM_VOXELS, LEN_FPS))
393
- for fp_idx, fp_ in enumerate(fp):
394
- fp_arr[:, fp_idx] = fp_.flatten()
395
-
396
- RMAP = np.zeros(NUM_VOXELS)
397
- for voxel in range(NUM_VOXELS):
398
- corr_val = np.corrcoef(fp_arr[voxel, :], performances)[0][1]
399
-
400
- RMAP[voxel] = corr_val
401
-
402
- return RMAP
403
-
404
- @staticmethod
405
- # @jit(nopython=True)
406
- def get_corr_numba(fp, fp_test):
407
- val = np.corrcoef(fp_test, fp)[0][1]
408
- return val
409
-
410
- def leave_one_ch_out_cv(
411
- self, l_fps_names: list, l_fps_dat: list, l_per: list
412
- ):
413
- # l_fps_dat is not flattened
414
-
415
- per_left_out = []
416
- per_predict = []
417
-
418
- for idx_left_out, f_left_out in enumerate(l_fps_names):
419
- # print(idx_left_out)
420
- l_cv = l_fps_dat.copy()
421
- per_cv = l_per.copy()
422
-
423
- l_cv.pop(idx_left_out)
424
- per_cv.pop(idx_left_out)
425
-
426
- conn_arr = []
427
- for f in l_cv:
428
- conn_arr.append(f.flatten())
429
- conn_arr = np.array(conn_arr)
430
-
431
- rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv)))
432
-
433
- per_predict.append(
434
- np.nan_to_num(
435
- self.get_corr_numba(
436
- rmap_cv, l_fps_dat[idx_left_out].flatten()
437
- )
438
- )
439
- )
440
- per_left_out.append(l_per[idx_left_out])
441
- return per_left_out, per_predict
442
-
443
- def leave_one_sub_out_cv(
444
- self, l_fps_names: list, l_fps_dat: list, l_per: list, sub_list: list
445
- ):
446
- # l_fps_dat assume non flatted arrays
447
- # each fp including the sub_list string will be iteratively removed for test set
448
-
449
- per_predict = []
450
- per_left_out = []
451
-
452
- for subject_test in sub_list:
453
- # print(subject_test)
454
- idx_test = [
455
- idx for idx, f in enumerate(l_fps_names) if subject_test in f
456
- ]
457
- idx_train = [
458
- idx
459
- for idx, f in enumerate(l_fps_names)
460
- if subject_test not in f
461
- ]
462
- l_cv = list(np.array(l_fps_dat)[idx_train])
463
- per_cv = list(np.array(l_per)[idx_train])
464
-
465
- conn_arr = []
466
- for f in l_cv:
467
- conn_arr.append(f.flatten())
468
- conn_arr = np.array(conn_arr)
469
- rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv)))
470
-
471
- for idx in idx_test:
472
- per_predict.append(
473
- np.nan_to_num(
474
- self.get_corr_numba(rmap_cv, l_fps_dat[idx].flatten())
475
- )
476
- )
477
- per_left_out.append(l_per[idx])
478
- return per_left_out, per_predict
479
-
480
- def get_highest_corr_sub_ch(
481
- self,
482
- cohort_test: str,
483
- sub_test: str,
484
- ch_test: str,
485
- cohorts_train: dict,
486
- path_dir: str = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Connectomics\DecodingToolbox_BerlinPittsburgh_Beijing\functional_connectivity",
487
- ):
488
-
489
- fp_test = self.get_fingerprints_from_path_with_cond(
490
- path_dir=path_dir,
491
- str_to_keep=f"{cohort_test}_{sub_test}_ROI_{ch_test}",
492
- keep=True,
493
- )[1][
494
- 0
495
- ].flatten() # index 1 for getting the array, 0 for the list fp that was found
496
-
497
- fp_pairs = []
498
-
499
- for cohort in cohorts_train.keys():
500
- for sub in cohorts_train[cohort]:
501
- fps_name, fps = self.get_fingerprints_from_path_with_cond(
502
- path_dir=path_dir,
503
- str_to_keep=f"{cohort}_{sub}_ROI",
504
- keep=True,
505
- )
506
-
507
- for fp, fp_name in zip(fps, fps_name):
508
- ch = fp_name[
509
- fp_name.find("ROI") + 4 : fp_name.find("func") - 1
510
- ]
511
- corr_val = self.get_corr_numba(fp_test, fp)
512
- fp_pairs.append([cohort, sub, ch, corr_val])
513
-
514
- idx_max = np.argmax(np.array(fp_pairs)[:, 3])
515
- return fp_pairs[idx_max][0:3]
516
-
517
- def plot_performance_prediction_correlation(
518
- per_left_out, per_predict, out_path_save: str = None
519
- ):
520
- df_plt_corr = pd.DataFrame()
521
- df_plt_corr["test_performance"] = per_left_out
522
- df_plt_corr["struct_conn_predict"] = (
523
- per_predict # change "struct" with "funct" for functional connectivity
524
- )
525
-
526
- nm_plots.reg_plot(
527
- x_col="test_performance",
528
- y_col="struct_conn_predict",
529
- data=df_plt_corr,
530
- out_path_save=out_path_save,
531
- )
1
+ import numpy as np
2
+ from pathlib import PurePath, Path
3
+
4
+
5
+ # from numba import jit
6
+ import scipy.io as sio
7
+ import pandas as pd
8
+ import nibabel as nib
9
+ from matplotlib import pyplot as plt
10
+
11
+ from py_neuromodulation.nm_plots import reg_plot
12
+ from py_neuromodulation.nm_types import _PathLike
13
+ from py_neuromodulation import PYNM_DIR
14
+
15
+ LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
16
+ LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN = [
17
+ 1,
18
+ 8,
19
+ 16,
20
+ 33,
21
+ 34,
22
+ 35,
23
+ 36,
24
+ 37,
25
+ 51,
26
+ 75,
27
+ 77,
28
+ 78,
29
+ 99,
30
+ 109,
31
+ 115,
32
+ 136,
33
+ 155,
34
+ 170,
35
+ 210,
36
+ 215,
37
+ 243,
38
+ 352,
39
+ 359,
40
+ 361,
41
+ 415,
42
+ 416,
43
+ 422,
44
+ 529,
45
+ 567,
46
+ 569,
47
+ 622,
48
+ 623,
49
+ 625,
50
+ 627,
51
+ 632,
52
+ 633,
53
+ 634,
54
+ 635,
55
+ 639,
56
+ 640,
57
+ 641,
58
+ 643,
59
+ 644,
60
+ 650,
61
+ 661,
62
+ 663,
63
+ 667,
64
+ 683,
65
+ 684,
66
+ 685,
67
+ 704,
68
+ 708,
69
+ 722,
70
+ 839,
71
+ 840,
72
+ 905,
73
+ 993,
74
+ 1011,
75
+ ]
76
+
77
+
78
+ class ConnectivityChannelSelector:
79
+ def __init__(
80
+ self,
81
+ whole_brain_connectome: bool = True,
82
+ func_connectivity: bool = True,
83
+ ) -> None:
84
+ """ConnectivityChannelSelector
85
+
86
+ Parameters
87
+ ----------
88
+ whole_brain_connectome : bool, optional
89
+ if True a 1236 whole-brain point grid is chosen,
90
+ if False, a 1025 point grid of the cortical hull is loaded,
91
+ by default True
92
+ func_connectivity : bool, optional
93
+ if true, functional connectivity fMRI is loaded,
94
+ if false structural dMRIby, default True
95
+ """
96
+
97
+ self.connectome_name = self._get_connectome_name(
98
+ whole_brain_connectome, func_connectivity
99
+ )
100
+
101
+ self.whole_brain_connectome = whole_brain_connectome
102
+ self.func_connectivity = func_connectivity
103
+
104
+ self.PATH_CONN_DECODING = PYNM_DIR / "ConnectivityDecoding"
105
+
106
+ if whole_brain_connectome:
107
+ self.PATH_GRID = PurePath(
108
+ self.PATH_CONN_DECODING,
109
+ "mni_coords_whole_brain.mat",
110
+ )
111
+ self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"]
112
+ if not func_connectivity:
113
+ # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN
114
+ self.grid = np.delete(
115
+ self.grid,
116
+ LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN,
117
+ axis=0,
118
+ )
119
+ else:
120
+ self.PATH_GRID = PurePath(
121
+ self.PATH_CONN_DECODING,
122
+ "mni_coords_cortical_surface.mat",
123
+ )
124
+ self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"]
125
+ if not func_connectivity:
126
+ # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL
127
+ self.grid = np.delete(
128
+ self.grid, LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL, axis=0
129
+ )
130
+
131
+ if func_connectivity:
132
+ self.RMAP_arr = nib.load(
133
+ PurePath(self.PATH_CONN_DECODING, "RMAP_func_all.nii")
134
+ ).get_fdata()
135
+ else:
136
+ self.RMAP_arr = nib.load(
137
+ PurePath(self.PATH_CONN_DECODING, "RMAP_struc.nii")
138
+ ).get_fdata()
139
+
140
+ def _get_connectome_name(self, whole_brain_connectome: str, func_connectivity: str):
141
+ connectome_name = "connectome_"
142
+ if whole_brain_connectome:
143
+ connectome_name += "whole_brain_"
144
+ else:
145
+ connectome_name += "hull_"
146
+ if func_connectivity:
147
+ connectome_name += "func"
148
+ else:
149
+ connectome_name += "struc"
150
+ return connectome_name
151
+
152
+ def get_available_connectomes(self) -> list:
153
+ """Return list of saved connectomes in the
154
+ package folder/ConnectivityDecoding/connectome_folder/ folder.
155
+
156
+ Returns
157
+ -------
158
+ list_connectomes: list
159
+ """
160
+ return list(Path(self.PATH_CONN_DECODING, "connectome_folder").iterdir())
161
+
162
+ def plot_grid(self) -> None:
163
+ """Plot the loaded template grid that passed coordinates are matched to."""
164
+
165
+ fig = plt.figure()
166
+ ax = fig.add_subplot(111, projection="3d")
167
+ ax.scatter(self.grid[:, 0], self.grid[:, 1], self.grid[:, 2], s=50, alpha=0.2)
168
+ plt.show()
169
+
170
+ def get_closest_node(self, coord: list | np.ndarray) -> tuple[list, list]:
171
+ """Given a list or np.array of coordinates, return the closest nodes in the
172
+ grid and their indices.
173
+
174
+ Parameters
175
+ ----------
176
+ coord : np.ndarray
177
+ MNI coordinates with shape (num_channels, 3)
178
+
179
+ Returns
180
+ -------
181
+ Tuple[list, list]
182
+ Grid coordinates, grid indices
183
+ """
184
+
185
+ idx_ = []
186
+ for c in coord:
187
+ dist = np.linalg.norm(self.grid - c, axis=1)
188
+ idx_.append(np.argmin(dist))
189
+
190
+ return [self.grid[idx] for idx in idx_], idx_
191
+
192
+ def get_rmap_correlations(
193
+ self, fps: list | np.ndarray, RMAP_use: np.ndarray | None = None
194
+ ) -> list:
195
+ """Calculate correlations of passed fingerprints with the RMAP
196
+
197
+ Parameters
198
+ ----------
199
+ fps : Union[list, np.array]
200
+ List of fingerprints
201
+ RMAP_use : np.ndarray, optional
202
+ Passed RMAP, by default None
203
+
204
+ Returns
205
+ -------
206
+ List
207
+ correlation values
208
+ """
209
+
210
+ RMAP_ = self.RMAP_arr if RMAP_use is None else RMAP_use
211
+ RMAP_ = RMAP_.flatten()
212
+ corrs = []
213
+ for fp in fps:
214
+ corrs.append(np.corrcoef(RMAP_, fp.flatten())[0][1])
215
+ return corrs
216
+
217
+ def load_connectome(
218
+ self,
219
+ whole_brain_connectome: bool | None = None,
220
+ func_connectivity: bool | None = None,
221
+ ) -> None:
222
+ """Load connectome, if not available download connectome from
223
+ Zenodo.
224
+
225
+ Parameters
226
+ ----------
227
+ whole_brain_connectome : bool, optional
228
+ if true whole brain connectome
229
+ if false cortical hull grid connectome, by default None
230
+ func_connectivity : bool, optional
231
+ if true fMRI if false dMRI, by default None
232
+ """
233
+
234
+ if whole_brain_connectome is not None:
235
+ self.whole_brain_connectome = whole_brain_connectome
236
+ if func_connectivity is not None:
237
+ self.func_connectivity = func_connectivity
238
+
239
+ self.connectome_name = self._get_connectome_name(
240
+ self.whole_brain_connectome, self.func_connectivity
241
+ )
242
+
243
+ PATH_CONNECTOME = Path(
244
+ self.PATH_CONN_DECODING,
245
+ "connectome_folder",
246
+ self.connectome_name + ".mat",
247
+ )
248
+
249
+ if not PATH_CONNECTOME.exists():
250
+ user_input = input(
251
+ "Do you want to download the connectome? (yes/no): "
252
+ ).lower()
253
+ if user_input == "yes":
254
+ self._download_connectome()
255
+ elif user_input == "no":
256
+ print("Connectome missing, has to be downloaded")
257
+
258
+ self.connectome = sio.loadmat(PATH_CONNECTOME)
259
+
260
+ def get_grid_fingerprints(self, grid_idx: list | np.ndarray) -> list:
261
+ return [self.connectome[str(grid_idx)] for grid_idx in grid_idx]
262
+
263
+ def download_connectome(
264
+ self,
265
+ ):
266
+
267
+ from urllib.request import urlretrieve
268
+
269
+ # download the connectome from the Zenodo API
270
+ print("Downloading the connectome...")
271
+
272
+ record_id = "10804702"
273
+ file_name = self.connectome_name
274
+
275
+ filepath = Path(self.PATH_CONN_DECODING, "connectome_folder")
276
+ filepath.mkdir(parents=True, exist_ok=True)
277
+
278
+ urlretrieve(
279
+ f"https://zenodo.org/api/records/{record_id}/files/{file_name}/content",
280
+ filepath / f"{self.connectome_name}.mat",
281
+ )
282
+
283
+
284
+ class RMAPCross_Val_ChannelSelector:
285
+ def __init__(self) -> None:
286
+ pass
287
+
288
+ def load_fingerprint(self, path_nii) -> None:
289
+ """Return Nifti fingerprint"""
290
+ epi_img = nib.load(path_nii)
291
+ self.affine = epi_img.affine
292
+ fp = epi_img.get_fdata()
293
+ return fp
294
+
295
+ def load_all_fingerprints(self, path_dir: str, cond_str: str = "_AvgR_Fz.nii"):
296
+ if cond_str is not None:
297
+ l_fps = list(filter(lambda k: cond_str in str(k), Path(path_dir).iterdir()))
298
+ else:
299
+ l_fps = list(Path(path_dir).iterdir())
300
+
301
+ return l_fps, [self.load_fingerprint(PurePath(path_dir, f)) for f in l_fps]
302
+
303
+ def get_fingerprints_from_path_with_cond(
304
+ self,
305
+ path_dir: _PathLike,
306
+ str_to_omit: str = "",
307
+ str_to_keep: str = "",
308
+ keep: bool = True,
309
+ ) -> tuple[list, list]:
310
+ l_fps = []
311
+
312
+ if keep and str_to_keep:
313
+ l_fps = list(
314
+ filter(
315
+ lambda k: "_AvgR_Fz.nii" in str(k) and str_to_keep in str(k),
316
+ Path(path_dir).iterdir(),
317
+ )
318
+ )
319
+
320
+ elif not keep and str_to_omit:
321
+ l_fps = list(
322
+ filter(
323
+ lambda k: "_AvgR_Fz.nii" in str(k) and str_to_omit not in str(k),
324
+ Path(path_dir).iterdir(),
325
+ )
326
+ )
327
+
328
+ return l_fps, [self.load_fingerprint(PurePath(path_dir, f)) for f in l_fps]
329
+
330
+ @staticmethod
331
+ def save_Nii(
332
+ fp: np.ndarray,
333
+ affine: np.ndarray,
334
+ name: str = "img.nii",
335
+ reshape: bool = True,
336
+ ):
337
+ if reshape:
338
+ fp = np.reshape(fp, (91, 109, 91), order="F")
339
+
340
+ img = nib.nifti1.Nifti1Image(fp, affine=affine)
341
+
342
+ nib.save(img, name)
343
+
344
+ def get_RMAP(self, X: np.ndarray, y: np.ndarray):
345
+ # faster than calculate_RMap_numba
346
+ # https://stackoverflow.com/questions/71252740/correlating-an-array-row-wise-with-a-vector/71253141#71253141
347
+
348
+ r = (
349
+ len(y) * np.sum(X * y[None, :], axis=-1) - (np.sum(X, axis=-1) * np.sum(y))
350
+ ) / (
351
+ np.sqrt(
352
+ (len(y) * np.sum(X**2, axis=-1) - np.sum(X, axis=-1) ** 2)
353
+ * (len(y) * np.sum(y**2) - np.sum(y) ** 2)
354
+ )
355
+ )
356
+ return r
357
+
358
+ @staticmethod
359
+ # @jit(nopython=True)
360
+ def calculate_RMap_numba(fp, performances):
361
+ # The RMap also needs performances; for every fingerprint / channel
362
+ # Save the corresponding performance
363
+ # for every voxel; correlate it with performances
364
+
365
+ arr = fp[0].flatten()
366
+ NUM_VOXELS = arr.shape[0]
367
+ LEN_FPS = len(fp)
368
+ fp_arr = np.empty((NUM_VOXELS, LEN_FPS))
369
+ for fp_idx, fp_ in enumerate(fp):
370
+ fp_arr[:, fp_idx] = fp_.flatten()
371
+
372
+ RMAP = np.zeros(NUM_VOXELS)
373
+ for voxel in range(NUM_VOXELS):
374
+ corr_val = np.corrcoef(fp_arr[voxel, :], performances)[0][1]
375
+
376
+ RMAP[voxel] = corr_val
377
+
378
+ return RMAP
379
+
380
+ @staticmethod
381
+ # @jit(nopython=True)
382
+ def get_corr_numba(fp, fp_test):
383
+ val = np.corrcoef(fp_test, fp)[0][1]
384
+ return val
385
+
386
+ def leave_one_ch_out_cv(self, l_fps_names: list, l_fps_dat: list, l_per: list):
387
+ # l_fps_dat is not flattened
388
+
389
+ per_left_out = []
390
+ per_predict = []
391
+
392
+ for idx_left_out, f_left_out in enumerate(l_fps_names):
393
+ # print(idx_left_out)
394
+ l_cv = l_fps_dat.copy()
395
+ per_cv = l_per.copy()
396
+
397
+ l_cv.pop(idx_left_out)
398
+ per_cv.pop(idx_left_out)
399
+
400
+ conn_arr = []
401
+ for f in l_cv:
402
+ conn_arr.append(f.flatten())
403
+ conn_arr = np.array(conn_arr)
404
+
405
+ rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv)))
406
+
407
+ per_predict.append(
408
+ np.nan_to_num(
409
+ self.get_corr_numba(rmap_cv, l_fps_dat[idx_left_out].flatten())
410
+ )
411
+ )
412
+ per_left_out.append(l_per[idx_left_out])
413
+ return per_left_out, per_predict
414
+
415
+ def leave_one_sub_out_cv(
416
+ self, l_fps_names: list, l_fps_dat: list, l_per: list, sub_list: list
417
+ ):
418
+ # l_fps_dat assume non flatted arrays
419
+ # each fp including the sub_list string will be iteratively removed for test set
420
+
421
+ per_predict = []
422
+ per_left_out = []
423
+
424
+ for subject_test in sub_list:
425
+ # print(subject_test)
426
+ idx_test = [idx for idx, f in enumerate(l_fps_names) if subject_test in f]
427
+ idx_train = [
428
+ idx for idx, f in enumerate(l_fps_names) if subject_test not in f
429
+ ]
430
+ l_cv = list(np.array(l_fps_dat)[idx_train])
431
+ per_cv = list(np.array(l_per)[idx_train])
432
+
433
+ conn_arr = []
434
+ for f in l_cv:
435
+ conn_arr.append(f.flatten())
436
+ conn_arr = np.array(conn_arr)
437
+ rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv)))
438
+
439
+ for idx in idx_test:
440
+ per_predict.append(
441
+ np.nan_to_num(
442
+ self.get_corr_numba(rmap_cv, l_fps_dat[idx].flatten())
443
+ )
444
+ )
445
+ per_left_out.append(l_per[idx])
446
+ return per_left_out, per_predict
447
+
448
+ def get_highest_corr_sub_ch(
449
+ self,
450
+ cohort_test: str,
451
+ sub_test: str,
452
+ ch_test: str,
453
+ cohorts_train: dict,
454
+ path_dir: str = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Connectomics\DecodingToolbox_BerlinPittsburgh_Beijing\functional_connectivity",
455
+ ):
456
+ fp_test = self.get_fingerprints_from_path_with_cond(
457
+ path_dir=path_dir,
458
+ str_to_keep=f"{cohort_test}_{sub_test}_ROI_{ch_test}",
459
+ keep=True,
460
+ )[1][
461
+ 0
462
+ ].flatten() # index 1 for getting the array, 0 for the list fp that was found
463
+
464
+ fp_pairs = []
465
+
466
+ for cohort in cohorts_train.keys():
467
+ for sub in cohorts_train[cohort]:
468
+ fps_name, fps = self.get_fingerprints_from_path_with_cond(
469
+ path_dir=path_dir,
470
+ str_to_keep=f"{cohort}_{sub}_ROI",
471
+ keep=True,
472
+ )
473
+
474
+ for fp, fp_name in zip(fps, fps_name):
475
+ ch = fp_name[fp_name.find("ROI") + 4 : fp_name.find("func") - 1]
476
+ corr_val = self.get_corr_numba(fp_test, fp)
477
+ fp_pairs.append([cohort, sub, ch, corr_val])
478
+
479
+ idx_max = np.argmax(np.array(fp_pairs)[:, 3])
480
+ return fp_pairs[idx_max][0:3]
481
+
482
+ def plot_performance_prediction_correlation(
483
+ per_left_out, per_predict, out_path_save: str | None = None
484
+ ):
485
+ df_plt_corr = pd.DataFrame()
486
+ df_plt_corr["test_performance"] = per_left_out
487
+ df_plt_corr["struct_conn_predict"] = (
488
+ per_predict # change "struct" with "funct" for functional connectivity
489
+ )
490
+
491
+ reg_plot(
492
+ x_col="test_performance",
493
+ y_col="struct_conn_predict",
494
+ data=df_plt_corr,
495
+ out_path_save=out_path_save,
496
+ )