brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -1,303 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from __future__ import annotations
16
-
17
- from typing import Union, Callable, Optional
18
-
19
- import brainunit as u
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate._state import ParamState
25
- from brainstate._utils import set_module_as
26
- from brainstate.init import param
27
- from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike, Size
29
-
30
- __all__ = [
31
- 'CSRLinear',
32
- ]
33
-
34
-
35
- class CSRLinear(Module):
36
- """
37
- The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
38
-
39
- Parameters
40
- ----------
41
- in_size : Size
42
- Number of pre-synaptic neurons, i.e., input size.
43
- out_size : Size
44
- Number of post-synaptic neurons, i.e., output size.
45
- weight : float or callable or jax.Array or brainunit.Quantity
46
- Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
47
- name : str, optional
48
- Name of the module.
49
- """
50
-
51
- __module__ = 'brainstate.event'
52
-
53
- def __init__(
54
- self,
55
- in_size: Size,
56
- out_size: Size,
57
- indptr: ArrayLike,
58
- indices: ArrayLike,
59
- weight: Union[Callable, ArrayLike],
60
- name: Optional[str] = None,
61
- ):
62
- super().__init__(name=name)
63
-
64
- # network size
65
- self.in_size = in_size
66
- self.out_size = out_size
67
- self.n_pre = self.in_size[-1]
68
- self.n_post = self.out_size[-1]
69
-
70
- # CSR data structure
71
- with jax.ensure_compile_time_eval():
72
- indptr = jnp.asarray(indptr)
73
- indices = jnp.asarray(indices)
74
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
75
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
76
- assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
77
- self.indptr = u.math.asarray(indptr)
78
- self.indices = u.math.asarray(indices)
79
-
80
- # maximum synaptic conductance
81
- weight = param(weight, (len(indices),), allow_none=False)
82
- if u.math.size(weight) != 1 and u.math.size(weight) != len(self.indices):
83
- raise ValueError(f"weight must be 1D or 2D with size {len(self.indices)}. Got: {u.math.size(weight)}")
84
- self.weight = ParamState(weight)
85
-
86
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
87
- weight = self.weight.value
88
-
89
- # return zero if no pre-synaptic neurons
90
- if len(self.indices) == 0:
91
- r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
92
- dtype=weight.dtype,
93
- unit=u.get_unit(weight) * u.get_unit(spk))
94
- return u.maybe_decimal(r)
95
-
96
- device_kind = jax.devices()[0].platform # spk.device.device_kind
97
-
98
- # CPU implementation
99
- return cpu_event_csr(
100
- u.math.asarray(spk),
101
- self.indptr,
102
- self.indices,
103
- u.math.asarray(weight),
104
- n_post=self.n_post,
105
- )
106
-
107
-
108
- @set_module_as('brainstate.event')
109
- def cpu_event_csr(
110
- spk: jax.Array,
111
- indptr: jax.Array,
112
- indices: jax.Array,
113
- weight: Union[u.Quantity, jax.Array],
114
- *,
115
- n_post: int,
116
- grad_mode: str = 'vjp'
117
- ) -> Union[u.Quantity, jax.Array]:
118
- """
119
- The CSRLinear module implements a fixed probability connection with CSR sparse data structure.
120
-
121
- Parameters
122
- ----------
123
- spk : jax.Array
124
- Spike events.
125
- indptr : jax.Array
126
- Index pointer of post connected neurons.
127
- indices : jax.Array
128
- Indices of post connected neurons.
129
- weight : brainunit.Quantity or jax.Array
130
- Maximum synaptic conductance.
131
- n_post : int
132
- Number of post-synaptic neurons.
133
- grad_mode : str, optional
134
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
135
-
136
- Returns
137
- -------
138
- post_data : brainunit.Quantity or jax.Array
139
- Post synaptic data.
140
- """
141
- unit = u.get_unit(weight)
142
- weight = u.get_mantissa(weight)
143
-
144
- def mv(spk_vector):
145
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
146
- if grad_mode == 'vjp':
147
- post_data = _cpu_event_csr_mv_vjp(spk_vector, indptr, indices, weight, n_post)
148
- elif grad_mode == 'jvp':
149
- post_data = _cpu_event_csr_mv_jvp(spk_vector, indptr, indices, weight, n_post)
150
- else:
151
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
152
- return post_data
153
-
154
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
155
- assert weight.ndim in [1, 0], f"g_max must be 1D or 0D. Got: {weight.ndim}"
156
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
157
-
158
- if spk.ndim == 1:
159
- post_data = mv(spk)
160
- else:
161
- shape = spk.shape[:-1]
162
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
163
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
164
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
165
-
166
-
167
- # --------------
168
- # Implementation
169
- # --------------
170
-
171
-
172
- def _cpu_event_csr_mv(
173
- spk: jax.Array,
174
- indptr: jax.Array,
175
- indices: jax.Array,
176
- weight: Union[u.Quantity, jax.Array],
177
- n_post: int
178
- ) -> jax.Array:
179
- bool_x = spk.dtype == jnp.bool_
180
- homo_w = jnp.size(weight) == 1
181
-
182
- def add_fn(post_val, i_start, i_end, sp):
183
- def body_fn(x):
184
- post, i = x
185
- i_post = indices[i]
186
- w = weight if homo_w else weight[i]
187
- w = w if bool_x else w * sp
188
- post = post.at[i_post].add(w)
189
- return post, i + 1
190
-
191
- return jax.lax.while_loop(lambda x: x[1] < i_end, body_fn, (post_val, i_start))[0]
192
-
193
- def scan_fn(post, i):
194
- sp = spk[i] # pre-synaptic spike event
195
- if bool_x:
196
- post = jax.lax.cond(sp, lambda: add_fn(post, indptr[i], indptr[i + 1], sp), lambda: post)
197
- else:
198
- post = jax.lax.cond(sp == 0., lambda: post, lambda: add_fn(post, indptr[i], indptr[i + 1], sp))
199
- return post, None
200
-
201
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
202
-
203
-
204
- # --------------
205
- # VJP
206
- # --------------
207
-
208
- def _cpu_event_csr_mv_fwd(
209
- spk: jax.Array,
210
- indptr: jax.Array,
211
- indices: jax.Array,
212
- weight: Union[u.Quantity, jax.Array],
213
- n_post: int
214
- ):
215
- return _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post), (spk, weight)
216
-
217
-
218
- def _cpu_event_csr_mv_bwd(indptr, indices, n_post, res, ct):
219
- spk, weight = res
220
- homo = jnp.size(weight) == 1
221
- bool_spk = spk.dtype == jnp.bool_
222
-
223
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
224
- def fn_spk(i_pre):
225
- def body_fn(x):
226
- r, i = x
227
- i_post = indices[i]
228
- r = r + (ct[i_post] if homo else ct[i_post] * weight[i])
229
- return r, i + 1
230
-
231
- p = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (0., indptr[i_pre]))[0]
232
- p = p * weight if homo else p
233
- return p
234
-
235
- ct_spk = jax.vmap(fn_spk)(np.arange(len(spk)))
236
-
237
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
238
- if homo: # scalar
239
- ct_gmax = _cpu_event_csr_mv(spk, indptr, indices, jnp.asarray(1.), n_post=n_post)
240
- ct_gmax = jnp.inner(ct, ct_gmax)
241
- else:
242
- def single_post(dw, i_pre):
243
- def body_fn(x):
244
- dw, i = x
245
- i_post = indices[i]
246
- dw = dw.at[i].add(ct[i_post] if bool_spk else ct[i_post] * spk[i_pre])
247
- return dw, i + 1
248
-
249
- return jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1], body_fn, (dw, indptr[i_pre]))[0]
250
-
251
- def fn_w(dw, i_pre):
252
- sp = spk[i_pre]
253
- if bool_spk:
254
- return jax.lax.cond(sp, lambda: single_post(dw, i_pre), lambda: dw), None
255
- else:
256
- return jax.lax.cond(sp == 0., lambda: dw, lambda: single_post(dw, i_pre)), None
257
-
258
- ct_gmax = jax.lax.scan(fn_w, jnp.zeros_like(weight), np.arange(len(spk)))[0]
259
- return ct_spk, ct_gmax
260
-
261
-
262
- _cpu_event_csr_mv_vjp = jax.custom_vjp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
263
- _cpu_event_csr_mv_vjp.defvjp(_cpu_event_csr_mv_fwd, _cpu_event_csr_mv_bwd)
264
-
265
-
266
- # --------------
267
- # JVP
268
- # --------------
269
-
270
-
271
- def _cpu_event_csr_mv_jvp_rule(indptr, indices, n_post, primals, tangents):
272
- # forward pass
273
- spk, weight = primals
274
- y = _cpu_event_csr_mv(spk, indptr, indices, weight, n_post=n_post)
275
-
276
- # forward gradients
277
- spk_dot, weight_dot = tangents
278
- homo_w = jnp.size(weight) == 1
279
-
280
- # ∂y/∂gmax
281
- dweight = _cpu_event_csr_mv(spk, indptr, indices, weight_dot, n_post=n_post)
282
-
283
- # ∂y/∂gspk
284
- def scan_fn(post, i_pre):
285
- def while_fn(x):
286
- p, i, sp = x
287
- i_post = indices[i]
288
- p = p.at[i_post].add(sp if homo_w else sp * weight[i])
289
- return p, i + 1, sp
290
-
291
- post = jax.lax.while_loop(lambda x: x[1] < indptr[i_pre + 1],
292
- while_fn,
293
- (post, indptr[i_pre], spk_dot[i_pre]))[0]
294
-
295
- return post, None
296
-
297
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
298
- dspk = (dspk * weight) if homo_w else dspk
299
- return y, dweight + dspk
300
-
301
-
302
- _cpu_event_csr_mv_jvp = jax.custom_jvp(_cpu_event_csr_mv, nondiff_argnums=(1, 2, 4))
303
- _cpu_event_csr_mv_jvp.defjvp(_cpu_event_csr_mv_jvp_rule)
@@ -1,14 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,118 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import jax.numpy
19
- import jax.numpy as jnp
20
- import numpy as np
21
- from absl.testing import parameterized
22
-
23
- import brainstate as bst
24
-
25
-
26
- def _get_csr(n_pre, n_post, prob):
27
- n_conn = int(n_post * prob)
28
- indptr = np.arange(n_pre + 1) * n_conn
29
- indices = np.random.randint(0, n_post, (n_pre * n_conn,))
30
- return indptr, indices
31
-
32
-
33
- def true_fn(x, w, indices, indptr, n_out):
34
- homo_w = jnp.size(w) == 1
35
-
36
- post = jnp.zeros((n_out,))
37
- for i_pre in range(x.shape[0]):
38
- ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
39
- post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
40
- return post
41
-
42
-
43
- # class TestFixedProbCSR(parameterized.TestCase):
44
- # @parameterized.product(
45
- # homo_w=[True, False],
46
- # )
47
- # def test1(self, homo_w):
48
- # x = bst.random.rand(20) < 0.1
49
- # indptr, indices = _get_csr(20, 40, 0.1)
50
- # m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
- # y = m(x)
52
- # y2 = true_fn(x, m.weight.value, indices, indptr, 40)
53
- # self.assertTrue(jnp.allclose(y, y2))
54
- #
55
- # @parameterized.product(
56
- # bool_x=[True, False],
57
- # homo_w=[True, False]
58
- # )
59
- # def test_vjp(self, bool_x, homo_w):
60
- # n_in = 20
61
- # n_out = 30
62
- # if bool_x:
63
- # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
- # else:
65
- # x = bst.random.rand(n_in)
66
- #
67
- # indptr, indices = _get_csr(n_in, n_out, 0.1)
68
- # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
- # w = fn.weight.value
70
- #
71
- # def f(x, w):
72
- # fn.weight.value = w
73
- # return fn(x).sum()
74
- #
75
- # r = jax.grad(f, argnums=(0, 1))(x, w)
76
- #
77
- # # -------------------
78
- # # TRUE gradients
79
- #
80
- # def f2(x, w):
81
- # return true_fn(x, w, indices, indptr, n_out).sum()
82
- #
83
- # r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
- # self.assertTrue(jnp.allclose(r[0], r2[0]))
85
- # self.assertTrue(jnp.allclose(r[1], r2[1]))
86
- #
87
- # @parameterized.product(
88
- # bool_x=[True, False],
89
- # homo_w=[True, False]
90
- # )
91
- # def test_jvp(self, bool_x, homo_w):
92
- # n_in = 20
93
- # n_out = 30
94
- # if bool_x:
95
- # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
- # else:
97
- # x = bst.random.rand(n_in)
98
- #
99
- # indptr, indices = _get_csr(n_in, n_out, 0.1)
100
- # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
101
- # 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
- # w = fn.weight.value
103
- #
104
- # def f(x, w):
105
- # fn.weight.value = w
106
- # return fn(x)
107
- #
108
- # o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
- #
110
- # # -------------------
111
- # # TRUE gradients
112
- #
113
- # def f2(x, w):
114
- # return true_fn(x, w, indices, indptr, n_out)
115
- #
116
- # o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
- # self.assertTrue(jnp.allclose(r1, r2))
118
- # self.assertTrue(jnp.allclose(o1, o2))
@@ -1,90 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # -*- coding: utf-8 -*-
16
-
17
-
18
- import unittest
19
-
20
- import brainunit as u
21
-
22
- import brainstate as bst
23
-
24
-
25
- class TestCSR(unittest.TestCase):
26
- def test_event_homo_bool(self):
27
- for dat in [1., 2., 3.]:
28
- mask = (bst.random.rand(10, 20) < 0.1).astype(float) * dat
29
- csr = u.sparse.CSR.fromdense(mask)
30
- csr = bst.event.CSR((dat, csr.indices, csr.indptr), shape=mask.shape)
31
-
32
- v = bst.random.rand(20) < 0.5
33
- self.assertTrue(
34
- u.math.allclose(
35
- mask.astype(float) @ v.astype(float),
36
- csr @ v
37
- )
38
- )
39
-
40
- v = bst.random.rand(10) < 0.5
41
- self.assertTrue(
42
- u.math.allclose(
43
- v.astype(float) @ mask.astype(float),
44
- v @ csr
45
- )
46
- )
47
-
48
- def test_event_homo_heter(self):
49
- mat = bst.random.rand(10, 20)
50
- mask = (bst.random.rand(10, 20) < 0.1) * mat
51
- csr = u.sparse.CSR.fromdense(mask)
52
- csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
53
-
54
- v = bst.random.rand(20) < 0.5
55
- self.assertTrue(
56
- u.math.allclose(
57
- mask.astype(float) @ v.astype(float),
58
- csr @ v
59
- )
60
- )
61
-
62
- v = bst.random.rand(10) < 0.5
63
- self.assertTrue(
64
- u.math.allclose(
65
- v.astype(float) @ mask.astype(float),
66
- v @ csr
67
- )
68
- )
69
-
70
- def test_event_heter_float_as_bool(self):
71
- mat = bst.random.rand(10, 20)
72
- mask = (mat < 0.1).astype(float) * mat
73
- csr = u.sparse.CSR.fromdense(mask)
74
- csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
75
-
76
- v = (bst.random.rand(20) < 0.5).astype(float)
77
- self.assertTrue(
78
- u.math.allclose(
79
- mask.astype(float) @ v.astype(float),
80
- csr @ v
81
- )
82
- )
83
-
84
- v = (bst.random.rand(10) < 0.5).astype(float)
85
- self.assertTrue(
86
- u.math.allclose(
87
- v.astype(float) @ mask.astype(float),
88
- v @ csr
89
- )
90
- )