bmtool 0.7.1.7__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +1 -81
- bmtool/analysis/spikes.py +55 -6
- bmtool/bmplot/connections.py +116 -67
- bmtool/bmplot/entrainment.py +525 -28
- bmtool/bmplot/spikes.py +118 -5
- bmtool/synapses.py +3 -3
- bmtool/util/util.py +3 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/METADATA +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/RECORD +13 -13
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/WHEEL +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
|
|
7
7
|
import numba
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
10
|
-
import scipy.stats as stats
|
11
10
|
import xarray as xr
|
12
11
|
from numba import cuda
|
13
12
|
from scipy import signal
|
14
13
|
from tqdm.notebook import tqdm
|
15
14
|
|
16
|
-
from .lfp import butter_bandpass_filter, get_lfp_phase,
|
15
|
+
from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
|
17
16
|
|
18
17
|
|
19
18
|
def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
|
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
|
|
635
634
|
return entrainment_dict
|
636
635
|
|
637
636
|
|
638
|
-
def calculate_spike_rate_power_correlation(
|
639
|
-
spike_rate: xr.DataArray,
|
640
|
-
lfp_data: np.ndarray,
|
641
|
-
fs: float,
|
642
|
-
pop_names: list,
|
643
|
-
filter_method: str = "wavelet",
|
644
|
-
bandwidth: float = 2.0,
|
645
|
-
lowcut: float = None,
|
646
|
-
highcut: float = None,
|
647
|
-
freq_range: tuple = (10, 100),
|
648
|
-
freq_step: float = 5,
|
649
|
-
type_name: str = "raw", # 'raw' or 'smoothed'
|
650
|
-
):
|
651
|
-
"""
|
652
|
-
Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
|
653
|
-
|
654
|
-
Parameters
|
655
|
-
----------
|
656
|
-
spike_rate : xr.DataArray
|
657
|
-
Population spike rates with dimensions (time, population[, type])
|
658
|
-
lfp_data : np.ndarray
|
659
|
-
LFP data
|
660
|
-
fs : float
|
661
|
-
Sampling frequency
|
662
|
-
pop_names : list
|
663
|
-
List of population names to analyze
|
664
|
-
filter_method : str, optional
|
665
|
-
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
666
|
-
bandwidth : float, optional
|
667
|
-
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
668
|
-
lowcut : float, optional
|
669
|
-
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
670
|
-
highcut : float, optional
|
671
|
-
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
672
|
-
freq_range : tuple, optional
|
673
|
-
Min and max frequency to analyze (default: (10, 100))
|
674
|
-
freq_step : float, optional
|
675
|
-
Step size for frequency analysis (default: 5)
|
676
|
-
type_name : str, optional
|
677
|
-
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
678
|
-
|
679
|
-
Returns
|
680
|
-
-------
|
681
|
-
correlation_results : dict
|
682
|
-
Dictionary with correlation results for each population and frequency
|
683
|
-
frequencies : array
|
684
|
-
Array of frequencies analyzed
|
685
|
-
"""
|
686
|
-
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
687
|
-
correlation_results = {pop: {} for pop in pop_names}
|
688
|
-
|
689
|
-
# Calculate power at each frequency band using specified filter
|
690
|
-
power_by_freq = {}
|
691
|
-
for freq in frequencies:
|
692
|
-
power_by_freq[freq] = get_lfp_power(
|
693
|
-
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
694
|
-
)
|
695
|
-
|
696
|
-
# For each population, extract the correct spike rate
|
697
|
-
for pop in pop_names:
|
698
|
-
# If 'type' dimension exists, select the type
|
699
|
-
if "type" in spike_rate.dims:
|
700
|
-
pop_rate = spike_rate.sel(population=pop, type=type_name).values
|
701
|
-
else:
|
702
|
-
pop_rate = spike_rate.sel(population=pop).values
|
703
|
-
|
704
|
-
# Calculate correlation with power at each frequency
|
705
|
-
for freq in frequencies:
|
706
|
-
lfp_power = power_by_freq[freq]
|
707
|
-
# Ensure lengths match
|
708
|
-
min_len = min(len(pop_rate), len(lfp_power))
|
709
|
-
if len(pop_rate) != len(lfp_power):
|
710
|
-
print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
|
711
|
-
corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
|
712
|
-
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
713
|
-
|
714
|
-
return correlation_results, frequencies
|
715
|
-
|
716
|
-
|
717
637
|
def get_spikes_in_cycle(
|
718
638
|
spike_df,
|
719
639
|
lfp_data,
|
bmtool/analysis/spikes.py
CHANGED
@@ -281,12 +281,13 @@ def get_population_spike_rate(
|
|
281
281
|
node_number = {}
|
282
282
|
|
283
283
|
if config is None:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
284
|
+
pass
|
285
|
+
# print(
|
286
|
+
# "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
|
287
|
+
# )
|
288
|
+
# print(
|
289
|
+
# "You can provide a config to calculate the correct amount of nodes! for a true population rate."
|
290
|
+
# )
|
290
291
|
|
291
292
|
if config:
|
292
293
|
if not network_name:
|
@@ -602,3 +603,51 @@ def find_bursting_cells(
|
|
602
603
|
)
|
603
604
|
|
604
605
|
return burst_cells
|
606
|
+
|
607
|
+
|
608
|
+
def find_highest_firing_cells(
|
609
|
+
df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
|
610
|
+
) -> pd.DataFrame:
|
611
|
+
"""
|
612
|
+
Identifies and returns spikes from cells with firing rates above a specified upper quantile,
|
613
|
+
grouped by a population label.
|
614
|
+
|
615
|
+
Parameters
|
616
|
+
----------
|
617
|
+
df : pd.DataFrame
|
618
|
+
DataFrame containing spike data with at least the following columns:
|
619
|
+
- 'timestamps': Time of each spike event
|
620
|
+
- 'node_ids': Identifier for each neuron
|
621
|
+
- groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
|
622
|
+
|
623
|
+
upper_quantile : float
|
624
|
+
The upper quantile threshold (between 0 and 1).
|
625
|
+
Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
|
626
|
+
For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
|
627
|
+
|
628
|
+
groupby : str, optional
|
629
|
+
The column name used to group neurons by population. Default is 'pop_name'.
|
630
|
+
|
631
|
+
Returns
|
632
|
+
-------
|
633
|
+
pd.DataFrame
|
634
|
+
A DataFrame containing only the spikes from the high-firing cells across all groupbys.
|
635
|
+
"""
|
636
|
+
if upper_quantile == 0:
|
637
|
+
return df
|
638
|
+
df_list = []
|
639
|
+
for pop in df[groupby].unique():
|
640
|
+
pop_df = df[df[groupby] == pop]
|
641
|
+
_, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
|
642
|
+
|
643
|
+
# Identify high firing cells
|
644
|
+
threshold = pop_fr["firing_rate"].quantile(upper_quantile)
|
645
|
+
high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
|
646
|
+
|
647
|
+
# Filter spikes for high firing cells
|
648
|
+
pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
|
649
|
+
df_list.append(pop_spikes)
|
650
|
+
|
651
|
+
# Combine all high firing spikes into one DataFrame
|
652
|
+
result_df = pd.concat(df_list, ignore_index=True)
|
653
|
+
return result_df
|
bmtool/bmplot/connections.py
CHANGED
@@ -1088,7 +1088,6 @@ def plot_connection_info(
|
|
1088
1088
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
1089
1089
|
If there is no source or target, set the value to 0.
|
1090
1090
|
"""
|
1091
|
-
|
1092
1091
|
# Ensure text dimensions match num dimensions
|
1093
1092
|
num_source = len(source_labels)
|
1094
1093
|
num_target = len(target_labels)
|
@@ -1096,103 +1095,149 @@ def plot_connection_info(
|
|
1096
1095
|
# Set color map
|
1097
1096
|
matplotlib.rc("image", cmap="viridis")
|
1098
1097
|
|
1099
|
-
#
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1098
|
+
# Calculate square cell size to ensure proper aspect ratio
|
1099
|
+
base_cell_size = 0.6 # Base size per cell
|
1100
|
+
|
1101
|
+
# Calculate figure dimensions with proper aspect ratio
|
1102
|
+
# Make sure width and height are proportional to the matrix dimensions
|
1103
|
+
fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
|
1104
|
+
fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
|
1105
|
+
|
1106
|
+
# Ensure minimum readable size
|
1107
|
+
min_fig_size = 8
|
1108
|
+
if fig_width < min_fig_size or fig_height < min_fig_size:
|
1109
|
+
scale_factor = min_fig_size / min(fig_width, fig_height)
|
1110
|
+
fig_width *= scale_factor
|
1111
|
+
fig_height *= scale_factor
|
1112
|
+
|
1113
|
+
# Create figure and axis
|
1114
|
+
fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
|
1103
1115
|
|
1104
|
-
#
|
1116
|
+
# Replace NaN with 0 and create heatmap
|
1117
|
+
num_clean = np.nan_to_num(num, nan=0)
|
1118
|
+
# if string is nan\nnan make it 0
|
1119
|
+
|
1120
|
+
# Use 'auto' aspect ratio to let matplotlib handle it properly
|
1121
|
+
# This prevents the stretching issue
|
1122
|
+
im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
|
1123
|
+
|
1124
|
+
# Set ticks and labels
|
1105
1125
|
ax1.set_xticks(list(np.arange(len(target_labels))))
|
1106
1126
|
ax1.set_yticks(list(np.arange(len(source_labels))))
|
1107
1127
|
ax1.set_xticklabels(target_labels)
|
1108
|
-
ax1.set_yticklabels(source_labels
|
1128
|
+
ax1.set_yticklabels(source_labels)
|
1109
1129
|
|
1110
|
-
#
|
1130
|
+
# Improved font sizing based on matrix size
|
1131
|
+
label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
|
1132
|
+
|
1133
|
+
# Style the tick labels
|
1134
|
+
ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
|
1111
1135
|
plt.setp(
|
1112
1136
|
ax1.get_xticklabels(),
|
1113
1137
|
rotation=45,
|
1114
1138
|
ha="right",
|
1115
1139
|
rotation_mode="anchor",
|
1116
|
-
|
1117
|
-
weight="semibold",
|
1140
|
+
fontsize=label_font_size,
|
1118
1141
|
)
|
1119
1142
|
|
1120
1143
|
# Dictionary to store connection information
|
1121
1144
|
graph_dict = {}
|
1122
1145
|
|
1146
|
+
# Improved text size calculation - more readable for larger matrices
|
1147
|
+
text_size = max(6, min(12, 80 / max(num_source, num_target)))
|
1148
|
+
|
1123
1149
|
# Loop over data dimensions and create text annotations
|
1124
1150
|
for i in range(num_source):
|
1125
1151
|
for j in range(num_target):
|
1126
|
-
|
1127
|
-
edge_info = text[i, j] if text[i, j] is not None else 0
|
1152
|
+
edge_info = text[i, j] if text[i, j] is not None else "0\n0"
|
1128
1153
|
|
1129
|
-
# Initialize the dictionary for the source node if not already done
|
1130
1154
|
if source_labels[i] not in graph_dict:
|
1131
1155
|
graph_dict[source_labels[i]] = {}
|
1132
|
-
|
1133
|
-
# Add edge info for the target node
|
1134
1156
|
graph_dict[source_labels[i]][target_labels[j]] = edge_info
|
1135
1157
|
|
1136
|
-
#
|
1137
|
-
if
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
fig_text = ax1.text(
|
1152
|
-
j,
|
1153
|
-
i,
|
1154
|
-
edge_info,
|
1155
|
-
ha="center",
|
1156
|
-
va="center",
|
1157
|
-
color="w",
|
1158
|
-
rotation=37.5,
|
1159
|
-
size=7,
|
1160
|
-
weight="semibold",
|
1161
|
-
)
|
1158
|
+
# Skip displaying text for NaN values to reduce clutter
|
1159
|
+
if edge_info == "nan\nnan":
|
1160
|
+
edge_info = "0\n±0"
|
1161
|
+
|
1162
|
+
# Format the text display
|
1163
|
+
if isinstance(edge_info, str) and "\n" in edge_info:
|
1164
|
+
# For mean/std format (e.g. "15.5\n4.0")
|
1165
|
+
parts = edge_info.split("\n")
|
1166
|
+
if len(parts) == 2:
|
1167
|
+
try:
|
1168
|
+
mean_val = float(parts[0])
|
1169
|
+
std_val = float(parts[1])
|
1170
|
+
display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
|
1171
|
+
except ValueError:
|
1172
|
+
display_text = edge_info
|
1162
1173
|
else:
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
+
display_text = edge_info
|
1175
|
+
else:
|
1176
|
+
display_text = str(edge_info)
|
1177
|
+
|
1178
|
+
# Add text to plot with better contrast
|
1179
|
+
text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
|
1180
|
+
|
1181
|
+
if syn_info == "2" or syn_info == "3":
|
1182
|
+
ax1.text(
|
1183
|
+
j,
|
1184
|
+
i,
|
1185
|
+
display_text,
|
1186
|
+
ha="center",
|
1187
|
+
va="center",
|
1188
|
+
color=text_color,
|
1189
|
+
rotation=37.5,
|
1190
|
+
fontsize=text_size,
|
1191
|
+
weight="bold",
|
1192
|
+
)
|
1174
1193
|
else:
|
1175
|
-
|
1176
|
-
j,
|
1194
|
+
ax1.text(
|
1195
|
+
j,
|
1196
|
+
i,
|
1197
|
+
display_text,
|
1198
|
+
ha="center",
|
1199
|
+
va="center",
|
1200
|
+
color=text_color,
|
1201
|
+
fontsize=text_size,
|
1202
|
+
weight="bold",
|
1177
1203
|
)
|
1178
1204
|
|
1179
|
-
# Set labels and title
|
1180
|
-
|
1181
|
-
ax1.
|
1182
|
-
ax1.
|
1205
|
+
# Set labels and title
|
1206
|
+
title_font_size = max(12, min(18, label_font_size + 4))
|
1207
|
+
ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
|
1208
|
+
ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
|
1209
|
+
ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
|
1210
|
+
|
1211
|
+
# Add colorbar
|
1212
|
+
cbar = plt.colorbar(im1, shrink=0.8)
|
1213
|
+
cbar.ax.tick_params(labelsize=label_font_size)
|
1214
|
+
|
1215
|
+
# Adjust layout to minimize whitespace and prevent stretching
|
1216
|
+
plt.tight_layout(pad=1.5)
|
1217
|
+
|
1218
|
+
# Force square cells by setting equal axis limits if needed
|
1219
|
+
ax1.set_xlim(-0.5, num_target - 0.5)
|
1220
|
+
ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
|
1221
|
+
|
1222
|
+
# Display or save the plot
|
1223
|
+
try:
|
1224
|
+
# Check if running in notebook
|
1225
|
+
from IPython import get_ipython
|
1226
|
+
|
1227
|
+
notebook = get_ipython() is not None
|
1228
|
+
except ImportError:
|
1229
|
+
notebook = False
|
1183
1230
|
|
1184
|
-
# Display the plot or save it based on the environment and arguments
|
1185
|
-
notebook = is_notebook() # Check if running in a Jupyter notebook
|
1186
1231
|
if not notebook:
|
1187
|
-
|
1232
|
+
plt.show()
|
1188
1233
|
|
1189
1234
|
if save_file:
|
1190
|
-
plt.savefig(save_file)
|
1235
|
+
plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
|
1191
1236
|
|
1192
1237
|
if return_dict:
|
1193
1238
|
return graph_dict
|
1194
1239
|
else:
|
1195
|
-
return
|
1240
|
+
return fig1, ax1
|
1196
1241
|
|
1197
1242
|
|
1198
1243
|
def connector_percent_matrix(
|
@@ -1467,19 +1512,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1467
1512
|
plt.title(title)
|
1468
1513
|
plt.legend(handles=handles)
|
1469
1514
|
|
1515
|
+
# Add axis labels
|
1516
|
+
ax.set_xlabel("X Position (μm)")
|
1517
|
+
ax.set_ylabel("Y Position (μm)")
|
1518
|
+
ax.set_zlabel("Z Position (μm)")
|
1519
|
+
|
1470
1520
|
# Draw the plot
|
1471
1521
|
plt.draw()
|
1522
|
+
plt.tight_layout()
|
1472
1523
|
|
1473
1524
|
# Save the plot if save_file is provided
|
1474
1525
|
if save_file:
|
1475
1526
|
plt.savefig(save_file)
|
1476
1527
|
|
1477
|
-
# Show
|
1478
|
-
if
|
1528
|
+
# Show if running in notebook
|
1529
|
+
if is_notebook:
|
1479
1530
|
plt.show()
|
1480
1531
|
|
1481
|
-
return ax
|
1482
|
-
|
1483
1532
|
|
1484
1533
|
def plot_3d_cell_rotation(
|
1485
1534
|
config=None,
|
bmtool/bmplot/entrainment.py
CHANGED
@@ -1,56 +1,299 @@
|
|
1
|
+
from typing import List, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
4
|
import numpy as np
|
3
5
|
import pandas as pd
|
4
6
|
import seaborn as sns
|
7
|
+
import xarray as xr
|
5
8
|
from matplotlib.gridspec import GridSpec
|
6
9
|
from scipy import stats
|
7
10
|
|
8
|
-
|
9
|
-
|
11
|
+
from bmtool.analysis import entrainment as bmentr
|
12
|
+
from bmtool.analysis import spikes as bmspikes
|
13
|
+
from bmtool.analysis.lfp import get_lfp_power
|
14
|
+
|
15
|
+
|
16
|
+
def plot_spike_power_correlation(
|
17
|
+
spike_df: pd.DataFrame,
|
18
|
+
lfp_data: xr.DataArray,
|
19
|
+
firing_quantile: float,
|
20
|
+
fs: float,
|
21
|
+
pop_names: list,
|
22
|
+
filter_method: str = "wavelet",
|
23
|
+
bandwidth: float = 2.0,
|
24
|
+
lowcut: float = None,
|
25
|
+
highcut: float = None,
|
26
|
+
freq_range: tuple = (10, 100),
|
27
|
+
freq_step: float = 5,
|
28
|
+
type_name: str = "raw",
|
29
|
+
time_windows: list = None,
|
30
|
+
error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
|
31
|
+
):
|
10
32
|
"""
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
33
|
+
Calculate and plot correlation between population spike rates and LFP power across frequencies.
|
34
|
+
Supports both single-signal and trial-based analysis with error bars.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
spike_rate : xr.DataArray
|
39
|
+
Population spike rates with dimensions (time, population[, type])
|
40
|
+
lfp_data : xr.DataArray
|
41
|
+
LFP data
|
42
|
+
fs : float
|
43
|
+
Sampling frequency
|
19
44
|
pop_names : list
|
20
|
-
List of population names
|
45
|
+
List of population names to analyze
|
46
|
+
filter_method : str, optional
|
47
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
48
|
+
bandwidth : float, optional
|
49
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
50
|
+
lowcut : float, optional
|
51
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
52
|
+
highcut : float, optional
|
53
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
54
|
+
freq_range : tuple, optional
|
55
|
+
Min and max frequency to analyze (default: (10, 100))
|
56
|
+
freq_step : float, optional
|
57
|
+
Step size for frequency analysis (default: 5)
|
58
|
+
type_name : str, optional
|
59
|
+
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
60
|
+
time_windows : list, optional
|
61
|
+
List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
|
62
|
+
error_type : str, optional
|
63
|
+
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
21
64
|
"""
|
22
|
-
sns.set_style("whitegrid")
|
23
|
-
plt.figure(figsize=(10, 6))
|
24
65
|
|
66
|
+
if not (0 <= firing_quantile < 1):
|
67
|
+
raise ValueError("firing_quantile must be between 0 and 1")
|
68
|
+
|
69
|
+
if error_type not in ["ci", "sem", "std"]:
|
70
|
+
raise ValueError(
|
71
|
+
"error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
|
72
|
+
)
|
73
|
+
|
74
|
+
# Setup
|
75
|
+
is_trial_based = time_windows is not None
|
76
|
+
|
77
|
+
# Convert spike_df to spike rate with trial-based filtering of high firing cells
|
78
|
+
if is_trial_based:
|
79
|
+
# Initialize storage for trial-based spike rates
|
80
|
+
trial_rates = []
|
81
|
+
|
82
|
+
for start_time, end_time in time_windows:
|
83
|
+
# Get spikes for this trial
|
84
|
+
trial_spikes = spike_df[
|
85
|
+
(spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
|
86
|
+
].copy()
|
87
|
+
|
88
|
+
# Filter for high firing cells within this trial
|
89
|
+
trial_spikes = bmspikes.find_highest_firing_cells(
|
90
|
+
trial_spikes, upper_quantile=firing_quantile
|
91
|
+
)
|
92
|
+
# Calculate rate for this trial's filtered spikes
|
93
|
+
trial_rate = bmspikes.get_population_spike_rate(
|
94
|
+
trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
|
95
|
+
)
|
96
|
+
trial_rates.append(trial_rate)
|
97
|
+
|
98
|
+
# Combine all trial rates
|
99
|
+
spike_rate = xr.concat(trial_rates, dim="trial")
|
100
|
+
else:
|
101
|
+
# For non-trial analysis, proceed as before
|
102
|
+
spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
|
103
|
+
spike_rate = bmspikes.get_population_spike_rate(spike_df)
|
104
|
+
|
105
|
+
# Setup frequencies for analysis
|
106
|
+
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
107
|
+
|
108
|
+
# Pre-calculate LFP power for all frequencies
|
109
|
+
power_by_freq = {}
|
110
|
+
for freq in frequencies:
|
111
|
+
power_by_freq[freq] = get_lfp_power(
|
112
|
+
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
113
|
+
)
|
114
|
+
|
115
|
+
# Calculate correlations
|
116
|
+
results = {}
|
25
117
|
for pop in pop_names:
|
26
|
-
|
27
|
-
|
28
|
-
valid_freqs = []
|
118
|
+
pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
|
119
|
+
results[pop] = {}
|
29
120
|
|
30
121
|
for freq in frequencies:
|
31
|
-
|
32
|
-
corr_values.append(correlation_results[pop][freq]["correlation"])
|
33
|
-
valid_freqs.append(freq)
|
122
|
+
lfp_power = power_by_freq[freq]
|
34
123
|
|
35
|
-
|
36
|
-
|
124
|
+
if not is_trial_based:
|
125
|
+
# Single signal analysis
|
126
|
+
if len(pop_spike_rate) != len(lfp_power):
|
127
|
+
print(f"Warning: Length mismatch for {pop} at {freq} Hz")
|
128
|
+
continue
|
129
|
+
|
130
|
+
corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
|
131
|
+
results[pop][freq] = {
|
132
|
+
"correlation": corr,
|
133
|
+
"p_value": p_val,
|
134
|
+
}
|
135
|
+
else:
|
136
|
+
# Trial-based analysis using pre-filtered trial rates
|
137
|
+
trial_correlations = []
|
138
|
+
|
139
|
+
for trial_idx in range(len(time_windows)):
|
140
|
+
# Get time window first
|
141
|
+
start_time, end_time = time_windows[trial_idx]
|
142
|
+
|
143
|
+
# Get the pre-filtered spike rate for this trial
|
144
|
+
trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)
|
145
|
+
|
146
|
+
# Get corresponding LFP power for this trial window
|
147
|
+
trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))
|
148
|
+
|
149
|
+
# Ensure both signals have same time points
|
150
|
+
common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)
|
151
|
+
|
152
|
+
if len(common_times) > 0:
|
153
|
+
trial_sr = trial_spike_rate.sel(time=common_times).values
|
154
|
+
trial_lfp = trial_lfp_power.sel(time=common_times).values
|
155
|
+
|
156
|
+
if (
|
157
|
+
len(trial_sr) > 1 and len(trial_lfp) > 1
|
158
|
+
): # Need at least 2 points for correlation
|
159
|
+
corr, _ = stats.spearmanr(trial_sr, trial_lfp)
|
160
|
+
if not np.isnan(corr):
|
161
|
+
trial_correlations.append(corr)
|
162
|
+
|
163
|
+
# Calculate trial statistics
|
164
|
+
if len(trial_correlations) > 0:
|
165
|
+
trial_correlations = np.array(trial_correlations)
|
166
|
+
mean_corr = np.mean(trial_correlations)
|
167
|
+
|
168
|
+
if len(trial_correlations) > 1:
|
169
|
+
if error_type == "ci":
|
170
|
+
# Calculate 95% confidence interval using t-distribution
|
171
|
+
df = len(trial_correlations) - 1
|
172
|
+
sem = stats.sem(trial_correlations)
|
173
|
+
t_critical = stats.t.ppf(0.975, df) # 95% CI, two-tailed
|
174
|
+
error_val = t_critical * sem
|
175
|
+
error_lower = mean_corr - error_val
|
176
|
+
error_upper = mean_corr + error_val
|
177
|
+
elif error_type == "sem":
|
178
|
+
# Calculate standard error of the mean
|
179
|
+
sem = stats.sem(trial_correlations)
|
180
|
+
error_lower = mean_corr - sem
|
181
|
+
error_upper = mean_corr + sem
|
182
|
+
elif error_type == "std":
|
183
|
+
# Calculate standard deviation
|
184
|
+
std = np.std(trial_correlations, ddof=1)
|
185
|
+
error_lower = mean_corr - std
|
186
|
+
error_upper = mean_corr + std
|
187
|
+
else:
|
188
|
+
error_lower = error_upper = mean_corr
|
189
|
+
|
190
|
+
results[pop][freq] = {
|
191
|
+
"correlation": mean_corr,
|
192
|
+
"error_lower": error_lower,
|
193
|
+
"error_upper": error_upper,
|
194
|
+
"n_trials": len(trial_correlations),
|
195
|
+
"trial_correlations": trial_correlations,
|
196
|
+
}
|
197
|
+
else:
|
198
|
+
# No valid trials
|
199
|
+
results[pop][freq] = {
|
200
|
+
"correlation": np.nan,
|
201
|
+
"error_lower": np.nan,
|
202
|
+
"error_upper": np.nan,
|
203
|
+
"n_trials": 0,
|
204
|
+
"trial_correlations": np.array([]),
|
205
|
+
}
|
206
|
+
|
207
|
+
# Plotting
|
208
|
+
sns.set_style("whitegrid")
|
209
|
+
plt.figure(figsize=(12, 8))
|
37
210
|
|
211
|
+
for i, pop in enumerate(pop_names):
|
212
|
+
# Extract data for plotting
|
213
|
+
plot_freqs = []
|
214
|
+
plot_corrs = []
|
215
|
+
plot_ci_lower = []
|
216
|
+
plot_ci_upper = []
|
217
|
+
|
218
|
+
for freq in frequencies:
|
219
|
+
if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
|
220
|
+
plot_freqs.append(freq)
|
221
|
+
plot_corrs.append(results[pop][freq]["correlation"])
|
222
|
+
|
223
|
+
if is_trial_based:
|
224
|
+
plot_ci_lower.append(results[pop][freq]["error_lower"])
|
225
|
+
plot_ci_upper.append(results[pop][freq]["error_upper"])
|
226
|
+
|
227
|
+
if len(plot_freqs) == 0:
|
228
|
+
continue
|
229
|
+
|
230
|
+
# Convert to arrays
|
231
|
+
plot_freqs = np.array(plot_freqs)
|
232
|
+
plot_corrs = np.array(plot_corrs)
|
233
|
+
|
234
|
+
# Get color for this population
|
235
|
+
colors = plt.get_cmap("tab10")
|
236
|
+
color = colors(i)
|
237
|
+
|
238
|
+
# Plot main line
|
239
|
+
plt.plot(
|
240
|
+
plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
|
241
|
+
)
|
242
|
+
|
243
|
+
# Plot error bands for trial-based analysis
|
244
|
+
if is_trial_based and len(plot_ci_lower) > 0:
|
245
|
+
plot_ci_lower = np.array(plot_ci_lower)
|
246
|
+
plot_ci_upper = np.array(plot_ci_upper)
|
247
|
+
plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)
|
248
|
+
|
249
|
+
# Formatting
|
38
250
|
plt.xlabel("Frequency (Hz)", fontsize=12)
|
39
251
|
plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
|
40
|
-
plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
|
41
|
-
plt.grid(True, alpha=0.3)
|
42
|
-
plt.legend(fontsize=12)
|
43
|
-
plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
|
44
252
|
|
45
|
-
#
|
253
|
+
# Calculate percentage for title
|
254
|
+
firing_percentage = round(float((1 - firing_quantile) * 100), 1)
|
255
|
+
if is_trial_based:
|
256
|
+
title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
|
257
|
+
else:
|
258
|
+
title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"
|
259
|
+
|
260
|
+
plt.title(title, fontsize=14)
|
261
|
+
plt.grid(True, alpha=0.3)
|
46
262
|
plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
|
47
263
|
|
48
|
-
#
|
264
|
+
# Legend
|
265
|
+
# Create legend elements for each population
|
266
|
+
from matplotlib.lines import Line2D
|
267
|
+
|
268
|
+
colors = plt.get_cmap("tab10")
|
269
|
+
legend_elements = [
|
270
|
+
Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
|
271
|
+
for i, pop in enumerate(pop_names)
|
272
|
+
]
|
273
|
+
|
274
|
+
# Add error band legend element for trial-based analysis
|
275
|
+
if is_trial_based:
|
276
|
+
# Map error type to legend label
|
277
|
+
error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
|
278
|
+
error_label = error_labels[error_type]
|
279
|
+
|
280
|
+
legend_elements.append(
|
281
|
+
Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
|
282
|
+
)
|
283
|
+
|
284
|
+
plt.legend(handles=legend_elements, fontsize=10, loc="best")
|
285
|
+
|
286
|
+
# Axis formatting
|
287
|
+
if len(frequencies) > 10:
|
288
|
+
plt.xticks(frequencies[::2])
|
289
|
+
else:
|
290
|
+
plt.xticks(frequencies)
|
291
|
+
plt.xlim(frequencies[0], frequencies[-1])
|
292
|
+
|
49
293
|
y_min, y_max = plt.ylim()
|
50
294
|
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
51
295
|
|
52
296
|
plt.tight_layout()
|
53
|
-
|
54
297
|
plt.show()
|
55
298
|
|
56
299
|
|
@@ -366,3 +609,257 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
|
|
366
609
|
plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
|
367
610
|
|
368
611
|
plt.show()
|
612
|
+
|
613
|
+
|
614
|
+
def plot_trial_avg_entrainment(
|
615
|
+
spike_df: pd.DataFrame,
|
616
|
+
lfp: np.ndarray,
|
617
|
+
time_windows: List[Tuple[float, float]],
|
618
|
+
entrainment_method: str,
|
619
|
+
pop_names: List[str],
|
620
|
+
freqs: Union[List[float], np.ndarray],
|
621
|
+
firing_quantile: float,
|
622
|
+
spike_fs: float = 1000,
|
623
|
+
error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
|
624
|
+
) -> None:
|
625
|
+
"""
|
626
|
+
Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
|
627
|
+
|
628
|
+
Parameters:
|
629
|
+
-----------
|
630
|
+
spike_df : pd.DataFrame
|
631
|
+
Spike data containing timestamps, node_ids, and pop_name columns
|
632
|
+
spike_fs : float
|
633
|
+
fs for spike data. Default is 1000
|
634
|
+
lfp : xarray
|
635
|
+
Xarray for a channel of the lfp data
|
636
|
+
time_windows : List[Tuple[float, float]]
|
637
|
+
List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
|
638
|
+
entrainment_method : str
|
639
|
+
Method for entrainment calculation ('ppc', 'ppc2' or 'plv')
|
640
|
+
pop_names : List[str]
|
641
|
+
List of population names to process (e.g., ['FSI', 'LTS'])
|
642
|
+
freqs : Union[List[float], np.ndarray]
|
643
|
+
Array of frequencies to analyze (Hz)
|
644
|
+
firing_quantile : float
|
645
|
+
Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
|
646
|
+
error_type : str
|
647
|
+
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
648
|
+
|
649
|
+
Raises:
|
650
|
+
-------
|
651
|
+
ValueError
|
652
|
+
If entrainment_method is not 'ppc', 'ppc2' or 'plv'
|
653
|
+
If error_type is not 'ci', 'sem', or 'std'
|
654
|
+
If no spikes found for a population in a trial
|
655
|
+
|
656
|
+
Returns:
|
657
|
+
--------
|
658
|
+
None
|
659
|
+
Displays plot and prints summary statistics
|
660
|
+
"""
|
661
|
+
sns.set_style("whitegrid")
|
662
|
+
# Validate inputs
|
663
|
+
if entrainment_method not in ["ppc", "plv", "ppc2"]:
|
664
|
+
raise ValueError("entrainment_method must be 'ppc', ppc2 or 'plv'")
|
665
|
+
|
666
|
+
if error_type not in ["ci", "sem", "std"]:
|
667
|
+
raise ValueError(
|
668
|
+
"error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
|
669
|
+
)
|
670
|
+
|
671
|
+
if not (0 <= firing_quantile < 1):
|
672
|
+
raise ValueError("firing_quantile must be between 0 and 1")
|
673
|
+
|
674
|
+
# Convert freqs to numpy array for easier indexing
|
675
|
+
freqs = np.array(freqs)
|
676
|
+
|
677
|
+
# Collect all PPC/PLV values across trials for each population
|
678
|
+
all_plv_data = {} # Dictionary to store results for each population
|
679
|
+
|
680
|
+
# Initialize storage for each population
|
681
|
+
for pop_name in pop_names:
|
682
|
+
all_plv_data[pop_name] = [] # Will be shape (n_trials, n_freqs)
|
683
|
+
|
684
|
+
# Loop through all pulse groups to collect data
|
685
|
+
for trial_idx in range(len(time_windows)):
|
686
|
+
plv_lists = {} # Store PLV lists for this trial
|
687
|
+
|
688
|
+
# Initialize PLV lists for each population
|
689
|
+
for pop_name in pop_names:
|
690
|
+
plv_lists[pop_name] = []
|
691
|
+
|
692
|
+
# Filter spikes for this trial
|
693
|
+
network_spikes = spike_df[
|
694
|
+
(spike_df["timestamps"] >= time_windows[trial_idx][0])
|
695
|
+
& (spike_df["timestamps"] <= time_windows[trial_idx][1])
|
696
|
+
].copy()
|
697
|
+
|
698
|
+
# Process each population
|
699
|
+
pop_spike_data = {}
|
700
|
+
for pop_name in pop_names:
|
701
|
+
# Get spikes for this population
|
702
|
+
pop_spikes = network_spikes[network_spikes["pop_name"] == pop_name]
|
703
|
+
|
704
|
+
if len(pop_spikes) == 0:
|
705
|
+
print(f"Warning: No spikes found for population {pop_name} in trial {trial_idx}")
|
706
|
+
# Add NaN values for this trial/population
|
707
|
+
plv_lists[pop_name] = [np.nan] * len(freqs)
|
708
|
+
continue
|
709
|
+
|
710
|
+
# Filter to get the top firing cells
|
711
|
+
# firing_quantile of 0.8 gets the top 20% of firing cells to use
|
712
|
+
pop_spikes = bmspikes.find_highest_firing_cells(
|
713
|
+
pop_spikes, upper_quantile=firing_quantile
|
714
|
+
)
|
715
|
+
|
716
|
+
if len(pop_spikes) == 0:
|
717
|
+
print(
|
718
|
+
f"Warning: No high-firing spikes found for population {pop_name} in trial {trial_idx}"
|
719
|
+
)
|
720
|
+
plv_lists[pop_name] = [np.nan] * len(freqs)
|
721
|
+
continue
|
722
|
+
|
723
|
+
pop_spike_data[pop_name] = pop_spikes
|
724
|
+
|
725
|
+
# Calculate PPC/PLV for each frequency and each population
|
726
|
+
for freq_idx, freq in enumerate(freqs):
|
727
|
+
for pop_name in pop_names:
|
728
|
+
if pop_name not in pop_spike_data:
|
729
|
+
continue # Skip if no data for this population
|
730
|
+
|
731
|
+
pop_spikes = pop_spike_data[pop_name]
|
732
|
+
|
733
|
+
try:
|
734
|
+
if entrainment_method == "ppc":
|
735
|
+
result = bmentr.calculate_ppc(
|
736
|
+
pop_spikes["timestamps"].values,
|
737
|
+
lfp,
|
738
|
+
spike_fs=spike_fs,
|
739
|
+
lfp_fs=lfp.fs,
|
740
|
+
freq_of_interest=freq,
|
741
|
+
filter_method="wavelet",
|
742
|
+
ppc_method="gpu",
|
743
|
+
)
|
744
|
+
elif entrainment_method == "plv":
|
745
|
+
result = bmentr.calculate_spike_lfp_plv(
|
746
|
+
pop_spikes["timestamps"].values,
|
747
|
+
lfp,
|
748
|
+
spike_fs=spike_fs,
|
749
|
+
lfp_fs=lfp.fs,
|
750
|
+
freq_of_interest=freq,
|
751
|
+
filter_method="wavelet",
|
752
|
+
)
|
753
|
+
elif entrainment_method == "ppc2":
|
754
|
+
result = bmentr.calculate_ppc2(
|
755
|
+
pop_spikes["timestamps"].values,
|
756
|
+
lfp,
|
757
|
+
spike_fs=spike_fs,
|
758
|
+
lfp_fs=lfp.fs,
|
759
|
+
freq_of_interest=freq,
|
760
|
+
filter_method="wavelet",
|
761
|
+
)
|
762
|
+
|
763
|
+
plv_lists[pop_name].append(result)
|
764
|
+
|
765
|
+
except Exception as e:
|
766
|
+
print(
|
767
|
+
f"Warning: Error calculating {entrainment_method} for {pop_name} at {freq}Hz in trial {trial_idx}: {e}"
|
768
|
+
)
|
769
|
+
plv_lists[pop_name].append(np.nan)
|
770
|
+
|
771
|
+
# Store this trial's results for each population
|
772
|
+
for pop_name in pop_names:
|
773
|
+
if pop_name in plv_lists and len(plv_lists[pop_name]) == len(freqs):
|
774
|
+
all_plv_data[pop_name].append(plv_lists[pop_name])
|
775
|
+
else:
|
776
|
+
# Fill with NaNs if data is missing
|
777
|
+
all_plv_data[pop_name].append([np.nan] * len(freqs))
|
778
|
+
|
779
|
+
# Convert to numpy arrays and calculate statistics
|
780
|
+
mean_plv = {}
|
781
|
+
error_plv = {}
|
782
|
+
|
783
|
+
for pop_name in pop_names:
|
784
|
+
all_plv_data[pop_name] = np.array(all_plv_data[pop_name]) # Shape: (n_trials, n_freqs)
|
785
|
+
|
786
|
+
# Calculate statistics across trials, ignoring NaN values
|
787
|
+
with np.errstate(invalid="ignore"): # Suppress warnings for all-NaN slices
|
788
|
+
mean_plv[pop_name] = np.nanmean(all_plv_data[pop_name], axis=0)
|
789
|
+
|
790
|
+
if error_type == "ci":
|
791
|
+
# Calculate 95% confidence intervals using SEM
|
792
|
+
valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
|
793
|
+
sem_plv = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(valid_counts)
|
794
|
+
|
795
|
+
# For 95% CI, multiply SEM by appropriate t-value
|
796
|
+
# Use minimum valid count across frequencies for conservative t-value
|
797
|
+
min_valid_trials = np.min(valid_counts[valid_counts > 1]) # Avoid division by zero
|
798
|
+
if min_valid_trials > 1:
|
799
|
+
t_value = stats.t.ppf(0.975, min_valid_trials - 1) # 95% CI, two-tailed
|
800
|
+
error_plv[pop_name] = t_value * sem_plv
|
801
|
+
else:
|
802
|
+
error_plv[pop_name] = np.full_like(sem_plv, np.nan)
|
803
|
+
|
804
|
+
elif error_type == "sem":
|
805
|
+
# Calculate standard error of the mean
|
806
|
+
valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
|
807
|
+
error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(
|
808
|
+
valid_counts
|
809
|
+
)
|
810
|
+
|
811
|
+
elif error_type == "std":
|
812
|
+
# Calculate standard deviation
|
813
|
+
error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
|
814
|
+
|
815
|
+
# Create the combined plot
|
816
|
+
plt.figure(figsize=(12, 8))
|
817
|
+
|
818
|
+
# Define markers and colors for different populations
|
819
|
+
markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
|
820
|
+
colors = sns.color_palette(n_colors=len(pop_names))
|
821
|
+
|
822
|
+
# Plot each population
|
823
|
+
for i, pop_name in enumerate(pop_names):
|
824
|
+
marker = markers[i % len(markers)] # Cycle through markers if more populations than markers
|
825
|
+
color = colors[i]
|
826
|
+
|
827
|
+
# Only plot if we have valid data
|
828
|
+
valid_mask = ~np.isnan(mean_plv[pop_name])
|
829
|
+
if np.any(valid_mask):
|
830
|
+
plt.plot(
|
831
|
+
freqs[valid_mask],
|
832
|
+
mean_plv[pop_name][valid_mask],
|
833
|
+
marker,
|
834
|
+
linewidth=2,
|
835
|
+
label=pop_name,
|
836
|
+
color=color,
|
837
|
+
markersize=6,
|
838
|
+
)
|
839
|
+
|
840
|
+
# Add error bars/shading if available
|
841
|
+
if not np.all(np.isnan(error_plv[pop_name])):
|
842
|
+
plt.fill_between(
|
843
|
+
freqs[valid_mask],
|
844
|
+
(mean_plv[pop_name] - error_plv[pop_name])[valid_mask],
|
845
|
+
(mean_plv[pop_name] + error_plv[pop_name])[valid_mask],
|
846
|
+
alpha=0.3,
|
847
|
+
color=color,
|
848
|
+
)
|
849
|
+
|
850
|
+
plt.xlabel("Frequency (Hz)", fontsize=12)
|
851
|
+
plt.ylabel(f"{entrainment_method.upper()}", fontsize=12)
|
852
|
+
|
853
|
+
# Calculate percentage for title and update title based on error type
|
854
|
+
firing_percentage = round(float((1 - firing_quantile) * 100), 1)
|
855
|
+
error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
|
856
|
+
error_label = error_labels[error_type]
|
857
|
+
plt.title(
|
858
|
+
f"{entrainment_method.upper()} Across Trials for Top {firing_percentage}% Firing Cells ({error_label})",
|
859
|
+
fontsize=14,
|
860
|
+
)
|
861
|
+
|
862
|
+
plt.legend(fontsize=10)
|
863
|
+
plt.grid(True, alpha=0.3)
|
864
|
+
plt.tight_layout()
|
865
|
+
plt.show()
|
bmtool/bmplot/spikes.py
CHANGED
@@ -20,6 +20,7 @@ def raster(
|
|
20
20
|
tstart: Optional[float] = None,
|
21
21
|
tstop: Optional[float] = None,
|
22
22
|
color_map: Optional[Dict[str, str]] = None,
|
23
|
+
dot_size: Optional[float] = 0.3,
|
23
24
|
) -> Axes:
|
24
25
|
"""
|
25
26
|
Plots a raster plot of neural spikes, with different colors for each population.
|
@@ -40,6 +41,8 @@ def raster(
|
|
40
41
|
Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
|
41
42
|
color_map : dict, optional
|
42
43
|
Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
|
44
|
+
dot_size: float, optional
|
45
|
+
Size of the dot to display on the scatterplot
|
43
46
|
|
44
47
|
Returns:
|
45
48
|
-------
|
@@ -53,6 +56,7 @@ def raster(
|
|
53
56
|
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
|
54
57
|
"""
|
55
58
|
# Initialize axes if none provided
|
59
|
+
sns.set_style("whitegrid")
|
56
60
|
if ax is None:
|
57
61
|
_, ax = plt.subplots(1, 1)
|
58
62
|
|
@@ -102,15 +106,17 @@ def raster(
|
|
102
106
|
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
103
107
|
|
104
108
|
# Plot each population with its specified or generated color
|
109
|
+
legend_handles = []
|
105
110
|
for pop_name, group in spikes_df.groupby(groupby):
|
106
|
-
ax.scatter(
|
107
|
-
|
108
|
-
)
|
111
|
+
ax.scatter(group["timestamps"], group["node_ids"], color=color_map[pop_name], s=dot_size)
|
112
|
+
# Dummy scatter for consistent legend appearance
|
113
|
+
handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
|
114
|
+
legend_handles.append(handle)
|
109
115
|
|
110
116
|
# Label axes
|
111
117
|
ax.set_xlabel("Time")
|
112
118
|
ax.set_ylabel("Node ID")
|
113
|
-
ax.legend(title="Population", loc="upper right", framealpha=0.9
|
119
|
+
ax.legend(handles=legend_handles, title="Population", loc="upper right", framealpha=0.9)
|
114
120
|
|
115
121
|
return ax
|
116
122
|
|
@@ -142,6 +148,7 @@ def plot_firing_rate_pop_stats(
|
|
142
148
|
Axes with the bar plot.
|
143
149
|
"""
|
144
150
|
# Ensure groupby is a list for consistent handling
|
151
|
+
sns.set_style("whitegrid")
|
145
152
|
if isinstance(groupby, str):
|
146
153
|
groupby = [groupby]
|
147
154
|
|
@@ -234,6 +241,7 @@ def plot_firing_rate_distribution(
|
|
234
241
|
matplotlib.axes.Axes
|
235
242
|
Axes with the selected plot type(s) overlayed.
|
236
243
|
"""
|
244
|
+
sns.set_style("whitegrid")
|
237
245
|
# Ensure groupby is a list for consistent handling
|
238
246
|
if isinstance(groupby, str):
|
239
247
|
groupby = [groupby]
|
@@ -287,8 +295,9 @@ def plot_firing_rate_distribution(
|
|
287
295
|
y="firing_rate",
|
288
296
|
ax=ax,
|
289
297
|
palette=color_map,
|
290
|
-
inner="
|
298
|
+
inner="box",
|
291
299
|
alpha=0.4,
|
300
|
+
cut=0, # This prevents the KDE from extending beyond the data range
|
292
301
|
)
|
293
302
|
elif pt == "swarm":
|
294
303
|
sns.swarmplot(
|
@@ -308,3 +317,107 @@ def plot_firing_rate_distribution(
|
|
308
317
|
ax.grid(axis="y", linestyle="--", alpha=0.7)
|
309
318
|
|
310
319
|
return ax
|
320
|
+
|
321
|
+
|
322
|
+
def plot_firing_rate_vs_node_attribute(
|
323
|
+
individual_stats: Optional[pd.DataFrame] = None,
|
324
|
+
config: Optional[str] = None,
|
325
|
+
nodes: Optional[pd.DataFrame] = None,
|
326
|
+
groupby: Optional[str] = None,
|
327
|
+
network_name: Optional[str] = None,
|
328
|
+
attribute: Optional[str] = None,
|
329
|
+
figsize=(12, 8),
|
330
|
+
dot_size: float = 3,
|
331
|
+
) -> plt.Figure:
|
332
|
+
"""
|
333
|
+
Plot firing rate vs node attribute for each group in separate subplots.
|
334
|
+
|
335
|
+
Parameters
|
336
|
+
----------
|
337
|
+
individual_stats : pd.DataFrame, optional
|
338
|
+
DataFrame containing individual cell firing rates from compute_firing_rate_stats
|
339
|
+
config : str, optional
|
340
|
+
Path to configuration file for loading node data
|
341
|
+
nodes : pd.DataFrame, optional
|
342
|
+
Pre-loaded node data as alternative to loading from config
|
343
|
+
groupby : str, optional
|
344
|
+
Column name in individual_stats to group plots by
|
345
|
+
network_name : str, optional
|
346
|
+
Name of network to load from config file
|
347
|
+
attribute : str, optional
|
348
|
+
Node attribute column name to plot against firing rate
|
349
|
+
figsize : tuple[int, int], optional
|
350
|
+
Figure dimensions (width, height) in inches
|
351
|
+
dot_size : float, optional
|
352
|
+
Size of scatter plot points
|
353
|
+
|
354
|
+
Returns
|
355
|
+
-------
|
356
|
+
matplotlib.figure.Figure
|
357
|
+
Figure containing the subplots
|
358
|
+
|
359
|
+
Raises
|
360
|
+
------
|
361
|
+
ValueError
|
362
|
+
If neither config nor nodes is provided
|
363
|
+
If network_name is missing when using config
|
364
|
+
If attribute is not found in nodes DataFrame
|
365
|
+
If node_ids column is missing
|
366
|
+
If nodes index is not unique
|
367
|
+
"""
|
368
|
+
# Input validation
|
369
|
+
if config is None and nodes is None:
|
370
|
+
raise ValueError("Must provide either config or nodes")
|
371
|
+
if config is not None and nodes is None:
|
372
|
+
if network_name is None:
|
373
|
+
raise ValueError("network_name required when using config")
|
374
|
+
nodes = load_nodes_from_config(config)
|
375
|
+
if attribute not in nodes.columns:
|
376
|
+
raise ValueError(f"Attribute '{attribute}' not found in nodes DataFrame")
|
377
|
+
|
378
|
+
# Extract node attribute data
|
379
|
+
node_attribute = nodes[attribute]
|
380
|
+
|
381
|
+
# Validate data structure
|
382
|
+
if "node_ids" not in individual_stats.columns:
|
383
|
+
raise ValueError("individual_stats missing required 'node_ids' column")
|
384
|
+
if not nodes.index.is_unique:
|
385
|
+
raise ValueError("nodes DataFrame must have unique index for merging")
|
386
|
+
|
387
|
+
# Merge firing rate data with node attributes
|
388
|
+
merged_df = individual_stats.merge(
|
389
|
+
node_attribute, left_on="node_ids", right_index=True, how="left"
|
390
|
+
)
|
391
|
+
|
392
|
+
# Setup subplot layout
|
393
|
+
max_groups = 15 # Maximum number of subplots to avoid overcrowding
|
394
|
+
unique_groups = merged_df[groupby].unique()
|
395
|
+
n_groups = min(len(unique_groups), max_groups)
|
396
|
+
|
397
|
+
if len(unique_groups) > max_groups:
|
398
|
+
print(f"Warning: Limiting display to {max_groups} groups out of {len(unique_groups)}")
|
399
|
+
unique_groups = unique_groups[:max_groups]
|
400
|
+
|
401
|
+
n_cols = min(3, n_groups)
|
402
|
+
n_rows = (n_groups + n_cols - 1) // n_cols
|
403
|
+
|
404
|
+
# Create subplots
|
405
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
406
|
+
if n_groups == 1:
|
407
|
+
axes = np.array([axes])
|
408
|
+
axes = axes.flatten()
|
409
|
+
|
410
|
+
# Plot each group
|
411
|
+
for i, group in enumerate(unique_groups):
|
412
|
+
group_df = merged_df[merged_df[groupby] == group]
|
413
|
+
axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size)
|
414
|
+
axes[i].set_xlabel("Firing Rate (Hz)")
|
415
|
+
axes[i].set_ylabel(attribute)
|
416
|
+
axes[i].set_title(f"{groupby}: {group}")
|
417
|
+
|
418
|
+
# Hide unused subplots
|
419
|
+
for j in range(i + 1, len(axes)):
|
420
|
+
axes[j].set_visible(False)
|
421
|
+
|
422
|
+
plt.tight_layout()
|
423
|
+
plt.show()
|
bmtool/synapses.py
CHANGED
@@ -18,7 +18,7 @@ from scipy.optimize import curve_fit, minimize, minimize_scalar
|
|
18
18
|
from scipy.signal import find_peaks
|
19
19
|
from tqdm.notebook import tqdm
|
20
20
|
|
21
|
-
from bmtool.util.util import
|
21
|
+
from bmtool.util.util import load_templates_from_config
|
22
22
|
|
23
23
|
|
24
24
|
class SynapseTuner:
|
@@ -69,7 +69,7 @@ class SynapseTuner:
|
|
69
69
|
neuron.load_mechanisms(mechanisms_dir)
|
70
70
|
h.load_file(templates_dir)
|
71
71
|
else:
|
72
|
-
|
72
|
+
# loads both mech and templates
|
73
73
|
load_templates_from_config(config)
|
74
74
|
|
75
75
|
self.conn_type_settings = conn_type_settings
|
@@ -983,7 +983,7 @@ class GapJunctionTuner:
|
|
983
983
|
neuron.load_mechanisms(mechanisms_dir)
|
984
984
|
h.load_file(templates_dir)
|
985
985
|
else:
|
986
|
-
|
986
|
+
# this will load both mechs and templates
|
987
987
|
load_templates_from_config(config)
|
988
988
|
|
989
989
|
self.general_settings = general_settings
|
bmtool/util/util.py
CHANGED
@@ -447,6 +447,9 @@ def load_mechanisms_from_config(config=None):
|
|
447
447
|
|
448
448
|
|
449
449
|
def load_templates_from_config(config=None):
|
450
|
+
"""
|
451
|
+
loads the neuron mechanisms and templates provided from BMTK config
|
452
|
+
"""
|
450
453
|
if config is None:
|
451
454
|
config = "simulation_config.json"
|
452
455
|
config = load_config(config)
|
@@ -6,29 +6,29 @@ bmtool/graphs.py,sha256=gBTzI6c2BBK49dWGcfWh9c56TAooyn-KaiEy0Im1HcI,6717
|
|
6
6
|
bmtool/manage.py,sha256=lsgRejp02P-x6QpA7SXcyXdalPhRmypoviIA2uAitQs,608
|
7
7
|
bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
|
8
8
|
bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
|
9
|
-
bmtool/synapses.py,sha256=
|
9
|
+
bmtool/synapses.py,sha256=y8UJAqO1jpZY-mY9gVVMN8Dj1r9jD2fI1nAaNQeQfz4,66148
|
10
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bmtool/analysis/entrainment.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=NQloQtVpEWjDzmkZwMWVcm3hSjErHBZfQl1mrBVoIE8,25321
|
12
12
|
bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
|
13
13
|
bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
14
|
+
bmtool/analysis/spikes.py,sha256=3n-xmyEZ7w6CKEND7-aKOAvdDg0lwDuPI5sMdOuPwa0,24637
|
15
15
|
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
|
-
bmtool/bmplot/connections.py,sha256=
|
17
|
-
bmtool/bmplot/entrainment.py,sha256=
|
16
|
+
bmtool/bmplot/connections.py,sha256=KSORZ43v1B5xfiBN6AnAD7tJySVTkLIY3j_zb2r-YPA,55696
|
17
|
+
bmtool/bmplot/entrainment.py,sha256=BrBMerqyiG2YWAO_OEFv7OJf3yeFz3l9jUt4NamluLc,32837
|
18
18
|
bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
|
19
19
|
bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
bmtool/bmplot/spikes.py,sha256=
|
20
|
+
bmtool/bmplot/spikes.py,sha256=odzCSMbFRHp9qthSGQ0WzMWUwNQ7R1Z6gLT6VPF_o5Q,15326
|
21
21
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
|
23
23
|
bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
|
24
24
|
bmtool/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
|
26
|
-
bmtool/util/util.py,sha256=
|
26
|
+
bmtool/util/util.py,sha256=S8sAXwDiISGAqnSXRIgFqxqCRzL5YcxAqP1UGxGA5Z4,62906
|
27
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
|
29
|
-
bmtool-0.7.
|
30
|
-
bmtool-0.7.
|
31
|
-
bmtool-0.7.
|
32
|
-
bmtool-0.7.
|
33
|
-
bmtool-0.7.
|
34
|
-
bmtool-0.7.
|
29
|
+
bmtool-0.7.2.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.2.dist-info/METADATA,sha256=JacxbP2RvvbSuuveedTUJjl8KeqOCKf4FlW-UrRmfCk,3575
|
31
|
+
bmtool-0.7.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
32
|
+
bmtool-0.7.2.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.2.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|