bmtool 0.7.1.5__py3-none-any.whl → 0.7.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -636,26 +636,26 @@ def calculate_entrainment_per_cell(
636
636
 
637
637
 
638
638
  def calculate_spike_rate_power_correlation(
639
- spike_rate,
640
- lfp_data,
641
- fs,
642
- pop_names,
643
- filter_method="wavelet",
644
- bandwidth=2.0,
645
- lowcut=None,
646
- highcut=None,
647
- freq_range=(10, 100),
648
- freq_step=5,
639
+ spike_rate: xr.DataArray,
640
+ lfp_data: np.ndarray,
641
+ fs: float,
642
+ pop_names: list,
643
+ filter_method: str = "wavelet",
644
+ bandwidth: float = 2.0,
645
+ lowcut: float = None,
646
+ highcut: float = None,
647
+ freq_range: tuple = (10, 100),
648
+ freq_step: float = 5,
649
+ type_name: str = "raw", # 'raw' or 'smoothed'
649
650
  ):
650
651
  """
651
- Calculate correlation between population spike rates and LFP power across frequencies
652
- using wavelet filtering. This function assumes the fs of the spike_rate and lfp are the same.
652
+ Calculate correlation between population spike rates (xarray) and LFP power across frequencies.
653
653
 
654
- Parameters:
655
- -----------
656
- spike_rate : DataFrame
657
- Pre-calculated population spike rates at the same fs as lfp
658
- lfp_data : np.array
654
+ Parameters
655
+ ----------
656
+ spike_rate : xr.DataArray
657
+ Population spike rates with dimensions (time, population[, type])
658
+ lfp_data : np.ndarray
659
659
  LFP data
660
660
  fs : float
661
661
  Sampling frequency
@@ -673,19 +673,17 @@ def calculate_spike_rate_power_correlation(
673
673
  Min and max frequency to analyze (default: (10, 100))
674
674
  freq_step : float, optional
675
675
  Step size for frequency analysis (default: 5)
676
+ type_name : str, optional
677
+ Which type of spike rate to use if 'type' dimension exists (default: 'raw')
676
678
 
677
- Returns:
678
- --------
679
+ Returns
680
+ -------
679
681
  correlation_results : dict
680
682
  Dictionary with correlation results for each population and frequency
681
683
  frequencies : array
682
684
  Array of frequencies analyzed
683
685
  """
684
-
685
- # Define frequency bands to analyze
686
686
  frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)
687
-
688
- # Dictionary to store results
689
687
  correlation_results = {pop: {} for pop in pop_names}
690
688
 
691
689
  # Calculate power at each frequency band using specified filter
@@ -695,20 +693,22 @@ def calculate_spike_rate_power_correlation(
695
693
  lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
696
694
  )
697
695
 
698
- # Calculate correlation for each population
696
+ # For each population, extract the correct spike rate
699
697
  for pop in pop_names:
700
- # Extract spike rate for this population
701
- pop_rate = spike_rate[pop]
698
+ # If 'type' dimension exists, select the type
699
+ if "type" in spike_rate.dims:
700
+ pop_rate = spike_rate.sel(population=pop, type=type_name).values
701
+ else:
702
+ pop_rate = spike_rate.sel(population=pop).values
702
703
 
703
704
  # Calculate correlation with power at each frequency
704
705
  for freq in frequencies:
705
- # Make sure the lengths match
706
- if len(pop_rate) != len(power_by_freq[freq]):
707
- raise ValueError(
708
- f"Mismatched lengths for {pop} at {freq} Hz len(pop_rate): {len(pop_rate)}, len(power_by_freq): {len(power_by_freq[freq])}"
709
- )
710
- # use spearman for non-parametric correlation
711
- corr, p_val = stats.spearmanr(pop_rate, power_by_freq[freq])
706
+ lfp_power = power_by_freq[freq]
707
+ # Ensure lengths match
708
+ min_len = min(len(pop_rate), len(lfp_power))
709
+ if len(pop_rate) != len(lfp_power):
710
+ print(f"Warning: Length mismatch for {pop} at {freq} Hz, truncating to {min_len}")
711
+ corr, p_val = stats.spearmanr(pop_rate[:min_len], lfp_power[:min_len])
712
712
  correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
713
713
 
714
714
  return correlation_results, frequencies
bmtool/analysis/spikes.py CHANGED
@@ -3,11 +3,12 @@ Module for processing BMTK spikes output.
3
3
  """
4
4
 
5
5
  import os
6
- from typing import Dict, List, Optional, Tuple, Union
6
+ from typing import List, Optional, Tuple, Union
7
7
 
8
8
  import h5py
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import xarray as xr
11
12
  from scipy.stats import mannwhitneyu
12
13
 
13
14
  from bmtool.util.util import load_nodes_from_config
@@ -213,9 +214,11 @@ def get_population_spike_rate(
213
214
  save: bool = False,
214
215
  save_path: Optional[str] = None,
215
216
  normalize: bool = False,
216
- ) -> Dict[str, np.ndarray]:
217
+ smooth_window: int = 50, # Window size for smoothing (in time bins)
218
+ smooth_method: str = "gaussian", # Smoothing method: 'gaussian', 'boxcar', or 'exponential'
219
+ ) -> xr.DataArray:
217
220
  """
218
- Calculate the population spike rate for each population in the given spike data, with an option to normalize.
221
+ Calculate the population spike rate for each population in the given spike data.
219
222
 
220
223
  Parameters
221
224
  ----------
@@ -239,31 +242,51 @@ def get_population_spike_rate(
239
242
  Directory path where the file should be saved if `save` is True (default: None)
240
243
  normalize : bool, optional
241
244
  Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
245
+ smooth_window : int, optional
246
+ Window size for smoothing in number of time bins (default: 50)
247
+ smooth_method : str, optional
248
+ Smoothing method to use: 'gaussian', 'boxcar', or 'exponential' (default: 'gaussian')
242
249
 
243
250
  Returns
244
251
  -------
245
- Dict[str, np.ndarray]
246
- A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
247
- If `normalize` is True, each population's spike rate is scaled to [0, 1].
252
+ xr.DataArray
253
+ An xarray DataArray containing the spike rates with dimensions of time, population, and type.
254
+ The 'type' dimension includes 'raw' and 'smoothed' values.
255
+ The DataArray includes sampling frequency (fs) as an attribute.
256
+ If normalize is True, each population's spike rate is scaled to [0, 1].
248
257
 
249
258
  Raises
250
259
  ------
251
260
  ValueError
252
261
  If `save` is True but `save_path` is not provided.
262
+ If an invalid smooth_method is specified.
253
263
 
254
264
  Notes
255
265
  -----
256
- - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
257
- - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
266
+ - If `config` is None, the function assumes all cells in each population have fired at least once;
267
+ otherwise, the node count may be inaccurate.
268
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization.
269
+ - Smoothing is applied using scipy.ndimage's filters based on the specified method.
258
270
  """
271
+ import numpy as np
272
+ from scipy import ndimage
273
+
274
+ # Validate smoothing method
275
+ if smooth_method not in ["gaussian", "boxcar", "exponential"]:
276
+ raise ValueError(
277
+ f"Invalid smooth_method: {smooth_method}. Choose from 'gaussian', 'boxcar', or 'exponential'."
278
+ )
279
+
259
280
  pop_spikes = {}
260
281
  node_number = {}
261
282
 
262
283
  if config is None:
263
284
  print(
264
- "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect."
285
+ "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, or not all cells fired,\nthen this count will not include all nodes so the firing rate will not be of the whole population!"
286
+ )
287
+ print(
288
+ "You can provide a config to calculate the correct amount of nodes! for a true population rate."
265
289
  )
266
- print("You can provide a config to calculate the correct amount of nodes!")
267
290
 
268
291
  if config:
269
292
  if not network_name:
@@ -271,7 +294,13 @@ def get_population_spike_rate(
271
294
  "Grabbing first network; specify a network name to ensure correct node population is selected."
272
295
  )
273
296
 
274
- for pop_name in spike_data["pop_name"].unique():
297
+ # Get t_stop if not provided
298
+ if t_stop is None:
299
+ t_stop = spike_data["timestamps"].max()
300
+
301
+ # Get population names and prepare data
302
+ populations = spike_data["pop_name"].unique()
303
+ for pop_name in populations:
275
304
  ps = spike_data[spike_data["pop_name"] == pop_name]
276
305
 
277
306
  if config:
@@ -282,12 +311,10 @@ def get_population_spike_rate(
282
311
  nodes = list(nodes.values())[0] if nodes else {}
283
312
  nodes = nodes[nodes["pop_name"] == pop_name]
284
313
  node_number[pop_name] = nodes.index.nunique()
314
+
285
315
  else:
286
316
  node_number[pop_name] = ps["node_ids"].nunique()
287
317
 
288
- if t_stop is None:
289
- t_stop = spike_data["timestamps"].max()
290
-
291
318
  filtered_spikes = spike_data[
292
319
  (spike_data["pop_name"] == pop_name)
293
320
  & (spike_data["timestamps"] > t_start)
@@ -295,29 +322,153 @@ def get_population_spike_rate(
295
322
  ]
296
323
  pop_spikes[pop_name] = filtered_spikes
297
324
 
298
- time = np.array([t_start, t_stop, 1000 / fs])
299
- pop_rspk = {p: _pop_spike_rate(spk["timestamps"], time) for p, spk in pop_spikes.items()}
300
- spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
325
+ # Calculate time points
326
+ time = np.arange(t_start, t_stop, 1000 / fs) # Convert sampling frequency to time steps
327
+
328
+ # Calculate spike rates for each population
329
+ spike_rates = []
330
+ for p in populations:
331
+ raw_rate = _pop_spike_rate(pop_spikes[p]["timestamps"], (t_start, t_stop, 1000 / fs))
332
+ rate = fs / node_number[p] * raw_rate
333
+ spike_rates.append(rate)
334
+
335
+ spike_rates_array = np.array(spike_rates).T # Transpose to have time as first dimension
336
+
337
+ # Calculate smoothed version for each population
338
+ smoothed_rates = []
339
+
340
+ for i in range(spike_rates_array.shape[1]):
341
+ pop_rate = spike_rates_array[:, i]
342
+
343
+ if smooth_method == "gaussian":
344
+ # Gaussian smoothing (sigma is approximately window/6 for a Gaussian filter)
345
+ sigma = smooth_window / 6
346
+ smoothed_pop_rate = ndimage.gaussian_filter1d(pop_rate, sigma=sigma)
347
+ elif smooth_method == "boxcar":
348
+ # Boxcar/uniform smoothing
349
+ kernel = np.ones(smooth_window) / smooth_window
350
+ smoothed_pop_rate = ndimage.convolve1d(pop_rate, kernel, mode="nearest")
351
+ elif smooth_method == "exponential":
352
+ # Exponential smoothing
353
+ alpha = 2 / (smooth_window + 1) # Equivalent to window size in exponential smoothing
354
+ smoothed_pop_rate = np.zeros_like(pop_rate)
355
+ smoothed_pop_rate[0] = pop_rate[0]
356
+ for t in range(1, len(pop_rate)):
357
+ smoothed_pop_rate[t] = alpha * pop_rate[t] + (1 - alpha) * smoothed_pop_rate[t - 1]
358
+
359
+ smoothed_rates.append(smoothed_pop_rate)
360
+
361
+ smoothed_rates_array = np.array(smoothed_rates).T # Transpose to have time as first dimension
362
+
363
+ # Stack raw and smoothed data
364
+ combined_data = np.stack([spike_rates_array, smoothed_rates_array], axis=2)
365
+
366
+ # Create DataArray with the additional 'type' dimension
367
+ spike_rate_array = xr.DataArray(
368
+ combined_data,
369
+ coords={"time": time, "population": populations, "type": ["raw", "smoothed"]},
370
+ dims=["time", "population", "type"],
371
+ attrs={
372
+ "fs": fs,
373
+ "normalized": False,
374
+ "smooth_method": smooth_method,
375
+ "smooth_window": smooth_window,
376
+ },
377
+ )
301
378
 
302
- # Normalize each spike rate series if normalize=True
379
+ # Normalize if requested
303
380
  if normalize:
304
- spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
305
-
381
+ # Apply normalization for each population and each type (raw/smoothed)
382
+ for pop_idx in range(len(populations)):
383
+ for type_idx, type_name in enumerate(["raw", "smoothed"]):
384
+ pop_data = spike_rate_array.sel(population=populations[pop_idx], type=type_name)
385
+ min_val = pop_data.min(dim="time")
386
+ max_val = pop_data.max(dim="time")
387
+
388
+ # Handle case where min == max (constant signal)
389
+ if max_val != min_val:
390
+ spike_rate_array.loc[:, populations[pop_idx], type_name] = (
391
+ pop_data - min_val
392
+ ) / (max_val - min_val)
393
+
394
+ spike_rate_array.attrs["normalized"] = True
395
+
396
+ # Save if requested
306
397
  if save:
307
398
  if save_path is None:
308
399
  raise ValueError("save_path must be provided if save is True.")
309
400
 
310
401
  os.makedirs(save_path, exist_ok=True)
311
-
312
402
  save_file = os.path.join(save_path, "spike_rate.h5")
313
- with h5py.File(save_file, "w") as f:
314
- f.create_dataset("time", data=time)
315
- grp = f.create_group("populations")
316
- for p, rspk in spike_rate.items():
317
- pop_grp = grp.create_group(p)
318
- pop_grp.create_dataset("data", data=rspk)
403
+ spike_rate_array.to_netcdf(save_file)
319
404
 
320
- return spike_rate
405
+ return spike_rate_array
406
+
407
+
408
+ def average_spike_rate_over_windows(
409
+ spike_rate: xr.DataArray, windows: List[Tuple[float, float]]
410
+ ) -> xr.DataArray:
411
+ """
412
+ Calculate the average spike rate over multiple time windows.
413
+
414
+ Parameters
415
+ ----------
416
+ spike_rate : xr.DataArray
417
+ The spike rate data array with dimensions (time, population, type)
418
+ where 'type' can be 'raw' or 'smoothed'
419
+ windows : List[Tuple[float, float]]
420
+ List of (start, end) times in milliseconds defining the windows to average over
421
+
422
+ Returns
423
+ -------
424
+ xr.DataArray
425
+ Averaged spike rate with time normalized to start at 0,
426
+ preserving all original dimensions (time, population, type)
427
+ """
428
+ # Check if the DataArray has a 'type' dimension (compatible with new format)
429
+ has_type_dim = "type" in spike_rate.dims
430
+
431
+ # Initialize list to store data from each window
432
+ window_data = []
433
+
434
+ # Get data for each window
435
+ for start, end in windows:
436
+ # Select data points within the window
437
+ window = spike_rate.sel(time=slice(start, end))
438
+
439
+ # Normalize time to start at 0 for this window
440
+ window = window.assign_coords(time=window.time - start)
441
+ window_data.append(window)
442
+
443
+ # Align and average windows
444
+ # First window determines the time coordinates
445
+ aligned_data = xr.concat(window_data, dim="window")
446
+ averaged_data = aligned_data.mean(dim="window")
447
+
448
+ # Create new DataArray with the averaged data
449
+ if has_type_dim:
450
+ # Create result with time, population, and type dimensions
451
+ result = xr.DataArray(
452
+ averaged_data.values,
453
+ coords={
454
+ "time": averaged_data.time.values,
455
+ "population": averaged_data.population,
456
+ "type": averaged_data.type,
457
+ },
458
+ dims=["time", "population", "type"],
459
+ )
460
+ else:
461
+ # Handle older format without 'type' dimension (for backward compatibility)
462
+ result = xr.DataArray(
463
+ averaged_data.values,
464
+ coords={"time": averaged_data.time.values, "population": averaged_data.population},
465
+ dims=["time", "population"],
466
+ )
467
+
468
+ # Preserve attributes
469
+ result.attrs = spike_rate.attrs
470
+
471
+ return result
321
472
 
322
473
 
323
474
  def compare_firing_over_times(
@@ -1,5 +1,6 @@
1
1
  import matplotlib.pyplot as plt
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  import seaborn as sns
4
5
  from matplotlib.gridspec import GridSpec
5
6
  from scipy import stats
@@ -55,7 +56,7 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
55
56
 
56
57
  def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
57
58
  """
58
- Plot an idealized gamma cycle with spike histograms for different neuron populations.
59
+ Plot an idealized cycle with spike histograms for different neuron populations.
59
60
 
60
61
  Parameters:
61
62
  -----------
@@ -120,7 +121,7 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
120
121
  plt.show()
121
122
 
122
123
 
123
- def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
124
+ def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
124
125
  """
125
126
  Plot PPC for all node populations on one graph with mean and standard error.
126
127
 
@@ -200,3 +201,168 @@ def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=No
200
201
  # Adjust layout and save
201
202
  plt.tight_layout()
202
203
  plt.show()
204
+
205
+
206
+ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title=None):
207
+ """
208
+ Plot a swarm plot of the entrainment for different populations at a single frequency.
209
+
210
+ Parameters:
211
+ -----------
212
+ ppc_dict : dict
213
+ Dictionary containing PPC values organized by population, node, and frequency
214
+ pop_names : list
215
+ List of population names to include in the plot
216
+ freq : float or int
217
+ The specific frequency to plot
218
+ save_path : str, optional
219
+ Path to save the figure. If None, figure is just displayed.
220
+
221
+ Returns:
222
+ --------
223
+ matplotlib.figure.Figure
224
+ The figure object for further customization if needed
225
+ """
226
+ # Set the style
227
+ sns.set_style("whitegrid")
228
+
229
+ # Prepare data for the swarm plot
230
+ data_list = []
231
+
232
+ for pop in pop_names:
233
+ values = []
234
+ node_ids = []
235
+
236
+ for node in ppc_dict[pop]:
237
+ if freq in ppc_dict[pop][node] and ppc_dict[pop][node][freq] is not None:
238
+ data_list.append(
239
+ {"Population": pop, "Node": node, "PPC Difference": ppc_dict[pop][node][freq]}
240
+ )
241
+
242
+ # Create DataFrame in long format
243
+ df = pd.DataFrame(data_list)
244
+
245
+ if df.empty:
246
+ print(f"No data available for frequency {freq}.")
247
+ return None
248
+
249
+ # Print mean PPC change for each population)
250
+ for pop in pop_names:
251
+ subset = df[df["Population"] == pop]
252
+ if not subset.empty:
253
+ mean_val = subset["PPC Difference"].mean()
254
+ std_val = subset["PPC Difference"].std()
255
+ n = len(subset)
256
+ sem_val = std_val / np.sqrt(n) # Standard error of the mean
257
+ print(f"{pop}: {mean_val:.4f} ± {sem_val:.4f} (n={n})")
258
+
259
+ # Create figure
260
+ plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
261
+
262
+ # Create swarm plot
263
+ ax = sns.swarmplot(
264
+ x="Population",
265
+ y="PPC Difference",
266
+ data=df,
267
+ size=3,
268
+ # palette='Set2'
269
+ )
270
+
271
+ # Add sample size annotations
272
+ for i, pop in enumerate(pop_names):
273
+ subset = df[df["Population"] == pop]
274
+ if not subset.empty:
275
+ n = len(subset)
276
+ y_min = subset["PPC Difference"].min()
277
+ y_max = subset["PPC Difference"].max()
278
+
279
+ # Position annotation below the lowest point
280
+ plt.annotate(
281
+ f"n={n}", (i, y_min - 0.05 * (y_max - y_min) - 0.05), ha="center", fontsize=10
282
+ )
283
+
284
+ # Add reference line at y=0
285
+ plt.axhline(y=0, color="black", linestyle="-", linewidth=0.5, alpha=0.7)
286
+
287
+ # Add horizontal lines for mean values
288
+ for i, pop in enumerate(pop_names):
289
+ subset = df[df["Population"] == pop]
290
+ if not subset.empty:
291
+ mean_val = subset["PPC Difference"].mean()
292
+ plt.plot([i - 0.25, i + 0.25], [mean_val, mean_val], "r-", linewidth=2)
293
+
294
+ # Calculate and display statistics
295
+ if len(pop_names) > 1:
296
+ # Print statistical test results
297
+ print(f"\nMann-Whitney U Test Results at {freq} Hz:")
298
+ print("-" * 60)
299
+
300
+ # Add p-values for pairwise comparisons
301
+ y_max = df["PPC Difference"].max()
302
+ y_min = df["PPC Difference"].min()
303
+ y_range = y_max - y_min
304
+
305
+ # Perform t-tests between populations if there are at least 2
306
+ for i in range(len(pop_names)):
307
+ for j in range(i + 1, len(pop_names)):
308
+ pop1 = pop_names[i]
309
+ pop2 = pop_names[j]
310
+
311
+ vals1 = df[df["Population"] == pop1]["PPC Difference"].values
312
+ vals2 = df[df["Population"] == pop2]["PPC Difference"].values
313
+
314
+ if len(vals1) > 1 and len(vals2) > 1:
315
+ # Perform Mann-Whitney U test (non-parametric)
316
+ u_stat, p_val = stats.mannwhitneyu(vals1, vals2, alternative="two-sided")
317
+
318
+ # Add significance markers
319
+ sig_str = "ns"
320
+ if p_val < 0.05:
321
+ sig_str = "*"
322
+ if p_val < 0.01:
323
+ sig_str = "**"
324
+ if p_val < 0.001:
325
+ sig_str = "***"
326
+
327
+ # Position the significance bar
328
+ bar_height = y_max + 0.1 * y_range * (1 + (j - i - 1) * 0.5)
329
+
330
+ # Draw the bar
331
+ plt.plot([i, j], [bar_height, bar_height], "k-")
332
+ plt.plot([i, i], [bar_height - 0.02 * y_range, bar_height], "k-")
333
+ plt.plot([j, j], [bar_height - 0.02 * y_range, bar_height], "k-")
334
+
335
+ # Add significance marker
336
+ plt.text(
337
+ (i + j) / 2,
338
+ bar_height + 0.01 * y_range,
339
+ sig_str,
340
+ ha="center",
341
+ va="bottom",
342
+ fontsize=12,
343
+ )
344
+
345
+ # Print the statistical comparison
346
+ print(f"{pop1} vs {pop2}: U={u_stat:.1f}, p={p_val:.4f} {sig_str}")
347
+
348
+ # Add labels and title
349
+ plt.xlabel("Population", fontsize=14)
350
+ plt.ylabel("PPC", fontsize=14)
351
+ if title:
352
+ plt.title(title, fontsize=16)
353
+
354
+ # Adjust y-axis limits to make room for annotations
355
+ y_min, y_max = plt.ylim()
356
+ plt.ylim(y_min - 0.15 * (y_max - y_min), y_max + 0.25 * (y_max - y_min))
357
+
358
+ # Add gridlines
359
+ plt.grid(True, linestyle="--", alpha=0.7, axis="y")
360
+
361
+ # Adjust layout
362
+ plt.tight_layout()
363
+
364
+ # Save figure if path is provided
365
+ if save_path:
366
+ plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
367
+
368
+ plt.show()
bmtool/bmplot/spikes.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Plotting functions for neural spikes and firing rates."""
2
+
1
3
  from typing import Dict, List, Optional, Union
2
4
 
3
5
  import matplotlib.pyplot as plt
bmtool/synapses.py CHANGED
@@ -18,15 +18,18 @@ from scipy.optimize import curve_fit, minimize, minimize_scalar
18
18
  from scipy.signal import find_peaks
19
19
  from tqdm.notebook import tqdm
20
20
 
21
+ from bmtool.util.util import load_mechanisms_from_config, load_templates_from_config
22
+
21
23
 
22
24
  class SynapseTuner:
23
25
  def __init__(
24
26
  self,
25
- mechanisms_dir: str,
26
- templates_dir: str,
27
- conn_type_settings: dict,
28
- connection: str,
29
- general_settings: dict,
27
+ mechanisms_dir: str = None,
28
+ templates_dir: str = None,
29
+ config: str = None,
30
+ conn_type_settings: dict = None,
31
+ connection: str = None,
32
+ general_settings: dict = None,
30
33
  json_folder_path: str = None,
31
34
  current_name: str = "i",
32
35
  other_vars_to_record: list = None,
@@ -57,8 +60,18 @@ class SynapseTuner:
57
60
  List of synaptic variables you would like sliders set up for the STP sliders method by default will use all parameters in spec_syn_param.
58
61
 
59
62
  """
60
- neuron.load_mechanisms(mechanisms_dir)
61
- h.load_file(templates_dir)
63
+ if config is None and (mechanisms_dir is None or templates_dir is None):
64
+ raise ValueError(
65
+ "Either a config file or both mechanisms_dir and templates_dir must be provided."
66
+ )
67
+
68
+ if config is None:
69
+ neuron.load_mechanisms(mechanisms_dir)
70
+ h.load_file(templates_dir)
71
+ else:
72
+ load_mechanisms_from_config(config)
73
+ load_templates_from_config(config)
74
+
62
75
  self.conn_type_settings = conn_type_settings
63
76
  if json_folder_path:
64
77
  print(f"updating settings from json path {json_folder_path}")
@@ -939,10 +952,11 @@ class SynapseTuner:
939
952
  class GapJunctionTuner:
940
953
  def __init__(
941
954
  self,
942
- mechanisms_dir: str,
943
- templates_dir: str,
944
- general_settings: dict,
945
- conn_type_settings: dict,
955
+ mechanisms_dir: str = None,
956
+ templates_dir: str = None,
957
+ config: str = None,
958
+ general_settings: dict = None,
959
+ conn_type_settings: dict = None,
946
960
  ):
947
961
  """
948
962
  Initialize the GapJunctionTuner class.
@@ -953,13 +967,24 @@ class GapJunctionTuner:
953
967
  Directory path containing the compiled mod files needed for NEURON mechanisms.
954
968
  templates_dir : str
955
969
  Directory path containing cell template files (.hoc or .py) loaded into NEURON.
970
+ config : str
971
+ Path to a BMTK config.json file. Can be used to load mechanisms, templates, and other settings.
956
972
  general_settings : dict
957
973
  General settings dictionary including parameters like simulation time step, duration, and temperature.
958
974
  conn_type_settings : dict
959
975
  A dictionary containing connection-specific settings for gap junctions.
960
976
  """
961
- neuron.load_mechanisms(mechanisms_dir)
962
- h.load_file(templates_dir)
977
+ if config is None and (mechanisms_dir is None or templates_dir is None):
978
+ raise ValueError(
979
+ "Either a config file or both mechanisms_dir and templates_dir must be provided."
980
+ )
981
+
982
+ if config is None:
983
+ neuron.load_mechanisms(mechanisms_dir)
984
+ h.load_file(templates_dir)
985
+ else:
986
+ load_mechanisms_from_config(config)
987
+ load_templates_from_config(config)
963
988
 
964
989
  self.general_settings = general_settings
965
990
  self.conn_type_settings = conn_type_settings
@@ -1049,7 +1074,6 @@ class GapJunctionTuner:
1049
1074
  plt.xlabel("Time (ms)")
1050
1075
  plt.ylabel("Membrane Voltage (mV)")
1051
1076
  plt.legend()
1052
- plt.show()
1053
1077
 
1054
1078
  def coupling_coefficient(self, t, v1, v2, t_start, t_end, dt=h.dt):
1055
1079
  """
@@ -1085,28 +1109,55 @@ class GapJunctionTuner:
1085
1109
 
1086
1110
  def InteractiveTuner(self):
1087
1111
  w_run = widgets.Button(description="Run", icon="history", button_style="primary")
1088
- values = [i * 10**-4 for i in range(1, 101)] # From 1e-4 to 1e-2
1112
+ values = [i * 10**-4 for i in range(1, 1001)] # From 1e-4 to 1e-1
1089
1113
 
1090
1114
  # Create the SelectionSlider widget with appropriate formatting
1091
- resistance = widgets.SelectionSlider(
1092
- options=[("%g" % i, i) for i in values], # Use scientific notation for display
1093
- value=10**-3, # Default value
1115
+ resistance = widgets.FloatLogSlider(
1116
+ value=0.001,
1117
+ base=10,
1118
+ min=-4, # max exponent of base
1119
+ max=-1, # min exponent of base
1120
+ step=0.1, # exponent step
1094
1121
  description="Resistance: ",
1095
1122
  continuous_update=True,
1096
1123
  )
1097
1124
 
1098
1125
  ui = VBox([w_run, resistance])
1126
+
1127
+ # Create an output widget to control what gets cleared
1128
+ output = widgets.Output()
1129
+
1099
1130
  display(ui)
1131
+ display(output)
1100
1132
 
1101
1133
  def on_button(*args):
1102
- clear_output()
1103
- display(ui)
1104
- resistance_for_gap = resistance.value
1105
- self.model(resistance_for_gap)
1106
- self.plot_model()
1107
- cc = self.coupling_coefficient(self.t_vec, self.soma_v_1, self.soma_v_2, 500, 1000)
1108
- print(f"coupling_coefficient is {cc:0.4f}")
1134
+ with output:
1135
+ # Clear only the output widget, not the entire cell
1136
+ output.clear_output(wait=True)
1137
+
1138
+ resistance_for_gap = resistance.value
1139
+ print(f"Running simulation with resistance: {resistance_for_gap}")
1140
+
1141
+ try:
1142
+ self.model(resistance_for_gap)
1143
+ self.plot_model()
1144
+
1145
+ # Convert NEURON vectors to numpy arrays
1146
+ t_array = np.array(self.t_vec)
1147
+ v1_array = np.array(self.soma_v_1)
1148
+ v2_array = np.array(self.soma_v_2)
1149
+
1150
+ cc = self.coupling_coefficient(t_array, v1_array, v2_array, 500, 1000)
1151
+ print(f"coupling_coefficient is {cc:0.4f}")
1152
+ plt.show()
1153
+
1154
+ except Exception as e:
1155
+ print(f"Error during simulation or analysis: {e}")
1156
+ import traceback
1157
+
1158
+ traceback.print_exc()
1109
1159
 
1160
+ # Run once initially
1110
1161
  on_button()
1111
1162
  w_run.on_click(on_button)
1112
1163
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.5
3
+ Version: 0.7.1.7
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -6,18 +6,18 @@ bmtool/graphs.py,sha256=gBTzI6c2BBK49dWGcfWh9c56TAooyn-KaiEy0Im1HcI,6717
6
6
  bmtool/manage.py,sha256=lsgRejp02P-x6QpA7SXcyXdalPhRmypoviIA2uAitQs,608
7
7
  bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
- bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
9
+ bmtool/synapses.py,sha256=hRuxRCXVpu0_0egi183qyp343tT-_gZNSxjk9rT5J8Q,66175
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=PM4Do8Cl248Y2kIXLRFLPmUB_mH38Yhl8CUDDcunGq0,28241
11
+ bmtool/analysis/entrainment.py,sha256=TUeV-WfCfVPqilVTjg6Cv1WKOz_zSX7LeE7k2Wuceug,28449
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
- bmtool/analysis/spikes.py,sha256=IHxV7_X8ojh4NDVBjzHCzfHF8muPPef2UtH3yqYre78,17091
14
+ bmtool/analysis/spikes.py,sha256=u7Qu0NVGPDAH5jlgNLv32H1hDAkOlG6P4nKEFeAOkdE,22833
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
17
- bmtool/bmplot/entrainment.py,sha256=VSlZvcSeXLr5OxGvmWcGU4s7JS7vOL38lq1XC69O_AE,6926
17
+ bmtool/bmplot/entrainment.py,sha256=YBHTJ-nK0OgM8CNssM8IyqPNYez9ss9bQi-C5HW4kGw,12593
18
18
  bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
19
19
  bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- bmtool/bmplot/spikes.py,sha256=Lg8V3ynYCqk-QJvq-BOInjZMHYHrxHgXjtDOX67df-A,11148
20
+ bmtool/bmplot/spikes.py,sha256=RJOOtmgWhTvyVi1CghoKTtxvt7MF9cJCrJVm5hV5wA4,11210
21
21
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
23
23
  bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.5.dist-info/METADATA,sha256=pvpABD7P2ytzO08EYUC8HrbUTg_fk8pc67mBaEvi7-M,3577
31
- bmtool-0.7.1.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- bmtool-0.7.1.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.5.dist-info/RECORD,,
29
+ bmtool-0.7.1.7.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.7.dist-info/METADATA,sha256=-2fuMCtlaM_YoVuYHcuhNAGe-Cw-5Yfb3kqejIV7S6c,3577
31
+ bmtool-0.7.1.7.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
32
+ bmtool-0.7.1.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5