bmtool 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bmtool might be problematic. Click here for more details.

bmtool/bmplot/spikes.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Plotting functions for neural spikes and firing rates."""
2
2
 
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, List, Optional, Tuple, Union
4
4
 
5
5
  import matplotlib.pyplot as plt
6
6
  import numpy as np
@@ -15,12 +15,13 @@ def raster(
15
15
  spikes_df: Optional[pd.DataFrame] = None,
16
16
  config: Optional[str] = None,
17
17
  network_name: Optional[str] = None,
18
- groupby: Optional[str] = "pop_name",
18
+ groupby: str = "pop_name",
19
+ sortby: Optional[str] = None,
19
20
  ax: Optional[Axes] = None,
20
21
  tstart: Optional[float] = None,
21
22
  tstop: Optional[float] = None,
22
23
  color_map: Optional[Dict[str, str]] = None,
23
- dot_size: Optional[float] = 0.3,
24
+ dot_size: float = 0.3,
24
25
  ) -> Axes:
25
26
  """
26
27
  Plots a raster plot of neural spikes, with different colors for each population.
@@ -33,6 +34,10 @@ def raster(
33
34
  Path to the configuration file used to load node data.
34
35
  network_name : str, optional
35
36
  Specific network name to select from the configuration; if not provided, uses the first network.
37
+ groupby : str, optional
38
+ Column name to group spikes by for coloring. Default is 'pop_name'.
39
+ sortby : str, optional
40
+ Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column.
36
41
  ax : matplotlib.axes.Axes, optional
37
42
  Axes on which to plot the raster; if None, a new figure and axes are created.
38
43
  tstart : float, optional
@@ -107,11 +112,32 @@ def raster(
107
112
 
108
113
  # Plot each population with its specified or generated color
109
114
  legend_handles = []
115
+ y_offset = 0 # Track y-position offset for stacking populations
116
+
110
117
  for pop_name, group in spikes_df.groupby(groupby):
111
- ax.scatter(group["timestamps"], group["node_ids"], color=color_map[pop_name], s=dot_size)
118
+ if sortby:
119
+ # Sort by the specified column, putting NaN values at the end
120
+ group_sorted = group.sort_values(by=sortby, na_position='last')
121
+ # Create a mapping from node_ids to consecutive y-positions based on sorted order
122
+ # Use the sorted order to maintain the same sequence for all spikes from same node
123
+ unique_nodes_sorted = group_sorted['node_ids'].drop_duplicates()
124
+ node_to_y = {node_id: y_offset + i for i, node_id in enumerate(unique_nodes_sorted)}
125
+ # Map node_ids to new y-positions for ALL spikes (not just the sorted group)
126
+ y_positions = group['node_ids'].map(node_to_y)
127
+ # Verify no data was lost
128
+ assert len(y_positions) == len(group), f"Data loss detected in population {pop_name}"
129
+ assert y_positions.isna().sum() == 0, f"Unmapped node_ids found in population {pop_name}"
130
+ else:
131
+ y_positions = group['node_ids']
132
+
133
+ ax.scatter(group["timestamps"], y_positions, color=color_map[pop_name], s=dot_size)
112
134
  # Dummy scatter for consistent legend appearance
113
135
  handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
114
136
  legend_handles.append(handle)
137
+
138
+ # Update y_offset for next population if sortby is used
139
+ if sortby:
140
+ y_offset += len(unique_nodes_sorted)
115
141
 
116
142
  # Label axes
117
143
  ax.set_xlabel("Time")
@@ -211,11 +237,12 @@ def plot_firing_rate_pop_stats(
211
237
  # uses df from bmtool.analysis.spikes compute_firing_rate_stats
212
238
  def plot_firing_rate_distribution(
213
239
  individual_stats: pd.DataFrame,
214
- groupby: Union[str, list],
240
+ groupby: Union[str, List[str]],
215
241
  ax: Optional[Axes] = None,
216
242
  color_map: Optional[Dict[str, str]] = None,
217
- plot_type: Union[str, list] = "box",
243
+ plot_type: Union[str, List[str]] = "box",
218
244
  swarm_alpha: float = 0.6,
245
+ logscale: bool = False,
219
246
  ) -> Axes:
220
247
  """
221
248
  Plots a distribution of individual firing rates using one or more plot types
@@ -235,6 +262,8 @@ def plot_firing_rate_distribution(
235
262
  List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
236
263
  swarm_alpha : float, optional
237
264
  Transparency of swarm plot points. Default is 0.6.
265
+ logscale : bool, optional
266
+ If True, use logarithmic scale for the y-axis (default is False).
238
267
 
239
268
  Returns:
240
269
  -------
@@ -316,40 +345,46 @@ def plot_firing_rate_distribution(
316
345
  ax.set_title("Firing Rate Distribution for individual cells")
317
346
  ax.grid(axis="y", linestyle="--", alpha=0.7)
318
347
 
348
+ if logscale:
349
+ ax.set_yscale('log')
350
+
319
351
  return ax
320
352
 
321
353
 
322
354
  def plot_firing_rate_vs_node_attribute(
323
- individual_stats: Optional[pd.DataFrame] = None,
355
+ individual_stats: pd.DataFrame,
356
+ groupby: str,
357
+ attribute: str,
324
358
  config: Optional[str] = None,
325
359
  nodes: Optional[pd.DataFrame] = None,
326
- groupby: Optional[str] = None,
327
360
  network_name: Optional[str] = None,
328
- attribute: Optional[str] = None,
329
- figsize=(12, 8),
361
+ figsize: Tuple[float, float] = (12, 8),
330
362
  dot_size: float = 3,
363
+ color_map: Optional[Dict[str, str]] = None,
331
364
  ) -> plt.Figure:
332
365
  """
333
366
  Plot firing rate vs node attribute for each group in separate subplots.
334
367
 
335
368
  Parameters
336
369
  ----------
337
- individual_stats : pd.DataFrame, optional
370
+ individual_stats : pd.DataFrame
338
371
  DataFrame containing individual cell firing rates from compute_firing_rate_stats
372
+ groupby : str
373
+ Column name in individual_stats to group plots by
374
+ attribute : str
375
+ Node attribute column name to plot against firing rate
339
376
  config : str, optional
340
377
  Path to configuration file for loading node data
341
378
  nodes : pd.DataFrame, optional
342
379
  Pre-loaded node data as alternative to loading from config
343
- groupby : str, optional
344
- Column name in individual_stats to group plots by
345
380
  network_name : str, optional
346
381
  Name of network to load from config file
347
- attribute : str, optional
348
- Node attribute column name to plot against firing rate
349
- figsize : tuple[int, int], optional
382
+ figsize : Tuple[float, float], optional
350
383
  Figure dimensions (width, height) in inches
351
384
  dot_size : float, optional
352
385
  Size of scatter plot points
386
+ color_map : dict, optional
387
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
353
388
 
354
389
  Returns
355
390
  -------
@@ -407,12 +442,26 @@ def plot_firing_rate_vs_node_attribute(
407
442
  axes = np.array([axes])
408
443
  axes = axes.flatten()
409
444
 
445
+ # Generate colors if no color_map is provided
446
+ if color_map is None:
447
+ cmap = plt.get_cmap("tab10")
448
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
449
+ else:
450
+ # Ensure color_map contains all groups
451
+ missing_colors = [group for group in unique_groups if group not in color_map]
452
+ if missing_colors:
453
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
454
+
410
455
  # Plot each group
411
456
  for i, group in enumerate(unique_groups):
412
457
  group_df = merged_df[merged_df[groupby] == group]
413
- axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size)
458
+ axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size, color=color_map[group])
414
459
  axes[i].set_xlabel("Firing Rate (Hz)")
415
460
  axes[i].set_ylabel(attribute)
461
+
462
+ # Calculate and display mean firing rate in legend
463
+ mean_fr = group_df["firing_rate"].mean()
464
+ axes[i].legend([f"Mean FR: {mean_fr:.2f} Hz"], loc="upper right")
416
465
  axes[i].set_title(f"{groupby}: {group}")
417
466
 
418
467
  # Hide unused subplots
@@ -420,4 +469,112 @@ def plot_firing_rate_vs_node_attribute(
420
469
  axes[j].set_visible(False)
421
470
 
422
471
  plt.tight_layout()
423
- plt.show()
472
+ return fig
473
+
474
+
475
+ def plot_firing_rate_histogram(
476
+ individual_stats: pd.DataFrame,
477
+ groupby: str = "pop_name",
478
+ ax: Optional[Axes] = None,
479
+ color_map: Optional[Dict[str, str]] = None,
480
+ bins: int = 30,
481
+ alpha: float = 0.7,
482
+ figsize: Tuple[float, float] = (12, 8),
483
+ stacked: bool = False,
484
+ logscale: bool = False,
485
+ min_fr: Optional[float] = None,
486
+ ) -> plt.Figure:
487
+ """
488
+ Plot histograms of firing rates for each population group.
489
+
490
+ Parameters:
491
+ ----------
492
+ individual_stats : pd.DataFrame
493
+ DataFrame containing individual firing rates with group labels.
494
+ groupby : str, optional
495
+ Column name to group by (default is "pop_name").
496
+ ax : matplotlib.axes.Axes, optional
497
+ Axes on which to plot; if None, a new figure is created.
498
+ color_map : dict, optional
499
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
500
+ bins : int, optional
501
+ Number of bins for the histogram (default is 30).
502
+ alpha : float, optional
503
+ Transparency level for the histograms (default is 0.7).
504
+ figsize : Tuple[float, float], optional
505
+ Figure size if creating a new figure (default is (12, 8)).
506
+ stacked : bool, optional
507
+ If True, plot all histograms on a single axes stacked (default is False).
508
+ logscale : bool, optional
509
+ If True, use logarithmic scale for the x-axis (default is False).
510
+ min_fr : float, optional
511
+ Minimum firing rate for log scale bins (default is None).
512
+
513
+ Returns:
514
+ -------
515
+ matplotlib.figure.Figure
516
+ Figure containing the histogram subplots.
517
+ """
518
+ sns.set_style("whitegrid")
519
+
520
+ # Get unique groups
521
+ unique_groups = individual_stats[groupby].unique()
522
+
523
+ # Generate colors if no color_map is provided
524
+ if color_map is None:
525
+ cmap = plt.get_cmap("tab10")
526
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
527
+ else:
528
+ # Ensure color_map contains all groups
529
+ missing_colors = [group for group in unique_groups if group not in color_map]
530
+ if missing_colors:
531
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
532
+
533
+ # Group data by population
534
+ pop_fr = {}
535
+ for group in unique_groups:
536
+ pop_fr[group] = individual_stats[individual_stats[groupby] == group]["firing_rate"].values
537
+
538
+ if logscale and min_fr is not None:
539
+ pop_fr = {p: np.fmax(fr, min_fr) for p, fr in pop_fr.items()}
540
+ fr = np.concatenate(list(pop_fr.values()))
541
+ if logscale:
542
+ fr = fr[fr > 0]
543
+ bins_array = np.geomspace(fr.min(), fr.max(), bins + 1)
544
+ else:
545
+ bins_array = np.linspace(fr.min(), fr.max(), bins + 1)
546
+
547
+ # Setup subplot layout or single plot
548
+ n_groups = len(unique_groups)
549
+ if stacked or not stacked: # Always use single ax for now, since stacked means overlaid
550
+ fig, ax = plt.subplots(figsize=figsize)
551
+ else:
552
+ # If not stacked, but since overlaid is default, perhaps keep as is
553
+ fig, ax = plt.subplots(figsize=figsize)
554
+
555
+ if stacked:
556
+ ax.hist(pop_fr.values(), bins=bins_array, label=list(pop_fr.keys()),
557
+ color=[color_map[p] for p in pop_fr.keys()], stacked=True)
558
+ else:
559
+ for p, fr_vals in pop_fr.items():
560
+ ax.hist(fr_vals, bins=bins_array, label=p, color=color_map[p], alpha=alpha)
561
+
562
+ if logscale:
563
+ ax.set_xscale('log')
564
+ plt.draw()
565
+ xt = ax.get_xticks()
566
+ xtl = [f'{x:g}' for x in xt]
567
+ if min_fr is not None:
568
+ xt = np.append(xt, min_fr)
569
+ xtl.append('0')
570
+ ax.set_xticks(xt)
571
+ ax.set_xticklabels(xtl)
572
+
573
+ ax.set_xlim(bins_array[0], bins_array[-1])
574
+ ax.legend(loc='upper right')
575
+ ax.set_title('Firing Rate Histogram')
576
+ ax.set_xlabel('Frequency (Hz)')
577
+ ax.set_ylabel('Count')
578
+ return fig
579
+
580
+
bmtool/connectors.py CHANGED
@@ -1604,6 +1604,392 @@ class CorrelatedGapJunction(GapJunction):
1604
1604
  if self.save_report:
1605
1605
  self.save_connection_report()
1606
1606
  return nsyns
1607
+
1608
+ class GapJunctionConditionalReciprocalConnector(AbstractConnector):
1609
+ """
1610
+ Object for building reciprocal chemical synapses in BMTK network model with
1611
+ probabilities that depend on the presence of gap junctions between cell pairs.
1612
+
1613
+ This connector creates chemical synapses where the connection probabilities
1614
+ (forward, backward, and reciprocal) differ based on whether a pair of cells
1615
+ shares a gap junction. This allows modeling of the experimentally observed
1616
+ correlation between electrical and chemical coupling in neural populations.
1617
+
1618
+ Algorithm:
1619
+ For each potential connection pair, first determine if a gap junction exists
1620
+ using the provided gap_connector. Then apply different bivariate Bernoulli
1621
+ probability distributions for chemical synapses based on electrical coupling
1622
+ status:
1623
+ - Electrically coupled pairs: Use p0_elec, p1_elec, pr_elec probabilities
1624
+ - Non-electrically coupled pairs: Use p0_nonelec, p1_nonelec, pr_nonelec probabilities
1625
+
1626
+ For each pair, generate random connections following the same bivariate
1627
+ Bernoulli distribution as ReciprocalConnector, but with conditional
1628
+ probabilities based on gap junction presence.
1629
+
1630
+ Use with BMTK:
1631
+ 1. First create and set up a gap junction connector:
1632
+
1633
+ gap_connector = GapJunction(p=0.08, verbose=True)
1634
+ gap_connector.setup_nodes(source_population, target_population)
1635
+ net.add_edges(is_gap_junction=True, **gap_connector.edge_params())
1636
+
1637
+ 2. Create the conditional reciprocal connector with different probabilities
1638
+ for electrically coupled vs non-coupled pairs:
1639
+
1640
+ chemical_connector = GapJunctionConditionalReciprocalConnector(
1641
+ gap_connector=gap_connector,
1642
+ p0_elec=0.50, p1_elec=0.50, pr_elec=0.25, # High reciprocity for coupled pairs
1643
+ p0_nonelec=0.125, p1_nonelec=0.125, pr_nonelec=0.03, # Low reciprocity for non-coupled pairs
1644
+ verbose=True
1645
+ )
1646
+
1647
+ 3. Set up nodes and add chemical edges:
1648
+
1649
+ chemical_connector.setup_nodes(source_population, target_population)
1650
+ net.add_edges(**chemical_connector.edge_params(),
1651
+ **chemical_synapse_properties)
1652
+
1653
+ 4. Build the network:
1654
+
1655
+ net.build()
1656
+
1657
+ Parameters:
1658
+ gap_connector: GapJunction connector object that has been set up with nodes.
1659
+ Used to determine which cell pairs have electrical coupling.
1660
+ p0_elec, p1_elec: Forward and backward connection probabilities for
1661
+ electrically coupled pairs. Can be constants or functions within [0, 1].
1662
+ pr_elec: Reciprocal connection probability for electrically coupled pairs.
1663
+ Can be a constant or function accepting (pr_arg, p0, p1) arguments.
1664
+ p0_nonelec, p1_nonelec: Forward and backward connection probabilities for
1665
+ non-electrically coupled pairs. Can be constants or functions within [0, 1].
1666
+ pr_nonelec: Reciprocal connection probability for non-electrically coupled pairs.
1667
+ Can be a constant or function accepting (pr_arg, p0, p1) arguments.
1668
+ p0_elec_arg, p1_elec_arg, pr_elec_arg: Input arguments for electrically coupled
1669
+ probability functions. Can be constants, distance functions, or other
1670
+ node property functions. Set to None if functions don't need arguments.
1671
+ p0_nonelec_arg, p1_nonelec_arg, pr_nonelec_arg: Input arguments for
1672
+ non-electrically coupled probability functions, similar to above.
1673
+ n_syn0, n_syn1: Number of synapses for forward/backward connections if
1674
+ established. Can be constants or functions of node properties.
1675
+ Limited to 255 due to uint8 storage.
1676
+ verbose: Whether to print detailed connection statistics and progress.
1677
+ save_report: Whether to save connection report to CSV file.
1678
+ report_name: Filename for connection report (default: "conn.csv").
1679
+
1680
+ Returns:
1681
+ An object that works with BMTK to build conditional reciprocal chemical
1682
+ edges in a network, with probabilities dependent on gap junction presence.
1683
+
1684
+ Important attributes:
1685
+ gap_connector: Reference to the GapJunction connector used for coupling detection.
1686
+ vars: Dictionary storing original input parameters.
1687
+ source, target: NodePool objects for source and target populations.
1688
+ recurrent: Whether source and target populations are the same.
1689
+ conn_mat: Connection matrix storing synapse counts.
1690
+ conn_prop: List of dictionaries storing connection properties for forward
1691
+ and backward connections. Format: [{src_id: {tgt_id: prop}, ...}, ...]
1692
+ gap_decisions: Dictionary caching gap junction presence for each pair.
1693
+ connection_stats: Statistics tracking connections by electrical coupling status.
1694
+ Format: {'elec': {'pairs': int, 'uni': int, 'recp': int},
1695
+ 'nonelec': {'pairs': int, 'uni': int, 'recp': int}}
1696
+ """
1697
+
1698
+ def __init__(self, gap_connector, p0_elec, p1_elec, pr_elec, p0_nonelec, p1_nonelec, pr_nonelec,
1699
+ p0_elec_arg=None, p1_elec_arg=None, pr_elec_arg=None,
1700
+ p0_nonelec_arg=None, p1_nonelec_arg=None, pr_nonelec_arg=None,
1701
+ n_syn0=1, n_syn1=1, verbose=True, save_report=True, report_name=None):
1702
+ # Store original parameters like ReciprocalConnector
1703
+ args = locals()
1704
+ var_set = ("p0_elec", "p0_elec_arg", "p1_elec", "p1_elec_arg", "pr_elec", "pr_elec_arg",
1705
+ "p0_nonelec", "p0_nonelec_arg", "p1_nonelec", "p1_nonelec_arg", "pr_nonelec", "pr_nonelec_arg",
1706
+ "n_syn0", "n_syn1")
1707
+ self.vars = {key: args[key] for key in var_set}
1708
+
1709
+ self.gap_connector = gap_connector
1710
+ self.verbose = verbose
1711
+ self.save_report = save_report
1712
+ self.report_name = report_name or "conn.csv"
1713
+ self.conn_prop = [{}, {}]
1714
+ self.stage = 0
1715
+ # Track gap junction decisions and connections for detailed reporting
1716
+ self.gap_decisions = {}
1717
+ self.connection_stats = {'elec': {'pairs': 0, 'uni': 0, 'recp': 0},
1718
+ 'nonelec': {'pairs': 0, 'uni': 0, 'recp': 0}}
1719
+
1720
+ def setup_variables(self):
1721
+ """Set up variables like ReciprocalConnector does"""
1722
+ callable_set = set()
1723
+ # Make constant variables constant functions
1724
+ for name, var in self.vars.items():
1725
+ if callable(var):
1726
+ callable_set.add(name) # record callable variables
1727
+ setattr(self, name, var)
1728
+ else:
1729
+ setattr(self, name, self.constant_function(var))
1730
+ callable_set.add(name) # constants converted to functions are also callable
1731
+ self.callable_set = callable_set
1732
+
1733
+ # Make callable variables accept index input instead of node input
1734
+ # Exclude probability functions (p0_elec, p1_elec, pr_elec, p0_nonelec, p1_nonelec, pr_nonelec)
1735
+ # as they are called with arguments from _arg functions
1736
+ for name in callable_set - {"p0_elec", "p1_elec", "pr_elec", "p0_nonelec", "p1_nonelec", "pr_nonelec"}:
1737
+ var = getattr(self, name) # Get the already converted function
1738
+ setattr(self, name, self.node_2_idx_input(var, '1' in name and name.startswith(('p1_', 'pr_', 'n_syn1'))))
1739
+
1740
+ @staticmethod
1741
+ def constant_function(val):
1742
+ """Convert a constant to a constant function"""
1743
+ def constant(*arg):
1744
+ return val
1745
+ return constant
1746
+
1747
+ def node_2_idx_input(self, var_func, reverse=False):
1748
+ """Convert a function that accept nodes as input to accept indices as input"""
1749
+ if reverse:
1750
+ def idx_2_var(j, i):
1751
+ return var_func(self.target_list[j], self.source_list[i])
1752
+ else:
1753
+ def idx_2_var(i, j):
1754
+ return var_func(self.source_list[i], self.target_list[j])
1755
+ return idx_2_var
1756
+
1757
+ def setup_nodes(self, source, target):
1758
+ self.source = source
1759
+ self.target = target
1760
+ self.recurrent = is_same_pop(self.source, self.target)
1761
+ self.source_ids = [s.node_id for s in self.source]
1762
+ self.n_source = len(self.source_ids)
1763
+ self.source_list = list(self.source)
1764
+ if self.recurrent:
1765
+ self.target_ids = self.source_ids
1766
+ self.n_target = self.n_source
1767
+ self.target_list = self.source_list
1768
+ else:
1769
+ self.target_ids = [t.node_id for t in self.target]
1770
+ self.n_target = len(self.target_ids)
1771
+ self.target_list = list(self.target)
1772
+
1773
+ def edge_params(self):
1774
+ if self.stage == 0:
1775
+ params = {
1776
+ "source": self.source,
1777
+ "target": self.target,
1778
+ "iterator": "one_to_all",
1779
+ "connection_rule": self.make_forward_connection,
1780
+ }
1781
+ else:
1782
+ params = {
1783
+ "source": self.target,
1784
+ "target": self.source,
1785
+ "iterator": "all_to_one",
1786
+ "connection_rule": self.make_backward_connection,
1787
+ }
1788
+ self.stage += 1
1789
+ return params
1790
+
1791
+ def has_gap(self, source_node, target_node):
1792
+ """Check if a gap junction exists between two cells by checking the gap_connector's connections"""
1793
+ sid = source_node.node_id
1794
+ tid = target_node.node_id
1795
+
1796
+ # Check if this pair has a gap junction recorded in the gap_connector
1797
+ # Since gap junctions are bidirectional, check both directions
1798
+ return (sid in self.gap_connector.conn_prop and tid in self.gap_connector.conn_prop[sid]) or \
1799
+ (tid in self.gap_connector.conn_prop and sid in self.gap_connector.conn_prop[tid])
1800
+
1801
+ def cond_backward_prob(self, forward, p0, p1, pr):
1802
+ """Calculate conditional probability of backward connection given forward result"""
1803
+ if p0 > 0:
1804
+ # Ensure pr is within valid bounds
1805
+ pr_min = max(0, p0 + p1 - 1)
1806
+ pr_max = min(p0, p1)
1807
+ pr = max(pr_min, min(pr_max, pr))
1808
+
1809
+ if forward:
1810
+ return pr / p0
1811
+ else:
1812
+ return (p1 - pr) / (1 - p0) if p1 > pr else 0.0
1813
+ else:
1814
+ return p1
1815
+
1816
+ def make_forward_connection(self, source, targets, *args, **kwargs):
1817
+ if not hasattr(self, 'conn_mat'):
1818
+ self.initialize()
1819
+ stage_idx = self.stage - 1
1820
+ nsyns = self.conn_mat[stage_idx, self.iter_count, :]
1821
+ self.iter_count += 1
1822
+ if self.iter_count == self.n_source and stage_idx == self.end_stage:
1823
+ if self.verbose:
1824
+ self.connection_number_info()
1825
+ return nsyns
1826
+
1827
+ def make_backward_connection(self, targets, source, *args, **kwargs):
1828
+ self.stage = 2
1829
+ return self.make_forward_connection(source, targets)
1830
+
1831
+ def calc_pair(self, i, j, is_elec):
1832
+ """Calculate probability values for a pair like ReciprocalConnector does"""
1833
+ if is_elec:
1834
+ p0_arg = self.p0_elec_arg(i, j)
1835
+ p1_arg = self.p1_elec_arg(j, i)
1836
+ p0 = self.p0_elec(p0_arg)
1837
+ p1 = self.p1_elec(p1_arg)
1838
+ pr = self.pr_elec(self.pr_elec_arg(i, j), p0, p1)
1839
+ else:
1840
+ p0_arg = self.p0_nonelec_arg(i, j)
1841
+ p1_arg = self.p1_nonelec_arg(j, i)
1842
+ p0 = self.p0_nonelec(p0_arg)
1843
+ p1 = self.p1_nonelec(p1_arg)
1844
+ pr = self.pr_nonelec(self.pr_nonelec_arg(i, j), p0, p1)
1845
+ return p0, p1, pr
1846
+
1847
+ def initialize(self):
1848
+ self.setup_variables() # Set up variables like ReciprocalConnector
1849
+ self.end_stage = 0 if self.recurrent else 1
1850
+ shape = (self.end_stage + 1, self.n_source, self.n_target)
1851
+ self.conn_mat = np.zeros(shape, dtype=np.uint8)
1852
+ self.iter_count = 0
1853
+
1854
+ if self.verbose:
1855
+ self.timer = Timer()
1856
+
1857
+ # Pre-generate gap junction connections for consistency
1858
+ # Store gap decisions with symmetric pair keys
1859
+ self.gap_decisions = {}
1860
+
1861
+ # Generate connections using proper bivariate Bernoulli distribution
1862
+ for i in range(self.n_source):
1863
+ for j in range(self.n_target):
1864
+ # Skip self-connections for recurrent networks
1865
+ if self.recurrent and i >= j:
1866
+ continue
1867
+
1868
+ sid = self.source_ids[i]
1869
+ tid = self.target_ids[j]
1870
+ source_node = self.source_list[i]
1871
+ target_node = self.target_list[j]
1872
+
1873
+ # Check or generate gap junction decision with symmetric key
1874
+ pair_key = (min(sid, tid), max(sid, tid)) # Ensure consistent ordering
1875
+ if pair_key not in self.gap_decisions:
1876
+ self.gap_decisions[pair_key] = self.has_gap(source_node, target_node)
1877
+ has_gap = self.gap_decisions[pair_key]
1878
+
1879
+ # Track pair counts for statistics
1880
+ if has_gap:
1881
+ self.connection_stats['elec']['pairs'] += 1
1882
+ p0, p1, pr = self.calc_pair(i, j, True)
1883
+ else:
1884
+ self.connection_stats['nonelec']['pairs'] += 1
1885
+ p0, p1, pr = self.calc_pair(i, j, False)
1886
+
1887
+ # First decide forward connection
1888
+ forward = decision(p0)
1889
+
1890
+ # Then decide backward connection based on forward result
1891
+ backward_prob = self.cond_backward_prob(forward, p0, p1, pr)
1892
+ backward = decision(backward_prob)
1893
+
1894
+ # Track connection statistics
1895
+ if forward and backward:
1896
+ # Reciprocal connection
1897
+ if has_gap:
1898
+ self.connection_stats['elec']['recp'] += 1
1899
+ else:
1900
+ self.connection_stats['nonelec']['recp'] += 1
1901
+ elif forward or backward:
1902
+ # Unidirectional connection
1903
+ if has_gap:
1904
+ self.connection_stats['elec']['uni'] += 1
1905
+ else:
1906
+ self.connection_stats['nonelec']['uni'] += 1
1907
+
1908
+ # Set connections in matrix
1909
+ if forward:
1910
+ self.conn_mat[0, i, j] = self.n_syn0(i, j)
1911
+ self.add_conn_prop(i, j, None, 0)
1912
+
1913
+ if backward:
1914
+ if self.recurrent:
1915
+ self.conn_mat[0, j, i] = self.n_syn1(j, i)
1916
+ self.add_conn_prop(j, i, None, 0)
1917
+ else:
1918
+ self.conn_mat[1, i, j] = self.n_syn1(i, j)
1919
+ self.add_conn_prop(i, j, None, 1)
1920
+
1921
+ if self.verbose:
1922
+ self.timer.report("Time for creating connection matrix")
1923
+ if self.save_report:
1924
+ self.save_connection_report()
1925
+
1926
+ def add_conn_prop(self, src, trg, prop, stage=0):
1927
+ sid = self.source_ids[src]
1928
+ tid = self.target_ids[trg]
1929
+ if stage:
1930
+ sid, tid = tid, sid
1931
+ trg_dict = self.conn_prop[stage].setdefault(sid, {})
1932
+ trg_dict[tid] = prop
1933
+
1934
+ def connection_number_info(self):
1935
+ conn_mat = self.conn_mat.astype(bool)
1936
+
1937
+ if self.recurrent:
1938
+ # Calculate total pairs (upper triangle only to avoid double counting)
1939
+ n_pairs = (self.n_source * (self.n_source - 1)) // 2
1940
+
1941
+ # Count reciprocal connections (both i->j and j->i exist)
1942
+ n_recp = np.count_nonzero(conn_mat[0] & conn_mat[0].T) // 2
1943
+ # Count total connections
1944
+ n_total = np.count_nonzero(conn_mat[0])
1945
+ # Unidirectional = total - 2*reciprocal
1946
+ n_uni = n_total - 2 * n_recp
1947
+
1948
+ # Print detailed breakdown by electrical coupling
1949
+ print("Detailed connection statistics by electrical coupling:")
1950
+ print("=" * 60)
1951
+
1952
+ # Electrically coupled pairs
1953
+ elec_pairs = self.connection_stats['elec']['pairs']
1954
+ elec_uni = self.connection_stats['elec']['uni']
1955
+ elec_recp = self.connection_stats['elec']['recp']
1956
+
1957
+ print(f"Electrically coupled pairs ({elec_pairs} pairs, {elec_pairs/n_pairs:.1%} of total):")
1958
+ print(f" Unidirectional: {elec_uni} ({elec_uni/elec_pairs:.1%} of elec pairs)")
1959
+ print(f" Bidirectional: {elec_recp} ({elec_recp/elec_pairs:.1%} of elec pairs)")
1960
+ print(f" No connection: {elec_pairs - elec_uni - elec_recp} ({(elec_pairs - elec_uni - elec_recp)/elec_pairs:.1%} of elec pairs)")
1961
+
1962
+ # Non-electrically coupled pairs
1963
+ nonelec_pairs = self.connection_stats['nonelec']['pairs']
1964
+ nonelec_uni = self.connection_stats['nonelec']['uni']
1965
+ nonelec_recp = self.connection_stats['nonelec']['recp']
1966
+
1967
+ print(f"\nNon-electrically coupled pairs ({nonelec_pairs} pairs, {nonelec_pairs/n_pairs:.1%} of total):")
1968
+ print(f" Unidirectional: {nonelec_uni} ({nonelec_uni/nonelec_pairs:.1%} of nonelec pairs)")
1969
+ print(f" Bidirectional: {nonelec_recp} ({nonelec_recp/nonelec_pairs:.1%} of nonelec pairs)")
1970
+ print(f" No connection: {nonelec_pairs - nonelec_uni - nonelec_recp} ({(nonelec_pairs - nonelec_uni - nonelec_recp)/nonelec_pairs:.1%} of nonelec pairs)")
1971
+
1972
+ print(f"\nOverall chemical connectivity:")
1973
+ print(f" Numbers of connections: unidirectional, reciprocal")
1974
+ print(f" Number of connected pairs: ({n_uni}, {n_recp})")
1975
+ print(f" Fraction of connected pairs: ({n_uni/n_pairs:.2%}, {n_recp/n_pairs:.2%})")
1976
+ print(f" Total chemical connectivity: {(n_uni + n_recp)/n_pairs:.2%}")
1977
+
1978
+ else:
1979
+ # For non-recurrent networks
1980
+ n_pairs = self.n_source * self.n_target
1981
+ n_forward = np.count_nonzero(conn_mat[0])
1982
+ n_backward = np.count_nonzero(conn_mat[1])
1983
+ n_recp = np.count_nonzero(conn_mat[0] & conn_mat[1])
1984
+
1985
+ print("Numbers of connections: forward, backward, reciprocal")
1986
+ print(f"Number of connected pairs: ({n_forward}, {n_backward}, {n_recp})")
1987
+ print(f"Fraction of connected pairs: ({n_forward/n_pairs:.2%}, {n_backward/n_pairs:.2%}, {n_recp/n_pairs:.2%})")
1988
+
1989
+ def save_connection_report(self):
1990
+ # Implement similar to ReciprocalConnector if needed
1991
+ pass
1992
+
1607
1993
 
1608
1994
 
1609
1995
  class OneToOneSequentialConnector(AbstractConnector):