brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241125__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd.py +121 -120
  5. brainstate/augment/_autograd_test.py +97 -0
  6. brainstate/event/__init__.py +10 -8
  7. brainstate/event/_csr_benchmark.py +14 -0
  8. brainstate/event/{_csr.py → _csr_mv.py} +26 -18
  9. brainstate/event/_csr_mv_benchmark.py +14 -0
  10. brainstate/event/_fixedprob_mv.py +708 -0
  11. brainstate/event/_fixedprob_mv_benchmark.py +128 -0
  12. brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +13 -10
  13. brainstate/event/_linear_mv.py +359 -0
  14. brainstate/event/_linear_mv_benckmark.py +82 -0
  15. brainstate/event/{_linear_test.py → _linear_mv_test.py} +9 -4
  16. brainstate/event/_xla_custom_op.py +309 -0
  17. brainstate/event/_xla_custom_op_test.py +55 -0
  18. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  19. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  20. brainstate/nn/_dynamics/_projection_base.py +1 -1
  21. brainstate/nn/_exp_euler.py +1 -1
  22. brainstate/nn/_interaction/__init__.py +13 -4
  23. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  24. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  25. brainstate/nn/_interaction/_linear.py +582 -0
  26. brainstate/nn/_interaction/_linear_test.py +42 -0
  27. brainstate/optim/_lr_scheduler.py +1 -1
  28. brainstate/optim/_optax_optimizer.py +19 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/METADATA +2 -2
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/RECORD +34 -24
  31. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/top_level.txt +1 -0
  32. brainstate/event/_fixed_probability.py +0 -271
  33. brainstate/event/_linear.py +0 -219
  34. /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
  35. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/LICENSE +0 -0
  36. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/WHEEL +0 -0
@@ -0,0 +1,309 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ctypes
4
+ import functools
5
+ import importlib.util
6
+ from functools import partial
7
+ from typing import Callable, Sequence, Tuple, Protocol
8
+
9
+ import jax
10
+ import numpy as np
11
+ from jax import tree_util
12
+ from jax.core import Primitive
13
+ from jax.interpreters import batching, ad
14
+ from jax.interpreters import xla, mlir
15
+ from jax.lib import xla_client
16
+ from jaxlib.hlo_helpers import custom_call
17
+
18
+ numba_installed = importlib.util.find_spec('numba') is not None
19
+
20
+ __all__ = [
21
+ 'defjvp',
22
+ 'XLACustomOp',
23
+ ]
24
+
25
+ # [void* pointer,
26
+ # const char *name,
27
+ # PyCapsule_Destructor destructor]
28
+ ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
29
+ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
30
+
31
+
32
+ def defjvp(primitive, *jvp_rules):
33
+ """Define JVP rules for any JAX primitive.
34
+
35
+ This function is similar to ``jax.interpreters.ad.defjvp``.
36
+ However, the JAX one only supports primitive with ``multiple_results=False``.
37
+ ``brainpy.math.defjvp`` enables to define the independent JVP rule for
38
+ each input parameter no matter ``multiple_results=False/True``.
39
+
40
+ For examples, please see ``test_ad_support.py``.
41
+
42
+ Args:
43
+ primitive: Primitive, XLACustomOp.
44
+ *jvp_rules: The JVP translation rule for each primal.
45
+ """
46
+ assert isinstance(primitive, Primitive)
47
+ if primitive.multiple_results:
48
+ ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
49
+ else:
50
+ ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
51
+
52
+
53
+ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
54
+ assert primitive.multiple_results
55
+ val_out = tuple(primitive.bind(*primals, **params))
56
+ tree = tree_util.tree_structure(val_out)
57
+ tangents_out = []
58
+ for rule, t in zip(jvp_rules, tangents):
59
+ if rule is not None and type(t) is not ad.Zero:
60
+ r = tuple(rule(t, *primals, **params))
61
+ tangents_out.append(r)
62
+ assert tree_util.tree_structure(r) == tree
63
+ r = functools.reduce(
64
+ _add_tangents,
65
+ tangents_out,
66
+ tree_util.tree_map(
67
+ # compatible with JAX 0.4.34
68
+ lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
69
+ val_out
70
+ )
71
+ )
72
+ return val_out, r
73
+
74
+
75
+ def _add_tangents(xs, ys):
76
+ return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
77
+
78
+
79
+ def _shape_to_layout(shape):
80
+ return tuple(range(len(shape) - 1, -1, -1))
81
+
82
+
83
+ def _numba_mlir_cpu_translation_rule(
84
+ kernel,
85
+ debug: bool,
86
+ ctx,
87
+ *ins,
88
+ **kwargs
89
+ ):
90
+ if not numba_installed:
91
+ raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
92
+
93
+ from numba import types, carray, cfunc # pylint: disable=import-error
94
+ from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
95
+
96
+ if not isinstance(kernel, Dispatcher):
97
+ kernel = kernel(**kwargs)
98
+ assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
99
+
100
+ # output information
101
+ outs = ctx.avals_out
102
+ output_shapes = tuple([out.shape for out in outs])
103
+ output_dtypes = tuple([out.dtype for out in outs])
104
+ output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
105
+ result_types = [mlir.aval_to_ir_type(out) for out in outs]
106
+
107
+ # input information
108
+ avals_in = ctx.avals_in
109
+ input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
110
+ input_dtypes = tuple(inp.dtype for inp in avals_in)
111
+ input_shapes = tuple(inp.shape for inp in avals_in)
112
+
113
+ # compiling function
114
+ code_scope = dict(func_to_call=kernel,
115
+ input_shapes=input_shapes,
116
+ input_dtypes=input_dtypes,
117
+ output_shapes=output_shapes,
118
+ output_dtypes=output_dtypes, carray=carray)
119
+ args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
120
+ for i in range(len(input_shapes))]
121
+ if len(output_shapes) > 1:
122
+ args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
123
+ for i in range(len(output_shapes))]
124
+ sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
125
+ else:
126
+ args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
127
+ sig = types.void(types.voidptr, types.CPointer(types.voidptr))
128
+ args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
129
+ code_string = '''
130
+ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
131
+ {args_in}
132
+ {args_out}
133
+ func_to_call({args_call})
134
+ '''.format(args_in="\n ".join(args_in),
135
+ args_out="\n ".join(args_out),
136
+ args_call=", ".join(args_call))
137
+ if debug:
138
+ print(code_string)
139
+ exec(compile(code_string.strip(), '', 'exec'), code_scope)
140
+ new_f = code_scope['numba_cpu_custom_call_target']
141
+
142
+ # register
143
+ xla_c_rule = cfunc(sig)(new_f)
144
+ target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
145
+ capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
146
+ xla_client.register_custom_call_target(target_name, capsule, "cpu")
147
+
148
+ # call
149
+ return custom_call(
150
+ call_target_name=target_name,
151
+ operands=ins,
152
+ operand_layouts=list(input_layouts),
153
+ result_layouts=list(output_layouts),
154
+ result_types=list(result_types),
155
+ has_side_effect=False,
156
+ ).results
157
+
158
+
159
+ def register_numba_mlir_cpu_translation_rule(
160
+ primitive: jax.core.Primitive,
161
+ cpu_kernel: Callable,
162
+ debug: bool = False
163
+ ):
164
+ rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
165
+ mlir.register_lowering(primitive, rule, platform='cpu')
166
+
167
+
168
+ class ShapeDtype(Protocol):
169
+
170
+ @property
171
+ def shape(self) -> Tuple[int, ...]:
172
+ ...
173
+
174
+ @property
175
+ def dtype(self) -> np.dtype:
176
+ ...
177
+
178
+
179
+ class XLACustomOp:
180
+ """Creating a XLA custom call operator.
181
+
182
+ Args:
183
+ cpu_kernel_generator: Callable. The function defines the computation on CPU backend.
184
+ gpu_kernel_generator: Callable. The function defines the computation on GPU backend.
185
+ batching_translation: Callable. The batching translation rule of JAX.
186
+ jvp_translation: Callable. The JVP translation rule of JAX.
187
+ transpose_translation: Callable. The transpose translation rule of JAX.
188
+ name: str. The primitive name.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ name: str,
194
+ cpu_kernel_generator: Callable,
195
+ gpu_kernel_generator: Callable = None,
196
+ batching_translation: Callable = None,
197
+ jvp_translation: Callable = None,
198
+ transpose_translation: Callable = None,
199
+ ):
200
+ # set cpu_kernel and gpu_kernel
201
+ self.cpu_kernel = cpu_kernel_generator
202
+
203
+ # primitive
204
+ self.primitive = jax.core.Primitive(name)
205
+ self.primitive.multiple_results = True
206
+
207
+ # abstract evaluation
208
+ self.primitive.def_impl(partial(xla.apply_primitive, self.primitive))
209
+ self.primitive.def_abstract_eval(self._abstract_eval)
210
+
211
+ # cpu kernel
212
+ if cpu_kernel_generator is not None:
213
+ self.def_cpu_kernel(cpu_kernel_generator)
214
+ if gpu_kernel_generator is not None:
215
+ self.def_gpu_kernel(gpu_kernel_generator)
216
+
217
+ # batching rule
218
+ if batching_translation is not None:
219
+ batching.primitive_batchers[self.primitive] = batching_translation
220
+
221
+ # jvp rule
222
+ if jvp_translation is not None:
223
+ ad.primitive_jvps[self.primitive] = jvp_translation
224
+
225
+ # transpose rule
226
+ if transpose_translation is not None:
227
+ ad.primitive_transposes[self.primitive] = transpose_translation
228
+
229
+ def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
230
+ return tuple(outs)
231
+
232
+ def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs):
233
+ assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.'
234
+ outs = jax.tree.map(_transform_to_shapedarray, outs)
235
+ return self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
236
+
237
+ def def_cpu_kernel(self, kernel_generator: Callable):
238
+ """
239
+ Define the CPU kernel using Numba.
240
+ """
241
+ register_numba_mlir_cpu_translation_rule(self.primitive, kernel_generator)
242
+
243
+ def def_gpu_kernel(self, kernel_generator: Callable):
244
+ """
245
+ Define the GPU kernel using the JAX Pallas language.
246
+ """
247
+ lower = mlir.lower_fun(
248
+ lambda *args, **kwargs: kernel_generator(**kwargs)(*args),
249
+ multiple_results=True
250
+ )
251
+ mlir.register_lowering(self.primitive, lower, platform='cuda')
252
+ mlir.register_lowering(self.primitive, lower, platform='tpu')
253
+
254
+ def def_batching_rule(self, fun):
255
+ """Define the batching rule.
256
+
257
+ Args:
258
+ fun: The batching rule.
259
+ """
260
+ batching.primitive_batchers[self.primitive] = fun
261
+
262
+ def def_jvp_rule(self, fun):
263
+ """Define the JVP rule.
264
+
265
+ Args:
266
+ fun: The JVP rule.
267
+ """
268
+ ad.primitive_jvps[self.primitive] = fun
269
+
270
+ def defjvp(self, *jvp_rules):
271
+ """
272
+ Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``,
273
+ but supports the Primitive with multiple results.
274
+
275
+ Args:
276
+ jvp_rules: The JVP rules.
277
+ """
278
+ defjvp(self.primitive, *jvp_rules)
279
+
280
+ def def_transpose_rule(self, fun):
281
+ """Define the transpose rule.
282
+
283
+ Args:
284
+ fun: The transpose rule.
285
+ """
286
+ ad.primitive_transposes[self.primitive] = fun
287
+
288
+ def def_xla_translation(self, platform, fun):
289
+ """Define the XLA translation rule.
290
+
291
+ Args:
292
+ platform: str. The computing platform.
293
+ fun: The XLA translation rule.
294
+ """
295
+ xla.backend_specific_translations[platform][self.primitive] = fun
296
+
297
+ def def_mlir_lowering(self, platform, fun):
298
+ """
299
+ Define the MLIR lowering rule.
300
+
301
+ Args:
302
+ platform: str. The computing platform.
303
+ fun: The lowering rule.
304
+ """
305
+ mlir.register_lowering(self.primitive, fun, platform)
306
+
307
+
308
+ def _transform_to_shapedarray(a):
309
+ return jax.core.ShapedArray(a.shape, a.dtype)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import jax
17
+ import jax.numpy as jnp
18
+ from jax.experimental import pallas as pl
19
+
20
+ import brainstate as bst
21
+
22
+
23
+ def test1():
24
+ import numba
25
+ def add_vectors_kernel(x_ref, y_ref, o_ref):
26
+ x, y = x_ref[...], y_ref[...]
27
+ o_ref[...] = x + y
28
+
29
+ def cpu_kernel(**kwargs):
30
+ @numba.njit
31
+ def add_kernel_numba(x, y, out):
32
+ out[...] = x + y
33
+
34
+ return add_kernel_numba
35
+
36
+ def gpu_kernel(x_info):
37
+ return pl.pallas_call(
38
+ add_vectors_kernel,
39
+ out_shape=[jax.ShapeDtypeStruct(x_info.shape, x_info.dtype)],
40
+ interpret=jax.default_backend() == 'cpu',
41
+ )
42
+
43
+ prim = bst.event.XLACustomOp(
44
+ 'add',
45
+ cpu_kernel,
46
+ gpu_kernel,
47
+ )
48
+
49
+ a = bst.random.rand(64)
50
+ b = bst.random.rand(64)
51
+ x_info = jax.ShapeDtypeStruct(a.shape, a.dtype)
52
+ r1 = prim(a, b, outs=[jax.ShapeDtypeStruct((64,), jax.numpy.float32)], x_info=x_info)
53
+ r2 = gpu_kernel(x_info)(a, b)
54
+
55
+ assert jnp.allclose(r1[0], r2[0])
@@ -112,10 +112,8 @@ class STP(Synapse):
112
112
  self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
113
113
 
114
114
  def update(self, pre_spike):
115
- du = lambda u: self.U - u / self.tau_f
116
- dx = lambda x: (1 - x) / self.tau_d
117
- u = exp_euler_step(du, self.u.value)
118
- x = exp_euler_step(dx, self.x.value)
115
+ u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
116
+ x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
119
117
 
120
118
  # --- original code:
121
119
  # if pre_spike.dtype == jax.numpy.bool_:
@@ -131,7 +129,7 @@ class STP(Synapse):
131
129
 
132
130
  self.u.value = u
133
131
  self.x.value = x
134
- return u * x
132
+ return u * x * pre_spike
135
133
 
136
134
 
137
135
  class STD(Synapse):
@@ -167,8 +165,7 @@ class STD(Synapse):
167
165
  self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
168
166
 
169
167
  def update(self, pre_spike):
170
- dx = lambda x: (1 - x) / self.tau
171
- x = exp_euler_step(dx, self.x.value)
168
+ x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
172
169
 
173
170
  # --- original code:
174
171
  # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
@@ -176,7 +173,7 @@ class STD(Synapse):
176
173
  # --- simplified code:
177
174
  self.x.value = x - pre_spike * self.U * self.x.value
178
175
 
179
- return self.x.value
176
+ return self.x.value * pre_spike
180
177
 
181
178
 
182
179
  class AMPA(Synapse):
@@ -315,6 +312,4 @@ class GABAa(AMPA):
315
312
  T: ArrayLike = 1.0 * u.mM,
316
313
  T_dur: ArrayLike = 1.0 * u.ms,
317
314
  ):
318
- super().__init__(alpha=alpha, beta=beta, T=T,
319
- T_dur=T_dur, name=name,
320
- in_size=in_size)
315
+ super().__init__(alpha=alpha, beta=beta, T=T, T_dur=T_dur, name=name, in_size=in_size)
@@ -23,7 +23,7 @@ import jax.numpy as jnp
23
23
 
24
24
  from brainstate import random, init, functional
25
25
  from brainstate._state import HiddenState, ParamState
26
- from brainstate.nn._interaction._connections import Linear
26
+ from brainstate.nn._interaction._linear import Linear
27
27
  from brainstate.nn._module import Module
28
28
  from brainstate.typing import ArrayLike
29
29
 
@@ -207,7 +207,7 @@ class AlignPostProj(Interaction):
207
207
 
208
208
 
209
209
  class DeltaProj(Interaction):
210
- """Full-chain of the synaptic projection for the Delta synapse model.
210
+ r"""Full-chain of the synaptic projection for the Delta synapse model.
211
211
 
212
212
  The synaptic projection requires the input is the spiking data, otherwise
213
213
  the synapse is not the Delta synapse model.
@@ -32,7 +32,7 @@ __all__ = [
32
32
  def exp_euler_step(
33
33
  fn: Callable, *args, **kwargs
34
34
  ):
35
- """
35
+ r"""
36
36
  One-step Exponential Euler method for solving ODEs.
37
37
 
38
38
  Examples
@@ -13,20 +13,29 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ._connections import *
17
- from ._connections import __all__ as connections_all
16
+ from ._conv import *
17
+ from ._conv import __all__ as conv_all
18
18
  from ._embedding import *
19
19
  from ._embedding import __all__ as embed_all
20
+ from ._linear import *
21
+ from ._linear import __all__ as linear_all
20
22
  from ._normalizations import *
21
23
  from ._normalizations import __all__ as normalizations_all
22
24
  from ._poolings import *
23
25
  from ._poolings import __all__ as poolings_all
24
26
 
25
27
  __all__ = (
26
- connections_all +
28
+ conv_all +
29
+ linear_all +
27
30
  normalizations_all +
28
31
  poolings_all +
29
32
  embed_all
30
33
  )
31
34
 
32
- del connections_all, normalizations_all, poolings_all, embed_all
35
+ del (
36
+ conv_all,
37
+ linear_all,
38
+ normalizations_all,
39
+ poolings_all,
40
+ embed_all
41
+ )