bmtool 0.6.8.7__py3-none-any.whl → 0.6.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/lfp.py CHANGED
@@ -6,13 +6,14 @@ import h5py
6
6
  import numpy as np
7
7
  import xarray as xr
8
8
  from fooof import FOOOF
9
- from fooof.sim.gen import gen_model
9
+ from fooof.sim.gen import gen_model, gen_aperiodic
10
10
  import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
13
  from bmtool.bmplot import is_notebook
14
14
  import numba
15
15
  from numba import cuda
16
+ import pandas as pd
16
17
 
17
18
 
18
19
  def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
@@ -647,6 +648,196 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
647
648
 
648
649
  return ppc2
649
650
 
650
-
651
651
 
652
+ # windowing functions
653
+ def windowed_xarray(da, windows, dim='time',
654
+ new_coord_name='cycle', new_coord=None):
655
+ """Divide xarray into windows of equal size along an axis
656
+ da: input DataArray
657
+ windows: 2d-array of windows
658
+ dim: dimension along which to divide
659
+ new_coord_name: name of new dimemsion along which to concatenate windows
660
+ new_coord: pandas Index object of new coordinates. Defaults to integer index
661
+ """
662
+ win_da = [da.sel({dim: slice(*w)}) for w in windows]
663
+ n_win = min(x.coords[dim].size for x in win_da)
664
+ idx = {dim: slice(n_win)}
665
+ coords = da.coords[dim].isel(idx).coords
666
+ win_da = [x.isel(idx).assign_coords(coords) for x in win_da]
667
+ if new_coord is None:
668
+ new_coord = pd.Index(range(len(win_da)), name=new_coord_name)
669
+ win_da = xr.concat(win_da, dim=new_coord)
670
+ return win_da
671
+
672
+
673
+ def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
674
+ """Group windows into a dictionary of DataArrays
675
+ win_da: input windowed DataArrays
676
+ win_grp_idx: dictionary of {window group id: window indices}
677
+ win_dim: dimension for different windows
678
+ Return: dictionaries of {window group id: DataArray of grouped windows}
679
+ win_on / win_off for windows selected / not selected by `win_grp_idx`
680
+ """
681
+ win_on, win_off = {}, {}
682
+ for g, w in win_grp_idx.items():
683
+ win_on[g] = win_da.sel({win_dim: w})
684
+ win_off[g] = win_da.drop_sel({win_dim: w})
685
+ return win_on, win_off
686
+
687
+
688
+ def average_group_windows(win_da, win_dim='cycle', grp_dim='unique_cycle'):
689
+ """Average over windows in each group and stack groups in a DataArray
690
+ win_da: input dictionary of {window group id: DataArray of grouped windows}
691
+ win_dim: dimension for different windows
692
+ grp_dim: dimension along which to stack average of window groups
693
+ """
694
+ win_avg = {g: xr.concat([x.mean(dim=win_dim), x.std(dim=win_dim)],
695
+ pd.Index(('mean_', 'std_'), name='stats'))
696
+ for g, x in win_da.items()}
697
+ win_avg = xr.concat(win_avg.values(), dim=pd.Index(win_avg.keys(), name=grp_dim))
698
+ win_avg = win_avg.to_dataset(dim='stats')
699
+ return win_avg
700
+
701
+ # used for avg spectrogram across different trials
702
+ def get_windowed_data(x, windows, win_grp_idx, dim='time',
703
+ win_dim='cycle', win_coord=None, grp_dim='unique_cycle'):
704
+ """Apply functions of windowing to data
705
+ x: DataArray
706
+ windows: `windows` for `windowed_xarray`
707
+ win_grp_idx: `win_grp_idx` for `group_windows`
708
+ dim: dimension along which to divide
709
+ win_dim: dimension for different windows
710
+ win_coord: pandas Index object of `win_dim` coordinates
711
+ grp_dim: dimension along which to stack average of window groups.
712
+ If None or empty or False, do not calculate average.
713
+ Return: data returned by three functions,
714
+ `windowed_xarray`, `group_windows`, `average_group_windows`
715
+ """
716
+ x_win = windowed_xarray(x, windows, dim=dim,
717
+ new_coord_name=win_dim, new_coord=win_coord)
718
+ x_win_onff = group_windows(x_win, win_grp_idx, win_dim=win_dim)
719
+ if grp_dim:
720
+ x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim)
721
+ for x in x_win_onff]
722
+ else:
723
+ x_win_avg = None
724
+ return x_win, x_win_onff, x_win_avg
725
+
726
+ # cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
727
+ f0 = 2 * np.pi
728
+ CMOR_COI = 2 ** -0.5
729
+ CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0 ** 2) ** 0.5)
730
+ COI_FREQ = 1 / (CMOR_COI * CMOR_FLAMBDA)
731
+
732
+ def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
733
+ bandwidth=1.0, axis=-1, detrend=False, normalize=False):
734
+ """Calculate spectrogram using continuous wavelet transform"""
735
+ x = np.asarray(x)
736
+ N = x.shape[axis]
737
+ times = np.arange(N) / fs
738
+ # detrend and normalize
739
+ if detrend:
740
+ x = signal.detrend(x, axis=axis, type='linear')
741
+ if normalize:
742
+ x = x / x.std()
743
+ # Define some parameters of our wavelet analysis.
744
+ # range of scales (in time) that makes sense
745
+ # min = 2 (Nyquist frequency)
746
+ # max = np.floor(N/2)
747
+ nOctaves = min(nOctaves, np.log2(2 * np.floor(N / 2)))
748
+ scales = 2 ** np.arange(1, nOctaves, 1 / nNotes)
749
+ # cwt and the frequencies used.
750
+ # Use the complex morelet with bw=2*bandwidth^2 and center frequency of 1.0
751
+ # bandwidth is sigma of the gaussian envelope
752
+ wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
753
+ frequencies = pywt.scale2frequency(wavelet, scales) * fs
754
+ scales = scales[(frequencies >= freq_range[0]) & (frequencies <= freq_range[1])]
755
+ coef, frequencies = pywt.cwt(x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis)
756
+ power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
757
+ # cone of influence in terms of wavelength
758
+ coi = N / 2 - np.abs(np.arange(N) - (N - 1) / 2)
759
+ # cone of influence in terms of frequency
760
+ coif = COI_FREQ * fs / coi
761
+ return power, times, frequencies, coif
762
+
763
+
764
+ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
765
+ channel_coords=None, **cwt_kwargs):
766
+ """Calculate spectrogram using continuous wavelet transform and return an xarray.Dataset
767
+ x: input array
768
+ fs: sampling frequency (Hz)
769
+ axis: dimension index of time axis in x
770
+ downsample_fs: downsample to the frequency if specified
771
+ time_unit: unit of time in seconds
772
+ channel_coords: dictionary of {coordinate name: index} for channels
773
+ cwt_kwargs: keyword arguments for cwt_spectrogram()
774
+ """
775
+ x = np.asarray(x)
776
+ T = x.shape[axis] # number of time points
777
+ t = np.arange(T) / fs if time is None else np.asarray(time)
778
+ if downsample_fs is None or downsample_fs >= fs:
779
+ downsample_fs = fs
780
+ downsampled = x
781
+ else:
782
+ num = int(T * downsample_fs / fs)
783
+ downsample_fs = num / T * fs
784
+ downsampled, t = signal.resample(x, num=num, t=t, axis=axis)
785
+ downsampled = np.moveaxis(downsampled, axis, -1)
786
+ sxx, _, f, coif = cwt_spectrogram(downsampled, downsample_fs, **cwt_kwargs)
787
+ sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
788
+ if channel_coords is None:
789
+ channel_coords = {f'dim_{i:d}': range(d) for i, d in enumerate(sxx.shape[:-2])}
790
+ sxx = xr.DataArray(sxx, coords={**channel_coords, 'frequency': f, 'time': t}).to_dataset(name='PSD')
791
+ sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
792
+ return sxx
793
+
794
+
795
+ # will probs move to bmplot later
796
+ def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
797
+ plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
798
+ """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
799
+ sxx = sxx_xarray.PSD.values.copy()
800
+ t = sxx_xarray.time.values.copy()
801
+ f = sxx_xarray.frequency.values.copy()
802
+
803
+ cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
804
+ if log_power:
805
+ with np.errstate(divide='ignore'):
806
+ sxx = np.log10(sxx)
807
+ cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
808
+
809
+ if remove_aperiodic is not None:
810
+ f1_idx = 0 if f[0] else 1
811
+ ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
812
+ sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
813
+ sxx[:f1_idx, :] = 0.
814
+
815
+ if log_power == 'dB':
816
+ sxx *= 10
817
+
818
+ if ax is None:
819
+ _, ax = plt.subplots(1, 1)
820
+ plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
821
+ if plt_range.size == 1:
822
+ plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
823
+ f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
824
+ if clr_freq_range is None:
825
+ vmin, vmax = None, None
826
+ else:
827
+ c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
828
+ vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
829
+
830
+ f = f[f_idx]
831
+ pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
832
+ if 'cone_of_influence_frequency' in sxx_xarray:
833
+ coif = sxx_xarray.cone_of_influence_frequency
834
+ ax.plot(t, coif)
835
+ ax.fill_between(t, coif, step='mid', alpha=0.2)
836
+ ax.set_xlim(t[0], t[-1])
837
+ #ax.set_xlim(t[0],0.2)
838
+ ax.set_ylim(f[0], f[-1])
839
+ plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
840
+ ax.set_xlabel('Time (sec)')
841
+ ax.set_ylabel('Frequency (Hz)')
842
+ return sxx
652
843
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.8.7
3
+ Version: 0.6.8.8
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
9
  bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
10
10
  bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
11
11
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bmtool/analysis/lfp.py,sha256=TmuetoKTZBCuZbdRFRus_cby5ugCBRNpOSSEIi6Kgac,25096
12
+ bmtool/analysis/lfp.py,sha256=TfPSLIh2LqbWTRFnYwBMBpfbj6yxyhC7EAQM1e2H4qs,33609
13
13
  bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
14
14
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
19
19
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
20
20
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
22
- bmtool-0.6.8.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
- bmtool-0.6.8.7.dist-info/METADATA,sha256=Gd7eoiKxTXA85sURHMmsQqNKZsXVK48uFxkQrDoEbsQ,20478
24
- bmtool-0.6.8.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
- bmtool-0.6.8.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
- bmtool-0.6.8.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
- bmtool-0.6.8.7.dist-info/RECORD,,
22
+ bmtool-0.6.8.8.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
+ bmtool-0.6.8.8.dist-info/METADATA,sha256=IqLuFdtSCGW_gxwgU8vD543KUj19fP-dhTbeSG_6HNg,20478
24
+ bmtool-0.6.8.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ bmtool-0.6.8.8.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
+ bmtool-0.6.8.8.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
+ bmtool-0.6.8.8.dist-info/RECORD,,