brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/augment/_autograd.py +9 -6
- brainstate/event/__init__.py +4 -2
- brainstate/event/_csr.py +26 -18
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_fixed_probability.py +589 -152
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +13 -10
- brainstate/event/_linear.py +267 -127
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +8 -3
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +18 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import os
|
17
|
+
|
18
|
+
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
19
|
+
|
20
|
+
import jax
|
21
|
+
|
22
|
+
import time
|
23
|
+
import brainstate as bst
|
24
|
+
|
25
|
+
|
26
|
+
def forward(n_pre, n_post, spk_prob, as_float: bool):
|
27
|
+
linear = bst.event.Linear(n_pre, n_post, weight=bst.init.KaimingUniform(), block_size=256)
|
28
|
+
spike = (bst.random.rand(n_pre) < spk_prob)
|
29
|
+
|
30
|
+
if as_float:
|
31
|
+
spike = spike.astype(float)
|
32
|
+
|
33
|
+
@jax.jit
|
34
|
+
def f1(spike):
|
35
|
+
return linear(spike)
|
36
|
+
|
37
|
+
@jax.jit
|
38
|
+
def f2(spike):
|
39
|
+
return spike @ linear.weight.value
|
40
|
+
|
41
|
+
y1 = jax.block_until_ready(f1(spike))
|
42
|
+
y2 = jax.block_until_ready(f2(spike))
|
43
|
+
print('max difference:', jax.numpy.abs(y1 - y2).max())
|
44
|
+
|
45
|
+
n = 100
|
46
|
+
t0 = time.time()
|
47
|
+
for _ in range(n):
|
48
|
+
jax.block_until_ready(f1(spike))
|
49
|
+
r1 = time.time() - t0
|
50
|
+
print(f"n_pre: {n_pre}, n_post: {n_post}, spike probability: {spk_prob}, Linear: {r1} s")
|
51
|
+
|
52
|
+
t0 = time.time()
|
53
|
+
for _ in range(n):
|
54
|
+
jax.block_until_ready(f2(spike))
|
55
|
+
r2 = time.time() - t0
|
56
|
+
print(f"n_pre: {n_pre}, n_post: {n_post}, spike probability: {spk_prob}, Matmul: {r2} s")
|
57
|
+
print('Acceleration ratio:', r2 / r1 - 1.)
|
58
|
+
|
59
|
+
print()
|
60
|
+
|
61
|
+
|
62
|
+
def benchmark_forward():
|
63
|
+
for n_pre, n_post in [
|
64
|
+
(1000, 1000),
|
65
|
+
(1000, 10000),
|
66
|
+
(10000, 10000),
|
67
|
+
(10000, 1000),
|
68
|
+
(20000, 10000),
|
69
|
+
(20000, 20000),
|
70
|
+
# (10000, 100000),
|
71
|
+
]:
|
72
|
+
forward(n_pre, n_post, 0.01, True)
|
73
|
+
forward(n_pre, n_post, 0.1, True)
|
74
|
+
print()
|
75
|
+
print()
|
76
|
+
|
77
|
+
|
78
|
+
if __name__ == '__main__':
|
79
|
+
# forward(1000, 2000, 0.01, True)
|
80
|
+
# forward(2000, 4000, 0.01, True)
|
81
|
+
# forward(10000, 20000, 0.01, True)
|
82
|
+
benchmark_forward()
|
brainstate/event/_linear_test.py
CHANGED
@@ -32,7 +32,7 @@ class TestEventLinear(parameterized.TestCase):
|
|
32
32
|
x = bst.random.rand(20) < 0.1
|
33
33
|
if not bool_x:
|
34
34
|
x = jnp.asarray(x, dtype=float)
|
35
|
-
m = Linear(20, 40, 1.5 if homo_w else bst.init.KaimingUniform())
|
35
|
+
m = Linear(20, 40, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
|
36
36
|
y = m(x)
|
37
37
|
print(y)
|
38
38
|
|
@@ -59,7 +59,7 @@ class TestEventLinear(parameterized.TestCase):
|
|
59
59
|
else:
|
60
60
|
x = bst.random.rand(n_in)
|
61
61
|
|
62
|
-
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform())
|
62
|
+
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), float_as_event=bool_x)
|
63
63
|
w = fn.weight.value
|
64
64
|
|
65
65
|
def f(x, w):
|
@@ -77,6 +77,10 @@ class TestEventLinear(parameterized.TestCase):
|
|
77
77
|
|
78
78
|
r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
79
79
|
self.assertTrue(jnp.allclose(r1[0], r2[0]))
|
80
|
+
|
81
|
+
if not jnp.allclose(r1[1], r2[1]):
|
82
|
+
print(r1[1] - r2[1])
|
83
|
+
|
80
84
|
self.assertTrue(jnp.allclose(r1[1], r2[1]))
|
81
85
|
|
82
86
|
@parameterized.product(
|
@@ -91,7 +95,8 @@ class TestEventLinear(parameterized.TestCase):
|
|
91
95
|
else:
|
92
96
|
x = bst.random.rand(n_in)
|
93
97
|
|
94
|
-
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(),
|
98
|
+
fn = Linear(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(),
|
99
|
+
float_as_event=bool_x)
|
95
100
|
w = fn.weight.value
|
96
101
|
|
97
102
|
def f(x, w):
|
@@ -0,0 +1,312 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
import ctypes
|
4
|
+
import functools
|
5
|
+
import importlib.util
|
6
|
+
from functools import partial
|
7
|
+
from typing import Callable, Sequence, Tuple, Protocol
|
8
|
+
|
9
|
+
import jax
|
10
|
+
import numpy as np
|
11
|
+
from jax import tree_util
|
12
|
+
from jax.core import Primitive
|
13
|
+
from jax.interpreters import batching, ad
|
14
|
+
from jax.interpreters import xla, mlir
|
15
|
+
from jax.lib import xla_client
|
16
|
+
from jaxlib.hlo_helpers import custom_call
|
17
|
+
|
18
|
+
numba_installed = importlib.util.find_spec('numba') is not None
|
19
|
+
|
20
|
+
if numba_installed:
|
21
|
+
import numba # pylint: disable=import-error
|
22
|
+
from numba import types, carray, cfunc # pylint: disable=import-error
|
23
|
+
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
|
24
|
+
else:
|
25
|
+
numba = None
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'XLACustomOp',
|
29
|
+
]
|
30
|
+
|
31
|
+
# [void* pointer,
|
32
|
+
# const char *name,
|
33
|
+
# PyCapsule_Destructor destructor]
|
34
|
+
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
|
35
|
+
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
|
36
|
+
|
37
|
+
|
38
|
+
def defjvp(primitive, *jvp_rules):
|
39
|
+
"""Define JVP rules for any JAX primitive.
|
40
|
+
|
41
|
+
This function is similar to ``jax.interpreters.ad.defjvp``.
|
42
|
+
However, the JAX one only supports primitive with ``multiple_results=False``.
|
43
|
+
``brainpy.math.defjvp`` enables to define the independent JVP rule for
|
44
|
+
each input parameter no matter ``multiple_results=False/True``.
|
45
|
+
|
46
|
+
For examples, please see ``test_ad_support.py``.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
primitive: Primitive, XLACustomOp.
|
50
|
+
*jvp_rules: The JVP translation rule for each primal.
|
51
|
+
"""
|
52
|
+
assert isinstance(primitive, Primitive)
|
53
|
+
if primitive.multiple_results:
|
54
|
+
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
|
55
|
+
else:
|
56
|
+
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
|
57
|
+
|
58
|
+
|
59
|
+
def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
|
60
|
+
assert primitive.multiple_results
|
61
|
+
val_out = tuple(primitive.bind(*primals, **params))
|
62
|
+
tree = tree_util.tree_structure(val_out)
|
63
|
+
tangents_out = []
|
64
|
+
for rule, t in zip(jvp_rules, tangents):
|
65
|
+
if rule is not None and type(t) is not ad.Zero:
|
66
|
+
r = tuple(rule(t, *primals, **params))
|
67
|
+
tangents_out.append(r)
|
68
|
+
assert tree_util.tree_structure(r) == tree
|
69
|
+
r = functools.reduce(
|
70
|
+
_add_tangents,
|
71
|
+
tangents_out,
|
72
|
+
tree_util.tree_map(
|
73
|
+
# compatible with JAX 0.4.34
|
74
|
+
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
|
75
|
+
val_out
|
76
|
+
)
|
77
|
+
)
|
78
|
+
return val_out, r
|
79
|
+
|
80
|
+
|
81
|
+
def _add_tangents(xs, ys):
|
82
|
+
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
|
83
|
+
|
84
|
+
|
85
|
+
def _shape_to_layout(shape):
|
86
|
+
return tuple(range(len(shape) - 1, -1, -1))
|
87
|
+
|
88
|
+
|
89
|
+
def _numba_mlir_cpu_translation_rule(
|
90
|
+
kernel,
|
91
|
+
debug: bool,
|
92
|
+
ctx,
|
93
|
+
*ins,
|
94
|
+
**kwargs
|
95
|
+
):
|
96
|
+
if numba is None:
|
97
|
+
raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
|
98
|
+
|
99
|
+
if not isinstance(kernel, Dispatcher):
|
100
|
+
kernel = kernel(**kwargs)
|
101
|
+
assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
|
102
|
+
|
103
|
+
# output information
|
104
|
+
outs = ctx.avals_out
|
105
|
+
output_shapes = tuple([out.shape for out in outs])
|
106
|
+
output_dtypes = tuple([out.dtype for out in outs])
|
107
|
+
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
|
108
|
+
result_types = [mlir.aval_to_ir_type(out) for out in outs]
|
109
|
+
|
110
|
+
# input information
|
111
|
+
avals_in = ctx.avals_in
|
112
|
+
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
|
113
|
+
input_dtypes = tuple(inp.dtype for inp in avals_in)
|
114
|
+
input_shapes = tuple(inp.shape for inp in avals_in)
|
115
|
+
|
116
|
+
# compiling function
|
117
|
+
code_scope = dict(func_to_call=kernel,
|
118
|
+
input_shapes=input_shapes,
|
119
|
+
input_dtypes=input_dtypes,
|
120
|
+
output_shapes=output_shapes,
|
121
|
+
output_dtypes=output_dtypes, carray=carray)
|
122
|
+
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
|
123
|
+
for i in range(len(input_shapes))]
|
124
|
+
if len(output_shapes) > 1:
|
125
|
+
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
|
126
|
+
for i in range(len(output_shapes))]
|
127
|
+
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
|
128
|
+
else:
|
129
|
+
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
|
130
|
+
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
|
131
|
+
args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
|
132
|
+
code_string = '''
|
133
|
+
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
134
|
+
{args_in}
|
135
|
+
{args_out}
|
136
|
+
func_to_call({args_call})
|
137
|
+
'''.format(args_in="\n ".join(args_in),
|
138
|
+
args_out="\n ".join(args_out),
|
139
|
+
args_call=", ".join(args_call))
|
140
|
+
if debug:
|
141
|
+
print(code_string)
|
142
|
+
exec(compile(code_string.strip(), '', 'exec'), code_scope)
|
143
|
+
new_f = code_scope['numba_cpu_custom_call_target']
|
144
|
+
|
145
|
+
# register
|
146
|
+
xla_c_rule = cfunc(sig)(new_f)
|
147
|
+
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
|
148
|
+
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
|
149
|
+
xla_client.register_custom_call_target(target_name, capsule, "cpu")
|
150
|
+
|
151
|
+
# call
|
152
|
+
return custom_call(
|
153
|
+
call_target_name=target_name,
|
154
|
+
operands=ins,
|
155
|
+
operand_layouts=list(input_layouts),
|
156
|
+
result_layouts=list(output_layouts),
|
157
|
+
result_types=list(result_types),
|
158
|
+
has_side_effect=False,
|
159
|
+
).results
|
160
|
+
|
161
|
+
|
162
|
+
def register_numba_mlir_cpu_translation_rule(
|
163
|
+
primitive: jax.core.Primitive,
|
164
|
+
cpu_kernel: Callable,
|
165
|
+
debug: bool = False
|
166
|
+
):
|
167
|
+
rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
|
168
|
+
mlir.register_lowering(primitive, rule, platform='cpu')
|
169
|
+
|
170
|
+
|
171
|
+
class ShapeDtype(Protocol):
|
172
|
+
|
173
|
+
@property
|
174
|
+
def shape(self) -> Tuple[int, ...]:
|
175
|
+
...
|
176
|
+
|
177
|
+
@property
|
178
|
+
def dtype(self) -> np.dtype:
|
179
|
+
...
|
180
|
+
|
181
|
+
|
182
|
+
class XLACustomOp:
|
183
|
+
"""Creating a XLA custom call operator.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
cpu_kernel_generator: Callable. The function defines the computation on CPU backend.
|
187
|
+
gpu_kernel_generator: Callable. The function defines the computation on GPU backend.
|
188
|
+
batching_translation: Callable. The batching translation rule of JAX.
|
189
|
+
jvp_translation: Callable. The JVP translation rule of JAX.
|
190
|
+
transpose_translation: Callable. The transpose translation rule of JAX.
|
191
|
+
name: str. The primitive name.
|
192
|
+
"""
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
name: str,
|
197
|
+
cpu_kernel_generator: Callable,
|
198
|
+
gpu_kernel_generator: Callable = None,
|
199
|
+
batching_translation: Callable = None,
|
200
|
+
jvp_translation: Callable = None,
|
201
|
+
transpose_translation: Callable = None,
|
202
|
+
):
|
203
|
+
# set cpu_kernel and gpu_kernel
|
204
|
+
self.cpu_kernel = cpu_kernel_generator
|
205
|
+
|
206
|
+
# primitive
|
207
|
+
self.primitive = jax.core.Primitive(name)
|
208
|
+
self.primitive.multiple_results = True
|
209
|
+
|
210
|
+
# abstract evaluation
|
211
|
+
self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
|
212
|
+
self.primitive.def_abstract_eval(self._abstract_eval)
|
213
|
+
|
214
|
+
# cpu kernel
|
215
|
+
if cpu_kernel_generator is not None:
|
216
|
+
self.def_cpu_kernel(cpu_kernel_generator)
|
217
|
+
if gpu_kernel_generator is not None:
|
218
|
+
self.def_gpu_kernel(gpu_kernel_generator)
|
219
|
+
|
220
|
+
# batching rule
|
221
|
+
if batching_translation is not None:
|
222
|
+
batching.primitive_batchers[self.primitive] = batching_translation
|
223
|
+
|
224
|
+
# jvp rule
|
225
|
+
if jvp_translation is not None:
|
226
|
+
ad.primitive_jvps[self.primitive] = jvp_translation
|
227
|
+
|
228
|
+
# transpose rule
|
229
|
+
if transpose_translation is not None:
|
230
|
+
ad.primitive_transposes[self.primitive] = transpose_translation
|
231
|
+
|
232
|
+
def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
|
233
|
+
return tuple(outs)
|
234
|
+
|
235
|
+
def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
|
236
|
+
assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.'
|
237
|
+
outs = jax.tree.map(_transform_to_shapedarray, outs)
|
238
|
+
return self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
|
239
|
+
|
240
|
+
def def_cpu_kernel(self, kernel_generator: Callable):
|
241
|
+
"""
|
242
|
+
Define the CPU kernel using Numba.
|
243
|
+
"""
|
244
|
+
register_numba_mlir_cpu_translation_rule(self.primitive, kernel_generator)
|
245
|
+
|
246
|
+
def def_gpu_kernel(self, kernel_generator: Callable):
|
247
|
+
"""
|
248
|
+
Define the GPU kernel using the JAX Pallas language.
|
249
|
+
"""
|
250
|
+
lower = mlir.lower_fun(
|
251
|
+
lambda *args, **kwargs: kernel_generator(**kwargs)(*args),
|
252
|
+
multiple_results=True
|
253
|
+
)
|
254
|
+
mlir.register_lowering(self.primitive, lower, platform='cuda')
|
255
|
+
mlir.register_lowering(self.primitive, lower, platform='tpu')
|
256
|
+
|
257
|
+
def def_batching_rule(self, fun):
|
258
|
+
"""Define the batching rule.
|
259
|
+
|
260
|
+
Args:
|
261
|
+
fun: The batching rule.
|
262
|
+
"""
|
263
|
+
batching.primitive_batchers[self.primitive] = fun
|
264
|
+
|
265
|
+
def def_jvp_rule(self, fun):
|
266
|
+
"""Define the JVP rule.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
fun: The JVP rule.
|
270
|
+
"""
|
271
|
+
ad.primitive_jvps[self.primitive] = fun
|
272
|
+
|
273
|
+
def defjvp(self, *jvp_rules):
|
274
|
+
"""
|
275
|
+
Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``,
|
276
|
+
but supports the Primitive with multiple results.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
jvp_rules: The JVP rules.
|
280
|
+
"""
|
281
|
+
defjvp(self.primitive, *jvp_rules)
|
282
|
+
|
283
|
+
def def_transpose_rule(self, fun):
|
284
|
+
"""Define the transpose rule.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
fun: The transpose rule.
|
288
|
+
"""
|
289
|
+
ad.primitive_transposes[self.primitive] = fun
|
290
|
+
|
291
|
+
def def_xla_translation(self, platform, fun):
|
292
|
+
"""Define the XLA translation rule.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
platform: str. The computing platform.
|
296
|
+
fun: The XLA translation rule.
|
297
|
+
"""
|
298
|
+
xla.backend_specific_translations[platform][self.primitive] = fun
|
299
|
+
|
300
|
+
def def_mlir_lowering(self, platform, fun):
|
301
|
+
"""
|
302
|
+
Define the MLIR lowering rule.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
platform: str. The computing platform.
|
306
|
+
fun: The lowering rule.
|
307
|
+
"""
|
308
|
+
mlir.register_lowering(self.primitive, fun, platform)
|
309
|
+
|
310
|
+
|
311
|
+
def _transform_to_shapedarray(a):
|
312
|
+
return jax.core.ShapedArray(a.shape, a.dtype)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import jax
|
17
|
+
import jax.numpy as jnp
|
18
|
+
from jax.experimental import pallas as pl
|
19
|
+
|
20
|
+
import brainstate as bst
|
21
|
+
|
22
|
+
|
23
|
+
def test1():
|
24
|
+
import numba
|
25
|
+
def add_vectors_kernel(x_ref, y_ref, o_ref):
|
26
|
+
x, y = x_ref[...], y_ref[...]
|
27
|
+
o_ref[...] = x + y
|
28
|
+
|
29
|
+
def cpu_kernel(**kwargs):
|
30
|
+
@numba.njit
|
31
|
+
def add_kernel_numba(x, y, out):
|
32
|
+
out[...] = x + y
|
33
|
+
|
34
|
+
return add_kernel_numba
|
35
|
+
|
36
|
+
def gpu_kernel(x_info):
|
37
|
+
return pl.pallas_call(
|
38
|
+
add_vectors_kernel,
|
39
|
+
out_shape=[jax.ShapeDtypeStruct(x_info.shape, x_info.dtype)],
|
40
|
+
interpret=jax.default_backend() == 'cpu',
|
41
|
+
)
|
42
|
+
|
43
|
+
prim = bst.event.XLACustomOp(
|
44
|
+
'add',
|
45
|
+
cpu_kernel,
|
46
|
+
gpu_kernel,
|
47
|
+
)
|
48
|
+
|
49
|
+
a = bst.random.rand(64)
|
50
|
+
b = bst.random.rand(64)
|
51
|
+
x_info = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
52
|
+
r1 = prim(a, b, outs=[jax.ShapeDtypeStruct((64,), jax.numpy.float32)], x_info=x_info)
|
53
|
+
r2 = gpu_kernel(x_info)(a, b)
|
54
|
+
|
55
|
+
assert jnp.allclose(r1[0], r2[0])
|
@@ -112,10 +112,8 @@ class STP(Synapse):
|
|
112
112
|
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
113
113
|
|
114
114
|
def update(self, pre_spike):
|
115
|
-
|
116
|
-
|
117
|
-
u = exp_euler_step(du, self.u.value)
|
118
|
-
x = exp_euler_step(dx, self.x.value)
|
115
|
+
u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
|
116
|
+
x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
|
119
117
|
|
120
118
|
# --- original code:
|
121
119
|
# if pre_spike.dtype == jax.numpy.bool_:
|
@@ -131,7 +129,7 @@ class STP(Synapse):
|
|
131
129
|
|
132
130
|
self.u.value = u
|
133
131
|
self.x.value = x
|
134
|
-
return u * x
|
132
|
+
return u * x * pre_spike
|
135
133
|
|
136
134
|
|
137
135
|
class STD(Synapse):
|
@@ -167,8 +165,7 @@ class STD(Synapse):
|
|
167
165
|
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
168
166
|
|
169
167
|
def update(self, pre_spike):
|
170
|
-
|
171
|
-
x = exp_euler_step(dx, self.x.value)
|
168
|
+
x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
|
172
169
|
|
173
170
|
# --- original code:
|
174
171
|
# self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
|
@@ -176,7 +173,7 @@ class STD(Synapse):
|
|
176
173
|
# --- simplified code:
|
177
174
|
self.x.value = x - pre_spike * self.U * self.x.value
|
178
175
|
|
179
|
-
return self.x.value
|
176
|
+
return self.x.value * pre_spike
|
180
177
|
|
181
178
|
|
182
179
|
class AMPA(Synapse):
|
@@ -315,6 +312,4 @@ class GABAa(AMPA):
|
|
315
312
|
T: ArrayLike = 1.0 * u.mM,
|
316
313
|
T_dur: ArrayLike = 1.0 * u.ms,
|
317
314
|
):
|
318
|
-
super().__init__(alpha=alpha, beta=beta, T=T,
|
319
|
-
T_dur=T_dur, name=name,
|
320
|
-
in_size=in_size)
|
315
|
+
super().__init__(alpha=alpha, beta=beta, T=T, T_dur=T_dur, name=name, in_size=in_size)
|
@@ -23,7 +23,7 @@ import jax.numpy as jnp
|
|
23
23
|
|
24
24
|
from brainstate import random, init, functional
|
25
25
|
from brainstate._state import HiddenState, ParamState
|
26
|
-
from brainstate.nn._interaction.
|
26
|
+
from brainstate.nn._interaction._linear import Linear
|
27
27
|
from brainstate.nn._module import Module
|
28
28
|
from brainstate.typing import ArrayLike
|
29
29
|
|
@@ -207,7 +207,7 @@ class AlignPostProj(Interaction):
|
|
207
207
|
|
208
208
|
|
209
209
|
class DeltaProj(Interaction):
|
210
|
-
"""Full-chain of the synaptic projection for the Delta synapse model.
|
210
|
+
r"""Full-chain of the synaptic projection for the Delta synapse model.
|
211
211
|
|
212
212
|
The synaptic projection requires the input is the spiking data, otherwise
|
213
213
|
the synapse is not the Delta synapse model.
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -13,20 +13,29 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from .
|
17
|
-
from .
|
16
|
+
from ._conv import *
|
17
|
+
from ._conv import __all__ as conv_all
|
18
18
|
from ._embedding import *
|
19
19
|
from ._embedding import __all__ as embed_all
|
20
|
+
from ._linear import *
|
21
|
+
from ._linear import __all__ as linear_all
|
20
22
|
from ._normalizations import *
|
21
23
|
from ._normalizations import __all__ as normalizations_all
|
22
24
|
from ._poolings import *
|
23
25
|
from ._poolings import __all__ as poolings_all
|
24
26
|
|
25
27
|
__all__ = (
|
26
|
-
|
28
|
+
conv_all +
|
29
|
+
linear_all +
|
27
30
|
normalizations_all +
|
28
31
|
poolings_all +
|
29
32
|
embed_all
|
30
33
|
)
|
31
34
|
|
32
|
-
del
|
35
|
+
del (
|
36
|
+
conv_all,
|
37
|
+
linear_all,
|
38
|
+
normalizations_all,
|
39
|
+
poolings_all,
|
40
|
+
embed_all
|
41
|
+
)
|