brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from absl.testing import parameterized
21
+
22
+ import brainstate as bst
23
+ from brainstate.event._linear import Linear
24
+
25
+
26
+ class TestEventLinear(parameterized.TestCase):
27
+ @parameterized.product(
28
+ homo_w=[True, False],
29
+ bool_x=[True, False],
30
+ )
31
+ def test1(self, homo_w, bool_x):
32
+ x = bst.random.rand(20) < 0.1
33
+ if not bool_x:
34
+ x = jnp.asarray(x, dtype=float)
35
+ m = Linear(20, 40, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
36
+ y = m(x)
37
+ print(y)
38
+
39
+ self.assertTrue(jnp.allclose(y, (x.sum() * m.weight.value) if homo_w else (x @ m.weight.value)))
40
+
41
+ def test_grad_bool(self):
42
+ n_in = 20
43
+ n_out = 30
44
+ x = bst.random.rand(n_in) < 0.3
45
+ fn = Linear(n_in, n_out, bst.init.KaimingUniform())
46
+
47
+ with self.assertRaises(TypeError):
48
+ print(jax.grad(lambda x: fn(x).sum())(x))
49
+
50
+ @parameterized.product(
51
+ bool_x=[True, False],
52
+ homo_w=[True, False]
53
+ )
54
+ def test_vjp(self, bool_x, homo_w):
55
+ n_in = 20
56
+ n_out = 30
57
+ if bool_x:
58
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
59
+ else:
60
+ x = bst.random.rand(n_in)
61
+
62
+ fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
63
+ w = fn.weight.value
64
+
65
+ def f(x, w):
66
+ fn.weight.value = w
67
+ return fn(x).sum()
68
+
69
+ r1 = jax.grad(f, argnums=(0, 1))(x, w)
70
+
71
+ # -------------------
72
+ # TRUE gradients
73
+
74
+ def f2(x, w):
75
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
76
+ return y.sum()
77
+
78
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
79
+ self.assertTrue(jnp.allclose(r1[0], r2[0]))
80
+
81
+ if not jnp.allclose(r1[1], r2[1]):
82
+ print(r1[1] - r2[1])
83
+
84
+ self.assertTrue(jnp.allclose(r1[1], r2[1]))
85
+
86
+ @parameterized.product(
87
+ bool_x=[True, False],
88
+ homo_w=[True, False]
89
+ )
90
+ def test_jvp(self, bool_x, homo_w):
91
+ n_in = 20
92
+ n_out = 30
93
+ if bool_x:
94
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
95
+ else:
96
+ x = bst.random.rand(n_in)
97
+
98
+ fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(),
99
+ float_as_event=bool_x)
100
+ w = fn.weight.value
101
+
102
+ def f(x, w):
103
+ fn.weight.value = w
104
+ return fn(x)
105
+
106
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
107
+
108
+ # -------------------
109
+ # TRUE gradients
110
+
111
+ def f2(x, w):
112
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
113
+ return y
114
+
115
+ o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
116
+ self.assertTrue(jnp.allclose(o1, o2))
117
+ self.assertTrue(jnp.allclose(r1, r2))
@@ -12,23 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ from __future__ import annotations
16
16
 
17
17
  from typing import Union
18
18
 
19
19
  import numpy as np
20
20
 
21
21
  __all__ = [
22
- 'FloatScalar',
23
- 'IntScalar',
22
+ 'FloatScalar',
23
+ 'IntScalar',
24
24
  ]
25
25
 
26
26
  FloatScalar = Union[
27
- np.number, # NumPy scalar types
28
- float, # Python scalar types
27
+ np.number, # NumPy scalar types
28
+ float, # Python scalar types
29
29
  ]
30
30
 
31
31
  IntScalar = Union[
32
- np.number, # NumPy scalar types
33
- int, # Python scalar types
32
+ np.number, # NumPy scalar types
33
+ int, # Python scalar types
34
34
  ]
@@ -0,0 +1,312 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ctypes
4
+ import functools
5
+ import importlib.util
6
+ from functools import partial
7
+ from typing import Callable, Sequence, Tuple, Protocol
8
+
9
+ import jax
10
+ import numpy as np
11
+ from jax import tree_util
12
+ from jax.core import Primitive
13
+ from jax.interpreters import batching, ad
14
+ from jax.interpreters import xla, mlir
15
+ from jax.lib import xla_client
16
+ from jaxlib.hlo_helpers import custom_call
17
+
18
+ numba_installed = importlib.util.find_spec('numba') is not None
19
+
20
+ if numba_installed:
21
+ import numba # pylint: disable=import-error
22
+ from numba import types, carray, cfunc # pylint: disable=import-error
23
+ from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
24
+ else:
25
+ numba = None
26
+
27
+ __all__ = [
28
+ 'XLACustomOp',
29
+ ]
30
+
31
+ # [void* pointer,
32
+ # const char *name,
33
+ # PyCapsule_Destructor destructor]
34
+ ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
35
+ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
36
+
37
+
38
+ def defjvp(primitive, *jvp_rules):
39
+ """Define JVP rules for any JAX primitive.
40
+
41
+ This function is similar to ``jax.interpreters.ad.defjvp``.
42
+ However, the JAX one only supports primitive with ``multiple_results=False``.
43
+ ``brainpy.math.defjvp`` enables to define the independent JVP rule for
44
+ each input parameter no matter ``multiple_results=False/True``.
45
+
46
+ For examples, please see ``test_ad_support.py``.
47
+
48
+ Args:
49
+ primitive: Primitive, XLACustomOp.
50
+ *jvp_rules: The JVP translation rule for each primal.
51
+ """
52
+ assert isinstance(primitive, Primitive)
53
+ if primitive.multiple_results:
54
+ ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
55
+ else:
56
+ ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
57
+
58
+
59
+ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
60
+ assert primitive.multiple_results
61
+ val_out = tuple(primitive.bind(*primals, **params))
62
+ tree = tree_util.tree_structure(val_out)
63
+ tangents_out = []
64
+ for rule, t in zip(jvp_rules, tangents):
65
+ if rule is not None and type(t) is not ad.Zero:
66
+ r = tuple(rule(t, *primals, **params))
67
+ tangents_out.append(r)
68
+ assert tree_util.tree_structure(r) == tree
69
+ r = functools.reduce(
70
+ _add_tangents,
71
+ tangents_out,
72
+ tree_util.tree_map(
73
+ # compatible with JAX 0.4.34
74
+ lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
75
+ val_out
76
+ )
77
+ )
78
+ return val_out, r
79
+
80
+
81
+ def _add_tangents(xs, ys):
82
+ return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
83
+
84
+
85
+ def _shape_to_layout(shape):
86
+ return tuple(range(len(shape) - 1, -1, -1))
87
+
88
+
89
+ def _numba_mlir_cpu_translation_rule(
90
+ kernel,
91
+ debug: bool,
92
+ ctx,
93
+ *ins,
94
+ **kwargs
95
+ ):
96
+ if numba is None:
97
+ raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
98
+
99
+ if not isinstance(kernel, Dispatcher):
100
+ kernel = kernel(**kwargs)
101
+ assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
102
+
103
+ # output information
104
+ outs = ctx.avals_out
105
+ output_shapes = tuple([out.shape for out in outs])
106
+ output_dtypes = tuple([out.dtype for out in outs])
107
+ output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
108
+ result_types = [mlir.aval_to_ir_type(out) for out in outs]
109
+
110
+ # input information
111
+ avals_in = ctx.avals_in
112
+ input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
113
+ input_dtypes = tuple(inp.dtype for inp in avals_in)
114
+ input_shapes = tuple(inp.shape for inp in avals_in)
115
+
116
+ # compiling function
117
+ code_scope = dict(func_to_call=kernel,
118
+ input_shapes=input_shapes,
119
+ input_dtypes=input_dtypes,
120
+ output_shapes=output_shapes,
121
+ output_dtypes=output_dtypes, carray=carray)
122
+ args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
123
+ for i in range(len(input_shapes))]
124
+ if len(output_shapes) > 1:
125
+ args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
126
+ for i in range(len(output_shapes))]
127
+ sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
128
+ else:
129
+ args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
130
+ sig = types.void(types.voidptr, types.CPointer(types.voidptr))
131
+ args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
132
+ code_string = '''
133
+ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
134
+ {args_in}
135
+ {args_out}
136
+ func_to_call({args_call})
137
+ '''.format(args_in="\n ".join(args_in),
138
+ args_out="\n ".join(args_out),
139
+ args_call=", ".join(args_call))
140
+ if debug:
141
+ print(code_string)
142
+ exec(compile(code_string.strip(), '', 'exec'), code_scope)
143
+ new_f = code_scope['numba_cpu_custom_call_target']
144
+
145
+ # register
146
+ xla_c_rule = cfunc(sig)(new_f)
147
+ target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
148
+ capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
149
+ xla_client.register_custom_call_target(target_name, capsule, "cpu")
150
+
151
+ # call
152
+ return custom_call(
153
+ call_target_name=target_name,
154
+ operands=ins,
155
+ operand_layouts=list(input_layouts),
156
+ result_layouts=list(output_layouts),
157
+ result_types=list(result_types),
158
+ has_side_effect=False,
159
+ ).results
160
+
161
+
162
+ def register_numba_mlir_cpu_translation_rule(
163
+ primitive: jax.core.Primitive,
164
+ cpu_kernel: Callable,
165
+ debug: bool = False
166
+ ):
167
+ rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
168
+ mlir.register_lowering(primitive, rule, platform='cpu')
169
+
170
+
171
+ class ShapeDtype(Protocol):
172
+
173
+ @property
174
+ def shape(self) -> Tuple[int, ...]:
175
+ ...
176
+
177
+ @property
178
+ def dtype(self) -> np.dtype:
179
+ ...
180
+
181
+
182
+ class XLACustomOp:
183
+ """Creating a XLA custom call operator.
184
+
185
+ Args:
186
+ cpu_kernel_generator: Callable. The function defines the computation on CPU backend.
187
+ gpu_kernel_generator: Callable. The function defines the computation on GPU backend.
188
+ batching_translation: Callable. The batching translation rule of JAX.
189
+ jvp_translation: Callable. The JVP translation rule of JAX.
190
+ transpose_translation: Callable. The transpose translation rule of JAX.
191
+ name: str. The primitive name.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ name: str,
197
+ cpu_kernel_generator: Callable,
198
+ gpu_kernel_generator: Callable = None,
199
+ batching_translation: Callable = None,
200
+ jvp_translation: Callable = None,
201
+ transpose_translation: Callable = None,
202
+ ):
203
+ # set cpu_kernel and gpu_kernel
204
+ self.cpu_kernel = cpu_kernel_generator
205
+
206
+ # primitive
207
+ self.primitive = jax.core.Primitive(name)
208
+ self.primitive.multiple_results = True
209
+
210
+ # abstract evaluation
211
+ self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
212
+ self.primitive.def_abstract_eval(self._abstract_eval)
213
+
214
+ # cpu kernel
215
+ if cpu_kernel_generator is not None:
216
+ self.def_cpu_kernel(cpu_kernel_generator)
217
+ if gpu_kernel_generator is not None:
218
+ self.def_gpu_kernel(gpu_kernel_generator)
219
+
220
+ # batching rule
221
+ if batching_translation is not None:
222
+ batching.primitive_batchers[self.primitive] = batching_translation
223
+
224
+ # jvp rule
225
+ if jvp_translation is not None:
226
+ ad.primitive_jvps[self.primitive] = jvp_translation
227
+
228
+ # transpose rule
229
+ if transpose_translation is not None:
230
+ ad.primitive_transposes[self.primitive] = transpose_translation
231
+
232
+ def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
233
+ return tuple(outs)
234
+
235
+ def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
236
+ assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.'
237
+ outs = jax.tree.map(_transform_to_shapedarray, outs)
238
+ return self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
239
+
240
+ def def_cpu_kernel(self, kernel_generator: Callable):
241
+ """
242
+ Define the CPU kernel using Numba.
243
+ """
244
+ register_numba_mlir_cpu_translation_rule(self.primitive, kernel_generator)
245
+
246
+ def def_gpu_kernel(self, kernel_generator: Callable):
247
+ """
248
+ Define the GPU kernel using the JAX Pallas language.
249
+ """
250
+ lower = mlir.lower_fun(
251
+ lambda *args, **kwargs: kernel_generator(**kwargs)(*args),
252
+ multiple_results=True
253
+ )
254
+ mlir.register_lowering(self.primitive, lower, platform='cuda')
255
+ mlir.register_lowering(self.primitive, lower, platform='tpu')
256
+
257
+ def def_batching_rule(self, fun):
258
+ """Define the batching rule.
259
+
260
+ Args:
261
+ fun: The batching rule.
262
+ """
263
+ batching.primitive_batchers[self.primitive] = fun
264
+
265
+ def def_jvp_rule(self, fun):
266
+ """Define the JVP rule.
267
+
268
+ Args:
269
+ fun: The JVP rule.
270
+ """
271
+ ad.primitive_jvps[self.primitive] = fun
272
+
273
+ def defjvp(self, *jvp_rules):
274
+ """
275
+ Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``,
276
+ but supports the Primitive with multiple results.
277
+
278
+ Args:
279
+ jvp_rules: The JVP rules.
280
+ """
281
+ defjvp(self.primitive, *jvp_rules)
282
+
283
+ def def_transpose_rule(self, fun):
284
+ """Define the transpose rule.
285
+
286
+ Args:
287
+ fun: The transpose rule.
288
+ """
289
+ ad.primitive_transposes[self.primitive] = fun
290
+
291
+ def def_xla_translation(self, platform, fun):
292
+ """Define the XLA translation rule.
293
+
294
+ Args:
295
+ platform: str. The computing platform.
296
+ fun: The XLA translation rule.
297
+ """
298
+ xla.backend_specific_translations[platform][self.primitive] = fun
299
+
300
+ def def_mlir_lowering(self, platform, fun):
301
+ """
302
+ Define the MLIR lowering rule.
303
+
304
+ Args:
305
+ platform: str. The computing platform.
306
+ fun: The lowering rule.
307
+ """
308
+ mlir.register_lowering(self.primitive, fun, platform)
309
+
310
+
311
+ def _transform_to_shapedarray(a):
312
+ return jax.core.ShapedArray(a.shape, a.dtype)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import jax
17
+ import jax.numpy as jnp
18
+ from jax.experimental import pallas as pl
19
+
20
+ import brainstate as bst
21
+
22
+
23
+ def test1():
24
+ import numba
25
+ def add_vectors_kernel(x_ref, y_ref, o_ref):
26
+ x, y = x_ref[...], y_ref[...]
27
+ o_ref[...] = x + y
28
+
29
+ def cpu_kernel(**kwargs):
30
+ @numba.njit
31
+ def add_kernel_numba(x, y, out):
32
+ out[...] = x + y
33
+
34
+ return add_kernel_numba
35
+
36
+ def gpu_kernel(x_info):
37
+ return pl.pallas_call(
38
+ add_vectors_kernel,
39
+ out_shape=[jax.ShapeDtypeStruct(x_info.shape, x_info.dtype)],
40
+ interpret=jax.default_backend() == 'cpu',
41
+ )
42
+
43
+ prim = bst.event.XLACustomOp(
44
+ 'add',
45
+ cpu_kernel,
46
+ gpu_kernel,
47
+ )
48
+
49
+ a = bst.random.rand(64)
50
+ b = bst.random.rand(64)
51
+ x_info = jax.ShapeDtypeStruct(a.shape, a.dtype)
52
+ r1 = prim(a, b, outs=[jax.ShapeDtypeStruct((64,), jax.numpy.float32)], x_info=x_info)
53
+ r2 = gpu_kernel(x_info)(a, b)
54
+
55
+ assert jnp.allclose(r1[0], r2[0])