brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,127 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import jax.numpy
18
- import jax.numpy as jnp
19
- from absl.testing import parameterized
20
-
21
- import brainstate as bst
22
- from brainstate.nn.event.fixed_probability import EventFixedProb
23
-
24
-
25
- class TestFixedProbCSR(parameterized.TestCase):
26
- @parameterized.product(
27
- allow_multi_conn=[True, False]
28
- )
29
- def test1(self, allow_multi_conn):
30
- x = bst.random.rand(20) < 0.1
31
- # x = bst.random.rand(20)
32
- m = EventFixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
33
- y = m(x)
34
- print(y)
35
-
36
- m2 = EventFixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
37
- print(m2(x))
38
-
39
- def test_grad_bool(self):
40
- n_in = 20
41
- n_out = 30
42
- x = bst.random.rand(n_in) < 0.3
43
- fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
44
-
45
- def f(x):
46
- return fn(x).sum()
47
-
48
- with self.assertRaises(TypeError):
49
- print(jax.grad(f)(x))
50
-
51
- @parameterized.product(
52
- bool_x=[True, False],
53
- homo_w=[True, False]
54
- )
55
- def test_vjp(self, bool_x, homo_w):
56
- n_in = 20
57
- n_out = 30
58
- if bool_x:
59
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
60
- else:
61
- x = bst.random.rand(n_in)
62
-
63
- if homo_w:
64
- fn = EventFixedProb(n_in, n_out, 0.1, 1.5, seed=123)
65
- else:
66
- fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
67
- w = fn.weight
68
-
69
- def f(x, w):
70
- fn.weight = w
71
- return fn(x).sum()
72
-
73
- r = bst.transform.grad(f, argnums=(0, 1))(x, w)
74
-
75
- # -------------------
76
- # TRUE gradients
77
-
78
- def true_fn(x, w, indices, n_post):
79
- post = jnp.zeros((n_post,))
80
- for i in range(n_in):
81
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
82
- return post
83
-
84
- def f2(x, w):
85
- return true_fn(x, w, fn.indices, n_out).sum()
86
-
87
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
88
- self.assertTrue(jnp.allclose(r[0], r2[0]))
89
- self.assertTrue(jnp.allclose(r[1], r2[1]))
90
- print(r[1])
91
-
92
- @parameterized.product(
93
- bool_x=[True, False],
94
- homo_w=[True, False]
95
- )
96
- def test_jvp(self, bool_x, homo_w):
97
- n_in = 20
98
- n_out = 30
99
- if bool_x:
100
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
101
- else:
102
- x = bst.random.rand(n_in)
103
-
104
- fn = EventFixedProb(n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(), seed=123, grad_mode='jvp')
105
- w = fn.weight
106
-
107
- def f(x, w):
108
- fn.weight = w
109
- return fn(x)
110
-
111
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
112
-
113
- # -------------------
114
- # TRUE gradients
115
-
116
- def true_fn(x, w, indices, n_post):
117
- post = jnp.zeros((n_post,))
118
- for i in range(n_in):
119
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
120
- return post
121
-
122
- def f2(x, w):
123
- return true_fn(x, w, fn.indices, n_out)
124
-
125
- o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
126
- self.assertTrue(jnp.allclose(r1, r2))
127
- self.assertTrue(jnp.allclose(o1, o2))
@@ -1,220 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Callable, Optional
17
-
18
- import brainunit as u
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- from brainstate._state import ParamState, State
24
- from brainstate.init import param
25
- from brainstate.mixin import Mode, Training
26
- from brainstate.nn._base import DnnLayer
27
- from brainstate.typing import ArrayLike
28
- from ._misc import IntScalar
29
-
30
- __all__ = [
31
- 'EventDense',
32
- ]
33
-
34
-
35
- class EventDense(DnnLayer):
36
- """
37
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
38
-
39
- Parameters
40
- ----------
41
- n_pre : int
42
- Number of pre-synaptic neurons.
43
- n_post : int
44
- Number of post-synaptic neurons.
45
- weight : float or callable or jax.Array or brainunit.Quantity
46
- Maximum synaptic conductance.
47
- name : str, optional
48
- Name of the module.
49
- mode : brainstate.mixin.Mode, optional
50
- Mode of the module.
51
- """
52
-
53
- def __init__(
54
- self,
55
- n_pre: IntScalar,
56
- n_post: IntScalar,
57
- weight: Union[Callable, ArrayLike],
58
- name: Optional[str] = None,
59
- mode: Optional[Mode] = None,
60
- grad_mode: str = 'vjp'
61
- ):
62
- super().__init__(name=name, mode=mode)
63
- self.n_pre = n_pre
64
- self.n_post = n_post
65
- self.in_size = n_pre
66
- self.out_size = n_post
67
-
68
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
- self.grad_mode = grad_mode
70
-
71
- # maximum synaptic conductance
72
- weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
73
- if self.mode.has(Training):
74
- weight = ParamState(weight)
75
- self.weight = weight
76
-
77
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
78
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
79
- # if u.math.size(weight) == 1:
80
- # return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
81
-
82
- device_kind = jax.devices()[0].platform # spk.device.device_kind
83
- if device_kind == 'cpu':
84
- return cpu_event_linear(u.math.asarray(weight),
85
- u.math.asarray(spk),
86
- n_post=self.n_post,
87
- grad_mode=self.grad_mode)
88
- elif device_kind in ['gpu', 'tpu']:
89
- raise NotImplementedError()
90
- else:
91
- raise ValueError(f"Unsupported device: {device_kind}")
92
-
93
-
94
- def cpu_event_linear(
95
- g_max: Union[u.Quantity, jax.Array],
96
- spk: jax.Array,
97
- *,
98
- n_post: int = None,
99
- grad_mode: str = 'vjp'
100
- ) -> Union[u.Quantity, jax.Array]:
101
- """
102
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
103
-
104
- Parameters
105
- ----------
106
- n_post : int
107
- Number of post-synaptic neurons.
108
- g_max : brainunit.Quantity or jax.Array
109
- Maximum synaptic conductance.
110
- spk : jax.Array
111
- Spike events.
112
- grad_mode : str, optional
113
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
114
-
115
- Returns
116
- -------
117
- post_data : brainunit.Quantity or jax.Array
118
- Post synaptic data.
119
- """
120
- unit = u.get_unit(g_max)
121
- g_max = u.get_mantissa(g_max)
122
- spk = jnp.asarray(spk)
123
-
124
- def mv(spk_vector):
125
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
126
- if jnp.size(g_max) == 1:
127
- assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
128
- # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
129
- return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
130
-
131
- if grad_mode == 'vjp':
132
- post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
133
- elif grad_mode == 'jvp':
134
- post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
135
- else:
136
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
137
- return post
138
-
139
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
140
- assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
141
-
142
- if spk.ndim == 1:
143
- post_data = mv(spk)
144
- else:
145
- shape = spk.shape[:-1]
146
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
147
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
148
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
149
-
150
-
151
- # --------------
152
- # Implementation
153
- # --------------
154
-
155
-
156
- def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
157
- def scan_fn(post, i):
158
- sp = spk[i]
159
- if spk.dtype == jnp.bool_:
160
- post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
161
- else:
162
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
163
- return post, None
164
-
165
- return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
166
-
167
-
168
- # --------------
169
- # VJP
170
- # --------------
171
-
172
- def _cpu_event_linear_mv_fwd(g_max, spk):
173
- return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
174
-
175
-
176
- def _cpu_event_linear_mv_bwd(res, ct):
177
- g_max, spk = res
178
-
179
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
180
- ct_spk = jnp.matmul(g_max, ct)
181
-
182
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
183
- def map_fn(sp):
184
- if spk.dtype == jnp.bool_:
185
- d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
186
- else:
187
- d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
188
- return d_gmax
189
-
190
- ct_gmax = jax.vmap(map_fn)(spk)
191
- return ct_gmax, ct_spk
192
-
193
-
194
- _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
195
- _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
196
-
197
-
198
- # --------------
199
- # JVP
200
- # --------------
201
-
202
-
203
- def _cpu_event_linear_mv_jvp_rule(primals, tangents):
204
- # forward pass
205
- g_max, spk = primals
206
- y = _cpu_event_linear_mv(g_max, spk)
207
-
208
- # forward gradients
209
- gmax_dot, spk_dot = tangents
210
-
211
- # ∂y/∂gmax
212
- dgmax = _cpu_event_linear_mv(gmax_dot, spk)
213
-
214
- # ∂y/∂gspk
215
- dspk = spk_dot @ g_max
216
- return y, dgmax + dspk
217
-
218
-
219
- _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
220
- _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)
@@ -1,111 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import jax
18
- import jax.numpy as jnp
19
- from absl.testing import parameterized
20
-
21
- import brainstate as bst
22
- from brainstate.nn.event.linear import EventDense
23
-
24
-
25
- class TestEventLinear(parameterized.TestCase):
26
- @parameterized.product(
27
- homo_w=[True, False],
28
- bool_x=[True, False],
29
- )
30
- def test1(self, homo_w, bool_x):
31
- x = bst.random.rand(20) < 0.1
32
- if not bool_x:
33
- x = jnp.asarray(x, dtype=float)
34
- m = EventDense(20, 40, 1.5 if homo_w else bst.init.KaimingUniform())
35
- y = m(x)
36
- print(y)
37
-
38
- self.assertTrue(jnp.allclose(y, (x.sum() * m.weight) if homo_w else (x @ m.weight)))
39
-
40
- def test_grad_bool(self):
41
- n_in = 20
42
- n_out = 30
43
- x = bst.random.rand(n_in) < 0.3
44
- fn = EventDense(n_in, n_out, bst.init.KaimingUniform())
45
-
46
- with self.assertRaises(TypeError):
47
- print(jax.grad(lambda x: fn(x).sum())(x))
48
-
49
- @parameterized.product(
50
- bool_x=[True, False],
51
- homo_w=[True, False]
52
- )
53
- def test_vjp(self, bool_x, homo_w):
54
- n_in = 20
55
- n_out = 30
56
- if bool_x:
57
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
58
- else:
59
- x = bst.random.rand(n_in)
60
-
61
- fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform())
62
- w = fn.weight
63
-
64
- def f(x, w):
65
- fn.weight = w
66
- return fn(x).sum()
67
-
68
- r1 = jax.grad(f, argnums=(0, 1))(x, w)
69
-
70
- # -------------------
71
- # TRUE gradients
72
-
73
- def f2(x, w):
74
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
75
- return y.sum()
76
-
77
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
78
- self.assertTrue(jnp.allclose(r1[0], r2[0]))
79
- self.assertTrue(jnp.allclose(r1[1], r2[1]))
80
-
81
- @parameterized.product(
82
- bool_x=[True, False],
83
- homo_w=[True, False]
84
- )
85
- def test_jvp(self, bool_x, homo_w):
86
- n_in = 20
87
- n_out = 30
88
- if bool_x:
89
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
90
- else:
91
- x = bst.random.rand(n_in)
92
-
93
- fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), grad_mode='jvp')
94
- w = fn.weight
95
-
96
- def f(x, w):
97
- fn.weight = w
98
- return fn(x)
99
-
100
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
101
-
102
- # -------------------
103
- # TRUE gradients
104
-
105
- def f2(x, w):
106
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
107
- return y
108
-
109
- o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
110
- self.assertTrue(jnp.allclose(o1, o2))
111
- self.assertTrue(jnp.allclose(r1, r2))