bmtool 0.7.1.7__py3-none-any.whl → 0.7.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
7
7
  import numba
8
8
  import numpy as np
9
9
  import pandas as pd
10
- import scipy.stats as stats
11
10
  import xarray as xr
12
11
  from numba import cuda
13
12
  from scipy import signal
14
13
  from tqdm.notebook import tqdm
15
14
 
16
- from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
15
+ from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
17
16
 
18
17
 
19
18
  def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
635
634
  return entrainment_dict
636
635
 
637
636
 
638
- def calculate_spike_rate_power_correlation(
639
- spike_rate: xr.DataArray,
640
- lfp_data: np.ndarray,
641
- fs: float,
642
- pop_names: list,
643
- filter_method: str = "wavelet",
644
- bandwidth: float = 2.0,
645
- lowcut: float = None,
646
- highcut: float = None,
647
- freq_range: tuple = (10, 100),
648
- freq_step: float = 5,
649
- type_name: str = "raw", # 'raw' or 'smoothed'
650
- ):
651
- """
652
- Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
653
-
654
- Parameters
655
- ----------
656
- spike_rate : xr.DataArray
657
- Population spike rates with dimensions (time, population[, type])
658
- lfp_data : np.ndarray
659
- LFP data
660
- fs : float
661
- Sampling frequency
662
- pop_names : list
663
- List of population names to analyze
664
- filter_method : str, optional
665
- Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
666
- bandwidth : float, optional
667
- Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
668
- lowcut : float, optional
669
- Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
670
- highcut : float, optional
671
- Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
672
- freq_range : tuple, optional
673
- Min and max frequency to analyze (default: (10, 100))
674
- freq_step : float, optional
675
- Step size for frequency analysis (default: 5)
676
- type_name : str, optional
677
- Which type of spike rate to use if 'type' dimension exists (default: 'raw')
678
-
679
- Returns
680
- -------
681
- correlation_results : dict
682
- Dictionary with correlation results for each population and frequency
683
- frequencies : array
684
- Array of frequencies analyzed
685
- """
686
- frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
687
- correlation_results = {pop: {} for pop in pop_names}
688
-
689
- # Calculate power at each frequency band using specified filter
690
- power_by_freq = {}
691
- for freq in frequencies:
692
- power_by_freq[freq] = get_lfp_power(
693
- lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
694
- )
695
-
696
- # For each population, extract the correct spike rate
697
- for pop in pop_names:
698
- # If 'type' dimension exists, select the type
699
- if "type" in spike_rate.dims:
700
- pop_rate = spike_rate.sel(population=pop, type=type_name).values
701
- else:
702
- pop_rate = spike_rate.sel(population=pop).values
703
-
704
- # Calculate correlation with power at each frequency
705
- for freq in frequencies:
706
- lfp_power = power_by_freq[freq]
707
- # Ensure lengths match
708
- min_len = min(len(pop_rate), len(lfp_power))
709
- if len(pop_rate) != len(lfp_power):
710
- print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
711
- corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
712
- correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
713
-
714
- return correlation_results, frequencies
715
-
716
-
717
637
  def get_spikes_in_cycle(
718
638
  spike_df,
719
639
  lfp_data,
bmtool/analysis/spikes.py CHANGED
@@ -281,12 +281,13 @@ def get_population_spike_rate(
281
281
  node_number = {}
282
282
 
283
283
  if config is None:
284
- print(
285
- "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
286
- )
287
- print(
288
- "You can provide a config to calculate the correct amount of nodes! for a true population rate."
289
- )
284
+ pass
285
+ # print(
286
+ # "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
287
+ # )
288
+ # print(
289
+ # "You can provide a config to calculate the correct amount of nodes! for a true population rate."
290
+ # )
290
291
 
291
292
  if config:
292
293
  if not network_name:
@@ -602,3 +603,51 @@ def find_bursting_cells(
602
603
  )
603
604
 
604
605
  return burst_cells
606
+
607
+
608
+ def find_highest_firing_cells(
609
+ df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
610
+ ) -> pd.DataFrame:
611
+ """
612
+ Identifies and returns spikes from cells with firing rates above a specified upper quantile,
613
+ grouped by a population label.
614
+
615
+ Parameters
616
+ ----------
617
+ df : pd.DataFrame
618
+ DataFrame containing spike data with at least the following columns:
619
+ - 'timestamps': Time of each spike event
620
+ - 'node_ids': Identifier for each neuron
621
+ - groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
622
+
623
+ upper_quantile : float
624
+ The upper quantile threshold (between 0 and 1).
625
+ Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
626
+ For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
627
+
628
+ groupby : str, optional
629
+ The column name used to group neurons by population. Default is 'pop_name'.
630
+
631
+ Returns
632
+ -------
633
+ pd.DataFrame
634
+ A DataFrame containing only the spikes from the high-firing cells across all groupbys.
635
+ """
636
+ if upper_quantile == 0:
637
+ return df
638
+ df_list = []
639
+ for pop in df[groupby].unique():
640
+ pop_df = df[df[groupby] == pop]
641
+ _, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
642
+
643
+ # Identify high firing cells
644
+ threshold = pop_fr["firing_rate"].quantile(upper_quantile)
645
+ high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
646
+
647
+ # Filter spikes for high firing cells
648
+ pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
649
+ df_list.append(pop_spikes)
650
+
651
+ # Combine all high firing spikes into one DataFrame
652
+ result_df = pd.concat(df_list, ignore_index=True)
653
+ return result_df
@@ -12,7 +12,6 @@ import matplotlib.pyplot as plt
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
  from IPython import get_ipython
15
- from neuron import h
16
15
 
17
16
  from ..util import util
18
17
 
@@ -981,104 +980,158 @@ def distance_delay_plot(
981
980
  plt.show()
982
981
 
983
982
 
984
- def plot_synapse_location_histograms(config, target_model, source=None, target=None):
983
+ def plot_synapse_location(config: str, source: str, target: str, sids: str, tids: str) -> tuple:
985
984
  """
986
- generates a histogram of the positions of the synapses on a cell broken down by section
987
- config: a BMTK config
988
- target_model: the name of the model_template used when building the BMTK node
989
- source: The source BMTK network
990
- target: The target BMTK network
991
- """
992
- # Load mechanisms and template
993
-
994
- util.load_templates_from_config(config)
995
-
996
- # Load node and edge data
997
- nodes, edges = util.load_nodes_edges_from_config(config)
998
- nodes = nodes[source]
999
- edges = edges[f"{source}_to_{target}"]
1000
-
1001
- # Map target_node_id to model_template
1002
- edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
1003
-
1004
- # Map source_node_id to pop_name
1005
- edges["source_pop_name"] = edges["source_node_id"].map(nodes["pop_name"])
985
+ Generates a connectivity matrix showing synaptic distribution across different cell sections.
1006
986
 
1007
- edges = edges[edges["target_model_template"] == target_model]
1008
-
1009
- # Create the cell model from target model
1010
- cell = getattr(h, target_model.split(":")[1])()
1011
-
1012
- # Create a mapping from section index to section name
1013
- section_id_to_name = {}
1014
- for idx, sec in enumerate(cell.all):
1015
- section_id_to_name[idx] = sec.name()
1016
-
1017
- # Add a new column with section names based on afferent_section_id
1018
- edges["afferent_section_name"] = edges["afferent_section_id"].map(section_id_to_name)
1019
-
1020
- # Get unique sections and source populations
1021
- unique_pops = edges["source_pop_name"].unique()
1022
-
1023
- # Filter to only include sections with data
1024
- section_counts = edges["afferent_section_name"].value_counts()
1025
- sections_with_data = section_counts[section_counts > 0].index.tolist()
987
+ Parameters
988
+ ----------
989
+ config : str
990
+ Path to BMTK config file
991
+ source : str
992
+ The source BMTK network name
993
+ target : str
994
+ The target BMTK network name
995
+ sids : str
996
+ Column name in nodes file containing source population identifiers
997
+ tids : str
998
+ Column name in nodes file containing target population identifiers
999
+
1000
+ Returns
1001
+ -------
1002
+ tuple
1003
+ (matplotlib.figure.Figure, matplotlib.axes.Axes) containing the plot
1004
+
1005
+ Raises
1006
+ ------
1007
+ ValueError
1008
+ If required parameters are missing or invalid
1009
+ RuntimeError
1010
+ If template loading or cell instantiation fails
1011
+ """
1012
+ import matplotlib.pyplot as plt
1013
+ import numpy as np
1014
+ from neuron import h
1015
+
1016
+ # Validate inputs
1017
+ if not all([config, source, target, sids, tids]):
1018
+ raise ValueError(
1019
+ "Missing required parameters: config, source, target, sids, and tids must be provided"
1020
+ )
1026
1021
 
1027
- # Create a figure with subplots for each section
1028
- plt.figure(figsize=(8, 12))
1022
+ try:
1023
+ # Load mechanisms and template
1024
+ util.load_templates_from_config(config)
1025
+ except Exception as e:
1026
+ raise RuntimeError(f"Failed to load templates from config: {str(e)}")
1029
1027
 
1030
- # Color map for source populations
1031
- color_map = plt.cm.tab10(np.linspace(0, 1, len(unique_pops)))
1032
- pop_colors = {pop: color for pop, color in zip(unique_pops, color_map)}
1028
+ try:
1029
+ # Load node and edge data
1030
+ nodes, edges = util.load_nodes_edges_from_config(config)
1031
+ if source not in nodes or f"{source}_to_{target}" not in edges:
1032
+ raise ValueError(f"Source '{source}' or target '{target}' networks not found in data")
1033
1033
 
1034
- # Create a histogram for each section
1035
- for i, section in enumerate(sections_with_data):
1036
- ax = plt.subplot(len(sections_with_data), 1, i + 1)
1034
+ nodes = nodes[source]
1035
+ edges = edges[f"{source}_to_{target}"]
1036
+ except Exception as e:
1037
+ raise RuntimeError(f"Failed to load nodes and edges: {str(e)}")
1037
1038
 
1038
- # Get data for this section
1039
- section_data = edges[edges["afferent_section_name"] == section]
1039
+ # Map identifiers while checking for missing values
1040
+ edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
1041
+ edges["target_pop_name"] = edges["target_node_id"].map(nodes[tids])
1042
+ edges["source_pop_name"] = edges["source_node_id"].map(nodes[sids])
1043
+
1044
+ if edges["target_model_template"].isnull().any():
1045
+ print("Warning: Some target nodes missing model template")
1046
+ if edges["target_pop_name"].isnull().any():
1047
+ print("Warning: Some target nodes missing population name")
1048
+ if edges["source_pop_name"].isnull().any():
1049
+ print("Warning: Some source nodes missing population name")
1050
+
1051
+ # Get unique populations
1052
+ source_pops = edges["source_pop_name"].unique()
1053
+ target_pops = edges["target_pop_name"].unique()
1054
+
1055
+ # Initialize matrices
1056
+ num_connections = np.zeros((len(source_pops), len(target_pops)))
1057
+ text_data = np.empty((len(source_pops), len(target_pops)), dtype=object)
1058
+
1059
+ # Create mappings for indices
1060
+ source_pop_to_idx = {pop: idx for idx, pop in enumerate(source_pops)}
1061
+ target_pop_to_idx = {pop: idx for idx, pop in enumerate(target_pops)}
1062
+
1063
+ # Cache for section mappings to avoid recreating cells
1064
+ section_mappings = {}
1065
+
1066
+ # Calculate connectivity statistics
1067
+ for source_pop in source_pops:
1068
+ for target_pop in target_pops:
1069
+ # Filter edges for this source-target pair
1070
+ filtered_edges = edges[
1071
+ (edges["source_pop_name"] == source_pop) & (edges["target_pop_name"] == target_pop)
1072
+ ]
1040
1073
 
1041
- # Group by source population
1042
- for pop_name, pop_group in section_data.groupby("source_pop_name"):
1043
- if len(pop_group) > 0:
1044
- ax.hist(
1045
- pop_group["afferent_section_pos"],
1046
- bins=15,
1047
- alpha=0.7,
1048
- label=pop_name,
1049
- color=pop_colors[pop_name],
1050
- )
1074
+ source_idx = source_pop_to_idx[source_pop]
1075
+ target_idx = target_pop_to_idx[target_pop]
1051
1076
 
1052
- # Set title and labels
1053
- ax.set_title(f"{section}", fontsize=10)
1054
- ax.set_xlabel("Section Position", fontsize=8)
1055
- ax.set_ylabel("Frequency", fontsize=8)
1056
- ax.tick_params(labelsize=7)
1057
- ax.grid(True, alpha=0.3)
1077
+ if len(filtered_edges) == 0:
1078
+ num_connections[source_idx, target_idx] = 0
1079
+ text_data[source_idx, target_idx] = "No connections"
1080
+ continue
1058
1081
 
1059
- # Only add legend to the first plot
1060
- if i == 0:
1061
- ax.legend(fontsize=8)
1082
+ total_connections = len(filtered_edges)
1083
+ target_model_template = filtered_edges["target_model_template"].iloc[0]
1062
1084
 
1063
- plt.tight_layout()
1064
- plt.suptitle(
1065
- "Connection Distribution by Cell Section and Source Population", fontsize=16, y=1.02
1085
+ try:
1086
+ # Get or create section mapping for this model
1087
+ if target_model_template not in section_mappings:
1088
+ cell_class_name = (
1089
+ target_model_template.split(":")[1]
1090
+ if ":" in target_model_template
1091
+ else target_model_template
1092
+ )
1093
+ cell = getattr(h, cell_class_name)()
1094
+
1095
+ # Create section mapping
1096
+ section_mapping = {}
1097
+ for idx, sec in enumerate(cell.all):
1098
+ section_mapping[idx] = sec.name().split(".")[-1] # Clean name
1099
+ section_mappings[target_model_template] = section_mapping
1100
+
1101
+ section_mapping = section_mappings[target_model_template]
1102
+
1103
+ # Calculate section distribution
1104
+ section_counts = filtered_edges["afferent_section_id"].value_counts()
1105
+ section_percentages = (section_counts / total_connections * 100).round(1)
1106
+
1107
+ # Format section distribution text - show all sections
1108
+ section_display = []
1109
+ for section_id, percentage in section_percentages.items():
1110
+ section_name = section_mapping.get(section_id, f"sec_{section_id}")
1111
+ section_display.append(f"{section_name}:{percentage}%")
1112
+
1113
+ num_connections[source_idx, target_idx] = total_connections
1114
+ text_data[source_idx, target_idx] = "\n".join(section_display)
1115
+
1116
+ except Exception as e:
1117
+ print(f"Warning: Error processing {target_model_template}: {str(e)}")
1118
+ num_connections[source_idx, target_idx] = total_connections
1119
+ text_data[source_idx, target_idx] = "Section info N/A"
1120
+
1121
+ # Create the plot
1122
+ title = f"Synaptic Distribution by Section: {source} to {target}"
1123
+ fig, ax = plot_connection_info(
1124
+ text=text_data,
1125
+ num=num_connections,
1126
+ source_labels=list(source_pops),
1127
+ target_labels=list(target_pops),
1128
+ title=title,
1129
+ syn_info="1",
1066
1130
  )
1067
- if is_notebook:
1131
+ if is_notebook():
1068
1132
  plt.show()
1069
1133
  else:
1070
- pass
1071
-
1072
- # Create a summary table
1073
- print("Summary of connections by section and source population:")
1074
- pivot_table = edges.pivot_table(
1075
- values="afferent_section_id",
1076
- index="afferent_section_name",
1077
- columns="source_pop_name",
1078
- aggfunc="count",
1079
- fill_value=0,
1080
- )
1081
- print(pivot_table)
1134
+ return fig, ax
1082
1135
 
1083
1136
 
1084
1137
  def plot_connection_info(
@@ -1088,7 +1141,6 @@ def plot_connection_info(
1088
1141
  Function to plot connection information as a heatmap, including handling missing source and target values.
1089
1142
  If there is no source or target, set the value to 0.
1090
1143
  """
1091
-
1092
1144
  # Ensure text dimensions match num dimensions
1093
1145
  num_source = len(source_labels)
1094
1146
  num_target = len(target_labels)
@@ -1096,103 +1148,149 @@ def plot_connection_info(
1096
1148
  # Set color map
1097
1149
  matplotlib.rc("image", cmap="viridis")
1098
1150
 
1099
- # Create figure and axis for the plot
1100
- fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
1101
- num = np.nan_to_num(num, nan=0) # replace NaN with 0
1102
- im1 = ax1.imshow(num)
1151
+ # Calculate square cell size to ensure proper aspect ratio
1152
+ base_cell_size = 0.6 # Base size per cell
1153
+
1154
+ # Calculate figure dimensions with proper aspect ratio
1155
+ # Make sure width and height are proportional to the matrix dimensions
1156
+ fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
1157
+ fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
1103
1158
 
1104
- # Set ticks and labels for source and target
1159
+ # Ensure minimum readable size
1160
+ min_fig_size = 8
1161
+ if fig_width < min_fig_size or fig_height < min_fig_size:
1162
+ scale_factor = min_fig_size / min(fig_width, fig_height)
1163
+ fig_width *= scale_factor
1164
+ fig_height *= scale_factor
1165
+
1166
+ # Create figure and axis
1167
+ fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
1168
+
1169
+ # Replace NaN with 0 and create heatmap
1170
+ num_clean = np.nan_to_num(num, nan=0)
1171
+ # if string is nan\nnan make it 0
1172
+
1173
+ # Use 'auto' aspect ratio to let matplotlib handle it properly
1174
+ # This prevents the stretching issue
1175
+ im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
1176
+
1177
+ # Set ticks and labels
1105
1178
  ax1.set_xticks(list(np.arange(len(target_labels))))
1106
1179
  ax1.set_yticks(list(np.arange(len(source_labels))))
1107
1180
  ax1.set_xticklabels(target_labels)
1108
- ax1.set_yticklabels(source_labels, size=12, weight="semibold")
1181
+ ax1.set_yticklabels(source_labels)
1109
1182
 
1110
- # Rotate the tick labels for better visibility
1183
+ # Improved font sizing based on matrix size
1184
+ label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
1185
+
1186
+ # Style the tick labels
1187
+ ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
1111
1188
  plt.setp(
1112
1189
  ax1.get_xticklabels(),
1113
1190
  rotation=45,
1114
1191
  ha="right",
1115
1192
  rotation_mode="anchor",
1116
- size=12,
1117
- weight="semibold",
1193
+ fontsize=label_font_size,
1118
1194
  )
1119
1195
 
1120
1196
  # Dictionary to store connection information
1121
1197
  graph_dict = {}
1122
1198
 
1199
+ # Improved text size calculation - more readable for larger matrices
1200
+ text_size = max(6, min(12, 80 / max(num_source, num_target)))
1201
+
1123
1202
  # Loop over data dimensions and create text annotations
1124
1203
  for i in range(num_source):
1125
1204
  for j in range(num_target):
1126
- # Get the edge info, or set it to '0' if it's missing
1127
- edge_info = text[i, j] if text[i, j] is not None else 0
1205
+ edge_info = text[i, j] if text[i, j] is not None else "0\n0"
1128
1206
 
1129
- # Initialize the dictionary for the source node if not already done
1130
1207
  if source_labels[i] not in graph_dict:
1131
1208
  graph_dict[source_labels[i]] = {}
1132
-
1133
- # Add edge info for the target node
1134
1209
  graph_dict[source_labels[i]][target_labels[j]] = edge_info
1135
1210
 
1136
- # Set text annotations based on syn_info type
1137
- if syn_info == "2" or syn_info == "3":
1138
- if num_source > 8 and num_source < 20:
1139
- fig_text = ax1.text(
1140
- j,
1141
- i,
1142
- edge_info,
1143
- ha="center",
1144
- va="center",
1145
- color="w",
1146
- rotation=37.5,
1147
- size=8,
1148
- weight="semibold",
1149
- )
1150
- elif num_source > 20:
1151
- fig_text = ax1.text(
1152
- j,
1153
- i,
1154
- edge_info,
1155
- ha="center",
1156
- va="center",
1157
- color="w",
1158
- rotation=37.5,
1159
- size=7,
1160
- weight="semibold",
1161
- )
1211
+ # Skip displaying text for NaN values to reduce clutter
1212
+ if edge_info == "nan\nnan":
1213
+ edge_info = "0\n±0"
1214
+
1215
+ # Format the text display
1216
+ if isinstance(edge_info, str) and "\n" in edge_info:
1217
+ # For mean/std format (e.g. "15.5\n4.0")
1218
+ parts = edge_info.split("\n")
1219
+ if len(parts) == 2:
1220
+ try:
1221
+ mean_val = float(parts[0])
1222
+ std_val = float(parts[1])
1223
+ display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
1224
+ except ValueError:
1225
+ display_text = edge_info
1162
1226
  else:
1163
- fig_text = ax1.text(
1164
- j,
1165
- i,
1166
- edge_info,
1167
- ha="center",
1168
- va="center",
1169
- color="w",
1170
- rotation=37.5,
1171
- size=11,
1172
- weight="semibold",
1173
- )
1227
+ display_text = edge_info
1228
+ else:
1229
+ display_text = str(edge_info)
1230
+
1231
+ # Add text to plot with better contrast
1232
+ text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
1233
+
1234
+ if syn_info == "2" or syn_info == "3":
1235
+ ax1.text(
1236
+ j,
1237
+ i,
1238
+ display_text,
1239
+ ha="center",
1240
+ va="center",
1241
+ color=text_color,
1242
+ rotation=37.5,
1243
+ fontsize=text_size,
1244
+ weight="bold",
1245
+ )
1174
1246
  else:
1175
- fig_text = ax1.text(
1176
- j, i, edge_info, ha="center", va="center", color="w", size=11, weight="semibold"
1247
+ ax1.text(
1248
+ j,
1249
+ i,
1250
+ display_text,
1251
+ ha="center",
1252
+ va="center",
1253
+ color=text_color,
1254
+ fontsize=text_size,
1255
+ weight="bold",
1177
1256
  )
1178
1257
 
1179
- # Set labels and title for the plot
1180
- ax1.set_ylabel("Source", size=11, weight="semibold")
1181
- ax1.set_xlabel("Target", size=11, weight="semibold")
1182
- ax1.set_title(title, size=20, weight="semibold")
1258
+ # Set labels and title
1259
+ title_font_size = max(12, min(18, label_font_size + 4))
1260
+ ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
1261
+ ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
1262
+ ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
1263
+
1264
+ # Add colorbar
1265
+ cbar = plt.colorbar(im1, shrink=0.8)
1266
+ cbar.ax.tick_params(labelsize=label_font_size)
1267
+
1268
+ # Adjust layout to minimize whitespace and prevent stretching
1269
+ plt.tight_layout(pad=1.5)
1270
+
1271
+ # Force square cells by setting equal axis limits if needed
1272
+ ax1.set_xlim(-0.5, num_target - 0.5)
1273
+ ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
1274
+
1275
+ # Display or save the plot
1276
+ try:
1277
+ # Check if running in notebook
1278
+ from IPython import get_ipython
1279
+
1280
+ notebook = get_ipython() is not None
1281
+ except ImportError:
1282
+ notebook = False
1183
1283
 
1184
- # Display the plot or save it based on the environment and arguments
1185
- notebook = is_notebook() # Check if running in a Jupyter notebook
1186
1284
  if not notebook:
1187
- fig1.show()
1285
+ plt.show()
1188
1286
 
1189
1287
  if save_file:
1190
- plt.savefig(save_file)
1288
+ plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
1191
1289
 
1192
1290
  if return_dict:
1193
1291
  return graph_dict
1194
1292
  else:
1195
- return
1293
+ return fig1, ax1
1196
1294
 
1197
1295
 
1198
1296
  def connector_percent_matrix(
@@ -1467,19 +1565,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1467
1565
  plt.title(title)
1468
1566
  plt.legend(handles=handles)
1469
1567
 
1568
+ # Add axis labels
1569
+ ax.set_xlabel("X Position (μm)")
1570
+ ax.set_ylabel("Y Position (μm)")
1571
+ ax.set_zlabel("Z Position (μm)")
1572
+
1470
1573
  # Draw the plot
1471
1574
  plt.draw()
1575
+ plt.tight_layout()
1472
1576
 
1473
1577
  # Save the plot if save_file is provided
1474
1578
  if save_file:
1475
1579
  plt.savefig(save_file)
1476
1580
 
1477
- # Show the plot if running outside of a notebook
1478
- if not is_notebook:
1581
+ # Show if running in notebook
1582
+ if is_notebook:
1479
1583
  plt.show()
1480
1584
 
1481
- return ax
1482
-
1483
1585
 
1484
1586
  def plot_3d_cell_rotation(
1485
1587
  config=None,