brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/augment/__init__.py +10 -20
  3. brainstate/compile/__init__.py +18 -37
  4. brainstate/compile/_make_jaxpr.py +9 -2
  5. brainstate/compile/_make_jaxpr_test.py +10 -6
  6. brainstate/compile/_progress_bar.py +49 -6
  7. brainstate/compile/_unvmap.py +3 -3
  8. brainstate/graph/__init__.py +12 -12
  9. brainstate/nn/_dyn_impl/_inputs.py +4 -2
  10. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  11. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
  12. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
  13. brainstate/event/__init__.py +0 -27
  14. brainstate/event/_csr.py +0 -1149
  15. brainstate/event/_csr_benchmark.py +0 -14
  16. brainstate/event/_csr_mv.py +0 -303
  17. brainstate/event/_csr_test.py +0 -277
  18. brainstate/event/_fixedprob_mv.py +0 -730
  19. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  20. brainstate/event/_fixedprob_mv_test.py +0 -132
  21. brainstate/event/_linear_mv.py +0 -359
  22. brainstate/event/_linear_mv_benckmark.py +0 -82
  23. brainstate/event/_linear_mv_test.py +0 -117
  24. brainstate/event/_misc.py +0 -34
  25. brainstate/event/_xla_custom_op.py +0 -317
  26. brainstate/event/_xla_custom_op_test.py +0 -55
  27. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
@@ -1,14 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,303 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from __future__ import annotations
16
-
17
- from typing import Union, Callable, Optional
18
-
19
- import brainunit as u
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate._state import ParamState
25
- from brainstate._utils import set_module_as
26
- from brainstate.init import param
27
- from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike, Size
29
-
30
- __all__ = [
31
- 'CSRLinear',
32
- ]
33
-
34
-
35
- class CSRLinear(Module):
36
- """
37
- The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
38
-
39
- Parameters
40
- ----------
41
- in_size : Size
42
- Number of pre-synaptic neurons, i.e., input size.
43
- out_size : Size
44
- Number of post-synaptic neurons, i.e., output size.
45
- weight : float or callable or jax.Array or brainunit.Quantity
46
- Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
47
- name : str, optional
48
- Name of the module.
49
- """
50
-
51
- __module__ = 'brainstate.event'
52
-
53
- def __init__(
54
- self,
55
- in_size: Size,
56
- out_size: Size,
57
- indptr: ArrayLike,
58
- indices: ArrayLike,
59
- weight: Union[Callable, ArrayLike],
60
- name: Optional[str] = None,
61
- ):
62
- super().__init__(name=name)
63
-
64
- # network size
65
- self.in_size = in_size
66
- self.out_size = out_size
67
- self.n_pre = self.in_size[-1]
68
- self.n_post = self.out_size[-1]
69
-
70
- # CSR data structure
71
- with jax.ensure_compile_time_eval():
72
- indptr = jnp.asarray(indptr)
73
- indices = jnp.asarray(indices)
74
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
75
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
76
- assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
77
- self.indptr = u.math.asarray(indptr)
78
- self.indices = u.math.asarray(indices)
79
-
80
- # maximum synaptic conductance
81
- weight = param(weight, (len(indices),), allow_none=False)
82
- if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
83
- raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
84
- self.weight = ParamState(weight)
85
-
86
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
87
- weight = self.weight.value
88
-
89
- # return zero if no pre-synaptic neurons
90
- if len(self.indices) == 0:
91
- r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
92
- dtype=weight.dtype,
93
- unit=u.get_unit(weight) * u.get_unit(spk))
94
- return u.maybe_decimal(r)
95
-
96
- device_kind = jax.devices()[0].platform # spk.device.device_kind
97
-
98
- # CPU implementation
99
- return cpu_event_csr(
100
- u.math.asarray(spk),
101
- self.indptr,
102
- self.indices,
103
- u.math.asarray(weight),
104
- n_post=self.n_post,
105
- )
106
-
107
-
108
- @set_module_as('brainstate.event')
109
- def cpu_event_csr(
110
- spk: jax.Array,
111
- indptr: jax.Array,
112
- indices: jax.Array,
113
- weight: Union[u.Quantity, jax.Array],
114
- *,
115
- n_post: int,
116
- grad_mode: str = 'vjp'
117
- ) -> Union[u.Quantity, jax.Array]:
118
- """
119
- The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
120
-
121
- Parameters
122
- ----------
123
- spk : jax.Array
124
- Spike events.
125
- indptr : jax.Array
126
- Index pointer of post connected neurons.
127
- indices : jax.Array
128
- Indices of post connected neurons.
129
- weight : brainunit.Quantity or jax.Array
130
- Maximum synaptic conductance.
131
- n_post : int
132
- Number of post-synaptic neurons.
133
- grad_mode : str, optional
134
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
135
-
136
- Returns
137
- -------
138
- post_data : brainunit.Quantity or jax.Array
139
- Post synaptic data.
140
- """
141
- unit = u.get_unit(weight)
142
- weight = u.get_mantissa(weight)
143
-
144
- def mv(spk_vector):
145
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
146
- if grad_mode == 'vjp':
147
- post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
148
- elif grad_mode == 'jvp':
149
- post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
150
- else:
151
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
152
- return post_data
153
-
154
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
155
- assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
156
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
157
-
158
- if spk.ndim == 1:
159
- post_data = mv(spk)
160
- else:
161
- shape = spk.shape[:-1]
162
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
163
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
164
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
165
-
166
-
167
- # --------------
168
- # Implementation
169
- # --------------
170
-
171
-
172
- def _cpu_event_csr_mv(
173
- spk: jax.Array,
174
- indptr: jax.Array,
175
- indices: jax.Array,
176
- weight: Union[u.Quantity, jax.Array],
177
- n_post: int
178
- ) -> jax.Array:
179
- bool_x = spk.dtype == jnp.bool_
180
- homo_w = jnp.size(weight) == 1
181
-
182
- def add_fn(post_val, i_start, i_end, sp):
183
- def body_fn(x):
184
- post, i = x
185
- i_post = indices[i]
186
- w = weight if homo_w else weight[i]
187
- w = w if bool_x else w * sp
188
- post = post.at[i_post].add(w)
189
- return post, i + 1
190
-
191
- return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
192
-
193
- def scan_fn(post, i):
194
- sp = spk[i] # pre-synaptic spike event
195
- if bool_x:
196
- post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
197
- else:
198
- post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
199
- return post, None
200
-
201
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
202
-
203
-
204
- # --------------
205
- # VJP
206
- # --------------
207
-
208
- def _cpu_event_csr_mv_fwd(
209
- spk: jax.Array,
210
- indptr: jax.Array,
211
- indices: jax.Array,
212
- weight: Union[u.Quantity, jax.Array],
213
- n_post: int
214
- ):
215
- return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
216
-
217
-
218
- def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
219
- spk, weight = res
220
- homo = jnp.size(weight) == 1
221
- bool_spk = spk.dtype == jnp.bool_
222
-
223
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
224
- def fn_spk(i_pre):
225
- def body_fn(x):
226
- r, i = x
227
- i_post = indices[i]
228
- r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
229
- return r, i + 1
230
-
231
- p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
232
- p = p * weight if homo else p
233
- return p
234
-
235
- ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
236
-
237
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
238
- if homo: # scalar
239
- ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
240
- ct_gmax = jnp.inner(ct, ct_gmax)
241
- else:
242
- def single_post(dw, i_pre):
243
- def body_fn(x):
244
- dw, i = x
245
- i_post = indices[i]
246
- dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
247
- return dw, i + 1
248
-
249
- return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
250
-
251
- def fn_w(dw, i_pre):
252
- sp = spk[i_pre]
253
- if bool_spk:
254
- return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
255
- else:
256
- return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
257
-
258
- ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
259
- return ct_spk, ct_gmax
260
-
261
-
262
- _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
263
- _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
264
-
265
-
266
- # --------------
267
- # JVP
268
- # --------------
269
-
270
-
271
- def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
272
- # forward pass
273
- spk, weight = primals
274
- y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
275
-
276
- # forward gradients
277
- spk_dot, weight_dot = tangents
278
- homo_w = jnp.size(weight) == 1
279
-
280
- # ∂y/∂gmax
281
- dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
282
-
283
- # ∂y/∂gspk
284
- def scan_fn(post, i_pre):
285
- def while_fn(x):
286
- p, i, sp = x
287
- i_post = indices[i]
288
- p = p.at[i_post].add(sp if homo_w else sp * weight[i])
289
- return p, i + 1, sp
290
-
291
- post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
292
- while_fn,
293
- (post, indptr[i_pre], spk_dot[i_pre]))[0]
294
-
295
- return post, None
296
-
297
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
298
- dspk = (dspk * weight) if homo_w else dspk
299
- return y, dweight + dspk
300
-
301
-
302
- _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
303
- _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -1,277 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # -*- coding: utf-8 -*-
16
-
17
-
18
- import unittest
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- import brainstate as bst
26
-
27
-
28
- class TestCSR(unittest.TestCase):
29
- def test_event_homo_bool(self):
30
- for dat in [1., 2., 3.]:
31
- mask = (bst.random.rand(10, 20) < 0.1).astype(float) * dat
32
- csr = u.sparse.CSR.fromdense(mask)
33
- csr = bst.event.CSR((dat, csr.indices, csr.indptr), shape=mask.shape)
34
-
35
- v = bst.random.rand(20) < 0.5
36
- self.assertTrue(
37
- u.math.allclose(
38
- mask.astype(float) @ v.astype(float),
39
- csr @ v
40
- )
41
- )
42
-
43
- v = bst.random.rand(10) < 0.5
44
- self.assertTrue(
45
- u.math.allclose(
46
- v.astype(float) @ mask.astype(float),
47
- v @ csr
48
- )
49
- )
50
-
51
- def test_event_homo_heter(self):
52
- mat = bst.random.rand(10, 20)
53
- mask = (bst.random.rand(10, 20) < 0.1) * mat
54
- csr = u.sparse.CSR.fromdense(mask)
55
- csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
56
-
57
- v = bst.random.rand(20) < 0.5
58
- self.assertTrue(
59
- u.math.allclose(
60
- mask.astype(float) @ v.astype(float),
61
- csr @ v
62
- )
63
- )
64
-
65
- v = bst.random.rand(10) < 0.5
66
- self.assertTrue(
67
- u.math.allclose(
68
- v.astype(float) @ mask.astype(float),
69
- v @ csr
70
- )
71
- )
72
-
73
- def test_event_heter_float_as_bool(self):
74
- mat = bst.random.rand(10, 20)
75
- mask = (mat < 0.1).astype(float) * mat
76
- csr = u.sparse.CSR.fromdense(mask)
77
- csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
78
-
79
- v = (bst.random.rand(20) < 0.5).astype(float)
80
- self.assertTrue(
81
- u.math.allclose(
82
- mask.astype(float) @ v.astype(float),
83
- csr @ v
84
- )
85
- )
86
-
87
- v = (bst.random.rand(10) < 0.5).astype(float)
88
- self.assertTrue(
89
- u.math.allclose(
90
- v.astype(float) @ mask.astype(float),
91
- v @ csr
92
- )
93
- )
94
-
95
-
96
- def _get_csr(n_pre, n_post, prob):
97
- n_conn = int(n_post * prob)
98
- indptr = np.arange(n_pre + 1) * n_conn
99
- indices = np.random.randint(0, n_post, (n_pre * n_conn,))
100
- return indptr, indices
101
-
102
-
103
- def vector_csr(x, w, indices, indptr, shape):
104
- homo_w = jnp.size(w) == 1
105
- post = jnp.zeros((shape[1],))
106
- for i_pre in range(x.shape[0]):
107
- ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
108
- post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
109
- return post
110
-
111
-
112
- def matrix_csr(xs, w, indices, indptr, shape):
113
- homo_w = jnp.size(w) == 1
114
- post = jnp.zeros((xs.shape[0], shape[1]))
115
- for i_pre in range(xs.shape[1]):
116
- ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
117
- post = post.at[:, ids].add(
118
- w * xs[:, i_pre: i_pre + 1]
119
- if homo_w else
120
- (w[indptr[i_pre]: indptr[i_pre + 1]] * xs[:, i_pre: i_pre + 1])
121
- )
122
- return post
123
-
124
-
125
- def csr_vector(x, w, indices, indptr, shape):
126
- homo_w = jnp.size(w) == 1
127
- out = jnp.zeros([shape[0]])
128
- for i in range(shape[0]):
129
- ids = indices[indptr[i]: indptr[i + 1]]
130
- ws = w if homo_w else w[indptr[i]: indptr[i + 1]]
131
- out = out.at[i].set(jnp.sum(x[ids] * ws))
132
- return out
133
-
134
-
135
- def csr_matrix(xs, w, indices, indptr, shape):
136
- # CSR @ matrix
137
- homo_w = jnp.size(w) == 1
138
- out = jnp.zeros([shape[0], xs.shape[1]])
139
- for i in range(shape[0]):
140
- ids = indices[indptr[i]: indptr[i + 1]]
141
- ws = w if homo_w else jnp.expand_dims(w[indptr[i]: indptr[i + 1]], axis=1)
142
- out = out.at[i].set(jnp.sum(xs[ids] * ws, axis=0))
143
- return out
144
-
145
-
146
- class TestVectorCSR(unittest.TestCase):
147
- def test_vector_csr(self, ):
148
- m, n = 20, 40
149
- x = bst.random.rand(m) < 0.1
150
- indptr, indices = _get_csr(m, n, 0.1)
151
-
152
- for homo_w in [True, False]:
153
- print(f'homo_w = {homo_w}')
154
- data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
155
- csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
156
- y = x @ csr
157
- y2 = vector_csr(x, csr.data, indices, indptr, [m, n])
158
- self.assertTrue(jnp.allclose(y, y2))
159
-
160
- def test_vector_csr_vmap_vector(self):
161
- n_batch, m, n = 10, 20, 40
162
- xs = bst.random.rand(n_batch, m) < 0.1
163
- indptr, indices = _get_csr(m, n, 0.1)
164
-
165
- for homo_w in [True, False]:
166
- data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
167
- csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
168
- y = jax.vmap(lambda x: x @ csr)(xs)
169
- y2 = jax.vmap(lambda x: vector_csr(x, csr.data, indices, indptr, [m, n]))(xs)
170
- self.assertTrue(jnp.allclose(y, y2))
171
-
172
-
173
- class TestMatrixCSR(unittest.TestCase):
174
- def test_matrix_csr(self):
175
- k, m, n = 10, 20, 40
176
- x = bst.random.rand(k, m) < 0.1
177
- indptr, indices = _get_csr(m, n, 0.1)
178
-
179
- for homo_w in [True, False]:
180
- data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
181
- csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
182
- y = x @ csr
183
- y2 = matrix_csr(x, csr.data, indices, indptr, [m, n])
184
- self.assertTrue(jnp.allclose(y, y2))
185
-
186
-
187
- class TestCSRVector(unittest.TestCase):
188
- def test_csr_vector(self):
189
- m, n = 20, 40
190
- v = bst.random.rand(n) < 0.1
191
- indptr, indices = _get_csr(m, n, 0.1)
192
-
193
- for homo_w in [True, False]:
194
- data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
195
- csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
196
- y = csr @ v
197
- y2 = csr_vector(v, csr.data, indices, indptr, [m, n])
198
- self.assertTrue(jnp.allclose(y, y2))
199
-
200
-
201
- class TestCSRMatrix(unittest.TestCase):
202
- def test_csr_matrix(self):
203
- m, n, k = 20, 40, 10
204
- matrix = bst.random.rand(n, k) < 0.1
205
- indptr, indices = _get_csr(m, n, 0.1)
206
-
207
- for homo_w in [True, False]:
208
- data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
209
- csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
210
- y = csr @ matrix
211
- y2 = csr_matrix(matrix, csr.data, indices, indptr, [m, n])
212
- self.assertTrue(jnp.allclose(y, y2))
213
-
214
- # @parameterized.product(
215
- # bool_x=[True, False],
216
- # homo_w=[True, False]
217
- # )
218
- # def test_vjp(self, bool_x, homo_w):
219
- # n_in = 20
220
- # n_out = 30
221
- # if bool_x:
222
- # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
223
- # else:
224
- # x = bst.random.rand(n_in)
225
- #
226
- # indptr, indices = _get_csr(n_in, n_out, 0.1)
227
- # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
228
- # w = fn.weight.value
229
- #
230
- # def f(x, w):
231
- # fn.weight.value = w
232
- # return fn(x).sum()
233
- #
234
- # r = jax.grad(f, argnums=(0, 1))(x, w)
235
- #
236
- # # -------------------
237
- # # TRUE gradients
238
- #
239
- # def f2(x, w):
240
- # return true_fn(x, w, indices, indptr, n_out).sum()
241
- #
242
- # r2 = jax.grad(f2, argnums=(0, 1))(x, w)
243
- # self.assertTrue(jnp.allclose(r[0], r2[0]))
244
- # self.assertTrue(jnp.allclose(r[1], r2[1]))
245
- #
246
- # @parameterized.product(
247
- # bool_x=[True, False],
248
- # homo_w=[True, False]
249
- # )
250
- # def test_jvp(self, bool_x, homo_w):
251
- # n_in = 20
252
- # n_out = 30
253
- # if bool_x:
254
- # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
255
- # else:
256
- # x = bst.random.rand(n_in)
257
- #
258
- # indptr, indices = _get_csr(n_in, n_out, 0.1)
259
- # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
260
- # 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
261
- # w = fn.weight.value
262
- #
263
- # def f(x, w):
264
- # fn.weight.value = w
265
- # return fn(x)
266
- #
267
- # o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
268
- #
269
- # # -------------------
270
- # # TRUE gradients
271
- #
272
- # def f2(x, w):
273
- # return true_fn(x, w, indices, indptr, n_out)
274
- #
275
- # o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
276
- # self.assertTrue(jnp.allclose(r1, r2))
277
- # self.assertTrue(jnp.allclose(o1, o2))