bmtool 0.6.9.29__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {bmtool-0.6.9.29 → bmtool-0.7.0}/PKG-INFO +1 -1
  2. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/analysis/entrainment.py +0 -50
  3. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/analysis/lfp.py +0 -51
  4. bmtool-0.6.9.29/bmtool/bmplot.py → bmtool-0.7.0/bmtool/bmplot/connections.py +18 -444
  5. bmtool-0.7.0/bmtool/bmplot/entrainment.py +51 -0
  6. bmtool-0.7.0/bmtool/bmplot/lfp.py +53 -0
  7. bmtool-0.7.0/bmtool/bmplot/netcon_reports.py +4 -0
  8. bmtool-0.7.0/bmtool/bmplot/spikes.py +259 -0
  9. bmtool-0.7.0/bmtool/util/neuron/__init__.py +0 -0
  10. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/PKG-INFO +1 -1
  11. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/SOURCES.txt +6 -1
  12. {bmtool-0.6.9.29 → bmtool-0.7.0}/setup.py +1 -1
  13. {bmtool-0.6.9.29 → bmtool-0.7.0}/LICENSE +0 -0
  14. {bmtool-0.6.9.29 → bmtool-0.7.0}/README.md +0 -0
  15. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/SLURM.py +0 -0
  16. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/__init__.py +0 -0
  17. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/__main__.py +0 -0
  18. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/analysis/__init__.py +0 -0
  19. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/analysis/netcon_reports.py +0 -0
  20. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/analysis/spikes.py +0 -0
  21. {bmtool-0.6.9.29/bmtool/debug → bmtool-0.7.0/bmtool/bmplot}/__init__.py +0 -0
  22. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/connectors.py +0 -0
  23. {bmtool-0.6.9.29/bmtool/util → bmtool-0.7.0/bmtool/debug}/__init__.py +0 -0
  24. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/debug/commands.py +0 -0
  25. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/debug/debug.py +0 -0
  26. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/graphs.py +0 -0
  27. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/manage.py +0 -0
  28. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/plot_commands.py +0 -0
  29. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/singlecell.py +0 -0
  30. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/synapses.py +0 -0
  31. {bmtool-0.6.9.29/bmtool/util/neuron → bmtool-0.7.0/bmtool/util}/__init__.py +0 -0
  32. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/util/commands.py +0 -0
  33. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/util/neuron/celltuner.py +0 -0
  34. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool/util/util.py +0 -0
  35. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/dependency_links.txt +0 -0
  36. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/entry_points.txt +0 -0
  37. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/requires.txt +0 -0
  38. {bmtool-0.6.9.29 → bmtool-0.7.0}/bmtool.egg-info/top_level.txt +0 -0
  39. {bmtool-0.6.9.29 → bmtool-0.7.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.29
3
+ Version: 0.7.0
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -488,53 +488,3 @@ def calculate_spike_rate_power_correlation(spike_rate, lfp, fs, pop_names, freq_
488
488
 
489
489
  return correlation_results, frequencies
490
490
 
491
-
492
- def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
493
- """
494
- Plot the correlation between population spike rates and LFP power.
495
-
496
- Parameters:
497
- -----------
498
- correlation_results : dict
499
- Dictionary with correlation results for calculate_spike_rate_power_correlation
500
- frequencies : array
501
- Array of frequencies analyzed
502
- pop_names : list
503
- List of population names
504
- """
505
- sns.set_style("whitegrid")
506
- plt.figure(figsize=(10, 6))
507
-
508
- for pop in pop_names:
509
- # Extract correlation values for each frequency
510
- corr_values = []
511
- valid_freqs = []
512
-
513
- for freq in frequencies:
514
- if freq in correlation_results[pop]:
515
- corr_values.append(correlation_results[pop][freq]['correlation'])
516
- valid_freqs.append(freq)
517
-
518
- # Plot correlation line
519
- plt.plot(valid_freqs, corr_values, marker='o', label=pop,
520
- linewidth=2, markersize=6)
521
-
522
- plt.xlabel('Frequency (Hz)', fontsize=12)
523
- plt.ylabel('Spike Rate-Power Correlation', fontsize=12)
524
- plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
525
- plt.grid(True, alpha=0.3)
526
- plt.legend(fontsize=12)
527
- plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
528
-
529
- # Add horizontal line at zero for reference
530
- plt.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
531
-
532
- # Set y-axis limits to make zero visible
533
- y_min, y_max = plt.ylim()
534
- plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
535
-
536
- plt.tight_layout()
537
-
538
- plt.show()
539
-
540
-
@@ -10,7 +10,6 @@ from fooof.sim.gen import gen_model, gen_aperiodic
10
10
  import matplotlib.pyplot as plt
11
11
  from scipy import signal
12
12
  import pywt
13
- from bmtool.bmplot import is_notebook
14
13
  import pandas as pd
15
14
 
16
15
 
@@ -434,53 +433,3 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
434
433
  sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
435
434
  return sxx
436
435
 
437
-
438
- # will probs move to bmplot later
439
- def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
440
- plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
441
- """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
442
- sxx = sxx_xarray.PSD.values.copy()
443
- t = sxx_xarray.time.values.copy()
444
- f = sxx_xarray.frequency.values.copy()
445
-
446
- cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
447
- if log_power:
448
- with np.errstate(divide='ignore'):
449
- sxx = np.log10(sxx)
450
- cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
451
-
452
- if remove_aperiodic is not None:
453
- f1_idx = 0 if f[0] else 1
454
- ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
455
- sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
456
- sxx[:f1_idx, :] = 0.
457
-
458
- if log_power == 'dB':
459
- sxx *= 10
460
-
461
- if ax is None:
462
- _, ax = plt.subplots(1, 1)
463
- plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
464
- if plt_range.size == 1:
465
- plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
466
- f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
467
- if clr_freq_range is None:
468
- vmin, vmax = None, None
469
- else:
470
- c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
471
- vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
472
-
473
- f = f[f_idx]
474
- pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
475
- if 'cone_of_influence_frequency' in sxx_xarray:
476
- coif = sxx_xarray.cone_of_influence_frequency
477
- ax.plot(t, coif)
478
- ax.fill_between(t, coif, step='mid', alpha=0.2)
479
- ax.set_xlim(t[0], t[-1])
480
- #ax.set_xlim(t[0],0.2)
481
- ax.set_ylim(f[0], f[-1])
482
- plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
483
- ax.set_xlabel('Time (sec)')
484
- ax.set_ylabel('Frequency (Hz)')
485
- return sxx
486
-
@@ -2,18 +2,12 @@
2
2
  Want to be able to take multiple plot names in and plot them all at the same time, to save time
3
3
  https://stackoverflow.com/questions/458209/is-there-a-way-to-detach-matplotlib-plots-so-that-the-computation-can-continue
4
4
  """
5
- from .util import util
6
- import argparse,os,sys
7
-
5
+ from ..util import util
8
6
  import numpy as np
9
7
  import matplotlib
10
8
  import matplotlib.pyplot as plt
11
9
  import matplotlib.cm as cmx
12
10
  import matplotlib.colors as colors
13
- import matplotlib.gridspec as gridspec
14
- from mpl_toolkits.mplot3d import Axes3D
15
- from matplotlib.axes import Axes
16
- import seaborn as sns
17
11
  from IPython import get_ipython
18
12
  from IPython.display import display, HTML
19
13
  import statistics
@@ -21,9 +15,8 @@ import pandas as pd
21
15
  import os
22
16
  import sys
23
17
  import re
24
- from typing import Optional, Dict, Union, List
25
18
 
26
- from .util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
19
+ from ..util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
27
20
  from bmtk.analyzer.utils import listify
28
21
  from neuron import h
29
22
 
@@ -59,6 +52,7 @@ def is_notebook() -> bool:
59
52
  except NameError:
60
53
  return False # Probably standard Python interpreter
61
54
 
55
+
62
56
  def total_connection_matrix(config=None, title=None, sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False, save_file=None, synaptic_info='0', include_gap=True):
63
57
  """
64
58
  Generate a plot displaying total connections or other synaptic statistics.
@@ -122,7 +116,8 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
122
116
  title = "All Synapse .json Files Used"
123
117
  plot_connection_info(text,num,source_labels,target_labels,title, syn_info=synaptic_info, save_file=save_file)
124
118
  return
125
-
119
+
120
+
126
121
  def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,method = 'total',include_gap=True):
127
122
  """
128
123
  Generates a plot showing the percent connectivity of a network
@@ -159,6 +154,7 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
159
154
  plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
160
155
  return
161
156
 
157
+
162
158
  def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None,
163
159
  no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False,include_gap=True):
164
160
  """
@@ -235,6 +231,7 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
235
231
 
236
232
  return
237
233
 
234
+
238
235
  def convergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean+std',include_gap=True,return_dict=None):
239
236
  """
240
237
  Generates connection plot displaying convergence data
@@ -253,6 +250,7 @@ def convergence_connection_matrix(config=None,title=None,sources=None, targets=N
253
250
  raise Exception("Sources or targets not defined")
254
251
  return divergence_connection_matrix(config,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method,include_gap=include_gap,return_dict=return_dict)
255
252
 
253
+
256
254
  def divergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean+std',include_gap=True,return_dict=None):
257
255
  """
258
256
  Generates connection plot displaying divergence data
@@ -309,6 +307,7 @@ def divergence_connection_matrix(config=None,title=None,sources=None, targets=No
309
307
  plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
310
308
  return
311
309
 
310
+
312
311
  def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=None,tids=None, no_prepend_pop=False,save_file=None,method='convergence'):
313
312
  """
314
313
  Generates connection plot displaying gap junction data.
@@ -431,7 +430,8 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
431
430
  title+=' Percent Connectivity'
432
431
  plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
433
432
  return
434
-
433
+
434
+
435
435
  def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],no_prepend_pop=True,synaptic_info='0',
436
436
  source_cell = None,target_cell = None,include_gap=True):
437
437
  """
@@ -520,6 +520,7 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
520
520
  tids = []
521
521
  util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,not no_prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)
522
522
 
523
+
523
524
  def connection_distance(config: str,sources: str,targets: str,
524
525
  source_cell_id: int,target_id_type: str,ignore_z:bool=False) -> None:
525
526
  """
@@ -600,6 +601,7 @@ def connection_distance(config: str,sources: str,targets: str,
600
601
  plt.grid(True)
601
602
  plt.show()
602
603
 
604
+
603
605
  def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
604
606
  """
605
607
  Generates a matrix of histograms showing the distribution of edge properties between different populations.
@@ -683,6 +685,7 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
683
685
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
684
686
  plt.draw()
685
687
 
688
+
686
689
  def distance_delay_plot(simulation_config: str,source: str,target: str,
687
690
  group_by: str,sid: str,tid: str) -> None:
688
691
  """
@@ -739,6 +742,7 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
739
742
  plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
740
743
  plt.show()
741
744
 
745
+
742
746
  def plot_synapse_location_histograms(config, target_model, source=None, target=None):
743
747
  """
744
748
  generates a histogram of the positions of the synapses on a cell broken down by section
@@ -832,6 +836,7 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
832
836
  )
833
837
  print(pivot_table)
834
838
 
839
+
835
840
  def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
836
841
  """
837
842
  Function to plot connection information as a heatmap, including handling missing source and target values.
@@ -909,6 +914,7 @@ def plot_connection_info(text, num, source_labels, target_labels, title, syn_inf
909
914
  else:
910
915
  return
911
916
 
917
+
912
918
  def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_key=None, title: str = 'Percent connection matrix', pop_order=None) -> None:
913
919
  """
914
920
  Generates and plots a connection matrix based on connection probabilities from a CSV file produced by bmtool.connector.
@@ -1063,275 +1069,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
1063
1069
  plt.tight_layout()
1064
1070
  plt.show()
1065
1071
 
1066
- def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
1067
- ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
1068
- color_map: Optional[Dict[str, str]] = None) -> Axes:
1069
- """
1070
- Plots a raster plot of neural spikes, with different colors for each population.
1071
-
1072
- Parameters:
1073
- ----------
1074
- spikes_df : pd.DataFrame, optional
1075
- DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
1076
- config : str, optional
1077
- Path to the configuration file used to load node data.
1078
- network_name : str, optional
1079
- Specific network name to select from the configuration; if not provided, uses the first network.
1080
- ax : matplotlib.axes.Axes, optional
1081
- Axes on which to plot the raster; if None, a new figure and axes are created.
1082
- tstart : float, optional
1083
- Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
1084
- tstop : float, optional
1085
- Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
1086
- color_map : dict, optional
1087
- Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
1088
-
1089
- Returns:
1090
- -------
1091
- matplotlib.axes.Axes
1092
- Axes with the raster plot.
1093
-
1094
- Notes:
1095
- -----
1096
- - If `config` is provided, the function merges population names from the node data with `spikes_df`.
1097
- - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
1098
- - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
1099
- """
1100
- # Initialize axes if none provided
1101
- if ax is None:
1102
- _, ax = plt.subplots(1, 1)
1103
-
1104
- # Filter spikes by time range if specified
1105
- if tstart is not None:
1106
- spikes_df = spikes_df[spikes_df['timestamps'] > tstart]
1107
- if tstop is not None:
1108
- spikes_df = spikes_df[spikes_df['timestamps'] < tstop]
1109
-
1110
- # Load and merge node population data if config is provided
1111
- if config:
1112
- nodes = load_nodes_from_config(config)
1113
- if network_name:
1114
- nodes = nodes.get(network_name, {})
1115
- else:
1116
- nodes = list(nodes.values())[0] if nodes else {}
1117
- print("Grabbing first network; specify a network name to ensure correct node population is selected.")
1118
-
1119
- # Find common columns, but exclude the join key from the list
1120
- common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
1121
- common_columns = [col for col in common_columns if col != 'node_ids'] # Remove our join key from the common list
1122
-
1123
- # Drop all intersecting columns except the join key column from df2
1124
- spikes_df = spikes_df.drop(columns=common_columns)
1125
- # merge nodes and spikes df
1126
- spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
1127
-
1128
1072
 
1129
- # Get unique population names
1130
- unique_pop_names = spikes_df[groupby].unique()
1131
-
1132
- # Generate colors if no color_map is provided
1133
- if color_map is None:
1134
- cmap = plt.get_cmap('tab10') # Default colormap
1135
- color_map = {pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)}
1136
- else:
1137
- # Ensure color_map contains all population names
1138
- missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
1139
- if missing_colors:
1140
- raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
1141
-
1142
- # Plot each population with its specified or generated color
1143
- for pop_name, group in spikes_df.groupby(groupby):
1144
- ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
1145
-
1146
- # Label axes
1147
- ax.set_xlabel("Time")
1148
- ax.set_ylabel("Node ID")
1149
- ax.legend(title="Population", loc='upper right', framealpha=0.9, markerfirst=False)
1150
-
1151
- return ax
1152
-
1153
- # uses df from bmtool.analysis.spikes compute_firing_rate_stats
1154
- def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
1155
- color_map: Optional[Dict[str, str]] = None) -> Axes:
1156
- """
1157
- Plots a bar graph of mean firing rates with error bars (standard deviation).
1158
-
1159
- Parameters:
1160
- ----------
1161
- firing_stats : pd.DataFrame
1162
- Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
1163
- groupby : str or list of str
1164
- Column(s) used for grouping.
1165
- ax : matplotlib.axes.Axes, optional
1166
- Axes on which to plot the bar chart; if None, a new figure and axes are created.
1167
- color_map : dict, optional
1168
- Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
1169
-
1170
- Returns:
1171
- -------
1172
- matplotlib.axes.Axes
1173
- Axes with the bar plot.
1174
- """
1175
- # Ensure groupby is a list for consistent handling
1176
- if isinstance(groupby, str):
1177
- groupby = [groupby]
1178
-
1179
- # Create a categorical column for grouping
1180
- firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
1181
-
1182
- # Get unique group names
1183
- unique_groups = firing_stats["group"].unique()
1184
-
1185
- # Generate colors if no color_map is provided
1186
- if color_map is None:
1187
- cmap = plt.get_cmap('viridis')
1188
- color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
1189
- else:
1190
- # Ensure color_map contains all groups
1191
- missing_colors = [group for group in unique_groups if group not in color_map]
1192
- if missing_colors:
1193
- raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
1194
-
1195
- # Create new figure and axes if ax is not provided
1196
- if ax is None:
1197
- fig, ax = plt.subplots(figsize=(10, 6))
1198
-
1199
- # Sort data for consistent plotting
1200
- firing_stats = firing_stats.sort_values(by="group")
1201
-
1202
- # Extract values for plotting
1203
- x_labels = firing_stats["group"]
1204
- means = firing_stats["firing_rate_mean"]
1205
- std_devs = firing_stats["firing_rate_std"]
1206
-
1207
- # Get colors for each group
1208
- colors = [color_map[group] for group in x_labels]
1209
-
1210
- # Create bar plot
1211
- bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
1212
-
1213
- # Add error bars manually with caps
1214
- _, caps, _ = ax.errorbar(
1215
- x=np.arange(len(x_labels)),
1216
- y=means,
1217
- yerr=std_devs,
1218
- fmt='none',
1219
- capsize=5,
1220
- capthick=2,
1221
- color="black"
1222
- )
1223
-
1224
- # Formatting
1225
- ax.set_xticks(np.arange(len(x_labels)))
1226
- ax.set_xticklabels(x_labels, rotation=45, ha="right")
1227
- ax.set_xlabel("Population Group")
1228
- ax.set_ylabel("Mean Firing Rate (spikes/s)")
1229
- ax.set_title("Firing Rate Statistics by Population")
1230
- ax.grid(axis='y', linestyle='--', alpha=0.7)
1231
-
1232
- return ax
1233
-
1234
- # uses df from bmtool.analysis.spikes compute_firing_rate_stats
1235
- def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
1236
- color_map: Optional[Dict[str, str]] = None,
1237
- plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
1238
- """
1239
- Plots a distribution of individual firing rates using one or more plot types
1240
- (box plot, violin plot, or swarm plot), overlaying them on top of each other.
1241
-
1242
- Parameters:
1243
- ----------
1244
- individual_stats : pd.DataFrame
1245
- Dataframe containing individual firing rates and corresponding group labels.
1246
- groupby : str or list of str
1247
- Column(s) used for grouping.
1248
- ax : matplotlib.axes.Axes, optional
1249
- Axes on which to plot the graph; if None, a new figure and axes are created.
1250
- color_map : dict, optional
1251
- Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
1252
- plot_type : str or list of str, optional
1253
- List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
1254
- swarm_alpha : float, optional
1255
- Transparency of swarm plot points. Default is 0.6.
1256
-
1257
- Returns:
1258
- -------
1259
- matplotlib.axes.Axes
1260
- Axes with the selected plot type(s) overlayed.
1261
- """
1262
- # Ensure groupby is a list for consistent handling
1263
- if isinstance(groupby, str):
1264
- groupby = [groupby]
1265
-
1266
- # Create a categorical column for grouping
1267
- individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
1268
-
1269
- # Validate plot_type (it can be a list or a single type)
1270
- if isinstance(plot_type, str):
1271
- plot_type = [plot_type]
1272
-
1273
- for pt in plot_type:
1274
- if pt not in ["box", "violin", "swarm"]:
1275
- raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
1276
-
1277
- # Get unique groups for coloring
1278
- unique_groups = individual_stats["group"].unique()
1279
-
1280
- # Generate colors if no color_map is provided
1281
- if color_map is None:
1282
- cmap = plt.get_cmap('viridis')
1283
- color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
1284
-
1285
- # Ensure color_map contains all groups
1286
- missing_colors = [group for group in unique_groups if group not in color_map]
1287
- if missing_colors:
1288
- raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
1289
-
1290
- # Create new figure and axes if ax is not provided
1291
- if ax is None:
1292
- fig, ax = plt.subplots(figsize=(10, 6))
1293
-
1294
- # Sort data for consistent plotting
1295
- individual_stats = individual_stats.sort_values(by="group")
1296
-
1297
- # Loop over each plot type and overlay them
1298
- for pt in plot_type:
1299
- if pt == "box":
1300
- sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
1301
- elif pt == "violin":
1302
- sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
1303
- elif pt == "swarm":
1304
- sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
1305
-
1306
- # Formatting
1307
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
1308
- ax.set_xlabel("Population Group")
1309
- ax.set_ylabel("Firing Rate (spikes/s)")
1310
- ax.set_title("Firing Rate Distribution for individual cells")
1311
- ax.grid(axis='y', linestyle='--', alpha=0.7)
1312
-
1313
- return ax
1314
-
1315
- def plot_entrainment():
1316
- """
1317
- Plots entrainment analysis for oscillatory network activity.
1318
-
1319
- This function analyzes and visualizes how well neural populations entrain to rhythmic
1320
- input or how synchronized they become during oscillatory activity. It can show phase
1321
- locking, coherence, or other entrainment metrics.
1322
-
1323
- Note: This is currently a placeholder function and not yet implemented.
1324
-
1325
- Parameters:
1326
- -----------
1327
- None
1328
-
1329
- Returns:
1330
- --------
1331
- None
1332
- """
1333
- pass
1334
-
1335
1073
  def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file=None, subset=None):
1336
1074
  """
1337
1075
  Plots a 3D graph of all cells with x, y, z location.
@@ -1431,6 +1169,7 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1431
1169
 
1432
1170
  return ax
1433
1171
 
1172
+
1434
1173
  def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save_file=None, quiver_length=None, arrow_length_ratio=None, group=None, subset=None):
1435
1174
  from scipy.spatial.transform import Rotation as R
1436
1175
  if not config:
@@ -1531,168 +1270,3 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
1531
1270
  notebook = is_notebook
1532
1271
  if notebook == False:
1533
1272
  plt.show()
1534
-
1535
- def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,edge_property='model_template'):
1536
- """
1537
- Creates a directed graph visualization of the network connectivity using NetworkX.
1538
-
1539
- This function generates a network diagram showing the connections between different
1540
- cell populations, with edge labels indicating the connection types based on the specified
1541
- edge property.
1542
-
1543
- Parameters:
1544
- -----------
1545
- config : str
1546
- Path to a BMTK simulation configuration file.
1547
- nodes : dict, optional
1548
- Dictionary of node information (if already loaded).
1549
- edges : dict, optional
1550
- Dictionary of edge information (if already loaded).
1551
- title : str, optional
1552
- Custom title for the plot. If None, defaults to "Network Graph".
1553
- sources : str
1554
- Comma-separated list of source network names.
1555
- targets : str
1556
- Comma-separated list of target network names.
1557
- sids : str, optional
1558
- Comma-separated list of source node identifiers to filter by.
1559
- tids : str, optional
1560
- Comma-separated list of target node identifiers to filter by.
1561
- no_prepend_pop : bool, default=False
1562
- If True, population names are not prepended to node identifiers in the display.
1563
- save_file : str, optional
1564
- Path to save the generated plot.
1565
- edge_property : str, default='model_template'
1566
- The edge property to use for labeling connections in the graph.
1567
-
1568
- Returns:
1569
- --------
1570
- None
1571
- Displays a network graph visualization.
1572
- """
1573
- if not config:
1574
- raise Exception("config not defined")
1575
- if not sources or not targets:
1576
- raise Exception("Sources or targets not defined")
1577
- sources = sources.split(",")
1578
- targets = targets.split(",")
1579
- if sids:
1580
- sids = sids.split(",")
1581
- else:
1582
- sids = []
1583
- if tids:
1584
- tids = tids.split(",")
1585
- else:
1586
- tids = []
1587
- throw_away, data, source_labels, target_labels = util.connection_graph_edge_types(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
1588
-
1589
- if title == None or title=="":
1590
- title = "Network Graph"
1591
-
1592
- import networkx as nx
1593
-
1594
- net_graph = nx.MultiDiGraph() #or G = nx.MultiDiGraph()
1595
-
1596
- edges = []
1597
- edge_labels = {}
1598
- for node in list(set(source_labels+target_labels)):
1599
- net_graph.add_node(node)
1600
-
1601
- for s, source in enumerate(source_labels):
1602
- for t, target in enumerate(target_labels):
1603
- relationship = data[s][t]
1604
- for i, relation in enumerate(relationship):
1605
- edge_labels[(source,target)]=relation
1606
- edges.append([source,target])
1607
-
1608
- net_graph.add_edges_from(edges)
1609
- #pos = nx.spring_layout(net_graph,k=0.50,iterations=20)
1610
- pos = nx.shell_layout(net_graph)
1611
- plt.figure()
1612
- nx.draw(net_graph,pos,edge_color='black', width=1,linewidths=1,\
1613
- node_size=500,node_color='white',arrowstyle='->',alpha=0.9,\
1614
- labels={node:node for node in net_graph.nodes()})
1615
-
1616
- nx.draw_networkx_edge_labels(net_graph,pos,edge_labels=edge_labels,font_color='red')
1617
- plt.show()
1618
-
1619
- return
1620
-
1621
- def plot_report(config_file=None, report_file=None, report_name=None, variables=None, gids=None):
1622
- if report_file is None:
1623
- report_name, report_file = _get_cell_report(config_file, report_name)
1624
-
1625
- var_report = CellVarsFile(report_file)
1626
- variables = listify(variables) if variables is not None else var_report.variables
1627
- gids = listify(gids) if gids is not None else var_report.gids
1628
- time_steps = var_report.time_trace
1629
-
1630
- def __units_str(var):
1631
- units = var_report.units(var)
1632
- if units == CellVarsFile.UNITS_UNKNOWN:
1633
- units = missing_units.get(var, '')
1634
- return '({})'.format(units) if units else ''
1635
-
1636
- n_plots = len(variables)
1637
- if n_plots > 1:
1638
- # If more than one variale to plot do so in different subplots
1639
- f, axarr = plt.subplots(n_plots, 1)
1640
- for i, var in enumerate(variables):
1641
- for gid in gids:
1642
- axarr[i].plot(time_steps, var_report.data(gid=gid, var_name=var), label='gid {}'.format(gid))
1643
-
1644
- axarr[i].legend()
1645
- axarr[i].set_ylabel('{} {}'.format(var, __units_str(var)))
1646
- if i < n_plots - 1:
1647
- axarr[i].set_xticklabels([])
1648
-
1649
- axarr[i].set_xlabel('time (ms)')
1650
-
1651
- elif n_plots == 1:
1652
- # For plotting a single variable
1653
- plt.figure()
1654
- for gid in gids:
1655
- plt.plot(time_steps, var_report.data(gid=gid, var_name=variables[0]), label='gid {}'.format(gid))
1656
- plt.ylabel('{} {}'.format(variables[0], __units_str(variables[0])))
1657
- plt.xlabel('time (ms)')
1658
- plt.legend()
1659
- else:
1660
- return
1661
-
1662
- plt.show()
1663
-
1664
- def plot_report_default(config, report_name, variables, gids):
1665
- """
1666
- A simplified interface for plotting cell report variables from BMTK simulations.
1667
-
1668
- This function handles the common case of plotting specific variables for specific cells
1669
- from a BMTK report file, with minimal parameter requirements.
1670
-
1671
- Parameters:
1672
- -----------
1673
- config : str
1674
- Path to a BMTK simulation configuration file.
1675
- report_name : str
1676
- Name of the report to plot (without file extension).
1677
- variables : str
1678
- Comma-separated list of variable names to plot (e.g., 'v,i_na,i_k').
1679
- gids : str
1680
- Comma-separated list of cell IDs (gids) to plot data for.
1681
-
1682
- Returns:
1683
- --------
1684
- None
1685
- Displays plots of the specified variables for the specified cells.
1686
- """
1687
-
1688
- if variables:
1689
- variables = variables.split(',')
1690
- if gids:
1691
- gids = [int(i) for i in gids.split(',')]
1692
-
1693
- if report_name:
1694
- cfg = util.load_config(config)
1695
- report_file = os.path.join(cfg['output']['output_dir'],report_name+'.h5')
1696
- plot_report(config_file=config, report_file=report_file, report_name=report_name, variables=variables, gids=gids);
1697
-
1698
- return
@@ -0,0 +1,51 @@
1
+
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+
5
+ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
6
+ """
7
+ Plot the correlation between population spike rates and LFP power.
8
+
9
+ Parameters:
10
+ -----------
11
+ correlation_results : dict
12
+ Dictionary with correlation results for calculate_spike_rate_power_correlation
13
+ frequencies : array
14
+ Array of frequencies analyzed
15
+ pop_names : list
16
+ List of population names
17
+ """
18
+ sns.set_style("whitegrid")
19
+ plt.figure(figsize=(10, 6))
20
+
21
+ for pop in pop_names:
22
+ # Extract correlation values for each frequency
23
+ corr_values = []
24
+ valid_freqs = []
25
+
26
+ for freq in frequencies:
27
+ if freq in correlation_results[pop]:
28
+ corr_values.append(correlation_results[pop][freq]['correlation'])
29
+ valid_freqs.append(freq)
30
+
31
+ # Plot correlation line
32
+ plt.plot(valid_freqs, corr_values, marker='o', label=pop,
33
+ linewidth=2, markersize=6)
34
+
35
+ plt.xlabel('Frequency (Hz)', fontsize=12)
36
+ plt.ylabel('Spike Rate-Power Correlation', fontsize=12)
37
+ plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
38
+ plt.grid(True, alpha=0.3)
39
+ plt.legend(fontsize=12)
40
+ plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
41
+
42
+ # Add horizontal line at zero for reference
43
+ plt.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
44
+
45
+ # Set y-axis limits to make zero visible
46
+ y_min, y_max = plt.ylim()
47
+ plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
48
+
49
+ plt.tight_layout()
50
+
51
+ plt.show()
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+ from bmtool.analysis.lfp import gen_aperiodic
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
7
+ plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
8
+ """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
9
+ sxx = sxx_xarray.PSD.values.copy()
10
+ t = sxx_xarray.time.values.copy()
11
+ f = sxx_xarray.frequency.values.copy()
12
+
13
+ cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
14
+ if log_power:
15
+ with np.errstate(divide='ignore'):
16
+ sxx = np.log10(sxx)
17
+ cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
18
+
19
+ if remove_aperiodic is not None:
20
+ f1_idx = 0 if f[0] else 1
21
+ ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
22
+ sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
23
+ sxx[:f1_idx, :] = 0.
24
+
25
+ if log_power == 'dB':
26
+ sxx *= 10
27
+
28
+ if ax is None:
29
+ _, ax = plt.subplots(1, 1)
30
+ plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
31
+ if plt_range.size == 1:
32
+ plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
33
+ f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
34
+ if clr_freq_range is None:
35
+ vmin, vmax = None, None
36
+ else:
37
+ c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
38
+ vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
39
+
40
+ f = f[f_idx]
41
+ pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
42
+ if 'cone_of_influence_frequency' in sxx_xarray:
43
+ coif = sxx_xarray.cone_of_influence_frequency
44
+ ax.plot(t, coif)
45
+ ax.fill_between(t, coif, step='mid', alpha=0.2)
46
+ ax.set_xlim(t[0], t[-1])
47
+ #ax.set_xlim(t[0],0.2)
48
+ ax.set_ylim(f[0], f[-1])
49
+ plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
50
+ ax.set_xlabel('Time (sec)')
51
+ ax.set_ylabel('Frequency (Hz)')
52
+ return sxx
53
+
@@ -0,0 +1,259 @@
1
+ from ..util.util import load_nodes_from_config
2
+ from typing import Optional, Dict, List, Union
3
+ import pandas as pd
4
+ from matplotlib.axes import Axes
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import numpy as np
8
+
9
+
10
+
11
+ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
12
+ ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
13
+ color_map: Optional[Dict[str, str]] = None) -> Axes:
14
+ """
15
+ Plots a raster plot of neural spikes, with different colors for each population.
16
+
17
+ Parameters:
18
+ ----------
19
+ spikes_df : pd.DataFrame, optional
20
+ DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
21
+ config : str, optional
22
+ Path to the configuration file used to load node data.
23
+ network_name : str, optional
24
+ Specific network name to select from the configuration; if not provided, uses the first network.
25
+ ax : matplotlib.axes.Axes, optional
26
+ Axes on which to plot the raster; if None, a new figure and axes are created.
27
+ tstart : float, optional
28
+ Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
29
+ tstop : float, optional
30
+ Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
31
+ color_map : dict, optional
32
+ Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
33
+
34
+ Returns:
35
+ -------
36
+ matplotlib.axes.Axes
37
+ Axes with the raster plot.
38
+
39
+ Notes:
40
+ -----
41
+ - If `config` is provided, the function merges population names from the node data with `spikes_df`.
42
+ - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
43
+ - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
44
+ """
45
+ # Initialize axes if none provided
46
+ if ax is None:
47
+ _, ax = plt.subplots(1, 1)
48
+
49
+ # Filter spikes by time range if specified
50
+ if tstart is not None:
51
+ spikes_df = spikes_df[spikes_df['timestamps'] > tstart]
52
+ if tstop is not None:
53
+ spikes_df = spikes_df[spikes_df['timestamps'] < tstop]
54
+
55
+ # Load and merge node population data if config is provided
56
+ if config:
57
+ nodes = load_nodes_from_config(config)
58
+ if network_name:
59
+ nodes = nodes.get(network_name, {})
60
+ else:
61
+ nodes = list(nodes.values())[0] if nodes else {}
62
+ print("Grabbing first network; specify a network name to ensure correct node population is selected.")
63
+
64
+ # Find common columns, but exclude the join key from the list
65
+ common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
66
+ common_columns = [col for col in common_columns if col != 'node_ids'] # Remove our join key from the common list
67
+
68
+ # Drop all intersecting columns except the join key column from df2
69
+ spikes_df = spikes_df.drop(columns=common_columns)
70
+ # merge nodes and spikes df
71
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
72
+
73
+
74
+ # Get unique population names
75
+ unique_pop_names = spikes_df[groupby].unique()
76
+
77
+ # Generate colors if no color_map is provided
78
+ if color_map is None:
79
+ cmap = plt.get_cmap('tab10') # Default colormap
80
+ color_map = {pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)}
81
+ else:
82
+ # Ensure color_map contains all population names
83
+ missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
84
+ if missing_colors:
85
+ raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
86
+
87
+ # Plot each population with its specified or generated color
88
+ for pop_name, group in spikes_df.groupby(groupby):
89
+ ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
90
+
91
+ # Label axes
92
+ ax.set_xlabel("Time")
93
+ ax.set_ylabel("Node ID")
94
+ ax.legend(title="Population", loc='upper right', framealpha=0.9, markerfirst=False)
95
+
96
+ return ax
97
+
98
+ # uses df from bmtool.analysis.spikes compute_firing_rate_stats
99
+ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
100
+ color_map: Optional[Dict[str, str]] = None) -> Axes:
101
+ """
102
+ Plots a bar graph of mean firing rates with error bars (standard deviation).
103
+
104
+ Parameters:
105
+ ----------
106
+ firing_stats : pd.DataFrame
107
+ Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
108
+ groupby : str or list of str
109
+ Column(s) used for grouping.
110
+ ax : matplotlib.axes.Axes, optional
111
+ Axes on which to plot the bar chart; if None, a new figure and axes are created.
112
+ color_map : dict, optional
113
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
114
+
115
+ Returns:
116
+ -------
117
+ matplotlib.axes.Axes
118
+ Axes with the bar plot.
119
+ """
120
+ # Ensure groupby is a list for consistent handling
121
+ if isinstance(groupby, str):
122
+ groupby = [groupby]
123
+
124
+ # Create a categorical column for grouping
125
+ firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
126
+
127
+ # Get unique group names
128
+ unique_groups = firing_stats["group"].unique()
129
+
130
+ # Generate colors if no color_map is provided
131
+ if color_map is None:
132
+ cmap = plt.get_cmap('viridis')
133
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
134
+ else:
135
+ # Ensure color_map contains all groups
136
+ missing_colors = [group for group in unique_groups if group not in color_map]
137
+ if missing_colors:
138
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
139
+
140
+ # Create new figure and axes if ax is not provided
141
+ if ax is None:
142
+ fig, ax = plt.subplots(figsize=(10, 6))
143
+
144
+ # Sort data for consistent plotting
145
+ firing_stats = firing_stats.sort_values(by="group")
146
+
147
+ # Extract values for plotting
148
+ x_labels = firing_stats["group"]
149
+ means = firing_stats["firing_rate_mean"]
150
+ std_devs = firing_stats["firing_rate_std"]
151
+
152
+ # Get colors for each group
153
+ colors = [color_map[group] for group in x_labels]
154
+
155
+ # Create bar plot
156
+ bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
157
+
158
+ # Add error bars manually with caps
159
+ _, caps, _ = ax.errorbar(
160
+ x=np.arange(len(x_labels)),
161
+ y=means,
162
+ yerr=std_devs,
163
+ fmt='none',
164
+ capsize=5,
165
+ capthick=2,
166
+ color="black"
167
+ )
168
+
169
+ # Formatting
170
+ ax.set_xticks(np.arange(len(x_labels)))
171
+ ax.set_xticklabels(x_labels, rotation=45, ha="right")
172
+ ax.set_xlabel("Population Group")
173
+ ax.set_ylabel("Mean Firing Rate (spikes/s)")
174
+ ax.set_title("Firing Rate Statistics by Population")
175
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
176
+
177
+ return ax
178
+
179
+ # uses df from bmtool.analysis.spikes compute_firing_rate_stats
180
+ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
181
+ color_map: Optional[Dict[str, str]] = None,
182
+ plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
183
+ """
184
+ Plots a distribution of individual firing rates using one or more plot types
185
+ (box plot, violin plot, or swarm plot), overlaying them on top of each other.
186
+
187
+ Parameters:
188
+ ----------
189
+ individual_stats : pd.DataFrame
190
+ Dataframe containing individual firing rates and corresponding group labels.
191
+ groupby : str or list of str
192
+ Column(s) used for grouping.
193
+ ax : matplotlib.axes.Axes, optional
194
+ Axes on which to plot the graph; if None, a new figure and axes are created.
195
+ color_map : dict, optional
196
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
197
+ plot_type : str or list of str, optional
198
+ List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
199
+ swarm_alpha : float, optional
200
+ Transparency of swarm plot points. Default is 0.6.
201
+
202
+ Returns:
203
+ -------
204
+ matplotlib.axes.Axes
205
+ Axes with the selected plot type(s) overlayed.
206
+ """
207
+ # Ensure groupby is a list for consistent handling
208
+ if isinstance(groupby, str):
209
+ groupby = [groupby]
210
+
211
+ # Create a categorical column for grouping
212
+ individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
213
+
214
+ # Validate plot_type (it can be a list or a single type)
215
+ if isinstance(plot_type, str):
216
+ plot_type = [plot_type]
217
+
218
+ for pt in plot_type:
219
+ if pt not in ["box", "violin", "swarm"]:
220
+ raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
221
+
222
+ # Get unique groups for coloring
223
+ unique_groups = individual_stats["group"].unique()
224
+
225
+ # Generate colors if no color_map is provided
226
+ if color_map is None:
227
+ cmap = plt.get_cmap('viridis')
228
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
229
+
230
+ # Ensure color_map contains all groups
231
+ missing_colors = [group for group in unique_groups if group not in color_map]
232
+ if missing_colors:
233
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
234
+
235
+ # Create new figure and axes if ax is not provided
236
+ if ax is None:
237
+ fig, ax = plt.subplots(figsize=(10, 6))
238
+
239
+ # Sort data for consistent plotting
240
+ individual_stats = individual_stats.sort_values(by="group")
241
+
242
+ # Loop over each plot type and overlay them
243
+ for pt in plot_type:
244
+ if pt == "box":
245
+ sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
246
+ elif pt == "violin":
247
+ sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
248
+ elif pt == "swarm":
249
+ sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
250
+
251
+ # Formatting
252
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
253
+ ax.set_xlabel("Population Group")
254
+ ax.set_ylabel("Firing Rate (spikes/s)")
255
+ ax.set_title("Firing Rate Distribution for individual cells")
256
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
257
+
258
+ return ax
259
+
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.6.9.29
3
+ Version: 0.7.0
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -4,7 +4,6 @@ setup.py
4
4
  bmtool/SLURM.py
5
5
  bmtool/__init__.py
6
6
  bmtool/__main__.py
7
- bmtool/bmplot.py
8
7
  bmtool/connectors.py
9
8
  bmtool/graphs.py
10
9
  bmtool/manage.py
@@ -22,6 +21,12 @@ bmtool/analysis/entrainment.py
22
21
  bmtool/analysis/lfp.py
23
22
  bmtool/analysis/netcon_reports.py
24
23
  bmtool/analysis/spikes.py
24
+ bmtool/bmplot/__init__.py
25
+ bmtool/bmplot/connections.py
26
+ bmtool/bmplot/entrainment.py
27
+ bmtool/bmplot/lfp.py
28
+ bmtool/bmplot/netcon_reports.py
29
+ bmtool/bmplot/spikes.py
25
30
  bmtool/debug/__init__.py
26
31
  bmtool/debug/commands.py
27
32
  bmtool/debug/debug.py
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.9.29',
9
+ version='0.7.0',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes