brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/augment/__init__.py +10 -20
  3. brainstate/compile/__init__.py +18 -37
  4. brainstate/compile/_make_jaxpr.py +9 -2
  5. brainstate/compile/_make_jaxpr_test.py +10 -6
  6. brainstate/compile/_progress_bar.py +49 -6
  7. brainstate/compile/_unvmap.py +3 -3
  8. brainstate/graph/__init__.py +12 -12
  9. brainstate/nn/_dyn_impl/_inputs.py +4 -2
  10. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  11. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
  12. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
  13. brainstate/event/__init__.py +0 -27
  14. brainstate/event/_csr.py +0 -1149
  15. brainstate/event/_csr_benchmark.py +0 -14
  16. brainstate/event/_csr_mv.py +0 -303
  17. brainstate/event/_csr_test.py +0 -277
  18. brainstate/event/_fixedprob_mv.py +0 -730
  19. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  20. brainstate/event/_fixedprob_mv_test.py +0 -132
  21. brainstate/event/_linear_mv.py +0 -359
  22. brainstate/event/_linear_mv_benckmark.py +0 -82
  23. brainstate/event/_linear_mv_test.py +0 -117
  24. brainstate/event/_misc.py +0 -34
  25. brainstate/event/_xla_custom_op.py +0 -317
  26. brainstate/event/_xla_custom_op_test.py +0 -55
  27. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
@@ -1,128 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- # n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.004549980163574219 s
18
- # n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.04318690299987793 s
19
- # Acceleration ratio: 8.491668413330538
20
- #
21
- # n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.005620718002319336 s
22
- # n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 1.3311548233032227 s
23
- # Acceleration ratio: 235.83003181336161
24
- #
25
- # n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.015388727188110352 s
26
- # n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 10.791011333465576 s
27
- # Acceleration ratio: 700.2283213262065
28
- #
29
- # n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.01043844223022461 s
30
- # n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.8944694995880127 s
31
- # Acceleration ratio: 84.68994107167329
32
- #
33
- # n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.021282196044921875 s
34
- # n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.388156414031982 s
35
- # Acceleration ratio: 1003.9788268506901
36
- #
37
- # n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.025498151779174805 s
38
- # n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.211663246154785 s
39
- # Acceleration ratio: 830.8902259997943
40
- #
41
- # n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.044051408767700195 s
42
- # n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 42.31502842903137 s
43
- # Acceleration ratio: 959.5828647200498
44
- #
45
- # n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.06666803359985352 s
46
- # n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 62.46805453300476 s
47
- # Acceleration ratio: 936.0016057162067
48
- #
49
- # n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.08313393592834473 s
50
- # n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 63.61667847633362 s
51
- # Acceleration ratio: 764.231163013459
52
- #
53
- #
54
-
55
-
56
- import os
57
-
58
- # os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
59
- os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
60
-
61
- import jax
62
- #
63
- # jax.config.update('jax_cpu_enable_async_dispatch', False)
64
-
65
- import time
66
- import brainstate as bst
67
-
68
-
69
- def forward(n_pre, n_post, conn_prob, spk_prob, as_float: bool):
70
- linear = bst.event.FixedProb(n_pre, n_post, prob=conn_prob, weight=bst.init.Normal())
71
- spike = (bst.random.rand(n_pre) < spk_prob)
72
-
73
- if as_float:
74
- spike = spike.astype(float)
75
-
76
- @jax.jit
77
- def f1(spike):
78
- return linear(spike)
79
-
80
- weight = bst.init.Normal()([n_pre, n_post])
81
-
82
- @jax.jit
83
- def f2(spike):
84
- return spike @ weight
85
-
86
- y1 = jax.block_until_ready(f1(spike))
87
- y2 = jax.block_until_ready(f2(spike))
88
- # print('max difference:', jax.numpy.abs(y1 - y2).max())
89
-
90
- n = 1000
91
- t0 = time.time()
92
- for _ in range(n):
93
- jax.block_until_ready(f1(spike))
94
- r1 = time.time() - t0
95
- print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Linear: {r1} s")
96
-
97
- t0 = time.time()
98
- for _ in range(n):
99
- jax.block_until_ready(f2(spike))
100
- r2 = time.time() - t0
101
- print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Matmul: {r2} s")
102
- print('Acceleration ratio:', r2 / r1 - 1.)
103
-
104
- print()
105
- bst.util.clear_buffer_memory()
106
-
107
-
108
- def benchmark_forward():
109
- for n_pre, n_post in [
110
- (1000, 1000),
111
- (1000, 10000),
112
- (10000, 10000),
113
- (10000, 1000),
114
- (10000, 20000),
115
- (20000, 10000),
116
- (20000, 20000),
117
- (20000, 30000),
118
- (30000, 20000),
119
- ]:
120
- forward(n_pre, n_post, 0.01, 0.01, False)
121
-
122
-
123
- if __name__ == '__main__':
124
- pass
125
- # forward(1000, 6400, 0.01, 0.01, False)
126
- # forward(10000, 12800, 0.01, 0.01, False)
127
-
128
- benchmark_forward()
@@ -1,132 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
-
19
- import jax.numpy
20
- import jax.numpy as jnp
21
- from absl.testing import parameterized
22
-
23
- import brainstate as bst
24
-
25
-
26
- class TestFixedProbCSR(parameterized.TestCase):
27
- @parameterized.product(
28
- allow_multi_conn=[True, False]
29
- )
30
- def test1(self, allow_multi_conn):
31
- x = bst.random.rand(20) < 0.1
32
- # x = bst.random.rand(20)
33
- m = bst.event.FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
34
- y = m(x)
35
- print(y)
36
-
37
- m2 = bst.event.FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
38
- print(m2(x))
39
-
40
- def test_grad_bool(self):
41
- n_in = 20
42
- n_out = 30
43
- x = bst.random.rand(n_in) < 0.3
44
- fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
45
-
46
- def f(x):
47
- return fn(x).sum()
48
-
49
- with self.assertRaises(TypeError):
50
- print(jax.grad(f)(x))
51
-
52
- @parameterized.product(
53
- bool_x=[True, False],
54
- homo_w=[True, False]
55
- )
56
- def test_vjp(self, bool_x, homo_w):
57
- n_in = 20
58
- n_out = 30
59
- if bool_x:
60
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
61
- else:
62
- x = bst.random.rand(n_in)
63
-
64
- if homo_w:
65
- fn = bst.event.FixedProb(n_in, n_out, 0.1, 1.5, seed=123, float_as_event=bool_x)
66
- else:
67
- fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123, float_as_event=bool_x)
68
- w = fn.weight.value
69
-
70
- def f(x, w):
71
- fn.weight.value = w
72
- return fn(x).sum()
73
-
74
- r = bst.augment.grad(f, argnums=(0, 1))(x, w)
75
-
76
- # -------------------
77
- # TRUE gradients
78
-
79
- def true_fn(x, w, indices, n_post):
80
- post = jnp.zeros((n_post,))
81
- for i in range(n_in):
82
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
83
- return post
84
-
85
- def f2(x, w):
86
- return true_fn(x, w, fn.indices, n_out).sum()
87
-
88
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
89
- self.assertTrue(jnp.allclose(r[0], r2[0]))
90
- self.assertTrue(jnp.allclose(r[1], r2[1]))
91
-
92
- @parameterized.product(
93
- bool_x=[True, False],
94
- homo_w=[True, False]
95
- )
96
- def test_jvp(self, bool_x, homo_w):
97
- n_in = 20
98
- n_out = 30
99
- if bool_x:
100
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
101
- else:
102
- x = bst.random.rand(n_in)
103
-
104
- fn = bst.event.FixedProb(
105
- n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(),
106
- seed=123,
107
- float_as_event=bool_x
108
- )
109
- w = fn.weight.value
110
-
111
- def f(x, w):
112
- fn.weight.value = w
113
- return fn(x)
114
-
115
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
116
-
117
- # -------------------
118
- # TRUE gradients
119
-
120
- def true_fn(x, w, indices, n_post):
121
- post = jnp.zeros((n_post,))
122
- for i in range(n_in):
123
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
124
- return post
125
-
126
- def f2(x, w):
127
- return true_fn(x, w, fn.indices, n_out)
128
-
129
- o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
130
- self.assertTrue(jnp.allclose(o1, o2))
131
- # assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
132
- self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
@@ -1,359 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- from typing import Union, Callable, Optional
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.experimental.pallas as pl
23
- import jax.numpy as jnp
24
- import numpy as np
25
- from jax.interpreters import ad
26
-
27
- from brainstate._state import ParamState, State
28
- from brainstate.init import param
29
- from brainstate.nn._module import Module
30
- from brainstate.typing import ArrayLike, Size
31
- from ._xla_custom_op import XLACustomOp
32
-
33
- __all__ = [
34
- 'Linear',
35
- ]
36
-
37
-
38
- class Linear(Module):
39
- """
40
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
41
-
42
- Parameters
43
- ----------
44
- in_size : Size
45
- Number of pre-synaptic neurons, i.e., input size.
46
- out_size : Size
47
- Number of post-synaptic neurons, i.e., output size.
48
- weight : float or callable or jax.Array or brainunit.Quantity
49
- Maximum synaptic conductance.
50
- block_size : int, optional
51
- Block size for parallel computation.
52
- float_as_event : bool, optional
53
- Whether to treat float as event.
54
- name : str, optional
55
- Name of the module.
56
- """
57
-
58
- __module__ = 'brainstate.event'
59
-
60
- def __init__(
61
- self,
62
- in_size: Size,
63
- out_size: Size,
64
- weight: Union[Callable, ArrayLike],
65
- float_as_event: bool = True,
66
- block_size: int = 64,
67
- name: Optional[str] = None,
68
- ):
69
- super().__init__(name=name)
70
-
71
- # network parameters
72
- self.in_size = in_size
73
- self.out_size = out_size
74
- self.float_as_event = float_as_event
75
- self.block_size = block_size
76
-
77
- # maximum synaptic conductance
78
- weight = param(weight, (self.in_size[-1], self.out_size[-1]), allow_none=False)
79
- self.weight = ParamState(weight)
80
-
81
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
82
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
83
- if u.math.size(weight) == 1:
84
- return u.math.ones(self.out_size) * (u.math.sum(spk) * weight)
85
-
86
- return event_linear(spk, weight, block_size=self.block_size, float_as_event=self.float_as_event)
87
-
88
-
89
- def event_linear(spk, weight, *, block_size, float_as_event) -> jax.Array | u.Quantity:
90
- """
91
- The event-driven linear computation.
92
-
93
- Parameters
94
- ----------
95
- weight : brainunit.Quantity or jax.Array
96
- Maximum synaptic conductance.
97
- spk : jax.Array
98
- Spike events.
99
- block_size : int
100
- Block size for parallel computation.
101
- float_as_event : bool
102
- Whether to treat float as event.
103
-
104
- Returns
105
- -------
106
- post_data : brainunit.Quantity or jax.Array
107
- Post synaptic data.
108
- """
109
- with jax.ensure_compile_time_eval():
110
- weight = u.math.asarray(weight)
111
- unit = u.get_unit(weight)
112
- weight = u.get_mantissa(weight)
113
- spk = jnp.asarray(spk)
114
-
115
- def mv(spk_vector):
116
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
117
- return event_liner_p_call(
118
- spk,
119
- weight,
120
- block_size=block_size,
121
- float_as_event=float_as_event,
122
- )
123
-
124
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
125
- assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
126
-
127
- if spk.ndim == 1:
128
- [post_data] = mv(spk)
129
- else:
130
- [post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
131
- post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
132
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
133
-
134
-
135
- Kernel = Callable
136
-
137
-
138
- def cpu_kernel_generator(
139
- float_as_event: bool,
140
- spk_info: jax.ShapeDtypeStruct,
141
- **kwargs
142
- ) -> Kernel:
143
- import numba # pylint: disable=import-outside-toplevel
144
-
145
- if spk_info.dtype == jnp.bool_:
146
-
147
- @numba.njit
148
- def _kernel(spikes, weights, posts):
149
- r = np.zeros((weights.shape[1],), dtype=weights.dtype)
150
- for i in range(spikes.shape[0]):
151
- if spikes[i]:
152
- r = r + weights[i]
153
- posts[:] = r
154
-
155
- elif float_as_event:
156
- @numba.njit
157
- def _kernel(spikes, weights, posts):
158
- r = np.zeros((weights.shape[1],), dtype=weights.dtype)
159
- for i in range(spikes.shape[0]):
160
- if spikes[i] != 0.:
161
- r = r + weights[i]
162
- posts[:] = r
163
-
164
- else:
165
- @numba.njit
166
- def _kernel(spikes, weights, posts):
167
- r = np.zeros((weights.shape[1],), dtype=weights.dtype)
168
- for i in range(spikes.shape[0]):
169
- sp = spikes[i]
170
- if sp != 0.:
171
- r = r + weights[i] * sp
172
- posts[:] = r
173
-
174
- return _kernel
175
-
176
-
177
- def gpu_kernel_generator(
178
- block_size: int,
179
- float_as_event: bool,
180
- weight_info: jax.ShapeDtypeStruct,
181
- **kwargs
182
- ) -> Kernel:
183
- # # 每个block处理一个[block_size,]的post
184
- # # 每个block处理一个[n_pre]的pre
185
- # # 每个block处理一个[n_pre, block_size]的w
186
- # def _mv_kernel(sp_ref, w_ref, post_ref):
187
- #
188
- # pid = pl.program_id(0)
189
- #
190
- # def scan_fn(i, post_):
191
- # if sp_ref.dtype == jnp.bool_:
192
- # post_ = jax.lax.cond(
193
- # sp_ref[i],
194
- # lambda: post_ + w_ref[i, ...],
195
- # lambda: post_
196
- # )
197
- # else:
198
- # if float_as_event:
199
- # post_ = jax.lax.cond(
200
- # sp_ref[i] != 0.,
201
- # lambda: post_ + w_ref[i, ...],
202
- # lambda: post_
203
- # )
204
- # else:
205
- # sp = sp_ref[i]
206
- # post_ = jax.lax.cond(
207
- # sp != 0.,
208
- # lambda: post_ + w_ref[i, ...] * sp,
209
- # lambda: post_
210
- # )
211
- # return post_
212
- #
213
- # post = jax.lax.fori_loop(0, n_pre, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
214
- # mask = jnp.arange(block_size) + pid * block_size < n_pre
215
- # pl.store(post_ref, pl.dslice(None, None), post, mask=mask)
216
- #
217
- # n_pre = weight_info.shape[0]
218
- # n_post = weight_info.shape[1]
219
- # kernel = pl.pallas_call(
220
- # _mv_kernel,
221
- # out_shape=[
222
- # jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
223
- # ],
224
- # out_specs=[
225
- # pl.BlockSpec((block_size,), lambda i: i),
226
- # ],
227
- # in_specs=[
228
- # pl.BlockSpec((n_pre,), lambda i: 0),
229
- # pl.BlockSpec((n_pre, block_size), lambda i: (0, i)),
230
- # ],
231
- # grid=(
232
- # pl.cdiv(n_post, block_size),
233
- # ),
234
- # interpret=False,
235
- # )
236
- # return kernel
237
-
238
- # 每个block处理一个[block_size,]的post
239
- # 每个block处理一个[block_size]的pre
240
- # 每个block处理一个[block_size, block_size]的w
241
- def _mv_kernel(
242
- sp_ref, # [block_size]
243
- w_ref, # [block_size, block_size]
244
- post_ref, # [block_size]
245
- ):
246
-
247
- r_pid = pl.program_id(0)
248
- c_start = pl.program_id(1) * block_size
249
- row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
250
- mask = jnp.arange(block_size) + c_start < weight_info.shape[1]
251
-
252
- def scan_fn(i, post_):
253
- if sp_ref.dtype == jnp.bool_:
254
- post_ = jax.lax.cond(
255
- sp_ref[i],
256
- lambda: post_ + w_ref[i, ...],
257
- lambda: post_
258
- )
259
- else:
260
- if float_as_event:
261
- post_ = jax.lax.cond(
262
- sp_ref[i] != 0.,
263
- lambda: post_ + w_ref[i, ...],
264
- lambda: post_
265
- )
266
- else:
267
- sp = sp_ref[i]
268
- post_ = jax.lax.cond(
269
- sp != 0.,
270
- lambda: post_ + w_ref[i, ...] * sp,
271
- lambda: post_
272
- )
273
- return post_
274
-
275
- post = jax.lax.fori_loop(0, row_length, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
276
- pl.atomic_add(post_ref, pl.dslice(None, None), post, mask=mask)
277
-
278
- n_pre = weight_info.shape[0]
279
- n_post = weight_info.shape[1]
280
- kernel = pl.pallas_call(
281
- _mv_kernel,
282
- out_shape=[
283
- jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
284
- ],
285
- out_specs=[
286
- pl.BlockSpec((block_size,), lambda i, j: j),
287
- ],
288
- in_specs=[
289
- pl.BlockSpec((block_size,), lambda i, j: i),
290
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
291
- ],
292
- grid=(
293
- pl.cdiv(n_pre, block_size),
294
- pl.cdiv(n_post, block_size),
295
- ),
296
- interpret=False,
297
- )
298
- return kernel
299
-
300
-
301
- def jvp_spikes(spk_dot, spikes, weights, **kwargs):
302
- return [spk_dot @ weights]
303
-
304
-
305
- def jvp_weights(w_dot, spikes, weights, *, float_as_event, block_size, **kwargs):
306
- return event_liner_p_call(
307
- spikes,
308
- w_dot,
309
- block_size=block_size,
310
- float_as_event=float_as_event,
311
- )
312
-
313
-
314
- def transpose_rule(ct, spikes, weights, *, float_as_event, **kwargs):
315
- if ad.is_undefined_primal(spikes):
316
- ct_events = jnp.matmul(weights, ct[0])
317
- return (ad.Zero(spikes) if type(ct[0]) is ad.Zero else ct_events), weights
318
-
319
- else:
320
- def map_fn(sp):
321
- if spikes.dtype == jnp.bool_:
322
- d_gmax = jnp.where(sp, ct[0], jnp.zeros_like(ct[0]))
323
- else:
324
- if float_as_event:
325
- d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0])
326
- else:
327
- d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0] * sp)
328
- # d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct[0]), lambda: ct[0] * sp)
329
- return d_gmax
330
-
331
- ct_weights = jax.vmap(map_fn)(spikes)
332
- return spikes, (ad.Zero(weights) if type(ct[0]) is ad.Zero else ct_weights)
333
-
334
-
335
- event_linear_p = XLACustomOp(
336
- 'event_linear',
337
- cpu_kernel_or_generator=cpu_kernel_generator,
338
- gpu_kernel_or_generator=gpu_kernel_generator,
339
- )
340
- event_linear_p.defjvp(jvp_spikes, jvp_weights)
341
- event_linear_p.def_transpose_rule(transpose_rule)
342
-
343
-
344
- def event_liner_p_call(
345
- spikes,
346
- weights,
347
- *,
348
- block_size,
349
- float_as_event,
350
- ):
351
- return event_linear_p(
352
- spikes,
353
- weights,
354
- outs=[jax.ShapeDtypeStruct([weights.shape[1]], weights.dtype)],
355
- block_size=block_size,
356
- float_as_event=float_as_event,
357
- spk_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
358
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
359
- )
@@ -1,82 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import os
17
-
18
- os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
19
-
20
- import jax
21
-
22
- import time
23
- import brainstate as bst
24
-
25
-
26
- def forward(n_pre, n_post, spk_prob, as_float: bool):
27
- linear = bst.event.Linear(n_pre, n_post, weight=bst.init.KaimingUniform(), block_size=256)
28
- spike = (bst.random.rand(n_pre) < spk_prob)
29
-
30
- if as_float:
31
- spike = spike.astype(float)
32
-
33
- @jax.jit
34
- def f1(spike):
35
- return linear(spike)
36
-
37
- @jax.jit
38
- def f2(spike):
39
- return spike @ linear.weight.value
40
-
41
- y1 = jax.block_until_ready(f1(spike))
42
- y2 = jax.block_until_ready(f2(spike))
43
- print('max difference:', jax.numpy.abs(y1 - y2).max())
44
-
45
- n = 100
46
- t0 = time.time()
47
- for _ in range(n):
48
- jax.block_until_ready(f1(spike))
49
- r1 = time.time() - t0
50
- print(f"n_pre: {n_pre}, n_post: {n_post}, spike probability: {spk_prob}, Linear: {r1} s")
51
-
52
- t0 = time.time()
53
- for _ in range(n):
54
- jax.block_until_ready(f2(spike))
55
- r2 = time.time() - t0
56
- print(f"n_pre: {n_pre}, n_post: {n_post}, spike probability: {spk_prob}, Matmul: {r2} s")
57
- print('Acceleration ratio:', r2 / r1 - 1.)
58
-
59
- print()
60
-
61
-
62
- def benchmark_forward():
63
- for n_pre, n_post in [
64
- (1000, 1000),
65
- (1000, 10000),
66
- (10000, 10000),
67
- (10000, 1000),
68
- (20000, 10000),
69
- (20000, 20000),
70
- # (10000, 100000),
71
- ]:
72
- forward(n_pre, n_post, 0.01, True)
73
- forward(n_pre, n_post, 0.1, True)
74
- print()
75
- print()
76
-
77
-
78
- if __name__ == '__main__':
79
- # forward(1000, 2000, 0.01, True)
80
- # forward(2000, 4000, 0.01, True)
81
- # forward(10000, 20000, 0.01, True)
82
- benchmark_forward()