brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,127 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import jax.numpy
18
- import jax.numpy as jnp
19
- from absl.testing import parameterized
20
-
21
- import brainstate as bst
22
- from brainstate.nn.event.fixed_probability import EventFixedProb
23
-
24
-
25
- class TestFixedProbCSR(parameterized.TestCase):
26
- @parameterized.product(
27
- allow_multi_conn=[True, False]
28
- )
29
- def test1(self, allow_multi_conn):
30
- x = bst.random.rand(20) < 0.1
31
- # x = bst.random.rand(20)
32
- m = EventFixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
33
- y = m(x)
34
- print(y)
35
-
36
- m2 = EventFixedProb(20, 40, 0.1, bst.init.KaimingUniform(), seed=123)
37
- print(m2(x))
38
-
39
- def test_grad_bool(self):
40
- n_in = 20
41
- n_out = 30
42
- x = bst.random.rand(n_in) < 0.3
43
- fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
44
-
45
- def f(x):
46
- return fn(x).sum()
47
-
48
- with self.assertRaises(TypeError):
49
- print(jax.grad(f)(x))
50
-
51
- @parameterized.product(
52
- bool_x=[True, False],
53
- homo_w=[True, False]
54
- )
55
- def test_vjp(self, bool_x, homo_w):
56
- n_in = 20
57
- n_out = 30
58
- if bool_x:
59
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
60
- else:
61
- x = bst.random.rand(n_in)
62
-
63
- if homo_w:
64
- fn = EventFixedProb(n_in, n_out, 0.1, 1.5, seed=123)
65
- else:
66
- fn = EventFixedProb(n_in, n_out, 0.1, bst.init.KaimingUniform(), seed=123)
67
- w = fn.weight
68
-
69
- def f(x, w):
70
- fn.weight = w
71
- return fn(x).sum()
72
-
73
- r = bst.transform.grad(f, argnums=(0, 1))(x, w)
74
-
75
- # -------------------
76
- # TRUE gradients
77
-
78
- def true_fn(x, w, indices, n_post):
79
- post = jnp.zeros((n_post,))
80
- for i in range(n_in):
81
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
82
- return post
83
-
84
- def f2(x, w):
85
- return true_fn(x, w, fn.indices, n_out).sum()
86
-
87
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
88
- self.assertTrue(jnp.allclose(r[0], r2[0]))
89
- self.assertTrue(jnp.allclose(r[1], r2[1]))
90
- print(r[1])
91
-
92
- @parameterized.product(
93
- bool_x=[True, False],
94
- homo_w=[True, False]
95
- )
96
- def test_jvp(self, bool_x, homo_w):
97
- n_in = 20
98
- n_out = 30
99
- if bool_x:
100
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
101
- else:
102
- x = bst.random.rand(n_in)
103
-
104
- fn = EventFixedProb(n_in, n_out, 0.1, 1.5 if homo_w else bst.init.KaimingUniform(), seed=123, grad_mode='jvp')
105
- w = fn.weight
106
-
107
- def f(x, w):
108
- fn.weight = w
109
- return fn(x)
110
-
111
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
112
-
113
- # -------------------
114
- # TRUE gradients
115
-
116
- def true_fn(x, w, indices, n_post):
117
- post = jnp.zeros((n_post,))
118
- for i in range(n_in):
119
- post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
120
- return post
121
-
122
- def f2(x, w):
123
- return true_fn(x, w, fn.indices, n_out)
124
-
125
- o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
126
- self.assertTrue(jnp.allclose(r1, r2))
127
- self.assertTrue(jnp.allclose(o1, o2))
@@ -1,220 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Callable, Optional
17
-
18
- import brainunit as u
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- from brainstate._state import ParamState, State
24
- from brainstate.init import param
25
- from brainstate.mixin import Mode, Training
26
- from brainstate.nn._base import DnnLayer
27
- from brainstate.typing import ArrayLike
28
- from ._misc import IntScalar
29
-
30
- __all__ = [
31
- 'EventDense',
32
- ]
33
-
34
-
35
- class EventDense(DnnLayer):
36
- """
37
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
38
-
39
- Parameters
40
- ----------
41
- n_pre : int
42
- Number of pre-synaptic neurons.
43
- n_post : int
44
- Number of post-synaptic neurons.
45
- weight : float or callable or jax.Array or brainunit.Quantity
46
- Maximum synaptic conductance.
47
- name : str, optional
48
- Name of the module.
49
- mode : brainstate.mixin.Mode, optional
50
- Mode of the module.
51
- """
52
-
53
- def __init__(
54
- self,
55
- n_pre: IntScalar,
56
- n_post: IntScalar,
57
- weight: Union[Callable, ArrayLike],
58
- name: Optional[str] = None,
59
- mode: Optional[Mode] = None,
60
- grad_mode: str = 'vjp'
61
- ):
62
- super().__init__(name=name, mode=mode)
63
- self.n_pre = n_pre
64
- self.n_post = n_post
65
- self.in_size = n_pre
66
- self.out_size = n_post
67
-
68
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
- self.grad_mode = grad_mode
70
-
71
- # maximum synaptic conductance
72
- weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
73
- if self.mode.has(Training):
74
- weight = ParamState(weight)
75
- self.weight = weight
76
-
77
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
78
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
79
- # if u.math.size(weight) == 1:
80
- # return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
81
-
82
- device_kind = jax.devices()[0].platform # spk.device.device_kind
83
- if device_kind == 'cpu':
84
- return cpu_event_linear(u.math.asarray(weight),
85
- u.math.asarray(spk),
86
- n_post=self.n_post,
87
- grad_mode=self.grad_mode)
88
- elif device_kind in ['gpu', 'tpu']:
89
- raise NotImplementedError()
90
- else:
91
- raise ValueError(f"Unsupported device: {device_kind}")
92
-
93
-
94
- def cpu_event_linear(
95
- g_max: Union[u.Quantity, jax.Array],
96
- spk: jax.Array,
97
- *,
98
- n_post: int = None,
99
- grad_mode: str = 'vjp'
100
- ) -> Union[u.Quantity, jax.Array]:
101
- """
102
- The EventFixedProb module implements a fixed probability connection with CSR sparse data structure.
103
-
104
- Parameters
105
- ----------
106
- n_post : int
107
- Number of post-synaptic neurons.
108
- g_max : brainunit.Quantity or jax.Array
109
- Maximum synaptic conductance.
110
- spk : jax.Array
111
- Spike events.
112
- grad_mode : str, optional
113
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
114
-
115
- Returns
116
- -------
117
- post_data : brainunit.Quantity or jax.Array
118
- Post synaptic data.
119
- """
120
- unit = u.get_unit(g_max)
121
- g_max = u.get_mantissa(g_max)
122
- spk = jnp.asarray(spk)
123
-
124
- def mv(spk_vector):
125
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
126
- if jnp.size(g_max) == 1:
127
- assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
128
- # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
129
- return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
130
-
131
- if grad_mode == 'vjp':
132
- post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
133
- elif grad_mode == 'jvp':
134
- post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
135
- else:
136
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
137
- return post
138
-
139
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
140
- assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
141
-
142
- if spk.ndim == 1:
143
- post_data = mv(spk)
144
- else:
145
- shape = spk.shape[:-1]
146
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
147
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
148
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
149
-
150
-
151
- # --------------
152
- # Implementation
153
- # --------------
154
-
155
-
156
- def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
157
- def scan_fn(post, i):
158
- sp = spk[i]
159
- if spk.dtype == jnp.bool_:
160
- post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
161
- else:
162
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
163
- return post, None
164
-
165
- return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
166
-
167
-
168
- # --------------
169
- # VJP
170
- # --------------
171
-
172
- def _cpu_event_linear_mv_fwd(g_max, spk):
173
- return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
174
-
175
-
176
- def _cpu_event_linear_mv_bwd(res, ct):
177
- g_max, spk = res
178
-
179
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
180
- ct_spk = jnp.matmul(g_max, ct)
181
-
182
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
183
- def map_fn(sp):
184
- if spk.dtype == jnp.bool_:
185
- d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
186
- else:
187
- d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
188
- return d_gmax
189
-
190
- ct_gmax = jax.vmap(map_fn)(spk)
191
- return ct_gmax, ct_spk
192
-
193
-
194
- _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
195
- _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
196
-
197
-
198
- # --------------
199
- # JVP
200
- # --------------
201
-
202
-
203
- def _cpu_event_linear_mv_jvp_rule(primals, tangents):
204
- # forward pass
205
- g_max, spk = primals
206
- y = _cpu_event_linear_mv(g_max, spk)
207
-
208
- # forward gradients
209
- gmax_dot, spk_dot = tangents
210
-
211
- # ∂y/∂gmax
212
- dgmax = _cpu_event_linear_mv(gmax_dot, spk)
213
-
214
- # ∂y/∂gspk
215
- dspk = spk_dot @ g_max
216
- return y, dgmax + dspk
217
-
218
-
219
- _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
220
- _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)
@@ -1,111 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import jax
18
- import jax.numpy as jnp
19
- from absl.testing import parameterized
20
-
21
- import brainstate as bst
22
- from brainstate.nn.event.linear import EventDense
23
-
24
-
25
- class TestEventLinear(parameterized.TestCase):
26
- @parameterized.product(
27
- homo_w=[True, False],
28
- bool_x=[True, False],
29
- )
30
- def test1(self, homo_w, bool_x):
31
- x = bst.random.rand(20) < 0.1
32
- if not bool_x:
33
- x = jnp.asarray(x, dtype=float)
34
- m = EventDense(20, 40, 1.5 if homo_w else bst.init.KaimingUniform())
35
- y = m(x)
36
- print(y)
37
-
38
- self.assertTrue(jnp.allclose(y, (x.sum() * m.weight) if homo_w else (x @ m.weight)))
39
-
40
- def test_grad_bool(self):
41
- n_in = 20
42
- n_out = 30
43
- x = bst.random.rand(n_in) < 0.3
44
- fn = EventDense(n_in, n_out, bst.init.KaimingUniform())
45
-
46
- with self.assertRaises(TypeError):
47
- print(jax.grad(lambda x: fn(x).sum())(x))
48
-
49
- @parameterized.product(
50
- bool_x=[True, False],
51
- homo_w=[True, False]
52
- )
53
- def test_vjp(self, bool_x, homo_w):
54
- n_in = 20
55
- n_out = 30
56
- if bool_x:
57
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
58
- else:
59
- x = bst.random.rand(n_in)
60
-
61
- fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform())
62
- w = fn.weight
63
-
64
- def f(x, w):
65
- fn.weight = w
66
- return fn(x).sum()
67
-
68
- r1 = jax.grad(f, argnums=(0, 1))(x, w)
69
-
70
- # -------------------
71
- # TRUE gradients
72
-
73
- def f2(x, w):
74
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
75
- return y.sum()
76
-
77
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
78
- self.assertTrue(jnp.allclose(r1[0], r2[0]))
79
- self.assertTrue(jnp.allclose(r1[1], r2[1]))
80
-
81
- @parameterized.product(
82
- bool_x=[True, False],
83
- homo_w=[True, False]
84
- )
85
- def test_jvp(self, bool_x, homo_w):
86
- n_in = 20
87
- n_out = 30
88
- if bool_x:
89
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
90
- else:
91
- x = bst.random.rand(n_in)
92
-
93
- fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), grad_mode='jvp')
94
- w = fn.weight
95
-
96
- def f(x, w):
97
- fn.weight = w
98
- return fn(x)
99
-
100
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
101
-
102
- # -------------------
103
- # TRUE gradients
104
-
105
- def f2(x, w):
106
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
107
- return y
108
-
109
- o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
110
- self.assertTrue(jnp.allclose(o1, o2))
111
- self.assertTrue(jnp.allclose(r1, r2))