bmtool 0.6.9.23__py3-none-any.whl → 0.6.9.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +429 -0
- bmtool/analysis/lfp.py +0 -356
- bmtool/analysis/spikes.py +53 -0
- bmtool/bmplot.py +151 -1
- bmtool/synapses.py +1 -1
- bmtool/util/util.py +28 -0
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/METADATA +1 -1
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/RECORD +12 -11
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/WHEEL +1 -1
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.6.9.23.dist-info → bmtool-0.6.9.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,429 @@
|
|
1
|
+
"""
|
2
|
+
Module for entrainment analysis
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from scipy import signal
|
7
|
+
import numba
|
8
|
+
from numba import cuda
|
9
|
+
import pandas as pd
|
10
|
+
import xarray as xr
|
11
|
+
from .lfp import wavelet_filter,butter_bandpass_filter
|
12
|
+
from typing import Dict, List
|
13
|
+
from tqdm.notebook import tqdm
|
14
|
+
|
15
|
+
|
16
|
+
def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
17
|
+
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
18
|
+
bandwidth: float = 2.0) -> np.ndarray:
|
19
|
+
"""
|
20
|
+
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
21
|
+
|
22
|
+
Parameters:
|
23
|
+
- x1, x2: Input signals (1D arrays, same length)
|
24
|
+
- fs: Sampling frequency
|
25
|
+
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
26
|
+
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
27
|
+
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
28
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- plv: Phase Locking Value (1D array)
|
32
|
+
"""
|
33
|
+
if len(x1) != len(x2):
|
34
|
+
raise ValueError("Input signals must have the same length.")
|
35
|
+
|
36
|
+
if method == 'wavelet':
|
37
|
+
if freq_of_interest is None:
|
38
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
39
|
+
|
40
|
+
# Apply CWT to both signals
|
41
|
+
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
42
|
+
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
43
|
+
|
44
|
+
elif method == 'hilbert':
|
45
|
+
if lowcut is None or highcut is None:
|
46
|
+
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
47
|
+
|
48
|
+
if lowcut and highcut:
|
49
|
+
# Bandpass filter and get the analytic signal using the Hilbert transform
|
50
|
+
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
51
|
+
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
52
|
+
|
53
|
+
# Get phase using the Hilbert transform
|
54
|
+
theta1 = signal.hilbert(x1)
|
55
|
+
theta2 = signal.hilbert(x2)
|
56
|
+
|
57
|
+
else:
|
58
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
59
|
+
|
60
|
+
# Calculate phase difference
|
61
|
+
phase_diff = np.angle(theta1) - np.angle(theta2)
|
62
|
+
|
63
|
+
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
64
|
+
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
65
|
+
|
66
|
+
return plv
|
67
|
+
|
68
|
+
|
69
|
+
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
|
70
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
71
|
+
lowcut: float = None, highcut: float = None,
|
72
|
+
bandwidth: float = 2.0) -> tuple:
|
73
|
+
"""
|
74
|
+
Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
75
|
+
|
76
|
+
Parameters:
|
77
|
+
- spike_times: Array of spike times
|
78
|
+
- lfp_signal: Local field potential time series
|
79
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
80
|
+
- lfp_fs : Sampling frequency in Hz of the LFP
|
81
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
82
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
83
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
84
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
- ppc1: Phase-Phase Coupling value
|
88
|
+
- spike_phases: Phases at spike times
|
89
|
+
"""
|
90
|
+
|
91
|
+
if spike_fs == None:
|
92
|
+
spike_fs = lfp_fs
|
93
|
+
# Convert spike times to sample indices
|
94
|
+
spike_times_seconds = spike_times / spike_fs
|
95
|
+
|
96
|
+
# Then convert from seconds to samples at the new sampling rate
|
97
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
98
|
+
|
99
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
100
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
101
|
+
if len(valid_indices) <= 1:
|
102
|
+
return 0, np.array([])
|
103
|
+
|
104
|
+
# Extract phase using the specified method
|
105
|
+
if method == 'wavelet':
|
106
|
+
if freq_of_interest is None:
|
107
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
108
|
+
|
109
|
+
# Apply CWT to extract phase at the frequency of interest
|
110
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
111
|
+
instantaneous_phase = np.angle(lfp_complex)
|
112
|
+
|
113
|
+
elif method == 'hilbert':
|
114
|
+
if lowcut is None or highcut is None:
|
115
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
|
116
|
+
filtered_lfp = lfp_signal
|
117
|
+
else:
|
118
|
+
# Bandpass filter the signal
|
119
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
120
|
+
|
121
|
+
# Get phase using the Hilbert transform
|
122
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
123
|
+
instantaneous_phase = np.angle(analytic_signal)
|
124
|
+
|
125
|
+
else:
|
126
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
127
|
+
|
128
|
+
# Get phases at spike times
|
129
|
+
spike_phases = instantaneous_phase[valid_indices]
|
130
|
+
|
131
|
+
# Calculate PPC1
|
132
|
+
n = len(spike_phases)
|
133
|
+
|
134
|
+
# Convert phases to unit vectors in the complex plane
|
135
|
+
unit_vectors = np.exp(1j * spike_phases)
|
136
|
+
|
137
|
+
# Calculate the resultant vector
|
138
|
+
resultant_vector = np.sum(unit_vectors)
|
139
|
+
|
140
|
+
# Plv is the squared length of the resultant vector divided by n²
|
141
|
+
plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
|
142
|
+
|
143
|
+
return plv
|
144
|
+
|
145
|
+
|
146
|
+
@numba.njit(parallel=True, fastmath=True)
|
147
|
+
def _ppc_parallel_numba(spike_phases):
|
148
|
+
"""Numba-optimized parallel PPC calculation"""
|
149
|
+
n = len(spike_phases)
|
150
|
+
sum_cos = 0.0
|
151
|
+
for i in numba.prange(n):
|
152
|
+
phase_i = spike_phases[i]
|
153
|
+
for j in range(i + 1, n):
|
154
|
+
sum_cos += np.cos(phase_i - spike_phases[j])
|
155
|
+
return (2 / (n * (n - 1))) * sum_cos
|
156
|
+
|
157
|
+
|
158
|
+
@cuda.jit(fastmath=True)
|
159
|
+
def _ppc_cuda_kernel(spike_phases, out):
|
160
|
+
i = cuda.grid(1)
|
161
|
+
if i < len(spike_phases):
|
162
|
+
local_sum = 0.0
|
163
|
+
for j in range(i+1, len(spike_phases)):
|
164
|
+
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
165
|
+
out[i] = local_sum
|
166
|
+
|
167
|
+
|
168
|
+
def _ppc_gpu(spike_phases):
|
169
|
+
"""GPU-accelerated implementation"""
|
170
|
+
d_phases = cuda.to_device(spike_phases)
|
171
|
+
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
172
|
+
|
173
|
+
threads = 256
|
174
|
+
blocks = (len(spike_phases) + threads - 1) // threads
|
175
|
+
|
176
|
+
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
177
|
+
total = d_out.copy_to_host().sum()
|
178
|
+
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
179
|
+
|
180
|
+
|
181
|
+
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
182
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
183
|
+
lowcut: float = None, highcut: float = None,
|
184
|
+
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
185
|
+
"""
|
186
|
+
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
187
|
+
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
188
|
+
|
189
|
+
Parameters:
|
190
|
+
- spike_times: Array of spike times
|
191
|
+
- lfp_signal: Local field potential time series
|
192
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
193
|
+
- lfp_fs: Sampling frequency in Hz of the LFP
|
194
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
195
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
196
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
197
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
198
|
+
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
- ppc: Pairwise Phase Consistency value
|
202
|
+
"""
|
203
|
+
if spike_fs is None:
|
204
|
+
spike_fs = lfp_fs
|
205
|
+
# Convert spike times to sample indices
|
206
|
+
spike_times_seconds = spike_times / spike_fs
|
207
|
+
|
208
|
+
# Then convert from seconds to samples at the new sampling rate
|
209
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
210
|
+
|
211
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
212
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
213
|
+
if len(valid_indices) <= 1:
|
214
|
+
return 0, np.array([])
|
215
|
+
|
216
|
+
# Extract phase using the specified method
|
217
|
+
if method == 'wavelet':
|
218
|
+
if freq_of_interest is None:
|
219
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
220
|
+
|
221
|
+
# Apply CWT to extract phase at the frequency of interest
|
222
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
223
|
+
instantaneous_phase = np.angle(lfp_complex)
|
224
|
+
|
225
|
+
elif method == 'hilbert':
|
226
|
+
if lowcut is None or highcut is None:
|
227
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
|
228
|
+
filtered_lfp = lfp_signal
|
229
|
+
else:
|
230
|
+
# Bandpass filter the signal
|
231
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
232
|
+
|
233
|
+
# Get phase using the Hilbert transform
|
234
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
235
|
+
instantaneous_phase = np.angle(analytic_signal)
|
236
|
+
|
237
|
+
else:
|
238
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
239
|
+
|
240
|
+
# Get phases at spike times
|
241
|
+
spike_phases = instantaneous_phase[valid_indices]
|
242
|
+
|
243
|
+
n_spikes = len(spike_phases)
|
244
|
+
|
245
|
+
# Calculate PPC (Pairwise Phase Consistency)
|
246
|
+
if n_spikes <= 1:
|
247
|
+
return 0, spike_phases
|
248
|
+
|
249
|
+
# Explicit calculation of pairwise phase consistency
|
250
|
+
sum_cos_diff = 0.0
|
251
|
+
|
252
|
+
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
253
|
+
# for i in range(n_spikes - 1): # For each spike i
|
254
|
+
# for j in range(i + 1, n_spikes): # For each spike j > i
|
255
|
+
# # Calculate the phase difference between spikes i and j
|
256
|
+
# phase_diff = spike_phases[i] - spike_phases[j]
|
257
|
+
|
258
|
+
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
259
|
+
# cos_diff = np.cos(phase_diff)
|
260
|
+
|
261
|
+
# # Add to the sum
|
262
|
+
# sum_cos_diff += cos_diff
|
263
|
+
|
264
|
+
# # Calculate PPC according to the equation
|
265
|
+
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
266
|
+
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
267
|
+
|
268
|
+
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
269
|
+
if ppc_method == 'numpy':
|
270
|
+
i, j = np.triu_indices(n_spikes, k=1)
|
271
|
+
phase_diff = spike_phases[i] - spike_phases[j]
|
272
|
+
sum_cos_diff = np.sum(np.cos(phase_diff))
|
273
|
+
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
274
|
+
elif ppc_method == 'numba':
|
275
|
+
ppc = _ppc_parallel_numba(spike_phases)
|
276
|
+
elif ppc_method == 'gpu':
|
277
|
+
ppc = _ppc_gpu(spike_phases)
|
278
|
+
else:
|
279
|
+
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
280
|
+
return ppc
|
281
|
+
|
282
|
+
|
283
|
+
def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
284
|
+
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
285
|
+
lowcut: float = None, highcut: float = None,
|
286
|
+
bandwidth: float = 2.0) -> tuple:
|
287
|
+
"""
|
288
|
+
# -----------------------------------------------------------------------------
|
289
|
+
# PPC2 Calculation (Vinck et al., 2010)
|
290
|
+
# -----------------------------------------------------------------------------
|
291
|
+
# Equation(Original):
|
292
|
+
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
293
|
+
# Optimized Formula (Algebraically Equivalent):
|
294
|
+
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
295
|
+
# -----------------------------------------------------------------------------
|
296
|
+
|
297
|
+
Parameters:
|
298
|
+
- spike_times: Array of spike times
|
299
|
+
- lfp_signal: Local field potential time series
|
300
|
+
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
301
|
+
- lfp_fs: Sampling frequency in Hz of the LFP
|
302
|
+
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
303
|
+
- freq_of_interest: Desired frequency for wavelet phase extraction
|
304
|
+
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
305
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
- ppc2: Pairwise Phase Consistency 2 value
|
309
|
+
"""
|
310
|
+
|
311
|
+
if spike_fs is None:
|
312
|
+
spike_fs = lfp_fs
|
313
|
+
# Convert spike times to sample indices
|
314
|
+
spike_times_seconds = spike_times / spike_fs
|
315
|
+
|
316
|
+
# Then convert from seconds to samples at the new sampling rate
|
317
|
+
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
318
|
+
|
319
|
+
# Filter indices to ensure they're within bounds of the LFP signal
|
320
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
321
|
+
if len(valid_indices) <= 1:
|
322
|
+
return 0, np.array([])
|
323
|
+
|
324
|
+
# Extract phase using the specified method
|
325
|
+
if method == 'wavelet':
|
326
|
+
if freq_of_interest is None:
|
327
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
328
|
+
|
329
|
+
# Apply CWT to extract phase at the frequency of interest
|
330
|
+
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
331
|
+
instantaneous_phase = np.angle(lfp_complex)
|
332
|
+
|
333
|
+
elif method == 'hilbert':
|
334
|
+
if lowcut is None or highcut is None:
|
335
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
|
336
|
+
filtered_lfp = lfp_signal
|
337
|
+
else:
|
338
|
+
# Bandpass filter the signal
|
339
|
+
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
340
|
+
|
341
|
+
# Get phase using the Hilbert transform
|
342
|
+
analytic_signal = signal.hilbert(filtered_lfp)
|
343
|
+
instantaneous_phase = np.angle(analytic_signal)
|
344
|
+
|
345
|
+
else:
|
346
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
347
|
+
|
348
|
+
# Get phases at spike times
|
349
|
+
spike_phases = instantaneous_phase[valid_indices]
|
350
|
+
|
351
|
+
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
352
|
+
n = len(spike_phases)
|
353
|
+
|
354
|
+
if n <= 1:
|
355
|
+
return 0, spike_phases
|
356
|
+
|
357
|
+
# Convert phases to unit vectors in the complex plane
|
358
|
+
unit_vectors = np.exp(1j * spike_phases)
|
359
|
+
|
360
|
+
# Calculate the resultant vector
|
361
|
+
resultant_vector = np.sum(unit_vectors)
|
362
|
+
|
363
|
+
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
364
|
+
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
365
|
+
|
366
|
+
return ppc2
|
367
|
+
|
368
|
+
|
369
|
+
def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
|
370
|
+
spike_fs: float, lfp_fs:float,
|
371
|
+
pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
|
372
|
+
"""
|
373
|
+
Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
|
374
|
+
|
375
|
+
This function computes the PPC for each neuron within the specified populations based on their spike times
|
376
|
+
and a single-channel local field potential (LFP) signal.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
|
380
|
+
lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
|
381
|
+
spike_fs (float): sampling rate of spikes BMTK default is 1000
|
382
|
+
lfp_fs (float): sampling rate of lfp
|
383
|
+
pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
|
384
|
+
freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
|
388
|
+
{
|
389
|
+
population_name: {
|
390
|
+
node_id: {
|
391
|
+
frequency: PPC value
|
392
|
+
}
|
393
|
+
}
|
394
|
+
}
|
395
|
+
PPC values are floats representing the pairwise phase consistency at each frequency.
|
396
|
+
"""
|
397
|
+
ppc_dict = {}
|
398
|
+
for pop in pop_names:
|
399
|
+
skip_count = 0
|
400
|
+
pop_spikes = spike_df[spike_df['pop_name'] == pop]
|
401
|
+
nodes = pop_spikes['node_ids'].unique()
|
402
|
+
ppc_dict[pop] = {}
|
403
|
+
print(f'Processing {pop} population')
|
404
|
+
for node in tqdm(nodes):
|
405
|
+
node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
|
406
|
+
|
407
|
+
# Skip nodes with less than or equal to 1 spike
|
408
|
+
if len(node_spikes) <= 1:
|
409
|
+
skip_count += 1
|
410
|
+
continue
|
411
|
+
|
412
|
+
ppc_dict[pop][node] = {}
|
413
|
+
for freq in freqs:
|
414
|
+
ppc = calculate_ppc2(
|
415
|
+
node_spikes['timestamps'].values,
|
416
|
+
lfp_signal,
|
417
|
+
spike_fs=spike_fs,
|
418
|
+
lfp_fs=lfp_fs,
|
419
|
+
freq_of_interest=freq,
|
420
|
+
method='wavelet'
|
421
|
+
)
|
422
|
+
ppc_dict[pop][node][freq] = ppc
|
423
|
+
|
424
|
+
print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
|
425
|
+
|
426
|
+
return ppc_dict
|
427
|
+
|
428
|
+
|
429
|
+
|
bmtool/analysis/lfp.py
CHANGED
@@ -11,8 +11,6 @@ import matplotlib.pyplot as plt
|
|
11
11
|
from scipy import signal
|
12
12
|
import pywt
|
13
13
|
from bmtool.bmplot import is_notebook
|
14
|
-
import numba
|
15
|
-
from numba import cuda
|
16
14
|
import pandas as pd
|
17
15
|
|
18
16
|
|
@@ -295,360 +293,6 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
|
|
295
293
|
return x_a
|
296
294
|
|
297
295
|
|
298
|
-
def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
299
|
-
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
300
|
-
bandwidth: float = 2.0) -> np.ndarray:
|
301
|
-
"""
|
302
|
-
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
303
|
-
|
304
|
-
Parameters:
|
305
|
-
- x1, x2: Input signals (1D arrays, same length)
|
306
|
-
- fs: Sampling frequency
|
307
|
-
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
308
|
-
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
309
|
-
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
310
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
311
|
-
|
312
|
-
Returns:
|
313
|
-
- plv: Phase Locking Value (1D array)
|
314
|
-
"""
|
315
|
-
if len(x1) != len(x2):
|
316
|
-
raise ValueError("Input signals must have the same length.")
|
317
|
-
|
318
|
-
if method == 'wavelet':
|
319
|
-
if freq_of_interest is None:
|
320
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
321
|
-
|
322
|
-
# Apply CWT to both signals
|
323
|
-
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
324
|
-
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
325
|
-
|
326
|
-
elif method == 'hilbert':
|
327
|
-
if lowcut is None or highcut is None:
|
328
|
-
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
329
|
-
|
330
|
-
if lowcut and highcut:
|
331
|
-
# Bandpass filter and get the analytic signal using the Hilbert transform
|
332
|
-
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
333
|
-
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
334
|
-
|
335
|
-
# Get phase using the Hilbert transform
|
336
|
-
theta1 = signal.hilbert(x1)
|
337
|
-
theta2 = signal.hilbert(x2)
|
338
|
-
|
339
|
-
else:
|
340
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
341
|
-
|
342
|
-
# Calculate phase difference
|
343
|
-
phase_diff = np.angle(theta1) - np.angle(theta2)
|
344
|
-
|
345
|
-
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
346
|
-
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
347
|
-
|
348
|
-
return plv
|
349
|
-
|
350
|
-
|
351
|
-
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
|
352
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
353
|
-
lowcut: float = None, highcut: float = None,
|
354
|
-
bandwidth: float = 2.0) -> tuple:
|
355
|
-
"""
|
356
|
-
Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
357
|
-
|
358
|
-
Parameters:
|
359
|
-
- spike_times: Array of spike times
|
360
|
-
- lfp_signal: Local field potential time series
|
361
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
362
|
-
- lfp_fs : Sampling frequency in Hz of the LFP
|
363
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
364
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
365
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
366
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
367
|
-
|
368
|
-
Returns:
|
369
|
-
- ppc1: Phase-Phase Coupling value
|
370
|
-
- spike_phases: Phases at spike times
|
371
|
-
"""
|
372
|
-
|
373
|
-
if spike_fs == None:
|
374
|
-
spike_fs = lfp_fs
|
375
|
-
# Convert spike times to sample indices
|
376
|
-
spike_times_seconds = spike_times / spike_fs
|
377
|
-
|
378
|
-
# Then convert from seconds to samples at the new sampling rate
|
379
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
380
|
-
|
381
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
382
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
383
|
-
if len(valid_indices) <= 1:
|
384
|
-
return 0, np.array([])
|
385
|
-
|
386
|
-
# Extract phase using the specified method
|
387
|
-
if method == 'wavelet':
|
388
|
-
if freq_of_interest is None:
|
389
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
390
|
-
|
391
|
-
# Apply CWT to extract phase at the frequency of interest
|
392
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
393
|
-
instantaneous_phase = np.angle(lfp_complex)
|
394
|
-
|
395
|
-
elif method == 'hilbert':
|
396
|
-
if lowcut is None or highcut is None:
|
397
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
|
398
|
-
filtered_lfp = lfp_signal
|
399
|
-
else:
|
400
|
-
# Bandpass filter the signal
|
401
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
402
|
-
|
403
|
-
# Get phase using the Hilbert transform
|
404
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
405
|
-
instantaneous_phase = np.angle(analytic_signal)
|
406
|
-
|
407
|
-
else:
|
408
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
409
|
-
|
410
|
-
# Get phases at spike times
|
411
|
-
spike_phases = instantaneous_phase[valid_indices]
|
412
|
-
|
413
|
-
# Calculate PPC1
|
414
|
-
n = len(spike_phases)
|
415
|
-
|
416
|
-
# Convert phases to unit vectors in the complex plane
|
417
|
-
unit_vectors = np.exp(1j * spike_phases)
|
418
|
-
|
419
|
-
# Calculate the resultant vector
|
420
|
-
resultant_vector = np.sum(unit_vectors)
|
421
|
-
|
422
|
-
# Plv is the squared length of the resultant vector divided by n²
|
423
|
-
plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
|
424
|
-
|
425
|
-
return plv
|
426
|
-
|
427
|
-
|
428
|
-
@numba.njit(parallel=True, fastmath=True)
|
429
|
-
def _ppc_parallel_numba(spike_phases):
|
430
|
-
"""Numba-optimized parallel PPC calculation"""
|
431
|
-
n = len(spike_phases)
|
432
|
-
sum_cos = 0.0
|
433
|
-
for i in numba.prange(n):
|
434
|
-
phase_i = spike_phases[i]
|
435
|
-
for j in range(i + 1, n):
|
436
|
-
sum_cos += np.cos(phase_i - spike_phases[j])
|
437
|
-
return (2 / (n * (n - 1))) * sum_cos
|
438
|
-
|
439
|
-
|
440
|
-
@cuda.jit(fastmath=True)
|
441
|
-
def _ppc_cuda_kernel(spike_phases, out):
|
442
|
-
i = cuda.grid(1)
|
443
|
-
if i < len(spike_phases):
|
444
|
-
local_sum = 0.0
|
445
|
-
for j in range(i+1, len(spike_phases)):
|
446
|
-
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
447
|
-
out[i] = local_sum
|
448
|
-
|
449
|
-
|
450
|
-
def _ppc_gpu(spike_phases):
|
451
|
-
"""GPU-accelerated implementation"""
|
452
|
-
d_phases = cuda.to_device(spike_phases)
|
453
|
-
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
454
|
-
|
455
|
-
threads = 256
|
456
|
-
blocks = (len(spike_phases) + threads - 1) // threads
|
457
|
-
|
458
|
-
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
459
|
-
total = d_out.copy_to_host().sum()
|
460
|
-
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
461
|
-
|
462
|
-
|
463
|
-
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
464
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
465
|
-
lowcut: float = None, highcut: float = None,
|
466
|
-
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
467
|
-
"""
|
468
|
-
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
469
|
-
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
470
|
-
|
471
|
-
Parameters:
|
472
|
-
- spike_times: Array of spike times
|
473
|
-
- lfp_signal: Local field potential time series
|
474
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
475
|
-
- lfp_fs: Sampling frequency in Hz of the LFP
|
476
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
477
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
478
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
479
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
480
|
-
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
481
|
-
|
482
|
-
Returns:
|
483
|
-
- ppc: Pairwise Phase Consistency value
|
484
|
-
"""
|
485
|
-
if spike_fs is None:
|
486
|
-
spike_fs = lfp_fs
|
487
|
-
# Convert spike times to sample indices
|
488
|
-
spike_times_seconds = spike_times / spike_fs
|
489
|
-
|
490
|
-
# Then convert from seconds to samples at the new sampling rate
|
491
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
492
|
-
|
493
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
494
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
495
|
-
if len(valid_indices) <= 1:
|
496
|
-
return 0, np.array([])
|
497
|
-
|
498
|
-
# Extract phase using the specified method
|
499
|
-
if method == 'wavelet':
|
500
|
-
if freq_of_interest is None:
|
501
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
502
|
-
|
503
|
-
# Apply CWT to extract phase at the frequency of interest
|
504
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
505
|
-
instantaneous_phase = np.angle(lfp_complex)
|
506
|
-
|
507
|
-
elif method == 'hilbert':
|
508
|
-
if lowcut is None or highcut is None:
|
509
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
|
510
|
-
filtered_lfp = lfp_signal
|
511
|
-
else:
|
512
|
-
# Bandpass filter the signal
|
513
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
514
|
-
|
515
|
-
# Get phase using the Hilbert transform
|
516
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
517
|
-
instantaneous_phase = np.angle(analytic_signal)
|
518
|
-
|
519
|
-
else:
|
520
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
521
|
-
|
522
|
-
# Get phases at spike times
|
523
|
-
spike_phases = instantaneous_phase[valid_indices]
|
524
|
-
|
525
|
-
n_spikes = len(spike_phases)
|
526
|
-
|
527
|
-
# Calculate PPC (Pairwise Phase Consistency)
|
528
|
-
if n_spikes <= 1:
|
529
|
-
return 0, spike_phases
|
530
|
-
|
531
|
-
# Explicit calculation of pairwise phase consistency
|
532
|
-
sum_cos_diff = 0.0
|
533
|
-
|
534
|
-
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
535
|
-
# for i in range(n_spikes - 1): # For each spike i
|
536
|
-
# for j in range(i + 1, n_spikes): # For each spike j > i
|
537
|
-
# # Calculate the phase difference between spikes i and j
|
538
|
-
# phase_diff = spike_phases[i] - spike_phases[j]
|
539
|
-
|
540
|
-
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
541
|
-
# cos_diff = np.cos(phase_diff)
|
542
|
-
|
543
|
-
# # Add to the sum
|
544
|
-
# sum_cos_diff += cos_diff
|
545
|
-
|
546
|
-
# # Calculate PPC according to the equation
|
547
|
-
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
548
|
-
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
549
|
-
|
550
|
-
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
551
|
-
if ppc_method == 'numpy':
|
552
|
-
i, j = np.triu_indices(n_spikes, k=1)
|
553
|
-
phase_diff = spike_phases[i] - spike_phases[j]
|
554
|
-
sum_cos_diff = np.sum(np.cos(phase_diff))
|
555
|
-
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
556
|
-
elif ppc_method == 'numba':
|
557
|
-
ppc = _ppc_parallel_numba(spike_phases)
|
558
|
-
elif ppc_method == 'gpu':
|
559
|
-
ppc = _ppc_gpu(spike_phases)
|
560
|
-
else:
|
561
|
-
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
562
|
-
return ppc
|
563
|
-
|
564
|
-
|
565
|
-
def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
566
|
-
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
567
|
-
lowcut: float = None, highcut: float = None,
|
568
|
-
bandwidth: float = 2.0) -> tuple:
|
569
|
-
"""
|
570
|
-
# -----------------------------------------------------------------------------
|
571
|
-
# PPC2 Calculation (Vinck et al., 2010)
|
572
|
-
# -----------------------------------------------------------------------------
|
573
|
-
# Equation(Original):
|
574
|
-
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
575
|
-
# Optimized Formula (Algebraically Equivalent):
|
576
|
-
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
577
|
-
# -----------------------------------------------------------------------------
|
578
|
-
|
579
|
-
Parameters:
|
580
|
-
- spike_times: Array of spike times
|
581
|
-
- lfp_signal: Local field potential time series
|
582
|
-
- spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
|
583
|
-
- lfp_fs: Sampling frequency in Hz of the LFP
|
584
|
-
- method: 'wavelet' or 'hilbert' to choose the phase extraction method
|
585
|
-
- freq_of_interest: Desired frequency for wavelet phase extraction
|
586
|
-
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
587
|
-
- bandwidth: Bandwidth parameter for the wavelet
|
588
|
-
|
589
|
-
Returns:
|
590
|
-
- ppc2: Pairwise Phase Consistency 2 value
|
591
|
-
- spike_phases: Phases at spike times
|
592
|
-
"""
|
593
|
-
|
594
|
-
if spike_fs is None:
|
595
|
-
spike_fs = lfp_fs
|
596
|
-
# Convert spike times to sample indices
|
597
|
-
spike_times_seconds = spike_times / spike_fs
|
598
|
-
|
599
|
-
# Then convert from seconds to samples at the new sampling rate
|
600
|
-
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
601
|
-
|
602
|
-
# Filter indices to ensure they're within bounds of the LFP signal
|
603
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
|
604
|
-
if len(valid_indices) <= 1:
|
605
|
-
return 0, np.array([])
|
606
|
-
|
607
|
-
# Extract phase using the specified method
|
608
|
-
if method == 'wavelet':
|
609
|
-
if freq_of_interest is None:
|
610
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
611
|
-
|
612
|
-
# Apply CWT to extract phase at the frequency of interest
|
613
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
614
|
-
instantaneous_phase = np.angle(lfp_complex)
|
615
|
-
|
616
|
-
elif method == 'hilbert':
|
617
|
-
if lowcut is None or highcut is None:
|
618
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
|
619
|
-
filtered_lfp = lfp_signal
|
620
|
-
else:
|
621
|
-
# Bandpass filter the signal
|
622
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
623
|
-
|
624
|
-
# Get phase using the Hilbert transform
|
625
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
626
|
-
instantaneous_phase = np.angle(analytic_signal)
|
627
|
-
|
628
|
-
else:
|
629
|
-
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
630
|
-
|
631
|
-
# Get phases at spike times
|
632
|
-
spike_phases = instantaneous_phase[valid_indices]
|
633
|
-
|
634
|
-
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
635
|
-
n = len(spike_phases)
|
636
|
-
|
637
|
-
if n <= 1:
|
638
|
-
return 0, spike_phases
|
639
|
-
|
640
|
-
# Convert phases to unit vectors in the complex plane
|
641
|
-
unit_vectors = np.exp(1j * spike_phases)
|
642
|
-
|
643
|
-
# Calculate the resultant vector
|
644
|
-
resultant_vector = np.sum(unit_vectors)
|
645
|
-
|
646
|
-
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
647
|
-
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
648
|
-
|
649
|
-
return ppc2
|
650
|
-
|
651
|
-
|
652
296
|
# windowing functions
|
653
297
|
def windowed_xarray(da, windows, dim='time',
|
654
298
|
new_coord_name='cycle', new_coord=None):
|
bmtool/analysis/spikes.py
CHANGED
@@ -7,6 +7,7 @@ import pandas as pd
|
|
7
7
|
from bmtool.util.util import load_nodes_from_config
|
8
8
|
from typing import Dict, Optional,Tuple, Union, List
|
9
9
|
import numpy as np
|
10
|
+
from scipy.stats import mannwhitneyu
|
10
11
|
import os
|
11
12
|
|
12
13
|
|
@@ -252,3 +253,55 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
252
253
|
|
253
254
|
return spike_rate
|
254
255
|
|
256
|
+
|
257
|
+
def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
|
258
|
+
"""
|
259
|
+
Compares the firing rates of a population during two different time windows
|
260
|
+
time_window_1 and time_window_2 should be a list of [start, stop] in milliseconds
|
261
|
+
Returns firing rates and results of a Mann-Whitney U test (non-parametric)
|
262
|
+
"""
|
263
|
+
# Filter spikes for the population of interest
|
264
|
+
for pop_name in spike_df[group_by].unique():
|
265
|
+
print(f"Population: {pop_name}")
|
266
|
+
pop_spikes = spike_df[spike_df[group_by] == pop_name]
|
267
|
+
|
268
|
+
# Filter by time windows
|
269
|
+
pop_spikes_1 = pop_spikes[(pop_spikes['timestamps'] >= time_window_1[0]) & (pop_spikes['timestamps'] <= time_window_1[1])]
|
270
|
+
pop_spikes_2 = pop_spikes[(pop_spikes['timestamps'] >= time_window_2[0]) & (pop_spikes['timestamps'] <= time_window_2[1])]
|
271
|
+
|
272
|
+
# Get unique neuron IDs
|
273
|
+
unique_neurons = pop_spikes['node_ids'].unique()
|
274
|
+
|
275
|
+
# Calculate firing rates per neuron for each time window in Hz
|
276
|
+
neuron_rates_1 = []
|
277
|
+
neuron_rates_2 = []
|
278
|
+
|
279
|
+
for neuron in unique_neurons:
|
280
|
+
# Count spikes for this neuron in each window
|
281
|
+
n_spikes_1 = len(pop_spikes_1[pop_spikes_1['node_ids'] == neuron])
|
282
|
+
n_spikes_2 = len(pop_spikes_2[pop_spikes_2['node_ids'] == neuron])
|
283
|
+
|
284
|
+
# Calculate firing rate in Hz (convert ms to seconds by dividing by 1000)
|
285
|
+
rate_1 = n_spikes_1 / ((time_window_1[1] - time_window_1[0]) / 1000)
|
286
|
+
rate_2 = n_spikes_2 / ((time_window_2[1] - time_window_2[0]) / 1000)
|
287
|
+
|
288
|
+
neuron_rates_1.append(rate_1)
|
289
|
+
neuron_rates_2.append(rate_2)
|
290
|
+
|
291
|
+
# Calculate average firing rates
|
292
|
+
avg_firing_rate_1 = np.mean(neuron_rates_1) if neuron_rates_1 else 0
|
293
|
+
avg_firing_rate_2 = np.mean(neuron_rates_2) if neuron_rates_2 else 0
|
294
|
+
|
295
|
+
# Perform Mann-Whitney U test
|
296
|
+
# Handle the case when one or both arrays are empty
|
297
|
+
if len(neuron_rates_1) > 0 and len(neuron_rates_2) > 0:
|
298
|
+
u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative='two-sided')
|
299
|
+
else:
|
300
|
+
u_stat, p_val = np.nan, np.nan
|
301
|
+
|
302
|
+
print(f" Average firing rate in window 1: {avg_firing_rate_1:.2f} Hz")
|
303
|
+
print(f" Average firing rate in window 2: {avg_firing_rate_2:.2f} Hz")
|
304
|
+
print(f" U-statistic: {u_stat:.2f}")
|
305
|
+
print(f" p-value: {p_val}")
|
306
|
+
print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
|
307
|
+
return
|
bmtool/bmplot.py
CHANGED
@@ -23,8 +23,9 @@ import sys
|
|
23
23
|
import re
|
24
24
|
from typing import Optional, Dict, Union, List
|
25
25
|
|
26
|
-
from .util.util import CellVarsFile,load_nodes_from_config #, missing_units
|
26
|
+
from .util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
|
27
27
|
from bmtk.analyzer.utils import listify
|
28
|
+
from neuron import h
|
28
29
|
|
29
30
|
use_description = """
|
30
31
|
|
@@ -682,6 +683,155 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
|
|
682
683
|
fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
|
683
684
|
plt.draw()
|
684
685
|
|
686
|
+
def distance_delay_plot(simulation_config: str,source: str,target: str,
|
687
|
+
group_by: str,sid: str,tid: str) -> None:
|
688
|
+
"""
|
689
|
+
Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
|
690
|
+
|
691
|
+
This function loads the node and edge data from a simulation configuration file, filters nodes by population or group,
|
692
|
+
identifies connections (edges) between source and target node populations, calculates the Euclidean distance between
|
693
|
+
connected nodes, and plots the delay as a function of distance.
|
694
|
+
|
695
|
+
Args:
|
696
|
+
simulation_config (str): Path to the simulation config file
|
697
|
+
source (str): The name of the source population in the edge data.
|
698
|
+
target (str): The name of the target population in the edge data.
|
699
|
+
group_by (str): Column name to group nodes by (e.g., population name).
|
700
|
+
sid (str): Identifier for the source group (e.g., 'PN').
|
701
|
+
tid (str): Identifier for the target group (e.g., 'PN').
|
702
|
+
|
703
|
+
Returns:
|
704
|
+
None: The function creates and displays a scatter plot of distance vs delay.
|
705
|
+
"""
|
706
|
+
nodes, edges = util.load_nodes_edges_from_config(simulation_config)
|
707
|
+
nodes = nodes[target]
|
708
|
+
#node id is index of nodes df
|
709
|
+
node_id_source = nodes[nodes[group_by] == sid].index
|
710
|
+
node_id_target = nodes[nodes[group_by] == tid].index
|
711
|
+
|
712
|
+
edges = edges[f'{source}_to_{target}']
|
713
|
+
edges = edges[edges['source_node_id'].isin(node_id_source) & edges['target_node_id'].isin(node_id_target)]
|
714
|
+
|
715
|
+
stuff_to_plot = []
|
716
|
+
for index, row in edges.iterrows():
|
717
|
+
try:
|
718
|
+
source_node = row['source_node_id']
|
719
|
+
target_node = row['target_node_id']
|
720
|
+
|
721
|
+
source_pos = nodes.loc[[source_node], ['pos_x', 'pos_y', 'pos_z']]
|
722
|
+
target_pos = nodes.loc[[target_node], ['pos_x', 'pos_y', 'pos_z']]
|
723
|
+
|
724
|
+
distance = np.linalg.norm(source_pos.values - target_pos.values)
|
725
|
+
|
726
|
+
delay = row['delay'] # This line may raise KeyError
|
727
|
+
stuff_to_plot.append([distance, delay])
|
728
|
+
|
729
|
+
except KeyError as e:
|
730
|
+
print(f"KeyError: Missing key {e} in either edge properties or node positions.")
|
731
|
+
except IndexError as e:
|
732
|
+
print(f"IndexError: Node ID {source_node} or {target_node} not found in nodes.")
|
733
|
+
except Exception as e:
|
734
|
+
print(f"Unexpected error at edge index {index}: {e}")
|
735
|
+
|
736
|
+
plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
|
737
|
+
plt.xlabel('Distance')
|
738
|
+
plt.ylabel('Delay')
|
739
|
+
plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
|
740
|
+
plt.show()
|
741
|
+
|
742
|
+
def plot_synapse_location_histograms(config, target_model, source=None, target=None):
|
743
|
+
"""
|
744
|
+
generates a histogram of the positions of the synapses on a cell broken down by section
|
745
|
+
config: a BMTK config
|
746
|
+
target_model: the name of the model_template used when building the BMTK node
|
747
|
+
source: The source BMTK network
|
748
|
+
target: The target BMTK network
|
749
|
+
"""
|
750
|
+
# Load mechanisms and template
|
751
|
+
|
752
|
+
util.load_templates_from_config(config)
|
753
|
+
|
754
|
+
# Load node and edge data
|
755
|
+
nodes, edges = util.load_nodes_edges_from_config(config)
|
756
|
+
nodes = nodes[source]
|
757
|
+
edges = edges[f'{source}_to_{target}']
|
758
|
+
|
759
|
+
# Map target_node_id to model_template
|
760
|
+
edges['target_model_template'] = edges['target_node_id'].map(nodes['model_template'])
|
761
|
+
|
762
|
+
# Map source_node_id to pop_name
|
763
|
+
edges['source_pop_name'] = edges['source_node_id'].map(nodes['pop_name'])
|
764
|
+
|
765
|
+
edges = edges[edges['target_model_template'] == target_model]
|
766
|
+
|
767
|
+
# Create the cell model from target model
|
768
|
+
cell = getattr(h, target_model.split(':')[1])()
|
769
|
+
|
770
|
+
# Create a mapping from section index to section name
|
771
|
+
section_id_to_name = {}
|
772
|
+
for idx, sec in enumerate(cell.all):
|
773
|
+
section_id_to_name[idx] = sec.name()
|
774
|
+
|
775
|
+
# Add a new column with section names based on afferent_section_id
|
776
|
+
edges['afferent_section_name'] = edges['afferent_section_id'].map(section_id_to_name)
|
777
|
+
|
778
|
+
# Get unique sections and source populations
|
779
|
+
unique_pops = edges['source_pop_name'].unique()
|
780
|
+
|
781
|
+
# Filter to only include sections with data
|
782
|
+
section_counts = edges['afferent_section_name'].value_counts()
|
783
|
+
sections_with_data = section_counts[section_counts > 0].index.tolist()
|
784
|
+
|
785
|
+
|
786
|
+
# Create a figure with subplots for each section
|
787
|
+
plt.figure(figsize=(8,12))
|
788
|
+
|
789
|
+
# Color map for source populations
|
790
|
+
color_map = plt.cm.tab10(np.linspace(0, 1, len(unique_pops)))
|
791
|
+
pop_colors = {pop: color for pop, color in zip(unique_pops, color_map)}
|
792
|
+
|
793
|
+
# Create a histogram for each section
|
794
|
+
for i, section in enumerate(sections_with_data):
|
795
|
+
ax = plt.subplot(len(sections_with_data), 1, i+1)
|
796
|
+
|
797
|
+
# Get data for this section
|
798
|
+
section_data = edges[edges['afferent_section_name'] == section]
|
799
|
+
|
800
|
+
# Group by source population
|
801
|
+
for pop_name, pop_group in section_data.groupby('source_pop_name'):
|
802
|
+
if len(pop_group) > 0:
|
803
|
+
ax.hist(pop_group['afferent_section_pos'], bins=15, alpha=0.7,
|
804
|
+
label=pop_name, color=pop_colors[pop_name])
|
805
|
+
|
806
|
+
# Set title and labels
|
807
|
+
ax.set_title(f"{section}", fontsize=10)
|
808
|
+
ax.set_xlabel('Section Position', fontsize=8)
|
809
|
+
ax.set_ylabel('Frequency', fontsize=8)
|
810
|
+
ax.tick_params(labelsize=7)
|
811
|
+
ax.grid(True, alpha=0.3)
|
812
|
+
|
813
|
+
# Only add legend to the first plot
|
814
|
+
if i == 0:
|
815
|
+
ax.legend(fontsize=8)
|
816
|
+
|
817
|
+
plt.tight_layout()
|
818
|
+
plt.suptitle('Connection Distribution by Cell Section and Source Population', fontsize=16, y=1.02)
|
819
|
+
if is_notebook:
|
820
|
+
plt.show()
|
821
|
+
else:
|
822
|
+
pass
|
823
|
+
|
824
|
+
# Create a summary table
|
825
|
+
print("Summary of connections by section and source population:")
|
826
|
+
pivot_table = edges.pivot_table(
|
827
|
+
values='afferent_section_id',
|
828
|
+
index='afferent_section_name',
|
829
|
+
columns='source_pop_name',
|
830
|
+
aggfunc='count',
|
831
|
+
fill_value=0
|
832
|
+
)
|
833
|
+
print(pivot_table)
|
834
|
+
|
685
835
|
def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
|
686
836
|
"""
|
687
837
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
bmtool/synapses.py
CHANGED
@@ -647,7 +647,7 @@ class SynapseTuner:
|
|
647
647
|
# Widgets setup (Sliders)
|
648
648
|
freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
|
649
649
|
delays = [125, 250, 500, 1000, 2000, 4000]
|
650
|
-
durations = [300, 500, 1000, 2000, 5000, 10000]
|
650
|
+
durations = [100, 300, 500, 1000, 2000, 5000, 10000]
|
651
651
|
freq0 = 50
|
652
652
|
delay0 = 250
|
653
653
|
duration0 = 300
|
bmtool/util/util.py
CHANGED
@@ -6,6 +6,8 @@ import numpy as np
|
|
6
6
|
from numpy import genfromtxt
|
7
7
|
import h5py
|
8
8
|
import pandas as pd
|
9
|
+
import neuron
|
10
|
+
from neuron import h
|
9
11
|
|
10
12
|
#from bmtk.utils.io.cell_vars import CellVarsFile
|
11
13
|
#from bmtk.analyzer.cell_vars import _get_cell_report
|
@@ -392,6 +394,32 @@ def load_edges_from_paths(edge_paths):#network_dir='network'):
|
|
392
394
|
|
393
395
|
return edges_dict
|
394
396
|
|
397
|
+
def load_mechanisms_from_config(config=None):
|
398
|
+
"""
|
399
|
+
loads neuron mechanisms from BMTK config
|
400
|
+
"""
|
401
|
+
if config is None:
|
402
|
+
config = 'simulation_config.json'
|
403
|
+
config = load_config(config)
|
404
|
+
return neuron.load_mechanisms(config['components']['mechanisms_dir'])
|
405
|
+
|
406
|
+
def load_templates_from_config(config=None):
|
407
|
+
if config is None:
|
408
|
+
config = 'simulation_config.json'
|
409
|
+
config = load_config(config)
|
410
|
+
load_mechanisms_from_config(config)
|
411
|
+
return load_templates_from_paths(config['components']['templates_dir'])
|
412
|
+
|
413
|
+
def load_templates_from_paths(template_paths):
|
414
|
+
# load all the files in the templates dir
|
415
|
+
for item in os.listdir(template_paths):
|
416
|
+
item_path = os.path.join(template_paths, item)
|
417
|
+
if os.path.isfile(item_path):
|
418
|
+
print(f"loading {item_path}")
|
419
|
+
h.load_file(item_path)
|
420
|
+
|
421
|
+
|
422
|
+
|
395
423
|
def cell_positions_by_id(config=None, nodes=None, populations=[], popids=[], prepend_pop=True):
|
396
424
|
"""
|
397
425
|
Returns a dictionary of arrays of arrays {"population_popid":[[1,2,3],[1,2,4]],...
|
@@ -1,28 +1,29 @@
|
|
1
1
|
bmtool/SLURM.py,sha256=PST_jOD5ZmwbJj15Tgq3UIvdq4FYN4EkPuDt66P8OXU,20136
|
2
2
|
bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
|
3
3
|
bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
|
4
|
-
bmtool/bmplot.py,sha256=
|
4
|
+
bmtool/bmplot.py,sha256=GmXn4qAlgkPwhM9fwUcVKSbJDMRJBWiH6U90oE03ZPE,68757
|
5
5
|
bmtool/connectors.py,sha256=uLhZIjur0_jWOtSZ9w6-PHftB9Xj6FFXWL5tndEMDYY,73570
|
6
6
|
bmtool/graphs.py,sha256=ShBgJr1iZrM3ugU2wT6hbhmBAkc3mmf7yZQfPuPEqPM,6691
|
7
7
|
bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
|
8
8
|
bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
9
9
|
bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
|
10
|
-
bmtool/synapses.py,sha256=
|
10
|
+
bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
|
11
11
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bmtool/analysis/
|
12
|
+
bmtool/analysis/entrainment.py,sha256=GCytdEdCbHwF5yJ4vfK947YLxbBfA88-zSnf2n26fjk,17618
|
13
|
+
bmtool/analysis/lfp.py,sha256=pMXhqWO5TB-B1cIkR9ZhPgJZi7zpiqrGdRb-JNjjI2Y,18707
|
13
14
|
bmtool/analysis/netcon_reports.py,sha256=WWh12H9gjEZXhI_q7RErgGQ9iSPoTvCUnUjwNGxRwsY,3071
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
15
|
+
bmtool/analysis/spikes.py,sha256=x24kd0RUhumJkiunfHNEE7mM6JUqdWy1gqabmkMM4cU,14129
|
15
16
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
17
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
17
18
|
bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
|
18
19
|
bmtool/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
20
|
bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
20
|
-
bmtool/util/util.py,sha256=
|
21
|
+
bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
|
21
22
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
23
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
23
|
-
bmtool-0.6.9.
|
24
|
-
bmtool-0.6.9.
|
25
|
-
bmtool-0.6.9.
|
26
|
-
bmtool-0.6.9.
|
27
|
-
bmtool-0.6.9.
|
28
|
-
bmtool-0.6.9.
|
24
|
+
bmtool-0.6.9.25.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
25
|
+
bmtool-0.6.9.25.dist-info/METADATA,sha256=VsaoJQkvujK-1aHvsPHRoxQ49_BFRDWMjQtw_LiskB8,2769
|
26
|
+
bmtool-0.6.9.25.dist-info/WHEEL,sha256=ooBFpIzZCPdw3uqIQsOo4qqbA4ZRPxHnOH7peeONza0,91
|
27
|
+
bmtool-0.6.9.25.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
28
|
+
bmtool-0.6.9.25.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
29
|
+
bmtool-0.6.9.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|