bmtool 0.7.1.7__py3-none-any.whl → 0.7.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +1 -81
- bmtool/analysis/spikes.py +55 -6
- bmtool/bmplot/connections.py +255 -153
- bmtool/bmplot/entrainment.py +525 -28
- bmtool/bmplot/spikes.py +118 -5
- bmtool/synapses.py +3 -3
- bmtool/util/util.py +3 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/METADATA +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/RECORD +13 -13
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/WHEEL +1 -1
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.7.dist-info → bmtool-0.7.2.1.dist-info}/top_level.txt +0 -0
bmtool/bmplot/entrainment.py
CHANGED
@@ -1,56 +1,299 @@
|
|
1
|
+
from typing import List, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
4
|
import numpy as np
|
3
5
|
import pandas as pd
|
4
6
|
import seaborn as sns
|
7
|
+
import xarray as xr
|
5
8
|
from matplotlib.gridspec import GridSpec
|
6
9
|
from scipy import stats
|
7
10
|
|
8
|
-
|
9
|
-
|
11
|
+
from bmtool.analysis import entrainment as bmentr
|
12
|
+
from bmtool.analysis import spikes as bmspikes
|
13
|
+
from bmtool.analysis.lfp import get_lfp_power
|
14
|
+
|
15
|
+
|
16
|
+
def plot_spike_power_correlation(
|
17
|
+
spike_df: pd.DataFrame,
|
18
|
+
lfp_data: xr.DataArray,
|
19
|
+
firing_quantile: float,
|
20
|
+
fs: float,
|
21
|
+
pop_names: list,
|
22
|
+
filter_method: str = "wavelet",
|
23
|
+
bandwidth: float = 2.0,
|
24
|
+
lowcut: float = None,
|
25
|
+
highcut: float = None,
|
26
|
+
freq_range: tuple = (10, 100),
|
27
|
+
freq_step: float = 5,
|
28
|
+
type_name: str = "raw",
|
29
|
+
time_windows: list = None,
|
30
|
+
error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
|
31
|
+
):
|
10
32
|
"""
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
33
|
+
Calculate and plot correlation between population spike rates and LFP power across frequencies.
|
34
|
+
Supports both single-signal and trial-based analysis with error bars.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
spike_rate : xr.DataArray
|
39
|
+
Population spike rates with dimensions (time, population[, type])
|
40
|
+
lfp_data : xr.DataArray
|
41
|
+
LFP data
|
42
|
+
fs : float
|
43
|
+
Sampling frequency
|
19
44
|
pop_names : list
|
20
|
-
List of population names
|
45
|
+
List of population names to analyze
|
46
|
+
filter_method : str, optional
|
47
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
48
|
+
bandwidth : float, optional
|
49
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
50
|
+
lowcut : float, optional
|
51
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
52
|
+
highcut : float, optional
|
53
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
54
|
+
freq_range : tuple, optional
|
55
|
+
Min and max frequency to analyze (default: (10, 100))
|
56
|
+
freq_step : float, optional
|
57
|
+
Step size for frequency analysis (default: 5)
|
58
|
+
type_name : str, optional
|
59
|
+
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
60
|
+
time_windows : list, optional
|
61
|
+
List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
|
62
|
+
error_type : str, optional
|
63
|
+
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
21
64
|
"""
|
22
|
-
sns.set_style("whitegrid")
|
23
|
-
plt.figure(figsize=(10, 6))
|
24
65
|
|
66
|
+
if not (0 <= firing_quantile < 1):
|
67
|
+
raise ValueError("firing_quantile must be between 0 and 1")
|
68
|
+
|
69
|
+
if error_type not in ["ci", "sem", "std"]:
|
70
|
+
raise ValueError(
|
71
|
+
"error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
|
72
|
+
)
|
73
|
+
|
74
|
+
# Setup
|
75
|
+
is_trial_based = time_windows is not None
|
76
|
+
|
77
|
+
# Convert spike_df to spike rate with trial-based filtering of high firing cells
|
78
|
+
if is_trial_based:
|
79
|
+
# Initialize storage for trial-based spike rates
|
80
|
+
trial_rates = []
|
81
|
+
|
82
|
+
for start_time, end_time in time_windows:
|
83
|
+
# Get spikes for this trial
|
84
|
+
trial_spikes = spike_df[
|
85
|
+
(spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
|
86
|
+
].copy()
|
87
|
+
|
88
|
+
# Filter for high firing cells within this trial
|
89
|
+
trial_spikes = bmspikes.find_highest_firing_cells(
|
90
|
+
trial_spikes, upper_quantile=firing_quantile
|
91
|
+
)
|
92
|
+
# Calculate rate for this trial's filtered spikes
|
93
|
+
trial_rate = bmspikes.get_population_spike_rate(
|
94
|
+
trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
|
95
|
+
)
|
96
|
+
trial_rates.append(trial_rate)
|
97
|
+
|
98
|
+
# Combine all trial rates
|
99
|
+
spike_rate = xr.concat(trial_rates, dim="trial")
|
100
|
+
else:
|
101
|
+
# For non-trial analysis, proceed as before
|
102
|
+
spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
|
103
|
+
spike_rate = bmspikes.get_population_spike_rate(spike_df)
|
104
|
+
|
105
|
+
# Setup frequencies for analysis
|
106
|
+
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
107
|
+
|
108
|
+
# Pre-calculate LFP power for all frequencies
|
109
|
+
power_by_freq = {}
|
110
|
+
for freq in frequencies:
|
111
|
+
power_by_freq[freq] = get_lfp_power(
|
112
|
+
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
113
|
+
)
|
114
|
+
|
115
|
+
# Calculate correlations
|
116
|
+
results = {}
|
25
117
|
for pop in pop_names:
|
26
|
-
|
27
|
-
|
28
|
-
valid_freqs = []
|
118
|
+
pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
|
119
|
+
results[pop] = {}
|
29
120
|
|
30
121
|
for freq in frequencies:
|
31
|
-
|
32
|
-
corr_values.append(correlation_results[pop][freq]["correlation"])
|
33
|
-
valid_freqs.append(freq)
|
122
|
+
lfp_power = power_by_freq[freq]
|
34
123
|
|
35
|
-
|
36
|
-
|
124
|
+
if not is_trial_based:
|
125
|
+
# Single signal analysis
|
126
|
+
if len(pop_spike_rate) != len(lfp_power):
|
127
|
+
print(f"Warning: Length mismatch for {pop} at {freq} Hz")
|
128
|
+
continue
|
129
|
+
|
130
|
+
corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
|
131
|
+
results[pop][freq] = {
|
132
|
+
"correlation": corr,
|
133
|
+
"p_value": p_val,
|
134
|
+
}
|
135
|
+
else:
|
136
|
+
# Trial-based analysis using pre-filtered trial rates
|
137
|
+
trial_correlations = []
|
138
|
+
|
139
|
+
for trial_idx in range(len(time_windows)):
|
140
|
+
# Get time window first
|
141
|
+
start_time, end_time = time_windows[trial_idx]
|
142
|
+
|
143
|
+
# Get the pre-filtered spike rate for this trial
|
144
|
+
trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)
|
145
|
+
|
146
|
+
# Get corresponding LFP power for this trial window
|
147
|
+
trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))
|
148
|
+
|
149
|
+
# Ensure both signals have same time points
|
150
|
+
common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)
|
151
|
+
|
152
|
+
if len(common_times) > 0:
|
153
|
+
trial_sr = trial_spike_rate.sel(time=common_times).values
|
154
|
+
trial_lfp = trial_lfp_power.sel(time=common_times).values
|
155
|
+
|
156
|
+
if (
|
157
|
+
len(trial_sr) > 1 and len(trial_lfp) > 1
|
158
|
+
): # Need at least 2 points for correlation
|
159
|
+
corr, _ = stats.spearmanr(trial_sr, trial_lfp)
|
160
|
+
if not np.isnan(corr):
|
161
|
+
trial_correlations.append(corr)
|
162
|
+
|
163
|
+
# Calculate trial statistics
|
164
|
+
if len(trial_correlations) > 0:
|
165
|
+
trial_correlations = np.array(trial_correlations)
|
166
|
+
mean_corr = np.mean(trial_correlations)
|
167
|
+
|
168
|
+
if len(trial_correlations) > 1:
|
169
|
+
if error_type == "ci":
|
170
|
+
# Calculate 95% confidence interval using t-distribution
|
171
|
+
df = len(trial_correlations) - 1
|
172
|
+
sem = stats.sem(trial_correlations)
|
173
|
+
t_critical = stats.t.ppf(0.975, df) # 95% CI, two-tailed
|
174
|
+
error_val = t_critical * sem
|
175
|
+
error_lower = mean_corr - error_val
|
176
|
+
error_upper = mean_corr + error_val
|
177
|
+
elif error_type == "sem":
|
178
|
+
# Calculate standard error of the mean
|
179
|
+
sem = stats.sem(trial_correlations)
|
180
|
+
error_lower = mean_corr - sem
|
181
|
+
error_upper = mean_corr + sem
|
182
|
+
elif error_type == "std":
|
183
|
+
# Calculate standard deviation
|
184
|
+
std = np.std(trial_correlations, ddof=1)
|
185
|
+
error_lower = mean_corr - std
|
186
|
+
error_upper = mean_corr + std
|
187
|
+
else:
|
188
|
+
error_lower = error_upper = mean_corr
|
189
|
+
|
190
|
+
results[pop][freq] = {
|
191
|
+
"correlation": mean_corr,
|
192
|
+
"error_lower": error_lower,
|
193
|
+
"error_upper": error_upper,
|
194
|
+
"n_trials": len(trial_correlations),
|
195
|
+
"trial_correlations": trial_correlations,
|
196
|
+
}
|
197
|
+
else:
|
198
|
+
# No valid trials
|
199
|
+
results[pop][freq] = {
|
200
|
+
"correlation": np.nan,
|
201
|
+
"error_lower": np.nan,
|
202
|
+
"error_upper": np.nan,
|
203
|
+
"n_trials": 0,
|
204
|
+
"trial_correlations": np.array([]),
|
205
|
+
}
|
206
|
+
|
207
|
+
# Plotting
|
208
|
+
sns.set_style("whitegrid")
|
209
|
+
plt.figure(figsize=(12, 8))
|
37
210
|
|
211
|
+
for i, pop in enumerate(pop_names):
|
212
|
+
# Extract data for plotting
|
213
|
+
plot_freqs = []
|
214
|
+
plot_corrs = []
|
215
|
+
plot_ci_lower = []
|
216
|
+
plot_ci_upper = []
|
217
|
+
|
218
|
+
for freq in frequencies:
|
219
|
+
if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
|
220
|
+
plot_freqs.append(freq)
|
221
|
+
plot_corrs.append(results[pop][freq]["correlation"])
|
222
|
+
|
223
|
+
if is_trial_based:
|
224
|
+
plot_ci_lower.append(results[pop][freq]["error_lower"])
|
225
|
+
plot_ci_upper.append(results[pop][freq]["error_upper"])
|
226
|
+
|
227
|
+
if len(plot_freqs) == 0:
|
228
|
+
continue
|
229
|
+
|
230
|
+
# Convert to arrays
|
231
|
+
plot_freqs = np.array(plot_freqs)
|
232
|
+
plot_corrs = np.array(plot_corrs)
|
233
|
+
|
234
|
+
# Get color for this population
|
235
|
+
colors = plt.get_cmap("tab10")
|
236
|
+
color = colors(i)
|
237
|
+
|
238
|
+
# Plot main line
|
239
|
+
plt.plot(
|
240
|
+
plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
|
241
|
+
)
|
242
|
+
|
243
|
+
# Plot error bands for trial-based analysis
|
244
|
+
if is_trial_based and len(plot_ci_lower) > 0:
|
245
|
+
plot_ci_lower = np.array(plot_ci_lower)
|
246
|
+
plot_ci_upper = np.array(plot_ci_upper)
|
247
|
+
plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)
|
248
|
+
|
249
|
+
# Formatting
|
38
250
|
plt.xlabel("Frequency (Hz)", fontsize=12)
|
39
251
|
plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
|
40
|
-
plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
|
41
|
-
plt.grid(True, alpha=0.3)
|
42
|
-
plt.legend(fontsize=12)
|
43
|
-
plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
|
44
252
|
|
45
|
-
#
|
253
|
+
# Calculate percentage for title
|
254
|
+
firing_percentage = round(float((1 - firing_quantile) * 100), 1)
|
255
|
+
if is_trial_based:
|
256
|
+
title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
|
257
|
+
else:
|
258
|
+
title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"
|
259
|
+
|
260
|
+
plt.title(title, fontsize=14)
|
261
|
+
plt.grid(True, alpha=0.3)
|
46
262
|
plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
|
47
263
|
|
48
|
-
#
|
264
|
+
# Legend
|
265
|
+
# Create legend elements for each population
|
266
|
+
from matplotlib.lines import Line2D
|
267
|
+
|
268
|
+
colors = plt.get_cmap("tab10")
|
269
|
+
legend_elements = [
|
270
|
+
Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
|
271
|
+
for i, pop in enumerate(pop_names)
|
272
|
+
]
|
273
|
+
|
274
|
+
# Add error band legend element for trial-based analysis
|
275
|
+
if is_trial_based:
|
276
|
+
# Map error type to legend label
|
277
|
+
error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
|
278
|
+
error_label = error_labels[error_type]
|
279
|
+
|
280
|
+
legend_elements.append(
|
281
|
+
Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
|
282
|
+
)
|
283
|
+
|
284
|
+
plt.legend(handles=legend_elements, fontsize=10, loc="best")
|
285
|
+
|
286
|
+
# Axis formatting
|
287
|
+
if len(frequencies) > 10:
|
288
|
+
plt.xticks(frequencies[::2])
|
289
|
+
else:
|
290
|
+
plt.xticks(frequencies)
|
291
|
+
plt.xlim(frequencies[0], frequencies[-1])
|
292
|
+
|
49
293
|
y_min, y_max = plt.ylim()
|
50
294
|
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
51
295
|
|
52
296
|
plt.tight_layout()
|
53
|
-
|
54
297
|
plt.show()
|
55
298
|
|
56
299
|
|
@@ -366,3 +609,257 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
|
|
366
609
|
plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
|
367
610
|
|
368
611
|
plt.show()
|
612
|
+
|
613
|
+
|
614
|
+
def plot_trial_avg_entrainment(
|
615
|
+
spike_df: pd.DataFrame,
|
616
|
+
lfp: np.ndarray,
|
617
|
+
time_windows: List[Tuple[float, float]],
|
618
|
+
entrainment_method: str,
|
619
|
+
pop_names: List[str],
|
620
|
+
freqs: Union[List[float], np.ndarray],
|
621
|
+
firing_quantile: float,
|
622
|
+
spike_fs: float = 1000,
|
623
|
+
error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
|
624
|
+
) -> None:
|
625
|
+
"""
|
626
|
+
Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
|
627
|
+
|
628
|
+
Parameters:
|
629
|
+
-----------
|
630
|
+
spike_df : pd.DataFrame
|
631
|
+
Spike data containing timestamps, node_ids, and pop_name columns
|
632
|
+
spike_fs : float
|
633
|
+
fs for spike data. Default is 1000
|
634
|
+
lfp : xarray
|
635
|
+
Xarray for a channel of the lfp data
|
636
|
+
time_windows : List[Tuple[float, float]]
|
637
|
+
List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
|
638
|
+
entrainment_method : str
|
639
|
+
Method for entrainment calculation ('ppc', 'ppc2' or 'plv')
|
640
|
+
pop_names : List[str]
|
641
|
+
List of population names to process (e.g., ['FSI', 'LTS'])
|
642
|
+
freqs : Union[List[float], np.ndarray]
|
643
|
+
Array of frequencies to analyze (Hz)
|
644
|
+
firing_quantile : float
|
645
|
+
Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
|
646
|
+
error_type : str
|
647
|
+
Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
|
648
|
+
|
649
|
+
Raises:
|
650
|
+
-------
|
651
|
+
ValueError
|
652
|
+
If entrainment_method is not 'ppc', 'ppc2' or 'plv'
|
653
|
+
If error_type is not 'ci', 'sem', or 'std'
|
654
|
+
If no spikes found for a population in a trial
|
655
|
+
|
656
|
+
Returns:
|
657
|
+
--------
|
658
|
+
None
|
659
|
+
Displays plot and prints summary statistics
|
660
|
+
"""
|
661
|
+
sns.set_style("whitegrid")
|
662
|
+
# Validate inputs
|
663
|
+
if entrainment_method not in ["ppc", "plv", "ppc2"]:
|
664
|
+
raise ValueError("entrainment_method must be 'ppc', ppc2 or 'plv'")
|
665
|
+
|
666
|
+
if error_type not in ["ci", "sem", "std"]:
|
667
|
+
raise ValueError(
|
668
|
+
"error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
|
669
|
+
)
|
670
|
+
|
671
|
+
if not (0 <= firing_quantile < 1):
|
672
|
+
raise ValueError("firing_quantile must be between 0 and 1")
|
673
|
+
|
674
|
+
# Convert freqs to numpy array for easier indexing
|
675
|
+
freqs = np.array(freqs)
|
676
|
+
|
677
|
+
# Collect all PPC/PLV values across trials for each population
|
678
|
+
all_plv_data = {} # Dictionary to store results for each population
|
679
|
+
|
680
|
+
# Initialize storage for each population
|
681
|
+
for pop_name in pop_names:
|
682
|
+
all_plv_data[pop_name] = [] # Will be shape (n_trials, n_freqs)
|
683
|
+
|
684
|
+
# Loop through all pulse groups to collect data
|
685
|
+
for trial_idx in range(len(time_windows)):
|
686
|
+
plv_lists = {} # Store PLV lists for this trial
|
687
|
+
|
688
|
+
# Initialize PLV lists for each population
|
689
|
+
for pop_name in pop_names:
|
690
|
+
plv_lists[pop_name] = []
|
691
|
+
|
692
|
+
# Filter spikes for this trial
|
693
|
+
network_spikes = spike_df[
|
694
|
+
(spike_df["timestamps"] >= time_windows[trial_idx][0])
|
695
|
+
& (spike_df["timestamps"] <= time_windows[trial_idx][1])
|
696
|
+
].copy()
|
697
|
+
|
698
|
+
# Process each population
|
699
|
+
pop_spike_data = {}
|
700
|
+
for pop_name in pop_names:
|
701
|
+
# Get spikes for this population
|
702
|
+
pop_spikes = network_spikes[network_spikes["pop_name"] == pop_name]
|
703
|
+
|
704
|
+
if len(pop_spikes) == 0:
|
705
|
+
print(f"Warning: No spikes found for population {pop_name} in trial {trial_idx}")
|
706
|
+
# Add NaN values for this trial/population
|
707
|
+
plv_lists[pop_name] = [np.nan] * len(freqs)
|
708
|
+
continue
|
709
|
+
|
710
|
+
# Filter to get the top firing cells
|
711
|
+
# firing_quantile of 0.8 gets the top 20% of firing cells to use
|
712
|
+
pop_spikes = bmspikes.find_highest_firing_cells(
|
713
|
+
pop_spikes, upper_quantile=firing_quantile
|
714
|
+
)
|
715
|
+
|
716
|
+
if len(pop_spikes) == 0:
|
717
|
+
print(
|
718
|
+
f"Warning: No high-firing spikes found for population {pop_name} in trial {trial_idx}"
|
719
|
+
)
|
720
|
+
plv_lists[pop_name] = [np.nan] * len(freqs)
|
721
|
+
continue
|
722
|
+
|
723
|
+
pop_spike_data[pop_name] = pop_spikes
|
724
|
+
|
725
|
+
# Calculate PPC/PLV for each frequency and each population
|
726
|
+
for freq_idx, freq in enumerate(freqs):
|
727
|
+
for pop_name in pop_names:
|
728
|
+
if pop_name not in pop_spike_data:
|
729
|
+
continue # Skip if no data for this population
|
730
|
+
|
731
|
+
pop_spikes = pop_spike_data[pop_name]
|
732
|
+
|
733
|
+
try:
|
734
|
+
if entrainment_method == "ppc":
|
735
|
+
result = bmentr.calculate_ppc(
|
736
|
+
pop_spikes["timestamps"].values,
|
737
|
+
lfp,
|
738
|
+
spike_fs=spike_fs,
|
739
|
+
lfp_fs=lfp.fs,
|
740
|
+
freq_of_interest=freq,
|
741
|
+
filter_method="wavelet",
|
742
|
+
ppc_method="gpu",
|
743
|
+
)
|
744
|
+
elif entrainment_method == "plv":
|
745
|
+
result = bmentr.calculate_spike_lfp_plv(
|
746
|
+
pop_spikes["timestamps"].values,
|
747
|
+
lfp,
|
748
|
+
spike_fs=spike_fs,
|
749
|
+
lfp_fs=lfp.fs,
|
750
|
+
freq_of_interest=freq,
|
751
|
+
filter_method="wavelet",
|
752
|
+
)
|
753
|
+
elif entrainment_method == "ppc2":
|
754
|
+
result = bmentr.calculate_ppc2(
|
755
|
+
pop_spikes["timestamps"].values,
|
756
|
+
lfp,
|
757
|
+
spike_fs=spike_fs,
|
758
|
+
lfp_fs=lfp.fs,
|
759
|
+
freq_of_interest=freq,
|
760
|
+
filter_method="wavelet",
|
761
|
+
)
|
762
|
+
|
763
|
+
plv_lists[pop_name].append(result)
|
764
|
+
|
765
|
+
except Exception as e:
|
766
|
+
print(
|
767
|
+
f"Warning: Error calculating {entrainment_method} for {pop_name} at {freq}Hz in trial {trial_idx}: {e}"
|
768
|
+
)
|
769
|
+
plv_lists[pop_name].append(np.nan)
|
770
|
+
|
771
|
+
# Store this trial's results for each population
|
772
|
+
for pop_name in pop_names:
|
773
|
+
if pop_name in plv_lists and len(plv_lists[pop_name]) == len(freqs):
|
774
|
+
all_plv_data[pop_name].append(plv_lists[pop_name])
|
775
|
+
else:
|
776
|
+
# Fill with NaNs if data is missing
|
777
|
+
all_plv_data[pop_name].append([np.nan] * len(freqs))
|
778
|
+
|
779
|
+
# Convert to numpy arrays and calculate statistics
|
780
|
+
mean_plv = {}
|
781
|
+
error_plv = {}
|
782
|
+
|
783
|
+
for pop_name in pop_names:
|
784
|
+
all_plv_data[pop_name] = np.array(all_plv_data[pop_name]) # Shape: (n_trials, n_freqs)
|
785
|
+
|
786
|
+
# Calculate statistics across trials, ignoring NaN values
|
787
|
+
with np.errstate(invalid="ignore"): # Suppress warnings for all-NaN slices
|
788
|
+
mean_plv[pop_name] = np.nanmean(all_plv_data[pop_name], axis=0)
|
789
|
+
|
790
|
+
if error_type == "ci":
|
791
|
+
# Calculate 95% confidence intervals using SEM
|
792
|
+
valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
|
793
|
+
sem_plv = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(valid_counts)
|
794
|
+
|
795
|
+
# For 95% CI, multiply SEM by appropriate t-value
|
796
|
+
# Use minimum valid count across frequencies for conservative t-value
|
797
|
+
min_valid_trials = np.min(valid_counts[valid_counts > 1]) # Avoid division by zero
|
798
|
+
if min_valid_trials > 1:
|
799
|
+
t_value = stats.t.ppf(0.975, min_valid_trials - 1) # 95% CI, two-tailed
|
800
|
+
error_plv[pop_name] = t_value * sem_plv
|
801
|
+
else:
|
802
|
+
error_plv[pop_name] = np.full_like(sem_plv, np.nan)
|
803
|
+
|
804
|
+
elif error_type == "sem":
|
805
|
+
# Calculate standard error of the mean
|
806
|
+
valid_counts = np.sum(~np.isnan(all_plv_data[pop_name]), axis=0)
|
807
|
+
error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1) / np.sqrt(
|
808
|
+
valid_counts
|
809
|
+
)
|
810
|
+
|
811
|
+
elif error_type == "std":
|
812
|
+
# Calculate standard deviation
|
813
|
+
error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
|
814
|
+
|
815
|
+
# Create the combined plot
|
816
|
+
plt.figure(figsize=(12, 8))
|
817
|
+
|
818
|
+
# Define markers and colors for different populations
|
819
|
+
markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
|
820
|
+
colors = sns.color_palette(n_colors=len(pop_names))
|
821
|
+
|
822
|
+
# Plot each population
|
823
|
+
for i, pop_name in enumerate(pop_names):
|
824
|
+
marker = markers[i % len(markers)] # Cycle through markers if more populations than markers
|
825
|
+
color = colors[i]
|
826
|
+
|
827
|
+
# Only plot if we have valid data
|
828
|
+
valid_mask = ~np.isnan(mean_plv[pop_name])
|
829
|
+
if np.any(valid_mask):
|
830
|
+
plt.plot(
|
831
|
+
freqs[valid_mask],
|
832
|
+
mean_plv[pop_name][valid_mask],
|
833
|
+
marker,
|
834
|
+
linewidth=2,
|
835
|
+
label=pop_name,
|
836
|
+
color=color,
|
837
|
+
markersize=6,
|
838
|
+
)
|
839
|
+
|
840
|
+
# Add error bars/shading if available
|
841
|
+
if not np.all(np.isnan(error_plv[pop_name])):
|
842
|
+
plt.fill_between(
|
843
|
+
freqs[valid_mask],
|
844
|
+
(mean_plv[pop_name] - error_plv[pop_name])[valid_mask],
|
845
|
+
(mean_plv[pop_name] + error_plv[pop_name])[valid_mask],
|
846
|
+
alpha=0.3,
|
847
|
+
color=color,
|
848
|
+
)
|
849
|
+
|
850
|
+
plt.xlabel("Frequency (Hz)", fontsize=12)
|
851
|
+
plt.ylabel(f"{entrainment_method.upper()}", fontsize=12)
|
852
|
+
|
853
|
+
# Calculate percentage for title and update title based on error type
|
854
|
+
firing_percentage = round(float((1 - firing_quantile) * 100), 1)
|
855
|
+
error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
|
856
|
+
error_label = error_labels[error_type]
|
857
|
+
plt.title(
|
858
|
+
f"{entrainment_method.upper()} Across Trials for Top {firing_percentage}% Firing Cells ({error_label})",
|
859
|
+
fontsize=14,
|
860
|
+
)
|
861
|
+
|
862
|
+
plt.legend(fontsize=10)
|
863
|
+
plt.grid(True, alpha=0.3)
|
864
|
+
plt.tight_layout()
|
865
|
+
plt.show()
|