brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,312 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Callable, Optional
17
-
18
- import brainunit as u
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- from brainstate._state import ParamState, State
24
- from brainstate.init import param
25
- from brainstate.mixin import Mode, Training
26
- from brainstate.nn._base import DnnLayer
27
- from brainstate.typing import ArrayLike
28
- from ._misc import IntScalar
29
-
30
- __all__ = [
31
- 'EventCSR',
32
- ]
33
-
34
-
35
- class EventCSR(DnnLayer):
36
- """
37
- The EventCSR module implements a fixed probability connection with CSR sparse data structure.
38
-
39
- Parameters
40
- ----------
41
- n_pre : int
42
- Number of pre-synaptic neurons.
43
- n_post : int
44
- Number of post-synaptic neurons.
45
- weight : float or callable or jax.Array or brainunit.Quantity
46
- Maximum synaptic conductance.
47
- name : str, optional
48
- Name of the module.
49
- mode : brainstate.mixin.Mode, optional
50
- Mode of the module.
51
- """
52
-
53
- def __init__(
54
- self,
55
- n_pre: IntScalar,
56
- n_post: IntScalar,
57
- indptr: ArrayLike,
58
- indices: ArrayLike,
59
- weight: Union[Callable, ArrayLike],
60
- name: Optional[str] = None,
61
- mode: Optional[Mode] = None,
62
- grad_mode: str = 'vjp'
63
- ):
64
- super().__init__(name=name, mode=mode)
65
-
66
- self.in_size = n_pre
67
- self.out_size = n_post
68
- self.n_pre = n_pre
69
- self.n_post = n_post
70
-
71
- # gradient mode
72
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
73
- self.grad_mode = grad_mode
74
-
75
- # CSR data structure
76
- indptr = jnp.asarray(indptr)
77
- indices = jnp.asarray(indices)
78
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
- assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
81
- self.indptr = indptr
82
- self.indices = indices
83
-
84
- # maximum synaptic conductance
85
- weight = param(weight, (len(indices),), allow_none=False)
86
- # if callable(weight):
87
- # pass
88
- # else:
89
- # if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
90
- # raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
91
- if self.mode.has(Training):
92
- weight = ParamState(weight)
93
- self.weight = weight
94
-
95
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
96
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
97
- if len(self.indices) == 0:
98
- r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
99
- dtype=weight.dtype,
100
- unit=u.get_unit(weight) * u.get_unit(spk))
101
- return u.maybe_decimal(r)
102
-
103
- device_kind = jax.devices()[0].platform # spk.device.device_kind
104
- if device_kind == 'cpu':
105
- return cpu_event_csr(
106
- u.math.asarray(spk),
107
- u.math.asarray(self.indptr),
108
- u.math.asarray(self.indices),
109
- u.math.asarray(weight),
110
- n_post=self.n_post, grad_mode=self.grad_mode
111
- )
112
- elif device_kind in ['gpu', 'tpu']:
113
- raise NotImplementedError()
114
- else:
115
- raise ValueError(f"Unsupported device: {device_kind}")
116
-
117
-
118
- def cpu_event_csr(
119
- spk: jax.Array,
120
- indptr: jax.Array,
121
- indices: jax.Array,
122
- weight: Union[u.Quantity, jax.Array],
123
- *,
124
- n_post: int,
125
- grad_mode: str = 'vjp'
126
- ) -> Union[u.Quantity, jax.Array]:
127
- """
128
- The EventCSR module implements a fixed probability connection with CSR sparse data structure.
129
-
130
- Parameters
131
- ----------
132
- spk : jax.Array
133
- Spike events.
134
- indptr : jax.Array
135
- Index pointer of post connected neurons.
136
- indices : jax.Array
137
- Indices of post connected neurons.
138
- weight : brainunit.Quantity or jax.Array
139
- Maximum synaptic conductance.
140
- n_post : int
141
- Number of post-synaptic neurons.
142
- grad_mode : str, optional
143
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
144
-
145
- Returns
146
- -------
147
- post_data : brainunit.Quantity or jax.Array
148
- Post synaptic data.
149
- """
150
- unit = u.get_unit(weight)
151
- weight = u.get_mantissa(weight)
152
-
153
- def mv(spk_vector):
154
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
155
- if grad_mode == 'vjp':
156
- post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
157
- elif grad_mode == 'jvp':
158
- post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
159
- else:
160
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
161
- return post_data
162
-
163
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
164
- assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
165
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
166
-
167
- if spk.ndim == 1:
168
- post_data = mv(spk)
169
- else:
170
- shape = spk.shape[:-1]
171
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
172
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
173
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
174
-
175
-
176
- # --------------
177
- # Implementation
178
- # --------------
179
-
180
-
181
- def _cpu_event_csr_mv(
182
- spk: jax.Array,
183
- indptr: jax.Array,
184
- indices: jax.Array,
185
- weight: Union[u.Quantity, jax.Array],
186
- n_post: int
187
- ) -> jax.Array:
188
- bool_x = spk.dtype == jnp.bool_
189
- homo_w = jnp.size(weight) == 1
190
-
191
- def add_fn(post_val, i_start, i_end, sp):
192
- def body_fn(x):
193
- post, i = x
194
- i_post = indices[i]
195
- w = weight if homo_w else weight[i]
196
- w = w if bool_x else w * sp
197
- post = post.at[i_post].add(w)
198
- return post, i + 1
199
-
200
- return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
201
-
202
- def scan_fn(post, i):
203
- sp = spk[i] # pre-synaptic spike event
204
- if bool_x:
205
- post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
206
- else:
207
- post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
208
- return post, None
209
-
210
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
211
-
212
-
213
- # --------------
214
- # VJP
215
- # --------------
216
-
217
- def _cpu_event_csr_mv_fwd(
218
- spk: jax.Array,
219
- indptr: jax.Array,
220
- indices: jax.Array,
221
- weight: Union[u.Quantity, jax.Array],
222
- n_post: int
223
- ):
224
- return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
225
-
226
-
227
- def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
228
- spk, weight = res
229
- homo = jnp.size(weight) == 1
230
- bool_spk = spk.dtype == jnp.bool_
231
-
232
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
233
- def fn_spk(i_pre):
234
- def body_fn(x):
235
- r, i = x
236
- i_post = indices[i]
237
- r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
238
- return r, i + 1
239
-
240
- p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
241
- p = p * weight if homo else p
242
- return p
243
-
244
- ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
245
-
246
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
247
- if homo: # scalar
248
- ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
249
- ct_gmax = jnp.inner(ct, ct_gmax)
250
- else:
251
- def single_post(dw, i_pre):
252
- def body_fn(x):
253
- dw, i = x
254
- i_post = indices[i]
255
- dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
256
- return dw, i + 1
257
-
258
- return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
259
-
260
- def fn_w(dw, i_pre):
261
- sp = spk[i_pre]
262
- if bool_spk:
263
- return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
264
- else:
265
- return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
266
-
267
- ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
268
- return ct_spk, ct_gmax
269
-
270
-
271
- _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
272
- _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
273
-
274
-
275
- # --------------
276
- # JVP
277
- # --------------
278
-
279
-
280
- def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
281
- # forward pass
282
- spk, weight = primals
283
- y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
284
-
285
- # forward gradients
286
- spk_dot, weight_dot = tangents
287
- homo_w = jnp.size(weight) == 1
288
-
289
- # ∂y/∂gmax
290
- dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
291
-
292
- # ∂y/∂gspk
293
- def scan_fn(post, i_pre):
294
- def while_fn(x):
295
- p, i, sp = x
296
- i_post = indices[i]
297
- p = p.at[i_post].add(sp if homo_w else sp * weight[i])
298
- return p, i + 1, sp
299
-
300
- post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
301
- while_fn,
302
- (post, indptr[i_pre], spk_dot[i_pre]))[0]
303
-
304
- return post, None
305
-
306
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
307
- dspk = (dspk * weight) if homo_w else dspk
308
- return y, dweight + dspk
309
-
310
-
311
- _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
312
- _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -1,118 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import jax.numpy
18
- import jax.numpy as jnp
19
- import numpy as np
20
- from absl.testing import parameterized
21
-
22
- import brainstate as bst
23
- from brainstate.nn.event.csr import EventCSR
24
-
25
-
26
- def _get_csr(n_pre, n_post, prob):
27
- n_conn = int(n_post * prob)
28
- indptr = np.arange(n_pre + 1) * n_conn
29
- indices = np.random.randint(0, n_post, (n_pre * n_conn,))
30
- return indptr, indices
31
-
32
-
33
- def true_fn(x, w, indices, indptr, n_out):
34
- homo_w = jnp.size(w) == 1
35
-
36
- post = jnp.zeros((n_out,))
37
- for i_pre in range(x.shape[0]):
38
- ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
39
- post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
40
- return post
41
-
42
-
43
- class TestFixedProbCSR(parameterized.TestCase):
44
- @parameterized.product(
45
- homo_w=[True, False],
46
- )
47
- def test1(self, homo_w):
48
- x = bst.random.rand(20) < 0.1
49
- indptr, indices = _get_csr(20, 40, 0.1)
50
- m = EventCSR(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
- y = m(x)
52
- y2 = true_fn(x, m.weight, indices, indptr, 40)
53
- self.assertTrue(jnp.allclose(y, y2))
54
-
55
- @parameterized.product(
56
- bool_x=[True, False],
57
- homo_w=[True, False]
58
- )
59
- def test_vjp(self, bool_x, homo_w):
60
- n_in = 20
61
- n_out = 30
62
- if bool_x:
63
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
- else:
65
- x = bst.random.rand(n_in)
66
-
67
- indptr, indices = _get_csr(n_in, n_out, 0.1)
68
- fn = bst.nn.EventCSR(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
- w = fn.weight
70
-
71
- def f(x, w):
72
- fn.weight = w
73
- return fn(x).sum()
74
-
75
- r = jax.grad(f, argnums=(0, 1))(x, w)
76
-
77
- # -------------------
78
- # TRUE gradients
79
-
80
- def f2(x, w):
81
- return true_fn(x, w, indices, indptr, n_out).sum()
82
-
83
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
- self.assertTrue(jnp.allclose(r[0], r2[0]))
85
- self.assertTrue(jnp.allclose(r[1], r2[1]))
86
-
87
- @parameterized.product(
88
- bool_x=[True, False],
89
- homo_w=[True, False]
90
- )
91
- def test_jvp(self, bool_x, homo_w):
92
- n_in = 20
93
- n_out = 30
94
- if bool_x:
95
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
- else:
97
- x = bst.random.rand(n_in)
98
-
99
- indptr, indices = _get_csr(n_in, n_out, 0.1)
100
- fn = EventCSR(n_in, n_out, indptr, indices,
101
- 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
- w = fn.weight
103
-
104
- def f(x, w):
105
- fn.weight = w
106
- return fn(x)
107
-
108
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
-
110
- # -------------------
111
- # TRUE gradients
112
-
113
- def f2(x, w):
114
- return true_fn(x, w, indices, indptr, n_out)
115
-
116
- o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
- self.assertTrue(jnp.allclose(r1, r2))
118
- self.assertTrue(jnp.allclose(o1, o2))
@@ -1,276 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Callable, Optional
17
-
18
- import brainunit as u
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- from brainstate._state import ParamState, State
24
- from brainstate.init import param
25
- from brainstate.mixin import Mode, Training
26
- from brainstate.nn._base import DnnLayer
27
- from brainstate.random import RandomState
28
- from brainstate.transform import for_loop
29
- from brainstate.typing import ArrayLike
30
- from ._misc import FloatScalar, IntScalar
31
-
32
- __all__ = [
33
- 'EventFixedProb',
34
- ]
35
-
36
-
37
- class EventFixedProb(DnnLayer):
38
- """
39
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
40
-
41
- Parameters
42
- ----------
43
- n_pre : int
44
- Number of pre-synaptic neurons.
45
- n_post : int
46
- Number of post-synaptic neurons.
47
- prob : float
48
- Probability of connection.
49
- weight : float or callable or jax.Array or brainunit.Quantity
50
- Maximum synaptic conductance.
51
- allow_multi_conn : bool, optional
52
- Whether multiple connections are allowed from a single pre-synaptic neuron.
53
- Default is True, meaning that a value of ``a`` can be selected multiple times.
54
- prob : float
55
- Probability of connection.
56
- name : str, optional
57
- Name of the module.
58
- mode : brainstate.mixin.Mode, optional
59
- Mode of the module.
60
- """
61
-
62
- def __init__(
63
- self,
64
- n_pre: IntScalar,
65
- n_post: IntScalar,
66
- prob: FloatScalar,
67
- weight: Union[Callable, ArrayLike],
68
- allow_multi_conn: bool = True,
69
- seed: Optional[int] = None,
70
- name: Optional[str] = None,
71
- mode: Optional[Mode] = None,
72
- grad_mode: str = 'vjp'
73
- ):
74
- super().__init__(name=name, mode=mode)
75
- self.n_pre = n_pre
76
- self.n_post = n_post
77
- self.in_size = n_pre
78
- self.out_size = n_post
79
-
80
- self.n_conn = int(n_post * prob)
81
- if self.n_conn < 1:
82
- raise ValueError(f"The number of connections must be at least 1. Got: int({n_post} * {prob}) = {self.n_conn}")
83
-
84
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
85
- self.grad_mode = grad_mode
86
-
87
- # indices of post connected neurons
88
- if allow_multi_conn:
89
- self.indices = np.random.RandomState(seed).randint(0, n_post, size=(self.n_pre, self.n_conn))
90
- else:
91
- rng = RandomState(seed)
92
- self.indices = for_loop(lambda i: rng.choice(n_post, size=(self.n_conn,), replace=False), np.arange(n_pre))
93
-
94
- # maximum synaptic conductance
95
- weight = param(weight, (self.n_pre, self.n_conn), allow_none=False)
96
- if self.mode.has(Training):
97
- weight = ParamState(weight)
98
- self.weight = weight
99
-
100
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
101
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
102
- device_kind = jax.devices()[0].platform # spk.device.device_kind
103
- if device_kind == 'cpu':
104
- return cpu_event_fixed_prob(u.math.asarray(self.indices),
105
- u.math.asarray(weight),
106
- u.math.asarray(spk),
107
- n_post=self.n_post, grad_mode=self.grad_mode)
108
- elif device_kind in ['gpu', 'tpu']:
109
- raise NotImplementedError()
110
- else:
111
- raise ValueError(f"Unsupported device: {device_kind}")
112
-
113
-
114
- def cpu_event_fixed_prob(
115
- indices: jax.Array,
116
- weight: Union[u.Quantity, jax.Array],
117
- spk: jax.Array,
118
- *,
119
- n_post: int,
120
- grad_mode: str = 'vjp'
121
- ) -> Union[u.Quantity, jax.Array]:
122
- """
123
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
124
-
125
- Parameters
126
- ----------
127
- n_post : int
128
- Number of post-synaptic neurons.
129
- weight : brainunit.Quantity or jax.Array
130
- Maximum synaptic conductance.
131
- spk : jax.Array
132
- Spike events.
133
- indices : jax.Array
134
- Indices of post connected neurons.
135
- grad_mode : str, optional
136
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
137
-
138
- Returns
139
- -------
140
- post_data : brainunit.Quantity or jax.Array
141
- Post synaptic data.
142
- """
143
- unit = u.get_unit(weight)
144
- weight = u.get_mantissa(weight)
145
- indices = jnp.asarray(indices)
146
- spk = jnp.asarray(spk)
147
-
148
- def mv(spk_vector):
149
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
150
- if grad_mode == 'vjp':
151
- post_data = _cpu_event_fixed_prob_mv_vjp(indices, weight, spk_vector, n_post)
152
- elif grad_mode == 'jvp':
153
- post_data = _cpu_event_fixed_prob_mv_jvp(indices, weight, spk_vector, n_post)
154
- else:
155
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
156
- return post_data
157
-
158
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
159
- assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
160
- assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
161
-
162
- if spk.ndim == 1:
163
- post_data = mv(spk)
164
- else:
165
- shape = spk.shape[:-1]
166
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
167
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
168
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
169
-
170
-
171
- # -------------------
172
- # CPU Implementation
173
- # -------------------
174
-
175
-
176
- def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
177
- def scan_fn(post, i):
178
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
179
- ids = indices[i]
180
- sp = spk[i]
181
- if spk.dtype == jnp.bool_:
182
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
183
- else:
184
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
185
- return post, None
186
-
187
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
188
-
189
-
190
- # --------------
191
- # VJP
192
- # --------------
193
-
194
- def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
195
- return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
196
-
197
-
198
- def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
199
- weight, spk = res
200
-
201
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
202
- homo = jnp.size(weight) == 1
203
- if homo: # homogeneous weight
204
- ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
205
- else: # heterogeneous weight
206
- ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
207
-
208
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
209
- if homo: # scalar
210
- ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
211
- ct_gmax = jnp.inner(ct, ct_gmax)
212
- else:
213
- def scan_fn(d_gmax, i):
214
- if spk.dtype == jnp.bool_:
215
- d_gmax = jax.lax.cond(spk[i], lambda: d_gmax.at[i].add(ct[indices[i]]), lambda: d_gmax)
216
- else:
217
- d_gmax = jax.lax.cond(spk[i] == 0., lambda: d_gmax, lambda: d_gmax.at[i].add(ct[indices[i]] * spk[i]))
218
- return d_gmax, None
219
-
220
- ct_gmax = jax.lax.scan(scan_fn, jnp.zeros_like(weight), np.arange(len(spk)))[0]
221
- return ct_gmax, ct_spk
222
-
223
-
224
- _cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
225
- _cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
226
-
227
-
228
- # --------------
229
- # JVP
230
- # --------------
231
-
232
-
233
- def _cpu_event_fixed_prob_mv_jvp_rule(indices, n_post, primals, tangents):
234
- # forward pass
235
- weight, spk = primals
236
- y = _cpu_event_fixed_prob_mv(indices, weight, spk, n_post=n_post)
237
-
238
- # forward gradients
239
- gmax_dot, spk_dot = tangents
240
-
241
- # ∂y/∂gmax
242
- dgmax = _cpu_event_fixed_prob_mv(indices, gmax_dot, spk, n_post=n_post)
243
-
244
- def scan_fn(post, i):
245
- ids = indices[i]
246
- w = weight if jnp.size(weight) == 1 else weight[i]
247
- post = post.at[ids].add(w * spk_dot[i])
248
- return post, None
249
-
250
- # ∂y/∂gspk
251
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
252
- return y, dgmax + dspk
253
-
254
-
255
- _cpu_event_fixed_prob_mv_jvp = jax.custom_jvp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
256
- _cpu_event_fixed_prob_mv_jvp.defjvp(_cpu_event_fixed_prob_mv_jvp_rule)
257
-
258
-
259
-
260
-
261
-
262
-
263
- def _gpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
264
- def scan_fn(post, i):
265
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
266
- ids = indices[i]
267
- sp = spk[i]
268
- if spk.dtype == jnp.bool_:
269
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
270
- else:
271
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
272
- return post, None
273
-
274
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
275
-
276
-