bmtool 0.7.7__py3-none-any.whl → 0.7.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bmtool might be problematic. Click here for more details.

bmtool/bmplot/spikes.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Plotting functions for neural spikes and firing rates."""
2
2
 
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, List, Optional, Tuple, Union
4
4
 
5
5
  import matplotlib.pyplot as plt
6
6
  import numpy as np
@@ -15,12 +15,13 @@ def raster(
15
15
  spikes_df: Optional[pd.DataFrame] = None,
16
16
  config: Optional[str] = None,
17
17
  network_name: Optional[str] = None,
18
- groupby: Optional[str] = "pop_name",
18
+ groupby: str = "pop_name",
19
+ sortby: Optional[str] = None,
19
20
  ax: Optional[Axes] = None,
20
21
  tstart: Optional[float] = None,
21
22
  tstop: Optional[float] = None,
22
23
  color_map: Optional[Dict[str, str]] = None,
23
- dot_size: Optional[float] = 0.3,
24
+ dot_size: float = 0.3,
24
25
  ) -> Axes:
25
26
  """
26
27
  Plots a raster plot of neural spikes, with different colors for each population.
@@ -33,6 +34,10 @@ def raster(
33
34
  Path to the configuration file used to load node data.
34
35
  network_name : str, optional
35
36
  Specific network name to select from the configuration; if not provided, uses the first network.
37
+ groupby : str, optional
38
+ Column name to group spikes by for coloring. Default is 'pop_name'.
39
+ sortby : str, optional
40
+ Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column.
36
41
  ax : matplotlib.axes.Axes, optional
37
42
  Axes on which to plot the raster; if None, a new figure and axes are created.
38
43
  tstart : float, optional
@@ -107,11 +112,32 @@ def raster(
107
112
 
108
113
  # Plot each population with its specified or generated color
109
114
  legend_handles = []
115
+ y_offset = 0 # Track y-position offset for stacking populations
116
+
110
117
  for pop_name, group in spikes_df.groupby(groupby):
111
- ax.scatter(group["timestamps"], group["node_ids"], color=color_map[pop_name], s=dot_size)
118
+ if sortby:
119
+ # Sort by the specified column, putting NaN values at the end
120
+ group_sorted = group.sort_values(by=sortby, na_position='last')
121
+ # Create a mapping from node_ids to consecutive y-positions based on sorted order
122
+ # Use the sorted order to maintain the same sequence for all spikes from same node
123
+ unique_nodes_sorted = group_sorted['node_ids'].drop_duplicates()
124
+ node_to_y = {node_id: y_offset + i for i, node_id in enumerate(unique_nodes_sorted)}
125
+ # Map node_ids to new y-positions for ALL spikes (not just the sorted group)
126
+ y_positions = group['node_ids'].map(node_to_y)
127
+ # Verify no data was lost
128
+ assert len(y_positions) == len(group), f"Data loss detected in population {pop_name}"
129
+ assert y_positions.isna().sum() == 0, f"Unmapped node_ids found in population {pop_name}"
130
+ else:
131
+ y_positions = group['node_ids']
132
+
133
+ ax.scatter(group["timestamps"], y_positions, color=color_map[pop_name], s=dot_size)
112
134
  # Dummy scatter for consistent legend appearance
113
135
  handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
114
136
  legend_handles.append(handle)
137
+
138
+ # Update y_offset for next population if sortby is used
139
+ if sortby:
140
+ y_offset += len(unique_nodes_sorted)
115
141
 
116
142
  # Label axes
117
143
  ax.set_xlabel("Time")
@@ -211,11 +237,12 @@ def plot_firing_rate_pop_stats(
211
237
  # uses df from bmtool.analysis.spikes compute_firing_rate_stats
212
238
  def plot_firing_rate_distribution(
213
239
  individual_stats: pd.DataFrame,
214
- groupby: Union[str, list],
240
+ groupby: Union[str, List[str]],
215
241
  ax: Optional[Axes] = None,
216
242
  color_map: Optional[Dict[str, str]] = None,
217
- plot_type: Union[str, list] = "box",
243
+ plot_type: Union[str, List[str]] = "box",
218
244
  swarm_alpha: float = 0.6,
245
+ logscale: bool = False,
219
246
  ) -> Axes:
220
247
  """
221
248
  Plots a distribution of individual firing rates using one or more plot types
@@ -235,6 +262,8 @@ def plot_firing_rate_distribution(
235
262
  List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
236
263
  swarm_alpha : float, optional
237
264
  Transparency of swarm plot points. Default is 0.6.
265
+ logscale : bool, optional
266
+ If True, use logarithmic scale for the y-axis (default is False).
238
267
 
239
268
  Returns:
240
269
  -------
@@ -316,40 +345,46 @@ def plot_firing_rate_distribution(
316
345
  ax.set_title("Firing Rate Distribution for individual cells")
317
346
  ax.grid(axis="y", linestyle="--", alpha=0.7)
318
347
 
348
+ if logscale:
349
+ ax.set_yscale('log')
350
+
319
351
  return ax
320
352
 
321
353
 
322
354
  def plot_firing_rate_vs_node_attribute(
323
- individual_stats: Optional[pd.DataFrame] = None,
355
+ individual_stats: pd.DataFrame,
356
+ groupby: str,
357
+ attribute: str,
324
358
  config: Optional[str] = None,
325
359
  nodes: Optional[pd.DataFrame] = None,
326
- groupby: Optional[str] = None,
327
360
  network_name: Optional[str] = None,
328
- attribute: Optional[str] = None,
329
- figsize=(12, 8),
361
+ figsize: Tuple[float, float] = (12, 8),
330
362
  dot_size: float = 3,
363
+ color_map: Optional[Dict[str, str]] = None,
331
364
  ) -> plt.Figure:
332
365
  """
333
366
  Plot firing rate vs node attribute for each group in separate subplots.
334
367
 
335
368
  Parameters
336
369
  ----------
337
- individual_stats : pd.DataFrame, optional
370
+ individual_stats : pd.DataFrame
338
371
  DataFrame containing individual cell firing rates from compute_firing_rate_stats
372
+ groupby : str
373
+ Column name in individual_stats to group plots by
374
+ attribute : str
375
+ Node attribute column name to plot against firing rate
339
376
  config : str, optional
340
377
  Path to configuration file for loading node data
341
378
  nodes : pd.DataFrame, optional
342
379
  Pre-loaded node data as alternative to loading from config
343
- groupby : str, optional
344
- Column name in individual_stats to group plots by
345
380
  network_name : str, optional
346
381
  Name of network to load from config file
347
- attribute : str, optional
348
- Node attribute column name to plot against firing rate
349
- figsize : tuple[int, int], optional
382
+ figsize : Tuple[float, float], optional
350
383
  Figure dimensions (width, height) in inches
351
384
  dot_size : float, optional
352
385
  Size of scatter plot points
386
+ color_map : dict, optional
387
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
353
388
 
354
389
  Returns
355
390
  -------
@@ -407,12 +442,26 @@ def plot_firing_rate_vs_node_attribute(
407
442
  axes = np.array([axes])
408
443
  axes = axes.flatten()
409
444
 
445
+ # Generate colors if no color_map is provided
446
+ if color_map is None:
447
+ cmap = plt.get_cmap("tab10")
448
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
449
+ else:
450
+ # Ensure color_map contains all groups
451
+ missing_colors = [group for group in unique_groups if group not in color_map]
452
+ if missing_colors:
453
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
454
+
410
455
  # Plot each group
411
456
  for i, group in enumerate(unique_groups):
412
457
  group_df = merged_df[merged_df[groupby] == group]
413
- axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size)
458
+ axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size, color=color_map[group])
414
459
  axes[i].set_xlabel("Firing Rate (Hz)")
415
460
  axes[i].set_ylabel(attribute)
461
+
462
+ # Calculate and display mean firing rate in legend
463
+ mean_fr = group_df["firing_rate"].mean()
464
+ axes[i].legend([f"Mean FR: {mean_fr:.2f} Hz"], loc="upper right")
416
465
  axes[i].set_title(f"{groupby}: {group}")
417
466
 
418
467
  # Hide unused subplots
@@ -420,4 +469,112 @@ def plot_firing_rate_vs_node_attribute(
420
469
  axes[j].set_visible(False)
421
470
 
422
471
  plt.tight_layout()
423
- plt.show()
472
+ return fig
473
+
474
+
475
+ def plot_firing_rate_histogram(
476
+ individual_stats: pd.DataFrame,
477
+ groupby: str = "pop_name",
478
+ ax: Optional[Axes] = None,
479
+ color_map: Optional[Dict[str, str]] = None,
480
+ bins: int = 30,
481
+ alpha: float = 0.7,
482
+ figsize: Tuple[float, float] = (12, 8),
483
+ stacked: bool = False,
484
+ logscale: bool = False,
485
+ min_fr: Optional[float] = None,
486
+ ) -> plt.Figure:
487
+ """
488
+ Plot histograms of firing rates for each population group.
489
+
490
+ Parameters:
491
+ ----------
492
+ individual_stats : pd.DataFrame
493
+ DataFrame containing individual firing rates with group labels.
494
+ groupby : str, optional
495
+ Column name to group by (default is "pop_name").
496
+ ax : matplotlib.axes.Axes, optional
497
+ Axes on which to plot; if None, a new figure is created.
498
+ color_map : dict, optional
499
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
500
+ bins : int, optional
501
+ Number of bins for the histogram (default is 30).
502
+ alpha : float, optional
503
+ Transparency level for the histograms (default is 0.7).
504
+ figsize : Tuple[float, float], optional
505
+ Figure size if creating a new figure (default is (12, 8)).
506
+ stacked : bool, optional
507
+ If True, plot all histograms on a single axes stacked (default is False).
508
+ logscale : bool, optional
509
+ If True, use logarithmic scale for the x-axis (default is False).
510
+ min_fr : float, optional
511
+ Minimum firing rate for log scale bins (default is None).
512
+
513
+ Returns:
514
+ -------
515
+ matplotlib.figure.Figure
516
+ Figure containing the histogram subplots.
517
+ """
518
+ sns.set_style("whitegrid")
519
+
520
+ # Get unique groups
521
+ unique_groups = individual_stats[groupby].unique()
522
+
523
+ # Generate colors if no color_map is provided
524
+ if color_map is None:
525
+ cmap = plt.get_cmap("tab10")
526
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
527
+ else:
528
+ # Ensure color_map contains all groups
529
+ missing_colors = [group for group in unique_groups if group not in color_map]
530
+ if missing_colors:
531
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
532
+
533
+ # Group data by population
534
+ pop_fr = {}
535
+ for group in unique_groups:
536
+ pop_fr[group] = individual_stats[individual_stats[groupby] == group]["firing_rate"].values
537
+
538
+ if logscale and min_fr is not None:
539
+ pop_fr = {p: np.fmax(fr, min_fr) for p, fr in pop_fr.items()}
540
+ fr = np.concatenate(list(pop_fr.values()))
541
+ if logscale:
542
+ fr = fr[fr > 0]
543
+ bins_array = np.geomspace(fr.min(), fr.max(), bins + 1)
544
+ else:
545
+ bins_array = np.linspace(fr.min(), fr.max(), bins + 1)
546
+
547
+ # Setup subplot layout or single plot
548
+ n_groups = len(unique_groups)
549
+ if stacked or not stacked: # Always use single ax for now, since stacked means overlaid
550
+ fig, ax = plt.subplots(figsize=figsize)
551
+ else:
552
+ # If not stacked, but since overlaid is default, perhaps keep as is
553
+ fig, ax = plt.subplots(figsize=figsize)
554
+
555
+ if stacked:
556
+ ax.hist(pop_fr.values(), bins=bins_array, label=list(pop_fr.keys()),
557
+ color=[color_map[p] for p in pop_fr.keys()], stacked=True)
558
+ else:
559
+ for p, fr_vals in pop_fr.items():
560
+ ax.hist(fr_vals, bins=bins_array, label=p, color=color_map[p], alpha=alpha)
561
+
562
+ if logscale:
563
+ ax.set_xscale('log')
564
+ plt.draw()
565
+ xt = ax.get_xticks()
566
+ xtl = [f'{x:g}' for x in xt]
567
+ if min_fr is not None:
568
+ xt = np.append(xt, min_fr)
569
+ xtl.append('0')
570
+ ax.set_xticks(xt)
571
+ ax.set_xticklabels(xtl)
572
+
573
+ ax.set_xlim(bins_array[0], bins_array[-1])
574
+ ax.legend(loc='upper right')
575
+ ax.set_title('Firing Rate Histogram')
576
+ ax.set_xlabel('Frequency (Hz)')
577
+ ax.set_ylabel('Count')
578
+ return fig
579
+
580
+
bmtool/singlecell.py CHANGED
@@ -1042,6 +1042,7 @@ class Profiler:
1042
1042
  self.mechanism_dir = None
1043
1043
  self.templates = None # Initialize templates attribute
1044
1044
  self.config = config # Store config path
1045
+ self.last_figure = None # Store reference to last generated figure
1045
1046
 
1046
1047
  # If a BMTK config is provided, load mechanisms/templates from it
1047
1048
  if config is not None:
@@ -1202,6 +1203,7 @@ class Profiler:
1202
1203
  plt.title("Passive Cell Current Injection")
1203
1204
  plt.xlabel("Time (ms)")
1204
1205
  plt.ylabel("Membrane Potential (mV)")
1206
+ self.last_figure = plt.gcf()
1205
1207
  plt.show()
1206
1208
 
1207
1209
  return time, amp
@@ -1230,6 +1232,8 @@ class Profiler:
1230
1232
  plt.title("Current Injection")
1231
1233
  plt.xlabel("Time (ms)")
1232
1234
  plt.ylabel("Membrane Potential (mV)")
1235
+ plt.xlim(ccl.inj_delay - 10, ccl.inj_delay + ccl.inj_dur + 10)
1236
+ self.last_figure = plt.gcf()
1233
1237
  plt.show()
1234
1238
 
1235
1239
  return time, amp
@@ -1278,6 +1282,7 @@ class Profiler:
1278
1282
  plt.title("FI Curve")
1279
1283
  plt.xlabel("Injection (pA)")
1280
1284
  plt.ylabel("# Spikes")
1285
+ self.last_figure = plt.gcf()
1281
1286
  plt.show()
1282
1287
 
1283
1288
  return amp, nspk
@@ -1317,18 +1322,22 @@ class Profiler:
1317
1322
  plt.title("ZAP Response")
1318
1323
  plt.xlabel("Time (ms)")
1319
1324
  plt.ylabel("Membrane Potential (mV)")
1325
+ self.last_figure = plt.gcf()
1320
1326
 
1321
1327
  plt.figure()
1322
1328
  plt.plot(time, zap.zap_vec)
1323
1329
  plt.title("ZAP Current")
1324
1330
  plt.xlabel("Time (ms)")
1325
1331
  plt.ylabel("Current Injection (nA)")
1332
+ # Note: This will overwrite last_figure with the current plot
1333
+ self.last_figure = plt.gcf()
1326
1334
 
1327
1335
  plt.figure()
1328
1336
  plt.plot(*zap.get_impedance(smooth=smooth))
1329
1337
  plt.title("Impedance Amplitude Profile")
1330
1338
  plt.xlabel("Frequency (Hz)")
1331
1339
  plt.ylabel("Impedance (MOhms)")
1340
+ self.last_figure = plt.gcf()
1332
1341
  plt.show()
1333
1342
 
1334
1343
  return time, amp
@@ -1461,6 +1470,15 @@ class Profiler:
1461
1470
  layout=widgets.Layout(width='150px')
1462
1471
  )
1463
1472
 
1473
+ save_path_text = widgets.Text(value='', description='Save Path:', placeholder='e.g., plot.png', style={'description_width': 'initial'}, layout=widgets.Layout(width='300px'))
1474
+
1475
+ save_button = widgets.Button(
1476
+ description='Save Plot',
1477
+ button_style='success',
1478
+ icon='save',
1479
+ layout=widgets.Layout(width='120px')
1480
+ )
1481
+
1464
1482
  output_area = widgets.Output(
1465
1483
  layout=widgets.Layout(border='1px solid #ccc', padding='10px', margin='10px 0 0 0')
1466
1484
  )
@@ -1475,7 +1493,8 @@ class Profiler:
1475
1493
  # Button row - main controls
1476
1494
  button_row = widgets.HBox([
1477
1495
  run_button,
1478
- reset_button
1496
+ reset_button,
1497
+ save_button
1479
1498
  ], layout=widgets.Layout(margin='0 0 10px 0'))
1480
1499
 
1481
1500
  # Section row - recording and injection sections
@@ -1485,6 +1504,11 @@ class Profiler:
1485
1504
  post_init_text
1486
1505
  ], layout=widgets.Layout(margin='0 0 10px 0'))
1487
1506
 
1507
+ # Save row
1508
+ save_row = widgets.HBox([
1509
+ save_path_text
1510
+ ], layout=widgets.Layout(margin='0 0 10px 0'))
1511
+
1488
1512
  # Parameter columns - organized in columns like synapse tuner
1489
1513
  injection_params_col1 = widgets.VBox([
1490
1514
  inj_amp_slider,
@@ -1681,15 +1705,36 @@ class Profiler:
1681
1705
  update_slider_values(method)
1682
1706
  print(f"Reset all parameters to defaults for {method}")
1683
1707
 
1708
+ # Save function
1709
+ def save_plot(b):
1710
+ path = save_path_text.value
1711
+ if not path:
1712
+ with output_area:
1713
+ print("Please enter a save path")
1714
+ return
1715
+ if self.last_figure is None:
1716
+ with output_area:
1717
+ print("No plot to save. Run an analysis first.")
1718
+ return
1719
+ try:
1720
+ self.last_figure.savefig(path)
1721
+ with output_area:
1722
+ print(f"Plot saved to {path}")
1723
+ except Exception as e:
1724
+ with output_area:
1725
+ print(f"Error saving plot: {e}")
1726
+
1684
1727
  run_button.on_click(run_analysis)
1685
1728
  reset_button.on_click(reset_to_defaults)
1729
+ save_button.on_click(save_plot)
1686
1730
 
1687
1731
  # Create main UI layout - matching synapse tuner structure
1688
1732
  ui = widgets.VBox([
1689
1733
  selection_row,
1690
1734
  button_row,
1691
1735
  section_row,
1692
- param_columns
1736
+ param_columns,
1737
+ save_row
1693
1738
  ], layout=widgets.Layout(padding='10px'))
1694
1739
 
1695
1740
  # Display the interface - UI on top, output below (like synapse tuner)