bmtool 0.6.8.6__py3-none-any.whl → 0.6.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/lfp.py CHANGED
@@ -6,11 +6,14 @@ import h5py
6
6
  import numpy as np
7
7
  import xarray as xr
8
8
  from fooof import FOOOF
9
- from fooof.sim.gen import gen_model
9
+ from fooof.sim.gen import gen_model, gen_aperiodic
10
10
  import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
13
  from bmtool.bmplot import is_notebook
14
+ import numba
15
+ from numba import cuda
16
+ import pandas as pd
14
17
 
15
18
 
16
19
  def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
@@ -422,10 +425,45 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
422
425
  return plv
423
426
 
424
427
 
428
+ @numba.njit(parallel=True, fastmath=True)
429
+ def _ppc_parallel_numba(spike_phases):
430
+ """Numba-optimized parallel PPC calculation"""
431
+ n = len(spike_phases)
432
+ sum_cos = 0.0
433
+ for i in numba.prange(n):
434
+ phase_i = spike_phases[i]
435
+ for j in range(i + 1, n):
436
+ sum_cos += np.cos(phase_i - spike_phases[j])
437
+ return (2 / (n * (n - 1))) * sum_cos
438
+
439
+
440
+ @cuda.jit(fastmath=True)
441
+ def _ppc_cuda_kernel(spike_phases, out):
442
+ i = cuda.grid(1)
443
+ if i < len(spike_phases):
444
+ local_sum = 0.0
445
+ for j in range(i+1, len(spike_phases)):
446
+ local_sum += np.cos(spike_phases[i] - spike_phases[j])
447
+ out[i] = local_sum
448
+
449
+
450
+ def _ppc_gpu(spike_phases):
451
+ """GPU-accelerated implementation"""
452
+ d_phases = cuda.to_device(spike_phases)
453
+ d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
454
+
455
+ threads = 256
456
+ blocks = (len(spike_phases) + threads - 1) // threads
457
+
458
+ _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
459
+ total = d_out.copy_to_host().sum()
460
+ return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
461
+
462
+
425
463
  def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
426
464
  lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
427
465
  lowcut: float = None, highcut: float = None,
428
- bandwidth: float = 2.0) -> tuple:
466
+ bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
429
467
  """
430
468
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
431
469
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
@@ -439,12 +477,11 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
439
477
  - freq_of_interest: Desired frequency for wavelet phase extraction
440
478
  - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
441
479
  - bandwidth: Bandwidth parameter for the wavelet
480
+ - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
442
481
 
443
482
  Returns:
444
483
  - ppc: Pairwise Phase Consistency value
445
- - spike_phases: Phases at spike times
446
484
  """
447
- print("Note this method will a very long time if there are a lot of spikes. If there are a lot of spikes consider using the PPC2 method if speed is an issue")
448
485
  if spike_fs is None:
449
486
  spike_fs = lfp_fs
450
487
  # Convert spike times to sample indices
@@ -511,11 +548,17 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
511
548
  # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
512
549
 
513
550
  # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
514
- i, j = np.triu_indices(n_spikes, k=1)
515
- phase_diff = spike_phases[i] - spike_phases[j]
516
- sum_cos_diff = np.sum(np.cos(phase_diff))
517
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
518
-
551
+ if ppc_method == 'numpy':
552
+ i, j = np.triu_indices(n_spikes, k=1)
553
+ phase_diff = spike_phases[i] - spike_phases[j]
554
+ sum_cos_diff = np.sum(np.cos(phase_diff))
555
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
556
+ elif ppc_method == 'numba':
557
+ ppc = _ppc_parallel_numba(spike_phases)
558
+ elif ppc_method == 'gpu':
559
+ ppc = _ppc_gpu(spike_phases)
560
+ else:
561
+ raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
519
562
  return ppc
520
563
 
521
564
 
@@ -605,6 +648,196 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
605
648
 
606
649
  return ppc2
607
650
 
608
-
609
651
 
652
+ # windowing functions
653
+ def windowed_xarray(da, windows, dim='time',
654
+ new_coord_name='cycle', new_coord=None):
655
+ """Divide xarray into windows of equal size along an axis
656
+ da: input DataArray
657
+ windows: 2d-array of windows
658
+ dim: dimension along which to divide
659
+ new_coord_name: name of new dimemsion along which to concatenate windows
660
+ new_coord: pandas Index object of new coordinates. Defaults to integer index
661
+ """
662
+ win_da = [da.sel({dim: slice(*w)}) for w in windows]
663
+ n_win = min(x.coords[dim].size for x in win_da)
664
+ idx = {dim: slice(n_win)}
665
+ coords = da.coords[dim].isel(idx).coords
666
+ win_da = [x.isel(idx).assign_coords(coords) for x in win_da]
667
+ if new_coord is None:
668
+ new_coord = pd.Index(range(len(win_da)), name=new_coord_name)
669
+ win_da = xr.concat(win_da, dim=new_coord)
670
+ return win_da
671
+
672
+
673
+ def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
674
+ """Group windows into a dictionary of DataArrays
675
+ win_da: input windowed DataArrays
676
+ win_grp_idx: dictionary of {window group id: window indices}
677
+ win_dim: dimension for different windows
678
+ Return: dictionaries of {window group id: DataArray of grouped windows}
679
+ win_on / win_off for windows selected / not selected by `win_grp_idx`
680
+ """
681
+ win_on, win_off = {}, {}
682
+ for g, w in win_grp_idx.items():
683
+ win_on[g] = win_da.sel({win_dim: w})
684
+ win_off[g] = win_da.drop_sel({win_dim: w})
685
+ return win_on, win_off
686
+
687
+
688
+ def average_group_windows(win_da, win_dim='cycle', grp_dim='unique_cycle'):
689
+ """Average over windows in each group and stack groups in a DataArray
690
+ win_da: input dictionary of {window group id: DataArray of grouped windows}
691
+ win_dim: dimension for different windows
692
+ grp_dim: dimension along which to stack average of window groups
693
+ """
694
+ win_avg = {g: xr.concat([x.mean(dim=win_dim), x.std(dim=win_dim)],
695
+ pd.Index(('mean_', 'std_'), name='stats'))
696
+ for g, x in win_da.items()}
697
+ win_avg = xr.concat(win_avg.values(), dim=pd.Index(win_avg.keys(), name=grp_dim))
698
+ win_avg = win_avg.to_dataset(dim='stats')
699
+ return win_avg
700
+
701
+ # used for avg spectrogram across different trials
702
+ def get_windowed_data(x, windows, win_grp_idx, dim='time',
703
+ win_dim='cycle', win_coord=None, grp_dim='unique_cycle'):
704
+ """Apply functions of windowing to data
705
+ x: DataArray
706
+ windows: `windows` for `windowed_xarray`
707
+ win_grp_idx: `win_grp_idx` for `group_windows`
708
+ dim: dimension along which to divide
709
+ win_dim: dimension for different windows
710
+ win_coord: pandas Index object of `win_dim` coordinates
711
+ grp_dim: dimension along which to stack average of window groups.
712
+ If None or empty or False, do not calculate average.
713
+ Return: data returned by three functions,
714
+ `windowed_xarray`, `group_windows`, `average_group_windows`
715
+ """
716
+ x_win = windowed_xarray(x, windows, dim=dim,
717
+ new_coord_name=win_dim, new_coord=win_coord)
718
+ x_win_onff = group_windows(x_win, win_grp_idx, win_dim=win_dim)
719
+ if grp_dim:
720
+ x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim)
721
+ for x in x_win_onff]
722
+ else:
723
+ x_win_avg = None
724
+ return x_win, x_win_onff, x_win_avg
725
+
726
+ # cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
727
+ f0 = 2 * np.pi
728
+ CMOR_COI = 2 ** -0.5
729
+ CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0 ** 2) ** 0.5)
730
+ COI_FREQ = 1 / (CMOR_COI * CMOR_FLAMBDA)
731
+
732
+ def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
733
+ bandwidth=1.0, axis=-1, detrend=False, normalize=False):
734
+ """Calculate spectrogram using continuous wavelet transform"""
735
+ x = np.asarray(x)
736
+ N = x.shape[axis]
737
+ times = np.arange(N) / fs
738
+ # detrend and normalize
739
+ if detrend:
740
+ x = signal.detrend(x, axis=axis, type='linear')
741
+ if normalize:
742
+ x = x / x.std()
743
+ # Define some parameters of our wavelet analysis.
744
+ # range of scales (in time) that makes sense
745
+ # min = 2 (Nyquist frequency)
746
+ # max = np.floor(N/2)
747
+ nOctaves = min(nOctaves, np.log2(2 * np.floor(N / 2)))
748
+ scales = 2 ** np.arange(1, nOctaves, 1 / nNotes)
749
+ # cwt and the frequencies used.
750
+ # Use the complex morelet with bw=2*bandwidth^2 and center frequency of 1.0
751
+ # bandwidth is sigma of the gaussian envelope
752
+ wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
753
+ frequencies = pywt.scale2frequency(wavelet, scales) * fs
754
+ scales = scales[(frequencies >= freq_range[0]) & (frequencies <= freq_range[1])]
755
+ coef, frequencies = pywt.cwt(x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis)
756
+ power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
757
+ # cone of influence in terms of wavelength
758
+ coi = N / 2 - np.abs(np.arange(N) - (N - 1) / 2)
759
+ # cone of influence in terms of frequency
760
+ coif = COI_FREQ * fs / coi
761
+ return power, times, frequencies, coif
762
+
763
+
764
+ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
765
+ channel_coords=None, **cwt_kwargs):
766
+ """Calculate spectrogram using continuous wavelet transform and return an xarray.Dataset
767
+ x: input array
768
+ fs: sampling frequency (Hz)
769
+ axis: dimension index of time axis in x
770
+ downsample_fs: downsample to the frequency if specified
771
+ time_unit: unit of time in seconds
772
+ channel_coords: dictionary of {coordinate name: index} for channels
773
+ cwt_kwargs: keyword arguments for cwt_spectrogram()
774
+ """
775
+ x = np.asarray(x)
776
+ T = x.shape[axis] # number of time points
777
+ t = np.arange(T) / fs if time is None else np.asarray(time)
778
+ if downsample_fs is None or downsample_fs >= fs:
779
+ downsample_fs = fs
780
+ downsampled = x
781
+ else:
782
+ num = int(T * downsample_fs / fs)
783
+ downsample_fs = num / T * fs
784
+ downsampled, t = signal.resample(x, num=num, t=t, axis=axis)
785
+ downsampled = np.moveaxis(downsampled, axis, -1)
786
+ sxx, _, f, coif = cwt_spectrogram(downsampled, downsample_fs, **cwt_kwargs)
787
+ sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
788
+ if channel_coords is None:
789
+ channel_coords = {f'dim_{i:d}': range(d) for i, d in enumerate(sxx.shape[:-2])}
790
+ sxx = xr.DataArray(sxx, coords={**channel_coords, 'frequency': f, 'time': t}).to_dataset(name='PSD')
791
+ sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
792
+ return sxx
793
+
794
+
795
+ # will probs move to bmplot later
796
+ def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
797
+ plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
798
+ """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
799
+ sxx = sxx_xarray.PSD.values.copy()
800
+ t = sxx_xarray.time.values.copy()
801
+ f = sxx_xarray.frequency.values.copy()
802
+
803
+ cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
804
+ if log_power:
805
+ with np.errstate(divide='ignore'):
806
+ sxx = np.log10(sxx)
807
+ cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
808
+
809
+ if remove_aperiodic is not None:
810
+ f1_idx = 0 if f[0] else 1
811
+ ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
812
+ sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
813
+ sxx[:f1_idx, :] = 0.
814
+
815
+ if log_power == 'dB':
816
+ sxx *= 10
817
+
818
+ if ax is None:
819
+ _, ax = plt.subplots(1, 1)
820
+ plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
821
+ if plt_range.size == 1:
822
+ plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
823
+ f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
824
+ if clr_freq_range is None:
825
+ vmin, vmax = None, None
826
+ else:
827
+ c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
828
+ vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
829
+
830
+ f = f[f_idx]
831
+ pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
832
+ if 'cone_of_influence_frequency' in sxx_xarray:
833
+ coif = sxx_xarray.cone_of_influence_frequency
834
+ ax.plot(t, coif)
835
+ ax.fill_between(t, coif, step='mid', alpha=0.2)
836
+ ax.set_xlim(t[0], t[-1])
837
+ #ax.set_xlim(t[0],0.2)
838
+ ax.set_ylim(f[0], f[-1])
839
+ plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
840
+ ax.set_xlabel('Time (sec)')
841
+ ax.set_ylabel('Frequency (Hz)')
842
+ return sxx
610
843
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.8.6
3
+ Version: 0.6.8.8
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -33,6 +33,7 @@ Requires-Dist: fooof
33
33
  Requires-Dist: requests
34
34
  Requires-Dist: pyyaml
35
35
  Requires-Dist: PyWavelets
36
+ Requires-Dist: numba
36
37
  Dynamic: author
37
38
  Dynamic: author-email
38
39
  Dynamic: classifier
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
9
  bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
10
10
  bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
11
11
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bmtool/analysis/lfp.py,sha256=Ei-l9aA13IOsdOEjmkqmdthKgPkEPnbiHdJ_-TB2twQ,23771
12
+ bmtool/analysis/lfp.py,sha256=TfPSLIh2LqbWTRFnYwBMBpfbj6yxyhC7EAQM1e2H4qs,33609
13
13
  bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
14
14
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
19
19
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
20
20
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
22
- bmtool-0.6.8.6.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
- bmtool-0.6.8.6.dist-info/METADATA,sha256=ndqm3Ph6LxmfatX3g9znK_GBNa88uOzVgCQ6Gcq88RY,20457
24
- bmtool-0.6.8.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
- bmtool-0.6.8.6.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
- bmtool-0.6.8.6.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
- bmtool-0.6.8.6.dist-info/RECORD,,
22
+ bmtool-0.6.8.8.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
+ bmtool-0.6.8.8.dist-info/METADATA,sha256=IqLuFdtSCGW_gxwgU8vD543KUj19fP-dhTbeSG_6HNg,20478
24
+ bmtool-0.6.8.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ bmtool-0.6.8.8.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
+ bmtool-0.6.8.8.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
+ bmtool-0.6.8.8.dist-info/RECORD,,