bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,23 +2,62 @@
2
2
  Module for entrainment analysis
3
3
  """
4
4
 
5
- import numpy as np
6
- from scipy import signal
5
+ from typing import Dict, List
6
+
7
7
  import numba
8
- from numba import cuda
8
+ import numpy as np
9
9
  import pandas as pd
10
- from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
11
- from typing import Dict, List, Optional
12
- from tqdm.notebook import tqdm
13
10
  import scipy.stats as stats
11
+ import xarray as xr
12
+ from numba import cuda
13
+ from scipy import signal
14
+ from tqdm.notebook import tqdm
15
+
16
+ from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
17
+
18
+
19
+ def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
20
+ """the lfp xarray should have a time axis. use that to align the spike times since the lfp can start at a
21
+ non-zero time after sliced. Both need to be on same fs for this to be correct.
14
22
 
23
+ Parameters
24
+ ----------
25
+ lfp : xarray.DataArray
26
+ LFP data with time coordinates
27
+ timestamps : np.ndarray
28
+ Array of spike timestamps
15
29
 
16
- def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
17
- filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
18
- bandwidth: float = 2.0) -> np.ndarray:
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ Copy of timestamps with adjusted timestamps to align with lfp.
34
+ """
35
+ # print("Pairing LFP and Spike Times")
36
+ # print(lfp.time.values)
37
+ # print(f"LFP starts at {lfp.time.values[0]}ms")
38
+ # need to make sure lfp and spikes have the same time axis
39
+ # align spikes with lfp
40
+ timestamps = timestamps[
41
+ (timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
42
+ ].copy()
43
+ # set the time axis of the spikes to match the lfp
44
+ timestamps = timestamps - lfp.time.values[0]
45
+ return timestamps
46
+
47
+
48
+ def calculate_signal_signal_plv(
49
+ signal1: np.ndarray,
50
+ signal2: np.ndarray,
51
+ fs: float,
52
+ freq_of_interest: float = None,
53
+ filter_method: str = "wavelet",
54
+ lowcut: float = None,
55
+ highcut: float = None,
56
+ bandwidth: float = 2.0,
57
+ ) -> np.ndarray:
19
58
  """
20
59
  Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
21
-
60
+
22
61
  Parameters
23
62
  ----------
24
63
  signal1 : np.ndarray
@@ -37,7 +76,7 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
37
76
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
38
77
  bandwidth : float, optional
39
78
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
40
-
79
+
41
80
  Returns
42
81
  -------
43
82
  np.ndarray
@@ -45,23 +84,29 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
45
84
  """
46
85
  if len(signal1) != len(signal2):
47
86
  raise ValueError("Input signals must have the same length.")
48
-
49
- if filter_method == 'wavelet':
87
+
88
+ if filter_method == "wavelet":
50
89
  if freq_of_interest is None:
51
90
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
52
-
91
+
53
92
  # Apply CWT to both signals
54
93
  theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
55
94
  theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
56
-
57
- elif filter_method == 'butter':
95
+
96
+ elif filter_method == "butter":
58
97
  if lowcut is None or highcut is None:
59
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
60
-
98
+ print(
99
+ "Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation"
100
+ )
101
+
61
102
  if lowcut and highcut:
62
103
  # Bandpass filter and get the analytic signal using the Hilbert transform
63
- filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
64
- filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
104
+ filtered_signal1 = butter_bandpass_filter(
105
+ data=signal1, lowcut=lowcut, highcut=highcut, fs=fs
106
+ )
107
+ filtered_signal2 = butter_bandpass_filter(
108
+ data=signal2, lowcut=lowcut, highcut=highcut, fs=fs
109
+ )
65
110
  # Get phase using the Hilbert transform
66
111
  theta1 = signal.hilbert(filtered_signal1)
67
112
  theta2 = signal.hilbert(filtered_signal2)
@@ -69,26 +114,34 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
69
114
  # Get phase using the Hilbert transform without filtering
70
115
  theta1 = signal.hilbert(signal1)
71
116
  theta2 = signal.hilbert(signal2)
72
-
117
+
73
118
  else:
74
119
  raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
75
-
120
+
76
121
  # Calculate phase difference
77
122
  phase_diff = np.angle(theta1) - np.angle(theta2)
78
-
123
+
79
124
  # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
80
125
  plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
81
-
126
+
82
127
  return plv
83
128
 
84
129
 
85
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
86
- lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
87
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
88
- filtered_lfp_phase: np.ndarray = None) -> float:
130
+ def calculate_spike_lfp_plv(
131
+ spike_times: np.ndarray = None,
132
+ lfp_data=None,
133
+ spike_fs: float = None,
134
+ lfp_fs: float = None,
135
+ filter_method: str = "butter",
136
+ freq_of_interest: float = None,
137
+ lowcut: float = None,
138
+ highcut: float = None,
139
+ bandwidth: float = 2.0,
140
+ filtered_lfp_phase: np.ndarray = None,
141
+ ) -> float:
89
142
  """
90
- Calculate spike-lfp unbiased phase locking value
91
-
143
+ Calculate spike-lfp unbiased phase locking value
144
+
92
145
  Parameters
93
146
  ----------
94
147
  spike_times : np.ndarray
@@ -111,13 +164,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
111
164
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
112
165
  filtered_lfp_phase : np.ndarray, optional
113
166
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
114
-
167
+
115
168
  Returns
116
169
  -------
117
170
  float
118
171
  Phase Locking Value (unbiased)
119
172
  """
120
-
173
+
121
174
  if spike_fs is None:
122
175
  spike_fs = lfp_fs
123
176
  # Convert spike times to sample indices
@@ -125,33 +178,50 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
125
178
 
126
179
  # Then convert from seconds to samples at the new sampling rate
127
180
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
128
-
181
+
129
182
  # Filter indices to ensure they're within bounds of the LFP signal
130
- if filtered_lfp_phase is not None:
131
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
132
- else:
133
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
134
-
183
+ if isinstance(lfp_data, xr.DataArray):
184
+ if filtered_lfp_phase is not None:
185
+ valid_indices = align_spike_times_with_lfp(
186
+ lfp=filtered_lfp_phase, timestamps=spike_indices
187
+ )
188
+ else:
189
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
190
+ elif isinstance(lfp_data, np.ndarray):
191
+ if filtered_lfp_phase is not None:
192
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
193
+ else:
194
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
195
+
135
196
  if len(valid_indices) <= 1:
136
197
  return 0
137
-
198
+
138
199
  # Get instantaneous phase
139
200
  if filtered_lfp_phase is None:
140
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
141
- freq_of_interest=freq_of_interest, lowcut=lowcut,
142
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
201
+ instantaneous_phase = get_lfp_phase(
202
+ lfp_data=lfp_data,
203
+ filter_method=filter_method,
204
+ freq_of_interest=freq_of_interest,
205
+ lowcut=lowcut,
206
+ highcut=highcut,
207
+ bandwidth=bandwidth,
208
+ fs=lfp_fs,
209
+ )
143
210
  else:
144
211
  instantaneous_phase = filtered_lfp_phase
145
-
212
+
146
213
  # Get phases at spike times
147
- spike_phases = instantaneous_phase[valid_indices]
214
+ if isinstance(instantaneous_phase, xr.DataArray):
215
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
216
+ else:
217
+ spike_phases = instantaneous_phase[valid_indices]
148
218
 
149
219
  # Number of spikes
150
220
  N = len(spike_phases)
151
-
221
+
152
222
  # Convert phases to unit vectors in the complex plane
153
223
  unit_vectors = np.exp(1j * spike_phases)
154
-
224
+
155
225
  # Sum of all unit vectors (resultant vector)
156
226
  resultant_vector = np.sum(unit_vectors)
157
227
 
@@ -159,7 +229,7 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
159
229
  plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
160
230
  plv = (plv2n / N) ** 0.5
161
231
  ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
162
- plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
232
+ plv_unbiased = np.fmax(ppc, 0.0) ** 0.5 # ensure non-negative
163
233
 
164
234
  return plv_unbiased
165
235
 
@@ -181,7 +251,7 @@ def _ppc_cuda_kernel(spike_phases, out):
181
251
  i = cuda.grid(1)
182
252
  if i < len(spike_phases):
183
253
  local_sum = 0.0
184
- for j in range(i+1, len(spike_phases)):
254
+ for j in range(i + 1, len(spike_phases)):
185
255
  local_sum += np.cos(spike_phases[i] - spike_phases[j])
186
256
  out[i] = local_sum
187
257
 
@@ -190,23 +260,32 @@ def _ppc_gpu(spike_phases):
190
260
  """GPU-accelerated implementation"""
191
261
  d_phases = cuda.to_device(spike_phases)
192
262
  d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
193
-
263
+
194
264
  threads = 256
195
265
  blocks = (len(spike_phases) + threads - 1) // threads
196
-
266
+
197
267
  _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
198
268
  total = d_out.copy_to_host().sum()
199
- return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
200
-
201
-
202
- def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
203
- lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
204
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
205
- ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
269
+ return (2 / (len(spike_phases) * (len(spike_phases) - 1))) * total
270
+
271
+
272
+ def calculate_ppc(
273
+ spike_times: np.ndarray = None,
274
+ lfp_data=None,
275
+ spike_fs: float = None,
276
+ lfp_fs: float = None,
277
+ filter_method: str = "wavelet",
278
+ freq_of_interest: float = None,
279
+ lowcut: float = None,
280
+ highcut: float = None,
281
+ bandwidth: float = 2.0,
282
+ ppc_method: str = "numpy",
283
+ filtered_lfp_phase: np.ndarray = None,
284
+ ) -> float:
206
285
  """
207
286
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
208
287
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
209
-
288
+
210
289
  Parameters
211
290
  ----------
212
291
  spike_times : np.ndarray
@@ -231,7 +310,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
231
310
  Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
232
311
  filtered_lfp_phase : np.ndarray, optional
233
312
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
234
-
313
+
235
314
  Returns
236
315
  -------
237
316
  float
@@ -244,63 +323,88 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
244
323
 
245
324
  # Then convert from seconds to samples at the new sampling rate
246
325
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
247
-
326
+
248
327
  # Filter indices to ensure they're within bounds of the LFP signal
249
- if filtered_lfp_phase is not None:
250
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
251
- else:
252
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
253
-
328
+ if isinstance(lfp_data, xr.DataArray):
329
+ if filtered_lfp_phase is not None:
330
+ valid_indices = align_spike_times_with_lfp(
331
+ lfp=filtered_lfp_phase, timestamps=spike_indices
332
+ )
333
+ else:
334
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
335
+ elif isinstance(lfp_data, np.ndarray):
336
+ if filtered_lfp_phase is not None:
337
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
338
+ else:
339
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
340
+
254
341
  if len(valid_indices) <= 1:
255
342
  return 0
256
-
343
+
257
344
  # Get instantaneous phase
258
345
  if filtered_lfp_phase is None:
259
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
260
- freq_of_interest=freq_of_interest, lowcut=lowcut,
261
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
346
+ instantaneous_phase = get_lfp_phase(
347
+ lfp_data=lfp_data,
348
+ filter_method=filter_method,
349
+ freq_of_interest=freq_of_interest,
350
+ lowcut=lowcut,
351
+ highcut=highcut,
352
+ bandwidth=bandwidth,
353
+ fs=lfp_fs,
354
+ )
262
355
  else:
263
356
  instantaneous_phase = filtered_lfp_phase
264
-
357
+
265
358
  # Get phases at spike times
266
- spike_phases = instantaneous_phase[valid_indices]
267
-
359
+ if isinstance(instantaneous_phase, xr.DataArray):
360
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
361
+ else:
362
+ spike_phases = instantaneous_phase[valid_indices]
363
+
268
364
  n_spikes = len(spike_phases)
269
365
 
270
366
  # Calculate PPC (Pairwise Phase Consistency)
271
367
  if n_spikes <= 1:
272
368
  return 0
273
-
369
+
274
370
  # Explicit calculation of pairwise phase consistency
275
371
  # Vectorized computation for efficiency
276
- if ppc_method == 'numpy':
372
+ if ppc_method == "numpy":
277
373
  i, j = np.triu_indices(n_spikes, k=1)
278
374
  phase_diff = spike_phases[i] - spike_phases[j]
279
375
  sum_cos_diff = np.sum(np.cos(phase_diff))
280
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
281
- elif ppc_method == 'numba':
376
+ ppc = (2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff
377
+ elif ppc_method == "numba":
282
378
  ppc = _ppc_parallel_numba(spike_phases)
283
- elif ppc_method == 'gpu':
379
+ elif ppc_method == "gpu":
284
380
  ppc = _ppc_gpu(spike_phases)
285
381
  else:
286
382
  raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
287
383
  return ppc
288
384
 
289
-
290
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
291
- lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
292
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
293
- filtered_lfp_phase: np.ndarray = None) -> float:
385
+
386
+ def calculate_ppc2(
387
+ spike_times: np.ndarray = None,
388
+ lfp_data=None,
389
+ spike_fs: float = None,
390
+ lfp_fs: float = None,
391
+ filter_method: str = "wavelet",
392
+ freq_of_interest: float = None,
393
+ lowcut: float = None,
394
+ highcut: float = None,
395
+ bandwidth: float = 2.0,
396
+ filtered_lfp_phase: np.ndarray = None,
397
+ ) -> float:
294
398
  """
295
399
  # -----------------------------------------------------------------------------
296
- # PPC2 Calculation (Vinck et al., 2010)
400
+ # PPC2 Calculation (Vinck et al., 2010)
297
401
  # -----------------------------------------------------------------------------
298
402
  # Equation(Original):
299
403
  # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
300
404
  # Optimized Formula (Algebraically Equivalent):
301
405
  # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
302
406
  # -----------------------------------------------------------------------------
303
-
407
+
304
408
  Parameters
305
409
  ----------
306
410
  spike_times : np.ndarray
@@ -323,13 +427,13 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
323
427
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
324
428
  filtered_lfp_phase : np.ndarray, optional
325
429
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
326
-
430
+
327
431
  Returns
328
432
  -------
329
433
  float
330
434
  Pairwise Phase Consistency 2 (PPC2) value
331
435
  """
332
-
436
+
333
437
  if spike_fs is None:
334
438
  spike_fs = lfp_fs
335
439
  # Convert spike times to sample indices
@@ -337,49 +441,75 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
337
441
 
338
442
  # Then convert from seconds to samples at the new sampling rate
339
443
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
340
-
444
+
341
445
  # Filter indices to ensure they're within bounds of the LFP signal
342
- if filtered_lfp_phase is not None:
343
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
344
- else:
345
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
346
-
446
+ if isinstance(lfp_data, xr.DataArray):
447
+ if filtered_lfp_phase is not None:
448
+ valid_indices = align_spike_times_with_lfp(
449
+ lfp=filtered_lfp_phase, timestamps=spike_indices
450
+ )
451
+ else:
452
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
453
+ elif isinstance(lfp_data, np.ndarray):
454
+ if filtered_lfp_phase is not None:
455
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
456
+ else:
457
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
458
+
347
459
  if len(valid_indices) <= 1:
348
460
  return 0
349
-
461
+
350
462
  # Get instantaneous phase
351
463
  if filtered_lfp_phase is None:
352
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
353
- freq_of_interest=freq_of_interest, lowcut=lowcut,
354
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
464
+ instantaneous_phase = get_lfp_phase(
465
+ lfp_data=lfp_data,
466
+ filter_method=filter_method,
467
+ freq_of_interest=freq_of_interest,
468
+ lowcut=lowcut,
469
+ highcut=highcut,
470
+ bandwidth=bandwidth,
471
+ fs=lfp_fs,
472
+ )
355
473
  else:
356
474
  instantaneous_phase = filtered_lfp_phase
357
-
475
+
358
476
  # Get phases at spike times
359
- spike_phases = instantaneous_phase[valid_indices]
360
-
477
+ if isinstance(instantaneous_phase, xr.DataArray):
478
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
479
+ else:
480
+ spike_phases = instantaneous_phase[valid_indices]
361
481
  # Calculate PPC2 according to Vinck et al. (2010), Equation 6
362
482
  n = len(spike_phases)
363
-
483
+
364
484
  if n <= 1:
365
485
  return 0
366
-
486
+
367
487
  # Convert phases to unit vectors in the complex plane
368
488
  unit_vectors = np.exp(1j * spike_phases)
369
-
489
+
370
490
  # Calculate the resultant vector
371
491
  resultant_vector = np.sum(unit_vectors)
372
-
492
+
373
493
  # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
374
- ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
375
-
494
+ ppc2 = (np.abs(resultant_vector) ** 2 - n) / (n * (n - 1))
495
+
376
496
  return ppc2
377
497
 
378
498
 
379
- def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
380
- entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
381
- spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
382
- freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
499
+ def calculate_entrainment_per_cell(
500
+ spike_df: pd.DataFrame = None,
501
+ lfp_data: np.ndarray = None,
502
+ filter_method: str = "wavelet",
503
+ pop_names: List[str] = None,
504
+ entrainment_method: str = "plv",
505
+ lowcut: float = None,
506
+ highcut: float = None,
507
+ spike_fs: float = None,
508
+ lfp_fs: float = None,
509
+ bandwidth: float = 2,
510
+ freqs: List[float] = None,
511
+ ppc_method: str = "numpy",
512
+ ) -> Dict[str, Dict[int, Dict[float, float]]]:
383
513
  """
384
514
  Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
385
515
 
@@ -431,26 +561,26 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
431
561
  filtered_lfp_phases = {}
432
562
  for freq in range(len(freqs)):
433
563
  phase = get_lfp_phase(
434
- lfp_data=lfp_data,
435
- freq_of_interest=freqs[freq],
436
- fs=lfp_fs,
564
+ lfp_data=lfp_data,
565
+ freq_of_interest=freqs[freq],
566
+ fs=lfp_fs,
437
567
  filter_method=filter_method,
438
- lowcut=lowcut,
439
- highcut=highcut,
440
- bandwidth=bandwidth
568
+ lowcut=lowcut,
569
+ highcut=highcut,
570
+ bandwidth=bandwidth,
441
571
  )
442
572
  filtered_lfp_phases[freqs[freq]] = phase
443
-
573
+
444
574
  entrainment_dict = {}
445
575
  for pop in pop_names:
446
576
  skip_count = 0
447
- pop_spikes = spike_df[spike_df['pop_name'] == pop]
448
- nodes = pop_spikes['node_ids'].unique()
577
+ pop_spikes = spike_df[spike_df["pop_name"] == pop]
578
+ nodes = pop_spikes["node_ids"].unique()
449
579
  entrainment_dict[pop] = {}
450
- print(f'Processing {pop} population')
580
+ print(f"Processing {pop} population")
451
581
  for node in tqdm(nodes):
452
- node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
453
-
582
+ node_spikes = pop_spikes[pop_spikes["node_ids"] == node]
583
+
454
584
  # Skip nodes with less than or equal to 1 spike
455
585
  if len(node_spikes) <= 1:
456
586
  skip_count += 1
@@ -459,9 +589,9 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
459
589
  entrainment_dict[pop][node] = {}
460
590
  for freq in freqs:
461
591
  # Calculate entrainment based on the selected method using the pre-filtered phases
462
- if entrainment_method == 'plv':
592
+ if entrainment_method == "plv":
463
593
  entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
464
- node_spikes['timestamps'].values,
594
+ node_spikes["timestamps"].values,
465
595
  lfp_data,
466
596
  spike_fs=spike_fs,
467
597
  lfp_fs=lfp_fs,
@@ -470,11 +600,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
470
600
  lowcut=lowcut,
471
601
  highcut=highcut,
472
602
  filter_method=filter_method,
473
- filtered_lfp_phase=filtered_lfp_phases[freq]
603
+ filtered_lfp_phase=filtered_lfp_phases[freq],
474
604
  )
475
- elif entrainment_method == 'ppc2':
605
+ elif entrainment_method == "ppc2":
476
606
  entrainment_dict[pop][node][freq] = calculate_ppc2(
477
- node_spikes['timestamps'].values,
607
+ node_spikes["timestamps"].values,
478
608
  lfp_data,
479
609
  spike_fs=spike_fs,
480
610
  lfp_fs=lfp_fs,
@@ -483,11 +613,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
483
613
  lowcut=lowcut,
484
614
  highcut=highcut,
485
615
  filter_method=filter_method,
486
- filtered_lfp_phase=filtered_lfp_phases[freq]
616
+ filtered_lfp_phase=filtered_lfp_phases[freq],
487
617
  )
488
- elif entrainment_method == 'ppc':
618
+ elif entrainment_method == "ppc":
489
619
  entrainment_dict[pop][node][freq] = calculate_ppc(
490
- node_spikes['timestamps'].values,
620
+ node_spikes["timestamps"].values,
491
621
  lfp_data,
492
622
  spike_fs=spike_fs,
493
623
  lfp_fs=lfp_fs,
@@ -497,21 +627,32 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
497
627
  highcut=highcut,
498
628
  filter_method=filter_method,
499
629
  ppc_method=ppc_method,
500
- filtered_lfp_phase=filtered_lfp_phases[freq]
630
+ filtered_lfp_phase=filtered_lfp_phases[freq],
501
631
  )
502
632
 
503
- print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
633
+ print(
634
+ f"Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes"
635
+ )
504
636
 
505
637
  return entrainment_dict
506
638
 
507
639
 
508
- def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
509
- bandwidth=2.0, lowcut=None, highcut=None,
510
- freq_range=(10, 100), freq_step=5):
640
+ def calculate_spike_rate_power_correlation(
641
+ spike_rate,
642
+ lfp_data,
643
+ fs,
644
+ pop_names,
645
+ filter_method="wavelet",
646
+ bandwidth=2.0,
647
+ lowcut=None,
648
+ highcut=None,
649
+ freq_range=(10, 100),
650
+ freq_step=5,
651
+ ):
511
652
  """
512
653
  Calculate correlation between population spike rates and LFP power across frequencies
513
654
  using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
514
-
655
+
515
656
  Parameters:
516
657
  -----------
517
658
  spike_rate : DataFrame
@@ -534,7 +675,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
534
675
  Min and max frequency to analyze (default: (10, 100))
535
676
  freq_step : float, optional
536
677
  Step size for frequency analysis (default: 5)
537
-
678
+
538
679
  Returns:
539
680
  --------
540
681
  correlation_results : dict
@@ -542,32 +683,34 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
542
683
  frequencies : array
543
684
  Array of frequencies analyzed
544
685
  """
545
-
686
+
546
687
  # Define frequency bands to analyze
547
688
  frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
548
-
689
+
549
690
  # Dictionary to store results
550
691
  correlation_results = {pop: {} for pop in pop_names}
551
-
692
+
552
693
  # Calculate power at each frequency band using specified filter
553
694
  power_by_freq = {}
554
695
  for freq in frequencies:
555
- power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
556
- lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
557
-
696
+ power_by_freq[freq] = get_lfp_power(
697
+ lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
698
+ )
699
+
558
700
  # Calculate correlation for each population
559
701
  for pop in pop_names:
560
702
  # Extract spike rate for this population
561
703
  pop_rate = spike_rate[pop]
562
-
704
+
563
705
  # Calculate correlation with power at each frequency
564
706
  for freq in frequencies:
565
707
  # Make sure the lengths match
566
708
  if len(pop_rate) != len(power_by_freq[freq]):
567
- raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
709
+ raise ValueError(
710
+ f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
711
+ )
568
712
  # use spearman for non-parametric correlation
569
713
  corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
570
- correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
571
-
572
- return correlation_results, frequencies
714
+ correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
573
715
 
716
+ return correlation_results, frequencies