bmtool 0.7.1.7__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,12 @@ from typing import Dict, List, Optional, Union
7
7
  import numba
8
8
  import numpy as np
9
9
  import pandas as pd
10
- import scipy.stats as stats
11
10
  import xarray as xr
12
11
  from numba import cuda
13
12
  from scipy import signal
14
13
  from tqdm.notebook import tqdm
15
14
 
16
- from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
15
+ from .lfp import butter_bandpass_filter, get_lfp_phase, wavelet_filter
17
16
 
18
17
 
19
18
  def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
@@ -635,85 +634,6 @@ def calculate_entrainment_per_cell(
635
634
  return entrainment_dict
636
635
 
637
636
 
638
- def calculate_spike_rate_power_correlation(
639
- spike_rate: xr.DataArray,
640
- lfp_data: np.ndarray,
641
- fs: float,
642
- pop_names: list,
643
- filter_method: str = "wavelet",
644
- bandwidth: float = 2.0,
645
- lowcut: float = None,
646
- highcut: float = None,
647
- freq_range: tuple = (10, 100),
648
- freq_step: float = 5,
649
- type_name: str = "raw", # 'raw' or 'smoothed'
650
- ):
651
- """
652
- Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
653
-
654
- Parameters
655
- ----------
656
- spike_rate : xr.DataArray
657
- Population spike rates with dimensions (time, population[, type])
658
- lfp_data : np.ndarray
659
- LFP data
660
- fs : float
661
- Sampling frequency
662
- pop_names : list
663
- List of population names to analyze
664
- filter_method : str, optional
665
- Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
666
- bandwidth : float, optional
667
- Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
668
- lowcut : float, optional
669
- Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
670
- highcut : float, optional
671
- Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
672
- freq_range : tuple, optional
673
- Min and max frequency to analyze (default: (10, 100))
674
- freq_step : float, optional
675
- Step size for frequency analysis (default: 5)
676
- type_name : str, optional
677
- Which type of spike rate to use if 'type' dimension exists (default: 'raw')
678
-
679
- Returns
680
- -------
681
- correlation_results : dict
682
- Dictionary with correlation results for each population and frequency
683
- frequencies : array
684
- Array of frequencies analyzed
685
- """
686
- frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
687
- correlation_results = {pop: {} for pop in pop_names}
688
-
689
- # Calculate power at each frequency band using specified filter
690
- power_by_freq = {}
691
- for freq in frequencies:
692
- power_by_freq[freq] = get_lfp_power(
693
- lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
694
- )
695
-
696
- # For each population, extract the correct spike rate
697
- for pop in pop_names:
698
- # If 'type' dimension exists, select the type
699
- if "type" in spike_rate.dims:
700
- pop_rate = spike_rate.sel(population=pop, type=type_name).values
701
- else:
702
- pop_rate = spike_rate.sel(population=pop).values
703
-
704
- # Calculate correlation with power at each frequency
705
- for freq in frequencies:
706
- lfp_power = power_by_freq[freq]
707
- # Ensure lengths match
708
- min_len = min(len(pop_rate), len(lfp_power))
709
- if len(pop_rate) != len(lfp_power):
710
- print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
711
- corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
712
- correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
713
-
714
- return correlation_results, frequencies
715
-
716
-
717
637
  def get_spikes_in_cycle(
718
638
  spike_df,
719
639
  lfp_data,
bmtool/analysis/spikes.py CHANGED
@@ -281,12 +281,13 @@ def get_population_spike_rate(
281
281
  node_number = {}
282
282
 
283
283
  if config is None:
284
- print(
285
- "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
286
- )
287
- print(
288
- "You can provide a config to calculate the correct amount of nodes! for a true population rate."
289
- )
284
+ pass
285
+ # print(
286
+ # "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
287
+ # )
288
+ # print(
289
+ # "You can provide a config to calculate the correct amount of nodes! for a true population rate."
290
+ # )
290
291
 
291
292
  if config:
292
293
  if not network_name:
@@ -602,3 +603,51 @@ def find_bursting_cells(
602
603
  )
603
604
 
604
605
  return burst_cells
606
+
607
+
608
+ def find_highest_firing_cells(
609
+ df: pd.DataFrame, upper_quantile: float, groupby: str = "pop_name"
610
+ ) -> pd.DataFrame:
611
+ """
612
+ Identifies and returns spikes from cells with firing rates above a specified upper quantile,
613
+ grouped by a population label.
614
+
615
+ Parameters
616
+ ----------
617
+ df : pd.DataFrame
618
+ DataFrame containing spike data with at least the following columns:
619
+ - 'timestamps': Time of each spike event
620
+ - 'node_ids': Identifier for each neuron
621
+ - groupby (e.g., 'pop_name'): Population labels or grouping identifiers for neurons
622
+
623
+ upper_quantile : float
624
+ The upper quantile threshold (between 0 and 1).
625
+ Cells with firing rates in the top (1 - upper_quantile) fraction are selected.
626
+ For example, upper_quantile=0.8 selects the top 20% of high-firing cells.
627
+
628
+ groupby : str, optional
629
+ The column name used to group neurons by population. Default is 'pop_name'.
630
+
631
+ Returns
632
+ -------
633
+ pd.DataFrame
634
+ A DataFrame containing only the spikes from the high-firing cells across all groupbys.
635
+ """
636
+ if upper_quantile == 0:
637
+ return df
638
+ df_list = []
639
+ for pop in df[groupby].unique():
640
+ pop_df = df[df[groupby] == pop]
641
+ _, pop_fr = compute_firing_rate_stats(pop_df, groupby=groupby)
642
+
643
+ # Identify high firing cells
644
+ threshold = pop_fr["firing_rate"].quantile(upper_quantile)
645
+ high_firing_cells = pop_fr[pop_fr["firing_rate"] >= threshold]["node_ids"]
646
+
647
+ # Filter spikes for high firing cells
648
+ pop_spikes = pop_df[pop_df["node_ids"].isin(high_firing_cells)]
649
+ df_list.append(pop_spikes)
650
+
651
+ # Combine all high firing spikes into one DataFrame
652
+ result_df = pd.concat(df_list, ignore_index=True)
653
+ return result_df
@@ -1088,7 +1088,6 @@ def plot_connection_info(
1088
1088
  Function to plot connection information as a heatmap, including handling missing source and target values.
1089
1089
  If there is no source or target, set the value to 0.
1090
1090
  """
1091
-
1092
1091
  # Ensure text dimensions match num dimensions
1093
1092
  num_source = len(source_labels)
1094
1093
  num_target = len(target_labels)
@@ -1096,103 +1095,149 @@ def plot_connection_info(
1096
1095
  # Set color map
1097
1096
  matplotlib.rc("image", cmap="viridis")
1098
1097
 
1099
- # Create figure and axis for the plot
1100
- fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
1101
- num = np.nan_to_num(num, nan=0) # replace NaN with 0
1102
- im1 = ax1.imshow(num)
1098
+ # Calculate square cell size to ensure proper aspect ratio
1099
+ base_cell_size = 0.6 # Base size per cell
1100
+
1101
+ # Calculate figure dimensions with proper aspect ratio
1102
+ # Make sure width and height are proportional to the matrix dimensions
1103
+ fig_width = max(8, num_target * base_cell_size + 4) # Width based on columns
1104
+ fig_height = max(6, num_source * base_cell_size + 3) # Height based on rows
1105
+
1106
+ # Ensure minimum readable size
1107
+ min_fig_size = 8
1108
+ if fig_width < min_fig_size or fig_height < min_fig_size:
1109
+ scale_factor = min_fig_size / min(fig_width, fig_height)
1110
+ fig_width *= scale_factor
1111
+ fig_height *= scale_factor
1112
+
1113
+ # Create figure and axis
1114
+ fig1, ax1 = plt.subplots(figsize=(fig_width, fig_height))
1103
1115
 
1104
- # Set ticks and labels for source and target
1116
+ # Replace NaN with 0 and create heatmap
1117
+ num_clean = np.nan_to_num(num, nan=0)
1118
+ # if string is nan\nnan make it 0
1119
+
1120
+ # Use 'auto' aspect ratio to let matplotlib handle it properly
1121
+ # This prevents the stretching issue
1122
+ im1 = ax1.imshow(num_clean, aspect="auto", interpolation="nearest")
1123
+
1124
+ # Set ticks and labels
1105
1125
  ax1.set_xticks(list(np.arange(len(target_labels))))
1106
1126
  ax1.set_yticks(list(np.arange(len(source_labels))))
1107
1127
  ax1.set_xticklabels(target_labels)
1108
- ax1.set_yticklabels(source_labels, size=12, weight="semibold")
1128
+ ax1.set_yticklabels(source_labels)
1109
1129
 
1110
- # Rotate the tick labels for better visibility
1130
+ # Improved font sizing based on matrix size
1131
+ label_font_size = max(8, min(14, 120 / max(num_source, num_target)))
1132
+
1133
+ # Style the tick labels
1134
+ ax1.tick_params(axis="y", labelsize=label_font_size, pad=5)
1111
1135
  plt.setp(
1112
1136
  ax1.get_xticklabels(),
1113
1137
  rotation=45,
1114
1138
  ha="right",
1115
1139
  rotation_mode="anchor",
1116
- size=12,
1117
- weight="semibold",
1140
+ fontsize=label_font_size,
1118
1141
  )
1119
1142
 
1120
1143
  # Dictionary to store connection information
1121
1144
  graph_dict = {}
1122
1145
 
1146
+ # Improved text size calculation - more readable for larger matrices
1147
+ text_size = max(6, min(12, 80 / max(num_source, num_target)))
1148
+
1123
1149
  # Loop over data dimensions and create text annotations
1124
1150
  for i in range(num_source):
1125
1151
  for j in range(num_target):
1126
- # Get the edge info, or set it to '0' if it's missing
1127
- edge_info = text[i, j] if text[i, j] is not None else 0
1152
+ edge_info = text[i, j] if text[i, j] is not None else "0\n0"
1128
1153
 
1129
- # Initialize the dictionary for the source node if not already done
1130
1154
  if source_labels[i] not in graph_dict:
1131
1155
  graph_dict[source_labels[i]] = {}
1132
-
1133
- # Add edge info for the target node
1134
1156
  graph_dict[source_labels[i]][target_labels[j]] = edge_info
1135
1157
 
1136
- # Set text annotations based on syn_info type
1137
- if syn_info == "2" or syn_info == "3":
1138
- if num_source > 8 and num_source < 20:
1139
- fig_text = ax1.text(
1140
- j,
1141
- i,
1142
- edge_info,
1143
- ha="center",
1144
- va="center",
1145
- color="w",
1146
- rotation=37.5,
1147
- size=8,
1148
- weight="semibold",
1149
- )
1150
- elif num_source > 20:
1151
- fig_text = ax1.text(
1152
- j,
1153
- i,
1154
- edge_info,
1155
- ha="center",
1156
- va="center",
1157
- color="w",
1158
- rotation=37.5,
1159
- size=7,
1160
- weight="semibold",
1161
- )
1158
+ # Skip displaying text for NaN values to reduce clutter
1159
+ if edge_info == "nan\nnan":
1160
+ edge_info = "0\n±0"
1161
+
1162
+ # Format the text display
1163
+ if isinstance(edge_info, str) and "\n" in edge_info:
1164
+ # For mean/std format (e.g. "15.5\n4.0")
1165
+ parts = edge_info.split("\n")
1166
+ if len(parts) == 2:
1167
+ try:
1168
+ mean_val = float(parts[0])
1169
+ std_val = float(parts[1])
1170
+ display_text = f"{mean_val:.1f}\n±{std_val:.1f}"
1171
+ except ValueError:
1172
+ display_text = edge_info
1162
1173
  else:
1163
- fig_text = ax1.text(
1164
- j,
1165
- i,
1166
- edge_info,
1167
- ha="center",
1168
- va="center",
1169
- color="w",
1170
- rotation=37.5,
1171
- size=11,
1172
- weight="semibold",
1173
- )
1174
+ display_text = edge_info
1175
+ else:
1176
+ display_text = str(edge_info)
1177
+
1178
+ # Add text to plot with better contrast
1179
+ text_color = "white" if num_clean[i, j] < (np.nanmax(num_clean) * 0.9) else "black"
1180
+
1181
+ if syn_info == "2" or syn_info == "3":
1182
+ ax1.text(
1183
+ j,
1184
+ i,
1185
+ display_text,
1186
+ ha="center",
1187
+ va="center",
1188
+ color=text_color,
1189
+ rotation=37.5,
1190
+ fontsize=text_size,
1191
+ weight="bold",
1192
+ )
1174
1193
  else:
1175
- fig_text = ax1.text(
1176
- j, i, edge_info, ha="center", va="center", color="w", size=11, weight="semibold"
1194
+ ax1.text(
1195
+ j,
1196
+ i,
1197
+ display_text,
1198
+ ha="center",
1199
+ va="center",
1200
+ color=text_color,
1201
+ fontsize=text_size,
1202
+ weight="bold",
1177
1203
  )
1178
1204
 
1179
- # Set labels and title for the plot
1180
- ax1.set_ylabel("Source", size=11, weight="semibold")
1181
- ax1.set_xlabel("Target", size=11, weight="semibold")
1182
- ax1.set_title(title, size=20, weight="semibold")
1205
+ # Set labels and title
1206
+ title_font_size = max(12, min(18, label_font_size + 4))
1207
+ ax1.set_ylabel("Source", fontsize=title_font_size, weight="bold", labelpad=10)
1208
+ ax1.set_xlabel("Target", fontsize=title_font_size, weight="bold", labelpad=10)
1209
+ ax1.set_title(title, fontsize=title_font_size + 2, weight="bold", pad=20)
1210
+
1211
+ # Add colorbar
1212
+ cbar = plt.colorbar(im1, shrink=0.8)
1213
+ cbar.ax.tick_params(labelsize=label_font_size)
1214
+
1215
+ # Adjust layout to minimize whitespace and prevent stretching
1216
+ plt.tight_layout(pad=1.5)
1217
+
1218
+ # Force square cells by setting equal axis limits if needed
1219
+ ax1.set_xlim(-0.5, num_target - 0.5)
1220
+ ax1.set_ylim(num_source - 0.5, -0.5) # Inverted for proper matrix orientation
1221
+
1222
+ # Display or save the plot
1223
+ try:
1224
+ # Check if running in notebook
1225
+ from IPython import get_ipython
1226
+
1227
+ notebook = get_ipython() is not None
1228
+ except ImportError:
1229
+ notebook = False
1183
1230
 
1184
- # Display the plot or save it based on the environment and arguments
1185
- notebook = is_notebook() # Check if running in a Jupyter notebook
1186
1231
  if not notebook:
1187
- fig1.show()
1232
+ plt.show()
1188
1233
 
1189
1234
  if save_file:
1190
- plt.savefig(save_file)
1235
+ plt.savefig(save_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
1191
1236
 
1192
1237
  if return_dict:
1193
1238
  return graph_dict
1194
1239
  else:
1195
- return
1240
+ return fig1, ax1
1196
1241
 
1197
1242
 
1198
1243
  def connector_percent_matrix(
@@ -1467,19 +1512,23 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1467
1512
  plt.title(title)
1468
1513
  plt.legend(handles=handles)
1469
1514
 
1515
+ # Add axis labels
1516
+ ax.set_xlabel("X Position (μm)")
1517
+ ax.set_ylabel("Y Position (μm)")
1518
+ ax.set_zlabel("Z Position (μm)")
1519
+
1470
1520
  # Draw the plot
1471
1521
  plt.draw()
1522
+ plt.tight_layout()
1472
1523
 
1473
1524
  # Save the plot if save_file is provided
1474
1525
  if save_file:
1475
1526
  plt.savefig(save_file)
1476
1527
 
1477
- # Show the plot if running outside of a notebook
1478
- if not is_notebook:
1528
+ # Show if running in notebook
1529
+ if is_notebook:
1479
1530
  plt.show()
1480
1531
 
1481
- return ax
1482
-
1483
1532
 
1484
1533
  def plot_3d_cell_rotation(
1485
1534
  config=None,
@@ -1,56 +1,299 @@
1
+ from typing import List, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
4
  import numpy as np
3
5
  import pandas as pd
4
6
  import seaborn as sns
7
+ import xarray as xr
5
8
  from matplotlib.gridspec import GridSpec
6
9
  from scipy import stats
7
10
 
8
-
9
- def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
11
+ from bmtool.analysis import entrainment as bmentr
12
+ from bmtool.analysis import spikes as bmspikes
13
+ from bmtool.analysis.lfp import get_lfp_power
14
+
15
+
16
+ def plot_spike_power_correlation(
17
+ spike_df: pd.DataFrame,
18
+ lfp_data: xr.DataArray,
19
+ firing_quantile: float,
20
+ fs: float,
21
+ pop_names: list,
22
+ filter_method: str = "wavelet",
23
+ bandwidth: float = 2.0,
24
+ lowcut: float = None,
25
+ highcut: float = None,
26
+ freq_range: tuple = (10, 100),
27
+ freq_step: float = 5,
28
+ type_name: str = "raw",
29
+ time_windows: list = None,
30
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
31
+ ):
10
32
  """
11
- Plot the correlation between population spike rates and LFP power.
12
-
13
- Parameters:
14
- -----------
15
- correlation_results : dict
16
- Dictionary with correlation results for calculate_spike_rate_power_correlation
17
- frequencies : array
18
- Array of frequencies analyzed
33
+ Calculate and plot correlation between population spike rates and LFP power across frequencies.
34
+ Supports both single-signal and trial-based analysis with error bars.
35
+
36
+ Parameters
37
+ ----------
38
+ spike_rate : xr.DataArray
39
+ Population spike rates with dimensions (time, population[, type])
40
+ lfp_data : xr.DataArray
41
+ LFP data
42
+ fs : float
43
+ Sampling frequency
19
44
  pop_names : list
20
- List of population names
45
+ List of population names to analyze
46
+ filter_method : str, optional
47
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
48
+ bandwidth : float, optional
49
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
50
+ lowcut : float, optional
51
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
52
+ highcut : float, optional
53
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
54
+ freq_range : tuple, optional
55
+ Min and max frequency to analyze (default: (10, 100))
56
+ freq_step : float, optional
57
+ Step size for frequency analysis (default: 5)
58
+ type_name : str, optional
59
+ Which type of spike rate to use if 'type' dimension exists (default: 'raw')
60
+ time_windows : list, optional
61
+ List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
62
+ error_type : str, optional
63
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
21
64
  """
22
- sns.set_style("whitegrid")
23
- plt.figure(figsize=(10, 6))
24
65
 
66
+ if not (0 <= firing_quantile < 1):
67
+ raise ValueError("firing_quantile must be between 0 and 1")
68
+
69
+ if error_type not in ["ci", "sem", "std"]:
70
+ raise ValueError(
71
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
72
+ )
73
+
74
+ # Setup
75
+ is_trial_based = time_windows is not None
76
+
77
+ # Convert spike_df to spike rate with trial-based filtering of high firing cells
78
+ if is_trial_based:
79
+ # Initialize storage for trial-based spike rates
80
+ trial_rates = []
81
+
82
+ for start_time, end_time in time_windows:
83
+ # Get spikes for this trial
84
+ trial_spikes = spike_df[
85
+ (spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
86
+ ].copy()
87
+
88
+ # Filter for high firing cells within this trial
89
+ trial_spikes = bmspikes.find_highest_firing_cells(
90
+ trial_spikes, upper_quantile=firing_quantile
91
+ )
92
+ # Calculate rate for this trial's filtered spikes
93
+ trial_rate = bmspikes.get_population_spike_rate(
94
+ trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
95
+ )
96
+ trial_rates.append(trial_rate)
97
+
98
+ # Combine all trial rates
99
+ spike_rate = xr.concat(trial_rates, dim="trial")
100
+ else:
101
+ # For non-trial analysis, proceed as before
102
+ spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
103
+ spike_rate = bmspikes.get_population_spike_rate(spike_df)
104
+
105
+ # Setup frequencies for analysis
106
+ frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
107
+
108
+ # Pre-calculate LFP power for all frequencies
109
+ power_by_freq = {}
110
+ for freq in frequencies:
111
+ power_by_freq[freq] = get_lfp_power(
112
+ lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
113
+ )
114
+
115
+ # Calculate correlations
116
+ results = {}
25
117
  for pop in pop_names:
26
- # Extract correlation values for each frequency
27
- corr_values = []
28
- valid_freqs = []
118
+ pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
119
+ results[pop] = {}
29
120
 
30
121
  for freq in frequencies:
31
- if freq in correlation_results[pop]:
32
- corr_values.append(correlation_results[pop][freq]["correlation"])
33
- valid_freqs.append(freq)
122
+ lfp_power = power_by_freq[freq]
34
123
 
35
- # Plot correlation line
36
- plt.plot(valid_freqs, corr_values, marker="o", label=pop, linewidth=2, markersize=6)
124
+ if not is_trial_based:
125
+ # Single signal analysis
126
+ if len(pop_spike_rate) != len(lfp_power):
127
+ print(f"Warning: Length mismatch for {pop} at {freq} Hz")
128
+ continue
129
+
130
+ corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
131
+ results[pop][freq] = {
132
+ "correlation": corr,
133
+ "p_value": p_val,
134
+ }
135
+ else:
136
+ # Trial-based analysis using pre-filtered trial rates
137
+ trial_correlations = []
138
+
139
+ for trial_idx in range(len(time_windows)):
140
+ # Get time window first
141
+ start_time, end_time = time_windows[trial_idx]
142
+
143
+ # Get the pre-filtered spike rate for this trial
144
+ trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)
145
+
146
+ # Get corresponding LFP power for this trial window
147
+ trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))
148
+
149
+ # Ensure both signals have same time points
150
+ common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)
151
+
152
+ if len(common_times) > 0:
153
+ trial_sr = trial_spike_rate.sel(time=common_times).values
154
+ trial_lfp = trial_lfp_power.sel(time=common_times).values
155
+
156
+ if (
157
+ len(trial_sr) > 1 and len(trial_lfp) > 1
158
+ ): # Need at least 2 points for correlation
159
+ corr, _ = stats.spearmanr(trial_sr, trial_lfp)
160
+ if not np.isnan(corr):
161
+ trial_correlations.append(corr)
162
+
163
+ # Calculate trial statistics
164
+ if len(trial_correlations) > 0:
165
+ trial_correlations = np.array(trial_correlations)
166
+ mean_corr = np.mean(trial_correlations)
167
+
168
+ if len(trial_correlations) > 1:
169
+ if error_type == "ci":
170
+ # Calculate 95% confidence interval using t-distribution
171
+ df = len(trial_correlations) - 1
172
+ sem = stats.sem(trial_correlations)
173
+ t_critical = stats.t.ppf(0.975, df) # 95% CI, two-tailed
174
+ error_val = t_critical * sem
175
+ error_lower = mean_corr - error_val
176
+ error_upper = mean_corr + error_val
177
+ elif error_type == "sem":
178
+ # Calculate standard error of the mean
179
+ sem = stats.sem(trial_correlations)
180
+ error_lower = mean_corr - sem
181
+ error_upper = mean_corr + sem
182
+ elif error_type == "std":
183
+ # Calculate standard deviation
184
+ std = np.std(trial_correlations, ddof=1)
185
+ error_lower = mean_corr - std
186
+ error_upper = mean_corr + std
187
+ else:
188
+ error_lower = error_upper = mean_corr
189
+
190
+ results[pop][freq] = {
191
+ "correlation": mean_corr,
192
+ "error_lower": error_lower,
193
+ "error_upper": error_upper,
194
+ "n_trials": len(trial_correlations),
195
+ "trial_correlations": trial_correlations,
196
+ }
197
+ else:
198
+ # No valid trials
199
+ results[pop][freq] = {
200
+ "correlation": np.nan,
201
+ "error_lower": np.nan,
202
+ "error_upper": np.nan,
203
+ "n_trials": 0,
204
+ "trial_correlations": np.array([]),
205
+ }
206
+
207
+ # Plotting
208
+ sns.set_style("whitegrid")
209
+ plt.figure(figsize=(12, 8))
37
210
 
211
+ for i, pop in enumerate(pop_names):
212
+ # Extract data for plotting
213
+ plot_freqs = []
214
+ plot_corrs = []
215
+ plot_ci_lower = []
216
+ plot_ci_upper = []
217
+
218
+ for freq in frequencies:
219
+ if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
220
+ plot_freqs.append(freq)
221
+ plot_corrs.append(results[pop][freq]["correlation"])
222
+
223
+ if is_trial_based:
224
+ plot_ci_lower.append(results[pop][freq]["error_lower"])
225
+ plot_ci_upper.append(results[pop][freq]["error_upper"])
226
+
227
+ if len(plot_freqs) == 0:
228
+ continue
229
+
230
+ # Convert to arrays
231
+ plot_freqs = np.array(plot_freqs)
232
+ plot_corrs = np.array(plot_corrs)
233
+
234
+ # Get color for this population
235
+ colors = plt.get_cmap("tab10")
236
+ color = colors(i)
237
+
238
+ # Plot main line
239
+ plt.plot(
240
+ plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
241
+ )
242
+
243
+ # Plot error bands for trial-based analysis
244
+ if is_trial_based and len(plot_ci_lower) > 0:
245
+ plot_ci_lower = np.array(plot_ci_lower)
246
+ plot_ci_upper = np.array(plot_ci_upper)
247
+ plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)
248
+
249
+ # Formatting
38
250
  plt.xlabel("Frequency (Hz)", fontsize=12)
39
251
  plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
40
- plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
41
- plt.grid(True, alpha=0.3)
42
- plt.legend(fontsize=12)
43
- plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
44
252
 
45
- # Add horizontal line at zero for reference
253
+ # Calculate percentage for title
254
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
255
+ if is_trial_based:
256
+ title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
257
+ else:
258
+ title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"
259
+
260
+ plt.title(title, fontsize=14)
261
+ plt.grid(True, alpha=0.3)
46
262
  plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
47
263
 
48
- # Set y-axis limits to make zero visible
264
+ # Legend
265
+ # Create legend elements for each population
266
+ from matplotlib.lines import Line2D
267
+
268
+ colors = plt.get_cmap("tab10")
269
+ legend_elements = [
270
+ Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
271
+ for i, pop in enumerate(pop_names)
272
+ ]
273
+
274
+ # Add error band legend element for trial-based analysis
275
+ if is_trial_based:
276
+ # Map error type to legend label
277
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
278
+ error_label = error_labels[error_type]
279
+
280
+ legend_elements.append(
281
+ Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
282
+ )
283
+
284
+ plt.legend(handles=legend_elements, fontsize=10, loc="best")
285
+
286
+ # Axis formatting
287
+ if len(frequencies) > 10:
288
+ plt.xticks(frequencies[::2])
289
+ else:
290
+ plt.xticks(frequencies)
291
+ plt.xlim(frequencies[0], frequencies[-1])
292
+
49
293
  y_min, y_max = plt.ylim()
50
294
  plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
51
295
 
52
296
  plt.tight_layout()
53
-
54
297
  plt.show()
55
298
 
56
299
 
@@ -366,3 +609,257 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
366
609
  plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
367
610
 
368
611
  plt.show()
612
+
613
+
614
+ def plot_trial_avg_entrainment(
615
+ spike_df: pd.DataFrame,
616
+ lfp: np.ndarray,
617
+ time_windows: List[Tuple[float, float]],
618
+ entrainment_method: str,
619
+ pop_names: List[str],
620
+ freqs: Union[List[float], np.ndarray],
621
+ firing_quantile: float,
622
+ spike_fs: float = 1000,
623
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
624
+ ) -> None:
625
+ """
626
+ Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
627
+
628
+ Parameters:
629
+ -----------
630
+ spike_df : pd.DataFrame
631
+ Spike data containing timestamps, node_ids, and pop_name columns
632
+ spike_fs : float
633
+ fs for spike data. Default is 1000
634
+ lfp : xarray
635
+ Xarray for a channel of the lfp data
636
+ time_windows : List[Tuple[float, float]]
637
+ List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
638
+ entrainment_method : str
639
+ Method for entrainment calculation ('ppc', 'ppc2' or 'plv')
640
+ pop_names : List[str]
641
+ List of population names to process (e.g., ['FSI', 'LTS'])
642
+ freqs : Union[List[float], np.ndarray]
643
+ Array of frequencies to analyze (Hz)
644
+ firing_quantile : float
645
+ Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
646
+ error_type : str
647
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
648
+
649
+ Raises:
650
+ -------
651
+ ValueError
652
+ If entrainment_method is not 'ppc', 'ppc2' or 'plv'
653
+ If error_type is not 'ci', 'sem', or 'std'
654
+ If no spikes found for a population in a trial
655
+
656
+ Returns:
657
+ --------
658
+ None
659
+ Displays plot and prints summary statistics
660
+ """
661
+ sns.set_style("whitegrid")
662
+ # Validate inputs
663
+ if entrainment_method not in ["ppc", "plv", "ppc2"]:
664
+ raise ValueError("entrainment_method must be 'ppc', ppc2 or 'plv'")
665
+
666
+ if error_type not in ["ci", "sem", "std"]:
667
+ raise ValueError(
668
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
669
+ )
670
+
671
+ if not (0 <= firing_quantile < 1):
672
+ raise ValueError("firing_quantile must be between 0 and 1")
673
+
674
+ # Convert freqs to numpy array for easier indexing
675
+ freqs = np.array(freqs)
676
+
677
+ # Collect all PPC/PLV values across trials for each population
678
+ all_plv_data = {} # Dictionary to store results for each population
679
+
680
+ # Initialize storage for each population
681
+ for pop_name in pop_names:
682
+ all_plv_data[pop_name] = [] # Will be shape (n_trials, n_freqs)
683
+
684
+ # Loop through all pulse groups to collect data
685
+ for trial_idx in range(len(time_windows)):
686
+ plv_lists = {} # Store PLV lists for this trial
687
+
688
+ # Initialize PLV lists for each population
689
+ for pop_name in pop_names:
690
+ plv_lists[pop_name] = []
691
+
692
+ # Filter spikes for this trial
693
+ network_spikes = spike_df[
694
+ (spike_df["timestamps"] >= time_windows[trial_idx][0])
695
+ & (spike_df["timestamps"] <= time_windows[trial_idx][1])
696
+ ].copy()
697
+
698
+ # Process each population
699
+ pop_spike_data = {}
700
+ for pop_name in pop_names:
701
+ # Get spikes for this population
702
+ pop_spikes = network_spikes[network_spikes["pop_name"] == pop_name]
703
+
704
+ if len(pop_spikes) == 0:
705
+ print(f"Warning: No spikes found for population {pop_name} in trial {trial_idx}")
706
+ # Add NaN values for this trial/population
707
+ plv_lists[pop_name] = [np.nan] * len(freqs)
708
+ continue
709
+
710
+ # Filter to get the top firing cells
711
+ # firing_quantile of 0.8 gets the top 20% of firing cells to use
712
+ pop_spikes = bmspikes.find_highest_firing_cells(
713
+ pop_spikes, upper_quantile=firing_quantile
714
+ )
715
+
716
+ if len(pop_spikes) == 0:
717
+ print(
718
+ f"Warning: No high-firing spikes found for population {pop_name} in trial {trial_idx}"
719
+ )
720
+ plv_lists[pop_name] = [np.nan] * len(freqs)
721
+ continue
722
+
723
+ pop_spike_data[pop_name] = pop_spikes
724
+
725
+ # Calculate PPC/PLV for each frequency and each population
726
+ for freq_idx, freq in enumerate(freqs):
727
+ for pop_name in pop_names:
728
+ if pop_name not in pop_spike_data:
729
+ continue # Skip if no data for this population
730
+
731
+ pop_spikes = pop_spike_data[pop_name]
732
+
733
+ try:
734
+ if entrainment_method == "ppc":
735
+ result = bmentr.calculate_ppc(
736
+ pop_spikes["timestamps"].values,
737
+ lfp,
738
+ spike_fs=spike_fs,
739
+ lfp_fs=lfp.fs,
740
+ freq_of_interest=freq,
741
+ filter_method="wavelet",
742
+ ppc_method="gpu",
743
+ )
744
+ elif entrainment_method == "plv":
745
+ result = bmentr.calculate_spike_lfp_plv(
746
+ pop_spikes["timestamps"].values,
747
+ lfp,
748
+ spike_fs=spike_fs,
749
+ lfp_fs=lfp.fs,
750
+ freq_of_interest=freq,
751
+ filter_method="wavelet",
752
+ )
753
+ elif entrainment_method == "ppc2":
754
+ result = bmentr.calculate_ppc2(
755
+ pop_spikes["timestamps"].values,
756
+ lfp,
757
+ spike_fs=spike_fs,
758
+ lfp_fs=lfp.fs,
759
+ freq_of_interest=freq,
760
+ filter_method="wavelet",
761
+ )
762
+
763
+ plv_lists[pop_name].append(result)
764
+
765
+ except Exception as e:
766
+ print(
767
+ f"Warning: Error calculating {entrainment_method} for {pop_name} at {freq}Hz in trial {trial_idx}: {e}"
768
+ )
769
+ plv_lists[pop_name].append(np.nan)
770
+
771
+ # Store this trial's results for each population
772
+ for pop_name in pop_names:
773
+ if pop_name in plv_lists and len(plv_lists[pop_name]) == len(freqs):
774
+ all_plv_data[pop_name].append(plv_lists[pop_name])
775
+ else:
776
+ # Fill with NaNs if data is missing
777
+ all_plv_data[pop_name].append([np.nan] * len(freqs))
778
+
779
+ # Convert to numpy arrays and calculate statistics
780
+ mean_plv = {}
781
+ error_plv = {}
782
+
783
+ for pop_name in pop_names:
784
+ all_plv_data[pop_name] = np.array(all_plv_data[pop_name]) # Shape: (n_trials, n_freqs)
785
+
786
+ # Calculate statistics across trials, ignoring NaN values
787
+ with np.errstate(invalid="ignore"): # Suppress warnings for all-NaN slices
788
+ mean_plv[pop_name] = np.nanmean(all_plv_data[pop_name], axis=0)
789
+
790
+ if error_type == "ci":
791
+ # Calculate 95% confidence intervals using SEM
792
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
793
+ sem_plv = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(valid_counts)
794
+
795
+ # For 95% CI, multiply SEM by appropriate t-value
796
+ # Use minimum valid count across frequencies for conservative t-value
797
+ min_valid_trials = np.min(valid_counts[valid_counts > 1]) # Avoid division by zero
798
+ if min_valid_trials > 1:
799
+ t_value = stats.t.ppf(0.975, min_valid_trials - 1) # 95% CI, two-tailed
800
+ error_plv[pop_name] = t_value * sem_plv
801
+ else:
802
+ error_plv[pop_name] = np.full_like(sem_plv, np.nan)
803
+
804
+ elif error_type == "sem":
805
+ # Calculate standard error of the mean
806
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
807
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(
808
+ valid_counts
809
+ )
810
+
811
+ elif error_type == "std":
812
+ # Calculate standard deviation
813
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
814
+
815
+ # Create the combined plot
816
+ plt.figure(figsize=(12, 8))
817
+
818
+ # Define markers and colors for different populations
819
+ markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
820
+ colors = sns.color_palette(n_colors=len(pop_names))
821
+
822
+ # Plot each population
823
+ for i, pop_name in enumerate(pop_names):
824
+ marker = markers[i % len(markers)] # Cycle through markers if more populations than markers
825
+ color = colors[i]
826
+
827
+ # Only plot if we have valid data
828
+ valid_mask = ~np.isnan(mean_plv[pop_name])
829
+ if np.any(valid_mask):
830
+ plt.plot(
831
+ freqs[valid_mask],
832
+ mean_plv[pop_name][valid_mask],
833
+ marker,
834
+ linewidth=2,
835
+ label=pop_name,
836
+ color=color,
837
+ markersize=6,
838
+ )
839
+
840
+ # Add error bars/shading if available
841
+ if not np.all(np.isnan(error_plv[pop_name])):
842
+ plt.fill_between(
843
+ freqs[valid_mask],
844
+ (mean_plv[pop_name] - error_plv[pop_name])[valid_mask],
845
+ (mean_plv[pop_name] + error_plv[pop_name])[valid_mask],
846
+ alpha=0.3,
847
+ color=color,
848
+ )
849
+
850
+ plt.xlabel("Frequency (Hz)", fontsize=12)
851
+ plt.ylabel(f"{entrainment_method.upper()}", fontsize=12)
852
+
853
+ # Calculate percentage for title and update title based on error type
854
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
855
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
856
+ error_label = error_labels[error_type]
857
+ plt.title(
858
+ f"{entrainment_method.upper()} Across Trials for Top {firing_percentage}% Firing Cells ({error_label})",
859
+ fontsize=14,
860
+ )
861
+
862
+ plt.legend(fontsize=10)
863
+ plt.grid(True, alpha=0.3)
864
+ plt.tight_layout()
865
+ plt.show()
bmtool/bmplot/spikes.py CHANGED
@@ -20,6 +20,7 @@ def raster(
20
20
  tstart: Optional[float] = None,
21
21
  tstop: Optional[float] = None,
22
22
  color_map: Optional[Dict[str, str]] = None,
23
+ dot_size: Optional[float] = 0.3,
23
24
  ) -> Axes:
24
25
  """
25
26
  Plots a raster plot of neural spikes, with different colors for each population.
@@ -40,6 +41,8 @@ def raster(
40
41
  Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
41
42
  color_map : dict, optional
42
43
  Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
44
+ dot_size: float, optional
45
+ Size of the dot to display on the scatterplot
43
46
 
44
47
  Returns:
45
48
  -------
@@ -53,6 +56,7 @@ def raster(
53
56
  - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
54
57
  """
55
58
  # Initialize axes if none provided
59
+ sns.set_style("whitegrid")
56
60
  if ax is None:
57
61
  _, ax = plt.subplots(1, 1)
58
62
 
@@ -102,15 +106,17 @@ def raster(
102
106
  raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
103
107
 
104
108
  # Plot each population with its specified or generated color
109
+ legend_handles = []
105
110
  for pop_name, group in spikes_df.groupby(groupby):
106
- ax.scatter(
107
- group["timestamps"], group["node_ids"], label=pop_name, color=color_map[pop_name], s=0.5
108
- )
111
+ ax.scatter(group["timestamps"], group["node_ids"], color=color_map[pop_name], s=dot_size)
112
+ # Dummy scatter for consistent legend appearance
113
+ handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
114
+ legend_handles.append(handle)
109
115
 
110
116
  # Label axes
111
117
  ax.set_xlabel("Time")
112
118
  ax.set_ylabel("Node ID")
113
- ax.legend(title="Population", loc="upper right", framealpha=0.9, markerfirst=False)
119
+ ax.legend(handles=legend_handles, title="Population", loc="upper right", framealpha=0.9)
114
120
 
115
121
  return ax
116
122
 
@@ -142,6 +148,7 @@ def plot_firing_rate_pop_stats(
142
148
  Axes with the bar plot.
143
149
  """
144
150
  # Ensure groupby is a list for consistent handling
151
+ sns.set_style("whitegrid")
145
152
  if isinstance(groupby, str):
146
153
  groupby = [groupby]
147
154
 
@@ -234,6 +241,7 @@ def plot_firing_rate_distribution(
234
241
  matplotlib.axes.Axes
235
242
  Axes with the selected plot type(s) overlayed.
236
243
  """
244
+ sns.set_style("whitegrid")
237
245
  # Ensure groupby is a list for consistent handling
238
246
  if isinstance(groupby, str):
239
247
  groupby = [groupby]
@@ -287,8 +295,9 @@ def plot_firing_rate_distribution(
287
295
  y="firing_rate",
288
296
  ax=ax,
289
297
  palette=color_map,
290
- inner="quartile",
298
+ inner="box",
291
299
  alpha=0.4,
300
+ cut=0, # This prevents the KDE from extending beyond the data range
292
301
  )
293
302
  elif pt == "swarm":
294
303
  sns.swarmplot(
@@ -308,3 +317,107 @@ def plot_firing_rate_distribution(
308
317
  ax.grid(axis="y", linestyle="--", alpha=0.7)
309
318
 
310
319
  return ax
320
+
321
+
322
+ def plot_firing_rate_vs_node_attribute(
323
+ individual_stats: Optional[pd.DataFrame] = None,
324
+ config: Optional[str] = None,
325
+ nodes: Optional[pd.DataFrame] = None,
326
+ groupby: Optional[str] = None,
327
+ network_name: Optional[str] = None,
328
+ attribute: Optional[str] = None,
329
+ figsize=(12, 8),
330
+ dot_size: float = 3,
331
+ ) -> plt.Figure:
332
+ """
333
+ Plot firing rate vs node attribute for each group in separate subplots.
334
+
335
+ Parameters
336
+ ----------
337
+ individual_stats : pd.DataFrame, optional
338
+ DataFrame containing individual cell firing rates from compute_firing_rate_stats
339
+ config : str, optional
340
+ Path to configuration file for loading node data
341
+ nodes : pd.DataFrame, optional
342
+ Pre-loaded node data as alternative to loading from config
343
+ groupby : str, optional
344
+ Column name in individual_stats to group plots by
345
+ network_name : str, optional
346
+ Name of network to load from config file
347
+ attribute : str, optional
348
+ Node attribute column name to plot against firing rate
349
+ figsize : tuple[int, int], optional
350
+ Figure dimensions (width, height) in inches
351
+ dot_size : float, optional
352
+ Size of scatter plot points
353
+
354
+ Returns
355
+ -------
356
+ matplotlib.figure.Figure
357
+ Figure containing the subplots
358
+
359
+ Raises
360
+ ------
361
+ ValueError
362
+ If neither config nor nodes is provided
363
+ If network_name is missing when using config
364
+ If attribute is not found in nodes DataFrame
365
+ If node_ids column is missing
366
+ If nodes index is not unique
367
+ """
368
+ # Input validation
369
+ if config is None and nodes is None:
370
+ raise ValueError("Must provide either config or nodes")
371
+ if config is not None and nodes is None:
372
+ if network_name is None:
373
+ raise ValueError("network_name required when using config")
374
+ nodes = load_nodes_from_config(config)
375
+ if attribute not in nodes.columns:
376
+ raise ValueError(f"Attribute '{attribute}' not found in nodes DataFrame")
377
+
378
+ # Extract node attribute data
379
+ node_attribute = nodes[attribute]
380
+
381
+ # Validate data structure
382
+ if "node_ids" not in individual_stats.columns:
383
+ raise ValueError("individual_stats missing required 'node_ids' column")
384
+ if not nodes.index.is_unique:
385
+ raise ValueError("nodes DataFrame must have unique index for merging")
386
+
387
+ # Merge firing rate data with node attributes
388
+ merged_df = individual_stats.merge(
389
+ node_attribute, left_on="node_ids", right_index=True, how="left"
390
+ )
391
+
392
+ # Setup subplot layout
393
+ max_groups = 15 # Maximum number of subplots to avoid overcrowding
394
+ unique_groups = merged_df[groupby].unique()
395
+ n_groups = min(len(unique_groups), max_groups)
396
+
397
+ if len(unique_groups) > max_groups:
398
+ print(f"Warning: Limiting display to {max_groups} groups out of {len(unique_groups)}")
399
+ unique_groups = unique_groups[:max_groups]
400
+
401
+ n_cols = min(3, n_groups)
402
+ n_rows = (n_groups + n_cols - 1) // n_cols
403
+
404
+ # Create subplots
405
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
406
+ if n_groups == 1:
407
+ axes = np.array([axes])
408
+ axes = axes.flatten()
409
+
410
+ # Plot each group
411
+ for i, group in enumerate(unique_groups):
412
+ group_df = merged_df[merged_df[groupby] == group]
413
+ axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size)
414
+ axes[i].set_xlabel("Firing Rate (Hz)")
415
+ axes[i].set_ylabel(attribute)
416
+ axes[i].set_title(f"{groupby}: {group}")
417
+
418
+ # Hide unused subplots
419
+ for j in range(i + 1, len(axes)):
420
+ axes[j].set_visible(False)
421
+
422
+ plt.tight_layout()
423
+ plt.show()
bmtool/synapses.py CHANGED
@@ -18,7 +18,7 @@ from scipy.optimize import curve_fit, minimize, minimize_scalar
18
18
  from scipy.signal import find_peaks
19
19
  from tqdm.notebook import tqdm
20
20
 
21
- from bmtool.util.util import load_mechanisms_from_config, load_templates_from_config
21
+ from bmtool.util.util import load_templates_from_config
22
22
 
23
23
 
24
24
  class SynapseTuner:
@@ -69,7 +69,7 @@ class SynapseTuner:
69
69
  neuron.load_mechanisms(mechanisms_dir)
70
70
  h.load_file(templates_dir)
71
71
  else:
72
- load_mechanisms_from_config(config)
72
+ # loads both mech and templates
73
73
  load_templates_from_config(config)
74
74
 
75
75
  self.conn_type_settings = conn_type_settings
@@ -983,7 +983,7 @@ class GapJunctionTuner:
983
983
  neuron.load_mechanisms(mechanisms_dir)
984
984
  h.load_file(templates_dir)
985
985
  else:
986
- load_mechanisms_from_config(config)
986
+ # this will load both mechs and templates
987
987
  load_templates_from_config(config)
988
988
 
989
989
  self.general_settings = general_settings
bmtool/util/util.py CHANGED
@@ -447,6 +447,9 @@ def load_mechanisms_from_config(config=None):
447
447
 
448
448
 
449
449
  def load_templates_from_config(config=None):
450
+ """
451
+ loads the neuron mechanisms and templates provided from BMTK config
452
+ """
450
453
  if config is None:
451
454
  config = "simulation_config.json"
452
455
  config = load_config(config)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.7
3
+ Version: 0.7.2
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -6,29 +6,29 @@ bmtool/graphs.py,sha256=gBTzI6c2BBK49dWGcfWh9c56TAooyn-KaiEy0Im1HcI,6717
6
6
  bmtool/manage.py,sha256=lsgRejp02P-x6QpA7SXcyXdalPhRmypoviIA2uAitQs,608
7
7
  bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
- bmtool/synapses.py,sha256=hRuxRCXVpu0_0egi183qyp343tT-_gZNSxjk9rT5J8Q,66175
9
+ bmtool/synapses.py,sha256=y8UJAqO1jpZY-mY9gVVMN8Dj1r9jD2fI1nAaNQeQfz4,66148
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=TUeV-WfCfVPqilVTjg6Cv1WKOz_zSX7LeE7k2Wuceug,28449
11
+ bmtool/analysis/entrainment.py,sha256=NQloQtVpEWjDzmkZwMWVcm3hSjErHBZfQl1mrBVoIE8,25321
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
- bmtool/analysis/spikes.py,sha256=u7Qu0NVGPDAH5jlgNLv32H1hDAkOlG6P4nKEFeAOkdE,22833
14
+ bmtool/analysis/spikes.py,sha256=3n-xmyEZ7w6CKEND7-aKOAvdDg0lwDuPI5sMdOuPwa0,24637
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
17
- bmtool/bmplot/entrainment.py,sha256=YBHTJ-nK0OgM8CNssM8IyqPNYez9ss9bQi-C5HW4kGw,12593
16
+ bmtool/bmplot/connections.py,sha256=KSORZ43v1B5xfiBN6AnAD7tJySVTkLIY3j_zb2r-YPA,55696
17
+ bmtool/bmplot/entrainment.py,sha256=BrBMerqyiG2YWAO_OEFv7OJf3yeFz3l9jUt4NamluLc,32837
18
18
  bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
19
19
  bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- bmtool/bmplot/spikes.py,sha256=RJOOtmgWhTvyVi1CghoKTtxvt7MF9cJCrJVm5hV5wA4,11210
20
+ bmtool/bmplot/spikes.py,sha256=odzCSMbFRHp9qthSGQ0WzMWUwNQ7R1Z6gLT6VPF_o5Q,15326
21
21
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
23
23
  bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
24
24
  bmtool/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
- bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
26
+ bmtool/util/util.py,sha256=S8sAXwDiISGAqnSXRIgFqxqCRzL5YcxAqP1UGxGA5Z4,62906
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.7.dist-info/METADATA,sha256=-2fuMCtlaM_YoVuYHcuhNAGe-Cw-5Yfb3kqejIV7S6c,3577
31
- bmtool-0.7.1.7.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
32
- bmtool-0.7.1.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.7.dist-info/RECORD,,
29
+ bmtool-0.7.2.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.2.dist-info/METADATA,sha256=JacxbP2RvvbSuuveedTUJjl8KeqOCKf4FlW-UrRmfCk,3575
31
+ bmtool-0.7.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
32
+ bmtool-0.7.2.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.2.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5