bmtool 0.7.0.3__tar.gz → 0.7.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/PKG-INFO +1 -1
- bmtool-0.7.0.5/bmtool/analysis/entrainment.py +573 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/lfp.py +176 -1
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/spikes.py +115 -63
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/PKG-INFO +1 -1
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/setup.py +1 -1
- bmtool-0.7.0.3/bmtool/analysis/entrainment.py +0 -490
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/LICENSE +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/README.md +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/SLURM.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/__main__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/netcon_reports.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/connections.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/entrainment.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/lfp.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/netcon_reports.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/spikes.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/connectors.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/commands.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/debug.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/graphs.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/manage.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/plot_commands.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/singlecell.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/synapses.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/commands.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/neuron/__init__.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/neuron/celltuner.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/util.py +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/SOURCES.txt +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/dependency_links.txt +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/entry_points.txt +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/requires.txt +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/top_level.txt +0 -0
- {bmtool-0.7.0.3 → bmtool-0.7.0.5}/setup.cfg +0 -0
@@ -0,0 +1,573 @@
|
|
1
|
+
"""
|
2
|
+
Module for entrainment analysis
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from scipy import signal
|
7
|
+
import numba
|
8
|
+
from numba import cuda
|
9
|
+
import pandas as pd
|
10
|
+
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
|
11
|
+
from typing import Dict, List, Optional
|
12
|
+
from tqdm.notebook import tqdm
|
13
|
+
import scipy.stats as stats
|
14
|
+
|
15
|
+
|
16
|
+
def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
|
17
|
+
filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
18
|
+
bandwidth: float = 2.0) -> np.ndarray:
|
19
|
+
"""
|
20
|
+
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
signal1 : np.ndarray
|
25
|
+
First input signal (1D array)
|
26
|
+
signal2 : np.ndarray
|
27
|
+
Second input signal (1D array, same length as signal1)
|
28
|
+
fs : float
|
29
|
+
Sampling frequency in Hz
|
30
|
+
freq_of_interest : float, optional
|
31
|
+
Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
|
32
|
+
filter_method : str, optional
|
33
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
34
|
+
lowcut : float, optional
|
35
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
36
|
+
highcut : float, optional
|
37
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
38
|
+
bandwidth : float, optional
|
39
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
40
|
+
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
np.ndarray
|
44
|
+
Phase Locking Value (1D array)
|
45
|
+
"""
|
46
|
+
if len(signal1) != len(signal2):
|
47
|
+
raise ValueError("Input signals must have the same length.")
|
48
|
+
|
49
|
+
if filter_method == 'wavelet':
|
50
|
+
if freq_of_interest is None:
|
51
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
52
|
+
|
53
|
+
# Apply CWT to both signals
|
54
|
+
theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
55
|
+
theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
56
|
+
|
57
|
+
elif filter_method == 'butter':
|
58
|
+
if lowcut is None or highcut is None:
|
59
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
|
60
|
+
|
61
|
+
if lowcut and highcut:
|
62
|
+
# Bandpass filter and get the analytic signal using the Hilbert transform
|
63
|
+
filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
|
64
|
+
filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
|
65
|
+
# Get phase using the Hilbert transform
|
66
|
+
theta1 = signal.hilbert(filtered_signal1)
|
67
|
+
theta2 = signal.hilbert(filtered_signal2)
|
68
|
+
else:
|
69
|
+
# Get phase using the Hilbert transform without filtering
|
70
|
+
theta1 = signal.hilbert(signal1)
|
71
|
+
theta2 = signal.hilbert(signal2)
|
72
|
+
|
73
|
+
else:
|
74
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
75
|
+
|
76
|
+
# Calculate phase difference
|
77
|
+
phase_diff = np.angle(theta1) - np.angle(theta2)
|
78
|
+
|
79
|
+
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
80
|
+
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
81
|
+
|
82
|
+
return plv
|
83
|
+
|
84
|
+
|
85
|
+
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
86
|
+
lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
|
87
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
88
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
89
|
+
"""
|
90
|
+
Calculate spike-lfp unbiased phase locking value
|
91
|
+
|
92
|
+
Parameters
|
93
|
+
----------
|
94
|
+
spike_times : np.ndarray
|
95
|
+
Array of spike times
|
96
|
+
lfp_data : np.ndarray
|
97
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
98
|
+
spike_fs : float, optional
|
99
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
100
|
+
lfp_fs : float
|
101
|
+
Sampling frequency in Hz of the LFP data
|
102
|
+
filter_method : str, optional
|
103
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
|
104
|
+
freq_of_interest : float, optional
|
105
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
106
|
+
lowcut : float, optional
|
107
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
108
|
+
highcut : float, optional
|
109
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
110
|
+
bandwidth : float, optional
|
111
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
112
|
+
filtered_lfp_phase : np.ndarray, optional
|
113
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
114
|
+
|
115
|
+
Returns
|
116
|
+
-------
|
117
|
+
float
|
118
|
+
Phase Locking Value (unbiased)
|
119
|
+
"""
|
120
|
+
|
121
|
+
if spike_fs is None:
|
122
|
+
spike_fs = lfp_fs
|
123
|
+
# Convert spike times to sample indices
|
124
|
+
spike_times_seconds = spike_times / spike_fs
|
125
|
+
|
126
|
+
# Then convert from seconds to samples at the new sampling rate
|
127
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
128
|
+
|
129
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
130
|
+
if filtered_lfp_phase is not None:
|
131
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
132
|
+
else:
|
133
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
134
|
+
|
135
|
+
if len(valid_indices) <= 1:
|
136
|
+
return 0
|
137
|
+
|
138
|
+
# Get instantaneous phase
|
139
|
+
if filtered_lfp_phase is None:
|
140
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
141
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
142
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
143
|
+
else:
|
144
|
+
instantaneous_phase = filtered_lfp_phase
|
145
|
+
|
146
|
+
# Get phases at spike times
|
147
|
+
spike_phases = instantaneous_phase[valid_indices]
|
148
|
+
|
149
|
+
# Number of spikes
|
150
|
+
N = len(spike_phases)
|
151
|
+
|
152
|
+
# Convert phases to unit vectors in the complex plane
|
153
|
+
unit_vectors = np.exp(1j * spike_phases)
|
154
|
+
|
155
|
+
# Sum of all unit vectors (resultant vector)
|
156
|
+
resultant_vector = np.sum(unit_vectors)
|
157
|
+
|
158
|
+
# Calculate plv^2 * N
|
159
|
+
plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
|
160
|
+
plv = (plv2n / N) ** 0.5
|
161
|
+
ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
|
162
|
+
plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
|
163
|
+
|
164
|
+
return plv_unbiased
|
165
|
+
|
166
|
+
|
167
|
+
@numba.njit(parallel=True, fastmath=True)
|
168
|
+
def _ppc_parallel_numba(spike_phases):
|
169
|
+
"""Numba-optimized parallel PPC calculation"""
|
170
|
+
n = len(spike_phases)
|
171
|
+
sum_cos = 0.0
|
172
|
+
for i in numba.prange(n):
|
173
|
+
phase_i = spike_phases[i]
|
174
|
+
for j in range(i + 1, n):
|
175
|
+
sum_cos += np.cos(phase_i - spike_phases[j])
|
176
|
+
return (2 / (n * (n - 1))) * sum_cos
|
177
|
+
|
178
|
+
|
179
|
+
@cuda.jit(fastmath=True)
|
180
|
+
def _ppc_cuda_kernel(spike_phases, out):
|
181
|
+
i = cuda.grid(1)
|
182
|
+
if i < len(spike_phases):
|
183
|
+
local_sum = 0.0
|
184
|
+
for j in range(i+1, len(spike_phases)):
|
185
|
+
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
186
|
+
out[i] = local_sum
|
187
|
+
|
188
|
+
|
189
|
+
def _ppc_gpu(spike_phases):
|
190
|
+
"""GPU-accelerated implementation"""
|
191
|
+
d_phases = cuda.to_device(spike_phases)
|
192
|
+
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
193
|
+
|
194
|
+
threads = 256
|
195
|
+
blocks = (len(spike_phases) + threads - 1) // threads
|
196
|
+
|
197
|
+
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
198
|
+
total = d_out.copy_to_host().sum()
|
199
|
+
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
200
|
+
|
201
|
+
|
202
|
+
def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
203
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
204
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
205
|
+
ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
|
206
|
+
"""
|
207
|
+
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
208
|
+
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
spike_times : np.ndarray
|
213
|
+
Array of spike times
|
214
|
+
lfp_data : np.ndarray
|
215
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
216
|
+
spike_fs : float, optional
|
217
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
218
|
+
lfp_fs : float
|
219
|
+
Sampling frequency in Hz of the LFP data
|
220
|
+
filter_method : str, optional
|
221
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
222
|
+
freq_of_interest : float, optional
|
223
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
224
|
+
lowcut : float, optional
|
225
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
226
|
+
highcut : float, optional
|
227
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
228
|
+
bandwidth : float, optional
|
229
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
230
|
+
ppc_method : str, optional
|
231
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
232
|
+
filtered_lfp_phase : np.ndarray, optional
|
233
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
234
|
+
|
235
|
+
Returns
|
236
|
+
-------
|
237
|
+
float
|
238
|
+
Pairwise Phase Consistency value
|
239
|
+
"""
|
240
|
+
if spike_fs is None:
|
241
|
+
spike_fs = lfp_fs
|
242
|
+
# Convert spike times to sample indices
|
243
|
+
spike_times_seconds = spike_times / spike_fs
|
244
|
+
|
245
|
+
# Then convert from seconds to samples at the new sampling rate
|
246
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
247
|
+
|
248
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
249
|
+
if filtered_lfp_phase is not None:
|
250
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
251
|
+
else:
|
252
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
253
|
+
|
254
|
+
if len(valid_indices) <= 1:
|
255
|
+
return 0
|
256
|
+
|
257
|
+
# Get instantaneous phase
|
258
|
+
if filtered_lfp_phase is None:
|
259
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
260
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
261
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
262
|
+
else:
|
263
|
+
instantaneous_phase = filtered_lfp_phase
|
264
|
+
|
265
|
+
# Get phases at spike times
|
266
|
+
spike_phases = instantaneous_phase[valid_indices]
|
267
|
+
|
268
|
+
n_spikes = len(spike_phases)
|
269
|
+
|
270
|
+
# Calculate PPC (Pairwise Phase Consistency)
|
271
|
+
if n_spikes <= 1:
|
272
|
+
return 0
|
273
|
+
|
274
|
+
# Explicit calculation of pairwise phase consistency
|
275
|
+
# Vectorized computation for efficiency
|
276
|
+
if ppc_method == 'numpy':
|
277
|
+
i, j = np.triu_indices(n_spikes, k=1)
|
278
|
+
phase_diff = spike_phases[i] - spike_phases[j]
|
279
|
+
sum_cos_diff = np.sum(np.cos(phase_diff))
|
280
|
+
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
281
|
+
elif ppc_method == 'numba':
|
282
|
+
ppc = _ppc_parallel_numba(spike_phases)
|
283
|
+
elif ppc_method == 'gpu':
|
284
|
+
ppc = _ppc_gpu(spike_phases)
|
285
|
+
else:
|
286
|
+
raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
|
287
|
+
return ppc
|
288
|
+
|
289
|
+
|
290
|
+
def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
291
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
292
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
293
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
294
|
+
"""
|
295
|
+
# -----------------------------------------------------------------------------
|
296
|
+
# PPC2 Calculation (Vinck et al., 2010)
|
297
|
+
# -----------------------------------------------------------------------------
|
298
|
+
# Equation(Original):
|
299
|
+
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
300
|
+
# Optimized Formula (Algebraically Equivalent):
|
301
|
+
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
302
|
+
# -----------------------------------------------------------------------------
|
303
|
+
|
304
|
+
Parameters
|
305
|
+
----------
|
306
|
+
spike_times : np.ndarray
|
307
|
+
Array of spike times
|
308
|
+
lfp_data : np.ndarray
|
309
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
310
|
+
spike_fs : float, optional
|
311
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
312
|
+
lfp_fs : float
|
313
|
+
Sampling frequency in Hz of the LFP data
|
314
|
+
filter_method : str, optional
|
315
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
316
|
+
freq_of_interest : float, optional
|
317
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
318
|
+
lowcut : float, optional
|
319
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
320
|
+
highcut : float, optional
|
321
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
322
|
+
bandwidth : float, optional
|
323
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
324
|
+
filtered_lfp_phase : np.ndarray, optional
|
325
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
326
|
+
|
327
|
+
Returns
|
328
|
+
-------
|
329
|
+
float
|
330
|
+
Pairwise Phase Consistency 2 (PPC2) value
|
331
|
+
"""
|
332
|
+
|
333
|
+
if spike_fs is None:
|
334
|
+
spike_fs = lfp_fs
|
335
|
+
# Convert spike times to sample indices
|
336
|
+
spike_times_seconds = spike_times / spike_fs
|
337
|
+
|
338
|
+
# Then convert from seconds to samples at the new sampling rate
|
339
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
340
|
+
|
341
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
342
|
+
if filtered_lfp_phase is not None:
|
343
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
344
|
+
else:
|
345
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
346
|
+
|
347
|
+
if len(valid_indices) <= 1:
|
348
|
+
return 0
|
349
|
+
|
350
|
+
# Get instantaneous phase
|
351
|
+
if filtered_lfp_phase is None:
|
352
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
353
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
354
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
355
|
+
else:
|
356
|
+
instantaneous_phase = filtered_lfp_phase
|
357
|
+
|
358
|
+
# Get phases at spike times
|
359
|
+
spike_phases = instantaneous_phase[valid_indices]
|
360
|
+
|
361
|
+
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
362
|
+
n = len(spike_phases)
|
363
|
+
|
364
|
+
if n <= 1:
|
365
|
+
return 0
|
366
|
+
|
367
|
+
# Convert phases to unit vectors in the complex plane
|
368
|
+
unit_vectors = np.exp(1j * spike_phases)
|
369
|
+
|
370
|
+
# Calculate the resultant vector
|
371
|
+
resultant_vector = np.sum(unit_vectors)
|
372
|
+
|
373
|
+
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
374
|
+
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
375
|
+
|
376
|
+
return ppc2
|
377
|
+
|
378
|
+
|
379
|
+
def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
|
380
|
+
entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
|
381
|
+
spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
|
382
|
+
freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
|
383
|
+
"""
|
384
|
+
Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
|
385
|
+
|
386
|
+
This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
|
387
|
+
and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
|
388
|
+
organized by population, node ID, and frequency.
|
389
|
+
|
390
|
+
Parameters
|
391
|
+
----------
|
392
|
+
spike_df : pd.DataFrame
|
393
|
+
DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
|
394
|
+
lfp_data : np.ndarray
|
395
|
+
Local field potential (LFP) time series data
|
396
|
+
filter_method : str, optional
|
397
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
398
|
+
entrainment_method : str, optional
|
399
|
+
Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
|
400
|
+
lowcut : float, optional
|
401
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
402
|
+
highcut : float, optional
|
403
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
404
|
+
spike_fs : float
|
405
|
+
Sampling frequency of the spike times in Hz
|
406
|
+
lfp_fs : float
|
407
|
+
Sampling frequency of the LFP signal in Hz
|
408
|
+
bandwidth : float, optional
|
409
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
410
|
+
ppc_method : str, optional
|
411
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
412
|
+
pop_names : List[str]
|
413
|
+
List of population names to analyze
|
414
|
+
freqs : List[float]
|
415
|
+
List of frequencies (in Hz) at which to calculate entrainment
|
416
|
+
|
417
|
+
Returns
|
418
|
+
-------
|
419
|
+
Dict[str, Dict[int, Dict[float, float]]]
|
420
|
+
Nested dictionary where the structure is:
|
421
|
+
{
|
422
|
+
population_name: {
|
423
|
+
node_id: {
|
424
|
+
frequency: entrainment value
|
425
|
+
}
|
426
|
+
}
|
427
|
+
}
|
428
|
+
Entrainment values are floats representing the metric (PPC, PLV) at each frequency
|
429
|
+
"""
|
430
|
+
# pre filter lfp to speed up calculate of entrainment
|
431
|
+
filtered_lfp_phases = {}
|
432
|
+
for freq in range(len(freqs)):
|
433
|
+
phase = get_lfp_phase(
|
434
|
+
lfp_data=lfp_data,
|
435
|
+
freq_of_interest=freqs[freq],
|
436
|
+
fs=lfp_fs,
|
437
|
+
filter_method=filter_method,
|
438
|
+
lowcut=lowcut,
|
439
|
+
highcut=highcut,
|
440
|
+
bandwidth=bandwidth
|
441
|
+
)
|
442
|
+
filtered_lfp_phases[freqs[freq]] = phase
|
443
|
+
|
444
|
+
entrainment_dict = {}
|
445
|
+
for pop in pop_names:
|
446
|
+
skip_count = 0
|
447
|
+
pop_spikes = spike_df[spike_df['pop_name'] == pop]
|
448
|
+
nodes = pop_spikes['node_ids'].unique()
|
449
|
+
entrainment_dict[pop] = {}
|
450
|
+
print(f'Processing {pop} population')
|
451
|
+
for node in tqdm(nodes):
|
452
|
+
node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
|
453
|
+
|
454
|
+
# Skip nodes with less than or equal to 1 spike
|
455
|
+
if len(node_spikes) <= 1:
|
456
|
+
skip_count += 1
|
457
|
+
continue
|
458
|
+
|
459
|
+
entrainment_dict[pop][node] = {}
|
460
|
+
for freq in freqs:
|
461
|
+
# Calculate entrainment based on the selected method using the pre-filtered phases
|
462
|
+
if entrainment_method == 'plv':
|
463
|
+
entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
|
464
|
+
node_spikes['timestamps'].values,
|
465
|
+
lfp_data,
|
466
|
+
spike_fs=spike_fs,
|
467
|
+
lfp_fs=lfp_fs,
|
468
|
+
freq_of_interest=freq,
|
469
|
+
bandwidth=bandwidth,
|
470
|
+
lowcut=lowcut,
|
471
|
+
highcut=highcut,
|
472
|
+
filter_method=filter_method,
|
473
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
474
|
+
)
|
475
|
+
elif entrainment_method == 'ppc2':
|
476
|
+
entrainment_dict[pop][node][freq] = calculate_ppc2(
|
477
|
+
node_spikes['timestamps'].values,
|
478
|
+
lfp_data,
|
479
|
+
spike_fs=spike_fs,
|
480
|
+
lfp_fs=lfp_fs,
|
481
|
+
freq_of_interest=freq,
|
482
|
+
bandwidth=bandwidth,
|
483
|
+
lowcut=lowcut,
|
484
|
+
highcut=highcut,
|
485
|
+
filter_method=filter_method,
|
486
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
487
|
+
)
|
488
|
+
elif entrainment_method == 'ppc':
|
489
|
+
entrainment_dict[pop][node][freq] = calculate_ppc(
|
490
|
+
node_spikes['timestamps'].values,
|
491
|
+
lfp_data,
|
492
|
+
spike_fs=spike_fs,
|
493
|
+
lfp_fs=lfp_fs,
|
494
|
+
freq_of_interest=freq,
|
495
|
+
bandwidth=bandwidth,
|
496
|
+
lowcut=lowcut,
|
497
|
+
highcut=highcut,
|
498
|
+
filter_method=filter_method,
|
499
|
+
ppc_method=ppc_method,
|
500
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
501
|
+
)
|
502
|
+
|
503
|
+
print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
|
504
|
+
|
505
|
+
return entrainment_dict
|
506
|
+
|
507
|
+
|
508
|
+
def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
|
509
|
+
bandwidth=2.0, lowcut=None, highcut=None,
|
510
|
+
freq_range=(10, 100), freq_step=5):
|
511
|
+
"""
|
512
|
+
Calculate correlation between population spike rates and LFP power across frequencies
|
513
|
+
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
514
|
+
|
515
|
+
Parameters:
|
516
|
+
-----------
|
517
|
+
spike_rate : DataFrame
|
518
|
+
Pre-calculated population spike rates at the same fs as lfp
|
519
|
+
lfp_data : np.array
|
520
|
+
LFP data
|
521
|
+
fs : float
|
522
|
+
Sampling frequency
|
523
|
+
pop_names : list
|
524
|
+
List of population names to analyze
|
525
|
+
filter_method : str, optional
|
526
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
527
|
+
bandwidth : float, optional
|
528
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
529
|
+
lowcut : float, optional
|
530
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
531
|
+
highcut : float, optional
|
532
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
533
|
+
freq_range : tuple, optional
|
534
|
+
Min and max frequency to analyze (default: (10, 100))
|
535
|
+
freq_step : float, optional
|
536
|
+
Step size for frequency analysis (default: 5)
|
537
|
+
|
538
|
+
Returns:
|
539
|
+
--------
|
540
|
+
correlation_results : dict
|
541
|
+
Dictionary with correlation results for each population and frequency
|
542
|
+
frequencies : array
|
543
|
+
Array of frequencies analyzed
|
544
|
+
"""
|
545
|
+
|
546
|
+
# Define frequency bands to analyze
|
547
|
+
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
548
|
+
|
549
|
+
# Dictionary to store results
|
550
|
+
correlation_results = {pop: {} for pop in pop_names}
|
551
|
+
|
552
|
+
# Calculate power at each frequency band using specified filter
|
553
|
+
power_by_freq = {}
|
554
|
+
for freq in frequencies:
|
555
|
+
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
556
|
+
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
|
557
|
+
|
558
|
+
# Calculate correlation for each population
|
559
|
+
for pop in pop_names:
|
560
|
+
# Extract spike rate for this population
|
561
|
+
pop_rate = spike_rate[pop]
|
562
|
+
|
563
|
+
# Calculate correlation with power at each frequency
|
564
|
+
for freq in frequencies:
|
565
|
+
# Make sure the lengths match
|
566
|
+
if len(pop_rate) != len(power_by_freq[freq]):
|
567
|
+
raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
|
568
|
+
# use spearman for non-parametric correlation
|
569
|
+
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
570
|
+
correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
|
571
|
+
|
572
|
+
return correlation_results, frequencies
|
573
|
+
|