bmtool 0.6.9.28__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +2 -50
- bmtool/analysis/lfp.py +0 -51
- bmtool/bmplot/__init__.py +0 -0
- bmtool/{bmplot.py → bmplot/connections.py} +18 -444
- bmtool/bmplot/entrainment.py +51 -0
- bmtool/bmplot/lfp.py +53 -0
- bmtool/bmplot/netcon_reports.py +4 -0
- bmtool/bmplot/spikes.py +259 -0
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/METADATA +1 -1
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/RECORD +14 -9
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/WHEEL +1 -1
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.6.9.28.dist-info → bmtool-0.7.0.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -12,6 +12,8 @@ from .lfp import wavelet_filter,butter_bandpass_filter
|
|
12
12
|
from typing import Dict, List
|
13
13
|
from tqdm.notebook import tqdm
|
14
14
|
import scipy.stats as stats
|
15
|
+
import seaborn as sns
|
16
|
+
import matplotlib.pyplot as plt
|
15
17
|
|
16
18
|
|
17
19
|
def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
@@ -486,53 +488,3 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
486
488
|
|
487
489
|
return correlation_results, frequencies
|
488
490
|
|
489
|
-
|
490
|
-
def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
491
|
-
"""
|
492
|
-
Plot the correlation between population spike rates and LFP power.
|
493
|
-
|
494
|
-
Parameters:
|
495
|
-
-----------
|
496
|
-
correlation_results : dict
|
497
|
-
Dictionary with correlation results for calculate_spike_rate_power_correlation
|
498
|
-
frequencies : array
|
499
|
-
Array of frequencies analyzed
|
500
|
-
pop_names : list
|
501
|
-
List of population names
|
502
|
-
"""
|
503
|
-
sns.set_style("whitegrid")
|
504
|
-
plt.figure(figsize=(10, 6))
|
505
|
-
|
506
|
-
for pop in pop_names:
|
507
|
-
# Extract correlation values for each frequency
|
508
|
-
corr_values = []
|
509
|
-
valid_freqs = []
|
510
|
-
|
511
|
-
for freq in frequencies:
|
512
|
-
if freq in correlation_results[pop]:
|
513
|
-
corr_values.append(correlation_results[pop][freq]['correlation'])
|
514
|
-
valid_freqs.append(freq)
|
515
|
-
|
516
|
-
# Plot correlation line
|
517
|
-
plt.plot(valid_freqs, corr_values, marker='o', label=pop,
|
518
|
-
linewidth=2, markersize=6)
|
519
|
-
|
520
|
-
plt.xlabel('Frequency (Hz)', fontsize=12)
|
521
|
-
plt.ylabel('Spike Rate-Power Correlation', fontsize=12)
|
522
|
-
plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
|
523
|
-
plt.grid(True, alpha=0.3)
|
524
|
-
plt.legend(fontsize=12)
|
525
|
-
plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
|
526
|
-
|
527
|
-
# Add horizontal line at zero for reference
|
528
|
-
plt.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
|
529
|
-
|
530
|
-
# Set y-axis limits to make zero visible
|
531
|
-
y_min, y_max = plt.ylim()
|
532
|
-
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
533
|
-
|
534
|
-
plt.tight_layout()
|
535
|
-
|
536
|
-
plt.show()
|
537
|
-
|
538
|
-
|
bmtool/analysis/lfp.py
CHANGED
@@ -10,7 +10,6 @@ from fooof.sim.gen import gen_model, gen_aperiodic
|
|
10
10
|
import matplotlib.pyplot as plt
|
11
11
|
from scipy import signal
|
12
12
|
import pywt
|
13
|
-
from bmtool.bmplot import is_notebook
|
14
13
|
import pandas as pd
|
15
14
|
|
16
15
|
|
@@ -434,53 +433,3 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
|
|
434
433
|
sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
|
435
434
|
return sxx
|
436
435
|
|
437
|
-
|
438
|
-
# will probs move to bmplot later
|
439
|
-
def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
|
440
|
-
plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
|
441
|
-
"""Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
|
442
|
-
sxx = sxx_xarray.PSD.values.copy()
|
443
|
-
t = sxx_xarray.time.values.copy()
|
444
|
-
f = sxx_xarray.frequency.values.copy()
|
445
|
-
|
446
|
-
cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
|
447
|
-
if log_power:
|
448
|
-
with np.errstate(divide='ignore'):
|
449
|
-
sxx = np.log10(sxx)
|
450
|
-
cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
|
451
|
-
|
452
|
-
if remove_aperiodic is not None:
|
453
|
-
f1_idx = 0 if f[0] else 1
|
454
|
-
ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
|
455
|
-
sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
|
456
|
-
sxx[:f1_idx, :] = 0.
|
457
|
-
|
458
|
-
if log_power == 'dB':
|
459
|
-
sxx *= 10
|
460
|
-
|
461
|
-
if ax is None:
|
462
|
-
_, ax = plt.subplots(1, 1)
|
463
|
-
plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
|
464
|
-
if plt_range.size == 1:
|
465
|
-
plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
|
466
|
-
f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
|
467
|
-
if clr_freq_range is None:
|
468
|
-
vmin, vmax = None, None
|
469
|
-
else:
|
470
|
-
c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
|
471
|
-
vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
|
472
|
-
|
473
|
-
f = f[f_idx]
|
474
|
-
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
|
475
|
-
if 'cone_of_influence_frequency' in sxx_xarray:
|
476
|
-
coif = sxx_xarray.cone_of_influence_frequency
|
477
|
-
ax.plot(t, coif)
|
478
|
-
ax.fill_between(t, coif, step='mid', alpha=0.2)
|
479
|
-
ax.set_xlim(t[0], t[-1])
|
480
|
-
#ax.set_xlim(t[0],0.2)
|
481
|
-
ax.set_ylim(f[0], f[-1])
|
482
|
-
plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
|
483
|
-
ax.set_xlabel('Time (sec)')
|
484
|
-
ax.set_ylabel('Frequency (Hz)')
|
485
|
-
return sxx
|
486
|
-
|
File without changes
|
@@ -2,18 +2,12 @@
|
|
2
2
|
Want to be able to take multiple plot names in and plot them all at the same time, to save time
|
3
3
|
https://stackoverflow.com/questions/458209/is-there-a-way-to-detach-matplotlib-plots-so-that-the-computation-can-continue
|
4
4
|
"""
|
5
|
-
from
|
6
|
-
import argparse,os,sys
|
7
|
-
|
5
|
+
from ..util import util
|
8
6
|
import numpy as np
|
9
7
|
import matplotlib
|
10
8
|
import matplotlib.pyplot as plt
|
11
9
|
import matplotlib.cm as cmx
|
12
10
|
import matplotlib.colors as colors
|
13
|
-
import matplotlib.gridspec as gridspec
|
14
|
-
from mpl_toolkits.mplot3d import Axes3D
|
15
|
-
from matplotlib.axes import Axes
|
16
|
-
import seaborn as sns
|
17
11
|
from IPython import get_ipython
|
18
12
|
from IPython.display import display, HTML
|
19
13
|
import statistics
|
@@ -21,9 +15,8 @@ import pandas as pd
|
|
21
15
|
import os
|
22
16
|
import sys
|
23
17
|
import re
|
24
|
-
from typing import Optional, Dict, Union, List
|
25
18
|
|
26
|
-
from
|
19
|
+
from ..util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
|
27
20
|
from bmtk.analyzer.utils import listify
|
28
21
|
from neuron import h
|
29
22
|
|
@@ -59,6 +52,7 @@ def is_notebook() -> bool:
|
|
59
52
|
except NameError:
|
60
53
|
return False # Probably standard Python interpreter
|
61
54
|
|
55
|
+
|
62
56
|
def total_connection_matrix(config=None, title=None, sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False, save_file=None, synaptic_info='0', include_gap=True):
|
63
57
|
"""
|
64
58
|
Generate a plot displaying total connections or other synaptic statistics.
|
@@ -122,7 +116,8 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
|
|
122
116
|
title = "All Synapse .json Files Used"
|
123
117
|
plot_connection_info(text,num,source_labels,target_labels,title, syn_info=synaptic_info, save_file=save_file)
|
124
118
|
return
|
125
|
-
|
119
|
+
|
120
|
+
|
126
121
|
def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,method = 'total',include_gap=True):
|
127
122
|
"""
|
128
123
|
Generates a plot showing the percent connectivity of a network
|
@@ -159,6 +154,7 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
|
|
159
154
|
plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
|
160
155
|
return
|
161
156
|
|
157
|
+
|
162
158
|
def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None,
|
163
159
|
no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False,include_gap=True):
|
164
160
|
"""
|
@@ -235,6 +231,7 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
|
|
235
231
|
|
236
232
|
return
|
237
233
|
|
234
|
+
|
238
235
|
def convergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean+std',include_gap=True,return_dict=None):
|
239
236
|
"""
|
240
237
|
Generates connection plot displaying convergence data
|
@@ -253,6 +250,7 @@ def convergence_connection_matrix(config=None,title=None,sources=None, targets=N
|
|
253
250
|
raise Exception("Sources or targets not defined")
|
254
251
|
return divergence_connection_matrix(config,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method,include_gap=include_gap,return_dict=return_dict)
|
255
252
|
|
253
|
+
|
256
254
|
def divergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean+std',include_gap=True,return_dict=None):
|
257
255
|
"""
|
258
256
|
Generates connection plot displaying divergence data
|
@@ -309,6 +307,7 @@ def divergence_connection_matrix(config=None,title=None,sources=None, targets=No
|
|
309
307
|
plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
|
310
308
|
return
|
311
309
|
|
310
|
+
|
312
311
|
def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=None,tids=None, no_prepend_pop=False,save_file=None,method='convergence'):
|
313
312
|
"""
|
314
313
|
Generates connection plot displaying gap junction data.
|
@@ -431,7 +430,8 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
431
430
|
title+=' Percent Connectivity'
|
432
431
|
plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
|
433
432
|
return
|
434
|
-
|
433
|
+
|
434
|
+
|
435
435
|
def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],no_prepend_pop=True,synaptic_info='0',
|
436
436
|
source_cell = None,target_cell = None,include_gap=True):
|
437
437
|
"""
|
@@ -520,6 +520,7 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
|
|
520
520
|
tids = []
|
521
521
|
util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,not no_prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)
|
522
522
|
|
523
|
+
|
523
524
|
def connection_distance(config: str,sources: str,targets: str,
|
524
525
|
source_cell_id: int,target_id_type: str,ignore_z:bool=False) -> None:
|
525
526
|
"""
|
@@ -600,6 +601,7 @@ def connection_distance(config: str,sources: str,targets: str,
|
|
600
601
|
plt.grid(True)
|
601
602
|
plt.show()
|
602
603
|
|
604
|
+
|
603
605
|
def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
|
604
606
|
"""
|
605
607
|
Generates a matrix of histograms showing the distribution of edge properties between different populations.
|
@@ -683,6 +685,7 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
|
|
683
685
|
fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
|
684
686
|
plt.draw()
|
685
687
|
|
688
|
+
|
686
689
|
def distance_delay_plot(simulation_config: str,source: str,target: str,
|
687
690
|
group_by: str,sid: str,tid: str) -> None:
|
688
691
|
"""
|
@@ -739,6 +742,7 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
|
|
739
742
|
plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
|
740
743
|
plt.show()
|
741
744
|
|
745
|
+
|
742
746
|
def plot_synapse_location_histograms(config, target_model, source=None, target=None):
|
743
747
|
"""
|
744
748
|
generates a histogram of the positions of the synapses on a cell broken down by section
|
@@ -832,6 +836,7 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
|
|
832
836
|
)
|
833
837
|
print(pivot_table)
|
834
838
|
|
839
|
+
|
835
840
|
def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
|
836
841
|
"""
|
837
842
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
@@ -909,6 +914,7 @@ def plot_connection_info(text, num, source_labels, target_labels, title, syn_inf
|
|
909
914
|
else:
|
910
915
|
return
|
911
916
|
|
917
|
+
|
912
918
|
def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_key=None, title: str = 'Percent connection matrix', pop_order=None) -> None:
|
913
919
|
"""
|
914
920
|
Generates and plots a connection matrix based on connection probabilities from a CSV file produced by bmtool.connector.
|
@@ -1063,275 +1069,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
1063
1069
|
plt.tight_layout()
|
1064
1070
|
plt.show()
|
1065
1071
|
|
1066
|
-
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
|
1067
|
-
ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
|
1068
|
-
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
1069
|
-
"""
|
1070
|
-
Plots a raster plot of neural spikes, with different colors for each population.
|
1071
|
-
|
1072
|
-
Parameters:
|
1073
|
-
----------
|
1074
|
-
spikes_df : pd.DataFrame, optional
|
1075
|
-
DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
|
1076
|
-
config : str, optional
|
1077
|
-
Path to the configuration file used to load node data.
|
1078
|
-
network_name : str, optional
|
1079
|
-
Specific network name to select from the configuration; if not provided, uses the first network.
|
1080
|
-
ax : matplotlib.axes.Axes, optional
|
1081
|
-
Axes on which to plot the raster; if None, a new figure and axes are created.
|
1082
|
-
tstart : float, optional
|
1083
|
-
Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
|
1084
|
-
tstop : float, optional
|
1085
|
-
Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
|
1086
|
-
color_map : dict, optional
|
1087
|
-
Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
|
1088
|
-
|
1089
|
-
Returns:
|
1090
|
-
-------
|
1091
|
-
matplotlib.axes.Axes
|
1092
|
-
Axes with the raster plot.
|
1093
|
-
|
1094
|
-
Notes:
|
1095
|
-
-----
|
1096
|
-
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
|
1097
|
-
- Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
|
1098
|
-
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
|
1099
|
-
"""
|
1100
|
-
# Initialize axes if none provided
|
1101
|
-
if ax is None:
|
1102
|
-
_, ax = plt.subplots(1, 1)
|
1103
|
-
|
1104
|
-
# Filter spikes by time range if specified
|
1105
|
-
if tstart is not None:
|
1106
|
-
spikes_df = spikes_df[spikes_df['timestamps'] > tstart]
|
1107
|
-
if tstop is not None:
|
1108
|
-
spikes_df = spikes_df[spikes_df['timestamps'] < tstop]
|
1109
|
-
|
1110
|
-
# Load and merge node population data if config is provided
|
1111
|
-
if config:
|
1112
|
-
nodes = load_nodes_from_config(config)
|
1113
|
-
if network_name:
|
1114
|
-
nodes = nodes.get(network_name, {})
|
1115
|
-
else:
|
1116
|
-
nodes = list(nodes.values())[0] if nodes else {}
|
1117
|
-
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
1118
|
-
|
1119
|
-
# Find common columns, but exclude the join key from the list
|
1120
|
-
common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
|
1121
|
-
common_columns = [col for col in common_columns if col != 'node_ids'] # Remove our join key from the common list
|
1122
|
-
|
1123
|
-
# Drop all intersecting columns except the join key column from df2
|
1124
|
-
spikes_df = spikes_df.drop(columns=common_columns)
|
1125
|
-
# merge nodes and spikes df
|
1126
|
-
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
1127
|
-
|
1128
1072
|
|
1129
|
-
# Get unique population names
|
1130
|
-
unique_pop_names = spikes_df[groupby].unique()
|
1131
|
-
|
1132
|
-
# Generate colors if no color_map is provided
|
1133
|
-
if color_map is None:
|
1134
|
-
cmap = plt.get_cmap('tab10') # Default colormap
|
1135
|
-
color_map = {pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)}
|
1136
|
-
else:
|
1137
|
-
# Ensure color_map contains all population names
|
1138
|
-
missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
|
1139
|
-
if missing_colors:
|
1140
|
-
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
1141
|
-
|
1142
|
-
# Plot each population with its specified or generated color
|
1143
|
-
for pop_name, group in spikes_df.groupby(groupby):
|
1144
|
-
ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
|
1145
|
-
|
1146
|
-
# Label axes
|
1147
|
-
ax.set_xlabel("Time")
|
1148
|
-
ax.set_ylabel("Node ID")
|
1149
|
-
ax.legend(title="Population", loc='upper right', framealpha=0.9, markerfirst=False)
|
1150
|
-
|
1151
|
-
return ax
|
1152
|
-
|
1153
|
-
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
1154
|
-
def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
|
1155
|
-
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
1156
|
-
"""
|
1157
|
-
Plots a bar graph of mean firing rates with error bars (standard deviation).
|
1158
|
-
|
1159
|
-
Parameters:
|
1160
|
-
----------
|
1161
|
-
firing_stats : pd.DataFrame
|
1162
|
-
Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
|
1163
|
-
groupby : str or list of str
|
1164
|
-
Column(s) used for grouping.
|
1165
|
-
ax : matplotlib.axes.Axes, optional
|
1166
|
-
Axes on which to plot the bar chart; if None, a new figure and axes are created.
|
1167
|
-
color_map : dict, optional
|
1168
|
-
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
1169
|
-
|
1170
|
-
Returns:
|
1171
|
-
-------
|
1172
|
-
matplotlib.axes.Axes
|
1173
|
-
Axes with the bar plot.
|
1174
|
-
"""
|
1175
|
-
# Ensure groupby is a list for consistent handling
|
1176
|
-
if isinstance(groupby, str):
|
1177
|
-
groupby = [groupby]
|
1178
|
-
|
1179
|
-
# Create a categorical column for grouping
|
1180
|
-
firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
|
1181
|
-
|
1182
|
-
# Get unique group names
|
1183
|
-
unique_groups = firing_stats["group"].unique()
|
1184
|
-
|
1185
|
-
# Generate colors if no color_map is provided
|
1186
|
-
if color_map is None:
|
1187
|
-
cmap = plt.get_cmap('viridis')
|
1188
|
-
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
1189
|
-
else:
|
1190
|
-
# Ensure color_map contains all groups
|
1191
|
-
missing_colors = [group for group in unique_groups if group not in color_map]
|
1192
|
-
if missing_colors:
|
1193
|
-
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
1194
|
-
|
1195
|
-
# Create new figure and axes if ax is not provided
|
1196
|
-
if ax is None:
|
1197
|
-
fig, ax = plt.subplots(figsize=(10, 6))
|
1198
|
-
|
1199
|
-
# Sort data for consistent plotting
|
1200
|
-
firing_stats = firing_stats.sort_values(by="group")
|
1201
|
-
|
1202
|
-
# Extract values for plotting
|
1203
|
-
x_labels = firing_stats["group"]
|
1204
|
-
means = firing_stats["firing_rate_mean"]
|
1205
|
-
std_devs = firing_stats["firing_rate_std"]
|
1206
|
-
|
1207
|
-
# Get colors for each group
|
1208
|
-
colors = [color_map[group] for group in x_labels]
|
1209
|
-
|
1210
|
-
# Create bar plot
|
1211
|
-
bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
|
1212
|
-
|
1213
|
-
# Add error bars manually with caps
|
1214
|
-
_, caps, _ = ax.errorbar(
|
1215
|
-
x=np.arange(len(x_labels)),
|
1216
|
-
y=means,
|
1217
|
-
yerr=std_devs,
|
1218
|
-
fmt='none',
|
1219
|
-
capsize=5,
|
1220
|
-
capthick=2,
|
1221
|
-
color="black"
|
1222
|
-
)
|
1223
|
-
|
1224
|
-
# Formatting
|
1225
|
-
ax.set_xticks(np.arange(len(x_labels)))
|
1226
|
-
ax.set_xticklabels(x_labels, rotation=45, ha="right")
|
1227
|
-
ax.set_xlabel("Population Group")
|
1228
|
-
ax.set_ylabel("Mean Firing Rate (spikes/s)")
|
1229
|
-
ax.set_title("Firing Rate Statistics by Population")
|
1230
|
-
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
1231
|
-
|
1232
|
-
return ax
|
1233
|
-
|
1234
|
-
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
1235
|
-
def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
|
1236
|
-
color_map: Optional[Dict[str, str]] = None,
|
1237
|
-
plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
|
1238
|
-
"""
|
1239
|
-
Plots a distribution of individual firing rates using one or more plot types
|
1240
|
-
(box plot, violin plot, or swarm plot), overlaying them on top of each other.
|
1241
|
-
|
1242
|
-
Parameters:
|
1243
|
-
----------
|
1244
|
-
individual_stats : pd.DataFrame
|
1245
|
-
Dataframe containing individual firing rates and corresponding group labels.
|
1246
|
-
groupby : str or list of str
|
1247
|
-
Column(s) used for grouping.
|
1248
|
-
ax : matplotlib.axes.Axes, optional
|
1249
|
-
Axes on which to plot the graph; if None, a new figure and axes are created.
|
1250
|
-
color_map : dict, optional
|
1251
|
-
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
1252
|
-
plot_type : str or list of str, optional
|
1253
|
-
List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
|
1254
|
-
swarm_alpha : float, optional
|
1255
|
-
Transparency of swarm plot points. Default is 0.6.
|
1256
|
-
|
1257
|
-
Returns:
|
1258
|
-
-------
|
1259
|
-
matplotlib.axes.Axes
|
1260
|
-
Axes with the selected plot type(s) overlayed.
|
1261
|
-
"""
|
1262
|
-
# Ensure groupby is a list for consistent handling
|
1263
|
-
if isinstance(groupby, str):
|
1264
|
-
groupby = [groupby]
|
1265
|
-
|
1266
|
-
# Create a categorical column for grouping
|
1267
|
-
individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
|
1268
|
-
|
1269
|
-
# Validate plot_type (it can be a list or a single type)
|
1270
|
-
if isinstance(plot_type, str):
|
1271
|
-
plot_type = [plot_type]
|
1272
|
-
|
1273
|
-
for pt in plot_type:
|
1274
|
-
if pt not in ["box", "violin", "swarm"]:
|
1275
|
-
raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
|
1276
|
-
|
1277
|
-
# Get unique groups for coloring
|
1278
|
-
unique_groups = individual_stats["group"].unique()
|
1279
|
-
|
1280
|
-
# Generate colors if no color_map is provided
|
1281
|
-
if color_map is None:
|
1282
|
-
cmap = plt.get_cmap('viridis')
|
1283
|
-
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
1284
|
-
|
1285
|
-
# Ensure color_map contains all groups
|
1286
|
-
missing_colors = [group for group in unique_groups if group not in color_map]
|
1287
|
-
if missing_colors:
|
1288
|
-
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
1289
|
-
|
1290
|
-
# Create new figure and axes if ax is not provided
|
1291
|
-
if ax is None:
|
1292
|
-
fig, ax = plt.subplots(figsize=(10, 6))
|
1293
|
-
|
1294
|
-
# Sort data for consistent plotting
|
1295
|
-
individual_stats = individual_stats.sort_values(by="group")
|
1296
|
-
|
1297
|
-
# Loop over each plot type and overlay them
|
1298
|
-
for pt in plot_type:
|
1299
|
-
if pt == "box":
|
1300
|
-
sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
|
1301
|
-
elif pt == "violin":
|
1302
|
-
sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
|
1303
|
-
elif pt == "swarm":
|
1304
|
-
sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
|
1305
|
-
|
1306
|
-
# Formatting
|
1307
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
|
1308
|
-
ax.set_xlabel("Population Group")
|
1309
|
-
ax.set_ylabel("Firing Rate (spikes/s)")
|
1310
|
-
ax.set_title("Firing Rate Distribution for individual cells")
|
1311
|
-
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
1312
|
-
|
1313
|
-
return ax
|
1314
|
-
|
1315
|
-
def plot_entrainment():
|
1316
|
-
"""
|
1317
|
-
Plots entrainment analysis for oscillatory network activity.
|
1318
|
-
|
1319
|
-
This function analyzes and visualizes how well neural populations entrain to rhythmic
|
1320
|
-
input or how synchronized they become during oscillatory activity. It can show phase
|
1321
|
-
locking, coherence, or other entrainment metrics.
|
1322
|
-
|
1323
|
-
Note: This is currently a placeholder function and not yet implemented.
|
1324
|
-
|
1325
|
-
Parameters:
|
1326
|
-
-----------
|
1327
|
-
None
|
1328
|
-
|
1329
|
-
Returns:
|
1330
|
-
--------
|
1331
|
-
None
|
1332
|
-
"""
|
1333
|
-
pass
|
1334
|
-
|
1335
1073
|
def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file=None, subset=None):
|
1336
1074
|
"""
|
1337
1075
|
Plots a 3D graph of all cells with x, y, z location.
|
@@ -1431,6 +1169,7 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1431
1169
|
|
1432
1170
|
return ax
|
1433
1171
|
|
1172
|
+
|
1434
1173
|
def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save_file=None, quiver_length=None, arrow_length_ratio=None, group=None, subset=None):
|
1435
1174
|
from scipy.spatial.transform import Rotation as R
|
1436
1175
|
if not config:
|
@@ -1531,168 +1270,3 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
|
|
1531
1270
|
notebook = is_notebook
|
1532
1271
|
if notebook == False:
|
1533
1272
|
plt.show()
|
1534
|
-
|
1535
|
-
def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,edge_property='model_template'):
|
1536
|
-
"""
|
1537
|
-
Creates a directed graph visualization of the network connectivity using NetworkX.
|
1538
|
-
|
1539
|
-
This function generates a network diagram showing the connections between different
|
1540
|
-
cell populations, with edge labels indicating the connection types based on the specified
|
1541
|
-
edge property.
|
1542
|
-
|
1543
|
-
Parameters:
|
1544
|
-
-----------
|
1545
|
-
config : str
|
1546
|
-
Path to a BMTK simulation configuration file.
|
1547
|
-
nodes : dict, optional
|
1548
|
-
Dictionary of node information (if already loaded).
|
1549
|
-
edges : dict, optional
|
1550
|
-
Dictionary of edge information (if already loaded).
|
1551
|
-
title : str, optional
|
1552
|
-
Custom title for the plot. If None, defaults to "Network Graph".
|
1553
|
-
sources : str
|
1554
|
-
Comma-separated list of source network names.
|
1555
|
-
targets : str
|
1556
|
-
Comma-separated list of target network names.
|
1557
|
-
sids : str, optional
|
1558
|
-
Comma-separated list of source node identifiers to filter by.
|
1559
|
-
tids : str, optional
|
1560
|
-
Comma-separated list of target node identifiers to filter by.
|
1561
|
-
no_prepend_pop : bool, default=False
|
1562
|
-
If True, population names are not prepended to node identifiers in the display.
|
1563
|
-
save_file : str, optional
|
1564
|
-
Path to save the generated plot.
|
1565
|
-
edge_property : str, default='model_template'
|
1566
|
-
The edge property to use for labeling connections in the graph.
|
1567
|
-
|
1568
|
-
Returns:
|
1569
|
-
--------
|
1570
|
-
None
|
1571
|
-
Displays a network graph visualization.
|
1572
|
-
"""
|
1573
|
-
if not config:
|
1574
|
-
raise Exception("config not defined")
|
1575
|
-
if not sources or not targets:
|
1576
|
-
raise Exception("Sources or targets not defined")
|
1577
|
-
sources = sources.split(",")
|
1578
|
-
targets = targets.split(",")
|
1579
|
-
if sids:
|
1580
|
-
sids = sids.split(",")
|
1581
|
-
else:
|
1582
|
-
sids = []
|
1583
|
-
if tids:
|
1584
|
-
tids = tids.split(",")
|
1585
|
-
else:
|
1586
|
-
tids = []
|
1587
|
-
throw_away, data, source_labels, target_labels = util.connection_graph_edge_types(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
|
1588
|
-
|
1589
|
-
if title == None or title=="":
|
1590
|
-
title = "Network Graph"
|
1591
|
-
|
1592
|
-
import networkx as nx
|
1593
|
-
|
1594
|
-
net_graph = nx.MultiDiGraph() #or G = nx.MultiDiGraph()
|
1595
|
-
|
1596
|
-
edges = []
|
1597
|
-
edge_labels = {}
|
1598
|
-
for node in list(set(source_labels+target_labels)):
|
1599
|
-
net_graph.add_node(node)
|
1600
|
-
|
1601
|
-
for s, source in enumerate(source_labels):
|
1602
|
-
for t, target in enumerate(target_labels):
|
1603
|
-
relationship = data[s][t]
|
1604
|
-
for i, relation in enumerate(relationship):
|
1605
|
-
edge_labels[(source,target)]=relation
|
1606
|
-
edges.append([source,target])
|
1607
|
-
|
1608
|
-
net_graph.add_edges_from(edges)
|
1609
|
-
#pos = nx.spring_layout(net_graph,k=0.50,iterations=20)
|
1610
|
-
pos = nx.shell_layout(net_graph)
|
1611
|
-
plt.figure()
|
1612
|
-
nx.draw(net_graph,pos,edge_color='black', width=1,linewidths=1,\
|
1613
|
-
node_size=500,node_color='white',arrowstyle='->',alpha=0.9,\
|
1614
|
-
labels={node:node for node in net_graph.nodes()})
|
1615
|
-
|
1616
|
-
nx.draw_networkx_edge_labels(net_graph,pos,edge_labels=edge_labels,font_color='red')
|
1617
|
-
plt.show()
|
1618
|
-
|
1619
|
-
return
|
1620
|
-
|
1621
|
-
def plot_report(config_file=None, report_file=None, report_name=None, variables=None, gids=None):
|
1622
|
-
if report_file is None:
|
1623
|
-
report_name, report_file = _get_cell_report(config_file, report_name)
|
1624
|
-
|
1625
|
-
var_report = CellVarsFile(report_file)
|
1626
|
-
variables = listify(variables) if variables is not None else var_report.variables
|
1627
|
-
gids = listify(gids) if gids is not None else var_report.gids
|
1628
|
-
time_steps = var_report.time_trace
|
1629
|
-
|
1630
|
-
def __units_str(var):
|
1631
|
-
units = var_report.units(var)
|
1632
|
-
if units == CellVarsFile.UNITS_UNKNOWN:
|
1633
|
-
units = missing_units.get(var, '')
|
1634
|
-
return '({})'.format(units) if units else ''
|
1635
|
-
|
1636
|
-
n_plots = len(variables)
|
1637
|
-
if n_plots > 1:
|
1638
|
-
# If more than one variale to plot do so in different subplots
|
1639
|
-
f, axarr = plt.subplots(n_plots, 1)
|
1640
|
-
for i, var in enumerate(variables):
|
1641
|
-
for gid in gids:
|
1642
|
-
axarr[i].plot(time_steps, var_report.data(gid=gid, var_name=var), label='gid {}'.format(gid))
|
1643
|
-
|
1644
|
-
axarr[i].legend()
|
1645
|
-
axarr[i].set_ylabel('{} {}'.format(var, __units_str(var)))
|
1646
|
-
if i < n_plots - 1:
|
1647
|
-
axarr[i].set_xticklabels([])
|
1648
|
-
|
1649
|
-
axarr[i].set_xlabel('time (ms)')
|
1650
|
-
|
1651
|
-
elif n_plots == 1:
|
1652
|
-
# For plotting a single variable
|
1653
|
-
plt.figure()
|
1654
|
-
for gid in gids:
|
1655
|
-
plt.plot(time_steps, var_report.data(gid=gid, var_name=variables[0]), label='gid {}'.format(gid))
|
1656
|
-
plt.ylabel('{} {}'.format(variables[0], __units_str(variables[0])))
|
1657
|
-
plt.xlabel('time (ms)')
|
1658
|
-
plt.legend()
|
1659
|
-
else:
|
1660
|
-
return
|
1661
|
-
|
1662
|
-
plt.show()
|
1663
|
-
|
1664
|
-
def plot_report_default(config, report_name, variables, gids):
|
1665
|
-
"""
|
1666
|
-
A simplified interface for plotting cell report variables from BMTK simulations.
|
1667
|
-
|
1668
|
-
This function handles the common case of plotting specific variables for specific cells
|
1669
|
-
from a BMTK report file, with minimal parameter requirements.
|
1670
|
-
|
1671
|
-
Parameters:
|
1672
|
-
-----------
|
1673
|
-
config : str
|
1674
|
-
Path to a BMTK simulation configuration file.
|
1675
|
-
report_name : str
|
1676
|
-
Name of the report to plot (without file extension).
|
1677
|
-
variables : str
|
1678
|
-
Comma-separated list of variable names to plot (e.g., 'v,i_na,i_k').
|
1679
|
-
gids : str
|
1680
|
-
Comma-separated list of cell IDs (gids) to plot data for.
|
1681
|
-
|
1682
|
-
Returns:
|
1683
|
-
--------
|
1684
|
-
None
|
1685
|
-
Displays plots of the specified variables for the specified cells.
|
1686
|
-
"""
|
1687
|
-
|
1688
|
-
if variables:
|
1689
|
-
variables = variables.split(',')
|
1690
|
-
if gids:
|
1691
|
-
gids = [int(i) for i in gids.split(',')]
|
1692
|
-
|
1693
|
-
if report_name:
|
1694
|
-
cfg = util.load_config(config)
|
1695
|
-
report_file = os.path.join(cfg['output']['output_dir'],report_name+'.h5')
|
1696
|
-
plot_report(config_file=config, report_file=report_file, report_name=report_name, variables=variables, gids=gids);
|
1697
|
-
|
1698
|
-
return
|
@@ -0,0 +1,51 @@
|
|
1
|
+
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
import seaborn as sns
|
4
|
+
|
5
|
+
def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
6
|
+
"""
|
7
|
+
Plot the correlation between population spike rates and LFP power.
|
8
|
+
|
9
|
+
Parameters:
|
10
|
+
-----------
|
11
|
+
correlation_results : dict
|
12
|
+
Dictionary with correlation results for calculate_spike_rate_power_correlation
|
13
|
+
frequencies : array
|
14
|
+
Array of frequencies analyzed
|
15
|
+
pop_names : list
|
16
|
+
List of population names
|
17
|
+
"""
|
18
|
+
sns.set_style("whitegrid")
|
19
|
+
plt.figure(figsize=(10, 6))
|
20
|
+
|
21
|
+
for pop in pop_names:
|
22
|
+
# Extract correlation values for each frequency
|
23
|
+
corr_values = []
|
24
|
+
valid_freqs = []
|
25
|
+
|
26
|
+
for freq in frequencies:
|
27
|
+
if freq in correlation_results[pop]:
|
28
|
+
corr_values.append(correlation_results[pop][freq]['correlation'])
|
29
|
+
valid_freqs.append(freq)
|
30
|
+
|
31
|
+
# Plot correlation line
|
32
|
+
plt.plot(valid_freqs, corr_values, marker='o', label=pop,
|
33
|
+
linewidth=2, markersize=6)
|
34
|
+
|
35
|
+
plt.xlabel('Frequency (Hz)', fontsize=12)
|
36
|
+
plt.ylabel('Spike Rate-Power Correlation', fontsize=12)
|
37
|
+
plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
|
38
|
+
plt.grid(True, alpha=0.3)
|
39
|
+
plt.legend(fontsize=12)
|
40
|
+
plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
|
41
|
+
|
42
|
+
# Add horizontal line at zero for reference
|
43
|
+
plt.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
|
44
|
+
|
45
|
+
# Set y-axis limits to make zero visible
|
46
|
+
y_min, y_max = plt.ylim()
|
47
|
+
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
48
|
+
|
49
|
+
plt.tight_layout()
|
50
|
+
|
51
|
+
plt.show()
|
bmtool/bmplot/lfp.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from bmtool.analysis.lfp import gen_aperiodic
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
|
5
|
+
|
6
|
+
def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
|
7
|
+
plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
|
8
|
+
"""Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
|
9
|
+
sxx = sxx_xarray.PSD.values.copy()
|
10
|
+
t = sxx_xarray.time.values.copy()
|
11
|
+
f = sxx_xarray.frequency.values.copy()
|
12
|
+
|
13
|
+
cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
|
14
|
+
if log_power:
|
15
|
+
with np.errstate(divide='ignore'):
|
16
|
+
sxx = np.log10(sxx)
|
17
|
+
cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
|
18
|
+
|
19
|
+
if remove_aperiodic is not None:
|
20
|
+
f1_idx = 0 if f[0] else 1
|
21
|
+
ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
|
22
|
+
sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
|
23
|
+
sxx[:f1_idx, :] = 0.
|
24
|
+
|
25
|
+
if log_power == 'dB':
|
26
|
+
sxx *= 10
|
27
|
+
|
28
|
+
if ax is None:
|
29
|
+
_, ax = plt.subplots(1, 1)
|
30
|
+
plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
|
31
|
+
if plt_range.size == 1:
|
32
|
+
plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
|
33
|
+
f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
|
34
|
+
if clr_freq_range is None:
|
35
|
+
vmin, vmax = None, None
|
36
|
+
else:
|
37
|
+
c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
|
38
|
+
vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
|
39
|
+
|
40
|
+
f = f[f_idx]
|
41
|
+
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
|
42
|
+
if 'cone_of_influence_frequency' in sxx_xarray:
|
43
|
+
coif = sxx_xarray.cone_of_influence_frequency
|
44
|
+
ax.plot(t, coif)
|
45
|
+
ax.fill_between(t, coif, step='mid', alpha=0.2)
|
46
|
+
ax.set_xlim(t[0], t[-1])
|
47
|
+
#ax.set_xlim(t[0],0.2)
|
48
|
+
ax.set_ylim(f[0], f[-1])
|
49
|
+
plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
|
50
|
+
ax.set_xlabel('Time (sec)')
|
51
|
+
ax.set_ylabel('Frequency (Hz)')
|
52
|
+
return sxx
|
53
|
+
|
bmtool/bmplot/spikes.py
ADDED
@@ -0,0 +1,259 @@
|
|
1
|
+
from ..util.util import load_nodes_from_config
|
2
|
+
from typing import Optional, Dict, List, Union
|
3
|
+
import pandas as pd
|
4
|
+
from matplotlib.axes import Axes
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import seaborn as sns
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
|
10
|
+
|
11
|
+
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
|
12
|
+
ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
|
13
|
+
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
14
|
+
"""
|
15
|
+
Plots a raster plot of neural spikes, with different colors for each population.
|
16
|
+
|
17
|
+
Parameters:
|
18
|
+
----------
|
19
|
+
spikes_df : pd.DataFrame, optional
|
20
|
+
DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
|
21
|
+
config : str, optional
|
22
|
+
Path to the configuration file used to load node data.
|
23
|
+
network_name : str, optional
|
24
|
+
Specific network name to select from the configuration; if not provided, uses the first network.
|
25
|
+
ax : matplotlib.axes.Axes, optional
|
26
|
+
Axes on which to plot the raster; if None, a new figure and axes are created.
|
27
|
+
tstart : float, optional
|
28
|
+
Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
|
29
|
+
tstop : float, optional
|
30
|
+
Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
|
31
|
+
color_map : dict, optional
|
32
|
+
Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
-------
|
36
|
+
matplotlib.axes.Axes
|
37
|
+
Axes with the raster plot.
|
38
|
+
|
39
|
+
Notes:
|
40
|
+
-----
|
41
|
+
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
|
42
|
+
- Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
|
43
|
+
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
|
44
|
+
"""
|
45
|
+
# Initialize axes if none provided
|
46
|
+
if ax is None:
|
47
|
+
_, ax = plt.subplots(1, 1)
|
48
|
+
|
49
|
+
# Filter spikes by time range if specified
|
50
|
+
if tstart is not None:
|
51
|
+
spikes_df = spikes_df[spikes_df['timestamps'] > tstart]
|
52
|
+
if tstop is not None:
|
53
|
+
spikes_df = spikes_df[spikes_df['timestamps'] < tstop]
|
54
|
+
|
55
|
+
# Load and merge node population data if config is provided
|
56
|
+
if config:
|
57
|
+
nodes = load_nodes_from_config(config)
|
58
|
+
if network_name:
|
59
|
+
nodes = nodes.get(network_name, {})
|
60
|
+
else:
|
61
|
+
nodes = list(nodes.values())[0] if nodes else {}
|
62
|
+
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
63
|
+
|
64
|
+
# Find common columns, but exclude the join key from the list
|
65
|
+
common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
|
66
|
+
common_columns = [col for col in common_columns if col != 'node_ids'] # Remove our join key from the common list
|
67
|
+
|
68
|
+
# Drop all intersecting columns except the join key column from df2
|
69
|
+
spikes_df = spikes_df.drop(columns=common_columns)
|
70
|
+
# merge nodes and spikes df
|
71
|
+
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
72
|
+
|
73
|
+
|
74
|
+
# Get unique population names
|
75
|
+
unique_pop_names = spikes_df[groupby].unique()
|
76
|
+
|
77
|
+
# Generate colors if no color_map is provided
|
78
|
+
if color_map is None:
|
79
|
+
cmap = plt.get_cmap('tab10') # Default colormap
|
80
|
+
color_map = {pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)}
|
81
|
+
else:
|
82
|
+
# Ensure color_map contains all population names
|
83
|
+
missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
|
84
|
+
if missing_colors:
|
85
|
+
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
86
|
+
|
87
|
+
# Plot each population with its specified or generated color
|
88
|
+
for pop_name, group in spikes_df.groupby(groupby):
|
89
|
+
ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
|
90
|
+
|
91
|
+
# Label axes
|
92
|
+
ax.set_xlabel("Time")
|
93
|
+
ax.set_ylabel("Node ID")
|
94
|
+
ax.legend(title="Population", loc='upper right', framealpha=0.9, markerfirst=False)
|
95
|
+
|
96
|
+
return ax
|
97
|
+
|
98
|
+
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
99
|
+
def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
|
100
|
+
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
101
|
+
"""
|
102
|
+
Plots a bar graph of mean firing rates with error bars (standard deviation).
|
103
|
+
|
104
|
+
Parameters:
|
105
|
+
----------
|
106
|
+
firing_stats : pd.DataFrame
|
107
|
+
Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
|
108
|
+
groupby : str or list of str
|
109
|
+
Column(s) used for grouping.
|
110
|
+
ax : matplotlib.axes.Axes, optional
|
111
|
+
Axes on which to plot the bar chart; if None, a new figure and axes are created.
|
112
|
+
color_map : dict, optional
|
113
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
-------
|
117
|
+
matplotlib.axes.Axes
|
118
|
+
Axes with the bar plot.
|
119
|
+
"""
|
120
|
+
# Ensure groupby is a list for consistent handling
|
121
|
+
if isinstance(groupby, str):
|
122
|
+
groupby = [groupby]
|
123
|
+
|
124
|
+
# Create a categorical column for grouping
|
125
|
+
firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
|
126
|
+
|
127
|
+
# Get unique group names
|
128
|
+
unique_groups = firing_stats["group"].unique()
|
129
|
+
|
130
|
+
# Generate colors if no color_map is provided
|
131
|
+
if color_map is None:
|
132
|
+
cmap = plt.get_cmap('viridis')
|
133
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
134
|
+
else:
|
135
|
+
# Ensure color_map contains all groups
|
136
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
137
|
+
if missing_colors:
|
138
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
139
|
+
|
140
|
+
# Create new figure and axes if ax is not provided
|
141
|
+
if ax is None:
|
142
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
143
|
+
|
144
|
+
# Sort data for consistent plotting
|
145
|
+
firing_stats = firing_stats.sort_values(by="group")
|
146
|
+
|
147
|
+
# Extract values for plotting
|
148
|
+
x_labels = firing_stats["group"]
|
149
|
+
means = firing_stats["firing_rate_mean"]
|
150
|
+
std_devs = firing_stats["firing_rate_std"]
|
151
|
+
|
152
|
+
# Get colors for each group
|
153
|
+
colors = [color_map[group] for group in x_labels]
|
154
|
+
|
155
|
+
# Create bar plot
|
156
|
+
bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
|
157
|
+
|
158
|
+
# Add error bars manually with caps
|
159
|
+
_, caps, _ = ax.errorbar(
|
160
|
+
x=np.arange(len(x_labels)),
|
161
|
+
y=means,
|
162
|
+
yerr=std_devs,
|
163
|
+
fmt='none',
|
164
|
+
capsize=5,
|
165
|
+
capthick=2,
|
166
|
+
color="black"
|
167
|
+
)
|
168
|
+
|
169
|
+
# Formatting
|
170
|
+
ax.set_xticks(np.arange(len(x_labels)))
|
171
|
+
ax.set_xticklabels(x_labels, rotation=45, ha="right")
|
172
|
+
ax.set_xlabel("Population Group")
|
173
|
+
ax.set_ylabel("Mean Firing Rate (spikes/s)")
|
174
|
+
ax.set_title("Firing Rate Statistics by Population")
|
175
|
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
176
|
+
|
177
|
+
return ax
|
178
|
+
|
179
|
+
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
180
|
+
def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
|
181
|
+
color_map: Optional[Dict[str, str]] = None,
|
182
|
+
plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
|
183
|
+
"""
|
184
|
+
Plots a distribution of individual firing rates using one or more plot types
|
185
|
+
(box plot, violin plot, or swarm plot), overlaying them on top of each other.
|
186
|
+
|
187
|
+
Parameters:
|
188
|
+
----------
|
189
|
+
individual_stats : pd.DataFrame
|
190
|
+
Dataframe containing individual firing rates and corresponding group labels.
|
191
|
+
groupby : str or list of str
|
192
|
+
Column(s) used for grouping.
|
193
|
+
ax : matplotlib.axes.Axes, optional
|
194
|
+
Axes on which to plot the graph; if None, a new figure and axes are created.
|
195
|
+
color_map : dict, optional
|
196
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
197
|
+
plot_type : str or list of str, optional
|
198
|
+
List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
|
199
|
+
swarm_alpha : float, optional
|
200
|
+
Transparency of swarm plot points. Default is 0.6.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
-------
|
204
|
+
matplotlib.axes.Axes
|
205
|
+
Axes with the selected plot type(s) overlayed.
|
206
|
+
"""
|
207
|
+
# Ensure groupby is a list for consistent handling
|
208
|
+
if isinstance(groupby, str):
|
209
|
+
groupby = [groupby]
|
210
|
+
|
211
|
+
# Create a categorical column for grouping
|
212
|
+
individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
|
213
|
+
|
214
|
+
# Validate plot_type (it can be a list or a single type)
|
215
|
+
if isinstance(plot_type, str):
|
216
|
+
plot_type = [plot_type]
|
217
|
+
|
218
|
+
for pt in plot_type:
|
219
|
+
if pt not in ["box", "violin", "swarm"]:
|
220
|
+
raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
|
221
|
+
|
222
|
+
# Get unique groups for coloring
|
223
|
+
unique_groups = individual_stats["group"].unique()
|
224
|
+
|
225
|
+
# Generate colors if no color_map is provided
|
226
|
+
if color_map is None:
|
227
|
+
cmap = plt.get_cmap('viridis')
|
228
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
229
|
+
|
230
|
+
# Ensure color_map contains all groups
|
231
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
232
|
+
if missing_colors:
|
233
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
234
|
+
|
235
|
+
# Create new figure and axes if ax is not provided
|
236
|
+
if ax is None:
|
237
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
238
|
+
|
239
|
+
# Sort data for consistent plotting
|
240
|
+
individual_stats = individual_stats.sort_values(by="group")
|
241
|
+
|
242
|
+
# Loop over each plot type and overlay them
|
243
|
+
for pt in plot_type:
|
244
|
+
if pt == "box":
|
245
|
+
sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
|
246
|
+
elif pt == "violin":
|
247
|
+
sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
|
248
|
+
elif pt == "swarm":
|
249
|
+
sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
|
250
|
+
|
251
|
+
# Formatting
|
252
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
|
253
|
+
ax.set_xlabel("Population Group")
|
254
|
+
ax.set_ylabel("Firing Rate (spikes/s)")
|
255
|
+
ax.set_title("Firing Rate Distribution for individual cells")
|
256
|
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
257
|
+
|
258
|
+
return ax
|
259
|
+
|
@@ -1,7 +1,6 @@
|
|
1
1
|
bmtool/SLURM.py,sha256=PST_jOD5ZmwbJj15Tgq3UIvdq4FYN4EkPuDt66P8OXU,20136
|
2
2
|
bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
|
3
3
|
bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
|
4
|
-
bmtool/bmplot.py,sha256=GmXn4qAlgkPwhM9fwUcVKSbJDMRJBWiH6U90oE03ZPE,68757
|
5
4
|
bmtool/connectors.py,sha256=uLhZIjur0_jWOtSZ9w6-PHftB9Xj6FFXWL5tndEMDYY,73570
|
6
5
|
bmtool/graphs.py,sha256=ShBgJr1iZrM3ugU2wT6hbhmBAkc3mmf7yZQfPuPEqPM,6691
|
7
6
|
bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
|
@@ -9,10 +8,16 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
|
9
8
|
bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
|
10
9
|
bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
|
11
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bmtool/analysis/entrainment.py,sha256=
|
13
|
-
bmtool/analysis/lfp.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=IMhjbLYw-rL-MfRuFT3uCkyUFObNVJhcxmYV0R9Uh-M,20007
|
12
|
+
bmtool/analysis/lfp.py,sha256=1gsOUAtxM3eA8YqJl9jNtxMPG5cdSNqzeUX0s5WXfdQ,16710
|
14
13
|
bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
|
15
14
|
bmtool/analysis/spikes.py,sha256=x24kd0RUhumJkiunfHNEE7mM6JUqdWy1gqabmkMM4cU,14129
|
15
|
+
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
|
+
bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
|
17
|
+
bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
|
18
|
+
bmtool/bmplot/lfp.py,sha256=bfjUGt6al0t5mWTcSIyl03usLIqoQVXfmsvpZl4lol4,2023
|
19
|
+
bmtool/bmplot/netcon_reports.py,sha256=VFw4sJIt4Zc0-__eYnksN8Ku9qMhbPpHJEkXMWUiD30,4
|
20
|
+
bmtool/bmplot/spikes.py,sha256=sy8u6Ng-EVHdIa70uE0_CAsrxT8pBmcmQ-7S2RLxg5Y,10770
|
16
21
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
22
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
18
23
|
bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
|
@@ -21,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
21
26
|
bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
|
22
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
28
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
24
|
-
bmtool-0.
|
25
|
-
bmtool-0.
|
26
|
-
bmtool-0.
|
27
|
-
bmtool-0.
|
28
|
-
bmtool-0.
|
29
|
-
bmtool-0.
|
29
|
+
bmtool-0.7.0.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.0.dist-info/METADATA,sha256=r99nJv4MFuLf7bkDsMWv3dPDciCJ1BSl6gwAwWE5FP0,2766
|
31
|
+
bmtool-0.7.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
32
|
+
bmtool-0.7.0.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.0.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|