brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -22,35 +22,37 @@ import brainstate as bst
22
22
 
23
23
 
24
24
  class TestEnviron(unittest.TestCase):
25
- def test_precision(self):
26
- with bst.environ.context(precision=64):
27
- a = bst.random.randn(1)
28
- self.assertEqual(a.dtype, jnp.float64)
29
-
30
- with bst.environ.context(precision=32):
31
- a = bst.random.randn(1)
32
- self.assertEqual(a.dtype, jnp.float32)
33
-
34
- with bst.environ.context(precision=16):
35
- a = bst.random.randn(1)
36
- self.assertEqual(a.dtype, jnp.bfloat16)
37
-
38
- def test_platform(self):
39
- with self.assertRaises(ValueError):
40
- with bst.environ.context(platform='cpu'):
41
- a = bst.random.randn(1)
42
- self.assertEqual(a.device(), 'cpu')
43
-
44
- def test_register_default_behavior(self):
45
- dt_ = 0.1
46
-
47
- def dt_behavior(dt):
48
- nonlocal dt_
49
- dt_ = dt
50
- print(f'dt: {dt}')
51
-
52
- bst.environ.register_default_behavior('dt', dt_behavior)
53
-
54
- with bst.environ.context(dt=0.2):
55
- self.assertEqual(dt_, 0.2)
56
- self.assertEqual(dt_, 0.1)
25
+ def test_precision(self):
26
+ with bst.environ.context(precision=64):
27
+ a = bst.random.randn(1)
28
+ self.assertEqual(a.dtype, jnp.float64)
29
+
30
+ with bst.environ.context(precision=32):
31
+ a = bst.random.randn(1)
32
+ self.assertEqual(a.dtype, jnp.float32)
33
+
34
+ with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.bfloat16)
37
+
38
+ def test_platform(self):
39
+ with self.assertRaises(ValueError):
40
+ with bst.environ.context(platform='cpu'):
41
+ a = bst.random.randn(1)
42
+ self.assertEqual(a.device(), 'cpu')
43
+
44
+ def test_register_default_behavior(self):
45
+ bst.environ.set(dt=0.1)
46
+
47
+ dt_ = 0.1
48
+
49
+ def dt_behavior(dt):
50
+ nonlocal dt_
51
+ dt_ = dt
52
+ print(f'dt: {dt}')
53
+
54
+ bst.environ.register_default_behavior('dt', dt_behavior)
55
+
56
+ with bst.environ.context(dt=0.2):
57
+ self.assertEqual(dt_, 0.2)
58
+ self.assertEqual(dt_, 0.1)
@@ -0,0 +1,27 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._csr import *
18
+ from ._csr import __all__ as __all_csr
19
+ from ._fixed_probability import *
20
+ from ._fixed_probability import __all__ as __all_fixed_probability
21
+ from ._linear import *
22
+ from ._xla_custom_op import *
23
+ from ._xla_custom_op import __all__ as __all_xla_custom_op
24
+ from ._linear import __all__ as __all_linear
25
+
26
+ __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
27
+ del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
@@ -0,0 +1,316 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Callable, Optional
18
+
19
+ import brainunit as u
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate._state import ParamState
25
+ from brainstate._utils import set_module_as
26
+ from brainstate.init import param
27
+ from brainstate.nn._module import Module
28
+ from brainstate.typing import ArrayLike, Size
29
+
30
+ __all__ = [
31
+ 'CSRLinear',
32
+ ]
33
+
34
+
35
+ class CSRLinear(Module):
36
+ """
37
+ The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
38
+
39
+ Parameters
40
+ ----------
41
+ in_size : Size
42
+ Number of pre-synaptic neurons, i.e., input size.
43
+ out_size : Size
44
+ Number of post-synaptic neurons, i.e., output size.
45
+ weight : float or callable or jax.Array or brainunit.Quantity
46
+ Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
47
+ name : str, optional
48
+ Name of the module.
49
+ """
50
+
51
+ __module__ = 'brainstate.event'
52
+
53
+ def __init__(
54
+ self,
55
+ in_size: Size,
56
+ out_size: Size,
57
+ indptr: ArrayLike,
58
+ indices: ArrayLike,
59
+ weight: Union[Callable, ArrayLike],
60
+ name: Optional[str] = None,
61
+ grad_mode: str = 'vjp'
62
+ ):
63
+ super().__init__(name=name)
64
+
65
+ # network size
66
+ self.in_size = in_size
67
+ self.out_size = out_size
68
+ self.n_pre = self.in_size[-1]
69
+ self.n_post = self.out_size[-1]
70
+
71
+ # gradient mode
72
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
73
+ self.grad_mode = grad_mode
74
+
75
+ # CSR data structure
76
+ indptr = jnp.asarray(indptr)
77
+ indices = jnp.asarray(indices)
78
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
+ assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
81
+ with jax.ensure_compile_time_eval():
82
+ self.indptr = u.math.asarray(indptr)
83
+ self.indices = u.math.asarray(indices)
84
+
85
+ # maximum synaptic conductance
86
+ weight = param(weight, (len(indices),), allow_none=False)
87
+ if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
88
+ raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
89
+ self.weight = ParamState(weight)
90
+
91
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
92
+ weight = self.weight.value
93
+
94
+ # return zero if no pre-synaptic neurons
95
+ if len(self.indices) == 0:
96
+ r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
97
+ dtype=weight.dtype,
98
+ unit=u.get_unit(weight) * u.get_unit(spk))
99
+ return u.maybe_decimal(r)
100
+
101
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
102
+
103
+ # CPU implementation
104
+ if device_kind == 'cpu':
105
+ return cpu_event_csr(
106
+ u.math.asarray(spk),
107
+ self.indptr,
108
+ self.indices,
109
+ u.math.asarray(weight),
110
+ n_post=self.n_post, grad_mode=self.grad_mode
111
+ )
112
+
113
+ # GPU/TPU implementation
114
+ elif device_kind in ['gpu', 'tpu']:
115
+ raise NotImplementedError()
116
+
117
+ else:
118
+ raise ValueError(f"Unsupported device: {device_kind}")
119
+
120
+
121
+ @set_module_as('brainstate.event')
122
+ def cpu_event_csr(
123
+ spk: jax.Array,
124
+ indptr: jax.Array,
125
+ indices: jax.Array,
126
+ weight: Union[u.Quantity, jax.Array],
127
+ *,
128
+ n_post: int,
129
+ grad_mode: str = 'vjp'
130
+ ) -> Union[u.Quantity, jax.Array]:
131
+ """
132
+ The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
133
+
134
+ Parameters
135
+ ----------
136
+ spk : jax.Array
137
+ Spike events.
138
+ indptr : jax.Array
139
+ Index pointer of post connected neurons.
140
+ indices : jax.Array
141
+ Indices of post connected neurons.
142
+ weight : brainunit.Quantity or jax.Array
143
+ Maximum synaptic conductance.
144
+ n_post : int
145
+ Number of post-synaptic neurons.
146
+ grad_mode : str, optional
147
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
148
+
149
+ Returns
150
+ -------
151
+ post_data : brainunit.Quantity or jax.Array
152
+ Post synaptic data.
153
+ """
154
+ unit = u.get_unit(weight)
155
+ weight = u.get_mantissa(weight)
156
+
157
+ def mv(spk_vector):
158
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
159
+ if grad_mode == 'vjp':
160
+ post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
161
+ elif grad_mode == 'jvp':
162
+ post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
163
+ else:
164
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
165
+ return post_data
166
+
167
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
168
+ assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
169
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
170
+
171
+ if spk.ndim == 1:
172
+ post_data = mv(spk)
173
+ else:
174
+ shape = spk.shape[:-1]
175
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
176
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
177
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
178
+
179
+
180
+ # --------------
181
+ # Implementation
182
+ # --------------
183
+
184
+
185
+ def _cpu_event_csr_mv(
186
+ spk: jax.Array,
187
+ indptr: jax.Array,
188
+ indices: jax.Array,
189
+ weight: Union[u.Quantity, jax.Array],
190
+ n_post: int
191
+ ) -> jax.Array:
192
+ bool_x = spk.dtype == jnp.bool_
193
+ homo_w = jnp.size(weight) == 1
194
+
195
+ def add_fn(post_val, i_start, i_end, sp):
196
+ def body_fn(x):
197
+ post, i = x
198
+ i_post = indices[i]
199
+ w = weight if homo_w else weight[i]
200
+ w = w if bool_x else w * sp
201
+ post = post.at[i_post].add(w)
202
+ return post, i + 1
203
+
204
+ return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
205
+
206
+ def scan_fn(post, i):
207
+ sp = spk[i] # pre-synaptic spike event
208
+ if bool_x:
209
+ post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
210
+ else:
211
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
212
+ return post, None
213
+
214
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
215
+
216
+
217
+ # --------------
218
+ # VJP
219
+ # --------------
220
+
221
+ def _cpu_event_csr_mv_fwd(
222
+ spk: jax.Array,
223
+ indptr: jax.Array,
224
+ indices: jax.Array,
225
+ weight: Union[u.Quantity, jax.Array],
226
+ n_post: int
227
+ ):
228
+ return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
229
+
230
+
231
+ def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
232
+ spk, weight = res
233
+ homo = jnp.size(weight) == 1
234
+ bool_spk = spk.dtype == jnp.bool_
235
+
236
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
237
+ def fn_spk(i_pre):
238
+ def body_fn(x):
239
+ r, i = x
240
+ i_post = indices[i]
241
+ r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
242
+ return r, i + 1
243
+
244
+ p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
245
+ p = p * weight if homo else p
246
+ return p
247
+
248
+ ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
249
+
250
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
251
+ if homo: # scalar
252
+ ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
253
+ ct_gmax = jnp.inner(ct, ct_gmax)
254
+ else:
255
+ def single_post(dw, i_pre):
256
+ def body_fn(x):
257
+ dw, i = x
258
+ i_post = indices[i]
259
+ dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
260
+ return dw, i + 1
261
+
262
+ return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
263
+
264
+ def fn_w(dw, i_pre):
265
+ sp = spk[i_pre]
266
+ if bool_spk:
267
+ return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
268
+ else:
269
+ return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
270
+
271
+ ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
272
+ return ct_spk, ct_gmax
273
+
274
+
275
+ _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
276
+ _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
277
+
278
+
279
+ # --------------
280
+ # JVP
281
+ # --------------
282
+
283
+
284
+ def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
285
+ # forward pass
286
+ spk, weight = primals
287
+ y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
288
+
289
+ # forward gradients
290
+ spk_dot, weight_dot = tangents
291
+ homo_w = jnp.size(weight) == 1
292
+
293
+ # ∂y/∂gmax
294
+ dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
295
+
296
+ # ∂y/∂gspk
297
+ def scan_fn(post, i_pre):
298
+ def while_fn(x):
299
+ p, i, sp = x
300
+ i_post = indices[i]
301
+ p = p.at[i_post].add(sp if homo_w else sp * weight[i])
302
+ return p, i + 1, sp
303
+
304
+ post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
305
+ while_fn,
306
+ (post, indptr[i_pre], spk_dot[i_pre]))[0]
307
+
308
+ return post, None
309
+
310
+ dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
311
+ dspk = (dspk * weight) if homo_w else dspk
312
+ return y, dweight + dspk
313
+
314
+
315
+ _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
316
+ _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,118 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax.numpy
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from absl.testing import parameterized
22
+
23
+ import brainstate as bst
24
+
25
+
26
+ def _get_csr(n_pre, n_post, prob):
27
+ n_conn = int(n_post * prob)
28
+ indptr = np.arange(n_pre + 1) * n_conn
29
+ indices = np.random.randint(0, n_post, (n_pre * n_conn,))
30
+ return indptr, indices
31
+
32
+
33
+ def true_fn(x, w, indices, indptr, n_out):
34
+ homo_w = jnp.size(w) == 1
35
+
36
+ post = jnp.zeros((n_out,))
37
+ for i_pre in range(x.shape[0]):
38
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
39
+ post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
40
+ return post
41
+
42
+
43
+ class TestFixedProbCSR(parameterized.TestCase):
44
+ @parameterized.product(
45
+ homo_w=[True, False],
46
+ )
47
+ def test1(self, homo_w):
48
+ x = bst.random.rand(20) < 0.1
49
+ indptr, indices = _get_csr(20, 40, 0.1)
50
+ m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
+ y = m(x)
52
+ y2 = true_fn(x, m.weight.value, indices, indptr, 40)
53
+ self.assertTrue(jnp.allclose(y, y2))
54
+
55
+ @parameterized.product(
56
+ bool_x=[True, False],
57
+ homo_w=[True, False]
58
+ )
59
+ def test_vjp(self, bool_x, homo_w):
60
+ n_in = 20
61
+ n_out = 30
62
+ if bool_x:
63
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
+ else:
65
+ x = bst.random.rand(n_in)
66
+
67
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
68
+ fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
+ w = fn.weight.value
70
+
71
+ def f(x, w):
72
+ fn.weight.value = w
73
+ return fn(x).sum()
74
+
75
+ r = jax.grad(f, argnums=(0, 1))(x, w)
76
+
77
+ # -------------------
78
+ # TRUE gradients
79
+
80
+ def f2(x, w):
81
+ return true_fn(x, w, indices, indptr, n_out).sum()
82
+
83
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
+ self.assertTrue(jnp.allclose(r[0], r2[0]))
85
+ self.assertTrue(jnp.allclose(r[1], r2[1]))
86
+
87
+ @parameterized.product(
88
+ bool_x=[True, False],
89
+ homo_w=[True, False]
90
+ )
91
+ def test_jvp(self, bool_x, homo_w):
92
+ n_in = 20
93
+ n_out = 30
94
+ if bool_x:
95
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
+ else:
97
+ x = bst.random.rand(n_in)
98
+
99
+ indptr, indices = _get_csr(n_in, n_out, 0.1)
100
+ fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
101
+ 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
+ w = fn.weight.value
103
+
104
+ def f(x, w):
105
+ fn.weight.value = w
106
+ return fn(x)
107
+
108
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
+
110
+ # -------------------
111
+ # TRUE gradients
112
+
113
+ def f2(x, w):
114
+ return true_fn(x, w, indices, indptr, n_out)
115
+
116
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
+ self.assertTrue(jnp.allclose(r1, r2))
118
+ self.assertTrue(jnp.allclose(o1, o2))