bmtool 0.7.0.4__tar.gz → 0.7.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/PKG-INFO +1 -1
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/analysis/entrainment.py +155 -147
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/analysis/lfp.py +55 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/PKG-INFO +1 -1
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/setup.py +1 -1
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/LICENSE +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/README.md +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/SLURM.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/__main__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/analysis/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/analysis/netcon_reports.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/analysis/spikes.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/connections.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/entrainment.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/lfp.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/netcon_reports.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/bmplot/spikes.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/connectors.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/debug/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/debug/commands.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/debug/debug.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/graphs.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/manage.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/plot_commands.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/singlecell.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/synapses.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/util/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/util/commands.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/util/neuron/__init__.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/util/neuron/celltuner.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool/util/util.py +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/SOURCES.txt +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/dependency_links.txt +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/entry_points.txt +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/requires.txt +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/bmtool.egg-info/top_level.txt +0 -0
- {bmtool-0.7.0.4 → bmtool-0.7.0.5}/setup.cfg +0 -0
@@ -7,13 +7,10 @@ from scipy import signal
|
|
7
7
|
import numba
|
8
8
|
from numba import cuda
|
9
9
|
import pandas as pd
|
10
|
-
import
|
11
|
-
from
|
12
|
-
from typing import Dict, List
|
10
|
+
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
|
11
|
+
from typing import Dict, List, Optional
|
13
12
|
from tqdm.notebook import tqdm
|
14
13
|
import scipy.stats as stats
|
15
|
-
import seaborn as sns
|
16
|
-
import matplotlib.pyplot as plt
|
17
14
|
|
18
15
|
|
19
16
|
def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
|
@@ -87,17 +84,17 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
|
|
87
84
|
|
88
85
|
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
89
86
|
lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
|
90
|
-
lowcut: float = None, highcut: float = None,
|
91
|
-
|
87
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
88
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
92
89
|
"""
|
93
|
-
Calculate spike-lfp phase locking value
|
90
|
+
Calculate spike-lfp unbiased phase locking value
|
94
91
|
|
95
92
|
Parameters
|
96
93
|
----------
|
97
94
|
spike_times : np.ndarray
|
98
95
|
Array of spike times
|
99
96
|
lfp_data : np.ndarray
|
100
|
-
Local field potential time series data
|
97
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
101
98
|
spike_fs : float, optional
|
102
99
|
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
103
100
|
lfp_fs : float
|
@@ -112,13 +109,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
112
109
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
113
110
|
bandwidth : float, optional
|
114
111
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
112
|
+
filtered_lfp_phase : np.ndarray, optional
|
113
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
115
114
|
|
116
115
|
Returns
|
117
116
|
-------
|
118
|
-
|
119
|
-
(
|
120
|
-
- plv: Phase Locking Value
|
121
|
-
- spike_phases: Phases at spike times
|
117
|
+
float
|
118
|
+
Phase Locking Value (unbiased)
|
122
119
|
"""
|
123
120
|
|
124
121
|
if spike_fs is None:
|
@@ -130,46 +127,41 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
|
|
130
127
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
131
128
|
|
132
129
|
# Filter indices to ensure they're within bounds of the LFP signal
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
# Filter the LFP signal to extract the phase
|
138
|
-
if filter_method == 'wavelet':
|
139
|
-
if freq_of_interest is None:
|
140
|
-
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
141
|
-
|
142
|
-
# Apply CWT to extract phase
|
143
|
-
filtered_lfp = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
144
|
-
|
145
|
-
elif filter_method == 'butter':
|
146
|
-
if lowcut is None or highcut is None:
|
147
|
-
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
130
|
+
if filtered_lfp_phase is not None:
|
131
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
132
|
+
else:
|
133
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
148
134
|
|
149
|
-
|
150
|
-
|
151
|
-
filtered_lfp = signal.hilbert(filtered_lfp) # Get analytic signal
|
135
|
+
if len(valid_indices) <= 1:
|
136
|
+
return 0
|
152
137
|
|
153
|
-
|
138
|
+
# Get instantaneous phase
|
139
|
+
if filtered_lfp_phase is None:
|
140
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
141
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
142
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
154
143
|
else:
|
155
|
-
|
144
|
+
instantaneous_phase = filtered_lfp_phase
|
156
145
|
|
157
146
|
# Get phases at spike times
|
158
147
|
spike_phases = instantaneous_phase[valid_indices]
|
159
|
-
|
160
|
-
#
|
161
|
-
|
148
|
+
|
149
|
+
# Number of spikes
|
150
|
+
N = len(spike_phases)
|
162
151
|
|
163
152
|
# Convert phases to unit vectors in the complex plane
|
164
153
|
unit_vectors = np.exp(1j * spike_phases)
|
165
154
|
|
166
|
-
#
|
155
|
+
# Sum of all unit vectors (resultant vector)
|
167
156
|
resultant_vector = np.sum(unit_vectors)
|
168
|
-
|
169
|
-
#
|
170
|
-
|
171
|
-
|
172
|
-
|
157
|
+
|
158
|
+
# Calculate plv^2 * N
|
159
|
+
plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
|
160
|
+
plv = (plv2n / N) ** 0.5
|
161
|
+
ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
|
162
|
+
plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
|
163
|
+
|
164
|
+
return plv_unbiased
|
173
165
|
|
174
166
|
|
175
167
|
@numba.njit(parallel=True, fastmath=True)
|
@@ -209,8 +201,8 @@ def _ppc_gpu(spike_phases):
|
|
209
201
|
|
210
202
|
def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
211
203
|
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
212
|
-
lowcut: float = None, highcut: float = None,
|
213
|
-
|
204
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
205
|
+
ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
|
214
206
|
"""
|
215
207
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
216
208
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
@@ -220,7 +212,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
220
212
|
spike_times : np.ndarray
|
221
213
|
Array of spike times
|
222
214
|
lfp_data : np.ndarray
|
223
|
-
Local field potential time series data
|
215
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
224
216
|
spike_fs : float, optional
|
225
217
|
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
226
218
|
lfp_fs : float
|
@@ -237,13 +229,13 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
237
229
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
238
230
|
ppc_method : str, optional
|
239
231
|
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
232
|
+
filtered_lfp_phase : np.ndarray, optional
|
233
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
240
234
|
|
241
235
|
Returns
|
242
236
|
-------
|
243
|
-
|
244
|
-
|
245
|
-
- ppc: Pairwise Phase Consistency value
|
246
|
-
- spike_phases: Phases at spike times
|
237
|
+
float
|
238
|
+
Pairwise Phase Consistency value
|
247
239
|
"""
|
248
240
|
if spike_fs is None:
|
249
241
|
spike_fs = lfp_fs
|
@@ -254,32 +246,21 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
254
246
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
255
247
|
|
256
248
|
# Filter indices to ensure they're within bounds of the LFP signal
|
257
|
-
|
249
|
+
if filtered_lfp_phase is not None:
|
250
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
251
|
+
else:
|
252
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
253
|
+
|
258
254
|
if len(valid_indices) <= 1:
|
259
|
-
return 0
|
255
|
+
return 0
|
260
256
|
|
261
|
-
#
|
262
|
-
if
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
# Apply CWT to extract phase at the frequency of interest
|
267
|
-
lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
268
|
-
instantaneous_phase = np.angle(lfp_complex)
|
269
|
-
|
270
|
-
elif filter_method == 'butter':
|
271
|
-
if lowcut is None or highcut is None:
|
272
|
-
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
273
|
-
|
274
|
-
# Bandpass filter the signal
|
275
|
-
filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
|
276
|
-
|
277
|
-
# Get phase using the Hilbert transform
|
278
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
279
|
-
instantaneous_phase = np.angle(analytic_signal)
|
280
|
-
|
257
|
+
# Get instantaneous phase
|
258
|
+
if filtered_lfp_phase is None:
|
259
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
260
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
261
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
281
262
|
else:
|
282
|
-
|
263
|
+
instantaneous_phase = filtered_lfp_phase
|
283
264
|
|
284
265
|
# Get phases at spike times
|
285
266
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -288,28 +269,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
288
269
|
|
289
270
|
# Calculate PPC (Pairwise Phase Consistency)
|
290
271
|
if n_spikes <= 1:
|
291
|
-
return 0
|
272
|
+
return 0
|
292
273
|
|
293
274
|
# Explicit calculation of pairwise phase consistency
|
294
|
-
|
295
|
-
|
296
|
-
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
297
|
-
# for i in range(n_spikes - 1): # For each spike i
|
298
|
-
# for j in range(i + 1, n_spikes): # For each spike j > i
|
299
|
-
# # Calculate the phase difference between spikes i and j
|
300
|
-
# phase_diff = spike_phases[i] - spike_phases[j]
|
301
|
-
|
302
|
-
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
303
|
-
# cos_diff = np.cos(phase_diff)
|
304
|
-
|
305
|
-
# # Add to the sum
|
306
|
-
# sum_cos_diff += cos_diff
|
307
|
-
|
308
|
-
# # Calculate PPC according to the equation
|
309
|
-
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
310
|
-
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
311
|
-
|
312
|
-
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
275
|
+
# Vectorized computation for efficiency
|
313
276
|
if ppc_method == 'numpy':
|
314
277
|
i, j = np.triu_indices(n_spikes, k=1)
|
315
278
|
phase_diff = spike_phases[i] - spike_phases[j]
|
@@ -320,14 +283,14 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
|
|
320
283
|
elif ppc_method == 'gpu':
|
321
284
|
ppc = _ppc_gpu(spike_phases)
|
322
285
|
else:
|
323
|
-
raise
|
286
|
+
raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
|
324
287
|
return ppc
|
325
288
|
|
326
289
|
|
327
290
|
def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
328
291
|
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
329
|
-
lowcut: float = None, highcut: float = None,
|
330
|
-
|
292
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
293
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
331
294
|
"""
|
332
295
|
# -----------------------------------------------------------------------------
|
333
296
|
# PPC2 Calculation (Vinck et al., 2010)
|
@@ -343,7 +306,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
343
306
|
spike_times : np.ndarray
|
344
307
|
Array of spike times
|
345
308
|
lfp_data : np.ndarray
|
346
|
-
Local field potential time series data
|
309
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
347
310
|
spike_fs : float, optional
|
348
311
|
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
349
312
|
lfp_fs : float
|
@@ -358,6 +321,8 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
358
321
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
359
322
|
bandwidth : float, optional
|
360
323
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
324
|
+
filtered_lfp_phase : np.ndarray, optional
|
325
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
361
326
|
|
362
327
|
Returns
|
363
328
|
-------
|
@@ -374,32 +339,21 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
374
339
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
375
340
|
|
376
341
|
# Filter indices to ensure they're within bounds of the LFP signal
|
377
|
-
|
342
|
+
if filtered_lfp_phase is not None:
|
343
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
344
|
+
else:
|
345
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
346
|
+
|
378
347
|
if len(valid_indices) <= 1:
|
379
348
|
return 0
|
380
349
|
|
381
|
-
#
|
382
|
-
if
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
# Apply CWT to extract phase at the frequency of interest
|
387
|
-
lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
388
|
-
instantaneous_phase = np.angle(lfp_complex)
|
389
|
-
|
390
|
-
elif filter_method == 'butter':
|
391
|
-
if lowcut is None or highcut is None:
|
392
|
-
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
393
|
-
|
394
|
-
# Bandpass filter the signal
|
395
|
-
filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
|
396
|
-
|
397
|
-
# Get phase using the Hilbert transform
|
398
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
399
|
-
instantaneous_phase = np.angle(analytic_signal)
|
400
|
-
|
350
|
+
# Get instantaneous phase
|
351
|
+
if filtered_lfp_phase is None:
|
352
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
353
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
354
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
401
355
|
else:
|
402
|
-
|
356
|
+
instantaneous_phase = filtered_lfp_phase
|
403
357
|
|
404
358
|
# Get phases at spike times
|
405
359
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -422,14 +376,15 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
|
|
422
376
|
return ppc2
|
423
377
|
|
424
378
|
|
425
|
-
def
|
379
|
+
def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
|
380
|
+
entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
|
426
381
|
spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
|
427
|
-
|
382
|
+
freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
|
428
383
|
"""
|
429
|
-
Calculate
|
384
|
+
Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
|
430
385
|
|
431
|
-
This function computes the
|
432
|
-
and the provided LFP signal. It returns a nested dictionary structure containing the
|
386
|
+
This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
|
387
|
+
and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
|
433
388
|
organized by population, node ID, and frequency.
|
434
389
|
|
435
390
|
Parameters
|
@@ -438,14 +393,26 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
|
|
438
393
|
DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
|
439
394
|
lfp_data : np.ndarray
|
440
395
|
Local field potential (LFP) time series data
|
396
|
+
filter_method : str, optional
|
397
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
398
|
+
entrainment_method : str, optional
|
399
|
+
Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
|
400
|
+
lowcut : float, optional
|
401
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
402
|
+
highcut : float, optional
|
403
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
441
404
|
spike_fs : float
|
442
405
|
Sampling frequency of the spike times in Hz
|
443
406
|
lfp_fs : float
|
444
407
|
Sampling frequency of the LFP signal in Hz
|
408
|
+
bandwidth : float, optional
|
409
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
410
|
+
ppc_method : str, optional
|
411
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
445
412
|
pop_names : List[str]
|
446
413
|
List of population names to analyze
|
447
414
|
freqs : List[float]
|
448
|
-
List of frequencies (in Hz) at which to calculate
|
415
|
+
List of frequencies (in Hz) at which to calculate entrainment
|
449
416
|
|
450
417
|
Returns
|
451
418
|
-------
|
@@ -454,18 +421,32 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
|
|
454
421
|
{
|
455
422
|
population_name: {
|
456
423
|
node_id: {
|
457
|
-
frequency:
|
424
|
+
frequency: entrainment value
|
458
425
|
}
|
459
426
|
}
|
460
427
|
}
|
461
|
-
|
428
|
+
Entrainment values are floats representing the metric (PPC, PLV) at each frequency
|
462
429
|
"""
|
463
|
-
|
430
|
+
# pre filter lfp to speed up calculate of entrainment
|
431
|
+
filtered_lfp_phases = {}
|
432
|
+
for freq in range(len(freqs)):
|
433
|
+
phase = get_lfp_phase(
|
434
|
+
lfp_data=lfp_data,
|
435
|
+
freq_of_interest=freqs[freq],
|
436
|
+
fs=lfp_fs,
|
437
|
+
filter_method=filter_method,
|
438
|
+
lowcut=lowcut,
|
439
|
+
highcut=highcut,
|
440
|
+
bandwidth=bandwidth
|
441
|
+
)
|
442
|
+
filtered_lfp_phases[freqs[freq]] = phase
|
443
|
+
|
444
|
+
entrainment_dict = {}
|
464
445
|
for pop in pop_names:
|
465
446
|
skip_count = 0
|
466
447
|
pop_spikes = spike_df[spike_df['pop_name'] == pop]
|
467
448
|
nodes = pop_spikes['node_ids'].unique()
|
468
|
-
|
449
|
+
entrainment_dict[pop] = {}
|
469
450
|
print(f'Processing {pop} population')
|
470
451
|
for node in tqdm(nodes):
|
471
452
|
node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
|
@@ -475,22 +456,53 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
|
|
475
456
|
skip_count += 1
|
476
457
|
continue
|
477
458
|
|
478
|
-
|
459
|
+
entrainment_dict[pop][node] = {}
|
479
460
|
for freq in freqs:
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
461
|
+
# Calculate entrainment based on the selected method using the pre-filtered phases
|
462
|
+
if entrainment_method == 'plv':
|
463
|
+
entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
|
464
|
+
node_spikes['timestamps'].values,
|
465
|
+
lfp_data,
|
466
|
+
spike_fs=spike_fs,
|
467
|
+
lfp_fs=lfp_fs,
|
468
|
+
freq_of_interest=freq,
|
469
|
+
bandwidth=bandwidth,
|
470
|
+
lowcut=lowcut,
|
471
|
+
highcut=highcut,
|
472
|
+
filter_method=filter_method,
|
473
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
474
|
+
)
|
475
|
+
elif entrainment_method == 'ppc2':
|
476
|
+
entrainment_dict[pop][node][freq] = calculate_ppc2(
|
477
|
+
node_spikes['timestamps'].values,
|
478
|
+
lfp_data,
|
479
|
+
spike_fs=spike_fs,
|
480
|
+
lfp_fs=lfp_fs,
|
481
|
+
freq_of_interest=freq,
|
482
|
+
bandwidth=bandwidth,
|
483
|
+
lowcut=lowcut,
|
484
|
+
highcut=highcut,
|
485
|
+
filter_method=filter_method,
|
486
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
487
|
+
)
|
488
|
+
elif entrainment_method == 'ppc':
|
489
|
+
entrainment_dict[pop][node][freq] = calculate_ppc(
|
490
|
+
node_spikes['timestamps'].values,
|
491
|
+
lfp_data,
|
492
|
+
spike_fs=spike_fs,
|
493
|
+
lfp_fs=lfp_fs,
|
494
|
+
freq_of_interest=freq,
|
495
|
+
bandwidth=bandwidth,
|
496
|
+
lowcut=lowcut,
|
497
|
+
highcut=highcut,
|
498
|
+
filter_method=filter_method,
|
499
|
+
ppc_method=ppc_method,
|
500
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
501
|
+
)
|
490
502
|
|
491
|
-
print(f'Calculated
|
503
|
+
print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
|
492
504
|
|
493
|
-
return
|
505
|
+
return entrainment_dict
|
494
506
|
|
495
507
|
|
496
508
|
def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
|
@@ -540,12 +552,8 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
|
|
540
552
|
# Calculate power at each frequency band using specified filter
|
541
553
|
power_by_freq = {}
|
542
554
|
for freq in frequencies:
|
543
|
-
|
544
|
-
|
545
|
-
lowcut=None, highcut=None, bandwidth=bandwidth)
|
546
|
-
elif filter_method == 'butter':
|
547
|
-
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
548
|
-
lowcut=lowcut, highcut=highcut)
|
555
|
+
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
556
|
+
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
|
549
557
|
|
550
558
|
# Calculate correlation for each population
|
551
559
|
for pop in pop_names:
|
@@ -412,6 +412,61 @@ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: s
|
|
412
412
|
return power
|
413
413
|
|
414
414
|
|
415
|
+
def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
|
416
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
|
417
|
+
"""
|
418
|
+
Calculate the phase of the filtered signal.
|
419
|
+
|
420
|
+
Parameters
|
421
|
+
----------
|
422
|
+
lfp_data : np.ndarray
|
423
|
+
Input LFP data
|
424
|
+
fs : float
|
425
|
+
Sampling frequency (Hz)
|
426
|
+
freq : float
|
427
|
+
Frequency of interest (Hz)
|
428
|
+
filter_method : str, optional
|
429
|
+
Method for filtering the signal ('wavelet' or 'butter')
|
430
|
+
bandwidth : float, optional
|
431
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
|
432
|
+
lowcut : float, optional
|
433
|
+
Low cutoff frequency for Butterworth filter when method='butter'
|
434
|
+
highcut : float, optional
|
435
|
+
High cutoff frequency for Butterworth filter when method='butter'
|
436
|
+
|
437
|
+
Returns
|
438
|
+
-------
|
439
|
+
np.ndarray
|
440
|
+
Phase of the filtered signal
|
441
|
+
|
442
|
+
Notes
|
443
|
+
-----
|
444
|
+
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
445
|
+
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
446
|
+
followed by Hilbert transform to extract the phase
|
447
|
+
- When using the 'butter' method, both lowcut and highcut must be provided
|
448
|
+
"""
|
449
|
+
if filter_method == 'wavelet':
|
450
|
+
if freq_of_interest is None:
|
451
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
452
|
+
# Wavelet filter returns complex values directly
|
453
|
+
filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
|
454
|
+
# Phase is the angle of the complex signal
|
455
|
+
phase = np.angle(filtered_signal)
|
456
|
+
elif filter_method == 'butter':
|
457
|
+
if lowcut is None or highcut is None:
|
458
|
+
raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
|
459
|
+
# Butterworth filter returns real values
|
460
|
+
filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
|
461
|
+
# Apply Hilbert transform to get analytic signal (complex)
|
462
|
+
analytic_signal = signal.hilbert(filtered_signal)
|
463
|
+
# Phase is the angle of the analytic signal
|
464
|
+
phase = np.angle(analytic_signal)
|
465
|
+
else:
|
466
|
+
raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
|
467
|
+
|
468
|
+
return phase
|
469
|
+
|
415
470
|
# windowing functions
|
416
471
|
def windowed_xarray(da, windows, dim='time',
|
417
472
|
new_coord_name='cycle', new_coord=None):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|