bmtool 0.6.8.3__py3-none-any.whl → 0.6.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/lfp.py CHANGED
@@ -292,7 +292,7 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
292
292
  return x_a
293
293
 
294
294
 
295
- def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
295
+ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
296
296
  method: str = 'wavelet', lowcut: float = None, highcut: float = None,
297
297
  bandwidth: float = 2.0) -> np.ndarray:
298
298
  """
@@ -345,90 +345,210 @@ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: f
345
345
  return plv
346
346
 
347
347
 
348
- def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
349
- window_size: float, step_size: float,
350
- method: str = 'wavelet', freq_of_interest: float = None,
351
- lowcut: float = None, highcut: float = None,
352
- bandwidth: float = 2.0):
348
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
349
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
350
+ lowcut: float = None, highcut: float = None,
351
+ bandwidth: float = 2.0) -> tuple:
353
352
  """
354
- Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
353
+ Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
355
354
 
356
355
  Parameters:
357
- ----------
358
- x1, x2 : array-like
359
- Input signals (1D arrays, same length).
360
- fs : float
361
- Sampling frequency of the input signals.
362
- window_size : float
363
- Length of the window in seconds for PLV calculation.
364
- step_size : float
365
- Step size in seconds to slide the window across the signals.
366
- method : str, optional
367
- Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
368
- freq_of_interest : float, optional
369
- Frequency of interest for the wavelet method. Required if method is 'wavelet'.
370
- lowcut, highcut : float, optional
371
- Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
372
- bandwidth : float, optional
373
- Bandwidth parameter for the wavelet. Defaults to 2.0.
374
-
356
+ - spike_times: Array of spike times
357
+ - lfp_signal: Local field potential time series
358
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
359
+ - lfp_fs : Sampling frequency in Hz of the LFP
360
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
361
+ - freq_of_interest: Desired frequency for wavelet phase extraction
362
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
363
+ - bandwidth: Bandwidth parameter for the wavelet
364
+
375
365
  Returns:
376
- -------
377
- plv_over_time : 1D array
378
- Array of PLV values calculated over each window.
379
- times : 1D array
380
- The center times of each window where the PLV was calculated.
366
+ - ppc1: Phase-Phase Coupling value
367
+ - spike_phases: Phases at spike times
381
368
  """
382
- # Convert window and step size from seconds to samples
383
- window_samples = int(window_size * fs)
384
- step_samples = int(step_size * fs)
385
-
386
- # Initialize results
387
- plv_over_time = []
388
- times = []
389
-
390
- # Iterate over the signal with a sliding window
391
- for start in range(0, len(x1) - window_samples + 1, step_samples):
392
- end = start + window_samples
393
- window_x1 = x1[start:end]
394
- window_x2 = x2[start:end]
369
+
370
+ if spike_fs == None:
371
+ spike_fs = lfp_fs
372
+ # Convert spike times to sample indices
373
+ spike_times_seconds = spike_times / spike_fs
374
+
375
+ # Then convert from seconds to samples at the new sampling rate
376
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
377
+
378
+ # Filter indices to ensure they're within bounds of the LFP signal
379
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
380
+ if len(valid_indices) <= 1:
381
+ return 0, np.array([])
382
+
383
+ # Extract phase using the specified method
384
+ if method == 'wavelet':
385
+ if freq_of_interest is None:
386
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
395
387
 
396
- # Use the updated calculate_plv function within each window
397
- plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
398
- method=method, freq_of_interest=freq_of_interest,
399
- lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
400
- plv_over_time.append(plv)
388
+ # Apply CWT to extract phase at the frequency of interest
389
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
390
+ instantaneous_phase = np.angle(lfp_complex)
391
+
392
+ elif method == 'hilbert':
393
+ if lowcut is None or highcut is None:
394
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
395
+ filtered_lfp = lfp_signal
396
+ else:
397
+ # Bandpass filter the signal
398
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
401
399
 
402
- # Store the time at the center of the window
403
- center_time = (start + end) / 2 / fs
404
- times.append(center_time)
400
+ # Get phase using the Hilbert transform
401
+ analytic_signal = signal.hilbert(filtered_lfp)
402
+ instantaneous_phase = np.angle(analytic_signal)
403
+
404
+ else:
405
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
405
406
 
406
- return np.array(plv_over_time), np.array(times)
407
+ # Get phases at spike times
408
+ spike_phases = instantaneous_phase[valid_indices]
409
+
410
+ # Calculate PPC1
411
+ n = len(spike_phases)
412
+
413
+ # Convert phases to unit vectors in the complex plane
414
+ unit_vectors = np.exp(1j * spike_phases)
415
+
416
+ # Calculate the resultant vector
417
+ resultant_vector = np.sum(unit_vectors)
418
+
419
+ # Plv is the squared length of the resultant vector divided by n²
420
+ plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
421
+
422
+ return plv
407
423
 
408
424
 
409
- def calculate_ppc1(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
410
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
411
- lowcut: float = None, highcut: float = None,
412
- bandwidth: float = 2.0) -> tuple:
425
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
426
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
427
+ lowcut: float = None, highcut: float = None,
428
+ bandwidth: float = 2.0) -> tuple:
413
429
  """
414
- Calculate Phase-Phase Coupling (PPC1) between spike times and LFP signal. Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
430
+ Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
431
+ Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
415
432
 
416
433
  Parameters:
417
434
  - spike_times: Array of spike times
418
435
  - lfp_signal: Local field potential time series
419
436
  - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
420
- - lfp_fs : Sampling frequency in Hz of the LFP
437
+ - lfp_fs: Sampling frequency in Hz of the LFP
421
438
  - method: 'wavelet' or 'hilbert' to choose the phase extraction method
422
439
  - freq_of_interest: Desired frequency for wavelet phase extraction
423
440
  - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
424
441
  - bandwidth: Bandwidth parameter for the wavelet
425
442
 
426
443
  Returns:
427
- - ppc1: Phase-Phase Coupling value
444
+ - ppc: Pairwise Phase Consistency value
428
445
  - spike_phases: Phases at spike times
429
446
  """
447
+ print("Note this method will a very long time if there are a lot of spikes. If there are a lot of spikes consider using the PPC2 method if speed is an issue")
448
+ if spike_fs is None:
449
+ spike_fs = lfp_fs
450
+ # Convert spike times to sample indices
451
+ spike_times_seconds = spike_times / spike_fs
452
+
453
+ # Then convert from seconds to samples at the new sampling rate
454
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
430
455
 
431
- if spike_fs == None:
456
+ # Filter indices to ensure they're within bounds of the LFP signal
457
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
458
+ if len(valid_indices) <= 1:
459
+ return 0, np.array([])
460
+
461
+ # Extract phase using the specified method
462
+ if method == 'wavelet':
463
+ if freq_of_interest is None:
464
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
465
+
466
+ # Apply CWT to extract phase at the frequency of interest
467
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
468
+ instantaneous_phase = np.angle(lfp_complex)
469
+
470
+ elif method == 'hilbert':
471
+ if lowcut is None or highcut is None:
472
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
473
+ filtered_lfp = lfp_signal
474
+ else:
475
+ # Bandpass filter the signal
476
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
477
+
478
+ # Get phase using the Hilbert transform
479
+ analytic_signal = signal.hilbert(filtered_lfp)
480
+ instantaneous_phase = np.angle(analytic_signal)
481
+
482
+ else:
483
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
484
+
485
+ # Get phases at spike times
486
+ spike_phases = instantaneous_phase[valid_indices]
487
+
488
+ n_spikes = len(spike_phases)
489
+
490
+ # Calculate PPC (Pairwise Phase Consistency)
491
+ if n_spikes <= 1:
492
+ return 0, spike_phases
493
+
494
+ # Explicit calculation of pairwise phase consistency
495
+ sum_cos_diff = 0.0
496
+
497
+ # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
498
+ # for i in range(n_spikes - 1): # For each spike i
499
+ # for j in range(i + 1, n_spikes): # For each spike j > i
500
+ # # Calculate the phase difference between spikes i and j
501
+ # phase_diff = spike_phases[i] - spike_phases[j]
502
+
503
+ # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
504
+ # cos_diff = np.cos(phase_diff)
505
+
506
+ # # Add to the sum
507
+ # sum_cos_diff += cos_diff
508
+
509
+ # # Calculate PPC according to the equation
510
+ # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
511
+ # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
512
+
513
+ # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
514
+ i, j = np.triu_indices(n_spikes, k=1)
515
+ phase_diff = spike_phases[i] - spike_phases[j]
516
+ sum_cos_diff = np.sum(np.cos(phase_diff))
517
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
518
+
519
+ return ppc
520
+
521
+
522
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
523
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
524
+ lowcut: float = None, highcut: float = None,
525
+ bandwidth: float = 2.0) -> tuple:
526
+ """
527
+ # -----------------------------------------------------------------------------
528
+ # PPC2 Calculation (Vinck et al., 2010)
529
+ # -----------------------------------------------------------------------------
530
+ # Equation(Original):
531
+ # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
532
+ # Optimized Formula (Algebraically Equivalent):
533
+ # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
534
+ # -----------------------------------------------------------------------------
535
+
536
+ Parameters:
537
+ - spike_times: Array of spike times
538
+ - lfp_signal: Local field potential time series
539
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
540
+ - lfp_fs: Sampling frequency in Hz of the LFP
541
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
542
+ - freq_of_interest: Desired frequency for wavelet phase extraction
543
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
544
+ - bandwidth: Bandwidth parameter for the wavelet
545
+
546
+ Returns:
547
+ - ppc2: Pairwise Phase Consistency 2 value
548
+ - spike_phases: Phases at spike times
549
+ """
550
+
551
+ if spike_fs is None:
432
552
  spike_fs = lfp_fs
433
553
  # Convert spike times to sample indices
434
554
  spike_times_seconds = spike_times / spike_fs
@@ -452,7 +572,7 @@ def calculate_ppc1(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
452
572
 
453
573
  elif method == 'hilbert':
454
574
  if lowcut is None or highcut is None:
455
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
575
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
456
576
  filtered_lfp = lfp_signal
457
577
  else:
458
578
  # Bandpass filter the signal
@@ -468,16 +588,23 @@ def calculate_ppc1(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None
468
588
  # Get phases at spike times
469
589
  spike_phases = instantaneous_phase[valid_indices]
470
590
 
471
- # Calculate PPC1
591
+ # Calculate PPC2 according to Vinck et al. (2010), Equation 6
472
592
  n = len(spike_phases)
473
593
 
594
+ if n <= 1:
595
+ return 0, spike_phases
596
+
474
597
  # Convert phases to unit vectors in the complex plane
475
598
  unit_vectors = np.exp(1j * spike_phases)
476
599
 
477
600
  # Calculate the resultant vector
478
601
  resultant_vector = np.sum(unit_vectors)
479
602
 
480
- # PPC1 is the squared length of the resultant vector divided by n²
481
- ppc1 = (np.abs(resultant_vector) ** 2) / (n ** 2)
603
+ # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
604
+ ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
482
605
 
483
- return ppc1, spike_phases
606
+ return ppc2
607
+
608
+
609
+
610
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.8.3
3
+ Version: 0.6.8.5
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -32,6 +32,7 @@ Requires-Dist: xarray
32
32
  Requires-Dist: fooof
33
33
  Requires-Dist: requests
34
34
  Requires-Dist: pyyaml
35
+ Requires-Dist: pywt
35
36
  Dynamic: author
36
37
  Dynamic: author-email
37
38
  Dynamic: classifier
@@ -9,7 +9,7 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
9
  bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
10
10
  bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
11
11
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- bmtool/analysis/lfp.py,sha256=bCpxqhdH6r71yXyqAv_M7BMq4x75lO7bctyyQi6pqdU,18186
12
+ bmtool/analysis/lfp.py,sha256=Ei-l9aA13IOsdOEjmkqmdthKgPkEPnbiHdJ_-TB2twQ,23771
13
13
  bmtool/analysis/spikes.py,sha256=qqJ4zD8xfvSwltlWm_Bhicdngzl6uBqH6Kn5wOMKRc8,11507
14
14
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
@@ -19,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
19
19
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
20
20
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
22
- bmtool-0.6.8.3.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
- bmtool-0.6.8.3.dist-info/METADATA,sha256=jT598Nn_w-_OrCzo7L0cW1wtrv-OKW2aIDhxAqXA2-Q,20431
24
- bmtool-0.6.8.3.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
25
- bmtool-0.6.8.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
- bmtool-0.6.8.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
- bmtool-0.6.8.3.dist-info/RECORD,,
22
+ bmtool-0.6.8.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
+ bmtool-0.6.8.5.dist-info/METADATA,sha256=x5YjcoEnp1Cc0KDJgRyv11GFHH4zResjgcWGAmc9fuM,20451
24
+ bmtool-0.6.8.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
+ bmtool-0.6.8.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
+ bmtool-0.6.8.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
+ bmtool-0.6.8.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5