bmtool 0.7.0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,23 +2,62 @@
2
2
  Module for entrainment analysis
3
3
  """
4
4
 
5
- import numpy as np
6
- from scipy import signal
5
+ from typing import Dict, List
6
+
7
7
  import numba
8
- from numba import cuda
8
+ import numpy as np
9
9
  import pandas as pd
10
- from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
11
- from typing import Dict, List, Optional
12
- from tqdm.notebook import tqdm
13
10
  import scipy.stats as stats
11
+ import xarray as xr
12
+ from numba import cuda
13
+ from scipy import signal
14
+ from tqdm.notebook import tqdm
15
+
16
+ from .lfp import butter_bandpass_filter, get_lfp_phase, get_lfp_power, wavelet_filter
17
+
14
18
 
19
+ def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.ndarray:
20
+ """the lfp xarray should have a time axis. use that to align the spike times since the lfp can start at a
21
+ non-zero time after sliced. Both need to be on same fs for this to be correct.
15
22
 
16
- def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
17
- filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
18
- bandwidth: float = 2.0) -> np.ndarray:
23
+ Parameters
24
+ ----------
25
+ lfp : xarray.DataArray
26
+ LFP data with time coordinates
27
+ timestamps : np.ndarray
28
+ Array of spike timestamps
29
+
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ Copy of timestamps with adjusted timestamps to align with lfp.
34
+ """
35
+ # print("Pairing LFP and Spike Times")
36
+ # print(lfp.time.values)
37
+ # print(f"LFP starts at {lfp.time.values[0]}ms")
38
+ # need to make sure lfp and spikes have the same time axis
39
+ # align spikes with lfp
40
+ timestamps = timestamps[
41
+ (timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
42
+ ].copy()
43
+ # set the time axis of the spikes to match the lfp
44
+ timestamps = timestamps - lfp.time.values[0]
45
+ return timestamps
46
+
47
+
48
+ def calculate_signal_signal_plv(
49
+ signal1: np.ndarray,
50
+ signal2: np.ndarray,
51
+ fs: float,
52
+ freq_of_interest: float = None,
53
+ filter_method: str = "wavelet",
54
+ lowcut: float = None,
55
+ highcut: float = None,
56
+ bandwidth: float = 2.0,
57
+ ) -> np.ndarray:
19
58
  """
20
59
  Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
21
-
60
+
22
61
  Parameters
23
62
  ----------
24
63
  signal1 : np.ndarray
@@ -37,7 +76,7 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
37
76
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
38
77
  bandwidth : float, optional
39
78
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
40
-
79
+
41
80
  Returns
42
81
  -------
43
82
  np.ndarray
@@ -45,23 +84,29 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
45
84
  """
46
85
  if len(signal1) != len(signal2):
47
86
  raise ValueError("Input signals must have the same length.")
48
-
49
- if filter_method == 'wavelet':
87
+
88
+ if filter_method == "wavelet":
50
89
  if freq_of_interest is None:
51
90
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
52
-
91
+
53
92
  # Apply CWT to both signals
54
93
  theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
55
94
  theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
56
-
57
- elif filter_method == 'butter':
95
+
96
+ elif filter_method == "butter":
58
97
  if lowcut is None or highcut is None:
59
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
60
-
98
+ print(
99
+ "Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation"
100
+ )
101
+
61
102
  if lowcut and highcut:
62
103
  # Bandpass filter and get the analytic signal using the Hilbert transform
63
- filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
64
- filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
104
+ filtered_signal1 = butter_bandpass_filter(
105
+ data=signal1, lowcut=lowcut, highcut=highcut, fs=fs
106
+ )
107
+ filtered_signal2 = butter_bandpass_filter(
108
+ data=signal2, lowcut=lowcut, highcut=highcut, fs=fs
109
+ )
65
110
  # Get phase using the Hilbert transform
66
111
  theta1 = signal.hilbert(filtered_signal1)
67
112
  theta2 = signal.hilbert(filtered_signal2)
@@ -69,26 +114,34 @@ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: fl
69
114
  # Get phase using the Hilbert transform without filtering
70
115
  theta1 = signal.hilbert(signal1)
71
116
  theta2 = signal.hilbert(signal2)
72
-
117
+
73
118
  else:
74
119
  raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
75
-
120
+
76
121
  # Calculate phase difference
77
122
  phase_diff = np.angle(theta1) - np.angle(theta2)
78
-
123
+
79
124
  # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
80
125
  plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
81
-
126
+
82
127
  return plv
83
128
 
84
129
 
85
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
86
- lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
87
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
88
- filtered_lfp_phase: np.ndarray = None) -> float:
130
+ def calculate_spike_lfp_plv(
131
+ spike_times: np.ndarray = None,
132
+ lfp_data=None,
133
+ spike_fs: float = None,
134
+ lfp_fs: float = None,
135
+ filter_method: str = "butter",
136
+ freq_of_interest: float = None,
137
+ lowcut: float = None,
138
+ highcut: float = None,
139
+ bandwidth: float = 2.0,
140
+ filtered_lfp_phase: np.ndarray = None,
141
+ ) -> float:
89
142
  """
90
- Calculate spike-lfp unbiased phase locking value
91
-
143
+ Calculate spike-lfp unbiased phase locking value
144
+
92
145
  Parameters
93
146
  ----------
94
147
  spike_times : np.ndarray
@@ -111,13 +164,13 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
111
164
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
112
165
  filtered_lfp_phase : np.ndarray, optional
113
166
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
114
-
167
+
115
168
  Returns
116
169
  -------
117
170
  float
118
171
  Phase Locking Value (unbiased)
119
172
  """
120
-
173
+
121
174
  if spike_fs is None:
122
175
  spike_fs = lfp_fs
123
176
  # Convert spike times to sample indices
@@ -125,33 +178,36 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
125
178
 
126
179
  # Then convert from seconds to samples at the new sampling rate
127
180
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
128
-
181
+
129
182
  # Filter indices to ensure they're within bounds of the LFP signal
130
- if filtered_lfp_phase is not None:
131
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
132
- else:
133
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
134
-
183
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
184
+
135
185
  if len(valid_indices) <= 1:
136
186
  return 0
137
-
187
+
138
188
  # Get instantaneous phase
139
189
  if filtered_lfp_phase is None:
140
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
141
- freq_of_interest=freq_of_interest, lowcut=lowcut,
142
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
190
+ instantaneous_phase = get_lfp_phase(
191
+ lfp_data=lfp_data,
192
+ filter_method=filter_method,
193
+ freq_of_interest=freq_of_interest,
194
+ lowcut=lowcut,
195
+ highcut=highcut,
196
+ bandwidth=bandwidth,
197
+ fs=lfp_fs,
198
+ )
143
199
  else:
144
200
  instantaneous_phase = filtered_lfp_phase
145
-
201
+
146
202
  # Get phases at spike times
147
- spike_phases = instantaneous_phase[valid_indices]
203
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
148
204
 
149
205
  # Number of spikes
150
206
  N = len(spike_phases)
151
-
207
+
152
208
  # Convert phases to unit vectors in the complex plane
153
209
  unit_vectors = np.exp(1j * spike_phases)
154
-
210
+
155
211
  # Sum of all unit vectors (resultant vector)
156
212
  resultant_vector = np.sum(unit_vectors)
157
213
 
@@ -159,7 +215,7 @@ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray
159
215
  plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
160
216
  plv = (plv2n / N) ** 0.5
161
217
  ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
162
- plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
218
+ plv_unbiased = np.fmax(ppc, 0.0) ** 0.5 # ensure non-negative
163
219
 
164
220
  return plv_unbiased
165
221
 
@@ -181,7 +237,7 @@ def _ppc_cuda_kernel(spike_phases, out):
181
237
  i = cuda.grid(1)
182
238
  if i < len(spike_phases):
183
239
  local_sum = 0.0
184
- for j in range(i+1, len(spike_phases)):
240
+ for j in range(i + 1, len(spike_phases)):
185
241
  local_sum += np.cos(spike_phases[i] - spike_phases[j])
186
242
  out[i] = local_sum
187
243
 
@@ -190,23 +246,32 @@ def _ppc_gpu(spike_phases):
190
246
  """GPU-accelerated implementation"""
191
247
  d_phases = cuda.to_device(spike_phases)
192
248
  d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
193
-
249
+
194
250
  threads = 256
195
251
  blocks = (len(spike_phases) + threads - 1) // threads
196
-
252
+
197
253
  _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
198
254
  total = d_out.copy_to_host().sum()
199
- return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
200
-
201
-
202
- def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
203
- lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
204
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
205
- ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
255
+ return (2 / (len(spike_phases) * (len(spike_phases) - 1))) * total
256
+
257
+
258
+ def calculate_ppc(
259
+ spike_times: np.ndarray = None,
260
+ lfp_data=None,
261
+ spike_fs: float = None,
262
+ lfp_fs: float = None,
263
+ filter_method: str = "wavelet",
264
+ freq_of_interest: float = None,
265
+ lowcut: float = None,
266
+ highcut: float = None,
267
+ bandwidth: float = 2.0,
268
+ ppc_method: str = "numpy",
269
+ filtered_lfp_phase: np.ndarray = None,
270
+ ) -> float:
206
271
  """
207
272
  Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
208
273
  Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
209
-
274
+
210
275
  Parameters
211
276
  ----------
212
277
  spike_times : np.ndarray
@@ -231,7 +296,7 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
231
296
  Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
232
297
  filtered_lfp_phase : np.ndarray, optional
233
298
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
234
-
299
+
235
300
  Returns
236
301
  -------
237
302
  float
@@ -244,63 +309,77 @@ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, s
244
309
 
245
310
  # Then convert from seconds to samples at the new sampling rate
246
311
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
247
-
312
+
248
313
  # Filter indices to ensure they're within bounds of the LFP signal
249
314
  if filtered_lfp_phase is not None:
250
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
315
+ valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
251
316
  else:
252
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
253
-
317
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
318
+
254
319
  if len(valid_indices) <= 1:
255
320
  return 0
256
-
321
+
257
322
  # Get instantaneous phase
258
323
  if filtered_lfp_phase is None:
259
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
260
- freq_of_interest=freq_of_interest, lowcut=lowcut,
261
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
324
+ instantaneous_phase = get_lfp_phase(
325
+ lfp_data=lfp_data,
326
+ filter_method=filter_method,
327
+ freq_of_interest=freq_of_interest,
328
+ lowcut=lowcut,
329
+ highcut=highcut,
330
+ bandwidth=bandwidth,
331
+ fs=lfp_fs,
332
+ )
262
333
  else:
263
334
  instantaneous_phase = filtered_lfp_phase
264
-
335
+
265
336
  # Get phases at spike times
266
- spike_phases = instantaneous_phase[valid_indices]
267
-
337
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
338
+
268
339
  n_spikes = len(spike_phases)
269
340
 
270
341
  # Calculate PPC (Pairwise Phase Consistency)
271
342
  if n_spikes <= 1:
272
343
  return 0
273
-
344
+
274
345
  # Explicit calculation of pairwise phase consistency
275
346
  # Vectorized computation for efficiency
276
- if ppc_method == 'numpy':
347
+ if ppc_method == "numpy":
277
348
  i, j = np.triu_indices(n_spikes, k=1)
278
349
  phase_diff = spike_phases[i] - spike_phases[j]
279
350
  sum_cos_diff = np.sum(np.cos(phase_diff))
280
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
281
- elif ppc_method == 'numba':
351
+ ppc = (2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff
352
+ elif ppc_method == "numba":
282
353
  ppc = _ppc_parallel_numba(spike_phases)
283
- elif ppc_method == 'gpu':
354
+ elif ppc_method == "gpu":
284
355
  ppc = _ppc_gpu(spike_phases)
285
356
  else:
286
357
  raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
287
358
  return ppc
288
359
 
289
-
290
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
291
- lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
292
- lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
293
- filtered_lfp_phase: np.ndarray = None) -> float:
360
+
361
+ def calculate_ppc2(
362
+ spike_times: np.ndarray = None,
363
+ lfp_data=None,
364
+ spike_fs: float = None,
365
+ lfp_fs: float = None,
366
+ filter_method: str = "wavelet",
367
+ freq_of_interest: float = None,
368
+ lowcut: float = None,
369
+ highcut: float = None,
370
+ bandwidth: float = 2.0,
371
+ filtered_lfp_phase: np.ndarray = None,
372
+ ) -> float:
294
373
  """
295
374
  # -----------------------------------------------------------------------------
296
- # PPC2 Calculation (Vinck et al., 2010)
375
+ # PPC2 Calculation (Vinck et al., 2010)
297
376
  # -----------------------------------------------------------------------------
298
377
  # Equation(Original):
299
378
  # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
300
379
  # Optimized Formula (Algebraically Equivalent):
301
380
  # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
302
381
  # -----------------------------------------------------------------------------
303
-
382
+
304
383
  Parameters
305
384
  ----------
306
385
  spike_times : np.ndarray
@@ -323,13 +402,13 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
323
402
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
324
403
  filtered_lfp_phase : np.ndarray, optional
325
404
  Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
326
-
405
+
327
406
  Returns
328
407
  -------
329
408
  float
330
409
  Pairwise Phase Consistency 2 (PPC2) value
331
410
  """
332
-
411
+
333
412
  if spike_fs is None:
334
413
  spike_fs = lfp_fs
335
414
  # Convert spike times to sample indices
@@ -337,49 +416,64 @@ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None,
337
416
 
338
417
  # Then convert from seconds to samples at the new sampling rate
339
418
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
340
-
419
+
341
420
  # Filter indices to ensure they're within bounds of the LFP signal
342
421
  if filtered_lfp_phase is not None:
343
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
422
+ valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
344
423
  else:
345
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
346
-
424
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
425
+
347
426
  if len(valid_indices) <= 1:
348
427
  return 0
349
-
428
+
350
429
  # Get instantaneous phase
351
430
  if filtered_lfp_phase is None:
352
- instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
353
- freq_of_interest=freq_of_interest, lowcut=lowcut,
354
- highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
431
+ instantaneous_phase = get_lfp_phase(
432
+ lfp_data=lfp_data,
433
+ filter_method=filter_method,
434
+ freq_of_interest=freq_of_interest,
435
+ lowcut=lowcut,
436
+ highcut=highcut,
437
+ bandwidth=bandwidth,
438
+ fs=lfp_fs,
439
+ )
355
440
  else:
356
441
  instantaneous_phase = filtered_lfp_phase
357
-
442
+
358
443
  # Get phases at spike times
359
- spike_phases = instantaneous_phase[valid_indices]
360
-
444
+ spike_phases = instantaneous_phase.sel(time=valid_indices).values
361
445
  # Calculate PPC2 according to Vinck et al. (2010), Equation 6
362
446
  n = len(spike_phases)
363
-
447
+
364
448
  if n <= 1:
365
449
  return 0
366
-
450
+
367
451
  # Convert phases to unit vectors in the complex plane
368
452
  unit_vectors = np.exp(1j * spike_phases)
369
-
453
+
370
454
  # Calculate the resultant vector
371
455
  resultant_vector = np.sum(unit_vectors)
372
-
456
+
373
457
  # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
374
- ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
375
-
458
+ ppc2 = (np.abs(resultant_vector) ** 2 - n) / (n * (n - 1))
459
+
376
460
  return ppc2
377
461
 
378
462
 
379
- def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
380
- entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
381
- spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
382
- freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
463
+ def calculate_entrainment_per_cell(
464
+ spike_df: pd.DataFrame = None,
465
+ lfp_data: np.ndarray = None,
466
+ filter_method: str = "wavelet",
467
+ pop_names: List[str] = None,
468
+ entrainment_method: str = "plv",
469
+ lowcut: float = None,
470
+ highcut: float = None,
471
+ spike_fs: float = None,
472
+ lfp_fs: float = None,
473
+ bandwidth: float = 2,
474
+ freqs: List[float] = None,
475
+ ppc_method: str = "numpy",
476
+ ) -> Dict[str, Dict[int, Dict[float, float]]]:
383
477
  """
384
478
  Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
385
479
 
@@ -431,26 +525,26 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
431
525
  filtered_lfp_phases = {}
432
526
  for freq in range(len(freqs)):
433
527
  phase = get_lfp_phase(
434
- lfp_data=lfp_data,
435
- freq_of_interest=freqs[freq],
436
- fs=lfp_fs,
528
+ lfp_data=lfp_data,
529
+ freq_of_interest=freqs[freq],
530
+ fs=lfp_fs,
437
531
  filter_method=filter_method,
438
- lowcut=lowcut,
439
- highcut=highcut,
440
- bandwidth=bandwidth
532
+ lowcut=lowcut,
533
+ highcut=highcut,
534
+ bandwidth=bandwidth,
441
535
  )
442
536
  filtered_lfp_phases[freqs[freq]] = phase
443
-
537
+
444
538
  entrainment_dict = {}
445
539
  for pop in pop_names:
446
540
  skip_count = 0
447
- pop_spikes = spike_df[spike_df['pop_name'] == pop]
448
- nodes = pop_spikes['node_ids'].unique()
541
+ pop_spikes = spike_df[spike_df["pop_name"] == pop]
542
+ nodes = pop_spikes["node_ids"].unique()
449
543
  entrainment_dict[pop] = {}
450
- print(f'Processing {pop} population')
544
+ print(f"Processing {pop} population")
451
545
  for node in tqdm(nodes):
452
- node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
453
-
546
+ node_spikes = pop_spikes[pop_spikes["node_ids"] == node]
547
+
454
548
  # Skip nodes with less than or equal to 1 spike
455
549
  if len(node_spikes) <= 1:
456
550
  skip_count += 1
@@ -459,9 +553,9 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
459
553
  entrainment_dict[pop][node] = {}
460
554
  for freq in freqs:
461
555
  # Calculate entrainment based on the selected method using the pre-filtered phases
462
- if entrainment_method == 'plv':
556
+ if entrainment_method == "plv":
463
557
  entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
464
- node_spikes['timestamps'].values,
558
+ node_spikes["timestamps"].values,
465
559
  lfp_data,
466
560
  spike_fs=spike_fs,
467
561
  lfp_fs=lfp_fs,
@@ -470,11 +564,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
470
564
  lowcut=lowcut,
471
565
  highcut=highcut,
472
566
  filter_method=filter_method,
473
- filtered_lfp_phase=filtered_lfp_phases[freq]
567
+ filtered_lfp_phase=filtered_lfp_phases[freq],
474
568
  )
475
- elif entrainment_method == 'ppc2':
569
+ elif entrainment_method == "ppc2":
476
570
  entrainment_dict[pop][node][freq] = calculate_ppc2(
477
- node_spikes['timestamps'].values,
571
+ node_spikes["timestamps"].values,
478
572
  lfp_data,
479
573
  spike_fs=spike_fs,
480
574
  lfp_fs=lfp_fs,
@@ -483,11 +577,11 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
483
577
  lowcut=lowcut,
484
578
  highcut=highcut,
485
579
  filter_method=filter_method,
486
- filtered_lfp_phase=filtered_lfp_phases[freq]
580
+ filtered_lfp_phase=filtered_lfp_phases[freq],
487
581
  )
488
- elif entrainment_method == 'ppc':
582
+ elif entrainment_method == "ppc":
489
583
  entrainment_dict[pop][node][freq] = calculate_ppc(
490
- node_spikes['timestamps'].values,
584
+ node_spikes["timestamps"].values,
491
585
  lfp_data,
492
586
  spike_fs=spike_fs,
493
587
  lfp_fs=lfp_fs,
@@ -497,21 +591,32 @@ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.nda
497
591
  highcut=highcut,
498
592
  filter_method=filter_method,
499
593
  ppc_method=ppc_method,
500
- filtered_lfp_phase=filtered_lfp_phases[freq]
594
+ filtered_lfp_phase=filtered_lfp_phases[freq],
501
595
  )
502
596
 
503
- print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
597
+ print(
598
+ f"Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes"
599
+ )
504
600
 
505
601
  return entrainment_dict
506
602
 
507
603
 
508
- def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
509
- bandwidth=2.0, lowcut=None, highcut=None,
510
- freq_range=(10, 100), freq_step=5):
604
+ def calculate_spike_rate_power_correlation(
605
+ spike_rate,
606
+ lfp_data,
607
+ fs,
608
+ pop_names,
609
+ filter_method="wavelet",
610
+ bandwidth=2.0,
611
+ lowcut=None,
612
+ highcut=None,
613
+ freq_range=(10, 100),
614
+ freq_step=5,
615
+ ):
511
616
  """
512
617
  Calculate correlation between population spike rates and LFP power across frequencies
513
618
  using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
514
-
619
+
515
620
  Parameters:
516
621
  -----------
517
622
  spike_rate : DataFrame
@@ -534,7 +639,7 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
534
639
  Min and max frequency to analyze (default: (10, 100))
535
640
  freq_step : float, optional
536
641
  Step size for frequency analysis (default: 5)
537
-
642
+
538
643
  Returns:
539
644
  --------
540
645
  correlation_results : dict
@@ -542,32 +647,34 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names,
542
647
  frequencies : array
543
648
  Array of frequencies analyzed
544
649
  """
545
-
650
+
546
651
  # Define frequency bands to analyze
547
652
  frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
548
-
653
+
549
654
  # Dictionary to store results
550
655
  correlation_results = {pop: {} for pop in pop_names}
551
-
656
+
552
657
  # Calculate power at each frequency band using specified filter
553
658
  power_by_freq = {}
554
659
  for freq in frequencies:
555
- power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
556
- lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
557
-
660
+ power_by_freq[freq] = get_lfp_power(
661
+ lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
662
+ )
663
+
558
664
  # Calculate correlation for each population
559
665
  for pop in pop_names:
560
666
  # Extract spike rate for this population
561
667
  pop_rate = spike_rate[pop]
562
-
668
+
563
669
  # Calculate correlation with power at each frequency
564
670
  for freq in frequencies:
565
671
  # Make sure the lengths match
566
672
  if len(pop_rate) != len(power_by_freq[freq]):
567
- raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
673
+ raise ValueError(
674
+ f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
675
+ )
568
676
  # use spearman for non-parametric correlation
569
677
  corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
570
- correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
571
-
572
- return correlation_results, frequencies
678
+ correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
573
679
 
680
+ return correlation_results, frequencies