brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/augment/__init__.py +10 -20
  3. brainstate/compile/__init__.py +18 -37
  4. brainstate/compile/_make_jaxpr.py +9 -2
  5. brainstate/compile/_make_jaxpr_test.py +10 -6
  6. brainstate/compile/_progress_bar.py +49 -6
  7. brainstate/compile/_unvmap.py +3 -3
  8. brainstate/graph/__init__.py +12 -12
  9. brainstate/nn/_dyn_impl/_inputs.py +4 -2
  10. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  11. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
  12. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
  13. brainstate/event/__init__.py +0 -27
  14. brainstate/event/_csr.py +0 -1149
  15. brainstate/event/_csr_benchmark.py +0 -14
  16. brainstate/event/_csr_mv.py +0 -303
  17. brainstate/event/_csr_test.py +0 -277
  18. brainstate/event/_fixedprob_mv.py +0 -730
  19. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  20. brainstate/event/_fixedprob_mv_test.py +0 -132
  21. brainstate/event/_linear_mv.py +0 -359
  22. brainstate/event/_linear_mv_benckmark.py +0 -82
  23. brainstate/event/_linear_mv_test.py +0 -117
  24. brainstate/event/_misc.py +0 -34
  25. brainstate/event/_xla_custom_op.py +0 -317
  26. brainstate/event/_xla_custom_op_test.py +0 -55
  27. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
@@ -1,117 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- from absl.testing import parameterized
21
-
22
- import brainstate as bst
23
- from brainstate.event._linear_mv import Linear
24
-
25
-
26
- class TestEventLinear(parameterized.TestCase):
27
- @parameterized.product(
28
- homo_w=[True, False],
29
- bool_x=[True, False],
30
- )
31
- def test1(self, homo_w, bool_x):
32
- x = bst.random.rand(20) < 0.1
33
- if not bool_x:
34
- x = jnp.asarray(x, dtype=float)
35
- m = Linear(20, 40, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
36
- y = m(x)
37
- print(y)
38
-
39
- self.assertTrue(jnp.allclose(y, (x.sum() * m.weight.value) if homo_w else (x @ m.weight.value)))
40
-
41
- def test_grad_bool(self):
42
- n_in = 20
43
- n_out = 30
44
- x = bst.random.rand(n_in) < 0.3
45
- fn = Linear(n_in, n_out, bst.init.KaimingUniform())
46
-
47
- with self.assertRaises(TypeError):
48
- print(jax.grad(lambda x: fn(x).sum())(x))
49
-
50
- @parameterized.product(
51
- bool_x=[True, False],
52
- homo_w=[True, False]
53
- )
54
- def test_vjp(self, bool_x, homo_w):
55
- n_in = 20
56
- n_out = 30
57
- if bool_x:
58
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
59
- else:
60
- x = bst.random.rand(n_in)
61
-
62
- fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
63
- w = fn.weight.value
64
-
65
- def f(x, w):
66
- fn.weight.value = w
67
- return fn(x).sum()
68
-
69
- r1 = jax.grad(f, argnums=(0, 1))(x, w)
70
-
71
- # -------------------
72
- # TRUE gradients
73
-
74
- def f2(x, w):
75
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
76
- return y.sum()
77
-
78
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
79
- self.assertTrue(jnp.allclose(r1[0], r2[0]))
80
-
81
- if not jnp.allclose(r1[1], r2[1]):
82
- print(r1[1] - r2[1])
83
-
84
- self.assertTrue(jnp.allclose(r1[1], r2[1]))
85
-
86
- @parameterized.product(
87
- bool_x=[True, False],
88
- homo_w=[True, False]
89
- )
90
- def test_jvp(self, bool_x, homo_w):
91
- n_in = 20
92
- n_out = 30
93
- if bool_x:
94
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
95
- else:
96
- x = bst.random.rand(n_in)
97
-
98
- fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(),
99
- float_as_event=bool_x)
100
- w = fn.weight.value
101
-
102
- def f(x, w):
103
- fn.weight.value = w
104
- return fn(x)
105
-
106
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
107
-
108
- # -------------------
109
- # TRUE gradients
110
-
111
- def f2(x, w):
112
- y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
113
- return y
114
-
115
- o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
116
- self.assertTrue(jnp.allclose(o1, o2))
117
- self.assertTrue(jnp.allclose(r1, r2))
brainstate/event/_misc.py DELETED
@@ -1,34 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from __future__ import annotations
16
-
17
- from typing import Union
18
-
19
- import numpy as np
20
-
21
- __all__ = [
22
- 'FloatScalar',
23
- 'IntScalar',
24
- ]
25
-
26
- FloatScalar = Union[
27
- np.number, # NumPy scalar types
28
- float, # Python scalar types
29
- ]
30
-
31
- IntScalar = Union[
32
- np.number, # NumPy scalar types
33
- int, # Python scalar types
34
- ]
@@ -1,317 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import ctypes
4
- import functools
5
- import importlib.util
6
- from functools import partial
7
- from typing import Callable, Sequence, Tuple, Protocol
8
-
9
- import jax
10
- import numpy as np
11
- from jax import tree_util
12
- from jax.interpreters import batching, ad
13
- from jax.interpreters import xla, mlir
14
- from jaxlib.hlo_helpers import custom_call
15
-
16
- if jax.__version_info__ < (0, 4, 35):
17
- from jax.lib import xla_client
18
- else:
19
- import jax.extend as je
20
-
21
- if jax.__version_info__ < (0, 4, 38):
22
- from jax.core import Primitive
23
- else:
24
- from jax.extend.core import Primitive
25
-
26
- numba_installed = importlib.util.find_spec('numba') is not None
27
-
28
- __all__ = [
29
- 'defjvp',
30
- 'XLACustomOp',
31
- ]
32
-
33
- # [void* pointer,
34
- # const char *name,
35
- # PyCapsule_Destructor destructor]
36
- ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
37
- ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
38
-
39
-
40
- def defjvp(primitive, *jvp_rules):
41
- """Define JVP rules for any JAX primitive.
42
-
43
- This function is similar to ``jax.interpreters.ad.defjvp``.
44
- However, the JAX one only supports primitive with ``multiple_results=False``.
45
- ``brainpy.math.defjvp`` enables to define the independent JVP rule for
46
- each input parameter no matter ``multiple_results=False/True``.
47
-
48
- For examples, please see ``test_ad_support.py``.
49
-
50
- Args:
51
- primitive: Primitive, XLACustomOp.
52
- *jvp_rules: The JVP translation rule for each primal.
53
- """
54
- assert isinstance(primitive, Primitive)
55
- if primitive.multiple_results:
56
- ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
57
- else:
58
- ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
59
-
60
-
61
- def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
62
- assert primitive.multiple_results
63
- val_out = tuple(primitive.bind(*primals, **params))
64
- tree = tree_util.tree_structure(val_out)
65
- tangents_out = []
66
- for rule, t in zip(jvp_rules, tangents):
67
- if rule is not None and type(t) is not ad.Zero:
68
- r = tuple(rule(t, *primals, **params))
69
- tangents_out.append(r)
70
- assert tree_util.tree_structure(r) == tree
71
- r = functools.reduce(
72
- _add_tangents,
73
- tangents_out,
74
- tree_util.tree_map(
75
- # compatible with JAX 0.4.34
76
- lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
77
- val_out
78
- )
79
- )
80
- return val_out, r
81
-
82
-
83
- def _add_tangents(xs, ys):
84
- return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
85
-
86
-
87
- def _shape_to_layout(shape):
88
- return tuple(range(len(shape) - 1, -1, -1))
89
-
90
-
91
- def _numba_mlir_cpu_translation_rule(
92
- kernel,
93
- debug: bool,
94
- ctx,
95
- *ins,
96
- **kwargs
97
- ):
98
- if not numba_installed:
99
- raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
100
-
101
- from numba import types, carray, cfunc # pylint: disable=import-error
102
- from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
103
-
104
- if not isinstance(kernel, Dispatcher):
105
- kernel = kernel(**kwargs)
106
- assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
107
-
108
- # output information
109
- outs = ctx.avals_out
110
- output_shapes = tuple([out.shape for out in outs])
111
- output_dtypes = tuple([out.dtype for out in outs])
112
- output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
113
- result_types = [mlir.aval_to_ir_type(out) for out in outs]
114
-
115
- # input information
116
- avals_in = ctx.avals_in
117
- input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
118
- input_dtypes = tuple(inp.dtype for inp in avals_in)
119
- input_shapes = tuple(inp.shape for inp in avals_in)
120
-
121
- # compiling function
122
- code_scope = dict(func_to_call=kernel,
123
- input_shapes=input_shapes,
124
- input_dtypes=input_dtypes,
125
- output_shapes=output_shapes,
126
- output_dtypes=output_dtypes, carray=carray)
127
- args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
128
- for i in range(len(input_shapes))]
129
- if len(output_shapes) > 1:
130
- args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
131
- for i in range(len(output_shapes))]
132
- sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
133
- else:
134
- args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
135
- sig = types.void(types.voidptr, types.CPointer(types.voidptr))
136
- args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
137
- code_string = '''
138
- def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
139
- {args_in}
140
- {args_out}
141
- func_to_call({args_call})
142
- '''.format(args_in="\n ".join(args_in),
143
- args_out="\n ".join(args_out),
144
- args_call=", ".join(args_call))
145
- if debug:
146
- print(code_string)
147
- exec(compile(code_string.strip(), '', 'exec'), code_scope)
148
- new_f = code_scope['numba_cpu_custom_call_target']
149
-
150
- # register
151
- xla_c_rule = cfunc(sig)(new_f)
152
- target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
153
- capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
154
- if jax.__version_info__ < (0, 4, 35):
155
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
156
- else:
157
- je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
158
-
159
- # call
160
- return custom_call(
161
- call_target_name=target_name,
162
- operands=ins,
163
- operand_layouts=list(input_layouts),
164
- result_layouts=list(output_layouts),
165
- result_types=list(result_types),
166
- has_side_effect=False,
167
- ).results
168
-
169
-
170
- def register_numba_mlir_cpu_translation_rule(
171
- primitive: Primitive,
172
- cpu_kernel: Callable,
173
- debug: bool = False
174
- ):
175
- rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
176
- mlir.register_lowering(primitive, rule, platform='cpu')
177
-
178
-
179
- class ShapeDtype(Protocol):
180
-
181
- @property
182
- def shape(self) -> Tuple[int, ...]:
183
- ...
184
-
185
- @property
186
- def dtype(self) -> np.dtype:
187
- ...
188
-
189
-
190
- class XLACustomOp:
191
- """Creating a XLA custom call operator.
192
-
193
- Args:
194
- cpu_kernel_or_generator: Callable. The function defines the computation on CPU backend.
195
- gpu_kernel_or_generator: Callable. The function defines the computation on GPU backend.
196
- batching_translation: Callable. The batching translation rule of JAX.
197
- jvp_translation: Callable. The JVP translation rule of JAX.
198
- transpose_translation: Callable. The transpose translation rule of JAX.
199
- name: str. The primitive name.
200
- """
201
-
202
- def __init__(
203
- self,
204
- name: str,
205
- cpu_kernel_or_generator: Callable,
206
- gpu_kernel_or_generator: Callable = None,
207
- batching_translation: Callable = None,
208
- jvp_translation: Callable = None,
209
- transpose_translation: Callable = None,
210
- ):
211
- # primitive
212
- self.primitive = Primitive(name)
213
- self.primitive.multiple_results = True
214
-
215
- # abstract evaluation
216
- self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
217
- self.primitive.def_abstract_eval(self._abstract_eval)
218
-
219
- # cpu kernel
220
- if cpu_kernel_or_generator is not None:
221
- self.def_cpu_kernel(cpu_kernel_or_generator)
222
- if gpu_kernel_or_generator is not None:
223
- self.def_gpu_kernel(gpu_kernel_or_generator)
224
-
225
- # batching rule
226
- if batching_translation is not None:
227
- batching.primitive_batchers[self.primitive] = batching_translation
228
-
229
- # jvp rule
230
- if jvp_translation is not None:
231
- ad.primitive_jvps[self.primitive] = jvp_translation
232
-
233
- # transpose rule
234
- if transpose_translation is not None:
235
- ad.primitive_transposes[self.primitive] = transpose_translation
236
-
237
- def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
238
- return tuple(outs)
239
-
240
- def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
241
- assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.'
242
- outs = jax.tree.map(_transform_to_shapedarray, outs)
243
- return self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
244
-
245
- def def_cpu_kernel(self, kernel_generator: Callable):
246
- """
247
- Define the CPU kernel using Numba.
248
- """
249
- register_numba_mlir_cpu_translation_rule(self.primitive, kernel_generator)
250
-
251
- def def_gpu_kernel(self, kernel_generator: Callable):
252
- """
253
- Define the GPU kernel using the JAX Pallas language.
254
- """
255
- lower = mlir.lower_fun(
256
- lambda *args, **kwargs: kernel_generator(**kwargs)(*args),
257
- multiple_results=True
258
- )
259
- mlir.register_lowering(self.primitive, lower, platform='cuda')
260
- mlir.register_lowering(self.primitive, lower, platform='tpu')
261
-
262
- def def_batching_rule(self, fun):
263
- """Define the batching rule.
264
-
265
- Args:
266
- fun: The batching rule.
267
- """
268
- batching.primitive_batchers[self.primitive] = fun
269
-
270
- def def_jvp_rule(self, fun):
271
- """Define the JVP rule.
272
-
273
- Args:
274
- fun: The JVP rule.
275
- """
276
- ad.primitive_jvps[self.primitive] = fun
277
-
278
- def defjvp(self, *jvp_rules):
279
- """
280
- Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``,
281
- but supports the Primitive with multiple results.
282
-
283
- Args:
284
- jvp_rules: The JVP rules.
285
- """
286
- defjvp(self.primitive, *jvp_rules)
287
-
288
- def def_transpose_rule(self, fun):
289
- """Define the transpose rule.
290
-
291
- Args:
292
- fun: The transpose rule.
293
- """
294
- ad.primitive_transposes[self.primitive] = fun
295
-
296
- def def_xla_translation(self, platform, fun):
297
- """Define the XLA translation rule.
298
-
299
- Args:
300
- platform: str. The computing platform.
301
- fun: The XLA translation rule.
302
- """
303
- xla.backend_specific_translations[platform][self.primitive] = fun
304
-
305
- def def_mlir_lowering(self, platform, fun):
306
- """
307
- Define the MLIR lowering rule.
308
-
309
- Args:
310
- platform: str. The computing platform.
311
- fun: The lowering rule.
312
- """
313
- mlir.register_lowering(self.primitive, fun, platform)
314
-
315
-
316
- def _transform_to_shapedarray(a):
317
- return jax.core.ShapedArray(a.shape, a.dtype)
@@ -1,55 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import jax
17
- import jax.numpy as jnp
18
- from jax.experimental import pallas as pl
19
-
20
- import brainstate as bst
21
-
22
-
23
- def test1():
24
- import numba
25
- def add_vectors_kernel(x_ref, y_ref, o_ref):
26
- x, y = x_ref[...], y_ref[...]
27
- o_ref[...] = x + y
28
-
29
- def cpu_kernel(**kwargs):
30
- @numba.njit
31
- def add_kernel_numba(x, y, out):
32
- out[...] = x + y
33
-
34
- return add_kernel_numba
35
-
36
- def gpu_kernel(x_info):
37
- return pl.pallas_call(
38
- add_vectors_kernel,
39
- out_shape=[jax.ShapeDtypeStruct(x_info.shape, x_info.dtype)],
40
- interpret=jax.default_backend() == 'cpu',
41
- )
42
-
43
- prim = bst.event.XLACustomOp(
44
- 'add',
45
- cpu_kernel,
46
- gpu_kernel,
47
- )
48
-
49
- a = bst.random.rand(64)
50
- b = bst.random.rand(64)
51
- x_info = jax.ShapeDtypeStruct(a.shape, a.dtype)
52
- r1 = prim(a, b, outs=[jax.ShapeDtypeStruct((64,), jax.numpy.float32)], x_info=x_info)
53
- r2 = gpu_kernel(x_info)(a, b)
54
-
55
- assert jnp.allclose(r1[0], r2[0])