py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/__init__.py +80 -13
  5. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
  6. py_neuromodulation/analysis/__init__.py +4 -0
  7. py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
  8. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
  9. py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
  10. py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
  11. py_neuromodulation/data/README +6 -6
  12. py_neuromodulation/data/dataset_description.json +8 -8
  13. py_neuromodulation/data/participants.json +32 -32
  14. py_neuromodulation/data/participants.tsv +2 -2
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  22. py_neuromodulation/default_settings.yaml +241 -0
  23. py_neuromodulation/features/__init__.py +31 -0
  24. py_neuromodulation/features/bandpower.py +165 -0
  25. py_neuromodulation/features/bispectra.py +157 -0
  26. py_neuromodulation/features/bursts.py +297 -0
  27. py_neuromodulation/features/coherence.py +255 -0
  28. py_neuromodulation/features/feature_processor.py +121 -0
  29. py_neuromodulation/features/fooof.py +142 -0
  30. py_neuromodulation/features/hjorth_raw.py +57 -0
  31. py_neuromodulation/features/linelength.py +21 -0
  32. py_neuromodulation/features/mne_connectivity.py +148 -0
  33. py_neuromodulation/features/nolds.py +94 -0
  34. py_neuromodulation/features/oscillatory.py +249 -0
  35. py_neuromodulation/features/sharpwaves.py +432 -0
  36. py_neuromodulation/filter/__init__.py +3 -0
  37. py_neuromodulation/filter/kalman_filter.py +67 -0
  38. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  39. py_neuromodulation/filter/mne_filter.py +128 -0
  40. py_neuromodulation/filter/notch_filter.py +93 -0
  41. py_neuromodulation/grid_cortex.tsv +40 -40
  42. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  43. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  44. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  45. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  46. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  47. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  48. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  49. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  50. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  51. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  52. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  53. py_neuromodulation/processing/__init__.py +10 -0
  54. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
  55. py_neuromodulation/processing/data_preprocessor.py +77 -0
  56. py_neuromodulation/processing/filter_preprocessing.py +78 -0
  57. py_neuromodulation/processing/normalization.py +175 -0
  58. py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
  59. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
  60. py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
  61. py_neuromodulation/stream/__init__.py +3 -0
  62. py_neuromodulation/stream/data_processor.py +325 -0
  63. py_neuromodulation/stream/generator.py +53 -0
  64. py_neuromodulation/stream/mnelsl_player.py +94 -0
  65. py_neuromodulation/stream/mnelsl_stream.py +120 -0
  66. py_neuromodulation/stream/settings.py +292 -0
  67. py_neuromodulation/stream/stream.py +427 -0
  68. py_neuromodulation/utils/__init__.py +2 -0
  69. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
  70. py_neuromodulation/utils/database.py +149 -0
  71. py_neuromodulation/utils/io.py +378 -0
  72. py_neuromodulation/utils/keyboard.py +52 -0
  73. py_neuromodulation/utils/logging.py +66 -0
  74. py_neuromodulation/utils/types.py +251 -0
  75. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
  76. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  77. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
  78. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
  79. py_neuromodulation/FieldTrip.py +0 -589
  80. py_neuromodulation/_write_example_dataset_helper.py +0 -65
  81. py_neuromodulation/nm_EpochStream.py +0 -92
  82. py_neuromodulation/nm_IO.py +0 -417
  83. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  84. py_neuromodulation/nm_bispectra.py +0 -168
  85. py_neuromodulation/nm_bursts.py +0 -198
  86. py_neuromodulation/nm_coherence.py +0 -205
  87. py_neuromodulation/nm_cohortwrapper.py +0 -435
  88. py_neuromodulation/nm_eval_timing.py +0 -239
  89. py_neuromodulation/nm_features.py +0 -116
  90. py_neuromodulation/nm_features_abc.py +0 -39
  91. py_neuromodulation/nm_filter.py +0 -219
  92. py_neuromodulation/nm_filter_preprocessing.py +0 -91
  93. py_neuromodulation/nm_fooof.py +0 -159
  94. py_neuromodulation/nm_generator.py +0 -37
  95. py_neuromodulation/nm_hjorth_raw.py +0 -73
  96. py_neuromodulation/nm_kalmanfilter.py +0 -58
  97. py_neuromodulation/nm_linelength.py +0 -33
  98. py_neuromodulation/nm_mne_connectivity.py +0 -112
  99. py_neuromodulation/nm_nolds.py +0 -93
  100. py_neuromodulation/nm_normalization.py +0 -214
  101. py_neuromodulation/nm_oscillatory.py +0 -448
  102. py_neuromodulation/nm_run_analysis.py +0 -435
  103. py_neuromodulation/nm_settings.json +0 -338
  104. py_neuromodulation/nm_settings.py +0 -68
  105. py_neuromodulation/nm_sharpwaves.py +0 -401
  106. py_neuromodulation/nm_stream_abc.py +0 -218
  107. py_neuromodulation/nm_stream_offline.py +0 -359
  108. py_neuromodulation/utils/_logging.py +0 -24
  109. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,612 +1,627 @@
1
- from scipy import stats
2
- import os
3
- import numpy as np
4
- from matplotlib import pyplot as plt
5
- from matplotlib import gridspec
6
- from typing import Optional
7
- import seaborn as sb
8
- import pandas as pd
9
- import logging
10
-
11
- logger = logging.getLogger("PynmLogger")
12
-
13
- from py_neuromodulation import nm_IO, nm_stats
14
-
15
-
16
- def plot_df_subjects(
17
- df,
18
- x_col="sub",
19
- y_col="performance_test",
20
- hue=None,
21
- title="channel specific performances",
22
- PATH_SAVE: str = None,
23
- figsize_tuple: tuple = (5, 3),
24
- ):
25
- alpha_box = 0.4
26
- plt.figure(figsize=figsize_tuple, dpi=300)
27
- sb.boxplot(
28
- x=x_col,
29
- y=y_col,
30
- hue=hue,
31
- data=df,
32
- palette="viridis",
33
- showmeans=False,
34
- boxprops=dict(alpha=alpha_box),
35
- showcaps=True,
36
- showbox=True,
37
- showfliers=False,
38
- notch=False,
39
- whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
40
- capprops={"alpha": alpha_box},
41
- medianprops=dict(
42
- linestyle="-", linewidth=5, color="gray", alpha=alpha_box
43
- ),
44
- )
45
-
46
- ax = sb.stripplot(
47
- x=x_col,
48
- y=y_col,
49
- hue=hue,
50
- data=df,
51
- palette="viridis",
52
- dodge=True,
53
- s=5,
54
- )
55
-
56
- if hue is not None:
57
- n_hues = df[hue].nunique()
58
-
59
- handles, labels = ax.get_legend_handles_labels()
60
- l = plt.legend(
61
- handles[0:n_hues],
62
- labels[0:n_hues],
63
- bbox_to_anchor=(1.05, 1),
64
- loc=2,
65
- title=hue,
66
- borderaxespad=0.0,
67
- )
68
- plt.title(title)
69
- plt.ylabel(y_col)
70
- plt.xticks(rotation=90)
71
- if PATH_SAVE is not None:
72
- plt.savefig(
73
- PATH_SAVE,
74
- bbox_inches="tight",
75
- )
76
- # plt.show()
77
- return plt.gca()
78
-
79
-
80
- def plot_epoch(
81
- X_epoch: np.array,
82
- y_epoch: np.array,
83
- feature_names: list,
84
- z_score: bool = None,
85
- epoch_len: int = 4,
86
- sfreq: int = 10,
87
- str_title: str = None,
88
- str_label: str = None,
89
- ytick_labelsize: float = None,
90
- ):
91
- if z_score is None:
92
- X_epoch = stats.zscore(
93
- np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
94
- axis=0,
95
- nan_policy="omit",
96
- ).T
97
- y_epoch = np.stack(np.array(y_epoch))
98
- plt.figure(figsize=(6, 6))
99
- plt.subplot(211)
100
- plt.imshow(X_epoch, aspect="auto")
101
- plt.yticks(
102
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
103
- )
104
- plt.xticks(
105
- np.arange(0, X_epoch.shape[1], 1),
106
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
107
- rotation=90,
108
- )
109
- plt.gca().invert_yaxis()
110
- plt.xlabel("Time [s]")
111
- plt.title(str_title)
112
-
113
- plt.subplot(212)
114
- for i in range(y_epoch.shape[0]):
115
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
116
- plt.plot(
117
- y_epoch.mean(axis=0),
118
- color="black",
119
- alpha=1,
120
- linewidth=3.0,
121
- label="mean target",
122
- )
123
- plt.legend()
124
- plt.ylabel("Target")
125
- plt.title(str_label)
126
- plt.xticks(
127
- np.arange(0, X_epoch.shape[1], 1),
128
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
129
- rotation=90,
130
- )
131
- plt.xlabel("Time [s]")
132
- plt.tight_layout()
133
-
134
-
135
- def reg_plot(
136
- x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str = None
137
- ):
138
- plt.figure(figsize=(4, 4), dpi=300)
139
- rho, p = nm_stats.permutationTestSpearmansRho(
140
- data[x_col],
141
- data[y_col],
142
- False,
143
- "R^2",
144
- 5000,
145
- )
146
- sb.regplot(x=x_col, y=y_col, data=data)
147
- plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
148
-
149
- if out_path_save is not None:
150
- plt.savefig(
151
- out_path_save,
152
- bbox_inches="tight",
153
- )
154
-
155
-
156
- def plot_bar_performance_per_channel(
157
- ch_names,
158
- performances: dict,
159
- PATH_OUT: str,
160
- sub: str = None,
161
- save_str: str = "ch_comp_bar_plt.png",
162
- performance_metric: str = "Balanced Accuracy",
163
- ):
164
- """
165
- performances dict is output of ml_decode
166
- """
167
- plt.figure(figsize=(4, 3), dpi=300)
168
- if sub is None:
169
- sub = list(performances.keys())[0]
170
- plt.bar(
171
- np.arange(len(ch_names)),
172
- [performances[sub][p]["performance_test"] for p in performances[sub]],
173
- )
174
- plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
175
- plt.xlabel("channels")
176
- plt.ylabel(performance_metric)
177
- plt.savefig(
178
- os.path.join(PATH_OUT, save_str),
179
- bbox_inches="tight",
180
- )
181
- plt.close()
182
-
183
-
184
- def plot_corr_matrix(
185
- feature: pd.DataFrame,
186
- feature_file: str = None,
187
- ch_name: str = None,
188
- feature_names: list[str] = None,
189
- show_plot=True,
190
- OUT_PATH: str = None,
191
- feature_name_plt="Features_corr_matr",
192
- save_plot: bool = False,
193
- save_plot_name: str = None,
194
- figsize: tuple[int] = (7, 7),
195
- title: str = None,
196
- cbar_vmin: float = -1,
197
- cbar_vmax: float = 1.0,
198
- ):
199
- # cut out channel name for each column
200
- if ch_name is not None:
201
- feature_col_name = [
202
- i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
203
- ]
204
- else:
205
- feature_col_name = feature.columns
206
-
207
- plt.figure(figsize=figsize)
208
- if feature_names is not None:
209
- corr = feature[feature_names].corr()
210
- else:
211
- corr = feature.corr()
212
- sb.heatmap(
213
- corr,
214
- xticklabels=feature_col_name,
215
- yticklabels=feature_col_name,
216
- vmin=cbar_vmin,
217
- vmax=cbar_vmax,
218
- cmap="viridis",
219
- )
220
- if title is None:
221
- if ch_name is not None:
222
- plt.title("Correlation matrix features channel: " + str(ch_name))
223
- else:
224
- plt.title("Correlation matrix")
225
- else:
226
- plt.title(title)
227
-
228
- # if len(feature_col_name) > 50:
229
- # plt.xticks([])
230
- # plt.yticks([])
231
-
232
- if save_plot and save_plot_name is None:
233
- plt_path = get_plt_path(
234
- OUT_PATH=OUT_PATH,
235
- feature_file=feature_file,
236
- ch_name=ch_name,
237
- str_plt_type=feature_name_plt,
238
- # feature_name=feature_names.__str__, # This here raises an error in os.path.join in line 251
239
- )
240
- if save_plot and save_plot_name is not None:
241
- plt_path = os.path.join(OUT_PATH, save_plot_name)
242
-
243
- if save_plot:
244
- plt.savefig(plt_path, bbox_inches="tight")
245
- logger.info(f"Correlation matrix figure saved to {plt_path}")
246
-
247
- if show_plot is False:
248
- plt.close()
249
-
250
- plt.tight_layout()
251
-
252
- return plt.gca()
253
-
254
-
255
- def plot_feature_series_time(features) -> None:
256
- plt.imshow(features.T, aspect="auto")
257
-
258
-
259
- def get_plt_path(
260
- OUT_PATH: str | None = None,
261
- feature_file: str | None = None,
262
- ch_name: str | None = None,
263
- str_plt_type: str | None = None,
264
- feature_name: str | None = None,
265
- ) -> None:
266
- """[summary]
267
-
268
- Parameters
269
- ----------
270
- OUT_PATH : str, optional
271
- folder of preprocessed runs, by default None
272
- feature_file : str, optional
273
- run_name, by default None
274
- ch_name : str, optional
275
- ch_name, by default None
276
- str_plt_type : str, optional
277
- type of plot, e.g. mov_avg_feature or corr_matr, by default None
278
- feature_name : str, optional
279
- e.g. bandpower, stft, sharpwave_prominence, by default None
280
- """
281
- if None not in (ch_name, OUT_PATH, feature_file):
282
- if feature_name is None:
283
- plt_path = os.path.join(
284
- OUT_PATH,
285
- feature_file,
286
- str_plt_type + "_ch_" + ch_name + ".png",
287
- )
288
- else:
289
- plt_path = os.path.join(
290
- OUT_PATH,
291
- feature_file,
292
- str_plt_type + "_ch_" + ch_name + "_" + feature_name + ".png",
293
- )
294
- elif None not in (OUT_PATH, feature_file) and ch_name is None:
295
- plt_path = os.path.join(
296
- OUT_PATH,
297
- feature_file,
298
- str_plt_type + "_ch_" + feature_name + ".png",
299
- )
300
-
301
- else:
302
- plt_path = os.getcwd() + ".png"
303
- return plt_path
304
-
305
-
306
- def plot_epochs_avg(
307
- X_epoch: np.ndarray,
308
- y_epoch: np.ndarray,
309
- epoch_len: int,
310
- sfreq: int,
311
- feature_names: list[str] = None,
312
- feature_str_add: str = None,
313
- cut_ch_name_cols: bool = True,
314
- ch_name: str = None,
315
- label_name: str = None,
316
- normalize_data: bool = True,
317
- show_plot: bool = True,
318
- save: bool = False,
319
- OUT_PATH: str = None,
320
- feature_file: str = None,
321
- str_title: str = "Movement aligned features",
322
- ytick_labelsize=None,
323
- figsize_x: float = 8,
324
- figsize_y: float = 8,
325
- ) -> None:
326
- # cut channel name of for axis + "_" for more dense plot
327
- if feature_names is None:
328
- if cut_ch_name_cols and None not in (ch_name, feature_names):
329
- feature_names = [
330
- i[len(ch_name) + 1 :]
331
- for i in list(feature_names)
332
- if ch_name in i
333
- ]
334
-
335
- if normalize_data:
336
- X_epoch_mean = stats.zscore(
337
- np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
338
- ).T
339
- else:
340
- X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
341
-
342
- if len(X_epoch_mean.shape) == 1:
343
- X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
344
-
345
- plt.figure(figsize=(figsize_x, figsize_y))
346
- gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
347
- plt.subplot(gs[0])
348
- plt.imshow(X_epoch_mean, aspect="auto")
349
- plt.yticks(
350
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
351
- )
352
- plt.xticks(
353
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
354
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
355
- rotation=90,
356
- )
357
- plt.xlabel("Time [s]")
358
- str_title = str_title
359
- if ch_name:
360
- str_title += f" channel: {ch_name}"
361
- plt.title(str_title)
362
-
363
- plt.subplot(gs[1])
364
- for i in range(y_epoch.shape[0]):
365
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
366
- plt.plot(
367
- y_epoch.mean(axis=0),
368
- color="black",
369
- alpha=1,
370
- linewidth=3.0,
371
- label="mean target",
372
- )
373
- plt.legend()
374
- plt.ylabel("Target")
375
- plt.title(label_name)
376
- plt.xticks(
377
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
378
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
379
- rotation=90,
380
- )
381
- plt.xlabel("Time [s]")
382
- plt.tight_layout()
383
-
384
- if save:
385
- plt_path = get_plt_path(
386
- OUT_PATH,
387
- feature_file,
388
- ch_name,
389
- str_plt_type="MOV_aligned_features",
390
- feature_name=feature_str_add,
391
- )
392
- plt.savefig(plt_path, bbox_inches="tight")
393
- logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
394
- if show_plot is False:
395
- plt.close()
396
-
397
-
398
- def plot_grid_elec_3d(
399
- cortex_grid: np.ndarray | None = None,
400
- ecog_strip: np.ndarray | None = None,
401
- grid_color: np.ndarray | None = None,
402
- strip_color: np.ndarray | None = None,
403
- ):
404
- ax = plt.axes(projection="3d")
405
-
406
- if cortex_grid is not None:
407
- grid_color = (
408
- np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
409
- )
410
- _ = ax.scatter3D(
411
- cortex_grid[:, 0],
412
- cortex_grid[:, 1],
413
- cortex_grid[:, 2],
414
- c=grid_color,
415
- s=300,
416
- alpha=0.8,
417
- cmap="viridis",
418
- )
419
-
420
- if ecog_strip is not None:
421
- strip_color = (
422
- np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
423
- )
424
- _ = ax.scatter(
425
- ecog_strip[:, 0],
426
- ecog_strip[:, 1],
427
- ecog_strip[:, 2],
428
- c=strip_color,
429
- s=500,
430
- alpha=0.8,
431
- cmap="gray",
432
- marker="o",
433
- )
434
-
435
-
436
- def plot_all_features(
437
- df: pd.DataFrame,
438
- time_limit_low_s: float = None,
439
- time_limit_high_s: float = None,
440
- normalize: bool = True,
441
- ytick_labelsize: int = 4,
442
- clim_low: float = None,
443
- clim_high: float = None,
444
- save: bool = False,
445
- title="all_feature_plt.pdf",
446
- OUT_PATH: str = None,
447
- feature_file: str = None,
448
- ):
449
- if time_limit_high_s is not None:
450
- df = df[df["time"] < time_limit_high_s * 1000]
451
- if time_limit_low_s is not None:
452
- df = df[df["time"] > time_limit_low_s * 1000]
453
-
454
- cols_plt = [c for c in df.columns if c != "time"]
455
- if normalize is True:
456
- data_plt = stats.zscore(df[cols_plt], nan_policy="omit")
457
- else:
458
- data_plt = df[cols_plt]
459
-
460
- plt.figure() # figsize=(7, 5), dpi=300
461
- plt.imshow(data_plt.T, aspect="auto")
462
- plt.xlabel("Time [s]")
463
- plt.ylabel("Feature Names")
464
- plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
465
-
466
- tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
467
- tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
468
- plt.xticks(tick_num, tick_labels)
469
-
470
- plt.title(f"Feature Plot {feature_file}")
471
-
472
- if clim_low is not None:
473
- plt.clim(vmin=clim_low)
474
- if clim_high is not None:
475
- plt.clim(vmax=clim_high)
476
-
477
- plt.colorbar()
478
- plt.tight_layout()
479
-
480
- if save is True:
481
- plt_path = os.path.join(OUT_PATH, feature_file, title)
482
- plt.savefig(plt_path, bbox_inches="tight")
483
-
484
-
485
- class NM_Plot:
486
- def __init__(
487
- self,
488
- ecog_strip: np.ndarray | None = None,
489
- grid_cortex: np.ndarray | None = None,
490
- grid_subcortex: np.ndarray | None = None,
491
- sess_right: Optional[bool] = False,
492
- proj_matrix_cortex: np.ndarray | None = None,
493
- ) -> None:
494
- self.grid_cortex = grid_cortex
495
- self.grid_subcortex = grid_subcortex
496
- self.ecog_strip = ecog_strip
497
- self.sess_right = sess_right
498
- self.proj_matrix_cortex = proj_matrix_cortex
499
-
500
- (
501
- self.faces,
502
- self.vertices,
503
- self.grid,
504
- self.stn_surf,
505
- self.x_ver,
506
- self.y_ver,
507
- self.x_ecog,
508
- self.y_ecog,
509
- self.z_ecog,
510
- self.x_stn,
511
- self.y_stn,
512
- self.z_stn,
513
- ) = nm_IO.read_plot_modules()
514
-
515
- def plot_grid_elec_3d(self) -> None:
516
- plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
517
-
518
- def plot_cortex(
519
- self,
520
- grid_cortex: Optional[np.ndarray] = None,
521
- grid_color: Optional[np.ndarray] = None,
522
- ecog_strip: Optional[np.ndarray] = None,
523
- strip_color: Optional[np.ndarray] = None,
524
- sess_right: Optional[bool] = None,
525
- save: bool = False,
526
- OUT_PATH: str = None,
527
- feature_file: str = None,
528
- feature_str_add: str = None,
529
- show_plot: bool = True,
530
- title: str = "Cortical grid",
531
- set_clim: bool = True,
532
- lower_clim: float = 0.5,
533
- upper_clim: float = 0.7,
534
- cbar_label: str = "Balanced Accuracy",
535
- ):
536
- """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
537
- Colorcoded by grid_color
538
- """
539
-
540
- if grid_cortex is None:
541
- if type(self.grid_cortex) is pd.DataFrame:
542
- grid_cortex = np.array(self.grid_cortex)
543
- else:
544
- grid_cortex = self.grid_cortex
545
-
546
- if ecog_strip is None:
547
- ecog_strip = self.ecog_strip
548
-
549
- if sess_right is True:
550
- grid_cortex[0, :] = grid_cortex[0, :] * -1
551
-
552
- fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
553
- axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
554
- axes.axes.set_aspect("equal", anchor="C")
555
-
556
- if grid_cortex is not None:
557
- grid_color = (
558
- np.ones(grid_cortex.shape[0])
559
- if grid_color is None
560
- else grid_color
561
- )
562
-
563
- pos_ecog = axes.scatter(
564
- grid_cortex[:, 0],
565
- grid_cortex[:, 1],
566
- c=grid_color,
567
- s=150,
568
- alpha=0.8,
569
- cmap="viridis",
570
- label="grid points",
571
- )
572
- if set_clim:
573
- pos_ecog.set_clim(lower_clim, upper_clim)
574
- if ecog_strip is not None:
575
- strip_color = (
576
- np.ones(ecog_strip.shape[0])
577
- if strip_color is None
578
- else strip_color
579
- )
580
-
581
- pos_ecog = axes.scatter(
582
- ecog_strip[:, 0],
583
- ecog_strip[:, 1],
584
- c=strip_color,
585
- s=400,
586
- alpha=0.8,
587
- cmap="viridis",
588
- marker="x",
589
- label="ecog electrode",
590
- )
591
- plt.axis("off")
592
- plt.legend()
593
- plt.title(title)
594
- if set_clim:
595
- pos_ecog.set_clim(lower_clim, upper_clim)
596
- cbar = fig.colorbar(pos_ecog)
597
- cbar.set_label(cbar_label)
598
-
599
- if save:
600
- plt_path = get_plt_path(
601
- OUT_PATH,
602
- feature_file,
603
- ch_name=None,
604
- str_plt_type="PLOT_CORTEX",
605
- feature_name=feature_str_add,
606
- )
607
- plt.savefig(plt_path, bbox_inches="tight")
608
- logger.info(
609
- f"Feature epoch average figure saved to: {str(plt_path)}"
610
- )
611
- if show_plot is False:
612
- plt.close()
1
+ import numpy as np
2
+ import pandas as pd
3
+ from matplotlib import pyplot as plt
4
+ from matplotlib import gridspec
5
+ import seaborn as sb
6
+ from pathlib import PurePath
7
+ from py_neuromodulation import logger, PYNM_DIR
8
+ from py_neuromodulation.utils.types import _PathLike
9
+
10
+
11
+ def plot_df_subjects(
12
+ df,
13
+ x_col="sub",
14
+ y_col="performance_test",
15
+ hue=None,
16
+ title="channel specific performances",
17
+ PATH_SAVE: _PathLike = "",
18
+ figsize_tuple: tuple[float, float] = (5, 3),
19
+ ):
20
+ alpha_box = 0.4
21
+ plt.figure(figsize=figsize_tuple, dpi=300)
22
+ sb.boxplot(
23
+ x=x_col,
24
+ y=y_col,
25
+ hue=hue,
26
+ data=df,
27
+ palette="viridis",
28
+ showmeans=False,
29
+ boxprops=dict(alpha=alpha_box),
30
+ showcaps=True,
31
+ showbox=True,
32
+ showfliers=False,
33
+ notch=False,
34
+ whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
35
+ capprops={"alpha": alpha_box},
36
+ medianprops=dict(linestyle="-", linewidth=5, color="gray", alpha=alpha_box),
37
+ )
38
+
39
+ ax = sb.stripplot(
40
+ x=x_col,
41
+ y=y_col,
42
+ hue=hue,
43
+ data=df,
44
+ palette="viridis",
45
+ dodge=True,
46
+ s=5,
47
+ )
48
+
49
+ if hue is not None:
50
+ n_hues = df[hue].nunique()
51
+
52
+ handles, labels = ax.get_legend_handles_labels()
53
+ plt.legend(
54
+ handles[0:n_hues],
55
+ labels[0:n_hues],
56
+ bbox_to_anchor=(1.05, 1),
57
+ loc=2,
58
+ title=hue,
59
+ borderaxespad=0.0,
60
+ )
61
+ plt.title(title)
62
+ plt.ylabel(y_col)
63
+ plt.xticks(rotation=90)
64
+ if PATH_SAVE:
65
+ plt.savefig(
66
+ PATH_SAVE,
67
+ bbox_inches="tight",
68
+ )
69
+ # plt.show()
70
+ return plt.gca()
71
+
72
+
73
+ def plot_epoch(
74
+ X_epoch: np.ndarray,
75
+ y_epoch: np.ndarray,
76
+ feature_names: list,
77
+ z_score: bool | None = None,
78
+ epoch_len: int = 4,
79
+ sfreq: int = 10,
80
+ str_title: str = "",
81
+ str_label: str = "",
82
+ ytick_labelsize: float | None = None,
83
+ ):
84
+ from scipy.stats import zscore
85
+
86
+ if z_score is None:
87
+ X_epoch = zscore(
88
+ np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
89
+ axis=0,
90
+ nan_policy="omit",
91
+ ).T
92
+ y_epoch = np.stack([np.array(y_epoch)])
93
+ plt.figure(figsize=(6, 6))
94
+ plt.subplot(211)
95
+ plt.imshow(X_epoch, aspect="auto")
96
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
97
+ plt.xticks(
98
+ np.arange(0, X_epoch.shape[1], 1),
99
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
100
+ rotation=90,
101
+ )
102
+ plt.gca().invert_yaxis()
103
+ plt.xlabel("Time [s]")
104
+ plt.title(str_title)
105
+
106
+ plt.subplot(212)
107
+ for i in range(y_epoch.shape[0]):
108
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
109
+ plt.plot(
110
+ y_epoch.mean(axis=0),
111
+ color="black",
112
+ alpha=1,
113
+ linewidth=3.0,
114
+ label="mean target",
115
+ )
116
+ plt.legend()
117
+ plt.ylabel("Target")
118
+ plt.title(str_label)
119
+ plt.xticks(
120
+ np.arange(0, X_epoch.shape[1], 1),
121
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
122
+ rotation=90,
123
+ )
124
+ plt.xlabel("Time [s]")
125
+ plt.tight_layout()
126
+
127
+
128
+ def reg_plot(
129
+ x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str | None = None
130
+ ):
131
+ from py_neuromodulation.analysis.stats import permutationTestSpearmansRho
132
+
133
+ plt.figure(figsize=(4, 4), dpi=300)
134
+ rho, p = permutationTestSpearmansRho(
135
+ data[x_col],
136
+ data[y_col],
137
+ False,
138
+ "R^2",
139
+ 5000,
140
+ )
141
+ sb.regplot(x=x_col, y=y_col, data=data)
142
+ plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
143
+
144
+ if out_path_save is not None:
145
+ plt.savefig(
146
+ out_path_save,
147
+ bbox_inches="tight",
148
+ )
149
+
150
+
151
+ def plot_bar_performance_per_channel(
152
+ ch_names,
153
+ performances: dict,
154
+ PATH_OUT: _PathLike,
155
+ sub: str | None = None,
156
+ save_str: str = "ch_comp_bar_plt.png",
157
+ performance_metric: str = "Balanced Accuracy",
158
+ ):
159
+ """
160
+ performances dict is output of ml_decode
161
+ """
162
+ plt.figure(figsize=(4, 3), dpi=300)
163
+ if sub is None:
164
+ sub = list(performances.keys())[0]
165
+ plt.bar(
166
+ np.arange(len(ch_names)),
167
+ [performances[sub][p]["performance_test"] for p in performances[sub]],
168
+ )
169
+ plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
170
+ plt.xlabel("channels")
171
+ plt.ylabel(performance_metric)
172
+ plt.savefig(
173
+ PurePath(PATH_OUT, save_str),
174
+ bbox_inches="tight",
175
+ )
176
+ plt.close()
177
+
178
+
179
+ def plot_corr_matrix(
180
+ feature: pd.DataFrame,
181
+ feature_file: _PathLike = "",
182
+ ch_name: str = "",
183
+ feature_names: list[str] = [],
184
+ show_plot=True,
185
+ OUT_PATH: _PathLike = "",
186
+ feature_name_plt="Features_corr_matr",
187
+ save_plot: bool = False,
188
+ save_plot_name: str = "",
189
+ figsize: tuple[float, float] = (7, 7),
190
+ title: str = "",
191
+ cbar_vmin: float = -1,
192
+ cbar_vmax: float = 1.0,
193
+ ):
194
+ # cut out channel name for each column
195
+ if not ch_name:
196
+ feature_col_name = [
197
+ i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
198
+ ]
199
+ else:
200
+ feature_col_name = feature.columns
201
+
202
+ plt.figure(figsize=figsize)
203
+ if (
204
+ len(feature_names) > 0
205
+ ): # Checking length to accomodate for tests passing a pandas Index
206
+ corr = feature[feature_names].corr()
207
+ else:
208
+ corr = feature.corr()
209
+ sb.heatmap(
210
+ corr,
211
+ xticklabels=feature_col_name,
212
+ yticklabels=feature_col_name,
213
+ vmin=cbar_vmin,
214
+ vmax=cbar_vmax,
215
+ cmap="viridis",
216
+ )
217
+ if not title:
218
+ if ch_name:
219
+ plt.title("Correlation matrix features channel: " + str(ch_name))
220
+ else:
221
+ plt.title("Correlation matrix")
222
+ else:
223
+ plt.title(title)
224
+
225
+ # if len(feature_col_name) > 50:
226
+ # plt.xticks([])
227
+ # plt.yticks([])
228
+
229
+ if save_plot:
230
+ plt_path = (
231
+ PurePath(OUT_PATH, save_plot_name)
232
+ if save_plot_name
233
+ else get_plt_path(
234
+ OUT_PATH=OUT_PATH,
235
+ feature_file=feature_file,
236
+ ch_name=ch_name,
237
+ str_plt_type=feature_name_plt,
238
+ feature_name="_".join(feature_names),
239
+ )
240
+ )
241
+
242
+ plt.savefig(plt_path, bbox_inches="tight")
243
+ logger.info(f"Correlation matrix figure saved to {plt_path}")
244
+
245
+ if not show_plot:
246
+ plt.close()
247
+
248
+ plt.tight_layout()
249
+
250
+ return plt.gca()
251
+
252
+
253
+ def plot_feature_series_time(features) -> None:
254
+ plt.imshow(features.T, aspect="auto")
255
+
256
+
257
+ def get_plt_path(
258
+ OUT_PATH: _PathLike = "",
259
+ feature_file: str = "",
260
+ ch_name: str = "",
261
+ str_plt_type: str = "",
262
+ feature_name: str = "",
263
+ ) -> _PathLike:
264
+ """[summary]
265
+
266
+ Parameters
267
+ ----------
268
+ OUT_PATH : str, optional
269
+ folder of preprocessed runs, by default None
270
+ feature_file : str, optional
271
+ run_name, by default None
272
+ ch_name : str, optional
273
+ ch_name, by default None
274
+ str_plt_type : str, optional
275
+ type of plot, e.g. mov_avg_feature or corr_matr, by default None
276
+ feature_name : str, optional
277
+ e.g. bandpower, stft, sharpwave_prominence, by default None
278
+ """
279
+ filename = (
280
+ str_plt_type
281
+ + (("_ch_" + ch_name) if ch_name else "")
282
+ + (("_" + feature_name) if feature_name else "")
283
+ + ".png"
284
+ )
285
+
286
+ return PurePath(OUT_PATH, feature_file, filename)
287
+
288
+
289
+ def plot_epochs_avg(
290
+ X_epoch: np.ndarray,
291
+ y_epoch: np.ndarray,
292
+ epoch_len: int,
293
+ sfreq: int,
294
+ feature_names: list[str] = [],
295
+ feature_str_add: str = "",
296
+ cut_ch_name_cols: bool = True,
297
+ ch_name: str = "",
298
+ label_name: str = "",
299
+ normalize_data: bool = True,
300
+ show_plot: bool = True,
301
+ save: bool = False,
302
+ OUT_PATH: _PathLike = "",
303
+ feature_file: str = "",
304
+ str_title: str = "Movement aligned features",
305
+ ytick_labelsize=None,
306
+ figsize_x: float = 8,
307
+ figsize_y: float = 8,
308
+ ) -> None:
309
+ from scipy.stats import zscore
310
+
311
+ # cut channel name of for axis + "_" for more dense plot
312
+ if not feature_names:
313
+ if cut_ch_name_cols and None not in (ch_name, feature_names):
314
+ feature_names = [
315
+ i[len(ch_name) + 1 :] for i in list(feature_names) if ch_name in i
316
+ ]
317
+
318
+ if normalize_data:
319
+ X_epoch_mean = zscore(
320
+ np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
321
+ ).T
322
+ else:
323
+ X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
324
+
325
+ if len(X_epoch_mean.shape) == 1:
326
+ X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
327
+
328
+ plt.figure(figsize=(figsize_x, figsize_y))
329
+ gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
330
+ plt.subplot(gs[0])
331
+ plt.imshow(X_epoch_mean, aspect="auto")
332
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
333
+ plt.xticks(
334
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
335
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
336
+ rotation=90,
337
+ )
338
+ plt.xlabel("Time [s]")
339
+ str_title = str_title
340
+ if ch_name:
341
+ str_title += f" channel: {ch_name}"
342
+ plt.title(str_title)
343
+
344
+ plt.subplot(gs[1])
345
+ for i in range(y_epoch.shape[0]):
346
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
347
+ plt.plot(
348
+ y_epoch.mean(axis=0),
349
+ color="black",
350
+ alpha=1,
351
+ linewidth=3.0,
352
+ label="mean target",
353
+ )
354
+ plt.legend()
355
+ plt.ylabel("Target")
356
+ plt.title(label_name)
357
+ plt.xticks(
358
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
359
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
360
+ rotation=90,
361
+ )
362
+ plt.xlabel("Time [s]")
363
+ plt.tight_layout()
364
+
365
+ if save:
366
+ plt_path = get_plt_path(
367
+ OUT_PATH,
368
+ feature_file,
369
+ ch_name,
370
+ str_plt_type="MOV_aligned_features",
371
+ feature_name=feature_str_add,
372
+ )
373
+ plt.savefig(plt_path, bbox_inches="tight")
374
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
375
+ if not show_plot:
376
+ plt.close()
377
+
378
+
379
+ def plot_grid_elec_3d(
380
+ cortex_grid: np.ndarray | None = None,
381
+ ecog_strip: np.ndarray | None = None,
382
+ grid_color: np.ndarray | None = None,
383
+ strip_color: np.ndarray | None = None,
384
+ ):
385
+ ax = plt.axes(projection="3d")
386
+
387
+ if cortex_grid is not None:
388
+ grid_color = np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
389
+ _ = ax.scatter3D(
390
+ cortex_grid[:, 0],
391
+ cortex_grid[:, 1],
392
+ cortex_grid[:, 2],
393
+ c=grid_color,
394
+ s=300,
395
+ alpha=0.8,
396
+ cmap="viridis",
397
+ )
398
+
399
+ if ecog_strip is not None:
400
+ strip_color = (
401
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
402
+ )
403
+ _ = ax.scatter(
404
+ ecog_strip[:, 0],
405
+ ecog_strip[:, 1],
406
+ ecog_strip[:, 2],
407
+ c=strip_color,
408
+ s=500, # Bug? Third argument is s, what is this value?
409
+ alpha=0.8,
410
+ cmap="gray",
411
+ marker="o",
412
+ )
413
+
414
+
415
+ def plot_all_features(
416
+ df: pd.DataFrame,
417
+ time_limit_low_s: float | None = None,
418
+ time_limit_high_s: float | None = None,
419
+ normalize: bool = True,
420
+ ytick_labelsize: int = 4,
421
+ clim_low: float | None = None,
422
+ clim_high: float | None = None,
423
+ save: bool = False,
424
+ title="all_feature_plt.pdf",
425
+ OUT_PATH: _PathLike = "",
426
+ feature_file: str = "",
427
+ ):
428
+ from scipy.stats import zscore
429
+
430
+ if time_limit_high_s is not None:
431
+ df = df[df["time"] < time_limit_high_s * 1000]
432
+ if time_limit_low_s is not None:
433
+ df = df[df["time"] > time_limit_low_s * 1000]
434
+
435
+ cols_plt = [c for c in df.columns if c != "time"]
436
+ if normalize:
437
+ data_plt = zscore(df[cols_plt], nan_policy="omit")
438
+ else:
439
+ data_plt = df[cols_plt]
440
+
441
+ plt.figure() # figsize=(7, 5), dpi=300
442
+ plt.imshow(data_plt.T, aspect="auto")
443
+ plt.xlabel("Time [s]")
444
+ plt.ylabel("Feature Names")
445
+ plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
446
+
447
+ tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
448
+ tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
449
+ plt.xticks(tick_num, tick_labels)
450
+
451
+ plt.title(f"Feature Plot {feature_file}")
452
+
453
+ if clim_low is not None:
454
+ plt.clim(vmin=clim_low)
455
+ if clim_high is not None:
456
+ plt.clim(vmax=clim_high)
457
+
458
+ plt.colorbar()
459
+ plt.tight_layout()
460
+
461
+ if save:
462
+ plt_path = PurePath(OUT_PATH, feature_file, title)
463
+ plt.savefig(plt_path, bbox_inches="tight")
464
+
465
+
466
+ def read_plot_modules(
467
+ PATH_PLOT: _PathLike = PYNM_DIR / "plots",
468
+ ):
469
+ """Read required .mat files for plotting
470
+
471
+ Parameters
472
+ ----------
473
+ PATH_PLOT : regexp, optional
474
+ path to plotting files, by default
475
+ """
476
+ from py_neuromodulation.utils.io import loadmat
477
+
478
+ faces = loadmat(PurePath(PATH_PLOT, "faces.mat"))
479
+ vertices = loadmat(PurePath(PATH_PLOT, "Vertices.mat"))
480
+ grid = loadmat(PurePath(PATH_PLOT, "grid.mat"))["grid"]
481
+ stn_surf = loadmat(PurePath(PATH_PLOT, "STN_surf.mat"))
482
+ x_ver = stn_surf["vertices"][::2, 0]
483
+ y_ver = stn_surf["vertices"][::2, 1]
484
+ x_ecog = vertices["Vertices"][::1, 0]
485
+ y_ecog = vertices["Vertices"][::1, 1]
486
+ z_ecog = vertices["Vertices"][::1, 2]
487
+ x_stn = stn_surf["vertices"][::1, 0]
488
+ y_stn = stn_surf["vertices"][::1, 1]
489
+ z_stn = stn_surf["vertices"][::1, 2]
490
+
491
+ return (
492
+ faces,
493
+ vertices,
494
+ grid,
495
+ stn_surf,
496
+ x_ver,
497
+ y_ver,
498
+ x_ecog,
499
+ y_ecog,
500
+ z_ecog,
501
+ x_stn,
502
+ y_stn,
503
+ z_stn,
504
+ )
505
+
506
+
507
+ class NM_Plot:
508
+ def __init__(
509
+ self,
510
+ ecog_strip: np.ndarray | None = None,
511
+ grid_cortex: np.ndarray | None = None,
512
+ grid_subcortex: np.ndarray | None = None,
513
+ sess_right: bool | None = False,
514
+ proj_matrix_cortex: np.ndarray | None = None,
515
+ ) -> None:
516
+ self.grid_cortex = grid_cortex
517
+ self.grid_subcortex = grid_subcortex
518
+ self.ecog_strip = ecog_strip
519
+ self.sess_right = sess_right
520
+ self.proj_matrix_cortex = proj_matrix_cortex
521
+
522
+ (
523
+ self.faces,
524
+ self.vertices,
525
+ self.grid,
526
+ self.stn_surf,
527
+ self.x_ver,
528
+ self.y_ver,
529
+ self.x_ecog,
530
+ self.y_ecog,
531
+ self.z_ecog,
532
+ self.x_stn,
533
+ self.y_stn,
534
+ self.z_stn,
535
+ ) = read_plot_modules()
536
+
537
+ def plot_grid_elec_3d(self) -> None:
538
+ plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
539
+
540
+ def plot_cortex(
541
+ self,
542
+ grid_cortex: np.ndarray | pd.DataFrame | None = None,
543
+ grid_color: np.ndarray | None = None,
544
+ ecog_strip: np.ndarray | None = None,
545
+ strip_color: np.ndarray | None = None,
546
+ sess_right: bool | None = None,
547
+ save: bool = False,
548
+ OUT_PATH: _PathLike = "",
549
+ feature_file: str = "",
550
+ feature_str_add: str = "",
551
+ show_plot: bool = True,
552
+ title: str = "Cortical grid",
553
+ set_clim: bool = True,
554
+ lower_clim: float = 0.5,
555
+ upper_clim: float = 0.7,
556
+ cbar_label: str = "Balanced Accuracy",
557
+ ):
558
+ """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
559
+ Colorcoded by grid_color
560
+ """
561
+
562
+ if grid_cortex is None:
563
+ if type(self.grid_cortex) is pd.DataFrame:
564
+ grid_cortex = np.array(self.grid_cortex)
565
+ else:
566
+ grid_cortex = self.grid_cortex
567
+
568
+ if ecog_strip is None:
569
+ ecog_strip = self.ecog_strip
570
+
571
+ if sess_right:
572
+ grid_cortex[0, :] = grid_cortex[0, :] * -1 # type: ignore # Handled above
573
+
574
+ fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
575
+ axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
576
+ axes.axes.set_aspect("equal", anchor="C")
577
+
578
+ if grid_cortex is not None:
579
+ grid_color = (
580
+ np.ones(grid_cortex.shape[0]) if grid_color is None else grid_color
581
+ )
582
+
583
+ pos_ecog = axes.scatter(
584
+ grid_cortex[:, 0],
585
+ grid_cortex[:, 1],
586
+ c=grid_color,
587
+ s=150,
588
+ alpha=0.8,
589
+ cmap="viridis",
590
+ label="grid points",
591
+ )
592
+ if set_clim:
593
+ pos_ecog.set_clim(lower_clim, upper_clim)
594
+ if ecog_strip is not None:
595
+ strip_color = (
596
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
597
+ )
598
+
599
+ pos_ecog = axes.scatter(
600
+ ecog_strip[:, 0],
601
+ ecog_strip[:, 1],
602
+ c=strip_color,
603
+ s=400,
604
+ alpha=0.8,
605
+ cmap="viridis",
606
+ marker="x",
607
+ label="ecog electrode",
608
+ )
609
+ plt.axis("off")
610
+ plt.legend()
611
+ plt.title(title)
612
+ if set_clim:
613
+ pos_ecog.set_clim(lower_clim, upper_clim)
614
+ cbar = fig.colorbar(pos_ecog)
615
+ cbar.set_label(cbar_label)
616
+
617
+ if save:
618
+ plt_path = get_plt_path(
619
+ OUT_PATH,
620
+ feature_file,
621
+ str_plt_type="PLOT_CORTEX",
622
+ feature_name=feature_str_add,
623
+ )
624
+ plt.savefig(plt_path, bbox_inches="tight")
625
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
626
+ if not show_plot:
627
+ plt.close()