bmtool 0.7.0.3__py3-none-any.whl → 0.7.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +221 -146
- bmtool/analysis/lfp.py +121 -1
- bmtool/analysis/spikes.py +115 -63
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/METADATA +1 -1
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/RECORD +9 -9
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/WHEEL +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.3.dist-info → bmtool-0.7.0.4.dist-info}/top_level.txt +0 -0
    
        bmtool/analysis/entrainment.py
    CHANGED
    
    | @@ -8,7 +8,7 @@ import numba | |
| 8 8 | 
             
            from numba import cuda
         | 
| 9 9 | 
             
            import pandas as pd
         | 
| 10 10 | 
             
            import xarray as xr
         | 
| 11 | 
            -
            from .lfp import wavelet_filter,butter_bandpass_filter
         | 
| 11 | 
            +
            from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power
         | 
| 12 12 | 
             
            from typing import Dict, List
         | 
| 13 13 | 
             
            from tqdm.notebook import tqdm
         | 
| 14 14 | 
             
            import scipy.stats as stats
         | 
| @@ -16,49 +16,65 @@ import seaborn as sns | |
| 16 16 | 
             
            import matplotlib.pyplot as plt
         | 
| 17 17 |  | 
| 18 18 |  | 
| 19 | 
            -
            def calculate_signal_signal_plv( | 
| 20 | 
            -
                               | 
| 19 | 
            +
            def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None, 
         | 
| 20 | 
            +
                              filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None, 
         | 
| 21 21 | 
             
                              bandwidth: float = 2.0) -> np.ndarray:
         | 
| 22 22 | 
             
                """
         | 
| 23 23 | 
             
                Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
         | 
| 24 24 |  | 
| 25 | 
            -
                Parameters | 
| 26 | 
            -
                 | 
| 27 | 
            -
                 | 
| 28 | 
            -
             | 
| 29 | 
            -
                 | 
| 30 | 
            -
             | 
| 31 | 
            -
                 | 
| 32 | 
            -
             | 
| 33 | 
            -
                 | 
| 34 | 
            -
             | 
| 25 | 
            +
                Parameters
         | 
| 26 | 
            +
                ----------
         | 
| 27 | 
            +
                signal1 : np.ndarray
         | 
| 28 | 
            +
                    First input signal (1D array)
         | 
| 29 | 
            +
                signal2 : np.ndarray
         | 
| 30 | 
            +
                    Second input signal (1D array, same length as signal1)
         | 
| 31 | 
            +
                fs : float
         | 
| 32 | 
            +
                    Sampling frequency in Hz
         | 
| 33 | 
            +
                freq_of_interest : float, optional
         | 
| 34 | 
            +
                    Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
         | 
| 35 | 
            +
                filter_method : str, optional
         | 
| 36 | 
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         | 
| 37 | 
            +
                lowcut : float, optional
         | 
| 38 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 39 | 
            +
                highcut : float, optional
         | 
| 40 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 41 | 
            +
                bandwidth : float, optional
         | 
| 42 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                Returns
         | 
| 45 | 
            +
                -------
         | 
| 46 | 
            +
                np.ndarray
         | 
| 47 | 
            +
                    Phase Locking Value (1D array)
         | 
| 35 48 | 
             
                """
         | 
| 36 | 
            -
                if len( | 
| 49 | 
            +
                if len(signal1) != len(signal2):
         | 
| 37 50 | 
             
                    raise ValueError("Input signals must have the same length.")
         | 
| 38 51 |  | 
| 39 | 
            -
                if  | 
| 52 | 
            +
                if filter_method == 'wavelet':
         | 
| 40 53 | 
             
                    if freq_of_interest is None:
         | 
| 41 54 | 
             
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         | 
| 42 55 |  | 
| 43 56 | 
             
                    # Apply CWT to both signals
         | 
| 44 | 
            -
                    theta1 = wavelet_filter(x= | 
| 45 | 
            -
                    theta2 = wavelet_filter(x= | 
| 57 | 
            +
                    theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
         | 
| 58 | 
            +
                    theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
         | 
| 46 59 |  | 
| 47 | 
            -
                elif  | 
| 60 | 
            +
                elif filter_method == 'butter':
         | 
| 48 61 | 
             
                    if lowcut is None or highcut is None:
         | 
| 49 | 
            -
                        print("Lowcut and | 
| 62 | 
            +
                        print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
         | 
| 50 63 |  | 
| 51 64 | 
             
                    if lowcut and highcut:
         | 
| 52 65 | 
             
                        # Bandpass filter and get the analytic signal using the Hilbert transform
         | 
| 53 | 
            -
                         | 
| 54 | 
            -
                         | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
                     | 
| 66 | 
            +
                        filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
         | 
| 67 | 
            +
                        filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
         | 
| 68 | 
            +
                        # Get phase using the Hilbert transform
         | 
| 69 | 
            +
                        theta1 = signal.hilbert(filtered_signal1)
         | 
| 70 | 
            +
                        theta2 = signal.hilbert(filtered_signal2)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        # Get phase using the Hilbert transform without filtering
         | 
| 73 | 
            +
                        theta1 = signal.hilbert(signal1)
         | 
| 74 | 
            +
                        theta2 = signal.hilbert(signal2)
         | 
| 59 75 |  | 
| 60 76 | 
             
                else:
         | 
| 61 | 
            -
                    raise ValueError("Invalid method. Choose 'wavelet' or ' | 
| 77 | 
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         | 
| 62 78 |  | 
| 63 79 | 
             
                # Calculate phase difference
         | 
| 64 80 | 
             
                phase_diff = np.angle(theta1) - np.angle(theta2)
         | 
| @@ -69,29 +85,43 @@ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_ | |
| 69 85 | 
             
                return plv
         | 
| 70 86 |  | 
| 71 87 |  | 
| 72 | 
            -
            def calculate_spike_lfp_plv(spike_times: np.ndarray = None,  | 
| 73 | 
            -
                               lfp_fs: float = None,  | 
| 88 | 
            +
            def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         | 
| 89 | 
            +
                               lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
         | 
| 74 90 | 
             
                               lowcut: float = None, highcut: float = None,
         | 
| 75 91 | 
             
                               bandwidth: float = 2.0) -> tuple:
         | 
| 76 92 | 
             
                """
         | 
| 77 93 | 
             
                Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
         | 
| 78 94 |  | 
| 79 | 
            -
                Parameters | 
| 80 | 
            -
                 | 
| 81 | 
            -
                 | 
| 82 | 
            -
             | 
| 83 | 
            -
                 | 
| 84 | 
            -
             | 
| 85 | 
            -
                 | 
| 86 | 
            -
             | 
| 87 | 
            -
                 | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
| 90 | 
            -
             | 
| 91 | 
            -
                 | 
| 95 | 
            +
                Parameters
         | 
| 96 | 
            +
                ----------
         | 
| 97 | 
            +
                spike_times : np.ndarray
         | 
| 98 | 
            +
                    Array of spike times
         | 
| 99 | 
            +
                lfp_data : np.ndarray
         | 
| 100 | 
            +
                    Local field potential time series data
         | 
| 101 | 
            +
                spike_fs : float, optional
         | 
| 102 | 
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         | 
| 103 | 
            +
                lfp_fs : float
         | 
| 104 | 
            +
                    Sampling frequency in Hz of the LFP data
         | 
| 105 | 
            +
                filter_method : str, optional
         | 
| 106 | 
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
         | 
| 107 | 
            +
                freq_of_interest : float, optional
         | 
| 108 | 
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         | 
| 109 | 
            +
                lowcut : float, optional
         | 
| 110 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 111 | 
            +
                highcut : float, optional
         | 
| 112 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 113 | 
            +
                bandwidth : float, optional
         | 
| 114 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                Returns
         | 
| 117 | 
            +
                -------
         | 
| 118 | 
            +
                tuple
         | 
| 119 | 
            +
                    (plv, spike_phases) where:
         | 
| 120 | 
            +
                    - plv: Phase Locking Value
         | 
| 121 | 
            +
                    - spike_phases: Phases at spike times
         | 
| 92 122 | 
             
                """
         | 
| 93 123 |  | 
| 94 | 
            -
                if spike_fs  | 
| 124 | 
            +
                if spike_fs is None:
         | 
| 95 125 | 
             
                    spike_fs = lfp_fs
         | 
| 96 126 | 
             
                # Convert spike times to sample indices
         | 
| 97 127 | 
             
                spike_times_seconds = spike_times / spike_fs
         | 
| @@ -100,33 +130,29 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr | |
| 100 130 | 
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         | 
| 101 131 |  | 
| 102 132 | 
             
                # Filter indices to ensure they're within bounds of the LFP signal
         | 
| 103 | 
            -
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len( | 
| 133 | 
            +
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         | 
| 104 134 | 
             
                if len(valid_indices) <= 1:
         | 
| 105 135 | 
             
                    return 0, np.array([])
         | 
| 106 136 |  | 
| 107 | 
            -
                #  | 
| 108 | 
            -
                if  | 
| 137 | 
            +
                # Filter the LFP signal to extract the phase
         | 
| 138 | 
            +
                if filter_method == 'wavelet':
         | 
| 109 139 | 
             
                    if freq_of_interest is None:
         | 
| 110 140 | 
             
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         | 
| 111 141 |  | 
| 112 | 
            -
                    # Apply CWT to extract phase | 
| 113 | 
            -
                     | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
                elif method == 'hilbert':
         | 
| 142 | 
            +
                    # Apply CWT to extract phase
         | 
| 143 | 
            +
                    filtered_lfp = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         | 
| 144 | 
            +
                
         | 
| 145 | 
            +
                elif filter_method == 'butter':
         | 
| 117 146 | 
             
                    if lowcut is None or highcut is None:
         | 
| 118 | 
            -
                         | 
| 119 | 
            -
                        filtered_lfp = lfp_signal
         | 
| 120 | 
            -
                    else:
         | 
| 121 | 
            -
                        # Bandpass filter the signal
         | 
| 122 | 
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         | 
| 147 | 
            +
                        raise ValueError("Both lowcut and highcut must be specified for the butter method.")
         | 
| 123 148 |  | 
| 124 | 
            -
                    #  | 
| 125 | 
            -
                     | 
| 126 | 
            -
                     | 
| 149 | 
            +
                    # Bandpass filter the LFP signal
         | 
| 150 | 
            +
                    filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
         | 
| 151 | 
            +
                    filtered_lfp = signal.hilbert(filtered_lfp)  # Get analytic signal
         | 
| 152 | 
            +
                
         | 
| 127 153 |  | 
| 128 154 | 
             
                else:
         | 
| 129 | 
            -
                    raise ValueError("Invalid method. Choose 'wavelet' or ' | 
| 155 | 
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         | 
| 130 156 |  | 
| 131 157 | 
             
                # Get phases at spike times
         | 
| 132 158 | 
             
                spike_phases = instantaneous_phase[valid_indices]
         | 
| @@ -181,27 +207,43 @@ def _ppc_gpu(spike_phases): | |
| 181 207 | 
             
                return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
         | 
| 182 208 |  | 
| 183 209 |  | 
| 184 | 
            -
            def calculate_ppc(spike_times: np.ndarray = None,  | 
| 185 | 
            -
                              lfp_fs: float = None,  | 
| 210 | 
            +
            def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         | 
| 211 | 
            +
                              lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
         | 
| 186 212 | 
             
                              lowcut: float = None, highcut: float = None,
         | 
| 187 | 
            -
                              bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
         | 
| 213 | 
            +
                              bandwidth: float = 2.0, ppc_method: str = 'numpy') -> tuple:
         | 
| 188 214 | 
             
                """
         | 
| 189 215 | 
             
                Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
         | 
| 190 216 | 
             
                Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
         | 
| 191 217 |  | 
| 192 | 
            -
                Parameters | 
| 193 | 
            -
                 | 
| 194 | 
            -
                 | 
| 195 | 
            -
             | 
| 196 | 
            -
                 | 
| 197 | 
            -
             | 
| 198 | 
            -
                 | 
| 199 | 
            -
             | 
| 200 | 
            -
                 | 
| 201 | 
            -
             | 
| 202 | 
            -
                
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                 | 
| 218 | 
            +
                Parameters
         | 
| 219 | 
            +
                ----------
         | 
| 220 | 
            +
                spike_times : np.ndarray
         | 
| 221 | 
            +
                    Array of spike times
         | 
| 222 | 
            +
                lfp_data : np.ndarray
         | 
| 223 | 
            +
                    Local field potential time series data
         | 
| 224 | 
            +
                spike_fs : float, optional
         | 
| 225 | 
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         | 
| 226 | 
            +
                lfp_fs : float
         | 
| 227 | 
            +
                    Sampling frequency in Hz of the LFP data
         | 
| 228 | 
            +
                filter_method : str, optional
         | 
| 229 | 
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         | 
| 230 | 
            +
                freq_of_interest : float, optional
         | 
| 231 | 
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         | 
| 232 | 
            +
                lowcut : float, optional
         | 
| 233 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 234 | 
            +
                highcut : float, optional
         | 
| 235 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 236 | 
            +
                bandwidth : float, optional
         | 
| 237 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         | 
| 238 | 
            +
                ppc_method : str, optional
         | 
| 239 | 
            +
                    Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
         | 
| 240 | 
            +
                
         | 
| 241 | 
            +
                Returns
         | 
| 242 | 
            +
                -------
         | 
| 243 | 
            +
                tuple
         | 
| 244 | 
            +
                    (ppc, spike_phases) where:
         | 
| 245 | 
            +
                    - ppc: Pairwise Phase Consistency value
         | 
| 246 | 
            +
                    - spike_phases: Phases at spike times
         | 
| 205 247 | 
             
                """
         | 
| 206 248 | 
             
                if spike_fs is None:
         | 
| 207 249 | 
             
                    spike_fs = lfp_fs
         | 
| @@ -212,33 +254,32 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, | |
| 212 254 | 
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         | 
| 213 255 |  | 
| 214 256 | 
             
                # Filter indices to ensure they're within bounds of the LFP signal
         | 
| 215 | 
            -
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len( | 
| 257 | 
            +
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         | 
| 216 258 | 
             
                if len(valid_indices) <= 1:
         | 
| 217 259 | 
             
                    return 0, np.array([])
         | 
| 218 260 |  | 
| 219 261 | 
             
                # Extract phase using the specified method
         | 
| 220 | 
            -
                if  | 
| 262 | 
            +
                if filter_method == 'wavelet':
         | 
| 221 263 | 
             
                    if freq_of_interest is None:
         | 
| 222 264 | 
             
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         | 
| 223 265 |  | 
| 224 266 | 
             
                    # Apply CWT to extract phase at the frequency of interest
         | 
| 225 | 
            -
                    lfp_complex = wavelet_filter(x= | 
| 267 | 
            +
                    lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         | 
| 226 268 | 
             
                    instantaneous_phase = np.angle(lfp_complex)
         | 
| 227 269 |  | 
| 228 | 
            -
                elif  | 
| 270 | 
            +
                elif filter_method == 'butter':
         | 
| 229 271 | 
             
                    if lowcut is None or highcut is None:
         | 
| 230 | 
            -
                         | 
| 231 | 
            -
             | 
| 232 | 
            -
                     | 
| 233 | 
            -
             | 
| 234 | 
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         | 
| 272 | 
            +
                        raise ValueError("Both lowcut and highcut must be specified for the butter method.")
         | 
| 273 | 
            +
                    
         | 
| 274 | 
            +
                    # Bandpass filter the signal
         | 
| 275 | 
            +
                    filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
         | 
| 235 276 |  | 
| 236 277 | 
             
                    # Get phase using the Hilbert transform
         | 
| 237 278 | 
             
                    analytic_signal = signal.hilbert(filtered_lfp)
         | 
| 238 279 | 
             
                    instantaneous_phase = np.angle(analytic_signal)
         | 
| 239 280 |  | 
| 240 281 | 
             
                else:
         | 
| 241 | 
            -
                    raise ValueError("Invalid method. Choose 'wavelet' or ' | 
| 282 | 
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         | 
| 242 283 |  | 
| 243 284 | 
             
                # Get phases at spike times
         | 
| 244 285 | 
             
                spike_phases = instantaneous_phase[valid_indices]
         | 
| @@ -283,10 +324,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, | |
| 283 324 | 
             
                return ppc
         | 
| 284 325 |  | 
| 285 326 |  | 
| 286 | 
            -
            def calculate_ppc2(spike_times: np.ndarray = None,  | 
| 287 | 
            -
                              lfp_fs: float = None,  | 
| 327 | 
            +
            def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
         | 
| 328 | 
            +
                              lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
         | 
| 288 329 | 
             
                              lowcut: float = None, highcut: float = None,
         | 
| 289 | 
            -
                              bandwidth: float = 2.0) ->  | 
| 330 | 
            +
                              bandwidth: float = 2.0) -> float:
         | 
| 290 331 | 
             
                """
         | 
| 291 332 | 
             
                # -----------------------------------------------------------------------------
         | 
| 292 333 | 
             
                # PPC2 Calculation (Vinck et al., 2010) 
         | 
| @@ -297,18 +338,31 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None | |
| 297 338 | 
             
                #   PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
         | 
| 298 339 | 
             
                # -----------------------------------------------------------------------------
         | 
| 299 340 |  | 
| 300 | 
            -
                Parameters | 
| 301 | 
            -
                 | 
| 302 | 
            -
                 | 
| 303 | 
            -
             | 
| 304 | 
            -
                 | 
| 305 | 
            -
             | 
| 306 | 
            -
                 | 
| 307 | 
            -
             | 
| 308 | 
            -
                 | 
| 309 | 
            -
             | 
| 310 | 
            -
                 | 
| 311 | 
            -
             | 
| 341 | 
            +
                Parameters
         | 
| 342 | 
            +
                ----------
         | 
| 343 | 
            +
                spike_times : np.ndarray
         | 
| 344 | 
            +
                    Array of spike times
         | 
| 345 | 
            +
                lfp_data : np.ndarray
         | 
| 346 | 
            +
                    Local field potential time series data
         | 
| 347 | 
            +
                spike_fs : float, optional
         | 
| 348 | 
            +
                    Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
         | 
| 349 | 
            +
                lfp_fs : float
         | 
| 350 | 
            +
                    Sampling frequency in Hz of the LFP data
         | 
| 351 | 
            +
                filter_method : str, optional
         | 
| 352 | 
            +
                    Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
         | 
| 353 | 
            +
                freq_of_interest : float, optional
         | 
| 354 | 
            +
                    Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
         | 
| 355 | 
            +
                lowcut : float, optional
         | 
| 356 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 357 | 
            +
                highcut : float, optional
         | 
| 358 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 359 | 
            +
                bandwidth : float, optional
         | 
| 360 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         | 
| 361 | 
            +
                
         | 
| 362 | 
            +
                Returns
         | 
| 363 | 
            +
                -------
         | 
| 364 | 
            +
                float
         | 
| 365 | 
            +
                    Pairwise Phase Consistency 2 (PPC2) value
         | 
| 312 366 | 
             
                """
         | 
| 313 367 |  | 
| 314 368 | 
             
                if spike_fs is None:
         | 
| @@ -320,33 +374,32 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None | |
| 320 374 | 
             
                spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
         | 
| 321 375 |  | 
| 322 376 | 
             
                # Filter indices to ensure they're within bounds of the LFP signal
         | 
| 323 | 
            -
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len( | 
| 377 | 
            +
                valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
         | 
| 324 378 | 
             
                if len(valid_indices) <= 1:
         | 
| 325 | 
            -
                    return 0 | 
| 379 | 
            +
                    return 0
         | 
| 326 380 |  | 
| 327 381 | 
             
                # Extract phase using the specified method
         | 
| 328 | 
            -
                if  | 
| 382 | 
            +
                if filter_method == 'wavelet':
         | 
| 329 383 | 
             
                    if freq_of_interest is None:
         | 
| 330 384 | 
             
                        raise ValueError("freq_of_interest must be provided for the wavelet method.")
         | 
| 331 385 |  | 
| 332 386 | 
             
                    # Apply CWT to extract phase at the frequency of interest
         | 
| 333 | 
            -
                    lfp_complex = wavelet_filter(x= | 
| 387 | 
            +
                    lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
         | 
| 334 388 | 
             
                    instantaneous_phase = np.angle(lfp_complex)
         | 
| 335 389 |  | 
| 336 | 
            -
                elif  | 
| 390 | 
            +
                elif filter_method == 'butter':
         | 
| 337 391 | 
             
                    if lowcut is None or highcut is None:
         | 
| 338 | 
            -
                         | 
| 339 | 
            -
             | 
| 340 | 
            -
                     | 
| 341 | 
            -
             | 
| 342 | 
            -
                        filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
         | 
| 392 | 
            +
                        raise ValueError("Both lowcut and highcut must be specified for the butter method.")
         | 
| 393 | 
            +
                    
         | 
| 394 | 
            +
                    # Bandpass filter the signal
         | 
| 395 | 
            +
                    filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
         | 
| 343 396 |  | 
| 344 397 | 
             
                    # Get phase using the Hilbert transform
         | 
| 345 398 | 
             
                    analytic_signal = signal.hilbert(filtered_lfp)
         | 
| 346 399 | 
             
                    instantaneous_phase = np.angle(analytic_signal)
         | 
| 347 400 |  | 
| 348 401 | 
             
                else:
         | 
| 349 | 
            -
                    raise ValueError("Invalid method. Choose 'wavelet' or ' | 
| 402 | 
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         | 
| 350 403 |  | 
| 351 404 | 
             
                # Get phases at spike times
         | 
| 352 405 | 
             
                spike_phases = instantaneous_phase[valid_indices]
         | 
| @@ -355,7 +408,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None | |
| 355 408 | 
             
                n = len(spike_phases)
         | 
| 356 409 |  | 
| 357 410 | 
             
                if n <= 1:
         | 
| 358 | 
            -
                    return 0 | 
| 411 | 
            +
                    return 0
         | 
| 359 412 |  | 
| 360 413 | 
             
                # Convert phases to unit vectors in the complex plane
         | 
| 361 414 | 
             
                unit_vectors = np.exp(1j * spike_phases)
         | 
| @@ -369,33 +422,43 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None | |
| 369 422 | 
             
                return ppc2
         | 
| 370 423 |  | 
| 371 424 |  | 
| 372 | 
            -
            def calculate_ppc_per_cell(spike_df: pd.DataFrame,  | 
| 373 | 
            -
                                        spike_fs: float, lfp_fs:float,
         | 
| 374 | 
            -
                                        pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
         | 
| 425 | 
            +
            def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None,
         | 
| 426 | 
            +
                                        spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
         | 
| 427 | 
            +
                                        pop_names: List[str]=None, freqs: List[float]=None) -> Dict[str, Dict[int, Dict[float, float]]]:
         | 
| 375 428 | 
             
                """
         | 
| 376 429 | 
             
                Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
         | 
| 377 430 |  | 
| 378 431 | 
             
                This function computes the PPC for each neuron within the specified populations based on their spike times
         | 
| 379 | 
            -
                and a  | 
| 432 | 
            +
                and the provided LFP signal. It returns a nested dictionary structure containing the PPC values
         | 
| 433 | 
            +
                organized by population, node ID, and frequency.
         | 
| 380 434 |  | 
| 381 | 
            -
                 | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
                     | 
| 385 | 
            -
             | 
| 386 | 
            -
                     | 
| 387 | 
            -
             | 
| 435 | 
            +
                Parameters
         | 
| 436 | 
            +
                ----------
         | 
| 437 | 
            +
                spike_df : pd.DataFrame
         | 
| 438 | 
            +
                    DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
         | 
| 439 | 
            +
                lfp_data : np.ndarray
         | 
| 440 | 
            +
                    Local field potential (LFP) time series data
         | 
| 441 | 
            +
                spike_fs : float
         | 
| 442 | 
            +
                    Sampling frequency of the spike times in Hz
         | 
| 443 | 
            +
                lfp_fs : float
         | 
| 444 | 
            +
                    Sampling frequency of the LFP signal in Hz
         | 
| 445 | 
            +
                pop_names : List[str]
         | 
| 446 | 
            +
                    List of population names to analyze
         | 
| 447 | 
            +
                freqs : List[float]
         | 
| 448 | 
            +
                    List of frequencies (in Hz) at which to calculate PPC
         | 
| 388 449 |  | 
| 389 | 
            -
                Returns | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 450 | 
            +
                Returns
         | 
| 451 | 
            +
                -------
         | 
| 452 | 
            +
                Dict[str, Dict[int, Dict[float, float]]]
         | 
| 453 | 
            +
                    Nested dictionary where the structure is:
         | 
| 454 | 
            +
                    {
         | 
| 455 | 
            +
                        population_name: {
         | 
| 456 | 
            +
                            node_id: {
         | 
| 457 | 
            +
                                frequency: PPC value
         | 
| 396 458 | 
             
                            }
         | 
| 397 459 | 
             
                        }
         | 
| 398 | 
            -
             | 
| 460 | 
            +
                    }
         | 
| 461 | 
            +
                    PPC values are floats representing the pairwise phase consistency at each frequency
         | 
| 399 462 | 
             
                """
         | 
| 400 463 | 
             
                ppc_dict = {}
         | 
| 401 464 | 
             
                for pop in pop_names:
         | 
| @@ -416,11 +479,12 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray, | |
| 416 479 | 
             
                        for freq in freqs:
         | 
| 417 480 | 
             
                            ppc = calculate_ppc2(
         | 
| 418 481 | 
             
                                node_spikes['timestamps'].values,
         | 
| 419 | 
            -
                                 | 
| 482 | 
            +
                                lfp_data,
         | 
| 420 483 | 
             
                                spike_fs=spike_fs,
         | 
| 421 484 | 
             
                                lfp_fs=lfp_fs,
         | 
| 422 485 | 
             
                                freq_of_interest=freq,
         | 
| 423 | 
            -
                                 | 
| 486 | 
            +
                                bandwidth=bandwidth,
         | 
| 487 | 
            +
                                filter_method='wavelet'
         | 
| 424 488 | 
             
                            )
         | 
| 425 489 | 
             
                            ppc_dict[pop][node][freq] = ppc
         | 
| 426 490 |  | 
| @@ -429,7 +493,9 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray, | |
| 429 493 | 
             
                return ppc_dict
         | 
| 430 494 |  | 
| 431 495 |  | 
| 432 | 
            -
            def calculate_spike_rate_power_correlation(spike_rate,  | 
| 496 | 
            +
            def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
         | 
| 497 | 
            +
                                                      bandwidth=2.0, lowcut=None, highcut=None,
         | 
| 498 | 
            +
                                                      freq_range=(10, 100), freq_step=5):
         | 
| 433 499 | 
             
                """
         | 
| 434 500 | 
             
                Calculate correlation between population spike rates and LFP power across frequencies
         | 
| 435 501 | 
             
                using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
         | 
| @@ -438,16 +504,24 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ | |
| 438 504 | 
             
                -----------
         | 
| 439 505 | 
             
                spike_rate : DataFrame
         | 
| 440 506 | 
             
                    Pre-calculated population spike rates at the same fs as lfp
         | 
| 441 | 
            -
                 | 
| 507 | 
            +
                lfp_data : np.array
         | 
| 442 508 | 
             
                    LFP data
         | 
| 443 509 | 
             
                fs : float
         | 
| 444 510 | 
             
                    Sampling frequency
         | 
| 445 511 | 
             
                pop_names : list
         | 
| 446 512 | 
             
                    List of population names to analyze
         | 
| 447 | 
            -
                 | 
| 448 | 
            -
                     | 
| 449 | 
            -
                 | 
| 450 | 
            -
                     | 
| 513 | 
            +
                filter_method : str, optional
         | 
| 514 | 
            +
                    Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
         | 
| 515 | 
            +
                bandwidth : float, optional
         | 
| 516 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
         | 
| 517 | 
            +
                lowcut : float, optional
         | 
| 518 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 519 | 
            +
                highcut : float, optional
         | 
| 520 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 521 | 
            +
                freq_range : tuple, optional
         | 
| 522 | 
            +
                    Min and max frequency to analyze (default: (10, 100))
         | 
| 523 | 
            +
                freq_step : float, optional
         | 
| 524 | 
            +
                    Step size for frequency analysis (default: 5)
         | 
| 451 525 |  | 
| 452 526 | 
             
                Returns:
         | 
| 453 527 | 
             
                --------
         | 
| @@ -463,14 +537,15 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ | |
| 463 537 | 
             
                # Dictionary to store results
         | 
| 464 538 | 
             
                correlation_results = {pop: {} for pop in pop_names}
         | 
| 465 539 |  | 
| 466 | 
            -
                # Calculate power at each frequency band using  | 
| 540 | 
            +
                # Calculate power at each frequency band using specified filter
         | 
| 467 541 | 
             
                power_by_freq = {}
         | 
| 468 542 | 
             
                for freq in frequencies:
         | 
| 469 | 
            -
                     | 
| 470 | 
            -
             | 
| 471 | 
            -
             | 
| 472 | 
            -
                     | 
| 473 | 
            -
             | 
| 543 | 
            +
                    if filter_method == 'wavelet':
         | 
| 544 | 
            +
                        power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method, 
         | 
| 545 | 
            +
                                                           lowcut=None, highcut=None, bandwidth=bandwidth)
         | 
| 546 | 
            +
                    elif filter_method == 'butter':
         | 
| 547 | 
            +
                        power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method, 
         | 
| 548 | 
            +
                                                           lowcut=lowcut, highcut=highcut)
         | 
| 474 549 |  | 
| 475 550 | 
             
                # Calculate correlation for each population
         | 
| 476 551 | 
             
                for pop in pop_names:
         | 
| @@ -481,7 +556,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_ | |
| 481 556 | 
             
                    for freq in frequencies:
         | 
| 482 557 | 
             
                        # Make sure the lengths match
         | 
| 483 558 | 
             
                        if len(pop_rate) != len(power_by_freq[freq]):
         | 
| 484 | 
            -
                            raise  | 
| 559 | 
            +
                            raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
         | 
| 485 560 | 
             
                        # use spearman for non-parametric correlation
         | 
| 486 561 | 
             
                        corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
         | 
| 487 562 | 
             
                        correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
         | 
    
        bmtool/analysis/lfp.py
    CHANGED
    
    | @@ -273,10 +273,83 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float: | |
| 273 273 | 
             
                return normalized_power
         | 
| 274 274 |  | 
| 275 275 |  | 
| 276 | 
            -
            def  | 
| 276 | 
            +
            def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
         | 
| 277 | 
            +
                """
         | 
| 278 | 
            +
                Calculate the passband of a complex Morlet wavelet filter.
         | 
| 279 | 
            +
                
         | 
| 280 | 
            +
                Parameters
         | 
| 281 | 
            +
                ----------
         | 
| 282 | 
            +
                center_freq : float
         | 
| 283 | 
            +
                    Center frequency (Hz) of the wavelet filter
         | 
| 284 | 
            +
                bandwidth : float
         | 
| 285 | 
            +
                    Bandwidth parameter of the wavelet filter
         | 
| 286 | 
            +
                threshold : float, optional
         | 
| 287 | 
            +
                    Power threshold to define the passband edges (default: 0.5 = -3dB point)
         | 
| 288 | 
            +
                    
         | 
| 289 | 
            +
                Returns
         | 
| 290 | 
            +
                -------
         | 
| 291 | 
            +
                tuple
         | 
| 292 | 
            +
                    (lower_bound, upper_bound, passband_width) of the frequency passband in Hz
         | 
| 293 | 
            +
                """
         | 
| 294 | 
            +
                # Create a high-resolution frequency axis around the center frequency
         | 
| 295 | 
            +
                # Extend range to 3x the expected width to ensure we capture the full passband
         | 
| 296 | 
            +
                expected_width = center_freq * bandwidth / 2
         | 
| 297 | 
            +
                freq_min = max(0.1, center_freq - 3 * expected_width)
         | 
| 298 | 
            +
                freq_max = center_freq + 3 * expected_width
         | 
| 299 | 
            +
                freq_axis = np.linspace(freq_min, freq_max, 1000)
         | 
| 300 | 
            +
                
         | 
| 301 | 
            +
                # Calculate the theoretical frequency response of the Morlet wavelet
         | 
| 302 | 
            +
                # For a complex Morlet wavelet, the frequency response approximates a Gaussian
         | 
| 303 | 
            +
                # centered at the center frequency with width related to the bandwidth parameter
         | 
| 304 | 
            +
                sigma_f = bandwidth * center_freq / 8  # Approximate relationship for cmor wavelet
         | 
| 305 | 
            +
                response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
         | 
| 306 | 
            +
                
         | 
| 307 | 
            +
                # Find the passband edges (where response crosses the threshold)
         | 
| 308 | 
            +
                above_threshold = response >= threshold
         | 
| 309 | 
            +
                if not np.any(above_threshold):
         | 
| 310 | 
            +
                    return (center_freq, center_freq, 0)  # No passband found
         | 
| 311 | 
            +
                
         | 
| 312 | 
            +
                # Find the first and last indices where response is above threshold
         | 
| 313 | 
            +
                indices = np.where(above_threshold)[0]
         | 
| 314 | 
            +
                lower_idx = indices[0]
         | 
| 315 | 
            +
                upper_idx = indices[-1]
         | 
| 316 | 
            +
                
         | 
| 317 | 
            +
                # Get the corresponding frequencies
         | 
| 318 | 
            +
                lower_bound = freq_axis[lower_idx]
         | 
| 319 | 
            +
                upper_bound = freq_axis[upper_idx]
         | 
| 320 | 
            +
                passband_width = upper_bound - lower_bound
         | 
| 321 | 
            +
                
         | 
| 322 | 
            +
                return (lower_bound, upper_bound, passband_width)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
         | 
| 277 326 | 
             
                """
         | 
| 278 327 | 
             
                Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                Parameters
         | 
| 330 | 
            +
                ----------
         | 
| 331 | 
            +
                x : np.ndarray
         | 
| 332 | 
            +
                    Input signal
         | 
| 333 | 
            +
                freq : float
         | 
| 334 | 
            +
                    Target frequency for the wavelet filter
         | 
| 335 | 
            +
                fs : float
         | 
| 336 | 
            +
                    Sampling frequency of the signal
         | 
| 337 | 
            +
                bandwidth : float, optional
         | 
| 338 | 
            +
                    Bandwidth parameter of the wavelet filter (default is 1.0)
         | 
| 339 | 
            +
                axis : int, optional
         | 
| 340 | 
            +
                    Axis along which to compute the CWT (default is -1)
         | 
| 341 | 
            +
                show_passband : bool, optional
         | 
| 342 | 
            +
                    If True, print the passband of the wavelet filter (default is False)
         | 
| 343 | 
            +
                
         | 
| 344 | 
            +
                Returns
         | 
| 345 | 
            +
                -------
         | 
| 346 | 
            +
                np.ndarray
         | 
| 347 | 
            +
                    Continuous Wavelet Transform of the input signal
         | 
| 279 348 | 
             
                """
         | 
| 349 | 
            +
                if show_passband:
         | 
| 350 | 
            +
                    lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
         | 
| 351 | 
            +
                    print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
         | 
| 352 | 
            +
                    print(f"  Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
         | 
| 280 353 | 
             
                wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
         | 
| 281 354 | 
             
                scale = pywt.scale2frequency(wavelet, 1) * fs / freq
         | 
| 282 355 | 
             
                x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
         | 
| @@ -292,6 +365,53 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: | |
| 292 365 | 
             
                return x_a
         | 
| 293 366 |  | 
| 294 367 |  | 
| 368 | 
            +
            def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
         | 
| 369 | 
            +
                               lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
         | 
| 370 | 
            +
                """
         | 
| 371 | 
            +
                Compute the power of the raw LFP signal in a specified frequency band.
         | 
| 372 | 
            +
                
         | 
| 373 | 
            +
                Parameters
         | 
| 374 | 
            +
                ----------
         | 
| 375 | 
            +
                lfp_data : np.ndarray
         | 
| 376 | 
            +
                    Raw local field potential (LFP) time series data
         | 
| 377 | 
            +
                freq : float
         | 
| 378 | 
            +
                    Center frequency (Hz) for wavelet filtering method
         | 
| 379 | 
            +
                fs : float
         | 
| 380 | 
            +
                    Sampling frequency (Hz) of the input data
         | 
| 381 | 
            +
                filter_method : str, optional
         | 
| 382 | 
            +
                    Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
         | 
| 383 | 
            +
                lowcut : float, optional
         | 
| 384 | 
            +
                    Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 385 | 
            +
                highcut : float, optional
         | 
| 386 | 
            +
                    Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
         | 
| 387 | 
            +
                bandwidth : float, optional
         | 
| 388 | 
            +
                    Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
         | 
| 389 | 
            +
                    
         | 
| 390 | 
            +
                Returns
         | 
| 391 | 
            +
                -------
         | 
| 392 | 
            +
                np.ndarray
         | 
| 393 | 
            +
                    Power of the filtered signal (magnitude squared)
         | 
| 394 | 
            +
                    
         | 
| 395 | 
            +
                Notes
         | 
| 396 | 
            +
                -----
         | 
| 397 | 
            +
                - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
         | 
| 398 | 
            +
                - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
         | 
| 399 | 
            +
                - When using the 'butter' method, both lowcut and highcut must be provided
         | 
| 400 | 
            +
                """
         | 
| 401 | 
            +
                if filter_method == 'wavelet':
         | 
| 402 | 
            +
                    filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
         | 
| 403 | 
            +
                elif filter_method == 'butter':
         | 
| 404 | 
            +
                    if lowcut is None or highcut is None:
         | 
| 405 | 
            +
                        raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
         | 
| 406 | 
            +
                    filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
         | 
| 407 | 
            +
                else:
         | 
| 408 | 
            +
                    raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
         | 
| 409 | 
            +
                
         | 
| 410 | 
            +
                # Calculate power (magnitude squared of filtered signal)
         | 
| 411 | 
            +
                power = np.abs(filtered_signal)**2
         | 
| 412 | 
            +
                return power
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 295 415 | 
             
            # windowing functions 
         | 
| 296 416 | 
             
            def windowed_xarray(da, windows, dim='time',
         | 
| 297 417 | 
             
                                new_coord_name='cycle', new_coord=None):
         | 
    
        bmtool/analysis/spikes.py
    CHANGED
    
    | @@ -11,22 +11,33 @@ from scipy.stats import mannwhitneyu | |
| 11 11 | 
             
            import os
         | 
| 12 12 |  | 
| 13 13 |  | 
| 14 | 
            -
            def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
         | 
| 14 | 
            +
            def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
         | 
| 15 15 | 
             
                """
         | 
| 16 16 | 
             
                Load spike data from an HDF5 file into a pandas DataFrame.
         | 
| 17 17 |  | 
| 18 | 
            -
                 | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
                     | 
| 22 | 
            -
             | 
| 23 | 
            -
                     | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 18 | 
            +
                Parameters
         | 
| 19 | 
            +
                ----------
         | 
| 20 | 
            +
                spike_file : str
         | 
| 21 | 
            +
                    Path to the HDF5 file containing spike data
         | 
| 22 | 
            +
                network_name : str
         | 
| 23 | 
            +
                    The name of the network within the HDF5 file from which to load spike data
         | 
| 24 | 
            +
                sort : bool, optional
         | 
| 25 | 
            +
                    Whether to sort the DataFrame by 'timestamps' (default: True)
         | 
| 26 | 
            +
                config : str, optional
         | 
| 27 | 
            +
                    Path to configuration file to label the cell type of each spike (default: None)
         | 
| 28 | 
            +
                groupby : Union[str, List[str]], optional
         | 
| 29 | 
            +
                    The column(s) to group by (default: 'pop_name')
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                Returns
         | 
| 32 | 
            +
                -------
         | 
| 33 | 
            +
                pd.DataFrame
         | 
| 34 | 
            +
                    A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
         | 
| 35 | 
            +
                    with additional columns if a config file is provided
         | 
| 27 36 |  | 
| 28 | 
            -
                 | 
| 29 | 
            -
             | 
| 37 | 
            +
                Examples
         | 
| 38 | 
            +
                --------
         | 
| 39 | 
            +
                >>> df = load_spikes_to_df("spikes.h5", "cortex")
         | 
| 40 | 
            +
                >>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
         | 
| 30 41 | 
             
                """
         | 
| 31 42 | 
             
                with h5py.File(spike_file) as f:
         | 
| 32 43 | 
             
                    spikes_df = pd.DataFrame({
         | 
| @@ -126,23 +137,31 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] = | |
| 126 137 |  | 
| 127 138 |  | 
| 128 139 | 
             
            def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None, 
         | 
| 129 | 
            -
                               time_points: Optional[Union[np.ndarray, list]] = None,  | 
| 140 | 
            +
                               time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
         | 
| 130 141 | 
             
                """
         | 
| 131 142 | 
             
                Calculate the spike count or frequency histogram over specified time intervals.
         | 
| 132 143 |  | 
| 133 | 
            -
                 | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
                     | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 144 | 
            +
                Parameters
         | 
| 145 | 
            +
                ----------
         | 
| 146 | 
            +
                spike_times : Union[np.ndarray, list]
         | 
| 147 | 
            +
                    Array or list of spike times in milliseconds
         | 
| 148 | 
            +
                time : Optional[Tuple[float, float, float]], optional
         | 
| 149 | 
            +
                    Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points 
         | 
| 150 | 
            +
                    if `time_points` is not provided. Default is None.
         | 
| 151 | 
            +
                time_points : Optional[Union[np.ndarray, list]], optional
         | 
| 152 | 
            +
                    Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
         | 
| 153 | 
            +
                frequency : bool, optional
         | 
| 154 | 
            +
                    If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                Returns
         | 
| 157 | 
            +
                -------
         | 
| 158 | 
            +
                np.ndarray
         | 
| 159 | 
            +
                    Array of spike counts or frequencies, depending on the `frequency` flag.
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                Raises
         | 
| 162 | 
            +
                ------
         | 
| 163 | 
            +
                ValueError
         | 
| 164 | 
            +
                    If both `time` and `time_points` are None.
         | 
| 146 165 | 
             
                """
         | 
| 147 166 | 
             
                if time_points is None:
         | 
| 148 167 | 
             
                    if time is None:
         | 
| @@ -156,43 +175,57 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f | |
| 156 175 | 
             
                bins = np.append(time_points, time_points[-1] + dt)
         | 
| 157 176 | 
             
                spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
         | 
| 158 177 |  | 
| 159 | 
            -
                if  | 
| 178 | 
            +
                if frequency:
         | 
| 160 179 | 
             
                    spike_rate = 1000 / dt * spike_rate
         | 
| 161 180 |  | 
| 162 181 | 
             
                return spike_rate
         | 
| 163 182 |  | 
| 164 183 |  | 
| 165 | 
            -
            def get_population_spike_rate( | 
| 184 | 
            +
            def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None, 
         | 
| 166 185 | 
             
                                          config: Optional[str] = None, network_name: Optional[str] = None,
         | 
| 167 186 | 
             
                                          save: bool = False, save_path: Optional[str] = None,
         | 
| 168 187 | 
             
                                          normalize: bool = False) -> Dict[str, np.ndarray]:
         | 
| 169 188 | 
             
                """
         | 
| 170 189 | 
             
                Calculate the population spike rate for each population in the given spike data, with an option to normalize.
         | 
| 171 190 |  | 
| 172 | 
            -
                 | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
                     | 
| 176 | 
            -
             | 
| 177 | 
            -
                     | 
| 178 | 
            -
             | 
| 179 | 
            -
                     | 
| 180 | 
            -
             | 
| 181 | 
            -
                     | 
| 182 | 
            -
             | 
| 183 | 
            -
                     | 
| 184 | 
            -
             | 
| 185 | 
            -
                 | 
| 186 | 
            -
                     | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
                 | 
| 193 | 
            -
                     | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 191 | 
            +
                Parameters
         | 
| 192 | 
            +
                ----------
         | 
| 193 | 
            +
                spike_data : pd.DataFrame
         | 
| 194 | 
            +
                    A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'
         | 
| 195 | 
            +
                fs : float, optional
         | 
| 196 | 
            +
                    Sampling frequency in Hz, which determines the time bin size for calculating the spike rate (default: 400.0)
         | 
| 197 | 
            +
                t_start : float, optional
         | 
| 198 | 
            +
                    Start time (in milliseconds) for spike rate calculation (default: 0)
         | 
| 199 | 
            +
                t_stop : Optional[float], optional
         | 
| 200 | 
            +
                    Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data
         | 
| 201 | 
            +
                config : Optional[str], optional
         | 
| 202 | 
            +
                    Path to a configuration file containing node information, used to determine the correct number of nodes per population.
         | 
| 203 | 
            +
                    If None, node count is estimated from unique node spikes (default: None)
         | 
| 204 | 
            +
                network_name : Optional[str], optional
         | 
| 205 | 
            +
                    Name of the network used in the configuration file, allowing selection of nodes for that network.
         | 
| 206 | 
            +
                    Required if `config` is provided (default: None)
         | 
| 207 | 
            +
                save : bool, optional
         | 
| 208 | 
            +
                    Whether to save the calculated population spike rate to a file (default: False)
         | 
| 209 | 
            +
                save_path : Optional[str], optional
         | 
| 210 | 
            +
                    Directory path where the file should be saved if `save` is True (default: None)
         | 
| 211 | 
            +
                normalize : bool, optional
         | 
| 212 | 
            +
                    Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                Returns
         | 
| 215 | 
            +
                -------
         | 
| 216 | 
            +
                Dict[str, np.ndarray]
         | 
| 217 | 
            +
                    A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
         | 
| 218 | 
            +
                    If `normalize` is True, each population's spike rate is scaled to [0, 1].
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                Raises
         | 
| 221 | 
            +
                ------
         | 
| 222 | 
            +
                ValueError
         | 
| 223 | 
            +
                    If `save` is True but `save_path` is not provided.
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                Notes
         | 
| 226 | 
            +
                -----
         | 
| 227 | 
            +
                - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
         | 
| 228 | 
            +
                - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
         | 
| 196 229 | 
             
                """
         | 
| 197 230 | 
             
                pop_spikes = {}
         | 
| 198 231 | 
             
                node_number = {}
         | 
| @@ -205,8 +238,8 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: | |
| 205 238 | 
             
                    if not network_name:
         | 
| 206 239 | 
             
                        print("Grabbing first network; specify a network name to ensure correct node population is selected.")
         | 
| 207 240 |  | 
| 208 | 
            -
                for pop_name in  | 
| 209 | 
            -
                    ps =  | 
| 241 | 
            +
                for pop_name in spike_data['pop_name'].unique():
         | 
| 242 | 
            +
                    ps = spike_data[spike_data['pop_name'] == pop_name]
         | 
| 210 243 |  | 
| 211 244 | 
             
                    if config:
         | 
| 212 245 | 
             
                        nodes = load_nodes_from_config(config)
         | 
| @@ -220,12 +253,12 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: | |
| 220 253 | 
             
                        node_number[pop_name] = ps['node_ids'].nunique()
         | 
| 221 254 |  | 
| 222 255 | 
             
                    if t_stop is None:
         | 
| 223 | 
            -
                        t_stop =  | 
| 256 | 
            +
                        t_stop = spike_data['timestamps'].max()
         | 
| 224 257 |  | 
| 225 | 
            -
                    filtered_spikes =  | 
| 226 | 
            -
                        ( | 
| 227 | 
            -
                        ( | 
| 228 | 
            -
                        ( | 
| 258 | 
            +
                    filtered_spikes = spike_data[
         | 
| 259 | 
            +
                        (spike_data['pop_name'] == pop_name) & 
         | 
| 260 | 
            +
                        (spike_data['timestamps'] > t_start) & 
         | 
| 261 | 
            +
                        (spike_data['timestamps'] < t_stop)
         | 
| 229 262 | 
             
                    ]
         | 
| 230 263 | 
             
                    pop_spikes[pop_name] = filtered_spikes
         | 
| 231 264 |  | 
| @@ -254,11 +287,30 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: | |
| 254 287 | 
             
                return spike_rate
         | 
| 255 288 |  | 
| 256 289 |  | 
| 257 | 
            -
            def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
         | 
| 290 | 
            +
            def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
         | 
| 258 291 | 
             
                """
         | 
| 259 | 
            -
                Compares the firing rates of a population during two different time windows
         | 
| 260 | 
            -
                 | 
| 261 | 
            -
                 | 
| 292 | 
            +
                Compares the firing rates of a population during two different time windows and performs
         | 
| 293 | 
            +
                a statistical test to determine if there is a significant difference.
         | 
| 294 | 
            +
                
         | 
| 295 | 
            +
                Parameters
         | 
| 296 | 
            +
                ----------
         | 
| 297 | 
            +
                spike_df : pd.DataFrame
         | 
| 298 | 
            +
                    DataFrame containing spike data with columns for timestamps, node_ids, and grouping variable
         | 
| 299 | 
            +
                group_by : str
         | 
| 300 | 
            +
                    Column name to group spikes by (e.g., 'pop_name')
         | 
| 301 | 
            +
                time_window_1 : List[float]
         | 
| 302 | 
            +
                    First time window as [start, stop] in milliseconds
         | 
| 303 | 
            +
                time_window_2 : List[float]
         | 
| 304 | 
            +
                    Second time window as [start, stop] in milliseconds
         | 
| 305 | 
            +
                
         | 
| 306 | 
            +
                Returns
         | 
| 307 | 
            +
                -------
         | 
| 308 | 
            +
                None
         | 
| 309 | 
            +
                    Results are printed to the console
         | 
| 310 | 
            +
                
         | 
| 311 | 
            +
                Notes
         | 
| 312 | 
            +
                -----
         | 
| 313 | 
            +
                Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
         | 
| 262 314 | 
             
                """
         | 
| 263 315 | 
             
                # Filter spikes for the population of interest
         | 
| 264 316 | 
             
                for pop_name in spike_df[group_by].unique():
         | 
| @@ -8,10 +8,10 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298 | |
| 8 8 | 
             
            bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
         | 
| 9 9 | 
             
            bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
         | 
| 10 10 | 
             
            bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 11 | 
            -
            bmtool/analysis/entrainment.py,sha256= | 
| 12 | 
            -
            bmtool/analysis/lfp.py,sha256= | 
| 11 | 
            +
            bmtool/analysis/entrainment.py,sha256=JRg9sQ7WrZMqHwMaDJtCN7kGgZHJ5msUSzP-JPltC8k,23158
         | 
| 12 | 
            +
            bmtool/analysis/lfp.py,sha256=hOqD4xcDEL0NrNIN2-Ler_mkvY5cEUhxr7VUdX5Gwh8,21737
         | 
| 13 13 | 
             
            bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
         | 
| 14 | 
            -
            bmtool/analysis/spikes.py,sha256= | 
| 14 | 
            +
            bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
         | 
| 15 15 | 
             
            bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 16 16 | 
             
            bmtool/bmplot/connections.py,sha256=re6QZX_NfQnIaWayGt3EhMINhCeMMSQ6rFR2sJbFeWk,51385
         | 
| 17 17 | 
             
            bmtool/bmplot/entrainment.py,sha256=3IBD6tfW7lvkuB6DTan7rAVAeznOOzmHLr1qA2rgtCY,1671
         | 
| @@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396 | |
| 26 26 | 
             
            bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
         | 
| 27 27 | 
             
            bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 28 28 | 
             
            bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
         | 
| 29 | 
            -
            bmtool-0.7.0. | 
| 30 | 
            -
            bmtool-0.7.0. | 
| 31 | 
            -
            bmtool-0.7.0. | 
| 32 | 
            -
            bmtool-0.7.0. | 
| 33 | 
            -
            bmtool-0.7.0. | 
| 34 | 
            -
            bmtool-0.7.0. | 
| 29 | 
            +
            bmtool-0.7.0.4.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
         | 
| 30 | 
            +
            bmtool-0.7.0.4.dist-info/METADATA,sha256=I-D4fwZQIHcHvD6Ou8Az3Kclwl17kwpZH7YHrD0eEg4,2768
         | 
| 31 | 
            +
            bmtool-0.7.0.4.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
         | 
| 32 | 
            +
            bmtool-0.7.0.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
         | 
| 33 | 
            +
            bmtool-0.7.0.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
         | 
| 34 | 
            +
            bmtool-0.7.0.4.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |