bmtool 0.7.1.7__py3-none-any.whl → 0.7.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,56 +1,299 @@
1
+ from typing import List, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
4
  import numpy as np
3
5
  import pandas as pd
4
6
  import seaborn as sns
7
+ import xarray as xr
5
8
  from matplotlib.gridspec import GridSpec
6
9
  from scipy import stats
7
10
 
8
-
9
- def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
11
+ from bmtool.analysis import entrainment as bmentr
12
+ from bmtool.analysis import spikes as bmspikes
13
+ from bmtool.analysis.lfp import get_lfp_power
14
+
15
+
16
+ def plot_spike_power_correlation(
17
+ spike_df: pd.DataFrame,
18
+ lfp_data: xr.DataArray,
19
+ firing_quantile: float,
20
+ fs: float,
21
+ pop_names: list,
22
+ filter_method: str = "wavelet",
23
+ bandwidth: float = 2.0,
24
+ lowcut: float = None,
25
+ highcut: float = None,
26
+ freq_range: tuple = (10, 100),
27
+ freq_step: float = 5,
28
+ type_name: str = "raw",
29
+ time_windows: list = None,
30
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
31
+ ):
10
32
  """
11
- Plot the correlation between population spike rates and LFP power.
12
-
13
- Parameters:
14
- -----------
15
- correlation_results : dict
16
- Dictionary with correlation results for calculate_spike_rate_power_correlation
17
- frequencies : array
18
- Array of frequencies analyzed
33
+ Calculate and plot correlation between population spike rates and LFP power across frequencies.
34
+ Supports both single-signal and trial-based analysis with error bars.
35
+
36
+ Parameters
37
+ ----------
38
+ spike_rate : xr.DataArray
39
+ Population spike rates with dimensions (time, population[, type])
40
+ lfp_data : xr.DataArray
41
+ LFP data
42
+ fs : float
43
+ Sampling frequency
19
44
  pop_names : list
20
- List of population names
45
+ List of population names to analyze
46
+ filter_method : str, optional
47
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
48
+ bandwidth : float, optional
49
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
50
+ lowcut : float, optional
51
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
52
+ highcut : float, optional
53
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
54
+ freq_range : tuple, optional
55
+ Min and max frequency to analyze (default: (10, 100))
56
+ freq_step : float, optional
57
+ Step size for frequency analysis (default: 5)
58
+ type_name : str, optional
59
+ Which type of spike rate to use if 'type' dimension exists (default: 'raw')
60
+ time_windows : list, optional
61
+ List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
62
+ error_type : str, optional
63
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
21
64
  """
22
- sns.set_style("whitegrid")
23
- plt.figure(figsize=(10, 6))
24
65
 
66
+ if not (0 <= firing_quantile < 1):
67
+ raise ValueError("firing_quantile must be between 0 and 1")
68
+
69
+ if error_type not in ["ci", "sem", "std"]:
70
+ raise ValueError(
71
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
72
+ )
73
+
74
+ # Setup
75
+ is_trial_based = time_windows is not None
76
+
77
+ # Convert spike_df to spike rate with trial-based filtering of high firing cells
78
+ if is_trial_based:
79
+ # Initialize storage for trial-based spike rates
80
+ trial_rates = []
81
+
82
+ for start_time, end_time in time_windows:
83
+ # Get spikes for this trial
84
+ trial_spikes = spike_df[
85
+ (spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
86
+ ].copy()
87
+
88
+ # Filter for high firing cells within this trial
89
+ trial_spikes = bmspikes.find_highest_firing_cells(
90
+ trial_spikes, upper_quantile=firing_quantile
91
+ )
92
+ # Calculate rate for this trial's filtered spikes
93
+ trial_rate = bmspikes.get_population_spike_rate(
94
+ trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
95
+ )
96
+ trial_rates.append(trial_rate)
97
+
98
+ # Combine all trial rates
99
+ spike_rate = xr.concat(trial_rates, dim="trial")
100
+ else:
101
+ # For non-trial analysis, proceed as before
102
+ spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
103
+ spike_rate = bmspikes.get_population_spike_rate(spike_df)
104
+
105
+ # Setup frequencies for analysis
106
+ frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
107
+
108
+ # Pre-calculate LFP power for all frequencies
109
+ power_by_freq = {}
110
+ for freq in frequencies:
111
+ power_by_freq[freq] = get_lfp_power(
112
+ lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
113
+ )
114
+
115
+ # Calculate correlations
116
+ results = {}
25
117
  for pop in pop_names:
26
- # Extract correlation values for each frequency
27
- corr_values = []
28
- valid_freqs = []
118
+ pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
119
+ results[pop] = {}
29
120
 
30
121
  for freq in frequencies:
31
- if freq in correlation_results[pop]:
32
- corr_values.append(correlation_results[pop][freq]["correlation"])
33
- valid_freqs.append(freq)
122
+ lfp_power = power_by_freq[freq]
34
123
 
35
- # Plot correlation line
36
- plt.plot(valid_freqs, corr_values, marker="o", label=pop, linewidth=2, markersize=6)
124
+ if not is_trial_based:
125
+ # Single signal analysis
126
+ if len(pop_spike_rate) != len(lfp_power):
127
+ print(f"Warning: Length mismatch for {pop} at {freq} Hz")
128
+ continue
129
+
130
+ corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
131
+ results[pop][freq] = {
132
+ "correlation": corr,
133
+ "p_value": p_val,
134
+ }
135
+ else:
136
+ # Trial-based analysis using pre-filtered trial rates
137
+ trial_correlations = []
138
+
139
+ for trial_idx in range(len(time_windows)):
140
+ # Get time window first
141
+ start_time, end_time = time_windows[trial_idx]
142
+
143
+ # Get the pre-filtered spike rate for this trial
144
+ trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)
145
+
146
+ # Get corresponding LFP power for this trial window
147
+ trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))
148
+
149
+ # Ensure both signals have same time points
150
+ common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)
151
+
152
+ if len(common_times) > 0:
153
+ trial_sr = trial_spike_rate.sel(time=common_times).values
154
+ trial_lfp = trial_lfp_power.sel(time=common_times).values
155
+
156
+ if (
157
+ len(trial_sr) > 1 and len(trial_lfp) > 1
158
+ ): # Need at least 2 points for correlation
159
+ corr, _ = stats.spearmanr(trial_sr, trial_lfp)
160
+ if not np.isnan(corr):
161
+ trial_correlations.append(corr)
162
+
163
+ # Calculate trial statistics
164
+ if len(trial_correlations) > 0:
165
+ trial_correlations = np.array(trial_correlations)
166
+ mean_corr = np.mean(trial_correlations)
167
+
168
+ if len(trial_correlations) > 1:
169
+ if error_type == "ci":
170
+ # Calculate 95% confidence interval using t-distribution
171
+ df = len(trial_correlations) - 1
172
+ sem = stats.sem(trial_correlations)
173
+ t_critical = stats.t.ppf(0.975, df) # 95% CI, two-tailed
174
+ error_val = t_critical * sem
175
+ error_lower = mean_corr - error_val
176
+ error_upper = mean_corr + error_val
177
+ elif error_type == "sem":
178
+ # Calculate standard error of the mean
179
+ sem = stats.sem(trial_correlations)
180
+ error_lower = mean_corr - sem
181
+ error_upper = mean_corr + sem
182
+ elif error_type == "std":
183
+ # Calculate standard deviation
184
+ std = np.std(trial_correlations, ddof=1)
185
+ error_lower = mean_corr - std
186
+ error_upper = mean_corr + std
187
+ else:
188
+ error_lower = error_upper = mean_corr
189
+
190
+ results[pop][freq] = {
191
+ "correlation": mean_corr,
192
+ "error_lower": error_lower,
193
+ "error_upper": error_upper,
194
+ "n_trials": len(trial_correlations),
195
+ "trial_correlations": trial_correlations,
196
+ }
197
+ else:
198
+ # No valid trials
199
+ results[pop][freq] = {
200
+ "correlation": np.nan,
201
+ "error_lower": np.nan,
202
+ "error_upper": np.nan,
203
+ "n_trials": 0,
204
+ "trial_correlations": np.array([]),
205
+ }
206
+
207
+ # Plotting
208
+ sns.set_style("whitegrid")
209
+ plt.figure(figsize=(12, 8))
37
210
 
211
+ for i, pop in enumerate(pop_names):
212
+ # Extract data for plotting
213
+ plot_freqs = []
214
+ plot_corrs = []
215
+ plot_ci_lower = []
216
+ plot_ci_upper = []
217
+
218
+ for freq in frequencies:
219
+ if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
220
+ plot_freqs.append(freq)
221
+ plot_corrs.append(results[pop][freq]["correlation"])
222
+
223
+ if is_trial_based:
224
+ plot_ci_lower.append(results[pop][freq]["error_lower"])
225
+ plot_ci_upper.append(results[pop][freq]["error_upper"])
226
+
227
+ if len(plot_freqs) == 0:
228
+ continue
229
+
230
+ # Convert to arrays
231
+ plot_freqs = np.array(plot_freqs)
232
+ plot_corrs = np.array(plot_corrs)
233
+
234
+ # Get color for this population
235
+ colors = plt.get_cmap("tab10")
236
+ color = colors(i)
237
+
238
+ # Plot main line
239
+ plt.plot(
240
+ plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
241
+ )
242
+
243
+ # Plot error bands for trial-based analysis
244
+ if is_trial_based and len(plot_ci_lower) > 0:
245
+ plot_ci_lower = np.array(plot_ci_lower)
246
+ plot_ci_upper = np.array(plot_ci_upper)
247
+ plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)
248
+
249
+ # Formatting
38
250
  plt.xlabel("Frequency (Hz)", fontsize=12)
39
251
  plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
40
- plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
41
- plt.grid(True, alpha=0.3)
42
- plt.legend(fontsize=12)
43
- plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
44
252
 
45
- # Add horizontal line at zero for reference
253
+ # Calculate percentage for title
254
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
255
+ if is_trial_based:
256
+ title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
257
+ else:
258
+ title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"
259
+
260
+ plt.title(title, fontsize=14)
261
+ plt.grid(True, alpha=0.3)
46
262
  plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
47
263
 
48
- # Set y-axis limits to make zero visible
264
+ # Legend
265
+ # Create legend elements for each population
266
+ from matplotlib.lines import Line2D
267
+
268
+ colors = plt.get_cmap("tab10")
269
+ legend_elements = [
270
+ Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
271
+ for i, pop in enumerate(pop_names)
272
+ ]
273
+
274
+ # Add error band legend element for trial-based analysis
275
+ if is_trial_based:
276
+ # Map error type to legend label
277
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
278
+ error_label = error_labels[error_type]
279
+
280
+ legend_elements.append(
281
+ Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
282
+ )
283
+
284
+ plt.legend(handles=legend_elements, fontsize=10, loc="best")
285
+
286
+ # Axis formatting
287
+ if len(frequencies) > 10:
288
+ plt.xticks(frequencies[::2])
289
+ else:
290
+ plt.xticks(frequencies)
291
+ plt.xlim(frequencies[0], frequencies[-1])
292
+
49
293
  y_min, y_max = plt.ylim()
50
294
  plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
51
295
 
52
296
  plt.tight_layout()
53
-
54
297
  plt.show()
55
298
 
56
299
 
@@ -366,3 +609,257 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
366
609
  plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
367
610
 
368
611
  plt.show()
612
+
613
+
614
+ def plot_trial_avg_entrainment(
615
+ spike_df: pd.DataFrame,
616
+ lfp: np.ndarray,
617
+ time_windows: List[Tuple[float, float]],
618
+ entrainment_method: str,
619
+ pop_names: List[str],
620
+ freqs: Union[List[float], np.ndarray],
621
+ firing_quantile: float,
622
+ spike_fs: float = 1000,
623
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
624
+ ) -> None:
625
+ """
626
+ Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
627
+
628
+ Parameters:
629
+ -----------
630
+ spike_df : pd.DataFrame
631
+ Spike data containing timestamps, node_ids, and pop_name columns
632
+ spike_fs : float
633
+ fs for spike data. Default is 1000
634
+ lfp : xarray
635
+ Xarray for a channel of the lfp data
636
+ time_windows : List[Tuple[float, float]]
637
+ List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
638
+ entrainment_method : str
639
+ Method for entrainment calculation ('ppc', 'ppc2' or 'plv')
640
+ pop_names : List[str]
641
+ List of population names to process (e.g., ['FSI', 'LTS'])
642
+ freqs : Union[List[float], np.ndarray]
643
+ Array of frequencies to analyze (Hz)
644
+ firing_quantile : float
645
+ Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
646
+ error_type : str
647
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
648
+
649
+ Raises:
650
+ -------
651
+ ValueError
652
+ If entrainment_method is not 'ppc', 'ppc2' or 'plv'
653
+ If error_type is not 'ci', 'sem', or 'std'
654
+ If no spikes found for a population in a trial
655
+
656
+ Returns:
657
+ --------
658
+ None
659
+ Displays plot and prints summary statistics
660
+ """
661
+ sns.set_style("whitegrid")
662
+ # Validate inputs
663
+ if entrainment_method not in ["ppc", "plv", "ppc2"]:
664
+ raise ValueError("entrainment_method must be 'ppc', ppc2 or 'plv'")
665
+
666
+ if error_type not in ["ci", "sem", "std"]:
667
+ raise ValueError(
668
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
669
+ )
670
+
671
+ if not (0 <= firing_quantile < 1):
672
+ raise ValueError("firing_quantile must be between 0 and 1")
673
+
674
+ # Convert freqs to numpy array for easier indexing
675
+ freqs = np.array(freqs)
676
+
677
+ # Collect all PPC/PLV values across trials for each population
678
+ all_plv_data = {} # Dictionary to store results for each population
679
+
680
+ # Initialize storage for each population
681
+ for pop_name in pop_names:
682
+ all_plv_data[pop_name] = [] # Will be shape (n_trials, n_freqs)
683
+
684
+ # Loop through all pulse groups to collect data
685
+ for trial_idx in range(len(time_windows)):
686
+ plv_lists = {} # Store PLV lists for this trial
687
+
688
+ # Initialize PLV lists for each population
689
+ for pop_name in pop_names:
690
+ plv_lists[pop_name] = []
691
+
692
+ # Filter spikes for this trial
693
+ network_spikes = spike_df[
694
+ (spike_df["timestamps"] >= time_windows[trial_idx][0])
695
+ & (spike_df["timestamps"] <= time_windows[trial_idx][1])
696
+ ].copy()
697
+
698
+ # Process each population
699
+ pop_spike_data = {}
700
+ for pop_name in pop_names:
701
+ # Get spikes for this population
702
+ pop_spikes = network_spikes[network_spikes["pop_name"] == pop_name]
703
+
704
+ if len(pop_spikes) == 0:
705
+ print(f"Warning: No spikes found for population {pop_name} in trial {trial_idx}")
706
+ # Add NaN values for this trial/population
707
+ plv_lists[pop_name] = [np.nan] * len(freqs)
708
+ continue
709
+
710
+ # Filter to get the top firing cells
711
+ # firing_quantile of 0.8 gets the top 20% of firing cells to use
712
+ pop_spikes = bmspikes.find_highest_firing_cells(
713
+ pop_spikes, upper_quantile=firing_quantile
714
+ )
715
+
716
+ if len(pop_spikes) == 0:
717
+ print(
718
+ f"Warning: No high-firing spikes found for population {pop_name} in trial {trial_idx}"
719
+ )
720
+ plv_lists[pop_name] = [np.nan] * len(freqs)
721
+ continue
722
+
723
+ pop_spike_data[pop_name] = pop_spikes
724
+
725
+ # Calculate PPC/PLV for each frequency and each population
726
+ for freq_idx, freq in enumerate(freqs):
727
+ for pop_name in pop_names:
728
+ if pop_name not in pop_spike_data:
729
+ continue # Skip if no data for this population
730
+
731
+ pop_spikes = pop_spike_data[pop_name]
732
+
733
+ try:
734
+ if entrainment_method == "ppc":
735
+ result = bmentr.calculate_ppc(
736
+ pop_spikes["timestamps"].values,
737
+ lfp,
738
+ spike_fs=spike_fs,
739
+ lfp_fs=lfp.fs,
740
+ freq_of_interest=freq,
741
+ filter_method="wavelet",
742
+ ppc_method="gpu",
743
+ )
744
+ elif entrainment_method == "plv":
745
+ result = bmentr.calculate_spike_lfp_plv(
746
+ pop_spikes["timestamps"].values,
747
+ lfp,
748
+ spike_fs=spike_fs,
749
+ lfp_fs=lfp.fs,
750
+ freq_of_interest=freq,
751
+ filter_method="wavelet",
752
+ )
753
+ elif entrainment_method == "ppc2":
754
+ result = bmentr.calculate_ppc2(
755
+ pop_spikes["timestamps"].values,
756
+ lfp,
757
+ spike_fs=spike_fs,
758
+ lfp_fs=lfp.fs,
759
+ freq_of_interest=freq,
760
+ filter_method="wavelet",
761
+ )
762
+
763
+ plv_lists[pop_name].append(result)
764
+
765
+ except Exception as e:
766
+ print(
767
+ f"Warning: Error calculating {entrainment_method} for {pop_name} at {freq}Hz in trial {trial_idx}: {e}"
768
+ )
769
+ plv_lists[pop_name].append(np.nan)
770
+
771
+ # Store this trial's results for each population
772
+ for pop_name in pop_names:
773
+ if pop_name in plv_lists and len(plv_lists[pop_name]) == len(freqs):
774
+ all_plv_data[pop_name].append(plv_lists[pop_name])
775
+ else:
776
+ # Fill with NaNs if data is missing
777
+ all_plv_data[pop_name].append([np.nan] * len(freqs))
778
+
779
+ # Convert to numpy arrays and calculate statistics
780
+ mean_plv = {}
781
+ error_plv = {}
782
+
783
+ for pop_name in pop_names:
784
+ all_plv_data[pop_name] = np.array(all_plv_data[pop_name]) # Shape: (n_trials, n_freqs)
785
+
786
+ # Calculate statistics across trials, ignoring NaN values
787
+ with np.errstate(invalid="ignore"): # Suppress warnings for all-NaN slices
788
+ mean_plv[pop_name] = np.nanmean(all_plv_data[pop_name], axis=0)
789
+
790
+ if error_type == "ci":
791
+ # Calculate 95% confidence intervals using SEM
792
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
793
+ sem_plv = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(valid_counts)
794
+
795
+ # For 95% CI, multiply SEM by appropriate t-value
796
+ # Use minimum valid count across frequencies for conservative t-value
797
+ min_valid_trials = np.min(valid_counts[valid_counts > 1]) # Avoid division by zero
798
+ if min_valid_trials > 1:
799
+ t_value = stats.t.ppf(0.975, min_valid_trials - 1) # 95% CI, two-tailed
800
+ error_plv[pop_name] = t_value * sem_plv
801
+ else:
802
+ error_plv[pop_name] = np.full_like(sem_plv, np.nan)
803
+
804
+ elif error_type == "sem":
805
+ # Calculate standard error of the mean
806
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
807
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(
808
+ valid_counts
809
+ )
810
+
811
+ elif error_type == "std":
812
+ # Calculate standard deviation
813
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
814
+
815
+ # Create the combined plot
816
+ plt.figure(figsize=(12, 8))
817
+
818
+ # Define markers and colors for different populations
819
+ markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
820
+ colors = sns.color_palette(n_colors=len(pop_names))
821
+
822
+ # Plot each population
823
+ for i, pop_name in enumerate(pop_names):
824
+ marker = markers[i % len(markers)] # Cycle through markers if more populations than markers
825
+ color = colors[i]
826
+
827
+ # Only plot if we have valid data
828
+ valid_mask = ~np.isnan(mean_plv[pop_name])
829
+ if np.any(valid_mask):
830
+ plt.plot(
831
+ freqs[valid_mask],
832
+ mean_plv[pop_name][valid_mask],
833
+ marker,
834
+ linewidth=2,
835
+ label=pop_name,
836
+ color=color,
837
+ markersize=6,
838
+ )
839
+
840
+ # Add error bars/shading if available
841
+ if not np.all(np.isnan(error_plv[pop_name])):
842
+ plt.fill_between(
843
+ freqs[valid_mask],
844
+ (mean_plv[pop_name] - error_plv[pop_name])[valid_mask],
845
+ (mean_plv[pop_name] + error_plv[pop_name])[valid_mask],
846
+ alpha=0.3,
847
+ color=color,
848
+ )
849
+
850
+ plt.xlabel("Frequency (Hz)", fontsize=12)
851
+ plt.ylabel(f"{entrainment_method.upper()}", fontsize=12)
852
+
853
+ # Calculate percentage for title and update title based on error type
854
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
855
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
856
+ error_label = error_labels[error_type]
857
+ plt.title(
858
+ f"{entrainment_method.upper()} Across Trials for Top {firing_percentage}% Firing Cells ({error_label})",
859
+ fontsize=14,
860
+ )
861
+
862
+ plt.legend(fontsize=10)
863
+ plt.grid(True, alpha=0.3)
864
+ plt.tight_layout()
865
+ plt.show()