brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/augment/_autograd.py +9 -6
- brainstate/event/__init__.py +4 -2
- brainstate/event/_csr.py +26 -18
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_fixed_probability.py +589 -152
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +13 -10
- brainstate/event/_linear.py +267 -127
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +8 -3
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +18 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
# n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.004549980163574219 s
|
18
|
+
# n_pre: 1000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.04318690299987793 s
|
19
|
+
# Acceleration ratio: 8.491668413330538
|
20
|
+
#
|
21
|
+
# n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.005620718002319336 s
|
22
|
+
# n_pre: 1000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 1.3311548233032227 s
|
23
|
+
# Acceleration ratio: 235.83003181336161
|
24
|
+
#
|
25
|
+
# n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.015388727188110352 s
|
26
|
+
# n_pre: 10000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 10.791011333465576 s
|
27
|
+
# Acceleration ratio: 700.2283213262065
|
28
|
+
#
|
29
|
+
# n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.01043844223022461 s
|
30
|
+
# n_pre: 10000, n_post: 1000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 0.8944694995880127 s
|
31
|
+
# Acceleration ratio: 84.68994107167329
|
32
|
+
#
|
33
|
+
# n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.021282196044921875 s
|
34
|
+
# n_pre: 10000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.388156414031982 s
|
35
|
+
# Acceleration ratio: 1003.9788268506901
|
36
|
+
#
|
37
|
+
# n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.025498151779174805 s
|
38
|
+
# n_pre: 20000, n_post: 10000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 21.211663246154785 s
|
39
|
+
# Acceleration ratio: 830.8902259997943
|
40
|
+
#
|
41
|
+
# n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.044051408767700195 s
|
42
|
+
# n_pre: 20000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 42.31502842903137 s
|
43
|
+
# Acceleration ratio: 959.5828647200498
|
44
|
+
#
|
45
|
+
# n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.06666803359985352 s
|
46
|
+
# n_pre: 20000, n_post: 30000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 62.46805453300476 s
|
47
|
+
# Acceleration ratio: 936.0016057162067
|
48
|
+
#
|
49
|
+
# n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Linear: 0.08313393592834473 s
|
50
|
+
# n_pre: 30000, n_post: 20000, conn_prob: 0.01, spk_prob: 0.01, Matmul: 63.61667847633362 s
|
51
|
+
# Acceleration ratio: 764.231163013459
|
52
|
+
#
|
53
|
+
#
|
54
|
+
|
55
|
+
|
56
|
+
import os
|
57
|
+
|
58
|
+
# os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
|
59
|
+
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
60
|
+
|
61
|
+
import jax
|
62
|
+
#
|
63
|
+
# jax.config.update('jax_cpu_enable_async_dispatch', False)
|
64
|
+
|
65
|
+
import time
|
66
|
+
import brainstate as bst
|
67
|
+
|
68
|
+
|
69
|
+
def forward(n_pre, n_post, conn_prob, spk_prob, as_float: bool):
|
70
|
+
linear = bst.event.FixedProb(n_pre, n_post, prob=conn_prob, weight=bst.init.Normal())
|
71
|
+
spike = (bst.random.rand(n_pre) < spk_prob)
|
72
|
+
|
73
|
+
if as_float:
|
74
|
+
spike = spike.astype(float)
|
75
|
+
|
76
|
+
@jax.jit
|
77
|
+
def f1(spike):
|
78
|
+
return linear(spike)
|
79
|
+
|
80
|
+
weight = bst.init.Normal()([n_pre, n_post])
|
81
|
+
|
82
|
+
@jax.jit
|
83
|
+
def f2(spike):
|
84
|
+
return spike @ weight
|
85
|
+
|
86
|
+
y1 = jax.block_until_ready(f1(spike))
|
87
|
+
y2 = jax.block_until_ready(f2(spike))
|
88
|
+
# print('max difference:', jax.numpy.abs(y1 - y2).max())
|
89
|
+
|
90
|
+
n = 1000
|
91
|
+
t0 = time.time()
|
92
|
+
for _ in range(n):
|
93
|
+
jax.block_until_ready(f1(spike))
|
94
|
+
r1 = time.time() - t0
|
95
|
+
print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Linear: {r1} s")
|
96
|
+
|
97
|
+
t0 = time.time()
|
98
|
+
for _ in range(n):
|
99
|
+
jax.block_until_ready(f2(spike))
|
100
|
+
r2 = time.time() - t0
|
101
|
+
print(f"n_pre: {n_pre}, n_post: {n_post}, conn_prob: {conn_prob}, spk_prob: {spk_prob}, Matmul: {r2} s")
|
102
|
+
print('Acceleration ratio:', r2 / r1 - 1.)
|
103
|
+
|
104
|
+
print()
|
105
|
+
bst.util.clear_buffer_memory()
|
106
|
+
|
107
|
+
|
108
|
+
def benchmark_forward():
|
109
|
+
for n_pre, n_post in [
|
110
|
+
(1000, 1000),
|
111
|
+
(1000, 10000),
|
112
|
+
(10000, 10000),
|
113
|
+
(10000, 1000),
|
114
|
+
(10000, 20000),
|
115
|
+
(20000, 10000),
|
116
|
+
(20000, 20000),
|
117
|
+
(20000, 30000),
|
118
|
+
(30000, 20000),
|
119
|
+
]:
|
120
|
+
forward(n_pre, n_post, 0.01, 0.01, False)
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == '__main__':
|
124
|
+
pass
|
125
|
+
# forward(1000, 6400, 0.01, 0.01, False)
|
126
|
+
# forward(10000, 12800, 0.01, 0.01, False)
|
127
|
+
|
128
|
+
benchmark_forward()
|
@@ -15,12 +15,12 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
|
18
19
|
import jax.numpy
|
19
20
|
import jax.numpy as jnp
|
20
21
|
from absl.testing import parameterized
|
21
22
|
|
22
23
|
import brainstate as bst
|
23
|
-
from brainstate.event._fixed_probability import FixedProb
|
24
24
|
|
25
25
|
|
26
26
|
class TestFixedProbCSR(parameterized.TestCase):
|
@@ -30,18 +30,18 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
30
30
|
def test1(self, allow_multi_conn):
|
31
31
|
x = bst.random.rand(20) < 0.1
|
32
32
|
# x = bst.random.rand(20)
|
33
|
-
m = FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
|
33
|
+
m = bst.event.FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
|
34
34
|
y = m(x)
|
35
35
|
print(y)
|
36
36
|
|
37
|
-
m2 = FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
|
37
|
+
m2 = bst.event.FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
|
38
38
|
print(m2(x))
|
39
39
|
|
40
40
|
def test_grad_bool(self):
|
41
41
|
n_in = 20
|
42
42
|
n_out = 30
|
43
43
|
x = bst.random.rand(n_in) < 0.3
|
44
|
-
fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
|
44
|
+
fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
|
45
45
|
|
46
46
|
def f(x):
|
47
47
|
return fn(x).sum()
|
@@ -62,16 +62,16 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
62
62
|
x = bst.random.rand(n_in)
|
63
63
|
|
64
64
|
if homo_w:
|
65
|
-
fn = FixedProb(n_in, n_out, 0.1, 1.5, seed=123)
|
65
|
+
fn = bst.event.FixedProb(n_in, n_out, 0.1, 1.5, seed=123, float_as_event=bool_x)
|
66
66
|
else:
|
67
|
-
fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
|
67
|
+
fn = bst.event.FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123, float_as_event=bool_x)
|
68
68
|
w = fn.weight.value
|
69
69
|
|
70
70
|
def f(x, w):
|
71
71
|
fn.weight.value = w
|
72
72
|
return fn(x).sum()
|
73
73
|
|
74
|
-
r = bst.
|
74
|
+
r = bst.augment.grad(f, argnums=(0, 1))(x, w)
|
75
75
|
|
76
76
|
# -------------------
|
77
77
|
# TRUE gradients
|
@@ -88,7 +88,6 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
88
88
|
r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
89
89
|
self.assertTrue(jnp.allclose(r[0], r2[0]))
|
90
90
|
self.assertTrue(jnp.allclose(r[1], r2[1]))
|
91
|
-
print(r[1])
|
92
91
|
|
93
92
|
@parameterized.product(
|
94
93
|
bool_x=[True, False],
|
@@ -102,7 +101,11 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
102
101
|
else:
|
103
102
|
x = bst.random.rand(n_in)
|
104
103
|
|
105
|
-
fn =
|
104
|
+
fn = bst.event.FixedProb(
|
105
|
+
n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(),
|
106
|
+
seed=123,
|
107
|
+
float_as_event=bool_x
|
108
|
+
)
|
106
109
|
w = fn.weight.value
|
107
110
|
|
108
111
|
def f(x, w):
|
@@ -124,5 +127,5 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
124
127
|
return true_fn(x, w, fn.indices, n_out)
|
125
128
|
|
126
129
|
o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
127
|
-
self.assertTrue(jnp.allclose(r1, r2))
|
128
130
|
self.assertTrue(jnp.allclose(o1, o2))
|
131
|
+
self.assertTrue(jnp.allclose(r1, r2))
|
brainstate/event/_linear.py
CHANGED
@@ -12,21 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
15
16
|
from __future__ import annotations
|
16
17
|
|
17
18
|
from typing import Union, Callable, Optional
|
18
19
|
|
19
20
|
import brainunit as u
|
20
21
|
import jax
|
22
|
+
import jax.experimental.pallas as pl
|
21
23
|
import jax.numpy as jnp
|
22
24
|
import numpy as np
|
25
|
+
from jax.interpreters import ad
|
23
26
|
|
24
27
|
from brainstate._state import ParamState, State
|
25
|
-
from brainstate._utils import set_module_as
|
26
28
|
from brainstate.init import param
|
27
29
|
from brainstate.nn._module import Module
|
28
|
-
from brainstate.typing import ArrayLike
|
29
|
-
from .
|
30
|
+
from brainstate.typing import ArrayLike, Size
|
31
|
+
from ._xla_custom_op import XLACustomOp
|
30
32
|
|
31
33
|
__all__ = [
|
32
34
|
'Linear',
|
@@ -39,12 +41,16 @@ class Linear(Module):
|
|
39
41
|
|
40
42
|
Parameters
|
41
43
|
----------
|
42
|
-
|
43
|
-
Number of pre-synaptic neurons.
|
44
|
-
|
45
|
-
Number of post-synaptic neurons.
|
44
|
+
in_size : Size
|
45
|
+
Number of pre-synaptic neurons, i.e., input size.
|
46
|
+
out_size : Size
|
47
|
+
Number of post-synaptic neurons, i.e., output size.
|
46
48
|
weight : float or callable or jax.Array or brainunit.Quantity
|
47
49
|
Maximum synaptic conductance.
|
50
|
+
block_size : int, optional
|
51
|
+
Block size for parallel computation.
|
52
|
+
float_as_event : bool, optional
|
53
|
+
Whether to treat float as event.
|
48
54
|
name : str, optional
|
49
55
|
Name of the module.
|
50
56
|
"""
|
@@ -53,167 +59,301 @@ class Linear(Module):
|
|
53
59
|
|
54
60
|
def __init__(
|
55
61
|
self,
|
56
|
-
|
57
|
-
|
62
|
+
in_size: Size,
|
63
|
+
out_size: Size,
|
58
64
|
weight: Union[Callable, ArrayLike],
|
65
|
+
float_as_event: bool = True,
|
66
|
+
block_size: int = 64,
|
59
67
|
name: Optional[str] = None,
|
60
|
-
grad_mode: str = 'vjp'
|
61
68
|
):
|
62
69
|
super().__init__(name=name)
|
63
|
-
self.n_pre = n_pre
|
64
|
-
self.n_post = n_post
|
65
|
-
self.in_size = n_pre
|
66
|
-
self.out_size = n_post
|
67
70
|
|
68
|
-
|
69
|
-
self.
|
71
|
+
# network parameters
|
72
|
+
self.in_size = in_size
|
73
|
+
self.out_size = out_size
|
74
|
+
self.float_as_event = float_as_event
|
75
|
+
self.block_size = block_size
|
70
76
|
|
71
77
|
# maximum synaptic conductance
|
72
|
-
weight = param(weight, (self.
|
78
|
+
weight = param(weight, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
73
79
|
self.weight = ParamState(weight)
|
74
80
|
|
75
81
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
76
82
|
weight = self.weight.value if isinstance(self.weight, State) else self.weight
|
77
83
|
if u.math.size(weight) == 1:
|
78
|
-
return u.math.ones(self.
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
n_post=self.n_post,
|
85
|
-
grad_mode=self.grad_mode)
|
86
|
-
elif device_kind in ['gpu', 'tpu']:
|
87
|
-
raise NotImplementedError()
|
88
|
-
else:
|
89
|
-
raise ValueError(f"Unsupported device: {device_kind}")
|
90
|
-
|
91
|
-
|
92
|
-
@set_module_as('brainstate.event')
|
93
|
-
def cpu_event_linear(
|
94
|
-
g_max: Union[u.Quantity, jax.Array],
|
95
|
-
spk: jax.Array,
|
96
|
-
*,
|
97
|
-
n_post: int = None,
|
98
|
-
grad_mode: str = 'vjp'
|
99
|
-
) -> Union[u.Quantity, jax.Array]:
|
84
|
+
return u.math.ones(self.out_size) * (u.math.sum(spk) * weight)
|
85
|
+
|
86
|
+
return event_linear(spk, weight, block_size=self.block_size, float_as_event=self.float_as_event)
|
87
|
+
|
88
|
+
|
89
|
+
def event_linear(spk, weight, *, block_size, float_as_event) -> jax.Array | u.Quantity:
|
100
90
|
"""
|
101
|
-
The
|
91
|
+
The event-driven linear computation.
|
102
92
|
|
103
93
|
Parameters
|
104
94
|
----------
|
105
|
-
|
106
|
-
Number of post-synaptic neurons.
|
107
|
-
g_max : brainunit.Quantity or jax.Array
|
95
|
+
weight : brainunit.Quantity or jax.Array
|
108
96
|
Maximum synaptic conductance.
|
109
97
|
spk : jax.Array
|
110
98
|
Spike events.
|
111
|
-
|
112
|
-
|
99
|
+
block_size : int
|
100
|
+
Block size for parallel computation.
|
101
|
+
float_as_event : bool
|
102
|
+
Whether to treat float as event.
|
113
103
|
|
114
104
|
Returns
|
115
105
|
-------
|
116
106
|
post_data : brainunit.Quantity or jax.Array
|
117
107
|
Post synaptic data.
|
118
108
|
"""
|
119
|
-
|
120
|
-
|
121
|
-
|
109
|
+
with jax.ensure_compile_time_eval():
|
110
|
+
weight = u.math.asarray(weight)
|
111
|
+
unit = u.get_unit(weight)
|
112
|
+
weight = u.get_mantissa(weight)
|
113
|
+
spk = jnp.asarray(spk)
|
122
114
|
|
123
115
|
def mv(spk_vector):
|
124
116
|
assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
|
132
|
-
elif grad_mode == 'jvp':
|
133
|
-
post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
|
134
|
-
else:
|
135
|
-
raise ValueError(f"Unsupported grad_mode: {grad_mode}")
|
136
|
-
return post
|
117
|
+
return event_liner_p_call(
|
118
|
+
spk,
|
119
|
+
weight,
|
120
|
+
block_size=block_size,
|
121
|
+
float_as_event=float_as_event,
|
122
|
+
)
|
137
123
|
|
138
124
|
assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
|
139
|
-
assert
|
125
|
+
assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
|
140
126
|
|
141
127
|
if spk.ndim == 1:
|
142
|
-
post_data = mv(spk)
|
128
|
+
[post_data] = mv(spk)
|
143
129
|
else:
|
144
|
-
|
145
|
-
post_data =
|
146
|
-
post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
|
130
|
+
[post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
|
131
|
+
post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
|
147
132
|
return u.maybe_decimal(u.Quantity(post_data, unit=unit))
|
148
133
|
|
149
134
|
|
150
|
-
|
151
|
-
# Implementation
|
152
|
-
# --------------
|
153
|
-
|
154
|
-
|
155
|
-
def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
|
156
|
-
def scan_fn(post, i):
|
157
|
-
sp = spk[i]
|
158
|
-
if spk.dtype == jnp.bool_:
|
159
|
-
post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
|
160
|
-
else:
|
161
|
-
post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
|
162
|
-
return post, None
|
163
|
-
|
164
|
-
return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
|
165
|
-
|
166
|
-
|
167
|
-
# --------------
|
168
|
-
# VJP
|
169
|
-
# --------------
|
170
|
-
|
171
|
-
def _cpu_event_linear_mv_fwd(g_max, spk):
|
172
|
-
return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
|
135
|
+
Kernel = Callable
|
173
136
|
|
174
137
|
|
175
|
-
def
|
176
|
-
|
138
|
+
def cpu_kernel_generator(
|
139
|
+
float_as_event: bool,
|
140
|
+
spk_info: jax.ShapeDtypeStruct,
|
141
|
+
**kwargs
|
142
|
+
) -> Kernel:
|
143
|
+
import numba # pylint: disable=import-outside-toplevel
|
177
144
|
|
178
|
-
|
179
|
-
ct_spk = jnp.matmul(g_max, ct)
|
145
|
+
if spk_info.dtype == jnp.bool_:
|
180
146
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
147
|
+
@numba.njit
|
148
|
+
def _kernel(spikes, weights, posts):
|
149
|
+
r = np.zeros((weights.shape[1],), dtype=weights.dtype)
|
150
|
+
for i in range(spikes.shape[0]):
|
151
|
+
if spikes[i]:
|
152
|
+
r = r + weights[i]
|
153
|
+
posts[:] = r
|
188
154
|
|
189
|
-
|
190
|
-
|
155
|
+
elif float_as_event:
|
156
|
+
@numba.njit
|
157
|
+
def _kernel(spikes, weights, posts):
|
158
|
+
r = np.zeros((weights.shape[1],), dtype=weights.dtype)
|
159
|
+
for i in range(spikes.shape[0]):
|
160
|
+
if spikes[i] != 0.:
|
161
|
+
r = r + weights[i]
|
162
|
+
posts[:] = r
|
191
163
|
|
164
|
+
else:
|
165
|
+
@numba.njit
|
166
|
+
def _kernel(spikes, weights, posts):
|
167
|
+
r = np.zeros((weights.shape[1],), dtype=weights.dtype)
|
168
|
+
for i in range(spikes.shape[0]):
|
169
|
+
sp = spikes[i]
|
170
|
+
if sp != 0.:
|
171
|
+
r = r + weights[i] * sp
|
172
|
+
posts[:] = r
|
173
|
+
|
174
|
+
return _kernel
|
175
|
+
|
176
|
+
|
177
|
+
def gpu_kernel_generator(
|
178
|
+
block_size: int,
|
179
|
+
float_as_event: bool,
|
180
|
+
weight_info: jax.ShapeDtypeStruct,
|
181
|
+
**kwargs
|
182
|
+
) -> Kernel:
|
183
|
+
# # 每个block处理一个[block_size,]的post
|
184
|
+
# # 每个block处理一个[n_pre]的pre
|
185
|
+
# # 每个block处理一个[n_pre, block_size]的w
|
186
|
+
# def _mv_kernel(sp_ref, w_ref, post_ref):
|
187
|
+
#
|
188
|
+
# pid = pl.program_id(0)
|
189
|
+
#
|
190
|
+
# def scan_fn(i, post_):
|
191
|
+
# if sp_ref.dtype == jnp.bool_:
|
192
|
+
# post_ = jax.lax.cond(
|
193
|
+
# sp_ref[i],
|
194
|
+
# lambda: post_ + w_ref[i, ...],
|
195
|
+
# lambda: post_
|
196
|
+
# )
|
197
|
+
# else:
|
198
|
+
# if float_as_event:
|
199
|
+
# post_ = jax.lax.cond(
|
200
|
+
# sp_ref[i] != 0.,
|
201
|
+
# lambda: post_ + w_ref[i, ...],
|
202
|
+
# lambda: post_
|
203
|
+
# )
|
204
|
+
# else:
|
205
|
+
# sp = sp_ref[i]
|
206
|
+
# post_ = jax.lax.cond(
|
207
|
+
# sp != 0.,
|
208
|
+
# lambda: post_ + w_ref[i, ...] * sp,
|
209
|
+
# lambda: post_
|
210
|
+
# )
|
211
|
+
# return post_
|
212
|
+
#
|
213
|
+
# post = jax.lax.fori_loop(0, n_pre, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
|
214
|
+
# mask = jnp.arange(block_size) + pid * block_size < n_pre
|
215
|
+
# pl.store(post_ref, pl.dslice(None, None), post, mask=mask)
|
216
|
+
#
|
217
|
+
# n_pre = weight_info.shape[0]
|
218
|
+
# n_post = weight_info.shape[1]
|
219
|
+
# kernel = pl.pallas_call(
|
220
|
+
# _mv_kernel,
|
221
|
+
# out_shape=[
|
222
|
+
# jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
|
223
|
+
# ],
|
224
|
+
# out_specs=[
|
225
|
+
# pl.BlockSpec((block_size,), lambda i: i),
|
226
|
+
# ],
|
227
|
+
# in_specs=[
|
228
|
+
# pl.BlockSpec((n_pre,), lambda i: 0),
|
229
|
+
# pl.BlockSpec((n_pre, block_size), lambda i: (0, i)),
|
230
|
+
# ],
|
231
|
+
# grid=(
|
232
|
+
# pl.cdiv(n_post, block_size),
|
233
|
+
# ),
|
234
|
+
# interpret=False,
|
235
|
+
# )
|
236
|
+
# return kernel
|
237
|
+
|
238
|
+
# 每个block处理一个[block_size,]的post
|
239
|
+
# 每个block处理一个[block_size]的pre
|
240
|
+
# 每个block处理一个[block_size, block_size]的w
|
241
|
+
def _mv_kernel(
|
242
|
+
sp_ref, # [block_size]
|
243
|
+
w_ref, # [block_size, block_size]
|
244
|
+
post_ref, # [block_size]
|
245
|
+
):
|
192
246
|
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
247
|
+
r_pid = pl.program_id(0)
|
248
|
+
c_start = pl.program_id(1) * block_size
|
249
|
+
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
250
|
+
mask = jnp.arange(block_size) + c_start < weight_info.shape[1]
|
251
|
+
|
252
|
+
def scan_fn(i, post_):
|
253
|
+
if sp_ref.dtype == jnp.bool_:
|
254
|
+
post_ = jax.lax.cond(
|
255
|
+
sp_ref[i],
|
256
|
+
lambda: post_ + w_ref[i, ...],
|
257
|
+
lambda: post_
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
if float_as_event:
|
261
|
+
post_ = jax.lax.cond(
|
262
|
+
sp_ref[i] != 0.,
|
263
|
+
lambda: post_ + w_ref[i, ...],
|
264
|
+
lambda: post_
|
265
|
+
)
|
266
|
+
else:
|
267
|
+
sp = sp_ref[i]
|
268
|
+
post_ = jax.lax.cond(
|
269
|
+
sp != 0.,
|
270
|
+
lambda: post_ + w_ref[i, ...] * sp,
|
271
|
+
lambda: post_
|
272
|
+
)
|
273
|
+
return post_
|
274
|
+
|
275
|
+
post = jax.lax.fori_loop(0, row_length, scan_fn, jnp.zeros(post_ref.shape, dtype=post_ref.dtype))
|
276
|
+
pl.atomic_add(post_ref, pl.dslice(None, None), post, mask=mask)
|
277
|
+
|
278
|
+
n_pre = weight_info.shape[0]
|
279
|
+
n_post = weight_info.shape[1]
|
280
|
+
kernel = pl.pallas_call(
|
281
|
+
_mv_kernel,
|
282
|
+
out_shape=[
|
283
|
+
jax.ShapeDtypeStruct([weight_info.shape[1]], weight_info.dtype),
|
284
|
+
],
|
285
|
+
out_specs=[
|
286
|
+
pl.BlockSpec((block_size,), lambda i, j: j),
|
287
|
+
],
|
288
|
+
in_specs=[
|
289
|
+
pl.BlockSpec((block_size,), lambda i, j: i),
|
290
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
|
291
|
+
],
|
292
|
+
grid=(
|
293
|
+
pl.cdiv(n_pre, block_size),
|
294
|
+
pl.cdiv(n_post, block_size),
|
295
|
+
),
|
296
|
+
interpret=False,
|
297
|
+
)
|
298
|
+
return kernel
|
299
|
+
|
300
|
+
|
301
|
+
def jvp_spikes(spk_dot, spikes, weights, **kwargs):
|
302
|
+
return [spk_dot @ weights]
|
303
|
+
|
304
|
+
|
305
|
+
def jvp_weights(w_dot, spikes, weights, *, float_as_event, block_size, **kwargs):
|
306
|
+
return event_liner_p_call(
|
307
|
+
spikes,
|
308
|
+
w_dot,
|
309
|
+
block_size=block_size,
|
310
|
+
float_as_event=float_as_event,
|
311
|
+
)
|
312
|
+
|
313
|
+
|
314
|
+
def transpose_rule(ct, spikes, weights, *, float_as_event, **kwargs):
|
315
|
+
if ad.is_undefined_primal(spikes):
|
316
|
+
ct_events = jnp.matmul(weights, ct[0])
|
317
|
+
return (ad.Zero(spikes) if type(ct[0]) is ad.Zero else ct_events), weights
|
217
318
|
|
218
|
-
|
219
|
-
|
319
|
+
else:
|
320
|
+
def map_fn(sp):
|
321
|
+
if spikes.dtype == jnp.bool_:
|
322
|
+
d_gmax = jnp.where(sp, ct[0], jnp.zeros_like(ct[0]))
|
323
|
+
else:
|
324
|
+
if float_as_event:
|
325
|
+
d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0])
|
326
|
+
else:
|
327
|
+
d_gmax = jnp.where(sp == 0., jnp.zeros_like(ct[0]), ct[0] * sp)
|
328
|
+
# d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct[0]), lambda: ct[0] * sp)
|
329
|
+
return d_gmax
|
330
|
+
|
331
|
+
ct_weights = jax.vmap(map_fn)(spikes)
|
332
|
+
return spikes, (ad.Zero(weights) if type(ct[0]) is ad.Zero else ct_weights)
|
333
|
+
|
334
|
+
|
335
|
+
event_linear_p = XLACustomOp(
|
336
|
+
'event_linear',
|
337
|
+
cpu_kernel_generator=cpu_kernel_generator,
|
338
|
+
gpu_kernel_generator=gpu_kernel_generator,
|
339
|
+
)
|
340
|
+
event_linear_p.defjvp(jvp_spikes, jvp_weights)
|
341
|
+
event_linear_p.def_transpose_rule(transpose_rule)
|
342
|
+
|
343
|
+
|
344
|
+
def event_liner_p_call(
|
345
|
+
spikes,
|
346
|
+
weights,
|
347
|
+
*,
|
348
|
+
block_size,
|
349
|
+
float_as_event,
|
350
|
+
):
|
351
|
+
return event_linear_p(
|
352
|
+
spikes,
|
353
|
+
weights,
|
354
|
+
outs=[jax.ShapeDtypeStruct([weights.shape[1]], weights.dtype)],
|
355
|
+
block_size=block_size,
|
356
|
+
float_as_event=float_as_event,
|
357
|
+
spk_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
|
358
|
+
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
359
|
+
)
|