superneuroabm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- superneuroabm/__init__.py +3 -0
- superneuroabm/component_base_config.yaml +129 -0
- superneuroabm/io/__init__.py +3 -0
- superneuroabm/io/nx.py +425 -0
- superneuroabm/io/synthetic_networks.py +770 -0
- superneuroabm/model.py +689 -0
- superneuroabm/step_functions/soma/izh.py +86 -0
- superneuroabm/step_functions/soma/lif.py +98 -0
- superneuroabm/step_functions/soma/lif_soma_adaptive_thr.py +111 -0
- superneuroabm/step_functions/synapse/single_exp.py +71 -0
- superneuroabm/step_functions/synapse/stdp/Low_resolution_synapse.py +117 -0
- superneuroabm/step_functions/synapse/stdp/Three-bit_exp_pair_wise.py +130 -0
- superneuroabm/step_functions/synapse/stdp/Three_bit_exp_pair_wise.py +133 -0
- superneuroabm/step_functions/synapse/stdp/exp_pair_wise_stdp.py +119 -0
- superneuroabm/step_functions/synapse/stdp/learning_rule_selector.py +72 -0
- superneuroabm/step_functions/synapse/util.py +49 -0
- superneuroabm/util.py +38 -0
- superneuroabm-1.0.0.dist-info/METADATA +100 -0
- superneuroabm-1.0.0.dist-info/RECORD +22 -0
- superneuroabm-1.0.0.dist-info/WHEEL +5 -0
- superneuroabm-1.0.0.dist-info/licenses/LICENSE +28 -0
- superneuroabm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
soma:
|
|
2
|
+
lif_soma:
|
|
3
|
+
config_0:
|
|
4
|
+
hyperparameters:
|
|
5
|
+
C: 10e-9 # Membrane capacitance in Farads (10 nF)
|
|
6
|
+
R: 1e6 # Membrane resistance in Ohms (1 TΩ)
|
|
7
|
+
vthr: -45 # Spike threshold voltage (mV)
|
|
8
|
+
tref: 5e-3 # Refractory period (5 ms)
|
|
9
|
+
vrest: -60 # Resting potential (mV)
|
|
10
|
+
vreset: -60 # Reset potential after spike (mV)
|
|
11
|
+
tref_allows_integration: 1 # Whether to allow integration during refractory period
|
|
12
|
+
I_in: 0 # Input current (40 nA)
|
|
13
|
+
scaling_factor: 1e-5 # Scaling factor for synaptic current
|
|
14
|
+
internal_state:
|
|
15
|
+
v: -60.0 # Initial membrane voltage
|
|
16
|
+
tcount: 0.0 # Time counter
|
|
17
|
+
tlast: 0.0 # Last spike time
|
|
18
|
+
izh_soma:
|
|
19
|
+
config_0: #intrinsic bursting
|
|
20
|
+
hyperparameters:
|
|
21
|
+
k: 1.2
|
|
22
|
+
vthr: -45
|
|
23
|
+
C: 150
|
|
24
|
+
a: 0.01
|
|
25
|
+
b: 5
|
|
26
|
+
vpeak: 50
|
|
27
|
+
vrest: -75
|
|
28
|
+
d: 130
|
|
29
|
+
vreset: -56
|
|
30
|
+
I_in: 420
|
|
31
|
+
internal_state:
|
|
32
|
+
v: -75 # Initial membrane voltage
|
|
33
|
+
u: 0 # Initial recovery variable
|
|
34
|
+
|
|
35
|
+
config_1: # regular spiking
|
|
36
|
+
hyperparameters:
|
|
37
|
+
k: 0.7
|
|
38
|
+
vthr: -40
|
|
39
|
+
C: 100
|
|
40
|
+
a: 0.03
|
|
41
|
+
b: -2
|
|
42
|
+
vpeak: 35
|
|
43
|
+
vrest: -60
|
|
44
|
+
d: 100
|
|
45
|
+
vreset: -50
|
|
46
|
+
I_in: 100
|
|
47
|
+
internal_state:
|
|
48
|
+
v: -60 # Initial membrane voltage
|
|
49
|
+
u: 0 # Initial recovery variable
|
|
50
|
+
|
|
51
|
+
lif_soma_adaptive_thr:
|
|
52
|
+
config_0:
|
|
53
|
+
hyperparameters:
|
|
54
|
+
C: 10e-9 # Membrane capacitance in Farads (10 nF)
|
|
55
|
+
R: 1e6 # Membrane resistance in Ohms (1 TΩ)
|
|
56
|
+
vthr_initial: -45 # Spike threshold voltage (mV)
|
|
57
|
+
tref: 5e-3 # Refractory period (5 ms)
|
|
58
|
+
vrest: -60 # Resting potential (mV)
|
|
59
|
+
vreset: -60 # Reset potential after spike (mV)
|
|
60
|
+
tref_allows_integration: 1 # Whether to allow integration during refractory period
|
|
61
|
+
I_in: 0 # Input current (40 nA)
|
|
62
|
+
scaling_factor: 1e-5 # Scaling factor for synaptic current
|
|
63
|
+
delta_thr: 1.0 # Threshold increase after spike (mV)
|
|
64
|
+
tau_decay_thr: 30e-3 # Time constant for threshold decay (100 ms)
|
|
65
|
+
internal_state:
|
|
66
|
+
v: -60.0 # Initial membrane voltage
|
|
67
|
+
tcount: 0.0 # Time counter
|
|
68
|
+
tlast: 0.0 # Last spike time
|
|
69
|
+
vthr: -45.0 # Initial spike threshold voltage
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
synapse:
|
|
73
|
+
single_exp_synapse:
|
|
74
|
+
no_learning_config_0:
|
|
75
|
+
hyperparameters:
|
|
76
|
+
weight: 14.0 # Synaptic weight (strength)
|
|
77
|
+
synaptic_delay: 1.0 # Transmission delay (ms)
|
|
78
|
+
scale: 1.0 # Scaling factor
|
|
79
|
+
tau_fall: 1e-2 # Decay time constant (1 ms)
|
|
80
|
+
tau_rise: 0 # Rise time constant (instantaneous)
|
|
81
|
+
internal_state:
|
|
82
|
+
I_synapse: 0.0 # Initial synaptic current
|
|
83
|
+
learning_hyperparameters:
|
|
84
|
+
stdp_type: -1 # No learning
|
|
85
|
+
exp_pair_wise_stdp_config_0:
|
|
86
|
+
hyperparameters:
|
|
87
|
+
weight: 14.0 # Synaptic weight (strength)
|
|
88
|
+
synaptic_delay: 1.0 # Transmission delay (ms)
|
|
89
|
+
scale: 1.0 # Scaling factor
|
|
90
|
+
tau_fall: 1e-2 # Decay time constant (1 ms)
|
|
91
|
+
tau_rise: 0 # Rise time constant (instantaneous)
|
|
92
|
+
internal_state:
|
|
93
|
+
I_synapse: 0.0 # Initial synaptic current
|
|
94
|
+
learning_hyperparameters:
|
|
95
|
+
stdp_type: 0.0 # Exp pair-wise STDP
|
|
96
|
+
|
|
97
|
+
tau_pre_stdp: 10e-3 # Pre-synaptic STDP time constant (10 ms)
|
|
98
|
+
tau_post_stdp: 10e-3 # Post-synaptic STDP time constant (10 ms)
|
|
99
|
+
a_exp_pre: 0.005 # Pre-synaptic STDP learning rate
|
|
100
|
+
a_exp_post: 0.005 # Post-synaptic STDP learning rate
|
|
101
|
+
stdp_history_length: 100 # Length of STDP history buffer
|
|
102
|
+
internal_learning_state:
|
|
103
|
+
pre_trace: 0 # Pre-synaptic trace
|
|
104
|
+
post_trace: 0 # Post-synaptic trace
|
|
105
|
+
dW: 0 # Weight change accumulator
|
|
106
|
+
|
|
107
|
+
three_bit_exp_pair_wise_stdp_config_0:
|
|
108
|
+
hyperparameters:
|
|
109
|
+
weight: 14.0 # Synaptic weight (strength)
|
|
110
|
+
synaptic_delay: 1.0 # Transmission delay (ms)
|
|
111
|
+
scale: 1.0 # Scaling factor
|
|
112
|
+
tau_fall: 1e-2 # Decay time constant (1 ms)
|
|
113
|
+
tau_rise: 0 # Rise time constant (instantaneous)
|
|
114
|
+
internal_state:
|
|
115
|
+
I_synapse: 0.0 # Initial synaptic current
|
|
116
|
+
learning_hyperparameters:
|
|
117
|
+
stdp_type: 1.0 # Exp pair-wise STDP
|
|
118
|
+
tau_pre_stdp: 10e-3 # Pre-synaptic STDP time constant (10 ms)
|
|
119
|
+
tau_post_stdp: 10e-3 # Post-synaptic STDP time constant (10 ms)
|
|
120
|
+
a_exp_pre: 0.005 # Pre-synaptic STDP learning rate
|
|
121
|
+
a_exp_post: 0.005 # Post-synaptic STDP learning rate
|
|
122
|
+
stdp_history_length: 100 # Length of STDP history buffer
|
|
123
|
+
wmin: 0.0 # Minimum synaptic weight
|
|
124
|
+
wmax: 24.0 # Maximum synaptic weight
|
|
125
|
+
num_levels: 8 # Number of quantization levels (3 bits)
|
|
126
|
+
internal_learning_state:
|
|
127
|
+
pre_trace: 0 # Pre-synaptic trace
|
|
128
|
+
post_trace: 0 # Post-synaptic trace
|
|
129
|
+
dW: 0 # Weight change accumulator
|
superneuroabm/io/nx.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Saves and loads NetworkX graphs and parses them into SuperNeuroABM networks.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
|
|
9
|
+
import networkx as nx
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from superneuroabm.model import NeuromorphicModel
|
|
13
|
+
from superneuroabm.util import load_component_configurations
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def generate_metis_partition(graph: nx.DiGraph, num_workers: int) -> Dict[int, int]:
|
|
17
|
+
"""
|
|
18
|
+
Generate network partition using METIS for optimal agent-to-worker assignment.
|
|
19
|
+
|
|
20
|
+
This function creates a partition that minimizes cross-worker communication by grouping
|
|
21
|
+
connected nodes together. The partition can significantly improve multi-worker performance
|
|
22
|
+
(10-20× reduction in MPI overhead).
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
graph: NetworkX graph to partition
|
|
26
|
+
num_workers: Number of MPI workers
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Dictionary mapping node -> rank
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ImportError: If metis is not installed
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
import metis
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"METIS not installed."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Filter out external input nodes (-1)
|
|
42
|
+
nodes_to_remove = [n for n in graph.nodes() if n == -1]
|
|
43
|
+
G_filtered = graph.copy()
|
|
44
|
+
G_filtered.remove_nodes_from(nodes_to_remove)
|
|
45
|
+
|
|
46
|
+
# Convert to undirected graph for METIS
|
|
47
|
+
if G_filtered.is_directed():
|
|
48
|
+
G_undirected = G_filtered.to_undirected()
|
|
49
|
+
else:
|
|
50
|
+
G_undirected = G_filtered
|
|
51
|
+
|
|
52
|
+
# Create adjacency list (METIS format)
|
|
53
|
+
node_list = list(G_undirected.nodes())
|
|
54
|
+
node_to_idx = {node: idx for idx, node in enumerate(node_list)}
|
|
55
|
+
|
|
56
|
+
adjacency = []
|
|
57
|
+
for node in node_list:
|
|
58
|
+
neighbors = [node_to_idx[neighbor] for neighbor in G_undirected.neighbors(node)]
|
|
59
|
+
adjacency.append(neighbors)
|
|
60
|
+
|
|
61
|
+
# Run METIS
|
|
62
|
+
print(f"[SuperNeuroABM] Running METIS partition with {num_workers} partitions...")
|
|
63
|
+
_, partition_array = metis.part_graph(adjacency, nparts=num_workers, recursive=True)
|
|
64
|
+
|
|
65
|
+
# Normalize partition indices to start from 0
|
|
66
|
+
# METIS may return partitions starting from 1 when nparts=1
|
|
67
|
+
unique_partitions = sorted(set(partition_array))
|
|
68
|
+
partition_remap = {old_id: new_id for new_id, old_id in enumerate(unique_partitions)}
|
|
69
|
+
partition_array = [partition_remap[p] for p in partition_array]
|
|
70
|
+
|
|
71
|
+
# Convert to dict mapping original node -> rank
|
|
72
|
+
partition_dict = {}
|
|
73
|
+
for idx, rank in enumerate(partition_array):
|
|
74
|
+
original_node = node_list[idx]
|
|
75
|
+
partition_dict[original_node] = int(rank)
|
|
76
|
+
|
|
77
|
+
# Calculate partition quality
|
|
78
|
+
total_edges = 0
|
|
79
|
+
cross_worker_edges = 0
|
|
80
|
+
for u, v in graph.edges():
|
|
81
|
+
if u in partition_dict and v in partition_dict:
|
|
82
|
+
total_edges += 1
|
|
83
|
+
if partition_dict[u] != partition_dict[v]:
|
|
84
|
+
cross_worker_edges += 1
|
|
85
|
+
|
|
86
|
+
edge_cut_ratio = cross_worker_edges / total_edges if total_edges > 0 else 0
|
|
87
|
+
|
|
88
|
+
print(f"[SuperNeuroABM] Partition quality:")
|
|
89
|
+
print(f" - Edge cut ratio (P_cross): {edge_cut_ratio:.4f}")
|
|
90
|
+
print(f" - Total edges: {total_edges}, Cross-worker edges: {cross_worker_edges}")
|
|
91
|
+
|
|
92
|
+
return partition_dict
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def model_from_nx_graph(
|
|
96
|
+
graph: nx.DiGraph,
|
|
97
|
+
enable_internal_state_tracking: bool = True,
|
|
98
|
+
partition_method: Optional[str] = None,
|
|
99
|
+
partition_dict: Optional[Dict[int, int]] = None
|
|
100
|
+
) -> NeuromorphicModel:
|
|
101
|
+
"""
|
|
102
|
+
Load a NetworkX graph and create a NeuromorphicModel.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
graph: A NetworkX DiGraph object.
|
|
106
|
+
Vertices should have 'soma_breed' and 'config' attributes, and optionally
|
|
107
|
+
'overrides' and 'tags attributes.
|
|
108
|
+
Edges should have 'synapse_breed' and 'config' attributes, and optionally
|
|
109
|
+
'overrides' and 'tags attributes.
|
|
110
|
+
enable_internal_state_tracking: If True (default), tracks internal states history
|
|
111
|
+
during simulation. If False, disables tracking to save memory and improve performance.
|
|
112
|
+
partition_method: Partition method to use. Options:
|
|
113
|
+
- None: No partitioning (default, round-robin assignment)
|
|
114
|
+
- 'metis': Generate METIS partition (requires multiple MPI workers)
|
|
115
|
+
partition_dict: Pre-computed partition dictionary mapping node_id -> rank.
|
|
116
|
+
If provided, this overrides partition_method.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A NeuromorphicModel object constructed from the graph.
|
|
120
|
+
"""
|
|
121
|
+
from mpi4py import MPI
|
|
122
|
+
comm = MPI.COMM_WORLD
|
|
123
|
+
rank = comm.Get_rank()
|
|
124
|
+
size = comm.Get_size()
|
|
125
|
+
|
|
126
|
+
component_configurations = load_component_configurations()
|
|
127
|
+
|
|
128
|
+
model = NeuromorphicModel(enable_internal_state_tracking=enable_internal_state_tracking)
|
|
129
|
+
|
|
130
|
+
# Handle partitioning based on method
|
|
131
|
+
node_to_rank = None
|
|
132
|
+
|
|
133
|
+
if partition_dict is not None:
|
|
134
|
+
# Use provided partition dictionary
|
|
135
|
+
node_to_rank = partition_dict
|
|
136
|
+
if rank == 0:
|
|
137
|
+
print(f"[SuperNeuroABM] Using provided partition dictionary")
|
|
138
|
+
|
|
139
|
+
elif partition_method == 'metis':
|
|
140
|
+
if size == 1:
|
|
141
|
+
if rank == 0:
|
|
142
|
+
print("[SuperNeuroABM] Warning: METIS partition requested but running with single worker. Skipping partition.")
|
|
143
|
+
else:
|
|
144
|
+
if rank == 0:
|
|
145
|
+
print(f"\n{'='*60}")
|
|
146
|
+
print(f"Multi-worker mode: {size} workers")
|
|
147
|
+
print(f"Generating METIS partition for optimal performance...")
|
|
148
|
+
print(f"{'='*60}\n")
|
|
149
|
+
|
|
150
|
+
# Generate METIS partition (only rank 0)
|
|
151
|
+
node_to_rank = generate_metis_partition(graph, size)
|
|
152
|
+
|
|
153
|
+
# Broadcast partition to all ranks
|
|
154
|
+
node_to_rank = comm.bcast(node_to_rank, root=0)
|
|
155
|
+
|
|
156
|
+
if rank == 0:
|
|
157
|
+
print(f"[SuperNeuroABM] Partition generated and broadcast to all workers")
|
|
158
|
+
|
|
159
|
+
elif partition_method is None:
|
|
160
|
+
if size > 1 and rank == 0:
|
|
161
|
+
print(f"[SuperNeuroABM] Multi-worker mode ({size} workers) - using round-robin assignment")
|
|
162
|
+
print(f"[SuperNeuroABM] Tip: Use partition_method='metis' for better performance")
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Unknown partition_method: {partition_method}. Use None or 'metis'")
|
|
165
|
+
|
|
166
|
+
name2id = {}
|
|
167
|
+
id2name = {}
|
|
168
|
+
|
|
169
|
+
# If we have a partition, we need to convert node->rank to agent_id->rank
|
|
170
|
+
# We do this by creating agents in a specific order based on the partition
|
|
171
|
+
if node_to_rank:
|
|
172
|
+
# Create mapping from agent_id (creation order) to rank
|
|
173
|
+
# We'll create nodes in sorted order to ensure consistency across all ranks
|
|
174
|
+
# -1 indicates an input or output node
|
|
175
|
+
sorted_nodes = sorted([n for n in graph.nodes() if n != -1])
|
|
176
|
+
agent_id_to_rank = {}
|
|
177
|
+
agent_id = 0
|
|
178
|
+
|
|
179
|
+
# First, assign neuron agents (nodes)
|
|
180
|
+
# CRITICAL: Must assign EVERY agent_id that will be created!
|
|
181
|
+
neurons_assigned = 0
|
|
182
|
+
for node in sorted_nodes:
|
|
183
|
+
# Every node in sorted_nodes will create an agent, so every agent_id needs a rank
|
|
184
|
+
if node in node_to_rank:
|
|
185
|
+
agent_id_to_rank[agent_id] = node_to_rank[node]
|
|
186
|
+
else:
|
|
187
|
+
# Fallback to round-robin for any nodes not in partition
|
|
188
|
+
agent_id_to_rank[agent_id] = agent_id % size
|
|
189
|
+
neurons_assigned += 1
|
|
190
|
+
agent_id += 1
|
|
191
|
+
|
|
192
|
+
# Second, assign synapse agents (edges) to keep them with their clusters
|
|
193
|
+
# Synapses will be created in the order of graph.edges()
|
|
194
|
+
# Assign each synapse to the same worker as its pre-synaptic neuron
|
|
195
|
+
synapses_assigned = 0
|
|
196
|
+
for u, v, data in graph.edges(data=True):
|
|
197
|
+
# Determine which worker this synapse should be on
|
|
198
|
+
if u in node_to_rank and u >= 0:
|
|
199
|
+
# Assign synapse to same worker as pre-synaptic neuron
|
|
200
|
+
synapse_rank = node_to_rank[u]
|
|
201
|
+
agent_id_to_rank[agent_id] = synapse_rank
|
|
202
|
+
synapses_assigned += 1
|
|
203
|
+
elif v in node_to_rank:
|
|
204
|
+
# If u is external input (-1), use post-synaptic neuron's worker
|
|
205
|
+
synapse_rank = node_to_rank[v]
|
|
206
|
+
agent_id_to_rank[agent_id] = synapse_rank
|
|
207
|
+
synapses_assigned += 1
|
|
208
|
+
else:
|
|
209
|
+
# Fallback to round-robin (shouldn't happen with proper partition)
|
|
210
|
+
synapse_rank = agent_id % size
|
|
211
|
+
agent_id_to_rank[agent_id] = synapse_rank
|
|
212
|
+
synapses_assigned += 1
|
|
213
|
+
|
|
214
|
+
# Always increment for every synapse created
|
|
215
|
+
agent_id += 1
|
|
216
|
+
|
|
217
|
+
# Load this mapping directly into the model (no file needed)
|
|
218
|
+
model._agent_factory._partition_mapping = agent_id_to_rank
|
|
219
|
+
model._agent_factory._partition_loaded = True
|
|
220
|
+
|
|
221
|
+
if rank == 0:
|
|
222
|
+
print(f"[SuperNeuroABM] Converted node partition to agent_id partition")
|
|
223
|
+
print(f"[SuperNeuroABM] Assigned {neurons_assigned}/{len(sorted_nodes)} neurons, {synapses_assigned} synapses")
|
|
224
|
+
print(f"[SuperNeuroABM] Total agents with partition: {len(agent_id_to_rank)}")
|
|
225
|
+
|
|
226
|
+
# Create somas in sorted order to match partition
|
|
227
|
+
for node in sorted_nodes:
|
|
228
|
+
data = graph.nodes[node]
|
|
229
|
+
soma_breed = data.get("soma_breed")
|
|
230
|
+
config_name = data.get("config", "config_0")
|
|
231
|
+
overrides = data.get("overrides", {})
|
|
232
|
+
tags = set(data.get("tags", []))
|
|
233
|
+
tags.add(f"nx_node:{node}")
|
|
234
|
+
|
|
235
|
+
soma_id = model.create_soma(
|
|
236
|
+
breed=soma_breed,
|
|
237
|
+
config_name=config_name,
|
|
238
|
+
hyperparameters_overrides=overrides.get("hyperparameters"),
|
|
239
|
+
default_internal_state_overrides=overrides.get("internal_state"),
|
|
240
|
+
tags=tags,
|
|
241
|
+
)
|
|
242
|
+
name2id[node] = soma_id
|
|
243
|
+
id2name[soma_id] = node
|
|
244
|
+
else:
|
|
245
|
+
# Create somas from graph nodes (original behavior)
|
|
246
|
+
for node, data in graph.nodes(data=True):
|
|
247
|
+
# -1 indicates an input or output node (external synapse)
|
|
248
|
+
if node == -1:
|
|
249
|
+
continue
|
|
250
|
+
soma_breed = data.get("soma_breed")
|
|
251
|
+
config_name = data.get("config", "config_0")
|
|
252
|
+
overrides = data.get("overrides", {})
|
|
253
|
+
tags = set(data.get("tags", []))
|
|
254
|
+
tags.add(f"nx_node:{node}")
|
|
255
|
+
|
|
256
|
+
soma_id = model.create_soma(
|
|
257
|
+
breed=soma_breed,
|
|
258
|
+
config_name=config_name,
|
|
259
|
+
hyperparameters_overrides=overrides.get("hyperparameters"),
|
|
260
|
+
default_internal_state_overrides=overrides.get("internal_state"),
|
|
261
|
+
tags=tags,
|
|
262
|
+
)
|
|
263
|
+
name2id[node] = soma_id
|
|
264
|
+
id2name[soma_id] = node
|
|
265
|
+
|
|
266
|
+
# Create synapses from graph edges
|
|
267
|
+
synapse_count = 0
|
|
268
|
+
for u, v, data in graph.edges(data=True):
|
|
269
|
+
synapse_breed = data.get("synapse_breed")
|
|
270
|
+
config_name = data.get("config", "config_0")
|
|
271
|
+
overrides = data.get("overrides", {})
|
|
272
|
+
tags = set(data.get("tags", []))
|
|
273
|
+
tags.add(f"nx_edge:{u}_to_{v}")
|
|
274
|
+
|
|
275
|
+
pre_soma_id = name2id.get(u, -1) # External input if not found
|
|
276
|
+
post_soma_id = name2id[v]
|
|
277
|
+
model.create_synapse(
|
|
278
|
+
breed=synapse_breed,
|
|
279
|
+
pre_soma_id=pre_soma_id,
|
|
280
|
+
post_soma_id=post_soma_id,
|
|
281
|
+
config_name=config_name,
|
|
282
|
+
hyperparameters_overrides=overrides.get("hyperparameters"),
|
|
283
|
+
default_internal_state_overrides=overrides.get("internal_state"),
|
|
284
|
+
learning_hyperparameters_overrides=overrides.get(
|
|
285
|
+
"learning_hyperparameters"
|
|
286
|
+
),
|
|
287
|
+
default_internal_learning_state_overrides=overrides.get(
|
|
288
|
+
"default_internal_learning_state"
|
|
289
|
+
),
|
|
290
|
+
tags=tags,
|
|
291
|
+
)
|
|
292
|
+
synapse_count += 1
|
|
293
|
+
|
|
294
|
+
return model
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def nx_graph_from_model(
|
|
298
|
+
model: NeuromorphicModel, override_internal_state: bool = True
|
|
299
|
+
) -> nx.DiGraph:
|
|
300
|
+
"""
|
|
301
|
+
Convert a NeuromorphicModel to a NetworkX graph.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
model: A NeuromorphicModel object.
|
|
305
|
+
override_internal_state: If True, adds overrides of internal_state
|
|
306
|
+
and internal_learning_state with post simulation internal_state
|
|
307
|
+
and internal_learning_state.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A NetworkX DiGraph representing the model.
|
|
311
|
+
"""
|
|
312
|
+
graph = nx.DiGraph()
|
|
313
|
+
|
|
314
|
+
# Add nodes for somas
|
|
315
|
+
for soma_id in model.get_agents_with_tag("soma"):
|
|
316
|
+
soma_breed = model.get_agent_breed(soma_id).name
|
|
317
|
+
config = model.get_agent_config_name(soma_id)
|
|
318
|
+
overrides = model.get_agent_config_diff(soma_id)
|
|
319
|
+
|
|
320
|
+
if not override_internal_state:
|
|
321
|
+
# Remove internal state overrides if not needed
|
|
322
|
+
overrides.pop("internal_state", None)
|
|
323
|
+
overrides.pop("internal_learning_state", None)
|
|
324
|
+
|
|
325
|
+
graph.add_node(
|
|
326
|
+
soma_id,
|
|
327
|
+
soma_breed=soma_breed,
|
|
328
|
+
config=config,
|
|
329
|
+
overrides=overrides,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Add edges for synapses
|
|
333
|
+
for synapse_id in model.get_agents_with_tag("synapse"):
|
|
334
|
+
pre_soma_id, post_soma_id = model.get_synapse_connectivity(synapse_id)
|
|
335
|
+
synapse_breed = model.get_agent_breed(synapse_id).name
|
|
336
|
+
config = model.get_agent_config_name(synapse_id)
|
|
337
|
+
overrides = model.get_agent_config_diff(synapse_id)
|
|
338
|
+
|
|
339
|
+
if not override_internal_state:
|
|
340
|
+
# Remove internal state overrides if not needed
|
|
341
|
+
overrides.pop("internal_state", None)
|
|
342
|
+
overrides.pop("internal_learning_state", None)
|
|
343
|
+
|
|
344
|
+
graph.add_edge(
|
|
345
|
+
pre_soma_id,
|
|
346
|
+
post_soma_id,
|
|
347
|
+
synapse_breed=synapse_breed,
|
|
348
|
+
config=config,
|
|
349
|
+
overrides=overrides,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return graph
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
if __name__ == "__main__":
|
|
356
|
+
# Example usage
|
|
357
|
+
graph = nx.DiGraph()
|
|
358
|
+
graph.add_node(
|
|
359
|
+
"A",
|
|
360
|
+
soma_breed="lif_soma",
|
|
361
|
+
config="config_0",
|
|
362
|
+
overrides={
|
|
363
|
+
"hyperparameters": {"R": 1.1e6},
|
|
364
|
+
"internal_state": {"v": -60.01},
|
|
365
|
+
},
|
|
366
|
+
)
|
|
367
|
+
graph.add_node(
|
|
368
|
+
"B",
|
|
369
|
+
soma_breed="izh_soma",
|
|
370
|
+
config="config_0",
|
|
371
|
+
overrides={
|
|
372
|
+
"hyperparameters": {"a": 0.0102, "b": 5.001},
|
|
373
|
+
"internal_state": {"v": -75.002},
|
|
374
|
+
},
|
|
375
|
+
)
|
|
376
|
+
# -1 indicates external synapse
|
|
377
|
+
graph.add_edge(
|
|
378
|
+
-1,
|
|
379
|
+
"A",
|
|
380
|
+
synapse_breed="single_exp_synapse",
|
|
381
|
+
config="no_learning_config_0",
|
|
382
|
+
overrides={"hyperparameters": {"weight": 13.5}},
|
|
383
|
+
)
|
|
384
|
+
graph.add_edge(
|
|
385
|
+
"A",
|
|
386
|
+
"B",
|
|
387
|
+
synapse_breed="single_exp_synapse",
|
|
388
|
+
config="no_learning_config_0",
|
|
389
|
+
overrides={"hyperparameters": {"weight": 13.5}},
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
model = model_from_nx_graph(graph)
|
|
393
|
+
model.setup(use_gpu=True)
|
|
394
|
+
|
|
395
|
+
# Add spikes to the synapse connected to the first soma
|
|
396
|
+
input_synapses = model.get_agents_with_tag("input_synapse")
|
|
397
|
+
model.add_spike(synapse_id=input_synapses.pop(), tick=10, value=1.0)
|
|
398
|
+
|
|
399
|
+
model.simulate(ticks=200, update_data_ticks=200)
|
|
400
|
+
|
|
401
|
+
# Retrieve and print soma spikes
|
|
402
|
+
for soma_id in model.get_agents_with_tag("soma"):
|
|
403
|
+
spikes = model.get_spike_times(soma_id)
|
|
404
|
+
print(f"Soma {soma_id} spikes: {spikes}")
|
|
405
|
+
|
|
406
|
+
# Print the graph structure with all attributes
|
|
407
|
+
graph_out = nx_graph_from_model(model)
|
|
408
|
+
print("---------------------------------------------------------------")
|
|
409
|
+
print("Graph structure (override internal states):")
|
|
410
|
+
for node, data in graph_out.nodes(data=True):
|
|
411
|
+
print(f"Node {node}: {data}")
|
|
412
|
+
for u, v, data in graph_out.edges(data=True):
|
|
413
|
+
print(f"Edge {u} -> {v}: {data}")
|
|
414
|
+
|
|
415
|
+
print("---------------------------------------------------------------")
|
|
416
|
+
print("\n")
|
|
417
|
+
print("---------------------------------------------------------------")
|
|
418
|
+
# Print the graph structure with all attributes
|
|
419
|
+
graph_out = nx_graph_from_model(model, override_internal_state=False)
|
|
420
|
+
print("Graph structure (do not override internal states):")
|
|
421
|
+
for node, data in graph_out.nodes(data=True):
|
|
422
|
+
print(f"Node {node}: {data}")
|
|
423
|
+
for u, v, data in graph_out.edges(data=True):
|
|
424
|
+
print(f"Edge {u} -> {v}: {data}")
|
|
425
|
+
print("---------------------------------------------------------------")
|