bmtool 0.7.1.6__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,61 +1,305 @@
1
+ from typing import List, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
4
  import numpy as np
5
+ import pandas as pd
3
6
  import seaborn as sns
7
+ import xarray as xr
4
8
  from matplotlib.gridspec import GridSpec
5
9
  from scipy import stats
6
10
 
7
-
8
- def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
11
+ from bmtool.analysis import entrainment as bmentr
12
+ from bmtool.analysis import spikes as bmspikes
13
+ from bmtool.analysis.lfp import get_lfp_power
14
+
15
+
16
+ def plot_spike_power_correlation(
17
+ spike_df: pd.DataFrame,
18
+ lfp_data: xr.DataArray,
19
+ firing_quantile: float,
20
+ fs: float,
21
+ pop_names: list,
22
+ filter_method: str = "wavelet",
23
+ bandwidth: float = 2.0,
24
+ lowcut: float = None,
25
+ highcut: float = None,
26
+ freq_range: tuple = (10, 100),
27
+ freq_step: float = 5,
28
+ type_name: str = "raw",
29
+ time_windows: list = None,
30
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
31
+ ):
9
32
  """
10
- Plot the correlation between population spike rates and LFP power.
11
-
12
- Parameters:
13
- -----------
14
- correlation_results : dict
15
- Dictionary with correlation results for calculate_spike_rate_power_correlation
16
- frequencies : array
17
- Array of frequencies analyzed
33
+ Calculate and plot correlation between population spike rates and LFP power across frequencies.
34
+ Supports both single-signal and trial-based analysis with error bars.
35
+
36
+ Parameters
37
+ ----------
38
+ spike_rate : xr.DataArray
39
+ Population spike rates with dimensions (time, population[, type])
40
+ lfp_data : xr.DataArray
41
+ LFP data
42
+ fs : float
43
+ Sampling frequency
18
44
  pop_names : list
19
- List of population names
45
+ List of population names to analyze
46
+ filter_method : str, optional
47
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
48
+ bandwidth : float, optional
49
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
50
+ lowcut : float, optional
51
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
52
+ highcut : float, optional
53
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
54
+ freq_range : tuple, optional
55
+ Min and max frequency to analyze (default: (10, 100))
56
+ freq_step : float, optional
57
+ Step size for frequency analysis (default: 5)
58
+ type_name : str, optional
59
+ Which type of spike rate to use if 'type' dimension exists (default: 'raw')
60
+ time_windows : list, optional
61
+ List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
62
+ error_type : str, optional
63
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
20
64
  """
21
- sns.set_style("whitegrid")
22
- plt.figure(figsize=(10, 6))
23
65
 
66
+ if not (0 <= firing_quantile < 1):
67
+ raise ValueError("firing_quantile must be between 0 and 1")
68
+
69
+ if error_type not in ["ci", "sem", "std"]:
70
+ raise ValueError(
71
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
72
+ )
73
+
74
+ # Setup
75
+ is_trial_based = time_windows is not None
76
+
77
+ # Convert spike_df to spike rate with trial-based filtering of high firing cells
78
+ if is_trial_based:
79
+ # Initialize storage for trial-based spike rates
80
+ trial_rates = []
81
+
82
+ for start_time, end_time in time_windows:
83
+ # Get spikes for this trial
84
+ trial_spikes = spike_df[
85
+ (spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
86
+ ].copy()
87
+
88
+ # Filter for high firing cells within this trial
89
+ trial_spikes = bmspikes.find_highest_firing_cells(
90
+ trial_spikes, upper_quantile=firing_quantile
91
+ )
92
+ # Calculate rate for this trial's filtered spikes
93
+ trial_rate = bmspikes.get_population_spike_rate(
94
+ trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
95
+ )
96
+ trial_rates.append(trial_rate)
97
+
98
+ # Combine all trial rates
99
+ spike_rate = xr.concat(trial_rates, dim="trial")
100
+ else:
101
+ # For non-trial analysis, proceed as before
102
+ spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
103
+ spike_rate = bmspikes.get_population_spike_rate(spike_df)
104
+
105
+ # Setup frequencies for analysis
106
+ frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
107
+
108
+ # Pre-calculate LFP power for all frequencies
109
+ power_by_freq = {}
110
+ for freq in frequencies:
111
+ power_by_freq[freq] = get_lfp_power(
112
+ lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
113
+ )
114
+
115
+ # Calculate correlations
116
+ results = {}
24
117
  for pop in pop_names:
25
- # Extract correlation values for each frequency
26
- corr_values = []
27
- valid_freqs = []
118
+ pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
119
+ results[pop] = {}
28
120
 
29
121
  for freq in frequencies:
30
- if freq in correlation_results[pop]:
31
- corr_values.append(correlation_results[pop][freq]["correlation"])
32
- valid_freqs.append(freq)
122
+ lfp_power = power_by_freq[freq]
123
+
124
+ if not is_trial_based:
125
+ # Single signal analysis
126
+ if len(pop_spike_rate) != len(lfp_power):
127
+ print(f"Warning: Length mismatch for {pop} at {freq} Hz")
128
+ continue
129
+
130
+ corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
131
+ results[pop][freq] = {
132
+ "correlation": corr,
133
+ "p_value": p_val,
134
+ }
135
+ else:
136
+ # Trial-based analysis using pre-filtered trial rates
137
+ trial_correlations = []
138
+
139
+ for trial_idx in range(len(time_windows)):
140
+ # Get time window first
141
+ start_time, end_time = time_windows[trial_idx]
142
+
143
+ # Get the pre-filtered spike rate for this trial
144
+ trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)
145
+
146
+ # Get corresponding LFP power for this trial window
147
+ trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))
148
+
149
+ # Ensure both signals have same time points
150
+ common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)
151
+
152
+ if len(common_times) > 0:
153
+ trial_sr = trial_spike_rate.sel(time=common_times).values
154
+ trial_lfp = trial_lfp_power.sel(time=common_times).values
155
+
156
+ if (
157
+ len(trial_sr) > 1 and len(trial_lfp) > 1
158
+ ): # Need at least 2 points for correlation
159
+ corr, _ = stats.spearmanr(trial_sr, trial_lfp)
160
+ if not np.isnan(corr):
161
+ trial_correlations.append(corr)
162
+
163
+ # Calculate trial statistics
164
+ if len(trial_correlations) > 0:
165
+ trial_correlations = np.array(trial_correlations)
166
+ mean_corr = np.mean(trial_correlations)
167
+
168
+ if len(trial_correlations) > 1:
169
+ if error_type == "ci":
170
+ # Calculate 95% confidence interval using t-distribution
171
+ df = len(trial_correlations) - 1
172
+ sem = stats.sem(trial_correlations)
173
+ t_critical = stats.t.ppf(0.975, df) # 95% CI, two-tailed
174
+ error_val = t_critical * sem
175
+ error_lower = mean_corr - error_val
176
+ error_upper = mean_corr + error_val
177
+ elif error_type == "sem":
178
+ # Calculate standard error of the mean
179
+ sem = stats.sem(trial_correlations)
180
+ error_lower = mean_corr - sem
181
+ error_upper = mean_corr + sem
182
+ elif error_type == "std":
183
+ # Calculate standard deviation
184
+ std = np.std(trial_correlations, ddof=1)
185
+ error_lower = mean_corr - std
186
+ error_upper = mean_corr + std
187
+ else:
188
+ error_lower = error_upper = mean_corr
189
+
190
+ results[pop][freq] = {
191
+ "correlation": mean_corr,
192
+ "error_lower": error_lower,
193
+ "error_upper": error_upper,
194
+ "n_trials": len(trial_correlations),
195
+ "trial_correlations": trial_correlations,
196
+ }
197
+ else:
198
+ # No valid trials
199
+ results[pop][freq] = {
200
+ "correlation": np.nan,
201
+ "error_lower": np.nan,
202
+ "error_upper": np.nan,
203
+ "n_trials": 0,
204
+ "trial_correlations": np.array([]),
205
+ }
206
+
207
+ # Plotting
208
+ sns.set_style("whitegrid")
209
+ plt.figure(figsize=(12, 8))
210
+
211
+ for i, pop in enumerate(pop_names):
212
+ # Extract data for plotting
213
+ plot_freqs = []
214
+ plot_corrs = []
215
+ plot_ci_lower = []
216
+ plot_ci_upper = []
217
+
218
+ for freq in frequencies:
219
+ if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
220
+ plot_freqs.append(freq)
221
+ plot_corrs.append(results[pop][freq]["correlation"])
222
+
223
+ if is_trial_based:
224
+ plot_ci_lower.append(results[pop][freq]["error_lower"])
225
+ plot_ci_upper.append(results[pop][freq]["error_upper"])
226
+
227
+ if len(plot_freqs) == 0:
228
+ continue
229
+
230
+ # Convert to arrays
231
+ plot_freqs = np.array(plot_freqs)
232
+ plot_corrs = np.array(plot_corrs)
33
233
 
34
- # Plot correlation line
35
- plt.plot(valid_freqs, corr_values, marker="o", label=pop, linewidth=2, markersize=6)
234
+ # Get color for this population
235
+ colors = plt.get_cmap("tab10")
236
+ color = colors(i)
36
237
 
238
+ # Plot main line
239
+ plt.plot(
240
+ plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
241
+ )
242
+
243
+ # Plot error bands for trial-based analysis
244
+ if is_trial_based and len(plot_ci_lower) > 0:
245
+ plot_ci_lower = np.array(plot_ci_lower)
246
+ plot_ci_upper = np.array(plot_ci_upper)
247
+ plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)
248
+
249
+ # Formatting
37
250
  plt.xlabel("Frequency (Hz)", fontsize=12)
38
251
  plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
39
- plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
40
- plt.grid(True, alpha=0.3)
41
- plt.legend(fontsize=12)
42
- plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
43
252
 
44
- # Add horizontal line at zero for reference
253
+ # Calculate percentage for title
254
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
255
+ if is_trial_based:
256
+ title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
257
+ else:
258
+ title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"
259
+
260
+ plt.title(title, fontsize=14)
261
+ plt.grid(True, alpha=0.3)
45
262
  plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
46
263
 
47
- # Set y-axis limits to make zero visible
264
+ # Legend
265
+ # Create legend elements for each population
266
+ from matplotlib.lines import Line2D
267
+
268
+ colors = plt.get_cmap("tab10")
269
+ legend_elements = [
270
+ Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
271
+ for i, pop in enumerate(pop_names)
272
+ ]
273
+
274
+ # Add error band legend element for trial-based analysis
275
+ if is_trial_based:
276
+ # Map error type to legend label
277
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
278
+ error_label = error_labels[error_type]
279
+
280
+ legend_elements.append(
281
+ Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
282
+ )
283
+
284
+ plt.legend(handles=legend_elements, fontsize=10, loc="best")
285
+
286
+ # Axis formatting
287
+ if len(frequencies) > 10:
288
+ plt.xticks(frequencies[::2])
289
+ else:
290
+ plt.xticks(frequencies)
291
+ plt.xlim(frequencies[0], frequencies[-1])
292
+
48
293
  y_min, y_max = plt.ylim()
49
294
  plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
50
295
 
51
296
  plt.tight_layout()
52
-
53
297
  plt.show()
54
298
 
55
299
 
56
300
  def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
57
301
  """
58
- Plot an idealized gamma cycle with spike histograms for different neuron populations.
302
+ Plot an idealized cycle with spike histograms for different neuron populations.
59
303
 
60
304
  Parameters:
61
305
  -----------
@@ -120,7 +364,7 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
120
364
  plt.show()
121
365
 
122
366
 
123
- def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
367
+ def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
124
368
  """
125
369
  Plot PPC for all node populations on one graph with mean and standard error.
126
370
 
@@ -200,3 +444,422 @@ def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=No
200
444
  # Adjust layout and save
201
445
  plt.tight_layout()
202
446
  plt.show()
447
+
448
+
449
+ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title=None):
450
+ """
451
+ Plot a swarm plot of the entrainment for different populations at a single frequency.
452
+
453
+ Parameters:
454
+ -----------
455
+ ppc_dict : dict
456
+ Dictionary containing PPC values organized by population, node, and frequency
457
+ pop_names : list
458
+ List of population names to include in the plot
459
+ freq : float or int
460
+ The specific frequency to plot
461
+ save_path : str, optional
462
+ Path to save the figure. If None, figure is just displayed.
463
+
464
+ Returns:
465
+ --------
466
+ matplotlib.figure.Figure
467
+ The figure object for further customization if needed
468
+ """
469
+ # Set the style
470
+ sns.set_style("whitegrid")
471
+
472
+ # Prepare data for the swarm plot
473
+ data_list = []
474
+
475
+ for pop in pop_names:
476
+ values = []
477
+ node_ids = []
478
+
479
+ for node in ppc_dict[pop]:
480
+ if freq in ppc_dict[pop][node] and ppc_dict[pop][node][freq] is not None:
481
+ data_list.append(
482
+ {"Population": pop, "Node": node, "PPC Difference": ppc_dict[pop][node][freq]}
483
+ )
484
+
485
+ # Create DataFrame in long format
486
+ df = pd.DataFrame(data_list)
487
+
488
+ if df.empty:
489
+ print(f"No data available for frequency {freq}.")
490
+ return None
491
+
492
+ # Print mean PPC change for each population)
493
+ for pop in pop_names:
494
+ subset = df[df["Population"] == pop]
495
+ if not subset.empty:
496
+ mean_val = subset["PPC Difference"].mean()
497
+ std_val = subset["PPC Difference"].std()
498
+ n = len(subset)
499
+ sem_val = std_val / np.sqrt(n) # Standard error of the mean
500
+ print(f"{pop}: {mean_val:.4f} ± {sem_val:.4f} (n={n})")
501
+
502
+ # Create figure
503
+ plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
504
+
505
+ # Create swarm plot
506
+ ax = sns.swarmplot(
507
+ x="Population",
508
+ y="PPC Difference",
509
+ data=df,
510
+ size=3,
511
+ # palette='Set2'
512
+ )
513
+
514
+ # Add sample size annotations
515
+ for i, pop in enumerate(pop_names):
516
+ subset = df[df["Population"] == pop]
517
+ if not subset.empty:
518
+ n = len(subset)
519
+ y_min = subset["PPC Difference"].min()
520
+ y_max = subset["PPC Difference"].max()
521
+
522
+ # Position annotation below the lowest point
523
+ plt.annotate(
524
+ f"n={n}", (i, y_min - 0.05 * (y_max - y_min) - 0.05), ha="center", fontsize=10
525
+ )
526
+
527
+ # Add reference line at y=0
528
+ plt.axhline(y=0, color="black", linestyle="-", linewidth=0.5, alpha=0.7)
529
+
530
+ # Add horizontal lines for mean values
531
+ for i, pop in enumerate(pop_names):
532
+ subset = df[df["Population"] == pop]
533
+ if not subset.empty:
534
+ mean_val = subset["PPC Difference"].mean()
535
+ plt.plot([i - 0.25, i + 0.25], [mean_val, mean_val], "r-", linewidth=2)
536
+
537
+ # Calculate and display statistics
538
+ if len(pop_names) > 1:
539
+ # Print statistical test results
540
+ print(f"\nMann-Whitney U Test Results at {freq} Hz:")
541
+ print("-" * 60)
542
+
543
+ # Add p-values for pairwise comparisons
544
+ y_max = df["PPC Difference"].max()
545
+ y_min = df["PPC Difference"].min()
546
+ y_range = y_max - y_min
547
+
548
+ # Perform t-tests between populations if there are at least 2
549
+ for i in range(len(pop_names)):
550
+ for j in range(i + 1, len(pop_names)):
551
+ pop1 = pop_names[i]
552
+ pop2 = pop_names[j]
553
+
554
+ vals1 = df[df["Population"] == pop1]["PPC Difference"].values
555
+ vals2 = df[df["Population"] == pop2]["PPC Difference"].values
556
+
557
+ if len(vals1) > 1 and len(vals2) > 1:
558
+ # Perform Mann-Whitney U test (non-parametric)
559
+ u_stat, p_val = stats.mannwhitneyu(vals1, vals2, alternative="two-sided")
560
+
561
+ # Add significance markers
562
+ sig_str = "ns"
563
+ if p_val < 0.05:
564
+ sig_str = "*"
565
+ if p_val < 0.01:
566
+ sig_str = "**"
567
+ if p_val < 0.001:
568
+ sig_str = "***"
569
+
570
+ # Position the significance bar
571
+ bar_height = y_max + 0.1 * y_range * (1 + (j - i - 1) * 0.5)
572
+
573
+ # Draw the bar
574
+ plt.plot([i, j], [bar_height, bar_height], "k-")
575
+ plt.plot([i, i], [bar_height - 0.02 * y_range, bar_height], "k-")
576
+ plt.plot([j, j], [bar_height - 0.02 * y_range, bar_height], "k-")
577
+
578
+ # Add significance marker
579
+ plt.text(
580
+ (i + j) / 2,
581
+ bar_height + 0.01 * y_range,
582
+ sig_str,
583
+ ha="center",
584
+ va="bottom",
585
+ fontsize=12,
586
+ )
587
+
588
+ # Print the statistical comparison
589
+ print(f"{pop1} vs {pop2}: U={u_stat:.1f}, p={p_val:.4f} {sig_str}")
590
+
591
+ # Add labels and title
592
+ plt.xlabel("Population", fontsize=14)
593
+ plt.ylabel("PPC", fontsize=14)
594
+ if title:
595
+ plt.title(title, fontsize=16)
596
+
597
+ # Adjust y-axis limits to make room for annotations
598
+ y_min, y_max = plt.ylim()
599
+ plt.ylim(y_min - 0.15 * (y_max - y_min), y_max + 0.25 * (y_max - y_min))
600
+
601
+ # Add gridlines
602
+ plt.grid(True, linestyle="--", alpha=0.7, axis="y")
603
+
604
+ # Adjust layout
605
+ plt.tight_layout()
606
+
607
+ # Save figure if path is provided
608
+ if save_path:
609
+ plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
610
+
611
+ plt.show()
612
+
613
+
614
+ def plot_trial_avg_entrainment(
615
+ spike_df: pd.DataFrame,
616
+ lfp: np.ndarray,
617
+ time_windows: List[Tuple[float, float]],
618
+ entrainment_method: str,
619
+ pop_names: List[str],
620
+ freqs: Union[List[float], np.ndarray],
621
+ firing_quantile: float,
622
+ spike_fs: float = 1000,
623
+ error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
624
+ ) -> None:
625
+ """
626
+ Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
627
+
628
+ Parameters:
629
+ -----------
630
+ spike_df : pd.DataFrame
631
+ Spike data containing timestamps, node_ids, and pop_name columns
632
+ spike_fs : float
633
+ fs for spike data. Default is 1000
634
+ lfp : xarray
635
+ Xarray for a channel of the lfp data
636
+ time_windows : List[Tuple[float, float]]
637
+ List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
638
+ entrainment_method : str
639
+ Method for entrainment calculation ('ppc', 'ppc2' or 'plv')
640
+ pop_names : List[str]
641
+ List of population names to process (e.g., ['FSI', 'LTS'])
642
+ freqs : Union[List[float], np.ndarray]
643
+ Array of frequencies to analyze (Hz)
644
+ firing_quantile : float
645
+ Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
646
+ error_type : str
647
+ Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
648
+
649
+ Raises:
650
+ -------
651
+ ValueError
652
+ If entrainment_method is not 'ppc', 'ppc2' or 'plv'
653
+ If error_type is not 'ci', 'sem', or 'std'
654
+ If no spikes found for a population in a trial
655
+
656
+ Returns:
657
+ --------
658
+ None
659
+ Displays plot and prints summary statistics
660
+ """
661
+ sns.set_style("whitegrid")
662
+ # Validate inputs
663
+ if entrainment_method not in ["ppc", "plv", "ppc2"]:
664
+ raise ValueError("entrainment_method must be 'ppc', ppc2 or 'plv'")
665
+
666
+ if error_type not in ["ci", "sem", "std"]:
667
+ raise ValueError(
668
+ "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
669
+ )
670
+
671
+ if not (0 <= firing_quantile < 1):
672
+ raise ValueError("firing_quantile must be between 0 and 1")
673
+
674
+ # Convert freqs to numpy array for easier indexing
675
+ freqs = np.array(freqs)
676
+
677
+ # Collect all PPC/PLV values across trials for each population
678
+ all_plv_data = {} # Dictionary to store results for each population
679
+
680
+ # Initialize storage for each population
681
+ for pop_name in pop_names:
682
+ all_plv_data[pop_name] = [] # Will be shape (n_trials, n_freqs)
683
+
684
+ # Loop through all pulse groups to collect data
685
+ for trial_idx in range(len(time_windows)):
686
+ plv_lists = {} # Store PLV lists for this trial
687
+
688
+ # Initialize PLV lists for each population
689
+ for pop_name in pop_names:
690
+ plv_lists[pop_name] = []
691
+
692
+ # Filter spikes for this trial
693
+ network_spikes = spike_df[
694
+ (spike_df["timestamps"] >= time_windows[trial_idx][0])
695
+ & (spike_df["timestamps"] <= time_windows[trial_idx][1])
696
+ ].copy()
697
+
698
+ # Process each population
699
+ pop_spike_data = {}
700
+ for pop_name in pop_names:
701
+ # Get spikes for this population
702
+ pop_spikes = network_spikes[network_spikes["pop_name"] == pop_name]
703
+
704
+ if len(pop_spikes) == 0:
705
+ print(f"Warning: No spikes found for population {pop_name} in trial {trial_idx}")
706
+ # Add NaN values for this trial/population
707
+ plv_lists[pop_name] = [np.nan] * len(freqs)
708
+ continue
709
+
710
+ # Filter to get the top firing cells
711
+ # firing_quantile of 0.8 gets the top 20% of firing cells to use
712
+ pop_spikes = bmspikes.find_highest_firing_cells(
713
+ pop_spikes, upper_quantile=firing_quantile
714
+ )
715
+
716
+ if len(pop_spikes) == 0:
717
+ print(
718
+ f"Warning: No high-firing spikes found for population {pop_name} in trial {trial_idx}"
719
+ )
720
+ plv_lists[pop_name] = [np.nan] * len(freqs)
721
+ continue
722
+
723
+ pop_spike_data[pop_name] = pop_spikes
724
+
725
+ # Calculate PPC/PLV for each frequency and each population
726
+ for freq_idx, freq in enumerate(freqs):
727
+ for pop_name in pop_names:
728
+ if pop_name not in pop_spike_data:
729
+ continue # Skip if no data for this population
730
+
731
+ pop_spikes = pop_spike_data[pop_name]
732
+
733
+ try:
734
+ if entrainment_method == "ppc":
735
+ result = bmentr.calculate_ppc(
736
+ pop_spikes["timestamps"].values,
737
+ lfp,
738
+ spike_fs=spike_fs,
739
+ lfp_fs=lfp.fs,
740
+ freq_of_interest=freq,
741
+ filter_method="wavelet",
742
+ ppc_method="gpu",
743
+ )
744
+ elif entrainment_method == "plv":
745
+ result = bmentr.calculate_spike_lfp_plv(
746
+ pop_spikes["timestamps"].values,
747
+ lfp,
748
+ spike_fs=spike_fs,
749
+ lfp_fs=lfp.fs,
750
+ freq_of_interest=freq,
751
+ filter_method="wavelet",
752
+ )
753
+ elif entrainment_method == "ppc2":
754
+ result = bmentr.calculate_ppc2(
755
+ pop_spikes["timestamps"].values,
756
+ lfp,
757
+ spike_fs=spike_fs,
758
+ lfp_fs=lfp.fs,
759
+ freq_of_interest=freq,
760
+ filter_method="wavelet",
761
+ )
762
+
763
+ plv_lists[pop_name].append(result)
764
+
765
+ except Exception as e:
766
+ print(
767
+ f"Warning: Error calculating {entrainment_method} for {pop_name} at {freq}Hz in trial {trial_idx}: {e}"
768
+ )
769
+ plv_lists[pop_name].append(np.nan)
770
+
771
+ # Store this trial's results for each population
772
+ for pop_name in pop_names:
773
+ if pop_name in plv_lists and len(plv_lists[pop_name]) == len(freqs):
774
+ all_plv_data[pop_name].append(plv_lists[pop_name])
775
+ else:
776
+ # Fill with NaNs if data is missing
777
+ all_plv_data[pop_name].append([np.nan] * len(freqs))
778
+
779
+ # Convert to numpy arrays and calculate statistics
780
+ mean_plv = {}
781
+ error_plv = {}
782
+
783
+ for pop_name in pop_names:
784
+ all_plv_data[pop_name] = np.array(all_plv_data[pop_name]) # Shape: (n_trials, n_freqs)
785
+
786
+ # Calculate statistics across trials, ignoring NaN values
787
+ with np.errstate(invalid="ignore"): # Suppress warnings for all-NaN slices
788
+ mean_plv[pop_name] = np.nanmean(all_plv_data[pop_name], axis=0)
789
+
790
+ if error_type == "ci":
791
+ # Calculate 95% confidence intervals using SEM
792
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
793
+ sem_plv = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(valid_counts)
794
+
795
+ # For 95% CI, multiply SEM by appropriate t-value
796
+ # Use minimum valid count across frequencies for conservative t-value
797
+ min_valid_trials = np.min(valid_counts[valid_counts > 1]) # Avoid division by zero
798
+ if min_valid_trials > 1:
799
+ t_value = stats.t.ppf(0.975, min_valid_trials - 1) # 95% CI, two-tailed
800
+ error_plv[pop_name] = t_value * sem_plv
801
+ else:
802
+ error_plv[pop_name] = np.full_like(sem_plv, np.nan)
803
+
804
+ elif error_type == "sem":
805
+ # Calculate standard error of the mean
806
+ valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
807
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(
808
+ valid_counts
809
+ )
810
+
811
+ elif error_type == "std":
812
+ # Calculate standard deviation
813
+ error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
814
+
815
+ # Create the combined plot
816
+ plt.figure(figsize=(12, 8))
817
+
818
+ # Define markers and colors for different populations
819
+ markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
820
+ colors = sns.color_palette(n_colors=len(pop_names))
821
+
822
+ # Plot each population
823
+ for i, pop_name in enumerate(pop_names):
824
+ marker = markers[i % len(markers)] # Cycle through markers if more populations than markers
825
+ color = colors[i]
826
+
827
+ # Only plot if we have valid data
828
+ valid_mask = ~np.isnan(mean_plv[pop_name])
829
+ if np.any(valid_mask):
830
+ plt.plot(
831
+ freqs[valid_mask],
832
+ mean_plv[pop_name][valid_mask],
833
+ marker,
834
+ linewidth=2,
835
+ label=pop_name,
836
+ color=color,
837
+ markersize=6,
838
+ )
839
+
840
+ # Add error bars/shading if available
841
+ if not np.all(np.isnan(error_plv[pop_name])):
842
+ plt.fill_between(
843
+ freqs[valid_mask],
844
+ (mean_plv[pop_name] - error_plv[pop_name])[valid_mask],
845
+ (mean_plv[pop_name] + error_plv[pop_name])[valid_mask],
846
+ alpha=0.3,
847
+ color=color,
848
+ )
849
+
850
+ plt.xlabel("Frequency (Hz)", fontsize=12)
851
+ plt.ylabel(f"{entrainment_method.upper()}", fontsize=12)
852
+
853
+ # Calculate percentage for title and update title based on error type
854
+ firing_percentage = round(float((1 - firing_quantile) * 100), 1)
855
+ error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
856
+ error_label = error_labels[error_type]
857
+ plt.title(
858
+ f"{entrainment_method.upper()} Across Trials for Top {firing_percentage}% Firing Cells ({error_label})",
859
+ fontsize=14,
860
+ )
861
+
862
+ plt.legend(fontsize=10)
863
+ plt.grid(True, alpha=0.3)
864
+ plt.tight_layout()
865
+ plt.show()