brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -22,35 +22,37 @@ import brainstate as bst
22
22
 
23
23
 
24
24
  class TestEnviron(unittest.TestCase):
25
- def test_precision(self):
26
- with bst.environ.context(precision=64):
27
- a = bst.random.randn(1)
28
- self.assertEqual(a.dtype, jnp.float64)
29
-
30
- with bst.environ.context(precision=32):
31
- a = bst.random.randn(1)
32
- self.assertEqual(a.dtype, jnp.float32)
33
-
34
- with bst.environ.context(precision=16):
35
- a = bst.random.randn(1)
36
- self.assertEqual(a.dtype, jnp.bfloat16)
37
-
38
- def test_platform(self):
39
- with self.assertRaises(ValueError):
40
- with bst.environ.context(platform='cpu'):
41
- a = bst.random.randn(1)
42
- self.assertEqual(a.device(), 'cpu')
43
-
44
- def test_register_default_behavior(self):
45
- dt_ = 0.1
46
-
47
- def dt_behavior(dt):
48
- nonlocal dt_
49
- dt_ = dt
50
- print(f'dt: {dt}')
51
-
52
- bst.environ.register_default_behavior('dt', dt_behavior)
53
-
54
- with bst.environ.context(dt=0.2):
55
- self.assertEqual(dt_, 0.2)
56
- self.assertEqual(dt_, 0.1)
25
+ def test_precision(self):
26
+ with bst.environ.context(precision=64):
27
+ a = bst.random.randn(1)
28
+ self.assertEqual(a.dtype, jnp.float64)
29
+
30
+ with bst.environ.context(precision=32):
31
+ a = bst.random.randn(1)
32
+ self.assertEqual(a.dtype, jnp.float32)
33
+
34
+ with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.bfloat16)
37
+
38
+ def test_platform(self):
39
+ with self.assertRaises(ValueError):
40
+ with bst.environ.context(platform='cpu'):
41
+ a = bst.random.randn(1)
42
+ self.assertEqual(a.device(), 'cpu')
43
+
44
+ def test_register_default_behavior(self):
45
+ bst.environ.set(dt=0.1)
46
+
47
+ dt_ = 0.1
48
+
49
+ def dt_behavior(dt):
50
+ nonlocal dt_
51
+ dt_ = dt
52
+ print(f'dt: {dt}')
53
+
54
+ bst.environ.register_default_behavior('dt', dt_behavior)
55
+
56
+ with bst.environ.context(dt=0.2):
57
+ self.assertEqual(dt_, 0.2)
58
+ self.assertEqual(dt_, 0.1)
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from .csr import *
18
- from .csr import __all__ as __all_csr
19
- from .fixed_probability import *
20
- from .fixed_probability import __all__ as __all_fixed_probability
21
- from .linear import *
22
- from .linear import __all__ as __all_linear
17
+ from ._csr import *
18
+ from ._csr import __all__ as __all_csr
19
+ from ._fixed_probability import *
20
+ from ._fixed_probability import __all__ as __all_fixed_probability
21
+ from ._linear import *
22
+ from ._linear import __all__ as __all_linear
23
23
 
24
24
  __all__ = __all_fixed_probability + __all_linear + __all_csr
25
25
  del __all_fixed_probability, __all_linear, __all_csr
@@ -0,0 +1,308 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Callable, Optional
18
+
19
+ import brainunit as u
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate._state import ParamState, State
25
+ from brainstate._utils import set_module_as
26
+ from brainstate.init import param
27
+ from brainstate.nn._module import Module
28
+ from brainstate.typing import ArrayLike
29
+ from ._misc import IntScalar
30
+
31
+ __all__ = [
32
+ 'CSRLinear',
33
+ ]
34
+
35
+
36
+ class CSRLinear(Module):
37
+ """
38
+ The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
39
+
40
+ Parameters
41
+ ----------
42
+ n_pre : int
43
+ Number of pre-synaptic neurons.
44
+ n_post : int
45
+ Number of post-synaptic neurons.
46
+ weight : float or callable or jax.Array or brainunit.Quantity
47
+ Maximum synaptic conductance.
48
+ name : str, optional
49
+ Name of the module.
50
+ """
51
+
52
+ __module__ = 'brainstate.event'
53
+
54
+ def __init__(
55
+ self,
56
+ n_pre: IntScalar,
57
+ n_post: IntScalar,
58
+ indptr: ArrayLike,
59
+ indices: ArrayLike,
60
+ weight: Union[Callable, ArrayLike],
61
+ name: Optional[str] = None,
62
+ grad_mode: str = 'vjp'
63
+ ):
64
+ super().__init__(name=name)
65
+
66
+ self.in_size = n_pre
67
+ self.out_size = n_post
68
+ self.n_pre = n_pre
69
+ self.n_post = n_post
70
+
71
+ # gradient mode
72
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
73
+ self.grad_mode = grad_mode
74
+
75
+ # CSR data structure
76
+ indptr = jnp.asarray(indptr)
77
+ indices = jnp.asarray(indices)
78
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
+ assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
81
+ self.indptr = u.math.asarray(indptr)
82
+ self.indices = u.math.asarray(indices)
83
+
84
+ # maximum synaptic conductance
85
+ weight = param(weight, (len(indices),), allow_none=False)
86
+ if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
87
+ raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
88
+ self.weight = ParamState(weight)
89
+
90
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
91
+ weight = self.weight.value if isinstance(self.weight, State) else self.weight
92
+ if len(self.indices) == 0:
93
+ r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
94
+ dtype=weight.dtype,
95
+ unit=u.get_unit(weight) * u.get_unit(spk))
96
+ return u.maybe_decimal(r)
97
+
98
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
99
+ if device_kind == 'cpu':
100
+ return cpu_event_csr(
101
+ u.math.asarray(spk),
102
+ self.indptr,
103
+ self.indices,
104
+ u.math.asarray(weight),
105
+ n_post=self.n_post, grad_mode=self.grad_mode
106
+ )
107
+ elif device_kind in ['gpu', 'tpu']:
108
+ raise NotImplementedError()
109
+ else:
110
+ raise ValueError(f"Unsupported device: {device_kind}")
111
+
112
+
113
+ @set_module_as('brainstate.event')
114
+ def cpu_event_csr(
115
+ spk: jax.Array,
116
+ indptr: jax.Array,
117
+ indices: jax.Array,
118
+ weight: Union[u.Quantity, jax.Array],
119
+ *,
120
+ n_post: int,
121
+ grad_mode: str = 'vjp'
122
+ ) -> Union[u.Quantity, jax.Array]:
123
+ """
124
+ The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
125
+
126
+ Parameters
127
+ ----------
128
+ spk : jax.Array
129
+ Spike events.
130
+ indptr : jax.Array
131
+ Index pointer of post connected neurons.
132
+ indices : jax.Array
133
+ Indices of post connected neurons.
134
+ weight : brainunit.Quantity or jax.Array
135
+ Maximum synaptic conductance.
136
+ n_post : int
137
+ Number of post-synaptic neurons.
138
+ grad_mode : str, optional
139
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
140
+
141
+ Returns
142
+ -------
143
+ post_data : brainunit.Quantity or jax.Array
144
+ Post synaptic data.
145
+ """
146
+ unit = u.get_unit(weight)
147
+ weight = u.get_mantissa(weight)
148
+
149
+ def mv(spk_vector):
150
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
151
+ if grad_mode == 'vjp':
152
+ post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
153
+ elif grad_mode == 'jvp':
154
+ post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
155
+ else:
156
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
157
+ return post_data
158
+
159
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
160
+ assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
161
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
162
+
163
+ if spk.ndim == 1:
164
+ post_data = mv(spk)
165
+ else:
166
+ shape = spk.shape[:-1]
167
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
168
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
169
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
170
+
171
+
172
+ # --------------
173
+ # Implementation
174
+ # --------------
175
+
176
+
177
+ def _cpu_event_csr_mv(
178
+ spk: jax.Array,
179
+ indptr: jax.Array,
180
+ indices: jax.Array,
181
+ weight: Union[u.Quantity, jax.Array],
182
+ n_post: int
183
+ ) -> jax.Array:
184
+ bool_x = spk.dtype == jnp.bool_
185
+ homo_w = jnp.size(weight) == 1
186
+
187
+ def add_fn(post_val, i_start, i_end, sp):
188
+ def body_fn(x):
189
+ post, i = x
190
+ i_post = indices[i]
191
+ w = weight if homo_w else weight[i]
192
+ w = w if bool_x else w * sp
193
+ post = post.at[i_post].add(w)
194
+ return post, i + 1
195
+
196
+ return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
197
+
198
+ def scan_fn(post, i):
199
+ sp = spk[i] # pre-synaptic spike event
200
+ if bool_x:
201
+ post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
202
+ else:
203
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
204
+ return post, None
205
+
206
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
207
+
208
+
209
+ # --------------
210
+ # VJP
211
+ # --------------
212
+
213
+ def _cpu_event_csr_mv_fwd(
214
+ spk: jax.Array,
215
+ indptr: jax.Array,
216
+ indices: jax.Array,
217
+ weight: Union[u.Quantity, jax.Array],
218
+ n_post: int
219
+ ):
220
+ return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
221
+
222
+
223
+ def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
224
+ spk, weight = res
225
+ homo = jnp.size(weight) == 1
226
+ bool_spk = spk.dtype == jnp.bool_
227
+
228
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
229
+ def fn_spk(i_pre):
230
+ def body_fn(x):
231
+ r, i = x
232
+ i_post = indices[i]
233
+ r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
234
+ return r, i + 1
235
+
236
+ p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
237
+ p = p * weight if homo else p
238
+ return p
239
+
240
+ ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
241
+
242
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
243
+ if homo: # scalar
244
+ ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
245
+ ct_gmax = jnp.inner(ct, ct_gmax)
246
+ else:
247
+ def single_post(dw, i_pre):
248
+ def body_fn(x):
249
+ dw, i = x
250
+ i_post = indices[i]
251
+ dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
252
+ return dw, i + 1
253
+
254
+ return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
255
+
256
+ def fn_w(dw, i_pre):
257
+ sp = spk[i_pre]
258
+ if bool_spk:
259
+ return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
260
+ else:
261
+ return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
262
+
263
+ ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
264
+ return ct_spk, ct_gmax
265
+
266
+
267
+ _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
268
+ _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
269
+
270
+
271
+ # --------------
272
+ # JVP
273
+ # --------------
274
+
275
+
276
+ def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
277
+ # forward pass
278
+ spk, weight = primals
279
+ y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
280
+
281
+ # forward gradients
282
+ spk_dot, weight_dot = tangents
283
+ homo_w = jnp.size(weight) == 1
284
+
285
+ # ∂y/∂gmax
286
+ dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
287
+
288
+ # ∂y/∂gspk
289
+ def scan_fn(post, i_pre):
290
+ def while_fn(x):
291
+ p, i, sp = x
292
+ i_post = indices[i]
293
+ p = p.at[i_post].add(sp if homo_w else sp * weight[i])
294
+ return p, i + 1, sp
295
+
296
+ post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
297
+ while_fn,
298
+ (post, indptr[i_pre], spk_dot[i_pre]))[0]
299
+
300
+ return post, None
301
+
302
+ dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
303
+ dspk = (dspk * weight) if homo_w else dspk
304
+ return y, dweight + dspk
305
+
306
+
307
+ _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
308
+ _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -0,0 +1,118 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax.numpy
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from absl.testing import parameterized
22
+
23
+ import brainstate as bst
24
+
25
+
26
+ def _get_csr(n_pre, n_post, prob):
27
+ n_conn = int(n_post * prob)
28
+ indptr = np.arange(n_pre + 1) * n_conn
29
+ indices = np.random.randint(0, n_post, (n_pre * n_conn,))
30
+ return indptr, indices
31
+
32
+
33
+ def true_fn(x, w, indices, indptr, n_out):
34
+ homo_w = jnp.size(w) == 1
35
+
36
+ post = jnp.zeros((n_out,))
37
+ for i_pre in range(x.shape[0]):
38
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
39
+ post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
40
+ return post
41
+
42
+
43
+ class TestFixedProbCSR(parameterized.TestCase):
44
+ @parameterized.product(
45
+ homo_w=[True, False],
46
+ )
47
+ def test1(self, homo_w):
48
+ x = bst.random.rand(20) < 0.1
49
+ indptr, indices = _get_csr(20, 40, 0.1)
50
+ m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
+ y = m(x)
52
+ y2 = true_fn(x, m.weight.value, indices, indptr, 40)
53
+ self.assertTrue(jnp.allclose(y, y2))
54
+
55
+ @parameterized.product(
56
+ bool_x=[True, False],
57
+ homo_w=[True, False]
58
+ )
59
+ def test_vjp(self, bool_x, homo_w):
60
+ n_in = 20
61
+ n_out = 30
62
+ if bool_x:
63
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
+ else:
65
+ x = bst.random.rand(n_in)
66
+
67
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
68
+ fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
+ w = fn.weight.value
70
+
71
+ def f(x, w):
72
+ fn.weight.value = w
73
+ return fn(x).sum()
74
+
75
+ r = jax.grad(f, argnums=(0, 1))(x, w)
76
+
77
+ # -------------------
78
+ # TRUE gradients
79
+
80
+ def f2(x, w):
81
+ return true_fn(x, w, indices, indptr, n_out).sum()
82
+
83
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
+ self.assertTrue(jnp.allclose(r[0], r2[0]))
85
+ self.assertTrue(jnp.allclose(r[1], r2[1]))
86
+
87
+ @parameterized.product(
88
+ bool_x=[True, False],
89
+ homo_w=[True, False]
90
+ )
91
+ def test_jvp(self, bool_x, homo_w):
92
+ n_in = 20
93
+ n_out = 30
94
+ if bool_x:
95
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
+ else:
97
+ x = bst.random.rand(n_in)
98
+
99
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
100
+ fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
101
+ 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
+ w = fn.weight.value
103
+
104
+ def f(x, w):
105
+ fn.weight.value = w
106
+ return fn(x)
107
+
108
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
+
110
+ # -------------------
111
+ # TRUE gradients
112
+
113
+ def f2(x, w):
114
+ return true_fn(x, w, indices, indptr, n_out)
115
+
116
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
+ self.assertTrue(jnp.allclose(r1, r2))
118
+ self.assertTrue(jnp.allclose(o1, o2))