bmtool 0.7.1.5__py3-none-any.whl → 0.7.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/analysis/entrainment.py +33 -33
- bmtool/analysis/spikes.py +179 -28
- bmtool/bmplot/entrainment.py +168 -2
- bmtool/bmplot/spikes.py +2 -0
- bmtool/synapses.py +76 -25
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/METADATA +1 -1
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/RECORD +11 -11
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/WHEEL +1 -1
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.1.5.dist-info → bmtool-0.7.1.7.dist-info}/top_level.txt +0 -0
bmtool/analysis/entrainment.py
CHANGED
@@ -636,26 +636,26 @@ def calculate_entrainment_per_cell(
|
|
636
636
|
|
637
637
|
|
638
638
|
def calculate_spike_rate_power_correlation(
|
639
|
-
spike_rate,
|
640
|
-
lfp_data,
|
641
|
-
fs,
|
642
|
-
pop_names,
|
643
|
-
filter_method="wavelet",
|
644
|
-
bandwidth=2.0,
|
645
|
-
lowcut=None,
|
646
|
-
highcut=None,
|
647
|
-
freq_range=(10, 100),
|
648
|
-
freq_step=5,
|
639
|
+
spike_rate: xr.DataArray,
|
640
|
+
lfp_data: np.ndarray,
|
641
|
+
fs: float,
|
642
|
+
pop_names: list,
|
643
|
+
filter_method: str = "wavelet",
|
644
|
+
bandwidth: float = 2.0,
|
645
|
+
lowcut: float = None,
|
646
|
+
highcut: float = None,
|
647
|
+
freq_range: tuple = (10, 100),
|
648
|
+
freq_step: float = 5,
|
649
|
+
type_name: str = "raw", # 'raw' or 'smoothed'
|
649
650
|
):
|
650
651
|
"""
|
651
|
-
Calculate correlation between population spike rates and LFP power across frequencies
|
652
|
-
using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
|
652
|
+
Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
|
653
653
|
|
654
|
-
Parameters
|
655
|
-
|
656
|
-
spike_rate :
|
657
|
-
|
658
|
-
lfp_data : np.
|
654
|
+
Parameters
|
655
|
+
----------
|
656
|
+
spike_rate : xr.DataArray
|
657
|
+
Population spike rates with dimensions (time, population[, type])
|
658
|
+
lfp_data : np.ndarray
|
659
659
|
LFP data
|
660
660
|
fs : float
|
661
661
|
Sampling frequency
|
@@ -673,19 +673,17 @@ def calculate_spike_rate_power_correlation(
|
|
673
673
|
Min and max frequency to analyze (default: (10, 100))
|
674
674
|
freq_step : float, optional
|
675
675
|
Step size for frequency analysis (default: 5)
|
676
|
+
type_name : str, optional
|
677
|
+
Which type of spike rate to use if 'type' dimension exists (default: 'raw')
|
676
678
|
|
677
|
-
Returns
|
678
|
-
|
679
|
+
Returns
|
680
|
+
-------
|
679
681
|
correlation_results : dict
|
680
682
|
Dictionary with correlation results for each population and frequency
|
681
683
|
frequencies : array
|
682
684
|
Array of frequencies analyzed
|
683
685
|
"""
|
684
|
-
|
685
|
-
# Define frequency bands to analyze
|
686
686
|
frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
|
687
|
-
|
688
|
-
# Dictionary to store results
|
689
687
|
correlation_results = {pop: {} for pop in pop_names}
|
690
688
|
|
691
689
|
# Calculate power at each frequency band using specified filter
|
@@ -695,20 +693,22 @@ def calculate_spike_rate_power_correlation(
|
|
695
693
|
lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
|
696
694
|
)
|
697
695
|
|
698
|
-
#
|
696
|
+
# For each population, extract the correct spike rate
|
699
697
|
for pop in pop_names:
|
700
|
-
#
|
701
|
-
|
698
|
+
# If 'type' dimension exists, select the type
|
699
|
+
if "type" in spike_rate.dims:
|
700
|
+
pop_rate = spike_rate.sel(population=pop, type=type_name).values
|
701
|
+
else:
|
702
|
+
pop_rate = spike_rate.sel(population=pop).values
|
702
703
|
|
703
704
|
# Calculate correlation with power at each frequency
|
704
705
|
for freq in frequencies:
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
)
|
710
|
-
|
711
|
-
corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
|
706
|
+
lfp_power = power_by_freq[freq]
|
707
|
+
# Ensure lengths match
|
708
|
+
min_len = min(len(pop_rate), len(lfp_power))
|
709
|
+
if len(pop_rate) != len(lfp_power):
|
710
|
+
print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
|
711
|
+
corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
|
712
712
|
correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
|
713
713
|
|
714
714
|
return correlation_results, frequencies
|
bmtool/analysis/spikes.py
CHANGED
@@ -3,11 +3,12 @@ Module for processing BMTK spikes output.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import os
|
6
|
-
from typing import
|
6
|
+
from typing import List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
import h5py
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
11
|
+
import xarray as xr
|
11
12
|
from scipy.stats import mannwhitneyu
|
12
13
|
|
13
14
|
from bmtool.util.util import load_nodes_from_config
|
@@ -213,9 +214,11 @@ def get_population_spike_rate(
|
|
213
214
|
save: bool = False,
|
214
215
|
save_path: Optional[str] = None,
|
215
216
|
normalize: bool = False,
|
216
|
-
|
217
|
+
smooth_window: int = 50, # Window size for smoothing (in time bins)
|
218
|
+
smooth_method: str = "gaussian", # Smoothing method: 'gaussian', 'boxcar', or 'exponential'
|
219
|
+
) -> xr.DataArray:
|
217
220
|
"""
|
218
|
-
Calculate the population spike rate for each population in the given spike data
|
221
|
+
Calculate the population spike rate for each population in the given spike data.
|
219
222
|
|
220
223
|
Parameters
|
221
224
|
----------
|
@@ -239,31 +242,51 @@ def get_population_spike_rate(
|
|
239
242
|
Directory path where the file should be saved if `save` is True (default: None)
|
240
243
|
normalize : bool, optional
|
241
244
|
Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
|
245
|
+
smooth_window : int, optional
|
246
|
+
Window size for smoothing in number of time bins (default: 50)
|
247
|
+
smooth_method : str, optional
|
248
|
+
Smoothing method to use: 'gaussian', 'boxcar', or 'exponential' (default: 'gaussian')
|
242
249
|
|
243
250
|
Returns
|
244
251
|
-------
|
245
|
-
|
246
|
-
|
247
|
-
|
252
|
+
xr.DataArray
|
253
|
+
An xarray DataArray containing the spike rates with dimensions of time, population, and type.
|
254
|
+
The 'type' dimension includes 'raw' and 'smoothed' values.
|
255
|
+
The DataArray includes sampling frequency (fs) as an attribute.
|
256
|
+
If normalize is True, each population's spike rate is scaled to [0, 1].
|
248
257
|
|
249
258
|
Raises
|
250
259
|
------
|
251
260
|
ValueError
|
252
261
|
If `save` is True but `save_path` is not provided.
|
262
|
+
If an invalid smooth_method is specified.
|
253
263
|
|
254
264
|
Notes
|
255
265
|
-----
|
256
|
-
- If `config` is None, the function assumes all cells in each population have fired at least once;
|
257
|
-
|
266
|
+
- If `config` is None, the function assumes all cells in each population have fired at least once;
|
267
|
+
otherwise, the node count may be inaccurate.
|
268
|
+
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization.
|
269
|
+
- Smoothing is applied using scipy.ndimage's filters based on the specified method.
|
258
270
|
"""
|
271
|
+
import numpy as np
|
272
|
+
from scipy import ndimage
|
273
|
+
|
274
|
+
# Validate smoothing method
|
275
|
+
if smooth_method not in ["gaussian", "boxcar", "exponential"]:
|
276
|
+
raise ValueError(
|
277
|
+
f"Invalid smooth_method: {smooth_method}. Choose from 'gaussian', 'boxcar', or 'exponential'."
|
278
|
+
)
|
279
|
+
|
259
280
|
pop_spikes = {}
|
260
281
|
node_number = {}
|
261
282
|
|
262
283
|
if config is None:
|
263
284
|
print(
|
264
|
-
"Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration,
|
285
|
+
"Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
|
286
|
+
)
|
287
|
+
print(
|
288
|
+
"You can provide a config to calculate the correct amount of nodes! for a true population rate."
|
265
289
|
)
|
266
|
-
print("You can provide a config to calculate the correct amount of nodes!")
|
267
290
|
|
268
291
|
if config:
|
269
292
|
if not network_name:
|
@@ -271,7 +294,13 @@ def get_population_spike_rate(
|
|
271
294
|
"Grabbing first network; specify a network name to ensure correct node population is selected."
|
272
295
|
)
|
273
296
|
|
274
|
-
|
297
|
+
# Get t_stop if not provided
|
298
|
+
if t_stop is None:
|
299
|
+
t_stop = spike_data["timestamps"].max()
|
300
|
+
|
301
|
+
# Get population names and prepare data
|
302
|
+
populations = spike_data["pop_name"].unique()
|
303
|
+
for pop_name in populations:
|
275
304
|
ps = spike_data[spike_data["pop_name"] == pop_name]
|
276
305
|
|
277
306
|
if config:
|
@@ -282,12 +311,10 @@ def get_population_spike_rate(
|
|
282
311
|
nodes = list(nodes.values())[0] if nodes else {}
|
283
312
|
nodes = nodes[nodes["pop_name"] == pop_name]
|
284
313
|
node_number[pop_name] = nodes.index.nunique()
|
314
|
+
|
285
315
|
else:
|
286
316
|
node_number[pop_name] = ps["node_ids"].nunique()
|
287
317
|
|
288
|
-
if t_stop is None:
|
289
|
-
t_stop = spike_data["timestamps"].max()
|
290
|
-
|
291
318
|
filtered_spikes = spike_data[
|
292
319
|
(spike_data["pop_name"] == pop_name)
|
293
320
|
& (spike_data["timestamps"] > t_start)
|
@@ -295,29 +322,153 @@ def get_population_spike_rate(
|
|
295
322
|
]
|
296
323
|
pop_spikes[pop_name] = filtered_spikes
|
297
324
|
|
298
|
-
|
299
|
-
|
300
|
-
|
325
|
+
# Calculate time points
|
326
|
+
time = np.arange(t_start, t_stop, 1000 / fs) # Convert sampling frequency to time steps
|
327
|
+
|
328
|
+
# Calculate spike rates for each population
|
329
|
+
spike_rates = []
|
330
|
+
for p in populations:
|
331
|
+
raw_rate = _pop_spike_rate(pop_spikes[p]["timestamps"], (t_start, t_stop, 1000 / fs))
|
332
|
+
rate = fs / node_number[p] * raw_rate
|
333
|
+
spike_rates.append(rate)
|
334
|
+
|
335
|
+
spike_rates_array = np.array(spike_rates).T # Transpose to have time as first dimension
|
336
|
+
|
337
|
+
# Calculate smoothed version for each population
|
338
|
+
smoothed_rates = []
|
339
|
+
|
340
|
+
for i in range(spike_rates_array.shape[1]):
|
341
|
+
pop_rate = spike_rates_array[:, i]
|
342
|
+
|
343
|
+
if smooth_method == "gaussian":
|
344
|
+
# Gaussian smoothing (sigma is approximately window/6 for a Gaussian filter)
|
345
|
+
sigma = smooth_window / 6
|
346
|
+
smoothed_pop_rate = ndimage.gaussian_filter1d(pop_rate, sigma=sigma)
|
347
|
+
elif smooth_method == "boxcar":
|
348
|
+
# Boxcar/uniform smoothing
|
349
|
+
kernel = np.ones(smooth_window) / smooth_window
|
350
|
+
smoothed_pop_rate = ndimage.convolve1d(pop_rate, kernel, mode="nearest")
|
351
|
+
elif smooth_method == "exponential":
|
352
|
+
# Exponential smoothing
|
353
|
+
alpha = 2 / (smooth_window + 1) # Equivalent to window size in exponential smoothing
|
354
|
+
smoothed_pop_rate = np.zeros_like(pop_rate)
|
355
|
+
smoothed_pop_rate[0] = pop_rate[0]
|
356
|
+
for t in range(1, len(pop_rate)):
|
357
|
+
smoothed_pop_rate[t] = alpha * pop_rate[t] + (1 - alpha) * smoothed_pop_rate[t - 1]
|
358
|
+
|
359
|
+
smoothed_rates.append(smoothed_pop_rate)
|
360
|
+
|
361
|
+
smoothed_rates_array = np.array(smoothed_rates).T # Transpose to have time as first dimension
|
362
|
+
|
363
|
+
# Stack raw and smoothed data
|
364
|
+
combined_data = np.stack([spike_rates_array, smoothed_rates_array], axis=2)
|
365
|
+
|
366
|
+
# Create DataArray with the additional 'type' dimension
|
367
|
+
spike_rate_array = xr.DataArray(
|
368
|
+
combined_data,
|
369
|
+
coords={"time": time, "population": populations, "type": ["raw", "smoothed"]},
|
370
|
+
dims=["time", "population", "type"],
|
371
|
+
attrs={
|
372
|
+
"fs": fs,
|
373
|
+
"normalized": False,
|
374
|
+
"smooth_method": smooth_method,
|
375
|
+
"smooth_window": smooth_window,
|
376
|
+
},
|
377
|
+
)
|
301
378
|
|
302
|
-
# Normalize
|
379
|
+
# Normalize if requested
|
303
380
|
if normalize:
|
304
|
-
|
305
|
-
|
381
|
+
# Apply normalization for each population and each type (raw/smoothed)
|
382
|
+
for pop_idx in range(len(populations)):
|
383
|
+
for type_idx, type_name in enumerate(["raw", "smoothed"]):
|
384
|
+
pop_data = spike_rate_array.sel(population=populations[pop_idx], type=type_name)
|
385
|
+
min_val = pop_data.min(dim="time")
|
386
|
+
max_val = pop_data.max(dim="time")
|
387
|
+
|
388
|
+
# Handle case where min == max (constant signal)
|
389
|
+
if max_val != min_val:
|
390
|
+
spike_rate_array.loc[:, populations[pop_idx], type_name] = (
|
391
|
+
pop_data - min_val
|
392
|
+
) / (max_val - min_val)
|
393
|
+
|
394
|
+
spike_rate_array.attrs["normalized"] = True
|
395
|
+
|
396
|
+
# Save if requested
|
306
397
|
if save:
|
307
398
|
if save_path is None:
|
308
399
|
raise ValueError("save_path must be provided if save is True.")
|
309
400
|
|
310
401
|
os.makedirs(save_path, exist_ok=True)
|
311
|
-
|
312
402
|
save_file = os.path.join(save_path, "spike_rate.h5")
|
313
|
-
|
314
|
-
f.create_dataset("time", data=time)
|
315
|
-
grp = f.create_group("populations")
|
316
|
-
for p, rspk in spike_rate.items():
|
317
|
-
pop_grp = grp.create_group(p)
|
318
|
-
pop_grp.create_dataset("data", data=rspk)
|
403
|
+
spike_rate_array.to_netcdf(save_file)
|
319
404
|
|
320
|
-
return
|
405
|
+
return spike_rate_array
|
406
|
+
|
407
|
+
|
408
|
+
def average_spike_rate_over_windows(
|
409
|
+
spike_rate: xr.DataArray, windows: List[Tuple[float, float]]
|
410
|
+
) -> xr.DataArray:
|
411
|
+
"""
|
412
|
+
Calculate the average spike rate over multiple time windows.
|
413
|
+
|
414
|
+
Parameters
|
415
|
+
----------
|
416
|
+
spike_rate : xr.DataArray
|
417
|
+
The spike rate data array with dimensions (time, population, type)
|
418
|
+
where 'type' can be 'raw' or 'smoothed'
|
419
|
+
windows : List[Tuple[float, float]]
|
420
|
+
List of (start, end) times in milliseconds defining the windows to average over
|
421
|
+
|
422
|
+
Returns
|
423
|
+
-------
|
424
|
+
xr.DataArray
|
425
|
+
Averaged spike rate with time normalized to start at 0,
|
426
|
+
preserving all original dimensions (time, population, type)
|
427
|
+
"""
|
428
|
+
# Check if the DataArray has a 'type' dimension (compatible with new format)
|
429
|
+
has_type_dim = "type" in spike_rate.dims
|
430
|
+
|
431
|
+
# Initialize list to store data from each window
|
432
|
+
window_data = []
|
433
|
+
|
434
|
+
# Get data for each window
|
435
|
+
for start, end in windows:
|
436
|
+
# Select data points within the window
|
437
|
+
window = spike_rate.sel(time=slice(start, end))
|
438
|
+
|
439
|
+
# Normalize time to start at 0 for this window
|
440
|
+
window = window.assign_coords(time=window.time - start)
|
441
|
+
window_data.append(window)
|
442
|
+
|
443
|
+
# Align and average windows
|
444
|
+
# First window determines the time coordinates
|
445
|
+
aligned_data = xr.concat(window_data, dim="window")
|
446
|
+
averaged_data = aligned_data.mean(dim="window")
|
447
|
+
|
448
|
+
# Create new DataArray with the averaged data
|
449
|
+
if has_type_dim:
|
450
|
+
# Create result with time, population, and type dimensions
|
451
|
+
result = xr.DataArray(
|
452
|
+
averaged_data.values,
|
453
|
+
coords={
|
454
|
+
"time": averaged_data.time.values,
|
455
|
+
"population": averaged_data.population,
|
456
|
+
"type": averaged_data.type,
|
457
|
+
},
|
458
|
+
dims=["time", "population", "type"],
|
459
|
+
)
|
460
|
+
else:
|
461
|
+
# Handle older format without 'type' dimension (for backward compatibility)
|
462
|
+
result = xr.DataArray(
|
463
|
+
averaged_data.values,
|
464
|
+
coords={"time": averaged_data.time.values, "population": averaged_data.population},
|
465
|
+
dims=["time", "population"],
|
466
|
+
)
|
467
|
+
|
468
|
+
# Preserve attributes
|
469
|
+
result.attrs = spike_rate.attrs
|
470
|
+
|
471
|
+
return result
|
321
472
|
|
322
473
|
|
323
474
|
def compare_firing_over_times(
|
bmtool/bmplot/entrainment.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import matplotlib.pyplot as plt
|
2
2
|
import numpy as np
|
3
|
+
import pandas as pd
|
3
4
|
import seaborn as sns
|
4
5
|
from matplotlib.gridspec import GridSpec
|
5
6
|
from scipy import stats
|
@@ -55,7 +56,7 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
|
55
56
|
|
56
57
|
def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
|
57
58
|
"""
|
58
|
-
Plot an idealized
|
59
|
+
Plot an idealized cycle with spike histograms for different neuron populations.
|
59
60
|
|
60
61
|
Parameters:
|
61
62
|
-----------
|
@@ -120,7 +121,7 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
|
|
120
121
|
plt.show()
|
121
122
|
|
122
123
|
|
123
|
-
def
|
124
|
+
def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
|
124
125
|
"""
|
125
126
|
Plot PPC for all node populations on one graph with mean and standard error.
|
126
127
|
|
@@ -200,3 +201,168 @@ def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=No
|
|
200
201
|
# Adjust layout and save
|
201
202
|
plt.tight_layout()
|
202
203
|
plt.show()
|
204
|
+
|
205
|
+
|
206
|
+
def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title=None):
|
207
|
+
"""
|
208
|
+
Plot a swarm plot of the entrainment for different populations at a single frequency.
|
209
|
+
|
210
|
+
Parameters:
|
211
|
+
-----------
|
212
|
+
ppc_dict : dict
|
213
|
+
Dictionary containing PPC values organized by population, node, and frequency
|
214
|
+
pop_names : list
|
215
|
+
List of population names to include in the plot
|
216
|
+
freq : float or int
|
217
|
+
The specific frequency to plot
|
218
|
+
save_path : str, optional
|
219
|
+
Path to save the figure. If None, figure is just displayed.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
--------
|
223
|
+
matplotlib.figure.Figure
|
224
|
+
The figure object for further customization if needed
|
225
|
+
"""
|
226
|
+
# Set the style
|
227
|
+
sns.set_style("whitegrid")
|
228
|
+
|
229
|
+
# Prepare data for the swarm plot
|
230
|
+
data_list = []
|
231
|
+
|
232
|
+
for pop in pop_names:
|
233
|
+
values = []
|
234
|
+
node_ids = []
|
235
|
+
|
236
|
+
for node in ppc_dict[pop]:
|
237
|
+
if freq in ppc_dict[pop][node] and ppc_dict[pop][node][freq] is not None:
|
238
|
+
data_list.append(
|
239
|
+
{"Population": pop, "Node": node, "PPC Difference": ppc_dict[pop][node][freq]}
|
240
|
+
)
|
241
|
+
|
242
|
+
# Create DataFrame in long format
|
243
|
+
df = pd.DataFrame(data_list)
|
244
|
+
|
245
|
+
if df.empty:
|
246
|
+
print(f"No data available for frequency {freq}.")
|
247
|
+
return None
|
248
|
+
|
249
|
+
# Print mean PPC change for each population)
|
250
|
+
for pop in pop_names:
|
251
|
+
subset = df[df["Population"] == pop]
|
252
|
+
if not subset.empty:
|
253
|
+
mean_val = subset["PPC Difference"].mean()
|
254
|
+
std_val = subset["PPC Difference"].std()
|
255
|
+
n = len(subset)
|
256
|
+
sem_val = std_val / np.sqrt(n) # Standard error of the mean
|
257
|
+
print(f"{pop}: {mean_val:.4f} ± {sem_val:.4f} (n={n})")
|
258
|
+
|
259
|
+
# Create figure
|
260
|
+
plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
|
261
|
+
|
262
|
+
# Create swarm plot
|
263
|
+
ax = sns.swarmplot(
|
264
|
+
x="Population",
|
265
|
+
y="PPC Difference",
|
266
|
+
data=df,
|
267
|
+
size=3,
|
268
|
+
# palette='Set2'
|
269
|
+
)
|
270
|
+
|
271
|
+
# Add sample size annotations
|
272
|
+
for i, pop in enumerate(pop_names):
|
273
|
+
subset = df[df["Population"] == pop]
|
274
|
+
if not subset.empty:
|
275
|
+
n = len(subset)
|
276
|
+
y_min = subset["PPC Difference"].min()
|
277
|
+
y_max = subset["PPC Difference"].max()
|
278
|
+
|
279
|
+
# Position annotation below the lowest point
|
280
|
+
plt.annotate(
|
281
|
+
f"n={n}", (i, y_min - 0.05 * (y_max - y_min) - 0.05), ha="center", fontsize=10
|
282
|
+
)
|
283
|
+
|
284
|
+
# Add reference line at y=0
|
285
|
+
plt.axhline(y=0, color="black", linestyle="-", linewidth=0.5, alpha=0.7)
|
286
|
+
|
287
|
+
# Add horizontal lines for mean values
|
288
|
+
for i, pop in enumerate(pop_names):
|
289
|
+
subset = df[df["Population"] == pop]
|
290
|
+
if not subset.empty:
|
291
|
+
mean_val = subset["PPC Difference"].mean()
|
292
|
+
plt.plot([i - 0.25, i + 0.25], [mean_val, mean_val], "r-", linewidth=2)
|
293
|
+
|
294
|
+
# Calculate and display statistics
|
295
|
+
if len(pop_names) > 1:
|
296
|
+
# Print statistical test results
|
297
|
+
print(f"\nMann-Whitney U Test Results at {freq} Hz:")
|
298
|
+
print("-" * 60)
|
299
|
+
|
300
|
+
# Add p-values for pairwise comparisons
|
301
|
+
y_max = df["PPC Difference"].max()
|
302
|
+
y_min = df["PPC Difference"].min()
|
303
|
+
y_range = y_max - y_min
|
304
|
+
|
305
|
+
# Perform t-tests between populations if there are at least 2
|
306
|
+
for i in range(len(pop_names)):
|
307
|
+
for j in range(i + 1, len(pop_names)):
|
308
|
+
pop1 = pop_names[i]
|
309
|
+
pop2 = pop_names[j]
|
310
|
+
|
311
|
+
vals1 = df[df["Population"] == pop1]["PPC Difference"].values
|
312
|
+
vals2 = df[df["Population"] == pop2]["PPC Difference"].values
|
313
|
+
|
314
|
+
if len(vals1) > 1 and len(vals2) > 1:
|
315
|
+
# Perform Mann-Whitney U test (non-parametric)
|
316
|
+
u_stat, p_val = stats.mannwhitneyu(vals1, vals2, alternative="two-sided")
|
317
|
+
|
318
|
+
# Add significance markers
|
319
|
+
sig_str = "ns"
|
320
|
+
if p_val < 0.05:
|
321
|
+
sig_str = "*"
|
322
|
+
if p_val < 0.01:
|
323
|
+
sig_str = "**"
|
324
|
+
if p_val < 0.001:
|
325
|
+
sig_str = "***"
|
326
|
+
|
327
|
+
# Position the significance bar
|
328
|
+
bar_height = y_max + 0.1 * y_range * (1 + (j - i - 1) * 0.5)
|
329
|
+
|
330
|
+
# Draw the bar
|
331
|
+
plt.plot([i, j], [bar_height, bar_height], "k-")
|
332
|
+
plt.plot([i, i], [bar_height - 0.02 * y_range, bar_height], "k-")
|
333
|
+
plt.plot([j, j], [bar_height - 0.02 * y_range, bar_height], "k-")
|
334
|
+
|
335
|
+
# Add significance marker
|
336
|
+
plt.text(
|
337
|
+
(i + j) / 2,
|
338
|
+
bar_height + 0.01 * y_range,
|
339
|
+
sig_str,
|
340
|
+
ha="center",
|
341
|
+
va="bottom",
|
342
|
+
fontsize=12,
|
343
|
+
)
|
344
|
+
|
345
|
+
# Print the statistical comparison
|
346
|
+
print(f"{pop1} vs {pop2}: U={u_stat:.1f}, p={p_val:.4f} {sig_str}")
|
347
|
+
|
348
|
+
# Add labels and title
|
349
|
+
plt.xlabel("Population", fontsize=14)
|
350
|
+
plt.ylabel("PPC", fontsize=14)
|
351
|
+
if title:
|
352
|
+
plt.title(title, fontsize=16)
|
353
|
+
|
354
|
+
# Adjust y-axis limits to make room for annotations
|
355
|
+
y_min, y_max = plt.ylim()
|
356
|
+
plt.ylim(y_min - 0.15 * (y_max - y_min), y_max + 0.25 * (y_max - y_min))
|
357
|
+
|
358
|
+
# Add gridlines
|
359
|
+
plt.grid(True, linestyle="--", alpha=0.7, axis="y")
|
360
|
+
|
361
|
+
# Adjust layout
|
362
|
+
plt.tight_layout()
|
363
|
+
|
364
|
+
# Save figure if path is provided
|
365
|
+
if save_path:
|
366
|
+
plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
|
367
|
+
|
368
|
+
plt.show()
|
bmtool/bmplot/spikes.py
CHANGED
bmtool/synapses.py
CHANGED
@@ -18,15 +18,18 @@ from scipy.optimize import curve_fit, minimize, minimize_scalar
|
|
18
18
|
from scipy.signal import find_peaks
|
19
19
|
from tqdm.notebook import tqdm
|
20
20
|
|
21
|
+
from bmtool.util.util import load_mechanisms_from_config, load_templates_from_config
|
22
|
+
|
21
23
|
|
22
24
|
class SynapseTuner:
|
23
25
|
def __init__(
|
24
26
|
self,
|
25
|
-
mechanisms_dir: str,
|
26
|
-
templates_dir: str,
|
27
|
-
|
28
|
-
|
29
|
-
|
27
|
+
mechanisms_dir: str = None,
|
28
|
+
templates_dir: str = None,
|
29
|
+
config: str = None,
|
30
|
+
conn_type_settings: dict = None,
|
31
|
+
connection: str = None,
|
32
|
+
general_settings: dict = None,
|
30
33
|
json_folder_path: str = None,
|
31
34
|
current_name: str = "i",
|
32
35
|
other_vars_to_record: list = None,
|
@@ -57,8 +60,18 @@ class SynapseTuner:
|
|
57
60
|
List of synaptic variables you would like sliders set up for the STP sliders method by default will use all parameters in spec_syn_param.
|
58
61
|
|
59
62
|
"""
|
60
|
-
|
61
|
-
|
63
|
+
if config is None and (mechanisms_dir is None or templates_dir is None):
|
64
|
+
raise ValueError(
|
65
|
+
"Either a config file or both mechanisms_dir and templates_dir must be provided."
|
66
|
+
)
|
67
|
+
|
68
|
+
if config is None:
|
69
|
+
neuron.load_mechanisms(mechanisms_dir)
|
70
|
+
h.load_file(templates_dir)
|
71
|
+
else:
|
72
|
+
load_mechanisms_from_config(config)
|
73
|
+
load_templates_from_config(config)
|
74
|
+
|
62
75
|
self.conn_type_settings = conn_type_settings
|
63
76
|
if json_folder_path:
|
64
77
|
print(f"updating settings from json path {json_folder_path}")
|
@@ -939,10 +952,11 @@ class SynapseTuner:
|
|
939
952
|
class GapJunctionTuner:
|
940
953
|
def __init__(
|
941
954
|
self,
|
942
|
-
mechanisms_dir: str,
|
943
|
-
templates_dir: str,
|
944
|
-
|
945
|
-
|
955
|
+
mechanisms_dir: str = None,
|
956
|
+
templates_dir: str = None,
|
957
|
+
config: str = None,
|
958
|
+
general_settings: dict = None,
|
959
|
+
conn_type_settings: dict = None,
|
946
960
|
):
|
947
961
|
"""
|
948
962
|
Initialize the GapJunctionTuner class.
|
@@ -953,13 +967,24 @@ class GapJunctionTuner:
|
|
953
967
|
Directory path containing the compiled mod files needed for NEURON mechanisms.
|
954
968
|
templates_dir : str
|
955
969
|
Directory path containing cell template files (.hoc or .py) loaded into NEURON.
|
970
|
+
config : str
|
971
|
+
Path to a BMTK config.json file. Can be used to load mechanisms, templates, and other settings.
|
956
972
|
general_settings : dict
|
957
973
|
General settings dictionary including parameters like simulation time step, duration, and temperature.
|
958
974
|
conn_type_settings : dict
|
959
975
|
A dictionary containing connection-specific settings for gap junctions.
|
960
976
|
"""
|
961
|
-
|
962
|
-
|
977
|
+
if config is None and (mechanisms_dir is None or templates_dir is None):
|
978
|
+
raise ValueError(
|
979
|
+
"Either a config file or both mechanisms_dir and templates_dir must be provided."
|
980
|
+
)
|
981
|
+
|
982
|
+
if config is None:
|
983
|
+
neuron.load_mechanisms(mechanisms_dir)
|
984
|
+
h.load_file(templates_dir)
|
985
|
+
else:
|
986
|
+
load_mechanisms_from_config(config)
|
987
|
+
load_templates_from_config(config)
|
963
988
|
|
964
989
|
self.general_settings = general_settings
|
965
990
|
self.conn_type_settings = conn_type_settings
|
@@ -1049,7 +1074,6 @@ class GapJunctionTuner:
|
|
1049
1074
|
plt.xlabel("Time (ms)")
|
1050
1075
|
plt.ylabel("Membrane Voltage (mV)")
|
1051
1076
|
plt.legend()
|
1052
|
-
plt.show()
|
1053
1077
|
|
1054
1078
|
def coupling_coefficient(self, t, v1, v2, t_start, t_end, dt=h.dt):
|
1055
1079
|
"""
|
@@ -1085,28 +1109,55 @@ class GapJunctionTuner:
|
|
1085
1109
|
|
1086
1110
|
def InteractiveTuner(self):
|
1087
1111
|
w_run = widgets.Button(description="Run", icon="history", button_style="primary")
|
1088
|
-
values = [i * 10**-4 for i in range(1,
|
1112
|
+
values = [i * 10**-4 for i in range(1, 1001)] # From 1e-4 to 1e-1
|
1089
1113
|
|
1090
1114
|
# Create the SelectionSlider widget with appropriate formatting
|
1091
|
-
resistance = widgets.
|
1092
|
-
|
1093
|
-
|
1115
|
+
resistance = widgets.FloatLogSlider(
|
1116
|
+
value=0.001,
|
1117
|
+
base=10,
|
1118
|
+
min=-4, # max exponent of base
|
1119
|
+
max=-1, # min exponent of base
|
1120
|
+
step=0.1, # exponent step
|
1094
1121
|
description="Resistance: ",
|
1095
1122
|
continuous_update=True,
|
1096
1123
|
)
|
1097
1124
|
|
1098
1125
|
ui = VBox([w_run, resistance])
|
1126
|
+
|
1127
|
+
# Create an output widget to control what gets cleared
|
1128
|
+
output = widgets.Output()
|
1129
|
+
|
1099
1130
|
display(ui)
|
1131
|
+
display(output)
|
1100
1132
|
|
1101
1133
|
def on_button(*args):
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1134
|
+
with output:
|
1135
|
+
# Clear only the output widget, not the entire cell
|
1136
|
+
output.clear_output(wait=True)
|
1137
|
+
|
1138
|
+
resistance_for_gap = resistance.value
|
1139
|
+
print(f"Running simulation with resistance: {resistance_for_gap}")
|
1140
|
+
|
1141
|
+
try:
|
1142
|
+
self.model(resistance_for_gap)
|
1143
|
+
self.plot_model()
|
1144
|
+
|
1145
|
+
# Convert NEURON vectors to numpy arrays
|
1146
|
+
t_array = np.array(self.t_vec)
|
1147
|
+
v1_array = np.array(self.soma_v_1)
|
1148
|
+
v2_array = np.array(self.soma_v_2)
|
1149
|
+
|
1150
|
+
cc = self.coupling_coefficient(t_array, v1_array, v2_array, 500, 1000)
|
1151
|
+
print(f"coupling_coefficient is {cc:0.4f}")
|
1152
|
+
plt.show()
|
1153
|
+
|
1154
|
+
except Exception as e:
|
1155
|
+
print(f"Error during simulation or analysis: {e}")
|
1156
|
+
import traceback
|
1157
|
+
|
1158
|
+
traceback.print_exc()
|
1109
1159
|
|
1160
|
+
# Run once initially
|
1110
1161
|
on_button()
|
1111
1162
|
w_run.on_click(on_button)
|
1112
1163
|
|
@@ -6,18 +6,18 @@ bmtool/graphs.py,sha256=gBTzI6c2BBK49dWGcfWh9c56TAooyn-KaiEy0Im1HcI,6717
|
|
6
6
|
bmtool/manage.py,sha256=lsgRejp02P-x6QpA7SXcyXdalPhRmypoviIA2uAitQs,608
|
7
7
|
bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
|
8
8
|
bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
|
9
|
-
bmtool/synapses.py,sha256=
|
9
|
+
bmtool/synapses.py,sha256=hRuxRCXVpu0_0egi183qyp343tT-_gZNSxjk9rT5J8Q,66175
|
10
10
|
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
bmtool/analysis/entrainment.py,sha256=
|
11
|
+
bmtool/analysis/entrainment.py,sha256=TUeV-WfCfVPqilVTjg6Cv1WKOz_zSX7LeE7k2Wuceug,28449
|
12
12
|
bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
|
13
13
|
bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
|
14
|
-
bmtool/analysis/spikes.py,sha256=
|
14
|
+
bmtool/analysis/spikes.py,sha256=u7Qu0NVGPDAH5jlgNLv32H1hDAkOlG6P4nKEFeAOkdE,22833
|
15
15
|
bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
|
17
|
-
bmtool/bmplot/entrainment.py,sha256=
|
17
|
+
bmtool/bmplot/entrainment.py,sha256=YBHTJ-nK0OgM8CNssM8IyqPNYez9ss9bQi-C5HW4kGw,12593
|
18
18
|
bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
|
19
19
|
bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
bmtool/bmplot/spikes.py,sha256=
|
20
|
+
bmtool/bmplot/spikes.py,sha256=RJOOtmgWhTvyVi1CghoKTtxvt7MF9cJCrJVm5hV5wA4,11210
|
21
21
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
|
23
23
|
bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
|
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
|
|
26
26
|
bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
|
27
27
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
|
29
|
-
bmtool-0.7.1.
|
30
|
-
bmtool-0.7.1.
|
31
|
-
bmtool-0.7.1.
|
32
|
-
bmtool-0.7.1.
|
33
|
-
bmtool-0.7.1.
|
34
|
-
bmtool-0.7.1.
|
29
|
+
bmtool-0.7.1.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
30
|
+
bmtool-0.7.1.7.dist-info/METADATA,sha256=-2fuMCtlaM_YoVuYHcuhNAGe-Cw-5Yfb3kqejIV7S6c,3577
|
31
|
+
bmtool-0.7.1.7.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
32
|
+
bmtool-0.7.1.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
33
|
+
bmtool-0.7.1.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
34
|
+
bmtool-0.7.1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|