bmtool 0.6.8.6__py3-none-any.whl → 0.6.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/lfp.py +243 -10
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/METADATA +2 -1
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/RECORD +7 -7
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/WHEEL +0 -0
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.6.8.6.dist-info → bmtool-0.6.8.8.dist-info}/top_level.txt +0 -0
bmtool/analysis/lfp.py
CHANGED
@@ -6,11 +6,14 @@ import h5py
|
|
6
6
|
import numpy as np
|
7
7
|
import xarray as xr
|
8
8
|
from fooof import FOOOF
|
9
|
-
from fooof.sim.gen import gen_model
|
9
|
+
from fooof.sim.gen import gen_model, gen_aperiodic
|
10
10
|
import matplotlib.pyplot as plt
|
11
11
|
from scipy import signal
|
12
12
|
import pywt
|
13
13
|
from bmtool.bmplot import is_notebook
|
14
|
+
import numba
|
15
|
+
from numba import cuda
|
16
|
+
import pandas as pd
|
14
17
|
|
15
18
|
|
16
19
|
def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
@@ -422,10 +425,45 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
|
|
422
425
|
return plv
|
423
426
|
|
424
427
|
|
428
|
+
@numba.njit(parallel=True, fastmath=True)
|
429
|
+
def _ppc_parallel_numba(spike_phases):
|
430
|
+
"""Numba-optimized parallel PPC calculation"""
|
431
|
+
n = len(spike_phases)
|
432
|
+
sum_cos = 0.0
|
433
|
+
for i in numba.prange(n):
|
434
|
+
phase_i = spike_phases[i]
|
435
|
+
for j in range(i + 1, n):
|
436
|
+
sum_cos += np.cos(phase_i - spike_phases[j])
|
437
|
+
return (2 / (n * (n - 1))) * sum_cos
|
438
|
+
|
439
|
+
|
440
|
+
@cuda.jit(fastmath=True)
|
441
|
+
def _ppc_cuda_kernel(spike_phases, out):
|
442
|
+
i = cuda.grid(1)
|
443
|
+
if i < len(spike_phases):
|
444
|
+
local_sum = 0.0
|
445
|
+
for j in range(i+1, len(spike_phases)):
|
446
|
+
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
447
|
+
out[i] = local_sum
|
448
|
+
|
449
|
+
|
450
|
+
def _ppc_gpu(spike_phases):
|
451
|
+
"""GPU-accelerated implementation"""
|
452
|
+
d_phases = cuda.to_device(spike_phases)
|
453
|
+
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
454
|
+
|
455
|
+
threads = 256
|
456
|
+
blocks = (len(spike_phases) + threads - 1) // threads
|
457
|
+
|
458
|
+
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
459
|
+
total = d_out.copy_to_host().sum()
|
460
|
+
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
461
|
+
|
462
|
+
|
425
463
|
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
426
464
|
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
427
465
|
lowcut: float = None, highcut: float = None,
|
428
|
-
bandwidth: float = 2.0) -> tuple:
|
466
|
+
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
429
467
|
"""
|
430
468
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
431
469
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
@@ -439,12 +477,11 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
439
477
|
- freq_of_interest: Desired frequency for wavelet phase extraction
|
440
478
|
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
441
479
|
- bandwidth: Bandwidth parameter for the wavelet
|
480
|
+
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
442
481
|
|
443
482
|
Returns:
|
444
483
|
- ppc: Pairwise Phase Consistency value
|
445
|
-
- spike_phases: Phases at spike times
|
446
484
|
"""
|
447
|
-
print("Note this method will a very long time if there are a lot of spikes. If there are a lot of spikes consider using the PPC2 method if speed is an issue")
|
448
485
|
if spike_fs is None:
|
449
486
|
spike_fs = lfp_fs
|
450
487
|
# Convert spike times to sample indices
|
@@ -511,11 +548,17 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
511
548
|
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
512
549
|
|
513
550
|
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
551
|
+
if ppc_method == 'numpy':
|
552
|
+
i, j = np.triu_indices(n_spikes, k=1)
|
553
|
+
phase_diff = spike_phases[i] - spike_phases[j]
|
554
|
+
sum_cos_diff = np.sum(np.cos(phase_diff))
|
555
|
+
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
556
|
+
elif ppc_method == 'numba':
|
557
|
+
ppc = _ppc_parallel_numba(spike_phases)
|
558
|
+
elif ppc_method == 'gpu':
|
559
|
+
ppc = _ppc_gpu(spike_phases)
|
560
|
+
else:
|
561
|
+
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
519
562
|
return ppc
|
520
563
|
|
521
564
|
|
@@ -605,6 +648,196 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
|
|
605
648
|
|
606
649
|
return ppc2
|
607
650
|
|
608
|
-
|
609
651
|
|
652
|
+
# windowing functions
|
653
|
+
def windowed_xarray(da, windows, dim='time',
|
654
|
+
new_coord_name='cycle', new_coord=None):
|
655
|
+
"""Divide xarray into windows of equal size along an axis
|
656
|
+
da: input DataArray
|
657
|
+
windows: 2d-array of windows
|
658
|
+
dim: dimension along which to divide
|
659
|
+
new_coord_name: name of new dimemsion along which to concatenate windows
|
660
|
+
new_coord: pandas Index object of new coordinates. Defaults to integer index
|
661
|
+
"""
|
662
|
+
win_da = [da.sel({dim: slice(*w)}) for w in windows]
|
663
|
+
n_win = min(x.coords[dim].size for x in win_da)
|
664
|
+
idx = {dim: slice(n_win)}
|
665
|
+
coords = da.coords[dim].isel(idx).coords
|
666
|
+
win_da = [x.isel(idx).assign_coords(coords) for x in win_da]
|
667
|
+
if new_coord is None:
|
668
|
+
new_coord = pd.Index(range(len(win_da)), name=new_coord_name)
|
669
|
+
win_da = xr.concat(win_da, dim=new_coord)
|
670
|
+
return win_da
|
671
|
+
|
672
|
+
|
673
|
+
def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
|
674
|
+
"""Group windows into a dictionary of DataArrays
|
675
|
+
win_da: input windowed DataArrays
|
676
|
+
win_grp_idx: dictionary of {window group id: window indices}
|
677
|
+
win_dim: dimension for different windows
|
678
|
+
Return: dictionaries of {window group id: DataArray of grouped windows}
|
679
|
+
win_on / win_off for windows selected / not selected by `win_grp_idx`
|
680
|
+
"""
|
681
|
+
win_on, win_off = {}, {}
|
682
|
+
for g, w in win_grp_idx.items():
|
683
|
+
win_on[g] = win_da.sel({win_dim: w})
|
684
|
+
win_off[g] = win_da.drop_sel({win_dim: w})
|
685
|
+
return win_on, win_off
|
686
|
+
|
687
|
+
|
688
|
+
def average_group_windows(win_da, win_dim='cycle', grp_dim='unique_cycle'):
|
689
|
+
"""Average over windows in each group and stack groups in a DataArray
|
690
|
+
win_da: input dictionary of {window group id: DataArray of grouped windows}
|
691
|
+
win_dim: dimension for different windows
|
692
|
+
grp_dim: dimension along which to stack average of window groups
|
693
|
+
"""
|
694
|
+
win_avg = {g: xr.concat([x.mean(dim=win_dim), x.std(dim=win_dim)],
|
695
|
+
pd.Index(('mean_', 'std_'), name='stats'))
|
696
|
+
for g, x in win_da.items()}
|
697
|
+
win_avg = xr.concat(win_avg.values(), dim=pd.Index(win_avg.keys(), name=grp_dim))
|
698
|
+
win_avg = win_avg.to_dataset(dim='stats')
|
699
|
+
return win_avg
|
700
|
+
|
701
|
+
# used for avg spectrogram across different trials
|
702
|
+
def get_windowed_data(x, windows, win_grp_idx, dim='time',
|
703
|
+
win_dim='cycle', win_coord=None, grp_dim='unique_cycle'):
|
704
|
+
"""Apply functions of windowing to data
|
705
|
+
x: DataArray
|
706
|
+
windows: `windows` for `windowed_xarray`
|
707
|
+
win_grp_idx: `win_grp_idx` for `group_windows`
|
708
|
+
dim: dimension along which to divide
|
709
|
+
win_dim: dimension for different windows
|
710
|
+
win_coord: pandas Index object of `win_dim` coordinates
|
711
|
+
grp_dim: dimension along which to stack average of window groups.
|
712
|
+
If None or empty or False, do not calculate average.
|
713
|
+
Return: data returned by three functions,
|
714
|
+
`windowed_xarray`, `group_windows`, `average_group_windows`
|
715
|
+
"""
|
716
|
+
x_win = windowed_xarray(x, windows, dim=dim,
|
717
|
+
new_coord_name=win_dim, new_coord=win_coord)
|
718
|
+
x_win_onff = group_windows(x_win, win_grp_idx, win_dim=win_dim)
|
719
|
+
if grp_dim:
|
720
|
+
x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim)
|
721
|
+
for x in x_win_onff]
|
722
|
+
else:
|
723
|
+
x_win_avg = None
|
724
|
+
return x_win, x_win_onff, x_win_avg
|
725
|
+
|
726
|
+
# cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
|
727
|
+
f0 = 2 * np.pi
|
728
|
+
CMOR_COI = 2 ** -0.5
|
729
|
+
CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0 ** 2) ** 0.5)
|
730
|
+
COI_FREQ = 1 / (CMOR_COI * CMOR_FLAMBDA)
|
731
|
+
|
732
|
+
def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
|
733
|
+
bandwidth=1.0, axis=-1, detrend=False, normalize=False):
|
734
|
+
"""Calculate spectrogram using continuous wavelet transform"""
|
735
|
+
x = np.asarray(x)
|
736
|
+
N = x.shape[axis]
|
737
|
+
times = np.arange(N) / fs
|
738
|
+
# detrend and normalize
|
739
|
+
if detrend:
|
740
|
+
x = signal.detrend(x, axis=axis, type='linear')
|
741
|
+
if normalize:
|
742
|
+
x = x / x.std()
|
743
|
+
# Define some parameters of our wavelet analysis.
|
744
|
+
# range of scales (in time) that makes sense
|
745
|
+
# min = 2 (Nyquist frequency)
|
746
|
+
# max = np.floor(N/2)
|
747
|
+
nOctaves = min(nOctaves, np.log2(2 * np.floor(N / 2)))
|
748
|
+
scales = 2 ** np.arange(1, nOctaves, 1 / nNotes)
|
749
|
+
# cwt and the frequencies used.
|
750
|
+
# Use the complex morelet with bw=2*bandwidth^2 and center frequency of 1.0
|
751
|
+
# bandwidth is sigma of the gaussian envelope
|
752
|
+
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
|
753
|
+
frequencies = pywt.scale2frequency(wavelet, scales) * fs
|
754
|
+
scales = scales[(frequencies >= freq_range[0]) & (frequencies <= freq_range[1])]
|
755
|
+
coef, frequencies = pywt.cwt(x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis)
|
756
|
+
power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
|
757
|
+
# cone of influence in terms of wavelength
|
758
|
+
coi = N / 2 - np.abs(np.arange(N) - (N - 1) / 2)
|
759
|
+
# cone of influence in terms of frequency
|
760
|
+
coif = COI_FREQ * fs / coi
|
761
|
+
return power, times, frequencies, coif
|
762
|
+
|
763
|
+
|
764
|
+
def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
|
765
|
+
channel_coords=None, **cwt_kwargs):
|
766
|
+
"""Calculate spectrogram using continuous wavelet transform and return an xarray.Dataset
|
767
|
+
x: input array
|
768
|
+
fs: sampling frequency (Hz)
|
769
|
+
axis: dimension index of time axis in x
|
770
|
+
downsample_fs: downsample to the frequency if specified
|
771
|
+
time_unit: unit of time in seconds
|
772
|
+
channel_coords: dictionary of {coordinate name: index} for channels
|
773
|
+
cwt_kwargs: keyword arguments for cwt_spectrogram()
|
774
|
+
"""
|
775
|
+
x = np.asarray(x)
|
776
|
+
T = x.shape[axis] # number of time points
|
777
|
+
t = np.arange(T) / fs if time is None else np.asarray(time)
|
778
|
+
if downsample_fs is None or downsample_fs >= fs:
|
779
|
+
downsample_fs = fs
|
780
|
+
downsampled = x
|
781
|
+
else:
|
782
|
+
num = int(T * downsample_fs / fs)
|
783
|
+
downsample_fs = num / T * fs
|
784
|
+
downsampled, t = signal.resample(x, num=num, t=t, axis=axis)
|
785
|
+
downsampled = np.moveaxis(downsampled, axis, -1)
|
786
|
+
sxx, _, f, coif = cwt_spectrogram(downsampled, downsample_fs, **cwt_kwargs)
|
787
|
+
sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
|
788
|
+
if channel_coords is None:
|
789
|
+
channel_coords = {f'dim_{i:d}': range(d) for i, d in enumerate(sxx.shape[:-2])}
|
790
|
+
sxx = xr.DataArray(sxx, coords={**channel_coords, 'frequency': f, 'time': t}).to_dataset(name='PSD')
|
791
|
+
sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
|
792
|
+
return sxx
|
793
|
+
|
794
|
+
|
795
|
+
# will probs move to bmplot later
|
796
|
+
def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
|
797
|
+
plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
|
798
|
+
"""Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
|
799
|
+
sxx = sxx_xarray.PSD.values.copy()
|
800
|
+
t = sxx_xarray.time.values.copy()
|
801
|
+
f = sxx_xarray.frequency.values.copy()
|
802
|
+
|
803
|
+
cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
|
804
|
+
if log_power:
|
805
|
+
with np.errstate(divide='ignore'):
|
806
|
+
sxx = np.log10(sxx)
|
807
|
+
cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
|
808
|
+
|
809
|
+
if remove_aperiodic is not None:
|
810
|
+
f1_idx = 0 if f[0] else 1
|
811
|
+
ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
|
812
|
+
sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
|
813
|
+
sxx[:f1_idx, :] = 0.
|
814
|
+
|
815
|
+
if log_power == 'dB':
|
816
|
+
sxx *= 10
|
817
|
+
|
818
|
+
if ax is None:
|
819
|
+
_, ax = plt.subplots(1, 1)
|
820
|
+
plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
|
821
|
+
if plt_range.size == 1:
|
822
|
+
plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
|
823
|
+
f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
|
824
|
+
if clr_freq_range is None:
|
825
|
+
vmin, vmax = None, None
|
826
|
+
else:
|
827
|
+
c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
|
828
|
+
vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
|
829
|
+
|
830
|
+
f = f[f_idx]
|
831
|
+
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
|
832
|
+
if 'cone_of_influence_frequency' in sxx_xarray:
|
833
|
+
coif = sxx_xarray.cone_of_influence_frequency
|
834
|
+
ax.plot(t, coif)
|
835
|
+
ax.fill_between(t, coif, step='mid', alpha=0.2)
|
836
|
+
ax.set_xlim(t[0], t[-1])
|
837
|
+
#ax.set_xlim(t[0],0.2)
|
838
|
+
ax.set_ylim(f[0], f[-1])
|
839
|
+
plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
|
840
|
+
ax.set_xlabel('Time (sec)')
|
841
|
+
ax.set_ylabel('Frequency (Hz)')
|
842
|
+
return sxx
|
610
843
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.6.8.
|
3
|
+
Version: 0.6.8.8
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -33,6 +33,7 @@ Requires-Dist: fooof
|
|
33
33
|
Requires-Dist: requests
|
34
34
|
Requires-Dist: pyyaml
|
35
35
|
Requires-Dist: PyWavelets
|
36
|
+
Requires-Dist: numba
|
36
37
|
Dynamic: author
|
37
38
|
Dynamic: author-email
|
38
39
|
Dynamic: classifier
|
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
|
9
9
|
bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
|
10
10
|
bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
|
11
11
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bmtool/analysis/lfp.py,sha256=
|
12
|
+
bmtool/analysis/lfp.py,sha256=TfPSLIh2LqbWTRFnYwBMBpfbj6yxyhC7EAQM1e2H4qs,33609
|
13
13
|
bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
|
14
14
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
19
19
|
bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
|
20
20
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
21
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
22
|
-
bmtool-0.6.8.
|
23
|
-
bmtool-0.6.8.
|
24
|
-
bmtool-0.6.8.
|
25
|
-
bmtool-0.6.8.
|
26
|
-
bmtool-0.6.8.
|
27
|
-
bmtool-0.6.8.
|
22
|
+
bmtool-0.6.8.8.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
23
|
+
bmtool-0.6.8.8.dist-info/METADATA,sha256=IqLuFdtSCGW_gxwgU8vD543KUj19fP-dhTbeSG_6HNg,20478
|
24
|
+
bmtool-0.6.8.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
25
|
+
bmtool-0.6.8.8.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
26
|
+
bmtool-0.6.8.8.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
27
|
+
bmtool-0.6.8.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|