superneuroabm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,86 @@
1
+ """
2
+ Izhikevich Neuron and weighted synapse step functions for spiking neural networks
3
+
4
+ """
5
+
6
+ import cupy as cp
7
+ from cupyx import jit
8
+
9
+
10
+ @jit.rawkernel(device="cuda")
11
+ def izh_soma_step_func(
12
+ tick,
13
+ agent_index,
14
+ globals,
15
+ agent_ids,
16
+ breeds,
17
+ locations,
18
+ neuron_params, # k, vth, C, a, b,
19
+ learning_params,
20
+ internal_state, # v, u
21
+ internal_learning_state,
22
+ synapse_history, # Synapse delay
23
+ input_spikes_tensor, # input spikes
24
+ output_spikes_tensor,
25
+ internal_states_buffer,
26
+ internal_learning_states_buffer,
27
+ ):
28
+ synapse_indices = locations[agent_index] # network location is defined by neighbors
29
+
30
+ I_synapse = 0.0
31
+
32
+ # synapse_indices now contains pre-computed local indices (converted in SAGESim)
33
+ # No linear search needed!
34
+ for i in range(len(synapse_indices)):
35
+ synapse_index = int(synapse_indices[i])
36
+ if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
37
+ I_synapse += internal_state[synapse_index][0]
38
+
39
+ # Get the current time step value:
40
+ t_current = int(tick)
41
+ dt = globals[0] # time step size
42
+ I_bias = globals[1] # bias current
43
+
44
+ # NOTE: neuron_params would need to be as long as the max number of params in any spiking neuron model
45
+ k = neuron_params[agent_index][0]
46
+ vthr = neuron_params[agent_index][1]
47
+ C = neuron_params[agent_index][2]
48
+ a = neuron_params[agent_index][3]
49
+ b = neuron_params[agent_index][4]
50
+ vpeak = neuron_params[agent_index][5]
51
+ vrest = neuron_params[agent_index][6]
52
+ d = neuron_params[agent_index][7]
53
+ vreset = neuron_params[agent_index][8]
54
+ I_in = neuron_params[agent_index][9]
55
+
56
+ # From https://www.izhikevich.org/publications/spikes.htm
57
+ # v' = 0.04v^2 + 5v + 140 -u + I
58
+ # u' = a(bv - u)
59
+ # if v=30mV: v = c, u = u + d, spike
60
+
61
+ # dv = (k*(internal_state[my_idx]-vrest)*(internal_state[my_idx]-vthr)-u[my_idx]+I) / C
62
+ # internal_state: [0] - v, [1] - u
63
+ # NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
64
+
65
+ v = internal_state[agent_index][0]
66
+ u = internal_state[agent_index][1]
67
+
68
+ dv = (k * (v - vrest) * (v - vthr) - u + I_synapse + I_bias + I_in) / C
69
+ v = v + dt * dv * 1e3
70
+
71
+ u += dt * 1e3 * (a * (b * (v - vrest) - u))
72
+ # s = 1 * (v >= vthr) # output spike
73
+ s = 1 * (v >= vpeak) # output spike
74
+ u = u + d * s # If spiked, update recovery variable
75
+ v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
76
+
77
+ internal_state[agent_index][0] = v
78
+ internal_state[agent_index][1] = u
79
+
80
+ output_spikes_tensor[agent_index][t_current] = s
81
+
82
+ # Safe buffer indexing: use modulo to prevent out-of-bounds access
83
+ # When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
84
+ buffer_idx = t_current % len(internal_states_buffer[agent_index])
85
+ internal_states_buffer[agent_index][buffer_idx][0] = v
86
+ internal_states_buffer[agent_index][buffer_idx][1] = u
@@ -0,0 +1,98 @@
1
+ """
2
+ LIF Neuron and weighted synapse step functions for spiking neural networks
3
+
4
+ """
5
+
6
+ from cupyx import jit
7
+ import cupy as cp
8
+
9
+
10
+ @jit.rawkernel(device="cuda")
11
+ def lif_soma_step_func( # NOTE: update the name to soma_step_func from neuron_step_func
12
+ tick,
13
+ agent_index,
14
+ globals,
15
+ agent_ids,
16
+ breeds,
17
+ locations,
18
+ neuron_params, # k, vth, C, a, b,
19
+ learning_params,
20
+ internal_state, # v, u
21
+ internal_learning_state,
22
+ synapse_history, # Synapse delay
23
+ input_spikes_tensor, # input spikes
24
+ output_spikes_tensor,
25
+ internal_states_buffer,
26
+ internal_learning_states_buffer,
27
+ ):
28
+ synapse_indices = locations[agent_index] # Now contains local indices instead of IDs
29
+
30
+ I_synapse = 0.0
31
+
32
+ # synapse_indices now contains pre-computed local indices (converted in SAGESim)
33
+ # No linear search needed!
34
+ for i in range(len(synapse_indices)):
35
+ synapse_index = int(synapse_indices[i])
36
+ if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
37
+ I_synapse += internal_state[synapse_index][0]
38
+
39
+ # Get the current time step value:
40
+ t_current = int(tick) # Check if tcount is needed or if we ca use this directly.
41
+ dt = globals[0] # time step size
42
+ I_bias = globals[1] # bias current
43
+
44
+ # NOTE: neuron_params would need to as long as the max number of params in any spiking neuron model
45
+ # Neuron Parameter
46
+ C = neuron_params[agent_index][0] # membrane capacitance
47
+ R = neuron_params[agent_index][1] # Leak resistance
48
+ vthr = neuron_params[agent_index][2] # spike threshold
49
+ tref = neuron_params[agent_index][3] # refractory period
50
+ vrest = neuron_params[agent_index][4] # resting potential
51
+ vreset = neuron_params[agent_index][5] # reset potential
52
+ tref_allows_integration = neuron_params[agent_index][
53
+ 6
54
+ ] # whether to allow integration during refractory period
55
+ I_in = neuron_params[agent_index][7] # input current
56
+ scaling_factor = neuron_params[agent_index][
57
+ 8
58
+ ] # scaling factor for synaptic current
59
+ # vreset = neuron_params[agent_index][8]
60
+ # I_in = neuron_params[agent_index][9]
61
+
62
+ # NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
63
+ # Internal state variables
64
+ v = internal_state[agent_index][0] # membrane potential
65
+ tcount = internal_state[agent_index][
66
+ 1
67
+ ] # time count from the start of the simulation
68
+ tlast = internal_state[agent_index][2] # last spike time
69
+
70
+ # Calculate the membrane potential update
71
+ dv = (vrest - v) / (R * C) + (I_synapse * scaling_factor + I_bias + I_in) / C
72
+
73
+ v += (
74
+ (dv * dt)
75
+ if ((dt * tcount) > (tlast + tref)) or tref_allows_integration
76
+ else 0.0
77
+ )
78
+
79
+ #if tlast > 0 else 1 # output spike only happens if the membrane potential exceeds the threshold and the neuron is not in refractory period.
80
+ s = 1.0 * ((v >= vthr) and (( dt * tcount > tlast + tref) if tlast > 0 else True))
81
+
82
+
83
+ tlast = tlast * (1 - s) + dt * tcount * s
84
+ v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
85
+
86
+ internal_state[agent_index][0] = v
87
+ internal_state[agent_index][1] += 1
88
+ internal_state[agent_index][2] = tlast
89
+
90
+ output_spikes_tensor[agent_index][t_current] = s
91
+
92
+ # Safe buffer indexing: use modulo to prevent out-of-bounds access
93
+ # When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
94
+ buffer_idx = t_current % len(internal_states_buffer[agent_index])
95
+ internal_states_buffer[agent_index][buffer_idx][0] = v
96
+ internal_states_buffer[agent_index][buffer_idx][1] = internal_state[agent_index][1] + 1
97
+ internal_states_buffer[agent_index][buffer_idx][2] = tlast
98
+
@@ -0,0 +1,111 @@
1
+ """
2
+ LIF Neuron and weighted synapse step functions for spiking neural networks
3
+
4
+ """
5
+
6
+ from cupyx import jit
7
+ import cupy as cp
8
+
9
+
10
+ @jit.rawkernel(device="cuda")
11
+ def lif_soma_adaptive_thr_step_func( # NOTE: update the name to soma_step_func from neuron_step_func
12
+ tick,
13
+ agent_index,
14
+ globals,
15
+ agent_ids,
16
+ breeds,
17
+ locations,
18
+ neuron_params, # k, vth, C, a, b,
19
+ learning_params,
20
+ internal_state, # v, u
21
+ internal_learning_state,
22
+ synapse_history, # Synapse delay
23
+ input_spikes_tensor, # input spikes
24
+ output_spikes_tensor,
25
+ internal_states_buffer,
26
+ internal_learning_states_buffer,
27
+ ):
28
+ synapse_indices = locations[agent_index] # Now contains local indices instead of IDs
29
+
30
+ I_synapse = 0.0
31
+
32
+ # synapse_indices now contains pre-computed local indices (converted in SAGESim)
33
+ # No linear search needed!
34
+ for i in range(len(synapse_indices)):
35
+ synapse_index = int(synapse_indices[i])
36
+ if synapse_index >= 0 and not cp.isnan(synapse_indices[i]):
37
+ I_synapse += internal_state[synapse_index][0]
38
+
39
+ # Get the current time step value:
40
+ t_current = int(tick) # Check if tcount is needed or if we ca use this directly.
41
+ dt = globals[0] # time step size
42
+ I_bias = globals[1] # bias current
43
+
44
+ # NOTE: neuron_params would need to as long as the max number of params in any spiking neuron model
45
+ # Neuron Parameter
46
+ C = neuron_params[agent_index][0] # membrane capacitance
47
+ R = neuron_params[agent_index][1] # Leak resistance
48
+ vthr_initial = neuron_params[agent_index][2] # initial spike threshold
49
+ tref = neuron_params[agent_index][3] # refractory period
50
+ vrest = neuron_params[agent_index][4] # resting potential
51
+ vreset = neuron_params[agent_index][5] # reset potential
52
+ tref_allows_integration = neuron_params[agent_index][6] # whether to allow integration during refractory period
53
+ I_in = neuron_params[agent_index][7] # input current
54
+ scaling_factor = neuron_params[agent_index][8] # scaling factor for synaptic current
55
+ # Adaptive threshold parameters
56
+ delta_thr = neuron_params[agent_index][9] # threshold increment
57
+ tau_decay_thr = neuron_params[agent_index][10] # threshold decay constant
58
+
59
+
60
+
61
+ # vreset = neuron_params[agent_index][8]
62
+ # I_in = neuron_params[agent_index][9]
63
+
64
+ # NOTE: size of internal_state would need to be set as the maximum possible state variables of any spiking neuron
65
+ # Internal state variables
66
+ v = internal_state[agent_index][0] # membrane potential
67
+ tcount = internal_state[agent_index][
68
+ 1
69
+ ] # time count from the start of the simulation
70
+ tlast = internal_state[agent_index][2] # last spike time
71
+ vthr = internal_state[agent_index][3] # spike threshold (updated value)
72
+ # Calculate the membrane potential update
73
+ dv = (vrest - v) / (R * C) + (I_synapse * scaling_factor + I_bias + I_in) / C
74
+
75
+ v += (
76
+ (dv * dt)
77
+ if ((dt * tcount) > (tlast + tref)) or tref_allows_integration
78
+ else 0.0
79
+ )
80
+
81
+ #if tlast > 0 else 1 # output spike only happens if the membrane potential exceeds the threshold and the neuron is not in refractory period.
82
+ s = 1.0 * ((v >= vthr) and (( dt * tcount > tlast + tref) if tlast > 0 else True))
83
+ # -----------------------------
84
+ # ADAPTIVE THRESHOLD UPDATE
85
+ # -----------------------------
86
+ if s == 1.0:
87
+ # Increase threshold after spike
88
+ vthr += delta_thr
89
+ else:
90
+ # Exponential decay: vthr = vrest_thr + (vthr - vrest_thr)*exp(-dt/tau)
91
+ vthr = vthr_initial + (vthr - vthr_initial) * cp.exp(-dt / tau_decay_thr)
92
+
93
+
94
+ tlast = tlast * (1 - s) + dt * tcount * s
95
+ v = v * (1 - s) + vreset * s # If spiked, reset membrane potential
96
+
97
+ internal_state[agent_index][0] = v
98
+ internal_state[agent_index][1] += 1
99
+ internal_state[agent_index][2] = tlast
100
+ # Write back updated threshold state
101
+ internal_state[agent_index][3] =vthr
102
+
103
+ output_spikes_tensor[agent_index][t_current] = s
104
+
105
+ # Safe buffer indexing: use modulo to prevent out-of-bounds access
106
+ # When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
107
+ buffer_idx = t_current % len(internal_states_buffer[agent_index])
108
+ internal_states_buffer[agent_index][buffer_idx][0] = v
109
+ internal_states_buffer[agent_index][buffer_idx][1] = internal_state[agent_index][1] + 1
110
+ internal_states_buffer[agent_index][buffer_idx][2] = tlast
111
+ internal_states_buffer[agent_index][buffer_idx][3] = vthr
@@ -0,0 +1,71 @@
1
+ """
2
+ Single exponential synapse step functions for spiking neural networks
3
+
4
+ """
5
+
6
+ import cupy as cp
7
+ import numpy as np
8
+ from cupyx import jit
9
+
10
+ from superneuroabm.step_functions.synapse.util import get_soma_spike
11
+ from sagesim.utils import get_neighbor_data_from_tensor
12
+
13
+
14
+ @jit.rawkernel(device="cuda")
15
+ def synapse_single_exp_step_func(
16
+ tick,
17
+ agent_index,
18
+ globals,
19
+ agent_ids,
20
+ breeds,
21
+ locations,
22
+ synapse_params, # scale, time constant (tau_rise and tau_fall)
23
+ learning_params,
24
+ internal_state, #
25
+ internal_learning_state,
26
+ synapse_history, # delay
27
+ input_spikes_tensor, # input spikes
28
+ output_spikes_tensor,
29
+ internal_states_buffer,
30
+ internal_learning_states_buffer,
31
+ ):
32
+ t_current = int(tick)
33
+
34
+ dt = globals[0] # time step size
35
+
36
+ weight = synapse_params[agent_index][0]
37
+ synaptic_delay = synapse_params[agent_index][1]
38
+ scale = synapse_params[agent_index][2]
39
+ tau_fall = synapse_params[agent_index][3]
40
+ tau_rise = synapse_params[agent_index][4]
41
+
42
+ # locations[agent_index] = [pre_soma_index, post_soma_index]
43
+ # SAGESim has already converted agent IDs to local indices
44
+ pre_soma_index, post_soma_index = (
45
+ locations[agent_index][0],
46
+ locations[agent_index][1],
47
+ )
48
+
49
+ spike = get_soma_spike(
50
+ tick,
51
+ agent_index,
52
+ globals,
53
+ agent_ids,
54
+ pre_soma_index,
55
+ t_current,
56
+ input_spikes_tensor,
57
+ output_spikes_tensor,
58
+ )
59
+
60
+ I_synapse = internal_state[agent_index][0]
61
+
62
+ I_synapse = I_synapse * (1 - dt / tau_fall) + spike * scale * weight
63
+
64
+ internal_state[agent_index][0] = I_synapse
65
+
66
+ # Safe buffer indexing: use modulo to prevent out-of-bounds access
67
+ # When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
68
+ buffer_idx = t_current % len(internal_states_buffer[agent_index])
69
+ internal_states_buffer[agent_index][buffer_idx][0] = I_synapse
70
+ internal_states_buffer[agent_index][buffer_idx][1] = spike
71
+ internal_states_buffer[agent_index][buffer_idx][2] = t_current
@@ -0,0 +1,117 @@
1
+ import numpy as np
2
+ from cupyx import jit
3
+
4
+ from superneuroabm.step_functions.synapse.util import get_pre_soma_spike
5
+
6
+ @jit.rawkernel(device="cuda")
7
+ def synapse_single_exp_step_func(
8
+ tick,
9
+ agent_index,
10
+ globals,
11
+ agent_ids,
12
+ breeds,
13
+ locations,
14
+ synapse_params, # Layout modified to add learning sigmas
15
+ learning_params, # Adding sigmas for learning noise
16
+ internal_state, # I_synapse
17
+ internal_learning_state, # Added w_eff, flags, rng state
18
+ synapse_history,
19
+ input_spikes_tensor,
20
+ output_spikes_tensor,
21
+ internal_states_buffer,
22
+ internal_learning_states_buffer,
23
+ ):
24
+ # --- tiny device helpers (inline for rawkernel styles) -------------------
25
+ def _xorshift32(state_u32):
26
+ # returns (new_state, uint32)
27
+ x = state_u32 & 0xFFFFFFFF
28
+ x ^= (x << 13) & 0xFFFFFFFF
29
+ x ^= (x >> 17) & 0xFFFFFFFF
30
+ x ^= (x << 5) & 0xFFFFFFFF
31
+ return x, x
32
+
33
+ def _u01_from_u32(u32):
34
+ # map to (0,1); avoid exact 0
35
+ return (u32 + 1.0) * (1.0 / 4294967297.0)
36
+
37
+ def _randn(state_u32):
38
+ # Box–Muller using two uniforms
39
+ s, u1_i = _xorshift32(state_u32)
40
+ s, u2_i = _xorshift32(s)
41
+ u1 = _u01_from_u32(u1_i)
42
+ u2 = _u01_from_u32(u2_i)
43
+ # sqrt(-2 ln u1) * cos(2 pi u2)
44
+ r = math.sqrt(-2.0 * math.log(u1))
45
+ z = r * math.cos(6.283185307179586 * u2)
46
+ return s, z # new_state, standard normal
47
+ # -------------------------------------------------------------------------
48
+
49
+ t_current = int(tick)
50
+ dt = globals[0]
51
+
52
+ # ---- read params ---------------------------------------------------------
53
+ w_nom = synapse_params[agent_index][0]
54
+ synaptic_delay = synapse_params[agent_index][1]
55
+ scale = synapse_params[agent_index][2]
56
+ tau_fall = synapse_params[agent_index][3]
57
+ tau_rise = synapse_params[agent_index][4]
58
+
59
+ # per-synapse standard deviations for weight noise
60
+ sigma_prog = synapse_params[agent_index][5] # standard deviation for initial programming
61
+ sigma_stdp = synapse_params[agent_index][6] # standard deviation for STDP updates
62
+
63
+ # ---- persistent learning state ------------------------------------------
64
+ w_eff = internal_learning_state[agent_index][0]
65
+ is_programmed = int(internal_learning_state[agent_index][1]) # bool as int
66
+ rng_state_u32 = int(internal_learning_state[agent_index][2])
67
+ needs_reprog = int(internal_learning_state[agent_index][3])
68
+
69
+ # seed rng at first use (deterministic per synapse)
70
+ if rng_state_u32 == 0:
71
+ # simple seed from agent_index (offset avoids zero)
72
+ rng_state_u32 = (1664525 * (agent_index + 1) + 1013904223) & 0xFFFFFFFF
73
+ if rng_state_u32 == 0:
74
+ rng_state_u32 = 123456789
75
+
76
+ # initial programming (once)
77
+ if is_programmed == 0:
78
+ rng_state_u32, z = _randn(rng_state_u32)
79
+ # choose additive or multiplicative model;
80
+ # (a) additive: w_eff = w_nom + sigma_prog * z
81
+ # (b) multiplicative: w_eff = w_nom * (1.0 + sigma_prog * z)
82
+ w_eff = w_nom + sigma_prog * z # Additive, Katie's STDP is multiplicative in the orignal code
83
+ is_programmed = 1
84
+ needs_reprog = 0
85
+
86
+ # reprogram after STDP changed nominal weight
87
+ if needs_reprog == 1:
88
+ rng_state_u32, z = _randn(rng_state_u32)
89
+ # apply STDP programming noise
90
+ w_eff = w_nom + sigma_stdp * z # or multiplicative form
91
+ needs_reprog = 0
92
+
93
+ # store back persistent learning state
94
+ internal_learning_state[agent_index][0] = w_eff
95
+ internal_learning_state[agent_index][1] = is_programmed
96
+ internal_learning_state[agent_index][2] = rng_state_u32
97
+ internal_learning_state[agent_index][3] = needs_reprog
98
+
99
+ # ---- spikes and current update (unchanged except weight source) ----------
100
+ location_data = locations[agent_index]
101
+ pre_soma_id = -1 if cp.isnan(location_data[1]) else location_data[0]
102
+ post_soma_id = location_data[0] if cp.isnan(location_data[1]) else location_data[1]
103
+
104
+ spike = get_pre_soma_spike(
105
+ tick, agent_index, globals, agent_ids,
106
+ pre_soma_id, t_current, input_spikes_tensor, output_spikes_tensor
107
+ )
108
+
109
+ I_synapse = internal_state[agent_index][0]
110
+ # use w_eff (constant between programming events)
111
+ I_synapse = I_synapse * (1 - dt / tau_fall) + spike * scale * w_eff
112
+
113
+ internal_state[agent_index][0] = I_synapse
114
+ internal_states_buffer[agent_index][t_current][0] = I_synapse
115
+ internal_states_buffer[agent_index][t_current][1] = spike
116
+ internal_states_buffer[agent_index][t_current][2] = t_current
117
+ internal_states_buffer[agent_index][t_current][3] = pre_soma_id
@@ -0,0 +1,130 @@
1
+ """
2
+ Exponential STDP (Spike-Timing Dependent Plasticity) step function for spiking neural networks
3
+
4
+ """
5
+
6
+ import cupy as cp
7
+ from cupyx import jit
8
+
9
+ from superneuroabm.step_functions.synapse.util import get_soma_spike
10
+
11
+
12
+ @jit.rawkernel(device="cuda")
13
+ def exp_pair_wise_stdp_quantized(
14
+ tick,
15
+ agent_index,
16
+ globals,
17
+ agent_ids,
18
+ breeds,
19
+ locations,
20
+ connectivity,
21
+ synapse_params, # scale, time constant (tau_rise and tau_fall)
22
+ learning_params,
23
+ internal_state, #
24
+ internal_learning_state, # learning state variables
25
+ synapse_history, # delay
26
+ input_spikes_tensor, # input spikes
27
+ output_spikes_tensor,
28
+ internal_states_buffer,
29
+ internal_learning_states_buffer,
30
+ ):
31
+ t_current = int(tick)
32
+
33
+ dt = globals[0] # time step size
34
+
35
+ # Get the synapse parameters:
36
+ weight = synapse_params[agent_index][0]
37
+ synaptic_delay = synapse_params[agent_index][1]
38
+
39
+ # Get the learning parameters:
40
+ # stdpType = 0 # Parsed in the learning rule selector
41
+ tau_pre_stdp = learning_params[agent_index][1]
42
+ tau_post_stdp = learning_params[agent_index][2]
43
+ a_exp_pre = learning_params[agent_index][3]
44
+ a_exp_post = learning_params[agent_index][4]
45
+ stdp_history_length = learning_params[agent_index][5]
46
+ # Wmax, Wmin
47
+
48
+ pre_trace = internal_learning_state[agent_index][0]
49
+ post_trace = internal_learning_state[agent_index][1]
50
+ dW = internal_learning_state[agent_index][2]
51
+
52
+ # Get pre and post soma IDs from connectivity (contains agent IDs, not converted by SAGESim)
53
+ pre_soma_id = connectivity[agent_index][0]
54
+ post_soma_id = connectivity[agent_index][1]
55
+
56
+ # Get the pre-soma spike
57
+ pre_soma_spike = get_soma_spike(
58
+ tick,
59
+ agent_index,
60
+ globals,
61
+ agent_ids,
62
+ pre_soma_id,
63
+ t_current,
64
+ input_spikes_tensor,
65
+ output_spikes_tensor,
66
+ )
67
+
68
+ post_soma_spike = get_soma_spike(
69
+ tick,
70
+ agent_index,
71
+ globals,
72
+ agent_ids,
73
+ post_soma_id,
74
+ t_current,
75
+ input_spikes_tensor,
76
+ output_spikes_tensor,
77
+ )
78
+
79
+ pre_trace = pre_trace * (1 - dt / tau_pre_stdp) + pre_soma_spike * a_exp_pre
80
+ post_trace = post_trace * (1 - dt / tau_post_stdp) + post_soma_spike * a_exp_post
81
+ dW = pre_trace * post_soma_spike - post_trace * pre_soma_spike
82
+
83
+ weight += dW # Update the weight
84
+
85
+ # === 3-bit quantization ===
86
+ wmin = 0#learning_params[agent_index][6] # assuming stored in learning_params
87
+ wmax = 14#learning_params[agent_index][7]
88
+ num_levels = 8 # 3 bits -> 8 quantization levels#learning_params[agent_index][8]
89
+ delta = (wmax - wmin) / (num_levels - 1)
90
+ weight = cp.clip(weight, wmin, wmax)
91
+ quantized_weight = cp.round((weight - wmin) / delta) * delta + wmin
92
+ weight = quantized_weight
93
+ # ==========================
94
+
95
+ synapse_params[agent_index][0] = weight # Update quantized weight
96
+
97
+ internal_learning_state[agent_index][0] = pre_trace
98
+ internal_learning_state[agent_index][1] = post_trace
99
+ internal_learning_state[agent_index][2] = dW
100
+
101
+ # Safe buffer indexing: use modulo to prevent out-of-bounds access
102
+ # When tracking is disabled, buffer length is 1, so t_current % 1 = 0 always
103
+ buffer_idx = t_current % len(internal_learning_states_buffer[agent_index])
104
+ internal_learning_states_buffer[agent_index][buffer_idx][0] = pre_trace
105
+ internal_learning_states_buffer[agent_index][buffer_idx][1] = post_trace
106
+ internal_learning_states_buffer[agent_index][buffer_idx][2] = dW
107
+
108
+ # spike_pre_[t_current] = pre_soma_spike #spike_pre_ is an array of size (stdp_history_length, number of input neurons), pre_soma_spike is (number of input neurons,)
109
+ # spike_post_[:, t_current] = post_soma_spike#spike_post_ is an array of size (number of output neurons,stdp_history_length), post_soma_spike is (number of output neurons,)
110
+ # trace_pre_[t_current] = pre_trace #Corresponding traces an array of size (stdp_history_length,number of input neurons), pre_trace is (number of input neurons,)
111
+ # trace_post_[:, t_current] = post_trace #Corresponding traces is an array of size (number of output neurons,stdp_history_length)
112
+
113
+ # if t_current == stdp_history_length:
114
+ # dW = cp.dot(spike_post_, trace_pre_)#(1,stdp_history_length) dot (stdp_history_length,1) we might need additional learning rate and multiplicative STDP*(wmax - W)*
115
+ # dW -=cp.dot(trace_post_, spike_pre_)#(1,stdp_history_length) dot (stdp_history_length,1), add learning rat*W for multiplicative STDP
116
+ # clipped_dW = cp.clip(dW / stdp_history_length, dw_max, dw_min) # Clip the weight change if needed
117
+ # weight = cp.clip(weight+clipped_dW,wmin, wmax) # Update the weight
118
+ # #reset the traces and spikes buffers
119
+ # spike_pre_ = cp.zeros((stdp_history_length, number_of_input_neurons), dtype=cp.float32)
120
+ # spike_post_ = cp.zeros((number_of_output_neurons, stdp_history_length), dtype=cp.float32)
121
+ # trace_pre_ = cp.zeros((stdp_history_length, number_of_input_neurons), dtype=cp.float32)
122
+ # trace_post_ = cp.zeros((number_of_output_neurons, stdp_history_length), dtype=cp.float32)
123
+
124
+ internal_state[agent_index][2] = pre_trace
125
+ internal_state[agent_index][3] = post_trace
126
+
127
+ # Safe buffer indexing for internal_states_buffer (reuse buffer_idx from above)
128
+ state_buffer_idx = t_current % len(internal_states_buffer[agent_index])
129
+ internal_states_buffer[agent_index][state_buffer_idx][2] = post_soma_spike
130
+ internal_states_buffer[agent_index][state_buffer_idx][3] = post_trace