bmtool 0.6.8.5__py3-none-any.whl → 0.6.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/lfp.py CHANGED
@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
13
  from bmtool.bmplot import is_notebook
14
+ import numba
15
+ from numba import cuda
14
16
 
15
17
 
16
18
  def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
@@ -422,10 +424,45 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarr
422
424
  return plv
423
425
 
424
426
 
427
+ @numba.njit(parallel=True, fastmath=True)
428
+ def _ppc_parallel_numba(spike_phases):
429
+ """Numba-optimized parallel PPC calculation"""
430
+ n = len(spike_phases)
431
+ sum_cos = 0.0
432
+ for i in numba.prange(n):
433
+ phase_i = spike_phases[i]
434
+ for j in range(i + 1, n):
435
+ sum_cos += np.cos(phase_i - spike_phases[j])
436
+ return (2 / (n * (n - 1))) * sum_cos
437
+
438
+
439
+ @cuda.jit(fastmath=True)
440
+ def _ppc_cuda_kernel(spike_phases, out):
441
+ i = cuda.grid(1)
442
+ if i < len(spike_phases):
443
+ local_sum = 0.0
444
+ for j in range(i+1, len(spike_phases)):
445
+ local_sum += np.cos(spike_phases[i] - spike_phases[j])
446
+ out[i] = local_sum
447
+
448
+
449
+ def _ppc_gpu(spike_phases):
450
+ """GPU-accelerated implementation"""
451
+ d_phases = cuda.to_device(spike_phases)
452
+ d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
453
+
454
+ threads = 256
455
+ blocks = (len(spike_phases) + threads - 1) // threads
456
+
457
+ _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
458
+ total = d_out.copy_to_host().sum()
459
+ return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
460
+
461
+
425
462
  def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
426
463
  lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
427
464
  lowcut: float = None, highcut: float = None,
428
- bandwidth: float = 2.0) -> tuple:
465
+ bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
429
466
  """
430
467
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
431
468
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
@@ -439,12 +476,11 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
439
476
  - freq_of_interest: Desired frequency for wavelet phase extraction
440
477
  - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
441
478
  - bandwidth: Bandwidth parameter for the wavelet
479
+ - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
442
480
 
443
481
  Returns:
444
482
  - ppc: Pairwise Phase Consistency value
445
- - spike_phases: Phases at spike times
446
483
  """
447
- print("Note this method will a very long time if there are a lot of spikes. If there are a lot of spikes consider using the PPC2 method if speed is an issue")
448
484
  if spike_fs is None:
449
485
  spike_fs = lfp_fs
450
486
  # Convert spike times to sample indices
@@ -511,11 +547,17 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None,
511
547
  # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
512
548
 
513
549
  # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
514
- i, j = np.triu_indices(n_spikes, k=1)
515
- phase_diff = spike_phases[i] - spike_phases[j]
516
- sum_cos_diff = np.sum(np.cos(phase_diff))
517
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
518
-
550
+ if ppc_method == 'numpy':
551
+ i, j = np.triu_indices(n_spikes, k=1)
552
+ phase_diff = spike_phases[i] - spike_phases[j]
553
+ sum_cos_diff = np.sum(np.cos(phase_diff))
554
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
555
+ elif ppc_method == 'numba':
556
+ ppc = _ppc_parallel_numba(spike_phases)
557
+ elif ppc_method == 'gpu':
558
+ ppc = _ppc_gpu(spike_phases)
559
+ else:
560
+ raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
519
561
  return ppc
520
562
 
521
563
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.8.5
3
+ Version: 0.6.8.7
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -32,7 +32,8 @@ Requires-Dist: xarray
32
32
  Requires-Dist: fooof
33
33
  Requires-Dist: requests
34
34
  Requires-Dist: pyyaml
35
- Requires-Dist: pywt
35
+ Requires-Dist: PyWavelets
36
+ Requires-Dist: numba
36
37
  Dynamic: author
37
38
  Dynamic: author-email
38
39
  Dynamic: classifier
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
9
  bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
10
10
  bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
11
11
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bmtool/analysis/lfp.py,sha256=Ei-l9aA13IOsdOEjmkqmdthKgPkEPnbiHdJ_-TB2twQ,23771
12
+ bmtool/analysis/lfp.py,sha256=TmuetoKTZBCuZbdRFRus_cby5ugCBRNpOSSEIi6Kgac,25096
13
13
  bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
14
14
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
19
19
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
20
20
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
22
- bmtool-0.6.8.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
- bmtool-0.6.8.5.dist-info/METADATA,sha256=x5YjcoEnp1Cc0KDJgRyv11GFHH4zResjgcWGAmc9fuM,20451
24
- bmtool-0.6.8.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
- bmtool-0.6.8.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
- bmtool-0.6.8.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
- bmtool-0.6.8.5.dist-info/RECORD,,
22
+ bmtool-0.6.8.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
+ bmtool-0.6.8.7.dist-info/METADATA,sha256=Gd7eoiKxTXA85sURHMmsQqNKZsXVK48uFxkQrDoEbsQ,20478
24
+ bmtool-0.6.8.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ bmtool-0.6.8.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
+ bmtool-0.6.8.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
+ bmtool-0.6.8.7.dist-info/RECORD,,