brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -2
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/__init__.py +10 -20
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/__init__.py +18 -37
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +15 -4
- brainstate/compile/_make_jaxpr_test.py +10 -6
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +9 -6
- brainstate/graph/__init__.py +12 -16
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_elementwise/_dropout_test.py +1 -1
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
- brainstate/event/__init__.py +0 -29
- brainstate/event/_csr.py +0 -906
- brainstate/event/_csr_mv.py +0 -303
- brainstate/event/_csr_mv_benchmark.py +0 -14
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/event/_csr_test.py +0 -90
- brainstate/event/_fixedprob_mv.py +0 -730
- brainstate/event/_fixedprob_mv_benchmark.py +0 -128
- brainstate/event/_fixedprob_mv_test.py +0 -132
- brainstate/event/_linear_mv.py +0 -359
- brainstate/event/_linear_mv_benckmark.py +0 -82
- brainstate/event/_linear_mv_test.py +0 -117
- brainstate/event/_misc.py +0 -34
- brainstate/event/_xla_custom_op.py +0 -313
- brainstate/event/_xla_custom_op_test.py +0 -55
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -1,117 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
from absl.testing import parameterized
|
21
|
-
|
22
|
-
import brainstate as bst
|
23
|
-
from brainstate.event._linear_mv import Linear
|
24
|
-
|
25
|
-
|
26
|
-
class TestEventLinear(parameterized.TestCase):
|
27
|
-
@parameterized.product(
|
28
|
-
homo_w=[True, False],
|
29
|
-
bool_x=[True, False],
|
30
|
-
)
|
31
|
-
def test1(self, homo_w, bool_x):
|
32
|
-
x = bst.random.rand(20) < 0.1
|
33
|
-
if not bool_x:
|
34
|
-
x = jnp.asarray(x, dtype=float)
|
35
|
-
m = Linear(20, 40, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
|
36
|
-
y = m(x)
|
37
|
-
print(y)
|
38
|
-
|
39
|
-
self.assertTrue(jnp.allclose(y, (x.sum() * m.weight.value) if homo_w else (x @ m.weight.value)))
|
40
|
-
|
41
|
-
def test_grad_bool(self):
|
42
|
-
n_in = 20
|
43
|
-
n_out = 30
|
44
|
-
x = bst.random.rand(n_in) < 0.3
|
45
|
-
fn = Linear(n_in, n_out, bst.init.KaimingUniform())
|
46
|
-
|
47
|
-
with self.assertRaises(TypeError):
|
48
|
-
print(jax.grad(lambda x: fn(x).sum())(x))
|
49
|
-
|
50
|
-
@parameterized.product(
|
51
|
-
bool_x=[True, False],
|
52
|
-
homo_w=[True, False]
|
53
|
-
)
|
54
|
-
def test_vjp(self, bool_x, homo_w):
|
55
|
-
n_in = 20
|
56
|
-
n_out = 30
|
57
|
-
if bool_x:
|
58
|
-
x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
59
|
-
else:
|
60
|
-
x = bst.random.rand(n_in)
|
61
|
-
|
62
|
-
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
|
63
|
-
w = fn.weight.value
|
64
|
-
|
65
|
-
def f(x, w):
|
66
|
-
fn.weight.value = w
|
67
|
-
return fn(x).sum()
|
68
|
-
|
69
|
-
r1 = jax.grad(f, argnums=(0, 1))(x, w)
|
70
|
-
|
71
|
-
# -------------------
|
72
|
-
# TRUE gradients
|
73
|
-
|
74
|
-
def f2(x, w):
|
75
|
-
y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
|
76
|
-
return y.sum()
|
77
|
-
|
78
|
-
r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
79
|
-
self.assertTrue(jnp.allclose(r1[0], r2[0]))
|
80
|
-
|
81
|
-
if not jnp.allclose(r1[1], r2[1]):
|
82
|
-
print(r1[1] - r2[1])
|
83
|
-
|
84
|
-
self.assertTrue(jnp.allclose(r1[1], r2[1]))
|
85
|
-
|
86
|
-
@parameterized.product(
|
87
|
-
bool_x=[True, False],
|
88
|
-
homo_w=[True, False]
|
89
|
-
)
|
90
|
-
def test_jvp(self, bool_x, homo_w):
|
91
|
-
n_in = 20
|
92
|
-
n_out = 30
|
93
|
-
if bool_x:
|
94
|
-
x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
95
|
-
else:
|
96
|
-
x = bst.random.rand(n_in)
|
97
|
-
|
98
|
-
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(),
|
99
|
-
float_as_event=bool_x)
|
100
|
-
w = fn.weight.value
|
101
|
-
|
102
|
-
def f(x, w):
|
103
|
-
fn.weight.value = w
|
104
|
-
return fn(x)
|
105
|
-
|
106
|
-
o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
107
|
-
|
108
|
-
# -------------------
|
109
|
-
# TRUE gradients
|
110
|
-
|
111
|
-
def f2(x, w):
|
112
|
-
y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
|
113
|
-
return y
|
114
|
-
|
115
|
-
o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
116
|
-
self.assertTrue(jnp.allclose(o1, o2))
|
117
|
-
self.assertTrue(jnp.allclose(r1, r2))
|
brainstate/event/_misc.py
DELETED
@@ -1,34 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
|
-
|
17
|
-
from typing import Union
|
18
|
-
|
19
|
-
import numpy as np
|
20
|
-
|
21
|
-
__all__ = [
|
22
|
-
'FloatScalar',
|
23
|
-
'IntScalar',
|
24
|
-
]
|
25
|
-
|
26
|
-
FloatScalar = Union[
|
27
|
-
np.number, # NumPy scalar types
|
28
|
-
float, # Python scalar types
|
29
|
-
]
|
30
|
-
|
31
|
-
IntScalar = Union[
|
32
|
-
np.number, # NumPy scalar types
|
33
|
-
int, # Python scalar types
|
34
|
-
]
|
@@ -1,313 +0,0 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
|
-
import ctypes
|
4
|
-
import functools
|
5
|
-
import importlib.util
|
6
|
-
from functools import partial
|
7
|
-
from typing import Callable, Sequence, Tuple, Protocol
|
8
|
-
|
9
|
-
import jax
|
10
|
-
import numpy as np
|
11
|
-
from jax import tree_util
|
12
|
-
from jax.core import Primitive
|
13
|
-
from jax.interpreters import batching, ad
|
14
|
-
from jax.interpreters import xla, mlir
|
15
|
-
from jaxlib.hlo_helpers import custom_call
|
16
|
-
|
17
|
-
if jax.__version_info__ < (0, 4, 35):
|
18
|
-
from jax.lib import xla_client
|
19
|
-
else:
|
20
|
-
import jax.extend as je
|
21
|
-
|
22
|
-
numba_installed = importlib.util.find_spec('numba') is not None
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
'defjvp',
|
26
|
-
'XLACustomOp',
|
27
|
-
]
|
28
|
-
|
29
|
-
# [void* pointer,
|
30
|
-
# const char *name,
|
31
|
-
# PyCapsule_Destructor destructor]
|
32
|
-
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
|
33
|
-
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
|
34
|
-
|
35
|
-
|
36
|
-
def defjvp(primitive, *jvp_rules):
|
37
|
-
"""Define JVP rules for any JAX primitive.
|
38
|
-
|
39
|
-
This function is similar to ``jax.interpreters.ad.defjvp``.
|
40
|
-
However, the JAX one only supports primitive with ``multiple_results=False``.
|
41
|
-
``brainpy.math.defjvp`` enables to define the independent JVP rule for
|
42
|
-
each input parameter no matter ``multiple_results=False/True``.
|
43
|
-
|
44
|
-
For examples, please see ``test_ad_support.py``.
|
45
|
-
|
46
|
-
Args:
|
47
|
-
primitive: Primitive, XLACustomOp.
|
48
|
-
*jvp_rules: The JVP translation rule for each primal.
|
49
|
-
"""
|
50
|
-
assert isinstance(primitive, Primitive)
|
51
|
-
if primitive.multiple_results:
|
52
|
-
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
|
53
|
-
else:
|
54
|
-
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
|
55
|
-
|
56
|
-
|
57
|
-
def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
|
58
|
-
assert primitive.multiple_results
|
59
|
-
val_out = tuple(primitive.bind(*primals, **params))
|
60
|
-
tree = tree_util.tree_structure(val_out)
|
61
|
-
tangents_out = []
|
62
|
-
for rule, t in zip(jvp_rules, tangents):
|
63
|
-
if rule is not None and type(t) is not ad.Zero:
|
64
|
-
r = tuple(rule(t, *primals, **params))
|
65
|
-
tangents_out.append(r)
|
66
|
-
assert tree_util.tree_structure(r) == tree
|
67
|
-
r = functools.reduce(
|
68
|
-
_add_tangents,
|
69
|
-
tangents_out,
|
70
|
-
tree_util.tree_map(
|
71
|
-
# compatible with JAX 0.4.34
|
72
|
-
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
|
73
|
-
val_out
|
74
|
-
)
|
75
|
-
)
|
76
|
-
return val_out, r
|
77
|
-
|
78
|
-
|
79
|
-
def _add_tangents(xs, ys):
|
80
|
-
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
|
81
|
-
|
82
|
-
|
83
|
-
def _shape_to_layout(shape):
|
84
|
-
return tuple(range(len(shape) - 1, -1, -1))
|
85
|
-
|
86
|
-
|
87
|
-
def _numba_mlir_cpu_translation_rule(
|
88
|
-
kernel,
|
89
|
-
debug: bool,
|
90
|
-
ctx,
|
91
|
-
*ins,
|
92
|
-
**kwargs
|
93
|
-
):
|
94
|
-
if not numba_installed:
|
95
|
-
raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
|
96
|
-
|
97
|
-
from numba import types, carray, cfunc # pylint: disable=import-error
|
98
|
-
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
|
99
|
-
|
100
|
-
if not isinstance(kernel, Dispatcher):
|
101
|
-
kernel = kernel(**kwargs)
|
102
|
-
assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
|
103
|
-
|
104
|
-
# output information
|
105
|
-
outs = ctx.avals_out
|
106
|
-
output_shapes = tuple([out.shape for out in outs])
|
107
|
-
output_dtypes = tuple([out.dtype for out in outs])
|
108
|
-
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
|
109
|
-
result_types = [mlir.aval_to_ir_type(out) for out in outs]
|
110
|
-
|
111
|
-
# input information
|
112
|
-
avals_in = ctx.avals_in
|
113
|
-
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
|
114
|
-
input_dtypes = tuple(inp.dtype for inp in avals_in)
|
115
|
-
input_shapes = tuple(inp.shape for inp in avals_in)
|
116
|
-
|
117
|
-
# compiling function
|
118
|
-
code_scope = dict(func_to_call=kernel,
|
119
|
-
input_shapes=input_shapes,
|
120
|
-
input_dtypes=input_dtypes,
|
121
|
-
output_shapes=output_shapes,
|
122
|
-
output_dtypes=output_dtypes, carray=carray)
|
123
|
-
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
|
124
|
-
for i in range(len(input_shapes))]
|
125
|
-
if len(output_shapes) > 1:
|
126
|
-
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
|
127
|
-
for i in range(len(output_shapes))]
|
128
|
-
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
|
129
|
-
else:
|
130
|
-
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
|
131
|
-
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
|
132
|
-
args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
|
133
|
-
code_string = '''
|
134
|
-
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
135
|
-
{args_in}
|
136
|
-
{args_out}
|
137
|
-
func_to_call({args_call})
|
138
|
-
'''.format(args_in="\n ".join(args_in),
|
139
|
-
args_out="\n ".join(args_out),
|
140
|
-
args_call=", ".join(args_call))
|
141
|
-
if debug:
|
142
|
-
print(code_string)
|
143
|
-
exec(compile(code_string.strip(), '', 'exec'), code_scope)
|
144
|
-
new_f = code_scope['numba_cpu_custom_call_target']
|
145
|
-
|
146
|
-
# register
|
147
|
-
xla_c_rule = cfunc(sig)(new_f)
|
148
|
-
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
|
149
|
-
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
|
150
|
-
if jax.__version_info__ < (0, 4, 35):
|
151
|
-
xla_client.register_custom_call_target(target_name, capsule, "cpu")
|
152
|
-
else:
|
153
|
-
je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
|
154
|
-
|
155
|
-
# call
|
156
|
-
return custom_call(
|
157
|
-
call_target_name=target_name,
|
158
|
-
operands=ins,
|
159
|
-
operand_layouts=list(input_layouts),
|
160
|
-
result_layouts=list(output_layouts),
|
161
|
-
result_types=list(result_types),
|
162
|
-
has_side_effect=False,
|
163
|
-
).results
|
164
|
-
|
165
|
-
|
166
|
-
def register_numba_mlir_cpu_translation_rule(
|
167
|
-
primitive: jax.core.Primitive,
|
168
|
-
cpu_kernel: Callable,
|
169
|
-
debug: bool = False
|
170
|
-
):
|
171
|
-
rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
|
172
|
-
mlir.register_lowering(primitive, rule, platform='cpu')
|
173
|
-
|
174
|
-
|
175
|
-
class ShapeDtype(Protocol):
|
176
|
-
|
177
|
-
@property
|
178
|
-
def shape(self) -> Tuple[int, ...]:
|
179
|
-
...
|
180
|
-
|
181
|
-
@property
|
182
|
-
def dtype(self) -> np.dtype:
|
183
|
-
...
|
184
|
-
|
185
|
-
|
186
|
-
class XLACustomOp:
|
187
|
-
"""Creating a XLA custom call operator.
|
188
|
-
|
189
|
-
Args:
|
190
|
-
cpu_kernel_or_generator: Callable. The function defines the computation on CPU backend.
|
191
|
-
gpu_kernel_or_generator: Callable. The function defines the computation on GPU backend.
|
192
|
-
batching_translation: Callable. The batching translation rule of JAX.
|
193
|
-
jvp_translation: Callable. The JVP translation rule of JAX.
|
194
|
-
transpose_translation: Callable. The transpose translation rule of JAX.
|
195
|
-
name: str. The primitive name.
|
196
|
-
"""
|
197
|
-
|
198
|
-
def __init__(
|
199
|
-
self,
|
200
|
-
name: str,
|
201
|
-
cpu_kernel_or_generator: Callable,
|
202
|
-
gpu_kernel_or_generator: Callable = None,
|
203
|
-
batching_translation: Callable = None,
|
204
|
-
jvp_translation: Callable = None,
|
205
|
-
transpose_translation: Callable = None,
|
206
|
-
):
|
207
|
-
# primitive
|
208
|
-
self.primitive = jax.core.Primitive(name)
|
209
|
-
self.primitive.multiple_results = True
|
210
|
-
|
211
|
-
# abstract evaluation
|
212
|
-
self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
|
213
|
-
self.primitive.def_abstract_eval(self._abstract_eval)
|
214
|
-
|
215
|
-
# cpu kernel
|
216
|
-
if cpu_kernel_or_generator is not None:
|
217
|
-
self.def_cpu_kernel(cpu_kernel_or_generator)
|
218
|
-
if gpu_kernel_or_generator is not None:
|
219
|
-
self.def_gpu_kernel(gpu_kernel_or_generator)
|
220
|
-
|
221
|
-
# batching rule
|
222
|
-
if batching_translation is not None:
|
223
|
-
batching.primitive_batchers[self.primitive] = batching_translation
|
224
|
-
|
225
|
-
# jvp rule
|
226
|
-
if jvp_translation is not None:
|
227
|
-
ad.primitive_jvps[self.primitive] = jvp_translation
|
228
|
-
|
229
|
-
# transpose rule
|
230
|
-
if transpose_translation is not None:
|
231
|
-
ad.primitive_transposes[self.primitive] = transpose_translation
|
232
|
-
|
233
|
-
def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
|
234
|
-
return tuple(outs)
|
235
|
-
|
236
|
-
def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
|
237
|
-
assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.'
|
238
|
-
outs = jax.tree.map(_transform_to_shapedarray, outs)
|
239
|
-
return self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
|
240
|
-
|
241
|
-
def def_cpu_kernel(self, kernel_generator: Callable):
|
242
|
-
"""
|
243
|
-
Define the CPU kernel using Numba.
|
244
|
-
"""
|
245
|
-
register_numba_mlir_cpu_translation_rule(self.primitive, kernel_generator)
|
246
|
-
|
247
|
-
def def_gpu_kernel(self, kernel_generator: Callable):
|
248
|
-
"""
|
249
|
-
Define the GPU kernel using the JAX Pallas language.
|
250
|
-
"""
|
251
|
-
lower = mlir.lower_fun(
|
252
|
-
lambda *args, **kwargs: kernel_generator(**kwargs)(*args),
|
253
|
-
multiple_results=True
|
254
|
-
)
|
255
|
-
mlir.register_lowering(self.primitive, lower, platform='cuda')
|
256
|
-
mlir.register_lowering(self.primitive, lower, platform='tpu')
|
257
|
-
|
258
|
-
def def_batching_rule(self, fun):
|
259
|
-
"""Define the batching rule.
|
260
|
-
|
261
|
-
Args:
|
262
|
-
fun: The batching rule.
|
263
|
-
"""
|
264
|
-
batching.primitive_batchers[self.primitive] = fun
|
265
|
-
|
266
|
-
def def_jvp_rule(self, fun):
|
267
|
-
"""Define the JVP rule.
|
268
|
-
|
269
|
-
Args:
|
270
|
-
fun: The JVP rule.
|
271
|
-
"""
|
272
|
-
ad.primitive_jvps[self.primitive] = fun
|
273
|
-
|
274
|
-
def defjvp(self, *jvp_rules):
|
275
|
-
"""
|
276
|
-
Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``,
|
277
|
-
but supports the Primitive with multiple results.
|
278
|
-
|
279
|
-
Args:
|
280
|
-
jvp_rules: The JVP rules.
|
281
|
-
"""
|
282
|
-
defjvp(self.primitive, *jvp_rules)
|
283
|
-
|
284
|
-
def def_transpose_rule(self, fun):
|
285
|
-
"""Define the transpose rule.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
fun: The transpose rule.
|
289
|
-
"""
|
290
|
-
ad.primitive_transposes[self.primitive] = fun
|
291
|
-
|
292
|
-
def def_xla_translation(self, platform, fun):
|
293
|
-
"""Define the XLA translation rule.
|
294
|
-
|
295
|
-
Args:
|
296
|
-
platform: str. The computing platform.
|
297
|
-
fun: The XLA translation rule.
|
298
|
-
"""
|
299
|
-
xla.backend_specific_translations[platform][self.primitive] = fun
|
300
|
-
|
301
|
-
def def_mlir_lowering(self, platform, fun):
|
302
|
-
"""
|
303
|
-
Define the MLIR lowering rule.
|
304
|
-
|
305
|
-
Args:
|
306
|
-
platform: str. The computing platform.
|
307
|
-
fun: The lowering rule.
|
308
|
-
"""
|
309
|
-
mlir.register_lowering(self.primitive, fun, platform)
|
310
|
-
|
311
|
-
|
312
|
-
def _transform_to_shapedarray(a):
|
313
|
-
return jax.core.ShapedArray(a.shape, a.dtype)
|
@@ -1,55 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import jax
|
17
|
-
import jax.numpy as jnp
|
18
|
-
from jax.experimental import pallas as pl
|
19
|
-
|
20
|
-
import brainstate as bst
|
21
|
-
|
22
|
-
|
23
|
-
def test1():
|
24
|
-
import numba
|
25
|
-
def add_vectors_kernel(x_ref, y_ref, o_ref):
|
26
|
-
x, y = x_ref[...], y_ref[...]
|
27
|
-
o_ref[...] = x + y
|
28
|
-
|
29
|
-
def cpu_kernel(**kwargs):
|
30
|
-
@numba.njit
|
31
|
-
def add_kernel_numba(x, y, out):
|
32
|
-
out[...] = x + y
|
33
|
-
|
34
|
-
return add_kernel_numba
|
35
|
-
|
36
|
-
def gpu_kernel(x_info):
|
37
|
-
return pl.pallas_call(
|
38
|
-
add_vectors_kernel,
|
39
|
-
out_shape=[jax.ShapeDtypeStruct(x_info.shape, x_info.dtype)],
|
40
|
-
interpret=jax.default_backend() == 'cpu',
|
41
|
-
)
|
42
|
-
|
43
|
-
prim = bst.event.XLACustomOp(
|
44
|
-
'add',
|
45
|
-
cpu_kernel,
|
46
|
-
gpu_kernel,
|
47
|
-
)
|
48
|
-
|
49
|
-
a = bst.random.rand(64)
|
50
|
-
b = bst.random.rand(64)
|
51
|
-
x_info = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
52
|
-
r1 = prim(a, b, outs=[jax.ShapeDtypeStruct((64,), jax.numpy.float32)], x_info=x_info)
|
53
|
-
r2 = gpu_kernel(x_info)(a, b)
|
54
|
-
|
55
|
-
assert jnp.allclose(r1[0], r2[0])
|