brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,220 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate as bc
22
+
23
+
24
+ class TestCond(unittest.TestCase):
25
+ def test1(self):
26
+ bc.random.seed(1)
27
+ bc.transform.cond(True, lambda: bc.random.random(10), lambda: bc.random.random(10))
28
+ bc.transform.cond(False, lambda: bc.random.random(10), lambda: bc.random.random(10))
29
+
30
+ def test2(self):
31
+ st1 = bc.State(bc.random.rand(10))
32
+ st2 = bc.State(bc.random.rand(2))
33
+ st3 = bc.State(bc.random.rand(5))
34
+ st4 = bc.State(bc.random.rand(2, 10))
35
+
36
+ def true_fun(x):
37
+ st1.value = st2.value @ st4.value + x
38
+
39
+ def false_fun(x):
40
+ st3.value = (st3.value + 1.) * x
41
+
42
+ bc.transform.cond(True, true_fun, false_fun, 2.)
43
+ assert not isinstance(st1.value, jax.core.Tracer)
44
+ assert not isinstance(st2.value, jax.core.Tracer)
45
+ assert not isinstance(st3.value, jax.core.Tracer)
46
+ assert not isinstance(st4.value, jax.core.Tracer)
47
+
48
+
49
+ class TestSwitch(unittest.TestCase):
50
+ def testSwitch(self):
51
+ def branch(x):
52
+ y = jax.lax.mul(2, x)
53
+ return y, jax.lax.mul(2, y)
54
+
55
+ branches = [lambda x: (x, x),
56
+ branch,
57
+ lambda x: (x, -x)]
58
+
59
+ def fun(x):
60
+ if x <= 0:
61
+ return branches[0](x)
62
+ elif x == 1:
63
+ return branches[1](x)
64
+ else:
65
+ return branches[2](x)
66
+
67
+ def cfun(x):
68
+ return bc.transform.switch(x, branches, x)
69
+
70
+ self.assertEqual(fun(-1), cfun(-1))
71
+ self.assertEqual(fun(0), cfun(0))
72
+ self.assertEqual(fun(1), cfun(1))
73
+ self.assertEqual(fun(2), cfun(2))
74
+ self.assertEqual(fun(3), cfun(3))
75
+
76
+ cfun = jax.jit(cfun)
77
+
78
+ self.assertEqual(fun(-1), cfun(-1))
79
+ self.assertEqual(fun(0), cfun(0))
80
+ self.assertEqual(fun(1), cfun(1))
81
+ self.assertEqual(fun(2), cfun(2))
82
+ self.assertEqual(fun(3), cfun(3))
83
+
84
+ def testSwitchMultiOperands(self):
85
+ branches = [jax.lax.add, jax.lax.mul]
86
+
87
+ def fun(x):
88
+ i = 0 if x <= 0 else 1
89
+ return branches[i](x, x)
90
+
91
+ def cfun(x):
92
+ return bc.transform.switch(x, branches, x, x)
93
+
94
+ self.assertEqual(fun(-1), cfun(-1))
95
+ self.assertEqual(fun(0), cfun(0))
96
+ self.assertEqual(fun(1), cfun(1))
97
+ self.assertEqual(fun(2), cfun(2))
98
+ cfun = jax.jit(cfun)
99
+ self.assertEqual(fun(-1), cfun(-1))
100
+ self.assertEqual(fun(0), cfun(0))
101
+ self.assertEqual(fun(1), cfun(1))
102
+ self.assertEqual(fun(2), cfun(2))
103
+
104
+ def testSwitchResidualsMerge(self):
105
+ def get_conds(fun):
106
+ jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
107
+ return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
108
+
109
+ def branch_invars_len(cond_eqn):
110
+ lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
111
+ assert len(set(lens)) == 1
112
+ return lens[0]
113
+
114
+ def branch_outvars_len(cond_eqn):
115
+ lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
116
+ assert len(set(lens)) == 1
117
+ return lens[0]
118
+
119
+ branches1 = [lambda x: jnp.sin(x),
120
+ lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
121
+ branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
122
+ branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
123
+
124
+ def fun1(x, i):
125
+ return bc.transform.switch(i + 1, branches1, x)
126
+
127
+ def fun2(x, i):
128
+ return bc.transform.switch(i + 1, branches2, x)
129
+
130
+ def fun3(x, i):
131
+ return bc.transform.switch(i + 1, branches3, x)
132
+
133
+ fwd1, bwd1 = get_conds(fun1)
134
+ fwd2, bwd2 = get_conds(fun2)
135
+ fwd3, bwd3 = get_conds(fun3)
136
+
137
+ fwd1_num_out = branch_outvars_len(fwd1)
138
+ fwd2_num_out = branch_outvars_len(fwd2)
139
+ fwd3_num_out = branch_outvars_len(fwd3)
140
+ assert fwd1_num_out == fwd2_num_out
141
+ assert fwd3_num_out == fwd2_num_out + 1
142
+
143
+ bwd1_num_in = branch_invars_len(bwd1)
144
+ bwd2_num_in = branch_invars_len(bwd2)
145
+ bwd3_num_in = branch_invars_len(bwd3)
146
+ assert bwd1_num_in == bwd2_num_in
147
+ assert bwd3_num_in == bwd2_num_in + 1
148
+
149
+ def testOneBranchSwitch(self):
150
+ branch = lambda x: -x
151
+ f = lambda i, x: bc.transform.switch(i, [branch], x)
152
+ x = 7.
153
+ self.assertEqual(f(-1, x), branch(x))
154
+ self.assertEqual(f(0, x), branch(x))
155
+ self.assertEqual(f(1, x), branch(x))
156
+ cf = jax.jit(f)
157
+ self.assertEqual(cf(-1, x), branch(x))
158
+ self.assertEqual(cf(0, x), branch(x))
159
+ self.assertEqual(cf(1, x), branch(x))
160
+ cf = jax.jit(f, static_argnums=0)
161
+ self.assertEqual(cf(-1, x), branch(x))
162
+ self.assertEqual(cf(0, x), branch(x))
163
+ self.assertEqual(cf(1, x), branch(x))
164
+
165
+
166
+ class TestIfElse(unittest.TestCase):
167
+ def test1(self):
168
+ def f(a):
169
+ return bc.transform.ifelse(conditions=[a < 0,
170
+ a >= 0 and a < 2,
171
+ a >= 2 and a < 5,
172
+ a >= 5 and a < 10,
173
+ a >= 10],
174
+ branches=[lambda: 1,
175
+ lambda: 2,
176
+ lambda: 3,
177
+ lambda: 4,
178
+ lambda: 5])
179
+
180
+ self.assertTrue(f(3) == 3)
181
+ self.assertTrue(f(1) == 2)
182
+ self.assertTrue(f(-1) == 1)
183
+
184
+ def test_vmap(self):
185
+ def f(operands):
186
+ f = lambda a: bc.transform.ifelse([a > 10,
187
+ jnp.logical_and(a <= 10, a > 5),
188
+ jnp.logical_and(a <= 5, a > 2),
189
+ jnp.logical_and(a <= 2, a > 0),
190
+ a <= 0],
191
+ [lambda _: 1,
192
+ lambda _: 2,
193
+ lambda _: 3,
194
+ lambda _: 4,
195
+ lambda _: 5, ],
196
+ a)
197
+ return jax.vmap(f)(operands)
198
+
199
+ r = f(bc.random.randint(-20, 20, 200))
200
+ self.assertTrue(r.size == 200)
201
+
202
+ def test_grad1(self):
203
+ def F2(x):
204
+ return bc.transform.ifelse((x >= 10, x < 10),
205
+ [lambda x: x, lambda x: x ** 2, ],
206
+ x)
207
+
208
+ self.assertTrue(jax.grad(F2)(9.0) == 18.)
209
+ self.assertTrue(jax.grad(F2)(11.0) == 1.)
210
+
211
+ def test_grad2(self):
212
+ def F3(x):
213
+ return bc.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
+ [lambda x: x,
215
+ lambda x: x ** 2,
216
+ lambda x: x ** 4, ],
217
+ x)
218
+
219
+ self.assertTrue(jax.grad(F3)(9.0) == 18.)
220
+ self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -0,0 +1,239 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from collections.abc import Iterable, Sequence
20
+ from typing import (Any, Callable, Union)
21
+
22
+ import jax
23
+ from jax._src import sharding_impls
24
+ from jax.lib import xla_client as xc
25
+
26
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
27
+ from brainstate._utils import set_module_as
28
+
29
+ __all__ = ['jit']
30
+
31
+
32
+ class JittedFunction(Callable):
33
+ """
34
+ A wrapped version of ``fun``, set up for just-in-time compilation.
35
+ """
36
+ origin_fun: Callable # the original function
37
+ stateful_fun: StatefulFunction # the stateful function for extracting states
38
+ jitted_fun: jax.stages.Wrapped # the jitted function
39
+ clear_cache: Callable # clear the cache of the jitted function
40
+
41
+
42
+ def _get_jitted_fun(
43
+ fun: Callable,
44
+ in_shardings,
45
+ out_shardings,
46
+ static_argnums,
47
+ donate_argnums,
48
+ donate_argnames,
49
+ keep_unused,
50
+ device,
51
+ backend,
52
+ inline,
53
+ abstracted_axes,
54
+ **kwargs
55
+ ) -> JittedFunction:
56
+ static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
57
+ # TODO: add to cache stack for clear_cache
58
+ fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
59
+ jit_fun = jax.jit(fun.jaxpr_call,
60
+ static_argnums=tuple(i + 1 for i in static_argnums),
61
+ donate_argnums=donate_argnums,
62
+ donate_argnames=donate_argnames,
63
+ keep_unused=keep_unused,
64
+ device=device,
65
+ backend=backend,
66
+ inline=inline,
67
+ in_shardings=in_shardings,
68
+ out_shardings=out_shardings,
69
+ abstracted_axes=abstracted_axes,
70
+ **kwargs)
71
+
72
+ @functools.wraps(fun.fun)
73
+ def jitted_fun(*args, **params):
74
+ if jax.config.jax_disable_jit:
75
+ return fun.fun(*args, **params)
76
+ states = fun.compile_and_get_states_by_static_args(*args, **kwargs)
77
+ state_vals, outs = jit_fun([st.value for st in states], *args, **params)
78
+ _assign_state_values(states, state_vals)
79
+ return outs
80
+
81
+ def clear_cache():
82
+ # clear the cache of the stateful function
83
+ fun.clear_cache()
84
+ # clear the cache of the jitted function
85
+ jit_fun.clear_cache()
86
+
87
+ jitted_fun: JittedFunction
88
+ # the original function
89
+ jitted_fun.origin_fun = fun.fun
90
+ # the stateful function for extracting states
91
+ jitted_fun.stateful_fun = fun
92
+ # the jitted function
93
+ jitted_fun.jitted_fun = jit_fun
94
+ # clear cache
95
+ jitted_fun.clear_cache = clear_cache
96
+
97
+ return jitted_fun
98
+
99
+
100
+ @set_module_as('brainstate.transform')
101
+ def jit(
102
+ fun: Callable = None,
103
+ in_shardings=sharding_impls.UNSPECIFIED,
104
+ out_shardings=sharding_impls.UNSPECIFIED,
105
+ static_argnums: int | Sequence[int] | None = None,
106
+ donate_argnums: int | Sequence[int] | None = None,
107
+ donate_argnames: str | Iterable[str] | None = None,
108
+ keep_unused: bool = False,
109
+ device: xc.Device | None = None,
110
+ backend: str | None = None,
111
+ inline: bool = False,
112
+ abstracted_axes: Any | None = None,
113
+ **kwargs
114
+ ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
115
+ """
116
+ Sets up ``fun`` for just-in-time compilation with XLA.
117
+
118
+ Does not support setting ``static_argnames`` as in ``jax.jit()``.
119
+
120
+
121
+ Args:
122
+ fun: Function to be jitted.
123
+ in_shardings: Pytree of structure matching that of arguments to ``fun``,
124
+ with all actual arguments replaced by resource assignment specifications.
125
+ It is also valid to specify a pytree prefix (e.g. one value in place of a
126
+ whole subtree), in which case the leaves get broadcast to all values in
127
+ that subtree.
128
+
129
+ The ``in_shardings`` argument is optional. JAX will infer the shardings
130
+ from the input :py:class:`jax.Array`'s and defaults to replicating the input
131
+ if the sharding cannot be inferred.
132
+
133
+ The valid resource assignment specifications are:
134
+ - :py:class:`XLACompatibleSharding`, which will decide how the value
135
+ will be partitioned. With this, using a mesh context manager is not
136
+ required.
137
+ - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
138
+ it wants.
139
+ For in_shardings, JAX will mark is as replicated but this behavior
140
+ can change in the future.
141
+ For out_shardings, we will rely on the XLA GSPMD partitioner to
142
+ determine the output shardings.
143
+
144
+ The size of every dimension has to be a multiple of the total number of
145
+ resources assigned to it. This is similar to pjit's in_shardings.
146
+ out_shardings: Like ``in_shardings``, but specifies resource
147
+ assignment for function outputs. This is similar to pjit's
148
+ out_shardings.
149
+
150
+ The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
151
+ will use GSPMD's sharding propagation to figure out what the sharding of the
152
+ output(s) should be.
153
+ static_argnums: An optional int or collection of ints that specify which
154
+ positional arguments to treat as static (compile-time constant).
155
+ Operations that only depend on static arguments will be constant-folded in
156
+ Python (during tracing), and so the corresponding argument values can be
157
+ any Python object.
158
+
159
+ Static arguments should be hashable, meaning both ``__hash__`` and
160
+ ``__eq__`` are implemented, and immutable. Calling the jitted function
161
+ with different values for these constants will trigger recompilation.
162
+ Arguments that are not arrays or containers thereof must be marked as
163
+ static.
164
+
165
+ If neither ``static_argnums`` nor ``static_argnames`` is provided, no
166
+ arguments are treated as static. If ``static_argnums`` is not provided but
167
+ ``static_argnames`` is, or vice versa, JAX uses
168
+ :code:`inspect.signature(fun)` to find any positional arguments that
169
+ correspond to ``static_argnames``
170
+ (or vice versa). If both ``static_argnums`` and ``static_argnames`` are
171
+ provided, ``inspect.signature`` is not used, and only actual
172
+ parameters listed in either ``static_argnums`` or ``static_argnames`` will
173
+ be treated as static.
174
+ donate_argnums: Specify which positional argument buffers are "donated" to
175
+ the computation. It is safe to donate argument buffers if you no longer
176
+ need them once the computation has finished. In some cases XLA can make
177
+ use of donated buffers to reduce the amount of memory needed to perform a
178
+ computation, for example recycling one of your input buffers to store a
179
+ result. You should not reuse buffers that you donate to a computation, JAX
180
+ will raise an error if you try to. By default, no argument buffers are
181
+ donated.
182
+
183
+ If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
184
+ arguments are donated. If ``donate_argnums`` is not provided but
185
+ ``donate_argnames`` is, or vice versa, JAX uses
186
+ :code:`inspect.signature(fun)` to find any positional arguments that
187
+ correspond to ``donate_argnames``
188
+ (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
189
+ provided, ``inspect.signature`` is not used, and only actual
190
+ parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
191
+ be donated.
192
+
193
+ For more details on buffer donation see the
194
+ `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
195
+ donate_argnames: An optional string or collection of strings specifying
196
+ which named arguments are donated to the computation. See the
197
+ comment on ``donate_argnums`` for details. If not
198
+ provided but ``donate_argnums`` is set, the default is based on calling
199
+ ``inspect.signature(fun)`` to find corresponding named arguments.
200
+ keep_unused: If `False` (the default), arguments that JAX determines to be
201
+ unused by `fun` *may* be dropped from resulting compiled XLA executables.
202
+ Such arguments will not be transferred to the device nor provided to the
203
+ underlying executable. If `True`, unused arguments will not be pruned.
204
+ device: This is an experimental feature and the API is likely to change.
205
+ Optional, the Device the jitted function will run on. (Available devices
206
+ can be retrieved via :py:func:`jax.devices`.) The default is inherited
207
+ from XLA's DeviceAssignment logic and is usually to use
208
+ ``jax.devices()[0]``.
209
+ backend: This is an experimental feature and the API is likely to change.
210
+ Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
211
+ ``'tpu'``.
212
+ inline: Specify whether this function should be inlined into enclosing
213
+ jaxprs (rather than being represented as an application of the xla_call
214
+ primitive with its own subjaxpr). Default False.
215
+ abstracted_axes:
216
+
217
+ Returns:
218
+ A wrapped version of ``fun``, set up for just-in-time compilation.
219
+ The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
220
+ and has the following attributes and methods:
221
+
222
+ - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
223
+ - ``origin_fun(*args, **kwargs)``: the original function
224
+ - ``jitted_fun(*args, **kwargs)``: the jitted function
225
+ - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
226
+
227
+ """
228
+
229
+ if fun is None:
230
+ def wrapper(fun_again: Callable) -> JittedFunction:
231
+ return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
232
+ donate_argnums, donate_argnames, keep_unused,
233
+ device, backend, inline, abstracted_axes, **kwargs)
234
+ return wrapper
235
+
236
+ else:
237
+ return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
238
+ donate_argnums, donate_argnames, keep_unused,
239
+ device, backend, inline, abstracted_axes, **kwargs)
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from functools import wraps, partial
19
+ from typing import Callable, Union
20
+
21
+ import jax
22
+ from jax import numpy as jnp
23
+ from jax.core import Primitive, ShapedArray
24
+ from jax.interpreters import batching, mlir, xla
25
+ from jax.lax import cond
26
+
27
+ from brainstate._utils import set_module_as
28
+
29
+ __all__ = [
30
+ 'jit_error',
31
+ ]
32
+
33
+
34
+ @set_module_as('brainstate.transform')
35
+ def remove_vmap(x, op='any'):
36
+ if op == 'any':
37
+ return _any_without_vmap(x)
38
+ elif op == 'all':
39
+ return _all_without_vmap(x)
40
+ else:
41
+ raise ValueError(f'Do not support type: {op}')
42
+
43
+
44
+ _any_no_vmap_prim = Primitive('any_no_vmap')
45
+
46
+
47
+ def _any_without_vmap(x):
48
+ return _any_no_vmap_prim.bind(x)
49
+
50
+
51
+ def _any_without_vmap_imp(x):
52
+ return jnp.any(x)
53
+
54
+
55
+ def _any_without_vmap_abs(x):
56
+ return ShapedArray(shape=(), dtype=jnp.bool_)
57
+
58
+
59
+ def _any_without_vmap_batch(x, batch_axes):
60
+ (x,) = x
61
+ return _any_without_vmap(x), batching.not_mapped
62
+
63
+
64
+ _any_no_vmap_prim.def_impl(_any_without_vmap_imp)
65
+ _any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
66
+ batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
67
+ if hasattr(xla, "lower_fun"):
68
+ xla.register_translation(_any_no_vmap_prim,
69
+ xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
70
+ mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
71
+
72
+ _all_no_vmap_prim = Primitive('all_no_vmap')
73
+
74
+
75
+ def _all_without_vmap(x):
76
+ return _all_no_vmap_prim.bind(x)
77
+
78
+
79
+ def _all_without_vmap_imp(x):
80
+ return jnp.all(x)
81
+
82
+
83
+ def _all_without_vmap_abs(x):
84
+ return ShapedArray(shape=(), dtype=jnp.bool_)
85
+
86
+
87
+ def _all_without_vmap_batch(x, batch_axes):
88
+ (x,) = x
89
+ return _all_without_vmap(x), batching.not_mapped
90
+
91
+
92
+ _all_no_vmap_prim.def_impl(_all_without_vmap_imp)
93
+ _all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
94
+ batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
95
+ if hasattr(xla, "lower_fun"):
96
+ xla.register_translation(_all_no_vmap_prim,
97
+ xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
98
+ mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
99
+
100
+
101
+ def _err_jit_true_branch(err_fun, x):
102
+ jax.debug.callback(err_fun, x)
103
+ return
104
+
105
+
106
+ def _err_jit_false_branch(x):
107
+ return
108
+
109
+
110
+ def _cond(err_fun, pred, err_arg):
111
+ @wraps(err_fun)
112
+ def true_err_fun(*arg):
113
+ err_fun(*arg)
114
+
115
+ cond(pred,
116
+ partial(_err_jit_true_branch, true_err_fun),
117
+ _err_jit_false_branch,
118
+ err_arg)
119
+
120
+
121
+ def _error_msg(msg, *arg):
122
+ if len(arg) == 0:
123
+ raise ValueError(msg)
124
+ else:
125
+ raise ValueError(msg.format(arg))
126
+
127
+
128
+ @set_module_as('brainstate.transform')
129
+ def jit_error(pred, err_fun: Union[Callable, str], err_arg=None, scope: str = 'any'):
130
+ """Check errors in a jit function.
131
+
132
+ >>> def error(arg):
133
+ >>> raise ValueError(f'error {arg}')
134
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
135
+ >>> jit_error(x.sum() < 5., error, err_arg=x)
136
+
137
+ Parameters
138
+ ----------
139
+ pred: bool, Array
140
+ The boolean prediction.
141
+ err_fun: callable
142
+ The error function, which raise errors.
143
+ err_arg: any
144
+ The arguments which passed into `err_f`.
145
+ scope: str
146
+ The scope of the error message. Can be None, 'all' or 'any'.
147
+ """
148
+ if isinstance(err_fun, str):
149
+ err_fun = partial(_error_msg, err_fun)
150
+ if scope is None:
151
+ pred = pred
152
+ elif scope == 'all':
153
+ pred = remove_vmap(pred, 'all')
154
+ elif scope == 'any':
155
+ pred = remove_vmap(pred, 'any')
156
+ else:
157
+ raise ValueError(f"Unknown scope: {scope}")
158
+ _cond(err_fun, pred, err_arg)