bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ import numba
8
8
  from numba import cuda
9
9
  import pandas as pd
10
10
  import xarray as xr
11
- from .lfp import wavelet_filter,butter_bandpass_filter
11
+ from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power
12
12
  from typing import Dict, List
13
13
  from tqdm.notebook import tqdm
14
14
  import scipy.stats as stats
@@ -16,49 +16,65 @@ import seaborn as sns
16
16
  import matplotlib.pyplot as plt
17
17
 
18
18
 
19
- def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
20
- method: str = 'wavelet', lowcut: float = None, highcut: float = None,
19
+ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
20
+ filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
21
21
  bandwidth: float = 2.0) -> np.ndarray:
22
22
  """
23
23
  Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
24
24
 
25
- Parameters:
26
- - x1, x2: Input signals (1D arrays, same length)
27
- - fs: Sampling frequency
28
- - freq_of_interest: Desired frequency for wavelet PLV calculation
29
- - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
30
- - lowcut, highcut: Cutoff frequencies for the Hilbert method
31
- - bandwidth: Bandwidth parameter for the wavelet
32
-
33
- Returns:
34
- - plv: Phase Locking Value (1D array)
25
+ Parameters
26
+ ----------
27
+ signal1 : np.ndarray
28
+ First input signal (1D array)
29
+ signal2 : np.ndarray
30
+ Second input signal (1D array, same length as signal1)
31
+ fs : float
32
+ Sampling frequency in Hz
33
+ freq_of_interest : float, optional
34
+ Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
35
+ filter_method : str, optional
36
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
37
+ lowcut : float, optional
38
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
39
+ highcut : float, optional
40
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
41
+ bandwidth : float, optional
42
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray
47
+ Phase Locking Value (1D array)
35
48
  """
36
- if len(x1) != len(x2):
49
+ if len(signal1) != len(signal2):
37
50
  raise ValueError("Input signals must have the same length.")
38
51
 
39
- if method == 'wavelet':
52
+ if filter_method == 'wavelet':
40
53
  if freq_of_interest is None:
41
54
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
42
55
 
43
56
  # Apply CWT to both signals
44
- theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
45
- theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
57
+ theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
58
+ theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
46
59
 
47
- elif method == 'hilbert':
60
+ elif filter_method == 'butter':
48
61
  if lowcut is None or highcut is None:
49
- print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
62
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
50
63
 
51
64
  if lowcut and highcut:
52
65
  # Bandpass filter and get the analytic signal using the Hilbert transform
53
- x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
54
- x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
55
-
56
- # Get phase using the Hilbert transform
57
- theta1 = signal.hilbert(x1)
58
- theta2 = signal.hilbert(x2)
66
+ filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
67
+ filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
68
+ # Get phase using the Hilbert transform
69
+ theta1 = signal.hilbert(filtered_signal1)
70
+ theta2 = signal.hilbert(filtered_signal2)
71
+ else:
72
+ # Get phase using the Hilbert transform without filtering
73
+ theta1 = signal.hilbert(signal1)
74
+ theta2 = signal.hilbert(signal2)
59
75
 
60
76
  else:
61
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
77
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
62
78
 
63
79
  # Calculate phase difference
64
80
  phase_diff = np.angle(theta1) - np.angle(theta2)
@@ -69,29 +85,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_
69
85
  return plv
70
86
 
71
87
 
72
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
73
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
88
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
89
+ lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
74
90
  lowcut: float = None, highcut: float = None,
75
91
  bandwidth: float = 2.0) -> tuple:
76
92
  """
77
93
  Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
78
94
 
79
- Parameters:
80
- - spike_times: Array of spike times
81
- - lfp_signal: Local field potential time series
82
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
83
- - lfp_fs : Sampling frequency in Hz of the LFP
84
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
85
- - freq_of_interest: Desired frequency for wavelet phase extraction
86
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
87
- - bandwidth: Bandwidth parameter for the wavelet
88
-
89
- Returns:
90
- - ppc1: Phase-Phase Coupling value
91
- - spike_phases: Phases at spike times
95
+ Parameters
96
+ ----------
97
+ spike_times : np.ndarray
98
+ Array of spike times
99
+ lfp_data : np.ndarray
100
+ Local field potential time series data
101
+ spike_fs : float, optional
102
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
103
+ lfp_fs : float
104
+ Sampling frequency in Hz of the LFP data
105
+ filter_method : str, optional
106
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
107
+ freq_of_interest : float, optional
108
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
109
+ lowcut : float, optional
110
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
111
+ highcut : float, optional
112
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
113
+ bandwidth : float, optional
114
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
115
+
116
+ Returns
117
+ -------
118
+ tuple
119
+ (plv, spike_phases) where:
120
+ - plv: Phase Locking Value
121
+ - spike_phases: Phases at spike times
92
122
  """
93
123
 
94
- if spike_fs == None:
124
+ if spike_fs is None:
95
125
  spike_fs = lfp_fs
96
126
  # Convert spike times to sample indices
97
127
  spike_times_seconds = spike_times / spike_fs
@@ -100,33 +130,29 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
100
130
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
101
131
 
102
132
  # Filter indices to ensure they're within bounds of the LFP signal
103
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
133
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
104
134
  if len(valid_indices) <= 1:
105
135
  return 0, np.array([])
106
136
 
107
- # Extract phase using the specified method
108
- if method == 'wavelet':
137
+ # Filter the LFP signal to extract the phase
138
+ if filter_method == 'wavelet':
109
139
  if freq_of_interest is None:
110
140
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
111
141
 
112
- # Apply CWT to extract phase at the frequency of interest
113
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
114
- instantaneous_phase = np.angle(lfp_complex)
115
-
116
- elif method == 'hilbert':
142
+ # Apply CWT to extract phase
143
+ filtered_lfp = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
144
+
145
+ elif filter_method == 'butter':
117
146
  if lowcut is None or highcut is None:
118
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
119
- filtered_lfp = lfp_signal
120
- else:
121
- # Bandpass filter the signal
122
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
147
+ raise ValueError("Both lowcut and highcut must be specified for the butter method.")
123
148
 
124
- # Get phase using the Hilbert transform
125
- analytic_signal = signal.hilbert(filtered_lfp)
126
- instantaneous_phase = np.angle(analytic_signal)
149
+ # Bandpass filter the LFP signal
150
+ filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
151
+ filtered_lfp = signal.hilbert(filtered_lfp) # Get analytic signal
152
+
127
153
 
128
154
  else:
129
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
155
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
130
156
 
131
157
  # Get phases at spike times
132
158
  spike_phases = instantaneous_phase[valid_indices]
@@ -181,27 +207,43 @@ def _ppc_gpu(spike_phases):
181
207
  return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
182
208
 
183
209
 
184
- def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
185
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
210
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
211
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
186
212
  lowcut: float = None, highcut: float = None,
187
- bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
213
+ bandwidth: float = 2.0, ppc_method: str = 'numpy') -> tuple:
188
214
  """
189
215
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
190
216
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
191
217
 
192
- Parameters:
193
- - spike_times: Array of spike times
194
- - lfp_signal: Local field potential time series
195
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
196
- - lfp_fs: Sampling frequency in Hz of the LFP
197
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
198
- - freq_of_interest: Desired frequency for wavelet phase extraction
199
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
200
- - bandwidth: Bandwidth parameter for the wavelet
201
- - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
202
-
203
- Returns:
204
- - ppc: Pairwise Phase Consistency value
218
+ Parameters
219
+ ----------
220
+ spike_times : np.ndarray
221
+ Array of spike times
222
+ lfp_data : np.ndarray
223
+ Local field potential time series data
224
+ spike_fs : float, optional
225
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
226
+ lfp_fs : float
227
+ Sampling frequency in Hz of the LFP data
228
+ filter_method : str, optional
229
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
230
+ freq_of_interest : float, optional
231
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
232
+ lowcut : float, optional
233
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
234
+ highcut : float, optional
235
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
236
+ bandwidth : float, optional
237
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
238
+ ppc_method : str, optional
239
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
240
+
241
+ Returns
242
+ -------
243
+ tuple
244
+ (ppc, spike_phases) where:
245
+ - ppc: Pairwise Phase Consistency value
246
+ - spike_phases: Phases at spike times
205
247
  """
206
248
  if spike_fs is None:
207
249
  spike_fs = lfp_fs
@@ -212,33 +254,32 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
212
254
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
213
255
 
214
256
  # Filter indices to ensure they're within bounds of the LFP signal
215
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
257
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
216
258
  if len(valid_indices) <= 1:
217
259
  return 0, np.array([])
218
260
 
219
261
  # Extract phase using the specified method
220
- if method == 'wavelet':
262
+ if filter_method == 'wavelet':
221
263
  if freq_of_interest is None:
222
264
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
223
265
 
224
266
  # Apply CWT to extract phase at the frequency of interest
225
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
267
+ lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
226
268
  instantaneous_phase = np.angle(lfp_complex)
227
269
 
228
- elif method == 'hilbert':
270
+ elif filter_method == 'butter':
229
271
  if lowcut is None or highcut is None:
230
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
231
- filtered_lfp = lfp_signal
232
- else:
233
- # Bandpass filter the signal
234
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
272
+ raise ValueError("Both lowcut and highcut must be specified for the butter method.")
273
+
274
+ # Bandpass filter the signal
275
+ filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
235
276
 
236
277
  # Get phase using the Hilbert transform
237
278
  analytic_signal = signal.hilbert(filtered_lfp)
238
279
  instantaneous_phase = np.angle(analytic_signal)
239
280
 
240
281
  else:
241
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
282
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
242
283
 
243
284
  # Get phases at spike times
244
285
  spike_phases = instantaneous_phase[valid_indices]
@@ -283,10 +324,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
283
324
  return ppc
284
325
 
285
326
 
286
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
287
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
327
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
328
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
288
329
  lowcut: float = None, highcut: float = None,
289
- bandwidth: float = 2.0) -> tuple:
330
+ bandwidth: float = 2.0) -> float:
290
331
  """
291
332
  # -----------------------------------------------------------------------------
292
333
  # PPC2 Calculation (Vinck et al., 2010)
@@ -297,18 +338,31 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
297
338
  # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
298
339
  # -----------------------------------------------------------------------------
299
340
 
300
- Parameters:
301
- - spike_times: Array of spike times
302
- - lfp_signal: Local field potential time series
303
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
304
- - lfp_fs: Sampling frequency in Hz of the LFP
305
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
306
- - freq_of_interest: Desired frequency for wavelet phase extraction
307
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
308
- - bandwidth: Bandwidth parameter for the wavelet
309
-
310
- Returns:
311
- - ppc2: Pairwise Phase Consistency 2 value
341
+ Parameters
342
+ ----------
343
+ spike_times : np.ndarray
344
+ Array of spike times
345
+ lfp_data : np.ndarray
346
+ Local field potential time series data
347
+ spike_fs : float, optional
348
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
349
+ lfp_fs : float
350
+ Sampling frequency in Hz of the LFP data
351
+ filter_method : str, optional
352
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
353
+ freq_of_interest : float, optional
354
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
355
+ lowcut : float, optional
356
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
357
+ highcut : float, optional
358
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
359
+ bandwidth : float, optional
360
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
361
+
362
+ Returns
363
+ -------
364
+ float
365
+ Pairwise Phase Consistency 2 (PPC2) value
312
366
  """
313
367
 
314
368
  if spike_fs is None:
@@ -320,33 +374,32 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
320
374
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
321
375
 
322
376
  # Filter indices to ensure they're within bounds of the LFP signal
323
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
377
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
324
378
  if len(valid_indices) <= 1:
325
- return 0, np.array([])
379
+ return 0
326
380
 
327
381
  # Extract phase using the specified method
328
- if method == 'wavelet':
382
+ if filter_method == 'wavelet':
329
383
  if freq_of_interest is None:
330
384
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
331
385
 
332
386
  # Apply CWT to extract phase at the frequency of interest
333
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
387
+ lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
334
388
  instantaneous_phase = np.angle(lfp_complex)
335
389
 
336
- elif method == 'hilbert':
390
+ elif filter_method == 'butter':
337
391
  if lowcut is None or highcut is None:
338
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
339
- filtered_lfp = lfp_signal
340
- else:
341
- # Bandpass filter the signal
342
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
392
+ raise ValueError("Both lowcut and highcut must be specified for the butter method.")
393
+
394
+ # Bandpass filter the signal
395
+ filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
343
396
 
344
397
  # Get phase using the Hilbert transform
345
398
  analytic_signal = signal.hilbert(filtered_lfp)
346
399
  instantaneous_phase = np.angle(analytic_signal)
347
400
 
348
401
  else:
349
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
402
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
350
403
 
351
404
  # Get phases at spike times
352
405
  spike_phases = instantaneous_phase[valid_indices]
@@ -355,7 +408,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
355
408
  n = len(spike_phases)
356
409
 
357
410
  if n <= 1:
358
- return 0, spike_phases
411
+ return 0
359
412
 
360
413
  # Convert phases to unit vectors in the complex plane
361
414
  unit_vectors = np.exp(1j * spike_phases)
@@ -369,33 +422,43 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
369
422
  return ppc2
370
423
 
371
424
 
372
- def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
373
- spike_fs: float, lfp_fs:float,
374
- pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
425
+ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None,
426
+ spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
427
+ pop_names: List[str]=None, freqs: List[float]=None) -> Dict[str, Dict[int, Dict[float, float]]]:
375
428
  """
376
429
  Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
377
430
 
378
431
  This function computes the PPC for each neuron within the specified populations based on their spike times
379
- and a single-channel local field potential (LFP) signal.
432
+ and the provided LFP signal. It returns a nested dictionary structure containing the PPC values
433
+ organized by population, node ID, and frequency.
380
434
 
381
- Args:
382
- spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
383
- lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
384
- spike_fs (float): sampling rate of spikes BMTK default is 1000
385
- lfp_fs (float): sampling rate of lfp
386
- pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
387
- freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
435
+ Parameters
436
+ ----------
437
+ spike_df : pd.DataFrame
438
+ DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
439
+ lfp_data : np.ndarray
440
+ Local field potential (LFP) time series data
441
+ spike_fs : float
442
+ Sampling frequency of the spike times in Hz
443
+ lfp_fs : float
444
+ Sampling frequency of the LFP signal in Hz
445
+ pop_names : List[str]
446
+ List of population names to analyze
447
+ freqs : List[float]
448
+ List of frequencies (in Hz) at which to calculate PPC
388
449
 
389
- Returns:
390
- Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
391
- {
392
- population_name: {
393
- node_id: {
394
- frequency: PPC value
395
- }
450
+ Returns
451
+ -------
452
+ Dict[str, Dict[int, Dict[float, float]]]
453
+ Nested dictionary where the structure is:
454
+ {
455
+ population_name: {
456
+ node_id: {
457
+ frequency: PPC value
396
458
  }
397
459
  }
398
- PPC values are floats representing the pairwise phase consistency at each frequency.
460
+ }
461
+ PPC values are floats representing the pairwise phase consistency at each frequency
399
462
  """
400
463
  ppc_dict = {}
401
464
  for pop in pop_names:
@@ -416,11 +479,12 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
416
479
  for freq in freqs:
417
480
  ppc = calculate_ppc2(
418
481
  node_spikes['timestamps'].values,
419
- lfp_signal,
482
+ lfp_data,
420
483
  spike_fs=spike_fs,
421
484
  lfp_fs=lfp_fs,
422
485
  freq_of_interest=freq,
423
- method='wavelet'
486
+ bandwidth=bandwidth,
487
+ filter_method='wavelet'
424
488
  )
425
489
  ppc_dict[pop][node][freq] = ppc
426
490
 
@@ -429,7 +493,9 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
429
493
  return ppc_dict
430
494
 
431
495
 
432
- def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_range=(10, 100), freq_step=5):
496
+ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
497
+ bandwidth=2.0, lowcut=None, highcut=None,
498
+ freq_range=(10, 100), freq_step=5):
433
499
  """
434
500
  Calculate correlation between population spike rates and LFP power across frequencies
435
501
  using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
@@ -438,16 +504,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
438
504
  -----------
439
505
  spike_rate : DataFrame
440
506
  Pre-calculated population spike rates at the same fs as lfp
441
- lfp : np.array
507
+ lfp_data : np.array
442
508
  LFP data
443
509
  fs : float
444
510
  Sampling frequency
445
511
  pop_names : list
446
512
  List of population names to analyze
447
- freq_range : tuple
448
- Min and max frequency to analyze
449
- freq_step : float
450
- Step size for frequency analysis
513
+ filter_method : str, optional
514
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
515
+ bandwidth : float, optional
516
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
517
+ lowcut : float, optional
518
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
519
+ highcut : float, optional
520
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
521
+ freq_range : tuple, optional
522
+ Min and max frequency to analyze (default: (10, 100))
523
+ freq_step : float, optional
524
+ Step size for frequency analysis (default: 5)
451
525
 
452
526
  Returns:
453
527
  --------
@@ -463,14 +537,15 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
463
537
  # Dictionary to store results
464
538
  correlation_results = {pop: {} for pop in pop_names}
465
539
 
466
- # Calculate power at each frequency band using wavelet filter
540
+ # Calculate power at each frequency band using specified filter
467
541
  power_by_freq = {}
468
542
  for freq in frequencies:
469
- # Use the wavelet_filter function from bmlfp
470
- filtered_signal = wavelet_filter(lfp, freq, fs)
471
- # Calculate power (magnitude squared of complex wavelet transform)
472
- power = np.abs(filtered_signal)**2
473
- power_by_freq[freq] = power
543
+ if filter_method == 'wavelet':
544
+ power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
545
+ lowcut=None, highcut=None, bandwidth=bandwidth)
546
+ elif filter_method == 'butter':
547
+ power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
548
+ lowcut=lowcut, highcut=highcut)
474
549
 
475
550
  # Calculate correlation for each population
476
551
  for pop in pop_names:
@@ -481,7 +556,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
481
556
  for freq in frequencies:
482
557
  # Make sure the lengths match
483
558
  if len(pop_rate) != len(power_by_freq[freq]):
484
- raise Exception(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
559
+ raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
485
560
  # use spearman for non-parametric correlation
486
561
  corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
487
562
  correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
bmtool/analysis/lfp.py CHANGED
@@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
273
273
  return normalized_power
274
274
 
275
275
 
276
- def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
276
+ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
277
+ """
278
+ Calculate the passband of a complex Morlet wavelet filter.
279
+
280
+ Parameters
281
+ ----------
282
+ center_freq : float
283
+ Center frequency (Hz) of the wavelet filter
284
+ bandwidth : float
285
+ Bandwidth parameter of the wavelet filter
286
+ threshold : float, optional
287
+ Power threshold to define the passband edges (default: 0.5 = -3dB point)
288
+
289
+ Returns
290
+ -------
291
+ tuple
292
+ (lower_bound, upper_bound, passband_width) of the frequency passband in Hz
293
+ """
294
+ # Create a high-resolution frequency axis around the center frequency
295
+ # Extend range to 3x the expected width to ensure we capture the full passband
296
+ expected_width = center_freq * bandwidth / 2
297
+ freq_min = max(0.1, center_freq - 3 * expected_width)
298
+ freq_max = center_freq + 3 * expected_width
299
+ freq_axis = np.linspace(freq_min, freq_max, 1000)
300
+
301
+ # Calculate the theoretical frequency response of the Morlet wavelet
302
+ # For a complex Morlet wavelet, the frequency response approximates a Gaussian
303
+ # centered at the center frequency with width related to the bandwidth parameter
304
+ sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
305
+ response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
306
+
307
+ # Find the passband edges (where response crosses the threshold)
308
+ above_threshold = response >= threshold
309
+ if not np.any(above_threshold):
310
+ return (center_freq, center_freq, 0) # No passband found
311
+
312
+ # Find the first and last indices where response is above threshold
313
+ indices = np.where(above_threshold)[0]
314
+ lower_idx = indices[0]
315
+ upper_idx = indices[-1]
316
+
317
+ # Get the corresponding frequencies
318
+ lower_bound = freq_axis[lower_idx]
319
+ upper_bound = freq_axis[upper_idx]
320
+ passband_width = upper_bound - lower_bound
321
+
322
+ return (lower_bound, upper_bound, passband_width)
323
+
324
+
325
+ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
277
326
  """
278
327
  Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
328
+
329
+ Parameters
330
+ ----------
331
+ x : np.ndarray
332
+ Input signal
333
+ freq : float
334
+ Target frequency for the wavelet filter
335
+ fs : float
336
+ Sampling frequency of the signal
337
+ bandwidth : float, optional
338
+ Bandwidth parameter of the wavelet filter (default is 1.0)
339
+ axis : int, optional
340
+ Axis along which to compute the CWT (default is -1)
341
+ show_passband : bool, optional
342
+ If True, print the passband of the wavelet filter (default is False)
343
+
344
+ Returns
345
+ -------
346
+ np.ndarray
347
+ Continuous Wavelet Transform of the input signal
279
348
  """
349
+ if show_passband:
350
+ lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
351
+ print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
352
+ print(f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
280
353
  wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
281
354
  scale = pywt.scale2frequency(wavelet, 1) * fs / freq
282
355
  x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
@@ -292,6 +365,53 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
292
365
  return x_a
293
366
 
294
367
 
368
+ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
369
+ lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
370
+ """
371
+ Compute the power of the raw LFP signal in a specified frequency band.
372
+
373
+ Parameters
374
+ ----------
375
+ lfp_data : np.ndarray
376
+ Raw local field potential (LFP) time series data
377
+ freq : float
378
+ Center frequency (Hz) for wavelet filtering method
379
+ fs : float
380
+ Sampling frequency (Hz) of the input data
381
+ filter_method : str, optional
382
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
383
+ lowcut : float, optional
384
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
385
+ highcut : float, optional
386
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
387
+ bandwidth : float, optional
388
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
389
+
390
+ Returns
391
+ -------
392
+ np.ndarray
393
+ Power of the filtered signal (magnitude squared)
394
+
395
+ Notes
396
+ -----
397
+ - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
398
+ - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
399
+ - When using the 'butter' method, both lowcut and highcut must be provided
400
+ """
401
+ if filter_method == 'wavelet':
402
+ filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
403
+ elif filter_method == 'butter':
404
+ if lowcut is None or highcut is None:
405
+ raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
406
+ filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
407
+ else:
408
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
409
+
410
+ # Calculate power (magnitude squared of filtered signal)
411
+ power = np.abs(filtered_signal)**2
412
+ return power
413
+
414
+
295
415
  # windowing functions
296
416
  def windowed_xarray(da, windows, dim='time',
297
417
  new_coord_name='cycle', new_coord=None):
bmtool/analysis/spikes.py CHANGED
@@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu
11
11
  import os
12
12
 
13
13
 
14
- def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
14
+ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
15
15
  """
16
16
  Load spike data from an HDF5 file into a pandas DataFrame.
17
17
 
18
- Args:
19
- spike_file (str): Path to the HDF5 file containing spike data.
20
- network_name (str): The name of the network within the HDF5 file from which to load spike data.
21
- sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
22
- config (str, optional): Will label the cell type of each spike.
23
- groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
24
-
25
- Returns:
26
- pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
18
+ Parameters
19
+ ----------
20
+ spike_file : str
21
+ Path to the HDF5 file containing spike data
22
+ network_name : str
23
+ The name of the network within the HDF5 file from which to load spike data
24
+ sort : bool, optional
25
+ Whether to sort the DataFrame by 'timestamps' (default: True)
26
+ config : str, optional
27
+ Path to configuration file to label the cell type of each spike (default: None)
28
+ groupby : Union[str, List[str]], optional
29
+ The column(s) to group by (default: 'pop_name')
30
+
31
+ Returns
32
+ -------
33
+ pd.DataFrame
34
+ A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
35
+ with additional columns if a config file is provided
27
36
 
28
- Example:
29
- df = load_spikes_to_df("spikes.h5", "cortex")
37
+ Examples
38
+ --------
39
+ >>> df = load_spikes_to_df("spikes.h5", "cortex")
40
+ >>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
30
41
  """
31
42
  with h5py.File(spike_file) as f:
32
43
  spikes_df = pd.DataFrame({
@@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
126
137
 
127
138
 
128
139
  def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
129
- time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
140
+ time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
130
141
  """
131
142
  Calculate the spike count or frequency histogram over specified time intervals.
132
143
 
133
- Args:
134
- spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
135
- time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
136
- Used to create evenly spaced time points if `time_points` is not provided. Default is None.
137
- time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
138
- If provided, `time` is ignored. Default is None.
139
- frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
140
-
141
- Returns:
142
- np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
143
-
144
- Raises:
145
- ValueError: If both `time` and `time_points` are None.
144
+ Parameters
145
+ ----------
146
+ spike_times : Union[np.ndarray, list]
147
+ Array or list of spike times in milliseconds
148
+ time : Optional[Tuple[float, float, float]], optional
149
+ Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
150
+ if `time_points` is not provided. Default is None.
151
+ time_points : Optional[Union[np.ndarray, list]], optional
152
+ Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
153
+ frequency : bool, optional
154
+ If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
155
+
156
+ Returns
157
+ -------
158
+ np.ndarray
159
+ Array of spike counts or frequencies, depending on the `frequency` flag.
160
+
161
+ Raises
162
+ ------
163
+ ValueError
164
+ If both `time` and `time_points` are None.
146
165
  """
147
166
  if time_points is None:
148
167
  if time is None:
@@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
156
175
  bins = np.append(time_points, time_points[-1] + dt)
157
176
  spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
158
177
 
159
- if frequeny:
178
+ if frequency:
160
179
  spike_rate = 1000 / dt * spike_rate
161
180
 
162
181
  return spike_rate
163
182
 
164
183
 
165
- def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
184
+ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
166
185
  config: Optional[str] = None, network_name: Optional[str] = None,
167
186
  save: bool = False, save_path: Optional[str] = None,
168
187
  normalize: bool = False) -> Dict[str, np.ndarray]:
169
188
  """
170
189
  Calculate the population spike rate for each population in the given spike data, with an option to normalize.
171
190
 
172
- Args:
173
- spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
174
- fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
175
- t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
176
- t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
177
- config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
178
- If None, node count is estimated from unique node spikes. Default is None.
179
- network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
180
- Required if `config` is provided. Default is None.
181
- save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
182
- save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
183
- normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
184
-
185
- Returns:
186
- Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
187
- If `normalize` is True, each population's spike rate is scaled to [0, 1].
188
-
189
- Raises:
190
- ValueError: If `save` is True but `save_path` is not provided.
191
-
192
- Notes:
193
- - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
194
- - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
195
-
191
+ Parameters
192
+ ----------
193
+ spike_data : pd.DataFrame
194
+ A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
195
+ fs : float, optional
196
+ Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
197
+ t_start : float, optional
198
+ Start time (in milliseconds) for spike rate calculation (default: 0)
199
+ t_stop : Optional[float], optional
200
+ Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
201
+ config : Optional[str], optional
202
+ Path to a configuration file containing node information, used to determine the correct number of nodes per population.
203
+ If None, node count is estimated from unique node spikes (default: None)
204
+ network_name : Optional[str], optional
205
+ Name of the network used in the configuration file, allowing selection of nodes for that network.
206
+ Required if `config` is provided (default: None)
207
+ save : bool, optional
208
+ Whether to save the calculated population spike rate to a file (default: False)
209
+ save_path : Optional[str], optional
210
+ Directory path where the file should be saved if `save` is True (default: None)
211
+ normalize : bool, optional
212
+ Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
213
+
214
+ Returns
215
+ -------
216
+ Dict[str, np.ndarray]
217
+ A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
218
+ If `normalize` is True, each population's spike rate is scaled to [0, 1].
219
+
220
+ Raises
221
+ ------
222
+ ValueError
223
+ If `save` is True but `save_path` is not provided.
224
+
225
+ Notes
226
+ -----
227
+ - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
228
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
196
229
  """
197
230
  pop_spikes = {}
198
231
  node_number = {}
@@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
205
238
  if not network_name:
206
239
  print("Grabbing first network; specify a network name to ensure correct node population is selected.")
207
240
 
208
- for pop_name in spikes['pop_name'].unique():
209
- ps = spikes[spikes['pop_name'] == pop_name]
241
+ for pop_name in spike_data['pop_name'].unique():
242
+ ps = spike_data[spike_data['pop_name'] == pop_name]
210
243
 
211
244
  if config:
212
245
  nodes = load_nodes_from_config(config)
@@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
220
253
  node_number[pop_name] = ps['node_ids'].nunique()
221
254
 
222
255
  if t_stop is None:
223
- t_stop = spikes['timestamps'].max()
256
+ t_stop = spike_data['timestamps'].max()
224
257
 
225
- filtered_spikes = spikes[
226
- (spikes['pop_name'] == pop_name) &
227
- (spikes['timestamps'] > t_start) &
228
- (spikes['timestamps'] < t_stop)
258
+ filtered_spikes = spike_data[
259
+ (spike_data['pop_name'] == pop_name) &
260
+ (spike_data['timestamps'] > t_start) &
261
+ (spike_data['timestamps'] < t_stop)
229
262
  ]
230
263
  pop_spikes[pop_name] = filtered_spikes
231
264
 
@@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
254
287
  return spike_rate
255
288
 
256
289
 
257
- def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
290
+ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
258
291
  """
259
- Compares the firing rates of a population during two different time windows
260
- time_window_1 and time_window_2 should be a list of [start, stop] in milliseconds
261
- Returns firing rates and results of a Mann-Whitney U test (non-parametric)
292
+ Compares the firing rates of a population during two different time windows and performs
293
+ a statistical test to determine if there is a significant difference.
294
+
295
+ Parameters
296
+ ----------
297
+ spike_df : pd.DataFrame
298
+ DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
299
+ group_by : str
300
+ Column name to group spikes by (e.g., 'pop_name')
301
+ time_window_1 : List[float]
302
+ First time window as [start, stop] in milliseconds
303
+ time_window_2 : List[float]
304
+ Second time window as [start, stop] in milliseconds
305
+
306
+ Returns
307
+ -------
308
+ None
309
+ Results are printed to the console
310
+
311
+ Notes
312
+ -----
313
+ Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
262
314
  """
263
315
  # Filter spikes for the population of interest
264
316
  for pop_name in spike_df[group_by].unique():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.0.3
3
+ Version: 0.7.0.4
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
8
8
  bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
9
9
  bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=IMhjbLYw-rL-MfRuFT3uCkyUFObNVJhcxmYV0R9Uh-M,20007
12
- bmtool/analysis/lfp.py,sha256=3dZkpyVqDtmssvxAqvX4zOLafhwmFnyMehSgVvnj5lM,16754
11
+ bmtool/analysis/entrainment.py,sha256=JRg9sQ7WrZMqHwMaDJtCN7kGgZHJ5msUSzP-JPltC8k,23158
12
+ bmtool/analysis/lfp.py,sha256=hOqD4xcDEL0NrNIN2-Ler_mkvY5cEUhxr7VUdX5Gwh8,21737
13
13
  bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
14
- bmtool/analysis/spikes.py,sha256=x24kd0RUhumJkiunfHNEE7mM6JUqdWy1gqabmkMM4cU,14129
14
+ bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
17
17
  bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
26
26
  bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
29
- bmtool-0.7.0.3.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.0.3.dist-info/METADATA,sha256=3jYm8B3kDgHTyLyl5C5zdvxZ8tiGoC4hJvAiVu5QLSM,2768
31
- bmtool-0.7.0.3.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
- bmtool-0.7.0.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.0.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.0.3.dist-info/RECORD,,
29
+ bmtool-0.7.0.4.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.0.4.dist-info/METADATA,sha256=I-D4fwZQIHcHvD6Ou8Az3Kclwl17kwpZH7YHrD0eEg4,2768
31
+ bmtool-0.7.0.4.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
+ bmtool-0.7.0.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.0.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.0.4.dist-info/RECORD,,