bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,58 +7,71 @@ from scipy import signal
7
7
  import numba
8
8
  from numba import cuda
9
9
  import pandas as pd
10
- import xarray as xr
11
- from .lfp import wavelet_filter,butter_bandpass_filter
12
- from typing import Dict, List
10
+ from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
11
+ from typing import Dict, List, Optional
13
12
  from tqdm.notebook import tqdm
14
13
  import scipy.stats as stats
15
- import seaborn as sns
16
- import matplotlib.pyplot as plt
17
14
 
18
15
 
19
- def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
20
- method: str = 'wavelet', lowcut: float = None, highcut: float = None,
16
+ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
17
+ filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
21
18
  bandwidth: float = 2.0) -> np.ndarray:
22
19
  """
23
20
  Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
24
21
 
25
- Parameters:
26
- - x1, x2: Input signals (1D arrays, same length)
27
- - fs: Sampling frequency
28
- - freq_of_interest: Desired frequency for wavelet PLV calculation
29
- - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
30
- - lowcut, highcut: Cutoff frequencies for the Hilbert method
31
- - bandwidth: Bandwidth parameter for the wavelet
32
-
33
- Returns:
34
- - plv: Phase Locking Value (1D array)
22
+ Parameters
23
+ ----------
24
+ signal1 : np.ndarray
25
+ First input signal (1D array)
26
+ signal2 : np.ndarray
27
+ Second input signal (1D array, same length as signal1)
28
+ fs : float
29
+ Sampling frequency in Hz
30
+ freq_of_interest : float, optional
31
+ Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
32
+ filter_method : str, optional
33
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
34
+ lowcut : float, optional
35
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
36
+ highcut : float, optional
37
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
38
+ bandwidth : float, optional
39
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
40
+
41
+ Returns
42
+ -------
43
+ np.ndarray
44
+ Phase Locking Value (1D array)
35
45
  """
36
- if len(x1) != len(x2):
46
+ if len(signal1) != len(signal2):
37
47
  raise ValueError("Input signals must have the same length.")
38
48
 
39
- if method == 'wavelet':
49
+ if filter_method == 'wavelet':
40
50
  if freq_of_interest is None:
41
51
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
42
52
 
43
53
  # Apply CWT to both signals
44
- theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
45
- theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
54
+ theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
55
+ theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
46
56
 
47
- elif method == 'hilbert':
57
+ elif filter_method == 'butter':
48
58
  if lowcut is None or highcut is None:
49
- print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
59
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
50
60
 
51
61
  if lowcut and highcut:
52
62
  # Bandpass filter and get the analytic signal using the Hilbert transform
53
- x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
54
- x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
55
-
56
- # Get phase using the Hilbert transform
57
- theta1 = signal.hilbert(x1)
58
- theta2 = signal.hilbert(x2)
63
+ filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
64
+ filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
65
+ # Get phase using the Hilbert transform
66
+ theta1 = signal.hilbert(filtered_signal1)
67
+ theta2 = signal.hilbert(filtered_signal2)
68
+ else:
69
+ # Get phase using the Hilbert transform without filtering
70
+ theta1 = signal.hilbert(signal1)
71
+ theta2 = signal.hilbert(signal2)
59
72
 
60
73
  else:
61
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
74
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
62
75
 
63
76
  # Calculate phase difference
64
77
  phase_diff = np.angle(theta1) - np.angle(theta2)
@@ -69,29 +82,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_
69
82
  return plv
70
83
 
71
84
 
72
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
73
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
74
- lowcut: float = None, highcut: float = None,
75
- bandwidth: float = 2.0) -> tuple:
85
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
86
+ lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
87
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
88
+ filtered_lfp_phase: np.ndarray = None) -> float:
76
89
  """
77
- Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
78
-
79
- Parameters:
80
- - spike_times: Array of spike times
81
- - lfp_signal: Local field potential time series
82
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
83
- - lfp_fs : Sampling frequency in Hz of the LFP
84
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
85
- - freq_of_interest: Desired frequency for wavelet phase extraction
86
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
87
- - bandwidth: Bandwidth parameter for the wavelet
88
-
89
- Returns:
90
- - ppc1: Phase-Phase Coupling value
91
- - spike_phases: Phases at spike times
90
+ Calculate spike-lfp unbiased phase locking value
91
+
92
+ Parameters
93
+ ----------
94
+ spike_times : np.ndarray
95
+ Array of spike times
96
+ lfp_data : np.ndarray
97
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
98
+ spike_fs : float, optional
99
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
100
+ lfp_fs : float
101
+ Sampling frequency in Hz of the LFP data
102
+ filter_method : str, optional
103
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
104
+ freq_of_interest : float, optional
105
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
106
+ lowcut : float, optional
107
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
108
+ highcut : float, optional
109
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
110
+ bandwidth : float, optional
111
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
112
+ filtered_lfp_phase : np.ndarray, optional
113
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
114
+
115
+ Returns
116
+ -------
117
+ float
118
+ Phase Locking Value (unbiased)
92
119
  """
93
120
 
94
- if spike_fs == None:
121
+ if spike_fs is None:
95
122
  spike_fs = lfp_fs
96
123
  # Convert spike times to sample indices
97
124
  spike_times_seconds = spike_times / spike_fs
@@ -100,50 +127,41 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
100
127
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
101
128
 
102
129
  # Filter indices to ensure they're within bounds of the LFP signal
103
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
130
+ if filtered_lfp_phase is not None:
131
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
132
+ else:
133
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
134
+
104
135
  if len(valid_indices) <= 1:
105
- return 0, np.array([])
136
+ return 0
106
137
 
107
- # Extract phase using the specified method
108
- if method == 'wavelet':
109
- if freq_of_interest is None:
110
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
111
-
112
- # Apply CWT to extract phase at the frequency of interest
113
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
114
- instantaneous_phase = np.angle(lfp_complex)
115
-
116
- elif method == 'hilbert':
117
- if lowcut is None or highcut is None:
118
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
119
- filtered_lfp = lfp_signal
120
- else:
121
- # Bandpass filter the signal
122
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
123
-
124
- # Get phase using the Hilbert transform
125
- analytic_signal = signal.hilbert(filtered_lfp)
126
- instantaneous_phase = np.angle(analytic_signal)
127
-
138
+ # Get instantaneous phase
139
+ if filtered_lfp_phase is None:
140
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
141
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
142
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
128
143
  else:
129
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
144
+ instantaneous_phase = filtered_lfp_phase
130
145
 
131
146
  # Get phases at spike times
132
147
  spike_phases = instantaneous_phase[valid_indices]
133
-
134
- # Calculate PPC1
135
- n = len(spike_phases)
148
+
149
+ # Number of spikes
150
+ N = len(spike_phases)
136
151
 
137
152
  # Convert phases to unit vectors in the complex plane
138
153
  unit_vectors = np.exp(1j * spike_phases)
139
154
 
140
- # Calculate the resultant vector
155
+ # Sum of all unit vectors (resultant vector)
141
156
  resultant_vector = np.sum(unit_vectors)
142
-
143
- # Plv is the squared length of the resultant vector divided by n²
144
- plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
145
-
146
- return plv
157
+
158
+ # Calculate plv^2 * N
159
+ plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
160
+ plv = (plv2n / N) ** 0.5
161
+ ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
162
+ plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
163
+
164
+ return plv_unbiased
147
165
 
148
166
 
149
167
  @numba.njit(parallel=True, fastmath=True)
@@ -181,27 +199,43 @@ def _ppc_gpu(spike_phases):
181
199
  return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
182
200
 
183
201
 
184
- def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
185
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
186
- lowcut: float = None, highcut: float = None,
187
- bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
202
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
203
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
204
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
205
+ ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
188
206
  """
189
207
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
190
208
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
191
209
 
192
- Parameters:
193
- - spike_times: Array of spike times
194
- - lfp_signal: Local field potential time series
195
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
196
- - lfp_fs: Sampling frequency in Hz of the LFP
197
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
198
- - freq_of_interest: Desired frequency for wavelet phase extraction
199
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
200
- - bandwidth: Bandwidth parameter for the wavelet
201
- - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
202
-
203
- Returns:
204
- - ppc: Pairwise Phase Consistency value
210
+ Parameters
211
+ ----------
212
+ spike_times : np.ndarray
213
+ Array of spike times
214
+ lfp_data : np.ndarray
215
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
216
+ spike_fs : float, optional
217
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
218
+ lfp_fs : float
219
+ Sampling frequency in Hz of the LFP data
220
+ filter_method : str, optional
221
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
222
+ freq_of_interest : float, optional
223
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
224
+ lowcut : float, optional
225
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
226
+ highcut : float, optional
227
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
228
+ bandwidth : float, optional
229
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
230
+ ppc_method : str, optional
231
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
232
+ filtered_lfp_phase : np.ndarray, optional
233
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
234
+
235
+ Returns
236
+ -------
237
+ float
238
+ Pairwise Phase Consistency value
205
239
  """
206
240
  if spike_fs is None:
207
241
  spike_fs = lfp_fs
@@ -212,33 +246,21 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
212
246
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
213
247
 
214
248
  # Filter indices to ensure they're within bounds of the LFP signal
215
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
249
+ if filtered_lfp_phase is not None:
250
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
251
+ else:
252
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
253
+
216
254
  if len(valid_indices) <= 1:
217
- return 0, np.array([])
255
+ return 0
218
256
 
219
- # Extract phase using the specified method
220
- if method == 'wavelet':
221
- if freq_of_interest is None:
222
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
223
-
224
- # Apply CWT to extract phase at the frequency of interest
225
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
226
- instantaneous_phase = np.angle(lfp_complex)
227
-
228
- elif method == 'hilbert':
229
- if lowcut is None or highcut is None:
230
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
231
- filtered_lfp = lfp_signal
232
- else:
233
- # Bandpass filter the signal
234
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
235
-
236
- # Get phase using the Hilbert transform
237
- analytic_signal = signal.hilbert(filtered_lfp)
238
- instantaneous_phase = np.angle(analytic_signal)
239
-
257
+ # Get instantaneous phase
258
+ if filtered_lfp_phase is None:
259
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
260
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
261
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
240
262
  else:
241
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
263
+ instantaneous_phase = filtered_lfp_phase
242
264
 
243
265
  # Get phases at spike times
244
266
  spike_phases = instantaneous_phase[valid_indices]
@@ -247,28 +269,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
247
269
 
248
270
  # Calculate PPC (Pairwise Phase Consistency)
249
271
  if n_spikes <= 1:
250
- return 0, spike_phases
272
+ return 0
251
273
 
252
274
  # Explicit calculation of pairwise phase consistency
253
- sum_cos_diff = 0.0
254
-
255
- # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
256
- # for i in range(n_spikes - 1): # For each spike i
257
- # for j in range(i + 1, n_spikes): # For each spike j > i
258
- # # Calculate the phase difference between spikes i and j
259
- # phase_diff = spike_phases[i] - spike_phases[j]
260
-
261
- # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
262
- # cos_diff = np.cos(phase_diff)
263
-
264
- # # Add to the sum
265
- # sum_cos_diff += cos_diff
266
-
267
- # # Calculate PPC according to the equation
268
- # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
269
- # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
270
-
271
- # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
275
+ # Vectorized computation for efficiency
272
276
  if ppc_method == 'numpy':
273
277
  i, j = np.triu_indices(n_spikes, k=1)
274
278
  phase_diff = spike_phases[i] - spike_phases[j]
@@ -279,14 +283,14 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
279
283
  elif ppc_method == 'gpu':
280
284
  ppc = _ppc_gpu(spike_phases)
281
285
  else:
282
- raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
286
+ raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
283
287
  return ppc
284
288
 
285
289
 
286
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
287
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
288
- lowcut: float = None, highcut: float = None,
289
- bandwidth: float = 2.0) -> tuple:
290
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
291
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
292
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
293
+ filtered_lfp_phase: np.ndarray = None) -> float:
290
294
  """
291
295
  # -----------------------------------------------------------------------------
292
296
  # PPC2 Calculation (Vinck et al., 2010)
@@ -297,18 +301,33 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
297
301
  # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
298
302
  # -----------------------------------------------------------------------------
299
303
 
300
- Parameters:
301
- - spike_times: Array of spike times
302
- - lfp_signal: Local field potential time series
303
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
304
- - lfp_fs: Sampling frequency in Hz of the LFP
305
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
306
- - freq_of_interest: Desired frequency for wavelet phase extraction
307
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
308
- - bandwidth: Bandwidth parameter for the wavelet
309
-
310
- Returns:
311
- - ppc2: Pairwise Phase Consistency 2 value
304
+ Parameters
305
+ ----------
306
+ spike_times : np.ndarray
307
+ Array of spike times
308
+ lfp_data : np.ndarray
309
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
310
+ spike_fs : float, optional
311
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
312
+ lfp_fs : float
313
+ Sampling frequency in Hz of the LFP data
314
+ filter_method : str, optional
315
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
316
+ freq_of_interest : float, optional
317
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
318
+ lowcut : float, optional
319
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
320
+ highcut : float, optional
321
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
322
+ bandwidth : float, optional
323
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
324
+ filtered_lfp_phase : np.ndarray, optional
325
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
326
+
327
+ Returns
328
+ -------
329
+ float
330
+ Pairwise Phase Consistency 2 (PPC2) value
312
331
  """
313
332
 
314
333
  if spike_fs is None:
@@ -320,33 +339,21 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
320
339
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
321
340
 
322
341
  # Filter indices to ensure they're within bounds of the LFP signal
323
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
342
+ if filtered_lfp_phase is not None:
343
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
344
+ else:
345
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
346
+
324
347
  if len(valid_indices) <= 1:
325
- return 0, np.array([])
348
+ return 0
326
349
 
327
- # Extract phase using the specified method
328
- if method == 'wavelet':
329
- if freq_of_interest is None:
330
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
331
-
332
- # Apply CWT to extract phase at the frequency of interest
333
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
334
- instantaneous_phase = np.angle(lfp_complex)
335
-
336
- elif method == 'hilbert':
337
- if lowcut is None or highcut is None:
338
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
339
- filtered_lfp = lfp_signal
340
- else:
341
- # Bandpass filter the signal
342
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
343
-
344
- # Get phase using the Hilbert transform
345
- analytic_signal = signal.hilbert(filtered_lfp)
346
- instantaneous_phase = np.angle(analytic_signal)
347
-
350
+ # Get instantaneous phase
351
+ if filtered_lfp_phase is None:
352
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
353
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
354
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
348
355
  else:
349
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
356
+ instantaneous_phase = filtered_lfp_phase
350
357
 
351
358
  # Get phases at spike times
352
359
  spike_phases = instantaneous_phase[valid_indices]
@@ -355,7 +362,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
355
362
  n = len(spike_phases)
356
363
 
357
364
  if n <= 1:
358
- return 0, spike_phases
365
+ return 0
359
366
 
360
367
  # Convert phases to unit vectors in the complex plane
361
368
  unit_vectors = np.exp(1j * spike_phases)
@@ -369,40 +376,77 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
369
376
  return ppc2
370
377
 
371
378
 
372
- def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
373
- spike_fs: float, lfp_fs:float,
374
- pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
379
+ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
380
+ entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
381
+ spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
382
+ freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
375
383
  """
376
- Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
384
+ Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
377
385
 
378
- This function computes the PPC for each neuron within the specified populations based on their spike times
379
- and a single-channel local field potential (LFP) signal.
386
+ This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
387
+ and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
388
+ organized by population, node ID, and frequency.
380
389
 
381
- Args:
382
- spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
383
- lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
384
- spike_fs (float): sampling rate of spikes BMTK default is 1000
385
- lfp_fs (float): sampling rate of lfp
386
- pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
387
- freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
390
+ Parameters
391
+ ----------
392
+ spike_df : pd.DataFrame
393
+ DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
394
+ lfp_data : np.ndarray
395
+ Local field potential (LFP) time series data
396
+ filter_method : str, optional
397
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
398
+ entrainment_method : str, optional
399
+ Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
400
+ lowcut : float, optional
401
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
402
+ highcut : float, optional
403
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
404
+ spike_fs : float
405
+ Sampling frequency of the spike times in Hz
406
+ lfp_fs : float
407
+ Sampling frequency of the LFP signal in Hz
408
+ bandwidth : float, optional
409
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
410
+ ppc_method : str, optional
411
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
412
+ pop_names : List[str]
413
+ List of population names to analyze
414
+ freqs : List[float]
415
+ List of frequencies (in Hz) at which to calculate entrainment
388
416
 
389
- Returns:
390
- Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
391
- {
392
- population_name: {
393
- node_id: {
394
- frequency: PPC value
395
- }
417
+ Returns
418
+ -------
419
+ Dict[str, Dict[int, Dict[float, float]]]
420
+ Nested dictionary where the structure is:
421
+ {
422
+ population_name: {
423
+ node_id: {
424
+ frequency: entrainment value
396
425
  }
397
426
  }
398
- PPC values are floats representing the pairwise phase consistency at each frequency.
427
+ }
428
+ Entrainment values are floats representing the metric (PPC, PLV) at each frequency
399
429
  """
400
- ppc_dict = {}
430
+ # pre filter lfp to speed up calculate of entrainment
431
+ filtered_lfp_phases = {}
432
+ for freq in range(len(freqs)):
433
+ phase = get_lfp_phase(
434
+ lfp_data=lfp_data,
435
+ freq_of_interest=freqs[freq],
436
+ fs=lfp_fs,
437
+ filter_method=filter_method,
438
+ lowcut=lowcut,
439
+ highcut=highcut,
440
+ bandwidth=bandwidth
441
+ )
442
+ filtered_lfp_phases[freqs[freq]] = phase
443
+
444
+ entrainment_dict = {}
401
445
  for pop in pop_names:
402
446
  skip_count = 0
403
447
  pop_spikes = spike_df[spike_df['pop_name'] == pop]
404
448
  nodes = pop_spikes['node_ids'].unique()
405
- ppc_dict[pop] = {}
449
+ entrainment_dict[pop] = {}
406
450
  print(f'Processing {pop} population')
407
451
  for node in tqdm(nodes):
408
452
  node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
@@ -412,24 +456,58 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
412
456
  skip_count += 1
413
457
  continue
414
458
 
415
- ppc_dict[pop][node] = {}
459
+ entrainment_dict[pop][node] = {}
416
460
  for freq in freqs:
417
- ppc = calculate_ppc2(
418
- node_spikes['timestamps'].values,
419
- lfp_signal,
420
- spike_fs=spike_fs,
421
- lfp_fs=lfp_fs,
422
- freq_of_interest=freq,
423
- method='wavelet'
424
- )
425
- ppc_dict[pop][node][freq] = ppc
461
+ # Calculate entrainment based on the selected method using the pre-filtered phases
462
+ if entrainment_method == 'plv':
463
+ entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
464
+ node_spikes['timestamps'].values,
465
+ lfp_data,
466
+ spike_fs=spike_fs,
467
+ lfp_fs=lfp_fs,
468
+ freq_of_interest=freq,
469
+ bandwidth=bandwidth,
470
+ lowcut=lowcut,
471
+ highcut=highcut,
472
+ filter_method=filter_method,
473
+ filtered_lfp_phase=filtered_lfp_phases[freq]
474
+ )
475
+ elif entrainment_method == 'ppc2':
476
+ entrainment_dict[pop][node][freq] = calculate_ppc2(
477
+ node_spikes['timestamps'].values,
478
+ lfp_data,
479
+ spike_fs=spike_fs,
480
+ lfp_fs=lfp_fs,
481
+ freq_of_interest=freq,
482
+ bandwidth=bandwidth,
483
+ lowcut=lowcut,
484
+ highcut=highcut,
485
+ filter_method=filter_method,
486
+ filtered_lfp_phase=filtered_lfp_phases[freq]
487
+ )
488
+ elif entrainment_method == 'ppc':
489
+ entrainment_dict[pop][node][freq] = calculate_ppc(
490
+ node_spikes['timestamps'].values,
491
+ lfp_data,
492
+ spike_fs=spike_fs,
493
+ lfp_fs=lfp_fs,
494
+ freq_of_interest=freq,
495
+ bandwidth=bandwidth,
496
+ lowcut=lowcut,
497
+ highcut=highcut,
498
+ filter_method=filter_method,
499
+ ppc_method=ppc_method,
500
+ filtered_lfp_phase=filtered_lfp_phases[freq]
501
+ )
426
502
 
427
- print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
503
+ print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
428
504
 
429
- return ppc_dict
505
+ return entrainment_dict
430
506
 
431
507
 
432
- def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_range=(10, 100), freq_step=5):
508
+ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
509
+ bandwidth=2.0, lowcut=None, highcut=None,
510
+ freq_range=(10, 100), freq_step=5):
433
511
  """
434
512
  Calculate correlation between population spike rates and LFP power across frequencies
435
513
  using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
@@ -438,16 +516,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
438
516
  -----------
439
517
  spike_rate : DataFrame
440
518
  Pre-calculated population spike rates at the same fs as lfp
441
- lfp : np.array
519
+ lfp_data : np.array
442
520
  LFP data
443
521
  fs : float
444
522
  Sampling frequency
445
523
  pop_names : list
446
524
  List of population names to analyze
447
- freq_range : tuple
448
- Min and max frequency to analyze
449
- freq_step : float
450
- Step size for frequency analysis
525
+ filter_method : str, optional
526
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
527
+ bandwidth : float, optional
528
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
529
+ lowcut : float, optional
530
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
531
+ highcut : float, optional
532
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
533
+ freq_range : tuple, optional
534
+ Min and max frequency to analyze (default: (10, 100))
535
+ freq_step : float, optional
536
+ Step size for frequency analysis (default: 5)
451
537
 
452
538
  Returns:
453
539
  --------
@@ -463,14 +549,11 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
463
549
  # Dictionary to store results
464
550
  correlation_results = {pop: {} for pop in pop_names}
465
551
 
466
- # Calculate power at each frequency band using wavelet filter
552
+ # Calculate power at each frequency band using specified filter
467
553
  power_by_freq = {}
468
554
  for freq in frequencies:
469
- # Use the wavelet_filter function from bmlfp
470
- filtered_signal = wavelet_filter(lfp, freq, fs)
471
- # Calculate power (magnitude squared of complex wavelet transform)
472
- power = np.abs(filtered_signal)**2
473
- power_by_freq[freq] = power
555
+ power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
556
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
474
557
 
475
558
  # Calculate correlation for each population
476
559
  for pop in pop_names:
@@ -481,7 +564,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
481
564
  for freq in frequencies:
482
565
  # Make sure the lengths match
483
566
  if len(pop_rate) != len(power_by_freq[freq]):
484
- raise Exception(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
567
+ raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
485
568
  # use spearman for non-parametric correlation
486
569
  corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
487
570
  correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
bmtool/analysis/lfp.py CHANGED
@@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
273
273
  return normalized_power
274
274
 
275
275
 
276
- def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
276
+ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
277
+ """
278
+ Calculate the passband of a complex Morlet wavelet filter.
279
+
280
+ Parameters
281
+ ----------
282
+ center_freq : float
283
+ Center frequency (Hz) of the wavelet filter
284
+ bandwidth : float
285
+ Bandwidth parameter of the wavelet filter
286
+ threshold : float, optional
287
+ Power threshold to define the passband edges (default: 0.5 = -3dB point)
288
+
289
+ Returns
290
+ -------
291
+ tuple
292
+ (lower_bound, upper_bound, passband_width) of the frequency passband in Hz
293
+ """
294
+ # Create a high-resolution frequency axis around the center frequency
295
+ # Extend range to 3x the expected width to ensure we capture the full passband
296
+ expected_width = center_freq * bandwidth / 2
297
+ freq_min = max(0.1, center_freq - 3 * expected_width)
298
+ freq_max = center_freq + 3 * expected_width
299
+ freq_axis = np.linspace(freq_min, freq_max, 1000)
300
+
301
+ # Calculate the theoretical frequency response of the Morlet wavelet
302
+ # For a complex Morlet wavelet, the frequency response approximates a Gaussian
303
+ # centered at the center frequency with width related to the bandwidth parameter
304
+ sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
305
+ response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
306
+
307
+ # Find the passband edges (where response crosses the threshold)
308
+ above_threshold = response >= threshold
309
+ if not np.any(above_threshold):
310
+ return (center_freq, center_freq, 0) # No passband found
311
+
312
+ # Find the first and last indices where response is above threshold
313
+ indices = np.where(above_threshold)[0]
314
+ lower_idx = indices[0]
315
+ upper_idx = indices[-1]
316
+
317
+ # Get the corresponding frequencies
318
+ lower_bound = freq_axis[lower_idx]
319
+ upper_bound = freq_axis[upper_idx]
320
+ passband_width = upper_bound - lower_bound
321
+
322
+ return (lower_bound, upper_bound, passband_width)
323
+
324
+
325
+ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
277
326
  """
278
327
  Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
328
+
329
+ Parameters
330
+ ----------
331
+ x : np.ndarray
332
+ Input signal
333
+ freq : float
334
+ Target frequency for the wavelet filter
335
+ fs : float
336
+ Sampling frequency of the signal
337
+ bandwidth : float, optional
338
+ Bandwidth parameter of the wavelet filter (default is 1.0)
339
+ axis : int, optional
340
+ Axis along which to compute the CWT (default is -1)
341
+ show_passband : bool, optional
342
+ If True, print the passband of the wavelet filter (default is False)
343
+
344
+ Returns
345
+ -------
346
+ np.ndarray
347
+ Continuous Wavelet Transform of the input signal
279
348
  """
349
+ if show_passband:
350
+ lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
351
+ print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
352
+ print(f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
280
353
  wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
281
354
  scale = pywt.scale2frequency(wavelet, 1) * fs / freq
282
355
  x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
@@ -292,6 +365,108 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
292
365
  return x_a
293
366
 
294
367
 
368
+ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
369
+ lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
370
+ """
371
+ Compute the power of the raw LFP signal in a specified frequency band.
372
+
373
+ Parameters
374
+ ----------
375
+ lfp_data : np.ndarray
376
+ Raw local field potential (LFP) time series data
377
+ freq : float
378
+ Center frequency (Hz) for wavelet filtering method
379
+ fs : float
380
+ Sampling frequency (Hz) of the input data
381
+ filter_method : str, optional
382
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
383
+ lowcut : float, optional
384
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
385
+ highcut : float, optional
386
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
387
+ bandwidth : float, optional
388
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
389
+
390
+ Returns
391
+ -------
392
+ np.ndarray
393
+ Power of the filtered signal (magnitude squared)
394
+
395
+ Notes
396
+ -----
397
+ - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
398
+ - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
399
+ - When using the 'butter' method, both lowcut and highcut must be provided
400
+ """
401
+ if filter_method == 'wavelet':
402
+ filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
403
+ elif filter_method == 'butter':
404
+ if lowcut is None or highcut is None:
405
+ raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
406
+ filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
407
+ else:
408
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
409
+
410
+ # Calculate power (magnitude squared of filtered signal)
411
+ power = np.abs(filtered_signal)**2
412
+ return power
413
+
414
+
415
+ def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
416
+ lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
417
+ """
418
+ Calculate the phase of the filtered signal.
419
+
420
+ Parameters
421
+ ----------
422
+ lfp_data : np.ndarray
423
+ Input LFP data
424
+ fs : float
425
+ Sampling frequency (Hz)
426
+ freq : float
427
+ Frequency of interest (Hz)
428
+ filter_method : str, optional
429
+ Method for filtering the signal ('wavelet' or 'butter')
430
+ bandwidth : float, optional
431
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
432
+ lowcut : float, optional
433
+ Low cutoff frequency for Butterworth filter when method='butter'
434
+ highcut : float, optional
435
+ High cutoff frequency for Butterworth filter when method='butter'
436
+
437
+ Returns
438
+ -------
439
+ np.ndarray
440
+ Phase of the filtered signal
441
+
442
+ Notes
443
+ -----
444
+ - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
445
+ - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
446
+ followed by Hilbert transform to extract the phase
447
+ - When using the 'butter' method, both lowcut and highcut must be provided
448
+ """
449
+ if filter_method == 'wavelet':
450
+ if freq_of_interest is None:
451
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
452
+ # Wavelet filter returns complex values directly
453
+ filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
454
+ # Phase is the angle of the complex signal
455
+ phase = np.angle(filtered_signal)
456
+ elif filter_method == 'butter':
457
+ if lowcut is None or highcut is None:
458
+ raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
459
+ # Butterworth filter returns real values
460
+ filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
461
+ # Apply Hilbert transform to get analytic signal (complex)
462
+ analytic_signal = signal.hilbert(filtered_signal)
463
+ # Phase is the angle of the analytic signal
464
+ phase = np.angle(analytic_signal)
465
+ else:
466
+ raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
467
+
468
+ return phase
469
+
295
470
  # windowing functions
296
471
  def windowed_xarray(da, windows, dim='time',
297
472
  new_coord_name='cycle', new_coord=None):
bmtool/analysis/spikes.py CHANGED
@@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu
11
11
  import os
12
12
 
13
13
 
14
- def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
14
+ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
15
15
  """
16
16
  Load spike data from an HDF5 file into a pandas DataFrame.
17
17
 
18
- Args:
19
- spike_file (str): Path to the HDF5 file containing spike data.
20
- network_name (str): The name of the network within the HDF5 file from which to load spike data.
21
- sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
22
- config (str, optional): Will label the cell type of each spike.
23
- groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
24
-
25
- Returns:
26
- pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
18
+ Parameters
19
+ ----------
20
+ spike_file : str
21
+ Path to the HDF5 file containing spike data
22
+ network_name : str
23
+ The name of the network within the HDF5 file from which to load spike data
24
+ sort : bool, optional
25
+ Whether to sort the DataFrame by 'timestamps' (default: True)
26
+ config : str, optional
27
+ Path to configuration file to label the cell type of each spike (default: None)
28
+ groupby : Union[str, List[str]], optional
29
+ The column(s) to group by (default: 'pop_name')
30
+
31
+ Returns
32
+ -------
33
+ pd.DataFrame
34
+ A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
35
+ with additional columns if a config file is provided
27
36
 
28
- Example:
29
- df = load_spikes_to_df("spikes.h5", "cortex")
37
+ Examples
38
+ --------
39
+ >>> df = load_spikes_to_df("spikes.h5", "cortex")
40
+ >>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
30
41
  """
31
42
  with h5py.File(spike_file) as f:
32
43
  spikes_df = pd.DataFrame({
@@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
126
137
 
127
138
 
128
139
  def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
129
- time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
140
+ time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
130
141
  """
131
142
  Calculate the spike count or frequency histogram over specified time intervals.
132
143
 
133
- Args:
134
- spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
135
- time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
136
- Used to create evenly spaced time points if `time_points` is not provided. Default is None.
137
- time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
138
- If provided, `time` is ignored. Default is None.
139
- frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
140
-
141
- Returns:
142
- np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
143
-
144
- Raises:
145
- ValueError: If both `time` and `time_points` are None.
144
+ Parameters
145
+ ----------
146
+ spike_times : Union[np.ndarray, list]
147
+ Array or list of spike times in milliseconds
148
+ time : Optional[Tuple[float, float, float]], optional
149
+ Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
150
+ if `time_points` is not provided. Default is None.
151
+ time_points : Optional[Union[np.ndarray, list]], optional
152
+ Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
153
+ frequency : bool, optional
154
+ If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
155
+
156
+ Returns
157
+ -------
158
+ np.ndarray
159
+ Array of spike counts or frequencies, depending on the `frequency` flag.
160
+
161
+ Raises
162
+ ------
163
+ ValueError
164
+ If both `time` and `time_points` are None.
146
165
  """
147
166
  if time_points is None:
148
167
  if time is None:
@@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
156
175
  bins = np.append(time_points, time_points[-1] + dt)
157
176
  spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
158
177
 
159
- if frequeny:
178
+ if frequency:
160
179
  spike_rate = 1000 / dt * spike_rate
161
180
 
162
181
  return spike_rate
163
182
 
164
183
 
165
- def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
184
+ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
166
185
  config: Optional[str] = None, network_name: Optional[str] = None,
167
186
  save: bool = False, save_path: Optional[str] = None,
168
187
  normalize: bool = False) -> Dict[str, np.ndarray]:
169
188
  """
170
189
  Calculate the population spike rate for each population in the given spike data, with an option to normalize.
171
190
 
172
- Args:
173
- spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
174
- fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
175
- t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
176
- t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
177
- config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
178
- If None, node count is estimated from unique node spikes. Default is None.
179
- network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
180
- Required if `config` is provided. Default is None.
181
- save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
182
- save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
183
- normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
184
-
185
- Returns:
186
- Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
187
- If `normalize` is True, each population's spike rate is scaled to [0, 1].
188
-
189
- Raises:
190
- ValueError: If `save` is True but `save_path` is not provided.
191
-
192
- Notes:
193
- - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
194
- - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
195
-
191
+ Parameters
192
+ ----------
193
+ spike_data : pd.DataFrame
194
+ A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
195
+ fs : float, optional
196
+ Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
197
+ t_start : float, optional
198
+ Start time (in milliseconds) for spike rate calculation (default: 0)
199
+ t_stop : Optional[float], optional
200
+ Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
201
+ config : Optional[str], optional
202
+ Path to a configuration file containing node information, used to determine the correct number of nodes per population.
203
+ If None, node count is estimated from unique node spikes (default: None)
204
+ network_name : Optional[str], optional
205
+ Name of the network used in the configuration file, allowing selection of nodes for that network.
206
+ Required if `config` is provided (default: None)
207
+ save : bool, optional
208
+ Whether to save the calculated population spike rate to a file (default: False)
209
+ save_path : Optional[str], optional
210
+ Directory path where the file should be saved if `save` is True (default: None)
211
+ normalize : bool, optional
212
+ Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
213
+
214
+ Returns
215
+ -------
216
+ Dict[str, np.ndarray]
217
+ A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
218
+ If `normalize` is True, each population's spike rate is scaled to [0, 1].
219
+
220
+ Raises
221
+ ------
222
+ ValueError
223
+ If `save` is True but `save_path` is not provided.
224
+
225
+ Notes
226
+ -----
227
+ - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
228
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
196
229
  """
197
230
  pop_spikes = {}
198
231
  node_number = {}
@@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
205
238
  if not network_name:
206
239
  print("Grabbing first network; specify a network name to ensure correct node population is selected.")
207
240
 
208
- for pop_name in spikes['pop_name'].unique():
209
- ps = spikes[spikes['pop_name'] == pop_name]
241
+ for pop_name in spike_data['pop_name'].unique():
242
+ ps = spike_data[spike_data['pop_name'] == pop_name]
210
243
 
211
244
  if config:
212
245
  nodes = load_nodes_from_config(config)
@@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
220
253
  node_number[pop_name] = ps['node_ids'].nunique()
221
254
 
222
255
  if t_stop is None:
223
- t_stop = spikes['timestamps'].max()
256
+ t_stop = spike_data['timestamps'].max()
224
257
 
225
- filtered_spikes = spikes[
226
- (spikes['pop_name'] == pop_name) &
227
- (spikes['timestamps'] > t_start) &
228
- (spikes['timestamps'] < t_stop)
258
+ filtered_spikes = spike_data[
259
+ (spike_data['pop_name'] == pop_name) &
260
+ (spike_data['timestamps'] > t_start) &
261
+ (spike_data['timestamps'] < t_stop)
229
262
  ]
230
263
  pop_spikes[pop_name] = filtered_spikes
231
264
 
@@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
254
287
  return spike_rate
255
288
 
256
289
 
257
- def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
290
+ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
258
291
  """
259
- Compares the firing rates of a population during two different time windows
260
- time_window_1 and time_window_2 should be a list of [start, stop] in milliseconds
261
- Returns firing rates and results of a Mann-Whitney U test (non-parametric)
292
+ Compares the firing rates of a population during two different time windows and performs
293
+ a statistical test to determine if there is a significant difference.
294
+
295
+ Parameters
296
+ ----------
297
+ spike_df : pd.DataFrame
298
+ DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
299
+ group_by : str
300
+ Column name to group spikes by (e.g., 'pop_name')
301
+ time_window_1 : List[float]
302
+ First time window as [start, stop] in milliseconds
303
+ time_window_2 : List[float]
304
+ Second time window as [start, stop] in milliseconds
305
+
306
+ Returns
307
+ -------
308
+ None
309
+ Results are printed to the console
310
+
311
+ Notes
312
+ -----
313
+ Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
262
314
  """
263
315
  # Filter spikes for the population of interest
264
316
  for pop_name in spike_df[group_by].unique():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.0.3
3
+ Version: 0.7.0.5
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
8
8
  bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
9
9
  bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=IMhjbLYw-rL-MfRuFT3uCkyUFObNVJhcxmYV0R9Uh-M,20007
12
- bmtool/analysis/lfp.py,sha256=3dZkpyVqDtmssvxAqvX4zOLafhwmFnyMehSgVvnj5lM,16754
11
+ bmtool/analysis/entrainment.py,sha256=7lFlGMApL_2snwdvIPDLFW1KKPdyiuCnZ5ADa7ujx5o,24439
12
+ bmtool/analysis/lfp.py,sha256=6u9cHnac-5Fzpk9ecQew7MmXBAolzKZakRsnPn3-C2U,24109
13
13
  bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
14
- bmtool/analysis/spikes.py,sha256=x24kd0RUhumJkiunfHNEE7mM6JUqdWy1gqabmkMM4cU,14129
14
+ bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
17
17
  bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
26
26
  bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
29
- bmtool-0.7.0.3.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.0.3.dist-info/METADATA,sha256=3jYm8B3kDgHTyLyl5C5zdvxZ8tiGoC4hJvAiVu5QLSM,2768
31
- bmtool-0.7.0.3.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
- bmtool-0.7.0.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.0.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.0.3.dist-info/RECORD,,
29
+ bmtool-0.7.0.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.0.5.dist-info/METADATA,sha256=5A-VT9HRvmYInIJg4FvfxYSYGL70RzSgOaAaRULRXYs,2768
31
+ bmtool-0.7.0.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
+ bmtool-0.7.0.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.0.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.0.5.dist-info/RECORD,,