brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/augment/_autograd.py +9 -6
  4. brainstate/event/__init__.py +4 -2
  5. brainstate/event/_csr.py +26 -18
  6. brainstate/event/_csr_benchmark.py +14 -0
  7. brainstate/event/_fixed_probability.py +589 -152
  8. brainstate/event/_fixed_probability_benchmark.py +128 -0
  9. brainstate/event/_fixed_probability_test.py +13 -10
  10. brainstate/event/_linear.py +267 -127
  11. brainstate/event/_linear_benckmark.py +82 -0
  12. brainstate/event/_linear_test.py +8 -3
  13. brainstate/event/_xla_custom_op.py +312 -0
  14. brainstate/event/_xla_custom_op_test.py +55 -0
  15. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  16. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  17. brainstate/nn/_dynamics/_projection_base.py +1 -1
  18. brainstate/nn/_exp_euler.py +1 -1
  19. brainstate/nn/_interaction/__init__.py +13 -4
  20. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  21. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  22. brainstate/nn/_interaction/_linear.py +582 -0
  23. brainstate/nn/_interaction/_linear_test.py +42 -0
  24. brainstate/optim/_lr_scheduler.py +1 -1
  25. brainstate/optim/_optax_optimizer.py +18 -0
  26. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
  27. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
  28. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -12,23 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  from __future__ import annotations
16
17
 
17
18
  from typing import Union, Callable, Optional
18
19
 
19
20
  import brainunit as u
20
21
  import jax
22
+ import jax.experimental.pallas as pl
21
23
  import jax.numpy as jnp
22
24
  import numpy as np
25
+ from jax.interpreters import ad
23
26
 
24
27
  from brainstate._state import ParamState
25
- from brainstate._utils import set_module_as
26
- from brainstate.compile import for_loop
28
+ from brainstate.augment import vmap
27
29
  from brainstate.init import param
28
30
  from brainstate.nn._module import Module
29
31
  from brainstate.random import RandomState
30
- from brainstate.typing import ArrayLike
31
- from ._misc import FloatScalar, IntScalar
32
+ from brainstate.typing import ArrayLike, Size
33
+ from ._misc import FloatScalar
34
+ from ._xla_custom_op import XLACustomOp
32
35
 
33
36
  __all__ = [
34
37
  'FixedProb',
@@ -41,19 +44,23 @@ class FixedProb(Module):
41
44
 
42
45
  Parameters
43
46
  ----------
44
- n_pre : int
45
- Number of pre-synaptic neurons.
46
- n_post : int
47
- Number of post-synaptic neurons.
47
+ in_size : Size
48
+ Number of pre-synaptic neurons, i.e., input size.
49
+ out_size : Size
50
+ Number of post-synaptic neurons, i.e., output size.
48
51
  prob : float
49
- Probability of connection.
52
+ Probability of connection, i.e., connection probability.
50
53
  weight : float or callable or jax.Array or brainunit.Quantity
51
- Maximum synaptic conductance.
54
+ Maximum synaptic conductance, i.e., synaptic weight.
52
55
  allow_multi_conn : bool, optional
53
56
  Whether multiple connections are allowed from a single pre-synaptic neuron.
54
57
  Default is True, meaning that a value of ``a`` can be selected multiple times.
55
- prob : float
56
- Probability of connection.
58
+ seed: int, optional
59
+ Random seed. Default is None. If None, the default random seed will be used.
60
+ float_as_event : bool, optional
61
+ Whether to treat float as event. Default is True.
62
+ block_size : int, optional
63
+ Block size for parallel computation. Default is 64. This is only used for GPU.
57
64
  name : str, optional
58
65
  Name of the module.
59
66
  """
@@ -62,210 +69,640 @@ class FixedProb(Module):
62
69
 
63
70
  def __init__(
64
71
  self,
65
- n_pre: IntScalar,
66
- n_post: IntScalar,
72
+ in_size: Size,
73
+ out_size: Size,
67
74
  prob: FloatScalar,
68
75
  weight: Union[Callable, ArrayLike],
69
76
  allow_multi_conn: bool = True,
70
77
  seed: Optional[int] = None,
78
+ float_as_event: bool = True,
79
+ block_size: Optional[int] = None,
71
80
  name: Optional[str] = None,
72
- grad_mode: str = 'vjp'
73
81
  ):
74
82
  super().__init__(name=name)
75
- self.n_pre = n_pre
76
- self.n_post = n_post
77
- self.in_size = n_pre
78
- self.out_size = n_post
79
83
 
80
- self.n_conn = int(n_post * prob)
84
+ # network parameters
85
+ self.in_size = in_size
86
+ self.out_size = out_size
87
+ self.n_conn = int(self.out_size[-1] * prob)
81
88
  if self.n_conn < 1:
82
- raise ValueError(
83
- f"The number of connections must be at least 1. Got: int({n_post} * {prob}) = {self.n_conn}")
84
-
85
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
86
- self.grad_mode = grad_mode
89
+ raise ValueError(f"The number of connections must be at least 1. "
90
+ f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
91
+ self.float_as_event = float_as_event
92
+ self.block_size = block_size
87
93
 
88
94
  # indices of post connected neurons
89
- if allow_multi_conn:
90
- self.indices = np.random.RandomState(seed).randint(0, n_post, size=(self.n_pre, self.n_conn))
91
- else:
92
- rng = RandomState(seed)
93
- self.indices = for_loop(lambda i: rng.choice(n_post, size=(self.n_conn,), replace=False), np.arange(n_pre))
94
- self.indices = u.math.asarray(self.indices)
95
+ with jax.ensure_compile_time_eval():
96
+ if allow_multi_conn:
97
+ rng = np.random.RandomState(seed)
98
+ self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
99
+ else:
100
+ rng = RandomState(seed)
101
+
102
+ @vmap(rngs=rng)
103
+ def rand_indices(key):
104
+ rng.set_key(key)
105
+ return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
106
+
107
+ self.indices = rand_indices(rng.split_key(self.in_size[-1]))
108
+ self.indices = u.math.asarray(self.indices)
95
109
 
96
110
  # maximum synaptic conductance
97
- weight = param(weight, (self.n_pre, self.n_conn), allow_none=False)
111
+ weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
98
112
  self.weight = ParamState(weight)
99
113
 
100
114
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
101
- device_kind = jax.devices()[0].platform # spk.device.device_kind
102
- if device_kind == 'cpu':
103
- return cpu_fixed_prob(self.indices,
104
- u.math.asarray(self.weight.value),
105
- u.math.asarray(spk),
106
- n_post=self.n_post,
107
- grad_mode=self.grad_mode)
108
- elif device_kind in ['gpu', 'tpu']:
109
- raise NotImplementedError()
110
- else:
111
- raise ValueError(f"Unsupported device: {device_kind}")
115
+ return event_fixed_prob(
116
+ spk,
117
+ self.weight.value,
118
+ self.indices,
119
+ n_post=self.out_size[-1],
120
+ block_size=self.block_size,
121
+ float_as_event=self.float_as_event
122
+ )
112
123
 
113
124
 
114
- @set_module_as('brainstate.event')
115
- def cpu_fixed_prob(
116
- indices: jax.Array,
117
- weight: Union[u.Quantity, jax.Array],
118
- spk: jax.Array,
119
- *,
120
- n_post: int,
121
- grad_mode: str = 'vjp'
122
- ) -> Union[u.Quantity, jax.Array]:
125
+ def event_fixed_prob(spk, weight, indices, *, n_post, block_size, float_as_event):
123
126
  """
124
127
  The FixedProb module implements a fixed probability connection with CSR sparse data structure.
125
128
 
126
129
  Parameters
127
130
  ----------
128
- n_post : int
129
- Number of post-synaptic neurons.
130
131
  weight : brainunit.Quantity or jax.Array
131
132
  Maximum synaptic conductance.
132
133
  spk : jax.Array
133
134
  Spike events.
134
- indices : jax.Array
135
- Indices of post connected neurons.
136
- grad_mode : str, optional
137
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
138
135
 
139
136
  Returns
140
137
  -------
141
138
  post_data : brainunit.Quantity or jax.Array
142
139
  Post synaptic data.
143
140
  """
144
- unit = u.get_unit(weight)
145
- weight = u.get_mantissa(weight)
146
- indices = jnp.asarray(indices)
147
- spk = jnp.asarray(spk)
141
+ with jax.ensure_compile_time_eval():
142
+ weight = u.math.asarray(weight)
143
+ unit = u.get_unit(weight)
144
+ weight = u.get_mantissa(weight)
145
+ indices = jnp.asarray(indices)
146
+ spk = jnp.asarray(spk)
148
147
 
149
148
  def mv(spk_vector):
150
149
  assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
151
- if grad_mode == 'vjp':
152
- post_data = _cpu_event_fixed_prob_mv_vjp(indices, weight, spk_vector, n_post)
153
- elif grad_mode == 'jvp':
154
- post_data = _cpu_event_fixed_prob_mv_jvp(indices, weight, spk_vector, n_post)
155
- else:
156
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
157
- return post_data
150
+ return event_ellmv_p_call(
151
+ spk,
152
+ weight,
153
+ indices,
154
+ n_post=n_post,
155
+ block_size=block_size,
156
+ float_as_event=float_as_event
157
+ )
158
158
 
159
159
  assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
160
160
  assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
161
161
  assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
162
162
 
163
163
  if spk.ndim == 1:
164
- post_data = mv(spk)
164
+ [post_data] = mv(spk)
165
165
  else:
166
- shape = spk.shape[:-1]
167
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
168
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
166
+ [post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
167
+ post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
169
168
  return u.maybe_decimal(u.Quantity(post_data, unit=unit))
170
169
 
171
170
 
172
- # -------------------
173
- # CPU Implementation
174
- # -------------------
175
-
171
+ Kernel = Callable
172
+
173
+
174
+ def cpu_kernel_generator(
175
+ float_as_event: bool,
176
+ weight_info: jax.ShapeDtypeStruct,
177
+ spike_info: jax.ShapeDtypeStruct,
178
+ **kwargs
179
+ ):
180
+ import numba # pylint: disable=import-outside-toplevel
181
+
182
+ if weight_info.size == 1:
183
+ if spike_info.dtype == jnp.bool_:
184
+ @numba.njit
185
+ def ell_mv(spikes, weights, indices, posts):
186
+ posts[:] = 0.
187
+ w = weights[()]
188
+ for i in range(spikes.shape[0]):
189
+ if spikes[i]:
190
+ for j in range(indices.shape[1]):
191
+ posts[indices[i, j]] += w
192
+
193
+ elif float_as_event:
194
+ @numba.njit
195
+ def ell_mv(spikes, weights, indices, posts):
196
+ posts[:] = 0.
197
+ w = weights[()]
198
+ for i in range(spikes.shape[0]):
199
+ if spikes[i] != 0.:
200
+ for j in range(indices.shape[1]):
201
+ posts[indices[i, j]] += w
176
202
 
177
- def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
178
- def scan_fn(post, i):
179
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
180
- ids = indices[i]
181
- sp = spk[i]
182
- if spk.dtype == jnp.bool_:
183
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
184
203
  else:
185
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
186
- return post, None
187
-
188
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
189
-
190
-
191
- # --------------
192
- # VJP
193
- # --------------
194
-
195
- def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
196
- return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
197
-
198
-
199
- def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
200
- weight, spk = res
204
+ @numba.njit
205
+ def ell_mv(spikes, weights, indices, posts):
206
+ posts[:] = 0.
207
+ w = weights[()]
208
+ for i in range(spikes.shape[0]):
209
+ sp = spikes[i]
210
+ if sp != 0.:
211
+ wsp = w * sp
212
+ for j in range(indices.shape[1]):
213
+ posts[indices[i, j]] += wsp
201
214
 
202
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
203
- homo = jnp.size(weight) == 1
204
- if homo: # homogeneous weight
205
- ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
206
- else: # heterogeneous weight
207
- ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
208
-
209
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
210
- if homo: # scalar
211
- ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
212
- ct_gmax = jnp.inner(ct, ct_gmax)
213
215
  else:
214
- def scan_fn(d_gmax, i):
215
- if spk.dtype == jnp.bool_:
216
- d_gmax = jax.lax.cond(spk[i], lambda: d_gmax.at[i].add(ct[indices[i]]), lambda: d_gmax)
217
- else:
218
- d_gmax = jax.lax.cond(spk[i] == 0., lambda: d_gmax, lambda: d_gmax.at[i].add(ct[indices[i]] * spk[i]))
219
- return d_gmax, None
216
+ if spike_info.dtype == jnp.bool_:
217
+ @numba.njit
218
+ def ell_mv(spikes, weights, indices, posts):
219
+ posts[:] = 0.
220
+ for i in range(spikes.shape[0]):
221
+ if spikes[i]:
222
+ for j in range(indices.shape[1]):
223
+ posts[indices[i, j]] += weights[i, j]
224
+
225
+ elif float_as_event:
226
+ @numba.njit
227
+ def ell_mv(spikes, weights, indices, posts):
228
+ posts[:] = 0.
229
+ for i in range(spikes.shape[0]):
230
+ if spikes[i] != 0.:
231
+ for j in range(indices.shape[1]):
232
+ posts[indices[i, j]] += weights[i, j]
220
233
 
221
- ct_gmax = jax.lax.scan(scan_fn, jnp.zeros_like(weight), np.arange(len(spk)))[0]
222
- return ct_gmax, ct_spk
234
+ else:
235
+ @numba.njit
236
+ def ell_mv(spikes, weights, indices, posts):
237
+ posts[:] = 0.
238
+ for i in range(spikes.shape[0]):
239
+ sp = spikes[i]
240
+ if sp != 0.:
241
+ for j in range(indices.shape[1]):
242
+ posts[indices[i, j]] += weights[i, j] * sp
223
243
 
244
+ return ell_mv
224
245
 
225
- _cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
226
- _cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
227
246
 
247
+ def gpu_kernel_generator(
248
+ n_pre: int,
249
+ n_conn: int,
250
+ n_post: int,
251
+ block_size: int,
252
+ float_as_event: bool,
253
+ weight_info: jax.ShapeDtypeStruct,
254
+ **kwargs
255
+ ):
256
+ # 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵,
257
+ # 这个算子的计算逻辑为:
258
+ #
259
+ # - 每个block处理 [block_size] 个事件,每个事件对应一个 pre-synaptic neuron
260
+ # - 每个block处理 [block_size, block_size] 个 indices 和 weights
261
+
262
+ if weight_info.size == 1:
263
+ def _ell_mv_kernel_homo(
264
+ sp_ref, # [block_size]
265
+ ind_ref, # [block_size, block_size]
266
+ _,
267
+ y_ref, # [n_post]
268
+ ):
269
+ r_pid = pl.program_id(0)
270
+ c_start = pl.program_id(1) * block_size
271
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
272
+ mask = jnp.arange(block_size) + c_start < n_conn
273
+
274
+ def body_fn(j, _):
275
+ if sp_ref.dtype == jnp.bool_:
276
+ def true_fn():
277
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
278
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
279
+ # y_ref[ind] += 1.0
280
+ # ind = ind_ref[j, ...]
281
+ # pl.store(y_ref, ind, 1.0, mask=mask)
282
+
283
+ jax.lax.cond(sp_ref[j], true_fn, lambda: None)
284
+
285
+
286
+ else:
287
+ def true_fn(sp):
288
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
289
+ if float_as_event:
290
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
291
+ else:
292
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask)
293
+
294
+ sp_ = sp_ref[j]
295
+ jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
296
+
297
+ jax.lax.fori_loop(0, row_length, body_fn, None)
298
+
299
+ # homogenous weights
300
+ kernel = pl.pallas_call(
301
+ _ell_mv_kernel_homo,
302
+ out_shape=[
303
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
304
+ ],
305
+ in_specs=[
306
+ pl.BlockSpec((block_size,), lambda i, j: i),
307
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
308
+ pl.BlockSpec((n_post,), lambda i, j: 0)
309
+ ],
310
+ grid=(
311
+ pl.cdiv(n_pre, block_size),
312
+ pl.cdiv(n_conn, block_size),
313
+ ),
314
+ input_output_aliases={2: 0},
315
+ interpret=False
316
+ )
317
+ return (lambda spikes, weight, indices:
318
+ [kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight])
228
319
 
229
- # --------------
230
- # JVP
231
- # --------------
320
+ else:
321
+ def _ell_mv_kernel_heter(
322
+ sp_ref, # [block_size]
323
+ ind_ref, # [block_size, block_size]
324
+ w_ref, # [block_size, block_size]
325
+ _,
326
+ y_ref, # [n_post]
327
+ ):
328
+ r_pid = pl.program_id(0)
329
+ c_start = pl.program_id(1) * block_size
330
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
331
+ mask = jnp.arange(block_size) + c_start < n_conn
332
+
333
+ def body_fn(j, _):
334
+ if sp_ref.dtype == jnp.bool_:
335
+ def true_fn():
336
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
337
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
338
+ pl.atomic_add(y_ref, ind, w, mask=mask)
339
+
340
+ jax.lax.cond(sp_ref[j], true_fn, lambda: None)
341
+ else:
342
+ def true_fn(spk):
343
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
344
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
345
+ if not float_as_event:
346
+ w = w * spk
347
+ pl.atomic_add(y_ref, ind, w, mask=mask)
348
+
349
+ sp_ = sp_ref[j]
350
+ jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
351
+
352
+ jax.lax.fori_loop(0, row_length, body_fn, None)
353
+
354
+ # heterogeneous weights
355
+ kernel = pl.pallas_call(
356
+ _ell_mv_kernel_heter,
357
+ out_shape=[
358
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
359
+ ],
360
+ in_specs=[
361
+ pl.BlockSpec((block_size,), lambda i, j: i), # sp_ref
362
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
363
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref,
364
+ pl.BlockSpec((n_post,), lambda i, j: 0)
365
+ ],
366
+ grid=(
367
+ pl.cdiv(n_pre, block_size),
368
+ pl.cdiv(n_conn, block_size),
369
+ ),
370
+ input_output_aliases={3: 0},
371
+ interpret=False
372
+ )
373
+ return (lambda spikes, weight, indices:
374
+ kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
375
+
376
+
377
+ def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwargs):
378
+ return ellmv_p_call(
379
+ spk_dot,
380
+ weights,
381
+ indices,
382
+ n_post=n_post,
383
+ block_size=block_size,
384
+ )
385
+
386
+
387
+ def jvp_weights(w_dot, spikes, weights, indices, *, float_as_event, block_size, n_post, **kwargs):
388
+ return event_ellmv_p_call(
389
+ spikes,
390
+ w_dot,
391
+ indices,
392
+ n_post=n_post,
393
+ block_size=block_size,
394
+ float_as_event=float_as_event
395
+ )
396
+
397
+
398
+ def transpose_rule(
399
+ ct, spikes, weights, indices,
400
+ *,
401
+ float_as_event, n_post, n_conn, block_size, weight_info, **kwargs
402
+ ):
403
+ if ad.is_undefined_primal(indices):
404
+ raise ValueError("Cannot transpose with respect to sparse indices.")
232
405
 
406
+ ct = ct[0]
233
407
 
234
- def _cpu_event_fixed_prob_mv_jvp_rule(indices, n_post, primals, tangents):
235
- # forward pass
236
- weight, spk = primals
237
- y = _cpu_event_fixed_prob_mv(indices, weight, spk, n_post=n_post)
408
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
409
+ homo = weight_info.size == 1
410
+ if ad.is_undefined_primal(spikes):
411
+ if homo:
412
+ # homogeneous weight
413
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
414
+ else:
415
+ # heterogeneous weight
416
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
417
+ return (ad.Zero(spikes) if type(ct) is ad.Zero else ct_spk), weights, indices
238
418
 
239
- # forward gradients
240
- gmax_dot, spk_dot = tangents
419
+ else:
420
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
421
+ if homo:
422
+ # scalar
423
+ ct_gmax = event_ellmv_p_call(
424
+ spikes,
425
+ jnp.asarray(1., dtype=weight_info.dtype),
426
+ indices,
427
+ n_post=n_post,
428
+ block_size=block_size,
429
+ float_as_event=float_as_event
430
+ )
431
+ ct_gmax = jnp.inner(ct, ct_gmax[0])
432
+ else:
433
+ def map_fn(one_spk, one_ind):
434
+ if spikes.dtype == jnp.bool_:
435
+ return jax.lax.cond(
436
+ one_spk,
437
+ lambda: ct[one_ind],
438
+ lambda: jnp.zeros([n_conn], weight_info.dtype)
439
+ )
440
+ else:
441
+ if float_as_event:
442
+ return jax.lax.cond(
443
+ one_spk == 0.,
444
+ lambda: jnp.zeros([n_conn], weight_info.dtype),
445
+ lambda: ct[one_ind]
446
+ )
447
+ else:
448
+ return jax.lax.cond(
449
+ one_spk == 0.,
450
+ lambda: jnp.zeros([n_conn], weight_info.dtype),
451
+ lambda: ct[one_ind] * one_spk
452
+ )
453
+
454
+ ct_gmax = jax.vmap(map_fn)(spikes, indices)
455
+ return spikes, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
456
+
457
+
458
+ event_ellmv_p = XLACustomOp(
459
+ 'event_ell_mv',
460
+ cpu_kernel_generator=cpu_kernel_generator,
461
+ gpu_kernel_generator=gpu_kernel_generator,
462
+ )
463
+ event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
464
+ event_ellmv_p.def_transpose_rule(transpose_rule)
465
+
466
+
467
+ def event_ellmv_p_call(spikes, weights, indices, *, n_post, block_size, float_as_event):
468
+ n_conn = indices.shape[1]
469
+ if block_size is None:
470
+ if n_conn <= 16:
471
+ block_size = 16
472
+ elif n_conn <= 32:
473
+ block_size = 32
474
+ elif n_conn <= 64:
475
+ block_size = 64
476
+ elif n_conn <= 128:
477
+ block_size = 128
478
+ elif n_conn <= 256:
479
+ block_size = 256
480
+ else:
481
+ block_size = 128
482
+ return event_ellmv_p(
483
+ spikes,
484
+ weights,
485
+ indices,
486
+ outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)],
487
+ block_size=block_size,
488
+ float_as_event=float_as_event,
489
+ n_pre=spikes.shape[0],
490
+ n_conn=indices.shape[1],
491
+ n_post=n_post,
492
+ weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
493
+ spike_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
494
+ )
495
+
496
+
497
+ def ell_cpu_kernel_generator(
498
+ weight_info: jax.ShapeDtypeStruct,
499
+ **kwargs
500
+ ):
501
+ import numba # pylint: disable=import-outside-toplevel
502
+
503
+ if jnp.size(weight_info) == 1:
504
+ @numba.njit
505
+ def ell_mv(vector, weights, indices, posts):
506
+ posts[:] = 0.
507
+ w = weights[()]
508
+ for i in range(vector.shape[0]):
509
+ wv = w * vector[i]
510
+ for j in range(indices.shape[1]):
511
+ posts[indices[i, j]] += wv
241
512
 
242
- # ∂y/∂gmax
243
- dgmax = _cpu_event_fixed_prob_mv(indices, gmax_dot, spk, n_post=n_post)
513
+ else:
514
+ @numba.njit
515
+ def ell_mv(vector, weights, indices, posts):
516
+ posts[:] = 0.
517
+ for i in range(vector.shape[0]):
518
+ for j in range(indices.shape[1]):
519
+ posts[indices[i, j]] += weights[i, j] * vector[i]
244
520
 
245
- def scan_fn(post, i):
246
- ids = indices[i]
247
- w = weight if jnp.size(weight) == 1 else weight[i]
248
- post = post.at[ids].add(w * spk_dot[i])
249
- return post, None
521
+ return ell_mv
250
522
 
251
- # ∂y/∂gspk
252
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
253
- return y, dgmax + dspk
254
523
 
524
+ def ell_gpu_kernel_generator(
525
+ block_size: int,
526
+ n_pre: int,
527
+ n_conn: int,
528
+ n_post: int,
529
+ weight_info: jax.ShapeDtypeStruct,
530
+ **kwargs
531
+ ):
532
+ homo = jnp.size(weight_info) == 1
533
+
534
+ if homo:
535
+ def _kernel(
536
+ vec_ref, ind_ref, _, out_ref,
537
+ ):
538
+ # 每个block 处理 [block_size] 大小的vector
539
+ # 每个block 处理 [block_size, block_size] 大小的indices 和 weights
540
+
541
+ # -------------------------------
542
+ # vec_ref: [block_size]
543
+ # ind_ref: [block_size, block_size]
544
+ # out_ref: [n_post]
545
+
546
+ r_pid = pl.program_id(0)
547
+ c_start = pl.program_id(1) * block_size
548
+ mask = jnp.arange(block_size) + c_start
549
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
550
+
551
+ def body_fn(j, _):
552
+ y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype)
553
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
554
+ pl.atomic_add(out_ref, ind, y, mask=mask)
555
+
556
+ jax.lax.fori_loop(0, row_length, body_fn, None)
557
+
558
+ # heterogeneous weights
559
+ kernel = pl.pallas_call(
560
+ _kernel,
561
+ out_shape=[
562
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
563
+ ],
564
+ in_specs=[
565
+ pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
566
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
567
+ pl.BlockSpec((n_post,), lambda i, j: 0) # out_ref
568
+ ],
569
+ grid=(
570
+ pl.cdiv(n_pre, block_size),
571
+ pl.cdiv(n_conn, block_size),
572
+ ),
573
+ input_output_aliases={2: 0},
574
+ interpret=False
575
+ )
576
+ return lambda vector, weight, indices: kernel(vector, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight
255
577
 
256
- _cpu_event_fixed_prob_mv_jvp = jax.custom_jvp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
257
- _cpu_event_fixed_prob_mv_jvp.defjvp(_cpu_event_fixed_prob_mv_jvp_rule)
578
+ else:
579
+ def _kernel(
580
+ vec_ref, ind_ref, w_ref, _, out_ref,
581
+ ):
582
+ # 每个block 处理 [block_size] 大小的vector
583
+ # 每个block 处理 [block_size, n_conn] 大小的indices 和 weights
584
+
585
+ # -------------------------------
586
+ # vec_ref: [block_size]
587
+ # ind_ref: [block_size, block_size]
588
+ # w_ref: [block_size, block_size]
589
+ # out_ref: [n_post]
590
+
591
+ r_pid = pl.program_id(0)
592
+ c_start = pl.program_id(1) * block_size
593
+ mask = jnp.arange(block_size) + c_start
594
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
595
+
596
+ def body_fn(j, _):
597
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
598
+ y = w * vec_ref[j]
599
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
600
+ pl.atomic_add(out_ref, ind, y, mask=mask)
601
+
602
+ jax.lax.fori_loop(0, row_length, body_fn, None)
603
+
604
+ # heterogeneous weights
605
+ kernel = pl.pallas_call(
606
+ _kernel,
607
+ out_shape=[
608
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
609
+ ],
610
+ in_specs=[
611
+ pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
612
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
613
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref
614
+ pl.BlockSpec((n_post,), lambda i: 0) # out_ref
615
+ ],
616
+ grid=(
617
+ pl.cdiv(n_pre, block_size),
618
+ pl.cdiv(n_conn, block_size),
619
+ ),
620
+ input_output_aliases={3: 0},
621
+ interpret=False
622
+ )
623
+ return lambda vector, weight, indices: kernel(vector, indices, weight, jnp.zeros(n_post, dtype=weight.dtype))
624
+
625
+
626
+ def jvp_weights_no_spk(w_dot, vector, weights, indices, *, block_size, n_post, **kwargs):
627
+ return ellmv_p_call(
628
+ vector,
629
+ w_dot,
630
+ indices,
631
+ block_size=block_size,
632
+ n_post=n_post,
633
+ )
634
+
635
+
636
+ def transpose_rule_no_spk(
637
+ ct, vector, weights, indices,
638
+ *,
639
+ n_post, block_size, weight_info, **kwargs
640
+ ):
641
+ if ad.is_undefined_primal(indices):
642
+ raise ValueError("Cannot transpose with respect to sparse indices.")
258
643
 
644
+ ct = ct[0]
259
645
 
260
- def _gpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
261
- def scan_fn(post, i):
262
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
263
- ids = indices[i]
264
- sp = spk[i]
265
- if spk.dtype == jnp.bool_:
266
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
646
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
647
+ homo = weight_info.size == 1
648
+ if ad.is_undefined_primal(vector):
649
+ if homo:
650
+ # homogeneous weight
651
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
267
652
  else:
268
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
269
- return post, None
653
+ # heterogeneous weight
654
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
655
+ return (ad.Zero(vector) if type(ct) is ad.Zero else ct_spk), weights, indices
270
656
 
271
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
657
+ else:
658
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
659
+ if homo:
660
+ # scalar
661
+ ct_gmax = ellmv_p_call(
662
+ vector,
663
+ jnp.asarray(1., dtype=weight_info.dtype),
664
+ indices,
665
+ block_size=block_size,
666
+ n_post=n_post,
667
+ )
668
+ ct_gmax = jnp.inner(ct, ct_gmax[0])
669
+ else:
670
+ ct_gmax = jax.vmap(lambda vec, one_ind: ct[one_ind] * vec)(vector, indices)
671
+ return vector, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
672
+
673
+
674
+ ellmv_p = XLACustomOp(
675
+ 'ell_mv',
676
+ cpu_kernel_generator=ell_cpu_kernel_generator,
677
+ gpu_kernel_generator=ell_gpu_kernel_generator,
678
+ )
679
+ ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
680
+ ellmv_p.def_transpose_rule(transpose_rule_no_spk)
681
+
682
+
683
+ def ellmv_p_call(vector, weights, indices, *, n_post, block_size):
684
+ n_conn = indices.shape[1]
685
+ if block_size is None:
686
+ if n_conn <= 16:
687
+ block_size = 16
688
+ elif n_conn <= 32:
689
+ block_size = 32
690
+ elif n_conn <= 64:
691
+ block_size = 64
692
+ elif n_conn <= 128:
693
+ block_size = 128
694
+ elif n_conn <= 256:
695
+ block_size = 256
696
+ else:
697
+ block_size = 128
698
+ return ellmv_p(
699
+ vector,
700
+ weights,
701
+ indices,
702
+ n_post=n_post,
703
+ n_pre=indices.shape[0],
704
+ n_conn=indices.shape[1],
705
+ block_size=block_size,
706
+ weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
707
+ outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)]
708
+ )