superneuroabm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- superneuroabm/__init__.py +3 -0
- superneuroabm/component_base_config.yaml +129 -0
- superneuroabm/io/__init__.py +3 -0
- superneuroabm/io/nx.py +425 -0
- superneuroabm/io/synthetic_networks.py +770 -0
- superneuroabm/model.py +689 -0
- superneuroabm/step_functions/soma/izh.py +86 -0
- superneuroabm/step_functions/soma/lif.py +98 -0
- superneuroabm/step_functions/soma/lif_soma_adaptive_thr.py +111 -0
- superneuroabm/step_functions/synapse/single_exp.py +71 -0
- superneuroabm/step_functions/synapse/stdp/Low_resolution_synapse.py +117 -0
- superneuroabm/step_functions/synapse/stdp/Three-bit_exp_pair_wise.py +130 -0
- superneuroabm/step_functions/synapse/stdp/Three_bit_exp_pair_wise.py +133 -0
- superneuroabm/step_functions/synapse/stdp/exp_pair_wise_stdp.py +119 -0
- superneuroabm/step_functions/synapse/stdp/learning_rule_selector.py +72 -0
- superneuroabm/step_functions/synapse/util.py +49 -0
- superneuroabm/util.py +38 -0
- superneuroabm-1.0.0.dist-info/METADATA +100 -0
- superneuroabm-1.0.0.dist-info/RECORD +22 -0
- superneuroabm-1.0.0.dist-info/WHEEL +5 -0
- superneuroabm-1.0.0.dist-info/licenses/LICENSE +28 -0
- superneuroabm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Izhikevich Neuron and weighted synapse step functions for spiking neural networks
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import cupy as cp
|
|
7
|
+
from cupyx import jit
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@jit.rawkernel(device="cuda")
|
|
11
|
+
def izh_soma_step_func(
|
|
12
|
+
tick,
|
|
13
|
+
agent_index,
|
|
14
|
+
globals,
|
|
15
|
+
agent_ids,
|
|
16
|
+
breeds,
|
|
17
|
+
locations,
|
|
18
|
+
neuron_params, # k, vth, C, a, b,
|
|
19
|
+
learning_params,
|
|
20
|
+
internal_state, # v, u
|
|
21
|
+
internal_learning_state,
|
|
22
|
+
synapse_history, # Synapse delay
|
|
23
|
+
input_spikes_tensor, # input spikes
|
|
24
|
+
output_spikes_tensor,
|
|
25
|
+
internal_states_buffer,
|
|
26
|
+
internal_learning_states_buffer,
|
|
27
|
+
):
|
|
28
|
+
synapse_indices = locations[agent_index] # network location is defined by neighbors
|
|
29
|
+
|
|
30
|
+
I_synapse = 0.0
|
|
31
|
+
|
|
32
|
+
# synapse_indices now contains pre-computed local indices (converted in SAGESim)
|
|
33
|
+
# No linear search needed!
|
|
34
|
+
for i in range(len(synapse_indices)):
|
|
35
|
+
synapse_index = int(synapse_indices[i])
|
|
36
|
+
if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
|
|
37
|
+
I_synapse += internal_state[synapse_index][0]
|
|
38
|
+
|
|
39
|
+
# Get the current time step value:
|
|
40
|
+
t_current = int(tick)
|
|
41
|
+
dt = globals[0] # time step size
|
|
42
|
+
I_bias = globals[1] # bias current
|
|
43
|
+
|
|
44
|
+
# NOTE: neuron_params would need to be as long as the max number of params in any spiking neuron model
|
|
45
|
+
k = neuron_params[agent_index][0]
|
|
46
|
+
vthr = neuron_params[agent_index][1]
|
|
47
|
+
C = neuron_params[agent_index][2]
|
|
48
|
+
a = neuron_params[agent_index][3]
|
|
49
|
+
b = neuron_params[agent_index][4]
|
|
50
|
+
vpeak = neuron_params[agent_index][5]
|
|
51
|
+
vrest = neuron_params[agent_index][6]
|
|
52
|
+
d = neuron_params[agent_index][7]
|
|
53
|
+
vreset = neuron_params[agent_index][8]
|
|
54
|
+
I_in = neuron_params[agent_index][9]
|
|
55
|
+
|
|
56
|
+
# From https://www.izhikevich.org/publications/spikes.htm
|
|
57
|
+
# v' = 0.04v^2 + 5v + 140 -u + I
|
|
58
|
+
# u' = a(bv - u)
|
|
59
|
+
# if v=30mV: v = c, u = u + d, spike
|
|
60
|
+
|
|
61
|
+
# dv = (k*(internal_state[my_idx]-vrest)*(internal_state[my_idx]-vthr)-u[my_idx]+I) / C
|
|
62
|
+
# internal_state: [0] - v, [1] - u
|
|
63
|
+
# NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
|
|
64
|
+
|
|
65
|
+
v = internal_state[agent_index][0]
|
|
66
|
+
u = internal_state[agent_index][1]
|
|
67
|
+
|
|
68
|
+
dv = (k * (v - vrest) * (v - vthr) - u + I_synapse + I_bias + I_in) / C
|
|
69
|
+
v = v + dt * dv * 1e3
|
|
70
|
+
|
|
71
|
+
u += dt * 1e3 * (a * (b * (v - vrest) - u))
|
|
72
|
+
# s = 1 * (v >= vthr) # output spike
|
|
73
|
+
s = 1 * (v >= vpeak) # output spike
|
|
74
|
+
u = u + d * s # If spiked, update recovery variable
|
|
75
|
+
v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
|
|
76
|
+
|
|
77
|
+
internal_state[agent_index][0] = v
|
|
78
|
+
internal_state[agent_index][1] = u
|
|
79
|
+
|
|
80
|
+
output_spikes_tensor[agent_index][t_current] = s
|
|
81
|
+
|
|
82
|
+
# Safe buffer indexing: use modulo to prevent out-of-bounds access
|
|
83
|
+
# When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
|
|
84
|
+
buffer_idx = t_current % len(internal_states_buffer[agent_index])
|
|
85
|
+
internal_states_buffer[agent_index][buffer_idx][0] = v
|
|
86
|
+
internal_states_buffer[agent_index][buffer_idx][1] = u
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LIF Neuron and weighted synapse step functions for spiking neural networks
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from cupyx import jit
|
|
7
|
+
import cupy as cp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@jit.rawkernel(device="cuda")
|
|
11
|
+
def lif_soma_step_func( # NOTE: update the name to soma_step_func from neuron_step_func
|
|
12
|
+
tick,
|
|
13
|
+
agent_index,
|
|
14
|
+
globals,
|
|
15
|
+
agent_ids,
|
|
16
|
+
breeds,
|
|
17
|
+
locations,
|
|
18
|
+
neuron_params, # k, vth, C, a, b,
|
|
19
|
+
learning_params,
|
|
20
|
+
internal_state, # v, u
|
|
21
|
+
internal_learning_state,
|
|
22
|
+
synapse_history, # Synapse delay
|
|
23
|
+
input_spikes_tensor, # input spikes
|
|
24
|
+
output_spikes_tensor,
|
|
25
|
+
internal_states_buffer,
|
|
26
|
+
internal_learning_states_buffer,
|
|
27
|
+
):
|
|
28
|
+
synapse_indices = locations[agent_index] # Now contains local indices instead of IDs
|
|
29
|
+
|
|
30
|
+
I_synapse = 0.0
|
|
31
|
+
|
|
32
|
+
# synapse_indices now contains pre-computed local indices (converted in SAGESim)
|
|
33
|
+
# No linear search needed!
|
|
34
|
+
for i in range(len(synapse_indices)):
|
|
35
|
+
synapse_index = int(synapse_indices[i])
|
|
36
|
+
if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
|
|
37
|
+
I_synapse += internal_state[synapse_index][0]
|
|
38
|
+
|
|
39
|
+
# Get the current time step value:
|
|
40
|
+
t_current = int(tick) # Check if tcount is needed or if we ca use this directly.
|
|
41
|
+
dt = globals[0] # time step size
|
|
42
|
+
I_bias = globals[1] # bias current
|
|
43
|
+
|
|
44
|
+
# NOTE: neuron_params would need to as long as the max number of params in any spiking neuron model
|
|
45
|
+
# Neuron Parameter
|
|
46
|
+
C = neuron_params[agent_index][0] # membrane capacitance
|
|
47
|
+
R = neuron_params[agent_index][1] # Leak resistance
|
|
48
|
+
vthr = neuron_params[agent_index][2] # spike threshold
|
|
49
|
+
tref = neuron_params[agent_index][3] # refractory period
|
|
50
|
+
vrest = neuron_params[agent_index][4] # resting potential
|
|
51
|
+
vreset = neuron_params[agent_index][5] # reset potential
|
|
52
|
+
tref_allows_integration = neuron_params[agent_index][
|
|
53
|
+
6
|
|
54
|
+
] # whether to allow integration during refractory period
|
|
55
|
+
I_in = neuron_params[agent_index][7] # input current
|
|
56
|
+
scaling_factor = neuron_params[agent_index][
|
|
57
|
+
8
|
|
58
|
+
] # scaling factor for synaptic current
|
|
59
|
+
# vreset = neuron_params[agent_index][8]
|
|
60
|
+
# I_in = neuron_params[agent_index][9]
|
|
61
|
+
|
|
62
|
+
# NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
|
|
63
|
+
# Internal state variables
|
|
64
|
+
v = internal_state[agent_index][0] # membrane potential
|
|
65
|
+
tcount = internal_state[agent_index][
|
|
66
|
+
1
|
|
67
|
+
] # time count from the start of the simulation
|
|
68
|
+
tlast = internal_state[agent_index][2] # last spike time
|
|
69
|
+
|
|
70
|
+
# Calculate the membrane potential update
|
|
71
|
+
dv = (vrest - v) / (R * C) + (I_synapse * scaling_factor + I_bias + I_in) / C
|
|
72
|
+
|
|
73
|
+
v += (
|
|
74
|
+
(dv * dt)
|
|
75
|
+
if ((dt * tcount) > (tlast + tref)) or tref_allows_integration
|
|
76
|
+
else 0.0
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
#if tlast > 0 else 1 # output spike only happens if the membrane potential exceeds the threshold and the neuron is not in refractory period.
|
|
80
|
+
s = 1.0 * ((v >= vthr) and (( dt * tcount > tlast + tref) if tlast > 0 else True))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
tlast = tlast * (1 - s) + dt * tcount * s
|
|
84
|
+
v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
|
|
85
|
+
|
|
86
|
+
internal_state[agent_index][0] = v
|
|
87
|
+
internal_state[agent_index][1] += 1
|
|
88
|
+
internal_state[agent_index][2] = tlast
|
|
89
|
+
|
|
90
|
+
output_spikes_tensor[agent_index][t_current] = s
|
|
91
|
+
|
|
92
|
+
# Safe buffer indexing: use modulo to prevent out-of-bounds access
|
|
93
|
+
# When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
|
|
94
|
+
buffer_idx = t_current % len(internal_states_buffer[agent_index])
|
|
95
|
+
internal_states_buffer[agent_index][buffer_idx][0] = v
|
|
96
|
+
internal_states_buffer[agent_index][buffer_idx][1] = internal_state[agent_index][1] + 1
|
|
97
|
+
internal_states_buffer[agent_index][buffer_idx][2] = tlast
|
|
98
|
+
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LIF Neuron and weighted synapse step functions for spiking neural networks
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from cupyx import jit
|
|
7
|
+
import cupy as cp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@jit.rawkernel(device="cuda")
|
|
11
|
+
def lif_soma_adaptive_thr_step_func( # NOTE: update the name to soma_step_func from neuron_step_func
|
|
12
|
+
tick,
|
|
13
|
+
agent_index,
|
|
14
|
+
globals,
|
|
15
|
+
agent_ids,
|
|
16
|
+
breeds,
|
|
17
|
+
locations,
|
|
18
|
+
neuron_params, # k, vth, C, a, b,
|
|
19
|
+
learning_params,
|
|
20
|
+
internal_state, # v, u
|
|
21
|
+
internal_learning_state,
|
|
22
|
+
synapse_history, # Synapse delay
|
|
23
|
+
input_spikes_tensor, # input spikes
|
|
24
|
+
output_spikes_tensor,
|
|
25
|
+
internal_states_buffer,
|
|
26
|
+
internal_learning_states_buffer,
|
|
27
|
+
):
|
|
28
|
+
synapse_indices = locations[agent_index] # Now contains local indices instead of IDs
|
|
29
|
+
|
|
30
|
+
I_synapse = 0.0
|
|
31
|
+
|
|
32
|
+
# synapse_indices now contains pre-computed local indices (converted in SAGESim)
|
|
33
|
+
# No linear search needed!
|
|
34
|
+
for i in range(len(synapse_indices)):
|
|
35
|
+
synapse_index = int(synapse_indices[i])
|
|
36
|
+
if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
|
|
37
|
+
I_synapse += internal_state[synapse_index][0]
|
|
38
|
+
|
|
39
|
+
# Get the current time step value:
|
|
40
|
+
t_current = int(tick) # Check if tcount is needed or if we ca use this directly.
|
|
41
|
+
dt = globals[0] # time step size
|
|
42
|
+
I_bias = globals[1] # bias current
|
|
43
|
+
|
|
44
|
+
# NOTE: neuron_params would need to as long as the max number of params in any spiking neuron model
|
|
45
|
+
# Neuron Parameter
|
|
46
|
+
C = neuron_params[agent_index][0] # membrane capacitance
|
|
47
|
+
R = neuron_params[agent_index][1] # Leak resistance
|
|
48
|
+
vthr_initial = neuron_params[agent_index][2] # initial spike threshold
|
|
49
|
+
tref = neuron_params[agent_index][3] # refractory period
|
|
50
|
+
vrest = neuron_params[agent_index][4] # resting potential
|
|
51
|
+
vreset = neuron_params[agent_index][5] # reset potential
|
|
52
|
+
tref_allows_integration = neuron_params[agent_index][6] # whether to allow integration during refractory period
|
|
53
|
+
I_in = neuron_params[agent_index][7] # input current
|
|
54
|
+
scaling_factor = neuron_params[agent_index][8] # scaling factor for synaptic current
|
|
55
|
+
# Adaptive threshold parameters
|
|
56
|
+
delta_thr = neuron_params[agent_index][9] # threshold increment
|
|
57
|
+
tau_decay_thr = neuron_params[agent_index][10] # threshold decay constant
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# vreset = neuron_params[agent_index][8]
|
|
62
|
+
# I_in = neuron_params[agent_index][9]
|
|
63
|
+
|
|
64
|
+
# NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
|
|
65
|
+
# Internal state variables
|
|
66
|
+
v = internal_state[agent_index][0] # membrane potential
|
|
67
|
+
tcount = internal_state[agent_index][
|
|
68
|
+
1
|
|
69
|
+
] # time count from the start of the simulation
|
|
70
|
+
tlast = internal_state[agent_index][2] # last spike time
|
|
71
|
+
vthr = internal_state[agent_index][3] # spike threshold (updated value)
|
|
72
|
+
# Calculate the membrane potential update
|
|
73
|
+
dv = (vrest - v) / (R * C) + (I_synapse * scaling_factor + I_bias + I_in) / C
|
|
74
|
+
|
|
75
|
+
v += (
|
|
76
|
+
(dv * dt)
|
|
77
|
+
if ((dt * tcount) > (tlast + tref)) or tref_allows_integration
|
|
78
|
+
else 0.0
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
#if tlast > 0 else 1 # output spike only happens if the membrane potential exceeds the threshold and the neuron is not in refractory period.
|
|
82
|
+
s = 1.0 * ((v >= vthr) and (( dt * tcount > tlast + tref) if tlast > 0 else True))
|
|
83
|
+
# -----------------------------
|
|
84
|
+
# ADAPTIVE THRESHOLD UPDATE
|
|
85
|
+
# -----------------------------
|
|
86
|
+
if s == 1.0:
|
|
87
|
+
# Increase threshold after spike
|
|
88
|
+
vthr += delta_thr
|
|
89
|
+
else:
|
|
90
|
+
# Exponential decay: vthr = vrest_thr + (vthr - vrest_thr)*exp(-dt/tau)
|
|
91
|
+
vthr = vthr_initial + (vthr - vthr_initial) * cp.exp(-dt / tau_decay_thr)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
tlast = tlast * (1 - s) + dt * tcount * s
|
|
95
|
+
v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
|
|
96
|
+
|
|
97
|
+
internal_state[agent_index][0] = v
|
|
98
|
+
internal_state[agent_index][1] += 1
|
|
99
|
+
internal_state[agent_index][2] = tlast
|
|
100
|
+
# Write back updated threshold state
|
|
101
|
+
internal_state[agent_index][3] =vthr
|
|
102
|
+
|
|
103
|
+
output_spikes_tensor[agent_index][t_current] = s
|
|
104
|
+
|
|
105
|
+
# Safe buffer indexing: use modulo to prevent out-of-bounds access
|
|
106
|
+
# When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
|
|
107
|
+
buffer_idx = t_current % len(internal_states_buffer[agent_index])
|
|
108
|
+
internal_states_buffer[agent_index][buffer_idx][0] = v
|
|
109
|
+
internal_states_buffer[agent_index][buffer_idx][1] = internal_state[agent_index][1] + 1
|
|
110
|
+
internal_states_buffer[agent_index][buffer_idx][2] = tlast
|
|
111
|
+
internal_states_buffer[agent_index][buffer_idx][3] = vthr
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Single exponential synapse step functions for spiking neural networks
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import cupy as cp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from cupyx import jit
|
|
9
|
+
|
|
10
|
+
from superneuroabm.step_functions.synapse.util import get_soma_spike
|
|
11
|
+
from sagesim.utils import get_neighbor_data_from_tensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@jit.rawkernel(device="cuda")
|
|
15
|
+
def synapse_single_exp_step_func(
|
|
16
|
+
tick,
|
|
17
|
+
agent_index,
|
|
18
|
+
globals,
|
|
19
|
+
agent_ids,
|
|
20
|
+
breeds,
|
|
21
|
+
locations,
|
|
22
|
+
synapse_params, # scale, time constant (tau_rise and tau_fall)
|
|
23
|
+
learning_params,
|
|
24
|
+
internal_state, #
|
|
25
|
+
internal_learning_state,
|
|
26
|
+
synapse_history, # delay
|
|
27
|
+
input_spikes_tensor, # input spikes
|
|
28
|
+
output_spikes_tensor,
|
|
29
|
+
internal_states_buffer,
|
|
30
|
+
internal_learning_states_buffer,
|
|
31
|
+
):
|
|
32
|
+
t_current = int(tick)
|
|
33
|
+
|
|
34
|
+
dt = globals[0] # time step size
|
|
35
|
+
|
|
36
|
+
weight = synapse_params[agent_index][0]
|
|
37
|
+
synaptic_delay = synapse_params[agent_index][1]
|
|
38
|
+
scale = synapse_params[agent_index][2]
|
|
39
|
+
tau_fall = synapse_params[agent_index][3]
|
|
40
|
+
tau_rise = synapse_params[agent_index][4]
|
|
41
|
+
|
|
42
|
+
# locations[agent_index] = [pre_soma_index, post_soma_index]
|
|
43
|
+
# SAGESim has already converted agent IDs to local indices
|
|
44
|
+
pre_soma_index, post_soma_index = (
|
|
45
|
+
locations[agent_index][0],
|
|
46
|
+
locations[agent_index][1],
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
spike = get_soma_spike(
|
|
50
|
+
tick,
|
|
51
|
+
agent_index,
|
|
52
|
+
globals,
|
|
53
|
+
agent_ids,
|
|
54
|
+
pre_soma_index,
|
|
55
|
+
t_current,
|
|
56
|
+
input_spikes_tensor,
|
|
57
|
+
output_spikes_tensor,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
I_synapse = internal_state[agent_index][0]
|
|
61
|
+
|
|
62
|
+
I_synapse = I_synapse * (1 - dt / tau_fall) + spike * scale * weight
|
|
63
|
+
|
|
64
|
+
internal_state[agent_index][0] = I_synapse
|
|
65
|
+
|
|
66
|
+
# Safe buffer indexing: use modulo to prevent out-of-bounds access
|
|
67
|
+
# When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
|
|
68
|
+
buffer_idx = t_current % len(internal_states_buffer[agent_index])
|
|
69
|
+
internal_states_buffer[agent_index][buffer_idx][0] = I_synapse
|
|
70
|
+
internal_states_buffer[agent_index][buffer_idx][1] = spike
|
|
71
|
+
internal_states_buffer[agent_index][buffer_idx][2] = t_current
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from cupyx import jit
|
|
3
|
+
|
|
4
|
+
from superneuroabm.step_functions.synapse.util import get_pre_soma_spike
|
|
5
|
+
|
|
6
|
+
@jit.rawkernel(device="cuda")
|
|
7
|
+
def synapse_single_exp_step_func(
|
|
8
|
+
tick,
|
|
9
|
+
agent_index,
|
|
10
|
+
globals,
|
|
11
|
+
agent_ids,
|
|
12
|
+
breeds,
|
|
13
|
+
locations,
|
|
14
|
+
synapse_params, # Layout modified to add learning sigmas
|
|
15
|
+
learning_params, # Adding sigmas for learning noise
|
|
16
|
+
internal_state, # I_synapse
|
|
17
|
+
internal_learning_state, # Added w_eff, flags, rng state
|
|
18
|
+
synapse_history,
|
|
19
|
+
input_spikes_tensor,
|
|
20
|
+
output_spikes_tensor,
|
|
21
|
+
internal_states_buffer,
|
|
22
|
+
internal_learning_states_buffer,
|
|
23
|
+
):
|
|
24
|
+
# --- tiny device helpers (inline for rawkernel styles) -------------------
|
|
25
|
+
def _xorshift32(state_u32):
|
|
26
|
+
# returns (new_state, uint32)
|
|
27
|
+
x = state_u32 & 0xFFFFFFFF
|
|
28
|
+
x ^= (x << 13) & 0xFFFFFFFF
|
|
29
|
+
x ^= (x >> 17) & 0xFFFFFFFF
|
|
30
|
+
x ^= (x << 5) & 0xFFFFFFFF
|
|
31
|
+
return x, x
|
|
32
|
+
|
|
33
|
+
def _u01_from_u32(u32):
|
|
34
|
+
# map to (0,1); avoid exact 0
|
|
35
|
+
return (u32 + 1.0) * (1.0 / 4294967297.0)
|
|
36
|
+
|
|
37
|
+
def _randn(state_u32):
|
|
38
|
+
# Box–Muller using two uniforms
|
|
39
|
+
s, u1_i = _xorshift32(state_u32)
|
|
40
|
+
s, u2_i = _xorshift32(s)
|
|
41
|
+
u1 = _u01_from_u32(u1_i)
|
|
42
|
+
u2 = _u01_from_u32(u2_i)
|
|
43
|
+
# sqrt(-2 ln u1) * cos(2 pi u2)
|
|
44
|
+
r = math.sqrt(-2.0 * math.log(u1))
|
|
45
|
+
z = r * math.cos(6.283185307179586 * u2)
|
|
46
|
+
return s, z # new_state, standard normal
|
|
47
|
+
# -------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
t_current = int(tick)
|
|
50
|
+
dt = globals[0]
|
|
51
|
+
|
|
52
|
+
# ---- read params ---------------------------------------------------------
|
|
53
|
+
w_nom = synapse_params[agent_index][0]
|
|
54
|
+
synaptic_delay = synapse_params[agent_index][1]
|
|
55
|
+
scale = synapse_params[agent_index][2]
|
|
56
|
+
tau_fall = synapse_params[agent_index][3]
|
|
57
|
+
tau_rise = synapse_params[agent_index][4]
|
|
58
|
+
|
|
59
|
+
# per-synapse standard deviations for weight noise
|
|
60
|
+
sigma_prog = synapse_params[agent_index][5] # standard deviation for initial programming
|
|
61
|
+
sigma_stdp = synapse_params[agent_index][6] # standard deviation for STDP updates
|
|
62
|
+
|
|
63
|
+
# ---- persistent learning state ------------------------------------------
|
|
64
|
+
w_eff = internal_learning_state[agent_index][0]
|
|
65
|
+
is_programmed = int(internal_learning_state[agent_index][1]) # bool as int
|
|
66
|
+
rng_state_u32 = int(internal_learning_state[agent_index][2])
|
|
67
|
+
needs_reprog = int(internal_learning_state[agent_index][3])
|
|
68
|
+
|
|
69
|
+
# seed rng at first use (deterministic per synapse)
|
|
70
|
+
if rng_state_u32 == 0:
|
|
71
|
+
# simple seed from agent_index (offset avoids zero)
|
|
72
|
+
rng_state_u32 = (1664525 * (agent_index + 1) + 1013904223) & 0xFFFFFFFF
|
|
73
|
+
if rng_state_u32 == 0:
|
|
74
|
+
rng_state_u32 = 123456789
|
|
75
|
+
|
|
76
|
+
# initial programming (once)
|
|
77
|
+
if is_programmed == 0:
|
|
78
|
+
rng_state_u32, z = _randn(rng_state_u32)
|
|
79
|
+
# choose additive or multiplicative model;
|
|
80
|
+
# (a) additive: w_eff = w_nom + sigma_prog * z
|
|
81
|
+
# (b) multiplicative: w_eff = w_nom * (1.0 + sigma_prog * z)
|
|
82
|
+
w_eff = w_nom + sigma_prog * z # Additive, Katie's STDP is multiplicative in the orignal code
|
|
83
|
+
is_programmed = 1
|
|
84
|
+
needs_reprog = 0
|
|
85
|
+
|
|
86
|
+
# reprogram after STDP changed nominal weight
|
|
87
|
+
if needs_reprog == 1:
|
|
88
|
+
rng_state_u32, z = _randn(rng_state_u32)
|
|
89
|
+
# apply STDP programming noise
|
|
90
|
+
w_eff = w_nom + sigma_stdp * z # or multiplicative form
|
|
91
|
+
needs_reprog = 0
|
|
92
|
+
|
|
93
|
+
# store back persistent learning state
|
|
94
|
+
internal_learning_state[agent_index][0] = w_eff
|
|
95
|
+
internal_learning_state[agent_index][1] = is_programmed
|
|
96
|
+
internal_learning_state[agent_index][2] = rng_state_u32
|
|
97
|
+
internal_learning_state[agent_index][3] = needs_reprog
|
|
98
|
+
|
|
99
|
+
# ---- spikes and current update (unchanged except weight source) ----------
|
|
100
|
+
location_data = locations[agent_index]
|
|
101
|
+
pre_soma_id = -1 if cp.isnan(location_data[1]) else location_data[0]
|
|
102
|
+
post_soma_id = location_data[0] if cp.isnan(location_data[1]) else location_data[1]
|
|
103
|
+
|
|
104
|
+
spike = get_pre_soma_spike(
|
|
105
|
+
tick, agent_index, globals, agent_ids,
|
|
106
|
+
pre_soma_id, t_current, input_spikes_tensor, output_spikes_tensor
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
I_synapse = internal_state[agent_index][0]
|
|
110
|
+
# use w_eff (constant between programming events)
|
|
111
|
+
I_synapse = I_synapse * (1 - dt / tau_fall) + spike * scale * w_eff
|
|
112
|
+
|
|
113
|
+
internal_state[agent_index][0] = I_synapse
|
|
114
|
+
internal_states_buffer[agent_index][t_current][0] = I_synapse
|
|
115
|
+
internal_states_buffer[agent_index][t_current][1] = spike
|
|
116
|
+
internal_states_buffer[agent_index][t_current][2] = t_current
|
|
117
|
+
internal_states_buffer[agent_index][t_current][3] = pre_soma_id
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Exponential STDP (Spike-Timing Dependent Plasticity) step function for spiking neural networks
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import cupy as cp
|
|
7
|
+
from cupyx import jit
|
|
8
|
+
|
|
9
|
+
from superneuroabm.step_functions.synapse.util import get_soma_spike
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@jit.rawkernel(device="cuda")
|
|
13
|
+
def exp_pair_wise_stdp_quantized(
|
|
14
|
+
tick,
|
|
15
|
+
agent_index,
|
|
16
|
+
globals,
|
|
17
|
+
agent_ids,
|
|
18
|
+
breeds,
|
|
19
|
+
locations,
|
|
20
|
+
connectivity,
|
|
21
|
+
synapse_params, # scale, time constant (tau_rise and tau_fall)
|
|
22
|
+
learning_params,
|
|
23
|
+
internal_state, #
|
|
24
|
+
internal_learning_state, # learning state variables
|
|
25
|
+
synapse_history, # delay
|
|
26
|
+
input_spikes_tensor, # input spikes
|
|
27
|
+
output_spikes_tensor,
|
|
28
|
+
internal_states_buffer,
|
|
29
|
+
internal_learning_states_buffer,
|
|
30
|
+
):
|
|
31
|
+
t_current = int(tick)
|
|
32
|
+
|
|
33
|
+
dt = globals[0] # time step size
|
|
34
|
+
|
|
35
|
+
# Get the synapse parameters:
|
|
36
|
+
weight = synapse_params[agent_index][0]
|
|
37
|
+
synaptic_delay = synapse_params[agent_index][1]
|
|
38
|
+
|
|
39
|
+
# Get the learning parameters:
|
|
40
|
+
# stdpType = 0 # Parsed in the learning rule selector
|
|
41
|
+
tau_pre_stdp = learning_params[agent_index][1]
|
|
42
|
+
tau_post_stdp = learning_params[agent_index][2]
|
|
43
|
+
a_exp_pre = learning_params[agent_index][3]
|
|
44
|
+
a_exp_post = learning_params[agent_index][4]
|
|
45
|
+
stdp_history_length = learning_params[agent_index][5]
|
|
46
|
+
# Wmax, Wmin
|
|
47
|
+
|
|
48
|
+
pre_trace = internal_learning_state[agent_index][0]
|
|
49
|
+
post_trace = internal_learning_state[agent_index][1]
|
|
50
|
+
dW = internal_learning_state[agent_index][2]
|
|
51
|
+
|
|
52
|
+
# Get pre and post soma IDs from connectivity (contains agent IDs, not converted by SAGESim)
|
|
53
|
+
pre_soma_id = connectivity[agent_index][0]
|
|
54
|
+
post_soma_id = connectivity[agent_index][1]
|
|
55
|
+
|
|
56
|
+
# Get the pre-soma spike
|
|
57
|
+
pre_soma_spike = get_soma_spike(
|
|
58
|
+
tick,
|
|
59
|
+
agent_index,
|
|
60
|
+
globals,
|
|
61
|
+
agent_ids,
|
|
62
|
+
pre_soma_id,
|
|
63
|
+
t_current,
|
|
64
|
+
input_spikes_tensor,
|
|
65
|
+
output_spikes_tensor,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
post_soma_spike = get_soma_spike(
|
|
69
|
+
tick,
|
|
70
|
+
agent_index,
|
|
71
|
+
globals,
|
|
72
|
+
agent_ids,
|
|
73
|
+
post_soma_id,
|
|
74
|
+
t_current,
|
|
75
|
+
input_spikes_tensor,
|
|
76
|
+
output_spikes_tensor,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
pre_trace = pre_trace * (1 - dt / tau_pre_stdp) + pre_soma_spike * a_exp_pre
|
|
80
|
+
post_trace = post_trace * (1 - dt / tau_post_stdp) + post_soma_spike * a_exp_post
|
|
81
|
+
dW = pre_trace * post_soma_spike - post_trace * pre_soma_spike
|
|
82
|
+
|
|
83
|
+
weight += dW # Update the weight
|
|
84
|
+
|
|
85
|
+
# === 3-bit quantization ===
|
|
86
|
+
wmin = 0#learning_params[agent_index][6] # assuming stored in learning_params
|
|
87
|
+
wmax = 14#learning_params[agent_index][7]
|
|
88
|
+
num_levels = 8 # 3 bits -> 8 quantization levels#learning_params[agent_index][8]
|
|
89
|
+
delta = (wmax - wmin) / (num_levels - 1)
|
|
90
|
+
weight = cp.clip(weight, wmin, wmax)
|
|
91
|
+
quantized_weight = cp.round((weight - wmin) / delta) * delta + wmin
|
|
92
|
+
weight = quantized_weight
|
|
93
|
+
# ==========================
|
|
94
|
+
|
|
95
|
+
synapse_params[agent_index][0] = weight # Update quantized weight
|
|
96
|
+
|
|
97
|
+
internal_learning_state[agent_index][0] = pre_trace
|
|
98
|
+
internal_learning_state[agent_index][1] = post_trace
|
|
99
|
+
internal_learning_state[agent_index][2] = dW
|
|
100
|
+
|
|
101
|
+
# Safe buffer indexing: use modulo to prevent out-of-bounds access
|
|
102
|
+
# When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
|
|
103
|
+
buffer_idx = t_current % len(internal_learning_states_buffer[agent_index])
|
|
104
|
+
internal_learning_states_buffer[agent_index][buffer_idx][0] = pre_trace
|
|
105
|
+
internal_learning_states_buffer[agent_index][buffer_idx][1] = post_trace
|
|
106
|
+
internal_learning_states_buffer[agent_index][buffer_idx][2] = dW
|
|
107
|
+
|
|
108
|
+
# spike_pre_[t_current] = pre_soma_spike #spike_pre_ is an array of size (stdp_history_length, number of input neurons), pre_soma_spike is (number of input neurons,)
|
|
109
|
+
# spike_post_[:, t_current] = post_soma_spike#spike_post_ is an array of size (number of output neurons,stdp_history_length), post_soma_spike is (number of output neurons,)
|
|
110
|
+
# trace_pre_[t_current] = pre_trace #Corresponding traces an array of size (stdp_history_length,number of input neurons), pre_trace is (number of input neurons,)
|
|
111
|
+
# trace_post_[:, t_current] = post_trace #Corresponding traces is an array of size (number of output neurons,stdp_history_length)
|
|
112
|
+
|
|
113
|
+
# if t_current == stdp_history_length:
|
|
114
|
+
# dW = cp.dot(spike_post_, trace_pre_)#(1,stdp_history_length) dot (stdp_history_length,1) we might need additional learning rate and multiplicative STDP*(wmax - W)*
|
|
115
|
+
# dW -=cp.dot(trace_post_, spike_pre_)#(1,stdp_history_length) dot (stdp_history_length,1), add learning rat*W for multiplicative STDP
|
|
116
|
+
# clipped_dW = cp.clip(dW / stdp_history_length, dw_max, dw_min) # Clip the weight change if needed
|
|
117
|
+
# weight = cp.clip(weight+clipped_dW,wmin, wmax) # Update the weight
|
|
118
|
+
# #reset the traces and spikes buffers
|
|
119
|
+
# spike_pre_ = cp.zeros((stdp_history_length, number_of_input_neurons), dtype=cp.float32)
|
|
120
|
+
# spike_post_ = cp.zeros((number_of_output_neurons, stdp_history_length), dtype=cp.float32)
|
|
121
|
+
# trace_pre_ = cp.zeros((stdp_history_length, number_of_input_neurons), dtype=cp.float32)
|
|
122
|
+
# trace_post_ = cp.zeros((number_of_output_neurons, stdp_history_length), dtype=cp.float32)
|
|
123
|
+
|
|
124
|
+
internal_state[agent_index][2] = pre_trace
|
|
125
|
+
internal_state[agent_index][3] = post_trace
|
|
126
|
+
|
|
127
|
+
# Safe buffer indexing for internal_states_buffer (reuse buffer_idx from above)
|
|
128
|
+
state_buffer_idx = t_current % len(internal_states_buffer[agent_index])
|
|
129
|
+
internal_states_buffer[agent_index][state_buffer_idx][2] = post_soma_spike
|
|
130
|
+
internal_states_buffer[agent_index][state_buffer_idx][3] = post_trace
|