bmtool 0.7.1.7__py3-none-any.whl → 0.7.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +1 -81
- bmtool/analysis/spikes.py +55 -6
- bmtool/bmplot/connections.py +255 -153
- bmtool/bmplot/entrainment.py +525 -28
- bmtool/bmplot/spikes.py +118 -5
- bmtool/synapses.py +3 -3
- bmtool/util/util.py +3 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/METADATA +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/RECORD +13 -13
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/WHEEL +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
|
|
7
7
|
import numba
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
10
|
-
import scipy.stats as stats
|
11
10
|
import xarray as xr
|
12
11
|
from numba import cuda
|
13
12
|
from scipy import signal
|
14
13
|
from tqdm.notebook import tqdm
|
15
14
|
|
16
|
-
from .lfp import butter_bandpass_filter, get_lfp_phase,
|
15
|
+
from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
|
17
16
|
|
18
17
|
|
19
18
|
def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
|
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
|
|
635
634
|
return entrainment_dict
|
636
635
|
|
637
636
|
|
638
|
-
def calculate_spike_rate_power_correlation(
|
639
|
-
spike_rate: xr.DataArray,
|
640
|
-
lfp_data: np.ndarray,
|
641
|
-
fs: float,
|
642
|
-
pop_names: list,
|
643
|
-
filter_method: str = "wavelet",
|
644
|
-
bandwidth: float = 2.0,
|
645
|
-
lowcut: float = None,
|
646
|
-
highcut: float = None,
|
647
|
-
freq_range: tuple = (10, 100),
|
648
|
-
freq_step: float = 5,
|
649
|
-
type_name: str = "raw", # 'raw' or 'smoothed'
|
650
|
-
):
|
651
|
-
"""
|
652
|
-
Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
|
653
|
-
|
654
|
-
Parameters
|
655
|
-
----------
|
656
|
-
spike_rate : xr.DataArray
|
657
|
-
Population spike rates with dimensions (time, population[, type])
|
658
|
-
lfp_data : np.ndarray
|
659
|
-
LFP data
|
660
|
-
fs : float
|
661
|
-
Sampling frequency
|
662
|
-
pop_names : list
|
663
|
-
List of population names to analyze
|
664
|
-
filter_method : str, optional
|
665
|
-
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
666
|
-
bandwidth : float, optional
|
667
|
-
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
668
|
-
lowcut : float, optional
|
669
|
-
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
670
|
-
highcut : float, optional
|
671
|
-
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
672
|
-
freq_range : tuple, optional
|
673
|
-
Min and max frequency to analyze (default: (10, 100))
|
674
|
-
freq_step : float, optional
|
675
|
-
Step size for frequency analysis (default: 5)
|
676
|
-
type_name : str, optional
|
677
|
-
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
678
|
-
|
679
|
-
Returns
|
680
|
-
-------
|
681
|
-
correlation_results : dict
|
682
|
-
Dictionary with correlation results for each population and frequency
|
683
|
-
frequencies : array
|
684
|
-
Array of frequencies analyzed
|
685
|
-
"""
|
686
|
-
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
687
|
-
correlation_results = {pop: {} for pop in pop_names}
|
688
|
-
|
689
|
-
# Calculate power at each frequency band using specified filter
|
690
|
-
power_by_freq = {}
|
691
|
-
for freq in frequencies:
|
692
|
-
power_by_freq[freq] = get_lfp_power(
|
693
|
-
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
694
|
-
)
|
695
|
-
|
696
|
-
# For each population, extract the correct spike rate
|
697
|
-
for pop in pop_names:
|
698
|
-
# If 'type' dimension exists, select the type
|
699
|
-
if "type" in spike_rate.dims:
|
700
|
-
pop_rate = spike_rate.sel(population=pop, type=type_name).values
|
701
|
-
else:
|
702
|
-
pop_rate = spike_rate.sel(population=pop).values
|
703
|
-
|
704
|
-
# Calculate correlation with power at each frequency
|
705
|
-
for freq in frequencies:
|
706
|
-
lfp_power = power_by_freq[freq]
|
707
|
-
# Ensure lengths match
|
708
|
-
min_len = min(len(pop_rate), len(lfp_power))
|
709
|
-
if len(pop_rate) != len(lfp_power):
|
710
|
-
print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
|
711
|
-
corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
|
712
|
-
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
713
|
-
|
714
|
-
return correlation_results, frequencies
|
715
|
-
|
716
|
-
|
717
637
|
def get_spikes_in_cycle(
|
718
638
|
spike_df,
|
719
639
|
lfp_data,
|
bmtool/analysis/spikes.py
CHANGED
@@ -281,12 +281,13 @@ def get_population_spike_rate(
|
|
281
281
|
node_number = {}
|
282
282
|
|
283
283
|
if config is None:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
284
|
+
pass
|
285
|
+
# print(
|
286
|
+
# "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
|
287
|
+
# )
|
288
|
+
# print(
|
289
|
+
# "You can provide a config to calculate the correct amount of nodes! for a true population rate."
|
290
|
+
# )
|
290
291
|
|
291
292
|
if config:
|
292
293
|
if not network_name:
|
@@ -602,3 +603,51 @@ def find_bursting_cells(
|
|
602
603
|
)
|
603
604
|
|
604
605
|
return burst_cells
|
606
|
+
|
607
|
+
|
608
|
+
def find_highest_firing_cells(
|
609
|
+
df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
|
610
|
+
) -> pd.DataFrame:
|
611
|
+
"""
|
612
|
+
Identifies and returns spikes from cells with firing rates above a specified upper quantile,
|
613
|
+
grouped by a population label.
|
614
|
+
|
615
|
+
Parameters
|
616
|
+
----------
|
617
|
+
df : pd.DataFrame
|
618
|
+
DataFrame containing spike data with at least the following columns:
|
619
|
+
- 'timestamps': Time of each spike event
|
620
|
+
- 'node_ids': Identifier for each neuron
|
621
|
+
- groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
|
622
|
+
|
623
|
+
upper_quantile : float
|
624
|
+
The upper quantile threshold (between 0 and 1).
|
625
|
+
Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
|
626
|
+
For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
|
627
|
+
|
628
|
+
groupby : str, optional
|
629
|
+
The column name used to group neurons by population. Default is 'pop_name'.
|
630
|
+
|
631
|
+
Returns
|
632
|
+
-------
|
633
|
+
pd.DataFrame
|
634
|
+
A DataFrame containing only the spikes from the high-firing cells across all groupbys.
|
635
|
+
"""
|
636
|
+
if upper_quantile == 0:
|
637
|
+
return df
|
638
|
+
df_list = []
|
639
|
+
for pop in df[groupby].unique():
|
640
|
+
pop_df = df[df[groupby] == pop]
|
641
|
+
_, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
|
642
|
+
|
643
|
+
# Identify high firing cells
|
644
|
+
threshold = pop_fr["firing_rate"].quantile(upper_quantile)
|
645
|
+
high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
|
646
|
+
|
647
|
+
# Filter spikes for high firing cells
|
648
|
+
pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
|
649
|
+
df_list.append(pop_spikes)
|
650
|
+
|
651
|
+
# Combine all high firing spikes into one DataFrame
|
652
|
+
result_df = pd.concat(df_list, ignore_index=True)
|
653
|
+
return result_df
|
bmtool/bmplot/connections.py
CHANGED
@@ -12,7 +12,6 @@ import matplotlib.pyplot as plt
|
|
12
12
|
import numpy as np
|
13
13
|
import pandas as pd
|
14
14
|
from IPython import get_ipython
|
15
|
-
from neuron import h
|
16
15
|
|
17
16
|
from ..util import util
|
18
17
|
|
@@ -981,104 +980,158 @@ def distance_delay_plot(
|
|
981
980
|
plt.show()
|
982
981
|
|
983
982
|
|
984
|
-
def
|
983
|
+
def plot_synapse_location(config: str, source: str, target: str, sids: str, tids: str) -> tuple:
|
985
984
|
"""
|
986
|
-
|
987
|
-
config: a BMTK config
|
988
|
-
target_model: the name of the model_template used when building the BMTK node
|
989
|
-
source: The source BMTK network
|
990
|
-
target: The target BMTK network
|
991
|
-
"""
|
992
|
-
# Load mechanisms and template
|
993
|
-
|
994
|
-
util.load_templates_from_config(config)
|
995
|
-
|
996
|
-
# Load node and edge data
|
997
|
-
nodes, edges = util.load_nodes_edges_from_config(config)
|
998
|
-
nodes = nodes[source]
|
999
|
-
edges = edges[f"{source}_to_{target}"]
|
1000
|
-
|
1001
|
-
# Map target_node_id to model_template
|
1002
|
-
edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
|
1003
|
-
|
1004
|
-
# Map source_node_id to pop_name
|
1005
|
-
edges["source_pop_name"] = edges["source_node_id"].map(nodes["pop_name"])
|
985
|
+
Generates a connectivity matrix showing synaptic distribution across different cell sections.
|
1006
986
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
987
|
+
Parameters
|
988
|
+
----------
|
989
|
+
config : str
|
990
|
+
Path to BMTK config file
|
991
|
+
source : str
|
992
|
+
The source BMTK network name
|
993
|
+
target : str
|
994
|
+
The target BMTK network name
|
995
|
+
sids : str
|
996
|
+
Column name in nodes file containing source population identifiers
|
997
|
+
tids : str
|
998
|
+
Column name in nodes file containing target population identifiers
|
999
|
+
|
1000
|
+
Returns
|
1001
|
+
-------
|
1002
|
+
tuple
|
1003
|
+
(matplotlib.figure.Figure, matplotlib.axes.Axes) containing the plot
|
1004
|
+
|
1005
|
+
Raises
|
1006
|
+
------
|
1007
|
+
ValueError
|
1008
|
+
If required parameters are missing or invalid
|
1009
|
+
RuntimeError
|
1010
|
+
If template loading or cell instantiation fails
|
1011
|
+
"""
|
1012
|
+
import matplotlib.pyplot as plt
|
1013
|
+
import numpy as np
|
1014
|
+
from neuron import h
|
1015
|
+
|
1016
|
+
# Validate inputs
|
1017
|
+
if not all([config, source, target, sids, tids]):
|
1018
|
+
raise ValueError(
|
1019
|
+
"Missing required parameters: config, source, target, sids, and tids must be provided"
|
1020
|
+
)
|
1026
1021
|
|
1027
|
-
|
1028
|
-
|
1022
|
+
try:
|
1023
|
+
# Load mechanisms and template
|
1024
|
+
util.load_templates_from_config(config)
|
1025
|
+
except Exception as e:
|
1026
|
+
raise RuntimeError(f"Failed to load templates from config: {str(e)}")
|
1029
1027
|
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1028
|
+
try:
|
1029
|
+
# Load node and edge data
|
1030
|
+
nodes, edges = util.load_nodes_edges_from_config(config)
|
1031
|
+
if source not in nodes or f"{source}_to_{target}" not in edges:
|
1032
|
+
raise ValueError(f"Source '{source}' or target '{target}' networks not found in data")
|
1033
1033
|
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1034
|
+
nodes = nodes[source]
|
1035
|
+
edges = edges[f"{source}_to_{target}"]
|
1036
|
+
except Exception as e:
|
1037
|
+
raise RuntimeError(f"Failed to load nodes and edges: {str(e)}")
|
1037
1038
|
|
1038
|
-
|
1039
|
-
|
1039
|
+
# Map identifiers while checking for missing values
|
1040
|
+
edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
|
1041
|
+
edges["target_pop_name"] = edges["target_node_id"].map(nodes[tids])
|
1042
|
+
edges["source_pop_name"] = edges["source_node_id"].map(nodes[sids])
|
1043
|
+
|
1044
|
+
if edges["target_model_template"].isnull().any():
|
1045
|
+
print("Warning: Some target nodes missing model template")
|
1046
|
+
if edges["target_pop_name"].isnull().any():
|
1047
|
+
print("Warning: Some target nodes missing population name")
|
1048
|
+
if edges["source_pop_name"].isnull().any():
|
1049
|
+
print("Warning: Some source nodes missing population name")
|
1050
|
+
|
1051
|
+
# Get unique populations
|
1052
|
+
source_pops = edges["source_pop_name"].unique()
|
1053
|
+
target_pops = edges["target_pop_name"].unique()
|
1054
|
+
|
1055
|
+
# Initialize matrices
|
1056
|
+
num_connections = np.zeros((len(source_pops), len(target_pops)))
|
1057
|
+
text_data = np.empty((len(source_pops), len(target_pops)), dtype=object)
|
1058
|
+
|
1059
|
+
# Create mappings for indices
|
1060
|
+
source_pop_to_idx = {pop: idx for idx, pop in enumerate(source_pops)}
|
1061
|
+
target_pop_to_idx = {pop: idx for idx, pop in enumerate(target_pops)}
|
1062
|
+
|
1063
|
+
# Cache for section mappings to avoid recreating cells
|
1064
|
+
section_mappings = {}
|
1065
|
+
|
1066
|
+
# Calculate connectivity statistics
|
1067
|
+
for source_pop in source_pops:
|
1068
|
+
for target_pop in target_pops:
|
1069
|
+
# Filter edges for this source-target pair
|
1070
|
+
filtered_edges = edges[
|
1071
|
+
(edges["source_pop_name"] == source_pop) & (edges["target_pop_name"] == target_pop)
|
1072
|
+
]
|
1040
1073
|
|
1041
|
-
|
1042
|
-
|
1043
|
-
if len(pop_group) > 0:
|
1044
|
-
ax.hist(
|
1045
|
-
pop_group["afferent_section_pos"],
|
1046
|
-
bins=15,
|
1047
|
-
alpha=0.7,
|
1048
|
-
label=pop_name,
|
1049
|
-
color=pop_colors[pop_name],
|
1050
|
-
)
|
1074
|
+
source_idx = source_pop_to_idx[source_pop]
|
1075
|
+
target_idx = target_pop_to_idx[target_pop]
|
1051
1076
|
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
ax.tick_params(labelsize=7)
|
1057
|
-
ax.grid(True, alpha=0.3)
|
1077
|
+
if len(filtered_edges) == 0:
|
1078
|
+
num_connections[source_idx, target_idx] = 0
|
1079
|
+
text_data[source_idx, target_idx] = "No connections"
|
1080
|
+
continue
|
1058
1081
|
|
1059
|
-
|
1060
|
-
|
1061
|
-
ax.legend(fontsize=8)
|
1082
|
+
total_connections = len(filtered_edges)
|
1083
|
+
target_model_template = filtered_edges["target_model_template"].iloc[0]
|
1062
1084
|
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1085
|
+
try:
|
1086
|
+
# Get or create section mapping for this model
|
1087
|
+
if target_model_template not in section_mappings:
|
1088
|
+
cell_class_name = (
|
1089
|
+
target_model_template.split(":")[1]
|
1090
|
+
if ":" in target_model_template
|
1091
|
+
else target_model_template
|
1092
|
+
)
|
1093
|
+
cell = getattr(h, cell_class_name)()
|
1094
|
+
|
1095
|
+
# Create section mapping
|
1096
|
+
section_mapping = {}
|
1097
|
+
for idx, sec in enumerate(cell.all):
|
1098
|
+
section_mapping[idx] = sec.name().split(".")[-1] # Clean name
|
1099
|
+
section_mappings[target_model_template] = section_mapping
|
1100
|
+
|
1101
|
+
section_mapping = section_mappings[target_model_template]
|
1102
|
+
|
1103
|
+
# Calculate section distribution
|
1104
|
+
section_counts = filtered_edges["afferent_section_id"].value_counts()
|
1105
|
+
section_percentages = (section_counts / total_connections * 100).round(1)
|
1106
|
+
|
1107
|
+
# Format section distribution text - show all sections
|
1108
|
+
section_display = []
|
1109
|
+
for section_id, percentage in section_percentages.items():
|
1110
|
+
section_name = section_mapping.get(section_id, f"sec_{section_id}")
|
1111
|
+
section_display.append(f"{section_name}:{percentage}%")
|
1112
|
+
|
1113
|
+
num_connections[source_idx, target_idx] = total_connections
|
1114
|
+
text_data[source_idx, target_idx] = "\n".join(section_display)
|
1115
|
+
|
1116
|
+
except Exception as e:
|
1117
|
+
print(f"Warning: Error processing {target_model_template}: {str(e)}")
|
1118
|
+
num_connections[source_idx, target_idx] = total_connections
|
1119
|
+
text_data[source_idx, target_idx] = "Section info N/A"
|
1120
|
+
|
1121
|
+
# Create the plot
|
1122
|
+
title = f"Synaptic Distribution by Section: {source} to {target}"
|
1123
|
+
fig, ax = plot_connection_info(
|
1124
|
+
text=text_data,
|
1125
|
+
num=num_connections,
|
1126
|
+
source_labels=list(source_pops),
|
1127
|
+
target_labels=list(target_pops),
|
1128
|
+
title=title,
|
1129
|
+
syn_info="1",
|
1066
1130
|
)
|
1067
|
-
if is_notebook:
|
1131
|
+
if is_notebook():
|
1068
1132
|
plt.show()
|
1069
1133
|
else:
|
1070
|
-
|
1071
|
-
|
1072
|
-
# Create a summary table
|
1073
|
-
print("Summary of connections by section and source population:")
|
1074
|
-
pivot_table = edges.pivot_table(
|
1075
|
-
values="afferent_section_id",
|
1076
|
-
index="afferent_section_name",
|
1077
|
-
columns="source_pop_name",
|
1078
|
-
aggfunc="count",
|
1079
|
-
fill_value=0,
|
1080
|
-
)
|
1081
|
-
print(pivot_table)
|
1134
|
+
return fig, ax
|
1082
1135
|
|
1083
1136
|
|
1084
1137
|
def plot_connection_info(
|
@@ -1088,7 +1141,6 @@ def plot_connection_info(
|
|
1088
1141
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
1089
1142
|
If there is no source or target, set the value to 0.
|
1090
1143
|
"""
|
1091
|
-
|
1092
1144
|
# Ensure text dimensions match num dimensions
|
1093
1145
|
num_source = len(source_labels)
|
1094
1146
|
num_target = len(target_labels)
|
@@ -1096,103 +1148,149 @@ def plot_connection_info(
|
|
1096
1148
|
# Set color map
|
1097
1149
|
matplotlib.rc("image", cmap="viridis")
|
1098
1150
|
|
1099
|
-
#
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1151
|
+
# Calculate square cell size to ensure proper aspect ratio
|
1152
|
+
base_cell_size = 0.6 # Base size per cell
|
1153
|
+
|
1154
|
+
# Calculate figure dimensions with proper aspect ratio
|
1155
|
+
# Make sure width and height are proportional to the matrix dimensions
|
1156
|
+
fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
|
1157
|
+
fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
|
1103
1158
|
|
1104
|
-
#
|
1159
|
+
# Ensure minimum readable size
|
1160
|
+
min_fig_size = 8
|
1161
|
+
if fig_width < min_fig_size or fig_height < min_fig_size:
|
1162
|
+
scale_factor = min_fig_size / min(fig_width, fig_height)
|
1163
|
+
fig_width *= scale_factor
|
1164
|
+
fig_height *= scale_factor
|
1165
|
+
|
1166
|
+
# Create figure and axis
|
1167
|
+
fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
|
1168
|
+
|
1169
|
+
# Replace NaN with 0 and create heatmap
|
1170
|
+
num_clean = np.nan_to_num(num, nan=0)
|
1171
|
+
# if string is nan\nnan make it 0
|
1172
|
+
|
1173
|
+
# Use 'auto' aspect ratio to let matplotlib handle it properly
|
1174
|
+
# This prevents the stretching issue
|
1175
|
+
im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
|
1176
|
+
|
1177
|
+
# Set ticks and labels
|
1105
1178
|
ax1.set_xticks(list(np.arange(len(target_labels))))
|
1106
1179
|
ax1.set_yticks(list(np.arange(len(source_labels))))
|
1107
1180
|
ax1.set_xticklabels(target_labels)
|
1108
|
-
ax1.set_yticklabels(source_labels
|
1181
|
+
ax1.set_yticklabels(source_labels)
|
1109
1182
|
|
1110
|
-
#
|
1183
|
+
# Improved font sizing based on matrix size
|
1184
|
+
label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
|
1185
|
+
|
1186
|
+
# Style the tick labels
|
1187
|
+
ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
|
1111
1188
|
plt.setp(
|
1112
1189
|
ax1.get_xticklabels(),
|
1113
1190
|
rotation=45,
|
1114
1191
|
ha="right",
|
1115
1192
|
rotation_mode="anchor",
|
1116
|
-
|
1117
|
-
weight="semibold",
|
1193
|
+
fontsize=label_font_size,
|
1118
1194
|
)
|
1119
1195
|
|
1120
1196
|
# Dictionary to store connection information
|
1121
1197
|
graph_dict = {}
|
1122
1198
|
|
1199
|
+
# Improved text size calculation - more readable for larger matrices
|
1200
|
+
text_size = max(6, min(12, 80 / max(num_source, num_target)))
|
1201
|
+
|
1123
1202
|
# Loop over data dimensions and create text annotations
|
1124
1203
|
for i in range(num_source):
|
1125
1204
|
for j in range(num_target):
|
1126
|
-
|
1127
|
-
edge_info = text[i, j] if text[i, j] is not None else 0
|
1205
|
+
edge_info = text[i, j] if text[i, j] is not None else "0\n0"
|
1128
1206
|
|
1129
|
-
# Initialize the dictionary for the source node if not already done
|
1130
1207
|
if source_labels[i] not in graph_dict:
|
1131
1208
|
graph_dict[source_labels[i]] = {}
|
1132
|
-
|
1133
|
-
# Add edge info for the target node
|
1134
1209
|
graph_dict[source_labels[i]][target_labels[j]] = edge_info
|
1135
1210
|
|
1136
|
-
#
|
1137
|
-
if
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
fig_text = ax1.text(
|
1152
|
-
j,
|
1153
|
-
i,
|
1154
|
-
edge_info,
|
1155
|
-
ha="center",
|
1156
|
-
va="center",
|
1157
|
-
color="w",
|
1158
|
-
rotation=37.5,
|
1159
|
-
size=7,
|
1160
|
-
weight="semibold",
|
1161
|
-
)
|
1211
|
+
# Skip displaying text for NaN values to reduce clutter
|
1212
|
+
if edge_info == "nan\nnan":
|
1213
|
+
edge_info = "0\n±0"
|
1214
|
+
|
1215
|
+
# Format the text display
|
1216
|
+
if isinstance(edge_info, str) and "\n" in edge_info:
|
1217
|
+
# For mean/std format (e.g. "15.5\n4.0")
|
1218
|
+
parts = edge_info.split("\n")
|
1219
|
+
if len(parts) == 2:
|
1220
|
+
try:
|
1221
|
+
mean_val = float(parts[0])
|
1222
|
+
std_val = float(parts[1])
|
1223
|
+
display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
|
1224
|
+
except ValueError:
|
1225
|
+
display_text = edge_info
|
1162
1226
|
else:
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1227
|
+
display_text = edge_info
|
1228
|
+
else:
|
1229
|
+
display_text = str(edge_info)
|
1230
|
+
|
1231
|
+
# Add text to plot with better contrast
|
1232
|
+
text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
|
1233
|
+
|
1234
|
+
if syn_info == "2" or syn_info == "3":
|
1235
|
+
ax1.text(
|
1236
|
+
j,
|
1237
|
+
i,
|
1238
|
+
display_text,
|
1239
|
+
ha="center",
|
1240
|
+
va="center",
|
1241
|
+
color=text_color,
|
1242
|
+
rotation=37.5,
|
1243
|
+
fontsize=text_size,
|
1244
|
+
weight="bold",
|
1245
|
+
)
|
1174
1246
|
else:
|
1175
|
-
|
1176
|
-
j,
|
1247
|
+
ax1.text(
|
1248
|
+
j,
|
1249
|
+
i,
|
1250
|
+
display_text,
|
1251
|
+
ha="center",
|
1252
|
+
va="center",
|
1253
|
+
color=text_color,
|
1254
|
+
fontsize=text_size,
|
1255
|
+
weight="bold",
|
1177
1256
|
)
|
1178
1257
|
|
1179
|
-
# Set labels and title
|
1180
|
-
|
1181
|
-
ax1.
|
1182
|
-
ax1.
|
1258
|
+
# Set labels and title
|
1259
|
+
title_font_size = max(12, min(18, label_font_size + 4))
|
1260
|
+
ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
|
1261
|
+
ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
|
1262
|
+
ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
|
1263
|
+
|
1264
|
+
# Add colorbar
|
1265
|
+
cbar = plt.colorbar(im1, shrink=0.8)
|
1266
|
+
cbar.ax.tick_params(labelsize=label_font_size)
|
1267
|
+
|
1268
|
+
# Adjust layout to minimize whitespace and prevent stretching
|
1269
|
+
plt.tight_layout(pad=1.5)
|
1270
|
+
|
1271
|
+
# Force square cells by setting equal axis limits if needed
|
1272
|
+
ax1.set_xlim(-0.5, num_target - 0.5)
|
1273
|
+
ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
|
1274
|
+
|
1275
|
+
# Display or save the plot
|
1276
|
+
try:
|
1277
|
+
# Check if running in notebook
|
1278
|
+
from IPython import get_ipython
|
1279
|
+
|
1280
|
+
notebook = get_ipython() is not None
|
1281
|
+
except ImportError:
|
1282
|
+
notebook = False
|
1183
1283
|
|
1184
|
-
# Display the plot or save it based on the environment and arguments
|
1185
|
-
notebook = is_notebook() # Check if running in a Jupyter notebook
|
1186
1284
|
if not notebook:
|
1187
|
-
|
1285
|
+
plt.show()
|
1188
1286
|
|
1189
1287
|
if save_file:
|
1190
|
-
plt.savefig(save_file)
|
1288
|
+
plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
|
1191
1289
|
|
1192
1290
|
if return_dict:
|
1193
1291
|
return graph_dict
|
1194
1292
|
else:
|
1195
|
-
return
|
1293
|
+
return fig1, ax1
|
1196
1294
|
|
1197
1295
|
|
1198
1296
|
def connector_percent_matrix(
|
@@ -1467,19 +1565,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1467
1565
|
plt.title(title)
|
1468
1566
|
plt.legend(handles=handles)
|
1469
1567
|
|
1568
|
+
# Add axis labels
|
1569
|
+
ax.set_xlabel("X Position (μm)")
|
1570
|
+
ax.set_ylabel("Y Position (μm)")
|
1571
|
+
ax.set_zlabel("Z Position (μm)")
|
1572
|
+
|
1470
1573
|
# Draw the plot
|
1471
1574
|
plt.draw()
|
1575
|
+
plt.tight_layout()
|
1472
1576
|
|
1473
1577
|
# Save the plot if save_file is provided
|
1474
1578
|
if save_file:
|
1475
1579
|
plt.savefig(save_file)
|
1476
1580
|
|
1477
|
-
# Show
|
1478
|
-
if
|
1581
|
+
# Show if running in notebook
|
1582
|
+
if is_notebook:
|
1479
1583
|
plt.show()
|
1480
1584
|
|
1481
|
-
return ax
|
1482
|
-
|
1483
1585
|
|
1484
1586
|
def plot_3d_cell_rotation(
|
1485
1587
|
config=None,
|