bmtool 0.7.0.4__py3-none-any.whl → 0.7.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,10 @@ from scipy import signal
7
7
  import numba
8
8
  from numba import cuda
9
9
  import pandas as pd
10
- import xarray as xr
11
- from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power
12
- from typing import Dict, List
10
+ from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
11
+ from typing import Dict, List, Optional
13
12
  from tqdm.notebook import tqdm
14
13
  import scipy.stats as stats
15
- import seaborn as sns
16
- import matplotlib.pyplot as plt
17
14
 
18
15
 
19
16
  def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
@@ -87,17 +84,17 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
87
84
 
88
85
  def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
89
86
  lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
90
- lowcut: float = None, highcut: float = None,
91
- bandwidth: float = 2.0) -> tuple:
87
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
88
+ filtered_lfp_phase: np.ndarray = None) -> float:
92
89
  """
93
- Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
90
+ Calculate spike-lfp unbiased phase locking value
94
91
 
95
92
  Parameters
96
93
  ----------
97
94
  spike_times : np.ndarray
98
95
  Array of spike times
99
96
  lfp_data : np.ndarray
100
- Local field potential time series data
97
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
101
98
  spike_fs : float, optional
102
99
  Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
103
100
  lfp_fs : float
@@ -112,13 +109,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
112
109
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
113
110
  bandwidth : float, optional
114
111
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
112
+ filtered_lfp_phase : np.ndarray, optional
113
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
115
114
 
116
115
  Returns
117
116
  -------
118
- tuple
119
- (plv, spike_phases) where:
120
- - plv: Phase Locking Value
121
- - spike_phases: Phases at spike times
117
+ float
118
+ Phase Locking Value (unbiased)
122
119
  """
123
120
 
124
121
  if spike_fs is None:
@@ -130,46 +127,41 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
130
127
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
131
128
 
132
129
  # Filter indices to ensure they're within bounds of the LFP signal
133
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
134
- if len(valid_indices) <= 1:
135
- return 0, np.array([])
136
-
137
- # Filter the LFP signal to extract the phase
138
- if filter_method == 'wavelet':
139
- if freq_of_interest is None:
140
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
141
-
142
- # Apply CWT to extract phase
143
- filtered_lfp = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
144
-
145
- elif filter_method == 'butter':
146
- if lowcut is None or highcut is None:
147
- raise ValueError("Both lowcut and highcut must be specified for the butter method.")
130
+ if filtered_lfp_phase is not None:
131
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
132
+ else:
133
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
148
134
 
149
- # Bandpass filter the LFP signal
150
- filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
151
- filtered_lfp = signal.hilbert(filtered_lfp) # Get analytic signal
135
+ if len(valid_indices) <= 1:
136
+ return 0
152
137
 
153
-
138
+ # Get instantaneous phase
139
+ if filtered_lfp_phase is None:
140
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
141
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
142
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
154
143
  else:
155
- raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
144
+ instantaneous_phase = filtered_lfp_phase
156
145
 
157
146
  # Get phases at spike times
158
147
  spike_phases = instantaneous_phase[valid_indices]
159
-
160
- # Calculate PPC1
161
- n = len(spike_phases)
148
+
149
+ # Number of spikes
150
+ N = len(spike_phases)
162
151
 
163
152
  # Convert phases to unit vectors in the complex plane
164
153
  unit_vectors = np.exp(1j * spike_phases)
165
154
 
166
- # Calculate the resultant vector
155
+ # Sum of all unit vectors (resultant vector)
167
156
  resultant_vector = np.sum(unit_vectors)
168
-
169
- # Plv is the squared length of the resultant vector divided by n²
170
- plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
171
-
172
- return plv
157
+
158
+ # Calculate plv^2 * N
159
+ plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
160
+ plv = (plv2n / N) ** 0.5
161
+ ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
162
+ plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
163
+
164
+ return plv_unbiased
173
165
 
174
166
 
175
167
  @numba.njit(parallel=True, fastmath=True)
@@ -209,8 +201,8 @@ def _ppc_gpu(spike_phases):
209
201
 
210
202
  def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
211
203
  lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
212
- lowcut: float = None, highcut: float = None,
213
- bandwidth: float = 2.0, ppc_method: str = 'numpy') -> tuple:
204
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
205
+ ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
214
206
  """
215
207
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
216
208
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
@@ -220,7 +212,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
220
212
  spike_times : np.ndarray
221
213
  Array of spike times
222
214
  lfp_data : np.ndarray
223
- Local field potential time series data
215
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
224
216
  spike_fs : float, optional
225
217
  Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
226
218
  lfp_fs : float
@@ -237,13 +229,13 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
237
229
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
238
230
  ppc_method : str, optional
239
231
  Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
232
+ filtered_lfp_phase : np.ndarray, optional
233
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
240
234
 
241
235
  Returns
242
236
  -------
243
- tuple
244
- (ppc, spike_phases) where:
245
- - ppc: Pairwise Phase Consistency value
246
- - spike_phases: Phases at spike times
237
+ float
238
+ Pairwise Phase Consistency value
247
239
  """
248
240
  if spike_fs is None:
249
241
  spike_fs = lfp_fs
@@ -254,32 +246,21 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
254
246
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
255
247
 
256
248
  # Filter indices to ensure they're within bounds of the LFP signal
257
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
249
+ if filtered_lfp_phase is not None:
250
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
251
+ else:
252
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
253
+
258
254
  if len(valid_indices) <= 1:
259
- return 0, np.array([])
255
+ return 0
260
256
 
261
- # Extract phase using the specified method
262
- if filter_method == 'wavelet':
263
- if freq_of_interest is None:
264
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
265
-
266
- # Apply CWT to extract phase at the frequency of interest
267
- lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
268
- instantaneous_phase = np.angle(lfp_complex)
269
-
270
- elif filter_method == 'butter':
271
- if lowcut is None or highcut is None:
272
- raise ValueError("Both lowcut and highcut must be specified for the butter method.")
273
-
274
- # Bandpass filter the signal
275
- filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
276
-
277
- # Get phase using the Hilbert transform
278
- analytic_signal = signal.hilbert(filtered_lfp)
279
- instantaneous_phase = np.angle(analytic_signal)
280
-
257
+ # Get instantaneous phase
258
+ if filtered_lfp_phase is None:
259
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
260
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
261
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
281
262
  else:
282
- raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
263
+ instantaneous_phase = filtered_lfp_phase
283
264
 
284
265
  # Get phases at spike times
285
266
  spike_phases = instantaneous_phase[valid_indices]
@@ -288,28 +269,10 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
288
269
 
289
270
  # Calculate PPC (Pairwise Phase Consistency)
290
271
  if n_spikes <= 1:
291
- return 0, spike_phases
272
+ return 0
292
273
 
293
274
  # Explicit calculation of pairwise phase consistency
294
- sum_cos_diff = 0.0
295
-
296
- # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
297
- # for i in range(n_spikes - 1): # For each spike i
298
- # for j in range(i + 1, n_spikes): # For each spike j > i
299
- # # Calculate the phase difference between spikes i and j
300
- # phase_diff = spike_phases[i] - spike_phases[j]
301
-
302
- # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
303
- # cos_diff = np.cos(phase_diff)
304
-
305
- # # Add to the sum
306
- # sum_cos_diff += cos_diff
307
-
308
- # # Calculate PPC according to the equation
309
- # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
310
- # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
311
-
312
- # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
275
+ # Vectorized computation for efficiency
313
276
  if ppc_method == 'numpy':
314
277
  i, j = np.triu_indices(n_spikes, k=1)
315
278
  phase_diff = spike_phases[i] - spike_phases[j]
@@ -320,14 +283,14 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
320
283
  elif ppc_method == 'gpu':
321
284
  ppc = _ppc_gpu(spike_phases)
322
285
  else:
323
- raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
286
+ raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
324
287
  return ppc
325
288
 
326
289
 
327
290
  def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
328
291
  lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
329
- lowcut: float = None, highcut: float = None,
330
- bandwidth: float = 2.0) -> float:
292
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
293
+ filtered_lfp_phase: np.ndarray = None) -> float:
331
294
  """
332
295
  # -----------------------------------------------------------------------------
333
296
  # PPC2 Calculation (Vinck et al., 2010)
@@ -343,7 +306,7 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
343
306
  spike_times : np.ndarray
344
307
  Array of spike times
345
308
  lfp_data : np.ndarray
346
- Local field potential time series data
309
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
347
310
  spike_fs : float, optional
348
311
  Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
349
312
  lfp_fs : float
@@ -358,6 +321,8 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
358
321
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
359
322
  bandwidth : float, optional
360
323
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
324
+ filtered_lfp_phase : np.ndarray, optional
325
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
361
326
 
362
327
  Returns
363
328
  -------
@@ -374,32 +339,21 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
374
339
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
375
340
 
376
341
  # Filter indices to ensure they're within bounds of the LFP signal
377
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
342
+ if filtered_lfp_phase is not None:
343
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
344
+ else:
345
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
346
+
378
347
  if len(valid_indices) <= 1:
379
348
  return 0
380
349
 
381
- # Extract phase using the specified method
382
- if filter_method == 'wavelet':
383
- if freq_of_interest is None:
384
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
385
-
386
- # Apply CWT to extract phase at the frequency of interest
387
- lfp_complex = wavelet_filter(x=lfp_data, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
388
- instantaneous_phase = np.angle(lfp_complex)
389
-
390
- elif filter_method == 'butter':
391
- if lowcut is None or highcut is None:
392
- raise ValueError("Both lowcut and highcut must be specified for the butter method.")
393
-
394
- # Bandpass filter the signal
395
- filtered_lfp = butter_bandpass_filter(data=lfp_data, lowcut=lowcut, highcut=highcut, fs=lfp_fs)
396
-
397
- # Get phase using the Hilbert transform
398
- analytic_signal = signal.hilbert(filtered_lfp)
399
- instantaneous_phase = np.angle(analytic_signal)
400
-
350
+ # Get instantaneous phase
351
+ if filtered_lfp_phase is None:
352
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
353
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
354
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
401
355
  else:
402
- raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
356
+ instantaneous_phase = filtered_lfp_phase
403
357
 
404
358
  # Get phases at spike times
405
359
  spike_phases = instantaneous_phase[valid_indices]
@@ -422,14 +376,15 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
422
376
  return ppc2
423
377
 
424
378
 
425
- def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None,
379
+ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
380
+ entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
426
381
  spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
427
- pop_names: List[str]=None, freqs: List[float]=None) -> Dict[str, Dict[int, Dict[float, float]]]:
382
+ freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
428
383
  """
429
- Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
384
+ Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
430
385
 
431
- This function computes the PPC for each neuron within the specified populations based on their spike times
432
- and the provided LFP signal. It returns a nested dictionary structure containing the PPC values
386
+ This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
387
+ and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
433
388
  organized by population, node ID, and frequency.
434
389
 
435
390
  Parameters
@@ -438,14 +393,26 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
438
393
  DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
439
394
  lfp_data : np.ndarray
440
395
  Local field potential (LFP) time series data
396
+ filter_method : str, optional
397
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
398
+ entrainment_method : str, optional
399
+ Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
400
+ lowcut : float, optional
401
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
402
+ highcut : float, optional
403
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
441
404
  spike_fs : float
442
405
  Sampling frequency of the spike times in Hz
443
406
  lfp_fs : float
444
407
  Sampling frequency of the LFP signal in Hz
408
+ bandwidth : float, optional
409
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
410
+ ppc_method : str, optional
411
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
445
412
  pop_names : List[str]
446
413
  List of population names to analyze
447
414
  freqs : List[float]
448
- List of frequencies (in Hz) at which to calculate PPC
415
+ List of frequencies (in Hz) at which to calculate entrainment
449
416
 
450
417
  Returns
451
418
  -------
@@ -454,18 +421,32 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
454
421
  {
455
422
  population_name: {
456
423
  node_id: {
457
- frequency: PPC value
424
+ frequency: entrainment value
458
425
  }
459
426
  }
460
427
  }
461
- PPC values are floats representing the pairwise phase consistency at each frequency
428
+ Entrainment values are floats representing the metric (PPC, PLV) at each frequency
462
429
  """
463
- ppc_dict = {}
430
+ # pre filter lfp to speed up calculate of entrainment
431
+ filtered_lfp_phases = {}
432
+ for freq in range(len(freqs)):
433
+ phase = get_lfp_phase(
434
+ lfp_data=lfp_data,
435
+ freq_of_interest=freqs[freq],
436
+ fs=lfp_fs,
437
+ filter_method=filter_method,
438
+ lowcut=lowcut,
439
+ highcut=highcut,
440
+ bandwidth=bandwidth
441
+ )
442
+ filtered_lfp_phases[freqs[freq]] = phase
443
+
444
+ entrainment_dict = {}
464
445
  for pop in pop_names:
465
446
  skip_count = 0
466
447
  pop_spikes = spike_df[spike_df['pop_name'] == pop]
467
448
  nodes = pop_spikes['node_ids'].unique()
468
- ppc_dict[pop] = {}
449
+ entrainment_dict[pop] = {}
469
450
  print(f'Processing {pop} population')
470
451
  for node in tqdm(nodes):
471
452
  node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
@@ -475,22 +456,53 @@ def calculate_ppc_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=Non
475
456
  skip_count += 1
476
457
  continue
477
458
 
478
- ppc_dict[pop][node] = {}
459
+ entrainment_dict[pop][node] = {}
479
460
  for freq in freqs:
480
- ppc = calculate_ppc2(
481
- node_spikes['timestamps'].values,
482
- lfp_data,
483
- spike_fs=spike_fs,
484
- lfp_fs=lfp_fs,
485
- freq_of_interest=freq,
486
- bandwidth=bandwidth,
487
- filter_method='wavelet'
488
- )
489
- ppc_dict[pop][node][freq] = ppc
461
+ # Calculate entrainment based on the selected method using the pre-filtered phases
462
+ if entrainment_method == 'plv':
463
+ entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
464
+ node_spikes['timestamps'].values,
465
+ lfp_data,
466
+ spike_fs=spike_fs,
467
+ lfp_fs=lfp_fs,
468
+ freq_of_interest=freq,
469
+ bandwidth=bandwidth,
470
+ lowcut=lowcut,
471
+ highcut=highcut,
472
+ filter_method=filter_method,
473
+ filtered_lfp_phase=filtered_lfp_phases[freq]
474
+ )
475
+ elif entrainment_method == 'ppc2':
476
+ entrainment_dict[pop][node][freq] = calculate_ppc2(
477
+ node_spikes['timestamps'].values,
478
+ lfp_data,
479
+ spike_fs=spike_fs,
480
+ lfp_fs=lfp_fs,
481
+ freq_of_interest=freq,
482
+ bandwidth=bandwidth,
483
+ lowcut=lowcut,
484
+ highcut=highcut,
485
+ filter_method=filter_method,
486
+ filtered_lfp_phase=filtered_lfp_phases[freq]
487
+ )
488
+ elif entrainment_method == 'ppc':
489
+ entrainment_dict[pop][node][freq] = calculate_ppc(
490
+ node_spikes['timestamps'].values,
491
+ lfp_data,
492
+ spike_fs=spike_fs,
493
+ lfp_fs=lfp_fs,
494
+ freq_of_interest=freq,
495
+ bandwidth=bandwidth,
496
+ lowcut=lowcut,
497
+ highcut=highcut,
498
+ filter_method=filter_method,
499
+ ppc_method=ppc_method,
500
+ filtered_lfp_phase=filtered_lfp_phases[freq]
501
+ )
490
502
 
491
- print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
503
+ print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
492
504
 
493
- return ppc_dict
505
+ return entrainment_dict
494
506
 
495
507
 
496
508
  def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
@@ -540,12 +552,8 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
540
552
  # Calculate power at each frequency band using specified filter
541
553
  power_by_freq = {}
542
554
  for freq in frequencies:
543
- if filter_method == 'wavelet':
544
- power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
545
- lowcut=None, highcut=None, bandwidth=bandwidth)
546
- elif filter_method == 'butter':
547
- power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
548
- lowcut=lowcut, highcut=highcut)
555
+ power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
556
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
549
557
 
550
558
  # Calculate correlation for each population
551
559
  for pop in pop_names:
bmtool/analysis/lfp.py CHANGED
@@ -412,6 +412,61 @@ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: s
412
412
  return power
413
413
 
414
414
 
415
+ def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
416
+ lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
417
+ """
418
+ Calculate the phase of the filtered signal.
419
+
420
+ Parameters
421
+ ----------
422
+ lfp_data : np.ndarray
423
+ Input LFP data
424
+ fs : float
425
+ Sampling frequency (Hz)
426
+ freq : float
427
+ Frequency of interest (Hz)
428
+ filter_method : str, optional
429
+ Method for filtering the signal ('wavelet' or 'butter')
430
+ bandwidth : float, optional
431
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
432
+ lowcut : float, optional
433
+ Low cutoff frequency for Butterworth filter when method='butter'
434
+ highcut : float, optional
435
+ High cutoff frequency for Butterworth filter when method='butter'
436
+
437
+ Returns
438
+ -------
439
+ np.ndarray
440
+ Phase of the filtered signal
441
+
442
+ Notes
443
+ -----
444
+ - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
445
+ - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
446
+ followed by Hilbert transform to extract the phase
447
+ - When using the 'butter' method, both lowcut and highcut must be provided
448
+ """
449
+ if filter_method == 'wavelet':
450
+ if freq_of_interest is None:
451
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
452
+ # Wavelet filter returns complex values directly
453
+ filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
454
+ # Phase is the angle of the complex signal
455
+ phase = np.angle(filtered_signal)
456
+ elif filter_method == 'butter':
457
+ if lowcut is None or highcut is None:
458
+ raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
459
+ # Butterworth filter returns real values
460
+ filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
461
+ # Apply Hilbert transform to get analytic signal (complex)
462
+ analytic_signal = signal.hilbert(filtered_signal)
463
+ # Phase is the angle of the analytic signal
464
+ phase = np.angle(analytic_signal)
465
+ else:
466
+ raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
467
+
468
+ return phase
469
+
415
470
  # windowing functions
416
471
  def windowed_xarray(da, windows, dim='time',
417
472
  new_coord_name='cycle', new_coord=None):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.0.4
3
+ Version: 0.7.0.6
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,8 +8,8 @@ bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
8
8
  bmtool/singlecell.py,sha256=imcdxIzvYVkaOLSGDxYp8WGGssGwXXBCRhzhlqVp7hA,44267
9
9
  bmtool/synapses.py,sha256=Ow2fZavA_3_5BYCjcgPjW0YsyVOetn1wvLxL7hQvbZo,64556
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=JRg9sQ7WrZMqHwMaDJtCN7kGgZHJ5msUSzP-JPltC8k,23158
12
- bmtool/analysis/lfp.py,sha256=hOqD4xcDEL0NrNIN2-Ler_mkvY5cEUhxr7VUdX5Gwh8,21737
11
+ bmtool/analysis/entrainment.py,sha256=7lFlGMApL_2snwdvIPDLFW1KKPdyiuCnZ5ADa7ujx5o,24439
12
+ bmtool/analysis/lfp.py,sha256=6u9cHnac-5Fzpk9ecQew7MmXBAolzKZakRsnPn3-C2U,24109
13
13
  bmtool/analysis/netcon_reports.py,sha256=7moyoUC45Cl1_6sGqwZ5aKphK_8i4AimroePXcgUnIo,3057
14
14
  bmtool/analysis/spikes.py,sha256=kcJZQsvPVzQgcuiO-El_4OODW57hwNwdok_RsFMITCg,15097
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
26
26
  bmtool/util/util.py,sha256=XR0qZnv_Q47jMBKQpFzCSkCuKe9u8L3YSGJAOpP2zT0,57630
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
29
- bmtool-0.7.0.4.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.0.4.dist-info/METADATA,sha256=I-D4fwZQIHcHvD6Ou8Az3Kclwl17kwpZH7YHrD0eEg4,2768
31
- bmtool-0.7.0.4.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
- bmtool-0.7.0.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.0.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.0.4.dist-info/RECORD,,
29
+ bmtool-0.7.0.6.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.0.6.dist-info/METADATA,sha256=OiLqYFEe7N56N_6X1j89oy5zofo768sltjlCyqj4OxQ,2768
31
+ bmtool-0.7.0.6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
+ bmtool-0.7.0.6.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.0.6.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.0.6.dist-info/RECORD,,