bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +317 -234
- bmtool/analysis/lfp.py +176 -1
- bmtool/analysis/spikes.py +115 -63
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/METADATA +1 -1
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/RECORD +9 -9
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/WHEEL +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -7,58 +7,71 @@ from scipy import signal
|
|
7
7
|
import numba
|
8
8
|
from numba import cuda
|
9
9
|
import pandas as pd
|
10
|
-
import
|
11
|
-
from
|
12
|
-
from typing import Dict, List
|
10
|
+
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
|
11
|
+
from typing import Dict, List, Optional
|
13
12
|
from tqdm.notebook import tqdm
|
14
13
|
import scipy.stats as stats
|
15
|
-
import seaborn as sns
|
16
|
-
import matplotlib.pyplot as plt
|
17
14
|
|
18
15
|
|
19
|
-
def calculate_signal_signal_plv(
|
20
|
-
|
16
|
+
def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
|
17
|
+
filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
21
18
|
bandwidth: float = 2.0) -> np.ndarray:
|
22
19
|
"""
|
23
20
|
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
24
21
|
|
25
|
-
Parameters
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
signal1 : np.ndarray
|
25
|
+
First input signal (1D array)
|
26
|
+
signal2 : np.ndarray
|
27
|
+
Second input signal (1D array, same length as signal1)
|
28
|
+
fs : float
|
29
|
+
Sampling frequency in Hz
|
30
|
+
freq_of_interest : float, optional
|
31
|
+
Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
|
32
|
+
filter_method : str, optional
|
33
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
34
|
+
lowcut : float, optional
|
35
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
36
|
+
highcut : float, optional
|
37
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
38
|
+
bandwidth : float, optional
|
39
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
40
|
+
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
np.ndarray
|
44
|
+
Phase Locking Value (1D array)
|
35
45
|
"""
|
36
|
-
if len(
|
46
|
+
if len(signal1) != len(signal2):
|
37
47
|
raise ValueError("Input signals must have the same length.")
|
38
48
|
|
39
|
-
if
|
49
|
+
if filter_method == 'wavelet':
|
40
50
|
if freq_of_interest is None:
|
41
51
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
42
52
|
|
43
53
|
# Apply CWT to both signals
|
44
|
-
theta1 = wavelet_filter(x=
|
45
|
-
theta2 = wavelet_filter(x=
|
54
|
+
theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
55
|
+
theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
46
56
|
|
47
|
-
elif
|
57
|
+
elif filter_method == 'butter':
|
48
58
|
if lowcut is None or highcut is None:
|
49
|
-
print("Lowcut and
|
59
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
|
50
60
|
|
51
61
|
if lowcut and highcut:
|
52
62
|
# Bandpass filter and get the analytic signal using the Hilbert transform
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
63
|
+
filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
|
64
|
+
filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
|
65
|
+
# Get phase using the Hilbert transform
|
66
|
+
theta1 = signal.hilbert(filtered_signal1)
|
67
|
+
theta2 = signal.hilbert(filtered_signal2)
|
68
|
+
else:
|
69
|
+
# Get phase using the Hilbert transform without filtering
|
70
|
+
theta1 = signal.hilbert(signal1)
|
71
|
+
theta2 = signal.hilbert(signal2)
|
59
72
|
|
60
73
|
else:
|
61
|
-
raise ValueError("Invalid method. Choose 'wavelet' or '
|
74
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
62
75
|
|
63
76
|
# Calculate phase difference
|
64
77
|
phase_diff = np.angle(theta1) - np.angle(theta2)
|
@@ -69,29 +82,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_
|
|
69
82
|
return plv
|
70
83
|
|
71
84
|
|
72
|
-
def calculate_spike_lfp_plv(spike_times: np.ndarray = None,
|
73
|
-
lfp_fs: float = None,
|
74
|
-
lowcut: float = None, highcut: float = None,
|
75
|
-
|
85
|
+
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
86
|
+
lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
|
87
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
88
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
76
89
|
"""
|
77
|
-
Calculate spike-lfp phase locking value
|
78
|
-
|
79
|
-
Parameters
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
90
|
+
Calculate spike-lfp unbiased phase locking value
|
91
|
+
|
92
|
+
Parameters
|
93
|
+
----------
|
94
|
+
spike_times : np.ndarray
|
95
|
+
Array of spike times
|
96
|
+
lfp_data : np.ndarray
|
97
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
98
|
+
spike_fs : float, optional
|
99
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
100
|
+
lfp_fs : float
|
101
|
+
Sampling frequency in Hz of the LFP data
|
102
|
+
filter_method : str, optional
|
103
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
|
104
|
+
freq_of_interest : float, optional
|
105
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
106
|
+
lowcut : float, optional
|
107
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
108
|
+
highcut : float, optional
|
109
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
110
|
+
bandwidth : float, optional
|
111
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
112
|
+
filtered_lfp_phase : np.ndarray, optional
|
113
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
114
|
+
|
115
|
+
Returns
|
116
|
+
-------
|
117
|
+
float
|
118
|
+
Phase Locking Value (unbiased)
|
92
119
|
"""
|
93
120
|
|
94
|
-
if spike_fs
|
121
|
+
if spike_fs is None:
|
95
122
|
spike_fs = lfp_fs
|
96
123
|
# Convert spike times to sample indices
|
97
124
|
spike_times_seconds = spike_times / spike_fs
|
@@ -100,50 +127,41 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
|
|
100
127
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
101
128
|
|
102
129
|
# Filter indices to ensure they're within bounds of the LFP signal
|
103
|
-
|
130
|
+
if filtered_lfp_phase is not None:
|
131
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
132
|
+
else:
|
133
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
134
|
+
|
104
135
|
if len(valid_indices) <= 1:
|
105
|
-
return 0
|
136
|
+
return 0
|
106
137
|
|
107
|
-
#
|
108
|
-
if
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
# Apply CWT to extract phase at the frequency of interest
|
113
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
114
|
-
instantaneous_phase = np.angle(lfp_complex)
|
115
|
-
|
116
|
-
elif method == 'hilbert':
|
117
|
-
if lowcut is None or highcut is None:
|
118
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
|
119
|
-
filtered_lfp = lfp_signal
|
120
|
-
else:
|
121
|
-
# Bandpass filter the signal
|
122
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
123
|
-
|
124
|
-
# Get phase using the Hilbert transform
|
125
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
126
|
-
instantaneous_phase = np.angle(analytic_signal)
|
127
|
-
|
138
|
+
# Get instantaneous phase
|
139
|
+
if filtered_lfp_phase is None:
|
140
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
141
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
142
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
128
143
|
else:
|
129
|
-
|
144
|
+
instantaneous_phase = filtered_lfp_phase
|
130
145
|
|
131
146
|
# Get phases at spike times
|
132
147
|
spike_phases = instantaneous_phase[valid_indices]
|
133
|
-
|
134
|
-
#
|
135
|
-
|
148
|
+
|
149
|
+
# Number of spikes
|
150
|
+
N = len(spike_phases)
|
136
151
|
|
137
152
|
# Convert phases to unit vectors in the complex plane
|
138
153
|
unit_vectors = np.exp(1j * spike_phases)
|
139
154
|
|
140
|
-
#
|
155
|
+
# Sum of all unit vectors (resultant vector)
|
141
156
|
resultant_vector = np.sum(unit_vectors)
|
142
|
-
|
143
|
-
#
|
144
|
-
|
145
|
-
|
146
|
-
|
157
|
+
|
158
|
+
# Calculate plv^2 * N
|
159
|
+
plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
|
160
|
+
plv = (plv2n / N) ** 0.5
|
161
|
+
ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
|
162
|
+
plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
|
163
|
+
|
164
|
+
return plv_unbiased
|
147
165
|
|
148
166
|
|
149
167
|
@numba.njit(parallel=True, fastmath=True)
|
@@ -181,27 +199,43 @@ def _ppc_gpu(spike_phases):
|
|
181
199
|
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
182
200
|
|
183
201
|
|
184
|
-
def calculate_ppc(spike_times: np.ndarray = None,
|
185
|
-
lfp_fs: float = None,
|
186
|
-
lowcut: float = None, highcut: float = None,
|
187
|
-
|
202
|
+
def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
203
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
204
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
205
|
+
ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
|
188
206
|
"""
|
189
207
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
190
208
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
191
209
|
|
192
|
-
Parameters
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
spike_times : np.ndarray
|
213
|
+
Array of spike times
|
214
|
+
lfp_data : np.ndarray
|
215
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
216
|
+
spike_fs : float, optional
|
217
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
218
|
+
lfp_fs : float
|
219
|
+
Sampling frequency in Hz of the LFP data
|
220
|
+
filter_method : str, optional
|
221
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
222
|
+
freq_of_interest : float, optional
|
223
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
224
|
+
lowcut : float, optional
|
225
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
226
|
+
highcut : float, optional
|
227
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
228
|
+
bandwidth : float, optional
|
229
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
230
|
+
ppc_method : str, optional
|
231
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
232
|
+
filtered_lfp_phase : np.ndarray, optional
|
233
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
234
|
+
|
235
|
+
Returns
|
236
|
+
-------
|
237
|
+
float
|
238
|
+
Pairwise Phase Consistency value
|
205
239
|
"""
|
206
240
|
if spike_fs is None:
|
207
241
|
spike_fs = lfp_fs
|
@@ -212,33 +246,21 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
212
246
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
213
247
|
|
214
248
|
# Filter indices to ensure they're within bounds of the LFP signal
|
215
|
-
|
249
|
+
if filtered_lfp_phase is not None:
|
250
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
251
|
+
else:
|
252
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
253
|
+
|
216
254
|
if len(valid_indices) <= 1:
|
217
|
-
return 0
|
255
|
+
return 0
|
218
256
|
|
219
|
-
#
|
220
|
-
if
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
# Apply CWT to extract phase at the frequency of interest
|
225
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
226
|
-
instantaneous_phase = np.angle(lfp_complex)
|
227
|
-
|
228
|
-
elif method == 'hilbert':
|
229
|
-
if lowcut is None or highcut is None:
|
230
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
|
231
|
-
filtered_lfp = lfp_signal
|
232
|
-
else:
|
233
|
-
# Bandpass filter the signal
|
234
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
235
|
-
|
236
|
-
# Get phase using the Hilbert transform
|
237
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
238
|
-
instantaneous_phase = np.angle(analytic_signal)
|
239
|
-
|
257
|
+
# Get instantaneous phase
|
258
|
+
if filtered_lfp_phase is None:
|
259
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
260
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
261
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
240
262
|
else:
|
241
|
-
|
263
|
+
instantaneous_phase = filtered_lfp_phase
|
242
264
|
|
243
265
|
# Get phases at spike times
|
244
266
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -247,28 +269,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
247
269
|
|
248
270
|
# Calculate PPC (Pairwise Phase Consistency)
|
249
271
|
if n_spikes <= 1:
|
250
|
-
return 0
|
272
|
+
return 0
|
251
273
|
|
252
274
|
# Explicit calculation of pairwise phase consistency
|
253
|
-
|
254
|
-
|
255
|
-
# # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
256
|
-
# for i in range(n_spikes - 1): # For each spike i
|
257
|
-
# for j in range(i + 1, n_spikes): # For each spike j > i
|
258
|
-
# # Calculate the phase difference between spikes i and j
|
259
|
-
# phase_diff = spike_phases[i] - spike_phases[j]
|
260
|
-
|
261
|
-
# #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
|
262
|
-
# cos_diff = np.cos(phase_diff)
|
263
|
-
|
264
|
-
# # Add to the sum
|
265
|
-
# sum_cos_diff += cos_diff
|
266
|
-
|
267
|
-
# # Calculate PPC according to the equation
|
268
|
-
# # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
|
269
|
-
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
270
|
-
|
271
|
-
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
275
|
+
# Vectorized computation for efficiency
|
272
276
|
if ppc_method == 'numpy':
|
273
277
|
i, j = np.triu_indices(n_spikes, k=1)
|
274
278
|
phase_diff = spike_phases[i] - spike_phases[j]
|
@@ -279,14 +283,14 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
279
283
|
elif ppc_method == 'gpu':
|
280
284
|
ppc = _ppc_gpu(spike_phases)
|
281
285
|
else:
|
282
|
-
raise
|
286
|
+
raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
|
283
287
|
return ppc
|
284
288
|
|
285
289
|
|
286
|
-
def calculate_ppc2(spike_times: np.ndarray = None,
|
287
|
-
lfp_fs: float = None,
|
288
|
-
lowcut: float = None, highcut: float = None,
|
289
|
-
|
290
|
+
def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
291
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
292
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
|
293
|
+
filtered_lfp_phase: np.ndarray = None) -> float:
|
290
294
|
"""
|
291
295
|
# -----------------------------------------------------------------------------
|
292
296
|
# PPC2 Calculation (Vinck et al., 2010)
|
@@ -297,18 +301,33 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
297
301
|
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
298
302
|
# -----------------------------------------------------------------------------
|
299
303
|
|
300
|
-
Parameters
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
304
|
+
Parameters
|
305
|
+
----------
|
306
|
+
spike_times : np.ndarray
|
307
|
+
Array of spike times
|
308
|
+
lfp_data : np.ndarray
|
309
|
+
Local field potential time series data. Not required if filtered_lfp_phase is provided.
|
310
|
+
spike_fs : float, optional
|
311
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
312
|
+
lfp_fs : float
|
313
|
+
Sampling frequency in Hz of the LFP data
|
314
|
+
filter_method : str, optional
|
315
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
316
|
+
freq_of_interest : float, optional
|
317
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
318
|
+
lowcut : float, optional
|
319
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
320
|
+
highcut : float, optional
|
321
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
322
|
+
bandwidth : float, optional
|
323
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
324
|
+
filtered_lfp_phase : np.ndarray, optional
|
325
|
+
Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
|
326
|
+
|
327
|
+
Returns
|
328
|
+
-------
|
329
|
+
float
|
330
|
+
Pairwise Phase Consistency 2 (PPC2) value
|
312
331
|
"""
|
313
332
|
|
314
333
|
if spike_fs is None:
|
@@ -320,33 +339,21 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
320
339
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
321
340
|
|
322
341
|
# Filter indices to ensure they're within bounds of the LFP signal
|
323
|
-
|
342
|
+
if filtered_lfp_phase is not None:
|
343
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
|
344
|
+
else:
|
345
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
346
|
+
|
324
347
|
if len(valid_indices) <= 1:
|
325
|
-
return 0
|
348
|
+
return 0
|
326
349
|
|
327
|
-
#
|
328
|
-
if
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
# Apply CWT to extract phase at the frequency of interest
|
333
|
-
lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
334
|
-
instantaneous_phase = np.angle(lfp_complex)
|
335
|
-
|
336
|
-
elif method == 'hilbert':
|
337
|
-
if lowcut is None or highcut is None:
|
338
|
-
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
|
339
|
-
filtered_lfp = lfp_signal
|
340
|
-
else:
|
341
|
-
# Bandpass filter the signal
|
342
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
343
|
-
|
344
|
-
# Get phase using the Hilbert transform
|
345
|
-
analytic_signal = signal.hilbert(filtered_lfp)
|
346
|
-
instantaneous_phase = np.angle(analytic_signal)
|
347
|
-
|
350
|
+
# Get instantaneous phase
|
351
|
+
if filtered_lfp_phase is None:
|
352
|
+
instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
|
353
|
+
freq_of_interest=freq_of_interest, lowcut=lowcut,
|
354
|
+
highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
|
348
355
|
else:
|
349
|
-
|
356
|
+
instantaneous_phase = filtered_lfp_phase
|
350
357
|
|
351
358
|
# Get phases at spike times
|
352
359
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -355,7 +362,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
355
362
|
n = len(spike_phases)
|
356
363
|
|
357
364
|
if n <= 1:
|
358
|
-
return 0
|
365
|
+
return 0
|
359
366
|
|
360
367
|
# Convert phases to unit vectors in the complex plane
|
361
368
|
unit_vectors = np.exp(1j * spike_phases)
|
@@ -369,40 +376,77 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
369
376
|
return ppc2
|
370
377
|
|
371
378
|
|
372
|
-
def
|
373
|
-
|
374
|
-
|
379
|
+
def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
|
380
|
+
entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
|
381
|
+
spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
|
382
|
+
freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
|
375
383
|
"""
|
376
|
-
Calculate
|
384
|
+
Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
|
377
385
|
|
378
|
-
This function computes the
|
379
|
-
and a
|
386
|
+
This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
|
387
|
+
and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
|
388
|
+
organized by population, node ID, and frequency.
|
380
389
|
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
390
|
+
Parameters
|
391
|
+
----------
|
392
|
+
spike_df : pd.DataFrame
|
393
|
+
DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
|
394
|
+
lfp_data : np.ndarray
|
395
|
+
Local field potential (LFP) time series data
|
396
|
+
filter_method : str, optional
|
397
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
398
|
+
entrainment_method : str, optional
|
399
|
+
Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
|
400
|
+
lowcut : float, optional
|
401
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
402
|
+
highcut : float, optional
|
403
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
404
|
+
spike_fs : float
|
405
|
+
Sampling frequency of the spike times in Hz
|
406
|
+
lfp_fs : float
|
407
|
+
Sampling frequency of the LFP signal in Hz
|
408
|
+
bandwidth : float, optional
|
409
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
410
|
+
ppc_method : str, optional
|
411
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
412
|
+
pop_names : List[str]
|
413
|
+
List of population names to analyze
|
414
|
+
freqs : List[float]
|
415
|
+
List of frequencies (in Hz) at which to calculate entrainment
|
388
416
|
|
389
|
-
Returns
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
417
|
+
Returns
|
418
|
+
-------
|
419
|
+
Dict[str, Dict[int, Dict[float, float]]]
|
420
|
+
Nested dictionary where the structure is:
|
421
|
+
{
|
422
|
+
population_name: {
|
423
|
+
node_id: {
|
424
|
+
frequency: entrainment value
|
396
425
|
}
|
397
426
|
}
|
398
|
-
|
427
|
+
}
|
428
|
+
Entrainment values are floats representing the metric (PPC, PLV) at each frequency
|
399
429
|
"""
|
400
|
-
|
430
|
+
# pre filter lfp to speed up calculate of entrainment
|
431
|
+
filtered_lfp_phases = {}
|
432
|
+
for freq in range(len(freqs)):
|
433
|
+
phase = get_lfp_phase(
|
434
|
+
lfp_data=lfp_data,
|
435
|
+
freq_of_interest=freqs[freq],
|
436
|
+
fs=lfp_fs,
|
437
|
+
filter_method=filter_method,
|
438
|
+
lowcut=lowcut,
|
439
|
+
highcut=highcut,
|
440
|
+
bandwidth=bandwidth
|
441
|
+
)
|
442
|
+
filtered_lfp_phases[freqs[freq]] = phase
|
443
|
+
|
444
|
+
entrainment_dict = {}
|
401
445
|
for pop in pop_names:
|
402
446
|
skip_count = 0
|
403
447
|
pop_spikes = spike_df[spike_df['pop_name'] == pop]
|
404
448
|
nodes = pop_spikes['node_ids'].unique()
|
405
|
-
|
449
|
+
entrainment_dict[pop] = {}
|
406
450
|
print(f'Processing {pop} population')
|
407
451
|
for node in tqdm(nodes):
|
408
452
|
node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
|
@@ -412,24 +456,58 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
|
|
412
456
|
skip_count += 1
|
413
457
|
continue
|
414
458
|
|
415
|
-
|
459
|
+
entrainment_dict[pop][node] = {}
|
416
460
|
for freq in freqs:
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
461
|
+
# Calculate entrainment based on the selected method using the pre-filtered phases
|
462
|
+
if entrainment_method == 'plv':
|
463
|
+
entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
|
464
|
+
node_spikes['timestamps'].values,
|
465
|
+
lfp_data,
|
466
|
+
spike_fs=spike_fs,
|
467
|
+
lfp_fs=lfp_fs,
|
468
|
+
freq_of_interest=freq,
|
469
|
+
bandwidth=bandwidth,
|
470
|
+
lowcut=lowcut,
|
471
|
+
highcut=highcut,
|
472
|
+
filter_method=filter_method,
|
473
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
474
|
+
)
|
475
|
+
elif entrainment_method == 'ppc2':
|
476
|
+
entrainment_dict[pop][node][freq] = calculate_ppc2(
|
477
|
+
node_spikes['timestamps'].values,
|
478
|
+
lfp_data,
|
479
|
+
spike_fs=spike_fs,
|
480
|
+
lfp_fs=lfp_fs,
|
481
|
+
freq_of_interest=freq,
|
482
|
+
bandwidth=bandwidth,
|
483
|
+
lowcut=lowcut,
|
484
|
+
highcut=highcut,
|
485
|
+
filter_method=filter_method,
|
486
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
487
|
+
)
|
488
|
+
elif entrainment_method == 'ppc':
|
489
|
+
entrainment_dict[pop][node][freq] = calculate_ppc(
|
490
|
+
node_spikes['timestamps'].values,
|
491
|
+
lfp_data,
|
492
|
+
spike_fs=spike_fs,
|
493
|
+
lfp_fs=lfp_fs,
|
494
|
+
freq_of_interest=freq,
|
495
|
+
bandwidth=bandwidth,
|
496
|
+
lowcut=lowcut,
|
497
|
+
highcut=highcut,
|
498
|
+
filter_method=filter_method,
|
499
|
+
ppc_method=ppc_method,
|
500
|
+
filtered_lfp_phase=filtered_lfp_phases[freq]
|
501
|
+
)
|
426
502
|
|
427
|
-
print(f'Calculated
|
503
|
+
print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
|
428
504
|
|
429
|
-
return
|
505
|
+
return entrainment_dict
|
430
506
|
|
431
507
|
|
432
|
-
def calculate_spike_rate_power_correlation(spike_rate,
|
508
|
+
def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
|
509
|
+
bandwidth=2.0, lowcut=None, highcut=None,
|
510
|
+
freq_range=(10, 100), freq_step=5):
|
433
511
|
"""
|
434
512
|
Calculate correlation between population spike rates and LFP power across frequencies
|
435
513
|
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
@@ -438,16 +516,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
438
516
|
-----------
|
439
517
|
spike_rate : DataFrame
|
440
518
|
Pre-calculated population spike rates at the same fs as lfp
|
441
|
-
|
519
|
+
lfp_data : np.array
|
442
520
|
LFP data
|
443
521
|
fs : float
|
444
522
|
Sampling frequency
|
445
523
|
pop_names : list
|
446
524
|
List of population names to analyze
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
525
|
+
filter_method : str, optional
|
526
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
527
|
+
bandwidth : float, optional
|
528
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
529
|
+
lowcut : float, optional
|
530
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
531
|
+
highcut : float, optional
|
532
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
533
|
+
freq_range : tuple, optional
|
534
|
+
Min and max frequency to analyze (default: (10, 100))
|
535
|
+
freq_step : float, optional
|
536
|
+
Step size for frequency analysis (default: 5)
|
451
537
|
|
452
538
|
Returns:
|
453
539
|
--------
|
@@ -463,14 +549,11 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
463
549
|
# Dictionary to store results
|
464
550
|
correlation_results = {pop: {} for pop in pop_names}
|
465
551
|
|
466
|
-
# Calculate power at each frequency band using
|
552
|
+
# Calculate power at each frequency band using specified filter
|
467
553
|
power_by_freq = {}
|
468
554
|
for freq in frequencies:
|
469
|
-
|
470
|
-
|
471
|
-
# Calculate power (magnitude squared of complex wavelet transform)
|
472
|
-
power = np.abs(filtered_signal)**2
|
473
|
-
power_by_freq[freq] = power
|
555
|
+
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
556
|
+
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
|
474
557
|
|
475
558
|
# Calculate correlation for each population
|
476
559
|
for pop in pop_names:
|
@@ -481,7 +564,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
481
564
|
for freq in frequencies:
|
482
565
|
# Make sure the lengths match
|
483
566
|
if len(pop_rate) != len(power_by_freq[freq]):
|
484
|
-
raise
|
567
|
+
raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
|
485
568
|
# use spearman for non-parametric correlation
|
486
569
|
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
487
570
|
correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
|
bmtool/analysis/lfp.py
CHANGED
@@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
|
|
273
273
|
return normalized_power
|
274
274
|
|
275
275
|
|
276
|
-
def
|
276
|
+
def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
|
277
|
+
"""
|
278
|
+
Calculate the passband of a complex Morlet wavelet filter.
|
279
|
+
|
280
|
+
Parameters
|
281
|
+
----------
|
282
|
+
center_freq : float
|
283
|
+
Center frequency (Hz) of the wavelet filter
|
284
|
+
bandwidth : float
|
285
|
+
Bandwidth parameter of the wavelet filter
|
286
|
+
threshold : float, optional
|
287
|
+
Power threshold to define the passband edges (default: 0.5 = -3dB point)
|
288
|
+
|
289
|
+
Returns
|
290
|
+
-------
|
291
|
+
tuple
|
292
|
+
(lower_bound, upper_bound, passband_width) of the frequency passband in Hz
|
293
|
+
"""
|
294
|
+
# Create a high-resolution frequency axis around the center frequency
|
295
|
+
# Extend range to 3x the expected width to ensure we capture the full passband
|
296
|
+
expected_width = center_freq * bandwidth / 2
|
297
|
+
freq_min = max(0.1, center_freq - 3 * expected_width)
|
298
|
+
freq_max = center_freq + 3 * expected_width
|
299
|
+
freq_axis = np.linspace(freq_min, freq_max, 1000)
|
300
|
+
|
301
|
+
# Calculate the theoretical frequency response of the Morlet wavelet
|
302
|
+
# For a complex Morlet wavelet, the frequency response approximates a Gaussian
|
303
|
+
# centered at the center frequency with width related to the bandwidth parameter
|
304
|
+
sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
|
305
|
+
response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
|
306
|
+
|
307
|
+
# Find the passband edges (where response crosses the threshold)
|
308
|
+
above_threshold = response >= threshold
|
309
|
+
if not np.any(above_threshold):
|
310
|
+
return (center_freq, center_freq, 0) # No passband found
|
311
|
+
|
312
|
+
# Find the first and last indices where response is above threshold
|
313
|
+
indices = np.where(above_threshold)[0]
|
314
|
+
lower_idx = indices[0]
|
315
|
+
upper_idx = indices[-1]
|
316
|
+
|
317
|
+
# Get the corresponding frequencies
|
318
|
+
lower_bound = freq_axis[lower_idx]
|
319
|
+
upper_bound = freq_axis[upper_idx]
|
320
|
+
passband_width = upper_bound - lower_bound
|
321
|
+
|
322
|
+
return (lower_bound, upper_bound, passband_width)
|
323
|
+
|
324
|
+
|
325
|
+
def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
|
277
326
|
"""
|
278
327
|
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
|
328
|
+
|
329
|
+
Parameters
|
330
|
+
----------
|
331
|
+
x : np.ndarray
|
332
|
+
Input signal
|
333
|
+
freq : float
|
334
|
+
Target frequency for the wavelet filter
|
335
|
+
fs : float
|
336
|
+
Sampling frequency of the signal
|
337
|
+
bandwidth : float, optional
|
338
|
+
Bandwidth parameter of the wavelet filter (default is 1.0)
|
339
|
+
axis : int, optional
|
340
|
+
Axis along which to compute the CWT (default is -1)
|
341
|
+
show_passband : bool, optional
|
342
|
+
If True, print the passband of the wavelet filter (default is False)
|
343
|
+
|
344
|
+
Returns
|
345
|
+
-------
|
346
|
+
np.ndarray
|
347
|
+
Continuous Wavelet Transform of the input signal
|
279
348
|
"""
|
349
|
+
if show_passband:
|
350
|
+
lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
|
351
|
+
print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
|
352
|
+
print(f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
|
280
353
|
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
|
281
354
|
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
|
282
355
|
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
|
@@ -292,6 +365,108 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
|
|
292
365
|
return x_a
|
293
366
|
|
294
367
|
|
368
|
+
def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
|
369
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
|
370
|
+
"""
|
371
|
+
Compute the power of the raw LFP signal in a specified frequency band.
|
372
|
+
|
373
|
+
Parameters
|
374
|
+
----------
|
375
|
+
lfp_data : np.ndarray
|
376
|
+
Raw local field potential (LFP) time series data
|
377
|
+
freq : float
|
378
|
+
Center frequency (Hz) for wavelet filtering method
|
379
|
+
fs : float
|
380
|
+
Sampling frequency (Hz) of the input data
|
381
|
+
filter_method : str, optional
|
382
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
383
|
+
lowcut : float, optional
|
384
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
385
|
+
highcut : float, optional
|
386
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
387
|
+
bandwidth : float, optional
|
388
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
|
389
|
+
|
390
|
+
Returns
|
391
|
+
-------
|
392
|
+
np.ndarray
|
393
|
+
Power of the filtered signal (magnitude squared)
|
394
|
+
|
395
|
+
Notes
|
396
|
+
-----
|
397
|
+
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
398
|
+
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
399
|
+
- When using the 'butter' method, both lowcut and highcut must be provided
|
400
|
+
"""
|
401
|
+
if filter_method == 'wavelet':
|
402
|
+
filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
|
403
|
+
elif filter_method == 'butter':
|
404
|
+
if lowcut is None or highcut is None:
|
405
|
+
raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
|
406
|
+
filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
|
407
|
+
else:
|
408
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
409
|
+
|
410
|
+
# Calculate power (magnitude squared of filtered signal)
|
411
|
+
power = np.abs(filtered_signal)**2
|
412
|
+
return power
|
413
|
+
|
414
|
+
|
415
|
+
def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
|
416
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
|
417
|
+
"""
|
418
|
+
Calculate the phase of the filtered signal.
|
419
|
+
|
420
|
+
Parameters
|
421
|
+
----------
|
422
|
+
lfp_data : np.ndarray
|
423
|
+
Input LFP data
|
424
|
+
fs : float
|
425
|
+
Sampling frequency (Hz)
|
426
|
+
freq : float
|
427
|
+
Frequency of interest (Hz)
|
428
|
+
filter_method : str, optional
|
429
|
+
Method for filtering the signal ('wavelet' or 'butter')
|
430
|
+
bandwidth : float, optional
|
431
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
|
432
|
+
lowcut : float, optional
|
433
|
+
Low cutoff frequency for Butterworth filter when method='butter'
|
434
|
+
highcut : float, optional
|
435
|
+
High cutoff frequency for Butterworth filter when method='butter'
|
436
|
+
|
437
|
+
Returns
|
438
|
+
-------
|
439
|
+
np.ndarray
|
440
|
+
Phase of the filtered signal
|
441
|
+
|
442
|
+
Notes
|
443
|
+
-----
|
444
|
+
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
445
|
+
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
446
|
+
followed by Hilbert transform to extract the phase
|
447
|
+
- When using the 'butter' method, both lowcut and highcut must be provided
|
448
|
+
"""
|
449
|
+
if filter_method == 'wavelet':
|
450
|
+
if freq_of_interest is None:
|
451
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
452
|
+
# Wavelet filter returns complex values directly
|
453
|
+
filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
|
454
|
+
# Phase is the angle of the complex signal
|
455
|
+
phase = np.angle(filtered_signal)
|
456
|
+
elif filter_method == 'butter':
|
457
|
+
if lowcut is None or highcut is None:
|
458
|
+
raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
|
459
|
+
# Butterworth filter returns real values
|
460
|
+
filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
|
461
|
+
# Apply Hilbert transform to get analytic signal (complex)
|
462
|
+
analytic_signal = signal.hilbert(filtered_signal)
|
463
|
+
# Phase is the angle of the analytic signal
|
464
|
+
phase = np.angle(analytic_signal)
|
465
|
+
else:
|
466
|
+
raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
|
467
|
+
|
468
|
+
return phase
|
469
|
+
|
295
470
|
# windowing functions
|
296
471
|
def windowed_xarray(da, windows, dim='time',
|
297
472
|
new_coord_name='cycle', new_coord=None):
|
bmtool/analysis/spikes.py
CHANGED
@@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu
|
|
11
11
|
import os
|
12
12
|
|
13
13
|
|
14
|
-
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
|
14
|
+
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
|
15
15
|
"""
|
16
16
|
Load spike data from an HDF5 file into a pandas DataFrame.
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
spike_file : str
|
21
|
+
Path to the HDF5 file containing spike data
|
22
|
+
network_name : str
|
23
|
+
The name of the network within the HDF5 file from which to load spike data
|
24
|
+
sort : bool, optional
|
25
|
+
Whether to sort the DataFrame by 'timestamps' (default: True)
|
26
|
+
config : str, optional
|
27
|
+
Path to configuration file to label the cell type of each spike (default: None)
|
28
|
+
groupby : Union[str, List[str]], optional
|
29
|
+
The column(s) to group by (default: 'pop_name')
|
30
|
+
|
31
|
+
Returns
|
32
|
+
-------
|
33
|
+
pd.DataFrame
|
34
|
+
A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
|
35
|
+
with additional columns if a config file is provided
|
27
36
|
|
28
|
-
|
29
|
-
|
37
|
+
Examples
|
38
|
+
--------
|
39
|
+
>>> df = load_spikes_to_df("spikes.h5", "cortex")
|
40
|
+
>>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
|
30
41
|
"""
|
31
42
|
with h5py.File(spike_file) as f:
|
32
43
|
spikes_df = pd.DataFrame({
|
@@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
|
|
126
137
|
|
127
138
|
|
128
139
|
def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
|
129
|
-
time_points: Optional[Union[np.ndarray, list]] = None,
|
140
|
+
time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
|
130
141
|
"""
|
131
142
|
Calculate the spike count or frequency histogram over specified time intervals.
|
132
143
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
144
|
+
Parameters
|
145
|
+
----------
|
146
|
+
spike_times : Union[np.ndarray, list]
|
147
|
+
Array or list of spike times in milliseconds
|
148
|
+
time : Optional[Tuple[float, float, float]], optional
|
149
|
+
Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
|
150
|
+
if `time_points` is not provided. Default is None.
|
151
|
+
time_points : Optional[Union[np.ndarray, list]], optional
|
152
|
+
Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
|
153
|
+
frequency : bool, optional
|
154
|
+
If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
|
155
|
+
|
156
|
+
Returns
|
157
|
+
-------
|
158
|
+
np.ndarray
|
159
|
+
Array of spike counts or frequencies, depending on the `frequency` flag.
|
160
|
+
|
161
|
+
Raises
|
162
|
+
------
|
163
|
+
ValueError
|
164
|
+
If both `time` and `time_points` are None.
|
146
165
|
"""
|
147
166
|
if time_points is None:
|
148
167
|
if time is None:
|
@@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
|
|
156
175
|
bins = np.append(time_points, time_points[-1] + dt)
|
157
176
|
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
|
158
177
|
|
159
|
-
if
|
178
|
+
if frequency:
|
160
179
|
spike_rate = 1000 / dt * spike_rate
|
161
180
|
|
162
181
|
return spike_rate
|
163
182
|
|
164
183
|
|
165
|
-
def get_population_spike_rate(
|
184
|
+
def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
|
166
185
|
config: Optional[str] = None, network_name: Optional[str] = None,
|
167
186
|
save: bool = False, save_path: Optional[str] = None,
|
168
187
|
normalize: bool = False) -> Dict[str, np.ndarray]:
|
169
188
|
"""
|
170
189
|
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
|
171
190
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
spike_data : pd.DataFrame
|
194
|
+
A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
|
195
|
+
fs : float, optional
|
196
|
+
Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
|
197
|
+
t_start : float, optional
|
198
|
+
Start time (in milliseconds) for spike rate calculation (default: 0)
|
199
|
+
t_stop : Optional[float], optional
|
200
|
+
Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
|
201
|
+
config : Optional[str], optional
|
202
|
+
Path to a configuration file containing node information, used to determine the correct number of nodes per population.
|
203
|
+
If None, node count is estimated from unique node spikes (default: None)
|
204
|
+
network_name : Optional[str], optional
|
205
|
+
Name of the network used in the configuration file, allowing selection of nodes for that network.
|
206
|
+
Required if `config` is provided (default: None)
|
207
|
+
save : bool, optional
|
208
|
+
Whether to save the calculated population spike rate to a file (default: False)
|
209
|
+
save_path : Optional[str], optional
|
210
|
+
Directory path where the file should be saved if `save` is True (default: None)
|
211
|
+
normalize : bool, optional
|
212
|
+
Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
|
213
|
+
|
214
|
+
Returns
|
215
|
+
-------
|
216
|
+
Dict[str, np.ndarray]
|
217
|
+
A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
|
218
|
+
If `normalize` is True, each population's spike rate is scaled to [0, 1].
|
219
|
+
|
220
|
+
Raises
|
221
|
+
------
|
222
|
+
ValueError
|
223
|
+
If `save` is True but `save_path` is not provided.
|
224
|
+
|
225
|
+
Notes
|
226
|
+
-----
|
227
|
+
- If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
|
228
|
+
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
|
196
229
|
"""
|
197
230
|
pop_spikes = {}
|
198
231
|
node_number = {}
|
@@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
205
238
|
if not network_name:
|
206
239
|
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
207
240
|
|
208
|
-
for pop_name in
|
209
|
-
ps =
|
241
|
+
for pop_name in spike_data['pop_name'].unique():
|
242
|
+
ps = spike_data[spike_data['pop_name'] == pop_name]
|
210
243
|
|
211
244
|
if config:
|
212
245
|
nodes = load_nodes_from_config(config)
|
@@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
220
253
|
node_number[pop_name] = ps['node_ids'].nunique()
|
221
254
|
|
222
255
|
if t_stop is None:
|
223
|
-
t_stop =
|
256
|
+
t_stop = spike_data['timestamps'].max()
|
224
257
|
|
225
|
-
filtered_spikes =
|
226
|
-
(
|
227
|
-
(
|
228
|
-
(
|
258
|
+
filtered_spikes = spike_data[
|
259
|
+
(spike_data['pop_name'] == pop_name) &
|
260
|
+
(spike_data['timestamps'] > t_start) &
|
261
|
+
(spike_data['timestamps'] < t_stop)
|
229
262
|
]
|
230
263
|
pop_spikes[pop_name] = filtered_spikes
|
231
264
|
|
@@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
254
287
|
return spike_rate
|
255
288
|
|
256
289
|
|
257
|
-
def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
|
290
|
+
def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
|
258
291
|
"""
|
259
|
-
Compares the firing rates of a population during two different time windows
|
260
|
-
|
261
|
-
|
292
|
+
Compares the firing rates of a population during two different time windows and performs
|
293
|
+
a statistical test to determine if there is a significant difference.
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
spike_df : pd.DataFrame
|
298
|
+
DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
|
299
|
+
group_by : str
|
300
|
+
Column name to group spikes by (e.g., 'pop_name')
|
301
|
+
time_window_1 : List[float]
|
302
|
+
First time window as [start, stop] in milliseconds
|
303
|
+
time_window_2 : List[float]
|
304
|
+
Second time window as [start, stop] in milliseconds
|
305
|
+
|
306
|
+
Returns
|
307
|
+
-------
|
308
|
+
None
|
309
|
+
Results are printed to the console
|
310
|
+
|
311
|
+
Notes
|
312
|
+
-----
|
313
|
+
Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
|
262
314
|
"""
|
263
315
|
# Filter spikes for the population of interest
|
264
316
|
for pop_name in spike_df[group_by].unique():
|
@@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
|
8
8
|
bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
|
9
9
|
bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
|
10
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bmtool/analysis/entrainment.py,sha256=
|
12
|
-
bmtool/analysis/lfp.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=7lFlGMApL_2snwdvIPDLFW1KKPdyiuCnZ5ADa7ujx5o,24439
|
12
|
+
bmtool/analysis/lfp.py,sha256=6u9cHnac-5Fzpk9ecQew7MmXBAolzKZakRsnPn3-C2U,24109
|
13
13
|
bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
14
|
+
bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
|
15
15
|
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
|
17
17
|
bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
|
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
26
26
|
bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
|
27
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
29
|
-
bmtool-0.7.0.
|
30
|
-
bmtool-0.7.0.
|
31
|
-
bmtool-0.7.0.
|
32
|
-
bmtool-0.7.0.
|
33
|
-
bmtool-0.7.0.
|
34
|
-
bmtool-0.7.0.
|
29
|
+
bmtool-0.7.0.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.0.5.dist-info/METADATA,sha256=5A-VT9HRvmYInIJg4FvfxYSYGL70RzSgOaAaRULRXYs,2768
|
31
|
+
bmtool-0.7.0.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
32
|
+
bmtool-0.7.0.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.0.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|