brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/transform/__init__.py
CHANGED
@@ -1,59 +1,56 @@
|
|
1
|
-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from ._ad_checkpoint import *
|
18
|
-
from ._ad_checkpoint import __all__ as _ad_checkpoint_all
|
19
|
-
from ._autograd import *
|
20
|
-
from ._autograd import __all__ as _autograd_all
|
21
|
-
from ._conditions import *
|
22
|
-
from ._conditions import __all__ as _conditions_all
|
23
|
-
from ._error_if import *
|
24
|
-
from ._error_if import __all__ as _error_if_all
|
25
|
-
from .
|
26
|
-
from .
|
27
|
-
from ._jit import *
|
28
|
-
from ._jit import __all__ as _jit_all
|
29
|
-
from ._loop_collect_return import *
|
30
|
-
from ._loop_collect_return import __all__ as _loop_collect_return_all
|
31
|
-
from ._loop_no_collection import *
|
32
|
-
from ._loop_no_collection import __all__ as _loop_no_collection_all
|
33
|
-
from ._make_jaxpr import *
|
34
|
-
from ._make_jaxpr import __all__ as _make_jaxpr_all
|
35
|
-
from ._mapping import *
|
36
|
-
from ._mapping import __all__ as _mapping_all
|
37
|
-
from ._progress_bar import *
|
38
|
-
from ._progress_bar import __all__ as _progress_bar_all
|
39
|
-
from .
|
40
|
-
from .
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
__all__
|
45
|
-
|
46
|
-
|
47
|
-
del
|
48
|
-
del
|
49
|
-
del
|
50
|
-
del
|
51
|
-
del
|
52
|
-
del
|
53
|
-
del
|
54
|
-
del
|
55
|
-
del
|
56
|
-
del
|
57
|
-
del _progress_bar_all
|
58
|
-
del _random_all
|
59
|
-
del _unvmap_all
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from ._ad_checkpoint import *
|
18
|
+
from ._ad_checkpoint import __all__ as _ad_checkpoint_all
|
19
|
+
from ._autograd import *
|
20
|
+
from ._autograd import __all__ as _autograd_all
|
21
|
+
from ._conditions import *
|
22
|
+
from ._conditions import __all__ as _conditions_all
|
23
|
+
from ._error_if import *
|
24
|
+
from ._error_if import __all__ as _error_if_all
|
25
|
+
from ._find_state import *
|
26
|
+
from ._find_state import __all__ as _find_all
|
27
|
+
from ._jit import *
|
28
|
+
from ._jit import __all__ as _jit_all
|
29
|
+
from ._loop_collect_return import *
|
30
|
+
from ._loop_collect_return import __all__ as _loop_collect_return_all
|
31
|
+
from ._loop_no_collection import *
|
32
|
+
from ._loop_no_collection import __all__ as _loop_no_collection_all
|
33
|
+
from ._make_jaxpr import *
|
34
|
+
from ._make_jaxpr import __all__ as _make_jaxpr_all
|
35
|
+
from ._mapping import *
|
36
|
+
from ._mapping import __all__ as _mapping_all
|
37
|
+
from ._progress_bar import *
|
38
|
+
from ._progress_bar import __all__ as _progress_bar_all
|
39
|
+
from ._unvmap import *
|
40
|
+
from ._unvmap import __all__ as _unvmap_all
|
41
|
+
|
42
|
+
__all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all + _find_all
|
43
|
+
__all__ += _jit_all + _loop_collect_return_all + _loop_no_collection_all
|
44
|
+
__all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all + _unvmap_all
|
45
|
+
del _find_all
|
46
|
+
del _ad_checkpoint_all
|
47
|
+
del _autograd_all
|
48
|
+
del _conditions_all
|
49
|
+
del _error_if_all
|
50
|
+
del _jit_all
|
51
|
+
del _loop_collect_return_all
|
52
|
+
del _loop_no_collection_all
|
53
|
+
del _make_jaxpr_all
|
54
|
+
del _mapping_all
|
55
|
+
del _progress_bar_all
|
56
|
+
del _unvmap_all
|
@@ -1,176 +1,176 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Callable, Tuple, Union
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate._utils import set_module_as
|
22
|
-
from brainstate.typing import Missing
|
23
|
-
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'checkpoint',
|
27
|
-
'remat'
|
28
|
-
]
|
29
|
-
|
30
|
-
|
31
|
-
@set_module_as('brainstate.transform')
|
32
|
-
def checkpoint(
|
33
|
-
fun: Callable = Missing(),
|
34
|
-
*,
|
35
|
-
prevent_cse: bool = True,
|
36
|
-
policy: Callable[..., bool] | None = None,
|
37
|
-
static_argnums: int | Tuple[int, ...] = (),
|
38
|
-
) -> Union[Callable, Callable[[Callable], Callable]]:
|
39
|
-
"""Make ``fun`` recompute internal linearization points when differentiated.
|
40
|
-
|
41
|
-
This decorator wraps :func:`jax.checkpoint` (also exposed as :func:`jax.remat`) to
|
42
|
-
rematerialize intermediate values during reverse-mode automatic differentiation.
|
43
|
-
It allows trading additional computation for reduced peak memory when evaluating
|
44
|
-
functions with :func:`jax.grad`, :func:`jax.vjp`, or :func:`jax.linearize`.
|
45
|
-
|
46
|
-
Parameters
|
47
|
-
----------
|
48
|
-
fun : Callable, optional
|
49
|
-
Function whose autodiff evaluation strategy should use rematerialization.
|
50
|
-
Positional and keyword arguments may be arrays, scalars, or arbitrarily
|
51
|
-
nested Python containers of those types.
|
52
|
-
prevent_cse : bool, default True
|
53
|
-
Whether to prevent common-subexpression-elimination (CSE) optimizations in
|
54
|
-
the generated HLO. Disabling CSE is usually necessary under
|
55
|
-
:func:`jax.jit`/:func:`jax.pmap` so that rematerialization is not optimized
|
56
|
-
away. Set to ``False`` when decorating code inside control-flow primitives
|
57
|
-
(for example, :func:`jax.lax.scan`) where CSE is already handled safely.
|
58
|
-
policy : Callable[..., bool], optional
|
59
|
-
Callable drawn from :mod:`jax.checkpoint_policies` that decides which
|
60
|
-
primitive outputs may be saved as residuals instead of being recomputed. The
|
61
|
-
callable receives type-level information about a primitive application and
|
62
|
-
returns ``True`` when the corresponding value can be cached.
|
63
|
-
static_argnums : int or tuple of int, optional
|
64
|
-
Indices of arguments to treat as static during tracing. Marking arguments as
|
65
|
-
static can avoid :class:`jax.errors.ConcretizationTypeError` at the expense
|
66
|
-
of additional retracing when those arguments change.
|
67
|
-
|
68
|
-
Returns
|
69
|
-
-------
|
70
|
-
callable
|
71
|
-
A function with the same input/output behaviour as ``fun``. When
|
72
|
-
differentiated, it rematerializes intermediate linearization points instead
|
73
|
-
of storing them, reducing memory pressure at the cost of extra computation.
|
74
|
-
|
75
|
-
Notes
|
76
|
-
-----
|
77
|
-
Reverse-mode autodiff normally stores all linearization points during the
|
78
|
-
forward pass so that they can be reused during the backward pass. This storage
|
79
|
-
can dominate memory usage, particularly on accelerators where memory accesses
|
80
|
-
are expensive. Applying ``checkpoint`` causes those values to be recomputed on
|
81
|
-
the backward pass from the saved inputs instead of being cached.
|
82
|
-
|
83
|
-
The decorator can be composed recursively to express sophisticated
|
84
|
-
rematerialization strategies. For functions with data-dependent Python control
|
85
|
-
flow, specify ``static_argnums`` (and, if needed,
|
86
|
-
:func:`jax.ensure_compile_time_eval`) so that branching conditions are evaluated
|
87
|
-
at trace time.
|
88
|
-
|
89
|
-
Examples
|
90
|
-
--------
|
91
|
-
Use :func:`jax.checkpoint` to trade computation for memory:
|
92
|
-
|
93
|
-
.. code-block:: python
|
94
|
-
|
95
|
-
>>> import brainstate
|
96
|
-
>>> import jax.numpy as jnp
|
97
|
-
|
98
|
-
>>> @brainstate.transform.checkpoint
|
99
|
-
... def g(x):
|
100
|
-
... y = jnp.sin(x)
|
101
|
-
... z = jnp.sin(y)
|
102
|
-
... return z
|
103
|
-
|
104
|
-
>>> value, grad = jax.value_and_grad(g)(2.0)
|
105
|
-
|
106
|
-
Compose checkpoints recursively to control the rematerialization granularity:
|
107
|
-
|
108
|
-
.. code-block:: python
|
109
|
-
|
110
|
-
>>> import jax
|
111
|
-
|
112
|
-
>>> def recursive_checkpoint(funs):
|
113
|
-
... if len(funs) == 1:
|
114
|
-
... return funs[0]
|
115
|
-
... if len(funs) == 2:
|
116
|
-
... f1, f2 = funs
|
117
|
-
... return lambda x: f1(f2(x))
|
118
|
-
... f1 = recursive_checkpoint(funs[: len(funs) // 2])
|
119
|
-
... f2 = recursive_checkpoint(funs[len(funs) // 2 :])
|
120
|
-
... return lambda x: f1(jax.checkpoint(f2)(x))
|
121
|
-
|
122
|
-
When control flow depends on argument values, mark the relevant arguments as
|
123
|
-
static:
|
124
|
-
|
125
|
-
.. code-block:: python
|
126
|
-
|
127
|
-
>>> from functools import partial
|
128
|
-
>>> import jax
|
129
|
-
>>> import brainstate
|
130
|
-
|
131
|
-
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
132
|
-
... def foo(x, is_training):
|
133
|
-
... if is_training:
|
134
|
-
... ...
|
135
|
-
... else:
|
136
|
-
... ...
|
137
|
-
|
138
|
-
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
139
|
-
... def foo_with_eval(x, y):
|
140
|
-
... with jax.ensure_compile_time_eval():
|
141
|
-
... y_pos = y > 0
|
142
|
-
... if y_pos:
|
143
|
-
... ...
|
144
|
-
... else:
|
145
|
-
... ...
|
146
|
-
|
147
|
-
As an alternative to ``static_argnums``, compute values that drive control flow
|
148
|
-
outside the decorated function and close over them in the JAX-traced callable.
|
149
|
-
"""
|
150
|
-
if isinstance(fun, Missing):
|
151
|
-
return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
|
152
|
-
|
153
|
-
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
154
|
-
fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
|
155
|
-
checkpointed_fun = jax.checkpoint(
|
156
|
-
fun.jaxpr_call,
|
157
|
-
prevent_cse=prevent_cse,
|
158
|
-
policy=policy,
|
159
|
-
static_argnums=tuple(i + 1 for i in static_argnums)
|
160
|
-
)
|
161
|
-
|
162
|
-
@functools.wraps(fun.fun)
|
163
|
-
def remat_fun(*args, **params):
|
164
|
-
# compile the function and get the state trace
|
165
|
-
state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
|
166
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
167
|
-
# call the checkpointed function
|
168
|
-
write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
|
169
|
-
# write the state values back to the states
|
170
|
-
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
171
|
-
return outs
|
172
|
-
|
173
|
-
return remat_fun
|
174
|
-
|
175
|
-
|
176
|
-
remat = checkpoint
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Callable, Tuple, Union
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate._utils import set_module_as
|
22
|
+
from brainstate.typing import Missing
|
23
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'checkpoint',
|
27
|
+
'remat'
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
@set_module_as('brainstate.transform')
|
32
|
+
def checkpoint(
|
33
|
+
fun: Callable = Missing(),
|
34
|
+
*,
|
35
|
+
prevent_cse: bool = True,
|
36
|
+
policy: Callable[..., bool] | None = None,
|
37
|
+
static_argnums: int | Tuple[int, ...] = (),
|
38
|
+
) -> Union[Callable, Callable[[Callable], Callable]]:
|
39
|
+
"""Make ``fun`` recompute internal linearization points when differentiated.
|
40
|
+
|
41
|
+
This decorator wraps :func:`jax.checkpoint` (also exposed as :func:`jax.remat`) to
|
42
|
+
rematerialize intermediate values during reverse-mode automatic differentiation.
|
43
|
+
It allows trading additional computation for reduced peak memory when evaluating
|
44
|
+
functions with :func:`jax.grad`, :func:`jax.vjp`, or :func:`jax.linearize`.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
fun : Callable, optional
|
49
|
+
Function whose autodiff evaluation strategy should use rematerialization.
|
50
|
+
Positional and keyword arguments may be arrays, scalars, or arbitrarily
|
51
|
+
nested Python containers of those types.
|
52
|
+
prevent_cse : bool, default True
|
53
|
+
Whether to prevent common-subexpression-elimination (CSE) optimizations in
|
54
|
+
the generated HLO. Disabling CSE is usually necessary under
|
55
|
+
:func:`jax.jit`/:func:`jax.pmap` so that rematerialization is not optimized
|
56
|
+
away. Set to ``False`` when decorating code inside control-flow primitives
|
57
|
+
(for example, :func:`jax.lax.scan`) where CSE is already handled safely.
|
58
|
+
policy : Callable[..., bool], optional
|
59
|
+
Callable drawn from :mod:`jax.checkpoint_policies` that decides which
|
60
|
+
primitive outputs may be saved as residuals instead of being recomputed. The
|
61
|
+
callable receives type-level information about a primitive application and
|
62
|
+
returns ``True`` when the corresponding value can be cached.
|
63
|
+
static_argnums : int or tuple of int, optional
|
64
|
+
Indices of arguments to treat as static during tracing. Marking arguments as
|
65
|
+
static can avoid :class:`jax.errors.ConcretizationTypeError` at the expense
|
66
|
+
of additional retracing when those arguments change.
|
67
|
+
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
callable
|
71
|
+
A function with the same input/output behaviour as ``fun``. When
|
72
|
+
differentiated, it rematerializes intermediate linearization points instead
|
73
|
+
of storing them, reducing memory pressure at the cost of extra computation.
|
74
|
+
|
75
|
+
Notes
|
76
|
+
-----
|
77
|
+
Reverse-mode autodiff normally stores all linearization points during the
|
78
|
+
forward pass so that they can be reused during the backward pass. This storage
|
79
|
+
can dominate memory usage, particularly on accelerators where memory accesses
|
80
|
+
are expensive. Applying ``checkpoint`` causes those values to be recomputed on
|
81
|
+
the backward pass from the saved inputs instead of being cached.
|
82
|
+
|
83
|
+
The decorator can be composed recursively to express sophisticated
|
84
|
+
rematerialization strategies. For functions with data-dependent Python control
|
85
|
+
flow, specify ``static_argnums`` (and, if needed,
|
86
|
+
:func:`jax.ensure_compile_time_eval`) so that branching conditions are evaluated
|
87
|
+
at trace time.
|
88
|
+
|
89
|
+
Examples
|
90
|
+
--------
|
91
|
+
Use :func:`jax.checkpoint` to trade computation for memory:
|
92
|
+
|
93
|
+
.. code-block:: python
|
94
|
+
|
95
|
+
>>> import brainstate
|
96
|
+
>>> import jax.numpy as jnp
|
97
|
+
|
98
|
+
>>> @brainstate.transform.checkpoint
|
99
|
+
... def g(x):
|
100
|
+
... y = jnp.sin(x)
|
101
|
+
... z = jnp.sin(y)
|
102
|
+
... return z
|
103
|
+
|
104
|
+
>>> value, grad = jax.value_and_grad(g)(2.0)
|
105
|
+
|
106
|
+
Compose checkpoints recursively to control the rematerialization granularity:
|
107
|
+
|
108
|
+
.. code-block:: python
|
109
|
+
|
110
|
+
>>> import jax
|
111
|
+
|
112
|
+
>>> def recursive_checkpoint(funs):
|
113
|
+
... if len(funs) == 1:
|
114
|
+
... return funs[0]
|
115
|
+
... if len(funs) == 2:
|
116
|
+
... f1, f2 = funs
|
117
|
+
... return lambda x: f1(f2(x))
|
118
|
+
... f1 = recursive_checkpoint(funs[: len(funs) // 2])
|
119
|
+
... f2 = recursive_checkpoint(funs[len(funs) // 2 :])
|
120
|
+
... return lambda x: f1(jax.checkpoint(f2)(x))
|
121
|
+
|
122
|
+
When control flow depends on argument values, mark the relevant arguments as
|
123
|
+
static:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> from functools import partial
|
128
|
+
>>> import jax
|
129
|
+
>>> import brainstate
|
130
|
+
|
131
|
+
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
132
|
+
... def foo(x, is_training):
|
133
|
+
... if is_training:
|
134
|
+
... ...
|
135
|
+
... else:
|
136
|
+
... ...
|
137
|
+
|
138
|
+
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
139
|
+
... def foo_with_eval(x, y):
|
140
|
+
... with jax.ensure_compile_time_eval():
|
141
|
+
... y_pos = y > 0
|
142
|
+
... if y_pos:
|
143
|
+
... ...
|
144
|
+
... else:
|
145
|
+
... ...
|
146
|
+
|
147
|
+
As an alternative to ``static_argnums``, compute values that drive control flow
|
148
|
+
outside the decorated function and close over them in the JAX-traced callable.
|
149
|
+
"""
|
150
|
+
if isinstance(fun, Missing):
|
151
|
+
return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
|
152
|
+
|
153
|
+
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
154
|
+
fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
|
155
|
+
checkpointed_fun = jax.checkpoint(
|
156
|
+
fun.jaxpr_call,
|
157
|
+
prevent_cse=prevent_cse,
|
158
|
+
policy=policy,
|
159
|
+
static_argnums=tuple(i + 1 for i in static_argnums)
|
160
|
+
)
|
161
|
+
|
162
|
+
@functools.wraps(fun.fun)
|
163
|
+
def remat_fun(*args, **params):
|
164
|
+
# compile the function and get the state trace
|
165
|
+
state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
|
166
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
167
|
+
# call the checkpointed function
|
168
|
+
write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
|
169
|
+
# write the state values back to the states
|
170
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
171
|
+
return outs
|
172
|
+
|
173
|
+
return remat_fun
|
174
|
+
|
175
|
+
|
176
|
+
remat = checkpoint
|
@@ -1,49 +1,49 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import jax
|
17
|
-
import jax.numpy as jnp
|
18
|
-
from absl.testing import absltest
|
19
|
-
|
20
|
-
import brainstate
|
21
|
-
|
22
|
-
|
23
|
-
class TestRemat(absltest.TestCase):
|
24
|
-
def test_basic_remat(self):
|
25
|
-
module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
|
26
|
-
y = module(jnp.ones((1, 2)))
|
27
|
-
assert y.shape == (1, 3)
|
28
|
-
|
29
|
-
def test_remat_with_scan(self):
|
30
|
-
class ScanLinear(brainstate.nn.Module):
|
31
|
-
def __init__(self):
|
32
|
-
super().__init__()
|
33
|
-
self.linear = brainstate.nn.Linear(3, 3)
|
34
|
-
|
35
|
-
def __call__(self, x: jax.Array):
|
36
|
-
@brainstate.compile.remat
|
37
|
-
def fun(x: jax.Array, _):
|
38
|
-
x = self.linear(x)
|
39
|
-
return x, None
|
40
|
-
|
41
|
-
return brainstate.compile.scan(fun, x, None, length=10)[0]
|
42
|
-
|
43
|
-
m = ScanLinear()
|
44
|
-
|
45
|
-
assert m.linear.weight.value['weight'].shape == (3, 3)
|
46
|
-
assert m.linear.weight.value['bias'].shape == (3,)
|
47
|
-
|
48
|
-
y = m(jnp.ones((10, 3)))
|
49
|
-
assert y.shape == (10, 3)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import jax
|
17
|
+
import jax.numpy as jnp
|
18
|
+
from absl.testing import absltest
|
19
|
+
|
20
|
+
import brainstate
|
21
|
+
|
22
|
+
|
23
|
+
class TestRemat(absltest.TestCase):
|
24
|
+
def test_basic_remat(self):
|
25
|
+
module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
|
26
|
+
y = module(jnp.ones((1, 2)))
|
27
|
+
assert y.shape == (1, 3)
|
28
|
+
|
29
|
+
def test_remat_with_scan(self):
|
30
|
+
class ScanLinear(brainstate.nn.Module):
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
self.linear = brainstate.nn.Linear(3, 3)
|
34
|
+
|
35
|
+
def __call__(self, x: jax.Array):
|
36
|
+
@brainstate.compile.remat
|
37
|
+
def fun(x: jax.Array, _):
|
38
|
+
x = self.linear(x)
|
39
|
+
return x, None
|
40
|
+
|
41
|
+
return brainstate.compile.scan(fun, x, None, length=10)[0]
|
42
|
+
|
43
|
+
m = ScanLinear()
|
44
|
+
|
45
|
+
assert m.linear.weight.value['weight'].shape == (3, 3)
|
46
|
+
assert m.linear.weight.value['bias'].shape == (3,)
|
47
|
+
|
48
|
+
y = m(jnp.ones((10, 3)))
|
49
|
+
assert y.shape == (10, 3)
|