py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,612 +1,585 @@
1
- from scipy import stats
2
- import os
3
- import numpy as np
4
- from matplotlib import pyplot as plt
5
- from matplotlib import gridspec
6
- from typing import Optional
7
- import seaborn as sb
8
- import pandas as pd
9
- import logging
10
-
11
- logger = logging.getLogger("PynmLogger")
12
-
13
- from py_neuromodulation import nm_IO, nm_stats
14
-
15
-
16
- def plot_df_subjects(
17
- df,
18
- x_col="sub",
19
- y_col="performance_test",
20
- hue=None,
21
- title="channel specific performances",
22
- PATH_SAVE: str = None,
23
- figsize_tuple: tuple = (5, 3),
24
- ):
25
- alpha_box = 0.4
26
- plt.figure(figsize=figsize_tuple, dpi=300)
27
- sb.boxplot(
28
- x=x_col,
29
- y=y_col,
30
- hue=hue,
31
- data=df,
32
- palette="viridis",
33
- showmeans=False,
34
- boxprops=dict(alpha=alpha_box),
35
- showcaps=True,
36
- showbox=True,
37
- showfliers=False,
38
- notch=False,
39
- whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
40
- capprops={"alpha": alpha_box},
41
- medianprops=dict(
42
- linestyle="-", linewidth=5, color="gray", alpha=alpha_box
43
- ),
44
- )
45
-
46
- ax = sb.stripplot(
47
- x=x_col,
48
- y=y_col,
49
- hue=hue,
50
- data=df,
51
- palette="viridis",
52
- dodge=True,
53
- s=5,
54
- )
55
-
56
- if hue is not None:
57
- n_hues = df[hue].nunique()
58
-
59
- handles, labels = ax.get_legend_handles_labels()
60
- l = plt.legend(
61
- handles[0:n_hues],
62
- labels[0:n_hues],
63
- bbox_to_anchor=(1.05, 1),
64
- loc=2,
65
- title=hue,
66
- borderaxespad=0.0,
67
- )
68
- plt.title(title)
69
- plt.ylabel(y_col)
70
- plt.xticks(rotation=90)
71
- if PATH_SAVE is not None:
72
- plt.savefig(
73
- PATH_SAVE,
74
- bbox_inches="tight",
75
- )
76
- # plt.show()
77
- return plt.gca()
78
-
79
-
80
- def plot_epoch(
81
- X_epoch: np.array,
82
- y_epoch: np.array,
83
- feature_names: list,
84
- z_score: bool = None,
85
- epoch_len: int = 4,
86
- sfreq: int = 10,
87
- str_title: str = None,
88
- str_label: str = None,
89
- ytick_labelsize: float = None,
90
- ):
91
- if z_score is None:
92
- X_epoch = stats.zscore(
93
- np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
94
- axis=0,
95
- nan_policy="omit",
96
- ).T
97
- y_epoch = np.stack(np.array(y_epoch))
98
- plt.figure(figsize=(6, 6))
99
- plt.subplot(211)
100
- plt.imshow(X_epoch, aspect="auto")
101
- plt.yticks(
102
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
103
- )
104
- plt.xticks(
105
- np.arange(0, X_epoch.shape[1], 1),
106
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
107
- rotation=90,
108
- )
109
- plt.gca().invert_yaxis()
110
- plt.xlabel("Time [s]")
111
- plt.title(str_title)
112
-
113
- plt.subplot(212)
114
- for i in range(y_epoch.shape[0]):
115
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
116
- plt.plot(
117
- y_epoch.mean(axis=0),
118
- color="black",
119
- alpha=1,
120
- linewidth=3.0,
121
- label="mean target",
122
- )
123
- plt.legend()
124
- plt.ylabel("Target")
125
- plt.title(str_label)
126
- plt.xticks(
127
- np.arange(0, X_epoch.shape[1], 1),
128
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
129
- rotation=90,
130
- )
131
- plt.xlabel("Time [s]")
132
- plt.tight_layout()
133
-
134
-
135
- def reg_plot(
136
- x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str = None
137
- ):
138
- plt.figure(figsize=(4, 4), dpi=300)
139
- rho, p = nm_stats.permutationTestSpearmansRho(
140
- data[x_col],
141
- data[y_col],
142
- False,
143
- "R^2",
144
- 5000,
145
- )
146
- sb.regplot(x=x_col, y=y_col, data=data)
147
- plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
148
-
149
- if out_path_save is not None:
150
- plt.savefig(
151
- out_path_save,
152
- bbox_inches="tight",
153
- )
154
-
155
-
156
- def plot_bar_performance_per_channel(
157
- ch_names,
158
- performances: dict,
159
- PATH_OUT: str,
160
- sub: str = None,
161
- save_str: str = "ch_comp_bar_plt.png",
162
- performance_metric: str = "Balanced Accuracy",
163
- ):
164
- """
165
- performances dict is output of ml_decode
166
- """
167
- plt.figure(figsize=(4, 3), dpi=300)
168
- if sub is None:
169
- sub = list(performances.keys())[0]
170
- plt.bar(
171
- np.arange(len(ch_names)),
172
- [performances[sub][p]["performance_test"] for p in performances[sub]],
173
- )
174
- plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
175
- plt.xlabel("channels")
176
- plt.ylabel(performance_metric)
177
- plt.savefig(
178
- os.path.join(PATH_OUT, save_str),
179
- bbox_inches="tight",
180
- )
181
- plt.close()
182
-
183
-
184
- def plot_corr_matrix(
185
- feature: pd.DataFrame,
186
- feature_file: str = None,
187
- ch_name: str = None,
188
- feature_names: list[str] = None,
189
- show_plot=True,
190
- OUT_PATH: str = None,
191
- feature_name_plt="Features_corr_matr",
192
- save_plot: bool = False,
193
- save_plot_name: str = None,
194
- figsize: tuple[int] = (7, 7),
195
- title: str = None,
196
- cbar_vmin: float = -1,
197
- cbar_vmax: float = 1.0,
198
- ):
199
- # cut out channel name for each column
200
- if ch_name is not None:
201
- feature_col_name = [
202
- i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
203
- ]
204
- else:
205
- feature_col_name = feature.columns
206
-
207
- plt.figure(figsize=figsize)
208
- if feature_names is not None:
209
- corr = feature[feature_names].corr()
210
- else:
211
- corr = feature.corr()
212
- sb.heatmap(
213
- corr,
214
- xticklabels=feature_col_name,
215
- yticklabels=feature_col_name,
216
- vmin=cbar_vmin,
217
- vmax=cbar_vmax,
218
- cmap="viridis",
219
- )
220
- if title is None:
221
- if ch_name is not None:
222
- plt.title("Correlation matrix features channel: " + str(ch_name))
223
- else:
224
- plt.title("Correlation matrix")
225
- else:
226
- plt.title(title)
227
-
228
- # if len(feature_col_name) > 50:
229
- # plt.xticks([])
230
- # plt.yticks([])
231
-
232
- if save_plot and save_plot_name is None:
233
- plt_path = get_plt_path(
234
- OUT_PATH=OUT_PATH,
235
- feature_file=feature_file,
236
- ch_name=ch_name,
237
- str_plt_type=feature_name_plt,
238
- # feature_name=feature_names.__str__, # This here raises an error in os.path.join in line 251
239
- )
240
- if save_plot and save_plot_name is not None:
241
- plt_path = os.path.join(OUT_PATH, save_plot_name)
242
-
243
- if save_plot:
244
- plt.savefig(plt_path, bbox_inches="tight")
245
- logger.info(f"Correlation matrix figure saved to {plt_path}")
246
-
247
- if show_plot is False:
248
- plt.close()
249
-
250
- plt.tight_layout()
251
-
252
- return plt.gca()
253
-
254
-
255
- def plot_feature_series_time(features) -> None:
256
- plt.imshow(features.T, aspect="auto")
257
-
258
-
259
- def get_plt_path(
260
- OUT_PATH: str | None = None,
261
- feature_file: str | None = None,
262
- ch_name: str | None = None,
263
- str_plt_type: str | None = None,
264
- feature_name: str | None = None,
265
- ) -> None:
266
- """[summary]
267
-
268
- Parameters
269
- ----------
270
- OUT_PATH : str, optional
271
- folder of preprocessed runs, by default None
272
- feature_file : str, optional
273
- run_name, by default None
274
- ch_name : str, optional
275
- ch_name, by default None
276
- str_plt_type : str, optional
277
- type of plot, e.g. mov_avg_feature or corr_matr, by default None
278
- feature_name : str, optional
279
- e.g. bandpower, stft, sharpwave_prominence, by default None
280
- """
281
- if None not in (ch_name, OUT_PATH, feature_file):
282
- if feature_name is None:
283
- plt_path = os.path.join(
284
- OUT_PATH,
285
- feature_file,
286
- str_plt_type + "_ch_" + ch_name + ".png",
287
- )
288
- else:
289
- plt_path = os.path.join(
290
- OUT_PATH,
291
- feature_file,
292
- str_plt_type + "_ch_" + ch_name + "_" + feature_name + ".png",
293
- )
294
- elif None not in (OUT_PATH, feature_file) and ch_name is None:
295
- plt_path = os.path.join(
296
- OUT_PATH,
297
- feature_file,
298
- str_plt_type + "_ch_" + feature_name + ".png",
299
- )
300
-
301
- else:
302
- plt_path = os.getcwd() + ".png"
303
- return plt_path
304
-
305
-
306
- def plot_epochs_avg(
307
- X_epoch: np.ndarray,
308
- y_epoch: np.ndarray,
309
- epoch_len: int,
310
- sfreq: int,
311
- feature_names: list[str] = None,
312
- feature_str_add: str = None,
313
- cut_ch_name_cols: bool = True,
314
- ch_name: str = None,
315
- label_name: str = None,
316
- normalize_data: bool = True,
317
- show_plot: bool = True,
318
- save: bool = False,
319
- OUT_PATH: str = None,
320
- feature_file: str = None,
321
- str_title: str = "Movement aligned features",
322
- ytick_labelsize=None,
323
- figsize_x: float = 8,
324
- figsize_y: float = 8,
325
- ) -> None:
326
- # cut channel name of for axis + "_" for more dense plot
327
- if feature_names is None:
328
- if cut_ch_name_cols and None not in (ch_name, feature_names):
329
- feature_names = [
330
- i[len(ch_name) + 1 :]
331
- for i in list(feature_names)
332
- if ch_name in i
333
- ]
334
-
335
- if normalize_data:
336
- X_epoch_mean = stats.zscore(
337
- np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
338
- ).T
339
- else:
340
- X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
341
-
342
- if len(X_epoch_mean.shape) == 1:
343
- X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
344
-
345
- plt.figure(figsize=(figsize_x, figsize_y))
346
- gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
347
- plt.subplot(gs[0])
348
- plt.imshow(X_epoch_mean, aspect="auto")
349
- plt.yticks(
350
- np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize
351
- )
352
- plt.xticks(
353
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
354
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
355
- rotation=90,
356
- )
357
- plt.xlabel("Time [s]")
358
- str_title = str_title
359
- if ch_name:
360
- str_title += f" channel: {ch_name}"
361
- plt.title(str_title)
362
-
363
- plt.subplot(gs[1])
364
- for i in range(y_epoch.shape[0]):
365
- plt.plot(y_epoch[i, :], color="black", alpha=0.4)
366
- plt.plot(
367
- y_epoch.mean(axis=0),
368
- color="black",
369
- alpha=1,
370
- linewidth=3.0,
371
- label="mean target",
372
- )
373
- plt.legend()
374
- plt.ylabel("Target")
375
- plt.title(label_name)
376
- plt.xticks(
377
- np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
378
- np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
379
- rotation=90,
380
- )
381
- plt.xlabel("Time [s]")
382
- plt.tight_layout()
383
-
384
- if save:
385
- plt_path = get_plt_path(
386
- OUT_PATH,
387
- feature_file,
388
- ch_name,
389
- str_plt_type="MOV_aligned_features",
390
- feature_name=feature_str_add,
391
- )
392
- plt.savefig(plt_path, bbox_inches="tight")
393
- logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
394
- if show_plot is False:
395
- plt.close()
396
-
397
-
398
- def plot_grid_elec_3d(
399
- cortex_grid: np.ndarray | None = None,
400
- ecog_strip: np.ndarray | None = None,
401
- grid_color: np.ndarray | None = None,
402
- strip_color: np.ndarray | None = None,
403
- ):
404
- ax = plt.axes(projection="3d")
405
-
406
- if cortex_grid is not None:
407
- grid_color = (
408
- np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
409
- )
410
- _ = ax.scatter3D(
411
- cortex_grid[:, 0],
412
- cortex_grid[:, 1],
413
- cortex_grid[:, 2],
414
- c=grid_color,
415
- s=300,
416
- alpha=0.8,
417
- cmap="viridis",
418
- )
419
-
420
- if ecog_strip is not None:
421
- strip_color = (
422
- np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
423
- )
424
- _ = ax.scatter(
425
- ecog_strip[:, 0],
426
- ecog_strip[:, 1],
427
- ecog_strip[:, 2],
428
- c=strip_color,
429
- s=500,
430
- alpha=0.8,
431
- cmap="gray",
432
- marker="o",
433
- )
434
-
435
-
436
- def plot_all_features(
437
- df: pd.DataFrame,
438
- time_limit_low_s: float = None,
439
- time_limit_high_s: float = None,
440
- normalize: bool = True,
441
- ytick_labelsize: int = 4,
442
- clim_low: float = None,
443
- clim_high: float = None,
444
- save: bool = False,
445
- title="all_feature_plt.pdf",
446
- OUT_PATH: str = None,
447
- feature_file: str = None,
448
- ):
449
- if time_limit_high_s is not None:
450
- df = df[df["time"] < time_limit_high_s * 1000]
451
- if time_limit_low_s is not None:
452
- df = df[df["time"] > time_limit_low_s * 1000]
453
-
454
- cols_plt = [c for c in df.columns if c != "time"]
455
- if normalize is True:
456
- data_plt = stats.zscore(df[cols_plt], nan_policy="omit")
457
- else:
458
- data_plt = df[cols_plt]
459
-
460
- plt.figure() # figsize=(7, 5), dpi=300
461
- plt.imshow(data_plt.T, aspect="auto")
462
- plt.xlabel("Time [s]")
463
- plt.ylabel("Feature Names")
464
- plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
465
-
466
- tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
467
- tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
468
- plt.xticks(tick_num, tick_labels)
469
-
470
- plt.title(f"Feature Plot {feature_file}")
471
-
472
- if clim_low is not None:
473
- plt.clim(vmin=clim_low)
474
- if clim_high is not None:
475
- plt.clim(vmax=clim_high)
476
-
477
- plt.colorbar()
478
- plt.tight_layout()
479
-
480
- if save is True:
481
- plt_path = os.path.join(OUT_PATH, feature_file, title)
482
- plt.savefig(plt_path, bbox_inches="tight")
483
-
484
-
485
- class NM_Plot:
486
- def __init__(
487
- self,
488
- ecog_strip: np.ndarray | None = None,
489
- grid_cortex: np.ndarray | None = None,
490
- grid_subcortex: np.ndarray | None = None,
491
- sess_right: Optional[bool] = False,
492
- proj_matrix_cortex: np.ndarray | None = None,
493
- ) -> None:
494
- self.grid_cortex = grid_cortex
495
- self.grid_subcortex = grid_subcortex
496
- self.ecog_strip = ecog_strip
497
- self.sess_right = sess_right
498
- self.proj_matrix_cortex = proj_matrix_cortex
499
-
500
- (
501
- self.faces,
502
- self.vertices,
503
- self.grid,
504
- self.stn_surf,
505
- self.x_ver,
506
- self.y_ver,
507
- self.x_ecog,
508
- self.y_ecog,
509
- self.z_ecog,
510
- self.x_stn,
511
- self.y_stn,
512
- self.z_stn,
513
- ) = nm_IO.read_plot_modules()
514
-
515
- def plot_grid_elec_3d(self) -> None:
516
- plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
517
-
518
- def plot_cortex(
519
- self,
520
- grid_cortex: Optional[np.ndarray] = None,
521
- grid_color: Optional[np.ndarray] = None,
522
- ecog_strip: Optional[np.ndarray] = None,
523
- strip_color: Optional[np.ndarray] = None,
524
- sess_right: Optional[bool] = None,
525
- save: bool = False,
526
- OUT_PATH: str = None,
527
- feature_file: str = None,
528
- feature_str_add: str = None,
529
- show_plot: bool = True,
530
- title: str = "Cortical grid",
531
- set_clim: bool = True,
532
- lower_clim: float = 0.5,
533
- upper_clim: float = 0.7,
534
- cbar_label: str = "Balanced Accuracy",
535
- ):
536
- """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
537
- Colorcoded by grid_color
538
- """
539
-
540
- if grid_cortex is None:
541
- if type(self.grid_cortex) is pd.DataFrame:
542
- grid_cortex = np.array(self.grid_cortex)
543
- else:
544
- grid_cortex = self.grid_cortex
545
-
546
- if ecog_strip is None:
547
- ecog_strip = self.ecog_strip
548
-
549
- if sess_right is True:
550
- grid_cortex[0, :] = grid_cortex[0, :] * -1
551
-
552
- fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
553
- axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
554
- axes.axes.set_aspect("equal", anchor="C")
555
-
556
- if grid_cortex is not None:
557
- grid_color = (
558
- np.ones(grid_cortex.shape[0])
559
- if grid_color is None
560
- else grid_color
561
- )
562
-
563
- pos_ecog = axes.scatter(
564
- grid_cortex[:, 0],
565
- grid_cortex[:, 1],
566
- c=grid_color,
567
- s=150,
568
- alpha=0.8,
569
- cmap="viridis",
570
- label="grid points",
571
- )
572
- if set_clim:
573
- pos_ecog.set_clim(lower_clim, upper_clim)
574
- if ecog_strip is not None:
575
- strip_color = (
576
- np.ones(ecog_strip.shape[0])
577
- if strip_color is None
578
- else strip_color
579
- )
580
-
581
- pos_ecog = axes.scatter(
582
- ecog_strip[:, 0],
583
- ecog_strip[:, 1],
584
- c=strip_color,
585
- s=400,
586
- alpha=0.8,
587
- cmap="viridis",
588
- marker="x",
589
- label="ecog electrode",
590
- )
591
- plt.axis("off")
592
- plt.legend()
593
- plt.title(title)
594
- if set_clim:
595
- pos_ecog.set_clim(lower_clim, upper_clim)
596
- cbar = fig.colorbar(pos_ecog)
597
- cbar.set_label(cbar_label)
598
-
599
- if save:
600
- plt_path = get_plt_path(
601
- OUT_PATH,
602
- feature_file,
603
- ch_name=None,
604
- str_plt_type="PLOT_CORTEX",
605
- feature_name=feature_str_add,
606
- )
607
- plt.savefig(plt_path, bbox_inches="tight")
608
- logger.info(
609
- f"Feature epoch average figure saved to: {str(plt_path)}"
610
- )
611
- if show_plot is False:
612
- plt.close()
1
+ import numpy as np
2
+ import pandas as pd
3
+ from scipy.stats import zscore as scipy_zscore
4
+ from matplotlib import pyplot as plt
5
+ from matplotlib import gridspec
6
+ import seaborn as sb
7
+ from pathlib import PurePath
8
+
9
+ from py_neuromodulation.nm_types import _PathLike
10
+ from py_neuromodulation import logger
11
+
12
+
13
+ def plot_df_subjects(
14
+ df,
15
+ x_col="sub",
16
+ y_col="performance_test",
17
+ hue=None,
18
+ title="channel specific performances",
19
+ PATH_SAVE: _PathLike = "",
20
+ figsize_tuple: tuple[float, float] = (5, 3),
21
+ ):
22
+ alpha_box = 0.4
23
+ plt.figure(figsize=figsize_tuple, dpi=300)
24
+ sb.boxplot(
25
+ x=x_col,
26
+ y=y_col,
27
+ hue=hue,
28
+ data=df,
29
+ palette="viridis",
30
+ showmeans=False,
31
+ boxprops=dict(alpha=alpha_box),
32
+ showcaps=True,
33
+ showbox=True,
34
+ showfliers=False,
35
+ notch=False,
36
+ whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
37
+ capprops={"alpha": alpha_box},
38
+ medianprops=dict(linestyle="-", linewidth=5, color="gray", alpha=alpha_box),
39
+ )
40
+
41
+ ax = sb.stripplot(
42
+ x=x_col,
43
+ y=y_col,
44
+ hue=hue,
45
+ data=df,
46
+ palette="viridis",
47
+ dodge=True,
48
+ s=5,
49
+ )
50
+
51
+ if hue is not None:
52
+ n_hues = df[hue].nunique()
53
+
54
+ handles, labels = ax.get_legend_handles_labels()
55
+ plt.legend(
56
+ handles[0:n_hues],
57
+ labels[0:n_hues],
58
+ bbox_to_anchor=(1.05, 1),
59
+ loc=2,
60
+ title=hue,
61
+ borderaxespad=0.0,
62
+ )
63
+ plt.title(title)
64
+ plt.ylabel(y_col)
65
+ plt.xticks(rotation=90)
66
+ if PATH_SAVE:
67
+ plt.savefig(
68
+ PATH_SAVE,
69
+ bbox_inches="tight",
70
+ )
71
+ # plt.show()
72
+ return plt.gca()
73
+
74
+
75
+ def plot_epoch(
76
+ X_epoch: np.ndarray,
77
+ y_epoch: np.ndarray,
78
+ feature_names: list,
79
+ z_score: bool | None = None,
80
+ epoch_len: int = 4,
81
+ sfreq: int = 10,
82
+ str_title: str = "",
83
+ str_label: str = "",
84
+ ytick_labelsize: float | None = None,
85
+ ):
86
+ if z_score is None:
87
+ X_epoch = scipy_zscore(
88
+ np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
89
+ axis=0,
90
+ nan_policy="omit",
91
+ ).T
92
+ y_epoch = np.stack([np.array(y_epoch)])
93
+ plt.figure(figsize=(6, 6))
94
+ plt.subplot(211)
95
+ plt.imshow(X_epoch, aspect="auto")
96
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
97
+ plt.xticks(
98
+ np.arange(0, X_epoch.shape[1], 1),
99
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
100
+ rotation=90,
101
+ )
102
+ plt.gca().invert_yaxis()
103
+ plt.xlabel("Time [s]")
104
+ plt.title(str_title)
105
+
106
+ plt.subplot(212)
107
+ for i in range(y_epoch.shape[0]):
108
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
109
+ plt.plot(
110
+ y_epoch.mean(axis=0),
111
+ color="black",
112
+ alpha=1,
113
+ linewidth=3.0,
114
+ label="mean target",
115
+ )
116
+ plt.legend()
117
+ plt.ylabel("Target")
118
+ plt.title(str_label)
119
+ plt.xticks(
120
+ np.arange(0, X_epoch.shape[1], 1),
121
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
122
+ rotation=90,
123
+ )
124
+ plt.xlabel("Time [s]")
125
+ plt.tight_layout()
126
+
127
+
128
+ def reg_plot(
129
+ x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str | None = None
130
+ ):
131
+
132
+ from py_neuromodulation.nm_stats import permutationTestSpearmansRho
133
+
134
+ plt.figure(figsize=(4, 4), dpi=300)
135
+ rho, p = permutationTestSpearmansRho(
136
+ data[x_col],
137
+ data[y_col],
138
+ False,
139
+ "R^2",
140
+ 5000,
141
+ )
142
+ sb.regplot(x=x_col, y=y_col, data=data)
143
+ plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
144
+
145
+ if out_path_save is not None:
146
+ plt.savefig(
147
+ out_path_save,
148
+ bbox_inches="tight",
149
+ )
150
+
151
+
152
+ def plot_bar_performance_per_channel(
153
+ ch_names,
154
+ performances: dict,
155
+ PATH_OUT: _PathLike,
156
+ sub: str | None = None,
157
+ save_str: str = "ch_comp_bar_plt.png",
158
+ performance_metric: str = "Balanced Accuracy",
159
+ ):
160
+ """
161
+ performances dict is output of ml_decode
162
+ """
163
+ plt.figure(figsize=(4, 3), dpi=300)
164
+ if sub is None:
165
+ sub = list(performances.keys())[0]
166
+ plt.bar(
167
+ np.arange(len(ch_names)),
168
+ [performances[sub][p]["performance_test"] for p in performances[sub]],
169
+ )
170
+ plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
171
+ plt.xlabel("channels")
172
+ plt.ylabel(performance_metric)
173
+ plt.savefig(
174
+ PurePath(PATH_OUT, save_str),
175
+ bbox_inches="tight",
176
+ )
177
+ plt.close()
178
+
179
+
180
+ def plot_corr_matrix(
181
+ feature: pd.DataFrame,
182
+ feature_file: _PathLike = "",
183
+ ch_name: str = "",
184
+ feature_names: list[str] = [],
185
+ show_plot=True,
186
+ OUT_PATH: _PathLike = "",
187
+ feature_name_plt="Features_corr_matr",
188
+ save_plot: bool = False,
189
+ save_plot_name: str = "",
190
+ figsize: tuple[float, float] = (7, 7),
191
+ title: str = "",
192
+ cbar_vmin: float = -1,
193
+ cbar_vmax: float = 1.0,
194
+ ):
195
+ # cut out channel name for each column
196
+ if not ch_name:
197
+ feature_col_name = [
198
+ i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
199
+ ]
200
+ else:
201
+ feature_col_name = feature.columns
202
+
203
+ plt.figure(figsize=figsize)
204
+ if (
205
+ len(feature_names) > 0
206
+ ): # Checking length to accomodate for tests passing a pandas Index
207
+ corr = feature[feature_names].corr()
208
+ else:
209
+ corr = feature.corr()
210
+ sb.heatmap(
211
+ corr,
212
+ xticklabels=feature_col_name,
213
+ yticklabels=feature_col_name,
214
+ vmin=cbar_vmin,
215
+ vmax=cbar_vmax,
216
+ cmap="viridis",
217
+ )
218
+ if not title:
219
+ if ch_name:
220
+ plt.title("Correlation matrix features channel: " + str(ch_name))
221
+ else:
222
+ plt.title("Correlation matrix")
223
+ else:
224
+ plt.title(title)
225
+
226
+ # if len(feature_col_name) > 50:
227
+ # plt.xticks([])
228
+ # plt.yticks([])
229
+
230
+ if save_plot:
231
+ plt_path = (
232
+ PurePath(OUT_PATH, save_plot_name)
233
+ if save_plot_name
234
+ else get_plt_path(
235
+ OUT_PATH=OUT_PATH,
236
+ feature_file=feature_file,
237
+ ch_name=ch_name,
238
+ str_plt_type=feature_name_plt,
239
+ feature_name="_".join(feature_names),
240
+ )
241
+ )
242
+
243
+ plt.savefig(plt_path, bbox_inches="tight")
244
+ logger.info(f"Correlation matrix figure saved to {plt_path}")
245
+
246
+ if not show_plot:
247
+ plt.close()
248
+
249
+ plt.tight_layout()
250
+
251
+ return plt.gca()
252
+
253
+
254
+ def plot_feature_series_time(features) -> None:
255
+ plt.imshow(features.T, aspect="auto")
256
+
257
+
258
+ def get_plt_path(
259
+ OUT_PATH: _PathLike = "",
260
+ feature_file: str = "",
261
+ ch_name: str = "",
262
+ str_plt_type: str = "",
263
+ feature_name: str = "",
264
+ ) -> _PathLike:
265
+ """[summary]
266
+
267
+ Parameters
268
+ ----------
269
+ OUT_PATH : str, optional
270
+ folder of preprocessed runs, by default None
271
+ feature_file : str, optional
272
+ run_name, by default None
273
+ ch_name : str, optional
274
+ ch_name, by default None
275
+ str_plt_type : str, optional
276
+ type of plot, e.g. mov_avg_feature or corr_matr, by default None
277
+ feature_name : str, optional
278
+ e.g. bandpower, stft, sharpwave_prominence, by default None
279
+ """
280
+ filename = (
281
+ str_plt_type
282
+ + (("_ch_" + ch_name) if ch_name else "")
283
+ + (("_" + feature_name) if feature_name else "")
284
+ + ".png"
285
+ )
286
+
287
+ return PurePath(OUT_PATH, feature_file, filename)
288
+
289
+
290
+ def plot_epochs_avg(
291
+ X_epoch: np.ndarray,
292
+ y_epoch: np.ndarray,
293
+ epoch_len: int,
294
+ sfreq: int,
295
+ feature_names: list[str] = [],
296
+ feature_str_add: str = "",
297
+ cut_ch_name_cols: bool = True,
298
+ ch_name: str = "",
299
+ label_name: str = "",
300
+ normalize_data: bool = True,
301
+ show_plot: bool = True,
302
+ save: bool = False,
303
+ OUT_PATH: _PathLike = "",
304
+ feature_file: str = "",
305
+ str_title: str = "Movement aligned features",
306
+ ytick_labelsize=None,
307
+ figsize_x: float = 8,
308
+ figsize_y: float = 8,
309
+ ) -> None:
310
+ # cut channel name of for axis + "_" for more dense plot
311
+ if not feature_names:
312
+ if cut_ch_name_cols and None not in (ch_name, feature_names):
313
+ feature_names = [
314
+ i[len(ch_name) + 1 :] for i in list(feature_names) if ch_name in i
315
+ ]
316
+
317
+ if normalize_data:
318
+ X_epoch_mean = scipy_zscore(
319
+ np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
320
+ ).T
321
+ else:
322
+ X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
323
+
324
+ if len(X_epoch_mean.shape) == 1:
325
+ X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
326
+
327
+ plt.figure(figsize=(figsize_x, figsize_y))
328
+ gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
329
+ plt.subplot(gs[0])
330
+ plt.imshow(X_epoch_mean, aspect="auto")
331
+ plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
332
+ plt.xticks(
333
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
334
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
335
+ rotation=90,
336
+ )
337
+ plt.xlabel("Time [s]")
338
+ str_title = str_title
339
+ if ch_name:
340
+ str_title += f" channel: {ch_name}"
341
+ plt.title(str_title)
342
+
343
+ plt.subplot(gs[1])
344
+ for i in range(y_epoch.shape[0]):
345
+ plt.plot(y_epoch[i, :], color="black", alpha=0.4)
346
+ plt.plot(
347
+ y_epoch.mean(axis=0),
348
+ color="black",
349
+ alpha=1,
350
+ linewidth=3.0,
351
+ label="mean target",
352
+ )
353
+ plt.legend()
354
+ plt.ylabel("Target")
355
+ plt.title(label_name)
356
+ plt.xticks(
357
+ np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
358
+ np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
359
+ rotation=90,
360
+ )
361
+ plt.xlabel("Time [s]")
362
+ plt.tight_layout()
363
+
364
+ if save:
365
+ plt_path = get_plt_path(
366
+ OUT_PATH,
367
+ feature_file,
368
+ ch_name,
369
+ str_plt_type="MOV_aligned_features",
370
+ feature_name=feature_str_add,
371
+ )
372
+ plt.savefig(plt_path, bbox_inches="tight")
373
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
374
+ if not show_plot:
375
+ plt.close()
376
+
377
+
378
+ def plot_grid_elec_3d(
379
+ cortex_grid: np.ndarray | None = None,
380
+ ecog_strip: np.ndarray | None = None,
381
+ grid_color: np.ndarray | None = None,
382
+ strip_color: np.ndarray | None = None,
383
+ ):
384
+ ax = plt.axes(projection="3d")
385
+
386
+ if cortex_grid is not None:
387
+ grid_color = np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
388
+ _ = ax.scatter3D(
389
+ cortex_grid[:, 0],
390
+ cortex_grid[:, 1],
391
+ cortex_grid[:, 2],
392
+ c=grid_color,
393
+ s=300,
394
+ alpha=0.8,
395
+ cmap="viridis",
396
+ )
397
+
398
+ if ecog_strip is not None:
399
+ strip_color = (
400
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
401
+ )
402
+ _ = ax.scatter(
403
+ ecog_strip[:, 0],
404
+ ecog_strip[:, 1],
405
+ ecog_strip[:, 2],
406
+ c=strip_color,
407
+ s=500, # Bug? Third argument is s, what is this value?
408
+ alpha=0.8,
409
+ cmap="gray",
410
+ marker="o",
411
+ )
412
+
413
+
414
+ def plot_all_features(
415
+ df: pd.DataFrame,
416
+ time_limit_low_s: float | None = None,
417
+ time_limit_high_s: float | None = None,
418
+ normalize: bool = True,
419
+ ytick_labelsize: int = 4,
420
+ clim_low: float | None = None,
421
+ clim_high: float | None = None,
422
+ save: bool = False,
423
+ title="all_feature_plt.pdf",
424
+ OUT_PATH: _PathLike = "",
425
+ feature_file: str = "",
426
+ ):
427
+ if time_limit_high_s is not None:
428
+ df = df[df["time"] < time_limit_high_s * 1000]
429
+ if time_limit_low_s is not None:
430
+ df = df[df["time"] > time_limit_low_s * 1000]
431
+
432
+ cols_plt = [c for c in df.columns if c != "time"]
433
+ if normalize:
434
+ data_plt = scipy_zscore(df[cols_plt], nan_policy="omit")
435
+ else:
436
+ data_plt = df[cols_plt]
437
+
438
+ plt.figure() # figsize=(7, 5), dpi=300
439
+ plt.imshow(data_plt.T, aspect="auto")
440
+ plt.xlabel("Time [s]")
441
+ plt.ylabel("Feature Names")
442
+ plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
443
+
444
+ tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
445
+ tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
446
+ plt.xticks(tick_num, tick_labels)
447
+
448
+ plt.title(f"Feature Plot {feature_file}")
449
+
450
+ if clim_low is not None:
451
+ plt.clim(vmin=clim_low)
452
+ if clim_high is not None:
453
+ plt.clim(vmax=clim_high)
454
+
455
+ plt.colorbar()
456
+ plt.tight_layout()
457
+
458
+ if save:
459
+ plt_path = PurePath(OUT_PATH, feature_file, title)
460
+ plt.savefig(plt_path, bbox_inches="tight")
461
+
462
+
463
+ class NM_Plot:
464
+ def __init__(
465
+ self,
466
+ ecog_strip: np.ndarray | None = None,
467
+ grid_cortex: np.ndarray | None = None,
468
+ grid_subcortex: np.ndarray | None = None,
469
+ sess_right: bool | None = False,
470
+ proj_matrix_cortex: np.ndarray | None = None,
471
+ ) -> None:
472
+ self.grid_cortex = grid_cortex
473
+ self.grid_subcortex = grid_subcortex
474
+ self.ecog_strip = ecog_strip
475
+ self.sess_right = sess_right
476
+ self.proj_matrix_cortex = proj_matrix_cortex
477
+
478
+ from py_neuromodulation.nm_IO import read_plot_modules
479
+
480
+ (
481
+ self.faces,
482
+ self.vertices,
483
+ self.grid,
484
+ self.stn_surf,
485
+ self.x_ver,
486
+ self.y_ver,
487
+ self.x_ecog,
488
+ self.y_ecog,
489
+ self.z_ecog,
490
+ self.x_stn,
491
+ self.y_stn,
492
+ self.z_stn,
493
+ ) = read_plot_modules()
494
+
495
+ def plot_grid_elec_3d(self) -> None:
496
+ plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
497
+
498
+ def plot_cortex(
499
+ self,
500
+ grid_cortex: np.ndarray | pd.DataFrame | None = None,
501
+ grid_color: np.ndarray | None = None,
502
+ ecog_strip: np.ndarray | None = None,
503
+ strip_color: np.ndarray | None = None,
504
+ sess_right: bool | None = None,
505
+ save: bool = False,
506
+ OUT_PATH: _PathLike = "",
507
+ feature_file: str = "",
508
+ feature_str_add: str = "",
509
+ show_plot: bool = True,
510
+ title: str = "Cortical grid",
511
+ set_clim: bool = True,
512
+ lower_clim: float = 0.5,
513
+ upper_clim: float = 0.7,
514
+ cbar_label: str = "Balanced Accuracy",
515
+ ):
516
+ """Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
517
+ Colorcoded by grid_color
518
+ """
519
+
520
+ if grid_cortex is None:
521
+ if type(self.grid_cortex) is pd.DataFrame:
522
+ grid_cortex = np.array(self.grid_cortex)
523
+ else:
524
+ grid_cortex = self.grid_cortex
525
+
526
+ if ecog_strip is None:
527
+ ecog_strip = self.ecog_strip
528
+
529
+ if sess_right:
530
+ grid_cortex[0, :] = grid_cortex[0, :] * -1 # type: ignore # Handled above
531
+
532
+ fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
533
+ axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
534
+ axes.axes.set_aspect("equal", anchor="C")
535
+
536
+ if grid_cortex is not None:
537
+ grid_color = (
538
+ np.ones(grid_cortex.shape[0]) if grid_color is None else grid_color
539
+ )
540
+
541
+ pos_ecog = axes.scatter(
542
+ grid_cortex[:, 0],
543
+ grid_cortex[:, 1],
544
+ c=grid_color,
545
+ s=150,
546
+ alpha=0.8,
547
+ cmap="viridis",
548
+ label="grid points",
549
+ )
550
+ if set_clim:
551
+ pos_ecog.set_clim(lower_clim, upper_clim)
552
+ if ecog_strip is not None:
553
+ strip_color = (
554
+ np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
555
+ )
556
+
557
+ pos_ecog = axes.scatter(
558
+ ecog_strip[:, 0],
559
+ ecog_strip[:, 1],
560
+ c=strip_color,
561
+ s=400,
562
+ alpha=0.8,
563
+ cmap="viridis",
564
+ marker="x",
565
+ label="ecog electrode",
566
+ )
567
+ plt.axis("off")
568
+ plt.legend()
569
+ plt.title(title)
570
+ if set_clim:
571
+ pos_ecog.set_clim(lower_clim, upper_clim)
572
+ cbar = fig.colorbar(pos_ecog)
573
+ cbar.set_label(cbar_label)
574
+
575
+ if save:
576
+ plt_path = get_plt_path(
577
+ OUT_PATH,
578
+ feature_file,
579
+ str_plt_type="PLOT_CORTEX",
580
+ feature_name=feature_str_add,
581
+ )
582
+ plt.savefig(plt_path, bbox_inches="tight")
583
+ logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
584
+ if not show_plot:
585
+ plt.close()