bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +317 -234
 - bmtool/analysis/lfp.py +176 -1
 - bmtool/analysis/spikes.py +115 -63
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/METADATA +1 -1
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/RECORD +9 -9
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/WHEEL +0 -0
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/entry_points.txt +0 -0
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/licenses/LICENSE +0 -0
 - {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.5.dist-info}/top_level.txt +0 -0
 
    
        bmtool/analysis/entrainment.py
    CHANGED
    
    | 
         @@ -7,58 +7,71 @@ from scipy import signal 
     | 
|
| 
       7 
7 
     | 
    
         
             
            import numba
         
     | 
| 
       8 
8 
     | 
    
         
             
            from numba import cuda
         
     | 
| 
       9 
9 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       10 
     | 
    
         
            -
            import  
     | 
| 
       11 
     | 
    
         
            -
            from  
     | 
| 
       12 
     | 
    
         
            -
            from typing import Dict, List
         
     | 
| 
      
 10 
     | 
    
         
            +
            from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
         
     | 
| 
      
 11 
     | 
    
         
            +
            from typing import Dict, List, Optional
         
     | 
| 
       13 
12 
     | 
    
         
             
            from tqdm.notebook import tqdm
         
     | 
| 
       14 
13 
     | 
    
         
             
            import scipy.stats as stats
         
     | 
| 
       15 
     | 
    
         
            -
            import seaborn as sns
         
     | 
| 
       16 
     | 
    
         
            -
            import matplotlib.pyplot as plt
         
     | 
| 
       17 
14 
     | 
    
         | 
| 
       18 
15 
     | 
    
         | 
| 
       19 
     | 
    
         
            -
            def calculate_signal_signal_plv( 
     | 
| 
       20 
     | 
    
         
            -
                               
     | 
| 
      
 16 
     | 
    
         
            +
            def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None, 
         
     | 
| 
      
 17 
     | 
    
         
            +
                              filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None, 
         
     | 
| 
       21 
18 
     | 
    
         
             
                              bandwidth: float = 2.0) -> np.ndarray:
         
     | 
| 
       22 
19 
     | 
    
         
             
                """
         
     | 
| 
       23 
20 
     | 
    
         
             
                Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
         
     | 
| 
       24 
21 
     | 
    
         | 
| 
       25 
     | 
    
         
            -
                Parameters 
     | 
| 
       26 
     | 
    
         
            -
                 
     | 
| 
       27 
     | 
    
         
            -
                 
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
                 
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
                 
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
                 
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
      
 22 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 23 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 24 
     | 
    
         
            +
                signal1 : np.ndarray
         
     | 
| 
      
 25 
     | 
    
         
            +
                    First input signal (1D array)
         
     | 
| 
      
 26 
     | 
    
         
            +
                signal2 : np.ndarray
         
     | 
| 
      
 27 
     | 
    
         
            +
                    Second input signal (1D array, same length as signal1)
         
     | 
| 
      
 28 
     | 
    
         
            +
                fs : float
         
     | 
| 
      
 29 
     | 
    
         
            +
                    Sampling frequency in Hz
         
     | 
| 
      
 30 
     | 
    
         
            +
                freq_of_interest : float, optional
         
     | 
| 
      
 31 
     | 
    
         
            +
                    Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
         
     | 
| 
      
 32 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 33 
     | 
    
         
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 34 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 35 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 36 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 37 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 38 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 39 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 40 
     | 
    
         
            +
                
         
     | 
| 
      
 41 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 42 
     | 
    
         
            +
                -------
         
     | 
| 
      
 43 
     | 
    
         
            +
                np.ndarray
         
     | 
| 
      
 44 
     | 
    
         
            +
                    Phase Locking Value (1D array)
         
     | 
| 
       35 
45 
     | 
    
         
             
                """
         
     | 
| 
       36 
     | 
    
         
            -
                if len( 
     | 
| 
      
 46 
     | 
    
         
            +
                if len(signal1) != len(signal2):
         
     | 
| 
       37 
47 
     | 
    
         
             
                    raise ValueError("Input signals must have the same length.")
         
     | 
| 
       38 
48 
     | 
    
         | 
| 
       39 
     | 
    
         
            -
                if  
     | 
| 
      
 49 
     | 
    
         
            +
                if filter_method == 'wavelet':
         
     | 
| 
       40 
50 
     | 
    
         
             
                    if freq_of_interest is None:
         
     | 
| 
       41 
51 
     | 
    
         
             
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         
     | 
| 
       42 
52 
     | 
    
         | 
| 
       43 
53 
     | 
    
         
             
                    # Apply CWT to both signals
         
     | 
| 
       44 
     | 
    
         
            -
                    theta1 = wavelet_filter(x= 
     | 
| 
       45 
     | 
    
         
            -
                    theta2 = wavelet_filter(x= 
     | 
| 
      
 54 
     | 
    
         
            +
                    theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
         
     | 
| 
      
 55 
     | 
    
         
            +
                    theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
         
     | 
| 
       46 
56 
     | 
    
         | 
| 
       47 
     | 
    
         
            -
                elif  
     | 
| 
      
 57 
     | 
    
         
            +
                elif filter_method == 'butter':
         
     | 
| 
       48 
58 
     | 
    
         
             
                    if lowcut is None or highcut is None:
         
     | 
| 
       49 
     | 
    
         
            -
                        print("Lowcut and 
     | 
| 
      
 59 
     | 
    
         
            +
                        print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
         
     | 
| 
       50 
60 
     | 
    
         | 
| 
       51 
61 
     | 
    
         
             
                    if lowcut and highcut:
         
     | 
| 
       52 
62 
     | 
    
         
             
                        # Bandpass filter and get the analytic signal using the Hilbert transform
         
     | 
| 
       53 
     | 
    
         
            -
                         
     | 
| 
       54 
     | 
    
         
            -
                         
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
                     
     | 
| 
      
 63 
     | 
    
         
            +
                        filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
         
     | 
| 
      
 64 
     | 
    
         
            +
                        filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
         
     | 
| 
      
 65 
     | 
    
         
            +
                        # Get phase using the Hilbert transform
         
     | 
| 
      
 66 
     | 
    
         
            +
                        theta1 = signal.hilbert(filtered_signal1)
         
     | 
| 
      
 67 
     | 
    
         
            +
                        theta2 = signal.hilbert(filtered_signal2)
         
     | 
| 
      
 68 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 69 
     | 
    
         
            +
                        # Get phase using the Hilbert transform without filtering
         
     | 
| 
      
 70 
     | 
    
         
            +
                        theta1 = signal.hilbert(signal1)
         
     | 
| 
      
 71 
     | 
    
         
            +
                        theta2 = signal.hilbert(signal2)
         
     | 
| 
       59 
72 
     | 
    
         | 
| 
       60 
73 
     | 
    
         
             
                else:
         
     | 
| 
       61 
     | 
    
         
            -
                    raise ValueError("Invalid method. Choose 'wavelet' or ' 
     | 
| 
      
 74 
     | 
    
         
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         
     | 
| 
       62 
75 
     | 
    
         | 
| 
       63 
76 
     | 
    
         
             
                # Calculate phase difference
         
     | 
| 
       64 
77 
     | 
    
         
             
                phase_diff = np.angle(theta1) - np.angle(theta2)
         
     | 
| 
         @@ -69,29 +82,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_ 
     | 
|
| 
       69 
82 
     | 
    
         
             
                return plv
         
     | 
| 
       70 
83 
     | 
    
         | 
| 
       71 
84 
     | 
    
         | 
| 
       72 
     | 
    
         
            -
            def calculate_spike_lfp_plv(spike_times: np.ndarray = None,  
     | 
| 
       73 
     | 
    
         
            -
                               lfp_fs: float = None,  
     | 
| 
       74 
     | 
    
         
            -
                               lowcut: float = None, highcut: float = None,
         
     | 
| 
       75 
     | 
    
         
            -
                                
     | 
| 
      
 85 
     | 
    
         
            +
            def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         
     | 
| 
      
 86 
     | 
    
         
            +
                               lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
         
     | 
| 
      
 87 
     | 
    
         
            +
                               lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
         
     | 
| 
      
 88 
     | 
    
         
            +
                               filtered_lfp_phase: np.ndarray = None) -> float:
         
     | 
| 
       76 
89 
     | 
    
         
             
                """
         
     | 
| 
       77 
     | 
    
         
            -
                Calculate spike-lfp phase locking value  
     | 
| 
       78 
     | 
    
         
            -
                
         
     | 
| 
       79 
     | 
    
         
            -
                Parameters 
     | 
| 
       80 
     | 
    
         
            -
                 
     | 
| 
       81 
     | 
    
         
            -
                 
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
                 
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
                 
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
                 
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
                 
     | 
| 
       90 
     | 
    
         
            -
             
     | 
| 
       91 
     | 
    
         
            -
                 
     | 
| 
      
 90 
     | 
    
         
            +
                Calculate spike-lfp unbiased phase locking value 
         
     | 
| 
      
 91 
     | 
    
         
            +
                
         
     | 
| 
      
 92 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 93 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 94 
     | 
    
         
            +
                spike_times : np.ndarray
         
     | 
| 
      
 95 
     | 
    
         
            +
                    Array of spike times
         
     | 
| 
      
 96 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 97 
     | 
    
         
            +
                    Local field potential time series data. Not required if filtered_lfp_phase is provided.
         
     | 
| 
      
 98 
     | 
    
         
            +
                spike_fs : float, optional
         
     | 
| 
      
 99 
     | 
    
         
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         
     | 
| 
      
 100 
     | 
    
         
            +
                lfp_fs : float
         
     | 
| 
      
 101 
     | 
    
         
            +
                    Sampling frequency in Hz of the LFP data
         
     | 
| 
      
 102 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 103 
     | 
    
         
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
         
     | 
| 
      
 104 
     | 
    
         
            +
                freq_of_interest : float, optional
         
     | 
| 
      
 105 
     | 
    
         
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         
     | 
| 
      
 106 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 107 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 108 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 109 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 110 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 111 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 112 
     | 
    
         
            +
                filtered_lfp_phase : np.ndarray, optional
         
     | 
| 
      
 113 
     | 
    
         
            +
                    Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
         
     | 
| 
      
 114 
     | 
    
         
            +
                
         
     | 
| 
      
 115 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 116 
     | 
    
         
            +
                -------
         
     | 
| 
      
 117 
     | 
    
         
            +
                float
         
     | 
| 
      
 118 
     | 
    
         
            +
                    Phase Locking Value (unbiased)
         
     | 
| 
       92 
119 
     | 
    
         
             
                """
         
     | 
| 
       93 
120 
     | 
    
         | 
| 
       94 
     | 
    
         
            -
                if spike_fs  
     | 
| 
      
 121 
     | 
    
         
            +
                if spike_fs is None:
         
     | 
| 
       95 
122 
     | 
    
         
             
                    spike_fs = lfp_fs
         
     | 
| 
       96 
123 
     | 
    
         
             
                # Convert spike times to sample indices
         
     | 
| 
       97 
124 
     | 
    
         
             
                spike_times_seconds = spike_times / spike_fs
         
     | 
| 
         @@ -100,50 +127,41 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr 
     | 
|
| 
       100 
127 
     | 
    
         
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         
     | 
| 
       101 
128 
     | 
    
         | 
| 
       102 
129 
     | 
    
         
             
                # Filter indices to ensure they're within bounds of the LFP signal
         
     | 
| 
       103 
     | 
    
         
            -
                 
     | 
| 
      
 130 
     | 
    
         
            +
                if filtered_lfp_phase is not None:
         
     | 
| 
      
 131 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
         
     | 
| 
      
 132 
     | 
    
         
            +
                else:
         
     | 
| 
      
 133 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         
     | 
| 
      
 134 
     | 
    
         
            +
                    
         
     | 
| 
       104 
135 
     | 
    
         
             
                if len(valid_indices) <= 1:
         
     | 
| 
       105 
     | 
    
         
            -
                    return 0 
     | 
| 
      
 136 
     | 
    
         
            +
                    return 0
         
     | 
| 
       106 
137 
     | 
    
         | 
| 
       107 
     | 
    
         
            -
                #  
     | 
| 
       108 
     | 
    
         
            -
                if  
     | 
| 
       109 
     | 
    
         
            -
                     
     | 
| 
       110 
     | 
    
         
            -
             
     | 
| 
       111 
     | 
    
         
            -
             
     | 
| 
       112 
     | 
    
         
            -
                    # Apply CWT to extract phase at the frequency of interest
         
     | 
| 
       113 
     | 
    
         
            -
                    lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         
     | 
| 
       114 
     | 
    
         
            -
                    instantaneous_phase = np.angle(lfp_complex)
         
     | 
| 
       115 
     | 
    
         
            -
                    
         
     | 
| 
       116 
     | 
    
         
            -
                elif method == 'hilbert':
         
     | 
| 
       117 
     | 
    
         
            -
                    if lowcut is None or highcut is None:
         
     | 
| 
       118 
     | 
    
         
            -
                        print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
         
     | 
| 
       119 
     | 
    
         
            -
                        filtered_lfp = lfp_signal
         
     | 
| 
       120 
     | 
    
         
            -
                    else:
         
     | 
| 
       121 
     | 
    
         
            -
                        # Bandpass filter the signal
         
     | 
| 
       122 
     | 
    
         
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         
     | 
| 
       123 
     | 
    
         
            -
                    
         
     | 
| 
       124 
     | 
    
         
            -
                    # Get phase using the Hilbert transform
         
     | 
| 
       125 
     | 
    
         
            -
                    analytic_signal = signal.hilbert(filtered_lfp)
         
     | 
| 
       126 
     | 
    
         
            -
                    instantaneous_phase = np.angle(analytic_signal)
         
     | 
| 
       127 
     | 
    
         
            -
                    
         
     | 
| 
      
 138 
     | 
    
         
            +
                # Get instantaneous phase
         
     | 
| 
      
 139 
     | 
    
         
            +
                if filtered_lfp_phase is None:
         
     | 
| 
      
 140 
     | 
    
         
            +
                    instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method, 
         
     | 
| 
      
 141 
     | 
    
         
            +
                                                       freq_of_interest=freq_of_interest, lowcut=lowcut, 
         
     | 
| 
      
 142 
     | 
    
         
            +
                                                       highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
         
     | 
| 
       128 
143 
     | 
    
         
             
                else:
         
     | 
| 
       129 
     | 
    
         
            -
                     
     | 
| 
      
 144 
     | 
    
         
            +
                    instantaneous_phase = filtered_lfp_phase
         
     | 
| 
       130 
145 
     | 
    
         | 
| 
       131 
146 
     | 
    
         
             
                # Get phases at spike times
         
     | 
| 
       132 
147 
     | 
    
         
             
                spike_phases = instantaneous_phase[valid_indices]
         
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                #  
     | 
| 
       135 
     | 
    
         
            -
                 
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
                # Number of spikes
         
     | 
| 
      
 150 
     | 
    
         
            +
                N = len(spike_phases)
         
     | 
| 
       136 
151 
     | 
    
         | 
| 
       137 
152 
     | 
    
         
             
                # Convert phases to unit vectors in the complex plane
         
     | 
| 
       138 
153 
     | 
    
         
             
                unit_vectors = np.exp(1j * spike_phases)
         
     | 
| 
       139 
154 
     | 
    
         | 
| 
       140 
     | 
    
         
            -
                #  
     | 
| 
      
 155 
     | 
    
         
            +
                # Sum of all unit vectors (resultant vector)
         
     | 
| 
       141 
156 
     | 
    
         
             
                resultant_vector = np.sum(unit_vectors)
         
     | 
| 
       142 
     | 
    
         
            -
             
     | 
| 
       143 
     | 
    
         
            -
                #  
     | 
| 
       144 
     | 
    
         
            -
                 
     | 
| 
       145 
     | 
    
         
            -
                
         
     | 
| 
       146 
     | 
    
         
            -
                 
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
      
 158 
     | 
    
         
            +
                # Calculate plv^2 * N
         
     | 
| 
      
 159 
     | 
    
         
            +
                plv2n = (resultant_vector * resultant_vector.conjugate()).real / N  # plv^2 * N
         
     | 
| 
      
 160 
     | 
    
         
            +
                plv = (plv2n / N) ** 0.5
         
     | 
| 
      
 161 
     | 
    
         
            +
                ppc = (plv2n - 1) / (N - 1)  # ppc = (plv^2 * N - 1) / (N - 1)
         
     | 
| 
      
 162 
     | 
    
         
            +
                plv_unbiased = np.fmax(ppc, 0.) ** 0.5  # ensure non-negative
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                return plv_unbiased
         
     | 
| 
       147 
165 
     | 
    
         | 
| 
       148 
166 
     | 
    
         | 
| 
       149 
167 
     | 
    
         
             
            @numba.njit(parallel=True, fastmath=True)
         
     | 
| 
         @@ -181,27 +199,43 @@ def _ppc_gpu(spike_phases): 
     | 
|
| 
       181 
199 
     | 
    
         
             
                return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
         
     | 
| 
       182 
200 
     | 
    
         | 
| 
       183 
201 
     | 
    
         | 
| 
       184 
     | 
    
         
            -
            def calculate_ppc(spike_times: np.ndarray = None,  
     | 
| 
       185 
     | 
    
         
            -
                              lfp_fs: float = None,  
     | 
| 
       186 
     | 
    
         
            -
                              lowcut: float = None, highcut: float = None,
         
     | 
| 
       187 
     | 
    
         
            -
                               
     | 
| 
      
 202 
     | 
    
         
            +
            def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         
     | 
| 
      
 203 
     | 
    
         
            +
                              lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
         
     | 
| 
      
 204 
     | 
    
         
            +
                              lowcut: float = None, highcut: float = None, bandwidth: float = 2.0, 
         
     | 
| 
      
 205 
     | 
    
         
            +
                              ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
         
     | 
| 
       188 
206 
     | 
    
         
             
                """
         
     | 
| 
       189 
207 
     | 
    
         
             
                Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
         
     | 
| 
       190 
208 
     | 
    
         
             
                Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
         
     | 
| 
       191 
209 
     | 
    
         | 
| 
       192 
     | 
    
         
            -
                Parameters 
     | 
| 
       193 
     | 
    
         
            -
                 
     | 
| 
       194 
     | 
    
         
            -
                 
     | 
| 
       195 
     | 
    
         
            -
             
     | 
| 
       196 
     | 
    
         
            -
                 
     | 
| 
       197 
     | 
    
         
            -
             
     | 
| 
       198 
     | 
    
         
            -
                 
     | 
| 
       199 
     | 
    
         
            -
             
     | 
| 
       200 
     | 
    
         
            -
                 
     | 
| 
       201 
     | 
    
         
            -
             
     | 
| 
       202 
     | 
    
         
            -
                
         
     | 
| 
       203 
     | 
    
         
            -
             
     | 
| 
       204 
     | 
    
         
            -
                 
     | 
| 
      
 210 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 211 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 212 
     | 
    
         
            +
                spike_times : np.ndarray
         
     | 
| 
      
 213 
     | 
    
         
            +
                    Array of spike times
         
     | 
| 
      
 214 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 215 
     | 
    
         
            +
                    Local field potential time series data. Not required if filtered_lfp_phase is provided.
         
     | 
| 
      
 216 
     | 
    
         
            +
                spike_fs : float, optional
         
     | 
| 
      
 217 
     | 
    
         
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         
     | 
| 
      
 218 
     | 
    
         
            +
                lfp_fs : float
         
     | 
| 
      
 219 
     | 
    
         
            +
                    Sampling frequency in Hz of the LFP data
         
     | 
| 
      
 220 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 221 
     | 
    
         
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 222 
     | 
    
         
            +
                freq_of_interest : float, optional
         
     | 
| 
      
 223 
     | 
    
         
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         
     | 
| 
      
 224 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 225 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 226 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 227 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 228 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 229 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 230 
     | 
    
         
            +
                ppc_method : str, optional
         
     | 
| 
      
 231 
     | 
    
         
            +
                    Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
         
     | 
| 
      
 232 
     | 
    
         
            +
                filtered_lfp_phase : np.ndarray, optional
         
     | 
| 
      
 233 
     | 
    
         
            +
                    Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
         
     | 
| 
      
 234 
     | 
    
         
            +
                
         
     | 
| 
      
 235 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 236 
     | 
    
         
            +
                -------
         
     | 
| 
      
 237 
     | 
    
         
            +
                float
         
     | 
| 
      
 238 
     | 
    
         
            +
                    Pairwise Phase Consistency value
         
     | 
| 
       205 
239 
     | 
    
         
             
                """
         
     | 
| 
       206 
240 
     | 
    
         
             
                if spike_fs is None:
         
     | 
| 
       207 
241 
     | 
    
         
             
                    spike_fs = lfp_fs
         
     | 
| 
         @@ -212,33 +246,21 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, 
     | 
|
| 
       212 
246 
     | 
    
         
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         
     | 
| 
       213 
247 
     | 
    
         | 
| 
       214 
248 
     | 
    
         
             
                # Filter indices to ensure they're within bounds of the LFP signal
         
     | 
| 
       215 
     | 
    
         
            -
                 
     | 
| 
      
 249 
     | 
    
         
            +
                if filtered_lfp_phase is not None:
         
     | 
| 
      
 250 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
         
     | 
| 
      
 251 
     | 
    
         
            +
                else:
         
     | 
| 
      
 252 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         
     | 
| 
      
 253 
     | 
    
         
            +
                    
         
     | 
| 
       216 
254 
     | 
    
         
             
                if len(valid_indices) <= 1:
         
     | 
| 
       217 
     | 
    
         
            -
                    return 0 
     | 
| 
      
 255 
     | 
    
         
            +
                    return 0
         
     | 
| 
       218 
256 
     | 
    
         | 
| 
       219 
     | 
    
         
            -
                #  
     | 
| 
       220 
     | 
    
         
            -
                if  
     | 
| 
       221 
     | 
    
         
            -
                     
     | 
| 
       222 
     | 
    
         
            -
             
     | 
| 
       223 
     | 
    
         
            -
             
     | 
| 
       224 
     | 
    
         
            -
                    # Apply CWT to extract phase at the frequency of interest
         
     | 
| 
       225 
     | 
    
         
            -
                    lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         
     | 
| 
       226 
     | 
    
         
            -
                    instantaneous_phase = np.angle(lfp_complex)
         
     | 
| 
       227 
     | 
    
         
            -
                    
         
     | 
| 
       228 
     | 
    
         
            -
                elif method == 'hilbert':
         
     | 
| 
       229 
     | 
    
         
            -
                    if lowcut is None or highcut is None:
         
     | 
| 
       230 
     | 
    
         
            -
                        print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
         
     | 
| 
       231 
     | 
    
         
            -
                        filtered_lfp = lfp_signal
         
     | 
| 
       232 
     | 
    
         
            -
                    else:
         
     | 
| 
       233 
     | 
    
         
            -
                        # Bandpass filter the signal
         
     | 
| 
       234 
     | 
    
         
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         
     | 
| 
       235 
     | 
    
         
            -
                    
         
     | 
| 
       236 
     | 
    
         
            -
                    # Get phase using the Hilbert transform
         
     | 
| 
       237 
     | 
    
         
            -
                    analytic_signal = signal.hilbert(filtered_lfp)
         
     | 
| 
       238 
     | 
    
         
            -
                    instantaneous_phase = np.angle(analytic_signal)
         
     | 
| 
       239 
     | 
    
         
            -
                    
         
     | 
| 
      
 257 
     | 
    
         
            +
                # Get instantaneous phase
         
     | 
| 
      
 258 
     | 
    
         
            +
                if filtered_lfp_phase is None:
         
     | 
| 
      
 259 
     | 
    
         
            +
                    instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method, 
         
     | 
| 
      
 260 
     | 
    
         
            +
                                                       freq_of_interest=freq_of_interest, lowcut=lowcut, 
         
     | 
| 
      
 261 
     | 
    
         
            +
                                                       highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
         
     | 
| 
       240 
262 
     | 
    
         
             
                else:
         
     | 
| 
       241 
     | 
    
         
            -
                     
     | 
| 
      
 263 
     | 
    
         
            +
                    instantaneous_phase = filtered_lfp_phase
         
     | 
| 
       242 
264 
     | 
    
         | 
| 
       243 
265 
     | 
    
         
             
                # Get phases at spike times
         
     | 
| 
       244 
266 
     | 
    
         
             
                spike_phases = instantaneous_phase[valid_indices]
         
     | 
| 
         @@ -247,28 +269,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, 
     | 
|
| 
       247 
269 
     | 
    
         | 
| 
       248 
270 
     | 
    
         
             
                # Calculate PPC (Pairwise Phase Consistency)
         
     | 
| 
       249 
271 
     | 
    
         
             
                if n_spikes <= 1:
         
     | 
| 
       250 
     | 
    
         
            -
                    return 0 
     | 
| 
      
 272 
     | 
    
         
            +
                    return 0
         
     | 
| 
       251 
273 
     | 
    
         | 
| 
       252 
274 
     | 
    
         
             
                # Explicit calculation of pairwise phase consistency
         
     | 
| 
       253 
     | 
    
         
            -
                 
     | 
| 
       254 
     | 
    
         
            -
                
         
     | 
| 
       255 
     | 
    
         
            -
                # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
         
     | 
| 
       256 
     | 
    
         
            -
                # for i in range(n_spikes - 1):  # For each spike i
         
     | 
| 
       257 
     | 
    
         
            -
                #     for j in range(i + 1, n_spikes):  # For each spike j > i
         
     | 
| 
       258 
     | 
    
         
            -
                #         # Calculate the phase difference between spikes i and j
         
     | 
| 
       259 
     | 
    
         
            -
                #         phase_diff = spike_phases[i] - spike_phases[j]
         
     | 
| 
       260 
     | 
    
         
            -
                        
         
     | 
| 
       261 
     | 
    
         
            -
                #         #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
         
     | 
| 
       262 
     | 
    
         
            -
                #         cos_diff = np.cos(phase_diff)
         
     | 
| 
       263 
     | 
    
         
            -
                        
         
     | 
| 
       264 
     | 
    
         
            -
                #         # Add to the sum
         
     | 
| 
       265 
     | 
    
         
            -
                #         sum_cos_diff += cos_diff
         
     | 
| 
       266 
     | 
    
         
            -
                
         
     | 
| 
       267 
     | 
    
         
            -
                # # Calculate PPC according to the equation
         
     | 
| 
       268 
     | 
    
         
            -
                # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
         
     | 
| 
       269 
     | 
    
         
            -
                # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
         
     | 
| 
       270 
     | 
    
         
            -
                
         
     | 
| 
       271 
     | 
    
         
            -
                # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
         
     | 
| 
      
 275 
     | 
    
         
            +
                # Vectorized computation for efficiency
         
     | 
| 
       272 
276 
     | 
    
         
             
                if ppc_method == 'numpy':
         
     | 
| 
       273 
277 
     | 
    
         
             
                    i, j = np.triu_indices(n_spikes, k=1)
         
     | 
| 
       274 
278 
     | 
    
         
             
                    phase_diff = spike_phases[i] - spike_phases[j]
         
     | 
| 
         @@ -279,14 +283,14 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, 
     | 
|
| 
       279 
283 
     | 
    
         
             
                elif ppc_method == 'gpu':
         
     | 
| 
       280 
284 
     | 
    
         
             
                    ppc = _ppc_gpu(spike_phases)
         
     | 
| 
       281 
285 
     | 
    
         
             
                else:
         
     | 
| 
       282 
     | 
    
         
            -
                    raise  
     | 
| 
      
 286 
     | 
    
         
            +
                    raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
         
     | 
| 
       283 
287 
     | 
    
         
             
                return ppc
         
     | 
| 
       284 
288 
     | 
    
         | 
| 
       285 
289 
     | 
    
         | 
| 
       286 
     | 
    
         
            -
            def calculate_ppc2(spike_times: np.ndarray = None,  
     | 
| 
       287 
     | 
    
         
            -
                              lfp_fs: float = None,  
     | 
| 
       288 
     | 
    
         
            -
                              lowcut: float = None, highcut: float = None,
         
     | 
| 
       289 
     | 
    
         
            -
                               
     | 
| 
      
 290 
     | 
    
         
            +
            def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         
     | 
| 
      
 291 
     | 
    
         
            +
                              lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
         
     | 
| 
      
 292 
     | 
    
         
            +
                              lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
         
     | 
| 
      
 293 
     | 
    
         
            +
                              filtered_lfp_phase: np.ndarray = None) -> float:
         
     | 
| 
       290 
294 
     | 
    
         
             
                """
         
     | 
| 
       291 
295 
     | 
    
         
             
                # -----------------------------------------------------------------------------
         
     | 
| 
       292 
296 
     | 
    
         
             
                # PPC2 Calculation (Vinck et al., 2010) 
         
     | 
| 
         @@ -297,18 +301,33 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None 
     | 
|
| 
       297 
301 
     | 
    
         
             
                #   PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
         
     | 
| 
       298 
302 
     | 
    
         
             
                # -----------------------------------------------------------------------------
         
     | 
| 
       299 
303 
     | 
    
         | 
| 
       300 
     | 
    
         
            -
                Parameters 
     | 
| 
       301 
     | 
    
         
            -
                 
     | 
| 
       302 
     | 
    
         
            -
                 
     | 
| 
       303 
     | 
    
         
            -
             
     | 
| 
       304 
     | 
    
         
            -
                 
     | 
| 
       305 
     | 
    
         
            -
             
     | 
| 
       306 
     | 
    
         
            -
                 
     | 
| 
       307 
     | 
    
         
            -
             
     | 
| 
       308 
     | 
    
         
            -
                 
     | 
| 
       309 
     | 
    
         
            -
             
     | 
| 
       310 
     | 
    
         
            -
                 
     | 
| 
       311 
     | 
    
         
            -
             
     | 
| 
      
 304 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 305 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 306 
     | 
    
         
            +
                spike_times : np.ndarray
         
     | 
| 
      
 307 
     | 
    
         
            +
                    Array of spike times
         
     | 
| 
      
 308 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 309 
     | 
    
         
            +
                    Local field potential time series data. Not required if filtered_lfp_phase is provided.
         
     | 
| 
      
 310 
     | 
    
         
            +
                spike_fs : float, optional
         
     | 
| 
      
 311 
     | 
    
         
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         
     | 
| 
      
 312 
     | 
    
         
            +
                lfp_fs : float
         
     | 
| 
      
 313 
     | 
    
         
            +
                    Sampling frequency in Hz of the LFP data
         
     | 
| 
      
 314 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 315 
     | 
    
         
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 316 
     | 
    
         
            +
                freq_of_interest : float, optional
         
     | 
| 
      
 317 
     | 
    
         
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         
     | 
| 
      
 318 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 319 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 320 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 321 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 322 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 323 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 324 
     | 
    
         
            +
                filtered_lfp_phase : np.ndarray, optional
         
     | 
| 
      
 325 
     | 
    
         
            +
                    Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
         
     | 
| 
      
 326 
     | 
    
         
            +
                
         
     | 
| 
      
 327 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 328 
     | 
    
         
            +
                -------
         
     | 
| 
      
 329 
     | 
    
         
            +
                float
         
     | 
| 
      
 330 
     | 
    
         
            +
                    Pairwise Phase Consistency 2 (PPC2) value
         
     | 
| 
       312 
331 
     | 
    
         
             
                """
         
     | 
| 
       313 
332 
     | 
    
         | 
| 
       314 
333 
     | 
    
         
             
                if spike_fs is None:
         
     | 
| 
         @@ -320,33 +339,21 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None 
     | 
|
| 
       320 
339 
     | 
    
         
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         
     | 
| 
       321 
340 
     | 
    
         | 
| 
       322 
341 
     | 
    
         
             
                # Filter indices to ensure they're within bounds of the LFP signal
         
     | 
| 
       323 
     | 
    
         
            -
                 
     | 
| 
      
 342 
     | 
    
         
            +
                if filtered_lfp_phase is not None:
         
     | 
| 
      
 343 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
         
     | 
| 
      
 344 
     | 
    
         
            +
                else:
         
     | 
| 
      
 345 
     | 
    
         
            +
                    valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         
     | 
| 
      
 346 
     | 
    
         
            +
                    
         
     | 
| 
       324 
347 
     | 
    
         
             
                if len(valid_indices) <= 1:
         
     | 
| 
       325 
     | 
    
         
            -
                    return 0 
     | 
| 
      
 348 
     | 
    
         
            +
                    return 0
         
     | 
| 
       326 
349 
     | 
    
         | 
| 
       327 
     | 
    
         
            -
                #  
     | 
| 
       328 
     | 
    
         
            -
                if  
     | 
| 
       329 
     | 
    
         
            -
                     
     | 
| 
       330 
     | 
    
         
            -
             
     | 
| 
       331 
     | 
    
         
            -
             
     | 
| 
       332 
     | 
    
         
            -
                    # Apply CWT to extract phase at the frequency of interest
         
     | 
| 
       333 
     | 
    
         
            -
                    lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         
     | 
| 
       334 
     | 
    
         
            -
                    instantaneous_phase = np.angle(lfp_complex)
         
     | 
| 
       335 
     | 
    
         
            -
                    
         
     | 
| 
       336 
     | 
    
         
            -
                elif method == 'hilbert':
         
     | 
| 
       337 
     | 
    
         
            -
                    if lowcut is None or highcut is None:
         
     | 
| 
       338 
     | 
    
         
            -
                        print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
         
     | 
| 
       339 
     | 
    
         
            -
                        filtered_lfp = lfp_signal
         
     | 
| 
       340 
     | 
    
         
            -
                    else:
         
     | 
| 
       341 
     | 
    
         
            -
                        # Bandpass filter the signal
         
     | 
| 
       342 
     | 
    
         
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         
     | 
| 
       343 
     | 
    
         
            -
                    
         
     | 
| 
       344 
     | 
    
         
            -
                    # Get phase using the Hilbert transform
         
     | 
| 
       345 
     | 
    
         
            -
                    analytic_signal = signal.hilbert(filtered_lfp)
         
     | 
| 
       346 
     | 
    
         
            -
                    instantaneous_phase = np.angle(analytic_signal)
         
     | 
| 
       347 
     | 
    
         
            -
                    
         
     | 
| 
      
 350 
     | 
    
         
            +
                # Get instantaneous phase
         
     | 
| 
      
 351 
     | 
    
         
            +
                if filtered_lfp_phase is None:
         
     | 
| 
      
 352 
     | 
    
         
            +
                    instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method, 
         
     | 
| 
      
 353 
     | 
    
         
            +
                                                       freq_of_interest=freq_of_interest, lowcut=lowcut, 
         
     | 
| 
      
 354 
     | 
    
         
            +
                                                       highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
         
     | 
| 
       348 
355 
     | 
    
         
             
                else:
         
     | 
| 
       349 
     | 
    
         
            -
                     
     | 
| 
      
 356 
     | 
    
         
            +
                    instantaneous_phase = filtered_lfp_phase
         
     | 
| 
       350 
357 
     | 
    
         | 
| 
       351 
358 
     | 
    
         
             
                # Get phases at spike times
         
     | 
| 
       352 
359 
     | 
    
         
             
                spike_phases = instantaneous_phase[valid_indices]
         
     | 
| 
         @@ -355,7 +362,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None 
     | 
|
| 
       355 
362 
     | 
    
         
             
                n = len(spike_phases)
         
     | 
| 
       356 
363 
     | 
    
         | 
| 
       357 
364 
     | 
    
         
             
                if n <= 1:
         
     | 
| 
       358 
     | 
    
         
            -
                    return 0 
     | 
| 
      
 365 
     | 
    
         
            +
                    return 0
         
     | 
| 
       359 
366 
     | 
    
         | 
| 
       360 
367 
     | 
    
         
             
                # Convert phases to unit vectors in the complex plane
         
     | 
| 
       361 
368 
     | 
    
         
             
                unit_vectors = np.exp(1j * spike_phases)
         
     | 
| 
         @@ -369,40 +376,77 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None 
     | 
|
| 
       369 
376 
     | 
    
         
             
                return ppc2
         
     | 
| 
       370 
377 
     | 
    
         | 
| 
       371 
378 
     | 
    
         | 
| 
       372 
     | 
    
         
            -
            def  
     | 
| 
       373 
     | 
    
         
            -
                                         
     | 
| 
       374 
     | 
    
         
            -
                                         
     | 
| 
      
 379 
     | 
    
         
            +
            def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
         
     | 
| 
      
 380 
     | 
    
         
            +
                                        entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
         
     | 
| 
      
 381 
     | 
    
         
            +
                                        spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
         
     | 
| 
      
 382 
     | 
    
         
            +
                                        freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
         
     | 
| 
       375 
383 
     | 
    
         
             
                """
         
     | 
| 
       376 
     | 
    
         
            -
                Calculate  
     | 
| 
      
 384 
     | 
    
         
            +
                Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
         
     | 
| 
       377 
385 
     | 
    
         | 
| 
       378 
     | 
    
         
            -
                This function computes the  
     | 
| 
       379 
     | 
    
         
            -
                and a  
     | 
| 
      
 386 
     | 
    
         
            +
                This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
         
     | 
| 
      
 387 
     | 
    
         
            +
                and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
         
     | 
| 
      
 388 
     | 
    
         
            +
                organized by population, node ID, and frequency.
         
     | 
| 
       380 
389 
     | 
    
         | 
| 
       381 
     | 
    
         
            -
                 
     | 
| 
       382 
     | 
    
         
            -
             
     | 
| 
       383 
     | 
    
         
            -
             
     | 
| 
       384 
     | 
    
         
            -
                     
     | 
| 
       385 
     | 
    
         
            -
             
     | 
| 
       386 
     | 
    
         
            -
                     
     | 
| 
       387 
     | 
    
         
            -
             
     | 
| 
      
 390 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 391 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 392 
     | 
    
         
            +
                spike_df : pd.DataFrame
         
     | 
| 
      
 393 
     | 
    
         
            +
                    DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
         
     | 
| 
      
 394 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 395 
     | 
    
         
            +
                    Local field potential (LFP) time series data
         
     | 
| 
      
 396 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 397 
     | 
    
         
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 398 
     | 
    
         
            +
                entrainment_method : str, optional
         
     | 
| 
      
 399 
     | 
    
         
            +
                    Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
         
     | 
| 
      
 400 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 401 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 402 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 403 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 404 
     | 
    
         
            +
                spike_fs : float
         
     | 
| 
      
 405 
     | 
    
         
            +
                    Sampling frequency of the spike times in Hz
         
     | 
| 
      
 406 
     | 
    
         
            +
                lfp_fs : float
         
     | 
| 
      
 407 
     | 
    
         
            +
                    Sampling frequency of the LFP signal in Hz
         
     | 
| 
      
 408 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 409 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 410 
     | 
    
         
            +
                ppc_method : str, optional
         
     | 
| 
      
 411 
     | 
    
         
            +
                    Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
         
     | 
| 
      
 412 
     | 
    
         
            +
                pop_names : List[str]
         
     | 
| 
      
 413 
     | 
    
         
            +
                    List of population names to analyze
         
     | 
| 
      
 414 
     | 
    
         
            +
                freqs : List[float]
         
     | 
| 
      
 415 
     | 
    
         
            +
                    List of frequencies (in Hz) at which to calculate entrainment
         
     | 
| 
       388 
416 
     | 
    
         | 
| 
       389 
     | 
    
         
            -
                Returns 
     | 
| 
       390 
     | 
    
         
            -
             
     | 
| 
       391 
     | 
    
         
            -
             
     | 
| 
       392 
     | 
    
         
            -
             
     | 
| 
       393 
     | 
    
         
            -
             
     | 
| 
       394 
     | 
    
         
            -
             
     | 
| 
       395 
     | 
    
         
            -
             
     | 
| 
      
 417 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 418 
     | 
    
         
            +
                -------
         
     | 
| 
      
 419 
     | 
    
         
            +
                Dict[str, Dict[int, Dict[float, float]]]
         
     | 
| 
      
 420 
     | 
    
         
            +
                    Nested dictionary where the structure is:
         
     | 
| 
      
 421 
     | 
    
         
            +
                    {
         
     | 
| 
      
 422 
     | 
    
         
            +
                        population_name: {
         
     | 
| 
      
 423 
     | 
    
         
            +
                            node_id: {
         
     | 
| 
      
 424 
     | 
    
         
            +
                                frequency: entrainment value
         
     | 
| 
       396 
425 
     | 
    
         
             
                            }
         
     | 
| 
       397 
426 
     | 
    
         
             
                        }
         
     | 
| 
       398 
     | 
    
         
            -
             
     | 
| 
      
 427 
     | 
    
         
            +
                    }
         
     | 
| 
      
 428 
     | 
    
         
            +
                    Entrainment values are floats representing the metric (PPC, PLV) at each frequency
         
     | 
| 
       399 
429 
     | 
    
         
             
                """
         
     | 
| 
       400 
     | 
    
         
            -
                 
     | 
| 
      
 430 
     | 
    
         
            +
                # pre filter lfp to speed up calculate of entrainment
         
     | 
| 
      
 431 
     | 
    
         
            +
                filtered_lfp_phases = {}
         
     | 
| 
      
 432 
     | 
    
         
            +
                for freq in range(len(freqs)):
         
     | 
| 
      
 433 
     | 
    
         
            +
                    phase = get_lfp_phase(
         
     | 
| 
      
 434 
     | 
    
         
            +
                        lfp_data=lfp_data, 
         
     | 
| 
      
 435 
     | 
    
         
            +
                        freq_of_interest=freqs[freq], 
         
     | 
| 
      
 436 
     | 
    
         
            +
                        fs=lfp_fs, 
         
     | 
| 
      
 437 
     | 
    
         
            +
                        filter_method=filter_method,
         
     | 
| 
      
 438 
     | 
    
         
            +
                        lowcut=lowcut, 
         
     | 
| 
      
 439 
     | 
    
         
            +
                        highcut=highcut, 
         
     | 
| 
      
 440 
     | 
    
         
            +
                        bandwidth=bandwidth
         
     | 
| 
      
 441 
     | 
    
         
            +
                    )
         
     | 
| 
      
 442 
     | 
    
         
            +
                    filtered_lfp_phases[freqs[freq]] = phase
         
     | 
| 
      
 443 
     | 
    
         
            +
                        
         
     | 
| 
      
 444 
     | 
    
         
            +
                entrainment_dict = {}
         
     | 
| 
       401 
445 
     | 
    
         
             
                for pop in pop_names:
         
     | 
| 
       402 
446 
     | 
    
         
             
                    skip_count = 0
         
     | 
| 
       403 
447 
     | 
    
         
             
                    pop_spikes = spike_df[spike_df['pop_name'] == pop]
         
     | 
| 
       404 
448 
     | 
    
         
             
                    nodes = pop_spikes['node_ids'].unique()
         
     | 
| 
       405 
     | 
    
         
            -
                     
     | 
| 
      
 449 
     | 
    
         
            +
                    entrainment_dict[pop] = {}
         
     | 
| 
       406 
450 
     | 
    
         
             
                    print(f'Processing {pop} population')
         
     | 
| 
       407 
451 
     | 
    
         
             
                    for node in tqdm(nodes):
         
     | 
| 
       408 
452 
     | 
    
         
             
                        node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
         
     | 
| 
         @@ -412,24 +456,58 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray, 
     | 
|
| 
       412 
456 
     | 
    
         
             
                            skip_count += 1
         
     | 
| 
       413 
457 
     | 
    
         
             
                            continue
         
     | 
| 
       414 
458 
     | 
    
         | 
| 
       415 
     | 
    
         
            -
                         
     | 
| 
      
 459 
     | 
    
         
            +
                        entrainment_dict[pop][node] = {}
         
     | 
| 
       416 
460 
     | 
    
         
             
                        for freq in freqs:
         
     | 
| 
       417 
     | 
    
         
            -
                             
     | 
| 
       418 
     | 
    
         
            -
             
     | 
| 
       419 
     | 
    
         
            -
                                 
     | 
| 
       420 
     | 
    
         
            -
             
     | 
| 
       421 
     | 
    
         
            -
             
     | 
| 
       422 
     | 
    
         
            -
             
     | 
| 
       423 
     | 
    
         
            -
             
     | 
| 
       424 
     | 
    
         
            -
             
     | 
| 
       425 
     | 
    
         
            -
             
     | 
| 
      
 461 
     | 
    
         
            +
                            # Calculate entrainment based on the selected method using the pre-filtered phases
         
     | 
| 
      
 462 
     | 
    
         
            +
                            if entrainment_method == 'plv':
         
     | 
| 
      
 463 
     | 
    
         
            +
                                entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
         
     | 
| 
      
 464 
     | 
    
         
            +
                                    node_spikes['timestamps'].values,
         
     | 
| 
      
 465 
     | 
    
         
            +
                                    lfp_data,
         
     | 
| 
      
 466 
     | 
    
         
            +
                                    spike_fs=spike_fs,
         
     | 
| 
      
 467 
     | 
    
         
            +
                                    lfp_fs=lfp_fs,
         
     | 
| 
      
 468 
     | 
    
         
            +
                                    freq_of_interest=freq,
         
     | 
| 
      
 469 
     | 
    
         
            +
                                    bandwidth=bandwidth,
         
     | 
| 
      
 470 
     | 
    
         
            +
                                    lowcut=lowcut,
         
     | 
| 
      
 471 
     | 
    
         
            +
                                    highcut=highcut,
         
     | 
| 
      
 472 
     | 
    
         
            +
                                    filter_method=filter_method,
         
     | 
| 
      
 473 
     | 
    
         
            +
                                    filtered_lfp_phase=filtered_lfp_phases[freq]
         
     | 
| 
      
 474 
     | 
    
         
            +
                                )
         
     | 
| 
      
 475 
     | 
    
         
            +
                            elif entrainment_method == 'ppc2':
         
     | 
| 
      
 476 
     | 
    
         
            +
                                entrainment_dict[pop][node][freq] = calculate_ppc2(
         
     | 
| 
      
 477 
     | 
    
         
            +
                                    node_spikes['timestamps'].values,
         
     | 
| 
      
 478 
     | 
    
         
            +
                                    lfp_data,
         
     | 
| 
      
 479 
     | 
    
         
            +
                                    spike_fs=spike_fs,
         
     | 
| 
      
 480 
     | 
    
         
            +
                                    lfp_fs=lfp_fs,
         
     | 
| 
      
 481 
     | 
    
         
            +
                                    freq_of_interest=freq,
         
     | 
| 
      
 482 
     | 
    
         
            +
                                    bandwidth=bandwidth,
         
     | 
| 
      
 483 
     | 
    
         
            +
                                    lowcut=lowcut,
         
     | 
| 
      
 484 
     | 
    
         
            +
                                    highcut=highcut,
         
     | 
| 
      
 485 
     | 
    
         
            +
                                    filter_method=filter_method,
         
     | 
| 
      
 486 
     | 
    
         
            +
                                    filtered_lfp_phase=filtered_lfp_phases[freq]
         
     | 
| 
      
 487 
     | 
    
         
            +
                                )
         
     | 
| 
      
 488 
     | 
    
         
            +
                            elif entrainment_method == 'ppc':
         
     | 
| 
      
 489 
     | 
    
         
            +
                                entrainment_dict[pop][node][freq] = calculate_ppc(
         
     | 
| 
      
 490 
     | 
    
         
            +
                                    node_spikes['timestamps'].values,
         
     | 
| 
      
 491 
     | 
    
         
            +
                                    lfp_data,
         
     | 
| 
      
 492 
     | 
    
         
            +
                                    spike_fs=spike_fs,
         
     | 
| 
      
 493 
     | 
    
         
            +
                                    lfp_fs=lfp_fs,
         
     | 
| 
      
 494 
     | 
    
         
            +
                                    freq_of_interest=freq,
         
     | 
| 
      
 495 
     | 
    
         
            +
                                    bandwidth=bandwidth,
         
     | 
| 
      
 496 
     | 
    
         
            +
                                    lowcut=lowcut,
         
     | 
| 
      
 497 
     | 
    
         
            +
                                    highcut=highcut,
         
     | 
| 
      
 498 
     | 
    
         
            +
                                    filter_method=filter_method,
         
     | 
| 
      
 499 
     | 
    
         
            +
                                    ppc_method=ppc_method,
         
     | 
| 
      
 500 
     | 
    
         
            +
                                    filtered_lfp_phase=filtered_lfp_phases[freq]
         
     | 
| 
      
 501 
     | 
    
         
            +
                                )
         
     | 
| 
       426 
502 
     | 
    
         | 
| 
       427 
     | 
    
         
            -
                    print(f'Calculated  
     | 
| 
      
 503 
     | 
    
         
            +
                    print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
         
     | 
| 
       428 
504 
     | 
    
         | 
| 
       429 
     | 
    
         
            -
                return  
     | 
| 
      
 505 
     | 
    
         
            +
                return entrainment_dict
         
     | 
| 
       430 
506 
     | 
    
         | 
| 
       431 
507 
     | 
    
         | 
| 
       432 
     | 
    
         
            -
            def calculate_spike_rate_power_correlation(spike_rate,  
     | 
| 
      
 508 
     | 
    
         
            +
            def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
         
     | 
| 
      
 509 
     | 
    
         
            +
                                                      bandwidth=2.0, lowcut=None, highcut=None,
         
     | 
| 
      
 510 
     | 
    
         
            +
                                                      freq_range=(10, 100), freq_step=5):
         
     | 
| 
       433 
511 
     | 
    
         
             
                """
         
     | 
| 
       434 
512 
     | 
    
         
             
                Calculate correlation between population spike rates and LFP power across frequencies
         
     | 
| 
       435 
513 
     | 
    
         
             
                using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
         
     | 
| 
         @@ -438,16 +516,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ 
     | 
|
| 
       438 
516 
     | 
    
         
             
                -----------
         
     | 
| 
       439 
517 
     | 
    
         
             
                spike_rate : DataFrame
         
     | 
| 
       440 
518 
     | 
    
         
             
                    Pre-calculated population spike rates at the same fs as lfp
         
     | 
| 
       441 
     | 
    
         
            -
                 
     | 
| 
      
 519 
     | 
    
         
            +
                lfp_data : np.array
         
     | 
| 
       442 
520 
     | 
    
         
             
                    LFP data
         
     | 
| 
       443 
521 
     | 
    
         
             
                fs : float
         
     | 
| 
       444 
522 
     | 
    
         
             
                    Sampling frequency
         
     | 
| 
       445 
523 
     | 
    
         
             
                pop_names : list
         
     | 
| 
       446 
524 
     | 
    
         
             
                    List of population names to analyze
         
     | 
| 
       447 
     | 
    
         
            -
                 
     | 
| 
       448 
     | 
    
         
            -
                     
     | 
| 
       449 
     | 
    
         
            -
                 
     | 
| 
       450 
     | 
    
         
            -
                     
     | 
| 
      
 525 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 526 
     | 
    
         
            +
                    Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 527 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 528 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         
     | 
| 
      
 529 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 530 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 531 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 532 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 533 
     | 
    
         
            +
                freq_range : tuple, optional
         
     | 
| 
      
 534 
     | 
    
         
            +
                    Min and max frequency to analyze (default: (10, 100))
         
     | 
| 
      
 535 
     | 
    
         
            +
                freq_step : float, optional
         
     | 
| 
      
 536 
     | 
    
         
            +
                    Step size for frequency analysis (default: 5)
         
     | 
| 
       451 
537 
     | 
    
         | 
| 
       452 
538 
     | 
    
         
             
                Returns:
         
     | 
| 
       453 
539 
     | 
    
         
             
                --------
         
     | 
| 
         @@ -463,14 +549,11 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ 
     | 
|
| 
       463 
549 
     | 
    
         
             
                # Dictionary to store results
         
     | 
| 
       464 
550 
     | 
    
         
             
                correlation_results = {pop: {} for pop in pop_names}
         
     | 
| 
       465 
551 
     | 
    
         | 
| 
       466 
     | 
    
         
            -
                # Calculate power at each frequency band using  
     | 
| 
      
 552 
     | 
    
         
            +
                # Calculate power at each frequency band using specified filter
         
     | 
| 
       467 
553 
     | 
    
         
             
                power_by_freq = {}
         
     | 
| 
       468 
554 
     | 
    
         
             
                for freq in frequencies:
         
     | 
| 
       469 
     | 
    
         
            -
                     
     | 
| 
       470 
     | 
    
         
            -
             
     | 
| 
       471 
     | 
    
         
            -
                    # Calculate power (magnitude squared of complex wavelet transform)
         
     | 
| 
       472 
     | 
    
         
            -
                    power = np.abs(filtered_signal)**2
         
     | 
| 
       473 
     | 
    
         
            -
                    power_by_freq[freq] = power
         
     | 
| 
      
 555 
     | 
    
         
            +
                    power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method, 
         
     | 
| 
      
 556 
     | 
    
         
            +
                                                       lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
         
     | 
| 
       474 
557 
     | 
    
         | 
| 
       475 
558 
     | 
    
         
             
                # Calculate correlation for each population
         
     | 
| 
       476 
559 
     | 
    
         
             
                for pop in pop_names:
         
     | 
| 
         @@ -481,7 +564,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ 
     | 
|
| 
       481 
564 
     | 
    
         
             
                    for freq in frequencies:
         
     | 
| 
       482 
565 
     | 
    
         
             
                        # Make sure the lengths match
         
     | 
| 
       483 
566 
     | 
    
         
             
                        if len(pop_rate) != len(power_by_freq[freq]):
         
     | 
| 
       484 
     | 
    
         
            -
                            raise  
     | 
| 
      
 567 
     | 
    
         
            +
                            raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
         
     | 
| 
       485 
568 
     | 
    
         
             
                        # use spearman for non-parametric correlation
         
     | 
| 
       486 
569 
     | 
    
         
             
                        corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
         
     | 
| 
       487 
570 
     | 
    
         
             
                        correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
         
     | 
    
        bmtool/analysis/lfp.py
    CHANGED
    
    | 
         @@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float: 
     | 
|
| 
       273 
273 
     | 
    
         
             
                return normalized_power
         
     | 
| 
       274 
274 
     | 
    
         | 
| 
       275 
275 
     | 
    
         | 
| 
       276 
     | 
    
         
            -
            def  
     | 
| 
      
 276 
     | 
    
         
            +
            def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
         
     | 
| 
      
 277 
     | 
    
         
            +
                """
         
     | 
| 
      
 278 
     | 
    
         
            +
                Calculate the passband of a complex Morlet wavelet filter.
         
     | 
| 
      
 279 
     | 
    
         
            +
                
         
     | 
| 
      
 280 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 281 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 282 
     | 
    
         
            +
                center_freq : float
         
     | 
| 
      
 283 
     | 
    
         
            +
                    Center frequency (Hz) of the wavelet filter
         
     | 
| 
      
 284 
     | 
    
         
            +
                bandwidth : float
         
     | 
| 
      
 285 
     | 
    
         
            +
                    Bandwidth parameter of the wavelet filter
         
     | 
| 
      
 286 
     | 
    
         
            +
                threshold : float, optional
         
     | 
| 
      
 287 
     | 
    
         
            +
                    Power threshold to define the passband edges (default: 0.5 = -3dB point)
         
     | 
| 
      
 288 
     | 
    
         
            +
                    
         
     | 
| 
      
 289 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 290 
     | 
    
         
            +
                -------
         
     | 
| 
      
 291 
     | 
    
         
            +
                tuple
         
     | 
| 
      
 292 
     | 
    
         
            +
                    (lower_bound, upper_bound, passband_width) of the frequency passband in Hz
         
     | 
| 
      
 293 
     | 
    
         
            +
                """
         
     | 
| 
      
 294 
     | 
    
         
            +
                # Create a high-resolution frequency axis around the center frequency
         
     | 
| 
      
 295 
     | 
    
         
            +
                # Extend range to 3x the expected width to ensure we capture the full passband
         
     | 
| 
      
 296 
     | 
    
         
            +
                expected_width = center_freq * bandwidth / 2
         
     | 
| 
      
 297 
     | 
    
         
            +
                freq_min = max(0.1, center_freq - 3 * expected_width)
         
     | 
| 
      
 298 
     | 
    
         
            +
                freq_max = center_freq + 3 * expected_width
         
     | 
| 
      
 299 
     | 
    
         
            +
                freq_axis = np.linspace(freq_min, freq_max, 1000)
         
     | 
| 
      
 300 
     | 
    
         
            +
                
         
     | 
| 
      
 301 
     | 
    
         
            +
                # Calculate the theoretical frequency response of the Morlet wavelet
         
     | 
| 
      
 302 
     | 
    
         
            +
                # For a complex Morlet wavelet, the frequency response approximates a Gaussian
         
     | 
| 
      
 303 
     | 
    
         
            +
                # centered at the center frequency with width related to the bandwidth parameter
         
     | 
| 
      
 304 
     | 
    
         
            +
                sigma_f = bandwidth * center_freq / 8  # Approximate relationship for cmor wavelet
         
     | 
| 
      
 305 
     | 
    
         
            +
                response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
         
     | 
| 
      
 306 
     | 
    
         
            +
                
         
     | 
| 
      
 307 
     | 
    
         
            +
                # Find the passband edges (where response crosses the threshold)
         
     | 
| 
      
 308 
     | 
    
         
            +
                above_threshold = response >= threshold
         
     | 
| 
      
 309 
     | 
    
         
            +
                if not np.any(above_threshold):
         
     | 
| 
      
 310 
     | 
    
         
            +
                    return (center_freq, center_freq, 0)  # No passband found
         
     | 
| 
      
 311 
     | 
    
         
            +
                
         
     | 
| 
      
 312 
     | 
    
         
            +
                # Find the first and last indices where response is above threshold
         
     | 
| 
      
 313 
     | 
    
         
            +
                indices = np.where(above_threshold)[0]
         
     | 
| 
      
 314 
     | 
    
         
            +
                lower_idx = indices[0]
         
     | 
| 
      
 315 
     | 
    
         
            +
                upper_idx = indices[-1]
         
     | 
| 
      
 316 
     | 
    
         
            +
                
         
     | 
| 
      
 317 
     | 
    
         
            +
                # Get the corresponding frequencies
         
     | 
| 
      
 318 
     | 
    
         
            +
                lower_bound = freq_axis[lower_idx]
         
     | 
| 
      
 319 
     | 
    
         
            +
                upper_bound = freq_axis[upper_idx]
         
     | 
| 
      
 320 
     | 
    
         
            +
                passband_width = upper_bound - lower_bound
         
     | 
| 
      
 321 
     | 
    
         
            +
                
         
     | 
| 
      
 322 
     | 
    
         
            +
                return (lower_bound, upper_bound, passband_width)
         
     | 
| 
      
 323 
     | 
    
         
            +
             
     | 
| 
      
 324 
     | 
    
         
            +
             
     | 
| 
      
 325 
     | 
    
         
            +
            def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
         
     | 
| 
       277 
326 
     | 
    
         
             
                """
         
     | 
| 
       278 
327 
     | 
    
         
             
                Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
         
     | 
| 
      
 328 
     | 
    
         
            +
             
     | 
| 
      
 329 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 330 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 331 
     | 
    
         
            +
                x : np.ndarray
         
     | 
| 
      
 332 
     | 
    
         
            +
                    Input signal
         
     | 
| 
      
 333 
     | 
    
         
            +
                freq : float
         
     | 
| 
      
 334 
     | 
    
         
            +
                    Target frequency for the wavelet filter
         
     | 
| 
      
 335 
     | 
    
         
            +
                fs : float
         
     | 
| 
      
 336 
     | 
    
         
            +
                    Sampling frequency of the signal
         
     | 
| 
      
 337 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 338 
     | 
    
         
            +
                    Bandwidth parameter of the wavelet filter (default is 1.0)
         
     | 
| 
      
 339 
     | 
    
         
            +
                axis : int, optional
         
     | 
| 
      
 340 
     | 
    
         
            +
                    Axis along which to compute the CWT (default is -1)
         
     | 
| 
      
 341 
     | 
    
         
            +
                show_passband : bool, optional
         
     | 
| 
      
 342 
     | 
    
         
            +
                    If True, print the passband of the wavelet filter (default is False)
         
     | 
| 
      
 343 
     | 
    
         
            +
                
         
     | 
| 
      
 344 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 345 
     | 
    
         
            +
                -------
         
     | 
| 
      
 346 
     | 
    
         
            +
                np.ndarray
         
     | 
| 
      
 347 
     | 
    
         
            +
                    Continuous Wavelet Transform of the input signal
         
     | 
| 
       279 
348 
     | 
    
         
             
                """
         
     | 
| 
      
 349 
     | 
    
         
            +
                if show_passband:
         
     | 
| 
      
 350 
     | 
    
         
            +
                    lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
         
     | 
| 
      
 351 
     | 
    
         
            +
                    print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
         
     | 
| 
      
 352 
     | 
    
         
            +
                    print(f"  Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
         
     | 
| 
       280 
353 
     | 
    
         
             
                wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
         
     | 
| 
       281 
354 
     | 
    
         
             
                scale = pywt.scale2frequency(wavelet, 1) * fs / freq
         
     | 
| 
       282 
355 
     | 
    
         
             
                x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
         
     | 
| 
         @@ -292,6 +365,108 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: 
     | 
|
| 
       292 
365 
     | 
    
         
             
                return x_a
         
     | 
| 
       293 
366 
     | 
    
         | 
| 
       294 
367 
     | 
    
         | 
| 
      
 368 
     | 
    
         
            +
            def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
         
     | 
| 
      
 369 
     | 
    
         
            +
                               lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
         
     | 
| 
      
 370 
     | 
    
         
            +
                """
         
     | 
| 
      
 371 
     | 
    
         
            +
                Compute the power of the raw LFP signal in a specified frequency band.
         
     | 
| 
      
 372 
     | 
    
         
            +
                
         
     | 
| 
      
 373 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 374 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 375 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 376 
     | 
    
         
            +
                    Raw local field potential (LFP) time series data
         
     | 
| 
      
 377 
     | 
    
         
            +
                freq : float
         
     | 
| 
      
 378 
     | 
    
         
            +
                    Center frequency (Hz) for wavelet filtering method
         
     | 
| 
      
 379 
     | 
    
         
            +
                fs : float
         
     | 
| 
      
 380 
     | 
    
         
            +
                    Sampling frequency (Hz) of the input data
         
     | 
| 
      
 381 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 382 
     | 
    
         
            +
                    Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
         
     | 
| 
      
 383 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 384 
     | 
    
         
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 385 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 386 
     | 
    
         
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         
     | 
| 
      
 387 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 388 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
         
     | 
| 
      
 389 
     | 
    
         
            +
                    
         
     | 
| 
      
 390 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 391 
     | 
    
         
            +
                -------
         
     | 
| 
      
 392 
     | 
    
         
            +
                np.ndarray
         
     | 
| 
      
 393 
     | 
    
         
            +
                    Power of the filtered signal (magnitude squared)
         
     | 
| 
      
 394 
     | 
    
         
            +
                    
         
     | 
| 
      
 395 
     | 
    
         
            +
                Notes
         
     | 
| 
      
 396 
     | 
    
         
            +
                -----
         
     | 
| 
      
 397 
     | 
    
         
            +
                - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
         
     | 
| 
      
 398 
     | 
    
         
            +
                - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
         
     | 
| 
      
 399 
     | 
    
         
            +
                - When using the 'butter' method, both lowcut and highcut must be provided
         
     | 
| 
      
 400 
     | 
    
         
            +
                """
         
     | 
| 
      
 401 
     | 
    
         
            +
                if filter_method == 'wavelet':
         
     | 
| 
      
 402 
     | 
    
         
            +
                    filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
         
     | 
| 
      
 403 
     | 
    
         
            +
                elif filter_method == 'butter':
         
     | 
| 
      
 404 
     | 
    
         
            +
                    if lowcut is None or highcut is None:
         
     | 
| 
      
 405 
     | 
    
         
            +
                        raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
         
     | 
| 
      
 406 
     | 
    
         
            +
                    filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
         
     | 
| 
      
 407 
     | 
    
         
            +
                else:
         
     | 
| 
      
 408 
     | 
    
         
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         
     | 
| 
      
 409 
     | 
    
         
            +
                
         
     | 
| 
      
 410 
     | 
    
         
            +
                # Calculate power (magnitude squared of filtered signal)
         
     | 
| 
      
 411 
     | 
    
         
            +
                power = np.abs(filtered_signal)**2
         
     | 
| 
      
 412 
     | 
    
         
            +
                return power
         
     | 
| 
      
 413 
     | 
    
         
            +
             
     | 
| 
      
 414 
     | 
    
         
            +
             
     | 
| 
      
 415 
     | 
    
         
            +
            def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
         
     | 
| 
      
 416 
     | 
    
         
            +
                               lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
         
     | 
| 
      
 417 
     | 
    
         
            +
                """
         
     | 
| 
      
 418 
     | 
    
         
            +
                Calculate the phase of the filtered signal.
         
     | 
| 
      
 419 
     | 
    
         
            +
                
         
     | 
| 
      
 420 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 421 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 422 
     | 
    
         
            +
                lfp_data : np.ndarray
         
     | 
| 
      
 423 
     | 
    
         
            +
                    Input LFP data
         
     | 
| 
      
 424 
     | 
    
         
            +
                fs : float
         
     | 
| 
      
 425 
     | 
    
         
            +
                    Sampling frequency (Hz)
         
     | 
| 
      
 426 
     | 
    
         
            +
                freq : float
         
     | 
| 
      
 427 
     | 
    
         
            +
                    Frequency of interest (Hz)
         
     | 
| 
      
 428 
     | 
    
         
            +
                filter_method : str, optional
         
     | 
| 
      
 429 
     | 
    
         
            +
                    Method for filtering the signal ('wavelet' or 'butter')
         
     | 
| 
      
 430 
     | 
    
         
            +
                bandwidth : float, optional
         
     | 
| 
      
 431 
     | 
    
         
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
         
     | 
| 
      
 432 
     | 
    
         
            +
                lowcut : float, optional
         
     | 
| 
      
 433 
     | 
    
         
            +
                    Low cutoff frequency for Butterworth filter when method='butter'
         
     | 
| 
      
 434 
     | 
    
         
            +
                highcut : float, optional
         
     | 
| 
      
 435 
     | 
    
         
            +
                    High cutoff frequency for Butterworth filter when method='butter'
         
     | 
| 
      
 436 
     | 
    
         
            +
                    
         
     | 
| 
      
 437 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 438 
     | 
    
         
            +
                -------
         
     | 
| 
      
 439 
     | 
    
         
            +
                np.ndarray
         
     | 
| 
      
 440 
     | 
    
         
            +
                    Phase of the filtered signal
         
     | 
| 
      
 441 
     | 
    
         
            +
                    
         
     | 
| 
      
 442 
     | 
    
         
            +
                Notes
         
     | 
| 
      
 443 
     | 
    
         
            +
                -----
         
     | 
| 
      
 444 
     | 
    
         
            +
                - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
         
     | 
| 
      
 445 
     | 
    
         
            +
                - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
         
     | 
| 
      
 446 
     | 
    
         
            +
                  followed by Hilbert transform to extract the phase
         
     | 
| 
      
 447 
     | 
    
         
            +
                - When using the 'butter' method, both lowcut and highcut must be provided
         
     | 
| 
      
 448 
     | 
    
         
            +
                """
         
     | 
| 
      
 449 
     | 
    
         
            +
                if filter_method == 'wavelet':
         
     | 
| 
      
 450 
     | 
    
         
            +
                    if freq_of_interest is None:
         
     | 
| 
      
 451 
     | 
    
         
            +
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         
     | 
| 
      
 452 
     | 
    
         
            +
                    # Wavelet filter returns complex values directly
         
     | 
| 
      
 453 
     | 
    
         
            +
                    filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
         
     | 
| 
      
 454 
     | 
    
         
            +
                    # Phase is the angle of the complex signal
         
     | 
| 
      
 455 
     | 
    
         
            +
                    phase = np.angle(filtered_signal)
         
     | 
| 
      
 456 
     | 
    
         
            +
                elif filter_method == 'butter':
         
     | 
| 
      
 457 
     | 
    
         
            +
                    if lowcut is None or highcut is None:
         
     | 
| 
      
 458 
     | 
    
         
            +
                        raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
         
     | 
| 
      
 459 
     | 
    
         
            +
                    # Butterworth filter returns real values
         
     | 
| 
      
 460 
     | 
    
         
            +
                    filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
         
     | 
| 
      
 461 
     | 
    
         
            +
                    # Apply Hilbert transform to get analytic signal (complex)
         
     | 
| 
      
 462 
     | 
    
         
            +
                    analytic_signal = signal.hilbert(filtered_signal)
         
     | 
| 
      
 463 
     | 
    
         
            +
                    # Phase is the angle of the analytic signal
         
     | 
| 
      
 464 
     | 
    
         
            +
                    phase = np.angle(analytic_signal)
         
     | 
| 
      
 465 
     | 
    
         
            +
                else:
         
     | 
| 
      
 466 
     | 
    
         
            +
                    raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
         
     | 
| 
      
 467 
     | 
    
         
            +
                
         
     | 
| 
      
 468 
     | 
    
         
            +
                return phase
         
     | 
| 
      
 469 
     | 
    
         
            +
             
     | 
| 
       295 
470 
     | 
    
         
             
            # windowing functions 
         
     | 
| 
       296 
471 
     | 
    
         
             
            def windowed_xarray(da, windows, dim='time',
         
     | 
| 
       297 
472 
     | 
    
         
             
                                new_coord_name='cycle', new_coord=None):
         
     | 
    
        bmtool/analysis/spikes.py
    CHANGED
    
    | 
         @@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu 
     | 
|
| 
       11 
11 
     | 
    
         
             
            import os
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
            def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
         
     | 
| 
      
 14 
     | 
    
         
            +
            def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
         
     | 
| 
       15 
15 
     | 
    
         
             
                """
         
     | 
| 
       16 
16 
     | 
    
         
             
                Load spike data from an HDF5 file into a pandas DataFrame.
         
     | 
| 
       17 
17 
     | 
    
         | 
| 
       18 
     | 
    
         
            -
                 
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
                     
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
                     
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
      
 18 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 19 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 20 
     | 
    
         
            +
                spike_file : str
         
     | 
| 
      
 21 
     | 
    
         
            +
                    Path to the HDF5 file containing spike data
         
     | 
| 
      
 22 
     | 
    
         
            +
                network_name : str
         
     | 
| 
      
 23 
     | 
    
         
            +
                    The name of the network within the HDF5 file from which to load spike data
         
     | 
| 
      
 24 
     | 
    
         
            +
                sort : bool, optional
         
     | 
| 
      
 25 
     | 
    
         
            +
                    Whether to sort the DataFrame by 'timestamps' (default: True)
         
     | 
| 
      
 26 
     | 
    
         
            +
                config : str, optional
         
     | 
| 
      
 27 
     | 
    
         
            +
                    Path to configuration file to label the cell type of each spike (default: None)
         
     | 
| 
      
 28 
     | 
    
         
            +
                groupby : Union[str, List[str]], optional
         
     | 
| 
      
 29 
     | 
    
         
            +
                    The column(s) to group by (default: 'pop_name')
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 32 
     | 
    
         
            +
                -------
         
     | 
| 
      
 33 
     | 
    
         
            +
                pd.DataFrame
         
     | 
| 
      
 34 
     | 
    
         
            +
                    A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
         
     | 
| 
      
 35 
     | 
    
         
            +
                    with additional columns if a config file is provided
         
     | 
| 
       27 
36 
     | 
    
         | 
| 
       28 
     | 
    
         
            -
                 
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
      
 37 
     | 
    
         
            +
                Examples
         
     | 
| 
      
 38 
     | 
    
         
            +
                --------
         
     | 
| 
      
 39 
     | 
    
         
            +
                >>> df = load_spikes_to_df("spikes.h5", "cortex")
         
     | 
| 
      
 40 
     | 
    
         
            +
                >>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
         
     | 
| 
       30 
41 
     | 
    
         
             
                """
         
     | 
| 
       31 
42 
     | 
    
         
             
                with h5py.File(spike_file) as f:
         
     | 
| 
       32 
43 
     | 
    
         
             
                    spikes_df = pd.DataFrame({
         
     | 
| 
         @@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] = 
     | 
|
| 
       126 
137 
     | 
    
         | 
| 
       127 
138 
     | 
    
         | 
| 
       128 
139 
     | 
    
         
             
            def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None, 
         
     | 
| 
       129 
     | 
    
         
            -
                               time_points: Optional[Union[np.ndarray, list]] = None,  
     | 
| 
      
 140 
     | 
    
         
            +
                               time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
         
     | 
| 
       130 
141 
     | 
    
         
             
                """
         
     | 
| 
       131 
142 
     | 
    
         
             
                Calculate the spike count or frequency histogram over specified time intervals.
         
     | 
| 
       132 
143 
     | 
    
         | 
| 
       133 
     | 
    
         
            -
                 
     | 
| 
       134 
     | 
    
         
            -
             
     | 
| 
       135 
     | 
    
         
            -
             
     | 
| 
       136 
     | 
    
         
            -
             
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
                     
     | 
| 
       140 
     | 
    
         
            -
             
     | 
| 
       141 
     | 
    
         
            -
             
     | 
| 
       142 
     | 
    
         
            -
             
     | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
      
 144 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 145 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 146 
     | 
    
         
            +
                spike_times : Union[np.ndarray, list]
         
     | 
| 
      
 147 
     | 
    
         
            +
                    Array or list of spike times in milliseconds
         
     | 
| 
      
 148 
     | 
    
         
            +
                time : Optional[Tuple[float, float, float]], optional
         
     | 
| 
      
 149 
     | 
    
         
            +
                    Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points 
         
     | 
| 
      
 150 
     | 
    
         
            +
                    if `time_points` is not provided. Default is None.
         
     | 
| 
      
 151 
     | 
    
         
            +
                time_points : Optional[Union[np.ndarray, list]], optional
         
     | 
| 
      
 152 
     | 
    
         
            +
                    Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
         
     | 
| 
      
 153 
     | 
    
         
            +
                frequency : bool, optional
         
     | 
| 
      
 154 
     | 
    
         
            +
                    If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 157 
     | 
    
         
            +
                -------
         
     | 
| 
      
 158 
     | 
    
         
            +
                np.ndarray
         
     | 
| 
      
 159 
     | 
    
         
            +
                    Array of spike counts or frequencies, depending on the `frequency` flag.
         
     | 
| 
      
 160 
     | 
    
         
            +
             
     | 
| 
      
 161 
     | 
    
         
            +
                Raises
         
     | 
| 
      
 162 
     | 
    
         
            +
                ------
         
     | 
| 
      
 163 
     | 
    
         
            +
                ValueError
         
     | 
| 
      
 164 
     | 
    
         
            +
                    If both `time` and `time_points` are None.
         
     | 
| 
       146 
165 
     | 
    
         
             
                """
         
     | 
| 
       147 
166 
     | 
    
         
             
                if time_points is None:
         
     | 
| 
       148 
167 
     | 
    
         
             
                    if time is None:
         
     | 
| 
         @@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f 
     | 
|
| 
       156 
175 
     | 
    
         
             
                bins = np.append(time_points, time_points[-1] + dt)
         
     | 
| 
       157 
176 
     | 
    
         
             
                spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
         
     | 
| 
       158 
177 
     | 
    
         | 
| 
       159 
     | 
    
         
            -
                if  
     | 
| 
      
 178 
     | 
    
         
            +
                if frequency:
         
     | 
| 
       160 
179 
     | 
    
         
             
                    spike_rate = 1000 / dt * spike_rate
         
     | 
| 
       161 
180 
     | 
    
         | 
| 
       162 
181 
     | 
    
         
             
                return spike_rate
         
     | 
| 
       163 
182 
     | 
    
         | 
| 
       164 
183 
     | 
    
         | 
| 
       165 
     | 
    
         
            -
            def get_population_spike_rate( 
     | 
| 
      
 184 
     | 
    
         
            +
            def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None, 
         
     | 
| 
       166 
185 
     | 
    
         
             
                                          config: Optional[str] = None, network_name: Optional[str] = None,
         
     | 
| 
       167 
186 
     | 
    
         
             
                                          save: bool = False, save_path: Optional[str] = None,
         
     | 
| 
       168 
187 
     | 
    
         
             
                                          normalize: bool = False) -> Dict[str, np.ndarray]:
         
     | 
| 
       169 
188 
     | 
    
         
             
                """
         
     | 
| 
       170 
189 
     | 
    
         
             
                Calculate the population spike rate for each population in the given spike data, with an option to normalize.
         
     | 
| 
       171 
190 
     | 
    
         | 
| 
       172 
     | 
    
         
            -
                 
     | 
| 
       173 
     | 
    
         
            -
             
     | 
| 
       174 
     | 
    
         
            -
             
     | 
| 
       175 
     | 
    
         
            -
                     
     | 
| 
       176 
     | 
    
         
            -
             
     | 
| 
       177 
     | 
    
         
            -
                     
     | 
| 
       178 
     | 
    
         
            -
             
     | 
| 
       179 
     | 
    
         
            -
                     
     | 
| 
       180 
     | 
    
         
            -
             
     | 
| 
       181 
     | 
    
         
            -
                     
     | 
| 
       182 
     | 
    
         
            -
             
     | 
| 
       183 
     | 
    
         
            -
                     
     | 
| 
       184 
     | 
    
         
            -
             
     | 
| 
       185 
     | 
    
         
            -
                 
     | 
| 
       186 
     | 
    
         
            -
                     
     | 
| 
       187 
     | 
    
         
            -
             
     | 
| 
       188 
     | 
    
         
            -
             
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
       190 
     | 
    
         
            -
             
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
       192 
     | 
    
         
            -
                 
     | 
| 
       193 
     | 
    
         
            -
                     
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
       195 
     | 
    
         
            -
             
     | 
| 
      
 191 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 192 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 193 
     | 
    
         
            +
                spike_data : pd.DataFrame
         
     | 
| 
      
 194 
     | 
    
         
            +
                    A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
         
     | 
| 
      
 195 
     | 
    
         
            +
                fs : float, optional
         
     | 
| 
      
 196 
     | 
    
         
            +
                    Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
         
     | 
| 
      
 197 
     | 
    
         
            +
                t_start : float, optional
         
     | 
| 
      
 198 
     | 
    
         
            +
                    Start time (in milliseconds) for spike rate calculation (default: 0)
         
     | 
| 
      
 199 
     | 
    
         
            +
                t_stop : Optional[float], optional
         
     | 
| 
      
 200 
     | 
    
         
            +
                    Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
         
     | 
| 
      
 201 
     | 
    
         
            +
                config : Optional[str], optional
         
     | 
| 
      
 202 
     | 
    
         
            +
                    Path to a configuration file containing node information, used to determine the correct number of nodes per population.
         
     | 
| 
      
 203 
     | 
    
         
            +
                    If None, node count is estimated from unique node spikes (default: None)
         
     | 
| 
      
 204 
     | 
    
         
            +
                network_name : Optional[str], optional
         
     | 
| 
      
 205 
     | 
    
         
            +
                    Name of the network used in the configuration file, allowing selection of nodes for that network.
         
     | 
| 
      
 206 
     | 
    
         
            +
                    Required if `config` is provided (default: None)
         
     | 
| 
      
 207 
     | 
    
         
            +
                save : bool, optional
         
     | 
| 
      
 208 
     | 
    
         
            +
                    Whether to save the calculated population spike rate to a file (default: False)
         
     | 
| 
      
 209 
     | 
    
         
            +
                save_path : Optional[str], optional
         
     | 
| 
      
 210 
     | 
    
         
            +
                    Directory path where the file should be saved if `save` is True (default: None)
         
     | 
| 
      
 211 
     | 
    
         
            +
                normalize : bool, optional
         
     | 
| 
      
 212 
     | 
    
         
            +
                    Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 215 
     | 
    
         
            +
                -------
         
     | 
| 
      
 216 
     | 
    
         
            +
                Dict[str, np.ndarray]
         
     | 
| 
      
 217 
     | 
    
         
            +
                    A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
         
     | 
| 
      
 218 
     | 
    
         
            +
                    If `normalize` is True, each population's spike rate is scaled to [0, 1].
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                Raises
         
     | 
| 
      
 221 
     | 
    
         
            +
                ------
         
     | 
| 
      
 222 
     | 
    
         
            +
                ValueError
         
     | 
| 
      
 223 
     | 
    
         
            +
                    If `save` is True but `save_path` is not provided.
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                Notes
         
     | 
| 
      
 226 
     | 
    
         
            +
                -----
         
     | 
| 
      
 227 
     | 
    
         
            +
                - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
         
     | 
| 
      
 228 
     | 
    
         
            +
                - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
         
     | 
| 
       196 
229 
     | 
    
         
             
                """
         
     | 
| 
       197 
230 
     | 
    
         
             
                pop_spikes = {}
         
     | 
| 
       198 
231 
     | 
    
         
             
                node_number = {}
         
     | 
| 
         @@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: 
     | 
|
| 
       205 
238 
     | 
    
         
             
                    if not network_name:
         
     | 
| 
       206 
239 
     | 
    
         
             
                        print("Grabbing first network; specify a network name to ensure correct node population is selected.")
         
     | 
| 
       207 
240 
     | 
    
         | 
| 
       208 
     | 
    
         
            -
                for pop_name in  
     | 
| 
       209 
     | 
    
         
            -
                    ps =  
     | 
| 
      
 241 
     | 
    
         
            +
                for pop_name in spike_data['pop_name'].unique():
         
     | 
| 
      
 242 
     | 
    
         
            +
                    ps = spike_data[spike_data['pop_name'] == pop_name]
         
     | 
| 
       210 
243 
     | 
    
         | 
| 
       211 
244 
     | 
    
         
             
                    if config:
         
     | 
| 
       212 
245 
     | 
    
         
             
                        nodes = load_nodes_from_config(config)
         
     | 
| 
         @@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: 
     | 
|
| 
       220 
253 
     | 
    
         
             
                        node_number[pop_name] = ps['node_ids'].nunique()
         
     | 
| 
       221 
254 
     | 
    
         | 
| 
       222 
255 
     | 
    
         
             
                    if t_stop is None:
         
     | 
| 
       223 
     | 
    
         
            -
                        t_stop =  
     | 
| 
      
 256 
     | 
    
         
            +
                        t_stop = spike_data['timestamps'].max()
         
     | 
| 
       224 
257 
     | 
    
         | 
| 
       225 
     | 
    
         
            -
                    filtered_spikes =  
     | 
| 
       226 
     | 
    
         
            -
                        ( 
     | 
| 
       227 
     | 
    
         
            -
                        ( 
     | 
| 
       228 
     | 
    
         
            -
                        ( 
     | 
| 
      
 258 
     | 
    
         
            +
                    filtered_spikes = spike_data[
         
     | 
| 
      
 259 
     | 
    
         
            +
                        (spike_data['pop_name'] == pop_name) & 
         
     | 
| 
      
 260 
     | 
    
         
            +
                        (spike_data['timestamps'] > t_start) & 
         
     | 
| 
      
 261 
     | 
    
         
            +
                        (spike_data['timestamps'] < t_stop)
         
     | 
| 
       229 
262 
     | 
    
         
             
                    ]
         
     | 
| 
       230 
263 
     | 
    
         
             
                    pop_spikes[pop_name] = filtered_spikes
         
     | 
| 
       231 
264 
     | 
    
         | 
| 
         @@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: 
     | 
|
| 
       254 
287 
     | 
    
         
             
                return spike_rate
         
     | 
| 
       255 
288 
     | 
    
         | 
| 
       256 
289 
     | 
    
         | 
| 
       257 
     | 
    
         
            -
            def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
         
     | 
| 
      
 290 
     | 
    
         
            +
            def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
         
     | 
| 
       258 
291 
     | 
    
         
             
                """
         
     | 
| 
       259 
     | 
    
         
            -
                Compares the firing rates of a population during two different time windows
         
     | 
| 
       260 
     | 
    
         
            -
                 
     | 
| 
       261 
     | 
    
         
            -
                 
     | 
| 
      
 292 
     | 
    
         
            +
                Compares the firing rates of a population during two different time windows and performs
         
     | 
| 
      
 293 
     | 
    
         
            +
                a statistical test to determine if there is a significant difference.
         
     | 
| 
      
 294 
     | 
    
         
            +
                
         
     | 
| 
      
 295 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 296 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 297 
     | 
    
         
            +
                spike_df : pd.DataFrame
         
     | 
| 
      
 298 
     | 
    
         
            +
                    DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
         
     | 
| 
      
 299 
     | 
    
         
            +
                group_by : str
         
     | 
| 
      
 300 
     | 
    
         
            +
                    Column name to group spikes by (e.g., 'pop_name')
         
     | 
| 
      
 301 
     | 
    
         
            +
                time_window_1 : List[float]
         
     | 
| 
      
 302 
     | 
    
         
            +
                    First time window as [start, stop] in milliseconds
         
     | 
| 
      
 303 
     | 
    
         
            +
                time_window_2 : List[float]
         
     | 
| 
      
 304 
     | 
    
         
            +
                    Second time window as [start, stop] in milliseconds
         
     | 
| 
      
 305 
     | 
    
         
            +
                
         
     | 
| 
      
 306 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 307 
     | 
    
         
            +
                -------
         
     | 
| 
      
 308 
     | 
    
         
            +
                None
         
     | 
| 
      
 309 
     | 
    
         
            +
                    Results are printed to the console
         
     | 
| 
      
 310 
     | 
    
         
            +
                
         
     | 
| 
      
 311 
     | 
    
         
            +
                Notes
         
     | 
| 
      
 312 
     | 
    
         
            +
                -----
         
     | 
| 
      
 313 
     | 
    
         
            +
                Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
         
     | 
| 
       262 
314 
     | 
    
         
             
                """
         
     | 
| 
       263 
315 
     | 
    
         
             
                # Filter spikes for the population of interest
         
     | 
| 
       264 
316 
     | 
    
         
             
                for pop_name in spike_df[group_by].unique():
         
     | 
| 
         @@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298 
     | 
|
| 
       8 
8 
     | 
    
         
             
            bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
         
     | 
| 
       9 
9 
     | 
    
         
             
            bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
         
     | 
| 
       10 
10 
     | 
    
         
             
            bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       11 
     | 
    
         
            -
            bmtool/analysis/entrainment.py,sha256= 
     | 
| 
       12 
     | 
    
         
            -
            bmtool/analysis/lfp.py,sha256= 
     | 
| 
      
 11 
     | 
    
         
            +
            bmtool/analysis/entrainment.py,sha256=7lFlGMApL_2snwdvIPDLFW1KKPdyiuCnZ5ADa7ujx5o,24439
         
     | 
| 
      
 12 
     | 
    
         
            +
            bmtool/analysis/lfp.py,sha256=6u9cHnac-5Fzpk9ecQew7MmXBAolzKZakRsnPn3-C2U,24109
         
     | 
| 
       13 
13 
     | 
    
         
             
            bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
         
     | 
| 
       14 
     | 
    
         
            -
            bmtool/analysis/spikes.py,sha256= 
     | 
| 
      
 14 
     | 
    
         
            +
            bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
         
     | 
| 
       15 
15 
     | 
    
         
             
            bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       16 
16 
     | 
    
         
             
            bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
         
     | 
| 
       17 
17 
     | 
    
         
             
            bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
         
     | 
| 
         @@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396 
     | 
|
| 
       26 
26 
     | 
    
         
             
            bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
         
     | 
| 
       27 
27 
     | 
    
         
             
            bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       28 
28 
     | 
    
         
             
            bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
         
     | 
| 
       29 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
       30 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
       31 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
       32 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
       33 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
       34 
     | 
    
         
            -
            bmtool-0.7.0. 
     | 
| 
      
 29 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
         
     | 
| 
      
 30 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/METADATA,sha256=5A-VT9HRvmYInIJg4FvfxYSYGL70RzSgOaAaRULRXYs,2768
         
     | 
| 
      
 31 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
         
     | 
| 
      
 32 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
         
     | 
| 
      
 33 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
         
     | 
| 
      
 34 
     | 
    
         
            +
            bmtool-0.7.0.5.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |