bmtool 0.7.1.6__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +1 -81
- bmtool/analysis/spikes.py +55 -4
- bmtool/bmplot/connections.py +116 -67
- bmtool/bmplot/entrainment.py +693 -30
- bmtool/bmplot/spikes.py +118 -5
- bmtool/synapses.py +76 -25
- bmtool/util/util.py +3 -0
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/METADATA +1 -1
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/RECORD +13 -13
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/WHEEL +1 -1
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.6.dist-info → bmtool-0.7.2.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
|
|
7
7
|
import numba
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
10
|
-
import scipy.stats as stats
|
11
10
|
import xarray as xr
|
12
11
|
from numba import cuda
|
13
12
|
from scipy import signal
|
14
13
|
from tqdm.notebook import tqdm
|
15
14
|
|
16
|
-
from .lfp import butter_bandpass_filter, get_lfp_phase,
|
15
|
+
from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
|
17
16
|
|
18
17
|
|
19
18
|
def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
|
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
|
|
635
634
|
return entrainment_dict
|
636
635
|
|
637
636
|
|
638
|
-
def calculate_spike_rate_power_correlation(
|
639
|
-
spike_rate,
|
640
|
-
lfp_data,
|
641
|
-
fs,
|
642
|
-
pop_names,
|
643
|
-
filter_method="wavelet",
|
644
|
-
bandwidth=2.0,
|
645
|
-
lowcut=None,
|
646
|
-
highcut=None,
|
647
|
-
freq_range=(10, 100),
|
648
|
-
freq_step=5,
|
649
|
-
):
|
650
|
-
"""
|
651
|
-
Calculate correlation between population spike rates and LFP power across frequencies
|
652
|
-
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
653
|
-
|
654
|
-
Parameters:
|
655
|
-
-----------
|
656
|
-
spike_rate : DataFrame
|
657
|
-
Pre-calculated population spike rates at the same fs as lfp
|
658
|
-
lfp_data : np.array
|
659
|
-
LFP data
|
660
|
-
fs : float
|
661
|
-
Sampling frequency
|
662
|
-
pop_names : list
|
663
|
-
List of population names to analyze
|
664
|
-
filter_method : str, optional
|
665
|
-
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
666
|
-
bandwidth : float, optional
|
667
|
-
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
668
|
-
lowcut : float, optional
|
669
|
-
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
670
|
-
highcut : float, optional
|
671
|
-
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
672
|
-
freq_range : tuple, optional
|
673
|
-
Min and max frequency to analyze (default: (10, 100))
|
674
|
-
freq_step : float, optional
|
675
|
-
Step size for frequency analysis (default: 5)
|
676
|
-
|
677
|
-
Returns:
|
678
|
-
--------
|
679
|
-
correlation_results : dict
|
680
|
-
Dictionary with correlation results for each population and frequency
|
681
|
-
frequencies : array
|
682
|
-
Array of frequencies analyzed
|
683
|
-
"""
|
684
|
-
|
685
|
-
# Define frequency bands to analyze
|
686
|
-
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
687
|
-
|
688
|
-
# Dictionary to store results
|
689
|
-
correlation_results = {pop: {} for pop in pop_names}
|
690
|
-
|
691
|
-
# Calculate power at each frequency band using specified filter
|
692
|
-
power_by_freq = {}
|
693
|
-
for freq in frequencies:
|
694
|
-
power_by_freq[freq] = get_lfp_power(
|
695
|
-
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
696
|
-
)
|
697
|
-
|
698
|
-
# Calculate correlation for each population
|
699
|
-
for pop in pop_names:
|
700
|
-
# Extract spike rate for this population
|
701
|
-
pop_rate = spike_rate[pop]
|
702
|
-
|
703
|
-
# Calculate correlation with power at each frequency
|
704
|
-
for freq in frequencies:
|
705
|
-
# Make sure the lengths match
|
706
|
-
if len(pop_rate) != len(power_by_freq[freq]):
|
707
|
-
raise ValueError(
|
708
|
-
f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
|
709
|
-
)
|
710
|
-
# use spearman for non-parametric correlation
|
711
|
-
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
712
|
-
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
713
|
-
|
714
|
-
return correlation_results, frequencies
|
715
|
-
|
716
|
-
|
717
637
|
def get_spikes_in_cycle(
|
718
638
|
spike_df,
|
719
639
|
lfp_data,
|
bmtool/analysis/spikes.py
CHANGED
@@ -281,10 +281,13 @@ def get_population_spike_rate(
|
|
281
281
|
node_number = {}
|
282
282
|
|
283
283
|
if config is None:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
284
|
+
pass
|
285
|
+
# print(
|
286
|
+
# "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
|
287
|
+
# )
|
288
|
+
# print(
|
289
|
+
# "You can provide a config to calculate the correct amount of nodes! for a true population rate."
|
290
|
+
# )
|
288
291
|
|
289
292
|
if config:
|
290
293
|
if not network_name:
|
@@ -600,3 +603,51 @@ def find_bursting_cells(
|
|
600
603
|
)
|
601
604
|
|
602
605
|
return burst_cells
|
606
|
+
|
607
|
+
|
608
|
+
def find_highest_firing_cells(
|
609
|
+
df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
|
610
|
+
) -> pd.DataFrame:
|
611
|
+
"""
|
612
|
+
Identifies and returns spikes from cells with firing rates above a specified upper quantile,
|
613
|
+
grouped by a population label.
|
614
|
+
|
615
|
+
Parameters
|
616
|
+
----------
|
617
|
+
df : pd.DataFrame
|
618
|
+
DataFrame containing spike data with at least the following columns:
|
619
|
+
- 'timestamps': Time of each spike event
|
620
|
+
- 'node_ids': Identifier for each neuron
|
621
|
+
- groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
|
622
|
+
|
623
|
+
upper_quantile : float
|
624
|
+
The upper quantile threshold (between 0 and 1).
|
625
|
+
Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
|
626
|
+
For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
|
627
|
+
|
628
|
+
groupby : str, optional
|
629
|
+
The column name used to group neurons by population. Default is 'pop_name'.
|
630
|
+
|
631
|
+
Returns
|
632
|
+
-------
|
633
|
+
pd.DataFrame
|
634
|
+
A DataFrame containing only the spikes from the high-firing cells across all groupbys.
|
635
|
+
"""
|
636
|
+
if upper_quantile == 0:
|
637
|
+
return df
|
638
|
+
df_list = []
|
639
|
+
for pop in df[groupby].unique():
|
640
|
+
pop_df = df[df[groupby] == pop]
|
641
|
+
_, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
|
642
|
+
|
643
|
+
# Identify high firing cells
|
644
|
+
threshold = pop_fr["firing_rate"].quantile(upper_quantile)
|
645
|
+
high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
|
646
|
+
|
647
|
+
# Filter spikes for high firing cells
|
648
|
+
pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
|
649
|
+
df_list.append(pop_spikes)
|
650
|
+
|
651
|
+
# Combine all high firing spikes into one DataFrame
|
652
|
+
result_df = pd.concat(df_list, ignore_index=True)
|
653
|
+
return result_df
|
bmtool/bmplot/connections.py
CHANGED
@@ -1088,7 +1088,6 @@ def plot_connection_info(
|
|
1088
1088
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
1089
1089
|
If there is no source or target, set the value to 0.
|
1090
1090
|
"""
|
1091
|
-
|
1092
1091
|
# Ensure text dimensions match num dimensions
|
1093
1092
|
num_source = len(source_labels)
|
1094
1093
|
num_target = len(target_labels)
|
@@ -1096,103 +1095,149 @@ def plot_connection_info(
|
|
1096
1095
|
# Set color map
|
1097
1096
|
matplotlib.rc("image", cmap="viridis")
|
1098
1097
|
|
1099
|
-
#
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1098
|
+
# Calculate square cell size to ensure proper aspect ratio
|
1099
|
+
base_cell_size = 0.6 # Base size per cell
|
1100
|
+
|
1101
|
+
# Calculate figure dimensions with proper aspect ratio
|
1102
|
+
# Make sure width and height are proportional to the matrix dimensions
|
1103
|
+
fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
|
1104
|
+
fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
|
1105
|
+
|
1106
|
+
# Ensure minimum readable size
|
1107
|
+
min_fig_size = 8
|
1108
|
+
if fig_width < min_fig_size or fig_height < min_fig_size:
|
1109
|
+
scale_factor = min_fig_size / min(fig_width, fig_height)
|
1110
|
+
fig_width *= scale_factor
|
1111
|
+
fig_height *= scale_factor
|
1112
|
+
|
1113
|
+
# Create figure and axis
|
1114
|
+
fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
|
1103
1115
|
|
1104
|
-
#
|
1116
|
+
# Replace NaN with 0 and create heatmap
|
1117
|
+
num_clean = np.nan_to_num(num, nan=0)
|
1118
|
+
# if string is nan\nnan make it 0
|
1119
|
+
|
1120
|
+
# Use 'auto' aspect ratio to let matplotlib handle it properly
|
1121
|
+
# This prevents the stretching issue
|
1122
|
+
im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
|
1123
|
+
|
1124
|
+
# Set ticks and labels
|
1105
1125
|
ax1.set_xticks(list(np.arange(len(target_labels))))
|
1106
1126
|
ax1.set_yticks(list(np.arange(len(source_labels))))
|
1107
1127
|
ax1.set_xticklabels(target_labels)
|
1108
|
-
ax1.set_yticklabels(source_labels
|
1128
|
+
ax1.set_yticklabels(source_labels)
|
1109
1129
|
|
1110
|
-
#
|
1130
|
+
# Improved font sizing based on matrix size
|
1131
|
+
label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
|
1132
|
+
|
1133
|
+
# Style the tick labels
|
1134
|
+
ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
|
1111
1135
|
plt.setp(
|
1112
1136
|
ax1.get_xticklabels(),
|
1113
1137
|
rotation=45,
|
1114
1138
|
ha="right",
|
1115
1139
|
rotation_mode="anchor",
|
1116
|
-
|
1117
|
-
weight="semibold",
|
1140
|
+
fontsize=label_font_size,
|
1118
1141
|
)
|
1119
1142
|
|
1120
1143
|
# Dictionary to store connection information
|
1121
1144
|
graph_dict = {}
|
1122
1145
|
|
1146
|
+
# Improved text size calculation - more readable for larger matrices
|
1147
|
+
text_size = max(6, min(12, 80 / max(num_source, num_target)))
|
1148
|
+
|
1123
1149
|
# Loop over data dimensions and create text annotations
|
1124
1150
|
for i in range(num_source):
|
1125
1151
|
for j in range(num_target):
|
1126
|
-
|
1127
|
-
edge_info = text[i, j] if text[i, j] is not None else 0
|
1152
|
+
edge_info = text[i, j] if text[i, j] is not None else "0\n0"
|
1128
1153
|
|
1129
|
-
# Initialize the dictionary for the source node if not already done
|
1130
1154
|
if source_labels[i] not in graph_dict:
|
1131
1155
|
graph_dict[source_labels[i]] = {}
|
1132
|
-
|
1133
|
-
# Add edge info for the target node
|
1134
1156
|
graph_dict[source_labels[i]][target_labels[j]] = edge_info
|
1135
1157
|
|
1136
|
-
#
|
1137
|
-
if
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
fig_text = ax1.text(
|
1152
|
-
j,
|
1153
|
-
i,
|
1154
|
-
edge_info,
|
1155
|
-
ha="center",
|
1156
|
-
va="center",
|
1157
|
-
color="w",
|
1158
|
-
rotation=37.5,
|
1159
|
-
size=7,
|
1160
|
-
weight="semibold",
|
1161
|
-
)
|
1158
|
+
# Skip displaying text for NaN values to reduce clutter
|
1159
|
+
if edge_info == "nan\nnan":
|
1160
|
+
edge_info = "0\n±0"
|
1161
|
+
|
1162
|
+
# Format the text display
|
1163
|
+
if isinstance(edge_info, str) and "\n" in edge_info:
|
1164
|
+
# For mean/std format (e.g. "15.5\n4.0")
|
1165
|
+
parts = edge_info.split("\n")
|
1166
|
+
if len(parts) == 2:
|
1167
|
+
try:
|
1168
|
+
mean_val = float(parts[0])
|
1169
|
+
std_val = float(parts[1])
|
1170
|
+
display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
|
1171
|
+
except ValueError:
|
1172
|
+
display_text = edge_info
|
1162
1173
|
else:
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
+
display_text = edge_info
|
1175
|
+
else:
|
1176
|
+
display_text = str(edge_info)
|
1177
|
+
|
1178
|
+
# Add text to plot with better contrast
|
1179
|
+
text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
|
1180
|
+
|
1181
|
+
if syn_info == "2" or syn_info == "3":
|
1182
|
+
ax1.text(
|
1183
|
+
j,
|
1184
|
+
i,
|
1185
|
+
display_text,
|
1186
|
+
ha="center",
|
1187
|
+
va="center",
|
1188
|
+
color=text_color,
|
1189
|
+
rotation=37.5,
|
1190
|
+
fontsize=text_size,
|
1191
|
+
weight="bold",
|
1192
|
+
)
|
1174
1193
|
else:
|
1175
|
-
|
1176
|
-
j,
|
1194
|
+
ax1.text(
|
1195
|
+
j,
|
1196
|
+
i,
|
1197
|
+
display_text,
|
1198
|
+
ha="center",
|
1199
|
+
va="center",
|
1200
|
+
color=text_color,
|
1201
|
+
fontsize=text_size,
|
1202
|
+
weight="bold",
|
1177
1203
|
)
|
1178
1204
|
|
1179
|
-
# Set labels and title
|
1180
|
-
|
1181
|
-
ax1.
|
1182
|
-
ax1.
|
1205
|
+
# Set labels and title
|
1206
|
+
title_font_size = max(12, min(18, label_font_size + 4))
|
1207
|
+
ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
|
1208
|
+
ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
|
1209
|
+
ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
|
1210
|
+
|
1211
|
+
# Add colorbar
|
1212
|
+
cbar = plt.colorbar(im1, shrink=0.8)
|
1213
|
+
cbar.ax.tick_params(labelsize=label_font_size)
|
1214
|
+
|
1215
|
+
# Adjust layout to minimize whitespace and prevent stretching
|
1216
|
+
plt.tight_layout(pad=1.5)
|
1217
|
+
|
1218
|
+
# Force square cells by setting equal axis limits if needed
|
1219
|
+
ax1.set_xlim(-0.5, num_target - 0.5)
|
1220
|
+
ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
|
1221
|
+
|
1222
|
+
# Display or save the plot
|
1223
|
+
try:
|
1224
|
+
# Check if running in notebook
|
1225
|
+
from IPython import get_ipython
|
1226
|
+
|
1227
|
+
notebook = get_ipython() is not None
|
1228
|
+
except ImportError:
|
1229
|
+
notebook = False
|
1183
1230
|
|
1184
|
-
# Display the plot or save it based on the environment and arguments
|
1185
|
-
notebook = is_notebook() # Check if running in a Jupyter notebook
|
1186
1231
|
if not notebook:
|
1187
|
-
|
1232
|
+
plt.show()
|
1188
1233
|
|
1189
1234
|
if save_file:
|
1190
|
-
plt.savefig(save_file)
|
1235
|
+
plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
|
1191
1236
|
|
1192
1237
|
if return_dict:
|
1193
1238
|
return graph_dict
|
1194
1239
|
else:
|
1195
|
-
return
|
1240
|
+
return fig1, ax1
|
1196
1241
|
|
1197
1242
|
|
1198
1243
|
def connector_percent_matrix(
|
@@ -1467,19 +1512,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1467
1512
|
plt.title(title)
|
1468
1513
|
plt.legend(handles=handles)
|
1469
1514
|
|
1515
|
+
# Add axis labels
|
1516
|
+
ax.set_xlabel("X Position (μm)")
|
1517
|
+
ax.set_ylabel("Y Position (μm)")
|
1518
|
+
ax.set_zlabel("Z Position (μm)")
|
1519
|
+
|
1470
1520
|
# Draw the plot
|
1471
1521
|
plt.draw()
|
1522
|
+
plt.tight_layout()
|
1472
1523
|
|
1473
1524
|
# Save the plot if save_file is provided
|
1474
1525
|
if save_file:
|
1475
1526
|
plt.savefig(save_file)
|
1476
1527
|
|
1477
|
-
# Show
|
1478
|
-
if
|
1528
|
+
# Show if running in notebook
|
1529
|
+
if is_notebook:
|
1479
1530
|
plt.show()
|
1480
1531
|
|
1481
|
-
return ax
|
1482
|
-
|
1483
1532
|
|
1484
1533
|
def plot_3d_cell_rotation(
|
1485
1534
|
config=None,
|