bmtool 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bmtool might be problematic. Click here for more details.
- bmtool/analysis/entrainment.py +113 -0
- bmtool/analysis/lfp.py +1 -1
- bmtool/bmplot/connections.py +759 -337
- bmtool/bmplot/entrainment.py +169 -49
- bmtool/bmplot/lfp.py +146 -11
- bmtool/bmplot/netcon_reports.py +1 -0
- bmtool/bmplot/spikes.py +175 -18
- bmtool/connectors.py +386 -0
- bmtool/singlecell.py +474 -31
- bmtool/synapses.py +1684 -651
- bmtool/util/util.py +40 -5
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/METADATA +1 -1
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/RECORD +17 -17
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/WHEEL +0 -0
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.6.dist-info → bmtool-0.7.8.dist-info}/top_level.txt +0 -0
bmtool/bmplot/spikes.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Plotting functions for neural spikes and firing rates."""
|
|
2
2
|
|
|
3
|
-
from typing import Dict, List, Optional, Union
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import matplotlib.pyplot as plt
|
|
6
6
|
import numpy as np
|
|
@@ -15,12 +15,13 @@ def raster(
|
|
|
15
15
|
spikes_df: Optional[pd.DataFrame] = None,
|
|
16
16
|
config: Optional[str] = None,
|
|
17
17
|
network_name: Optional[str] = None,
|
|
18
|
-
groupby:
|
|
18
|
+
groupby: str = "pop_name",
|
|
19
|
+
sortby: Optional[str] = None,
|
|
19
20
|
ax: Optional[Axes] = None,
|
|
20
21
|
tstart: Optional[float] = None,
|
|
21
22
|
tstop: Optional[float] = None,
|
|
22
23
|
color_map: Optional[Dict[str, str]] = None,
|
|
23
|
-
dot_size:
|
|
24
|
+
dot_size: float = 0.3,
|
|
24
25
|
) -> Axes:
|
|
25
26
|
"""
|
|
26
27
|
Plots a raster plot of neural spikes, with different colors for each population.
|
|
@@ -33,6 +34,10 @@ def raster(
|
|
|
33
34
|
Path to the configuration file used to load node data.
|
|
34
35
|
network_name : str, optional
|
|
35
36
|
Specific network name to select from the configuration; if not provided, uses the first network.
|
|
37
|
+
groupby : str, optional
|
|
38
|
+
Column name to group spikes by for coloring. Default is 'pop_name'.
|
|
39
|
+
sortby : str, optional
|
|
40
|
+
Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column.
|
|
36
41
|
ax : matplotlib.axes.Axes, optional
|
|
37
42
|
Axes on which to plot the raster; if None, a new figure and axes are created.
|
|
38
43
|
tstart : float, optional
|
|
@@ -107,11 +112,32 @@ def raster(
|
|
|
107
112
|
|
|
108
113
|
# Plot each population with its specified or generated color
|
|
109
114
|
legend_handles = []
|
|
115
|
+
y_offset = 0 # Track y-position offset for stacking populations
|
|
116
|
+
|
|
110
117
|
for pop_name, group in spikes_df.groupby(groupby):
|
|
111
|
-
|
|
118
|
+
if sortby:
|
|
119
|
+
# Sort by the specified column, putting NaN values at the end
|
|
120
|
+
group_sorted = group.sort_values(by=sortby, na_position='last')
|
|
121
|
+
# Create a mapping from node_ids to consecutive y-positions based on sorted order
|
|
122
|
+
# Use the sorted order to maintain the same sequence for all spikes from same node
|
|
123
|
+
unique_nodes_sorted = group_sorted['node_ids'].drop_duplicates()
|
|
124
|
+
node_to_y = {node_id: y_offset + i for i, node_id in enumerate(unique_nodes_sorted)}
|
|
125
|
+
# Map node_ids to new y-positions for ALL spikes (not just the sorted group)
|
|
126
|
+
y_positions = group['node_ids'].map(node_to_y)
|
|
127
|
+
# Verify no data was lost
|
|
128
|
+
assert len(y_positions) == len(group), f"Data loss detected in population {pop_name}"
|
|
129
|
+
assert y_positions.isna().sum() == 0, f"Unmapped node_ids found in population {pop_name}"
|
|
130
|
+
else:
|
|
131
|
+
y_positions = group['node_ids']
|
|
132
|
+
|
|
133
|
+
ax.scatter(group["timestamps"], y_positions, color=color_map[pop_name], s=dot_size)
|
|
112
134
|
# Dummy scatter for consistent legend appearance
|
|
113
135
|
handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
|
|
114
136
|
legend_handles.append(handle)
|
|
137
|
+
|
|
138
|
+
# Update y_offset for next population if sortby is used
|
|
139
|
+
if sortby:
|
|
140
|
+
y_offset += len(unique_nodes_sorted)
|
|
115
141
|
|
|
116
142
|
# Label axes
|
|
117
143
|
ax.set_xlabel("Time")
|
|
@@ -211,11 +237,12 @@ def plot_firing_rate_pop_stats(
|
|
|
211
237
|
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
|
212
238
|
def plot_firing_rate_distribution(
|
|
213
239
|
individual_stats: pd.DataFrame,
|
|
214
|
-
groupby: Union[str,
|
|
240
|
+
groupby: Union[str, List[str]],
|
|
215
241
|
ax: Optional[Axes] = None,
|
|
216
242
|
color_map: Optional[Dict[str, str]] = None,
|
|
217
|
-
plot_type: Union[str,
|
|
243
|
+
plot_type: Union[str, List[str]] = "box",
|
|
218
244
|
swarm_alpha: float = 0.6,
|
|
245
|
+
logscale: bool = False,
|
|
219
246
|
) -> Axes:
|
|
220
247
|
"""
|
|
221
248
|
Plots a distribution of individual firing rates using one or more plot types
|
|
@@ -235,6 +262,8 @@ def plot_firing_rate_distribution(
|
|
|
235
262
|
List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
|
|
236
263
|
swarm_alpha : float, optional
|
|
237
264
|
Transparency of swarm plot points. Default is 0.6.
|
|
265
|
+
logscale : bool, optional
|
|
266
|
+
If True, use logarithmic scale for the y-axis (default is False).
|
|
238
267
|
|
|
239
268
|
Returns:
|
|
240
269
|
-------
|
|
@@ -316,40 +345,46 @@ def plot_firing_rate_distribution(
|
|
|
316
345
|
ax.set_title("Firing Rate Distribution for individual cells")
|
|
317
346
|
ax.grid(axis="y", linestyle="--", alpha=0.7)
|
|
318
347
|
|
|
348
|
+
if logscale:
|
|
349
|
+
ax.set_yscale('log')
|
|
350
|
+
|
|
319
351
|
return ax
|
|
320
352
|
|
|
321
353
|
|
|
322
354
|
def plot_firing_rate_vs_node_attribute(
|
|
323
|
-
individual_stats:
|
|
355
|
+
individual_stats: pd.DataFrame,
|
|
356
|
+
groupby: str,
|
|
357
|
+
attribute: str,
|
|
324
358
|
config: Optional[str] = None,
|
|
325
359
|
nodes: Optional[pd.DataFrame] = None,
|
|
326
|
-
groupby: Optional[str] = None,
|
|
327
360
|
network_name: Optional[str] = None,
|
|
328
|
-
|
|
329
|
-
figsize=(12, 8),
|
|
361
|
+
figsize: Tuple[float, float] = (12, 8),
|
|
330
362
|
dot_size: float = 3,
|
|
363
|
+
color_map: Optional[Dict[str, str]] = None,
|
|
331
364
|
) -> plt.Figure:
|
|
332
365
|
"""
|
|
333
366
|
Plot firing rate vs node attribute for each group in separate subplots.
|
|
334
367
|
|
|
335
368
|
Parameters
|
|
336
369
|
----------
|
|
337
|
-
individual_stats : pd.DataFrame
|
|
370
|
+
individual_stats : pd.DataFrame
|
|
338
371
|
DataFrame containing individual cell firing rates from compute_firing_rate_stats
|
|
372
|
+
groupby : str
|
|
373
|
+
Column name in individual_stats to group plots by
|
|
374
|
+
attribute : str
|
|
375
|
+
Node attribute column name to plot against firing rate
|
|
339
376
|
config : str, optional
|
|
340
377
|
Path to configuration file for loading node data
|
|
341
378
|
nodes : pd.DataFrame, optional
|
|
342
379
|
Pre-loaded node data as alternative to loading from config
|
|
343
|
-
groupby : str, optional
|
|
344
|
-
Column name in individual_stats to group plots by
|
|
345
380
|
network_name : str, optional
|
|
346
381
|
Name of network to load from config file
|
|
347
|
-
|
|
348
|
-
Node attribute column name to plot against firing rate
|
|
349
|
-
figsize : tuple[int, int], optional
|
|
382
|
+
figsize : Tuple[float, float], optional
|
|
350
383
|
Figure dimensions (width, height) in inches
|
|
351
384
|
dot_size : float, optional
|
|
352
385
|
Size of scatter plot points
|
|
386
|
+
color_map : dict, optional
|
|
387
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
|
353
388
|
|
|
354
389
|
Returns
|
|
355
390
|
-------
|
|
@@ -407,12 +442,26 @@ def plot_firing_rate_vs_node_attribute(
|
|
|
407
442
|
axes = np.array([axes])
|
|
408
443
|
axes = axes.flatten()
|
|
409
444
|
|
|
445
|
+
# Generate colors if no color_map is provided
|
|
446
|
+
if color_map is None:
|
|
447
|
+
cmap = plt.get_cmap("tab10")
|
|
448
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
|
449
|
+
else:
|
|
450
|
+
# Ensure color_map contains all groups
|
|
451
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
|
452
|
+
if missing_colors:
|
|
453
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
|
454
|
+
|
|
410
455
|
# Plot each group
|
|
411
456
|
for i, group in enumerate(unique_groups):
|
|
412
457
|
group_df = merged_df[merged_df[groupby] == group]
|
|
413
|
-
axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size)
|
|
458
|
+
axes[i].scatter(group_df["firing_rate"], group_df[attribute], s=dot_size, color=color_map[group])
|
|
414
459
|
axes[i].set_xlabel("Firing Rate (Hz)")
|
|
415
460
|
axes[i].set_ylabel(attribute)
|
|
461
|
+
|
|
462
|
+
# Calculate and display mean firing rate in legend
|
|
463
|
+
mean_fr = group_df["firing_rate"].mean()
|
|
464
|
+
axes[i].legend([f"Mean FR: {mean_fr:.2f} Hz"], loc="upper right")
|
|
416
465
|
axes[i].set_title(f"{groupby}: {group}")
|
|
417
466
|
|
|
418
467
|
# Hide unused subplots
|
|
@@ -420,4 +469,112 @@ def plot_firing_rate_vs_node_attribute(
|
|
|
420
469
|
axes[j].set_visible(False)
|
|
421
470
|
|
|
422
471
|
plt.tight_layout()
|
|
423
|
-
|
|
472
|
+
return fig
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def plot_firing_rate_histogram(
|
|
476
|
+
individual_stats: pd.DataFrame,
|
|
477
|
+
groupby: str = "pop_name",
|
|
478
|
+
ax: Optional[Axes] = None,
|
|
479
|
+
color_map: Optional[Dict[str, str]] = None,
|
|
480
|
+
bins: int = 30,
|
|
481
|
+
alpha: float = 0.7,
|
|
482
|
+
figsize: Tuple[float, float] = (12, 8),
|
|
483
|
+
stacked: bool = False,
|
|
484
|
+
logscale: bool = False,
|
|
485
|
+
min_fr: Optional[float] = None,
|
|
486
|
+
) -> plt.Figure:
|
|
487
|
+
"""
|
|
488
|
+
Plot histograms of firing rates for each population group.
|
|
489
|
+
|
|
490
|
+
Parameters:
|
|
491
|
+
----------
|
|
492
|
+
individual_stats : pd.DataFrame
|
|
493
|
+
DataFrame containing individual firing rates with group labels.
|
|
494
|
+
groupby : str, optional
|
|
495
|
+
Column name to group by (default is "pop_name").
|
|
496
|
+
ax : matplotlib.axes.Axes, optional
|
|
497
|
+
Axes on which to plot; if None, a new figure is created.
|
|
498
|
+
color_map : dict, optional
|
|
499
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
|
500
|
+
bins : int, optional
|
|
501
|
+
Number of bins for the histogram (default is 30).
|
|
502
|
+
alpha : float, optional
|
|
503
|
+
Transparency level for the histograms (default is 0.7).
|
|
504
|
+
figsize : Tuple[float, float], optional
|
|
505
|
+
Figure size if creating a new figure (default is (12, 8)).
|
|
506
|
+
stacked : bool, optional
|
|
507
|
+
If True, plot all histograms on a single axes stacked (default is False).
|
|
508
|
+
logscale : bool, optional
|
|
509
|
+
If True, use logarithmic scale for the x-axis (default is False).
|
|
510
|
+
min_fr : float, optional
|
|
511
|
+
Minimum firing rate for log scale bins (default is None).
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
-------
|
|
515
|
+
matplotlib.figure.Figure
|
|
516
|
+
Figure containing the histogram subplots.
|
|
517
|
+
"""
|
|
518
|
+
sns.set_style("whitegrid")
|
|
519
|
+
|
|
520
|
+
# Get unique groups
|
|
521
|
+
unique_groups = individual_stats[groupby].unique()
|
|
522
|
+
|
|
523
|
+
# Generate colors if no color_map is provided
|
|
524
|
+
if color_map is None:
|
|
525
|
+
cmap = plt.get_cmap("tab10")
|
|
526
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
|
527
|
+
else:
|
|
528
|
+
# Ensure color_map contains all groups
|
|
529
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
|
530
|
+
if missing_colors:
|
|
531
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
|
532
|
+
|
|
533
|
+
# Group data by population
|
|
534
|
+
pop_fr = {}
|
|
535
|
+
for group in unique_groups:
|
|
536
|
+
pop_fr[group] = individual_stats[individual_stats[groupby] == group]["firing_rate"].values
|
|
537
|
+
|
|
538
|
+
if logscale and min_fr is not None:
|
|
539
|
+
pop_fr = {p: np.fmax(fr, min_fr) for p, fr in pop_fr.items()}
|
|
540
|
+
fr = np.concatenate(list(pop_fr.values()))
|
|
541
|
+
if logscale:
|
|
542
|
+
fr = fr[fr > 0]
|
|
543
|
+
bins_array = np.geomspace(fr.min(), fr.max(), bins + 1)
|
|
544
|
+
else:
|
|
545
|
+
bins_array = np.linspace(fr.min(), fr.max(), bins + 1)
|
|
546
|
+
|
|
547
|
+
# Setup subplot layout or single plot
|
|
548
|
+
n_groups = len(unique_groups)
|
|
549
|
+
if stacked or not stacked: # Always use single ax for now, since stacked means overlaid
|
|
550
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
551
|
+
else:
|
|
552
|
+
# If not stacked, but since overlaid is default, perhaps keep as is
|
|
553
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
554
|
+
|
|
555
|
+
if stacked:
|
|
556
|
+
ax.hist(pop_fr.values(), bins=bins_array, label=list(pop_fr.keys()),
|
|
557
|
+
color=[color_map[p] for p in pop_fr.keys()], stacked=True)
|
|
558
|
+
else:
|
|
559
|
+
for p, fr_vals in pop_fr.items():
|
|
560
|
+
ax.hist(fr_vals, bins=bins_array, label=p, color=color_map[p], alpha=alpha)
|
|
561
|
+
|
|
562
|
+
if logscale:
|
|
563
|
+
ax.set_xscale('log')
|
|
564
|
+
plt.draw()
|
|
565
|
+
xt = ax.get_xticks()
|
|
566
|
+
xtl = [f'{x:g}' for x in xt]
|
|
567
|
+
if min_fr is not None:
|
|
568
|
+
xt = np.append(xt, min_fr)
|
|
569
|
+
xtl.append('0')
|
|
570
|
+
ax.set_xticks(xt)
|
|
571
|
+
ax.set_xticklabels(xtl)
|
|
572
|
+
|
|
573
|
+
ax.set_xlim(bins_array[0], bins_array[-1])
|
|
574
|
+
ax.legend(loc='upper right')
|
|
575
|
+
ax.set_title('Firing Rate Histogram')
|
|
576
|
+
ax.set_xlabel('Frequency (Hz)')
|
|
577
|
+
ax.set_ylabel('Count')
|
|
578
|
+
return fig
|
|
579
|
+
|
|
580
|
+
|
bmtool/connectors.py
CHANGED
|
@@ -1604,6 +1604,392 @@ class CorrelatedGapJunction(GapJunction):
|
|
|
1604
1604
|
if self.save_report:
|
|
1605
1605
|
self.save_connection_report()
|
|
1606
1606
|
return nsyns
|
|
1607
|
+
|
|
1608
|
+
class GapJunctionConditionalReciprocalConnector(AbstractConnector):
|
|
1609
|
+
"""
|
|
1610
|
+
Object for building reciprocal chemical synapses in BMTK network model with
|
|
1611
|
+
probabilities that depend on the presence of gap junctions between cell pairs.
|
|
1612
|
+
|
|
1613
|
+
This connector creates chemical synapses where the connection probabilities
|
|
1614
|
+
(forward, backward, and reciprocal) differ based on whether a pair of cells
|
|
1615
|
+
shares a gap junction. This allows modeling of the experimentally observed
|
|
1616
|
+
correlation between electrical and chemical coupling in neural populations.
|
|
1617
|
+
|
|
1618
|
+
Algorithm:
|
|
1619
|
+
For each potential connection pair, first determine if a gap junction exists
|
|
1620
|
+
using the provided gap_connector. Then apply different bivariate Bernoulli
|
|
1621
|
+
probability distributions for chemical synapses based on electrical coupling
|
|
1622
|
+
status:
|
|
1623
|
+
- Electrically coupled pairs: Use p0_elec, p1_elec, pr_elec probabilities
|
|
1624
|
+
- Non-electrically coupled pairs: Use p0_nonelec, p1_nonelec, pr_nonelec probabilities
|
|
1625
|
+
|
|
1626
|
+
For each pair, generate random connections following the same bivariate
|
|
1627
|
+
Bernoulli distribution as ReciprocalConnector, but with conditional
|
|
1628
|
+
probabilities based on gap junction presence.
|
|
1629
|
+
|
|
1630
|
+
Use with BMTK:
|
|
1631
|
+
1. First create and set up a gap junction connector:
|
|
1632
|
+
|
|
1633
|
+
gap_connector = GapJunction(p=0.08, verbose=True)
|
|
1634
|
+
gap_connector.setup_nodes(source_population, target_population)
|
|
1635
|
+
net.add_edges(is_gap_junction=True, **gap_connector.edge_params())
|
|
1636
|
+
|
|
1637
|
+
2. Create the conditional reciprocal connector with different probabilities
|
|
1638
|
+
for electrically coupled vs non-coupled pairs:
|
|
1639
|
+
|
|
1640
|
+
chemical_connector = GapJunctionConditionalReciprocalConnector(
|
|
1641
|
+
gap_connector=gap_connector,
|
|
1642
|
+
p0_elec=0.50, p1_elec=0.50, pr_elec=0.25, # High reciprocity for coupled pairs
|
|
1643
|
+
p0_nonelec=0.125, p1_nonelec=0.125, pr_nonelec=0.03, # Low reciprocity for non-coupled pairs
|
|
1644
|
+
verbose=True
|
|
1645
|
+
)
|
|
1646
|
+
|
|
1647
|
+
3. Set up nodes and add chemical edges:
|
|
1648
|
+
|
|
1649
|
+
chemical_connector.setup_nodes(source_population, target_population)
|
|
1650
|
+
net.add_edges(**chemical_connector.edge_params(),
|
|
1651
|
+
**chemical_synapse_properties)
|
|
1652
|
+
|
|
1653
|
+
4. Build the network:
|
|
1654
|
+
|
|
1655
|
+
net.build()
|
|
1656
|
+
|
|
1657
|
+
Parameters:
|
|
1658
|
+
gap_connector: GapJunction connector object that has been set up with nodes.
|
|
1659
|
+
Used to determine which cell pairs have electrical coupling.
|
|
1660
|
+
p0_elec, p1_elec: Forward and backward connection probabilities for
|
|
1661
|
+
electrically coupled pairs. Can be constants or functions within [0, 1].
|
|
1662
|
+
pr_elec: Reciprocal connection probability for electrically coupled pairs.
|
|
1663
|
+
Can be a constant or function accepting (pr_arg, p0, p1) arguments.
|
|
1664
|
+
p0_nonelec, p1_nonelec: Forward and backward connection probabilities for
|
|
1665
|
+
non-electrically coupled pairs. Can be constants or functions within [0, 1].
|
|
1666
|
+
pr_nonelec: Reciprocal connection probability for non-electrically coupled pairs.
|
|
1667
|
+
Can be a constant or function accepting (pr_arg, p0, p1) arguments.
|
|
1668
|
+
p0_elec_arg, p1_elec_arg, pr_elec_arg: Input arguments for electrically coupled
|
|
1669
|
+
probability functions. Can be constants, distance functions, or other
|
|
1670
|
+
node property functions. Set to None if functions don't need arguments.
|
|
1671
|
+
p0_nonelec_arg, p1_nonelec_arg, pr_nonelec_arg: Input arguments for
|
|
1672
|
+
non-electrically coupled probability functions, similar to above.
|
|
1673
|
+
n_syn0, n_syn1: Number of synapses for forward/backward connections if
|
|
1674
|
+
established. Can be constants or functions of node properties.
|
|
1675
|
+
Limited to 255 due to uint8 storage.
|
|
1676
|
+
verbose: Whether to print detailed connection statistics and progress.
|
|
1677
|
+
save_report: Whether to save connection report to CSV file.
|
|
1678
|
+
report_name: Filename for connection report (default: "conn.csv").
|
|
1679
|
+
|
|
1680
|
+
Returns:
|
|
1681
|
+
An object that works with BMTK to build conditional reciprocal chemical
|
|
1682
|
+
edges in a network, with probabilities dependent on gap junction presence.
|
|
1683
|
+
|
|
1684
|
+
Important attributes:
|
|
1685
|
+
gap_connector: Reference to the GapJunction connector used for coupling detection.
|
|
1686
|
+
vars: Dictionary storing original input parameters.
|
|
1687
|
+
source, target: NodePool objects for source and target populations.
|
|
1688
|
+
recurrent: Whether source and target populations are the same.
|
|
1689
|
+
conn_mat: Connection matrix storing synapse counts.
|
|
1690
|
+
conn_prop: List of dictionaries storing connection properties for forward
|
|
1691
|
+
and backward connections. Format: [{src_id: {tgt_id: prop}, ...}, ...]
|
|
1692
|
+
gap_decisions: Dictionary caching gap junction presence for each pair.
|
|
1693
|
+
connection_stats: Statistics tracking connections by electrical coupling status.
|
|
1694
|
+
Format: {'elec': {'pairs': int, 'uni': int, 'recp': int},
|
|
1695
|
+
'nonelec': {'pairs': int, 'uni': int, 'recp': int}}
|
|
1696
|
+
"""
|
|
1697
|
+
|
|
1698
|
+
def __init__(self, gap_connector, p0_elec, p1_elec, pr_elec, p0_nonelec, p1_nonelec, pr_nonelec,
|
|
1699
|
+
p0_elec_arg=None, p1_elec_arg=None, pr_elec_arg=None,
|
|
1700
|
+
p0_nonelec_arg=None, p1_nonelec_arg=None, pr_nonelec_arg=None,
|
|
1701
|
+
n_syn0=1, n_syn1=1, verbose=True, save_report=True, report_name=None):
|
|
1702
|
+
# Store original parameters like ReciprocalConnector
|
|
1703
|
+
args = locals()
|
|
1704
|
+
var_set = ("p0_elec", "p0_elec_arg", "p1_elec", "p1_elec_arg", "pr_elec", "pr_elec_arg",
|
|
1705
|
+
"p0_nonelec", "p0_nonelec_arg", "p1_nonelec", "p1_nonelec_arg", "pr_nonelec", "pr_nonelec_arg",
|
|
1706
|
+
"n_syn0", "n_syn1")
|
|
1707
|
+
self.vars = {key: args[key] for key in var_set}
|
|
1708
|
+
|
|
1709
|
+
self.gap_connector = gap_connector
|
|
1710
|
+
self.verbose = verbose
|
|
1711
|
+
self.save_report = save_report
|
|
1712
|
+
self.report_name = report_name or "conn.csv"
|
|
1713
|
+
self.conn_prop = [{}, {}]
|
|
1714
|
+
self.stage = 0
|
|
1715
|
+
# Track gap junction decisions and connections for detailed reporting
|
|
1716
|
+
self.gap_decisions = {}
|
|
1717
|
+
self.connection_stats = {'elec': {'pairs': 0, 'uni': 0, 'recp': 0},
|
|
1718
|
+
'nonelec': {'pairs': 0, 'uni': 0, 'recp': 0}}
|
|
1719
|
+
|
|
1720
|
+
def setup_variables(self):
|
|
1721
|
+
"""Set up variables like ReciprocalConnector does"""
|
|
1722
|
+
callable_set = set()
|
|
1723
|
+
# Make constant variables constant functions
|
|
1724
|
+
for name, var in self.vars.items():
|
|
1725
|
+
if callable(var):
|
|
1726
|
+
callable_set.add(name) # record callable variables
|
|
1727
|
+
setattr(self, name, var)
|
|
1728
|
+
else:
|
|
1729
|
+
setattr(self, name, self.constant_function(var))
|
|
1730
|
+
callable_set.add(name) # constants converted to functions are also callable
|
|
1731
|
+
self.callable_set = callable_set
|
|
1732
|
+
|
|
1733
|
+
# Make callable variables accept index input instead of node input
|
|
1734
|
+
# Exclude probability functions (p0_elec, p1_elec, pr_elec, p0_nonelec, p1_nonelec, pr_nonelec)
|
|
1735
|
+
# as they are called with arguments from _arg functions
|
|
1736
|
+
for name in callable_set - {"p0_elec", "p1_elec", "pr_elec", "p0_nonelec", "p1_nonelec", "pr_nonelec"}:
|
|
1737
|
+
var = getattr(self, name) # Get the already converted function
|
|
1738
|
+
setattr(self, name, self.node_2_idx_input(var, '1' in name and name.startswith(('p1_', 'pr_', 'n_syn1'))))
|
|
1739
|
+
|
|
1740
|
+
@staticmethod
|
|
1741
|
+
def constant_function(val):
|
|
1742
|
+
"""Convert a constant to a constant function"""
|
|
1743
|
+
def constant(*arg):
|
|
1744
|
+
return val
|
|
1745
|
+
return constant
|
|
1746
|
+
|
|
1747
|
+
def node_2_idx_input(self, var_func, reverse=False):
|
|
1748
|
+
"""Convert a function that accept nodes as input to accept indices as input"""
|
|
1749
|
+
if reverse:
|
|
1750
|
+
def idx_2_var(j, i):
|
|
1751
|
+
return var_func(self.target_list[j], self.source_list[i])
|
|
1752
|
+
else:
|
|
1753
|
+
def idx_2_var(i, j):
|
|
1754
|
+
return var_func(self.source_list[i], self.target_list[j])
|
|
1755
|
+
return idx_2_var
|
|
1756
|
+
|
|
1757
|
+
def setup_nodes(self, source, target):
|
|
1758
|
+
self.source = source
|
|
1759
|
+
self.target = target
|
|
1760
|
+
self.recurrent = is_same_pop(self.source, self.target)
|
|
1761
|
+
self.source_ids = [s.node_id for s in self.source]
|
|
1762
|
+
self.n_source = len(self.source_ids)
|
|
1763
|
+
self.source_list = list(self.source)
|
|
1764
|
+
if self.recurrent:
|
|
1765
|
+
self.target_ids = self.source_ids
|
|
1766
|
+
self.n_target = self.n_source
|
|
1767
|
+
self.target_list = self.source_list
|
|
1768
|
+
else:
|
|
1769
|
+
self.target_ids = [t.node_id for t in self.target]
|
|
1770
|
+
self.n_target = len(self.target_ids)
|
|
1771
|
+
self.target_list = list(self.target)
|
|
1772
|
+
|
|
1773
|
+
def edge_params(self):
|
|
1774
|
+
if self.stage == 0:
|
|
1775
|
+
params = {
|
|
1776
|
+
"source": self.source,
|
|
1777
|
+
"target": self.target,
|
|
1778
|
+
"iterator": "one_to_all",
|
|
1779
|
+
"connection_rule": self.make_forward_connection,
|
|
1780
|
+
}
|
|
1781
|
+
else:
|
|
1782
|
+
params = {
|
|
1783
|
+
"source": self.target,
|
|
1784
|
+
"target": self.source,
|
|
1785
|
+
"iterator": "all_to_one",
|
|
1786
|
+
"connection_rule": self.make_backward_connection,
|
|
1787
|
+
}
|
|
1788
|
+
self.stage += 1
|
|
1789
|
+
return params
|
|
1790
|
+
|
|
1791
|
+
def has_gap(self, source_node, target_node):
|
|
1792
|
+
"""Check if a gap junction exists between two cells by checking the gap_connector's connections"""
|
|
1793
|
+
sid = source_node.node_id
|
|
1794
|
+
tid = target_node.node_id
|
|
1795
|
+
|
|
1796
|
+
# Check if this pair has a gap junction recorded in the gap_connector
|
|
1797
|
+
# Since gap junctions are bidirectional, check both directions
|
|
1798
|
+
return (sid in self.gap_connector.conn_prop and tid in self.gap_connector.conn_prop[sid]) or \
|
|
1799
|
+
(tid in self.gap_connector.conn_prop and sid in self.gap_connector.conn_prop[tid])
|
|
1800
|
+
|
|
1801
|
+
def cond_backward_prob(self, forward, p0, p1, pr):
|
|
1802
|
+
"""Calculate conditional probability of backward connection given forward result"""
|
|
1803
|
+
if p0 > 0:
|
|
1804
|
+
# Ensure pr is within valid bounds
|
|
1805
|
+
pr_min = max(0, p0 + p1 - 1)
|
|
1806
|
+
pr_max = min(p0, p1)
|
|
1807
|
+
pr = max(pr_min, min(pr_max, pr))
|
|
1808
|
+
|
|
1809
|
+
if forward:
|
|
1810
|
+
return pr / p0
|
|
1811
|
+
else:
|
|
1812
|
+
return (p1 - pr) / (1 - p0) if p1 > pr else 0.0
|
|
1813
|
+
else:
|
|
1814
|
+
return p1
|
|
1815
|
+
|
|
1816
|
+
def make_forward_connection(self, source, targets, *args, **kwargs):
|
|
1817
|
+
if not hasattr(self, 'conn_mat'):
|
|
1818
|
+
self.initialize()
|
|
1819
|
+
stage_idx = self.stage - 1
|
|
1820
|
+
nsyns = self.conn_mat[stage_idx, self.iter_count, :]
|
|
1821
|
+
self.iter_count += 1
|
|
1822
|
+
if self.iter_count == self.n_source and stage_idx == self.end_stage:
|
|
1823
|
+
if self.verbose:
|
|
1824
|
+
self.connection_number_info()
|
|
1825
|
+
return nsyns
|
|
1826
|
+
|
|
1827
|
+
def make_backward_connection(self, targets, source, *args, **kwargs):
|
|
1828
|
+
self.stage = 2
|
|
1829
|
+
return self.make_forward_connection(source, targets)
|
|
1830
|
+
|
|
1831
|
+
def calc_pair(self, i, j, is_elec):
|
|
1832
|
+
"""Calculate probability values for a pair like ReciprocalConnector does"""
|
|
1833
|
+
if is_elec:
|
|
1834
|
+
p0_arg = self.p0_elec_arg(i, j)
|
|
1835
|
+
p1_arg = self.p1_elec_arg(j, i)
|
|
1836
|
+
p0 = self.p0_elec(p0_arg)
|
|
1837
|
+
p1 = self.p1_elec(p1_arg)
|
|
1838
|
+
pr = self.pr_elec(self.pr_elec_arg(i, j), p0, p1)
|
|
1839
|
+
else:
|
|
1840
|
+
p0_arg = self.p0_nonelec_arg(i, j)
|
|
1841
|
+
p1_arg = self.p1_nonelec_arg(j, i)
|
|
1842
|
+
p0 = self.p0_nonelec(p0_arg)
|
|
1843
|
+
p1 = self.p1_nonelec(p1_arg)
|
|
1844
|
+
pr = self.pr_nonelec(self.pr_nonelec_arg(i, j), p0, p1)
|
|
1845
|
+
return p0, p1, pr
|
|
1846
|
+
|
|
1847
|
+
def initialize(self):
|
|
1848
|
+
self.setup_variables() # Set up variables like ReciprocalConnector
|
|
1849
|
+
self.end_stage = 0 if self.recurrent else 1
|
|
1850
|
+
shape = (self.end_stage + 1, self.n_source, self.n_target)
|
|
1851
|
+
self.conn_mat = np.zeros(shape, dtype=np.uint8)
|
|
1852
|
+
self.iter_count = 0
|
|
1853
|
+
|
|
1854
|
+
if self.verbose:
|
|
1855
|
+
self.timer = Timer()
|
|
1856
|
+
|
|
1857
|
+
# Pre-generate gap junction connections for consistency
|
|
1858
|
+
# Store gap decisions with symmetric pair keys
|
|
1859
|
+
self.gap_decisions = {}
|
|
1860
|
+
|
|
1861
|
+
# Generate connections using proper bivariate Bernoulli distribution
|
|
1862
|
+
for i in range(self.n_source):
|
|
1863
|
+
for j in range(self.n_target):
|
|
1864
|
+
# Skip self-connections for recurrent networks
|
|
1865
|
+
if self.recurrent and i >= j:
|
|
1866
|
+
continue
|
|
1867
|
+
|
|
1868
|
+
sid = self.source_ids[i]
|
|
1869
|
+
tid = self.target_ids[j]
|
|
1870
|
+
source_node = self.source_list[i]
|
|
1871
|
+
target_node = self.target_list[j]
|
|
1872
|
+
|
|
1873
|
+
# Check or generate gap junction decision with symmetric key
|
|
1874
|
+
pair_key = (min(sid, tid), max(sid, tid)) # Ensure consistent ordering
|
|
1875
|
+
if pair_key not in self.gap_decisions:
|
|
1876
|
+
self.gap_decisions[pair_key] = self.has_gap(source_node, target_node)
|
|
1877
|
+
has_gap = self.gap_decisions[pair_key]
|
|
1878
|
+
|
|
1879
|
+
# Track pair counts for statistics
|
|
1880
|
+
if has_gap:
|
|
1881
|
+
self.connection_stats['elec']['pairs'] += 1
|
|
1882
|
+
p0, p1, pr = self.calc_pair(i, j, True)
|
|
1883
|
+
else:
|
|
1884
|
+
self.connection_stats['nonelec']['pairs'] += 1
|
|
1885
|
+
p0, p1, pr = self.calc_pair(i, j, False)
|
|
1886
|
+
|
|
1887
|
+
# First decide forward connection
|
|
1888
|
+
forward = decision(p0)
|
|
1889
|
+
|
|
1890
|
+
# Then decide backward connection based on forward result
|
|
1891
|
+
backward_prob = self.cond_backward_prob(forward, p0, p1, pr)
|
|
1892
|
+
backward = decision(backward_prob)
|
|
1893
|
+
|
|
1894
|
+
# Track connection statistics
|
|
1895
|
+
if forward and backward:
|
|
1896
|
+
# Reciprocal connection
|
|
1897
|
+
if has_gap:
|
|
1898
|
+
self.connection_stats['elec']['recp'] += 1
|
|
1899
|
+
else:
|
|
1900
|
+
self.connection_stats['nonelec']['recp'] += 1
|
|
1901
|
+
elif forward or backward:
|
|
1902
|
+
# Unidirectional connection
|
|
1903
|
+
if has_gap:
|
|
1904
|
+
self.connection_stats['elec']['uni'] += 1
|
|
1905
|
+
else:
|
|
1906
|
+
self.connection_stats['nonelec']['uni'] += 1
|
|
1907
|
+
|
|
1908
|
+
# Set connections in matrix
|
|
1909
|
+
if forward:
|
|
1910
|
+
self.conn_mat[0, i, j] = self.n_syn0(i, j)
|
|
1911
|
+
self.add_conn_prop(i, j, None, 0)
|
|
1912
|
+
|
|
1913
|
+
if backward:
|
|
1914
|
+
if self.recurrent:
|
|
1915
|
+
self.conn_mat[0, j, i] = self.n_syn1(j, i)
|
|
1916
|
+
self.add_conn_prop(j, i, None, 0)
|
|
1917
|
+
else:
|
|
1918
|
+
self.conn_mat[1, i, j] = self.n_syn1(i, j)
|
|
1919
|
+
self.add_conn_prop(i, j, None, 1)
|
|
1920
|
+
|
|
1921
|
+
if self.verbose:
|
|
1922
|
+
self.timer.report("Time for creating connection matrix")
|
|
1923
|
+
if self.save_report:
|
|
1924
|
+
self.save_connection_report()
|
|
1925
|
+
|
|
1926
|
+
def add_conn_prop(self, src, trg, prop, stage=0):
|
|
1927
|
+
sid = self.source_ids[src]
|
|
1928
|
+
tid = self.target_ids[trg]
|
|
1929
|
+
if stage:
|
|
1930
|
+
sid, tid = tid, sid
|
|
1931
|
+
trg_dict = self.conn_prop[stage].setdefault(sid, {})
|
|
1932
|
+
trg_dict[tid] = prop
|
|
1933
|
+
|
|
1934
|
+
def connection_number_info(self):
|
|
1935
|
+
conn_mat = self.conn_mat.astype(bool)
|
|
1936
|
+
|
|
1937
|
+
if self.recurrent:
|
|
1938
|
+
# Calculate total pairs (upper triangle only to avoid double counting)
|
|
1939
|
+
n_pairs = (self.n_source * (self.n_source - 1)) // 2
|
|
1940
|
+
|
|
1941
|
+
# Count reciprocal connections (both i->j and j->i exist)
|
|
1942
|
+
n_recp = np.count_nonzero(conn_mat[0] & conn_mat[0].T) // 2
|
|
1943
|
+
# Count total connections
|
|
1944
|
+
n_total = np.count_nonzero(conn_mat[0])
|
|
1945
|
+
# Unidirectional = total - 2*reciprocal
|
|
1946
|
+
n_uni = n_total - 2 * n_recp
|
|
1947
|
+
|
|
1948
|
+
# Print detailed breakdown by electrical coupling
|
|
1949
|
+
print("Detailed connection statistics by electrical coupling:")
|
|
1950
|
+
print("=" * 60)
|
|
1951
|
+
|
|
1952
|
+
# Electrically coupled pairs
|
|
1953
|
+
elec_pairs = self.connection_stats['elec']['pairs']
|
|
1954
|
+
elec_uni = self.connection_stats['elec']['uni']
|
|
1955
|
+
elec_recp = self.connection_stats['elec']['recp']
|
|
1956
|
+
|
|
1957
|
+
print(f"Electrically coupled pairs ({elec_pairs} pairs, {elec_pairs/n_pairs:.1%} of total):")
|
|
1958
|
+
print(f" Unidirectional: {elec_uni} ({elec_uni/elec_pairs:.1%} of elec pairs)")
|
|
1959
|
+
print(f" Bidirectional: {elec_recp} ({elec_recp/elec_pairs:.1%} of elec pairs)")
|
|
1960
|
+
print(f" No connection: {elec_pairs - elec_uni - elec_recp} ({(elec_pairs - elec_uni - elec_recp)/elec_pairs:.1%} of elec pairs)")
|
|
1961
|
+
|
|
1962
|
+
# Non-electrically coupled pairs
|
|
1963
|
+
nonelec_pairs = self.connection_stats['nonelec']['pairs']
|
|
1964
|
+
nonelec_uni = self.connection_stats['nonelec']['uni']
|
|
1965
|
+
nonelec_recp = self.connection_stats['nonelec']['recp']
|
|
1966
|
+
|
|
1967
|
+
print(f"\nNon-electrically coupled pairs ({nonelec_pairs} pairs, {nonelec_pairs/n_pairs:.1%} of total):")
|
|
1968
|
+
print(f" Unidirectional: {nonelec_uni} ({nonelec_uni/nonelec_pairs:.1%} of nonelec pairs)")
|
|
1969
|
+
print(f" Bidirectional: {nonelec_recp} ({nonelec_recp/nonelec_pairs:.1%} of nonelec pairs)")
|
|
1970
|
+
print(f" No connection: {nonelec_pairs - nonelec_uni - nonelec_recp} ({(nonelec_pairs - nonelec_uni - nonelec_recp)/nonelec_pairs:.1%} of nonelec pairs)")
|
|
1971
|
+
|
|
1972
|
+
print(f"\nOverall chemical connectivity:")
|
|
1973
|
+
print(f" Numbers of connections: unidirectional, reciprocal")
|
|
1974
|
+
print(f" Number of connected pairs: ({n_uni}, {n_recp})")
|
|
1975
|
+
print(f" Fraction of connected pairs: ({n_uni/n_pairs:.2%}, {n_recp/n_pairs:.2%})")
|
|
1976
|
+
print(f" Total chemical connectivity: {(n_uni + n_recp)/n_pairs:.2%}")
|
|
1977
|
+
|
|
1978
|
+
else:
|
|
1979
|
+
# For non-recurrent networks
|
|
1980
|
+
n_pairs = self.n_source * self.n_target
|
|
1981
|
+
n_forward = np.count_nonzero(conn_mat[0])
|
|
1982
|
+
n_backward = np.count_nonzero(conn_mat[1])
|
|
1983
|
+
n_recp = np.count_nonzero(conn_mat[0] & conn_mat[1])
|
|
1984
|
+
|
|
1985
|
+
print("Numbers of connections: forward, backward, reciprocal")
|
|
1986
|
+
print(f"Number of connected pairs: ({n_forward}, {n_backward}, {n_recp})")
|
|
1987
|
+
print(f"Fraction of connected pairs: ({n_forward/n_pairs:.2%}, {n_backward/n_pairs:.2%}, {n_recp/n_pairs:.2%})")
|
|
1988
|
+
|
|
1989
|
+
def save_connection_report(self):
|
|
1990
|
+
# Implement similar to ReciprocalConnector if needed
|
|
1991
|
+
pass
|
|
1992
|
+
|
|
1607
1993
|
|
|
1608
1994
|
|
|
1609
1995
|
class OneToOneSequentialConnector(AbstractConnector):
|