bmtool 0.7.0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +250 -143
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/METADATA +41 -3
- bmtool-0.7.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.2.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -2,23 +2,62 @@
|
|
2
2
|
Module for entrainment analysis
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
|
5
|
+
from typing import Dict, List
|
6
|
+
|
7
7
|
import numba
|
8
|
-
|
8
|
+
import numpy as np
|
9
9
|
import pandas as pd
|
10
|
-
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
|
11
|
-
from typing import Dict, List, Optional
|
12
|
-
from tqdm.notebook import tqdm
|
13
10
|
import scipy.stats as stats
|
11
|
+
import xarray as xr
|
12
|
+
from numba import cuda
|
13
|
+
from scipy import signal
|
14
|
+
from tqdm.notebook import tqdm
|
15
|
+
|
16
|
+
from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
|
17
|
+
|
14
18
|
|
19
|
+
def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
|
20
|
+
"""the lfp xarray should have a time axis. use that to align the spike times since the lfp can start at a
|
21
|
+
non-zero time after sliced. Both need to be on same fs for this to be correct.
|
15
22
|
|
16
|
-
|
17
|
-
|
18
|
-
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
lfp : xarray.DataArray
|
26
|
+
LFP data with time coordinates
|
27
|
+
timestamps : np.ndarray
|
28
|
+
Array of spike timestamps
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
np.ndarray
|
33
|
+
Copy of timestamps with adjusted timestamps to align with lfp.
|
34
|
+
"""
|
35
|
+
# print("Pairing LFP and Spike Times")
|
36
|
+
# print(lfp.time.values)
|
37
|
+
# print(f"LFP starts at {lfp.time.values[0]}ms")
|
38
|
+
# need to make sure lfp and spikes have the same time axis
|
39
|
+
# align spikes with lfp
|
40
|
+
timestamps = timestamps[
|
41
|
+
(timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
|
42
|
+
].copy()
|
43
|
+
# set the time axis of the spikes to match the lfp
|
44
|
+
timestamps = timestamps - lfp.time.values[0]
|
45
|
+
return timestamps
|
46
|
+
|
47
|
+
|
48
|
+
def calculate_signal_signal_plv(
|
49
|
+
signal1: np.ndarray,
|
50
|
+
signal2: np.ndarray,
|
51
|
+
fs: float,
|
52
|
+
freq_of_interest: float = None,
|
53
|
+
filter_method: str = "wavelet",
|
54
|
+
lowcut: float = None,
|
55
|
+
highcut: float = None,
|
56
|
+
bandwidth: float = 2.0,
|
57
|
+
) -> np.ndarray:
|
19
58
|
"""
|
20
59
|
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
21
|
-
|
60
|
+
|
22
61
|
Parameters
|
23
62
|
----------
|
24
63
|
signal1 : np.ndarray
|
@@ -37,7 +76,7 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
37
76
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
38
77
|
bandwidth : float, optional
|
39
78
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
40
|
-
|
79
|
+
|
41
80
|
Returns
|
42
81
|
-------
|
43
82
|
np.ndarray
|
@@ -45,23 +84,29 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
45
84
|
"""
|
46
85
|
if len(signal1) != len(signal2):
|
47
86
|
raise ValueError("Input signals must have the same length.")
|
48
|
-
|
49
|
-
if filter_method ==
|
87
|
+
|
88
|
+
if filter_method == "wavelet":
|
50
89
|
if freq_of_interest is None:
|
51
90
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
52
|
-
|
91
|
+
|
53
92
|
# Apply CWT to both signals
|
54
93
|
theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
55
94
|
theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
56
|
-
|
57
|
-
elif filter_method ==
|
95
|
+
|
96
|
+
elif filter_method == "butter":
|
58
97
|
if lowcut is None or highcut is None:
|
59
|
-
print(
|
60
|
-
|
98
|
+
print(
|
99
|
+
"Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation"
|
100
|
+
)
|
101
|
+
|
61
102
|
if lowcut and highcut:
|
62
103
|
# Bandpass filter and get the analytic signal using the Hilbert transform
|
63
|
-
filtered_signal1 = butter_bandpass_filter(
|
64
|
-
|
104
|
+
filtered_signal1 = butter_bandpass_filter(
|
105
|
+
data=signal1, lowcut=lowcut, highcut=highcut, fs=fs
|
106
|
+
)
|
107
|
+
filtered_signal2 = butter_bandpass_filter(
|
108
|
+
data=signal2, lowcut=lowcut, highcut=highcut, fs=fs
|
109
|
+
)
|
65
110
|
# Get phase using the Hilbert transform
|
66
111
|
theta1 = signal.hilbert(filtered_signal1)
|
67
112
|
theta2 = signal.hilbert(filtered_signal2)
|
@@ -69,26 +114,34 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
69
114
|
# Get phase using the Hilbert transform without filtering
|
70
115
|
theta1 = signal.hilbert(signal1)
|
71
116
|
theta2 = signal.hilbert(signal2)
|
72
|
-
|
117
|
+
|
73
118
|
else:
|
74
119
|
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
75
|
-
|
120
|
+
|
76
121
|
# Calculate phase difference
|
77
122
|
phase_diff = np.angle(theta1) - np.angle(theta2)
|
78
|
-
|
123
|
+
|
79
124
|
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
80
125
|
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
81
|
-
|
126
|
+
|
82
127
|
return plv
|
83
128
|
|
84
129
|
|
85
|
-
def calculate_spike_lfp_plv(
|
86
|
-
|
87
|
-
|
88
|
-
|
130
|
+
def calculate_spike_lfp_plv(
|
131
|
+
spike_times: np.ndarray = None,
|
132
|
+
lfp_data=None,
|
133
|
+
spike_fs: float = None,
|
134
|
+
lfp_fs: float = None,
|
135
|
+
filter_method: str = "butter",
|
136
|
+
freq_of_interest: float = None,
|
137
|
+
lowcut: float = None,
|
138
|
+
highcut: float = None,
|
139
|
+
bandwidth: float = 2.0,
|
140
|
+
filtered_lfp_phase: np.ndarray = None,
|
141
|
+
) -> float:
|
89
142
|
"""
|
90
|
-
Calculate spike-lfp unbiased phase locking value
|
91
|
-
|
143
|
+
Calculate spike-lfp unbiased phase locking value
|
144
|
+
|
92
145
|
Parameters
|
93
146
|
----------
|
94
147
|
spike_times : np.ndarray
|
@@ -111,13 +164,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
111
164
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
112
165
|
filtered_lfp_phase : np.ndarray, optional
|
113
166
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
114
|
-
|
167
|
+
|
115
168
|
Returns
|
116
169
|
-------
|
117
170
|
float
|
118
171
|
Phase Locking Value (unbiased)
|
119
172
|
"""
|
120
|
-
|
173
|
+
|
121
174
|
if spike_fs is None:
|
122
175
|
spike_fs = lfp_fs
|
123
176
|
# Convert spike times to sample indices
|
@@ -125,33 +178,36 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
125
178
|
|
126
179
|
# Then convert from seconds to samples at the new sampling rate
|
127
180
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
128
|
-
|
181
|
+
|
129
182
|
# Filter indices to ensure they're within bounds of the LFP signal
|
130
|
-
|
131
|
-
|
132
|
-
else:
|
133
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
134
|
-
|
183
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
184
|
+
|
135
185
|
if len(valid_indices) <= 1:
|
136
186
|
return 0
|
137
|
-
|
187
|
+
|
138
188
|
# Get instantaneous phase
|
139
189
|
if filtered_lfp_phase is None:
|
140
|
-
instantaneous_phase = get_lfp_phase(
|
141
|
-
|
142
|
-
|
190
|
+
instantaneous_phase = get_lfp_phase(
|
191
|
+
lfp_data=lfp_data,
|
192
|
+
filter_method=filter_method,
|
193
|
+
freq_of_interest=freq_of_interest,
|
194
|
+
lowcut=lowcut,
|
195
|
+
highcut=highcut,
|
196
|
+
bandwidth=bandwidth,
|
197
|
+
fs=lfp_fs,
|
198
|
+
)
|
143
199
|
else:
|
144
200
|
instantaneous_phase = filtered_lfp_phase
|
145
|
-
|
201
|
+
|
146
202
|
# Get phases at spike times
|
147
|
-
spike_phases = instantaneous_phase
|
203
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
148
204
|
|
149
205
|
# Number of spikes
|
150
206
|
N = len(spike_phases)
|
151
|
-
|
207
|
+
|
152
208
|
# Convert phases to unit vectors in the complex plane
|
153
209
|
unit_vectors = np.exp(1j * spike_phases)
|
154
|
-
|
210
|
+
|
155
211
|
# Sum of all unit vectors (resultant vector)
|
156
212
|
resultant_vector = np.sum(unit_vectors)
|
157
213
|
|
@@ -159,7 +215,7 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
159
215
|
plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
|
160
216
|
plv = (plv2n / N) ** 0.5
|
161
217
|
ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
|
162
|
-
plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
|
218
|
+
plv_unbiased = np.fmax(ppc, 0.0) ** 0.5 # ensure non-negative
|
163
219
|
|
164
220
|
return plv_unbiased
|
165
221
|
|
@@ -181,7 +237,7 @@ def _ppc_cuda_kernel(spike_phases, out):
|
|
181
237
|
i = cuda.grid(1)
|
182
238
|
if i < len(spike_phases):
|
183
239
|
local_sum = 0.0
|
184
|
-
for j in range(i+1, len(spike_phases)):
|
240
|
+
for j in range(i + 1, len(spike_phases)):
|
185
241
|
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
186
242
|
out[i] = local_sum
|
187
243
|
|
@@ -190,23 +246,32 @@ def _ppc_gpu(spike_phases):
|
|
190
246
|
"""GPU-accelerated implementation"""
|
191
247
|
d_phases = cuda.to_device(spike_phases)
|
192
248
|
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
193
|
-
|
249
|
+
|
194
250
|
threads = 256
|
195
251
|
blocks = (len(spike_phases) + threads - 1) // threads
|
196
|
-
|
252
|
+
|
197
253
|
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
198
254
|
total = d_out.copy_to_host().sum()
|
199
|
-
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
200
|
-
|
201
|
-
|
202
|
-
def calculate_ppc(
|
203
|
-
|
204
|
-
|
205
|
-
|
255
|
+
return (2 / (len(spike_phases) * (len(spike_phases) - 1))) * total
|
256
|
+
|
257
|
+
|
258
|
+
def calculate_ppc(
|
259
|
+
spike_times: np.ndarray = None,
|
260
|
+
lfp_data=None,
|
261
|
+
spike_fs: float = None,
|
262
|
+
lfp_fs: float = None,
|
263
|
+
filter_method: str = "wavelet",
|
264
|
+
freq_of_interest: float = None,
|
265
|
+
lowcut: float = None,
|
266
|
+
highcut: float = None,
|
267
|
+
bandwidth: float = 2.0,
|
268
|
+
ppc_method: str = "numpy",
|
269
|
+
filtered_lfp_phase: np.ndarray = None,
|
270
|
+
) -> float:
|
206
271
|
"""
|
207
272
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
208
273
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
209
|
-
|
274
|
+
|
210
275
|
Parameters
|
211
276
|
----------
|
212
277
|
spike_times : np.ndarray
|
@@ -231,7 +296,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
231
296
|
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
232
297
|
filtered_lfp_phase : np.ndarray, optional
|
233
298
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
234
|
-
|
299
|
+
|
235
300
|
Returns
|
236
301
|
-------
|
237
302
|
float
|
@@ -244,63 +309,77 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
244
309
|
|
245
310
|
# Then convert from seconds to samples at the new sampling rate
|
246
311
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
247
|
-
|
312
|
+
|
248
313
|
# Filter indices to ensure they're within bounds of the LFP signal
|
249
314
|
if filtered_lfp_phase is not None:
|
250
|
-
valid_indices =
|
315
|
+
valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
|
251
316
|
else:
|
252
|
-
valid_indices =
|
253
|
-
|
317
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
318
|
+
|
254
319
|
if len(valid_indices) <= 1:
|
255
320
|
return 0
|
256
|
-
|
321
|
+
|
257
322
|
# Get instantaneous phase
|
258
323
|
if filtered_lfp_phase is None:
|
259
|
-
instantaneous_phase = get_lfp_phase(
|
260
|
-
|
261
|
-
|
324
|
+
instantaneous_phase = get_lfp_phase(
|
325
|
+
lfp_data=lfp_data,
|
326
|
+
filter_method=filter_method,
|
327
|
+
freq_of_interest=freq_of_interest,
|
328
|
+
lowcut=lowcut,
|
329
|
+
highcut=highcut,
|
330
|
+
bandwidth=bandwidth,
|
331
|
+
fs=lfp_fs,
|
332
|
+
)
|
262
333
|
else:
|
263
334
|
instantaneous_phase = filtered_lfp_phase
|
264
|
-
|
335
|
+
|
265
336
|
# Get phases at spike times
|
266
|
-
spike_phases = instantaneous_phase
|
267
|
-
|
337
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
338
|
+
|
268
339
|
n_spikes = len(spike_phases)
|
269
340
|
|
270
341
|
# Calculate PPC (Pairwise Phase Consistency)
|
271
342
|
if n_spikes <= 1:
|
272
343
|
return 0
|
273
|
-
|
344
|
+
|
274
345
|
# Explicit calculation of pairwise phase consistency
|
275
346
|
# Vectorized computation for efficiency
|
276
|
-
if ppc_method ==
|
347
|
+
if ppc_method == "numpy":
|
277
348
|
i, j = np.triu_indices(n_spikes, k=1)
|
278
349
|
phase_diff = spike_phases[i] - spike_phases[j]
|
279
350
|
sum_cos_diff = np.sum(np.cos(phase_diff))
|
280
|
-
ppc = (
|
281
|
-
elif ppc_method ==
|
351
|
+
ppc = (2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff
|
352
|
+
elif ppc_method == "numba":
|
282
353
|
ppc = _ppc_parallel_numba(spike_phases)
|
283
|
-
elif ppc_method ==
|
354
|
+
elif ppc_method == "gpu":
|
284
355
|
ppc = _ppc_gpu(spike_phases)
|
285
356
|
else:
|
286
357
|
raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
|
287
358
|
return ppc
|
288
359
|
|
289
|
-
|
290
|
-
def calculate_ppc2(
|
291
|
-
|
292
|
-
|
293
|
-
|
360
|
+
|
361
|
+
def calculate_ppc2(
|
362
|
+
spike_times: np.ndarray = None,
|
363
|
+
lfp_data=None,
|
364
|
+
spike_fs: float = None,
|
365
|
+
lfp_fs: float = None,
|
366
|
+
filter_method: str = "wavelet",
|
367
|
+
freq_of_interest: float = None,
|
368
|
+
lowcut: float = None,
|
369
|
+
highcut: float = None,
|
370
|
+
bandwidth: float = 2.0,
|
371
|
+
filtered_lfp_phase: np.ndarray = None,
|
372
|
+
) -> float:
|
294
373
|
"""
|
295
374
|
# -----------------------------------------------------------------------------
|
296
|
-
# PPC2 Calculation (Vinck et al., 2010)
|
375
|
+
# PPC2 Calculation (Vinck et al., 2010)
|
297
376
|
# -----------------------------------------------------------------------------
|
298
377
|
# Equation(Original):
|
299
378
|
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
300
379
|
# Optimized Formula (Algebraically Equivalent):
|
301
380
|
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
302
381
|
# -----------------------------------------------------------------------------
|
303
|
-
|
382
|
+
|
304
383
|
Parameters
|
305
384
|
----------
|
306
385
|
spike_times : np.ndarray
|
@@ -323,13 +402,13 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
323
402
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
324
403
|
filtered_lfp_phase : np.ndarray, optional
|
325
404
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
326
|
-
|
405
|
+
|
327
406
|
Returns
|
328
407
|
-------
|
329
408
|
float
|
330
409
|
Pairwise Phase Consistency 2 (PPC2) value
|
331
410
|
"""
|
332
|
-
|
411
|
+
|
333
412
|
if spike_fs is None:
|
334
413
|
spike_fs = lfp_fs
|
335
414
|
# Convert spike times to sample indices
|
@@ -337,49 +416,64 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
337
416
|
|
338
417
|
# Then convert from seconds to samples at the new sampling rate
|
339
418
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
340
|
-
|
419
|
+
|
341
420
|
# Filter indices to ensure they're within bounds of the LFP signal
|
342
421
|
if filtered_lfp_phase is not None:
|
343
|
-
valid_indices =
|
422
|
+
valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
|
344
423
|
else:
|
345
|
-
valid_indices =
|
346
|
-
|
424
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
425
|
+
|
347
426
|
if len(valid_indices) <= 1:
|
348
427
|
return 0
|
349
|
-
|
428
|
+
|
350
429
|
# Get instantaneous phase
|
351
430
|
if filtered_lfp_phase is None:
|
352
|
-
instantaneous_phase = get_lfp_phase(
|
353
|
-
|
354
|
-
|
431
|
+
instantaneous_phase = get_lfp_phase(
|
432
|
+
lfp_data=lfp_data,
|
433
|
+
filter_method=filter_method,
|
434
|
+
freq_of_interest=freq_of_interest,
|
435
|
+
lowcut=lowcut,
|
436
|
+
highcut=highcut,
|
437
|
+
bandwidth=bandwidth,
|
438
|
+
fs=lfp_fs,
|
439
|
+
)
|
355
440
|
else:
|
356
441
|
instantaneous_phase = filtered_lfp_phase
|
357
|
-
|
442
|
+
|
358
443
|
# Get phases at spike times
|
359
|
-
spike_phases = instantaneous_phase
|
360
|
-
|
444
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
361
445
|
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
362
446
|
n = len(spike_phases)
|
363
|
-
|
447
|
+
|
364
448
|
if n <= 1:
|
365
449
|
return 0
|
366
|
-
|
450
|
+
|
367
451
|
# Convert phases to unit vectors in the complex plane
|
368
452
|
unit_vectors = np.exp(1j * spike_phases)
|
369
|
-
|
453
|
+
|
370
454
|
# Calculate the resultant vector
|
371
455
|
resultant_vector = np.sum(unit_vectors)
|
372
|
-
|
456
|
+
|
373
457
|
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
374
|
-
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
375
|
-
|
458
|
+
ppc2 = (np.abs(resultant_vector) ** 2 - n) / (n * (n - 1))
|
459
|
+
|
376
460
|
return ppc2
|
377
461
|
|
378
462
|
|
379
|
-
def calculate_entrainment_per_cell(
|
380
|
-
|
381
|
-
|
382
|
-
|
463
|
+
def calculate_entrainment_per_cell(
|
464
|
+
spike_df: pd.DataFrame = None,
|
465
|
+
lfp_data: np.ndarray = None,
|
466
|
+
filter_method: str = "wavelet",
|
467
|
+
pop_names: List[str] = None,
|
468
|
+
entrainment_method: str = "plv",
|
469
|
+
lowcut: float = None,
|
470
|
+
highcut: float = None,
|
471
|
+
spike_fs: float = None,
|
472
|
+
lfp_fs: float = None,
|
473
|
+
bandwidth: float = 2,
|
474
|
+
freqs: List[float] = None,
|
475
|
+
ppc_method: str = "numpy",
|
476
|
+
) -> Dict[str, Dict[int, Dict[float, float]]]:
|
383
477
|
"""
|
384
478
|
Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
|
385
479
|
|
@@ -431,26 +525,26 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
431
525
|
filtered_lfp_phases = {}
|
432
526
|
for freq in range(len(freqs)):
|
433
527
|
phase = get_lfp_phase(
|
434
|
-
lfp_data=lfp_data,
|
435
|
-
freq_of_interest=freqs[freq],
|
436
|
-
fs=lfp_fs,
|
528
|
+
lfp_data=lfp_data,
|
529
|
+
freq_of_interest=freqs[freq],
|
530
|
+
fs=lfp_fs,
|
437
531
|
filter_method=filter_method,
|
438
|
-
lowcut=lowcut,
|
439
|
-
highcut=highcut,
|
440
|
-
bandwidth=bandwidth
|
532
|
+
lowcut=lowcut,
|
533
|
+
highcut=highcut,
|
534
|
+
bandwidth=bandwidth,
|
441
535
|
)
|
442
536
|
filtered_lfp_phases[freqs[freq]] = phase
|
443
|
-
|
537
|
+
|
444
538
|
entrainment_dict = {}
|
445
539
|
for pop in pop_names:
|
446
540
|
skip_count = 0
|
447
|
-
pop_spikes = spike_df[spike_df[
|
448
|
-
nodes = pop_spikes[
|
541
|
+
pop_spikes = spike_df[spike_df["pop_name"] == pop]
|
542
|
+
nodes = pop_spikes["node_ids"].unique()
|
449
543
|
entrainment_dict[pop] = {}
|
450
|
-
print(f
|
544
|
+
print(f"Processing {pop} population")
|
451
545
|
for node in tqdm(nodes):
|
452
|
-
node_spikes = pop_spikes[pop_spikes[
|
453
|
-
|
546
|
+
node_spikes = pop_spikes[pop_spikes["node_ids"] == node]
|
547
|
+
|
454
548
|
# Skip nodes with less than or equal to 1 spike
|
455
549
|
if len(node_spikes) <= 1:
|
456
550
|
skip_count += 1
|
@@ -459,9 +553,9 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
459
553
|
entrainment_dict[pop][node] = {}
|
460
554
|
for freq in freqs:
|
461
555
|
# Calculate entrainment based on the selected method using the pre-filtered phases
|
462
|
-
if entrainment_method ==
|
556
|
+
if entrainment_method == "plv":
|
463
557
|
entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
|
464
|
-
node_spikes[
|
558
|
+
node_spikes["timestamps"].values,
|
465
559
|
lfp_data,
|
466
560
|
spike_fs=spike_fs,
|
467
561
|
lfp_fs=lfp_fs,
|
@@ -470,11 +564,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
470
564
|
lowcut=lowcut,
|
471
565
|
highcut=highcut,
|
472
566
|
filter_method=filter_method,
|
473
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
567
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
474
568
|
)
|
475
|
-
elif entrainment_method ==
|
569
|
+
elif entrainment_method == "ppc2":
|
476
570
|
entrainment_dict[pop][node][freq] = calculate_ppc2(
|
477
|
-
node_spikes[
|
571
|
+
node_spikes["timestamps"].values,
|
478
572
|
lfp_data,
|
479
573
|
spike_fs=spike_fs,
|
480
574
|
lfp_fs=lfp_fs,
|
@@ -483,11 +577,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
483
577
|
lowcut=lowcut,
|
484
578
|
highcut=highcut,
|
485
579
|
filter_method=filter_method,
|
486
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
580
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
487
581
|
)
|
488
|
-
elif entrainment_method ==
|
582
|
+
elif entrainment_method == "ppc":
|
489
583
|
entrainment_dict[pop][node][freq] = calculate_ppc(
|
490
|
-
node_spikes[
|
584
|
+
node_spikes["timestamps"].values,
|
491
585
|
lfp_data,
|
492
586
|
spike_fs=spike_fs,
|
493
587
|
lfp_fs=lfp_fs,
|
@@ -497,21 +591,32 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
497
591
|
highcut=highcut,
|
498
592
|
filter_method=filter_method,
|
499
593
|
ppc_method=ppc_method,
|
500
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
594
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
501
595
|
)
|
502
596
|
|
503
|
-
print(
|
597
|
+
print(
|
598
|
+
f"Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes"
|
599
|
+
)
|
504
600
|
|
505
601
|
return entrainment_dict
|
506
602
|
|
507
603
|
|
508
|
-
def calculate_spike_rate_power_correlation(
|
509
|
-
|
510
|
-
|
604
|
+
def calculate_spike_rate_power_correlation(
|
605
|
+
spike_rate,
|
606
|
+
lfp_data,
|
607
|
+
fs,
|
608
|
+
pop_names,
|
609
|
+
filter_method="wavelet",
|
610
|
+
bandwidth=2.0,
|
611
|
+
lowcut=None,
|
612
|
+
highcut=None,
|
613
|
+
freq_range=(10, 100),
|
614
|
+
freq_step=5,
|
615
|
+
):
|
511
616
|
"""
|
512
617
|
Calculate correlation between population spike rates and LFP power across frequencies
|
513
618
|
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
514
|
-
|
619
|
+
|
515
620
|
Parameters:
|
516
621
|
-----------
|
517
622
|
spike_rate : DataFrame
|
@@ -534,7 +639,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
|
|
534
639
|
Min and max frequency to analyze (default: (10, 100))
|
535
640
|
freq_step : float, optional
|
536
641
|
Step size for frequency analysis (default: 5)
|
537
|
-
|
642
|
+
|
538
643
|
Returns:
|
539
644
|
--------
|
540
645
|
correlation_results : dict
|
@@ -542,32 +647,34 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
|
|
542
647
|
frequencies : array
|
543
648
|
Array of frequencies analyzed
|
544
649
|
"""
|
545
|
-
|
650
|
+
|
546
651
|
# Define frequency bands to analyze
|
547
652
|
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
548
|
-
|
653
|
+
|
549
654
|
# Dictionary to store results
|
550
655
|
correlation_results = {pop: {} for pop in pop_names}
|
551
|
-
|
656
|
+
|
552
657
|
# Calculate power at each frequency band using specified filter
|
553
658
|
power_by_freq = {}
|
554
659
|
for freq in frequencies:
|
555
|
-
power_by_freq[freq] = get_lfp_power(
|
556
|
-
|
557
|
-
|
660
|
+
power_by_freq[freq] = get_lfp_power(
|
661
|
+
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
662
|
+
)
|
663
|
+
|
558
664
|
# Calculate correlation for each population
|
559
665
|
for pop in pop_names:
|
560
666
|
# Extract spike rate for this population
|
561
667
|
pop_rate = spike_rate[pop]
|
562
|
-
|
668
|
+
|
563
669
|
# Calculate correlation with power at each frequency
|
564
670
|
for freq in frequencies:
|
565
671
|
# Make sure the lengths match
|
566
672
|
if len(pop_rate) != len(power_by_freq[freq]):
|
567
|
-
raise ValueError(
|
673
|
+
raise ValueError(
|
674
|
+
f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
|
675
|
+
)
|
568
676
|
# use spearman for non-parametric correlation
|
569
677
|
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
570
|
-
correlation_results[pop][freq] = {
|
571
|
-
|
572
|
-
return correlation_results, frequencies
|
678
|
+
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
573
679
|
|
680
|
+
return correlation_results, frequencies
|