brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/augment/_autograd.py +9 -6
- brainstate/event/__init__.py +4 -2
- brainstate/event/_csr.py +26 -18
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_fixed_probability.py +589 -152
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +13 -10
- brainstate/event/_linear.py +267 -127
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +8 -3
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +18 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
benchmark/COBA_2005.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
#
|
17
|
+
# Implementation of the paper:
|
18
|
+
#
|
19
|
+
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
+
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
+
#
|
22
|
+
# which is based on the balanced network proposed by:
|
23
|
+
#
|
24
|
+
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
+
#
|
26
|
+
import os
|
27
|
+
import sys
|
28
|
+
|
29
|
+
sys.path.append('../')
|
30
|
+
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
31
|
+
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
32
|
+
|
33
|
+
|
34
|
+
import jax
|
35
|
+
import brainunit as u
|
36
|
+
import time
|
37
|
+
import brainstate as bst
|
38
|
+
|
39
|
+
|
40
|
+
class EINet(bst.nn.DynamicsGroup):
|
41
|
+
def __init__(self, scale):
|
42
|
+
super().__init__()
|
43
|
+
self.n_exc = int(3200 * scale)
|
44
|
+
self.n_inh = int(800 * scale)
|
45
|
+
self.num = self.n_exc + self.n_inh
|
46
|
+
self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
47
|
+
tau=20. * u.ms, tau_ref=5. * u.ms,
|
48
|
+
V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
|
49
|
+
self.E = bst.nn.AlignPostProj(
|
50
|
+
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
|
51
|
+
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
52
|
+
out=bst.nn.COBA.desc(E=0. * u.mV),
|
53
|
+
post=self.N
|
54
|
+
)
|
55
|
+
self.I = bst.nn.AlignPostProj(
|
56
|
+
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
|
57
|
+
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
58
|
+
out=bst.nn.COBA.desc(E=-80. * u.mV),
|
59
|
+
post=self.N
|
60
|
+
)
|
61
|
+
|
62
|
+
def init_state(self, *args, **kwargs):
|
63
|
+
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
64
|
+
|
65
|
+
def update(self, t, inp):
|
66
|
+
with bst.environ.context(t=t):
|
67
|
+
spk = self.N.get_spike() != 0.
|
68
|
+
self.E(spk[:self.n_exc])
|
69
|
+
self.I(spk[self.n_exc:])
|
70
|
+
self.N(inp)
|
71
|
+
self.rate.value += self.N.get_spike()
|
72
|
+
|
73
|
+
|
74
|
+
@bst.compile.jit(static_argnums=0)
|
75
|
+
def run(scale: float):
|
76
|
+
# network
|
77
|
+
net = EINet(scale)
|
78
|
+
bst.nn.init_all_states(net)
|
79
|
+
|
80
|
+
duration = 1e4 * u.ms
|
81
|
+
# simulation
|
82
|
+
with bst.environ.context(dt=0.1 * u.ms):
|
83
|
+
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
84
|
+
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
|
85
|
+
|
86
|
+
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
87
|
+
|
88
|
+
|
89
|
+
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
90
|
+
jax.block_until_ready(run(s))
|
91
|
+
|
92
|
+
t0 = time.time()
|
93
|
+
n, rate = jax.block_until_ready(run(s))
|
94
|
+
t1 = time.time()
|
95
|
+
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
96
|
+
|
97
|
+
|
98
|
+
# A6000 NVIDIA GPU
|
99
|
+
|
100
|
+
# scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
|
101
|
+
# scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
|
102
|
+
# scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
|
103
|
+
# scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
|
104
|
+
# scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
|
105
|
+
# scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
|
106
|
+
# scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
|
107
|
+
# scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
|
108
|
+
# scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
|
109
|
+
# scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
|
110
|
+
# scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
|
111
|
+
|
112
|
+
|
113
|
+
# AMD Ryzen 7 7840HS
|
114
|
+
|
115
|
+
# scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
|
116
|
+
# scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
|
117
|
+
# scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
|
118
|
+
# scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
|
119
|
+
# scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
|
120
|
+
# scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
|
121
|
+
# scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
|
122
|
+
# scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
|
123
|
+
# scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
|
124
|
+
# scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
|
125
|
+
# scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
|
benchmark/CUBA_2005.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
#
|
17
|
+
# Implementation of the paper:
|
18
|
+
#
|
19
|
+
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
+
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
+
#
|
22
|
+
# which is based on the balanced network proposed by:
|
23
|
+
#
|
24
|
+
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
+
#
|
26
|
+
|
27
|
+
import os
|
28
|
+
import sys
|
29
|
+
|
30
|
+
sys.path.append('../')
|
31
|
+
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
32
|
+
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
33
|
+
|
34
|
+
|
35
|
+
import jax
|
36
|
+
import time
|
37
|
+
|
38
|
+
import brainunit as u
|
39
|
+
|
40
|
+
import brainstate as bst
|
41
|
+
|
42
|
+
|
43
|
+
|
44
|
+
class FixedProb(bst.nn.Module):
|
45
|
+
def __init__(self, n_pre, n_post, prob, weight):
|
46
|
+
super().__init__()
|
47
|
+
self.prob = prob
|
48
|
+
self.weight = weight
|
49
|
+
self.n_pre = n_pre
|
50
|
+
self.n_post = n_post
|
51
|
+
|
52
|
+
self.mask = bst.random.rand(n_pre, n_post) < prob
|
53
|
+
|
54
|
+
def update(self, x):
|
55
|
+
return (x @ self.mask) * self.weight
|
56
|
+
|
57
|
+
|
58
|
+
class EINet(bst.nn.DynamicsGroup):
|
59
|
+
def __init__(self, scale=1.0):
|
60
|
+
super().__init__()
|
61
|
+
self.n_exc = int(3200 * scale)
|
62
|
+
self.n_inh = int(800 * scale)
|
63
|
+
self.num = self.n_exc + self.n_inh
|
64
|
+
self.N = bst.nn.LIFRef(
|
65
|
+
self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
66
|
+
tau=20. * u.ms, tau_ref=5. * u.ms,
|
67
|
+
V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
|
68
|
+
)
|
69
|
+
self.E = bst.nn.AlignPostProj(
|
70
|
+
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
71
|
+
# comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
72
|
+
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
73
|
+
out=bst.nn.CUBA.desc(scale=u.volt),
|
74
|
+
post=self.N
|
75
|
+
)
|
76
|
+
self.I = bst.nn.AlignPostProj(
|
77
|
+
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
78
|
+
# comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
79
|
+
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
80
|
+
out=bst.nn.CUBA.desc(scale=u.volt),
|
81
|
+
post=self.N
|
82
|
+
)
|
83
|
+
|
84
|
+
def init_state(self, *args, **kwargs):
|
85
|
+
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
86
|
+
|
87
|
+
def update(self, t, inp):
|
88
|
+
with bst.environ.context(t=t):
|
89
|
+
spk = self.N.get_spike()
|
90
|
+
self.E(spk[:self.n_exc])
|
91
|
+
self.I(spk[self.n_exc:])
|
92
|
+
self.N(inp)
|
93
|
+
self.rate.value += self.N.get_spike()
|
94
|
+
|
95
|
+
|
96
|
+
@bst.compile.jit(static_argnums=0)
|
97
|
+
def run(scale: float):
|
98
|
+
# network
|
99
|
+
net = EINet(scale)
|
100
|
+
bst.nn.init_all_states(net)
|
101
|
+
|
102
|
+
duration = 1e4 * u.ms
|
103
|
+
# simulation
|
104
|
+
with bst.environ.context(dt=0.1 * u.ms):
|
105
|
+
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
106
|
+
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
|
107
|
+
# pbar=bst.compile.ProgressBar(100)
|
108
|
+
)
|
109
|
+
|
110
|
+
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
111
|
+
|
112
|
+
|
113
|
+
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
114
|
+
jax.block_until_ready(run(s))
|
115
|
+
|
116
|
+
t0 = time.time()
|
117
|
+
n, rate = jax.block_until_ready(run(s))
|
118
|
+
t1 = time.time()
|
119
|
+
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
120
|
+
|
121
|
+
|
122
|
+
# A6000 NVIDIA GPU
|
123
|
+
|
124
|
+
# scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
|
125
|
+
# scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
|
126
|
+
# scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
|
127
|
+
# scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
|
128
|
+
# scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
|
129
|
+
# scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
|
130
|
+
# scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
|
131
|
+
# scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
|
132
|
+
# scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
|
133
|
+
# scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
|
134
|
+
# scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
|
135
|
+
|
136
|
+
|
137
|
+
# AMD Ryzen 7 7840HS
|
138
|
+
|
139
|
+
# scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
|
140
|
+
# scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
|
141
|
+
# scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
|
142
|
+
# scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
|
143
|
+
# scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
|
144
|
+
# scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
|
145
|
+
# scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
|
146
|
+
# scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
|
147
|
+
# scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
|
148
|
+
# scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
|
149
|
+
# scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
|
brainstate/augment/_autograd.py
CHANGED
@@ -45,7 +45,7 @@ from brainstate.typing import PyTree, Missing
|
|
45
45
|
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
46
46
|
|
47
47
|
__all__ = [
|
48
|
-
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
48
|
+
'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
49
49
|
]
|
50
50
|
|
51
51
|
A = TypeVar('A')
|
@@ -159,6 +159,9 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
|
|
159
159
|
return jacfun
|
160
160
|
|
161
161
|
|
162
|
+
TransformFn = Callable
|
163
|
+
|
164
|
+
|
162
165
|
class GradientTransform(PrettyRepr):
|
163
166
|
"""
|
164
167
|
Automatic Differentiation Transformations for the ``State`` system.
|
@@ -168,11 +171,11 @@ class GradientTransform(PrettyRepr):
|
|
168
171
|
def __init__(
|
169
172
|
self,
|
170
173
|
target: Callable,
|
171
|
-
transform:
|
172
|
-
grad_states:
|
173
|
-
argnums: Optional[Union[int, Sequence[int]]],
|
174
|
-
return_value: bool,
|
175
|
-
has_aux: bool,
|
174
|
+
transform: TransformFn,
|
175
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
176
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
177
|
+
return_value: bool = False,
|
178
|
+
has_aux: bool = False,
|
176
179
|
transform_params: Optional[Dict[str, Any]] = None,
|
177
180
|
):
|
178
181
|
# gradient variables
|
brainstate/event/__init__.py
CHANGED
@@ -19,7 +19,9 @@ from ._csr import __all__ as __all_csr
|
|
19
19
|
from ._fixed_probability import *
|
20
20
|
from ._fixed_probability import __all__ as __all_fixed_probability
|
21
21
|
from ._linear import *
|
22
|
+
from ._xla_custom_op import *
|
23
|
+
from ._xla_custom_op import __all__ as __all_xla_custom_op
|
22
24
|
from ._linear import __all__ as __all_linear
|
23
25
|
|
24
|
-
__all__ = __all_fixed_probability + __all_linear + __all_csr
|
25
|
-
del __all_fixed_probability, __all_linear, __all_csr
|
26
|
+
__all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
|
27
|
+
del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
|
brainstate/event/_csr.py
CHANGED
@@ -21,12 +21,11 @@ import jax
|
|
21
21
|
import jax.numpy as jnp
|
22
22
|
import numpy as np
|
23
23
|
|
24
|
-
from brainstate._state import ParamState
|
24
|
+
from brainstate._state import ParamState
|
25
25
|
from brainstate._utils import set_module_as
|
26
26
|
from brainstate.init import param
|
27
27
|
from brainstate.nn._module import Module
|
28
|
-
from brainstate.typing import ArrayLike
|
29
|
-
from ._misc import IntScalar
|
28
|
+
from brainstate.typing import ArrayLike, Size
|
30
29
|
|
31
30
|
__all__ = [
|
32
31
|
'CSRLinear',
|
@@ -39,12 +38,12 @@ class CSRLinear(Module):
|
|
39
38
|
|
40
39
|
Parameters
|
41
40
|
----------
|
42
|
-
|
43
|
-
Number of pre-synaptic neurons.
|
44
|
-
|
45
|
-
Number of post-synaptic neurons.
|
41
|
+
in_size : Size
|
42
|
+
Number of pre-synaptic neurons, i.e., input size.
|
43
|
+
out_size : Size
|
44
|
+
Number of post-synaptic neurons, i.e., output size.
|
46
45
|
weight : float or callable or jax.Array or brainunit.Quantity
|
47
|
-
Maximum synaptic conductance.
|
46
|
+
Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
|
48
47
|
name : str, optional
|
49
48
|
Name of the module.
|
50
49
|
"""
|
@@ -53,8 +52,8 @@ class CSRLinear(Module):
|
|
53
52
|
|
54
53
|
def __init__(
|
55
54
|
self,
|
56
|
-
|
57
|
-
|
55
|
+
in_size: Size,
|
56
|
+
out_size: Size,
|
58
57
|
indptr: ArrayLike,
|
59
58
|
indices: ArrayLike,
|
60
59
|
weight: Union[Callable, ArrayLike],
|
@@ -63,10 +62,11 @@ class CSRLinear(Module):
|
|
63
62
|
):
|
64
63
|
super().__init__(name=name)
|
65
64
|
|
66
|
-
|
67
|
-
self.
|
68
|
-
self.
|
69
|
-
self.
|
65
|
+
# network size
|
66
|
+
self.in_size = in_size
|
67
|
+
self.out_size = out_size
|
68
|
+
self.n_pre = self.in_size[-1]
|
69
|
+
self.n_post = self.out_size[-1]
|
70
70
|
|
71
71
|
# gradient mode
|
72
72
|
assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
|
@@ -77,9 +77,10 @@ class CSRLinear(Module):
|
|
77
77
|
indices = jnp.asarray(indices)
|
78
78
|
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
79
79
|
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
80
|
-
assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
|
81
|
-
|
82
|
-
|
80
|
+
assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
|
81
|
+
with jax.ensure_compile_time_eval():
|
82
|
+
self.indptr = u.math.asarray(indptr)
|
83
|
+
self.indices = u.math.asarray(indices)
|
83
84
|
|
84
85
|
# maximum synaptic conductance
|
85
86
|
weight = param(weight, (len(indices),), allow_none=False)
|
@@ -88,7 +89,9 @@ class CSRLinear(Module):
|
|
88
89
|
self.weight = ParamState(weight)
|
89
90
|
|
90
91
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
91
|
-
weight = self.weight.value
|
92
|
+
weight = self.weight.value
|
93
|
+
|
94
|
+
# return zero if no pre-synaptic neurons
|
92
95
|
if len(self.indices) == 0:
|
93
96
|
r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
|
94
97
|
dtype=weight.dtype,
|
@@ -96,6 +99,8 @@ class CSRLinear(Module):
|
|
96
99
|
return u.maybe_decimal(r)
|
97
100
|
|
98
101
|
device_kind = jax.devices()[0].platform # spk.device.device_kind
|
102
|
+
|
103
|
+
# CPU implementation
|
99
104
|
if device_kind == 'cpu':
|
100
105
|
return cpu_event_csr(
|
101
106
|
u.math.asarray(spk),
|
@@ -104,8 +109,11 @@ class CSRLinear(Module):
|
|
104
109
|
u.math.asarray(weight),
|
105
110
|
n_post=self.n_post, grad_mode=self.grad_mode
|
106
111
|
)
|
112
|
+
|
113
|
+
# GPU/TPU implementation
|
107
114
|
elif device_kind in ['gpu', 'tpu']:
|
108
115
|
raise NotImplementedError()
|
116
|
+
|
109
117
|
else:
|
110
118
|
raise ValueError(f"Unsupported device: {device_kind}")
|
111
119
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|