brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1498 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241010.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Union, Callable, Optional
17
+
18
+ import brainunit as u
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+
23
+ from brainstate._state import ParamState, State
24
+ from brainstate.init import param
25
+ from brainstate.mixin import Mode, Training
26
+ from brainstate.nn._base import DnnLayer
27
+ from brainstate.random import RandomState
28
+ from brainstate.transform import for_loop
29
+ from brainstate.typing import ArrayLike
30
+ from ._misc import FloatScalar, IntScalar
31
+
32
+ __all__ = [
33
+ 'EventFixedProb',
34
+ ]
35
+
36
+
37
+ class EventFixedProb(DnnLayer):
38
+ """
39
+ The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
40
+
41
+ Parameters
42
+ ----------
43
+ n_pre : int
44
+ Number of pre-synaptic neurons.
45
+ n_post : int
46
+ Number of post-synaptic neurons.
47
+ prob : float
48
+ Probability of connection.
49
+ weight : float or callable or jax.Array or brainunit.Quantity
50
+ Maximum synaptic conductance.
51
+ allow_multi_conn : bool, optional
52
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
53
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
54
+ prob : float
55
+ Probability of connection.
56
+ name : str, optional
57
+ Name of the module.
58
+ mode : brainstate.mixin.Mode, optional
59
+ Mode of the module.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ n_pre: IntScalar,
65
+ n_post: IntScalar,
66
+ prob: FloatScalar,
67
+ weight: Union[Callable, ArrayLike],
68
+ allow_multi_conn: bool = True,
69
+ seed: Optional[int] = None,
70
+ name: Optional[str] = None,
71
+ mode: Optional[Mode] = None,
72
+ grad_mode: str = 'vjp'
73
+ ):
74
+ super().__init__(name=name, mode=mode)
75
+ self.n_pre = n_pre
76
+ self.n_post = n_post
77
+ self.in_size = n_pre
78
+ self.out_size = n_post
79
+
80
+ self.n_conn = int(n_post * prob)
81
+ if self.n_conn < 1:
82
+ raise ValueError(f"The number of connections must be at least 1. Got: int({n_post} * {prob}) = {self.n_conn}")
83
+
84
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
85
+ self.grad_mode = grad_mode
86
+
87
+ # indices of post connected neurons
88
+ if allow_multi_conn:
89
+ self.indices = np.random.RandomState(seed).randint(0, n_post, size=(self.n_pre, self.n_conn))
90
+ else:
91
+ rng = RandomState(seed)
92
+ self.indices = for_loop(lambda i: rng.choice(n_post, size=(self.n_conn,), replace=False), np.arange(n_pre))
93
+
94
+ # maximum synaptic conductance
95
+ weight = param(weight, (self.n_pre, self.n_conn), allow_none=False)
96
+ if self.mode.has(Training):
97
+ weight = ParamState(weight)
98
+ self.weight = weight
99
+
100
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
101
+ weight = self.weight.value if isinstance(self.weight, State) else self.weight
102
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
103
+ if device_kind == 'cpu':
104
+ return cpu_event_fixed_prob(u.math.asarray(self.indices),
105
+ u.math.asarray(weight),
106
+ u.math.asarray(spk),
107
+ n_post=self.n_post, grad_mode=self.grad_mode)
108
+ elif device_kind in ['gpu', 'tpu']:
109
+ raise NotImplementedError()
110
+ else:
111
+ raise ValueError(f"Unsupported device: {device_kind}")
112
+
113
+
114
+ def cpu_event_fixed_prob(
115
+ indices: jax.Array,
116
+ weight: Union[u.Quantity, jax.Array],
117
+ spk: jax.Array,
118
+ *,
119
+ n_post: int,
120
+ grad_mode: str = 'vjp'
121
+ ) -> Union[u.Quantity, jax.Array]:
122
+ """
123
+ The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
124
+
125
+ Parameters
126
+ ----------
127
+ n_post : int
128
+ Number of post-synaptic neurons.
129
+ weight : brainunit.Quantity or jax.Array
130
+ Maximum synaptic conductance.
131
+ spk : jax.Array
132
+ Spike events.
133
+ indices : jax.Array
134
+ Indices of post connected neurons.
135
+ grad_mode : str, optional
136
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
137
+
138
+ Returns
139
+ -------
140
+ post_data : brainunit.Quantity or jax.Array
141
+ Post synaptic data.
142
+ """
143
+ unit = u.get_unit(weight)
144
+ weight = u.get_mantissa(weight)
145
+ indices = jnp.asarray(indices)
146
+ spk = jnp.asarray(spk)
147
+
148
+ def mv(spk_vector):
149
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
150
+ if grad_mode == 'vjp':
151
+ post_data = _cpu_event_fixed_prob_mv_vjp(indices, weight, spk_vector, n_post)
152
+ elif grad_mode == 'jvp':
153
+ post_data = _cpu_event_fixed_prob_mv_jvp(indices, weight, spk_vector, n_post)
154
+ else:
155
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
156
+ return post_data
157
+
158
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
159
+ assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
160
+ assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
161
+
162
+ if spk.ndim == 1:
163
+ post_data = mv(spk)
164
+ else:
165
+ shape = spk.shape[:-1]
166
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
167
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
168
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
169
+
170
+
171
+ # -------------------
172
+ # CPU Implementation
173
+ # -------------------
174
+
175
+
176
+ def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
177
+ def scan_fn(post, i):
178
+ w = g_max if jnp.size(g_max) == 1 else g_max[i]
179
+ ids = indices[i]
180
+ sp = spk[i]
181
+ if spk.dtype == jnp.bool_:
182
+ post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
183
+ else:
184
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
185
+ return post, None
186
+
187
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
188
+
189
+
190
+ # --------------
191
+ # VJP
192
+ # --------------
193
+
194
+ def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
195
+ return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
196
+
197
+
198
+ def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
199
+ weight, spk = res
200
+
201
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
202
+ homo = jnp.size(weight) == 1
203
+ if homo: # homogeneous weight
204
+ ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
205
+ else: # heterogeneous weight
206
+ ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
207
+
208
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
209
+ if homo: # scalar
210
+ ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
211
+ ct_gmax = jnp.inner(ct, ct_gmax)
212
+ else:
213
+ def scan_fn(d_gmax, i):
214
+ if spk.dtype == jnp.bool_:
215
+ d_gmax = jax.lax.cond(spk[i], lambda: d_gmax.at[i].add(ct[indices[i]]), lambda: d_gmax)
216
+ else:
217
+ d_gmax = jax.lax.cond(spk[i] == 0., lambda: d_gmax, lambda: d_gmax.at[i].add(ct[indices[i]] * spk[i]))
218
+ return d_gmax, None
219
+
220
+ ct_gmax = jax.lax.scan(scan_fn, jnp.zeros_like(weight), np.arange(len(spk)))[0]
221
+ return ct_gmax, ct_spk
222
+
223
+
224
+ _cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
225
+ _cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
226
+
227
+
228
+ # --------------
229
+ # JVP
230
+ # --------------
231
+
232
+
233
+ def _cpu_event_fixed_prob_mv_jvp_rule(indices, n_post, primals, tangents):
234
+ # forward pass
235
+ weight, spk = primals
236
+ y = _cpu_event_fixed_prob_mv(indices, weight, spk, n_post=n_post)
237
+
238
+ # forward gradients
239
+ gmax_dot, spk_dot = tangents
240
+
241
+ # ∂y/∂gmax
242
+ dgmax = _cpu_event_fixed_prob_mv(indices, gmax_dot, spk, n_post=n_post)
243
+
244
+ def scan_fn(post, i):
245
+ ids = indices[i]
246
+ w = weight if jnp.size(weight) == 1 else weight[i]
247
+ post = post.at[ids].add(w * spk_dot[i])
248
+ return post, None
249
+
250
+ # ∂y/∂gspk
251
+ dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
252
+ return y, dgmax + dspk
253
+
254
+
255
+ _cpu_event_fixed_prob_mv_jvp = jax.custom_jvp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
256
+ _cpu_event_fixed_prob_mv_jvp.defjvp(_cpu_event_fixed_prob_mv_jvp_rule)
257
+
258
+
259
+
260
+
261
+
262
+
263
+ def _gpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
264
+ def scan_fn(post, i):
265
+ w = g_max if jnp.size(g_max) == 1 else g_max[i]
266
+ ids = indices[i]
267
+ sp = spk[i]
268
+ if spk.dtype == jnp.bool_:
269
+ post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
270
+ else:
271
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
272
+ return post, None
273
+
274
+ return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
275
+
276
+
@@ -0,0 +1,127 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import jax.numpy
18
+ import jax.numpy as jnp
19
+ from absl.testing import parameterized
20
+
21
+ import brainstate as bst
22
+ from brainstate.nn.event.fixed_probability import EventFixedProb
23
+
24
+
25
+ class TestFixedProbCSR(parameterized.TestCase):
26
+ @parameterized.product(
27
+ allow_multi_conn=[True, False]
28
+ )
29
+ def test1(self, allow_multi_conn):
30
+ x = bst.random.rand(20) < 0.1
31
+ # x = bst.random.rand(20)
32
+ m = EventFixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
33
+ y = m(x)
34
+ print(y)
35
+
36
+ m2 = EventFixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
37
+ print(m2(x))
38
+
39
+ def test_grad_bool(self):
40
+ n_in = 20
41
+ n_out = 30
42
+ x = bst.random.rand(n_in) < 0.3
43
+ fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
44
+
45
+ def f(x):
46
+ return fn(x).sum()
47
+
48
+ with self.assertRaises(TypeError):
49
+ print(jax.grad(f)(x))
50
+
51
+ @parameterized.product(
52
+ bool_x=[True, False],
53
+ homo_w=[True, False]
54
+ )
55
+ def test_vjp(self, bool_x, homo_w):
56
+ n_in = 20
57
+ n_out = 30
58
+ if bool_x:
59
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
60
+ else:
61
+ x = bst.random.rand(n_in)
62
+
63
+ if homo_w:
64
+ fn = EventFixedProb(n_in, n_out, 0.1, 1.5, seed=123)
65
+ else:
66
+ fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
67
+ w = fn.weight
68
+
69
+ def f(x, w):
70
+ fn.weight = w
71
+ return fn(x).sum()
72
+
73
+ r = bst.transform.grad(f, argnums=(0, 1))(x, w)
74
+
75
+ # -------------------
76
+ # TRUE gradients
77
+
78
+ def true_fn(x, w, indices, n_post):
79
+ post = jnp.zeros((n_post,))
80
+ for i in range(n_in):
81
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
82
+ return post
83
+
84
+ def f2(x, w):
85
+ return true_fn(x, w, fn.indices, n_out).sum()
86
+
87
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
88
+ self.assertTrue(jnp.allclose(r[0], r2[0]))
89
+ self.assertTrue(jnp.allclose(r[1], r2[1]))
90
+ print(r[1])
91
+
92
+ @parameterized.product(
93
+ bool_x=[True, False],
94
+ homo_w=[True, False]
95
+ )
96
+ def test_jvp(self, bool_x, homo_w):
97
+ n_in = 20
98
+ n_out = 30
99
+ if bool_x:
100
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
101
+ else:
102
+ x = bst.random.rand(n_in)
103
+
104
+ fn = EventFixedProb(n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(), seed=123, grad_mode='jvp')
105
+ w = fn.weight
106
+
107
+ def f(x, w):
108
+ fn.weight = w
109
+ return fn(x)
110
+
111
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
112
+
113
+ # -------------------
114
+ # TRUE gradients
115
+
116
+ def true_fn(x, w, indices, n_post):
117
+ post = jnp.zeros((n_post,))
118
+ for i in range(n_in):
119
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
120
+ return post
121
+
122
+ def f2(x, w):
123
+ return true_fn(x, w, fn.indices, n_out)
124
+
125
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
126
+ self.assertTrue(jnp.allclose(r1, r2))
127
+ self.assertTrue(jnp.allclose(o1, o2))
@@ -0,0 +1,220 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Union, Callable, Optional
17
+
18
+ import brainunit as u
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+
23
+ from brainstate._state import ParamState, State
24
+ from brainstate.init import param
25
+ from brainstate.mixin import Mode, Training
26
+ from brainstate.nn._base import DnnLayer
27
+ from brainstate.typing import ArrayLike
28
+ from ._misc import IntScalar
29
+
30
+ __all__ = [
31
+ 'EventDense',
32
+ ]
33
+
34
+
35
+ class EventDense(DnnLayer):
36
+ """
37
+ The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
38
+
39
+ Parameters
40
+ ----------
41
+ n_pre : int
42
+ Number of pre-synaptic neurons.
43
+ n_post : int
44
+ Number of post-synaptic neurons.
45
+ weight : float or callable or jax.Array or brainunit.Quantity
46
+ Maximum synaptic conductance.
47
+ name : str, optional
48
+ Name of the module.
49
+ mode : brainstate.mixin.Mode, optional
50
+ Mode of the module.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ n_pre: IntScalar,
56
+ n_post: IntScalar,
57
+ weight: Union[Callable, ArrayLike],
58
+ name: Optional[str] = None,
59
+ mode: Optional[Mode] = None,
60
+ grad_mode: str = 'vjp'
61
+ ):
62
+ super().__init__(name=name, mode=mode)
63
+ self.n_pre = n_pre
64
+ self.n_post = n_post
65
+ self.in_size = n_pre
66
+ self.out_size = n_post
67
+
68
+ assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
+ self.grad_mode = grad_mode
70
+
71
+ # maximum synaptic conductance
72
+ weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
73
+ if self.mode.has(Training):
74
+ weight = ParamState(weight)
75
+ self.weight = weight
76
+
77
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
78
+ weight = self.weight.value if isinstance(self.weight, State) else self.weight
79
+ # if u.math.size(weight) == 1:
80
+ # return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
81
+
82
+ device_kind = jax.devices()[0].platform # spk.device.device_kind
83
+ if device_kind == 'cpu':
84
+ return cpu_event_linear(u.math.asarray(weight),
85
+ u.math.asarray(spk),
86
+ n_post=self.n_post,
87
+ grad_mode=self.grad_mode)
88
+ elif device_kind in ['gpu', 'tpu']:
89
+ raise NotImplementedError()
90
+ else:
91
+ raise ValueError(f"Unsupported device: {device_kind}")
92
+
93
+
94
+ def cpu_event_linear(
95
+ g_max: Union[u.Quantity, jax.Array],
96
+ spk: jax.Array,
97
+ *,
98
+ n_post: int = None,
99
+ grad_mode: str = 'vjp'
100
+ ) -> Union[u.Quantity, jax.Array]:
101
+ """
102
+ The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
103
+
104
+ Parameters
105
+ ----------
106
+ n_post : int
107
+ Number of post-synaptic neurons.
108
+ g_max : brainunit.Quantity or jax.Array
109
+ Maximum synaptic conductance.
110
+ spk : jax.Array
111
+ Spike events.
112
+ grad_mode : str, optional
113
+ Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
114
+
115
+ Returns
116
+ -------
117
+ post_data : brainunit.Quantity or jax.Array
118
+ Post synaptic data.
119
+ """
120
+ unit = u.get_unit(g_max)
121
+ g_max = u.get_mantissa(g_max)
122
+ spk = jnp.asarray(spk)
123
+
124
+ def mv(spk_vector):
125
+ assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
126
+ if jnp.size(g_max) == 1:
127
+ assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
128
+ # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
129
+ return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
130
+
131
+ if grad_mode == 'vjp':
132
+ post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
133
+ elif grad_mode == 'jvp':
134
+ post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
135
+ else:
136
+ raise ValueError(f"Unsupported grad_mode: {grad_mode}")
137
+ return post
138
+
139
+ assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
140
+ assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
141
+
142
+ if spk.ndim == 1:
143
+ post_data = mv(spk)
144
+ else:
145
+ shape = spk.shape[:-1]
146
+ post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
147
+ post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
148
+ return u.maybe_decimal(u.Quantity(post_data, unit=unit))
149
+
150
+
151
+ # --------------
152
+ # Implementation
153
+ # --------------
154
+
155
+
156
+ def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
157
+ def scan_fn(post, i):
158
+ sp = spk[i]
159
+ if spk.dtype == jnp.bool_:
160
+ post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
161
+ else:
162
+ post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
163
+ return post, None
164
+
165
+ return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
166
+
167
+
168
+ # --------------
169
+ # VJP
170
+ # --------------
171
+
172
+ def _cpu_event_linear_mv_fwd(g_max, spk):
173
+ return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
174
+
175
+
176
+ def _cpu_event_linear_mv_bwd(res, ct):
177
+ g_max, spk = res
178
+
179
+ # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
180
+ ct_spk = jnp.matmul(g_max, ct)
181
+
182
+ # ∂L/∂w = ∂L/∂y * ∂y/∂w
183
+ def map_fn(sp):
184
+ if spk.dtype == jnp.bool_:
185
+ d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
186
+ else:
187
+ d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
188
+ return d_gmax
189
+
190
+ ct_gmax = jax.vmap(map_fn)(spk)
191
+ return ct_gmax, ct_spk
192
+
193
+
194
+ _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
195
+ _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
196
+
197
+
198
+ # --------------
199
+ # JVP
200
+ # --------------
201
+
202
+
203
+ def _cpu_event_linear_mv_jvp_rule(primals, tangents):
204
+ # forward pass
205
+ g_max, spk = primals
206
+ y = _cpu_event_linear_mv(g_max, spk)
207
+
208
+ # forward gradients
209
+ gmax_dot, spk_dot = tangents
210
+
211
+ # ∂y/∂gmax
212
+ dgmax = _cpu_event_linear_mv(gmax_dot, spk)
213
+
214
+ # ∂y/∂gspk
215
+ dspk = spk_dot @ g_max
216
+ return y, dgmax + dspk
217
+
218
+
219
+ _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
220
+ _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)