bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +221 -146
- bmtool/analysis/lfp.py +121 -1
- bmtool/analysis/spikes.py +115 -63
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/METADATA +1 -1
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/RECORD +9 -9
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/WHEEL +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -8,7 +8,7 @@ import numba
|
|
8
8
|
from numba import cuda
|
9
9
|
import pandas as pd
|
10
10
|
import xarray as xr
|
11
|
-
from .lfp import wavelet_filter,butter_bandpass_filter
|
11
|
+
from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power
|
12
12
|
from typing import Dict, List
|
13
13
|
from tqdm.notebook import tqdm
|
14
14
|
import scipy.stats as stats
|
@@ -16,49 +16,65 @@ import seaborn as sns
|
|
16
16
|
import matplotlib.pyplot as plt
|
17
17
|
|
18
18
|
|
19
|
-
def calculate_signal_signal_plv(
|
20
|
-
|
19
|
+
def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
|
20
|
+
filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
21
21
|
bandwidth: float = 2.0) -> np.ndarray:
|
22
22
|
"""
|
23
23
|
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
24
24
|
|
25
|
-
Parameters
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
signal1 : np.ndarray
|
28
|
+
First input signal (1D array)
|
29
|
+
signal2 : np.ndarray
|
30
|
+
Second input signal (1D array, same length as signal1)
|
31
|
+
fs : float
|
32
|
+
Sampling frequency in Hz
|
33
|
+
freq_of_interest : float, optional
|
34
|
+
Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
|
35
|
+
filter_method : str, optional
|
36
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
37
|
+
lowcut : float, optional
|
38
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
39
|
+
highcut : float, optional
|
40
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
41
|
+
bandwidth : float, optional
|
42
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
43
|
+
|
44
|
+
Returns
|
45
|
+
-------
|
46
|
+
np.ndarray
|
47
|
+
Phase Locking Value (1D array)
|
35
48
|
"""
|
36
|
-
if len(
|
49
|
+
if len(signal1) != len(signal2):
|
37
50
|
raise ValueError("Input signals must have the same length.")
|
38
51
|
|
39
|
-
if
|
52
|
+
if filter_method == 'wavelet':
|
40
53
|
if freq_of_interest is None:
|
41
54
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
42
55
|
|
43
56
|
# Apply CWT to both signals
|
44
|
-
theta1 = wavelet_filter(x=
|
45
|
-
theta2 = wavelet_filter(x=
|
57
|
+
theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
58
|
+
theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
46
59
|
|
47
|
-
elif
|
60
|
+
elif filter_method == 'butter':
|
48
61
|
if lowcut is None or highcut is None:
|
49
|
-
print("Lowcut and
|
62
|
+
print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
|
50
63
|
|
51
64
|
if lowcut and highcut:
|
52
65
|
# Bandpass filter and get the analytic signal using the Hilbert transform
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
66
|
+
filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
|
67
|
+
filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
|
68
|
+
# Get phase using the Hilbert transform
|
69
|
+
theta1 = signal.hilbert(filtered_signal1)
|
70
|
+
theta2 = signal.hilbert(filtered_signal2)
|
71
|
+
else:
|
72
|
+
# Get phase using the Hilbert transform without filtering
|
73
|
+
theta1 = signal.hilbert(signal1)
|
74
|
+
theta2 = signal.hilbert(signal2)
|
59
75
|
|
60
76
|
else:
|
61
|
-
raise ValueError("Invalid method. Choose 'wavelet' or '
|
77
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
62
78
|
|
63
79
|
# Calculate phase difference
|
64
80
|
phase_diff = np.angle(theta1) - np.angle(theta2)
|
@@ -69,29 +85,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_
|
|
69
85
|
return plv
|
70
86
|
|
71
87
|
|
72
|
-
def calculate_spike_lfp_plv(spike_times: np.ndarray = None,
|
73
|
-
lfp_fs: float = None,
|
88
|
+
def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
89
|
+
lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
|
74
90
|
lowcut: float = None, highcut: float = None,
|
75
91
|
bandwidth: float = 2.0) -> tuple:
|
76
92
|
"""
|
77
93
|
Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
78
94
|
|
79
|
-
Parameters
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
spike_times : np.ndarray
|
98
|
+
Array of spike times
|
99
|
+
lfp_data : np.ndarray
|
100
|
+
Local field potential time series data
|
101
|
+
spike_fs : float, optional
|
102
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
103
|
+
lfp_fs : float
|
104
|
+
Sampling frequency in Hz of the LFP data
|
105
|
+
filter_method : str, optional
|
106
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
|
107
|
+
freq_of_interest : float, optional
|
108
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
109
|
+
lowcut : float, optional
|
110
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
111
|
+
highcut : float, optional
|
112
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
113
|
+
bandwidth : float, optional
|
114
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
tuple
|
119
|
+
(plv, spike_phases) where:
|
120
|
+
- plv: Phase Locking Value
|
121
|
+
- spike_phases: Phases at spike times
|
92
122
|
"""
|
93
123
|
|
94
|
-
if spike_fs
|
124
|
+
if spike_fs is None:
|
95
125
|
spike_fs = lfp_fs
|
96
126
|
# Convert spike times to sample indices
|
97
127
|
spike_times_seconds = spike_times / spike_fs
|
@@ -100,33 +130,29 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
|
|
100
130
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
101
131
|
|
102
132
|
# Filter indices to ensure they're within bounds of the LFP signal
|
103
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(
|
133
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
104
134
|
if len(valid_indices) <= 1:
|
105
135
|
return 0, np.array([])
|
106
136
|
|
107
|
-
#
|
108
|
-
if
|
137
|
+
# Filter the LFP signal to extract the phase
|
138
|
+
if filter_method == 'wavelet':
|
109
139
|
if freq_of_interest is None:
|
110
140
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
111
141
|
|
112
|
-
# Apply CWT to extract phase
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
elif method == 'hilbert':
|
142
|
+
# Apply CWT to extract phase
|
143
|
+
filtered_lfp = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
144
|
+
|
145
|
+
elif filter_method == 'butter':
|
117
146
|
if lowcut is None or highcut is None:
|
118
|
-
|
119
|
-
filtered_lfp = lfp_signal
|
120
|
-
else:
|
121
|
-
# Bandpass filter the signal
|
122
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
147
|
+
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
123
148
|
|
124
|
-
#
|
125
|
-
|
126
|
-
|
149
|
+
# Bandpass filter the LFP signal
|
150
|
+
filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
|
151
|
+
filtered_lfp = signal.hilbert(filtered_lfp) # Get analytic signal
|
152
|
+
|
127
153
|
|
128
154
|
else:
|
129
|
-
raise ValueError("Invalid method. Choose 'wavelet' or '
|
155
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
130
156
|
|
131
157
|
# Get phases at spike times
|
132
158
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -181,27 +207,43 @@ def _ppc_gpu(spike_phases):
|
|
181
207
|
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
182
208
|
|
183
209
|
|
184
|
-
def calculate_ppc(spike_times: np.ndarray = None,
|
185
|
-
lfp_fs: float = None,
|
210
|
+
def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
211
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
186
212
|
lowcut: float = None, highcut: float = None,
|
187
|
-
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
213
|
+
bandwidth: float = 2.0, ppc_method: str = 'numpy') -> tuple:
|
188
214
|
"""
|
189
215
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
190
216
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
191
217
|
|
192
|
-
Parameters
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
218
|
+
Parameters
|
219
|
+
----------
|
220
|
+
spike_times : np.ndarray
|
221
|
+
Array of spike times
|
222
|
+
lfp_data : np.ndarray
|
223
|
+
Local field potential time series data
|
224
|
+
spike_fs : float, optional
|
225
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
226
|
+
lfp_fs : float
|
227
|
+
Sampling frequency in Hz of the LFP data
|
228
|
+
filter_method : str, optional
|
229
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
230
|
+
freq_of_interest : float, optional
|
231
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
232
|
+
lowcut : float, optional
|
233
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
234
|
+
highcut : float, optional
|
235
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
236
|
+
bandwidth : float, optional
|
237
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
238
|
+
ppc_method : str, optional
|
239
|
+
Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
|
240
|
+
|
241
|
+
Returns
|
242
|
+
-------
|
243
|
+
tuple
|
244
|
+
(ppc, spike_phases) where:
|
245
|
+
- ppc: Pairwise Phase Consistency value
|
246
|
+
- spike_phases: Phases at spike times
|
205
247
|
"""
|
206
248
|
if spike_fs is None:
|
207
249
|
spike_fs = lfp_fs
|
@@ -212,33 +254,32 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
212
254
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
213
255
|
|
214
256
|
# Filter indices to ensure they're within bounds of the LFP signal
|
215
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(
|
257
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
216
258
|
if len(valid_indices) <= 1:
|
217
259
|
return 0, np.array([])
|
218
260
|
|
219
261
|
# Extract phase using the specified method
|
220
|
-
if
|
262
|
+
if filter_method == 'wavelet':
|
221
263
|
if freq_of_interest is None:
|
222
264
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
223
265
|
|
224
266
|
# Apply CWT to extract phase at the frequency of interest
|
225
|
-
lfp_complex = wavelet_filter(x=
|
267
|
+
lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
226
268
|
instantaneous_phase = np.angle(lfp_complex)
|
227
269
|
|
228
|
-
elif
|
270
|
+
elif filter_method == 'butter':
|
229
271
|
if lowcut is None or highcut is None:
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
272
|
+
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
273
|
+
|
274
|
+
# Bandpass filter the signal
|
275
|
+
filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
|
235
276
|
|
236
277
|
# Get phase using the Hilbert transform
|
237
278
|
analytic_signal = signal.hilbert(filtered_lfp)
|
238
279
|
instantaneous_phase = np.angle(analytic_signal)
|
239
280
|
|
240
281
|
else:
|
241
|
-
raise ValueError("Invalid method. Choose 'wavelet' or '
|
282
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
242
283
|
|
243
284
|
# Get phases at spike times
|
244
285
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -283,10 +324,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
283
324
|
return ppc
|
284
325
|
|
285
326
|
|
286
|
-
def calculate_ppc2(spike_times: np.ndarray = None,
|
287
|
-
lfp_fs: float = None,
|
327
|
+
def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
|
328
|
+
lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
|
288
329
|
lowcut: float = None, highcut: float = None,
|
289
|
-
bandwidth: float = 2.0) ->
|
330
|
+
bandwidth: float = 2.0) -> float:
|
290
331
|
"""
|
291
332
|
# -----------------------------------------------------------------------------
|
292
333
|
# PPC2 Calculation (Vinck et al., 2010)
|
@@ -297,18 +338,31 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
297
338
|
# PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
|
298
339
|
# -----------------------------------------------------------------------------
|
299
340
|
|
300
|
-
Parameters
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
341
|
+
Parameters
|
342
|
+
----------
|
343
|
+
spike_times : np.ndarray
|
344
|
+
Array of spike times
|
345
|
+
lfp_data : np.ndarray
|
346
|
+
Local field potential time series data
|
347
|
+
spike_fs : float, optional
|
348
|
+
Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
|
349
|
+
lfp_fs : float
|
350
|
+
Sampling frequency in Hz of the LFP data
|
351
|
+
filter_method : str, optional
|
352
|
+
Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
|
353
|
+
freq_of_interest : float, optional
|
354
|
+
Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
|
355
|
+
lowcut : float, optional
|
356
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
357
|
+
highcut : float, optional
|
358
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
359
|
+
bandwidth : float, optional
|
360
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
361
|
+
|
362
|
+
Returns
|
363
|
+
-------
|
364
|
+
float
|
365
|
+
Pairwise Phase Consistency 2 (PPC2) value
|
312
366
|
"""
|
313
367
|
|
314
368
|
if spike_fs is None:
|
@@ -320,33 +374,32 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
320
374
|
spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
|
321
375
|
|
322
376
|
# Filter indices to ensure they're within bounds of the LFP signal
|
323
|
-
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(
|
377
|
+
valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
|
324
378
|
if len(valid_indices) <= 1:
|
325
|
-
return 0
|
379
|
+
return 0
|
326
380
|
|
327
381
|
# Extract phase using the specified method
|
328
|
-
if
|
382
|
+
if filter_method == 'wavelet':
|
329
383
|
if freq_of_interest is None:
|
330
384
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
331
385
|
|
332
386
|
# Apply CWT to extract phase at the frequency of interest
|
333
|
-
lfp_complex = wavelet_filter(x=
|
387
|
+
lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
|
334
388
|
instantaneous_phase = np.angle(lfp_complex)
|
335
389
|
|
336
|
-
elif
|
390
|
+
elif filter_method == 'butter':
|
337
391
|
if lowcut is None or highcut is None:
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
|
392
|
+
raise ValueError("Both lowcut and highcut must be specified for the butter method.")
|
393
|
+
|
394
|
+
# Bandpass filter the signal
|
395
|
+
filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
|
343
396
|
|
344
397
|
# Get phase using the Hilbert transform
|
345
398
|
analytic_signal = signal.hilbert(filtered_lfp)
|
346
399
|
instantaneous_phase = np.angle(analytic_signal)
|
347
400
|
|
348
401
|
else:
|
349
|
-
raise ValueError("Invalid method. Choose 'wavelet' or '
|
402
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
350
403
|
|
351
404
|
# Get phases at spike times
|
352
405
|
spike_phases = instantaneous_phase[valid_indices]
|
@@ -355,7 +408,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
355
408
|
n = len(spike_phases)
|
356
409
|
|
357
410
|
if n <= 1:
|
358
|
-
return 0
|
411
|
+
return 0
|
359
412
|
|
360
413
|
# Convert phases to unit vectors in the complex plane
|
361
414
|
unit_vectors = np.exp(1j * spike_phases)
|
@@ -369,33 +422,43 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
369
422
|
return ppc2
|
370
423
|
|
371
424
|
|
372
|
-
def calculate_ppc_per_cell(spike_df: pd.DataFrame,
|
373
|
-
spike_fs: float, lfp_fs:float,
|
374
|
-
pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
|
425
|
+
def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None,
|
426
|
+
spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
|
427
|
+
pop_names: List[str]=None, freqs: List[float]=None) -> Dict[str, Dict[int, Dict[float, float]]]:
|
375
428
|
"""
|
376
429
|
Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
|
377
430
|
|
378
431
|
This function computes the PPC for each neuron within the specified populations based on their spike times
|
379
|
-
and a
|
432
|
+
and the provided LFP signal. It returns a nested dictionary structure containing the PPC values
|
433
|
+
organized by population, node ID, and frequency.
|
380
434
|
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
435
|
+
Parameters
|
436
|
+
----------
|
437
|
+
spike_df : pd.DataFrame
|
438
|
+
DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
|
439
|
+
lfp_data : np.ndarray
|
440
|
+
Local field potential (LFP) time series data
|
441
|
+
spike_fs : float
|
442
|
+
Sampling frequency of the spike times in Hz
|
443
|
+
lfp_fs : float
|
444
|
+
Sampling frequency of the LFP signal in Hz
|
445
|
+
pop_names : List[str]
|
446
|
+
List of population names to analyze
|
447
|
+
freqs : List[float]
|
448
|
+
List of frequencies (in Hz) at which to calculate PPC
|
388
449
|
|
389
|
-
Returns
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
450
|
+
Returns
|
451
|
+
-------
|
452
|
+
Dict[str, Dict[int, Dict[float, float]]]
|
453
|
+
Nested dictionary where the structure is:
|
454
|
+
{
|
455
|
+
population_name: {
|
456
|
+
node_id: {
|
457
|
+
frequency: PPC value
|
396
458
|
}
|
397
459
|
}
|
398
|
-
|
460
|
+
}
|
461
|
+
PPC values are floats representing the pairwise phase consistency at each frequency
|
399
462
|
"""
|
400
463
|
ppc_dict = {}
|
401
464
|
for pop in pop_names:
|
@@ -416,11 +479,12 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
|
|
416
479
|
for freq in freqs:
|
417
480
|
ppc = calculate_ppc2(
|
418
481
|
node_spikes['timestamps'].values,
|
419
|
-
|
482
|
+
lfp_data,
|
420
483
|
spike_fs=spike_fs,
|
421
484
|
lfp_fs=lfp_fs,
|
422
485
|
freq_of_interest=freq,
|
423
|
-
|
486
|
+
bandwidth=bandwidth,
|
487
|
+
filter_method='wavelet'
|
424
488
|
)
|
425
489
|
ppc_dict[pop][node][freq] = ppc
|
426
490
|
|
@@ -429,7 +493,9 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
|
|
429
493
|
return ppc_dict
|
430
494
|
|
431
495
|
|
432
|
-
def calculate_spike_rate_power_correlation(spike_rate,
|
496
|
+
def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
|
497
|
+
bandwidth=2.0, lowcut=None, highcut=None,
|
498
|
+
freq_range=(10, 100), freq_step=5):
|
433
499
|
"""
|
434
500
|
Calculate correlation between population spike rates and LFP power across frequencies
|
435
501
|
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
@@ -438,16 +504,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
438
504
|
-----------
|
439
505
|
spike_rate : DataFrame
|
440
506
|
Pre-calculated population spike rates at the same fs as lfp
|
441
|
-
|
507
|
+
lfp_data : np.array
|
442
508
|
LFP data
|
443
509
|
fs : float
|
444
510
|
Sampling frequency
|
445
511
|
pop_names : list
|
446
512
|
List of population names to analyze
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
513
|
+
filter_method : str, optional
|
514
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
515
|
+
bandwidth : float, optional
|
516
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
|
517
|
+
lowcut : float, optional
|
518
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
519
|
+
highcut : float, optional
|
520
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
521
|
+
freq_range : tuple, optional
|
522
|
+
Min and max frequency to analyze (default: (10, 100))
|
523
|
+
freq_step : float, optional
|
524
|
+
Step size for frequency analysis (default: 5)
|
451
525
|
|
452
526
|
Returns:
|
453
527
|
--------
|
@@ -463,14 +537,15 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
463
537
|
# Dictionary to store results
|
464
538
|
correlation_results = {pop: {} for pop in pop_names}
|
465
539
|
|
466
|
-
# Calculate power at each frequency band using
|
540
|
+
# Calculate power at each frequency band using specified filter
|
467
541
|
power_by_freq = {}
|
468
542
|
for freq in frequencies:
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
543
|
+
if filter_method == 'wavelet':
|
544
|
+
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
545
|
+
lowcut=None, highcut=None, bandwidth=bandwidth)
|
546
|
+
elif filter_method == 'butter':
|
547
|
+
power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
|
548
|
+
lowcut=lowcut, highcut=highcut)
|
474
549
|
|
475
550
|
# Calculate correlation for each population
|
476
551
|
for pop in pop_names:
|
@@ -481,7 +556,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
|
|
481
556
|
for freq in frequencies:
|
482
557
|
# Make sure the lengths match
|
483
558
|
if len(pop_rate) != len(power_by_freq[freq]):
|
484
|
-
raise
|
559
|
+
raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
|
485
560
|
# use spearman for non-parametric correlation
|
486
561
|
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
487
562
|
correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
|
bmtool/analysis/lfp.py
CHANGED
@@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
|
|
273
273
|
return normalized_power
|
274
274
|
|
275
275
|
|
276
|
-
def
|
276
|
+
def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
|
277
|
+
"""
|
278
|
+
Calculate the passband of a complex Morlet wavelet filter.
|
279
|
+
|
280
|
+
Parameters
|
281
|
+
----------
|
282
|
+
center_freq : float
|
283
|
+
Center frequency (Hz) of the wavelet filter
|
284
|
+
bandwidth : float
|
285
|
+
Bandwidth parameter of the wavelet filter
|
286
|
+
threshold : float, optional
|
287
|
+
Power threshold to define the passband edges (default: 0.5 = -3dB point)
|
288
|
+
|
289
|
+
Returns
|
290
|
+
-------
|
291
|
+
tuple
|
292
|
+
(lower_bound, upper_bound, passband_width) of the frequency passband in Hz
|
293
|
+
"""
|
294
|
+
# Create a high-resolution frequency axis around the center frequency
|
295
|
+
# Extend range to 3x the expected width to ensure we capture the full passband
|
296
|
+
expected_width = center_freq * bandwidth / 2
|
297
|
+
freq_min = max(0.1, center_freq - 3 * expected_width)
|
298
|
+
freq_max = center_freq + 3 * expected_width
|
299
|
+
freq_axis = np.linspace(freq_min, freq_max, 1000)
|
300
|
+
|
301
|
+
# Calculate the theoretical frequency response of the Morlet wavelet
|
302
|
+
# For a complex Morlet wavelet, the frequency response approximates a Gaussian
|
303
|
+
# centered at the center frequency with width related to the bandwidth parameter
|
304
|
+
sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
|
305
|
+
response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
|
306
|
+
|
307
|
+
# Find the passband edges (where response crosses the threshold)
|
308
|
+
above_threshold = response >= threshold
|
309
|
+
if not np.any(above_threshold):
|
310
|
+
return (center_freq, center_freq, 0) # No passband found
|
311
|
+
|
312
|
+
# Find the first and last indices where response is above threshold
|
313
|
+
indices = np.where(above_threshold)[0]
|
314
|
+
lower_idx = indices[0]
|
315
|
+
upper_idx = indices[-1]
|
316
|
+
|
317
|
+
# Get the corresponding frequencies
|
318
|
+
lower_bound = freq_axis[lower_idx]
|
319
|
+
upper_bound = freq_axis[upper_idx]
|
320
|
+
passband_width = upper_bound - lower_bound
|
321
|
+
|
322
|
+
return (lower_bound, upper_bound, passband_width)
|
323
|
+
|
324
|
+
|
325
|
+
def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
|
277
326
|
"""
|
278
327
|
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
|
328
|
+
|
329
|
+
Parameters
|
330
|
+
----------
|
331
|
+
x : np.ndarray
|
332
|
+
Input signal
|
333
|
+
freq : float
|
334
|
+
Target frequency for the wavelet filter
|
335
|
+
fs : float
|
336
|
+
Sampling frequency of the signal
|
337
|
+
bandwidth : float, optional
|
338
|
+
Bandwidth parameter of the wavelet filter (default is 1.0)
|
339
|
+
axis : int, optional
|
340
|
+
Axis along which to compute the CWT (default is -1)
|
341
|
+
show_passband : bool, optional
|
342
|
+
If True, print the passband of the wavelet filter (default is False)
|
343
|
+
|
344
|
+
Returns
|
345
|
+
-------
|
346
|
+
np.ndarray
|
347
|
+
Continuous Wavelet Transform of the input signal
|
279
348
|
"""
|
349
|
+
if show_passband:
|
350
|
+
lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
|
351
|
+
print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
|
352
|
+
print(f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
|
280
353
|
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
|
281
354
|
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
|
282
355
|
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
|
@@ -292,6 +365,53 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
|
|
292
365
|
return x_a
|
293
366
|
|
294
367
|
|
368
|
+
def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
|
369
|
+
lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
|
370
|
+
"""
|
371
|
+
Compute the power of the raw LFP signal in a specified frequency band.
|
372
|
+
|
373
|
+
Parameters
|
374
|
+
----------
|
375
|
+
lfp_data : np.ndarray
|
376
|
+
Raw local field potential (LFP) time series data
|
377
|
+
freq : float
|
378
|
+
Center frequency (Hz) for wavelet filtering method
|
379
|
+
fs : float
|
380
|
+
Sampling frequency (Hz) of the input data
|
381
|
+
filter_method : str, optional
|
382
|
+
Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
|
383
|
+
lowcut : float, optional
|
384
|
+
Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
385
|
+
highcut : float, optional
|
386
|
+
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
387
|
+
bandwidth : float, optional
|
388
|
+
Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
|
389
|
+
|
390
|
+
Returns
|
391
|
+
-------
|
392
|
+
np.ndarray
|
393
|
+
Power of the filtered signal (magnitude squared)
|
394
|
+
|
395
|
+
Notes
|
396
|
+
-----
|
397
|
+
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
398
|
+
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
399
|
+
- When using the 'butter' method, both lowcut and highcut must be provided
|
400
|
+
"""
|
401
|
+
if filter_method == 'wavelet':
|
402
|
+
filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
|
403
|
+
elif filter_method == 'butter':
|
404
|
+
if lowcut is None or highcut is None:
|
405
|
+
raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
|
406
|
+
filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
|
407
|
+
else:
|
408
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
409
|
+
|
410
|
+
# Calculate power (magnitude squared of filtered signal)
|
411
|
+
power = np.abs(filtered_signal)**2
|
412
|
+
return power
|
413
|
+
|
414
|
+
|
295
415
|
# windowing functions
|
296
416
|
def windowed_xarray(da, windows, dim='time',
|
297
417
|
new_coord_name='cycle', new_coord=None):
|
bmtool/analysis/spikes.py
CHANGED
@@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu
|
|
11
11
|
import os
|
12
12
|
|
13
13
|
|
14
|
-
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
|
14
|
+
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
|
15
15
|
"""
|
16
16
|
Load spike data from an HDF5 file into a pandas DataFrame.
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
spike_file : str
|
21
|
+
Path to the HDF5 file containing spike data
|
22
|
+
network_name : str
|
23
|
+
The name of the network within the HDF5 file from which to load spike data
|
24
|
+
sort : bool, optional
|
25
|
+
Whether to sort the DataFrame by 'timestamps' (default: True)
|
26
|
+
config : str, optional
|
27
|
+
Path to configuration file to label the cell type of each spike (default: None)
|
28
|
+
groupby : Union[str, List[str]], optional
|
29
|
+
The column(s) to group by (default: 'pop_name')
|
30
|
+
|
31
|
+
Returns
|
32
|
+
-------
|
33
|
+
pd.DataFrame
|
34
|
+
A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
|
35
|
+
with additional columns if a config file is provided
|
27
36
|
|
28
|
-
|
29
|
-
|
37
|
+
Examples
|
38
|
+
--------
|
39
|
+
>>> df = load_spikes_to_df("spikes.h5", "cortex")
|
40
|
+
>>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
|
30
41
|
"""
|
31
42
|
with h5py.File(spike_file) as f:
|
32
43
|
spikes_df = pd.DataFrame({
|
@@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
|
|
126
137
|
|
127
138
|
|
128
139
|
def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
|
129
|
-
time_points: Optional[Union[np.ndarray, list]] = None,
|
140
|
+
time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
|
130
141
|
"""
|
131
142
|
Calculate the spike count or frequency histogram over specified time intervals.
|
132
143
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
144
|
+
Parameters
|
145
|
+
----------
|
146
|
+
spike_times : Union[np.ndarray, list]
|
147
|
+
Array or list of spike times in milliseconds
|
148
|
+
time : Optional[Tuple[float, float, float]], optional
|
149
|
+
Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
|
150
|
+
if `time_points` is not provided. Default is None.
|
151
|
+
time_points : Optional[Union[np.ndarray, list]], optional
|
152
|
+
Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
|
153
|
+
frequency : bool, optional
|
154
|
+
If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
|
155
|
+
|
156
|
+
Returns
|
157
|
+
-------
|
158
|
+
np.ndarray
|
159
|
+
Array of spike counts or frequencies, depending on the `frequency` flag.
|
160
|
+
|
161
|
+
Raises
|
162
|
+
------
|
163
|
+
ValueError
|
164
|
+
If both `time` and `time_points` are None.
|
146
165
|
"""
|
147
166
|
if time_points is None:
|
148
167
|
if time is None:
|
@@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
|
|
156
175
|
bins = np.append(time_points, time_points[-1] + dt)
|
157
176
|
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
|
158
177
|
|
159
|
-
if
|
178
|
+
if frequency:
|
160
179
|
spike_rate = 1000 / dt * spike_rate
|
161
180
|
|
162
181
|
return spike_rate
|
163
182
|
|
164
183
|
|
165
|
-
def get_population_spike_rate(
|
184
|
+
def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
|
166
185
|
config: Optional[str] = None, network_name: Optional[str] = None,
|
167
186
|
save: bool = False, save_path: Optional[str] = None,
|
168
187
|
normalize: bool = False) -> Dict[str, np.ndarray]:
|
169
188
|
"""
|
170
189
|
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
|
171
190
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
spike_data : pd.DataFrame
|
194
|
+
A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
|
195
|
+
fs : float, optional
|
196
|
+
Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
|
197
|
+
t_start : float, optional
|
198
|
+
Start time (in milliseconds) for spike rate calculation (default: 0)
|
199
|
+
t_stop : Optional[float], optional
|
200
|
+
Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
|
201
|
+
config : Optional[str], optional
|
202
|
+
Path to a configuration file containing node information, used to determine the correct number of nodes per population.
|
203
|
+
If None, node count is estimated from unique node spikes (default: None)
|
204
|
+
network_name : Optional[str], optional
|
205
|
+
Name of the network used in the configuration file, allowing selection of nodes for that network.
|
206
|
+
Required if `config` is provided (default: None)
|
207
|
+
save : bool, optional
|
208
|
+
Whether to save the calculated population spike rate to a file (default: False)
|
209
|
+
save_path : Optional[str], optional
|
210
|
+
Directory path where the file should be saved if `save` is True (default: None)
|
211
|
+
normalize : bool, optional
|
212
|
+
Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
|
213
|
+
|
214
|
+
Returns
|
215
|
+
-------
|
216
|
+
Dict[str, np.ndarray]
|
217
|
+
A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
|
218
|
+
If `normalize` is True, each population's spike rate is scaled to [0, 1].
|
219
|
+
|
220
|
+
Raises
|
221
|
+
------
|
222
|
+
ValueError
|
223
|
+
If `save` is True but `save_path` is not provided.
|
224
|
+
|
225
|
+
Notes
|
226
|
+
-----
|
227
|
+
- If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
|
228
|
+
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
|
196
229
|
"""
|
197
230
|
pop_spikes = {}
|
198
231
|
node_number = {}
|
@@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
205
238
|
if not network_name:
|
206
239
|
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
207
240
|
|
208
|
-
for pop_name in
|
209
|
-
ps =
|
241
|
+
for pop_name in spike_data['pop_name'].unique():
|
242
|
+
ps = spike_data[spike_data['pop_name'] == pop_name]
|
210
243
|
|
211
244
|
if config:
|
212
245
|
nodes = load_nodes_from_config(config)
|
@@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
220
253
|
node_number[pop_name] = ps['node_ids'].nunique()
|
221
254
|
|
222
255
|
if t_stop is None:
|
223
|
-
t_stop =
|
256
|
+
t_stop = spike_data['timestamps'].max()
|
224
257
|
|
225
|
-
filtered_spikes =
|
226
|
-
(
|
227
|
-
(
|
228
|
-
(
|
258
|
+
filtered_spikes = spike_data[
|
259
|
+
(spike_data['pop_name'] == pop_name) &
|
260
|
+
(spike_data['timestamps'] > t_start) &
|
261
|
+
(spike_data['timestamps'] < t_stop)
|
229
262
|
]
|
230
263
|
pop_spikes[pop_name] = filtered_spikes
|
231
264
|
|
@@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
|
|
254
287
|
return spike_rate
|
255
288
|
|
256
289
|
|
257
|
-
def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
|
290
|
+
def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
|
258
291
|
"""
|
259
|
-
Compares the firing rates of a population during two different time windows
|
260
|
-
|
261
|
-
|
292
|
+
Compares the firing rates of a population during two different time windows and performs
|
293
|
+
a statistical test to determine if there is a significant difference.
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
spike_df : pd.DataFrame
|
298
|
+
DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
|
299
|
+
group_by : str
|
300
|
+
Column name to group spikes by (e.g., 'pop_name')
|
301
|
+
time_window_1 : List[float]
|
302
|
+
First time window as [start, stop] in milliseconds
|
303
|
+
time_window_2 : List[float]
|
304
|
+
Second time window as [start, stop] in milliseconds
|
305
|
+
|
306
|
+
Returns
|
307
|
+
-------
|
308
|
+
None
|
309
|
+
Results are printed to the console
|
310
|
+
|
311
|
+
Notes
|
312
|
+
-----
|
313
|
+
Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
|
262
314
|
"""
|
263
315
|
# Filter spikes for the population of interest
|
264
316
|
for pop_name in spike_df[group_by].unique():
|
@@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
|
8
8
|
bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
|
9
9
|
bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
|
10
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bmtool/analysis/entrainment.py,sha256=
|
12
|
-
bmtool/analysis/lfp.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=JRg9sQ7WrZMqHwMaDJtCN7kGgZHJ5msUSzP-JPltC8k,23158
|
12
|
+
bmtool/analysis/lfp.py,sha256=hOqD4xcDEL0NrNIN2-Ler_mkvY5cEUhxr7VUdX5Gwh8,21737
|
13
13
|
bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
14
|
+
bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
|
15
15
|
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
|
17
17
|
bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
|
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
26
26
|
bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
|
27
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
29
|
-
bmtool-0.7.0.
|
30
|
-
bmtool-0.7.0.
|
31
|
-
bmtool-0.7.0.
|
32
|
-
bmtool-0.7.0.
|
33
|
-
bmtool-0.7.0.
|
34
|
-
bmtool-0.7.0.
|
29
|
+
bmtool-0.7.0.4.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.0.4.dist-info/METADATA,sha256=I-D4fwZQIHcHvD6Ou8Az3Kclwl17kwpZH7YHrD0eEg4,2768
|
31
|
+
bmtool-0.7.0.4.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
32
|
+
bmtool-0.7.0.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.0.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.0.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|