bmtool 0.7.7__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bmtool might be problematic. Click here for more details.
- bmtool/analysis/entrainment.py +113 -0
- bmtool/analysis/lfp.py +1 -1
- bmtool/bmplot/connections.py +756 -339
- bmtool/bmplot/entrainment.py +169 -49
- bmtool/bmplot/lfp.py +146 -11
- bmtool/bmplot/netcon_reports.py +1 -0
- bmtool/bmplot/spikes.py +175 -18
- bmtool/singlecell.py +47 -2
- bmtool/synapses.py +1337 -633
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/METADATA +1 -1
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/RECORD +15 -15
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/WHEEL +0 -0
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.7.dist-info → bmtool-0.7.8.dist-info}/top_level.txt +0 -0
bmtool/bmplot/entrainment.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
from typing import List, Tuple, Union
|
|
1
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
|
+
from matplotlib.figure import Figure
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import seaborn as sns
|
|
@@ -18,16 +19,16 @@ def plot_spike_power_correlation(
|
|
|
18
19
|
lfp_data: xr.DataArray,
|
|
19
20
|
firing_quantile: float,
|
|
20
21
|
fs: float,
|
|
21
|
-
pop_names:
|
|
22
|
+
pop_names: List[str],
|
|
22
23
|
filter_method: str = "wavelet",
|
|
23
24
|
bandwidth: float = 2.0,
|
|
24
|
-
lowcut: float = None,
|
|
25
|
-
highcut: float = None,
|
|
26
|
-
freq_range:
|
|
25
|
+
lowcut: Optional[float] = None,
|
|
26
|
+
highcut: Optional[float] = None,
|
|
27
|
+
freq_range: Tuple[float, float] = (10, 100),
|
|
27
28
|
freq_step: float = 5,
|
|
28
29
|
type_name: str = "raw",
|
|
29
|
-
time_windows:
|
|
30
|
-
error_type: str = "ci",
|
|
30
|
+
time_windows: Optional[List[Tuple[float, float]]] = None,
|
|
31
|
+
error_type: str = "ci",
|
|
31
32
|
):
|
|
32
33
|
"""
|
|
33
34
|
Calculate and plot correlation between population spike rates and LFP power across frequencies.
|
|
@@ -35,13 +36,15 @@ def plot_spike_power_correlation(
|
|
|
35
36
|
|
|
36
37
|
Parameters
|
|
37
38
|
----------
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
spike_df : pd.DataFrame
|
|
40
|
+
DataFrame containing spike data with columns 'timestamps', 'node_ids', and 'pop_name'.
|
|
40
41
|
lfp_data : xr.DataArray
|
|
41
42
|
LFP data
|
|
43
|
+
firing_quantile : float
|
|
44
|
+
Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
|
|
42
45
|
fs : float
|
|
43
46
|
Sampling frequency
|
|
44
|
-
pop_names :
|
|
47
|
+
pop_names : List[str]
|
|
45
48
|
List of population names to analyze
|
|
46
49
|
filter_method : str, optional
|
|
47
50
|
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
|
@@ -51,16 +54,21 @@ def plot_spike_power_correlation(
|
|
|
51
54
|
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
|
52
55
|
highcut : float, optional
|
|
53
56
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
|
54
|
-
freq_range :
|
|
57
|
+
freq_range : Tuple[float, float], optional
|
|
55
58
|
Min and max frequency to analyze (default: (10, 100))
|
|
56
59
|
freq_step : float, optional
|
|
57
60
|
Step size for frequency analysis (default: 5)
|
|
58
61
|
type_name : str, optional
|
|
59
62
|
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
|
60
|
-
time_windows :
|
|
63
|
+
time_windows : List[Tuple[float, float]], optional
|
|
61
64
|
List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
|
|
62
65
|
error_type : str, optional
|
|
63
66
|
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
matplotlib.figure.Figure
|
|
71
|
+
The figure containing the correlation plot
|
|
64
72
|
"""
|
|
65
73
|
|
|
66
74
|
if not (0 <= firing_quantile < 1):
|
|
@@ -206,7 +214,7 @@ def plot_spike_power_correlation(
|
|
|
206
214
|
|
|
207
215
|
# Plotting
|
|
208
216
|
sns.set_style("whitegrid")
|
|
209
|
-
plt.figure(figsize=(12, 8))
|
|
217
|
+
fig = plt.figure(figsize=(12, 8))
|
|
210
218
|
|
|
211
219
|
for i, pop in enumerate(pop_names):
|
|
212
220
|
# Extract data for plotting
|
|
@@ -294,28 +302,31 @@ def plot_spike_power_correlation(
|
|
|
294
302
|
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
|
295
303
|
|
|
296
304
|
plt.tight_layout()
|
|
297
|
-
|
|
305
|
+
return fig
|
|
298
306
|
|
|
299
307
|
|
|
300
|
-
def plot_cycle_with_spike_histograms(phase_data, bins=36
|
|
308
|
+
def plot_cycle_with_spike_histograms(phase_data, pop_names: List[str], bins: int = 36):
|
|
301
309
|
"""
|
|
302
310
|
Plot an idealized cycle with spike histograms for different neuron populations.
|
|
303
311
|
|
|
304
|
-
Parameters
|
|
312
|
+
Parameters
|
|
305
313
|
-----------
|
|
306
314
|
phase_data : dict
|
|
307
315
|
Dictionary containing phase values for each spike and neuron population
|
|
308
|
-
|
|
309
|
-
Sampling frequency of LFP in Hz
|
|
310
|
-
bins : int
|
|
311
|
-
Number of bins for the phase histogram (default 36 gives 10-degree bins)
|
|
312
|
-
pop_name : list
|
|
316
|
+
pop_names : List[str]
|
|
313
317
|
List of population names to be plotted
|
|
318
|
+
bins : int, optional
|
|
319
|
+
Number of bins for the phase histogram (default 36 gives 10-degree bins)
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
matplotlib.figure.Figure
|
|
324
|
+
The figure containing the cycle and histograms
|
|
314
325
|
"""
|
|
315
326
|
sns.set_style("whitegrid")
|
|
316
327
|
# Create a figure with subplots
|
|
317
328
|
fig = plt.figure(figsize=(12, 8))
|
|
318
|
-
gs = GridSpec(len(
|
|
329
|
+
gs = GridSpec(len(pop_names) + 1, 1, height_ratios=[1.5] + [1] * len(pop_names))
|
|
319
330
|
|
|
320
331
|
# Top subplot: Idealized gamma cycle
|
|
321
332
|
ax_gamma = fig.add_subplot(gs[0])
|
|
@@ -335,10 +346,10 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
|
|
|
335
346
|
ax_gamma.axvline(x=0, color="k", linestyle="--", alpha=0.3)
|
|
336
347
|
|
|
337
348
|
# Generate a color map for the different populations
|
|
338
|
-
colors = plt.cm.tab10(np.linspace(0, 1, len(
|
|
349
|
+
colors = plt.cm.tab10(np.linspace(0, 1, len(pop_names)))
|
|
339
350
|
|
|
340
351
|
# Add histograms for each neuron population
|
|
341
|
-
for i, pop_name in enumerate(
|
|
352
|
+
for i, pop_name in enumerate(pop_names):
|
|
342
353
|
ax_hist = fig.add_subplot(gs[i + 1], sharex=ax_gamma)
|
|
343
354
|
|
|
344
355
|
# Compute histogram
|
|
@@ -361,27 +372,34 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
|
|
|
361
372
|
ax_hist.set_xlabel("Phase (degrees)", fontsize=12)
|
|
362
373
|
|
|
363
374
|
plt.tight_layout()
|
|
364
|
-
|
|
375
|
+
return fig
|
|
365
376
|
|
|
366
377
|
|
|
367
|
-
def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
|
|
378
|
+
def plot_entrainment_by_population(ppc_dict: Dict[str, Dict[str, Dict[float, float]]], pop_names: List[str], freqs: List[float], figsize: Tuple[float, float] = (15, 8), title: Optional[str] = None):
|
|
368
379
|
"""
|
|
369
380
|
Plot PPC for all node populations on one graph with mean and standard error.
|
|
370
381
|
|
|
371
382
|
Parameters:
|
|
372
383
|
-----------
|
|
373
|
-
ppc_dict :
|
|
384
|
+
ppc_dict : Dict[str, Dict[str, Dict[float, float]]]
|
|
374
385
|
Dictionary containing PPC data organized by population, node, and frequency
|
|
375
|
-
pop_names :
|
|
386
|
+
pop_names : List[str]
|
|
376
387
|
List of population names to plot data for
|
|
377
|
-
freqs :
|
|
388
|
+
freqs : List[float]
|
|
378
389
|
List of frequencies to plot
|
|
379
|
-
figsize :
|
|
390
|
+
figsize : Tuple[float, float], optional
|
|
380
391
|
Figure size for the plot
|
|
392
|
+
title : str, optional
|
|
393
|
+
Title for the plot
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
matplotlib.figure.Figure
|
|
398
|
+
The figure containing the bar plot
|
|
381
399
|
"""
|
|
382
400
|
# Set up the visualization style
|
|
383
401
|
sns.set_style("whitegrid")
|
|
384
|
-
plt.figure(figsize=figsize)
|
|
402
|
+
fig = plt.figure(figsize=figsize)
|
|
385
403
|
|
|
386
404
|
# Calculate the width of each group of bars
|
|
387
405
|
n_groups = len(freqs)
|
|
@@ -443,23 +461,25 @@ def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8),
|
|
|
443
461
|
|
|
444
462
|
# Adjust layout and save
|
|
445
463
|
plt.tight_layout()
|
|
446
|
-
|
|
464
|
+
return fig
|
|
447
465
|
|
|
448
466
|
|
|
449
|
-
def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title=None):
|
|
467
|
+
def plot_entrainment_swarm_plot(ppc_dict: Dict[str, Dict[str, Dict[float, float]]], pop_names: List[str], freq: Union[float, int], save_path: Optional[str] = None, title: Optional[str] = None):
|
|
450
468
|
"""
|
|
451
469
|
Plot a swarm plot of the entrainment for different populations at a single frequency.
|
|
452
470
|
|
|
453
471
|
Parameters:
|
|
454
472
|
-----------
|
|
455
|
-
ppc_dict :
|
|
473
|
+
ppc_dict : Dict[str, Dict[str, Dict[float, float]]]
|
|
456
474
|
Dictionary containing PPC values organized by population, node, and frequency
|
|
457
|
-
pop_names :
|
|
475
|
+
pop_names : List[str]
|
|
458
476
|
List of population names to include in the plot
|
|
459
|
-
freq : float
|
|
477
|
+
freq : Union[float, int]
|
|
460
478
|
The specific frequency to plot
|
|
461
479
|
save_path : str, optional
|
|
462
480
|
Path to save the figure. If None, figure is just displayed.
|
|
481
|
+
title : str, optional
|
|
482
|
+
Title for the plot
|
|
463
483
|
|
|
464
484
|
Returns:
|
|
465
485
|
--------
|
|
@@ -500,7 +520,7 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
|
|
|
500
520
|
print(f"{pop}: {mean_val:.4f} ± {sem_val:.4f} (n={n})")
|
|
501
521
|
|
|
502
522
|
# Create figure
|
|
503
|
-
plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
|
|
523
|
+
fig = plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
|
|
504
524
|
|
|
505
525
|
# Create swarm plot
|
|
506
526
|
ax = sns.swarmplot(
|
|
@@ -608,20 +628,20 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
|
|
|
608
628
|
if save_path:
|
|
609
629
|
plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
|
|
610
630
|
|
|
611
|
-
|
|
631
|
+
return fig
|
|
612
632
|
|
|
613
633
|
|
|
614
634
|
def plot_trial_avg_entrainment(
|
|
615
635
|
spike_df: pd.DataFrame,
|
|
616
|
-
lfp:
|
|
636
|
+
lfp: xr.DataArray,
|
|
617
637
|
time_windows: List[Tuple[float, float]],
|
|
618
638
|
entrainment_method: str,
|
|
619
639
|
pop_names: List[str],
|
|
620
640
|
freqs: Union[List[float], np.ndarray],
|
|
621
641
|
firing_quantile: float,
|
|
622
642
|
spike_fs: float = 1000,
|
|
623
|
-
error_type: str = "ci",
|
|
624
|
-
) ->
|
|
643
|
+
error_type: str = "ci",
|
|
644
|
+
) -> Figure:
|
|
625
645
|
"""
|
|
626
646
|
Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
|
|
627
647
|
|
|
@@ -629,9 +649,7 @@ def plot_trial_avg_entrainment(
|
|
|
629
649
|
-----------
|
|
630
650
|
spike_df : pd.DataFrame
|
|
631
651
|
Spike data containing timestamps, node_ids, and pop_name columns
|
|
632
|
-
|
|
633
|
-
fs for spike data. Default is 1000
|
|
634
|
-
lfp : xarray
|
|
652
|
+
lfp : xr.DataArray
|
|
635
653
|
Xarray for a channel of the lfp data
|
|
636
654
|
time_windows : List[Tuple[float, float]]
|
|
637
655
|
List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
|
|
@@ -643,7 +661,9 @@ def plot_trial_avg_entrainment(
|
|
|
643
661
|
Array of frequencies to analyze (Hz)
|
|
644
662
|
firing_quantile : float
|
|
645
663
|
Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
|
|
646
|
-
|
|
664
|
+
spike_fs : float, optional
|
|
665
|
+
fs for spike data. Default is 1000
|
|
666
|
+
error_type : str, optional
|
|
647
667
|
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
|
648
668
|
|
|
649
669
|
Raises:
|
|
@@ -655,8 +675,8 @@ def plot_trial_avg_entrainment(
|
|
|
655
675
|
|
|
656
676
|
Returns:
|
|
657
677
|
--------
|
|
658
|
-
|
|
659
|
-
|
|
678
|
+
matplotlib.figure.Figure
|
|
679
|
+
The figure containing the plot
|
|
660
680
|
"""
|
|
661
681
|
sns.set_style("whitegrid")
|
|
662
682
|
# Validate inputs
|
|
@@ -813,7 +833,7 @@ def plot_trial_avg_entrainment(
|
|
|
813
833
|
error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
|
|
814
834
|
|
|
815
835
|
# Create the combined plot
|
|
816
|
-
plt.figure(figsize=(12, 8))
|
|
836
|
+
fig = plt.figure(figsize=(12, 8))
|
|
817
837
|
|
|
818
838
|
# Define markers and colors for different populations
|
|
819
839
|
markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
|
|
@@ -862,4 +882,104 @@ def plot_trial_avg_entrainment(
|
|
|
862
882
|
plt.legend(fontsize=10)
|
|
863
883
|
plt.grid(True, alpha=0.3)
|
|
864
884
|
plt.tight_layout()
|
|
865
|
-
|
|
885
|
+
return fig
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def plot_fr_hist_phase_amplitude(
|
|
889
|
+
fr_hist: np.ndarray,
|
|
890
|
+
pop_names: List[str],
|
|
891
|
+
freq_labels: List[str],
|
|
892
|
+
nbins_pha: int = 16,
|
|
893
|
+
nbins_amp: int = 16,
|
|
894
|
+
common_clim: bool = True,
|
|
895
|
+
figsize: Tuple[float, float] = (3, 2),
|
|
896
|
+
cmap: str = 'viridis',
|
|
897
|
+
title: Optional[str] = None
|
|
898
|
+
) -> Tuple[plt.Figure, np.ndarray]:
|
|
899
|
+
"""
|
|
900
|
+
Plot firing rate histograms binned by LFP phase and amplitude.
|
|
901
|
+
Check out the bmtool/bmtool/analysis/entrainment.py function
|
|
902
|
+
compute_fr_hist_phase_amplitude
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
fr_hist : np.ndarray
|
|
907
|
+
Firing rate histogram of shape (n_pop, n_freq, nbins_pha, nbins_amp)
|
|
908
|
+
pop_names : List[str]
|
|
909
|
+
List of population names
|
|
910
|
+
freq_labels : List[str]
|
|
911
|
+
List of frequency labels for subplot titles (e.g., ['Beta', 'Gamma'])
|
|
912
|
+
nbins_pha : int, default=16
|
|
913
|
+
Number of phase bins
|
|
914
|
+
nbins_amp : int, default=16
|
|
915
|
+
Number of amplitude bins
|
|
916
|
+
common_clim : bool, default=True
|
|
917
|
+
Whether to use common color limits across all subplots
|
|
918
|
+
figsize : Tuple[float, float], default=(3, 2)
|
|
919
|
+
Size of each subplot
|
|
920
|
+
cmap : str, default='RdBu_r'
|
|
921
|
+
Colormap to use
|
|
922
|
+
title : Optional[str], default=None
|
|
923
|
+
Overall title for the figure
|
|
924
|
+
|
|
925
|
+
Returns
|
|
926
|
+
-------
|
|
927
|
+
Tuple[plt.Figure, np.ndarray]
|
|
928
|
+
Figure and axes objects
|
|
929
|
+
|
|
930
|
+
Examples
|
|
931
|
+
--------
|
|
932
|
+
>>> fig, axs = plot_fr_hist_phase_amplitude(
|
|
933
|
+
... fr_hist, ['PV', 'SST'], ['Beta', 'Gamma'],
|
|
934
|
+
... common_clim=True, cmap='RdBu_r', title='LFP Phase-Amplitude Coupling'
|
|
935
|
+
... )
|
|
936
|
+
"""
|
|
937
|
+
pha_bins = np.linspace(-np.pi, np.pi, nbins_pha + 1)
|
|
938
|
+
quantiles = np.linspace(0, 1, nbins_amp + 1)
|
|
939
|
+
|
|
940
|
+
n_pop = len(pop_names)
|
|
941
|
+
n_freq = len(freq_labels)
|
|
942
|
+
|
|
943
|
+
fig, axs = plt.subplots(n_pop, n_freq,
|
|
944
|
+
figsize=(figsize[0] * n_freq, figsize[1] * n_pop),
|
|
945
|
+
squeeze=False)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
# Add overall title if provided
|
|
949
|
+
if title:
|
|
950
|
+
fig.suptitle(title, fontsize=14, y=0.98)
|
|
951
|
+
|
|
952
|
+
for i, p in enumerate(pop_names):
|
|
953
|
+
if common_clim:
|
|
954
|
+
vmin, vmax = fr_hist.min(), fr_hist.max()
|
|
955
|
+
else:
|
|
956
|
+
vmin, vmax = None, None
|
|
957
|
+
|
|
958
|
+
for j, freq_label in enumerate(freq_labels):
|
|
959
|
+
ax = axs[i, j]
|
|
960
|
+
pcm = ax.pcolormesh(pha_bins, quantiles, fr_hist[i, j].T,
|
|
961
|
+
vmin=vmin, vmax=vmax, cmap=cmap)
|
|
962
|
+
ax.set_title(p)
|
|
963
|
+
|
|
964
|
+
if i < n_pop - 1:
|
|
965
|
+
ax.get_xaxis().set_visible(False)
|
|
966
|
+
else:
|
|
967
|
+
ax.set_xlabel(freq_label.title() + ' Phase')
|
|
968
|
+
ax.set_xticks((-np.pi, 0, np.pi))
|
|
969
|
+
ax.set_xticklabels([r'$-\pi$', '0', r'$\pi$'])
|
|
970
|
+
|
|
971
|
+
if j > 0:
|
|
972
|
+
ax.get_yaxis().set_visible(False)
|
|
973
|
+
else:
|
|
974
|
+
ax.set_ylabel('Amplitude (quantile)')
|
|
975
|
+
|
|
976
|
+
if not common_clim:
|
|
977
|
+
plt.colorbar(mappable=pcm, ax=ax,
|
|
978
|
+
label='Firing rate (% Change)' if j == n_freq - 1 else None,
|
|
979
|
+
pad=0.02)
|
|
980
|
+
|
|
981
|
+
if common_clim:
|
|
982
|
+
plt.colorbar(mappable=pcm, ax=axs[i],
|
|
983
|
+
label='Firing rate (% Change)', pad=0.02)
|
|
984
|
+
|
|
985
|
+
return fig, axs
|
bmtool/bmplot/lfp.py
CHANGED
|
@@ -1,18 +1,54 @@
|
|
|
1
1
|
import matplotlib.pyplot as plt
|
|
2
2
|
import numpy as np
|
|
3
3
|
from fooof.sim.gen import gen_aperiodic
|
|
4
|
+
from typing import Optional, List, Dict, Tuple, Any
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from ..analysis.spikes import get_population_spike_rate
|
|
7
|
+
from ..analysis.lfp import get_lfp_power, load_ecp_to_xarray, ecp_to_lfp
|
|
8
|
+
from matplotlib.figure import Figure
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
def plot_spectrogram(
|
|
7
|
-
sxx_xarray,
|
|
8
|
-
remove_aperiodic=None,
|
|
9
|
-
log_power=False,
|
|
10
|
-
plt_range=None,
|
|
11
|
-
clr_freq_range=None,
|
|
12
|
-
pad=0.03,
|
|
13
|
-
ax=None,
|
|
14
|
-
):
|
|
15
|
-
"""
|
|
12
|
+
sxx_xarray: Any,
|
|
13
|
+
remove_aperiodic: Optional[Any] = None,
|
|
14
|
+
log_power: bool = False,
|
|
15
|
+
plt_range: Optional[Tuple[float, float]] = None,
|
|
16
|
+
clr_freq_range: Optional[Tuple[float, float]] = None,
|
|
17
|
+
pad: float = 0.03,
|
|
18
|
+
ax: Optional[plt.Axes] = None,
|
|
19
|
+
) -> Figure:
|
|
20
|
+
"""
|
|
21
|
+
Plot a power spectrogram with optional aperiodic removal and frequency-based coloring.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
sxx_xarray : array-like
|
|
26
|
+
Spectrogram data as an xarray DataArray with PSD values.
|
|
27
|
+
remove_aperiodic : optional
|
|
28
|
+
FOOOF model object for aperiodic subtraction. If None, raw spectrum is displayed.
|
|
29
|
+
log_power : bool or str, optional
|
|
30
|
+
If True or 'dB', convert power to log scale. Default is False.
|
|
31
|
+
plt_range : tuple of float, optional
|
|
32
|
+
Frequency range to display as (f_min, f_max). If None, displays full range.
|
|
33
|
+
clr_freq_range : tuple of float, optional
|
|
34
|
+
Frequency range to use for determining color limits. If None, uses full range.
|
|
35
|
+
pad : float, optional
|
|
36
|
+
Padding for colorbar. Default is 0.03.
|
|
37
|
+
ax : matplotlib.axes.Axes, optional
|
|
38
|
+
Axes to plot on. If None, creates a new figure and axes.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
matplotlib.figure.Figure
|
|
43
|
+
The figure object containing the spectrogram.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
>>> fig = plot_spectrogram(
|
|
48
|
+
... sxx_xarray, log_power='dB',
|
|
49
|
+
... plt_range=(10, 100), clr_freq_range=(20, 50)
|
|
50
|
+
... )
|
|
51
|
+
"""
|
|
16
52
|
sxx = sxx_xarray.PSD.values.copy()
|
|
17
53
|
t = sxx_xarray.time.values.copy()
|
|
18
54
|
f = sxx_xarray.frequency.values.copy()
|
|
@@ -45,7 +81,7 @@ def plot_spectrogram(
|
|
|
45
81
|
vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
|
|
46
82
|
|
|
47
83
|
f = f[f_idx]
|
|
48
|
-
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax)
|
|
84
|
+
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax, rasterized=True)
|
|
49
85
|
if "cone_of_influence_frequency" in sxx_xarray:
|
|
50
86
|
coif = sxx_xarray.cone_of_influence_frequency
|
|
51
87
|
ax.plot(t, coif)
|
|
@@ -56,4 +92,103 @@ def plot_spectrogram(
|
|
|
56
92
|
plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
|
|
57
93
|
ax.set_xlabel("Time (sec)")
|
|
58
94
|
ax.set_ylabel("Frequency (Hz)")
|
|
59
|
-
return
|
|
95
|
+
return ax.figure
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def plot_population_spike_rates_with_lfp(
|
|
99
|
+
spikes_df: pd.DataFrame,
|
|
100
|
+
freq_of_interest: List[float],
|
|
101
|
+
freq_labels: List[str],
|
|
102
|
+
freq_colors: List[str],
|
|
103
|
+
time_range: Tuple[float, float],
|
|
104
|
+
pop_names: List[str],
|
|
105
|
+
pop_color: Dict[str, str],
|
|
106
|
+
trial_path: str,
|
|
107
|
+
filter_column: Optional[str] = None,
|
|
108
|
+
filter_value: Optional[Any] = None,
|
|
109
|
+
) -> Optional[Figure]:
|
|
110
|
+
"""
|
|
111
|
+
Plot population spike rates with LFP power overlays.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
spikes_df : pd.DataFrame
|
|
116
|
+
DataFrame with spike data.
|
|
117
|
+
freq_of_interest : list of float
|
|
118
|
+
List of frequencies for LFP power analysis (required).
|
|
119
|
+
freq_labels : list of str
|
|
120
|
+
Labels for the frequencies (required).
|
|
121
|
+
freq_colors : list of str
|
|
122
|
+
Colors for the frequency plots (required).
|
|
123
|
+
time_range : tuple of float
|
|
124
|
+
Tuple (start, end) for x-axis time limits (required).
|
|
125
|
+
pop_names : list of str
|
|
126
|
+
List of population names (required).
|
|
127
|
+
pop_color : dict
|
|
128
|
+
Dictionary mapping population names to colors (required).
|
|
129
|
+
trial_path : str
|
|
130
|
+
Path to trial data (required).
|
|
131
|
+
filter_column : str, optional
|
|
132
|
+
Column name to filter spikes_df on (optional).
|
|
133
|
+
filter_value : any, optional
|
|
134
|
+
Value to filter for in filter_column (optional).
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
matplotlib.figure.Figure or None
|
|
139
|
+
Figure object containing the plot, or None if no data to plot.
|
|
140
|
+
|
|
141
|
+
Examples
|
|
142
|
+
--------
|
|
143
|
+
>>> fig = plot_population_spike_rates_with_lfp(
|
|
144
|
+
... spikes_df, [40, 80], ['Beta', 'Gamma'],
|
|
145
|
+
... ['blue', 'red'], (0, 10), ['PV', 'SST'],
|
|
146
|
+
... {'PV': 'blue', 'SST': 'red'}, 'trial_data.h5'
|
|
147
|
+
... )
|
|
148
|
+
"""
|
|
149
|
+
# Compute spike rates based on filtering
|
|
150
|
+
if filter_column and filter_column in spikes_df.columns:
|
|
151
|
+
filtered_df = spikes_df[spikes_df[filter_column] == filter_value]
|
|
152
|
+
if not filtered_df.empty:
|
|
153
|
+
spike_rate = get_population_spike_rate(filtered_df, fs=400, network_name='cortex')
|
|
154
|
+
plot_title = f'{filter_column} {filter_value}'
|
|
155
|
+
save_suffix = f'_{filter_column}_{filter_value}'
|
|
156
|
+
else:
|
|
157
|
+
print(f"No data found for {filter_column} == {filter_value}.")
|
|
158
|
+
return
|
|
159
|
+
else:
|
|
160
|
+
spike_rate = get_population_spike_rate(spikes_df, fs=400, network_name='cortex')
|
|
161
|
+
plot_title = 'Overall Spike Rates'
|
|
162
|
+
save_suffix = '_overall'
|
|
163
|
+
|
|
164
|
+
# Load LFP data and compute power for each frequency of interest
|
|
165
|
+
ecp = load_ecp_to_xarray(ecp_file=trial_path + "/ecp.h5")
|
|
166
|
+
lfp = ecp_to_lfp(ecp)
|
|
167
|
+
powers = [
|
|
168
|
+
get_lfp_power(lfp, freq_of_interest=freq, fs=lfp.fs, filter_method="wavelet", bandwidth=1.0)
|
|
169
|
+
for freq in freq_of_interest
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
# Plotting
|
|
173
|
+
fig, axes = plt.subplots(len(spike_rate.population), 1, figsize=(12, 10))
|
|
174
|
+
for i, pop in enumerate(pop_names):
|
|
175
|
+
if pop in spike_rate.population.values:
|
|
176
|
+
ax = axes.flat[i]
|
|
177
|
+
spike_rate.sel(type='raw', population=pop).plot(ax=ax, color=pop_color[pop])
|
|
178
|
+
ax.set_title(f'{pop}')
|
|
179
|
+
ax.set_ylabel('Spike Rate (Hz)', color=pop_color[pop])
|
|
180
|
+
ax.tick_params(axis='y', labelcolor=pop_color[pop])
|
|
181
|
+
|
|
182
|
+
# Twin axis for LFP power
|
|
183
|
+
ax2 = ax.twinx()
|
|
184
|
+
for power, label, color in zip(powers, freq_labels, freq_colors):
|
|
185
|
+
ax2.plot(power['time'], power.values.squeeze(), color=color, label=label)
|
|
186
|
+
ax2.set_ylabel('LFP Power', color='black')
|
|
187
|
+
ax2.tick_params(axis='y', labelcolor='black')
|
|
188
|
+
ax2.legend(loc='upper right')
|
|
189
|
+
|
|
190
|
+
ax.set_xlim(time_range)
|
|
191
|
+
|
|
192
|
+
fig.suptitle(plot_title, fontsize=16, y=0.98)
|
|
193
|
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
|
194
|
+
return fig
|
bmtool/bmplot/netcon_reports.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# NOT IMPLETMENTED YET
|