bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +290 -147
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/METADATA +40 -2
- bmtool-0.7.1.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.4.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -2,23 +2,62 @@
|
|
2
2
|
Module for entrainment analysis
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
|
5
|
+
from typing import Dict, List
|
6
|
+
|
7
7
|
import numba
|
8
|
-
|
8
|
+
import numpy as np
|
9
9
|
import pandas as pd
|
10
|
-
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
|
11
|
-
from typing import Dict, List, Optional
|
12
|
-
from tqdm.notebook import tqdm
|
13
10
|
import scipy.stats as stats
|
11
|
+
import xarray as xr
|
12
|
+
from numba import cuda
|
13
|
+
from scipy import signal
|
14
|
+
from tqdm.notebook import tqdm
|
15
|
+
|
16
|
+
from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
|
17
|
+
|
18
|
+
|
19
|
+
def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
|
20
|
+
"""the lfp xarray should have a time axis. use that to align the spike times since the lfp can start at a
|
21
|
+
non-zero time after sliced. Both need to be on same fs for this to be correct.
|
14
22
|
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
lfp : xarray.DataArray
|
26
|
+
LFP data with time coordinates
|
27
|
+
timestamps : np.ndarray
|
28
|
+
Array of spike timestamps
|
15
29
|
|
16
|
-
|
17
|
-
|
18
|
-
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
np.ndarray
|
33
|
+
Copy of timestamps with adjusted timestamps to align with lfp.
|
34
|
+
"""
|
35
|
+
# print("Pairing LFP and Spike Times")
|
36
|
+
# print(lfp.time.values)
|
37
|
+
# print(f"LFP starts at {lfp.time.values[0]}ms")
|
38
|
+
# need to make sure lfp and spikes have the same time axis
|
39
|
+
# align spikes with lfp
|
40
|
+
timestamps = timestamps[
|
41
|
+
(timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
|
42
|
+
].copy()
|
43
|
+
# set the time axis of the spikes to match the lfp
|
44
|
+
timestamps = timestamps - lfp.time.values[0]
|
45
|
+
return timestamps
|
46
|
+
|
47
|
+
|
48
|
+
def calculate_signal_signal_plv(
|
49
|
+
signal1: np.ndarray,
|
50
|
+
signal2: np.ndarray,
|
51
|
+
fs: float,
|
52
|
+
freq_of_interest: float = None,
|
53
|
+
filter_method: str = "wavelet",
|
54
|
+
lowcut: float = None,
|
55
|
+
highcut: float = None,
|
56
|
+
bandwidth: float = 2.0,
|
57
|
+
) -> np.ndarray:
|
19
58
|
"""
|
20
59
|
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
21
|
-
|
60
|
+
|
22
61
|
Parameters
|
23
62
|
----------
|
24
63
|
signal1 : np.ndarray
|
@@ -37,7 +76,7 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
37
76
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
38
77
|
bandwidth : float, optional
|
39
78
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
40
|
-
|
79
|
+
|
41
80
|
Returns
|
42
81
|
-------
|
43
82
|
np.ndarray
|
@@ -45,23 +84,29 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
45
84
|
"""
|
46
85
|
if len(signal1) != len(signal2):
|
47
86
|
raise ValueError("Input signals must have the same length.")
|
48
|
-
|
49
|
-
if filter_method ==
|
87
|
+
|
88
|
+
if filter_method == "wavelet":
|
50
89
|
if freq_of_interest is None:
|
51
90
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
52
|
-
|
91
|
+
|
53
92
|
# Apply CWT to both signals
|
54
93
|
theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
55
94
|
theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
56
|
-
|
57
|
-
elif filter_method ==
|
95
|
+
|
96
|
+
elif filter_method == "butter":
|
58
97
|
if lowcut is None or highcut is None:
|
59
|
-
print(
|
60
|
-
|
98
|
+
print(
|
99
|
+
"Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation"
|
100
|
+
)
|
101
|
+
|
61
102
|
if lowcut and highcut:
|
62
103
|
# Bandpass filter and get the analytic signal using the Hilbert transform
|
63
|
-
filtered_signal1 = butter_bandpass_filter(
|
64
|
-
|
104
|
+
filtered_signal1 = butter_bandpass_filter(
|
105
|
+
data=signal1, lowcut=lowcut, highcut=highcut, fs=fs
|
106
|
+
)
|
107
|
+
filtered_signal2 = butter_bandpass_filter(
|
108
|
+
data=signal2, lowcut=lowcut, highcut=highcut, fs=fs
|
109
|
+
)
|
65
110
|
# Get phase using the Hilbert transform
|
66
111
|
theta1 = signal.hilbert(filtered_signal1)
|
67
112
|
theta2 = signal.hilbert(filtered_signal2)
|
@@ -69,26 +114,34 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
69
114
|
# Get phase using the Hilbert transform without filtering
|
70
115
|
theta1 = signal.hilbert(signal1)
|
71
116
|
theta2 = signal.hilbert(signal2)
|
72
|
-
|
117
|
+
|
73
118
|
else:
|
74
119
|
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
75
|
-
|
120
|
+
|
76
121
|
# Calculate phase difference
|
77
122
|
phase_diff = np.angle(theta1) - np.angle(theta2)
|
78
|
-
|
123
|
+
|
79
124
|
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
80
125
|
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
81
|
-
|
126
|
+
|
82
127
|
return plv
|
83
128
|
|
84
129
|
|
85
|
-
def calculate_spike_lfp_plv(
|
86
|
-
|
87
|
-
|
88
|
-
|
130
|
+
def calculate_spike_lfp_plv(
|
131
|
+
spike_times: np.ndarray = None,
|
132
|
+
lfp_data=None,
|
133
|
+
spike_fs: float = None,
|
134
|
+
lfp_fs: float = None,
|
135
|
+
filter_method: str = "butter",
|
136
|
+
freq_of_interest: float = None,
|
137
|
+
lowcut: float = None,
|
138
|
+
highcut: float = None,
|
139
|
+
bandwidth: float = 2.0,
|
140
|
+
filtered_lfp_phase: np.ndarray = None,
|
141
|
+
) -> float:
|
89
142
|
"""
|
90
|
-
Calculate spike-lfp unbiased phase locking value
|
91
|
-
|
143
|
+
Calculate spike-lfp unbiased phase locking value
|
144
|
+
|
92
145
|
Parameters
|
93
146
|
----------
|
94
147
|
spike_times : np.ndarray
|
@@ -111,13 +164,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
111
164
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
112
165
|
filtered_lfp_phase : np.ndarray, optional
|
113
166
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
114
|
-
|
167
|
+
|
115
168
|
Returns
|
116
169
|
-------
|
117
170
|
float
|
118
171
|
Phase Locking Value (unbiased)
|
119
172
|
"""
|
120
|
-
|
173
|
+
|
121
174
|
if spike_fs is None:
|
122
175
|
spike_fs = lfp_fs
|
123
176
|
# Convert spike times to sample indices
|
@@ -125,33 +178,50 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
125
178
|
|
126
179
|
# Then convert from seconds to samples at the new sampling rate
|
127
180
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
128
|
-
|
181
|
+
|
129
182
|
# Filter indices to ensure they're within bounds of the LFP signal
|
130
|
-
if
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
183
|
+
if isinstance(lfp_data, xr.DataArray):
|
184
|
+
if filtered_lfp_phase is not None:
|
185
|
+
valid_indices = align_spike_times_with_lfp(
|
186
|
+
lfp=filtered_lfp_phase, timestamps=spike_indices
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
190
|
+
elif isinstance(lfp_data, np.ndarray):
|
191
|
+
if filtered_lfp_phase is not None:
|
192
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
193
|
+
else:
|
194
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
195
|
+
|
135
196
|
if len(valid_indices) <= 1:
|
136
197
|
return 0
|
137
|
-
|
198
|
+
|
138
199
|
# Get instantaneous phase
|
139
200
|
if filtered_lfp_phase is None:
|
140
|
-
instantaneous_phase = get_lfp_phase(
|
141
|
-
|
142
|
-
|
201
|
+
instantaneous_phase = get_lfp_phase(
|
202
|
+
lfp_data=lfp_data,
|
203
|
+
filter_method=filter_method,
|
204
|
+
freq_of_interest=freq_of_interest,
|
205
|
+
lowcut=lowcut,
|
206
|
+
highcut=highcut,
|
207
|
+
bandwidth=bandwidth,
|
208
|
+
fs=lfp_fs,
|
209
|
+
)
|
143
210
|
else:
|
144
211
|
instantaneous_phase = filtered_lfp_phase
|
145
|
-
|
212
|
+
|
146
213
|
# Get phases at spike times
|
147
|
-
|
214
|
+
if isinstance(instantaneous_phase, xr.DataArray):
|
215
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
216
|
+
else:
|
217
|
+
spike_phases = instantaneous_phase[valid_indices]
|
148
218
|
|
149
219
|
# Number of spikes
|
150
220
|
N = len(spike_phases)
|
151
|
-
|
221
|
+
|
152
222
|
# Convert phases to unit vectors in the complex plane
|
153
223
|
unit_vectors = np.exp(1j * spike_phases)
|
154
|
-
|
224
|
+
|
155
225
|
# Sum of all unit vectors (resultant vector)
|
156
226
|
resultant_vector = np.sum(unit_vectors)
|
157
227
|
|
@@ -159,7 +229,7 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
159
229
|
plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
|
160
230
|
plv = (plv2n / N) ** 0.5
|
161
231
|
ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
|
162
|
-
plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
|
232
|
+
plv_unbiased = np.fmax(ppc, 0.0) ** 0.5 # ensure non-negative
|
163
233
|
|
164
234
|
return plv_unbiased
|
165
235
|
|
@@ -181,7 +251,7 @@ def _ppc_cuda_kernel(spike_phases, out):
|
|
181
251
|
i = cuda.grid(1)
|
182
252
|
if i < len(spike_phases):
|
183
253
|
local_sum = 0.0
|
184
|
-
for j in range(i+1, len(spike_phases)):
|
254
|
+
for j in range(i + 1, len(spike_phases)):
|
185
255
|
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
186
256
|
out[i] = local_sum
|
187
257
|
|
@@ -190,23 +260,32 @@ def _ppc_gpu(spike_phases):
|
|
190
260
|
"""GPU-accelerated implementation"""
|
191
261
|
d_phases = cuda.to_device(spike_phases)
|
192
262
|
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
193
|
-
|
263
|
+
|
194
264
|
threads = 256
|
195
265
|
blocks = (len(spike_phases) + threads - 1) // threads
|
196
|
-
|
266
|
+
|
197
267
|
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
198
268
|
total = d_out.copy_to_host().sum()
|
199
|
-
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
200
|
-
|
201
|
-
|
202
|
-
def calculate_ppc(
|
203
|
-
|
204
|
-
|
205
|
-
|
269
|
+
return (2 / (len(spike_phases) * (len(spike_phases) - 1))) * total
|
270
|
+
|
271
|
+
|
272
|
+
def calculate_ppc(
|
273
|
+
spike_times: np.ndarray = None,
|
274
|
+
lfp_data=None,
|
275
|
+
spike_fs: float = None,
|
276
|
+
lfp_fs: float = None,
|
277
|
+
filter_method: str = "wavelet",
|
278
|
+
freq_of_interest: float = None,
|
279
|
+
lowcut: float = None,
|
280
|
+
highcut: float = None,
|
281
|
+
bandwidth: float = 2.0,
|
282
|
+
ppc_method: str = "numpy",
|
283
|
+
filtered_lfp_phase: np.ndarray = None,
|
284
|
+
) -> float:
|
206
285
|
"""
|
207
286
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
208
287
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
209
|
-
|
288
|
+
|
210
289
|
Parameters
|
211
290
|
----------
|
212
291
|
spike_times : np.ndarray
|
@@ -231,7 +310,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
231
310
|
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
232
311
|
filtered_lfp_phase : np.ndarray, optional
|
233
312
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
234
|
-
|
313
|
+
|
235
314
|
Returns
|
236
315
|
-------
|
237
316
|
float
|
@@ -244,63 +323,88 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
244
323
|
|
245
324
|
# Then convert from seconds to samples at the new sampling rate
|
246
325
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
247
|
-
|
326
|
+
|
248
327
|
# Filter indices to ensure they're within bounds of the LFP signal
|
249
|
-
if
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
328
|
+
if isinstance(lfp_data, xr.DataArray):
|
329
|
+
if filtered_lfp_phase is not None:
|
330
|
+
valid_indices = align_spike_times_with_lfp(
|
331
|
+
lfp=filtered_lfp_phase, timestamps=spike_indices
|
332
|
+
)
|
333
|
+
else:
|
334
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
335
|
+
elif isinstance(lfp_data, np.ndarray):
|
336
|
+
if filtered_lfp_phase is not None:
|
337
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
338
|
+
else:
|
339
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
340
|
+
|
254
341
|
if len(valid_indices) <= 1:
|
255
342
|
return 0
|
256
|
-
|
343
|
+
|
257
344
|
# Get instantaneous phase
|
258
345
|
if filtered_lfp_phase is None:
|
259
|
-
instantaneous_phase = get_lfp_phase(
|
260
|
-
|
261
|
-
|
346
|
+
instantaneous_phase = get_lfp_phase(
|
347
|
+
lfp_data=lfp_data,
|
348
|
+
filter_method=filter_method,
|
349
|
+
freq_of_interest=freq_of_interest,
|
350
|
+
lowcut=lowcut,
|
351
|
+
highcut=highcut,
|
352
|
+
bandwidth=bandwidth,
|
353
|
+
fs=lfp_fs,
|
354
|
+
)
|
262
355
|
else:
|
263
356
|
instantaneous_phase = filtered_lfp_phase
|
264
|
-
|
357
|
+
|
265
358
|
# Get phases at spike times
|
266
|
-
|
267
|
-
|
359
|
+
if isinstance(instantaneous_phase, xr.DataArray):
|
360
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
361
|
+
else:
|
362
|
+
spike_phases = instantaneous_phase[valid_indices]
|
363
|
+
|
268
364
|
n_spikes = len(spike_phases)
|
269
365
|
|
270
366
|
# Calculate PPC (Pairwise Phase Consistency)
|
271
367
|
if n_spikes <= 1:
|
272
368
|
return 0
|
273
|
-
|
369
|
+
|
274
370
|
# Explicit calculation of pairwise phase consistency
|
275
371
|
# Vectorized computation for efficiency
|
276
|
-
if ppc_method ==
|
372
|
+
if ppc_method == "numpy":
|
277
373
|
i, j = np.triu_indices(n_spikes, k=1)
|
278
374
|
phase_diff = spike_phases[i] - spike_phases[j]
|
279
375
|
sum_cos_diff = np.sum(np.cos(phase_diff))
|
280
|
-
ppc = (
|
281
|
-
elif ppc_method ==
|
376
|
+
ppc = (2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff
|
377
|
+
elif ppc_method == "numba":
|
282
378
|
ppc = _ppc_parallel_numba(spike_phases)
|
283
|
-
elif ppc_method ==
|
379
|
+
elif ppc_method == "gpu":
|
284
380
|
ppc = _ppc_gpu(spike_phases)
|
285
381
|
else:
|
286
382
|
raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
|
287
383
|
return ppc
|
288
384
|
|
289
|
-
|
290
|
-
def calculate_ppc2(
|
291
|
-
|
292
|
-
|
293
|
-
|
385
|
+
|
386
|
+
def calculate_ppc2(
|
387
|
+
spike_times: np.ndarray = None,
|
388
|
+
lfp_data=None,
|
389
|
+
spike_fs: float = None,
|
390
|
+
lfp_fs: float = None,
|
391
|
+
filter_method: str = "wavelet",
|
392
|
+
freq_of_interest: float = None,
|
393
|
+
lowcut: float = None,
|
394
|
+
highcut: float = None,
|
395
|
+
bandwidth: float = 2.0,
|
396
|
+
filtered_lfp_phase: np.ndarray = None,
|
397
|
+
) -> float:
|
294
398
|
"""
|
295
399
|
# -----------------------------------------------------------------------------
|
296
|
-
# PPC2 Calculation (Vinck et al., 2010)
|
400
|
+
# PPC2 Calculation (Vinck et al., 2010)
|
297
401
|
# -----------------------------------------------------------------------------
|
298
402
|
# Equation(Original):
|
299
403
|
# PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
|
300
404
|
# Optimized Formula (Algebraically Equivalent):
|
301
405
|
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
302
406
|
# -----------------------------------------------------------------------------
|
303
|
-
|
407
|
+
|
304
408
|
Parameters
|
305
409
|
----------
|
306
410
|
spike_times : np.ndarray
|
@@ -323,13 +427,13 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
323
427
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
324
428
|
filtered_lfp_phase : np.ndarray, optional
|
325
429
|
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
326
|
-
|
430
|
+
|
327
431
|
Returns
|
328
432
|
-------
|
329
433
|
float
|
330
434
|
Pairwise Phase Consistency 2 (PPC2) value
|
331
435
|
"""
|
332
|
-
|
436
|
+
|
333
437
|
if spike_fs is None:
|
334
438
|
spike_fs = lfp_fs
|
335
439
|
# Convert spike times to sample indices
|
@@ -337,49 +441,75 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
337
441
|
|
338
442
|
# Then convert from seconds to samples at the new sampling rate
|
339
443
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
340
|
-
|
444
|
+
|
341
445
|
# Filter indices to ensure they're within bounds of the LFP signal
|
342
|
-
if
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
446
|
+
if isinstance(lfp_data, xr.DataArray):
|
447
|
+
if filtered_lfp_phase is not None:
|
448
|
+
valid_indices = align_spike_times_with_lfp(
|
449
|
+
lfp=filtered_lfp_phase, timestamps=spike_indices
|
450
|
+
)
|
451
|
+
else:
|
452
|
+
valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
|
453
|
+
elif isinstance(lfp_data, np.ndarray):
|
454
|
+
if filtered_lfp_phase is not None:
|
455
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
456
|
+
else:
|
457
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
458
|
+
|
347
459
|
if len(valid_indices) <= 1:
|
348
460
|
return 0
|
349
|
-
|
461
|
+
|
350
462
|
# Get instantaneous phase
|
351
463
|
if filtered_lfp_phase is None:
|
352
|
-
instantaneous_phase = get_lfp_phase(
|
353
|
-
|
354
|
-
|
464
|
+
instantaneous_phase = get_lfp_phase(
|
465
|
+
lfp_data=lfp_data,
|
466
|
+
filter_method=filter_method,
|
467
|
+
freq_of_interest=freq_of_interest,
|
468
|
+
lowcut=lowcut,
|
469
|
+
highcut=highcut,
|
470
|
+
bandwidth=bandwidth,
|
471
|
+
fs=lfp_fs,
|
472
|
+
)
|
355
473
|
else:
|
356
474
|
instantaneous_phase = filtered_lfp_phase
|
357
|
-
|
475
|
+
|
358
476
|
# Get phases at spike times
|
359
|
-
|
360
|
-
|
477
|
+
if isinstance(instantaneous_phase, xr.DataArray):
|
478
|
+
spike_phases = instantaneous_phase.sel(time=valid_indices).values
|
479
|
+
else:
|
480
|
+
spike_phases = instantaneous_phase[valid_indices]
|
361
481
|
# Calculate PPC2 according to Vinck et al. (2010), Equation 6
|
362
482
|
n = len(spike_phases)
|
363
|
-
|
483
|
+
|
364
484
|
if n <= 1:
|
365
485
|
return 0
|
366
|
-
|
486
|
+
|
367
487
|
# Convert phases to unit vectors in the complex plane
|
368
488
|
unit_vectors = np.exp(1j * spike_phases)
|
369
|
-
|
489
|
+
|
370
490
|
# Calculate the resultant vector
|
371
491
|
resultant_vector = np.sum(unit_vectors)
|
372
|
-
|
492
|
+
|
373
493
|
# PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
|
374
|
-
ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
|
375
|
-
|
494
|
+
ppc2 = (np.abs(resultant_vector) ** 2 - n) / (n * (n - 1))
|
495
|
+
|
376
496
|
return ppc2
|
377
497
|
|
378
498
|
|
379
|
-
def calculate_entrainment_per_cell(
|
380
|
-
|
381
|
-
|
382
|
-
|
499
|
+
def calculate_entrainment_per_cell(
|
500
|
+
spike_df: pd.DataFrame = None,
|
501
|
+
lfp_data: np.ndarray = None,
|
502
|
+
filter_method: str = "wavelet",
|
503
|
+
pop_names: List[str] = None,
|
504
|
+
entrainment_method: str = "plv",
|
505
|
+
lowcut: float = None,
|
506
|
+
highcut: float = None,
|
507
|
+
spike_fs: float = None,
|
508
|
+
lfp_fs: float = None,
|
509
|
+
bandwidth: float = 2,
|
510
|
+
freqs: List[float] = None,
|
511
|
+
ppc_method: str = "numpy",
|
512
|
+
) -> Dict[str, Dict[int, Dict[float, float]]]:
|
383
513
|
"""
|
384
514
|
Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
|
385
515
|
|
@@ -431,26 +561,26 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
431
561
|
filtered_lfp_phases = {}
|
432
562
|
for freq in range(len(freqs)):
|
433
563
|
phase = get_lfp_phase(
|
434
|
-
lfp_data=lfp_data,
|
435
|
-
freq_of_interest=freqs[freq],
|
436
|
-
fs=lfp_fs,
|
564
|
+
lfp_data=lfp_data,
|
565
|
+
freq_of_interest=freqs[freq],
|
566
|
+
fs=lfp_fs,
|
437
567
|
filter_method=filter_method,
|
438
|
-
lowcut=lowcut,
|
439
|
-
highcut=highcut,
|
440
|
-
bandwidth=bandwidth
|
568
|
+
lowcut=lowcut,
|
569
|
+
highcut=highcut,
|
570
|
+
bandwidth=bandwidth,
|
441
571
|
)
|
442
572
|
filtered_lfp_phases[freqs[freq]] = phase
|
443
|
-
|
573
|
+
|
444
574
|
entrainment_dict = {}
|
445
575
|
for pop in pop_names:
|
446
576
|
skip_count = 0
|
447
|
-
pop_spikes = spike_df[spike_df[
|
448
|
-
nodes = pop_spikes[
|
577
|
+
pop_spikes = spike_df[spike_df["pop_name"] == pop]
|
578
|
+
nodes = pop_spikes["node_ids"].unique()
|
449
579
|
entrainment_dict[pop] = {}
|
450
|
-
print(f
|
580
|
+
print(f"Processing {pop} population")
|
451
581
|
for node in tqdm(nodes):
|
452
|
-
node_spikes = pop_spikes[pop_spikes[
|
453
|
-
|
582
|
+
node_spikes = pop_spikes[pop_spikes["node_ids"] == node]
|
583
|
+
|
454
584
|
# Skip nodes with less than or equal to 1 spike
|
455
585
|
if len(node_spikes) <= 1:
|
456
586
|
skip_count += 1
|
@@ -459,9 +589,9 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
459
589
|
entrainment_dict[pop][node] = {}
|
460
590
|
for freq in freqs:
|
461
591
|
# Calculate entrainment based on the selected method using the pre-filtered phases
|
462
|
-
if entrainment_method ==
|
592
|
+
if entrainment_method == "plv":
|
463
593
|
entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
|
464
|
-
node_spikes[
|
594
|
+
node_spikes["timestamps"].values,
|
465
595
|
lfp_data,
|
466
596
|
spike_fs=spike_fs,
|
467
597
|
lfp_fs=lfp_fs,
|
@@ -470,11 +600,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
470
600
|
lowcut=lowcut,
|
471
601
|
highcut=highcut,
|
472
602
|
filter_method=filter_method,
|
473
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
603
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
474
604
|
)
|
475
|
-
elif entrainment_method ==
|
605
|
+
elif entrainment_method == "ppc2":
|
476
606
|
entrainment_dict[pop][node][freq] = calculate_ppc2(
|
477
|
-
node_spikes[
|
607
|
+
node_spikes["timestamps"].values,
|
478
608
|
lfp_data,
|
479
609
|
spike_fs=spike_fs,
|
480
610
|
lfp_fs=lfp_fs,
|
@@ -483,11 +613,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
483
613
|
lowcut=lowcut,
|
484
614
|
highcut=highcut,
|
485
615
|
filter_method=filter_method,
|
486
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
616
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
487
617
|
)
|
488
|
-
elif entrainment_method ==
|
618
|
+
elif entrainment_method == "ppc":
|
489
619
|
entrainment_dict[pop][node][freq] = calculate_ppc(
|
490
|
-
node_spikes[
|
620
|
+
node_spikes["timestamps"].values,
|
491
621
|
lfp_data,
|
492
622
|
spike_fs=spike_fs,
|
493
623
|
lfp_fs=lfp_fs,
|
@@ -497,21 +627,32 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
|
|
497
627
|
highcut=highcut,
|
498
628
|
filter_method=filter_method,
|
499
629
|
ppc_method=ppc_method,
|
500
|
-
filtered_lfp_phase=filtered_lfp_phases[freq]
|
630
|
+
filtered_lfp_phase=filtered_lfp_phases[freq],
|
501
631
|
)
|
502
632
|
|
503
|
-
print(
|
633
|
+
print(
|
634
|
+
f"Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes"
|
635
|
+
)
|
504
636
|
|
505
637
|
return entrainment_dict
|
506
638
|
|
507
639
|
|
508
|
-
def calculate_spike_rate_power_correlation(
|
509
|
-
|
510
|
-
|
640
|
+
def calculate_spike_rate_power_correlation(
|
641
|
+
spike_rate,
|
642
|
+
lfp_data,
|
643
|
+
fs,
|
644
|
+
pop_names,
|
645
|
+
filter_method="wavelet",
|
646
|
+
bandwidth=2.0,
|
647
|
+
lowcut=None,
|
648
|
+
highcut=None,
|
649
|
+
freq_range=(10, 100),
|
650
|
+
freq_step=5,
|
651
|
+
):
|
511
652
|
"""
|
512
653
|
Calculate correlation between population spike rates and LFP power across frequencies
|
513
654
|
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
514
|
-
|
655
|
+
|
515
656
|
Parameters:
|
516
657
|
-----------
|
517
658
|
spike_rate : DataFrame
|
@@ -534,7 +675,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
|
|
534
675
|
Min and max frequency to analyze (default: (10, 100))
|
535
676
|
freq_step : float, optional
|
536
677
|
Step size for frequency analysis (default: 5)
|
537
|
-
|
678
|
+
|
538
679
|
Returns:
|
539
680
|
--------
|
540
681
|
correlation_results : dict
|
@@ -542,32 +683,34 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
|
|
542
683
|
frequencies : array
|
543
684
|
Array of frequencies analyzed
|
544
685
|
"""
|
545
|
-
|
686
|
+
|
546
687
|
# Define frequency bands to analyze
|
547
688
|
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
548
|
-
|
689
|
+
|
549
690
|
# Dictionary to store results
|
550
691
|
correlation_results = {pop: {} for pop in pop_names}
|
551
|
-
|
692
|
+
|
552
693
|
# Calculate power at each frequency band using specified filter
|
553
694
|
power_by_freq = {}
|
554
695
|
for freq in frequencies:
|
555
|
-
power_by_freq[freq] = get_lfp_power(
|
556
|
-
|
557
|
-
|
696
|
+
power_by_freq[freq] = get_lfp_power(
|
697
|
+
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
698
|
+
)
|
699
|
+
|
558
700
|
# Calculate correlation for each population
|
559
701
|
for pop in pop_names:
|
560
702
|
# Extract spike rate for this population
|
561
703
|
pop_rate = spike_rate[pop]
|
562
|
-
|
704
|
+
|
563
705
|
# Calculate correlation with power at each frequency
|
564
706
|
for freq in frequencies:
|
565
707
|
# Make sure the lengths match
|
566
708
|
if len(pop_rate) != len(power_by_freq[freq]):
|
567
|
-
raise ValueError(
|
709
|
+
raise ValueError(
|
710
|
+
f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
|
711
|
+
)
|
568
712
|
# use spearman for non-parametric correlation
|
569
713
|
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
570
|
-
correlation_results[pop][freq] = {
|
571
|
-
|
572
|
-
return correlation_results, frequencies
|
714
|
+
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
573
715
|
|
716
|
+
return correlation_results, frequencies
|