bmtool 0.6.9.23__tar.gz → 0.6.9.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/PKG-INFO +1 -1
  2. bmtool-0.6.9.25/bmtool/analysis/entrainment.py +429 -0
  3. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/analysis/lfp.py +0 -356
  4. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/analysis/spikes.py +53 -0
  5. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/bmplot.py +151 -1
  6. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/synapses.py +1 -1
  7. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/util/util.py +28 -0
  8. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/PKG-INFO +1 -1
  9. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/SOURCES.txt +1 -0
  10. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/setup.py +1 -1
  11. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/LICENSE +0 -0
  12. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/README.md +0 -0
  13. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/SLURM.py +0 -0
  14. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/__init__.py +0 -0
  15. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/__main__.py +0 -0
  16. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/analysis/__init__.py +0 -0
  17. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/analysis/netcon_reports.py +0 -0
  18. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/connectors.py +0 -0
  19. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/debug/__init__.py +0 -0
  20. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/debug/commands.py +0 -0
  21. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/debug/debug.py +0 -0
  22. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/graphs.py +0 -0
  23. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/manage.py +0 -0
  24. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/plot_commands.py +0 -0
  25. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/singlecell.py +0 -0
  26. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/util/__init__.py +0 -0
  27. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/util/commands.py +0 -0
  28. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/util/neuron/__init__.py +0 -0
  29. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool/util/neuron/celltuner.py +0 -0
  30. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/dependency_links.txt +0 -0
  31. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/entry_points.txt +0 -0
  32. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/requires.txt +0 -0
  33. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/bmtool.egg-info/top_level.txt +0 -0
  34. {bmtool-0.6.9.23 → bmtool-0.6.9.25}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.23
3
+ Version: 0.6.9.25
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -0,0 +1,429 @@
1
+ """
2
+ Module for entrainment analysis
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy import signal
7
+ import numba
8
+ from numba import cuda
9
+ import pandas as pd
10
+ import xarray as xr
11
+ from .lfp import wavelet_filter,butter_bandpass_filter
12
+ from typing import Dict, List
13
+ from tqdm.notebook import tqdm
14
+
15
+
16
+ def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
17
+ method: str = 'wavelet', lowcut: float = None, highcut: float = None,
18
+ bandwidth: float = 2.0) -> np.ndarray:
19
+ """
20
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
21
+
22
+ Parameters:
23
+ - x1, x2: Input signals (1D arrays, same length)
24
+ - fs: Sampling frequency
25
+ - freq_of_interest: Desired frequency for wavelet PLV calculation
26
+ - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
27
+ - lowcut, highcut: Cutoff frequencies for the Hilbert method
28
+ - bandwidth: Bandwidth parameter for the wavelet
29
+
30
+ Returns:
31
+ - plv: Phase Locking Value (1D array)
32
+ """
33
+ if len(x1) != len(x2):
34
+ raise ValueError("Input signals must have the same length.")
35
+
36
+ if method == 'wavelet':
37
+ if freq_of_interest is None:
38
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
39
+
40
+ # Apply CWT to both signals
41
+ theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
42
+ theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
43
+
44
+ elif method == 'hilbert':
45
+ if lowcut is None or highcut is None:
46
+ print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
47
+
48
+ if lowcut and highcut:
49
+ # Bandpass filter and get the analytic signal using the Hilbert transform
50
+ x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
51
+ x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
52
+
53
+ # Get phase using the Hilbert transform
54
+ theta1 = signal.hilbert(x1)
55
+ theta2 = signal.hilbert(x2)
56
+
57
+ else:
58
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
59
+
60
+ # Calculate phase difference
61
+ phase_diff = np.angle(theta1) - np.angle(theta2)
62
+
63
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
64
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
65
+
66
+ return plv
67
+
68
+
69
+ def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
70
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
71
+ lowcut: float = None, highcut: float = None,
72
+ bandwidth: float = 2.0) -> tuple:
73
+ """
74
+ Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
75
+
76
+ Parameters:
77
+ - spike_times: Array of spike times
78
+ - lfp_signal: Local field potential time series
79
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
80
+ - lfp_fs : Sampling frequency in Hz of the LFP
81
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
82
+ - freq_of_interest: Desired frequency for wavelet phase extraction
83
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
84
+ - bandwidth: Bandwidth parameter for the wavelet
85
+
86
+ Returns:
87
+ - ppc1: Phase-Phase Coupling value
88
+ - spike_phases: Phases at spike times
89
+ """
90
+
91
+ if spike_fs == None:
92
+ spike_fs = lfp_fs
93
+ # Convert spike times to sample indices
94
+ spike_times_seconds = spike_times / spike_fs
95
+
96
+ # Then convert from seconds to samples at the new sampling rate
97
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
98
+
99
+ # Filter indices to ensure they're within bounds of the LFP signal
100
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
101
+ if len(valid_indices) <= 1:
102
+ return 0, np.array([])
103
+
104
+ # Extract phase using the specified method
105
+ if method == 'wavelet':
106
+ if freq_of_interest is None:
107
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
108
+
109
+ # Apply CWT to extract phase at the frequency of interest
110
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
111
+ instantaneous_phase = np.angle(lfp_complex)
112
+
113
+ elif method == 'hilbert':
114
+ if lowcut is None or highcut is None:
115
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
116
+ filtered_lfp = lfp_signal
117
+ else:
118
+ # Bandpass filter the signal
119
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
120
+
121
+ # Get phase using the Hilbert transform
122
+ analytic_signal = signal.hilbert(filtered_lfp)
123
+ instantaneous_phase = np.angle(analytic_signal)
124
+
125
+ else:
126
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
127
+
128
+ # Get phases at spike times
129
+ spike_phases = instantaneous_phase[valid_indices]
130
+
131
+ # Calculate PPC1
132
+ n = len(spike_phases)
133
+
134
+ # Convert phases to unit vectors in the complex plane
135
+ unit_vectors = np.exp(1j * spike_phases)
136
+
137
+ # Calculate the resultant vector
138
+ resultant_vector = np.sum(unit_vectors)
139
+
140
+ # Plv is the squared length of the resultant vector divided by n²
141
+ plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
142
+
143
+ return plv
144
+
145
+
146
+ @numba.njit(parallel=True, fastmath=True)
147
+ def _ppc_parallel_numba(spike_phases):
148
+ """Numba-optimized parallel PPC calculation"""
149
+ n = len(spike_phases)
150
+ sum_cos = 0.0
151
+ for i in numba.prange(n):
152
+ phase_i = spike_phases[i]
153
+ for j in range(i + 1, n):
154
+ sum_cos += np.cos(phase_i - spike_phases[j])
155
+ return (2 / (n * (n - 1))) * sum_cos
156
+
157
+
158
+ @cuda.jit(fastmath=True)
159
+ def _ppc_cuda_kernel(spike_phases, out):
160
+ i = cuda.grid(1)
161
+ if i < len(spike_phases):
162
+ local_sum = 0.0
163
+ for j in range(i+1, len(spike_phases)):
164
+ local_sum += np.cos(spike_phases[i] - spike_phases[j])
165
+ out[i] = local_sum
166
+
167
+
168
+ def _ppc_gpu(spike_phases):
169
+ """GPU-accelerated implementation"""
170
+ d_phases = cuda.to_device(spike_phases)
171
+ d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
172
+
173
+ threads = 256
174
+ blocks = (len(spike_phases) + threads - 1) // threads
175
+
176
+ _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
177
+ total = d_out.copy_to_host().sum()
178
+ return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
179
+
180
+
181
+ def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
182
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
183
+ lowcut: float = None, highcut: float = None,
184
+ bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
185
+ """
186
+ Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
187
+ Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
188
+
189
+ Parameters:
190
+ - spike_times: Array of spike times
191
+ - lfp_signal: Local field potential time series
192
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
193
+ - lfp_fs: Sampling frequency in Hz of the LFP
194
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
195
+ - freq_of_interest: Desired frequency for wavelet phase extraction
196
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
197
+ - bandwidth: Bandwidth parameter for the wavelet
198
+ - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
199
+
200
+ Returns:
201
+ - ppc: Pairwise Phase Consistency value
202
+ """
203
+ if spike_fs is None:
204
+ spike_fs = lfp_fs
205
+ # Convert spike times to sample indices
206
+ spike_times_seconds = spike_times / spike_fs
207
+
208
+ # Then convert from seconds to samples at the new sampling rate
209
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
210
+
211
+ # Filter indices to ensure they're within bounds of the LFP signal
212
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
213
+ if len(valid_indices) <= 1:
214
+ return 0, np.array([])
215
+
216
+ # Extract phase using the specified method
217
+ if method == 'wavelet':
218
+ if freq_of_interest is None:
219
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
220
+
221
+ # Apply CWT to extract phase at the frequency of interest
222
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
223
+ instantaneous_phase = np.angle(lfp_complex)
224
+
225
+ elif method == 'hilbert':
226
+ if lowcut is None or highcut is None:
227
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
228
+ filtered_lfp = lfp_signal
229
+ else:
230
+ # Bandpass filter the signal
231
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
232
+
233
+ # Get phase using the Hilbert transform
234
+ analytic_signal = signal.hilbert(filtered_lfp)
235
+ instantaneous_phase = np.angle(analytic_signal)
236
+
237
+ else:
238
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
239
+
240
+ # Get phases at spike times
241
+ spike_phases = instantaneous_phase[valid_indices]
242
+
243
+ n_spikes = len(spike_phases)
244
+
245
+ # Calculate PPC (Pairwise Phase Consistency)
246
+ if n_spikes <= 1:
247
+ return 0, spike_phases
248
+
249
+ # Explicit calculation of pairwise phase consistency
250
+ sum_cos_diff = 0.0
251
+
252
+ # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
253
+ # for i in range(n_spikes - 1): # For each spike i
254
+ # for j in range(i + 1, n_spikes): # For each spike j > i
255
+ # # Calculate the phase difference between spikes i and j
256
+ # phase_diff = spike_phases[i] - spike_phases[j]
257
+
258
+ # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
259
+ # cos_diff = np.cos(phase_diff)
260
+
261
+ # # Add to the sum
262
+ # sum_cos_diff += cos_diff
263
+
264
+ # # Calculate PPC according to the equation
265
+ # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
266
+ # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
267
+
268
+ # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
269
+ if ppc_method == 'numpy':
270
+ i, j = np.triu_indices(n_spikes, k=1)
271
+ phase_diff = spike_phases[i] - spike_phases[j]
272
+ sum_cos_diff = np.sum(np.cos(phase_diff))
273
+ ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
274
+ elif ppc_method == 'numba':
275
+ ppc = _ppc_parallel_numba(spike_phases)
276
+ elif ppc_method == 'gpu':
277
+ ppc = _ppc_gpu(spike_phases)
278
+ else:
279
+ raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
280
+ return ppc
281
+
282
+
283
+ def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
284
+ lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
285
+ lowcut: float = None, highcut: float = None,
286
+ bandwidth: float = 2.0) -> tuple:
287
+ """
288
+ # -----------------------------------------------------------------------------
289
+ # PPC2 Calculation (Vinck et al., 2010)
290
+ # -----------------------------------------------------------------------------
291
+ # Equation(Original):
292
+ # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
293
+ # Optimized Formula (Algebraically Equivalent):
294
+ # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
295
+ # -----------------------------------------------------------------------------
296
+
297
+ Parameters:
298
+ - spike_times: Array of spike times
299
+ - lfp_signal: Local field potential time series
300
+ - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
301
+ - lfp_fs: Sampling frequency in Hz of the LFP
302
+ - method: 'wavelet' or 'hilbert' to choose the phase extraction method
303
+ - freq_of_interest: Desired frequency for wavelet phase extraction
304
+ - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
305
+ - bandwidth: Bandwidth parameter for the wavelet
306
+
307
+ Returns:
308
+ - ppc2: Pairwise Phase Consistency 2 value
309
+ """
310
+
311
+ if spike_fs is None:
312
+ spike_fs = lfp_fs
313
+ # Convert spike times to sample indices
314
+ spike_times_seconds = spike_times / spike_fs
315
+
316
+ # Then convert from seconds to samples at the new sampling rate
317
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
318
+
319
+ # Filter indices to ensure they're within bounds of the LFP signal
320
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
321
+ if len(valid_indices) <= 1:
322
+ return 0, np.array([])
323
+
324
+ # Extract phase using the specified method
325
+ if method == 'wavelet':
326
+ if freq_of_interest is None:
327
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
328
+
329
+ # Apply CWT to extract phase at the frequency of interest
330
+ lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
331
+ instantaneous_phase = np.angle(lfp_complex)
332
+
333
+ elif method == 'hilbert':
334
+ if lowcut is None or highcut is None:
335
+ print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
336
+ filtered_lfp = lfp_signal
337
+ else:
338
+ # Bandpass filter the signal
339
+ filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
340
+
341
+ # Get phase using the Hilbert transform
342
+ analytic_signal = signal.hilbert(filtered_lfp)
343
+ instantaneous_phase = np.angle(analytic_signal)
344
+
345
+ else:
346
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
347
+
348
+ # Get phases at spike times
349
+ spike_phases = instantaneous_phase[valid_indices]
350
+
351
+ # Calculate PPC2 according to Vinck et al. (2010), Equation 6
352
+ n = len(spike_phases)
353
+
354
+ if n <= 1:
355
+ return 0, spike_phases
356
+
357
+ # Convert phases to unit vectors in the complex plane
358
+ unit_vectors = np.exp(1j * spike_phases)
359
+
360
+ # Calculate the resultant vector
361
+ resultant_vector = np.sum(unit_vectors)
362
+
363
+ # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
364
+ ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
365
+
366
+ return ppc2
367
+
368
+
369
+ def calculate_ppc_per_cell(spike_df: pd.DataFrame, lfp_signal: np.ndarray,
370
+ spike_fs: float, lfp_fs:float,
371
+ pop_names: List[str],freqs: List[float]) -> Dict[str, Dict[int, Dict[float, float]]]:
372
+ """
373
+ Calculate pairwise phase consistency (PPC) per neuron (cell) for specified frequencies across different populations.
374
+
375
+ This function computes the PPC for each neuron within the specified populations based on their spike times
376
+ and a single-channel local field potential (LFP) signal.
377
+
378
+ Args:
379
+ spike_df (pd.DataFrame): Spike dataframe use bmtool.analysis.load_spikes_to_df
380
+ lfp (xr.DataArray): xarray DataArray representing the LFP use bmtool.analysis.ecp_to_lfp
381
+ spike_fs (float): sampling rate of spikes BMTK default is 1000
382
+ lfp_fs (float): sampling rate of lfp
383
+ pop_names (List[str]): List of population names (as strings) to compute PPC for. pop_names should be in spike_df
384
+ freqs (List[float]): List of frequencies (in Hz) at which to calculate PPC.
385
+
386
+ Returns:
387
+ Dict[str, Dict[int, Dict[float, float]]]: Nested dictionary where the structure is:
388
+ {
389
+ population_name: {
390
+ node_id: {
391
+ frequency: PPC value
392
+ }
393
+ }
394
+ }
395
+ PPC values are floats representing the pairwise phase consistency at each frequency.
396
+ """
397
+ ppc_dict = {}
398
+ for pop in pop_names:
399
+ skip_count = 0
400
+ pop_spikes = spike_df[spike_df['pop_name'] == pop]
401
+ nodes = pop_spikes['node_ids'].unique()
402
+ ppc_dict[pop] = {}
403
+ print(f'Processing {pop} population')
404
+ for node in tqdm(nodes):
405
+ node_spikes = pop_spikes[pop_spikes['node_ids'] == node]
406
+
407
+ # Skip nodes with less than or equal to 1 spike
408
+ if len(node_spikes) <= 1:
409
+ skip_count += 1
410
+ continue
411
+
412
+ ppc_dict[pop][node] = {}
413
+ for freq in freqs:
414
+ ppc = calculate_ppc2(
415
+ node_spikes['timestamps'].values,
416
+ lfp_signal,
417
+ spike_fs=spike_fs,
418
+ lfp_fs=lfp_fs,
419
+ freq_of_interest=freq,
420
+ method='wavelet'
421
+ )
422
+ ppc_dict[pop][node][freq] = ppc
423
+
424
+ print(f'Calculated PPC for {pop} population with {len(nodes)-skip_count} valid cells, skipped {skip_count} cells for lack of spikes')
425
+
426
+ return ppc_dict
427
+
428
+
429
+
@@ -11,8 +11,6 @@ import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
13
  from bmtool.bmplot import is_notebook
14
- import numba
15
- from numba import cuda
16
14
  import pandas as pd
17
15
 
18
16
 
@@ -295,360 +293,6 @@ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs:
295
293
  return x_a
296
294
 
297
295
 
298
- def calculate_signal_signal_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
299
- method: str = 'wavelet', lowcut: float = None, highcut: float = None,
300
- bandwidth: float = 2.0) -> np.ndarray:
301
- """
302
- Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
303
-
304
- Parameters:
305
- - x1, x2: Input signals (1D arrays, same length)
306
- - fs: Sampling frequency
307
- - freq_of_interest: Desired frequency for wavelet PLV calculation
308
- - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
309
- - lowcut, highcut: Cutoff frequencies for the Hilbert method
310
- - bandwidth: Bandwidth parameter for the wavelet
311
-
312
- Returns:
313
- - plv: Phase Locking Value (1D array)
314
- """
315
- if len(x1) != len(x2):
316
- raise ValueError("Input signals must have the same length.")
317
-
318
- if method == 'wavelet':
319
- if freq_of_interest is None:
320
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
321
-
322
- # Apply CWT to both signals
323
- theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
324
- theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
325
-
326
- elif method == 'hilbert':
327
- if lowcut is None or highcut is None:
328
- print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
329
-
330
- if lowcut and highcut:
331
- # Bandpass filter and get the analytic signal using the Hilbert transform
332
- x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
333
- x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
334
-
335
- # Get phase using the Hilbert transform
336
- theta1 = signal.hilbert(x1)
337
- theta2 = signal.hilbert(x2)
338
-
339
- else:
340
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
341
-
342
- # Calculate phase difference
343
- phase_diff = np.angle(theta1) - np.angle(theta2)
344
-
345
- # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
346
- plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
347
-
348
- return plv
349
-
350
-
351
- def calculate_spike_lfp_plv(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs : float = None,
352
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
353
- lowcut: float = None, highcut: float = None,
354
- bandwidth: float = 2.0) -> tuple:
355
- """
356
- Calculate spike-lfp phase locking value Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
357
-
358
- Parameters:
359
- - spike_times: Array of spike times
360
- - lfp_signal: Local field potential time series
361
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
362
- - lfp_fs : Sampling frequency in Hz of the LFP
363
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
364
- - freq_of_interest: Desired frequency for wavelet phase extraction
365
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
366
- - bandwidth: Bandwidth parameter for the wavelet
367
-
368
- Returns:
369
- - ppc1: Phase-Phase Coupling value
370
- - spike_phases: Phases at spike times
371
- """
372
-
373
- if spike_fs == None:
374
- spike_fs = lfp_fs
375
- # Convert spike times to sample indices
376
- spike_times_seconds = spike_times / spike_fs
377
-
378
- # Then convert from seconds to samples at the new sampling rate
379
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
380
-
381
- # Filter indices to ensure they're within bounds of the LFP signal
382
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
383
- if len(valid_indices) <= 1:
384
- return 0, np.array([])
385
-
386
- # Extract phase using the specified method
387
- if method == 'wavelet':
388
- if freq_of_interest is None:
389
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
390
-
391
- # Apply CWT to extract phase at the frequency of interest
392
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
393
- instantaneous_phase = np.angle(lfp_complex)
394
-
395
- elif method == 'hilbert':
396
- if lowcut is None or highcut is None:
397
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC1 calculation")
398
- filtered_lfp = lfp_signal
399
- else:
400
- # Bandpass filter the signal
401
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
402
-
403
- # Get phase using the Hilbert transform
404
- analytic_signal = signal.hilbert(filtered_lfp)
405
- instantaneous_phase = np.angle(analytic_signal)
406
-
407
- else:
408
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
409
-
410
- # Get phases at spike times
411
- spike_phases = instantaneous_phase[valid_indices]
412
-
413
- # Calculate PPC1
414
- n = len(spike_phases)
415
-
416
- # Convert phases to unit vectors in the complex plane
417
- unit_vectors = np.exp(1j * spike_phases)
418
-
419
- # Calculate the resultant vector
420
- resultant_vector = np.sum(unit_vectors)
421
-
422
- # Plv is the squared length of the resultant vector divided by n²
423
- plv = (np.abs(resultant_vector) ** 2) / (n ** 2)
424
-
425
- return plv
426
-
427
-
428
- @numba.njit(parallel=True, fastmath=True)
429
- def _ppc_parallel_numba(spike_phases):
430
- """Numba-optimized parallel PPC calculation"""
431
- n = len(spike_phases)
432
- sum_cos = 0.0
433
- for i in numba.prange(n):
434
- phase_i = spike_phases[i]
435
- for j in range(i + 1, n):
436
- sum_cos += np.cos(phase_i - spike_phases[j])
437
- return (2 / (n * (n - 1))) * sum_cos
438
-
439
-
440
- @cuda.jit(fastmath=True)
441
- def _ppc_cuda_kernel(spike_phases, out):
442
- i = cuda.grid(1)
443
- if i < len(spike_phases):
444
- local_sum = 0.0
445
- for j in range(i+1, len(spike_phases)):
446
- local_sum += np.cos(spike_phases[i] - spike_phases[j])
447
- out[i] = local_sum
448
-
449
-
450
- def _ppc_gpu(spike_phases):
451
- """GPU-accelerated implementation"""
452
- d_phases = cuda.to_device(spike_phases)
453
- d_out = cuda.device_array(len(spike_phases), dtype=np.float64)
454
-
455
- threads = 256
456
- blocks = (len(spike_phases) + threads - 1) // threads
457
-
458
- _ppc_cuda_kernel[blocks, threads](d_phases, d_out)
459
- total = d_out.copy_to_host().sum()
460
- return (2/(len(spike_phases)*(len(spike_phases)-1))) * total
461
-
462
-
463
- def calculate_ppc(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
464
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
465
- lowcut: float = None, highcut: float = None,
466
- bandwidth: float = 2.0,ppc_method: str = 'numpy') -> tuple:
467
- """
468
- Calculate Pairwise Phase Consistency (PPC) between spike times and LFP signal.
469
- Based on https://www.sciencedirect.com/science/article/pii/S1053811910000959
470
-
471
- Parameters:
472
- - spike_times: Array of spike times
473
- - lfp_signal: Local field potential time series
474
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
475
- - lfp_fs: Sampling frequency in Hz of the LFP
476
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
477
- - freq_of_interest: Desired frequency for wavelet phase extraction
478
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
479
- - bandwidth: Bandwidth parameter for the wavelet
480
- - ppc_method: which algo to use for PPC calculate can be numpy, numba or gpu
481
-
482
- Returns:
483
- - ppc: Pairwise Phase Consistency value
484
- """
485
- if spike_fs is None:
486
- spike_fs = lfp_fs
487
- # Convert spike times to sample indices
488
- spike_times_seconds = spike_times / spike_fs
489
-
490
- # Then convert from seconds to samples at the new sampling rate
491
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
492
-
493
- # Filter indices to ensure they're within bounds of the LFP signal
494
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
495
- if len(valid_indices) <= 1:
496
- return 0, np.array([])
497
-
498
- # Extract phase using the specified method
499
- if method == 'wavelet':
500
- if freq_of_interest is None:
501
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
502
-
503
- # Apply CWT to extract phase at the frequency of interest
504
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
505
- instantaneous_phase = np.angle(lfp_complex)
506
-
507
- elif method == 'hilbert':
508
- if lowcut is None or highcut is None:
509
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC calculation")
510
- filtered_lfp = lfp_signal
511
- else:
512
- # Bandpass filter the signal
513
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
514
-
515
- # Get phase using the Hilbert transform
516
- analytic_signal = signal.hilbert(filtered_lfp)
517
- instantaneous_phase = np.angle(analytic_signal)
518
-
519
- else:
520
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
521
-
522
- # Get phases at spike times
523
- spike_phases = instantaneous_phase[valid_indices]
524
-
525
- n_spikes = len(spike_phases)
526
-
527
- # Calculate PPC (Pairwise Phase Consistency)
528
- if n_spikes <= 1:
529
- return 0, spike_phases
530
-
531
- # Explicit calculation of pairwise phase consistency
532
- sum_cos_diff = 0.0
533
-
534
- # # Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
535
- # for i in range(n_spikes - 1): # For each spike i
536
- # for j in range(i + 1, n_spikes): # For each spike j > i
537
- # # Calculate the phase difference between spikes i and j
538
- # phase_diff = spike_phases[i] - spike_phases[j]
539
-
540
- # #f(θᵢ, θⱼ) = cos(θᵢ)cos(θⱼ) + sin(θᵢ)sin(θⱼ) can become #f(θᵢ, θⱼ) = cos(θᵢ - θⱼ)
541
- # cos_diff = np.cos(phase_diff)
542
-
543
- # # Add to the sum
544
- # sum_cos_diff += cos_diff
545
-
546
- # # Calculate PPC according to the equation
547
- # # PPC = (2 / (N(N-1))) * Σᵢ Σⱼ₍ᵢ₊₁₎ f(θᵢ, θⱼ)
548
- # ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
549
-
550
- # same as above (i think) but with vectorized computation and memory fixes so it wont take forever to run.
551
- if ppc_method == 'numpy':
552
- i, j = np.triu_indices(n_spikes, k=1)
553
- phase_diff = spike_phases[i] - spike_phases[j]
554
- sum_cos_diff = np.sum(np.cos(phase_diff))
555
- ppc = ((2 / (n_spikes * (n_spikes - 1))) * sum_cos_diff)
556
- elif ppc_method == 'numba':
557
- ppc = _ppc_parallel_numba(spike_phases)
558
- elif ppc_method == 'gpu':
559
- ppc = _ppc_gpu(spike_phases)
560
- else:
561
- raise ExceptionType("Please use a supported ppc method currently that is numpy, numba or gpu")
562
- return ppc
563
-
564
-
565
- def calculate_ppc2(spike_times: np.ndarray = None, lfp_signal: np.ndarray = None, spike_fs: float = None,
566
- lfp_fs: float = None, method: str = 'hilbert', freq_of_interest: float = None,
567
- lowcut: float = None, highcut: float = None,
568
- bandwidth: float = 2.0) -> tuple:
569
- """
570
- # -----------------------------------------------------------------------------
571
- # PPC2 Calculation (Vinck et al., 2010)
572
- # -----------------------------------------------------------------------------
573
- # Equation(Original):
574
- # PPC = (2 / (n * (n - 1))) * sum(cos(φ_i - φ_j) for all i < j)
575
- # Optimized Formula (Algebraically Equivalent):
576
- # PPC = (|sum(e^(i*φ_j))|^2 - n) / (n * (n - 1))
577
- # -----------------------------------------------------------------------------
578
-
579
- Parameters:
580
- - spike_times: Array of spike times
581
- - lfp_signal: Local field potential time series
582
- - spike_fs: Sampling frequency in Hz of the spike times only needed if spikes times and lfp has different fs
583
- - lfp_fs: Sampling frequency in Hz of the LFP
584
- - method: 'wavelet' or 'hilbert' to choose the phase extraction method
585
- - freq_of_interest: Desired frequency for wavelet phase extraction
586
- - lowcut, highcut: Cutoff frequencies for bandpass filtering (Hilbert method)
587
- - bandwidth: Bandwidth parameter for the wavelet
588
-
589
- Returns:
590
- - ppc2: Pairwise Phase Consistency 2 value
591
- - spike_phases: Phases at spike times
592
- """
593
-
594
- if spike_fs is None:
595
- spike_fs = lfp_fs
596
- # Convert spike times to sample indices
597
- spike_times_seconds = spike_times / spike_fs
598
-
599
- # Then convert from seconds to samples at the new sampling rate
600
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
601
-
602
- # Filter indices to ensure they're within bounds of the LFP signal
603
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_signal)]
604
- if len(valid_indices) <= 1:
605
- return 0, np.array([])
606
-
607
- # Extract phase using the specified method
608
- if method == 'wavelet':
609
- if freq_of_interest is None:
610
- raise ValueError("freq_of_interest must be provided for the wavelet method.")
611
-
612
- # Apply CWT to extract phase at the frequency of interest
613
- lfp_complex = wavelet_filter(x=lfp_signal, freq=freq_of_interest, fs=lfp_fs, bandwidth=bandwidth)
614
- instantaneous_phase = np.angle(lfp_complex)
615
-
616
- elif method == 'hilbert':
617
- if lowcut is None or highcut is None:
618
- print("Lowcut and/or highcut were not defined, signal will not be filtered and will just take Hilbert transform for PPC2 calculation")
619
- filtered_lfp = lfp_signal
620
- else:
621
- # Bandpass filter the signal
622
- filtered_lfp = butter_bandpass_filter(lfp_signal, lowcut, highcut, lfp_fs)
623
-
624
- # Get phase using the Hilbert transform
625
- analytic_signal = signal.hilbert(filtered_lfp)
626
- instantaneous_phase = np.angle(analytic_signal)
627
-
628
- else:
629
- raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
630
-
631
- # Get phases at spike times
632
- spike_phases = instantaneous_phase[valid_indices]
633
-
634
- # Calculate PPC2 according to Vinck et al. (2010), Equation 6
635
- n = len(spike_phases)
636
-
637
- if n <= 1:
638
- return 0, spike_phases
639
-
640
- # Convert phases to unit vectors in the complex plane
641
- unit_vectors = np.exp(1j * spike_phases)
642
-
643
- # Calculate the resultant vector
644
- resultant_vector = np.sum(unit_vectors)
645
-
646
- # PPC2 = (|∑(e^(i*φ_j))|² - n) / (n * (n - 1))
647
- ppc2 = (np.abs(resultant_vector)**2 - n) / (n * (n - 1))
648
-
649
- return ppc2
650
-
651
-
652
296
  # windowing functions
653
297
  def windowed_xarray(da, windows, dim='time',
654
298
  new_coord_name='cycle', new_coord=None):
@@ -7,6 +7,7 @@ import pandas as pd
7
7
  from bmtool.util.util import load_nodes_from_config
8
8
  from typing import Dict, Optional,Tuple, Union, List
9
9
  import numpy as np
10
+ from scipy.stats import mannwhitneyu
10
11
  import os
11
12
 
12
13
 
@@ -252,3 +253,55 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
252
253
 
253
254
  return spike_rate
254
255
 
256
+
257
+ def compare_firing_over_times(spike_df,group_by, time_window_1, time_window_2):
258
+ """
259
+ Compares the firing rates of a population during two different time windows
260
+ time_window_1 and time_window_2 should be a list of [start, stop] in milliseconds
261
+ Returns firing rates and results of a Mann-Whitney U test (non-parametric)
262
+ """
263
+ # Filter spikes for the population of interest
264
+ for pop_name in spike_df[group_by].unique():
265
+ print(f"Population: {pop_name}")
266
+ pop_spikes = spike_df[spike_df[group_by] == pop_name]
267
+
268
+ # Filter by time windows
269
+ pop_spikes_1 = pop_spikes[(pop_spikes['timestamps'] >= time_window_1[0]) & (pop_spikes['timestamps'] <= time_window_1[1])]
270
+ pop_spikes_2 = pop_spikes[(pop_spikes['timestamps'] >= time_window_2[0]) & (pop_spikes['timestamps'] <= time_window_2[1])]
271
+
272
+ # Get unique neuron IDs
273
+ unique_neurons = pop_spikes['node_ids'].unique()
274
+
275
+ # Calculate firing rates per neuron for each time window in Hz
276
+ neuron_rates_1 = []
277
+ neuron_rates_2 = []
278
+
279
+ for neuron in unique_neurons:
280
+ # Count spikes for this neuron in each window
281
+ n_spikes_1 = len(pop_spikes_1[pop_spikes_1['node_ids'] == neuron])
282
+ n_spikes_2 = len(pop_spikes_2[pop_spikes_2['node_ids'] == neuron])
283
+
284
+ # Calculate firing rate in Hz (convert ms to seconds by dividing by 1000)
285
+ rate_1 = n_spikes_1 / ((time_window_1[1] - time_window_1[0]) / 1000)
286
+ rate_2 = n_spikes_2 / ((time_window_2[1] - time_window_2[0]) / 1000)
287
+
288
+ neuron_rates_1.append(rate_1)
289
+ neuron_rates_2.append(rate_2)
290
+
291
+ # Calculate average firing rates
292
+ avg_firing_rate_1 = np.mean(neuron_rates_1) if neuron_rates_1 else 0
293
+ avg_firing_rate_2 = np.mean(neuron_rates_2) if neuron_rates_2 else 0
294
+
295
+ # Perform Mann-Whitney U test
296
+ # Handle the case when one or both arrays are empty
297
+ if len(neuron_rates_1) > 0 and len(neuron_rates_2) > 0:
298
+ u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative='two-sided')
299
+ else:
300
+ u_stat, p_val = np.nan, np.nan
301
+
302
+ print(f" Average firing rate in window 1: {avg_firing_rate_1:.2f} Hz")
303
+ print(f" Average firing rate in window 2: {avg_firing_rate_2:.2f} Hz")
304
+ print(f" U-statistic: {u_stat:.2f}")
305
+ print(f" p-value: {p_val}")
306
+ print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
307
+ return
@@ -23,8 +23,9 @@ import sys
23
23
  import re
24
24
  from typing import Optional, Dict, Union, List
25
25
 
26
- from .util.util import CellVarsFile,load_nodes_from_config #, missing_units
26
+ from .util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
27
27
  from bmtk.analyzer.utils import listify
28
+ from neuron import h
28
29
 
29
30
  use_description = """
30
31
 
@@ -682,6 +683,155 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
682
683
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
683
684
  plt.draw()
684
685
 
686
+ def distance_delay_plot(simulation_config: str,source: str,target: str,
687
+ group_by: str,sid: str,tid: str) -> None:
688
+ """
689
+ Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
690
+
691
+ This function loads the node and edge data from a simulation configuration file, filters nodes by population or group,
692
+ identifies connections (edges) between source and target node populations, calculates the Euclidean distance between
693
+ connected nodes, and plots the delay as a function of distance.
694
+
695
+ Args:
696
+ simulation_config (str): Path to the simulation config file
697
+ source (str): The name of the source population in the edge data.
698
+ target (str): The name of the target population in the edge data.
699
+ group_by (str): Column name to group nodes by (e.g., population name).
700
+ sid (str): Identifier for the source group (e.g., 'PN').
701
+ tid (str): Identifier for the target group (e.g., 'PN').
702
+
703
+ Returns:
704
+ None: The function creates and displays a scatter plot of distance vs delay.
705
+ """
706
+ nodes, edges = util.load_nodes_edges_from_config(simulation_config)
707
+ nodes = nodes[target]
708
+ #node id is index of nodes df
709
+ node_id_source = nodes[nodes[group_by] == sid].index
710
+ node_id_target = nodes[nodes[group_by] == tid].index
711
+
712
+ edges = edges[f'{source}_to_{target}']
713
+ edges = edges[edges['source_node_id'].isin(node_id_source) & edges['target_node_id'].isin(node_id_target)]
714
+
715
+ stuff_to_plot = []
716
+ for index, row in edges.iterrows():
717
+ try:
718
+ source_node = row['source_node_id']
719
+ target_node = row['target_node_id']
720
+
721
+ source_pos = nodes.loc[[source_node], ['pos_x', 'pos_y', 'pos_z']]
722
+ target_pos = nodes.loc[[target_node], ['pos_x', 'pos_y', 'pos_z']]
723
+
724
+ distance = np.linalg.norm(source_pos.values - target_pos.values)
725
+
726
+ delay = row['delay'] # This line may raise KeyError
727
+ stuff_to_plot.append([distance, delay])
728
+
729
+ except KeyError as e:
730
+ print(f"KeyError: Missing key {e} in either edge properties or node positions.")
731
+ except IndexError as e:
732
+ print(f"IndexError: Node ID {source_node} or {target_node} not found in nodes.")
733
+ except Exception as e:
734
+ print(f"Unexpected error at edge index {index}: {e}")
735
+
736
+ plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
737
+ plt.xlabel('Distance')
738
+ plt.ylabel('Delay')
739
+ plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
740
+ plt.show()
741
+
742
+ def plot_synapse_location_histograms(config, target_model, source=None, target=None):
743
+ """
744
+ generates a histogram of the positions of the synapses on a cell broken down by section
745
+ config: a BMTK config
746
+ target_model: the name of the model_template used when building the BMTK node
747
+ source: The source BMTK network
748
+ target: The target BMTK network
749
+ """
750
+ # Load mechanisms and template
751
+
752
+ util.load_templates_from_config(config)
753
+
754
+ # Load node and edge data
755
+ nodes, edges = util.load_nodes_edges_from_config(config)
756
+ nodes = nodes[source]
757
+ edges = edges[f'{source}_to_{target}']
758
+
759
+ # Map target_node_id to model_template
760
+ edges['target_model_template'] = edges['target_node_id'].map(nodes['model_template'])
761
+
762
+ # Map source_node_id to pop_name
763
+ edges['source_pop_name'] = edges['source_node_id'].map(nodes['pop_name'])
764
+
765
+ edges = edges[edges['target_model_template'] == target_model]
766
+
767
+ # Create the cell model from target model
768
+ cell = getattr(h, target_model.split(':')[1])()
769
+
770
+ # Create a mapping from section index to section name
771
+ section_id_to_name = {}
772
+ for idx, sec in enumerate(cell.all):
773
+ section_id_to_name[idx] = sec.name()
774
+
775
+ # Add a new column with section names based on afferent_section_id
776
+ edges['afferent_section_name'] = edges['afferent_section_id'].map(section_id_to_name)
777
+
778
+ # Get unique sections and source populations
779
+ unique_pops = edges['source_pop_name'].unique()
780
+
781
+ # Filter to only include sections with data
782
+ section_counts = edges['afferent_section_name'].value_counts()
783
+ sections_with_data = section_counts[section_counts > 0].index.tolist()
784
+
785
+
786
+ # Create a figure with subplots for each section
787
+ plt.figure(figsize=(8,12))
788
+
789
+ # Color map for source populations
790
+ color_map = plt.cm.tab10(np.linspace(0, 1, len(unique_pops)))
791
+ pop_colors = {pop: color for pop, color in zip(unique_pops, color_map)}
792
+
793
+ # Create a histogram for each section
794
+ for i, section in enumerate(sections_with_data):
795
+ ax = plt.subplot(len(sections_with_data), 1, i+1)
796
+
797
+ # Get data for this section
798
+ section_data = edges[edges['afferent_section_name'] == section]
799
+
800
+ # Group by source population
801
+ for pop_name, pop_group in section_data.groupby('source_pop_name'):
802
+ if len(pop_group) > 0:
803
+ ax.hist(pop_group['afferent_section_pos'], bins=15, alpha=0.7,
804
+ label=pop_name, color=pop_colors[pop_name])
805
+
806
+ # Set title and labels
807
+ ax.set_title(f"{section}", fontsize=10)
808
+ ax.set_xlabel('Section Position', fontsize=8)
809
+ ax.set_ylabel('Frequency', fontsize=8)
810
+ ax.tick_params(labelsize=7)
811
+ ax.grid(True, alpha=0.3)
812
+
813
+ # Only add legend to the first plot
814
+ if i == 0:
815
+ ax.legend(fontsize=8)
816
+
817
+ plt.tight_layout()
818
+ plt.suptitle('Connection Distribution by Cell Section and Source Population', fontsize=16, y=1.02)
819
+ if is_notebook:
820
+ plt.show()
821
+ else:
822
+ pass
823
+
824
+ # Create a summary table
825
+ print("Summary of connections by section and source population:")
826
+ pivot_table = edges.pivot_table(
827
+ values='afferent_section_id',
828
+ index='afferent_section_name',
829
+ columns='source_pop_name',
830
+ aggfunc='count',
831
+ fill_value=0
832
+ )
833
+ print(pivot_table)
834
+
685
835
  def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
686
836
  """
687
837
  Function to plot connection information as a heatmap, including handling missing source and target values.
@@ -647,7 +647,7 @@ class SynapseTuner:
647
647
  # Widgets setup (Sliders)
648
648
  freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
649
649
  delays = [125, 250, 500, 1000, 2000, 4000]
650
- durations = [300, 500, 1000, 2000, 5000, 10000]
650
+ durations = [100, 300, 500, 1000, 2000, 5000, 10000]
651
651
  freq0 = 50
652
652
  delay0 = 250
653
653
  duration0 = 300
@@ -6,6 +6,8 @@ import numpy as np
6
6
  from numpy import genfromtxt
7
7
  import h5py
8
8
  import pandas as pd
9
+ import neuron
10
+ from neuron import h
9
11
 
10
12
  #from bmtk.utils.io.cell_vars import CellVarsFile
11
13
  #from bmtk.analyzer.cell_vars import _get_cell_report
@@ -392,6 +394,32 @@ def load_edges_from_paths(edge_paths):#network_dir='network'):
392
394
 
393
395
  return edges_dict
394
396
 
397
+ def load_mechanisms_from_config(config=None):
398
+ """
399
+ loads neuron mechanisms from BMTK config
400
+ """
401
+ if config is None:
402
+ config = 'simulation_config.json'
403
+ config = load_config(config)
404
+ return neuron.load_mechanisms(config['components']['mechanisms_dir'])
405
+
406
+ def load_templates_from_config(config=None):
407
+ if config is None:
408
+ config = 'simulation_config.json'
409
+ config = load_config(config)
410
+ load_mechanisms_from_config(config)
411
+ return load_templates_from_paths(config['components']['templates_dir'])
412
+
413
+ def load_templates_from_paths(template_paths):
414
+ # load all the files in the templates dir
415
+ for item in os.listdir(template_paths):
416
+ item_path = os.path.join(template_paths, item)
417
+ if os.path.isfile(item_path):
418
+ print(f"loading {item_path}")
419
+ h.load_file(item_path)
420
+
421
+
422
+
395
423
  def cell_positions_by_id(config=None, nodes=None, populations=[], popids=[], prepend_pop=True):
396
424
  """
397
425
  Returns a dictionary of arrays of arrays {"population_popid":[[1,2,3],[1,2,4]],...
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.23
3
+ Version: 0.6.9.25
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -18,6 +18,7 @@ bmtool.egg-info/entry_points.txt
18
18
  bmtool.egg-info/requires.txt
19
19
  bmtool.egg-info/top_level.txt
20
20
  bmtool/analysis/__init__.py
21
+ bmtool/analysis/entrainment.py
21
22
  bmtool/analysis/lfp.py
22
23
  bmtool/analysis/netcon_reports.py
23
24
  bmtool/analysis/spikes.py
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.9.23',
9
+ version='0.6.9.25',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes