brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -1,730 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- from typing import Union, Callable, Optional
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.experimental.pallas as pl
23
- import jax.numpy as jnp
24
- import numpy as np
25
- from jax.interpreters import ad
26
-
27
- from brainstate import environ
28
- from brainstate._state import ParamState
29
- from brainstate.augment import vmap
30
- from brainstate.init import param
31
- from brainstate.nn._module import Module
32
- from brainstate.random import RandomState
33
- from brainstate.typing import ArrayLike, Size
34
- from ._misc import FloatScalar
35
- from ._xla_custom_op import XLACustomOp
36
-
37
- __all__ = [
38
- 'FixedProb',
39
- ]
40
-
41
-
42
- class FixedProb(Module):
43
- """
44
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
45
-
46
- Parameters
47
- ----------
48
- in_size : Size
49
- Number of pre-synaptic neurons, i.e., input size.
50
- out_size : Size
51
- Number of post-synaptic neurons, i.e., output size.
52
- prob : float
53
- Probability of connection, i.e., connection probability.
54
- weight : float or callable or jax.Array or brainunit.Quantity
55
- Maximum synaptic conductance, i.e., synaptic weight.
56
- allow_multi_conn : bool, optional
57
- Whether multiple connections are allowed from a single pre-synaptic neuron.
58
- Default is True, meaning that a value of ``a`` can be selected multiple times.
59
- seed: int, optional
60
- Random seed. Default is None. If None, the default random seed will be used.
61
- float_as_event : bool, optional
62
- Whether to treat float as event. Default is True.
63
- block_size : int, optional
64
- Block size for parallel computation. Default is 64. This is only used for GPU.
65
- name : str, optional
66
- Name of the module.
67
- """
68
-
69
- __module__ = 'brainstate.event'
70
-
71
- def __init__(
72
- self,
73
- in_size: Size,
74
- out_size: Size,
75
- prob: FloatScalar,
76
- weight: Union[Callable, ArrayLike],
77
- allow_multi_conn: bool = True,
78
- seed: Optional[int] = None,
79
- float_as_event: bool = True,
80
- block_size: Optional[int] = None,
81
- name: Optional[str] = None,
82
- ):
83
- super().__init__(name=name)
84
-
85
- # network parameters
86
- self.in_size = in_size
87
- self.out_size = out_size
88
- self.n_conn = int(self.out_size[-1] * prob)
89
- self.float_as_event = float_as_event
90
- self.block_size = block_size
91
-
92
- if self.n_conn > 1:
93
- # indices of post connected neurons
94
- with jax.ensure_compile_time_eval():
95
- if allow_multi_conn:
96
- rng = np.random.RandomState(seed)
97
- self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
98
- else:
99
- rng = RandomState(seed)
100
-
101
- @vmap(rngs=rng)
102
- def rand_indices(key):
103
- rng.set_key(key)
104
- return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
105
-
106
- self.indices = rand_indices(rng.split_key(self.in_size[-1]))
107
- self.indices = u.math.asarray(self.indices)
108
-
109
- # maximum synaptic conductance
110
- weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
111
- self.weight = ParamState(weight)
112
-
113
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
114
- if self.n_conn > 1:
115
- r = event_fixed_prob(
116
- spk,
117
- self.weight.value,
118
- self.indices,
119
- n_post=self.out_size[-1],
120
- block_size=self.block_size,
121
- float_as_event=self.float_as_event
122
- )
123
- else:
124
- weight = self.weight.value
125
- unit = u.get_unit(weight)
126
- r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
127
- r = u.maybe_decimal(u.Quantity(r, unit=unit))
128
- return u.math.asarray(r, dtype=environ.dftype())
129
-
130
-
131
- def event_fixed_prob(
132
- spk, weight, indices,
133
- *,
134
- n_post, block_size, float_as_event
135
- ):
136
- """
137
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
138
-
139
- Parameters
140
- ----------
141
- weight : brainunit.Quantity or jax.Array
142
- Maximum synaptic conductance.
143
- spk : jax.Array
144
- Spike events.
145
-
146
- Returns
147
- -------
148
- post_data : brainunit.Quantity or jax.Array
149
- Post synaptic data.
150
- """
151
- with jax.ensure_compile_time_eval():
152
- weight = u.math.asarray(weight)
153
- unit = u.get_unit(weight)
154
- weight = u.get_mantissa(weight)
155
- indices = jnp.asarray(indices)
156
- spk = jnp.asarray(spk)
157
-
158
- def mv(spk_vector):
159
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
160
- return event_ellmv_p_call(
161
- spk,
162
- weight,
163
- indices,
164
- n_post=n_post,
165
- block_size=block_size,
166
- float_as_event=float_as_event
167
- )
168
-
169
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
170
- assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
171
- assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
172
-
173
- if spk.ndim == 1:
174
- [post_data] = mv(spk)
175
- else:
176
- [post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
177
- post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
178
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
179
-
180
-
181
- Kernel = Callable
182
-
183
-
184
- def cpu_kernel_generator(
185
- float_as_event: bool,
186
- weight_info: jax.ShapeDtypeStruct,
187
- spike_info: jax.ShapeDtypeStruct,
188
- **kwargs
189
- ):
190
- import numba # pylint: disable=import-outside-toplevel
191
-
192
- if weight_info.size == 1:
193
- if spike_info.dtype == jnp.bool_:
194
- @numba.njit
195
- def ell_mv(spikes, weights, indices, posts):
196
- posts[:] = 0.
197
- w = weights[()]
198
- for i in range(spikes.shape[0]):
199
- if spikes[i]:
200
- for j in range(indices.shape[1]):
201
- posts[indices[i, j]] += w
202
-
203
- elif float_as_event:
204
- @numba.njit
205
- def ell_mv(spikes, weights, indices, posts):
206
- posts[:] = 0.
207
- w = weights[()]
208
- for i in range(spikes.shape[0]):
209
- if spikes[i] != 0.:
210
- for j in range(indices.shape[1]):
211
- posts[indices[i, j]] += w
212
-
213
- else:
214
- @numba.njit
215
- def ell_mv(spikes, weights, indices, posts):
216
- posts[:] = 0.
217
- w = weights[()]
218
- for i in range(spikes.shape[0]):
219
- sp = spikes[i]
220
- if sp != 0.:
221
- wsp = w * sp
222
- for j in range(indices.shape[1]):
223
- posts[indices[i, j]] += wsp
224
-
225
- else:
226
- if spike_info.dtype == jnp.bool_:
227
- @numba.njit
228
- def ell_mv(spikes, weights, indices, posts):
229
- posts[:] = 0.
230
- for i in range(spikes.shape[0]):
231
- if spikes[i]:
232
- for j in range(indices.shape[1]):
233
- posts[indices[i, j]] += weights[i, j]
234
-
235
- elif float_as_event:
236
- @numba.njit
237
- def ell_mv(spikes, weights, indices, posts):
238
- posts[:] = 0.
239
- for i in range(spikes.shape[0]):
240
- if spikes[i] != 0.:
241
- for j in range(indices.shape[1]):
242
- posts[indices[i, j]] += weights[i, j]
243
-
244
- else:
245
- @numba.njit
246
- def ell_mv(spikes, weights, indices, posts):
247
- posts[:] = 0.
248
- for i in range(spikes.shape[0]):
249
- sp = spikes[i]
250
- if sp != 0.:
251
- for j in range(indices.shape[1]):
252
- posts[indices[i, j]] += weights[i, j] * sp
253
-
254
- return ell_mv
255
-
256
-
257
- def gpu_kernel_generator(
258
- n_pre: int,
259
- n_conn: int,
260
- n_post: int,
261
- block_size: int,
262
- float_as_event: bool,
263
- weight_info: jax.ShapeDtypeStruct,
264
- **kwargs
265
- ):
266
- # 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵,
267
- # 这个算子的计算逻辑为:
268
- #
269
- # - 每个block处理 [block_size] 个事件,每个事件对应一个 pre-synaptic neuron
270
- # - 每个block处理 [block_size, block_size] 个 indices 和 weights
271
-
272
- if weight_info.size == 1:
273
- def _ell_mv_kernel_homo(
274
- sp_ref, # [block_size]
275
- ind_ref, # [block_size, block_size]
276
- _,
277
- y_ref, # [n_post]
278
- ):
279
- r_pid = pl.program_id(0)
280
- c_start = pl.program_id(1) * block_size
281
- row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
282
- mask = jnp.arange(block_size) + c_start < n_conn
283
-
284
- def body_fn(j, _):
285
- if sp_ref.dtype == jnp.bool_:
286
- def true_fn():
287
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
288
- pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
289
- # y_ref[ind] += 1.0
290
- # ind = ind_ref[j, ...]
291
- # pl.store(y_ref, ind, 1.0, mask=mask)
292
-
293
- jax.lax.cond(sp_ref[j], true_fn, lambda: None)
294
-
295
-
296
- else:
297
- def true_fn(sp):
298
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
299
- if float_as_event:
300
- pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
301
- else:
302
- pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask)
303
-
304
- sp_ = sp_ref[j]
305
- jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
306
-
307
- jax.lax.fori_loop(0, row_length, body_fn, None)
308
-
309
- # homogenous weights
310
- kernel = pl.pallas_call(
311
- _ell_mv_kernel_homo,
312
- out_shape=[
313
- jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
314
- ],
315
- in_specs=[
316
- pl.BlockSpec((block_size,), lambda i, j: i),
317
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
318
- pl.BlockSpec((n_post,), lambda i, j: 0)
319
- ],
320
- grid=(
321
- pl.cdiv(n_pre, block_size),
322
- pl.cdiv(n_conn, block_size),
323
- ),
324
- input_output_aliases={2: 0},
325
- interpret=False
326
- )
327
- return (lambda spikes, weight, indices:
328
- [kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight])
329
-
330
- else:
331
- def _ell_mv_kernel_heter(
332
- sp_ref, # [block_size]
333
- ind_ref, # [block_size, block_size]
334
- w_ref, # [block_size, block_size]
335
- _,
336
- y_ref, # [n_post]
337
- ):
338
- r_pid = pl.program_id(0)
339
- c_start = pl.program_id(1) * block_size
340
- row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
341
- mask = jnp.arange(block_size) + c_start < n_conn
342
-
343
- def body_fn(j, _):
344
- if sp_ref.dtype == jnp.bool_:
345
- def true_fn():
346
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
347
- w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
348
- pl.atomic_add(y_ref, ind, w, mask=mask)
349
-
350
- jax.lax.cond(sp_ref[j], true_fn, lambda: None)
351
- else:
352
- def true_fn(spk):
353
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
354
- w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
355
- if not float_as_event:
356
- w = w * spk
357
- pl.atomic_add(y_ref, ind, w, mask=mask)
358
-
359
- sp_ = sp_ref[j]
360
- jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
361
-
362
- jax.lax.fori_loop(0, row_length, body_fn, None)
363
-
364
- # heterogeneous weights
365
- kernel = pl.pallas_call(
366
- _ell_mv_kernel_heter,
367
- out_shape=[
368
- jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
369
- ],
370
- in_specs=[
371
- pl.BlockSpec((block_size,), lambda i, j: i), # sp_ref
372
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
373
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref,
374
- pl.BlockSpec((n_post,), lambda i, j: 0)
375
- ],
376
- grid=(
377
- pl.cdiv(n_pre, block_size),
378
- pl.cdiv(n_conn, block_size),
379
- ),
380
- input_output_aliases={3: 0},
381
- interpret=False
382
- )
383
- return (lambda spikes, weight, indices:
384
- kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
385
-
386
-
387
- def jvp_spikes(
388
- spk_dot, spikes, weights, indices,
389
- *,
390
- n_post, block_size, **kwargs
391
- ):
392
- return ellmv_p_call(
393
- spk_dot,
394
- weights,
395
- indices,
396
- n_post=n_post,
397
- block_size=block_size,
398
- )
399
-
400
-
401
- def jvp_weights(
402
- w_dot, spikes, weights, indices,
403
- *,
404
- float_as_event, block_size, n_post, **kwargs
405
- ):
406
- return event_ellmv_p_call(
407
- spikes,
408
- w_dot,
409
- indices,
410
- n_post=n_post,
411
- block_size=block_size,
412
- float_as_event=float_as_event
413
- )
414
-
415
-
416
- def transpose_rule(
417
- ct, spikes, weights, indices,
418
- *,
419
- float_as_event, n_post, n_conn, block_size, weight_info, **kwargs
420
- ):
421
- if ad.is_undefined_primal(indices):
422
- raise ValueError("Cannot transpose with respect to sparse indices.")
423
-
424
- ct = ct[0]
425
-
426
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
427
- homo = weight_info.size == 1
428
- if ad.is_undefined_primal(spikes):
429
- if homo:
430
- # homogeneous weight
431
- ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
432
- else:
433
- # heterogeneous weight
434
- ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
435
- return (ad.Zero(spikes) if type(ct) is ad.Zero else ct_spk), weights, indices
436
-
437
- else:
438
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
439
- if homo:
440
- # scalar
441
- ct_gmax = event_ellmv_p_call(
442
- spikes,
443
- jnp.asarray(1., dtype=weight_info.dtype),
444
- indices,
445
- n_post=n_post,
446
- block_size=block_size,
447
- float_as_event=float_as_event
448
- )
449
- ct_gmax = jnp.inner(ct, ct_gmax[0])
450
- else:
451
- def map_fn(one_spk, one_ind):
452
- if spikes.dtype == jnp.bool_:
453
- return jax.lax.cond(
454
- one_spk,
455
- lambda: ct[one_ind],
456
- lambda: jnp.zeros([n_conn], weight_info.dtype)
457
- )
458
- else:
459
- if float_as_event:
460
- return jax.lax.cond(
461
- one_spk == 0.,
462
- lambda: jnp.zeros([n_conn], weight_info.dtype),
463
- lambda: ct[one_ind]
464
- )
465
- else:
466
- return jax.lax.cond(
467
- one_spk == 0.,
468
- lambda: jnp.zeros([n_conn], weight_info.dtype),
469
- lambda: ct[one_ind] * one_spk
470
- )
471
-
472
- ct_gmax = jax.vmap(map_fn)(spikes, indices)
473
- return spikes, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
474
-
475
-
476
- event_ellmv_p = XLACustomOp(
477
- 'event_ell_mv',
478
- cpu_kernel_or_generator=cpu_kernel_generator,
479
- gpu_kernel_or_generator=gpu_kernel_generator,
480
- )
481
- event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
482
- event_ellmv_p.def_transpose_rule(transpose_rule)
483
-
484
-
485
- def event_ellmv_p_call(
486
- spikes, weights, indices,
487
- *,
488
- n_post, block_size, float_as_event
489
- ):
490
- n_conn = indices.shape[1]
491
- if block_size is None:
492
- if n_conn <= 16:
493
- block_size = 16
494
- elif n_conn <= 32:
495
- block_size = 32
496
- elif n_conn <= 64:
497
- block_size = 64
498
- elif n_conn <= 128:
499
- block_size = 128
500
- elif n_conn <= 256:
501
- block_size = 256
502
- else:
503
- block_size = 128
504
- return event_ellmv_p(
505
- spikes,
506
- weights,
507
- indices,
508
- outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)],
509
- block_size=block_size,
510
- float_as_event=float_as_event,
511
- n_pre=spikes.shape[0],
512
- n_conn=indices.shape[1],
513
- n_post=n_post,
514
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
515
- spike_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
516
- )
517
-
518
-
519
- def ell_cpu_kernel_generator(
520
- weight_info: jax.ShapeDtypeStruct,
521
- **kwargs
522
- ):
523
- import numba # pylint: disable=import-outside-toplevel
524
-
525
- if jnp.size(weight_info) == 1:
526
- @numba.njit
527
- def ell_mv(vector, weights, indices, posts):
528
- posts[:] = 0.
529
- w = weights[()]
530
- for i in range(vector.shape[0]):
531
- wv = w * vector[i]
532
- for j in range(indices.shape[1]):
533
- posts[indices[i, j]] += wv
534
-
535
- else:
536
- @numba.njit
537
- def ell_mv(vector, weights, indices, posts):
538
- posts[:] = 0.
539
- for i in range(vector.shape[0]):
540
- for j in range(indices.shape[1]):
541
- posts[indices[i, j]] += weights[i, j] * vector[i]
542
-
543
- return ell_mv
544
-
545
-
546
- def ell_gpu_kernel_generator(
547
- block_size: int,
548
- n_pre: int,
549
- n_conn: int,
550
- n_post: int,
551
- weight_info: jax.ShapeDtypeStruct,
552
- **kwargs
553
- ):
554
- homo = jnp.size(weight_info) == 1
555
-
556
- if homo:
557
- def _kernel(
558
- vec_ref, ind_ref, _, out_ref,
559
- ):
560
- # 每个block 处理 [block_size] 大小的vector
561
- # 每个block 处理 [block_size, block_size] 大小的indices 和 weights
562
-
563
- # -------------------------------
564
- # vec_ref: [block_size]
565
- # ind_ref: [block_size, block_size]
566
- # out_ref: [n_post]
567
-
568
- r_pid = pl.program_id(0)
569
- c_start = pl.program_id(1) * block_size
570
- mask = jnp.arange(block_size) + c_start
571
- row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
572
-
573
- def body_fn(j, _):
574
- y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype)
575
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
576
- pl.atomic_add(out_ref, ind, y, mask=mask)
577
-
578
- jax.lax.fori_loop(0, row_length, body_fn, None)
579
-
580
- # heterogeneous weights
581
- kernel = pl.pallas_call(
582
- _kernel,
583
- out_shape=[
584
- jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
585
- ],
586
- in_specs=[
587
- pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
588
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
589
- pl.BlockSpec((n_post,), lambda i, j: 0) # out_ref
590
- ],
591
- grid=(
592
- pl.cdiv(n_pre, block_size),
593
- pl.cdiv(n_conn, block_size),
594
- ),
595
- input_output_aliases={2: 0},
596
- interpret=False
597
- )
598
- return lambda vector, weight, indices: kernel(vector, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight
599
-
600
- else:
601
- def _kernel(
602
- vec_ref, ind_ref, w_ref, _, out_ref,
603
- ):
604
- # 每个block 处理 [block_size] 大小的vector
605
- # 每个block 处理 [block_size, n_conn] 大小的indices 和 weights
606
-
607
- # -------------------------------
608
- # vec_ref: [block_size]
609
- # ind_ref: [block_size, block_size]
610
- # w_ref: [block_size, block_size]
611
- # out_ref: [n_post]
612
-
613
- r_pid = pl.program_id(0)
614
- c_start = pl.program_id(1) * block_size
615
- mask = jnp.arange(block_size) + c_start
616
- row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
617
-
618
- def body_fn(j, _):
619
- w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
620
- y = w * vec_ref[j]
621
- ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
622
- pl.atomic_add(out_ref, ind, y, mask=mask)
623
-
624
- jax.lax.fori_loop(0, row_length, body_fn, None)
625
-
626
- # heterogeneous weights
627
- kernel = pl.pallas_call(
628
- _kernel,
629
- out_shape=[
630
- jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
631
- ],
632
- in_specs=[
633
- pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
634
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
635
- pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref
636
- pl.BlockSpec((n_post,), lambda i: 0) # out_ref
637
- ],
638
- grid=(
639
- pl.cdiv(n_pre, block_size),
640
- pl.cdiv(n_conn, block_size),
641
- ),
642
- input_output_aliases={3: 0},
643
- interpret=False
644
- )
645
- return lambda vector, weight, indices: kernel(vector, indices, weight, jnp.zeros(n_post, dtype=weight.dtype))
646
-
647
-
648
- def jvp_weights_no_spk(w_dot, vector, weights, indices, *, block_size, n_post, **kwargs):
649
- return ellmv_p_call(
650
- vector,
651
- w_dot,
652
- indices,
653
- block_size=block_size,
654
- n_post=n_post,
655
- )
656
-
657
-
658
- def transpose_rule_no_spk(
659
- ct, vector, weights, indices,
660
- *,
661
- n_post, block_size, weight_info, **kwargs
662
- ):
663
- if ad.is_undefined_primal(indices):
664
- raise ValueError("Cannot transpose with respect to sparse indices.")
665
-
666
- ct = ct[0]
667
-
668
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
669
- homo = weight_info.size == 1
670
- if ad.is_undefined_primal(vector):
671
- if homo:
672
- # homogeneous weight
673
- ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
674
- else:
675
- # heterogeneous weight
676
- ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
677
- return (ad.Zero(vector) if type(ct) is ad.Zero else ct_spk), weights, indices
678
-
679
- else:
680
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
681
- if homo:
682
- # scalar
683
- ct_gmax = ellmv_p_call(
684
- vector,
685
- jnp.asarray(1., dtype=weight_info.dtype),
686
- indices,
687
- block_size=block_size,
688
- n_post=n_post,
689
- )
690
- ct_gmax = jnp.inner(ct, ct_gmax[0])
691
- else:
692
- ct_gmax = jax.vmap(lambda vec, one_ind: ct[one_ind] * vec)(vector, indices)
693
- return vector, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
694
-
695
-
696
- ellmv_p = XLACustomOp(
697
- 'ell_mv',
698
- cpu_kernel_or_generator=ell_cpu_kernel_generator,
699
- gpu_kernel_or_generator=ell_gpu_kernel_generator,
700
- )
701
- ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
702
- ellmv_p.def_transpose_rule(transpose_rule_no_spk)
703
-
704
-
705
- def ellmv_p_call(vector, weights, indices, *, n_post, block_size):
706
- n_conn = indices.shape[1]
707
- if block_size is None:
708
- if n_conn <= 16:
709
- block_size = 16
710
- elif n_conn <= 32:
711
- block_size = 32
712
- elif n_conn <= 64:
713
- block_size = 64
714
- elif n_conn <= 128:
715
- block_size = 128
716
- elif n_conn <= 256:
717
- block_size = 256
718
- else:
719
- block_size = 128
720
- return ellmv_p(
721
- vector,
722
- weights,
723
- indices,
724
- n_post=n_post,
725
- n_pre=indices.shape[0],
726
- n_conn=indices.shape[1],
727
- block_size=block_size,
728
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
729
- outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)]
730
- )