bmtool 0.6.8.5__py3-none-any.whl → 0.6.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/lfp.py +50 -8
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/METADATA +3 -2
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/RECORD +7 -7
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/WHEEL +0 -0
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.6.8.5.dist-info → bmtool-0.6.8.7.dist-info}/top_level.txt +0 -0
bmtool/analysis/lfp.py
CHANGED
@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
|
|
11
11
|
from scipy import signal
|
12
12
|
import pywt
|
13
13
|
from bmtool.bmplot import is_notebook
|
14
|
+
import numba
|
15
|
+
from numba import cuda
|
14
16
|
|
15
17
|
|
16
18
|
def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
@@ -422,10 +424,45 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
|
|
422
424
|
return plv
|
423
425
|
|
424
426
|
|
427
|
+
@numba.njit(parallel=True, fastmath=True)
|
428
|
+
def _ppc_parallel_numba(spike_phases):
|
429
|
+
"""Numba-optimized parallel PPC calculation"""
|
430
|
+
n = len(spike_phases)
|
431
|
+
sum_cos = 0.0
|
432
|
+
for i in numba.prange(n):
|
433
|
+
phase_i = spike_phases[i]
|
434
|
+
for j in range(i + 1, n):
|
435
|
+
sum_cos += np.cos(phase_i - spike_phases[j])
|
436
|
+
return (2 / (n * (n - 1))) * sum_cos
|
437
|
+
|
438
|
+
|
439
|
+
@cuda.jit(fastmath=True)
|
440
|
+
def _ppc_cuda_kernel(spike_phases, out):
|
441
|
+
i = cuda.grid(1)
|
442
|
+
if i < len(spike_phases):
|
443
|
+
local_sum = 0.0
|
444
|
+
for j in range(i+1, len(spike_phases)):
|
445
|
+
local_sum += np.cos(spike_phases[i] - spike_phases[j])
|
446
|
+
out[i] = local_sum
|
447
|
+
|
448
|
+
|
449
|
+
def _ppc_gpu(spike_phases):
|
450
|
+
"""GPU-accelerated implementation"""
|
451
|
+
d_phases = cuda.to_device(spike_phases)
|
452
|
+
d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
|
453
|
+
|
454
|
+
threads = 256
|
455
|
+
blocks = (len(spike_phases) + threads - 1) // threads
|
456
|
+
|
457
|
+
_ppc_cuda_kernel[blocks, threads](d_phases, d_out)
|
458
|
+
total = d_out.copy_to_host().sum()
|
459
|
+
return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
|
460
|
+
|
461
|
+
|
425
462
|
def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
|
426
463
|
lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
|
427
464
|
lowcut: float = None, highcut: float = None,
|
428
|
-
bandwidth: float = 2.0) -> tuple:
|
465
|
+
bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
|
429
466
|
"""
|
430
467
|
Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
|
431
468
|
Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
|
@@ -439,12 +476,11 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
439
476
|
- freq_of_interest: Desired frequency for wavelet phase extraction
|
440
477
|
- lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
|
441
478
|
- bandwidth: Bandwidth parameter for the wavelet
|
479
|
+
- ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
|
442
480
|
|
443
481
|
Returns:
|
444
482
|
- ppc: Pairwise Phase Consistency value
|
445
|
-
- spike_phases: Phases at spike times
|
446
483
|
"""
|
447
|
-
print("Note this method will a very long time if there are a lot of spikes. If there are a lot of spikes consider using the PPC2 method if speed is an issue")
|
448
484
|
if spike_fs is None:
|
449
485
|
spike_fs = lfp_fs
|
450
486
|
# Convert spike times to sample indices
|
@@ -511,11 +547,17 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
|
|
511
547
|
# ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
512
548
|
|
513
549
|
# same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
550
|
+
if ppc_method == 'numpy':
|
551
|
+
i, j = np.triu_indices(n_spikes, k=1)
|
552
|
+
phase_diff = spike_phases[i] - spike_phases[j]
|
553
|
+
sum_cos_diff = np.sum(np.cos(phase_diff))
|
554
|
+
ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
|
555
|
+
elif ppc_method == 'numba':
|
556
|
+
ppc = _ppc_parallel_numba(spike_phases)
|
557
|
+
elif ppc_method == 'gpu':
|
558
|
+
ppc = _ppc_gpu(spike_phases)
|
559
|
+
else:
|
560
|
+
raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
|
519
561
|
return ppc
|
520
562
|
|
521
563
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.6.8.
|
3
|
+
Version: 0.6.8.7
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -32,7 +32,8 @@ Requires-Dist: xarray
|
|
32
32
|
Requires-Dist: fooof
|
33
33
|
Requires-Dist: requests
|
34
34
|
Requires-Dist: pyyaml
|
35
|
-
Requires-Dist:
|
35
|
+
Requires-Dist: PyWavelets
|
36
|
+
Requires-Dist: numba
|
36
37
|
Dynamic: author
|
37
38
|
Dynamic: author-email
|
38
39
|
Dynamic: classifier
|
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
|
9
9
|
bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
|
10
10
|
bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
|
11
11
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
bmtool/analysis/lfp.py,sha256=
|
12
|
+
bmtool/analysis/lfp.py,sha256=TmuetoKTZBCuZbdRFRus_cby5ugCBRNpOSSEIi6Kgac,25096
|
13
13
|
bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
|
14
14
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
19
19
|
bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
|
20
20
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
21
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
22
|
-
bmtool-0.6.8.
|
23
|
-
bmtool-0.6.8.
|
24
|
-
bmtool-0.6.8.
|
25
|
-
bmtool-0.6.8.
|
26
|
-
bmtool-0.6.8.
|
27
|
-
bmtool-0.6.8.
|
22
|
+
bmtool-0.6.8.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
23
|
+
bmtool-0.6.8.7.dist-info/METADATA,sha256=Gd7eoiKxTXA85sURHMmsQqNKZsXVK48uFxkQrDoEbsQ,20478
|
24
|
+
bmtool-0.6.8.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
25
|
+
bmtool-0.6.8.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
26
|
+
bmtool-0.6.8.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
27
|
+
bmtool-0.6.8.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|