brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,708 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Union, Callable, Optional
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.experimental.pallas as pl
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+ from jax.interpreters import ad
26
+
27
+ from brainstate._state import ParamState
28
+ from brainstate.augment import vmap
29
+ from brainstate.init import param
30
+ from brainstate.nn._module import Module
31
+ from brainstate.random import RandomState
32
+ from brainstate.typing import ArrayLike, Size
33
+ from ._misc import FloatScalar
34
+ from ._xla_custom_op import XLACustomOp
35
+
36
+ __all__ = [
37
+ 'FixedProb',
38
+ ]
39
+
40
+
41
+ class FixedProb(Module):
42
+ """
43
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
44
+
45
+ Parameters
46
+ ----------
47
+ in_size : Size
48
+ Number of pre-synaptic neurons, i.e., input size.
49
+ out_size : Size
50
+ Number of post-synaptic neurons, i.e., output size.
51
+ prob : float
52
+ Probability of connection, i.e., connection probability.
53
+ weight : float or callable or jax.Array or brainunit.Quantity
54
+ Maximum synaptic conductance, i.e., synaptic weight.
55
+ allow_multi_conn : bool, optional
56
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
57
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
58
+ seed: int, optional
59
+ Random seed. Default is None. If None, the default random seed will be used.
60
+ float_as_event : bool, optional
61
+ Whether to treat float as event. Default is True.
62
+ block_size : int, optional
63
+ Block size for parallel computation. Default is 64. This is only used for GPU.
64
+ name : str, optional
65
+ Name of the module.
66
+ """
67
+
68
+ __module__ = 'brainstate.event'
69
+
70
+ def __init__(
71
+ self,
72
+ in_size: Size,
73
+ out_size: Size,
74
+ prob: FloatScalar,
75
+ weight: Union[Callable, ArrayLike],
76
+ allow_multi_conn: bool = True,
77
+ seed: Optional[int] = None,
78
+ float_as_event: bool = True,
79
+ block_size: Optional[int] = None,
80
+ name: Optional[str] = None,
81
+ ):
82
+ super().__init__(name=name)
83
+
84
+ # network parameters
85
+ self.in_size = in_size
86
+ self.out_size = out_size
87
+ self.n_conn = int(self.out_size[-1] * prob)
88
+ if self.n_conn < 1:
89
+ raise ValueError(f"The number of connections must be at least 1. "
90
+ f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
91
+ self.float_as_event = float_as_event
92
+ self.block_size = block_size
93
+
94
+ # indices of post connected neurons
95
+ with jax.ensure_compile_time_eval():
96
+ if allow_multi_conn:
97
+ rng = np.random.RandomState(seed)
98
+ self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
99
+ else:
100
+ rng = RandomState(seed)
101
+
102
+ @vmap(rngs=rng)
103
+ def rand_indices(key):
104
+ rng.set_key(key)
105
+ return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
106
+
107
+ self.indices = rand_indices(rng.split_key(self.in_size[-1]))
108
+ self.indices = u.math.asarray(self.indices)
109
+
110
+ # maximum synaptic conductance
111
+ weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
112
+ self.weight = ParamState(weight)
113
+
114
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
115
+ return event_fixed_prob(
116
+ spk,
117
+ self.weight.value,
118
+ self.indices,
119
+ n_post=self.out_size[-1],
120
+ block_size=self.block_size,
121
+ float_as_event=self.float_as_event
122
+ )
123
+
124
+
125
+ def event_fixed_prob(spk, weight, indices, *, n_post, block_size, float_as_event):
126
+ """
127
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
128
+
129
+ Parameters
130
+ ----------
131
+ weight : brainunit.Quantity or jax.Array
132
+ Maximum synaptic conductance.
133
+ spk : jax.Array
134
+ Spike events.
135
+
136
+ Returns
137
+ -------
138
+ post_data : brainunit.Quantity or jax.Array
139
+ Post synaptic data.
140
+ """
141
+ with jax.ensure_compile_time_eval():
142
+ weight = u.math.asarray(weight)
143
+ unit = u.get_unit(weight)
144
+ weight = u.get_mantissa(weight)
145
+ indices = jnp.asarray(indices)
146
+ spk = jnp.asarray(spk)
147
+
148
+ def mv(spk_vector):
149
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
150
+ return event_ellmv_p_call(
151
+ spk,
152
+ weight,
153
+ indices,
154
+ n_post=n_post,
155
+ block_size=block_size,
156
+ float_as_event=float_as_event
157
+ )
158
+
159
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
160
+ assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
161
+ assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
162
+
163
+ if spk.ndim == 1:
164
+ [post_data] = mv(spk)
165
+ else:
166
+ [post_data] = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
167
+ post_data = u.math.reshape(post_data, spk.shape[:-1] + post_data.shape[-1:])
168
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
169
+
170
+
171
+ Kernel = Callable
172
+
173
+
174
+ def cpu_kernel_generator(
175
+ float_as_event: bool,
176
+ weight_info: jax.ShapeDtypeStruct,
177
+ spike_info: jax.ShapeDtypeStruct,
178
+ **kwargs
179
+ ):
180
+ import numba # pylint: disable=import-outside-toplevel
181
+
182
+ if weight_info.size == 1:
183
+ if spike_info.dtype == jnp.bool_:
184
+ @numba.njit
185
+ def ell_mv(spikes, weights, indices, posts):
186
+ posts[:] = 0.
187
+ w = weights[()]
188
+ for i in range(spikes.shape[0]):
189
+ if spikes[i]:
190
+ for j in range(indices.shape[1]):
191
+ posts[indices[i, j]] += w
192
+
193
+ elif float_as_event:
194
+ @numba.njit
195
+ def ell_mv(spikes, weights, indices, posts):
196
+ posts[:] = 0.
197
+ w = weights[()]
198
+ for i in range(spikes.shape[0]):
199
+ if spikes[i] != 0.:
200
+ for j in range(indices.shape[1]):
201
+ posts[indices[i, j]] += w
202
+
203
+ else:
204
+ @numba.njit
205
+ def ell_mv(spikes, weights, indices, posts):
206
+ posts[:] = 0.
207
+ w = weights[()]
208
+ for i in range(spikes.shape[0]):
209
+ sp = spikes[i]
210
+ if sp != 0.:
211
+ wsp = w * sp
212
+ for j in range(indices.shape[1]):
213
+ posts[indices[i, j]] += wsp
214
+
215
+ else:
216
+ if spike_info.dtype == jnp.bool_:
217
+ @numba.njit
218
+ def ell_mv(spikes, weights, indices, posts):
219
+ posts[:] = 0.
220
+ for i in range(spikes.shape[0]):
221
+ if spikes[i]:
222
+ for j in range(indices.shape[1]):
223
+ posts[indices[i, j]] += weights[i, j]
224
+
225
+ elif float_as_event:
226
+ @numba.njit
227
+ def ell_mv(spikes, weights, indices, posts):
228
+ posts[:] = 0.
229
+ for i in range(spikes.shape[0]):
230
+ if spikes[i] != 0.:
231
+ for j in range(indices.shape[1]):
232
+ posts[indices[i, j]] += weights[i, j]
233
+
234
+ else:
235
+ @numba.njit
236
+ def ell_mv(spikes, weights, indices, posts):
237
+ posts[:] = 0.
238
+ for i in range(spikes.shape[0]):
239
+ sp = spikes[i]
240
+ if sp != 0.:
241
+ for j in range(indices.shape[1]):
242
+ posts[indices[i, j]] += weights[i, j] * sp
243
+
244
+ return ell_mv
245
+
246
+
247
+ def gpu_kernel_generator(
248
+ n_pre: int,
249
+ n_conn: int,
250
+ n_post: int,
251
+ block_size: int,
252
+ float_as_event: bool,
253
+ weight_info: jax.ShapeDtypeStruct,
254
+ **kwargs
255
+ ):
256
+ # 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵,
257
+ # 这个算子的计算逻辑为:
258
+ #
259
+ # - 每个block处理 [block_size] 个事件,每个事件对应一个 pre-synaptic neuron
260
+ # - 每个block处理 [block_size, block_size] 个 indices 和 weights
261
+
262
+ if weight_info.size == 1:
263
+ def _ell_mv_kernel_homo(
264
+ sp_ref, # [block_size]
265
+ ind_ref, # [block_size, block_size]
266
+ _,
267
+ y_ref, # [n_post]
268
+ ):
269
+ r_pid = pl.program_id(0)
270
+ c_start = pl.program_id(1) * block_size
271
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
272
+ mask = jnp.arange(block_size) + c_start < n_conn
273
+
274
+ def body_fn(j, _):
275
+ if sp_ref.dtype == jnp.bool_:
276
+ def true_fn():
277
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
278
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
279
+ # y_ref[ind] += 1.0
280
+ # ind = ind_ref[j, ...]
281
+ # pl.store(y_ref, ind, 1.0, mask=mask)
282
+
283
+ jax.lax.cond(sp_ref[j], true_fn, lambda: None)
284
+
285
+
286
+ else:
287
+ def true_fn(sp):
288
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
289
+ if float_as_event:
290
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask)
291
+ else:
292
+ pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask)
293
+
294
+ sp_ = sp_ref[j]
295
+ jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
296
+
297
+ jax.lax.fori_loop(0, row_length, body_fn, None)
298
+
299
+ # homogenous weights
300
+ kernel = pl.pallas_call(
301
+ _ell_mv_kernel_homo,
302
+ out_shape=[
303
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
304
+ ],
305
+ in_specs=[
306
+ pl.BlockSpec((block_size,), lambda i, j: i),
307
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)),
308
+ pl.BlockSpec((n_post,), lambda i, j: 0)
309
+ ],
310
+ grid=(
311
+ pl.cdiv(n_pre, block_size),
312
+ pl.cdiv(n_conn, block_size),
313
+ ),
314
+ input_output_aliases={2: 0},
315
+ interpret=False
316
+ )
317
+ return (lambda spikes, weight, indices:
318
+ [kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight])
319
+
320
+ else:
321
+ def _ell_mv_kernel_heter(
322
+ sp_ref, # [block_size]
323
+ ind_ref, # [block_size, block_size]
324
+ w_ref, # [block_size, block_size]
325
+ _,
326
+ y_ref, # [n_post]
327
+ ):
328
+ r_pid = pl.program_id(0)
329
+ c_start = pl.program_id(1) * block_size
330
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
331
+ mask = jnp.arange(block_size) + c_start < n_conn
332
+
333
+ def body_fn(j, _):
334
+ if sp_ref.dtype == jnp.bool_:
335
+ def true_fn():
336
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
337
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
338
+ pl.atomic_add(y_ref, ind, w, mask=mask)
339
+
340
+ jax.lax.cond(sp_ref[j], true_fn, lambda: None)
341
+ else:
342
+ def true_fn(spk):
343
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
344
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
345
+ if not float_as_event:
346
+ w = w * spk
347
+ pl.atomic_add(y_ref, ind, w, mask=mask)
348
+
349
+ sp_ = sp_ref[j]
350
+ jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_)
351
+
352
+ jax.lax.fori_loop(0, row_length, body_fn, None)
353
+
354
+ # heterogeneous weights
355
+ kernel = pl.pallas_call(
356
+ _ell_mv_kernel_heter,
357
+ out_shape=[
358
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
359
+ ],
360
+ in_specs=[
361
+ pl.BlockSpec((block_size,), lambda i, j: i), # sp_ref
362
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
363
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref,
364
+ pl.BlockSpec((n_post,), lambda i, j: 0)
365
+ ],
366
+ grid=(
367
+ pl.cdiv(n_pre, block_size),
368
+ pl.cdiv(n_conn, block_size),
369
+ ),
370
+ input_output_aliases={3: 0},
371
+ interpret=False
372
+ )
373
+ return (lambda spikes, weight, indices:
374
+ kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
375
+
376
+
377
+ def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwargs):
378
+ return ellmv_p_call(
379
+ spk_dot,
380
+ weights,
381
+ indices,
382
+ n_post=n_post,
383
+ block_size=block_size,
384
+ )
385
+
386
+
387
+ def jvp_weights(w_dot, spikes, weights, indices, *, float_as_event, block_size, n_post, **kwargs):
388
+ return event_ellmv_p_call(
389
+ spikes,
390
+ w_dot,
391
+ indices,
392
+ n_post=n_post,
393
+ block_size=block_size,
394
+ float_as_event=float_as_event
395
+ )
396
+
397
+
398
+ def transpose_rule(
399
+ ct, spikes, weights, indices,
400
+ *,
401
+ float_as_event, n_post, n_conn, block_size, weight_info, **kwargs
402
+ ):
403
+ if ad.is_undefined_primal(indices):
404
+ raise ValueError("Cannot transpose with respect to sparse indices.")
405
+
406
+ ct = ct[0]
407
+
408
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
409
+ homo = weight_info.size == 1
410
+ if ad.is_undefined_primal(spikes):
411
+ if homo:
412
+ # homogeneous weight
413
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
414
+ else:
415
+ # heterogeneous weight
416
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
417
+ return (ad.Zero(spikes) if type(ct) is ad.Zero else ct_spk), weights, indices
418
+
419
+ else:
420
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
421
+ if homo:
422
+ # scalar
423
+ ct_gmax = event_ellmv_p_call(
424
+ spikes,
425
+ jnp.asarray(1., dtype=weight_info.dtype),
426
+ indices,
427
+ n_post=n_post,
428
+ block_size=block_size,
429
+ float_as_event=float_as_event
430
+ )
431
+ ct_gmax = jnp.inner(ct, ct_gmax[0])
432
+ else:
433
+ def map_fn(one_spk, one_ind):
434
+ if spikes.dtype == jnp.bool_:
435
+ return jax.lax.cond(
436
+ one_spk,
437
+ lambda: ct[one_ind],
438
+ lambda: jnp.zeros([n_conn], weight_info.dtype)
439
+ )
440
+ else:
441
+ if float_as_event:
442
+ return jax.lax.cond(
443
+ one_spk == 0.,
444
+ lambda: jnp.zeros([n_conn], weight_info.dtype),
445
+ lambda: ct[one_ind]
446
+ )
447
+ else:
448
+ return jax.lax.cond(
449
+ one_spk == 0.,
450
+ lambda: jnp.zeros([n_conn], weight_info.dtype),
451
+ lambda: ct[one_ind] * one_spk
452
+ )
453
+
454
+ ct_gmax = jax.vmap(map_fn)(spikes, indices)
455
+ return spikes, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
456
+
457
+
458
+ event_ellmv_p = XLACustomOp(
459
+ 'event_ell_mv',
460
+ cpu_kernel_generator=cpu_kernel_generator,
461
+ gpu_kernel_generator=gpu_kernel_generator,
462
+ )
463
+ event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
464
+ event_ellmv_p.def_transpose_rule(transpose_rule)
465
+
466
+
467
+ def event_ellmv_p_call(spikes, weights, indices, *, n_post, block_size, float_as_event):
468
+ n_conn = indices.shape[1]
469
+ if block_size is None:
470
+ if n_conn <= 16:
471
+ block_size = 16
472
+ elif n_conn <= 32:
473
+ block_size = 32
474
+ elif n_conn <= 64:
475
+ block_size = 64
476
+ elif n_conn <= 128:
477
+ block_size = 128
478
+ elif n_conn <= 256:
479
+ block_size = 256
480
+ else:
481
+ block_size = 128
482
+ return event_ellmv_p(
483
+ spikes,
484
+ weights,
485
+ indices,
486
+ outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)],
487
+ block_size=block_size,
488
+ float_as_event=float_as_event,
489
+ n_pre=spikes.shape[0],
490
+ n_conn=indices.shape[1],
491
+ n_post=n_post,
492
+ weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
493
+ spike_info=jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
494
+ )
495
+
496
+
497
+ def ell_cpu_kernel_generator(
498
+ weight_info: jax.ShapeDtypeStruct,
499
+ **kwargs
500
+ ):
501
+ import numba # pylint: disable=import-outside-toplevel
502
+
503
+ if jnp.size(weight_info) == 1:
504
+ @numba.njit
505
+ def ell_mv(vector, weights, indices, posts):
506
+ posts[:] = 0.
507
+ w = weights[()]
508
+ for i in range(vector.shape[0]):
509
+ wv = w * vector[i]
510
+ for j in range(indices.shape[1]):
511
+ posts[indices[i, j]] += wv
512
+
513
+ else:
514
+ @numba.njit
515
+ def ell_mv(vector, weights, indices, posts):
516
+ posts[:] = 0.
517
+ for i in range(vector.shape[0]):
518
+ for j in range(indices.shape[1]):
519
+ posts[indices[i, j]] += weights[i, j] * vector[i]
520
+
521
+ return ell_mv
522
+
523
+
524
+ def ell_gpu_kernel_generator(
525
+ block_size: int,
526
+ n_pre: int,
527
+ n_conn: int,
528
+ n_post: int,
529
+ weight_info: jax.ShapeDtypeStruct,
530
+ **kwargs
531
+ ):
532
+ homo = jnp.size(weight_info) == 1
533
+
534
+ if homo:
535
+ def _kernel(
536
+ vec_ref, ind_ref, _, out_ref,
537
+ ):
538
+ # 每个block 处理 [block_size] 大小的vector
539
+ # 每个block 处理 [block_size, block_size] 大小的indices 和 weights
540
+
541
+ # -------------------------------
542
+ # vec_ref: [block_size]
543
+ # ind_ref: [block_size, block_size]
544
+ # out_ref: [n_post]
545
+
546
+ r_pid = pl.program_id(0)
547
+ c_start = pl.program_id(1) * block_size
548
+ mask = jnp.arange(block_size) + c_start
549
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
550
+
551
+ def body_fn(j, _):
552
+ y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype)
553
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
554
+ pl.atomic_add(out_ref, ind, y, mask=mask)
555
+
556
+ jax.lax.fori_loop(0, row_length, body_fn, None)
557
+
558
+ # heterogeneous weights
559
+ kernel = pl.pallas_call(
560
+ _kernel,
561
+ out_shape=[
562
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
563
+ ],
564
+ in_specs=[
565
+ pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
566
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
567
+ pl.BlockSpec((n_post,), lambda i, j: 0) # out_ref
568
+ ],
569
+ grid=(
570
+ pl.cdiv(n_pre, block_size),
571
+ pl.cdiv(n_conn, block_size),
572
+ ),
573
+ input_output_aliases={2: 0},
574
+ interpret=False
575
+ )
576
+ return lambda vector, weight, indices: kernel(vector, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight
577
+
578
+ else:
579
+ def _kernel(
580
+ vec_ref, ind_ref, w_ref, _, out_ref,
581
+ ):
582
+ # 每个block 处理 [block_size] 大小的vector
583
+ # 每个block 处理 [block_size, n_conn] 大小的indices 和 weights
584
+
585
+ # -------------------------------
586
+ # vec_ref: [block_size]
587
+ # ind_ref: [block_size, block_size]
588
+ # w_ref: [block_size, block_size]
589
+ # out_ref: [n_post]
590
+
591
+ r_pid = pl.program_id(0)
592
+ c_start = pl.program_id(1) * block_size
593
+ mask = jnp.arange(block_size) + c_start
594
+ row_length = jnp.minimum(n_pre - r_pid * block_size, block_size)
595
+
596
+ def body_fn(j, _):
597
+ w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask)
598
+ y = w * vec_ref[j]
599
+ ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask)
600
+ pl.atomic_add(out_ref, ind, y, mask=mask)
601
+
602
+ jax.lax.fori_loop(0, row_length, body_fn, None)
603
+
604
+ # heterogeneous weights
605
+ kernel = pl.pallas_call(
606
+ _kernel,
607
+ out_shape=[
608
+ jax.ShapeDtypeStruct((n_post,), weight_info.dtype),
609
+ ],
610
+ in_specs=[
611
+ pl.BlockSpec((block_size,), lambda i, j: i), # vec_ref
612
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # ind_ref
613
+ pl.BlockSpec((block_size, block_size), lambda i, j: (i, j)), # w_ref
614
+ pl.BlockSpec((n_post,), lambda i: 0) # out_ref
615
+ ],
616
+ grid=(
617
+ pl.cdiv(n_pre, block_size),
618
+ pl.cdiv(n_conn, block_size),
619
+ ),
620
+ input_output_aliases={3: 0},
621
+ interpret=False
622
+ )
623
+ return lambda vector, weight, indices: kernel(vector, indices, weight, jnp.zeros(n_post, dtype=weight.dtype))
624
+
625
+
626
+ def jvp_weights_no_spk(w_dot, vector, weights, indices, *, block_size, n_post, **kwargs):
627
+ return ellmv_p_call(
628
+ vector,
629
+ w_dot,
630
+ indices,
631
+ block_size=block_size,
632
+ n_post=n_post,
633
+ )
634
+
635
+
636
+ def transpose_rule_no_spk(
637
+ ct, vector, weights, indices,
638
+ *,
639
+ n_post, block_size, weight_info, **kwargs
640
+ ):
641
+ if ad.is_undefined_primal(indices):
642
+ raise ValueError("Cannot transpose with respect to sparse indices.")
643
+
644
+ ct = ct[0]
645
+
646
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
647
+ homo = weight_info.size == 1
648
+ if ad.is_undefined_primal(vector):
649
+ if homo:
650
+ # homogeneous weight
651
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weights))(indices)
652
+ else:
653
+ # heterogeneous weight
654
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weights)
655
+ return (ad.Zero(vector) if type(ct) is ad.Zero else ct_spk), weights, indices
656
+
657
+ else:
658
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
659
+ if homo:
660
+ # scalar
661
+ ct_gmax = ellmv_p_call(
662
+ vector,
663
+ jnp.asarray(1., dtype=weight_info.dtype),
664
+ indices,
665
+ block_size=block_size,
666
+ n_post=n_post,
667
+ )
668
+ ct_gmax = jnp.inner(ct, ct_gmax[0])
669
+ else:
670
+ ct_gmax = jax.vmap(lambda vec, one_ind: ct[one_ind] * vec)(vector, indices)
671
+ return vector, (ad.Zero(weights) if type(ct) is ad.Zero else ct_gmax), indices
672
+
673
+
674
+ ellmv_p = XLACustomOp(
675
+ 'ell_mv',
676
+ cpu_kernel_generator=ell_cpu_kernel_generator,
677
+ gpu_kernel_generator=ell_gpu_kernel_generator,
678
+ )
679
+ ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
680
+ ellmv_p.def_transpose_rule(transpose_rule_no_spk)
681
+
682
+
683
+ def ellmv_p_call(vector, weights, indices, *, n_post, block_size):
684
+ n_conn = indices.shape[1]
685
+ if block_size is None:
686
+ if n_conn <= 16:
687
+ block_size = 16
688
+ elif n_conn <= 32:
689
+ block_size = 32
690
+ elif n_conn <= 64:
691
+ block_size = 64
692
+ elif n_conn <= 128:
693
+ block_size = 128
694
+ elif n_conn <= 256:
695
+ block_size = 256
696
+ else:
697
+ block_size = 128
698
+ return ellmv_p(
699
+ vector,
700
+ weights,
701
+ indices,
702
+ n_post=n_post,
703
+ n_pre=indices.shape[0],
704
+ n_conn=indices.shape[1],
705
+ block_size=block_size,
706
+ weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
707
+ outs=[jax.ShapeDtypeStruct([n_post], weights.dtype)]
708
+ )