bmtool 0.6.9.24__tar.gz → 0.6.9.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/PKG-INFO +1 -1
  2. bmtool-0.6.9.25/bmtool/analysis/entrainment.py +429 -0
  3. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/lfp.py +0 -356
  4. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/bmplot.py +56 -0
  5. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/synapses.py +1 -1
  6. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/PKG-INFO +1 -1
  7. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/SOURCES.txt +1 -0
  8. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/setup.py +1 -1
  9. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/LICENSE +0 -0
  10. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/README.md +0 -0
  11. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/SLURM.py +0 -0
  12. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/__init__.py +0 -0
  13. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/__main__.py +0 -0
  14. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/__init__.py +0 -0
  15. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/netcon_reports.py +0 -0
  16. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/analysis/spikes.py +0 -0
  17. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/connectors.py +0 -0
  18. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/__init__.py +0 -0
  19. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/commands.py +0 -0
  20. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/debug/debug.py +0 -0
  21. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/graphs.py +0 -0
  22. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/manage.py +0 -0
  23. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/plot_commands.py +0 -0
  24. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/singlecell.py +0 -0
  25. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/__init__.py +0 -0
  26. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/commands.py +0 -0
  27. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/neuron/__init__.py +0 -0
  28. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/neuron/celltuner.py +0 -0
  29. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool/util/util.py +0 -0
  30. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/dependency_links.txt +0 -0
  31. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/entry_points.txt +0 -0
  32. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/requires.txt +0 -0
  33. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/bmtool.egg-info/top_level.txt +0 -0
  34. {bmtool-0.6.9.24 → bmtool-0.6.9.25}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.24
3
+ Version: 0.6.9.25
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -0,0 +1,429 @@
1
+ """
2
+ Module for entrainment analysis
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy import signal
7
+ import numba
8
+ from numba import cuda
9
+ import pandas as pd
10
+ import xarray as xr
11
+ from .lfp import wavelet_filter,butter_bandpass_filter
12
+ from typing import Dict, List
13
+ from tqdm.notebook import tqdm
14
+
15
+
16
+ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
17
+ method: str = 'wavelet', lowcut: float = None, highcut: float = None,
18
+ bandwidth: float = 2.0) -> np.ndarray:
19
+ """
20
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
21
+
22
+ Parameters:
23
+ - x1, x2: Input signals (1D arrays, same length)
24
+ - fs: Sampling frequency
25
+ - freq_of_interest: Desired frequency for wavelet PLV calculation
26
+ - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
27
+ - lowcut, highcut: Cutoff frequencies for the Hilbert method
28
+ - bandwidth: Bandwidth parameter for the wavelet
29
+
30
+ Returns:
31
+ - plv: Phase Locking Value (1D array)
32
+ """
33
+ if len(x1) != len(x2):
34
+ raise ValueError("Input signals must have the same length.")
35
+
36
+ if method == 'wavelet':
37
+ if freq_of_interest is None:
38
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
39
+
40
+ # Apply CWT to both signals
41
+ theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
42
+ theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
43
+
44
+ elif method == 'hilbert':
45
+ if lowcut is None or highcut is None:
46
+ print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
47
+
48
+ if lowcut and highcut:
49
+ # Bandpass filter and get the analytic signal using the Hilbert transform
50
+ x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
51
+ x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
52
+
53
+ # Get phase using the Hilbert transform
54
+ theta1 = signal.hilbert(x1)
55
+ theta2 = signal.hilbert(x2)
56
+
57
+ else:
58
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
59
+
60
+ # Calculate phase difference
61
+ phase_diff = np.angle(theta1) - np.angle(theta2)
62
+
63
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
64
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
65
+
66
+ return plv
67
+
68
+
69
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
70
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
71
+ lowcut: float = None, highcut: float = None,
72
+ bandwidth: float = 2.0) -> tuple:
73
+ """
74
+ Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
75
+
76
+ Parameters:
77
+ - spike_times: Array of spike times
78
+ - lfp_signal: Local field potential time series
79
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
80
+ - lfp_fs : Sampling frequency in Hz of the LFP
81
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
82
+ - freq_of_interest: Desired frequency for wavelet phase extraction
83
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
84
+ - bandwidth: Bandwidth parameter for the wavelet
85
+
86
+ Returns:
87
+ - ppc1: Phase-Phase Coupling value
88
+ - spike_phases: Phases at spike times
89
+ """
90
+
91
+ if spike_fs == None:
92
+ spike_fs = lfp_fs
93
+ # Convert spike times to sample indices
94
+ spike_times_seconds = spike_times / spike_fs
95
+
96
+ # Then convert from seconds to samples at the new sampling rate
97
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
98
+
99
+ # Filter indices to ensure they're within bounds of the LFP signal
100
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
101
+ if len(valid_indices) <= 1:
102
+ return 0, np.array([])
103
+
104
+ # Extract phase using the specified method
105
+ if method == 'wavelet':
106
+ if freq_of_interest is None:
107
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
108
+
109
+ # Apply CWT to extract phase at the frequency of interest
110
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
111
+ instantaneous_phase = np.angle(lfp_complex)
112
+
113
+ elif method == 'hilbert':
114
+ if lowcut is None or highcut is None:
115
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
116
+ filtered_lfp = lfp_signal
117
+ else:
118
+ # Bandpass filter the signal
119
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
120
+
121
+ # Get phase using the Hilbert transform
122
+ analytic_signal = signal.hilbert(filtered_lfp)
123
+ instantaneous_phase = np.angle(analytic_signal)
124
+
125
+ else:
126
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
127
+
128
+ # Get phases at spike times
129
+ spike_phases = instantaneous_phase[valid_indices]
130
+
131
+ # Calculate PPC1
132
+ n = len(spike_phases)
133
+
134
+ # Convert phases to unit vectors in the complex plane
135
+ unit_vectors = np.exp(1j * spike_phases)
136
+
137
+ # Calculate the resultant vector
138
+ resultant_vector = np.sum(unit_vectors)
139
+
140
+ # Plv is the squared length of the resultant vector divided by n²
141
+ plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
142
+
143
+ return plv
144
+
145
+
146
+ @numba.njit(parallel=True, fastmath=True)
147
+ def _ppc_parallel_numba(spike_phases):
148
+ """Numba-optimized parallel PPC calculation"""
149
+ n = len(spike_phases)
150
+ sum_cos = 0.0
151
+ for i in numba.prange(n):
152
+ phase_i = spike_phases[i]
153
+ for j in range(i + 1, n):
154
+ sum_cos += np.cos(phase_i - spike_phases[j])
155
+ return (2 / (n * (n - 1))) * sum_cos
156
+
157
+
158
+ @cuda.jit(fastmath=True)
159
+ def _ppc_cuda_kernel(spike_phases, out):
160
+ i = cuda.grid(1)
161
+ if i < len(spike_phases):
162
+ local_sum = 0.0
163
+ for j in range(i+1, len(spike_phases)):
164
+ local_sum += np.cos(spike_phases[i] - spike_phases[j])
165
+ out[i] = local_sum
166
+
167
+
168
+ def _ppc_gpu(spike_phases):
169
+ """GPU-accelerated implementation"""
170
+ d_phases = cuda.to_device(spike_phases)
171
+ d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
172
+
173
+ threads = 256
174
+ blocks = (len(spike_phases) + threads - 1) // threads
175
+
176
+ _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
177
+ total = d_out.copy_to_host().sum()
178
+ return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
179
+
180
+
181
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
182
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
183
+ lowcut: float = None, highcut: float = None,
184
+ bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
185
+ """
186
+ Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
187
+ Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
188
+
189
+ Parameters:
190
+ - spike_times: Array of spike times
191
+ - lfp_signal: Local field potential time series
192
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
193
+ - lfp_fs: Sampling frequency in Hz of the LFP
194
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
195
+ - freq_of_interest: Desired frequency for wavelet phase extraction
196
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
197
+ - bandwidth: Bandwidth parameter for the wavelet
198
+ - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
199
+
200
+ Returns:
201
+ - ppc: Pairwise Phase Consistency value
202
+ """
203
+ if spike_fs is None:
204
+ spike_fs = lfp_fs
205
+ # Convert spike times to sample indices
206
+ spike_times_seconds = spike_times / spike_fs
207
+
208
+ # Then convert from seconds to samples at the new sampling rate
209
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
210
+
211
+ # Filter indices to ensure they're within bounds of the LFP signal
212
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
213
+ if len(valid_indices) <= 1:
214
+ return 0, np.array([])
215
+
216
+ # Extract phase using the specified method
217
+ if method == 'wavelet':
218
+ if freq_of_interest is None:
219
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
220
+
221
+ # Apply CWT to extract phase at the frequency of interest
222
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
223
+ instantaneous_phase = np.angle(lfp_complex)
224
+
225
+ elif method == 'hilbert':
226
+ if lowcut is None or highcut is None:
227
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
228
+ filtered_lfp = lfp_signal
229
+ else:
230
+ # Bandpass filter the signal
231
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
232
+
233
+ # Get phase using the Hilbert transform
234
+ analytic_signal = signal.hilbert(filtered_lfp)
235
+ instantaneous_phase = np.angle(analytic_signal)
236
+
237
+ else:
238
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
239
+
240
+ # Get phases at spike times
241
+ spike_phases = instantaneous_phase[valid_indices]
242
+
243
+ n_spikes = len(spike_phases)
244
+
245
+ # Calculate PPC (Pairwise Phase Consistency)
246
+ if n_spikes <= 1:
247
+ return 0, spike_phases
248
+
249
+ # Explicit calculation of pairwise phase consistency
250
+ sum_cos_diff = 0.0
251
+
252
+ # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
253
+ # for i in range(n_spikes - 1): # For each spike i
254
+ # for j in range(i + 1, n_spikes): # For each spike j > i
255
+ # # Calculate the phase difference between spikes i and j
256
+ # phase_diff = spike_phases[i] - spike_phases[j]
257
+
258
+ # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
259
+ # cos_diff = np.cos(phase_diff)
260
+
261
+ # # Add to the sum
262
+ # sum_cos_diff += cos_diff
263
+
264
+ # # Calculate PPC according to the equation
265
+ # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
266
+ # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
267
+
268
+ # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
269
+ if ppc_method == 'numpy':
270
+ i, j = np.triu_indices(n_spikes, k=1)
271
+ phase_diff = spike_phases[i] - spike_phases[j]
272
+ sum_cos_diff = np.sum(np.cos(phase_diff))
273
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
274
+ elif ppc_method == 'numba':
275
+ ppc = _ppc_parallel_numba(spike_phases)
276
+ elif ppc_method == 'gpu':
277
+ ppc = _ppc_gpu(spike_phases)
278
+ else:
279
+ raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
280
+ return ppc
281
+
282
+
283
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
284
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
285
+ lowcut: float = None, highcut: float = None,
286
+ bandwidth: float = 2.0) -> tuple:
287
+ """
288
+ # -----------------------------------------------------------------------------
289
+ # PPC2 Calculation (Vinck et al., 2010)
290
+ # -----------------------------------------------------------------------------
291
+ # Equation(Original):
292
+ # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
293
+ # Optimized Formula (Algebraically Equivalent):
294
+ # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
295
+ # -----------------------------------------------------------------------------
296
+
297
+ Parameters:
298
+ - spike_times: Array of spike times
299
+ - lfp_signal: Local field potential time series
300
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
301
+ - lfp_fs: Sampling frequency in Hz of the LFP
302
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
303
+ - freq_of_interest: Desired frequency for wavelet phase extraction
304
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
305
+ - bandwidth: Bandwidth parameter for the wavelet
306
+
307
+ Returns:
308
+ - ppc2: Pairwise Phase Consistency 2 value
309
+ """
310
+
311
+ if spike_fs is None:
312
+ spike_fs = lfp_fs
313
+ # Convert spike times to sample indices
314
+ spike_times_seconds = spike_times / spike_fs
315
+
316
+ # Then convert from seconds to samples at the new sampling rate
317
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
318
+
319
+ # Filter indices to ensure they're within bounds of the LFP signal
320
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
321
+ if len(valid_indices) <= 1:
322
+ return 0, np.array([])
323
+
324
+ # Extract phase using the specified method
325
+ if method == 'wavelet':
326
+ if freq_of_interest is None:
327
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
328
+
329
+ # Apply CWT to extract phase at the frequency of interest
330
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
331
+ instantaneous_phase = np.angle(lfp_complex)
332
+
333
+ elif method == 'hilbert':
334
+ if lowcut is None or highcut is None:
335
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
336
+ filtered_lfp = lfp_signal
337
+ else:
338
+ # Bandpass filter the signal
339
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
340
+
341
+ # Get phase using the Hilbert transform
342
+ analytic_signal = signal.hilbert(filtered_lfp)
343
+ instantaneous_phase = np.angle(analytic_signal)
344
+
345
+ else:
346
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
347
+
348
+ # Get phases at spike times
349
+ spike_phases = instantaneous_phase[valid_indices]
350
+
351
+ # Calculate PPC2 according to Vinck et al. (2010), Equation 6
352
+ n = len(spike_phases)
353
+
354
+ if n <= 1:
355
+ return 0, spike_phases
356
+
357
+ # Convert phases to unit vectors in the complex plane
358
+ unit_vectors = np.exp(1j * spike_phases)
359
+
360
+ # Calculate the resultant vector
361
+ resultant_vector = np.sum(unit_vectors)
362
+
363
+ # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
364
+ ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
365
+
366
+ return ppc2
367
+
368
+
369
+ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
370
+ spike_fs: float, lfp_fs:float,
371
+ pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
372
+ """
373
+ Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
374
+
375
+ This function computes the PPC for each neuron within the specified populations based on their spike times
376
+ and a single-channel local field potential (LFP) signal.
377
+
378
+ Args:
379
+ spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
380
+ lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
381
+ spike_fs (float): sampling rate of spikes BMTK default is 1000
382
+ lfp_fs (float): sampling rate of lfp
383
+ pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
384
+ freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
385
+
386
+ Returns:
387
+ Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
388
+ {
389
+ population_name: {
390
+ node_id: {
391
+ frequency: PPC value
392
+ }
393
+ }
394
+ }
395
+ PPC values are floats representing the pairwise phase consistency at each frequency.
396
+ """
397
+ ppc_dict = {}
398
+ for pop in pop_names:
399
+ skip_count = 0
400
+ pop_spikes = spike_df[spike_df['pop_name'] == pop]
401
+ nodes = pop_spikes['node_ids'].unique()
402
+ ppc_dict[pop] = {}
403
+ print(f'Processing {pop} population')
404
+ for node in tqdm(nodes):
405
+ node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
406
+
407
+ # Skip nodes with less than or equal to 1 spike
408
+ if len(node_spikes) <= 1:
409
+ skip_count += 1
410
+ continue
411
+
412
+ ppc_dict[pop][node] = {}
413
+ for freq in freqs:
414
+ ppc = calculate_ppc2(
415
+ node_spikes['timestamps'].values,
416
+ lfp_signal,
417
+ spike_fs=spike_fs,
418
+ lfp_fs=lfp_fs,
419
+ freq_of_interest=freq,
420
+ method='wavelet'
421
+ )
422
+ ppc_dict[pop][node][freq] = ppc
423
+
424
+ print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
425
+
426
+ return ppc_dict
427
+
428
+
429
+
@@ -11,8 +11,6 @@ import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
13
  from bmtool.bmplot import is_notebook
14
- import numba
15
- from numba import cuda
16
14
  import pandas as pd
17
15
 
18
16
 
@@ -295,360 +293,6 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
295
293
  return x_a
296
294
 
297
295
 
298
- def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
299
- method: str = 'wavelet', lowcut: float = None, highcut: float = None,
300
- bandwidth: float = 2.0) -> np.ndarray:
301
- """
302
- Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
303
-
304
- Parameters:
305
- - x1, x2: Input signals (1D arrays, same length)
306
- - fs: Sampling frequency
307
- - freq_of_interest: Desired frequency for wavelet PLV calculation
308
- - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
309
- - lowcut, highcut: Cutoff frequencies for the Hilbert method
310
- - bandwidth: Bandwidth parameter for the wavelet
311
-
312
- Returns:
313
- - plv: Phase Locking Value (1D array)
314
- """
315
- if len(x1) != len(x2):
316
- raise ValueError("Input signals must have the same length.")
317
-
318
- if method == 'wavelet':
319
- if freq_of_interest is None:
320
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
321
-
322
- # Apply CWT to both signals
323
- theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
324
- theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
325
-
326
- elif method == 'hilbert':
327
- if lowcut is None or highcut is None:
328
- print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
329
-
330
- if lowcut and highcut:
331
- # Bandpass filter and get the analytic signal using the Hilbert transform
332
- x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
333
- x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
334
-
335
- # Get phase using the Hilbert transform
336
- theta1 = signal.hilbert(x1)
337
- theta2 = signal.hilbert(x2)
338
-
339
- else:
340
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
341
-
342
- # Calculate phase difference
343
- phase_diff = np.angle(theta1) - np.angle(theta2)
344
-
345
- # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
346
- plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
347
-
348
- return plv
349
-
350
-
351
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
352
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
353
- lowcut: float = None, highcut: float = None,
354
- bandwidth: float = 2.0) -> tuple:
355
- """
356
- Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
357
-
358
- Parameters:
359
- - spike_times: Array of spike times
360
- - lfp_signal: Local field potential time series
361
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
362
- - lfp_fs : Sampling frequency in Hz of the LFP
363
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
364
- - freq_of_interest: Desired frequency for wavelet phase extraction
365
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
366
- - bandwidth: Bandwidth parameter for the wavelet
367
-
368
- Returns:
369
- - ppc1: Phase-Phase Coupling value
370
- - spike_phases: Phases at spike times
371
- """
372
-
373
- if spike_fs == None:
374
- spike_fs = lfp_fs
375
- # Convert spike times to sample indices
376
- spike_times_seconds = spike_times / spike_fs
377
-
378
- # Then convert from seconds to samples at the new sampling rate
379
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
380
-
381
- # Filter indices to ensure they're within bounds of the LFP signal
382
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
383
- if len(valid_indices) <= 1:
384
- return 0, np.array([])
385
-
386
- # Extract phase using the specified method
387
- if method == 'wavelet':
388
- if freq_of_interest is None:
389
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
390
-
391
- # Apply CWT to extract phase at the frequency of interest
392
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
393
- instantaneous_phase = np.angle(lfp_complex)
394
-
395
- elif method == 'hilbert':
396
- if lowcut is None or highcut is None:
397
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
398
- filtered_lfp = lfp_signal
399
- else:
400
- # Bandpass filter the signal
401
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
402
-
403
- # Get phase using the Hilbert transform
404
- analytic_signal = signal.hilbert(filtered_lfp)
405
- instantaneous_phase = np.angle(analytic_signal)
406
-
407
- else:
408
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
409
-
410
- # Get phases at spike times
411
- spike_phases = instantaneous_phase[valid_indices]
412
-
413
- # Calculate PPC1
414
- n = len(spike_phases)
415
-
416
- # Convert phases to unit vectors in the complex plane
417
- unit_vectors = np.exp(1j * spike_phases)
418
-
419
- # Calculate the resultant vector
420
- resultant_vector = np.sum(unit_vectors)
421
-
422
- # Plv is the squared length of the resultant vector divided by n²
423
- plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
424
-
425
- return plv
426
-
427
-
428
- @numba.njit(parallel=True, fastmath=True)
429
- def _ppc_parallel_numba(spike_phases):
430
- """Numba-optimized parallel PPC calculation"""
431
- n = len(spike_phases)
432
- sum_cos = 0.0
433
- for i in numba.prange(n):
434
- phase_i = spike_phases[i]
435
- for j in range(i + 1, n):
436
- sum_cos += np.cos(phase_i - spike_phases[j])
437
- return (2 / (n * (n - 1))) * sum_cos
438
-
439
-
440
- @cuda.jit(fastmath=True)
441
- def _ppc_cuda_kernel(spike_phases, out):
442
- i = cuda.grid(1)
443
- if i < len(spike_phases):
444
- local_sum = 0.0
445
- for j in range(i+1, len(spike_phases)):
446
- local_sum += np.cos(spike_phases[i] - spike_phases[j])
447
- out[i] = local_sum
448
-
449
-
450
- def _ppc_gpu(spike_phases):
451
- """GPU-accelerated implementation"""
452
- d_phases = cuda.to_device(spike_phases)
453
- d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
454
-
455
- threads = 256
456
- blocks = (len(spike_phases) + threads - 1) // threads
457
-
458
- _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
459
- total = d_out.copy_to_host().sum()
460
- return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
461
-
462
-
463
- def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
464
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
465
- lowcut: float = None, highcut: float = None,
466
- bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
467
- """
468
- Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
469
- Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
470
-
471
- Parameters:
472
- - spike_times: Array of spike times
473
- - lfp_signal: Local field potential time series
474
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
475
- - lfp_fs: Sampling frequency in Hz of the LFP
476
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
477
- - freq_of_interest: Desired frequency for wavelet phase extraction
478
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
479
- - bandwidth: Bandwidth parameter for the wavelet
480
- - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
481
-
482
- Returns:
483
- - ppc: Pairwise Phase Consistency value
484
- """
485
- if spike_fs is None:
486
- spike_fs = lfp_fs
487
- # Convert spike times to sample indices
488
- spike_times_seconds = spike_times / spike_fs
489
-
490
- # Then convert from seconds to samples at the new sampling rate
491
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
492
-
493
- # Filter indices to ensure they're within bounds of the LFP signal
494
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
495
- if len(valid_indices) <= 1:
496
- return 0, np.array([])
497
-
498
- # Extract phase using the specified method
499
- if method == 'wavelet':
500
- if freq_of_interest is None:
501
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
502
-
503
- # Apply CWT to extract phase at the frequency of interest
504
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
505
- instantaneous_phase = np.angle(lfp_complex)
506
-
507
- elif method == 'hilbert':
508
- if lowcut is None or highcut is None:
509
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
510
- filtered_lfp = lfp_signal
511
- else:
512
- # Bandpass filter the signal
513
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
514
-
515
- # Get phase using the Hilbert transform
516
- analytic_signal = signal.hilbert(filtered_lfp)
517
- instantaneous_phase = np.angle(analytic_signal)
518
-
519
- else:
520
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
521
-
522
- # Get phases at spike times
523
- spike_phases = instantaneous_phase[valid_indices]
524
-
525
- n_spikes = len(spike_phases)
526
-
527
- # Calculate PPC (Pairwise Phase Consistency)
528
- if n_spikes <= 1:
529
- return 0, spike_phases
530
-
531
- # Explicit calculation of pairwise phase consistency
532
- sum_cos_diff = 0.0
533
-
534
- # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
535
- # for i in range(n_spikes - 1): # For each spike i
536
- # for j in range(i + 1, n_spikes): # For each spike j > i
537
- # # Calculate the phase difference between spikes i and j
538
- # phase_diff = spike_phases[i] - spike_phases[j]
539
-
540
- # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
541
- # cos_diff = np.cos(phase_diff)
542
-
543
- # # Add to the sum
544
- # sum_cos_diff += cos_diff
545
-
546
- # # Calculate PPC according to the equation
547
- # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
548
- # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
549
-
550
- # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
551
- if ppc_method == 'numpy':
552
- i, j = np.triu_indices(n_spikes, k=1)
553
- phase_diff = spike_phases[i] - spike_phases[j]
554
- sum_cos_diff = np.sum(np.cos(phase_diff))
555
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
556
- elif ppc_method == 'numba':
557
- ppc = _ppc_parallel_numba(spike_phases)
558
- elif ppc_method == 'gpu':
559
- ppc = _ppc_gpu(spike_phases)
560
- else:
561
- raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
562
- return ppc
563
-
564
-
565
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
566
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
567
- lowcut: float = None, highcut: float = None,
568
- bandwidth: float = 2.0) -> tuple:
569
- """
570
- # -----------------------------------------------------------------------------
571
- # PPC2 Calculation (Vinck et al., 2010)
572
- # -----------------------------------------------------------------------------
573
- # Equation(Original):
574
- # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
575
- # Optimized Formula (Algebraically Equivalent):
576
- # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
577
- # -----------------------------------------------------------------------------
578
-
579
- Parameters:
580
- - spike_times: Array of spike times
581
- - lfp_signal: Local field potential time series
582
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
583
- - lfp_fs: Sampling frequency in Hz of the LFP
584
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
585
- - freq_of_interest: Desired frequency for wavelet phase extraction
586
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
587
- - bandwidth: Bandwidth parameter for the wavelet
588
-
589
- Returns:
590
- - ppc2: Pairwise Phase Consistency 2 value
591
- - spike_phases: Phases at spike times
592
- """
593
-
594
- if spike_fs is None:
595
- spike_fs = lfp_fs
596
- # Convert spike times to sample indices
597
- spike_times_seconds = spike_times / spike_fs
598
-
599
- # Then convert from seconds to samples at the new sampling rate
600
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
601
-
602
- # Filter indices to ensure they're within bounds of the LFP signal
603
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
604
- if len(valid_indices) <= 1:
605
- return 0, np.array([])
606
-
607
- # Extract phase using the specified method
608
- if method == 'wavelet':
609
- if freq_of_interest is None:
610
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
611
-
612
- # Apply CWT to extract phase at the frequency of interest
613
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
614
- instantaneous_phase = np.angle(lfp_complex)
615
-
616
- elif method == 'hilbert':
617
- if lowcut is None or highcut is None:
618
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
619
- filtered_lfp = lfp_signal
620
- else:
621
- # Bandpass filter the signal
622
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
623
-
624
- # Get phase using the Hilbert transform
625
- analytic_signal = signal.hilbert(filtered_lfp)
626
- instantaneous_phase = np.angle(analytic_signal)
627
-
628
- else:
629
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
630
-
631
- # Get phases at spike times
632
- spike_phases = instantaneous_phase[valid_indices]
633
-
634
- # Calculate PPC2 according to Vinck et al. (2010), Equation 6
635
- n = len(spike_phases)
636
-
637
- if n <= 1:
638
- return 0, spike_phases
639
-
640
- # Convert phases to unit vectors in the complex plane
641
- unit_vectors = np.exp(1j * spike_phases)
642
-
643
- # Calculate the resultant vector
644
- resultant_vector = np.sum(unit_vectors)
645
-
646
- # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
647
- ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
648
-
649
- return ppc2
650
-
651
-
652
296
  # windowing functions
653
297
  def windowed_xarray(da, windows, dim='time',
654
298
  new_coord_name='cycle', new_coord=None):
@@ -683,6 +683,62 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
683
683
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
684
684
  plt.draw()
685
685
 
686
+ def distance_delay_plot(simulation_config: str,source: str,target: str,
687
+ group_by: str,sid: str,tid: str) -> None:
688
+ """
689
+ Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
690
+
691
+ This function loads the node and edge data from a simulation configuration file, filters nodes by population or group,
692
+ identifies connections (edges) between source and target node populations, calculates the Euclidean distance between
693
+ connected nodes, and plots the delay as a function of distance.
694
+
695
+ Args:
696
+ simulation_config (str): Path to the simulation config file
697
+ source (str): The name of the source population in the edge data.
698
+ target (str): The name of the target population in the edge data.
699
+ group_by (str): Column name to group nodes by (e.g., population name).
700
+ sid (str): Identifier for the source group (e.g., 'PN').
701
+ tid (str): Identifier for the target group (e.g., 'PN').
702
+
703
+ Returns:
704
+ None: The function creates and displays a scatter plot of distance vs delay.
705
+ """
706
+ nodes, edges = util.load_nodes_edges_from_config(simulation_config)
707
+ nodes = nodes[target]
708
+ #node id is index of nodes df
709
+ node_id_source = nodes[nodes[group_by] == sid].index
710
+ node_id_target = nodes[nodes[group_by] == tid].index
711
+
712
+ edges = edges[f'{source}_to_{target}']
713
+ edges = edges[edges['source_node_id'].isin(node_id_source) & edges['target_node_id'].isin(node_id_target)]
714
+
715
+ stuff_to_plot = []
716
+ for index, row in edges.iterrows():
717
+ try:
718
+ source_node = row['source_node_id']
719
+ target_node = row['target_node_id']
720
+
721
+ source_pos = nodes.loc[[source_node], ['pos_x', 'pos_y', 'pos_z']]
722
+ target_pos = nodes.loc[[target_node], ['pos_x', 'pos_y', 'pos_z']]
723
+
724
+ distance = np.linalg.norm(source_pos.values - target_pos.values)
725
+
726
+ delay = row['delay'] # This line may raise KeyError
727
+ stuff_to_plot.append([distance, delay])
728
+
729
+ except KeyError as e:
730
+ print(f"KeyError: Missing key {e} in either edge properties or node positions.")
731
+ except IndexError as e:
732
+ print(f"IndexError: Node ID {source_node} or {target_node} not found in nodes.")
733
+ except Exception as e:
734
+ print(f"Unexpected error at edge index {index}: {e}")
735
+
736
+ plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
737
+ plt.xlabel('Distance')
738
+ plt.ylabel('Delay')
739
+ plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
740
+ plt.show()
741
+
686
742
  def plot_synapse_location_histograms(config, target_model, source=None, target=None):
687
743
  """
688
744
  generates a histogram of the positions of the synapses on a cell broken down by section
@@ -647,7 +647,7 @@ class SynapseTuner:
647
647
  # Widgets setup (Sliders)
648
648
  freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
649
649
  delays = [125, 250, 500, 1000, 2000, 4000]
650
- durations = [300, 500, 1000, 2000, 5000, 10000]
650
+ durations = [100, 300, 500, 1000, 2000, 5000, 10000]
651
651
  freq0 = 50
652
652
  delay0 = 250
653
653
  duration0 = 300
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.24
3
+ Version: 0.6.9.25
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -18,6 +18,7 @@ bmtool.egg-info/entry_points.txt
18
18
  bmtool.egg-info/requires.txt
19
19
  bmtool.egg-info/top_level.txt
20
20
  bmtool/analysis/__init__.py
21
+ bmtool/analysis/entrainment.py
21
22
  bmtool/analysis/lfp.py
22
23
  bmtool/analysis/netcon_reports.py
23
24
  bmtool/analysis/spikes.py
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.9.24',
9
+ version='0.6.9.25',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes