bmtool 0.7.0.3__tar.gz → 0.7.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/PKG-INFO +1 -1
  2. bmtool-0.7.0.5/bmtool/analysis/entrainment.py +573 -0
  3. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/lfp.py +176 -1
  4. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/spikes.py +115 -63
  5. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/PKG-INFO +1 -1
  6. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/setup.py +1 -1
  7. bmtool-0.7.0.3/bmtool/analysis/entrainment.py +0 -490
  8. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/LICENSE +0 -0
  9. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/README.md +0 -0
  10. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/SLURM.py +0 -0
  11. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/__init__.py +0 -0
  12. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/__main__.py +0 -0
  13. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/__init__.py +0 -0
  14. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/analysis/netcon_reports.py +0 -0
  15. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/__init__.py +0 -0
  16. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/connections.py +0 -0
  17. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/entrainment.py +0 -0
  18. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/lfp.py +0 -0
  19. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/netcon_reports.py +0 -0
  20. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/bmplot/spikes.py +0 -0
  21. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/connectors.py +0 -0
  22. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/__init__.py +0 -0
  23. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/commands.py +0 -0
  24. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/debug/debug.py +0 -0
  25. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/graphs.py +0 -0
  26. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/manage.py +0 -0
  27. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/plot_commands.py +0 -0
  28. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/singlecell.py +0 -0
  29. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/synapses.py +0 -0
  30. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/__init__.py +0 -0
  31. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/commands.py +0 -0
  32. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/neuron/__init__.py +0 -0
  33. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/neuron/celltuner.py +0 -0
  34. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool/util/util.py +0 -0
  35. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/SOURCES.txt +0 -0
  36. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/dependency_links.txt +0 -0
  37. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/entry_points.txt +0 -0
  38. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/requires.txt +0 -0
  39. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/bmtool.egg-info/top_level.txt +0 -0
  40. {bmtool-0.7.0.3 → bmtool-0.7.0.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.0.3
3
+ Version: 0.7.0.5
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -0,0 +1,573 @@
1
+ """
2
+ Module for entrainment analysis
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy import signal
7
+ import numba
8
+ from numba import cuda
9
+ import pandas as pd
10
+ from .lfp import wavelet_filter,butter_bandpass_filter,get_lfp_power, get_lfp_phase
11
+ from typing import Dict, List, Optional
12
+ from tqdm.notebook import tqdm
13
+ import scipy.stats as stats
14
+
15
+
16
+ def calculate_signal_signal_plv(signal1: np.ndarray, signal2: np.ndarray, fs: float, freq_of_interest: float = None,
17
+ filter_method: str = 'wavelet', lowcut: float = None, highcut: float = None,
18
+ bandwidth: float = 2.0) -> np.ndarray:
19
+ """
20
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
21
+
22
+ Parameters
23
+ ----------
24
+ signal1 : np.ndarray
25
+ First input signal (1D array)
26
+ signal2 : np.ndarray
27
+ Second input signal (1D array, same length as signal1)
28
+ fs : float
29
+ Sampling frequency in Hz
30
+ freq_of_interest : float, optional
31
+ Desired frequency for wavelet PLV calculation, required if filter_method='wavelet'
32
+ filter_method : str, optional
33
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
34
+ lowcut : float, optional
35
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
36
+ highcut : float, optional
37
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
38
+ bandwidth : float, optional
39
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
40
+
41
+ Returns
42
+ -------
43
+ np.ndarray
44
+ Phase Locking Value (1D array)
45
+ """
46
+ if len(signal1) != len(signal2):
47
+ raise ValueError("Input signals must have the same length.")
48
+
49
+ if filter_method == 'wavelet':
50
+ if freq_of_interest is None:
51
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
52
+
53
+ # Apply CWT to both signals
54
+ theta1 = wavelet_filter(x=signal1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
55
+ theta2 = wavelet_filter(x=signal2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
56
+
57
+ elif filter_method == 'butter':
58
+ if lowcut is None or highcut is None:
59
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PLV calculation")
60
+
61
+ if lowcut and highcut:
62
+ # Bandpass filter and get the analytic signal using the Hilbert transform
63
+ filtered_signal1 = butter_bandpass_filter(data=signal1, lowcut=lowcut, highcut=highcut, fs=fs)
64
+ filtered_signal2 = butter_bandpass_filter(data=signal2, lowcut=lowcut, highcut=highcut, fs=fs)
65
+ # Get phase using the Hilbert transform
66
+ theta1 = signal.hilbert(filtered_signal1)
67
+ theta2 = signal.hilbert(filtered_signal2)
68
+ else:
69
+ # Get phase using the Hilbert transform without filtering
70
+ theta1 = signal.hilbert(signal1)
71
+ theta2 = signal.hilbert(signal2)
72
+
73
+ else:
74
+ raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
75
+
76
+ # Calculate phase difference
77
+ phase_diff = np.angle(theta1) - np.angle(theta2)
78
+
79
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
80
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
81
+
82
+ return plv
83
+
84
+
85
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
86
+ lfp_fs: float = None, filter_method: str = 'butter', freq_of_interest: float = None,
87
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
88
+ filtered_lfp_phase: np.ndarray = None) -> float:
89
+ """
90
+ Calculate spike-lfp unbiased phase locking value
91
+
92
+ Parameters
93
+ ----------
94
+ spike_times : np.ndarray
95
+ Array of spike times
96
+ lfp_data : np.ndarray
97
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
98
+ spike_fs : float, optional
99
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
100
+ lfp_fs : float
101
+ Sampling frequency in Hz of the LFP data
102
+ filter_method : str, optional
103
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
104
+ freq_of_interest : float, optional
105
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
106
+ lowcut : float, optional
107
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
108
+ highcut : float, optional
109
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
110
+ bandwidth : float, optional
111
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
112
+ filtered_lfp_phase : np.ndarray, optional
113
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
114
+
115
+ Returns
116
+ -------
117
+ float
118
+ Phase Locking Value (unbiased)
119
+ """
120
+
121
+ if spike_fs is None:
122
+ spike_fs = lfp_fs
123
+ # Convert spike times to sample indices
124
+ spike_times_seconds = spike_times / spike_fs
125
+
126
+ # Then convert from seconds to samples at the new sampling rate
127
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
128
+
129
+ # Filter indices to ensure they're within bounds of the LFP signal
130
+ if filtered_lfp_phase is not None:
131
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
132
+ else:
133
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
134
+
135
+ if len(valid_indices) <= 1:
136
+ return 0
137
+
138
+ # Get instantaneous phase
139
+ if filtered_lfp_phase is None:
140
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
141
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
142
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
143
+ else:
144
+ instantaneous_phase = filtered_lfp_phase
145
+
146
+ # Get phases at spike times
147
+ spike_phases = instantaneous_phase[valid_indices]
148
+
149
+ # Number of spikes
150
+ N = len(spike_phases)
151
+
152
+ # Convert phases to unit vectors in the complex plane
153
+ unit_vectors = np.exp(1j * spike_phases)
154
+
155
+ # Sum of all unit vectors (resultant vector)
156
+ resultant_vector = np.sum(unit_vectors)
157
+
158
+ # Calculate plv^2 * N
159
+ plv2n = (resultant_vector * resultant_vector.conjugate()).real / N # plv^2 * N
160
+ plv = (plv2n / N) ** 0.5
161
+ ppc = (plv2n - 1) / (N - 1) # ppc = (plv^2 * N - 1) / (N - 1)
162
+ plv_unbiased = np.fmax(ppc, 0.) ** 0.5 # ensure non-negative
163
+
164
+ return plv_unbiased
165
+
166
+
167
+ @numba.njit(parallel=True, fastmath=True)
168
+ def _ppc_parallel_numba(spike_phases):
169
+ """Numba-optimized parallel PPC calculation"""
170
+ n = len(spike_phases)
171
+ sum_cos = 0.0
172
+ for i in numba.prange(n):
173
+ phase_i = spike_phases[i]
174
+ for j in range(i + 1, n):
175
+ sum_cos += np.cos(phase_i - spike_phases[j])
176
+ return (2 / (n * (n - 1))) * sum_cos
177
+
178
+
179
+ @cuda.jit(fastmath=True)
180
+ def _ppc_cuda_kernel(spike_phases, out):
181
+ i = cuda.grid(1)
182
+ if i < len(spike_phases):
183
+ local_sum = 0.0
184
+ for j in range(i+1, len(spike_phases)):
185
+ local_sum += np.cos(spike_phases[i] - spike_phases[j])
186
+ out[i] = local_sum
187
+
188
+
189
+ def _ppc_gpu(spike_phases):
190
+ """GPU-accelerated implementation"""
191
+ d_phases = cuda.to_device(spike_phases)
192
+ d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
193
+
194
+ threads = 256
195
+ blocks = (len(spike_phases) + threads - 1) // threads
196
+
197
+ _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
198
+ total = d_out.copy_to_host().sum()
199
+ return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
200
+
201
+
202
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
203
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
204
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
205
+ ppc_method: str = 'numpy', filtered_lfp_phase: np.ndarray = None) -> float:
206
+ """
207
+ Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
208
+ Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
209
+
210
+ Parameters
211
+ ----------
212
+ spike_times : np.ndarray
213
+ Array of spike times
214
+ lfp_data : np.ndarray
215
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
216
+ spike_fs : float, optional
217
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
218
+ lfp_fs : float
219
+ Sampling frequency in Hz of the LFP data
220
+ filter_method : str, optional
221
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
222
+ freq_of_interest : float, optional
223
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
224
+ lowcut : float, optional
225
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
226
+ highcut : float, optional
227
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
228
+ bandwidth : float, optional
229
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
230
+ ppc_method : str, optional
231
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
232
+ filtered_lfp_phase : np.ndarray, optional
233
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
234
+
235
+ Returns
236
+ -------
237
+ float
238
+ Pairwise Phase Consistency value
239
+ """
240
+ if spike_fs is None:
241
+ spike_fs = lfp_fs
242
+ # Convert spike times to sample indices
243
+ spike_times_seconds = spike_times / spike_fs
244
+
245
+ # Then convert from seconds to samples at the new sampling rate
246
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
247
+
248
+ # Filter indices to ensure they're within bounds of the LFP signal
249
+ if filtered_lfp_phase is not None:
250
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
251
+ else:
252
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
253
+
254
+ if len(valid_indices) <= 1:
255
+ return 0
256
+
257
+ # Get instantaneous phase
258
+ if filtered_lfp_phase is None:
259
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
260
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
261
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
262
+ else:
263
+ instantaneous_phase = filtered_lfp_phase
264
+
265
+ # Get phases at spike times
266
+ spike_phases = instantaneous_phase[valid_indices]
267
+
268
+ n_spikes = len(spike_phases)
269
+
270
+ # Calculate PPC (Pairwise Phase Consistency)
271
+ if n_spikes <= 1:
272
+ return 0
273
+
274
+ # Explicit calculation of pairwise phase consistency
275
+ # Vectorized computation for efficiency
276
+ if ppc_method == 'numpy':
277
+ i, j = np.triu_indices(n_spikes, k=1)
278
+ phase_diff = spike_phases[i] - spike_phases[j]
279
+ sum_cos_diff = np.sum(np.cos(phase_diff))
280
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
281
+ elif ppc_method == 'numba':
282
+ ppc = _ppc_parallel_numba(spike_phases)
283
+ elif ppc_method == 'gpu':
284
+ ppc = _ppc_gpu(spike_phases)
285
+ else:
286
+ raise ValueError("Please use a supported ppc method currently that is numpy, numba or gpu")
287
+ return ppc
288
+
289
+
290
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_data: np.ndarray = None, spike_fs: float = None,
291
+ lfp_fs: float = None, filter_method: str = 'wavelet', freq_of_interest: float = None,
292
+ lowcut: float = None, highcut: float = None, bandwidth: float = 2.0,
293
+ filtered_lfp_phase: np.ndarray = None) -> float:
294
+ """
295
+ # -----------------------------------------------------------------------------
296
+ # PPC2 Calculation (Vinck et al., 2010)
297
+ # -----------------------------------------------------------------------------
298
+ # Equation(Original):
299
+ # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
300
+ # Optimized Formula (Algebraically Equivalent):
301
+ # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
302
+ # -----------------------------------------------------------------------------
303
+
304
+ Parameters
305
+ ----------
306
+ spike_times : np.ndarray
307
+ Array of spike times
308
+ lfp_data : np.ndarray
309
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
310
+ spike_fs : float, optional
311
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
312
+ lfp_fs : float
313
+ Sampling frequency in Hz of the LFP data
314
+ filter_method : str, optional
315
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
316
+ freq_of_interest : float, optional
317
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
318
+ lowcut : float, optional
319
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
320
+ highcut : float, optional
321
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
322
+ bandwidth : float, optional
323
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
324
+ filtered_lfp_phase : np.ndarray, optional
325
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
326
+
327
+ Returns
328
+ -------
329
+ float
330
+ Pairwise Phase Consistency 2 (PPC2) value
331
+ """
332
+
333
+ if spike_fs is None:
334
+ spike_fs = lfp_fs
335
+ # Convert spike times to sample indices
336
+ spike_times_seconds = spike_times / spike_fs
337
+
338
+ # Then convert from seconds to samples at the new sampling rate
339
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
340
+
341
+ # Filter indices to ensure they're within bounds of the LFP signal
342
+ if filtered_lfp_phase is not None:
343
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
344
+ else:
345
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
346
+
347
+ if len(valid_indices) <= 1:
348
+ return 0
349
+
350
+ # Get instantaneous phase
351
+ if filtered_lfp_phase is None:
352
+ instantaneous_phase = get_lfp_phase(lfp_data=lfp_data, filter_method=filter_method,
353
+ freq_of_interest=freq_of_interest, lowcut=lowcut,
354
+ highcut=highcut, bandwidth=bandwidth, fs=lfp_fs)
355
+ else:
356
+ instantaneous_phase = filtered_lfp_phase
357
+
358
+ # Get phases at spike times
359
+ spike_phases = instantaneous_phase[valid_indices]
360
+
361
+ # Calculate PPC2 according to Vinck et al. (2010), Equation 6
362
+ n = len(spike_phases)
363
+
364
+ if n <= 1:
365
+ return 0
366
+
367
+ # Convert phases to unit vectors in the complex plane
368
+ unit_vectors = np.exp(1j * spike_phases)
369
+
370
+ # Calculate the resultant vector
371
+ resultant_vector = np.sum(unit_vectors)
372
+
373
+ # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
374
+ ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
375
+
376
+ return ppc2
377
+
378
+
379
+ def calculate_entrainment_per_cell(spike_df: pd.DataFrame=None, lfp_data: np.ndarray=None, filter_method: str='wavelet', pop_names: List[str]=None,
380
+ entrainment_method: str='plv', lowcut: float=None, highcut: float=None,
381
+ spike_fs: float=None, lfp_fs: float=None, bandwidth: float=2,
382
+ freqs: List[float]=None, ppc_method: str='numpy',) -> Dict[str, Dict[int, Dict[float, float]]]:
383
+ """
384
+ Calculate neural entrainment (PPC, PLV) per neuron (cell) for specified frequencies across different populations.
385
+
386
+ This function computes the entrainment metrics for each neuron within the specified populations based on their spike times
387
+ and the provided LFP signal. It returns a nested dictionary structure containing the entrainment values
388
+ organized by population, node ID, and frequency.
389
+
390
+ Parameters
391
+ ----------
392
+ spike_df : pd.DataFrame
393
+ DataFrame containing spike data with columns 'pop_name', 'node_ids', and 'timestamps'
394
+ lfp_data : np.ndarray
395
+ Local field potential (LFP) time series data
396
+ filter_method : str, optional
397
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
398
+ entrainment_method : str, optional
399
+ Method to use for entrainment calculation, either 'plv', 'ppc', or 'ppc2' (default: 'plv')
400
+ lowcut : float, optional
401
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
402
+ highcut : float, optional
403
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
404
+ spike_fs : float
405
+ Sampling frequency of the spike times in Hz
406
+ lfp_fs : float
407
+ Sampling frequency of the LFP signal in Hz
408
+ bandwidth : float, optional
409
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
410
+ ppc_method : str, optional
411
+ Algorithm to use for PPC calculation: 'numpy', 'numba', or 'gpu' (default: 'numpy')
412
+ pop_names : List[str]
413
+ List of population names to analyze
414
+ freqs : List[float]
415
+ List of frequencies (in Hz) at which to calculate entrainment
416
+
417
+ Returns
418
+ -------
419
+ Dict[str, Dict[int, Dict[float, float]]]
420
+ Nested dictionary where the structure is:
421
+ {
422
+ population_name: {
423
+ node_id: {
424
+ frequency: entrainment value
425
+ }
426
+ }
427
+ }
428
+ Entrainment values are floats representing the metric (PPC, PLV) at each frequency
429
+ """
430
+ # pre filter lfp to speed up calculate of entrainment
431
+ filtered_lfp_phases = {}
432
+ for freq in range(len(freqs)):
433
+ phase = get_lfp_phase(
434
+ lfp_data=lfp_data,
435
+ freq_of_interest=freqs[freq],
436
+ fs=lfp_fs,
437
+ filter_method=filter_method,
438
+ lowcut=lowcut,
439
+ highcut=highcut,
440
+ bandwidth=bandwidth
441
+ )
442
+ filtered_lfp_phases[freqs[freq]] = phase
443
+
444
+ entrainment_dict = {}
445
+ for pop in pop_names:
446
+ skip_count = 0
447
+ pop_spikes = spike_df[spike_df['pop_name'] == pop]
448
+ nodes = pop_spikes['node_ids'].unique()
449
+ entrainment_dict[pop] = {}
450
+ print(f'Processing {pop} population')
451
+ for node in tqdm(nodes):
452
+ node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
453
+
454
+ # Skip nodes with less than or equal to 1 spike
455
+ if len(node_spikes) <= 1:
456
+ skip_count += 1
457
+ continue
458
+
459
+ entrainment_dict[pop][node] = {}
460
+ for freq in freqs:
461
+ # Calculate entrainment based on the selected method using the pre-filtered phases
462
+ if entrainment_method == 'plv':
463
+ entrainment_dict[pop][node][freq] = calculate_spike_lfp_plv(
464
+ node_spikes['timestamps'].values,
465
+ lfp_data,
466
+ spike_fs=spike_fs,
467
+ lfp_fs=lfp_fs,
468
+ freq_of_interest=freq,
469
+ bandwidth=bandwidth,
470
+ lowcut=lowcut,
471
+ highcut=highcut,
472
+ filter_method=filter_method,
473
+ filtered_lfp_phase=filtered_lfp_phases[freq]
474
+ )
475
+ elif entrainment_method == 'ppc2':
476
+ entrainment_dict[pop][node][freq] = calculate_ppc2(
477
+ node_spikes['timestamps'].values,
478
+ lfp_data,
479
+ spike_fs=spike_fs,
480
+ lfp_fs=lfp_fs,
481
+ freq_of_interest=freq,
482
+ bandwidth=bandwidth,
483
+ lowcut=lowcut,
484
+ highcut=highcut,
485
+ filter_method=filter_method,
486
+ filtered_lfp_phase=filtered_lfp_phases[freq]
487
+ )
488
+ elif entrainment_method == 'ppc':
489
+ entrainment_dict[pop][node][freq] = calculate_ppc(
490
+ node_spikes['timestamps'].values,
491
+ lfp_data,
492
+ spike_fs=spike_fs,
493
+ lfp_fs=lfp_fs,
494
+ freq_of_interest=freq,
495
+ bandwidth=bandwidth,
496
+ lowcut=lowcut,
497
+ highcut=highcut,
498
+ filter_method=filter_method,
499
+ ppc_method=ppc_method,
500
+ filtered_lfp_phase=filtered_lfp_phases[freq]
501
+ )
502
+
503
+ print(f'Calculated {entrainment_method.upper()} for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
504
+
505
+ return entrainment_dict
506
+
507
+
508
+ def calculate_spike_rate_power_correlation(spike_rate, lfp_data, fs, pop_names, filter_method='wavelet',
509
+ bandwidth=2.0, lowcut=None, highcut=None,
510
+ freq_range=(10, 100), freq_step=5):
511
+ """
512
+ Calculate correlation between population spike rates and LFP power across frequencies
513
+ using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
514
+
515
+ Parameters:
516
+ -----------
517
+ spike_rate : DataFrame
518
+ Pre-calculated population spike rates at the same fs as lfp
519
+ lfp_data : np.array
520
+ LFP data
521
+ fs : float
522
+ Sampling frequency
523
+ pop_names : list
524
+ List of population names to analyze
525
+ filter_method : str, optional
526
+ Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
527
+ bandwidth : float, optional
528
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
529
+ lowcut : float, optional
530
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
531
+ highcut : float, optional
532
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
533
+ freq_range : tuple, optional
534
+ Min and max frequency to analyze (default: (10, 100))
535
+ freq_step : float, optional
536
+ Step size for frequency analysis (default: 5)
537
+
538
+ Returns:
539
+ --------
540
+ correlation_results : dict
541
+ Dictionary with correlation results for each population and frequency
542
+ frequencies : array
543
+ Array of frequencies analyzed
544
+ """
545
+
546
+ # Define frequency bands to analyze
547
+ frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
548
+
549
+ # Dictionary to store results
550
+ correlation_results = {pop: {} for pop in pop_names}
551
+
552
+ # Calculate power at each frequency band using specified filter
553
+ power_by_freq = {}
554
+ for freq in frequencies:
555
+ power_by_freq[freq] = get_lfp_power(lfp_data, freq, fs, filter_method,
556
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
557
+
558
+ # Calculate correlation for each population
559
+ for pop in pop_names:
560
+ # Extract spike rate for this population
561
+ pop_rate = spike_rate[pop]
562
+
563
+ # Calculate correlation with power at each frequency
564
+ for freq in frequencies:
565
+ # Make sure the lengths match
566
+ if len(pop_rate) != len(power_by_freq[freq]):
567
+ raise ValueError(f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}")
568
+ # use spearman for non-parametric correlation
569
+ corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
570
+ correlation_results[pop][freq] = {'correlation': corr, 'p_value': p_val}
571
+
572
+ return correlation_results, frequencies
573
+