brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/augment/_autograd.py +9 -6
  4. brainstate/event/__init__.py +4 -2
  5. brainstate/event/_csr.py +26 -18
  6. brainstate/event/_csr_benchmark.py +14 -0
  7. brainstate/event/_fixed_probability.py +589 -152
  8. brainstate/event/_fixed_probability_benchmark.py +128 -0
  9. brainstate/event/_fixed_probability_test.py +13 -10
  10. brainstate/event/_linear.py +267 -127
  11. brainstate/event/_linear_benckmark.py +82 -0
  12. brainstate/event/_linear_test.py +8 -3
  13. brainstate/event/_xla_custom_op.py +312 -0
  14. brainstate/event/_xla_custom_op_test.py +55 -0
  15. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  16. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  17. brainstate/nn/_dynamics/_projection_base.py +1 -1
  18. brainstate/nn/_exp_euler.py +1 -1
  19. brainstate/nn/_interaction/__init__.py +13 -4
  20. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  21. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  22. brainstate/nn/_interaction/_linear.py +582 -0
  23. brainstate/nn/_interaction/_linear_test.py +42 -0
  24. brainstate/optim/_lr_scheduler.py +1 -1
  25. brainstate/optim/_optax_optimizer.py +18 -0
  26. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
  27. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
  28. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,128 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ # n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.004549980163574219 s
18
+ # n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.04318690299987793 s
19
+ # Acceleration ratio: 8.491668413330538
20
+ #
21
+ # n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.005620718002319336 s
22
+ # n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 1.3311548233032227 s
23
+ # Acceleration ratio: 235.83003181336161
24
+ #
25
+ # n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.015388727188110352 s
26
+ # n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 10.791011333465576 s
27
+ # Acceleration ratio: 700.2283213262065
28
+ #
29
+ # n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.01043844223022461 s
30
+ # n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.8944694995880127 s
31
+ # Acceleration ratio: 84.68994107167329
32
+ #
33
+ # n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.021282196044921875 s
34
+ # n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.388156414031982 s
35
+ # Acceleration ratio: 1003.9788268506901
36
+ #
37
+ # n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.025498151779174805 s
38
+ # n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.211663246154785 s
39
+ # Acceleration ratio: 830.8902259997943
40
+ #
41
+ # n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.044051408767700195 s
42
+ # n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 42.31502842903137 s
43
+ # Acceleration ratio: 959.5828647200498
44
+ #
45
+ # n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.06666803359985352 s
46
+ # n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 62.46805453300476 s
47
+ # Acceleration ratio: 936.0016057162067
48
+ #
49
+ # n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.08313393592834473 s
50
+ # n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 63.61667847633362 s
51
+ # Acceleration ratio: 764.231163013459
52
+ #
53
+ #
54
+
55
+
56
+ import os
57
+
58
+ # os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
59
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
60
+
61
+ import jax
62
+ #
63
+ # jax.config.update('jax_cpu_enable_async_dispatch', False)
64
+
65
+ import time
66
+ import brainstate as bst
67
+
68
+
69
+ def forward(n_pre, n_post, conn_prob, spk_prob, as_float: bool):
70
+ linear = bst.event.FixedProb(n_pre, n_post, prob=conn_prob, weight=bst.init.Normal())
71
+ spike = (bst.random.rand(n_pre) < spk_prob)
72
+
73
+ if as_float:
74
+ spike = spike.astype(float)
75
+
76
+ @jax.jit
77
+ def f1(spike):
78
+ return linear(spike)
79
+
80
+ weight = bst.init.Normal()([n_pre, n_post])
81
+
82
+ @jax.jit
83
+ def f2(spike):
84
+ return spike @ weight
85
+
86
+ y1 = jax.block_until_ready(f1(spike))
87
+ y2 = jax.block_until_ready(f2(spike))
88
+ # print('max difference:', jax.numpy.abs(y1 - y2).max())
89
+
90
+ n = 1000
91
+ t0 = time.time()
92
+ for _ in range(n):
93
+ jax.block_until_ready(f1(spike))
94
+ r1 = time.time() - t0
95
+ print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Linear: {r1} s")
96
+
97
+ t0 = time.time()
98
+ for _ in range(n):
99
+ jax.block_until_ready(f2(spike))
100
+ r2 = time.time() - t0
101
+ print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Matmul: {r2} s")
102
+ print('Acceleration ratio:', r2 / r1 - 1.)
103
+
104
+ print()
105
+ bst.util.clear_buffer_memory()
106
+
107
+
108
+ def benchmark_forward():
109
+ for n_pre, n_post in [
110
+ (1000, 1000),
111
+ (1000, 10000),
112
+ (10000, 10000),
113
+ (10000, 1000),
114
+ (10000, 20000),
115
+ (20000, 10000),
116
+ (20000, 20000),
117
+ (20000, 30000),
118
+ (30000, 20000),
119
+ ]:
120
+ forward(n_pre, n_post, 0.01, 0.01, False)
121
+
122
+
123
+ if __name__ == '__main__':
124
+ pass
125
+ # forward(1000, 6400, 0.01, 0.01, False)
126
+ # forward(10000, 12800, 0.01, 0.01, False)
127
+
128
+ benchmark_forward()
@@ -15,12 +15,12 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+
18
19
  import jax.numpy
19
20
  import jax.numpy as jnp
20
21
  from absl.testing import parameterized
21
22
 
22
23
  import brainstate as bst
23
- from brainstate.event._fixed_probability import FixedProb
24
24
 
25
25
 
26
26
  class TestFixedProbCSR(parameterized.TestCase):
@@ -30,18 +30,18 @@ class TestFixedProbCSR(parameterized.TestCase):
30
30
  def test1(self, allow_multi_conn):
31
31
  x = bst.random.rand(20) < 0.1
32
32
  # x = bst.random.rand(20)
33
- m = FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
33
+ m = bst.event.FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
34
34
  y = m(x)
35
35
  print(y)
36
36
 
37
- m2 = FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
37
+ m2 = bst.event.FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
38
38
  print(m2(x))
39
39
 
40
40
  def test_grad_bool(self):
41
41
  n_in = 20
42
42
  n_out = 30
43
43
  x = bst.random.rand(n_in) < 0.3
44
- fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
44
+ fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
45
45
 
46
46
  def f(x):
47
47
  return fn(x).sum()
@@ -62,16 +62,16 @@ class TestFixedProbCSR(parameterized.TestCase):
62
62
  x = bst.random.rand(n_in)
63
63
 
64
64
  if homo_w:
65
- fn = FixedProb(n_in, n_out, 0.1, 1.5, seed=123)
65
+ fn = bst.event.FixedProb(n_in, n_out, 0.1, 1.5, seed=123, float_as_event=bool_x)
66
66
  else:
67
- fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
67
+ fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123, float_as_event=bool_x)
68
68
  w = fn.weight.value
69
69
 
70
70
  def f(x, w):
71
71
  fn.weight.value = w
72
72
  return fn(x).sum()
73
73
 
74
- r = bst.transform.grad(f, argnums=(0, 1))(x, w)
74
+ r = bst.augment.grad(f, argnums=(0, 1))(x, w)
75
75
 
76
76
  # -------------------
77
77
  # TRUE gradients
@@ -88,7 +88,6 @@ class TestFixedProbCSR(parameterized.TestCase):
88
88
  r2 = jax.grad(f2, argnums=(0, 1))(x, w)
89
89
  self.assertTrue(jnp.allclose(r[0], r2[0]))
90
90
  self.assertTrue(jnp.allclose(r[1], r2[1]))
91
- print(r[1])
92
91
 
93
92
  @parameterized.product(
94
93
  bool_x=[True, False],
@@ -102,7 +101,11 @@ class TestFixedProbCSR(parameterized.TestCase):
102
101
  else:
103
102
  x = bst.random.rand(n_in)
104
103
 
105
- fn = FixedProb(n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(), seed=123, grad_mode='jvp')
104
+ fn = bst.event.FixedProb(
105
+ n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(),
106
+ seed=123,
107
+ float_as_event=bool_x
108
+ )
106
109
  w = fn.weight.value
107
110
 
108
111
  def f(x, w):
@@ -124,5 +127,5 @@ class TestFixedProbCSR(parameterized.TestCase):
124
127
  return true_fn(x, w, fn.indices, n_out)
125
128
 
126
129
  o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
127
- self.assertTrue(jnp.allclose(r1, r2))
128
130
  self.assertTrue(jnp.allclose(o1, o2))
131
+ self.assertTrue(jnp.allclose(r1, r2))
@@ -12,21 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  from __future__ import annotations
16
17
 
17
18
  from typing import Union, Callable, Optional
18
19
 
19
20
  import brainunit as u
20
21
  import jax
22
+ import jax.experimental.pallas as pl
21
23
  import jax.numpy as jnp
22
24
  import numpy as np
25
+ from jax.interpreters import ad
23
26
 
24
27
  from brainstate._state import ParamState, State
25
- from brainstate._utils import set_module_as
26
28
  from brainstate.init import param
27
29
  from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike
29
- from ._misc import IntScalar
30
+ from brainstate.typing import ArrayLike, Size
31
+ from ._xla_custom_op import XLACustomOp
30
32
 
31
33
  __all__ = [
32
34
  'Linear',
@@ -39,12 +41,16 @@ class Linear(Module):
39
41
 
40
42
  Parameters
41
43
  ----------
42
- n_pre : int
43
- Number of pre-synaptic neurons.
44
- n_post : int
45
- Number of post-synaptic neurons.
44
+ in_size : Size
45
+ Number of pre-synaptic neurons, i.e., input size.
46
+ out_size : Size
47
+ Number of post-synaptic neurons, i.e., output size.
46
48
  weight : float or callable or jax.Array or brainunit.Quantity
47
49
  Maximum synaptic conductance.
50
+ block_size : int, optional
51
+ Block size for parallel computation.
52
+ float_as_event : bool, optional
53
+ Whether to treat float as event.
48
54
  name : str, optional
49
55
  Name of the module.
50
56
  """
@@ -53,167 +59,301 @@ class Linear(Module):
53
59
 
54
60
  def __init__(
55
61
  self,
56
- n_pre: IntScalar,
57
- n_post: IntScalar,
62
+ in_size: Size,
63
+ out_size: Size,
58
64
  weight: Union[Callable, ArrayLike],
65
+ float_as_event: bool = True,
66
+ block_size: int = 64,
59
67
  name: Optional[str] = None,
60
- grad_mode: str = 'vjp'
61
68
  ):
62
69
  super().__init__(name=name)
63
- self.n_pre = n_pre
64
- self.n_post = n_post
65
- self.in_size = n_pre
66
- self.out_size = n_post
67
70
 
68
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
- self.grad_mode = grad_mode
71
+ # network parameters
72
+ self.in_size = in_size
73
+ self.out_size = out_size
74
+ self.float_as_event = float_as_event
75
+ self.block_size = block_size
70
76
 
71
77
  # maximum synaptic conductance
72
- weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
78
+ weight = param(weight, (self.in_size[-1], self.out_size[-1]), allow_none=False)
73
79
  self.weight = ParamState(weight)
74
80
 
75
81
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
76
82
  weight = self.weight.value if isinstance(self.weight, State) else self.weight
77
83
  if u.math.size(weight) == 1:
78
- return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
79
-
80
- device_kind = jax.devices()[0].platform # spk.device.device_kind
81
- if device_kind == 'cpu':
82
- return cpu_event_linear(u.math.asarray(weight),
83
- u.math.asarray(spk),
84
- n_post=self.n_post,
85
- grad_mode=self.grad_mode)
86
- elif device_kind in ['gpu', 'tpu']:
87
- raise NotImplementedError()
88
- else:
89
- raise ValueError(f"Unsupported device: {device_kind}")
90
-
91
-
92
- @set_module_as('brainstate.event')
93
- def cpu_event_linear(
94
- g_max: Union[u.Quantity, jax.Array],
95
- spk: jax.Array,
96
- *,
97
- n_post: int = None,
98
- grad_mode: str = 'vjp'
99
- ) -> Union[u.Quantity, jax.Array]:
84
+ return u.math.ones(self.out_size) * (u.math.sum(spk) * weight)
85
+
86
+ return event_linear(spk, weight, block_size=self.block_size, float_as_event=self.float_as_event)
87
+
88
+
89
+ def event_linear(spk, weight, *, block_size, float_as_event) -> jax.Array | u.Quantity:
100
90
  """
101
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
91
+ The event-driven linear computation.
102
92
 
103
93
  Parameters
104
94
  ----------
105
- n_post : int
106
- Number of post-synaptic neurons.
107
- g_max : brainunit.Quantity or jax.Array
95
+ weight : brainunit.Quantity or jax.Array
108
96
  Maximum synaptic conductance.
109
97
  spk : jax.Array
110
98
  Spike events.
111
- grad_mode : str, optional
112
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
99
+ block_size : int
100
+ Block size for parallel computation.
101
+ float_as_event : bool
102
+ Whether to treat float as event.
113
103
 
114
104
  Returns
115
105
  -------
116
106
  post_data : brainunit.Quantity or jax.Array
117
107
  Post synaptic data.
118
108
  """
119
- unit = u.get_unit(g_max)
120
- g_max = u.get_mantissa(g_max)
121
- spk = jnp.asarray(spk)
109
+ with jax.ensure_compile_time_eval():
110
+ weight = u.math.asarray(weight)
111
+ unit = u.get_unit(weight)
112
+ weight = u.get_mantissa(weight)
113
+ spk = jnp.asarray(spk)
122
114
 
123
115
  def mv(spk_vector):
124
116
  assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
125
- if jnp.size(g_max) == 1:
126
- assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
127
- # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
128
- return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
129
-
130
- if grad_mode == 'vjp':
131
- post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
132
- elif grad_mode == 'jvp':
133
- post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
134
- else:
135
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
136
- return post
117
+ return event_liner_p_call(
118
+ spk,
119
+ weight,
120
+ block_size=block_size,
121
+ float_as_event=float_as_event,
122
+ )
137
123
 
138
124
  assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
139
- assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
125
+ assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
140
126
 
141
127
  if spk.ndim == 1:
142
- post_data = mv(spk)
128
+ [post_data] = mv(spk)
143
129
  else:
144
- shape = spk.shape[:-1]
145
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
146
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
130
+ [post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
131
+ post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
147
132
  return u.maybe_decimal(u.Quantity(post_data, unit=unit))
148
133
 
149
134
 
150
- # --------------
151
- # Implementation
152
- # --------------
153
-
154
-
155
- def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
156
- def scan_fn(post, i):
157
- sp = spk[i]
158
- if spk.dtype == jnp.bool_:
159
- post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
160
- else:
161
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
162
- return post, None
163
-
164
- return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
165
-
166
-
167
- # --------------
168
- # VJP
169
- # --------------
170
-
171
- def _cpu_event_linear_mv_fwd(g_max, spk):
172
- return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
135
+ Kernel = Callable
173
136
 
174
137
 
175
- def _cpu_event_linear_mv_bwd(res, ct):
176
- g_max, spk = res
138
+ def cpu_kernel_generator(
139
+ float_as_event: bool,
140
+ spk_info: jax.ShapeDtypeStruct,
141
+ **kwargs
142
+ ) -> Kernel:
143
+ import numba # pylint: disable=import-outside-toplevel
177
144
 
178
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
179
- ct_spk = jnp.matmul(g_max, ct)
145
+ if spk_info.dtype == jnp.bool_:
180
146
 
181
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
182
- def map_fn(sp):
183
- if spk.dtype == jnp.bool_:
184
- d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
185
- else:
186
- d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
187
- return d_gmax
147
+ @numba.njit
148
+ def _kernel(spikes, weights, posts):
149
+ r = np.zeros((weights.shape[1],), dtype=weights.dtype)
150
+ for i in range(spikes.shape[0]):
151
+ if spikes[i]:
152
+ r = r + weights[i]
153
+ posts[:] = r
188
154
 
189
- ct_gmax = jax.vmap(map_fn)(spk)
190
- return ct_gmax, ct_spk
155
+ elif float_as_event:
156
+ @numba.njit
157
+ def _kernel(spikes, weights, posts):
158
+ r = np.zeros((weights.shape[1],), dtype=weights.dtype)
159
+ for i in range(spikes.shape[0]):
160
+ if spikes[i] != 0.:
161
+ r = r + weights[i]
162
+ posts[:] = r
191
163
 
164
+ else:
165
+ @numba.njit
166
+ def _kernel(spikes, weights, posts):
167
+ r = np.zeros((weights.shape[1],), dtype=weights.dtype)
168
+ for i in range(spikes.shape[0]):
169
+ sp = spikes[i]
170
+ if sp != 0.:
171
+ r = r + weights[i] * sp
172
+ posts[:] = r
173
+
174
+ return _kernel
175
+
176
+
177
+ def gpu_kernel_generator(
178
+ block_size: int,
179
+ float_as_event: bool,
180
+ weight_info: jax.ShapeDtypeStruct,
181
+ **kwargs
182
+ ) -> Kernel:
183
+ # # 每个block处理一个[block_size,]的post
184
+ # # 每个block处理一个[n_pre]的pre
185
+ # # 每个block处理一个[n_pre, block_size]的w
186
+ # def _mv_kernel(sp_ref, w_ref, post_ref):
187
+ #
188
+ # pid = pl.program_id(0)
189
+ #
190
+ # def scan_fn(i, post_):
191
+ # if sp_ref.dtype == jnp.bool_:
192
+ # post_ = jax.lax.cond(
193
+ # sp_ref[i],
194
+ # lambda: post_ + w_ref[i, ...],
195
+ # lambda: post_
196
+ # )
197
+ # else:
198
+ # if float_as_event:
199
+ # post_ = jax.lax.cond(
200
+ # sp_ref[i] != 0.,
201
+ # lambda: post_ + w_ref[i, ...],
202
+ # lambda: post_
203
+ # )
204
+ # else:
205
+ # sp = sp_ref[i]
206
+ # post_ = jax.lax.cond(
207
+ # sp != 0.,
208
+ # lambda: post_ + w_ref[i, ...] * sp,
209
+ # lambda: post_
210
+ # )
211
+ # return post_
212
+ #
213
+ # post = jax.lax.fori_loop(0, n_pre, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
214
+ # mask = jnp.arange(block_size) + pid * block_size < n_pre
215
+ # pl.store(post_ref, pl.dslice(None, None), post, mask=mask)
216
+ #
217
+ # n_pre = weight_info.shape[0]
218
+ # n_post = weight_info.shape[1]
219
+ # kernel = pl.pallas_call(
220
+ # _mv_kernel,
221
+ # out_shape=[
222
+ # jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
223
+ # ],
224
+ # out_specs=[
225
+ # pl.BlockSpec((block_size,), lambda i: i),
226
+ # ],
227
+ # in_specs=[
228
+ # pl.BlockSpec((n_pre,), lambda i: 0),
229
+ # pl.BlockSpec((n_pre, block_size), lambda i: (0, i)),
230
+ # ],
231
+ # grid=(
232
+ # pl.cdiv(n_post, block_size),
233
+ # ),
234
+ # interpret=False,
235
+ # )
236
+ # return kernel
237
+
238
+ # 每个block处理一个[block_size,]的post
239
+ # 每个block处理一个[block_size]的pre
240
+ # 每个block处理一个[block_size, block_size]的w
241
+ def _mv_kernel(
242
+ sp_ref, # [block_size]
243
+ w_ref, # [block_size, block_size]
244
+ post_ref, # [block_size]
245
+ ):
192
246
 
193
- _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
194
- _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
195
-
196
-
197
- # --------------
198
- # JVP
199
- # --------------
200
-
201
-
202
- def _cpu_event_linear_mv_jvp_rule(primals, tangents):
203
- # forward pass
204
- g_max, spk = primals
205
- y = _cpu_event_linear_mv(g_max, spk)
206
-
207
- # forward gradients
208
- gmax_dot, spk_dot = tangents
209
-
210
- # ∂y/∂gmax
211
- dgmax = _cpu_event_linear_mv(gmax_dot, spk)
212
-
213
- # ∂y/∂gspk
214
- dspk = spk_dot @ g_max
215
- return y, dgmax + dspk
216
-
247
+ r_pid = pl.program_id(0)
248
+ c_start = pl.program_id(1) * block_size
249
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
250
+ mask = jnp.arange(block_size) + c_start < weight_info.shape[1]
251
+
252
+ def scan_fn(i, post_):
253
+ if sp_ref.dtype == jnp.bool_:
254
+ post_ = jax.lax.cond(
255
+ sp_ref[i],
256
+ lambda: post_ + w_ref[i, ...],
257
+ lambda: post_
258
+ )
259
+ else:
260
+ if float_as_event:
261
+ post_ = jax.lax.cond(
262
+ sp_ref[i] != 0.,
263
+ lambda: post_ + w_ref[i, ...],
264
+ lambda: post_
265
+ )
266
+ else:
267
+ sp = sp_ref[i]
268
+ post_ = jax.lax.cond(
269
+ sp != 0.,
270
+ lambda: post_ + w_ref[i, ...] * sp,
271
+ lambda: post_
272
+ )
273
+ return post_
274
+
275
+ post = jax.lax.fori_loop(0, row_length, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
276
+ pl.atomic_add(post_ref, pl.dslice(None, None), post, mask=mask)
277
+
278
+ n_pre = weight_info.shape[0]
279
+ n_post = weight_info.shape[1]
280
+ kernel = pl.pallas_call(
281
+ _mv_kernel,
282
+ out_shape=[
283
+ jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
284
+ ],
285
+ out_specs=[
286
+ pl.BlockSpec((block_size,), lambda i, j: j),
287
+ ],
288
+ in_specs=[
289
+ pl.BlockSpec((block_size,), lambda i, j: i),
290
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
291
+ ],
292
+ grid=(
293
+ pl.cdiv(n_pre, block_size),
294
+ pl.cdiv(n_post, block_size),
295
+ ),
296
+ interpret=False,
297
+ )
298
+ return kernel
299
+
300
+
301
+ def jvp_spikes(spk_dot, spikes, weights, **kwargs):
302
+ return [spk_dot @ weights]
303
+
304
+
305
+ def jvp_weights(w_dot, spikes, weights, *, float_as_event, block_size, **kwargs):
306
+ return event_liner_p_call(
307
+ spikes,
308
+ w_dot,
309
+ block_size=block_size,
310
+ float_as_event=float_as_event,
311
+ )
312
+
313
+
314
+ def transpose_rule(ct, spikes, weights, *, float_as_event, **kwargs):
315
+ if ad.is_undefined_primal(spikes):
316
+ ct_events = jnp.matmul(weights, ct[0])
317
+ return (ad.Zero(spikes) if type(ct[0]) is ad.Zero else ct_events), weights
217
318
 
218
- _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
219
- _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)
319
+ else:
320
+ def map_fn(sp):
321
+ if spikes.dtype == jnp.bool_:
322
+ d_gmax = jnp.where(sp, ct[0], jnp.zeros_like(ct[0]))
323
+ else:
324
+ if float_as_event:
325
+ d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0])
326
+ else:
327
+ d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0] * sp)
328
+ # d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct[0]), lambda: ct[0] * sp)
329
+ return d_gmax
330
+
331
+ ct_weights = jax.vmap(map_fn)(spikes)
332
+ return spikes, (ad.Zero(weights) if type(ct[0]) is ad.Zero else ct_weights)
333
+
334
+
335
+ event_linear_p = XLACustomOp(
336
+ 'event_linear',
337
+ cpu_kernel_generator=cpu_kernel_generator,
338
+ gpu_kernel_generator=gpu_kernel_generator,
339
+ )
340
+ event_linear_p.defjvp(jvp_spikes, jvp_weights)
341
+ event_linear_p.def_transpose_rule(transpose_rule)
342
+
343
+
344
+ def event_liner_p_call(
345
+ spikes,
346
+ weights,
347
+ *,
348
+ block_size,
349
+ float_as_event,
350
+ ):
351
+ return event_linear_p(
352
+ spikes,
353
+ weights,
354
+ outs=[jax.ShapeDtypeStruct([weights.shape[1]], weights.dtype)],
355
+ block_size=block_size,
356
+ float_as_event=float_as_event,
357
+ spk_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
358
+ weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
359
+ )