bmtool 0.7.1.6__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
7
7
  import numba
8
8
  import numpy as np
9
9
  import pandas as pd
10
- import scipy.stats as stats
11
10
  import xarray as xr
12
11
  from numba import cuda
13
12
  from scipy import signal
14
13
  from tqdm.notebook import tqdm
15
14
 
16
- from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
15
+ from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
17
16
 
18
17
 
19
18
  def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
635
634
  return entrainment_dict
636
635
 
637
636
 
638
- def calculate_spike_rate_power_correlation(
639
- spike_rate,
640
- lfp_data,
641
- fs,
642
- pop_names,
643
- filter_method="wavelet",
644
- bandwidth=2.0,
645
- lowcut=None,
646
- highcut=None,
647
- freq_range=(10, 100),
648
- freq_step=5,
649
- ):
650
- """
651
- Calculate correlation between population spike rates and LFP power across frequencies
652
- using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
653
-
654
- Parameters:
655
- -----------
656
- spike_rate : DataFrame
657
- Pre-calculated population spike rates at the same fs as lfp
658
- lfp_data : np.array
659
- LFP data
660
- fs : float
661
- Sampling frequency
662
- pop_names : list
663
- List of population names to analyze
664
- filter_method : str, optional
665
- Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
666
- bandwidth : float, optional
667
- Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
668
- lowcut : float, optional
669
- Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
670
- highcut : float, optional
671
- Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
672
- freq_range : tuple, optional
673
- Min and max frequency to analyze (default: (10, 100))
674
- freq_step : float, optional
675
- Step size for frequency analysis (default: 5)
676
-
677
- Returns:
678
- --------
679
- correlation_results : dict
680
- Dictionary with correlation results for each population and frequency
681
- frequencies : array
682
- Array of frequencies analyzed
683
- """
684
-
685
- # Define frequency bands to analyze
686
- frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
687
-
688
- # Dictionary to store results
689
- correlation_results = {pop: {} for pop in pop_names}
690
-
691
- # Calculate power at each frequency band using specified filter
692
- power_by_freq = {}
693
- for freq in frequencies:
694
- power_by_freq[freq] = get_lfp_power(
695
- lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
696
- )
697
-
698
- # Calculate correlation for each population
699
- for pop in pop_names:
700
- # Extract spike rate for this population
701
- pop_rate = spike_rate[pop]
702
-
703
- # Calculate correlation with power at each frequency
704
- for freq in frequencies:
705
- # Make sure the lengths match
706
- if len(pop_rate) != len(power_by_freq[freq]):
707
- raise ValueError(
708
- f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
709
- )
710
- # use spearman for non-parametric correlation
711
- corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
712
- correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
713
-
714
- return correlation_results, frequencies
715
-
716
-
717
637
  def get_spikes_in_cycle(
718
638
  spike_df,
719
639
  lfp_data,
bmtool/analysis/spikes.py CHANGED
@@ -281,10 +281,13 @@ def get_population_spike_rate(
281
281
  node_number = {}
282
282
 
283
283
  if config is None:
284
- print(
285
- "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect."
286
- )
287
- print("You can provide a config to calculate the correct amount of nodes!")
284
+ pass
285
+ # print(
286
+ # "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
287
+ # )
288
+ # print(
289
+ # "You can provide a config to calculate the correct amount of nodes! for a true population rate."
290
+ # )
288
291
 
289
292
  if config:
290
293
  if not network_name:
@@ -600,3 +603,51 @@ def find_bursting_cells(
600
603
  )
601
604
 
602
605
  return burst_cells
606
+
607
+
608
+ def find_highest_firing_cells(
609
+ df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
610
+ ) -> pd.DataFrame:
611
+ """
612
+ Identifies and returns spikes from cells with firing rates above a specified upper quantile,
613
+ grouped by a population label.
614
+
615
+ Parameters
616
+ ----------
617
+ df : pd.DataFrame
618
+ DataFrame containing spike data with at least the following columns:
619
+ - 'timestamps': Time of each spike event
620
+ - 'node_ids': Identifier for each neuron
621
+ - groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
622
+
623
+ upper_quantile : float
624
+ The upper quantile threshold (between 0 and 1).
625
+ Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
626
+ For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
627
+
628
+ groupby : str, optional
629
+ The column name used to group neurons by population. Default is 'pop_name'.
630
+
631
+ Returns
632
+ -------
633
+ pd.DataFrame
634
+ A DataFrame containing only the spikes from the high-firing cells across all groupbys.
635
+ """
636
+ if upper_quantile == 0:
637
+ return df
638
+ df_list = []
639
+ for pop in df[groupby].unique():
640
+ pop_df = df[df[groupby] == pop]
641
+ _, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
642
+
643
+ # Identify high firing cells
644
+ threshold = pop_fr["firing_rate"].quantile(upper_quantile)
645
+ high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
646
+
647
+ # Filter spikes for high firing cells
648
+ pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
649
+ df_list.append(pop_spikes)
650
+
651
+ # Combine all high firing spikes into one DataFrame
652
+ result_df = pd.concat(df_list, ignore_index=True)
653
+ return result_df
@@ -1088,7 +1088,6 @@ def plot_connection_info(
1088
1088
  Function to plot connection information as a heatmap, including handling missing source and target values.
1089
1089
  If there is no source or target, set the value to 0.
1090
1090
  """
1091
-
1092
1091
  # Ensure text dimensions match num dimensions
1093
1092
  num_source = len(source_labels)
1094
1093
  num_target = len(target_labels)
@@ -1096,103 +1095,149 @@ def plot_connection_info(
1096
1095
  # Set color map
1097
1096
  matplotlib.rc("image", cmap="viridis")
1098
1097
 
1099
- # Create figure and axis for the plot
1100
- fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
1101
- num = np.nan_to_num(num, nan=0) # replace NaN with 0
1102
- im1 = ax1.imshow(num)
1098
+ # Calculate square cell size to ensure proper aspect ratio
1099
+ base_cell_size = 0.6 # Base size per cell
1100
+
1101
+ # Calculate figure dimensions with proper aspect ratio
1102
+ # Make sure width and height are proportional to the matrix dimensions
1103
+ fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
1104
+ fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
1105
+
1106
+ # Ensure minimum readable size
1107
+ min_fig_size = 8
1108
+ if fig_width < min_fig_size or fig_height < min_fig_size:
1109
+ scale_factor = min_fig_size / min(fig_width, fig_height)
1110
+ fig_width *= scale_factor
1111
+ fig_height *= scale_factor
1112
+
1113
+ # Create figure and axis
1114
+ fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
1103
1115
 
1104
- # Set ticks and labels for source and target
1116
+ # Replace NaN with 0 and create heatmap
1117
+ num_clean = np.nan_to_num(num, nan=0)
1118
+ # if string is nan\nnan make it 0
1119
+
1120
+ # Use 'auto' aspect ratio to let matplotlib handle it properly
1121
+ # This prevents the stretching issue
1122
+ im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
1123
+
1124
+ # Set ticks and labels
1105
1125
  ax1.set_xticks(list(np.arange(len(target_labels))))
1106
1126
  ax1.set_yticks(list(np.arange(len(source_labels))))
1107
1127
  ax1.set_xticklabels(target_labels)
1108
- ax1.set_yticklabels(source_labels, size=12, weight="semibold")
1128
+ ax1.set_yticklabels(source_labels)
1109
1129
 
1110
- # Rotate the tick labels for better visibility
1130
+ # Improved font sizing based on matrix size
1131
+ label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
1132
+
1133
+ # Style the tick labels
1134
+ ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
1111
1135
  plt.setp(
1112
1136
  ax1.get_xticklabels(),
1113
1137
  rotation=45,
1114
1138
  ha="right",
1115
1139
  rotation_mode="anchor",
1116
- size=12,
1117
- weight="semibold",
1140
+ fontsize=label_font_size,
1118
1141
  )
1119
1142
 
1120
1143
  # Dictionary to store connection information
1121
1144
  graph_dict = {}
1122
1145
 
1146
+ # Improved text size calculation - more readable for larger matrices
1147
+ text_size = max(6, min(12, 80 / max(num_source, num_target)))
1148
+
1123
1149
  # Loop over data dimensions and create text annotations
1124
1150
  for i in range(num_source):
1125
1151
  for j in range(num_target):
1126
- # Get the edge info, or set it to '0' if it's missing
1127
- edge_info = text[i, j] if text[i, j] is not None else 0
1152
+ edge_info = text[i, j] if text[i, j] is not None else "0\n0"
1128
1153
 
1129
- # Initialize the dictionary for the source node if not already done
1130
1154
  if source_labels[i] not in graph_dict:
1131
1155
  graph_dict[source_labels[i]] = {}
1132
-
1133
- # Add edge info for the target node
1134
1156
  graph_dict[source_labels[i]][target_labels[j]] = edge_info
1135
1157
 
1136
- # Set text annotations based on syn_info type
1137
- if syn_info == "2" or syn_info == "3":
1138
- if num_source > 8 and num_source < 20:
1139
- fig_text = ax1.text(
1140
- j,
1141
- i,
1142
- edge_info,
1143
- ha="center",
1144
- va="center",
1145
- color="w",
1146
- rotation=37.5,
1147
- size=8,
1148
- weight="semibold",
1149
- )
1150
- elif num_source > 20:
1151
- fig_text = ax1.text(
1152
- j,
1153
- i,
1154
- edge_info,
1155
- ha="center",
1156
- va="center",
1157
- color="w",
1158
- rotation=37.5,
1159
- size=7,
1160
- weight="semibold",
1161
- )
1158
+ # Skip displaying text for NaN values to reduce clutter
1159
+ if edge_info == "nan\nnan":
1160
+ edge_info = "0\n±0"
1161
+
1162
+ # Format the text display
1163
+ if isinstance(edge_info, str) and "\n" in edge_info:
1164
+ # For mean/std format (e.g. "15.5\n4.0")
1165
+ parts = edge_info.split("\n")
1166
+ if len(parts) == 2:
1167
+ try:
1168
+ mean_val = float(parts[0])
1169
+ std_val = float(parts[1])
1170
+ display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
1171
+ except ValueError:
1172
+ display_text = edge_info
1162
1173
  else:
1163
- fig_text = ax1.text(
1164
- j,
1165
- i,
1166
- edge_info,
1167
- ha="center",
1168
- va="center",
1169
- color="w",
1170
- rotation=37.5,
1171
- size=11,
1172
- weight="semibold",
1173
- )
1174
+ display_text = edge_info
1175
+ else:
1176
+ display_text = str(edge_info)
1177
+
1178
+ # Add text to plot with better contrast
1179
+ text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
1180
+
1181
+ if syn_info == "2" or syn_info == "3":
1182
+ ax1.text(
1183
+ j,
1184
+ i,
1185
+ display_text,
1186
+ ha="center",
1187
+ va="center",
1188
+ color=text_color,
1189
+ rotation=37.5,
1190
+ fontsize=text_size,
1191
+ weight="bold",
1192
+ )
1174
1193
  else:
1175
- fig_text = ax1.text(
1176
- j, i, edge_info, ha="center", va="center", color="w", size=11, weight="semibold"
1194
+ ax1.text(
1195
+ j,
1196
+ i,
1197
+ display_text,
1198
+ ha="center",
1199
+ va="center",
1200
+ color=text_color,
1201
+ fontsize=text_size,
1202
+ weight="bold",
1177
1203
  )
1178
1204
 
1179
- # Set labels and title for the plot
1180
- ax1.set_ylabel("Source", size=11, weight="semibold")
1181
- ax1.set_xlabel("Target", size=11, weight="semibold")
1182
- ax1.set_title(title, size=20, weight="semibold")
1205
+ # Set labels and title
1206
+ title_font_size = max(12, min(18, label_font_size + 4))
1207
+ ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
1208
+ ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
1209
+ ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
1210
+
1211
+ # Add colorbar
1212
+ cbar = plt.colorbar(im1, shrink=0.8)
1213
+ cbar.ax.tick_params(labelsize=label_font_size)
1214
+
1215
+ # Adjust layout to minimize whitespace and prevent stretching
1216
+ plt.tight_layout(pad=1.5)
1217
+
1218
+ # Force square cells by setting equal axis limits if needed
1219
+ ax1.set_xlim(-0.5, num_target - 0.5)
1220
+ ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
1221
+
1222
+ # Display or save the plot
1223
+ try:
1224
+ # Check if running in notebook
1225
+ from IPython import get_ipython
1226
+
1227
+ notebook = get_ipython() is not None
1228
+ except ImportError:
1229
+ notebook = False
1183
1230
 
1184
- # Display the plot or save it based on the environment and arguments
1185
- notebook = is_notebook() # Check if running in a Jupyter notebook
1186
1231
  if not notebook:
1187
- fig1.show()
1232
+ plt.show()
1188
1233
 
1189
1234
  if save_file:
1190
- plt.savefig(save_file)
1235
+ plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
1191
1236
 
1192
1237
  if return_dict:
1193
1238
  return graph_dict
1194
1239
  else:
1195
- return
1240
+ return fig1, ax1
1196
1241
 
1197
1242
 
1198
1243
  def connector_percent_matrix(
@@ -1467,19 +1512,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1467
1512
  plt.title(title)
1468
1513
  plt.legend(handles=handles)
1469
1514
 
1515
+ # Add axis labels
1516
+ ax.set_xlabel("X Position (μm)")
1517
+ ax.set_ylabel("Y Position (μm)")
1518
+ ax.set_zlabel("Z Position (μm)")
1519
+
1470
1520
  # Draw the plot
1471
1521
  plt.draw()
1522
+ plt.tight_layout()
1472
1523
 
1473
1524
  # Save the plot if save_file is provided
1474
1525
  if save_file:
1475
1526
  plt.savefig(save_file)
1476
1527
 
1477
- # Show the plot if running outside of a notebook
1478
- if not is_notebook:
1528
+ # Show if running in notebook
1529
+ if is_notebook:
1479
1530
  plt.show()
1480
1531
 
1481
- return ax
1482
-
1483
1532
 
1484
1533
  def plot_3d_cell_rotation(
1485
1534
  config=None,