bmtool 0.6.9.24__tar.gz → 0.6.9.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/PKG-INFO +1 -1
- bmtool-0.6.9.25/bmtool/analysis/entrainment.py +429 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/lfp.py +0 -356
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/bmplot.py +56 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/synapses.py +1 -1
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/PKG-INFO +1 -1
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/SOURCES.txt +1 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/setup.py +1 -1
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/LICENSE +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/README.md +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/SLURM.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/__init__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/__main__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/__init__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/netcon_reports.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/spikes.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/connectors.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/__init__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/commands.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/debug.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/graphs.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/manage.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/plot_commands.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/singlecell.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/__init__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/commands.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/neuron/__init__.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/neuron/celltuner.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/util.py +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/dependency_links.txt +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/entry_points.txt +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/requires.txt +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/top_level.txt +0 -0
- {bmtool-0.6.9.24 → bmtool-0.6.9.25}/setup.cfg +0 -0
@@ -0,0 +1,429 @@
|
|
1
|
+
"""
|
2
|
+
Module for entrainment analysis
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from scipy import signal
|
7
|
+
import numba
|
8
|
+
from numba import cuda
|
9
|
+
import pandas as pd
|
10
|
+
import xarray as xr
|
11
|
+
from .lfp import wavelet_filter,butter_bandpass_filter
|
12
|
+
from typing import Dict, List
|
13
|
+
from tqdm.notebook import tqdm
|
14
|
+
|
15
|
+
|
16
|
+
def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
17
|
+
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
18
|
+
bandwidth: float = 2.0) -> np.ndarray:
|
19
|
+
"""
|
20
|
+
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
21
|
+
|
22
|
+
Parameters:
|
23
|
+
- x1, x2: Input signals (1D arrays, same length)
|
24
|
+
- fs: Sampling frequency
|
25
|
+
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
26
|
+
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
27
|
+
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
28
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- plv: Phase Locking Value (1D array)
|
32
|
+
"""
|
33
|
+
if len(x1) != len(x2):
|
34
|
+
raise ValueError("Input signals must have the same length.")
|
35
|
+
|
36
|
+
if method == 'wavelet':
|
37
|
+
if freq_of_interest is None:
|
38
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
39
|
+
|
40
|
+
# Apply CWT to both signals
|
41
|
+
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
42
|
+
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
43
|
+
|
44
|
+
elif method == 'hilbert':
|
45
|
+
if lowcut is None or highcut is None:
|
46
|
+
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
47
|
+
|
48
|
+
if lowcut and highcut:
|
49
|
+
# Bandpass filter and get the analytic signal using the Hilbert transform
|
50
|
+
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
51
|
+
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
52
|
+
|
53
|
+
# Get phase using the Hilbert transform
|
54
|
+
theta1 = signal.hilbert(x1)
|
55
|
+
theta2 = signal.hilbert(x2)
|
56
|
+
|
57
|
+
else:
|
58
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
59
|
+
|
60
|
+
# Calculate phase difference
|
61
|
+
phase_diff = np.angle(theta1) - np.angle(theta2)
|
62
|
+
|
63
|
+
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
64
|
+
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
65
|
+
|
66
|
+
return plv
|
67
|
+
|
68
|
+
|
69
|
+
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
|
70
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
71
|
+
lowcut: float = None, highcut: float = None,
|
72
|
+
bandwidth: float = 2.0) -> tuple:
|
73
|
+
"""
|
74
|
+
Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
75
|
+
|
76
|
+
Parameters:
|
77
|
+
- spike_times: Array of spike times
|
78
|
+
- lfp_signal: Local field potential time series
|
79
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
80
|
+
- lfp_fs : Sampling frequency in Hz of the LFP
|
81
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
82
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
83
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
84
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
- ppc1: Phase-Phase Coupling value
|
88
|
+
- spike_phases: Phases at spike times
|
89
|
+
"""
|
90
|
+
|
91
|
+
if spike_fs == None:
|
92
|
+
spike_fs = lfp_fs
|
93
|
+
# Convert spike times to sample indices
|
94
|
+
spike_times_seconds = spike_times / spike_fs
|
95
|
+
|
96
|
+
# Then convert from seconds to samples at the new sampling rate
|
97
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
98
|
+
|
99
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
100
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
101
|
+
if len(valid_indices) <= 1:
|
102
|
+
return 0, np.array([])
|
103
|
+
|
104
|
+
# Extract phase using the specified method
|
105
|
+
if method == 'wavelet':
|
106
|
+
if freq_of_interest is None:
|
107
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
108
|
+
|
109
|
+
# Apply CWT to extract phase at the frequency of interest
|
110
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
111
|
+
instantaneous_phase = np.angle(lfp_complex)
|
112
|
+
|
113
|
+
elif method == 'hilbert':
|
114
|
+
if lowcut is None or highcut is None:
|
115
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
|
116
|
+
filtered_lfp = lfp_signal
|
117
|
+
else:
|
118
|
+
# Bandpass filter the signal
|
119
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
120
|
+
|
121
|
+
# Get phase using the Hilbert transform
|
122
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
123
|
+
instantaneous_phase = np.angle(analytic_signal)
|
124
|
+
|
125
|
+
else:
|
126
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
127
|
+
|
128
|
+
# Get phases at spike times
|
129
|
+
spike_phases = instantaneous_phase[valid_indices]
|
130
|
+
|
131
|
+
# Calculate PPC1
|
132
|
+
n = len(spike_phases)
|
133
|
+
|
134
|
+
# Convert phases to unit vectors in the complex plane
|
135
|
+
unit_vectors = np.exp(1j * spike_phases)
|
136
|
+
|
137
|
+
# Calculate the resultant vector
|
138
|
+
resultant_vector = np.sum(unit_vectors)
|
139
|
+
|
140
|
+
# Plv is the squared length of the resultant vector divided by n²
|
141
|
+
plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
|
142
|
+
|
143
|
+
return plv
|
144
|
+
|
145
|
+
|
146
|
+
@numba.njit(parallel=True, fastmath=True)
|
147
|
+
def _ppc_parallel_numba(spike_phases):
|
148
|
+
"""Numba-optimized parallel PPC calculation"""
|
149
|
+
n = len(spike_phases)
|
150
|
+
sum_cos = 0.0
|
151
|
+
for i in numba.prange(n):
|
152
|
+
phase_i = spike_phases[i]
|
153
|
+
for j in range(i + 1, n):
|
154
|
+
sum_cos += np.cos(phase_i - spike_phases[j])
|
155
|
+
return (2 / (n * (n - 1))) * sum_cos
|
156
|
+
|
157
|
+
|
158
|
+
@cuda.jit(fastmath=True)
|
159
|
+
def _ppc_cuda_kernel(spike_phases, out):
|
160
|
+
i = cuda.grid(1)
|
161
|
+
if i < len(spike_phases):
|
162
|
+
local_sum = 0.0
|
163
|
+
for j in range(i+1, len(spike_phases)):
|
164
|
+
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
165
|
+
out[i] = local_sum
|
166
|
+
|
167
|
+
|
168
|
+
def _ppc_gpu(spike_phases):
|
169
|
+
"""GPU-accelerated implementation"""
|
170
|
+
d_phases = cuda.to_device(spike_phases)
|
171
|
+
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
172
|
+
|
173
|
+
threads = 256
|
174
|
+
blocks = (len(spike_phases) + threads - 1) // threads
|
175
|
+
|
176
|
+
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
177
|
+
total = d_out.copy_to_host().sum()
|
178
|
+
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
179
|
+
|
180
|
+
|
181
|
+
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
182
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
183
|
+
lowcut: float = None, highcut: float = None,
|
184
|
+
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
185
|
+
"""
|
186
|
+
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
187
|
+
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
188
|
+
|
189
|
+
Parameters:
|
190
|
+
- spike_times: Array of spike times
|
191
|
+
- lfp_signal: Local field potential time series
|
192
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
193
|
+
- lfp_fs: Sampling frequency in Hz of the LFP
|
194
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
195
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
196
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
197
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
198
|
+
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
- ppc: Pairwise Phase Consistency value
|
202
|
+
"""
|
203
|
+
if spike_fs is None:
|
204
|
+
spike_fs = lfp_fs
|
205
|
+
# Convert spike times to sample indices
|
206
|
+
spike_times_seconds = spike_times / spike_fs
|
207
|
+
|
208
|
+
# Then convert from seconds to samples at the new sampling rate
|
209
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
210
|
+
|
211
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
212
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
213
|
+
if len(valid_indices) <= 1:
|
214
|
+
return 0, np.array([])
|
215
|
+
|
216
|
+
# Extract phase using the specified method
|
217
|
+
if method == 'wavelet':
|
218
|
+
if freq_of_interest is None:
|
219
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
220
|
+
|
221
|
+
# Apply CWT to extract phase at the frequency of interest
|
222
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
223
|
+
instantaneous_phase = np.angle(lfp_complex)
|
224
|
+
|
225
|
+
elif method == 'hilbert':
|
226
|
+
if lowcut is None or highcut is None:
|
227
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
|
228
|
+
filtered_lfp = lfp_signal
|
229
|
+
else:
|
230
|
+
# Bandpass filter the signal
|
231
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
232
|
+
|
233
|
+
# Get phase using the Hilbert transform
|
234
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
235
|
+
instantaneous_phase = np.angle(analytic_signal)
|
236
|
+
|
237
|
+
else:
|
238
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
239
|
+
|
240
|
+
# Get phases at spike times
|
241
|
+
spike_phases = instantaneous_phase[valid_indices]
|
242
|
+
|
243
|
+
n_spikes = len(spike_phases)
|
244
|
+
|
245
|
+
# Calculate PPC (Pairwise Phase Consistency)
|
246
|
+
if n_spikes <= 1:
|
247
|
+
return 0, spike_phases
|
248
|
+
|
249
|
+
# Explicit calculation of pairwise phase consistency
|
250
|
+
sum_cos_diff = 0.0
|
251
|
+
|
252
|
+
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
253
|
+
# for i in range(n_spikes - 1): # For each spike i
|
254
|
+
# for j in range(i + 1, n_spikes): # For each spike j > i
|
255
|
+
# # Calculate the phase difference between spikes i and j
|
256
|
+
# phase_diff = spike_phases[i] - spike_phases[j]
|
257
|
+
|
258
|
+
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
259
|
+
# cos_diff = np.cos(phase_diff)
|
260
|
+
|
261
|
+
# # Add to the sum
|
262
|
+
# sum_cos_diff += cos_diff
|
263
|
+
|
264
|
+
# # Calculate PPC according to the equation
|
265
|
+
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
266
|
+
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
267
|
+
|
268
|
+
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
269
|
+
if ppc_method == 'numpy':
|
270
|
+
i, j = np.triu_indices(n_spikes, k=1)
|
271
|
+
phase_diff = spike_phases[i] - spike_phases[j]
|
272
|
+
sum_cos_diff = np.sum(np.cos(phase_diff))
|
273
|
+
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
274
|
+
elif ppc_method == 'numba':
|
275
|
+
ppc = _ppc_parallel_numba(spike_phases)
|
276
|
+
elif ppc_method == 'gpu':
|
277
|
+
ppc = _ppc_gpu(spike_phases)
|
278
|
+
else:
|
279
|
+
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
280
|
+
return ppc
|
281
|
+
|
282
|
+
|
283
|
+
def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
284
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
285
|
+
lowcut: float = None, highcut: float = None,
|
286
|
+
bandwidth: float = 2.0) -> tuple:
|
287
|
+
"""
|
288
|
+
# -----------------------------------------------------------------------------
|
289
|
+
# PPC2 Calculation (Vinck et al., 2010)
|
290
|
+
# -----------------------------------------------------------------------------
|
291
|
+
# Equation(Original):
|
292
|
+
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
293
|
+
# Optimized Formula (Algebraically Equivalent):
|
294
|
+
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
295
|
+
# -----------------------------------------------------------------------------
|
296
|
+
|
297
|
+
Parameters:
|
298
|
+
- spike_times: Array of spike times
|
299
|
+
- lfp_signal: Local field potential time series
|
300
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
301
|
+
- lfp_fs: Sampling frequency in Hz of the LFP
|
302
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
303
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
304
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
305
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
- ppc2: Pairwise Phase Consistency 2 value
|
309
|
+
"""
|
310
|
+
|
311
|
+
if spike_fs is None:
|
312
|
+
spike_fs = lfp_fs
|
313
|
+
# Convert spike times to sample indices
|
314
|
+
spike_times_seconds = spike_times / spike_fs
|
315
|
+
|
316
|
+
# Then convert from seconds to samples at the new sampling rate
|
317
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
318
|
+
|
319
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
320
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
321
|
+
if len(valid_indices) <= 1:
|
322
|
+
return 0, np.array([])
|
323
|
+
|
324
|
+
# Extract phase using the specified method
|
325
|
+
if method == 'wavelet':
|
326
|
+
if freq_of_interest is None:
|
327
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
328
|
+
|
329
|
+
# Apply CWT to extract phase at the frequency of interest
|
330
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
331
|
+
instantaneous_phase = np.angle(lfp_complex)
|
332
|
+
|
333
|
+
elif method == 'hilbert':
|
334
|
+
if lowcut is None or highcut is None:
|
335
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
|
336
|
+
filtered_lfp = lfp_signal
|
337
|
+
else:
|
338
|
+
# Bandpass filter the signal
|
339
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
340
|
+
|
341
|
+
# Get phase using the Hilbert transform
|
342
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
343
|
+
instantaneous_phase = np.angle(analytic_signal)
|
344
|
+
|
345
|
+
else:
|
346
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
347
|
+
|
348
|
+
# Get phases at spike times
|
349
|
+
spike_phases = instantaneous_phase[valid_indices]
|
350
|
+
|
351
|
+
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
352
|
+
n = len(spike_phases)
|
353
|
+
|
354
|
+
if n <= 1:
|
355
|
+
return 0, spike_phases
|
356
|
+
|
357
|
+
# Convert phases to unit vectors in the complex plane
|
358
|
+
unit_vectors = np.exp(1j * spike_phases)
|
359
|
+
|
360
|
+
# Calculate the resultant vector
|
361
|
+
resultant_vector = np.sum(unit_vectors)
|
362
|
+
|
363
|
+
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
364
|
+
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
365
|
+
|
366
|
+
return ppc2
|
367
|
+
|
368
|
+
|
369
|
+
def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
|
370
|
+
spike_fs: float, lfp_fs:float,
|
371
|
+
pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
|
372
|
+
"""
|
373
|
+
Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
|
374
|
+
|
375
|
+
This function computes the PPC for each neuron within the specified populations based on their spike times
|
376
|
+
and a single-channel local field potential (LFP) signal.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
|
380
|
+
lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
|
381
|
+
spike_fs (float): sampling rate of spikes BMTK default is 1000
|
382
|
+
lfp_fs (float): sampling rate of lfp
|
383
|
+
pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
|
384
|
+
freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
|
388
|
+
{
|
389
|
+
population_name: {
|
390
|
+
node_id: {
|
391
|
+
frequency: PPC value
|
392
|
+
}
|
393
|
+
}
|
394
|
+
}
|
395
|
+
PPC values are floats representing the pairwise phase consistency at each frequency.
|
396
|
+
"""
|
397
|
+
ppc_dict = {}
|
398
|
+
for pop in pop_names:
|
399
|
+
skip_count = 0
|
400
|
+
pop_spikes = spike_df[spike_df['pop_name'] == pop]
|
401
|
+
nodes = pop_spikes['node_ids'].unique()
|
402
|
+
ppc_dict[pop] = {}
|
403
|
+
print(f'Processing {pop} population')
|
404
|
+
for node in tqdm(nodes):
|
405
|
+
node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
|
406
|
+
|
407
|
+
# Skip nodes with less than or equal to 1 spike
|
408
|
+
if len(node_spikes) <= 1:
|
409
|
+
skip_count += 1
|
410
|
+
continue
|
411
|
+
|
412
|
+
ppc_dict[pop][node] = {}
|
413
|
+
for freq in freqs:
|
414
|
+
ppc = calculate_ppc2(
|
415
|
+
node_spikes['timestamps'].values,
|
416
|
+
lfp_signal,
|
417
|
+
spike_fs=spike_fs,
|
418
|
+
lfp_fs=lfp_fs,
|
419
|
+
freq_of_interest=freq,
|
420
|
+
method='wavelet'
|
421
|
+
)
|
422
|
+
ppc_dict[pop][node][freq] = ppc
|
423
|
+
|
424
|
+
print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
|
425
|
+
|
426
|
+
return ppc_dict
|
427
|
+
|
428
|
+
|
429
|
+
|
@@ -11,8 +11,6 @@ import matplotlib.pyplot as plt
|
|
11
11
|
from scipy import signal
|
12
12
|
import pywt
|
13
13
|
from bmtool.bmplot import is_notebook
|
14
|
-
import numba
|
15
|
-
from numba import cuda
|
16
14
|
import pandas as pd
|
17
15
|
|
18
16
|
|
@@ -295,360 +293,6 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
|
|
295
293
|
return x_a
|
296
294
|
|
297
295
|
|
298
|
-
def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
299
|
-
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
300
|
-
bandwidth: float = 2.0) -> np.ndarray:
|
301
|
-
"""
|
302
|
-
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
303
|
-
|
304
|
-
Parameters:
|
305
|
-
- x1, x2: Input signals (1D arrays, same length)
|
306
|
-
- fs: Sampling frequency
|
307
|
-
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
308
|
-
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
309
|
-
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
310
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
311
|
-
|
312
|
-
Returns:
|
313
|
-
- plv: Phase Locking Value (1D array)
|
314
|
-
"""
|
315
|
-
if len(x1) != len(x2):
|
316
|
-
raise ValueError("Input signals must have the same length.")
|
317
|
-
|
318
|
-
if method == 'wavelet':
|
319
|
-
if freq_of_interest is None:
|
320
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
321
|
-
|
322
|
-
# Apply CWT to both signals
|
323
|
-
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
324
|
-
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
325
|
-
|
326
|
-
elif method == 'hilbert':
|
327
|
-
if lowcut is None or highcut is None:
|
328
|
-
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
329
|
-
|
330
|
-
if lowcut and highcut:
|
331
|
-
# Bandpass filter and get the analytic signal using the Hilbert transform
|
332
|
-
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
333
|
-
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
334
|
-
|
335
|
-
# Get phase using the Hilbert transform
|
336
|
-
theta1 = signal.hilbert(x1)
|
337
|
-
theta2 = signal.hilbert(x2)
|
338
|
-
|
339
|
-
else:
|
340
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
341
|
-
|
342
|
-
# Calculate phase difference
|
343
|
-
phase_diff = np.angle(theta1) - np.angle(theta2)
|
344
|
-
|
345
|
-
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
346
|
-
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
347
|
-
|
348
|
-
return plv
|
349
|
-
|
350
|
-
|
351
|
-
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
|
352
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
353
|
-
lowcut: float = None, highcut: float = None,
|
354
|
-
bandwidth: float = 2.0) -> tuple:
|
355
|
-
"""
|
356
|
-
Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
357
|
-
|
358
|
-
Parameters:
|
359
|
-
- spike_times: Array of spike times
|
360
|
-
- lfp_signal: Local field potential time series
|
361
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
362
|
-
- lfp_fs : Sampling frequency in Hz of the LFP
|
363
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
364
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
365
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
366
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
367
|
-
|
368
|
-
Returns:
|
369
|
-
- ppc1: Phase-Phase Coupling value
|
370
|
-
- spike_phases: Phases at spike times
|
371
|
-
"""
|
372
|
-
|
373
|
-
if spike_fs == None:
|
374
|
-
spike_fs = lfp_fs
|
375
|
-
# Convert spike times to sample indices
|
376
|
-
spike_times_seconds = spike_times / spike_fs
|
377
|
-
|
378
|
-
# Then convert from seconds to samples at the new sampling rate
|
379
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
380
|
-
|
381
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
382
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
383
|
-
if len(valid_indices) <= 1:
|
384
|
-
return 0, np.array([])
|
385
|
-
|
386
|
-
# Extract phase using the specified method
|
387
|
-
if method == 'wavelet':
|
388
|
-
if freq_of_interest is None:
|
389
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
390
|
-
|
391
|
-
# Apply CWT to extract phase at the frequency of interest
|
392
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
393
|
-
instantaneous_phase = np.angle(lfp_complex)
|
394
|
-
|
395
|
-
elif method == 'hilbert':
|
396
|
-
if lowcut is None or highcut is None:
|
397
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
|
398
|
-
filtered_lfp = lfp_signal
|
399
|
-
else:
|
400
|
-
# Bandpass filter the signal
|
401
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
402
|
-
|
403
|
-
# Get phase using the Hilbert transform
|
404
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
405
|
-
instantaneous_phase = np.angle(analytic_signal)
|
406
|
-
|
407
|
-
else:
|
408
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
409
|
-
|
410
|
-
# Get phases at spike times
|
411
|
-
spike_phases = instantaneous_phase[valid_indices]
|
412
|
-
|
413
|
-
# Calculate PPC1
|
414
|
-
n = len(spike_phases)
|
415
|
-
|
416
|
-
# Convert phases to unit vectors in the complex plane
|
417
|
-
unit_vectors = np.exp(1j * spike_phases)
|
418
|
-
|
419
|
-
# Calculate the resultant vector
|
420
|
-
resultant_vector = np.sum(unit_vectors)
|
421
|
-
|
422
|
-
# Plv is the squared length of the resultant vector divided by n²
|
423
|
-
plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
|
424
|
-
|
425
|
-
return plv
|
426
|
-
|
427
|
-
|
428
|
-
@numba.njit(parallel=True, fastmath=True)
|
429
|
-
def _ppc_parallel_numba(spike_phases):
|
430
|
-
"""Numba-optimized parallel PPC calculation"""
|
431
|
-
n = len(spike_phases)
|
432
|
-
sum_cos = 0.0
|
433
|
-
for i in numba.prange(n):
|
434
|
-
phase_i = spike_phases[i]
|
435
|
-
for j in range(i + 1, n):
|
436
|
-
sum_cos += np.cos(phase_i - spike_phases[j])
|
437
|
-
return (2 / (n * (n - 1))) * sum_cos
|
438
|
-
|
439
|
-
|
440
|
-
@cuda.jit(fastmath=True)
|
441
|
-
def _ppc_cuda_kernel(spike_phases, out):
|
442
|
-
i = cuda.grid(1)
|
443
|
-
if i < len(spike_phases):
|
444
|
-
local_sum = 0.0
|
445
|
-
for j in range(i+1, len(spike_phases)):
|
446
|
-
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
447
|
-
out[i] = local_sum
|
448
|
-
|
449
|
-
|
450
|
-
def _ppc_gpu(spike_phases):
|
451
|
-
"""GPU-accelerated implementation"""
|
452
|
-
d_phases = cuda.to_device(spike_phases)
|
453
|
-
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
454
|
-
|
455
|
-
threads = 256
|
456
|
-
blocks = (len(spike_phases) + threads - 1) // threads
|
457
|
-
|
458
|
-
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
459
|
-
total = d_out.copy_to_host().sum()
|
460
|
-
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
461
|
-
|
462
|
-
|
463
|
-
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
464
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
465
|
-
lowcut: float = None, highcut: float = None,
|
466
|
-
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
467
|
-
"""
|
468
|
-
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
469
|
-
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
470
|
-
|
471
|
-
Parameters:
|
472
|
-
- spike_times: Array of spike times
|
473
|
-
- lfp_signal: Local field potential time series
|
474
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
475
|
-
- lfp_fs: Sampling frequency in Hz of the LFP
|
476
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
477
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
478
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
479
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
480
|
-
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
481
|
-
|
482
|
-
Returns:
|
483
|
-
- ppc: Pairwise Phase Consistency value
|
484
|
-
"""
|
485
|
-
if spike_fs is None:
|
486
|
-
spike_fs = lfp_fs
|
487
|
-
# Convert spike times to sample indices
|
488
|
-
spike_times_seconds = spike_times / spike_fs
|
489
|
-
|
490
|
-
# Then convert from seconds to samples at the new sampling rate
|
491
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
492
|
-
|
493
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
494
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
495
|
-
if len(valid_indices) <= 1:
|
496
|
-
return 0, np.array([])
|
497
|
-
|
498
|
-
# Extract phase using the specified method
|
499
|
-
if method == 'wavelet':
|
500
|
-
if freq_of_interest is None:
|
501
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
502
|
-
|
503
|
-
# Apply CWT to extract phase at the frequency of interest
|
504
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
505
|
-
instantaneous_phase = np.angle(lfp_complex)
|
506
|
-
|
507
|
-
elif method == 'hilbert':
|
508
|
-
if lowcut is None or highcut is None:
|
509
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
|
510
|
-
filtered_lfp = lfp_signal
|
511
|
-
else:
|
512
|
-
# Bandpass filter the signal
|
513
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
514
|
-
|
515
|
-
# Get phase using the Hilbert transform
|
516
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
517
|
-
instantaneous_phase = np.angle(analytic_signal)
|
518
|
-
|
519
|
-
else:
|
520
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
521
|
-
|
522
|
-
# Get phases at spike times
|
523
|
-
spike_phases = instantaneous_phase[valid_indices]
|
524
|
-
|
525
|
-
n_spikes = len(spike_phases)
|
526
|
-
|
527
|
-
# Calculate PPC (Pairwise Phase Consistency)
|
528
|
-
if n_spikes <= 1:
|
529
|
-
return 0, spike_phases
|
530
|
-
|
531
|
-
# Explicit calculation of pairwise phase consistency
|
532
|
-
sum_cos_diff = 0.0
|
533
|
-
|
534
|
-
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
535
|
-
# for i in range(n_spikes - 1): # For each spike i
|
536
|
-
# for j in range(i + 1, n_spikes): # For each spike j > i
|
537
|
-
# # Calculate the phase difference between spikes i and j
|
538
|
-
# phase_diff = spike_phases[i] - spike_phases[j]
|
539
|
-
|
540
|
-
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
541
|
-
# cos_diff = np.cos(phase_diff)
|
542
|
-
|
543
|
-
# # Add to the sum
|
544
|
-
# sum_cos_diff += cos_diff
|
545
|
-
|
546
|
-
# # Calculate PPC according to the equation
|
547
|
-
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
548
|
-
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
549
|
-
|
550
|
-
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
551
|
-
if ppc_method == 'numpy':
|
552
|
-
i, j = np.triu_indices(n_spikes, k=1)
|
553
|
-
phase_diff = spike_phases[i] - spike_phases[j]
|
554
|
-
sum_cos_diff = np.sum(np.cos(phase_diff))
|
555
|
-
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
556
|
-
elif ppc_method == 'numba':
|
557
|
-
ppc = _ppc_parallel_numba(spike_phases)
|
558
|
-
elif ppc_method == 'gpu':
|
559
|
-
ppc = _ppc_gpu(spike_phases)
|
560
|
-
else:
|
561
|
-
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
562
|
-
return ppc
|
563
|
-
|
564
|
-
|
565
|
-
def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
566
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
567
|
-
lowcut: float = None, highcut: float = None,
|
568
|
-
bandwidth: float = 2.0) -> tuple:
|
569
|
-
"""
|
570
|
-
# -----------------------------------------------------------------------------
|
571
|
-
# PPC2 Calculation (Vinck et al., 2010)
|
572
|
-
# -----------------------------------------------------------------------------
|
573
|
-
# Equation(Original):
|
574
|
-
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
575
|
-
# Optimized Formula (Algebraically Equivalent):
|
576
|
-
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
577
|
-
# -----------------------------------------------------------------------------
|
578
|
-
|
579
|
-
Parameters:
|
580
|
-
- spike_times: Array of spike times
|
581
|
-
- lfp_signal: Local field potential time series
|
582
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
583
|
-
- lfp_fs: Sampling frequency in Hz of the LFP
|
584
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
585
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
586
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
587
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
588
|
-
|
589
|
-
Returns:
|
590
|
-
- ppc2: Pairwise Phase Consistency 2 value
|
591
|
-
- spike_phases: Phases at spike times
|
592
|
-
"""
|
593
|
-
|
594
|
-
if spike_fs is None:
|
595
|
-
spike_fs = lfp_fs
|
596
|
-
# Convert spike times to sample indices
|
597
|
-
spike_times_seconds = spike_times / spike_fs
|
598
|
-
|
599
|
-
# Then convert from seconds to samples at the new sampling rate
|
600
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
601
|
-
|
602
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
603
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
604
|
-
if len(valid_indices) <= 1:
|
605
|
-
return 0, np.array([])
|
606
|
-
|
607
|
-
# Extract phase using the specified method
|
608
|
-
if method == 'wavelet':
|
609
|
-
if freq_of_interest is None:
|
610
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
611
|
-
|
612
|
-
# Apply CWT to extract phase at the frequency of interest
|
613
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
614
|
-
instantaneous_phase = np.angle(lfp_complex)
|
615
|
-
|
616
|
-
elif method == 'hilbert':
|
617
|
-
if lowcut is None or highcut is None:
|
618
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
|
619
|
-
filtered_lfp = lfp_signal
|
620
|
-
else:
|
621
|
-
# Bandpass filter the signal
|
622
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
623
|
-
|
624
|
-
# Get phase using the Hilbert transform
|
625
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
626
|
-
instantaneous_phase = np.angle(analytic_signal)
|
627
|
-
|
628
|
-
else:
|
629
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
630
|
-
|
631
|
-
# Get phases at spike times
|
632
|
-
spike_phases = instantaneous_phase[valid_indices]
|
633
|
-
|
634
|
-
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
635
|
-
n = len(spike_phases)
|
636
|
-
|
637
|
-
if n <= 1:
|
638
|
-
return 0, spike_phases
|
639
|
-
|
640
|
-
# Convert phases to unit vectors in the complex plane
|
641
|
-
unit_vectors = np.exp(1j * spike_phases)
|
642
|
-
|
643
|
-
# Calculate the resultant vector
|
644
|
-
resultant_vector = np.sum(unit_vectors)
|
645
|
-
|
646
|
-
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
647
|
-
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
648
|
-
|
649
|
-
return ppc2
|
650
|
-
|
651
|
-
|
652
296
|
# windowing functions
|
653
297
|
def windowed_xarray(da, windows, dim='time',
|
654
298
|
new_coord_name='cycle', new_coord=None):
|
@@ -683,6 +683,62 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
|
|
683
683
|
fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
|
684
684
|
plt.draw()
|
685
685
|
|
686
|
+
def distance_delay_plot(simulation_config: str,source: str,target: str,
|
687
|
+
group_by: str,sid: str,tid: str) -> None:
|
688
|
+
"""
|
689
|
+
Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
|
690
|
+
|
691
|
+
This function loads the node and edge data from a simulation configuration file, filters nodes by population or group,
|
692
|
+
identifies connections (edges) between source and target node populations, calculates the Euclidean distance between
|
693
|
+
connected nodes, and plots the delay as a function of distance.
|
694
|
+
|
695
|
+
Args:
|
696
|
+
simulation_config (str): Path to the simulation config file
|
697
|
+
source (str): The name of the source population in the edge data.
|
698
|
+
target (str): The name of the target population in the edge data.
|
699
|
+
group_by (str): Column name to group nodes by (e.g., population name).
|
700
|
+
sid (str): Identifier for the source group (e.g., 'PN').
|
701
|
+
tid (str): Identifier for the target group (e.g., 'PN').
|
702
|
+
|
703
|
+
Returns:
|
704
|
+
None: The function creates and displays a scatter plot of distance vs delay.
|
705
|
+
"""
|
706
|
+
nodes, edges = util.load_nodes_edges_from_config(simulation_config)
|
707
|
+
nodes = nodes[target]
|
708
|
+
#node id is index of nodes df
|
709
|
+
node_id_source = nodes[nodes[group_by] == sid].index
|
710
|
+
node_id_target = nodes[nodes[group_by] == tid].index
|
711
|
+
|
712
|
+
edges = edges[f'{source}_to_{target}']
|
713
|
+
edges = edges[edges['source_node_id'].isin(node_id_source) & edges['target_node_id'].isin(node_id_target)]
|
714
|
+
|
715
|
+
stuff_to_plot = []
|
716
|
+
for index, row in edges.iterrows():
|
717
|
+
try:
|
718
|
+
source_node = row['source_node_id']
|
719
|
+
target_node = row['target_node_id']
|
720
|
+
|
721
|
+
source_pos = nodes.loc[[source_node], ['pos_x', 'pos_y', 'pos_z']]
|
722
|
+
target_pos = nodes.loc[[target_node], ['pos_x', 'pos_y', 'pos_z']]
|
723
|
+
|
724
|
+
distance = np.linalg.norm(source_pos.values - target_pos.values)
|
725
|
+
|
726
|
+
delay = row['delay'] # This line may raise KeyError
|
727
|
+
stuff_to_plot.append([distance, delay])
|
728
|
+
|
729
|
+
except KeyError as e:
|
730
|
+
print(f"KeyError: Missing key {e} in either edge properties or node positions.")
|
731
|
+
except IndexError as e:
|
732
|
+
print(f"IndexError: Node ID {source_node} or {target_node} not found in nodes.")
|
733
|
+
except Exception as e:
|
734
|
+
print(f"Unexpected error at edge index {index}: {e}")
|
735
|
+
|
736
|
+
plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
|
737
|
+
plt.xlabel('Distance')
|
738
|
+
plt.ylabel('Delay')
|
739
|
+
plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
|
740
|
+
plt.show()
|
741
|
+
|
686
742
|
def plot_synapse_location_histograms(config, target_model, source=None, target=None):
|
687
743
|
"""
|
688
744
|
generates a histogram of the positions of the synapses on a cell broken down by section
|
@@ -647,7 +647,7 @@ class SynapseTuner:
|
|
647
647
|
# Widgets setup (Sliders)
|
648
648
|
freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
|
649
649
|
delays = [125, 250, 500, 1000, 2000, 4000]
|
650
|
-
durations = [300, 500, 1000, 2000, 5000, 10000]
|
650
|
+
durations = [100, 300, 500, 1000, 2000, 5000, 10000]
|
651
651
|
freq0 = 50
|
652
652
|
delay0 = 250
|
653
653
|
duration0 = 300
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|