brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,271 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Callable, Optional
18
+
19
+ import brainunit as u
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate._state import ParamState
25
+ from brainstate._utils import set_module_as
26
+ from brainstate.compile import for_loop
27
+ from brainstate.init import param
28
+ from brainstate.nn._module import Module
29
+ from brainstate.random import RandomState
30
+ from brainstate.typing import ArrayLike
31
+ from ._misc import FloatScalar, IntScalar
32
+
33
+ __all__ = [
34
+ 'FixedProb',
35
+ ]
36
+
37
+
38
+ class FixedProb(Module):
39
+ """
40
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
41
+
42
+ Parameters
43
+ ----------
44
+ n_pre : int
45
+ Number of pre-synaptic neurons.
46
+ n_post : int
47
+ Number of post-synaptic neurons.
48
+ prob : float
49
+ Probability of connection.
50
+ weight : float or callable or jax.Array or brainunit.Quantity
51
+ Maximum synaptic conductance.
52
+ allow_multi_conn : bool, optional
53
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
54
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
55
+ prob : float
56
+ Probability of connection.
57
+ name : str, optional
58
+ Name of the module.
59
+ """
60
+
61
+ __module__ = 'brainstate.event'
62
+
63
+ def __init__(
64
+ self,
65
+ n_pre: IntScalar,
66
+ n_post: IntScalar,
67
+ prob: FloatScalar,
68
+ weight: Union[Callable, ArrayLike],
69
+ allow_multi_conn: bool = True,
70
+ seed: Optional[int] = None,
71
+ name: Optional[str] = None,
72
+ grad_mode: str = 'vjp'
73
+ ):
74
+ super().__init__(name=name)
75
+ self.n_pre = n_pre
76
+ self.n_post = n_post
77
+ self.in_size = n_pre
78
+ self.out_size = n_post
79
+
80
+ self.n_conn = int(n_post * prob)
81
+ if self.n_conn < 1:
82
+ raise ValueError(
83
+ f"The number of connections must be at least 1. Got: int({n_post} * {prob}) = {self.n_conn}")
84
+
85
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
86
+ self.grad_mode = grad_mode
87
+
88
+ # indices of post connected neurons
89
+ if allow_multi_conn:
90
+ self.indices = np.random.RandomState(seed).randint(0, n_post, size=(self.n_pre, self.n_conn))
91
+ else:
92
+ rng = RandomState(seed)
93
+ self.indices = for_loop(lambda i: rng.choice(n_post, size=(self.n_conn,), replace=False), np.arange(n_pre))
94
+ self.indices = u.math.asarray(self.indices)
95
+
96
+ # maximum synaptic conductance
97
+ weight = param(weight, (self.n_pre, self.n_conn), allow_none=False)
98
+ self.weight = ParamState(weight)
99
+
100
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
101
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
102
+ if device_kind == 'cpu':
103
+ return cpu_fixed_prob(self.indices,
104
+ u.math.asarray(self.weight.value),
105
+ u.math.asarray(spk),
106
+ n_post=self.n_post,
107
+ grad_mode=self.grad_mode)
108
+ elif device_kind in ['gpu', 'tpu']:
109
+ raise NotImplementedError()
110
+ else:
111
+ raise ValueError(f"Unsupported device: {device_kind}")
112
+
113
+
114
+ @set_module_as('brainstate.event')
115
+ def cpu_fixed_prob(
116
+ indices: jax.Array,
117
+ weight: Union[u.Quantity, jax.Array],
118
+ spk: jax.Array,
119
+ *,
120
+ n_post: int,
121
+ grad_mode: str = 'vjp'
122
+ ) -> Union[u.Quantity, jax.Array]:
123
+ """
124
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
125
+
126
+ Parameters
127
+ ----------
128
+ n_post : int
129
+ Number of post-synaptic neurons.
130
+ weight : brainunit.Quantity or jax.Array
131
+ Maximum synaptic conductance.
132
+ spk : jax.Array
133
+ Spike events.
134
+ indices : jax.Array
135
+ Indices of post connected neurons.
136
+ grad_mode : str, optional
137
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
138
+
139
+ Returns
140
+ -------
141
+ post_data : brainunit.Quantity or jax.Array
142
+ Post synaptic data.
143
+ """
144
+ unit = u.get_unit(weight)
145
+ weight = u.get_mantissa(weight)
146
+ indices = jnp.asarray(indices)
147
+ spk = jnp.asarray(spk)
148
+
149
+ def mv(spk_vector):
150
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
151
+ if grad_mode == 'vjp':
152
+ post_data = _cpu_event_fixed_prob_mv_vjp(indices, weight, spk_vector, n_post)
153
+ elif grad_mode == 'jvp':
154
+ post_data = _cpu_event_fixed_prob_mv_jvp(indices, weight, spk_vector, n_post)
155
+ else:
156
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
157
+ return post_data
158
+
159
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
160
+ assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
161
+ assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
162
+
163
+ if spk.ndim == 1:
164
+ post_data = mv(spk)
165
+ else:
166
+ shape = spk.shape[:-1]
167
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
168
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
169
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
170
+
171
+
172
+ # -------------------
173
+ # CPU Implementation
174
+ # -------------------
175
+
176
+
177
+ def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
178
+ def scan_fn(post, i):
179
+ w = g_max if jnp.size(g_max) == 1 else g_max[i]
180
+ ids = indices[i]
181
+ sp = spk[i]
182
+ if spk.dtype == jnp.bool_:
183
+ post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
184
+ else:
185
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
186
+ return post, None
187
+
188
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
189
+
190
+
191
+ # --------------
192
+ # VJP
193
+ # --------------
194
+
195
+ def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
196
+ return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
197
+
198
+
199
+ def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
200
+ weight, spk = res
201
+
202
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
203
+ homo = jnp.size(weight) == 1
204
+ if homo: # homogeneous weight
205
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
206
+ else: # heterogeneous weight
207
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
208
+
209
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
210
+ if homo: # scalar
211
+ ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
212
+ ct_gmax = jnp.inner(ct, ct_gmax)
213
+ else:
214
+ def scan_fn(d_gmax, i):
215
+ if spk.dtype == jnp.bool_:
216
+ d_gmax = jax.lax.cond(spk[i], lambda: d_gmax.at[i].add(ct[indices[i]]), lambda: d_gmax)
217
+ else:
218
+ d_gmax = jax.lax.cond(spk[i] == 0., lambda: d_gmax, lambda: d_gmax.at[i].add(ct[indices[i]] * spk[i]))
219
+ return d_gmax, None
220
+
221
+ ct_gmax = jax.lax.scan(scan_fn, jnp.zeros_like(weight), np.arange(len(spk)))[0]
222
+ return ct_gmax, ct_spk
223
+
224
+
225
+ _cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
226
+ _cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
227
+
228
+
229
+ # --------------
230
+ # JVP
231
+ # --------------
232
+
233
+
234
+ def _cpu_event_fixed_prob_mv_jvp_rule(indices, n_post, primals, tangents):
235
+ # forward pass
236
+ weight, spk = primals
237
+ y = _cpu_event_fixed_prob_mv(indices, weight, spk, n_post=n_post)
238
+
239
+ # forward gradients
240
+ gmax_dot, spk_dot = tangents
241
+
242
+ # ∂y/∂gmax
243
+ dgmax = _cpu_event_fixed_prob_mv(indices, gmax_dot, spk, n_post=n_post)
244
+
245
+ def scan_fn(post, i):
246
+ ids = indices[i]
247
+ w = weight if jnp.size(weight) == 1 else weight[i]
248
+ post = post.at[ids].add(w * spk_dot[i])
249
+ return post, None
250
+
251
+ # ∂y/∂gspk
252
+ dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
253
+ return y, dgmax + dspk
254
+
255
+
256
+ _cpu_event_fixed_prob_mv_jvp = jax.custom_jvp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
257
+ _cpu_event_fixed_prob_mv_jvp.defjvp(_cpu_event_fixed_prob_mv_jvp_rule)
258
+
259
+
260
+ def _gpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
261
+ def scan_fn(post, i):
262
+ w = g_max if jnp.size(g_max) == 1 else g_max[i]
263
+ ids = indices[i]
264
+ sp = spk[i]
265
+ if spk.dtype == jnp.bool_:
266
+ post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
267
+ else:
268
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
269
+ return post, None
270
+
271
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
@@ -0,0 +1,128 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax.numpy
19
+ import jax.numpy as jnp
20
+ from absl.testing import parameterized
21
+
22
+ import brainstate as bst
23
+ from brainstate.event._fixed_probability import FixedProb
24
+
25
+
26
+ class TestFixedProbCSR(parameterized.TestCase):
27
+ @parameterized.product(
28
+ allow_multi_conn=[True, False]
29
+ )
30
+ def test1(self, allow_multi_conn):
31
+ x = bst.random.rand(20) < 0.1
32
+ # x = bst.random.rand(20)
33
+ m = FixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
34
+ y = m(x)
35
+ print(y)
36
+
37
+ m2 = FixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
38
+ print(m2(x))
39
+
40
+ def test_grad_bool(self):
41
+ n_in = 20
42
+ n_out = 30
43
+ x = bst.random.rand(n_in) < 0.3
44
+ fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
45
+
46
+ def f(x):
47
+ return fn(x).sum()
48
+
49
+ with self.assertRaises(TypeError):
50
+ print(jax.grad(f)(x))
51
+
52
+ @parameterized.product(
53
+ bool_x=[True, False],
54
+ homo_w=[True, False]
55
+ )
56
+ def test_vjp(self, bool_x, homo_w):
57
+ n_in = 20
58
+ n_out = 30
59
+ if bool_x:
60
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
61
+ else:
62
+ x = bst.random.rand(n_in)
63
+
64
+ if homo_w:
65
+ fn = FixedProb(n_in, n_out, 0.1, 1.5, seed=123)
66
+ else:
67
+ fn = FixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
68
+ w = fn.weight.value
69
+
70
+ def f(x, w):
71
+ fn.weight.value = w
72
+ return fn(x).sum()
73
+
74
+ r = bst.transform.grad(f, argnums=(0, 1))(x, w)
75
+
76
+ # -------------------
77
+ # TRUE gradients
78
+
79
+ def true_fn(x, w, indices, n_post):
80
+ post = jnp.zeros((n_post,))
81
+ for i in range(n_in):
82
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
83
+ return post
84
+
85
+ def f2(x, w):
86
+ return true_fn(x, w, fn.indices, n_out).sum()
87
+
88
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
89
+ self.assertTrue(jnp.allclose(r[0], r2[0]))
90
+ self.assertTrue(jnp.allclose(r[1], r2[1]))
91
+ print(r[1])
92
+
93
+ @parameterized.product(
94
+ bool_x=[True, False],
95
+ homo_w=[True, False]
96
+ )
97
+ def test_jvp(self, bool_x, homo_w):
98
+ n_in = 20
99
+ n_out = 30
100
+ if bool_x:
101
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
102
+ else:
103
+ x = bst.random.rand(n_in)
104
+
105
+ fn = FixedProb(n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(), seed=123, grad_mode='jvp')
106
+ w = fn.weight.value
107
+
108
+ def f(x, w):
109
+ fn.weight.value = w
110
+ return fn(x)
111
+
112
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
113
+
114
+ # -------------------
115
+ # TRUE gradients
116
+
117
+ def true_fn(x, w, indices, n_post):
118
+ post = jnp.zeros((n_post,))
119
+ for i in range(n_in):
120
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
121
+ return post
122
+
123
+ def f2(x, w):
124
+ return true_fn(x, w, fn.indices, n_out)
125
+
126
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
127
+ self.assertTrue(jnp.allclose(r1, r2))
128
+ self.assertTrue(jnp.allclose(o1, o2))
@@ -0,0 +1,219 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Callable, Optional
18
+
19
+ import brainunit as u
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate._state import ParamState, State
25
+ from brainstate._utils import set_module_as
26
+ from brainstate.init import param
27
+ from brainstate.nn._module import Module
28
+ from brainstate.typing import ArrayLike
29
+ from ._misc import IntScalar
30
+
31
+ __all__ = [
32
+ 'Linear',
33
+ ]
34
+
35
+
36
+ class Linear(Module):
37
+ """
38
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
39
+
40
+ Parameters
41
+ ----------
42
+ n_pre : int
43
+ Number of pre-synaptic neurons.
44
+ n_post : int
45
+ Number of post-synaptic neurons.
46
+ weight : float or callable or jax.Array or brainunit.Quantity
47
+ Maximum synaptic conductance.
48
+ name : str, optional
49
+ Name of the module.
50
+ """
51
+
52
+ __module__ = 'brainstate.event'
53
+
54
+ def __init__(
55
+ self,
56
+ n_pre: IntScalar,
57
+ n_post: IntScalar,
58
+ weight: Union[Callable, ArrayLike],
59
+ name: Optional[str] = None,
60
+ grad_mode: str = 'vjp'
61
+ ):
62
+ super().__init__(name=name)
63
+ self.n_pre = n_pre
64
+ self.n_post = n_post
65
+ self.in_size = n_pre
66
+ self.out_size = n_post
67
+
68
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
+ self.grad_mode = grad_mode
70
+
71
+ # maximum synaptic conductance
72
+ weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
73
+ self.weight = ParamState(weight)
74
+
75
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
76
+ weight = self.weight.value if isinstance(self.weight, State) else self.weight
77
+ if u.math.size(weight) == 1:
78
+ return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
79
+
80
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
81
+ if device_kind == 'cpu':
82
+ return cpu_event_linear(u.math.asarray(weight),
83
+ u.math.asarray(spk),
84
+ n_post=self.n_post,
85
+ grad_mode=self.grad_mode)
86
+ elif device_kind in ['gpu', 'tpu']:
87
+ raise NotImplementedError()
88
+ else:
89
+ raise ValueError(f"Unsupported device: {device_kind}")
90
+
91
+
92
+ @set_module_as('brainstate.event')
93
+ def cpu_event_linear(
94
+ g_max: Union[u.Quantity, jax.Array],
95
+ spk: jax.Array,
96
+ *,
97
+ n_post: int = None,
98
+ grad_mode: str = 'vjp'
99
+ ) -> Union[u.Quantity, jax.Array]:
100
+ """
101
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
102
+
103
+ Parameters
104
+ ----------
105
+ n_post : int
106
+ Number of post-synaptic neurons.
107
+ g_max : brainunit.Quantity or jax.Array
108
+ Maximum synaptic conductance.
109
+ spk : jax.Array
110
+ Spike events.
111
+ grad_mode : str, optional
112
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
113
+
114
+ Returns
115
+ -------
116
+ post_data : brainunit.Quantity or jax.Array
117
+ Post synaptic data.
118
+ """
119
+ unit = u.get_unit(g_max)
120
+ g_max = u.get_mantissa(g_max)
121
+ spk = jnp.asarray(spk)
122
+
123
+ def mv(spk_vector):
124
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
125
+ if jnp.size(g_max) == 1:
126
+ assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
127
+ # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
128
+ return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
129
+
130
+ if grad_mode == 'vjp':
131
+ post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
132
+ elif grad_mode == 'jvp':
133
+ post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
134
+ else:
135
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
136
+ return post
137
+
138
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
139
+ assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
140
+
141
+ if spk.ndim == 1:
142
+ post_data = mv(spk)
143
+ else:
144
+ shape = spk.shape[:-1]
145
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
146
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
147
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
148
+
149
+
150
+ # --------------
151
+ # Implementation
152
+ # --------------
153
+
154
+
155
+ def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
156
+ def scan_fn(post, i):
157
+ sp = spk[i]
158
+ if spk.dtype == jnp.bool_:
159
+ post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
160
+ else:
161
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
162
+ return post, None
163
+
164
+ return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
165
+
166
+
167
+ # --------------
168
+ # VJP
169
+ # --------------
170
+
171
+ def _cpu_event_linear_mv_fwd(g_max, spk):
172
+ return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
173
+
174
+
175
+ def _cpu_event_linear_mv_bwd(res, ct):
176
+ g_max, spk = res
177
+
178
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
179
+ ct_spk = jnp.matmul(g_max, ct)
180
+
181
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
182
+ def map_fn(sp):
183
+ if spk.dtype == jnp.bool_:
184
+ d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
185
+ else:
186
+ d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
187
+ return d_gmax
188
+
189
+ ct_gmax = jax.vmap(map_fn)(spk)
190
+ return ct_gmax, ct_spk
191
+
192
+
193
+ _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
194
+ _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
195
+
196
+
197
+ # --------------
198
+ # JVP
199
+ # --------------
200
+
201
+
202
+ def _cpu_event_linear_mv_jvp_rule(primals, tangents):
203
+ # forward pass
204
+ g_max, spk = primals
205
+ y = _cpu_event_linear_mv(g_max, spk)
206
+
207
+ # forward gradients
208
+ gmax_dot, spk_dot = tangents
209
+
210
+ # ∂y/∂gmax
211
+ dgmax = _cpu_event_linear_mv(gmax_dot, spk)
212
+
213
+ # ∂y/∂gspk
214
+ dspk = spk_dot @ g_max
215
+ return y, dgmax + dspk
216
+
217
+
218
+ _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
219
+ _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)