brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -2
- brainstate/augment/__init__.py +10 -20
- brainstate/compile/__init__.py +18 -37
- brainstate/compile/_make_jaxpr.py +8 -1
- brainstate/compile/_make_jaxpr_test.py +10 -6
- brainstate/compile/_unvmap.py +3 -3
- brainstate/graph/__init__.py +12 -12
- brainstate/nn/_elementwise/_dropout_test.py +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +13 -27
- brainstate/event/__init__.py +0 -27
- brainstate/event/_csr.py +0 -1149
- brainstate/event/_csr_benchmark.py +0 -14
- brainstate/event/_csr_mv.py +0 -303
- brainstate/event/_csr_test.py +0 -277
- brainstate/event/_fixedprob_mv.py +0 -730
- brainstate/event/_fixedprob_mv_benchmark.py +0 -128
- brainstate/event/_fixedprob_mv_test.py +0 -132
- brainstate/event/_linear_mv.py +0 -359
- brainstate/event/_linear_mv_benckmark.py +0 -82
- brainstate/event/_linear_mv_test.py +0 -117
- brainstate/event/_misc.py +0 -34
- brainstate/event/_xla_custom_op.py +0 -317
- brainstate/event/_xla_custom_op_test.py +0 -55
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -1,730 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
from typing import Union, Callable, Optional
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax
|
22
|
-
import jax.experimental.pallas as pl
|
23
|
-
import jax.numpy as jnp
|
24
|
-
import numpy as np
|
25
|
-
from jax.interpreters import ad
|
26
|
-
|
27
|
-
from brainstate import environ
|
28
|
-
from brainstate._state import ParamState
|
29
|
-
from brainstate.augment import vmap
|
30
|
-
from brainstate.init import param
|
31
|
-
from brainstate.nn._module import Module
|
32
|
-
from brainstate.random import RandomState
|
33
|
-
from brainstate.typing import ArrayLike, Size
|
34
|
-
from ._misc import FloatScalar
|
35
|
-
from ._xla_custom_op import XLACustomOp
|
36
|
-
|
37
|
-
__all__ = [
|
38
|
-
'FixedProb',
|
39
|
-
]
|
40
|
-
|
41
|
-
|
42
|
-
class FixedProb(Module):
|
43
|
-
"""
|
44
|
-
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
45
|
-
|
46
|
-
Parameters
|
47
|
-
----------
|
48
|
-
in_size : Size
|
49
|
-
Number of pre-synaptic neurons, i.e., input size.
|
50
|
-
out_size : Size
|
51
|
-
Number of post-synaptic neurons, i.e., output size.
|
52
|
-
prob : float
|
53
|
-
Probability of connection, i.e., connection probability.
|
54
|
-
weight : float or callable or jax.Array or brainunit.Quantity
|
55
|
-
Maximum synaptic conductance, i.e., synaptic weight.
|
56
|
-
allow_multi_conn : bool, optional
|
57
|
-
Whether multiple connections are allowed from a single pre-synaptic neuron.
|
58
|
-
Default is True, meaning that a value of ``a`` can be selected multiple times.
|
59
|
-
seed: int, optional
|
60
|
-
Random seed. Default is None. If None, the default random seed will be used.
|
61
|
-
float_as_event : bool, optional
|
62
|
-
Whether to treat float as event. Default is True.
|
63
|
-
block_size : int, optional
|
64
|
-
Block size for parallel computation. Default is 64. This is only used for GPU.
|
65
|
-
name : str, optional
|
66
|
-
Name of the module.
|
67
|
-
"""
|
68
|
-
|
69
|
-
__module__ = 'brainstate.event'
|
70
|
-
|
71
|
-
def __init__(
|
72
|
-
self,
|
73
|
-
in_size: Size,
|
74
|
-
out_size: Size,
|
75
|
-
prob: FloatScalar,
|
76
|
-
weight: Union[Callable, ArrayLike],
|
77
|
-
allow_multi_conn: bool = True,
|
78
|
-
seed: Optional[int] = None,
|
79
|
-
float_as_event: bool = True,
|
80
|
-
block_size: Optional[int] = None,
|
81
|
-
name: Optional[str] = None,
|
82
|
-
):
|
83
|
-
super().__init__(name=name)
|
84
|
-
|
85
|
-
# network parameters
|
86
|
-
self.in_size = in_size
|
87
|
-
self.out_size = out_size
|
88
|
-
self.n_conn = int(self.out_size[-1] * prob)
|
89
|
-
self.float_as_event = float_as_event
|
90
|
-
self.block_size = block_size
|
91
|
-
|
92
|
-
if self.n_conn > 1:
|
93
|
-
# indices of post connected neurons
|
94
|
-
with jax.ensure_compile_time_eval():
|
95
|
-
if allow_multi_conn:
|
96
|
-
rng = np.random.RandomState(seed)
|
97
|
-
self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
|
98
|
-
else:
|
99
|
-
rng = RandomState(seed)
|
100
|
-
|
101
|
-
@vmap(rngs=rng)
|
102
|
-
def rand_indices(key):
|
103
|
-
rng.set_key(key)
|
104
|
-
return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
|
105
|
-
|
106
|
-
self.indices = rand_indices(rng.split_key(self.in_size[-1]))
|
107
|
-
self.indices = u.math.asarray(self.indices)
|
108
|
-
|
109
|
-
# maximum synaptic conductance
|
110
|
-
weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
|
111
|
-
self.weight = ParamState(weight)
|
112
|
-
|
113
|
-
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
114
|
-
if self.n_conn > 1:
|
115
|
-
r = event_fixed_prob(
|
116
|
-
spk,
|
117
|
-
self.weight.value,
|
118
|
-
self.indices,
|
119
|
-
n_post=self.out_size[-1],
|
120
|
-
block_size=self.block_size,
|
121
|
-
float_as_event=self.float_as_event
|
122
|
-
)
|
123
|
-
else:
|
124
|
-
weight = self.weight.value
|
125
|
-
unit = u.get_unit(weight)
|
126
|
-
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
127
|
-
r = u.maybe_decimal(u.Quantity(r, unit=unit))
|
128
|
-
return u.math.asarray(r, dtype=environ.dftype())
|
129
|
-
|
130
|
-
|
131
|
-
def event_fixed_prob(
|
132
|
-
spk, weight, indices,
|
133
|
-
*,
|
134
|
-
n_post, block_size, float_as_event
|
135
|
-
):
|
136
|
-
"""
|
137
|
-
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
138
|
-
|
139
|
-
Parameters
|
140
|
-
----------
|
141
|
-
weight : brainunit.Quantity or jax.Array
|
142
|
-
Maximum synaptic conductance.
|
143
|
-
spk : jax.Array
|
144
|
-
Spike events.
|
145
|
-
|
146
|
-
Returns
|
147
|
-
-------
|
148
|
-
post_data : brainunit.Quantity or jax.Array
|
149
|
-
Post synaptic data.
|
150
|
-
"""
|
151
|
-
with jax.ensure_compile_time_eval():
|
152
|
-
weight = u.math.asarray(weight)
|
153
|
-
unit = u.get_unit(weight)
|
154
|
-
weight = u.get_mantissa(weight)
|
155
|
-
indices = jnp.asarray(indices)
|
156
|
-
spk = jnp.asarray(spk)
|
157
|
-
|
158
|
-
def mv(spk_vector):
|
159
|
-
assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
|
160
|
-
return event_ellmv_p_call(
|
161
|
-
spk,
|
162
|
-
weight,
|
163
|
-
indices,
|
164
|
-
n_post=n_post,
|
165
|
-
block_size=block_size,
|
166
|
-
float_as_event=float_as_event
|
167
|
-
)
|
168
|
-
|
169
|
-
assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
|
170
|
-
assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
|
171
|
-
assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
|
172
|
-
|
173
|
-
if spk.ndim == 1:
|
174
|
-
[post_data] = mv(spk)
|
175
|
-
else:
|
176
|
-
[post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
|
177
|
-
post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
|
178
|
-
return u.maybe_decimal(u.Quantity(post_data, unit=unit))
|
179
|
-
|
180
|
-
|
181
|
-
Kernel = Callable
|
182
|
-
|
183
|
-
|
184
|
-
def cpu_kernel_generator(
|
185
|
-
float_as_event: bool,
|
186
|
-
weight_info: jax.ShapeDtypeStruct,
|
187
|
-
spike_info: jax.ShapeDtypeStruct,
|
188
|
-
**kwargs
|
189
|
-
):
|
190
|
-
import numba # pylint: disable=import-outside-toplevel
|
191
|
-
|
192
|
-
if weight_info.size == 1:
|
193
|
-
if spike_info.dtype == jnp.bool_:
|
194
|
-
@numba.njit
|
195
|
-
def ell_mv(spikes, weights, indices, posts):
|
196
|
-
posts[:] = 0.
|
197
|
-
w = weights[()]
|
198
|
-
for i in range(spikes.shape[0]):
|
199
|
-
if spikes[i]:
|
200
|
-
for j in range(indices.shape[1]):
|
201
|
-
posts[indices[i, j]] += w
|
202
|
-
|
203
|
-
elif float_as_event:
|
204
|
-
@numba.njit
|
205
|
-
def ell_mv(spikes, weights, indices, posts):
|
206
|
-
posts[:] = 0.
|
207
|
-
w = weights[()]
|
208
|
-
for i in range(spikes.shape[0]):
|
209
|
-
if spikes[i] != 0.:
|
210
|
-
for j in range(indices.shape[1]):
|
211
|
-
posts[indices[i, j]] += w
|
212
|
-
|
213
|
-
else:
|
214
|
-
@numba.njit
|
215
|
-
def ell_mv(spikes, weights, indices, posts):
|
216
|
-
posts[:] = 0.
|
217
|
-
w = weights[()]
|
218
|
-
for i in range(spikes.shape[0]):
|
219
|
-
sp = spikes[i]
|
220
|
-
if sp != 0.:
|
221
|
-
wsp = w * sp
|
222
|
-
for j in range(indices.shape[1]):
|
223
|
-
posts[indices[i, j]] += wsp
|
224
|
-
|
225
|
-
else:
|
226
|
-
if spike_info.dtype == jnp.bool_:
|
227
|
-
@numba.njit
|
228
|
-
def ell_mv(spikes, weights, indices, posts):
|
229
|
-
posts[:] = 0.
|
230
|
-
for i in range(spikes.shape[0]):
|
231
|
-
if spikes[i]:
|
232
|
-
for j in range(indices.shape[1]):
|
233
|
-
posts[indices[i, j]] += weights[i, j]
|
234
|
-
|
235
|
-
elif float_as_event:
|
236
|
-
@numba.njit
|
237
|
-
def ell_mv(spikes, weights, indices, posts):
|
238
|
-
posts[:] = 0.
|
239
|
-
for i in range(spikes.shape[0]):
|
240
|
-
if spikes[i] != 0.:
|
241
|
-
for j in range(indices.shape[1]):
|
242
|
-
posts[indices[i, j]] += weights[i, j]
|
243
|
-
|
244
|
-
else:
|
245
|
-
@numba.njit
|
246
|
-
def ell_mv(spikes, weights, indices, posts):
|
247
|
-
posts[:] = 0.
|
248
|
-
for i in range(spikes.shape[0]):
|
249
|
-
sp = spikes[i]
|
250
|
-
if sp != 0.:
|
251
|
-
for j in range(indices.shape[1]):
|
252
|
-
posts[indices[i, j]] += weights[i, j] * sp
|
253
|
-
|
254
|
-
return ell_mv
|
255
|
-
|
256
|
-
|
257
|
-
def gpu_kernel_generator(
|
258
|
-
n_pre: int,
|
259
|
-
n_conn: int,
|
260
|
-
n_post: int,
|
261
|
-
block_size: int,
|
262
|
-
float_as_event: bool,
|
263
|
-
weight_info: jax.ShapeDtypeStruct,
|
264
|
-
**kwargs
|
265
|
-
):
|
266
|
-
# 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵,
|
267
|
-
# 这个算子的计算逻辑为:
|
268
|
-
#
|
269
|
-
# - 每个block处理 [block_size] 个事件,每个事件对应一个 pre-synaptic neuron
|
270
|
-
# - 每个block处理 [block_size, block_size] 个 indices 和 weights
|
271
|
-
|
272
|
-
if weight_info.size == 1:
|
273
|
-
def _ell_mv_kernel_homo(
|
274
|
-
sp_ref, # [block_size]
|
275
|
-
ind_ref, # [block_size, block_size]
|
276
|
-
_,
|
277
|
-
y_ref, # [n_post]
|
278
|
-
):
|
279
|
-
r_pid = pl.program_id(0)
|
280
|
-
c_start = pl.program_id(1) * block_size
|
281
|
-
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
282
|
-
mask = jnp.arange(block_size) + c_start < n_conn
|
283
|
-
|
284
|
-
def body_fn(j, _):
|
285
|
-
if sp_ref.dtype == jnp.bool_:
|
286
|
-
def true_fn():
|
287
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
288
|
-
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
|
289
|
-
# y_ref[ind] += 1.0
|
290
|
-
# ind = ind_ref[j, ...]
|
291
|
-
# pl.store(y_ref, ind, 1.0, mask=mask)
|
292
|
-
|
293
|
-
jax.lax.cond(sp_ref[j], true_fn, lambda: None)
|
294
|
-
|
295
|
-
|
296
|
-
else:
|
297
|
-
def true_fn(sp):
|
298
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
299
|
-
if float_as_event:
|
300
|
-
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
|
301
|
-
else:
|
302
|
-
pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask)
|
303
|
-
|
304
|
-
sp_ = sp_ref[j]
|
305
|
-
jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
|
306
|
-
|
307
|
-
jax.lax.fori_loop(0, row_length, body_fn, None)
|
308
|
-
|
309
|
-
# homogenous weights
|
310
|
-
kernel = pl.pallas_call(
|
311
|
-
_ell_mv_kernel_homo,
|
312
|
-
out_shape=[
|
313
|
-
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
314
|
-
],
|
315
|
-
in_specs=[
|
316
|
-
pl.BlockSpec((block_size,), lambda i, j: i),
|
317
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
|
318
|
-
pl.BlockSpec((n_post,), lambda i, j: 0)
|
319
|
-
],
|
320
|
-
grid=(
|
321
|
-
pl.cdiv(n_pre, block_size),
|
322
|
-
pl.cdiv(n_conn, block_size),
|
323
|
-
),
|
324
|
-
input_output_aliases={2: 0},
|
325
|
-
interpret=False
|
326
|
-
)
|
327
|
-
return (lambda spikes, weight, indices:
|
328
|
-
[kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight])
|
329
|
-
|
330
|
-
else:
|
331
|
-
def _ell_mv_kernel_heter(
|
332
|
-
sp_ref, # [block_size]
|
333
|
-
ind_ref, # [block_size, block_size]
|
334
|
-
w_ref, # [block_size, block_size]
|
335
|
-
_,
|
336
|
-
y_ref, # [n_post]
|
337
|
-
):
|
338
|
-
r_pid = pl.program_id(0)
|
339
|
-
c_start = pl.program_id(1) * block_size
|
340
|
-
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
341
|
-
mask = jnp.arange(block_size) + c_start < n_conn
|
342
|
-
|
343
|
-
def body_fn(j, _):
|
344
|
-
if sp_ref.dtype == jnp.bool_:
|
345
|
-
def true_fn():
|
346
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
347
|
-
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
348
|
-
pl.atomic_add(y_ref, ind, w, mask=mask)
|
349
|
-
|
350
|
-
jax.lax.cond(sp_ref[j], true_fn, lambda: None)
|
351
|
-
else:
|
352
|
-
def true_fn(spk):
|
353
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
354
|
-
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
355
|
-
if not float_as_event:
|
356
|
-
w = w * spk
|
357
|
-
pl.atomic_add(y_ref, ind, w, mask=mask)
|
358
|
-
|
359
|
-
sp_ = sp_ref[j]
|
360
|
-
jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
|
361
|
-
|
362
|
-
jax.lax.fori_loop(0, row_length, body_fn, None)
|
363
|
-
|
364
|
-
# heterogeneous weights
|
365
|
-
kernel = pl.pallas_call(
|
366
|
-
_ell_mv_kernel_heter,
|
367
|
-
out_shape=[
|
368
|
-
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
369
|
-
],
|
370
|
-
in_specs=[
|
371
|
-
pl.BlockSpec((block_size,), lambda i, j: i), # sp_ref
|
372
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
373
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref,
|
374
|
-
pl.BlockSpec((n_post,), lambda i, j: 0)
|
375
|
-
],
|
376
|
-
grid=(
|
377
|
-
pl.cdiv(n_pre, block_size),
|
378
|
-
pl.cdiv(n_conn, block_size),
|
379
|
-
),
|
380
|
-
input_output_aliases={3: 0},
|
381
|
-
interpret=False
|
382
|
-
)
|
383
|
-
return (lambda spikes, weight, indices:
|
384
|
-
kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
|
385
|
-
|
386
|
-
|
387
|
-
def jvp_spikes(
|
388
|
-
spk_dot, spikes, weights, indices,
|
389
|
-
*,
|
390
|
-
n_post, block_size, **kwargs
|
391
|
-
):
|
392
|
-
return ellmv_p_call(
|
393
|
-
spk_dot,
|
394
|
-
weights,
|
395
|
-
indices,
|
396
|
-
n_post=n_post,
|
397
|
-
block_size=block_size,
|
398
|
-
)
|
399
|
-
|
400
|
-
|
401
|
-
def jvp_weights(
|
402
|
-
w_dot, spikes, weights, indices,
|
403
|
-
*,
|
404
|
-
float_as_event, block_size, n_post, **kwargs
|
405
|
-
):
|
406
|
-
return event_ellmv_p_call(
|
407
|
-
spikes,
|
408
|
-
w_dot,
|
409
|
-
indices,
|
410
|
-
n_post=n_post,
|
411
|
-
block_size=block_size,
|
412
|
-
float_as_event=float_as_event
|
413
|
-
)
|
414
|
-
|
415
|
-
|
416
|
-
def transpose_rule(
|
417
|
-
ct, spikes, weights, indices,
|
418
|
-
*,
|
419
|
-
float_as_event, n_post, n_conn, block_size, weight_info, **kwargs
|
420
|
-
):
|
421
|
-
if ad.is_undefined_primal(indices):
|
422
|
-
raise ValueError("Cannot transpose with respect to sparse indices.")
|
423
|
-
|
424
|
-
ct = ct[0]
|
425
|
-
|
426
|
-
# ∂L/∂spk = ∂L/∂y * ∂y/∂spk
|
427
|
-
homo = weight_info.size == 1
|
428
|
-
if ad.is_undefined_primal(spikes):
|
429
|
-
if homo:
|
430
|
-
# homogeneous weight
|
431
|
-
ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
|
432
|
-
else:
|
433
|
-
# heterogeneous weight
|
434
|
-
ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
|
435
|
-
return (ad.Zero(spikes) if type(ct) is ad.Zero else ct_spk), weights, indices
|
436
|
-
|
437
|
-
else:
|
438
|
-
# ∂L/∂w = ∂L/∂y * ∂y/∂w
|
439
|
-
if homo:
|
440
|
-
# scalar
|
441
|
-
ct_gmax = event_ellmv_p_call(
|
442
|
-
spikes,
|
443
|
-
jnp.asarray(1., dtype=weight_info.dtype),
|
444
|
-
indices,
|
445
|
-
n_post=n_post,
|
446
|
-
block_size=block_size,
|
447
|
-
float_as_event=float_as_event
|
448
|
-
)
|
449
|
-
ct_gmax = jnp.inner(ct, ct_gmax[0])
|
450
|
-
else:
|
451
|
-
def map_fn(one_spk, one_ind):
|
452
|
-
if spikes.dtype == jnp.bool_:
|
453
|
-
return jax.lax.cond(
|
454
|
-
one_spk,
|
455
|
-
lambda: ct[one_ind],
|
456
|
-
lambda: jnp.zeros([n_conn], weight_info.dtype)
|
457
|
-
)
|
458
|
-
else:
|
459
|
-
if float_as_event:
|
460
|
-
return jax.lax.cond(
|
461
|
-
one_spk == 0.,
|
462
|
-
lambda: jnp.zeros([n_conn], weight_info.dtype),
|
463
|
-
lambda: ct[one_ind]
|
464
|
-
)
|
465
|
-
else:
|
466
|
-
return jax.lax.cond(
|
467
|
-
one_spk == 0.,
|
468
|
-
lambda: jnp.zeros([n_conn], weight_info.dtype),
|
469
|
-
lambda: ct[one_ind] * one_spk
|
470
|
-
)
|
471
|
-
|
472
|
-
ct_gmax = jax.vmap(map_fn)(spikes, indices)
|
473
|
-
return spikes, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
|
474
|
-
|
475
|
-
|
476
|
-
event_ellmv_p = XLACustomOp(
|
477
|
-
'event_ell_mv',
|
478
|
-
cpu_kernel_or_generator=cpu_kernel_generator,
|
479
|
-
gpu_kernel_or_generator=gpu_kernel_generator,
|
480
|
-
)
|
481
|
-
event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
|
482
|
-
event_ellmv_p.def_transpose_rule(transpose_rule)
|
483
|
-
|
484
|
-
|
485
|
-
def event_ellmv_p_call(
|
486
|
-
spikes, weights, indices,
|
487
|
-
*,
|
488
|
-
n_post, block_size, float_as_event
|
489
|
-
):
|
490
|
-
n_conn = indices.shape[1]
|
491
|
-
if block_size is None:
|
492
|
-
if n_conn <= 16:
|
493
|
-
block_size = 16
|
494
|
-
elif n_conn <= 32:
|
495
|
-
block_size = 32
|
496
|
-
elif n_conn <= 64:
|
497
|
-
block_size = 64
|
498
|
-
elif n_conn <= 128:
|
499
|
-
block_size = 128
|
500
|
-
elif n_conn <= 256:
|
501
|
-
block_size = 256
|
502
|
-
else:
|
503
|
-
block_size = 128
|
504
|
-
return event_ellmv_p(
|
505
|
-
spikes,
|
506
|
-
weights,
|
507
|
-
indices,
|
508
|
-
outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)],
|
509
|
-
block_size=block_size,
|
510
|
-
float_as_event=float_as_event,
|
511
|
-
n_pre=spikes.shape[0],
|
512
|
-
n_conn=indices.shape[1],
|
513
|
-
n_post=n_post,
|
514
|
-
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
515
|
-
spike_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
|
516
|
-
)
|
517
|
-
|
518
|
-
|
519
|
-
def ell_cpu_kernel_generator(
|
520
|
-
weight_info: jax.ShapeDtypeStruct,
|
521
|
-
**kwargs
|
522
|
-
):
|
523
|
-
import numba # pylint: disable=import-outside-toplevel
|
524
|
-
|
525
|
-
if jnp.size(weight_info) == 1:
|
526
|
-
@numba.njit
|
527
|
-
def ell_mv(vector, weights, indices, posts):
|
528
|
-
posts[:] = 0.
|
529
|
-
w = weights[()]
|
530
|
-
for i in range(vector.shape[0]):
|
531
|
-
wv = w * vector[i]
|
532
|
-
for j in range(indices.shape[1]):
|
533
|
-
posts[indices[i, j]] += wv
|
534
|
-
|
535
|
-
else:
|
536
|
-
@numba.njit
|
537
|
-
def ell_mv(vector, weights, indices, posts):
|
538
|
-
posts[:] = 0.
|
539
|
-
for i in range(vector.shape[0]):
|
540
|
-
for j in range(indices.shape[1]):
|
541
|
-
posts[indices[i, j]] += weights[i, j] * vector[i]
|
542
|
-
|
543
|
-
return ell_mv
|
544
|
-
|
545
|
-
|
546
|
-
def ell_gpu_kernel_generator(
|
547
|
-
block_size: int,
|
548
|
-
n_pre: int,
|
549
|
-
n_conn: int,
|
550
|
-
n_post: int,
|
551
|
-
weight_info: jax.ShapeDtypeStruct,
|
552
|
-
**kwargs
|
553
|
-
):
|
554
|
-
homo = jnp.size(weight_info) == 1
|
555
|
-
|
556
|
-
if homo:
|
557
|
-
def _kernel(
|
558
|
-
vec_ref, ind_ref, _, out_ref,
|
559
|
-
):
|
560
|
-
# 每个block 处理 [block_size] 大小的vector
|
561
|
-
# 每个block 处理 [block_size, block_size] 大小的indices 和 weights
|
562
|
-
|
563
|
-
# -------------------------------
|
564
|
-
# vec_ref: [block_size]
|
565
|
-
# ind_ref: [block_size, block_size]
|
566
|
-
# out_ref: [n_post]
|
567
|
-
|
568
|
-
r_pid = pl.program_id(0)
|
569
|
-
c_start = pl.program_id(1) * block_size
|
570
|
-
mask = jnp.arange(block_size) + c_start
|
571
|
-
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
572
|
-
|
573
|
-
def body_fn(j, _):
|
574
|
-
y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype)
|
575
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
576
|
-
pl.atomic_add(out_ref, ind, y, mask=mask)
|
577
|
-
|
578
|
-
jax.lax.fori_loop(0, row_length, body_fn, None)
|
579
|
-
|
580
|
-
# heterogeneous weights
|
581
|
-
kernel = pl.pallas_call(
|
582
|
-
_kernel,
|
583
|
-
out_shape=[
|
584
|
-
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
585
|
-
],
|
586
|
-
in_specs=[
|
587
|
-
pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
|
588
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
589
|
-
pl.BlockSpec((n_post,), lambda i, j: 0) # out_ref
|
590
|
-
],
|
591
|
-
grid=(
|
592
|
-
pl.cdiv(n_pre, block_size),
|
593
|
-
pl.cdiv(n_conn, block_size),
|
594
|
-
),
|
595
|
-
input_output_aliases={2: 0},
|
596
|
-
interpret=False
|
597
|
-
)
|
598
|
-
return lambda vector, weight, indices: kernel(vector, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight
|
599
|
-
|
600
|
-
else:
|
601
|
-
def _kernel(
|
602
|
-
vec_ref, ind_ref, w_ref, _, out_ref,
|
603
|
-
):
|
604
|
-
# 每个block 处理 [block_size] 大小的vector
|
605
|
-
# 每个block 处理 [block_size, n_conn] 大小的indices 和 weights
|
606
|
-
|
607
|
-
# -------------------------------
|
608
|
-
# vec_ref: [block_size]
|
609
|
-
# ind_ref: [block_size, block_size]
|
610
|
-
# w_ref: [block_size, block_size]
|
611
|
-
# out_ref: [n_post]
|
612
|
-
|
613
|
-
r_pid = pl.program_id(0)
|
614
|
-
c_start = pl.program_id(1) * block_size
|
615
|
-
mask = jnp.arange(block_size) + c_start
|
616
|
-
row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
|
617
|
-
|
618
|
-
def body_fn(j, _):
|
619
|
-
w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
|
620
|
-
y = w * vec_ref[j]
|
621
|
-
ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
|
622
|
-
pl.atomic_add(out_ref, ind, y, mask=mask)
|
623
|
-
|
624
|
-
jax.lax.fori_loop(0, row_length, body_fn, None)
|
625
|
-
|
626
|
-
# heterogeneous weights
|
627
|
-
kernel = pl.pallas_call(
|
628
|
-
_kernel,
|
629
|
-
out_shape=[
|
630
|
-
jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
|
631
|
-
],
|
632
|
-
in_specs=[
|
633
|
-
pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
|
634
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
|
635
|
-
pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref
|
636
|
-
pl.BlockSpec((n_post,), lambda i: 0) # out_ref
|
637
|
-
],
|
638
|
-
grid=(
|
639
|
-
pl.cdiv(n_pre, block_size),
|
640
|
-
pl.cdiv(n_conn, block_size),
|
641
|
-
),
|
642
|
-
input_output_aliases={3: 0},
|
643
|
-
interpret=False
|
644
|
-
)
|
645
|
-
return lambda vector, weight, indices: kernel(vector, indices, weight, jnp.zeros(n_post, dtype=weight.dtype))
|
646
|
-
|
647
|
-
|
648
|
-
def jvp_weights_no_spk(w_dot, vector, weights, indices, *, block_size, n_post, **kwargs):
|
649
|
-
return ellmv_p_call(
|
650
|
-
vector,
|
651
|
-
w_dot,
|
652
|
-
indices,
|
653
|
-
block_size=block_size,
|
654
|
-
n_post=n_post,
|
655
|
-
)
|
656
|
-
|
657
|
-
|
658
|
-
def transpose_rule_no_spk(
|
659
|
-
ct, vector, weights, indices,
|
660
|
-
*,
|
661
|
-
n_post, block_size, weight_info, **kwargs
|
662
|
-
):
|
663
|
-
if ad.is_undefined_primal(indices):
|
664
|
-
raise ValueError("Cannot transpose with respect to sparse indices.")
|
665
|
-
|
666
|
-
ct = ct[0]
|
667
|
-
|
668
|
-
# ∂L/∂spk = ∂L/∂y * ∂y/∂spk
|
669
|
-
homo = weight_info.size == 1
|
670
|
-
if ad.is_undefined_primal(vector):
|
671
|
-
if homo:
|
672
|
-
# homogeneous weight
|
673
|
-
ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
|
674
|
-
else:
|
675
|
-
# heterogeneous weight
|
676
|
-
ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
|
677
|
-
return (ad.Zero(vector) if type(ct) is ad.Zero else ct_spk), weights, indices
|
678
|
-
|
679
|
-
else:
|
680
|
-
# ∂L/∂w = ∂L/∂y * ∂y/∂w
|
681
|
-
if homo:
|
682
|
-
# scalar
|
683
|
-
ct_gmax = ellmv_p_call(
|
684
|
-
vector,
|
685
|
-
jnp.asarray(1., dtype=weight_info.dtype),
|
686
|
-
indices,
|
687
|
-
block_size=block_size,
|
688
|
-
n_post=n_post,
|
689
|
-
)
|
690
|
-
ct_gmax = jnp.inner(ct, ct_gmax[0])
|
691
|
-
else:
|
692
|
-
ct_gmax = jax.vmap(lambda vec, one_ind: ct[one_ind] * vec)(vector, indices)
|
693
|
-
return vector, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
|
694
|
-
|
695
|
-
|
696
|
-
ellmv_p = XLACustomOp(
|
697
|
-
'ell_mv',
|
698
|
-
cpu_kernel_or_generator=ell_cpu_kernel_generator,
|
699
|
-
gpu_kernel_or_generator=ell_gpu_kernel_generator,
|
700
|
-
)
|
701
|
-
ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
|
702
|
-
ellmv_p.def_transpose_rule(transpose_rule_no_spk)
|
703
|
-
|
704
|
-
|
705
|
-
def ellmv_p_call(vector, weights, indices, *, n_post, block_size):
|
706
|
-
n_conn = indices.shape[1]
|
707
|
-
if block_size is None:
|
708
|
-
if n_conn <= 16:
|
709
|
-
block_size = 16
|
710
|
-
elif n_conn <= 32:
|
711
|
-
block_size = 32
|
712
|
-
elif n_conn <= 64:
|
713
|
-
block_size = 64
|
714
|
-
elif n_conn <= 128:
|
715
|
-
block_size = 128
|
716
|
-
elif n_conn <= 256:
|
717
|
-
block_size = 256
|
718
|
-
else:
|
719
|
-
block_size = 128
|
720
|
-
return ellmv_p(
|
721
|
-
vector,
|
722
|
-
weights,
|
723
|
-
indices,
|
724
|
-
n_post=n_post,
|
725
|
-
n_pre=indices.shape[0],
|
726
|
-
n_conn=indices.shape[1],
|
727
|
-
block_size=block_size,
|
728
|
-
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
729
|
-
outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)]
|
730
|
-
)
|