brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/augment/_autograd.py +9 -6
- brainstate/event/__init__.py +4 -2
- brainstate/event/_csr.py +26 -18
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_fixed_probability.py +589 -152
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +13 -10
- brainstate/event/_linear.py +267 -127
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +8 -3
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +18 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -12,23 +12,26 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
15
16
|
from __future__ import annotations
|
16
17
|
|
17
18
|
from typing import Union, Callable, Optional
|
18
19
|
|
19
20
|
import brainunit as u
|
20
21
|
import jax
|
22
|
+
import jax.experimental.pallas as pl
|
21
23
|
import jax.numpy as jnp
|
22
24
|
import numpy as np
|
25
|
+
from jax.interpreters import ad
|
23
26
|
|
24
27
|
from brainstate._state import ParamState
|
25
|
-
from brainstate.
|
26
|
-
from brainstate.compile import for_loop
|
28
|
+
from brainstate.augment import vmap
|
27
29
|
from brainstate.init import param
|
28
30
|
from brainstate.nn._module import Module
|
29
31
|
from brainstate.random import RandomState
|
30
|
-
from brainstate.typing import ArrayLike
|
31
|
-
from ._misc import FloatScalar
|
32
|
+
from brainstate.typing import ArrayLike, Size
|
33
|
+
from ._misc import FloatScalar
|
34
|
+
from ._xla_custom_op import XLACustomOp
|
32
35
|
|
33
36
|
__all__ = [
|
34
37
|
'FixedProb',
|
@@ -41,19 +44,23 @@ class FixedProb(Module):
|
|
41
44
|
|
42
45
|
Parameters
|
43
46
|
----------
|
44
|
-
|
45
|
-
Number of pre-synaptic neurons.
|
46
|
-
|
47
|
-
Number of post-synaptic neurons.
|
47
|
+
in_size : Size
|
48
|
+
Number of pre-synaptic neurons, i.e., input size.
|
49
|
+
out_size : Size
|
50
|
+
Number of post-synaptic neurons, i.e., output size.
|
48
51
|
prob : float
|
49
|
-
Probability of connection.
|
52
|
+
Probability of connection, i.e., connection probability.
|
50
53
|
weight : float or callable or jax.Array or brainunit.Quantity
|
51
|
-
Maximum synaptic conductance.
|
54
|
+
Maximum synaptic conductance, i.e., synaptic weight.
|
52
55
|
allow_multi_conn : bool, optional
|
53
56
|
Whether multiple connections are allowed from a single pre-synaptic neuron.
|
54
57
|
Default is True, meaning that a value of ``a`` can be selected multiple times.
|
55
|
-
|
56
|
-
|
58
|
+
seed: int, optional
|
59
|
+
Random seed. Default is None. If None, the default random seed will be used.
|
60
|
+
float_as_event : bool, optional
|
61
|
+
Whether to treat float as event. Default is True.
|
62
|
+
block_size : int, optional
|
63
|
+
Block size for parallel computation. Default is 64. This is only used for GPU.
|
57
64
|
name : str, optional
|
58
65
|
Name of the module.
|
59
66
|
"""
|
@@ -62,210 +69,640 @@ class FixedProb(Module):
|
|
62
69
|
|
63
70
|
def __init__(
|
64
71
|
self,
|
65
|
-
|
66
|
-
|
72
|
+
in_size: Size,
|
73
|
+
out_size: Size,
|
67
74
|
prob: FloatScalar,
|
68
75
|
weight: Union[Callable, ArrayLike],
|
69
76
|
allow_multi_conn: bool = True,
|
70
77
|
seed: Optional[int] = None,
|
78
|
+
float_as_event: bool = True,
|
79
|
+
block_size: Optional[int] = None,
|
71
80
|
name: Optional[str] = None,
|
72
|
-
grad_mode: str = 'vjp'
|
73
81
|
):
|
74
82
|
super().__init__(name=name)
|
75
|
-
self.n_pre = n_pre
|
76
|
-
self.n_post = n_post
|
77
|
-
self.in_size = n_pre
|
78
|
-
self.out_size = n_post
|
79
83
|
|
80
|
-
|
84
|
+
# network parameters
|
85
|
+
self.in_size = in_size
|
86
|
+
self.out_size = out_size
|
87
|
+
self.n_conn = int(self.out_size[-1] * prob)
|
81
88
|
if self.n_conn < 1:
|
82
|
-
raise ValueError(
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
self.grad_mode = grad_mode
|
89
|
+
raise ValueError(f"The number of connections must be at least 1. "
|
90
|
+
f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
|
91
|
+
self.float_as_event = float_as_event
|
92
|
+
self.block_size = block_size
|
87
93
|
|
88
94
|
# indices of post connected neurons
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
+
with jax.ensure_compile_time_eval():
|
96
|
+
if allow_multi_conn:
|
97
|
+
rng = np.random.RandomState(seed)
|
98
|
+
self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
|
99
|
+
else:
|
100
|
+
rng = RandomState(seed)
|
101
|
+
|
102
|
+
@vmap(rngs=rng)
|
103
|
+
def rand_indices(key):
|
104
|
+
rng.set_key(key)
|
105
|
+
return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
|
106
|
+
|
107
|
+
self.indices = rand_indices(rng.split_key(self.in_size[-1]))
|
108
|
+
self.indices = u.math.asarray(self.indices)
|
95
109
|
|
96
110
|
# maximum synaptic conductance
|
97
|
-
weight = param(weight, (self.
|
111
|
+
weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
|
98
112
|
self.weight = ParamState(weight)
|
99
113
|
|
100
114
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
raise NotImplementedError()
|
110
|
-
else:
|
111
|
-
raise ValueError(f"Unsupported device: {device_kind}")
|
115
|
+
return event_fixed_prob(
|
116
|
+
spk,
|
117
|
+
self.weight.value,
|
118
|
+
self.indices,
|
119
|
+
n_post=self.out_size[-1],
|
120
|
+
block_size=self.block_size,
|
121
|
+
float_as_event=self.float_as_event
|
122
|
+
)
|
112
123
|
|
113
124
|
|
114
|
-
|
115
|
-
def cpu_fixed_prob(
|
116
|
-
indices: jax.Array,
|
117
|
-
weight: Union[u.Quantity, jax.Array],
|
118
|
-
spk: jax.Array,
|
119
|
-
*,
|
120
|
-
n_post: int,
|
121
|
-
grad_mode: str = 'vjp'
|
122
|
-
) -> Union[u.Quantity, jax.Array]:
|
125
|
+
def event_fixed_prob(spk, weight, indices, *, n_post, block_size, float_as_event):
|
123
126
|
"""
|
124
127
|
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
125
128
|
|
126
129
|
Parameters
|
127
130
|
----------
|
128
|
-
n_post : int
|
129
|
-
Number of post-synaptic neurons.
|
130
131
|
weight : brainunit.Quantity or jax.Array
|
131
132
|
Maximum synaptic conductance.
|
132
133
|
spk : jax.Array
|
133
134
|
Spike events.
|
134
|
-
indices : jax.Array
|
135
|
-
Indices of post connected neurons.
|
136
|
-
grad_mode : str, optional
|
137
|
-
Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
|
138
135
|
|
139
136
|
Returns
|
140
137
|
-------
|
141
138
|
post_data : brainunit.Quantity or jax.Array
|
142
139
|
Post synaptic data.
|
143
140
|
"""
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
141
|
+
with jax.ensure_compile_time_eval():
|
142
|
+
weight = u.math.asarray(weight)
|
143
|
+
unit = u.get_unit(weight)
|
144
|
+
weight = u.get_mantissa(weight)
|
145
|
+
indices = jnp.asarray(indices)
|
146
|
+
spk = jnp.asarray(spk)
|
148
147
|
|
149
148
|
def mv(spk_vector):
|
150
149
|
assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
150
|
+
return event_ellmv_p_call(
|
151
|
+
spk,
|
152
|
+
weight,
|
153
|
+
indices,
|
154
|
+
n_post=n_post,
|
155
|
+
block_size=block_size,
|
156
|
+
float_as_event=float_as_event
|
157
|
+
)
|
158
158
|
|
159
159
|
assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
|
160
160
|
assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
|
161
161
|
assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
|
162
162
|
|
163
163
|
if spk.ndim == 1:
|
164
|
-
post_data = mv(spk)
|
164
|
+
[post_data] = mv(spk)
|
165
165
|
else:
|
166
|
-
|
167
|
-
post_data =
|
168
|
-
post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
|
166
|
+
[post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
|
167
|
+
post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
|
169
168
|
return u.maybe_decimal(u.Quantity(post_data, unit=unit))
|
170
169
|
|
171
170
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
171
|
+
Kernel = Callable
|
172
|
+
|
173
|
+
|
174
|
+
def cpu_kernel_generator(
|
175
|
+
float_as_event: bool,
|
176
|
+
weight_info: jax.ShapeDtypeStruct,
|
177
|
+
spike_info: jax.ShapeDtypeStruct,
|
178
|
+
**kwargs
|
179
|
+
):
|
180
|
+
import numba # pylint: disable=import-outside-toplevel
|
181
|
+
|
182
|
+
if weight_info.size == 1:
|
183
|
+
if spike_info.dtype == jnp.bool_:
|
184
|
+
@numba.njit
|
185
|
+
def ell_mv(spikes, weights, indices, posts):
|
186
|
+
posts[:] = 0.
|
187
|
+
w = weights[()]
|
188
|
+
for i in range(spikes.shape[0]):
|
189
|
+
if spikes[i]:
|
190
|
+
for j in range(indices.shape[1]):
|
191
|
+
posts[indices[i, j]] += w
|
192
|
+
|
193
|
+
elif float_as_event:
|
194
|
+
@numba.njit
|
195
|
+
def ell_mv(spikes, weights, indices, posts):
|
196
|
+
posts[:] = 0.
|
197
|
+
w = weights[()]
|
198
|
+
for i in range(spikes.shape[0]):
|
199
|
+
if spikes[i] != 0.:
|
200
|
+
for j in range(indices.shape[1]):
|
201
|
+
posts[indices[i, j]] += w
|
176
202
|
|
177
|
-
def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
|
178
|
-
def scan_fn(post, i):
|
179
|
-
w = g_max if jnp.size(g_max) == 1 else g_max[i]
|
180
|
-
ids = indices[i]
|
181
|
-
sp = spk[i]
|
182
|
-
if spk.dtype == jnp.bool_:
|
183
|
-
post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
|
184
203
|
else:
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
|
196
|
-
return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
|
197
|
-
|
198
|
-
|
199
|
-
def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
|
200
|
-
weight, spk = res
|
204
|
+
@numba.njit
|
205
|
+
def ell_mv(spikes, weights, indices, posts):
|
206
|
+
posts[:] = 0.
|
207
|
+
w = weights[()]
|
208
|
+
for i in range(spikes.shape[0]):
|
209
|
+
sp = spikes[i]
|
210
|
+
if sp != 0.:
|
211
|
+
wsp = w * sp
|
212
|
+
for j in range(indices.shape[1]):
|
213
|
+
posts[indices[i, j]] += wsp
|
201
214
|
|
202
|
-
# ∂L/∂spk = ∂L/∂y * ∂y/∂spk
|
203
|
-
homo = jnp.size(weight) == 1
|
204
|
-
if homo: # homogeneous weight
|
205
|
-
ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
|
206
|
-
else: # heterogeneous weight
|
207
|
-
ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
|
208
|
-
|
209
|
-
# ∂L/∂w = ∂L/∂y * ∂y/∂w
|
210
|
-
if homo: # scalar
|
211
|
-
ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
|
212
|
-
ct_gmax = jnp.inner(ct, ct_gmax)
|
213
215
|
else:
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
216
|
+
if spike_info.dtype == jnp.bool_:
|
217
|
+
@numba.njit
|
218
|
+
def ell_mv(spikes, weights, indices, posts):
|
219
|
+
posts[:] = 0.
|
220
|
+
for i in range(spikes.shape[0]):
|
221
|
+
if spikes[i]:
|
222
|
+
for j in range(indices.shape[1]):
|
223
|
+
posts[indices[i, j]] += weights[i, j]
|
224
|
+
|
225
|
+
elif float_as_event:
|
226
|
+
@numba.njit
|
227
|
+
def ell_mv(spikes, weights, indices, posts):
|
228
|
+
posts[:] = 0.
|
229
|
+
for i in range(spikes.shape[0]):
|
230
|
+
if spikes[i] != 0.:
|
231
|
+
for j in range(indices.shape[1]):
|
232
|
+
posts[indices[i, j]] += weights[i, j]
|
220
233
|
|
221
|
-
|
222
|
-
|
234
|
+
else:
|
235
|
+
@numba.njit
|
236
|
+
def ell_mv(spikes, weights, indices, posts):
|
237
|
+
posts[:] = 0.
|
238
|
+
for i in range(spikes.shape[0]):
|
239
|
+
sp = spikes[i]
|
240
|
+
if sp != 0.:
|
241
|
+
for j in range(indices.shape[1]):
|
242
|
+
posts[indices[i, j]] += weights[i, j] * sp
|
223
243
|
|
244
|
+
return ell_mv
|
224
245
|
|
225
|
-
_cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
|
226
|
-
_cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
|
227
246
|
|
247
|
+
def gpu_kernel_generator(
|
248
|
+
n_pre: int,
|
249
|
+
n_conn: int,
|
250
|
+
n_post: int,
|
251
|
+
block_size: int,
|
252
|
+
float_as_event: bool,
|
253
|
+
weight_info: jax.ShapeDtypeStruct,
|
254
|
+
**kwargs
|
255
|
+
):
|
256
|
+
# 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵,
|
257
|
+
# 这个算子的计算逻辑为:
|
258
|
+
#
|
259
|
+
# - 每个block处理 [block_size] 个事件,每个事件对应一个 pre-synaptic neuron
|
260
|
+
# - 每个block处理 [block_size, block_size] 个 indices 和 weights
|
261
|
+
|
262
|
+
if weight_info.size == 1:
|
263
|
+
def _ell_mv_kernel_homo(
|
264
|
+
sp_ref, # [block_size]
|
265
|
+
ind_ref, # [block_size, block_size]
|
266
|
+
_,
|
267
|
+
y_ref, # [n_post]
|
268
|
+
):
|
269
|
+
r_pid = pl.program_id(0)
|
270
|
+
c_start = pl.program_id(1) * block_size
|
271
|
+
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
272
|
+
mask = jnp.arange(block_size) + c_start < n_conn
|
273
|
+
|
274
|
+
def body_fn(j, _):
|
275
|
+
if sp_ref.dtype == jnp.bool_:
|
276
|
+
def true_fn():
|
277
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
278
|
+
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
|
279
|
+
# y_ref[ind] += 1.0
|
280
|
+
# ind = ind_ref[j, ...]
|
281
|
+
# pl.store(y_ref, ind, 1.0, mask=mask)
|
282
|
+
|
283
|
+
jax.lax.cond(sp_ref[j], true_fn, lambda: None)
|
284
|
+
|
285
|
+
|
286
|
+
else:
|
287
|
+
def true_fn(sp):
|
288
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
289
|
+
if float_as_event:
|
290
|
+
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
|
291
|
+
else:
|
292
|
+
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask)
|
293
|
+
|
294
|
+
sp_ = sp_ref[j]
|
295
|
+
jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
|
296
|
+
|
297
|
+
jax.lax.fori_loop(0, row_length, body_fn, None)
|
298
|
+
|
299
|
+
# homogenous weights
|
300
|
+
kernel = pl.pallas_call(
|
301
|
+
_ell_mv_kernel_homo,
|
302
|
+
out_shape=[
|
303
|
+
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
304
|
+
],
|
305
|
+
in_specs=[
|
306
|
+
pl.BlockSpec((block_size,), lambda i, j: i),
|
307
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
|
308
|
+
pl.BlockSpec((n_post,), lambda i, j: 0)
|
309
|
+
],
|
310
|
+
grid=(
|
311
|
+
pl.cdiv(n_pre, block_size),
|
312
|
+
pl.cdiv(n_conn, block_size),
|
313
|
+
),
|
314
|
+
input_output_aliases={2: 0},
|
315
|
+
interpret=False
|
316
|
+
)
|
317
|
+
return (lambda spikes, weight, indices:
|
318
|
+
[kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight])
|
228
319
|
|
229
|
-
|
230
|
-
|
231
|
-
#
|
320
|
+
else:
|
321
|
+
def _ell_mv_kernel_heter(
|
322
|
+
sp_ref, # [block_size]
|
323
|
+
ind_ref, # [block_size, block_size]
|
324
|
+
w_ref, # [block_size, block_size]
|
325
|
+
_,
|
326
|
+
y_ref, # [n_post]
|
327
|
+
):
|
328
|
+
r_pid = pl.program_id(0)
|
329
|
+
c_start = pl.program_id(1) * block_size
|
330
|
+
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
331
|
+
mask = jnp.arange(block_size) + c_start < n_conn
|
332
|
+
|
333
|
+
def body_fn(j, _):
|
334
|
+
if sp_ref.dtype == jnp.bool_:
|
335
|
+
def true_fn():
|
336
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
337
|
+
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
338
|
+
pl.atomic_add(y_ref, ind, w, mask=mask)
|
339
|
+
|
340
|
+
jax.lax.cond(sp_ref[j], true_fn, lambda: None)
|
341
|
+
else:
|
342
|
+
def true_fn(spk):
|
343
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
344
|
+
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
345
|
+
if not float_as_event:
|
346
|
+
w = w * spk
|
347
|
+
pl.atomic_add(y_ref, ind, w, mask=mask)
|
348
|
+
|
349
|
+
sp_ = sp_ref[j]
|
350
|
+
jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
|
351
|
+
|
352
|
+
jax.lax.fori_loop(0, row_length, body_fn, None)
|
353
|
+
|
354
|
+
# heterogeneous weights
|
355
|
+
kernel = pl.pallas_call(
|
356
|
+
_ell_mv_kernel_heter,
|
357
|
+
out_shape=[
|
358
|
+
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
359
|
+
],
|
360
|
+
in_specs=[
|
361
|
+
pl.BlockSpec((block_size,), lambda i, j: i), # sp_ref
|
362
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
363
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref,
|
364
|
+
pl.BlockSpec((n_post,), lambda i, j: 0)
|
365
|
+
],
|
366
|
+
grid=(
|
367
|
+
pl.cdiv(n_pre, block_size),
|
368
|
+
pl.cdiv(n_conn, block_size),
|
369
|
+
),
|
370
|
+
input_output_aliases={3: 0},
|
371
|
+
interpret=False
|
372
|
+
)
|
373
|
+
return (lambda spikes, weight, indices:
|
374
|
+
kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
|
375
|
+
|
376
|
+
|
377
|
+
def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwargs):
|
378
|
+
return ellmv_p_call(
|
379
|
+
spk_dot,
|
380
|
+
weights,
|
381
|
+
indices,
|
382
|
+
n_post=n_post,
|
383
|
+
block_size=block_size,
|
384
|
+
)
|
385
|
+
|
386
|
+
|
387
|
+
def jvp_weights(w_dot, spikes, weights, indices, *, float_as_event, block_size, n_post, **kwargs):
|
388
|
+
return event_ellmv_p_call(
|
389
|
+
spikes,
|
390
|
+
w_dot,
|
391
|
+
indices,
|
392
|
+
n_post=n_post,
|
393
|
+
block_size=block_size,
|
394
|
+
float_as_event=float_as_event
|
395
|
+
)
|
396
|
+
|
397
|
+
|
398
|
+
def transpose_rule(
|
399
|
+
ct, spikes, weights, indices,
|
400
|
+
*,
|
401
|
+
float_as_event, n_post, n_conn, block_size, weight_info, **kwargs
|
402
|
+
):
|
403
|
+
if ad.is_undefined_primal(indices):
|
404
|
+
raise ValueError("Cannot transpose with respect to sparse indices.")
|
232
405
|
|
406
|
+
ct = ct[0]
|
233
407
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
408
|
+
# ∂L/∂spk = ∂L/∂y * ∂y/∂spk
|
409
|
+
homo = weight_info.size == 1
|
410
|
+
if ad.is_undefined_primal(spikes):
|
411
|
+
if homo:
|
412
|
+
# homogeneous weight
|
413
|
+
ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
|
414
|
+
else:
|
415
|
+
# heterogeneous weight
|
416
|
+
ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
|
417
|
+
return (ad.Zero(spikes) if type(ct) is ad.Zero else ct_spk), weights, indices
|
238
418
|
|
239
|
-
|
240
|
-
|
419
|
+
else:
|
420
|
+
# ∂L/∂w = ∂L/∂y * ∂y/∂w
|
421
|
+
if homo:
|
422
|
+
# scalar
|
423
|
+
ct_gmax = event_ellmv_p_call(
|
424
|
+
spikes,
|
425
|
+
jnp.asarray(1., dtype=weight_info.dtype),
|
426
|
+
indices,
|
427
|
+
n_post=n_post,
|
428
|
+
block_size=block_size,
|
429
|
+
float_as_event=float_as_event
|
430
|
+
)
|
431
|
+
ct_gmax = jnp.inner(ct, ct_gmax[0])
|
432
|
+
else:
|
433
|
+
def map_fn(one_spk, one_ind):
|
434
|
+
if spikes.dtype == jnp.bool_:
|
435
|
+
return jax.lax.cond(
|
436
|
+
one_spk,
|
437
|
+
lambda: ct[one_ind],
|
438
|
+
lambda: jnp.zeros([n_conn], weight_info.dtype)
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
if float_as_event:
|
442
|
+
return jax.lax.cond(
|
443
|
+
one_spk == 0.,
|
444
|
+
lambda: jnp.zeros([n_conn], weight_info.dtype),
|
445
|
+
lambda: ct[one_ind]
|
446
|
+
)
|
447
|
+
else:
|
448
|
+
return jax.lax.cond(
|
449
|
+
one_spk == 0.,
|
450
|
+
lambda: jnp.zeros([n_conn], weight_info.dtype),
|
451
|
+
lambda: ct[one_ind] * one_spk
|
452
|
+
)
|
453
|
+
|
454
|
+
ct_gmax = jax.vmap(map_fn)(spikes, indices)
|
455
|
+
return spikes, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
|
456
|
+
|
457
|
+
|
458
|
+
event_ellmv_p = XLACustomOp(
|
459
|
+
'event_ell_mv',
|
460
|
+
cpu_kernel_generator=cpu_kernel_generator,
|
461
|
+
gpu_kernel_generator=gpu_kernel_generator,
|
462
|
+
)
|
463
|
+
event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
|
464
|
+
event_ellmv_p.def_transpose_rule(transpose_rule)
|
465
|
+
|
466
|
+
|
467
|
+
def event_ellmv_p_call(spikes, weights, indices, *, n_post, block_size, float_as_event):
|
468
|
+
n_conn = indices.shape[1]
|
469
|
+
if block_size is None:
|
470
|
+
if n_conn <= 16:
|
471
|
+
block_size = 16
|
472
|
+
elif n_conn <= 32:
|
473
|
+
block_size = 32
|
474
|
+
elif n_conn <= 64:
|
475
|
+
block_size = 64
|
476
|
+
elif n_conn <= 128:
|
477
|
+
block_size = 128
|
478
|
+
elif n_conn <= 256:
|
479
|
+
block_size = 256
|
480
|
+
else:
|
481
|
+
block_size = 128
|
482
|
+
return event_ellmv_p(
|
483
|
+
spikes,
|
484
|
+
weights,
|
485
|
+
indices,
|
486
|
+
outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)],
|
487
|
+
block_size=block_size,
|
488
|
+
float_as_event=float_as_event,
|
489
|
+
n_pre=spikes.shape[0],
|
490
|
+
n_conn=indices.shape[1],
|
491
|
+
n_post=n_post,
|
492
|
+
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
493
|
+
spike_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
|
494
|
+
)
|
495
|
+
|
496
|
+
|
497
|
+
def ell_cpu_kernel_generator(
|
498
|
+
weight_info: jax.ShapeDtypeStruct,
|
499
|
+
**kwargs
|
500
|
+
):
|
501
|
+
import numba # pylint: disable=import-outside-toplevel
|
502
|
+
|
503
|
+
if jnp.size(weight_info) == 1:
|
504
|
+
@numba.njit
|
505
|
+
def ell_mv(vector, weights, indices, posts):
|
506
|
+
posts[:] = 0.
|
507
|
+
w = weights[()]
|
508
|
+
for i in range(vector.shape[0]):
|
509
|
+
wv = w * vector[i]
|
510
|
+
for j in range(indices.shape[1]):
|
511
|
+
posts[indices[i, j]] += wv
|
241
512
|
|
242
|
-
|
243
|
-
|
513
|
+
else:
|
514
|
+
@numba.njit
|
515
|
+
def ell_mv(vector, weights, indices, posts):
|
516
|
+
posts[:] = 0.
|
517
|
+
for i in range(vector.shape[0]):
|
518
|
+
for j in range(indices.shape[1]):
|
519
|
+
posts[indices[i, j]] += weights[i, j] * vector[i]
|
244
520
|
|
245
|
-
|
246
|
-
ids = indices[i]
|
247
|
-
w = weight if jnp.size(weight) == 1 else weight[i]
|
248
|
-
post = post.at[ids].add(w * spk_dot[i])
|
249
|
-
return post, None
|
521
|
+
return ell_mv
|
250
522
|
|
251
|
-
# ∂y/∂gspk
|
252
|
-
dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
|
253
|
-
return y, dgmax + dspk
|
254
523
|
|
524
|
+
def ell_gpu_kernel_generator(
|
525
|
+
block_size: int,
|
526
|
+
n_pre: int,
|
527
|
+
n_conn: int,
|
528
|
+
n_post: int,
|
529
|
+
weight_info: jax.ShapeDtypeStruct,
|
530
|
+
**kwargs
|
531
|
+
):
|
532
|
+
homo = jnp.size(weight_info) == 1
|
533
|
+
|
534
|
+
if homo:
|
535
|
+
def _kernel(
|
536
|
+
vec_ref, ind_ref, _, out_ref,
|
537
|
+
):
|
538
|
+
# 每个block 处理 [block_size] 大小的vector
|
539
|
+
# 每个block 处理 [block_size, block_size] 大小的indices 和 weights
|
540
|
+
|
541
|
+
# -------------------------------
|
542
|
+
# vec_ref: [block_size]
|
543
|
+
# ind_ref: [block_size, block_size]
|
544
|
+
# out_ref: [n_post]
|
545
|
+
|
546
|
+
r_pid = pl.program_id(0)
|
547
|
+
c_start = pl.program_id(1) * block_size
|
548
|
+
mask = jnp.arange(block_size) + c_start
|
549
|
+
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
550
|
+
|
551
|
+
def body_fn(j, _):
|
552
|
+
y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype)
|
553
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
554
|
+
pl.atomic_add(out_ref, ind, y, mask=mask)
|
555
|
+
|
556
|
+
jax.lax.fori_loop(0, row_length, body_fn, None)
|
557
|
+
|
558
|
+
# heterogeneous weights
|
559
|
+
kernel = pl.pallas_call(
|
560
|
+
_kernel,
|
561
|
+
out_shape=[
|
562
|
+
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
563
|
+
],
|
564
|
+
in_specs=[
|
565
|
+
pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
|
566
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
567
|
+
pl.BlockSpec((n_post,), lambda i, j: 0) # out_ref
|
568
|
+
],
|
569
|
+
grid=(
|
570
|
+
pl.cdiv(n_pre, block_size),
|
571
|
+
pl.cdiv(n_conn, block_size),
|
572
|
+
),
|
573
|
+
input_output_aliases={2: 0},
|
574
|
+
interpret=False
|
575
|
+
)
|
576
|
+
return lambda vector, weight, indices: kernel(vector, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight
|
255
577
|
|
256
|
-
|
257
|
-
|
578
|
+
else:
|
579
|
+
def _kernel(
|
580
|
+
vec_ref, ind_ref, w_ref, _, out_ref,
|
581
|
+
):
|
582
|
+
# 每个block 处理 [block_size] 大小的vector
|
583
|
+
# 每个block 处理 [block_size, n_conn] 大小的indices 和 weights
|
584
|
+
|
585
|
+
# -------------------------------
|
586
|
+
# vec_ref: [block_size]
|
587
|
+
# ind_ref: [block_size, block_size]
|
588
|
+
# w_ref: [block_size, block_size]
|
589
|
+
# out_ref: [n_post]
|
590
|
+
|
591
|
+
r_pid = pl.program_id(0)
|
592
|
+
c_start = pl.program_id(1) * block_size
|
593
|
+
mask = jnp.arange(block_size) + c_start
|
594
|
+
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
595
|
+
|
596
|
+
def body_fn(j, _):
|
597
|
+
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
598
|
+
y = w * vec_ref[j]
|
599
|
+
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
600
|
+
pl.atomic_add(out_ref, ind, y, mask=mask)
|
601
|
+
|
602
|
+
jax.lax.fori_loop(0, row_length, body_fn, None)
|
603
|
+
|
604
|
+
# heterogeneous weights
|
605
|
+
kernel = pl.pallas_call(
|
606
|
+
_kernel,
|
607
|
+
out_shape=[
|
608
|
+
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
609
|
+
],
|
610
|
+
in_specs=[
|
611
|
+
pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
|
612
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
613
|
+
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref
|
614
|
+
pl.BlockSpec((n_post,), lambda i: 0) # out_ref
|
615
|
+
],
|
616
|
+
grid=(
|
617
|
+
pl.cdiv(n_pre, block_size),
|
618
|
+
pl.cdiv(n_conn, block_size),
|
619
|
+
),
|
620
|
+
input_output_aliases={3: 0},
|
621
|
+
interpret=False
|
622
|
+
)
|
623
|
+
return lambda vector, weight, indices: kernel(vector, indices, weight, jnp.zeros(n_post, dtype=weight.dtype))
|
624
|
+
|
625
|
+
|
626
|
+
def jvp_weights_no_spk(w_dot, vector, weights, indices, *, block_size, n_post, **kwargs):
|
627
|
+
return ellmv_p_call(
|
628
|
+
vector,
|
629
|
+
w_dot,
|
630
|
+
indices,
|
631
|
+
block_size=block_size,
|
632
|
+
n_post=n_post,
|
633
|
+
)
|
634
|
+
|
635
|
+
|
636
|
+
def transpose_rule_no_spk(
|
637
|
+
ct, vector, weights, indices,
|
638
|
+
*,
|
639
|
+
n_post, block_size, weight_info, **kwargs
|
640
|
+
):
|
641
|
+
if ad.is_undefined_primal(indices):
|
642
|
+
raise ValueError("Cannot transpose with respect to sparse indices.")
|
258
643
|
|
644
|
+
ct = ct[0]
|
259
645
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
|
646
|
+
# ∂L/∂spk = ∂L/∂y * ∂y/∂spk
|
647
|
+
homo = weight_info.size == 1
|
648
|
+
if ad.is_undefined_primal(vector):
|
649
|
+
if homo:
|
650
|
+
# homogeneous weight
|
651
|
+
ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
|
267
652
|
else:
|
268
|
-
|
269
|
-
|
653
|
+
# heterogeneous weight
|
654
|
+
ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
|
655
|
+
return (ad.Zero(vector) if type(ct) is ad.Zero else ct_spk), weights, indices
|
270
656
|
|
271
|
-
|
657
|
+
else:
|
658
|
+
# ∂L/∂w = ∂L/∂y * ∂y/∂w
|
659
|
+
if homo:
|
660
|
+
# scalar
|
661
|
+
ct_gmax = ellmv_p_call(
|
662
|
+
vector,
|
663
|
+
jnp.asarray(1., dtype=weight_info.dtype),
|
664
|
+
indices,
|
665
|
+
block_size=block_size,
|
666
|
+
n_post=n_post,
|
667
|
+
)
|
668
|
+
ct_gmax = jnp.inner(ct, ct_gmax[0])
|
669
|
+
else:
|
670
|
+
ct_gmax = jax.vmap(lambda vec, one_ind: ct[one_ind] * vec)(vector, indices)
|
671
|
+
return vector, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
|
672
|
+
|
673
|
+
|
674
|
+
ellmv_p = XLACustomOp(
|
675
|
+
'ell_mv',
|
676
|
+
cpu_kernel_generator=ell_cpu_kernel_generator,
|
677
|
+
gpu_kernel_generator=ell_gpu_kernel_generator,
|
678
|
+
)
|
679
|
+
ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
|
680
|
+
ellmv_p.def_transpose_rule(transpose_rule_no_spk)
|
681
|
+
|
682
|
+
|
683
|
+
def ellmv_p_call(vector, weights, indices, *, n_post, block_size):
|
684
|
+
n_conn = indices.shape[1]
|
685
|
+
if block_size is None:
|
686
|
+
if n_conn <= 16:
|
687
|
+
block_size = 16
|
688
|
+
elif n_conn <= 32:
|
689
|
+
block_size = 32
|
690
|
+
elif n_conn <= 64:
|
691
|
+
block_size = 64
|
692
|
+
elif n_conn <= 128:
|
693
|
+
block_size = 128
|
694
|
+
elif n_conn <= 256:
|
695
|
+
block_size = 256
|
696
|
+
else:
|
697
|
+
block_size = 128
|
698
|
+
return ellmv_p(
|
699
|
+
vector,
|
700
|
+
weights,
|
701
|
+
indices,
|
702
|
+
n_post=n_post,
|
703
|
+
n_pre=indices.shape[0],
|
704
|
+
n_conn=indices.shape[1],
|
705
|
+
block_size=block_size,
|
706
|
+
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
707
|
+
outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)]
|
708
|
+
)
|