bmtool 0.7.1.1__py3-none-any.whl → 0.7.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +166 -116
- bmtool/analysis/spikes.py +55 -0
- bmtool/bmplot/entrainment.py +152 -0
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/METADATA +1 -2
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/RECORD +9 -9
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/WHEEL +0 -0
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.1.dist-info → bmtool-0.7.1.3.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
Module for entrainment analysis
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Dict, List
|
5
|
+
from typing import Dict, List, Optional, Union
|
6
6
|
|
7
7
|
import numba
|
8
8
|
import numpy as np
|
@@ -41,7 +41,7 @@ def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.
|
|
41
41
|
(timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
|
42
42
|
].copy()
|
43
43
|
# set the time axis of the spikes to match the lfp
|
44
|
-
timestamps = timestamps - lfp.time.values[0]
|
44
|
+
# timestamps = timestamps - lfp.time.values[0]
|
45
45
|
return timestamps
|
46
46
|
|
47
47
|
|
@@ -127,33 +127,33 @@ def calculate_signal_signal_plv(
|
|
127
127
|
return plv
|
128
128
|
|
129
129
|
|
130
|
-
def
|
131
|
-
spike_times: np.ndarray
|
132
|
-
lfp_data
|
133
|
-
spike_fs: float
|
134
|
-
lfp_fs: float
|
135
|
-
filter_method: str = "
|
136
|
-
freq_of_interest: float = None,
|
137
|
-
lowcut: float = None,
|
138
|
-
highcut: float = None,
|
130
|
+
def _get_spike_phases(
|
131
|
+
spike_times: np.ndarray,
|
132
|
+
lfp_data: Union[np.ndarray, xr.DataArray],
|
133
|
+
spike_fs: float,
|
134
|
+
lfp_fs: float,
|
135
|
+
filter_method: str = "wavelet",
|
136
|
+
freq_of_interest: Optional[float] = None,
|
137
|
+
lowcut: Optional[float] = None,
|
138
|
+
highcut: Optional[float] = None,
|
139
139
|
bandwidth: float = 2.0,
|
140
|
-
filtered_lfp_phase: np.ndarray = None,
|
141
|
-
) ->
|
140
|
+
filtered_lfp_phase: Optional[Union[np.ndarray, xr.DataArray]] = None,
|
141
|
+
) -> np.ndarray:
|
142
142
|
"""
|
143
|
-
|
143
|
+
Helper function to get spike phases from LFP data.
|
144
144
|
|
145
145
|
Parameters
|
146
146
|
----------
|
147
147
|
spike_times : np.ndarray
|
148
148
|
Array of spike times
|
149
|
-
lfp_data : np.ndarray
|
149
|
+
lfp_data : Union[np.ndarray, xr.DataArray]
|
150
150
|
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
151
|
-
spike_fs : float
|
152
|
-
Sampling frequency in Hz of the spike times
|
151
|
+
spike_fs : float
|
152
|
+
Sampling frequency in Hz of the spike times
|
153
153
|
lfp_fs : float
|
154
154
|
Sampling frequency in Hz of the LFP data
|
155
155
|
filter_method : str, optional
|
156
|
-
Method to use for filtering, either 'wavelet' or 'butter' (default: '
|
156
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
157
157
|
freq_of_interest : float, optional
|
158
158
|
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
159
159
|
lowcut : float, optional
|
@@ -167,12 +167,9 @@ def calculate_spike_lfp_plv(
|
|
167
167
|
|
168
168
|
Returns
|
169
169
|
-------
|
170
|
-
|
171
|
-
|
170
|
+
np.ndarray
|
171
|
+
Array of phases at spike times
|
172
172
|
"""
|
173
|
-
|
174
|
-
if spike_fs is None:
|
175
|
-
spike_fs = lfp_fs
|
176
173
|
# Convert spike times to sample indices
|
177
174
|
spike_times_seconds = spike_times / spike_fs
|
178
175
|
|
@@ -194,7 +191,7 @@ def calculate_spike_lfp_plv(
|
|
194
191
|
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
195
192
|
|
196
193
|
if len(valid_indices) <= 1:
|
197
|
-
return
|
194
|
+
return np.array([])
|
198
195
|
|
199
196
|
# Get instantaneous phase
|
200
197
|
if filtered_lfp_phase is None:
|
@@ -212,10 +209,73 @@ def calculate_spike_lfp_plv(
|
|
212
209
|
|
213
210
|
# Get phases at spike times
|
214
211
|
if isinstance(instantaneous_phase, xr.DataArray):
|
215
|
-
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
212
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices, method="nearest").values
|
216
213
|
else:
|
217
214
|
spike_phases = instantaneous_phase[valid_indices]
|
218
215
|
|
216
|
+
return spike_phases
|
217
|
+
|
218
|
+
|
219
|
+
def calculate_spike_lfp_plv(
|
220
|
+
spike_times: np.ndarray = None,
|
221
|
+
lfp_data=None,
|
222
|
+
spike_fs: float = None,
|
223
|
+
lfp_fs: float = None,
|
224
|
+
filter_method: str = "butter",
|
225
|
+
freq_of_interest: float = None,
|
226
|
+
lowcut: float = None,
|
227
|
+
highcut: float = None,
|
228
|
+
bandwidth: float = 2.0,
|
229
|
+
filtered_lfp_phase: np.ndarray = None,
|
230
|
+
) -> float:
|
231
|
+
"""
|
232
|
+
Calculate spike-lfp unbiased phase locking value
|
233
|
+
|
234
|
+
Parameters
|
235
|
+
----------
|
236
|
+
spike_times : np.ndarray
|
237
|
+
Array of spike times
|
238
|
+
lfp_data : np.ndarray
|
239
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
240
|
+
spike_fs : float, optional
|
241
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
242
|
+
lfp_fs : float
|
243
|
+
Sampling frequency in Hz of the LFP data
|
244
|
+
filter_method : str, optional
|
245
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
|
246
|
+
freq_of_interest : float, optional
|
247
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
248
|
+
lowcut : float, optional
|
249
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
250
|
+
highcut : float, optional
|
251
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
252
|
+
bandwidth : float, optional
|
253
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
254
|
+
filtered_lfp_phase : np.ndarray, optional
|
255
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
256
|
+
|
257
|
+
Returns
|
258
|
+
-------
|
259
|
+
float
|
260
|
+
Phase Locking Value (unbiased)
|
261
|
+
"""
|
262
|
+
|
263
|
+
spike_phases = _get_spike_phases(
|
264
|
+
spike_times=spike_times,
|
265
|
+
lfp_data=lfp_data,
|
266
|
+
spike_fs=spike_fs,
|
267
|
+
lfp_fs=lfp_fs,
|
268
|
+
filter_method=filter_method,
|
269
|
+
freq_of_interest=freq_of_interest,
|
270
|
+
lowcut=lowcut,
|
271
|
+
highcut=highcut,
|
272
|
+
bandwidth=bandwidth,
|
273
|
+
filtered_lfp_phase=filtered_lfp_phase,
|
274
|
+
)
|
275
|
+
|
276
|
+
if len(spike_phases) <= 1:
|
277
|
+
return 0
|
278
|
+
|
219
279
|
# Number of spikes
|
220
280
|
N = len(spike_phases)
|
221
281
|
|
@@ -316,57 +376,26 @@ def calculate_ppc(
|
|
316
376
|
float
|
317
377
|
Pairwise Phase Consistency value
|
318
378
|
"""
|
319
|
-
if spike_fs is None:
|
320
|
-
spike_fs = lfp_fs
|
321
|
-
# Convert spike times to sample indices
|
322
|
-
spike_times_seconds = spike_times / spike_fs
|
323
|
-
|
324
|
-
# Then convert from seconds to samples at the new sampling rate
|
325
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
326
379
|
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
if len(valid_indices) <= 1:
|
380
|
+
spike_phases = _get_spike_phases(
|
381
|
+
spike_times=spike_times,
|
382
|
+
lfp_data=lfp_data,
|
383
|
+
spike_fs=spike_fs,
|
384
|
+
lfp_fs=lfp_fs,
|
385
|
+
filter_method=filter_method,
|
386
|
+
freq_of_interest=freq_of_interest,
|
387
|
+
lowcut=lowcut,
|
388
|
+
highcut=highcut,
|
389
|
+
bandwidth=bandwidth,
|
390
|
+
filtered_lfp_phase=filtered_lfp_phase,
|
391
|
+
)
|
392
|
+
|
393
|
+
if len(spike_phases) <= 1:
|
342
394
|
return 0
|
343
395
|
|
344
|
-
# Get instantaneous phase
|
345
|
-
if filtered_lfp_phase is None:
|
346
|
-
instantaneous_phase = get_lfp_phase(
|
347
|
-
lfp_data=lfp_data,
|
348
|
-
filter_method=filter_method,
|
349
|
-
freq_of_interest=freq_of_interest,
|
350
|
-
lowcut=lowcut,
|
351
|
-
highcut=highcut,
|
352
|
-
bandwidth=bandwidth,
|
353
|
-
fs=lfp_fs,
|
354
|
-
)
|
355
|
-
else:
|
356
|
-
instantaneous_phase = filtered_lfp_phase
|
357
|
-
|
358
|
-
# Get phases at spike times
|
359
|
-
if isinstance(instantaneous_phase, xr.DataArray):
|
360
|
-
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
361
|
-
else:
|
362
|
-
spike_phases = instantaneous_phase[valid_indices]
|
363
|
-
|
364
396
|
n_spikes = len(spike_phases)
|
365
397
|
|
366
398
|
# Calculate PPC (Pairwise Phase Consistency)
|
367
|
-
if n_spikes <= 1:
|
368
|
-
return 0
|
369
|
-
|
370
399
|
# Explicit calculation of pairwise phase consistency
|
371
400
|
# Vectorized computation for efficiency
|
372
401
|
if ppc_method == "numpy":
|
@@ -434,56 +463,25 @@ def calculate_ppc2(
|
|
434
463
|
Pairwise Phase Consistency 2 (PPC2) value
|
435
464
|
"""
|
436
465
|
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
else:
|
452
|
-
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
453
|
-
elif isinstance(lfp_data, np.ndarray):
|
454
|
-
if filtered_lfp_phase is not None:
|
455
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
456
|
-
else:
|
457
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
458
|
-
|
459
|
-
if len(valid_indices) <= 1:
|
466
|
+
spike_phases = _get_spike_phases(
|
467
|
+
spike_times=spike_times,
|
468
|
+
lfp_data=lfp_data,
|
469
|
+
spike_fs=spike_fs,
|
470
|
+
lfp_fs=lfp_fs,
|
471
|
+
filter_method=filter_method,
|
472
|
+
freq_of_interest=freq_of_interest,
|
473
|
+
lowcut=lowcut,
|
474
|
+
highcut=highcut,
|
475
|
+
bandwidth=bandwidth,
|
476
|
+
filtered_lfp_phase=filtered_lfp_phase,
|
477
|
+
)
|
478
|
+
|
479
|
+
if len(spike_phases) <= 1:
|
460
480
|
return 0
|
461
481
|
|
462
|
-
# Get instantaneous phase
|
463
|
-
if filtered_lfp_phase is None:
|
464
|
-
instantaneous_phase = get_lfp_phase(
|
465
|
-
lfp_data=lfp_data,
|
466
|
-
filter_method=filter_method,
|
467
|
-
freq_of_interest=freq_of_interest,
|
468
|
-
lowcut=lowcut,
|
469
|
-
highcut=highcut,
|
470
|
-
bandwidth=bandwidth,
|
471
|
-
fs=lfp_fs,
|
472
|
-
)
|
473
|
-
else:
|
474
|
-
instantaneous_phase = filtered_lfp_phase
|
475
|
-
|
476
|
-
# Get phases at spike times
|
477
|
-
if isinstance(instantaneous_phase, xr.DataArray):
|
478
|
-
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
479
|
-
else:
|
480
|
-
spike_phases = instantaneous_phase[valid_indices]
|
481
482
|
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
482
483
|
n = len(spike_phases)
|
483
484
|
|
484
|
-
if n <= 1:
|
485
|
-
return 0
|
486
|
-
|
487
485
|
# Convert phases to unit vectors in the complex plane
|
488
486
|
unit_vectors = np.exp(1j * spike_phases)
|
489
487
|
|
@@ -575,7 +573,7 @@ def calculate_entrainment_per_cell(
|
|
575
573
|
for pop in pop_names:
|
576
574
|
skip_count = 0
|
577
575
|
pop_spikes = spike_df[spike_df["pop_name"] == pop]
|
578
|
-
nodes = pop_spikes["node_ids"].unique()
|
576
|
+
nodes = sorted(pop_spikes["node_ids"].unique()) # sort so all nodes are processed in order
|
579
577
|
entrainment_dict[pop] = {}
|
580
578
|
print(f"Processing {pop} population")
|
581
579
|
for node in tqdm(nodes):
|
@@ -714,3 +712,55 @@ def calculate_spike_rate_power_correlation(
|
|
714
712
|
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
715
713
|
|
716
714
|
return correlation_results, frequencies
|
715
|
+
|
716
|
+
|
717
|
+
def get_spikes_in_cycle(spike_df, lfp_data, spike_fs=1000, lfp_fs=400, band=(30, 80)):
|
718
|
+
"""
|
719
|
+
Analyze spike timing relative to oscillation phases.
|
720
|
+
|
721
|
+
Parameters:
|
722
|
+
-----------
|
723
|
+
spike_df : pd.DataFrame
|
724
|
+
lfp_data : np.array
|
725
|
+
Raw LFP signal
|
726
|
+
fs : float
|
727
|
+
Sampling frequency of LFP in Hz
|
728
|
+
gamma_band : tuple
|
729
|
+
Lower and upper bounds of gamma frequency band in Hz
|
730
|
+
|
731
|
+
Returns:
|
732
|
+
--------
|
733
|
+
phase_data : dict
|
734
|
+
Dictionary containing phase values for each spike and neuron population
|
735
|
+
"""
|
736
|
+
filtered_lfp = butter_bandpass_filter(lfp_data, band[0], band[1], lfp_fs)
|
737
|
+
|
738
|
+
# Calculate phase using Hilbert transform
|
739
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
740
|
+
phase = np.angle(analytic_signal)
|
741
|
+
amplitude = np.abs(analytic_signal)
|
742
|
+
|
743
|
+
# Get unique neuron populations
|
744
|
+
neuron_pops = spike_df["pop_name"].unique()
|
745
|
+
|
746
|
+
# Get the phase at each spike time for each neuron population
|
747
|
+
phase_data = {}
|
748
|
+
|
749
|
+
for pop in neuron_pops:
|
750
|
+
# Get spike times for this population
|
751
|
+
pop_spikes = spike_df[spike_df["pop_name"] == pop]["timestamps"].values
|
752
|
+
|
753
|
+
# Convert spike times to sample indices
|
754
|
+
spike_times_seconds = pop_spikes / spike_fs
|
755
|
+
|
756
|
+
# Then convert from seconds to samples at the new sampling rate
|
757
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
758
|
+
|
759
|
+
# Ensure spike times are within LFP data range
|
760
|
+
valid_indices = (spike_indices >= 0) & (spike_indices < len(phase))
|
761
|
+
|
762
|
+
if np.any(valid_indices):
|
763
|
+
valid_samples = spike_indices[valid_indices]
|
764
|
+
phase_data[pop] = phase[valid_samples]
|
765
|
+
|
766
|
+
return phase_data, filtered_lfp, phase, amplitude
|
bmtool/analysis/spikes.py
CHANGED
@@ -398,3 +398,58 @@ def compare_firing_over_times(
|
|
398
398
|
print(f" p-value: {p_val}")
|
399
399
|
print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
|
400
400
|
return
|
401
|
+
|
402
|
+
|
403
|
+
def find_bursting_cells(
|
404
|
+
df: pd.DataFrame, burst_threshold: float = 10, rename_bursting_cells: bool = False
|
405
|
+
) -> pd.DataFrame:
|
406
|
+
"""
|
407
|
+
Finds bursting cells in a population based on a time difference threshold.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
df : pd.DataFrame
|
412
|
+
DataFrame containing spike data with columns for timestamps, node_ids, and pop_name
|
413
|
+
burst_threshold : float, optional
|
414
|
+
Time difference threshold in milliseconds to identify bursts
|
415
|
+
rename_bursting_cells : bool, optional
|
416
|
+
If True, returns a DataFrame with bursting cells renamed in their pop_name column
|
417
|
+
|
418
|
+
Returns
|
419
|
+
-------
|
420
|
+
pd.DataFrame
|
421
|
+
DataFrame with bursting cells renamed in their pop_name column
|
422
|
+
"""
|
423
|
+
# Create a new DataFrame with the time differences
|
424
|
+
diff_df = df.copy()
|
425
|
+
diff_df["time_diff"] = df.groupby("node_ids")["timestamps"].diff()
|
426
|
+
|
427
|
+
# Create a column indicating whether each time difference is a burst
|
428
|
+
diff_df["is_burst_instance"] = diff_df["time_diff"] < burst_threshold
|
429
|
+
|
430
|
+
# Group by node_ids and check if any row has a burst instance
|
431
|
+
burst_summary = diff_df.groupby("node_ids")["is_burst_instance"].any()
|
432
|
+
|
433
|
+
# Convert to a DataFrame with reset index
|
434
|
+
burst_cells = burst_summary.reset_index(name="is_burst")
|
435
|
+
|
436
|
+
# merge with original df to get timestamps
|
437
|
+
burst_cells = pd.merge(burst_cells, df, on="node_ids")
|
438
|
+
|
439
|
+
# Create a mask for burst cells that don't already have "_bursters" in their name
|
440
|
+
burst_mask = (burst_cells["is_burst"] is True) & (
|
441
|
+
~burst_cells["pop_name"].str.contains("_bursters")
|
442
|
+
)
|
443
|
+
|
444
|
+
# Add "_bursters" suffix only to those cells
|
445
|
+
if rename_bursting_cells:
|
446
|
+
burst_cells.loc[burst_mask, "pop_name"] = (
|
447
|
+
burst_cells.loc[burst_mask, "pop_name"] + "_bursters"
|
448
|
+
)
|
449
|
+
|
450
|
+
for pop in burst_cells["pop_name"].unique():
|
451
|
+
print(
|
452
|
+
f"Number of bursters in {pop}: {burst_cells[burst_cells['pop_name'] == pop]['node_ids'].nunique()}"
|
453
|
+
)
|
454
|
+
|
455
|
+
return burst_cells
|
bmtool/bmplot/entrainment.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
import matplotlib.pyplot as plt
|
2
|
+
import numpy as np
|
2
3
|
import seaborn as sns
|
4
|
+
from matplotlib.gridspec import GridSpec
|
5
|
+
from scipy import stats
|
3
6
|
|
4
7
|
|
5
8
|
def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
@@ -48,3 +51,152 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
|
48
51
|
plt.tight_layout()
|
49
52
|
|
50
53
|
plt.show()
|
54
|
+
|
55
|
+
|
56
|
+
def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
|
57
|
+
"""
|
58
|
+
Plot an idealized gamma cycle with spike histograms for different neuron populations.
|
59
|
+
|
60
|
+
Parameters:
|
61
|
+
-----------
|
62
|
+
phase_data : dict
|
63
|
+
Dictionary containing phase values for each spike and neuron population
|
64
|
+
fs : float
|
65
|
+
Sampling frequency of LFP in Hz
|
66
|
+
bins : int
|
67
|
+
Number of bins for the phase histogram (default 36 gives 10-degree bins)
|
68
|
+
pop_name : list
|
69
|
+
List of population names to be plotted
|
70
|
+
"""
|
71
|
+
sns.set_style("whitegrid")
|
72
|
+
# Create a figure with subplots
|
73
|
+
fig = plt.figure(figsize=(12, 8))
|
74
|
+
gs = GridSpec(len(pop_name) + 1, 1, height_ratios=[1.5] + [1] * len(pop_name))
|
75
|
+
|
76
|
+
# Top subplot: Idealized gamma cycle
|
77
|
+
ax_gamma = fig.add_subplot(gs[0])
|
78
|
+
|
79
|
+
# Create an idealized gamma cycle
|
80
|
+
x = np.linspace(-np.pi, np.pi, 1000)
|
81
|
+
y = np.sin(x)
|
82
|
+
|
83
|
+
ax_gamma.plot(x, y, "b-", linewidth=2)
|
84
|
+
ax_gamma.set_title("Cycle with Neuron Population Spike Distributions", fontsize=14)
|
85
|
+
ax_gamma.set_ylabel("Amplitude", fontsize=12)
|
86
|
+
ax_gamma.set_xlim(-np.pi, np.pi)
|
87
|
+
ax_gamma.set_xticks(np.linspace(-np.pi, np.pi, 9))
|
88
|
+
ax_gamma.set_xticklabels(["-180°", "-135°", "-90°", "-45°", "0°", "45°", "90°", "135°", "180°"])
|
89
|
+
ax_gamma.grid(True)
|
90
|
+
ax_gamma.axhline(y=0, color="k", linestyle="-", alpha=0.3)
|
91
|
+
ax_gamma.axvline(x=0, color="k", linestyle="--", alpha=0.3)
|
92
|
+
|
93
|
+
# Generate a color map for the different populations
|
94
|
+
colors = plt.cm.tab10(np.linspace(0, 1, len(pop_name)))
|
95
|
+
|
96
|
+
# Add histograms for each neuron population
|
97
|
+
for i, pop_name in enumerate(pop_name):
|
98
|
+
ax_hist = fig.add_subplot(gs[i + 1], sharex=ax_gamma)
|
99
|
+
|
100
|
+
# Compute histogram
|
101
|
+
hist, bin_edges = np.histogram(phase_data[pop_name], bins=bins, range=(-np.pi, np.pi))
|
102
|
+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
|
103
|
+
|
104
|
+
# Normalize histogram
|
105
|
+
if np.sum(hist) > 0:
|
106
|
+
hist = hist / np.sum(hist) * 100 # Convert to percentage
|
107
|
+
|
108
|
+
# Plot histogram
|
109
|
+
ax_hist.bar(bin_centers, hist, width=2 * np.pi / bins, alpha=0.7, color=colors[i])
|
110
|
+
ax_hist.set_ylabel(f"{pop_name}\nSpikes (%)", fontsize=10)
|
111
|
+
|
112
|
+
# Add grid to align with gamma cycle
|
113
|
+
ax_hist.grid(True, alpha=0.3)
|
114
|
+
ax_hist.set_ylim(0, max(hist) * 1.2) # Add some headroom
|
115
|
+
|
116
|
+
# Set x-label for the last subplot
|
117
|
+
ax_hist.set_xlabel("Phase (degrees)", fontsize=12)
|
118
|
+
|
119
|
+
plt.tight_layout()
|
120
|
+
plt.show()
|
121
|
+
|
122
|
+
|
123
|
+
def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
|
124
|
+
"""
|
125
|
+
Plot PPC for all node populations on one graph with mean and standard error.
|
126
|
+
|
127
|
+
Parameters:
|
128
|
+
-----------
|
129
|
+
ppc_dict : dict
|
130
|
+
Dictionary containing PPC data organized by population, node, and frequency
|
131
|
+
pop_names : list
|
132
|
+
List of population names to plot data for
|
133
|
+
freqs : list
|
134
|
+
List of frequencies to plot
|
135
|
+
figsize : tuple
|
136
|
+
Figure size for the plot
|
137
|
+
"""
|
138
|
+
# Set up the visualization style
|
139
|
+
sns.set_style("whitegrid")
|
140
|
+
plt.figure(figsize=figsize)
|
141
|
+
|
142
|
+
# Calculate the width of each group of bars
|
143
|
+
n_groups = len(freqs)
|
144
|
+
n_populations = len(pop_names)
|
145
|
+
group_width = 0.8
|
146
|
+
bar_width = group_width / n_populations
|
147
|
+
|
148
|
+
# Color palette for different populations
|
149
|
+
pop_colors = sns.color_palette(n_colors=n_populations)
|
150
|
+
|
151
|
+
# For tracking x-axis positions and labels
|
152
|
+
x_centers = np.arange(n_groups)
|
153
|
+
tick_labels = [str(freq) for freq in freqs]
|
154
|
+
|
155
|
+
# Process and plot data for each population
|
156
|
+
for i, pop in enumerate(pop_names):
|
157
|
+
# Store mean and SE for each frequency in this population
|
158
|
+
means = []
|
159
|
+
errors = []
|
160
|
+
valid_freqs_idx = []
|
161
|
+
|
162
|
+
# Collect and process data for all frequencies in this population
|
163
|
+
for freq_idx, freq in enumerate(freqs):
|
164
|
+
freq_values = []
|
165
|
+
|
166
|
+
# Collect values across all nodes for this frequency
|
167
|
+
for node in ppc_dict[pop]:
|
168
|
+
try:
|
169
|
+
ppc_value = ppc_dict[pop][node][freq]
|
170
|
+
freq_values.append(ppc_value)
|
171
|
+
except KeyError:
|
172
|
+
continue
|
173
|
+
|
174
|
+
# If we have data for this frequency
|
175
|
+
if freq_values:
|
176
|
+
mean_val = np.mean(freq_values)
|
177
|
+
se_val = stats.sem(freq_values)
|
178
|
+
means.append(mean_val)
|
179
|
+
errors.append(se_val)
|
180
|
+
valid_freqs_idx.append(freq_idx)
|
181
|
+
|
182
|
+
# Calculate x positions for this population's bars
|
183
|
+
# Each population's bars are offset within their frequency group
|
184
|
+
x_positions = x_centers[valid_freqs_idx] + (i - n_populations / 2 + 0.5) * bar_width
|
185
|
+
|
186
|
+
# Plot bars with error bars
|
187
|
+
plt.bar(
|
188
|
+
x_positions, means, width=bar_width * 0.9, color=pop_colors[i], alpha=0.7, label=pop
|
189
|
+
)
|
190
|
+
plt.errorbar(x_positions, means, yerr=errors, fmt="none", ecolor="black", capsize=4)
|
191
|
+
|
192
|
+
# Set up the plot labels and legend
|
193
|
+
plt.xlabel("Frequency")
|
194
|
+
plt.ylabel("PPC Value")
|
195
|
+
if title:
|
196
|
+
plt.title(title)
|
197
|
+
plt.xticks(x_centers, tick_labels)
|
198
|
+
plt.legend(title="Population")
|
199
|
+
|
200
|
+
# Adjust layout and save
|
201
|
+
plt.tight_layout()
|
202
|
+
plt.show()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.7.1.
|
3
|
+
Version: 0.7.1.3
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -38,7 +38,6 @@ Requires-Dist: PyWavelets
|
|
38
38
|
Requires-Dist: numba
|
39
39
|
Provides-Extra: dev
|
40
40
|
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
41
|
-
Requires-Dist: pyright>=1.1.0; extra == "dev"
|
42
41
|
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
43
42
|
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
44
43
|
Dynamic: author
|
@@ -8,13 +8,13 @@ bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
|
|
8
8
|
bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
|
9
9
|
bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
|
10
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bmtool/analysis/entrainment.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=uG2TWbeYJEg_VQB6pKEWlrVBzQ6M4h6FSAZR4GMKp-E,28178
|
12
12
|
bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
|
13
13
|
bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
14
|
+
bmtool/analysis/spikes.py,sha256=Mz0e7XOey_6eWQDZAU0ePjDzDMDTFkMbWSA5YWDooYk,17122
|
15
15
|
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
|
17
|
-
bmtool/bmplot/entrainment.py,sha256=
|
17
|
+
bmtool/bmplot/entrainment.py,sha256=VSlZvcSeXLr5OxGvmWcGU4s7JS7vOL38lq1XC69O_AE,6926
|
18
18
|
bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
|
19
19
|
bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
20
|
bmtool/bmplot/spikes.py,sha256=Lg8V3ynYCqk-QJvq-BOInjZMHYHrxHgXjtDOX67df-A,11148
|
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
|
|
26
26
|
bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
|
27
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
|
29
|
-
bmtool-0.7.1.
|
30
|
-
bmtool-0.7.1.
|
31
|
-
bmtool-0.7.1.
|
32
|
-
bmtool-0.7.1.
|
33
|
-
bmtool-0.7.1.
|
34
|
-
bmtool-0.7.1.
|
29
|
+
bmtool-0.7.1.3.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.1.3.dist-info/METADATA,sha256=GpvzWhlfyNgHHCfhakOBeeablgTAOCHeQFvhnExZ-X8,3577
|
31
|
+
bmtool-0.7.1.3.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
32
|
+
bmtool-0.7.1.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.1.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|