brainstate 0.0.2.post20240910__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/init/_regular_inits.py +7 -7
  12. brainstate/mixin.py +1 -1
  13. brainstate/nn/__init__.py +7 -2
  14. brainstate/nn/_base.py +2 -2
  15. brainstate/nn/_connections.py +4 -4
  16. brainstate/nn/_dynamics.py +5 -5
  17. brainstate/nn/_elementwise.py +9 -9
  18. brainstate/nn/_embedding.py +3 -3
  19. brainstate/nn/_normalizations.py +3 -3
  20. brainstate/nn/_others.py +2 -2
  21. brainstate/nn/_poolings.py +6 -6
  22. brainstate/nn/_rate_rnns.py +1 -1
  23. brainstate/nn/_readout.py +1 -1
  24. brainstate/nn/_synouts.py +1 -1
  25. brainstate/nn/event/__init__.py +25 -0
  26. brainstate/nn/event/_misc.py +34 -0
  27. brainstate/nn/event/csr.py +312 -0
  28. brainstate/nn/event/csr_test.py +118 -0
  29. brainstate/nn/event/fixed_probability.py +276 -0
  30. brainstate/nn/event/fixed_probability_test.py +127 -0
  31. brainstate/nn/event/linear.py +220 -0
  32. brainstate/nn/event/linear_test.py +111 -0
  33. brainstate/nn/metrics.py +390 -0
  34. brainstate/optim/__init__.py +5 -1
  35. brainstate/optim/_optax_optimizer.py +208 -0
  36. brainstate/optim/_optax_optimizer_test.py +14 -0
  37. brainstate/random/__init__.py +24 -0
  38. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  39. brainstate/random/_rand_seed.py +169 -0
  40. brainstate/random/_rand_state.py +1491 -0
  41. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  42. brainstate/{random_test.py → random/random_test.py} +208 -191
  43. brainstate/transform/_jit.py +1 -1
  44. brainstate/transform/_jit_test.py +19 -0
  45. brainstate/transform/_make_jaxpr.py +1 -1
  46. {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
  47. brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
  48. brainstate-0.0.2.post20240910.dist-info/RECORD +0 -70
  49. {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
  50. {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
  51. {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import functools
21
21
  from typing import Sequence, Optional
22
22
  from typing import Union, Tuple, Callable, List
23
23
 
24
- import brainunit as bu
24
+ import brainunit as u
25
25
  import jax
26
26
  import jax.numpy as jnp
27
27
  import numpy as np
@@ -29,7 +29,7 @@ import numpy as np
29
29
  from ._base import DnnLayer, ExplicitInOutSize
30
30
  from .. import environ
31
31
  from ..mixin import Mode
32
- from ..typing import Size
32
+ from brainstate.typing import Size
33
33
 
34
34
  __all__ = [
35
35
  'Flatten', 'Unflatten',
@@ -85,7 +85,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
85
85
 
86
86
  if in_size is not None:
87
87
  self.in_size = tuple(in_size)
88
- y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
88
+ y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
89
89
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
90
90
  self.out_size = y.shape
91
91
 
@@ -101,7 +101,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
101
101
  start_axis = self.start_axis + dim_diff
102
102
  else:
103
103
  start_axis = x.ndim + self.start_axis
104
- return bu.math.flatten(x, start_axis, self.end_axis)
104
+ return u.math.flatten(x, start_axis, self.end_axis)
105
105
 
106
106
  def __repr__(self) -> str:
107
107
  return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
@@ -153,12 +153,12 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
153
153
 
154
154
  if in_size is not None:
155
155
  self.in_size = tuple(in_size)
156
- y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
156
+ y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
157
157
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
158
158
  self.out_size = y.shape
159
159
 
160
160
  def update(self, x):
161
- return bu.math.unflatten(x, self.axis, self.sizes)
161
+ return u.math.unflatten(x, self.axis, self.sizes)
162
162
 
163
163
  def __repr__(self):
164
164
  return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
@@ -27,7 +27,7 @@ from .. import random, init, functional
27
27
  from .._module import Module
28
28
  from .._state import ShortTermState, ParamState
29
29
  from ..mixin import DelayedInit, Mode
30
- from ..typing import ArrayLike
30
+ from brainstate.typing import ArrayLike
31
31
 
32
32
  __all__ = [
33
33
  'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
brainstate/nn/_readout.py CHANGED
@@ -29,7 +29,7 @@ from ._misc import exp_euler_step
29
29
  from .. import environ, init, surrogate
30
30
  from .._state import ShortTermState, ParamState
31
31
  from ..mixin import Mode
32
- from ..typing import Size, ArrayLike, DTypeLike
32
+ from brainstate.typing import Size, ArrayLike, DTypeLike
33
33
 
34
34
  __all__ = [
35
35
  'LeakyRateReadout',
brainstate/nn/_synouts.py CHANGED
@@ -23,7 +23,7 @@ import jax.numpy as jnp
23
23
 
24
24
  from .._module import Module
25
25
  from ..mixin import DelayedInit, BindCondData
26
- from ..typing import ArrayLike
26
+ from brainstate.typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'SynOut', 'COBA', 'CUBA', 'MgBlock',
@@ -0,0 +1,25 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from .csr import *
18
+ from .csr import __all__ as __all_csr
19
+ from .fixed_probability import *
20
+ from .fixed_probability import __all__ as __all_fixed_probability
21
+ from .linear import *
22
+ from .linear import __all__ as __all_linear
23
+
24
+ __all__ = __all_fixed_probability + __all_linear + __all_csr
25
+ del __all_fixed_probability, __all_linear, __all_csr
@@ -0,0 +1,34 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Union
18
+
19
+ import numpy as np
20
+
21
+ __all__ = [
22
+ 'FloatScalar',
23
+ 'IntScalar',
24
+ ]
25
+
26
+ FloatScalar = Union[
27
+ np.number, # NumPy scalar types
28
+ float, # Python scalar types
29
+ ]
30
+
31
+ IntScalar = Union[
32
+ np.number, # NumPy scalar types
33
+ int, # Python scalar types
34
+ ]
@@ -0,0 +1,312 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Union, Callable, Optional
17
+
18
+ import brainunit as u
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+
23
+ from brainstate._state import ParamState, State
24
+ from brainstate.init import param
25
+ from brainstate.mixin import Mode, Training
26
+ from brainstate.nn._base import DnnLayer
27
+ from brainstate.typing import ArrayLike
28
+ from ._misc import IntScalar
29
+
30
+ __all__ = [
31
+ 'EventCSR',
32
+ ]
33
+
34
+
35
+ class EventCSR(DnnLayer):
36
+ """
37
+ The EventCSR module implements a fixed probability connection with CSR sparse data structure.
38
+
39
+ Parameters
40
+ ----------
41
+ n_pre : int
42
+ Number of pre-synaptic neurons.
43
+ n_post : int
44
+ Number of post-synaptic neurons.
45
+ weight : float or callable or jax.Array or brainunit.Quantity
46
+ Maximum synaptic conductance.
47
+ name : str, optional
48
+ Name of the module.
49
+ mode : brainstate.mixin.Mode, optional
50
+ Mode of the module.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ n_pre: IntScalar,
56
+ n_post: IntScalar,
57
+ indptr: ArrayLike,
58
+ indices: ArrayLike,
59
+ weight: Union[Callable, ArrayLike],
60
+ name: Optional[str] = None,
61
+ mode: Optional[Mode] = None,
62
+ grad_mode: str = 'vjp'
63
+ ):
64
+ super().__init__(name=name, mode=mode)
65
+
66
+ self.in_size = n_pre
67
+ self.out_size = n_post
68
+ self.n_pre = n_pre
69
+ self.n_post = n_post
70
+
71
+ # gradient mode
72
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
73
+ self.grad_mode = grad_mode
74
+
75
+ # CSR data structure
76
+ indptr = jnp.asarray(indptr)
77
+ indices = jnp.asarray(indices)
78
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
+ assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
81
+ self.indptr = indptr
82
+ self.indices = indices
83
+
84
+ # maximum synaptic conductance
85
+ weight = param(weight, (len(indices),), allow_none=False)
86
+ # if callable(weight):
87
+ # pass
88
+ # else:
89
+ # if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
90
+ # raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
91
+ if self.mode.has(Training):
92
+ weight = ParamState(weight)
93
+ self.weight = weight
94
+
95
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
96
+ weight = self.weight.value if isinstance(self.weight, State) else self.weight
97
+ if len(self.indices) == 0:
98
+ r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
99
+ dtype=weight.dtype,
100
+ unit=u.get_unit(weight) * u.get_unit(spk))
101
+ return u.maybe_decimal(r)
102
+
103
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
104
+ if device_kind == 'cpu':
105
+ return cpu_event_csr(
106
+ u.math.asarray(spk),
107
+ u.math.asarray(self.indptr),
108
+ u.math.asarray(self.indices),
109
+ u.math.asarray(weight),
110
+ n_post=self.n_post, grad_mode=self.grad_mode
111
+ )
112
+ elif device_kind in ['gpu', 'tpu']:
113
+ raise NotImplementedError()
114
+ else:
115
+ raise ValueError(f"Unsupported device: {device_kind}")
116
+
117
+
118
+ def cpu_event_csr(
119
+ spk: jax.Array,
120
+ indptr: jax.Array,
121
+ indices: jax.Array,
122
+ weight: Union[u.Quantity, jax.Array],
123
+ *,
124
+ n_post: int,
125
+ grad_mode: str = 'vjp'
126
+ ) -> Union[u.Quantity, jax.Array]:
127
+ """
128
+ The EventCSR module implements a fixed probability connection with CSR sparse data structure.
129
+
130
+ Parameters
131
+ ----------
132
+ spk : jax.Array
133
+ Spike events.
134
+ indptr : jax.Array
135
+ Index pointer of post connected neurons.
136
+ indices : jax.Array
137
+ Indices of post connected neurons.
138
+ weight : brainunit.Quantity or jax.Array
139
+ Maximum synaptic conductance.
140
+ n_post : int
141
+ Number of post-synaptic neurons.
142
+ grad_mode : str, optional
143
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
144
+
145
+ Returns
146
+ -------
147
+ post_data : brainunit.Quantity or jax.Array
148
+ Post synaptic data.
149
+ """
150
+ unit = u.get_unit(weight)
151
+ weight = u.get_mantissa(weight)
152
+
153
+ def mv(spk_vector):
154
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
155
+ if grad_mode == 'vjp':
156
+ post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
157
+ elif grad_mode == 'jvp':
158
+ post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
159
+ else:
160
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
161
+ return post_data
162
+
163
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
164
+ assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
165
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
166
+
167
+ if spk.ndim == 1:
168
+ post_data = mv(spk)
169
+ else:
170
+ shape = spk.shape[:-1]
171
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
172
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
173
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
174
+
175
+
176
+ # --------------
177
+ # Implementation
178
+ # --------------
179
+
180
+
181
+ def _cpu_event_csr_mv(
182
+ spk: jax.Array,
183
+ indptr: jax.Array,
184
+ indices: jax.Array,
185
+ weight: Union[u.Quantity, jax.Array],
186
+ n_post: int
187
+ ) -> jax.Array:
188
+ bool_x = spk.dtype == jnp.bool_
189
+ homo_w = jnp.size(weight) == 1
190
+
191
+ def add_fn(post_val, i_start, i_end, sp):
192
+ def body_fn(x):
193
+ post, i = x
194
+ i_post = indices[i]
195
+ w = weight if homo_w else weight[i]
196
+ w = w if bool_x else w * sp
197
+ post = post.at[i_post].add(w)
198
+ return post, i + 1
199
+
200
+ return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
201
+
202
+ def scan_fn(post, i):
203
+ sp = spk[i] # pre-synaptic spike event
204
+ if bool_x:
205
+ post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
206
+ else:
207
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
208
+ return post, None
209
+
210
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
211
+
212
+
213
+ # --------------
214
+ # VJP
215
+ # --------------
216
+
217
+ def _cpu_event_csr_mv_fwd(
218
+ spk: jax.Array,
219
+ indptr: jax.Array,
220
+ indices: jax.Array,
221
+ weight: Union[u.Quantity, jax.Array],
222
+ n_post: int
223
+ ):
224
+ return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
225
+
226
+
227
+ def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
228
+ spk, weight = res
229
+ homo = jnp.size(weight) == 1
230
+ bool_spk = spk.dtype == jnp.bool_
231
+
232
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
233
+ def fn_spk(i_pre):
234
+ def body_fn(x):
235
+ r, i = x
236
+ i_post = indices[i]
237
+ r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
238
+ return r, i + 1
239
+
240
+ p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
241
+ p = p * weight if homo else p
242
+ return p
243
+
244
+ ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
245
+
246
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
247
+ if homo: # scalar
248
+ ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
249
+ ct_gmax = jnp.inner(ct, ct_gmax)
250
+ else:
251
+ def single_post(dw, i_pre):
252
+ def body_fn(x):
253
+ dw, i = x
254
+ i_post = indices[i]
255
+ dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
256
+ return dw, i + 1
257
+
258
+ return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
259
+
260
+ def fn_w(dw, i_pre):
261
+ sp = spk[i_pre]
262
+ if bool_spk:
263
+ return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
264
+ else:
265
+ return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
266
+
267
+ ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
268
+ return ct_spk, ct_gmax
269
+
270
+
271
+ _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
272
+ _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
273
+
274
+
275
+ # --------------
276
+ # JVP
277
+ # --------------
278
+
279
+
280
+ def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
281
+ # forward pass
282
+ spk, weight = primals
283
+ y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
284
+
285
+ # forward gradients
286
+ spk_dot, weight_dot = tangents
287
+ homo_w = jnp.size(weight) == 1
288
+
289
+ # ∂y/∂gmax
290
+ dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
291
+
292
+ # ∂y/∂gspk
293
+ def scan_fn(post, i_pre):
294
+ def while_fn(x):
295
+ p, i, sp = x
296
+ i_post = indices[i]
297
+ p = p.at[i_post].add(sp if homo_w else sp * weight[i])
298
+ return p, i + 1, sp
299
+
300
+ post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
301
+ while_fn,
302
+ (post, indptr[i_pre], spk_dot[i_pre]))[0]
303
+
304
+ return post, None
305
+
306
+ dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
307
+ dspk = (dspk * weight) if homo_w else dspk
308
+ return y, dweight + dspk
309
+
310
+
311
+ _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
312
+ _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -0,0 +1,118 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import jax.numpy
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from absl.testing import parameterized
21
+
22
+ import brainstate as bst
23
+ from brainstate.nn.event.csr import EventCSR
24
+
25
+
26
+ def _get_csr(n_pre, n_post, prob):
27
+ n_conn = int(n_post * prob)
28
+ indptr = np.arange(n_pre + 1) * n_conn
29
+ indices = np.random.randint(0, n_post, (n_pre * n_conn,))
30
+ return indptr, indices
31
+
32
+
33
+ def true_fn(x, w, indices, indptr, n_out):
34
+ homo_w = jnp.size(w) == 1
35
+
36
+ post = jnp.zeros((n_out,))
37
+ for i_pre in range(x.shape[0]):
38
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
39
+ post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
40
+ return post
41
+
42
+
43
+ class TestFixedProbCSR(parameterized.TestCase):
44
+ @parameterized.product(
45
+ homo_w=[True, False],
46
+ )
47
+ def test1(self, homo_w):
48
+ x = bst.random.rand(20) < 0.1
49
+ indptr, indices = _get_csr(20, 40, 0.1)
50
+ m = EventCSR(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
+ y = m(x)
52
+ y2 = true_fn(x, m.weight, indices, indptr, 40)
53
+ self.assertTrue(jnp.allclose(y, y2))
54
+
55
+ @parameterized.product(
56
+ bool_x=[True, False],
57
+ homo_w=[True, False]
58
+ )
59
+ def test_vjp(self, bool_x, homo_w):
60
+ n_in = 20
61
+ n_out = 30
62
+ if bool_x:
63
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
+ else:
65
+ x = bst.random.rand(n_in)
66
+
67
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
68
+ fn = bst.nn.EventCSR(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
+ w = fn.weight
70
+
71
+ def f(x, w):
72
+ fn.weight = w
73
+ return fn(x).sum()
74
+
75
+ r = jax.grad(f, argnums=(0, 1))(x, w)
76
+
77
+ # -------------------
78
+ # TRUE gradients
79
+
80
+ def f2(x, w):
81
+ return true_fn(x, w, indices, indptr, n_out).sum()
82
+
83
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
+ self.assertTrue(jnp.allclose(r[0], r2[0]))
85
+ self.assertTrue(jnp.allclose(r[1], r2[1]))
86
+
87
+ @parameterized.product(
88
+ bool_x=[True, False],
89
+ homo_w=[True, False]
90
+ )
91
+ def test_jvp(self, bool_x, homo_w):
92
+ n_in = 20
93
+ n_out = 30
94
+ if bool_x:
95
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
+ else:
97
+ x = bst.random.rand(n_in)
98
+
99
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
100
+ fn = EventCSR(n_in, n_out, indptr, indices,
101
+ 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
+ w = fn.weight
103
+
104
+ def f(x, w):
105
+ fn.weight = w
106
+ return fn(x)
107
+
108
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
+
110
+ # -------------------
111
+ # TRUE gradients
112
+
113
+ def f2(x, w):
114
+ return true_fn(x, w, indices, indptr, n_out)
115
+
116
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
+ self.assertTrue(jnp.allclose(r1, r2))
118
+ self.assertTrue(jnp.allclose(o1, o2))