brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Callable, Tuple, Union
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate._utils import set_module_as
|
22
|
+
from brainstate.typing import Missing
|
23
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'checkpoint',
|
27
|
+
'remat'
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
@set_module_as('brainstate.transform')
|
32
|
+
def checkpoint(
|
33
|
+
fun: Callable = Missing(),
|
34
|
+
*,
|
35
|
+
prevent_cse: bool = True,
|
36
|
+
policy: Callable[..., bool] | None = None,
|
37
|
+
static_argnums: int | Tuple[int, ...] = (),
|
38
|
+
) -> Union[Callable, Callable[[Callable], Callable]]:
|
39
|
+
"""Make ``fun`` recompute internal linearization points when differentiated.
|
40
|
+
|
41
|
+
This decorator wraps :func:`jax.checkpoint` (also exposed as :func:`jax.remat`) to
|
42
|
+
rematerialize intermediate values during reverse-mode automatic differentiation.
|
43
|
+
It allows trading additional computation for reduced peak memory when evaluating
|
44
|
+
functions with :func:`jax.grad`, :func:`jax.vjp`, or :func:`jax.linearize`.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
fun : Callable, optional
|
49
|
+
Function whose autodiff evaluation strategy should use rematerialization.
|
50
|
+
Positional and keyword arguments may be arrays, scalars, or arbitrarily
|
51
|
+
nested Python containers of those types.
|
52
|
+
prevent_cse : bool, default True
|
53
|
+
Whether to prevent common-subexpression-elimination (CSE) optimizations in
|
54
|
+
the generated HLO. Disabling CSE is usually necessary under
|
55
|
+
:func:`jax.jit`/:func:`jax.pmap` so that rematerialization is not optimized
|
56
|
+
away. Set to ``False`` when decorating code inside control-flow primitives
|
57
|
+
(for example, :func:`jax.lax.scan`) where CSE is already handled safely.
|
58
|
+
policy : Callable[..., bool], optional
|
59
|
+
Callable drawn from :mod:`jax.checkpoint_policies` that decides which
|
60
|
+
primitive outputs may be saved as residuals instead of being recomputed. The
|
61
|
+
callable receives type-level information about a primitive application and
|
62
|
+
returns ``True`` when the corresponding value can be cached.
|
63
|
+
static_argnums : int or tuple of int, optional
|
64
|
+
Indices of arguments to treat as static during tracing. Marking arguments as
|
65
|
+
static can avoid :class:`jax.errors.ConcretizationTypeError` at the expense
|
66
|
+
of additional retracing when those arguments change.
|
67
|
+
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
callable
|
71
|
+
A function with the same input/output behaviour as ``fun``. When
|
72
|
+
differentiated, it rematerializes intermediate linearization points instead
|
73
|
+
of storing them, reducing memory pressure at the cost of extra computation.
|
74
|
+
|
75
|
+
Notes
|
76
|
+
-----
|
77
|
+
Reverse-mode autodiff normally stores all linearization points during the
|
78
|
+
forward pass so that they can be reused during the backward pass. This storage
|
79
|
+
can dominate memory usage, particularly on accelerators where memory accesses
|
80
|
+
are expensive. Applying ``checkpoint`` causes those values to be recomputed on
|
81
|
+
the backward pass from the saved inputs instead of being cached.
|
82
|
+
|
83
|
+
The decorator can be composed recursively to express sophisticated
|
84
|
+
rematerialization strategies. For functions with data-dependent Python control
|
85
|
+
flow, specify ``static_argnums`` (and, if needed,
|
86
|
+
:func:`jax.ensure_compile_time_eval`) so that branching conditions are evaluated
|
87
|
+
at trace time.
|
88
|
+
|
89
|
+
Examples
|
90
|
+
--------
|
91
|
+
Use :func:`jax.checkpoint` to trade computation for memory:
|
92
|
+
|
93
|
+
.. code-block:: python
|
94
|
+
|
95
|
+
>>> import brainstate
|
96
|
+
>>> import jax.numpy as jnp
|
97
|
+
|
98
|
+
>>> @brainstate.transform.checkpoint
|
99
|
+
... def g(x):
|
100
|
+
... y = jnp.sin(x)
|
101
|
+
... z = jnp.sin(y)
|
102
|
+
... return z
|
103
|
+
|
104
|
+
>>> value, grad = jax.value_and_grad(g)(2.0)
|
105
|
+
|
106
|
+
Compose checkpoints recursively to control the rematerialization granularity:
|
107
|
+
|
108
|
+
.. code-block:: python
|
109
|
+
|
110
|
+
>>> import jax
|
111
|
+
|
112
|
+
>>> def recursive_checkpoint(funs):
|
113
|
+
... if len(funs) == 1:
|
114
|
+
... return funs[0]
|
115
|
+
... if len(funs) == 2:
|
116
|
+
... f1, f2 = funs
|
117
|
+
... return lambda x: f1(f2(x))
|
118
|
+
... f1 = recursive_checkpoint(funs[: len(funs) // 2])
|
119
|
+
... f2 = recursive_checkpoint(funs[len(funs) // 2 :])
|
120
|
+
... return lambda x: f1(jax.checkpoint(f2)(x))
|
121
|
+
|
122
|
+
When control flow depends on argument values, mark the relevant arguments as
|
123
|
+
static:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> from functools import partial
|
128
|
+
>>> import jax
|
129
|
+
>>> import brainstate
|
130
|
+
|
131
|
+
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
132
|
+
... def foo(x, is_training):
|
133
|
+
... if is_training:
|
134
|
+
... ...
|
135
|
+
... else:
|
136
|
+
... ...
|
137
|
+
|
138
|
+
>>> @brainstate.transform.checkpoint(static_argnums=(1,))
|
139
|
+
... def foo_with_eval(x, y):
|
140
|
+
... with jax.ensure_compile_time_eval():
|
141
|
+
... y_pos = y > 0
|
142
|
+
... if y_pos:
|
143
|
+
... ...
|
144
|
+
... else:
|
145
|
+
... ...
|
146
|
+
|
147
|
+
As an alternative to ``static_argnums``, compute values that drive control flow
|
148
|
+
outside the decorated function and close over them in the JAX-traced callable.
|
149
|
+
"""
|
150
|
+
if isinstance(fun, Missing):
|
151
|
+
return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
|
152
|
+
|
153
|
+
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
154
|
+
fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
|
155
|
+
checkpointed_fun = jax.checkpoint(
|
156
|
+
fun.jaxpr_call,
|
157
|
+
prevent_cse=prevent_cse,
|
158
|
+
policy=policy,
|
159
|
+
static_argnums=tuple(i + 1 for i in static_argnums)
|
160
|
+
)
|
161
|
+
|
162
|
+
@functools.wraps(fun.fun)
|
163
|
+
def remat_fun(*args, **params):
|
164
|
+
# compile the function and get the state trace
|
165
|
+
state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
|
166
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
167
|
+
# call the checkpointed function
|
168
|
+
write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
|
169
|
+
# write the state values back to the states
|
170
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
171
|
+
return outs
|
172
|
+
|
173
|
+
return remat_fun
|
174
|
+
|
175
|
+
|
176
|
+
remat = checkpoint
|
@@ -1,49 +1,49 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import jax
|
17
|
-
import jax.numpy as jnp
|
18
|
-
from absl.testing import absltest
|
19
|
-
|
20
|
-
import brainstate
|
21
|
-
|
22
|
-
|
23
|
-
class TestRemat(absltest.TestCase):
|
24
|
-
def test_basic_remat(self):
|
25
|
-
module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
|
26
|
-
y = module(jnp.ones((1, 2)))
|
27
|
-
assert y.shape == (1, 3)
|
28
|
-
|
29
|
-
def test_remat_with_scan(self):
|
30
|
-
class ScanLinear(brainstate.nn.Module):
|
31
|
-
def __init__(self):
|
32
|
-
super().__init__()
|
33
|
-
self.linear = brainstate.nn.Linear(3, 3)
|
34
|
-
|
35
|
-
def __call__(self, x: jax.Array):
|
36
|
-
@brainstate.compile.remat
|
37
|
-
def fun(x: jax.Array, _):
|
38
|
-
x = self.linear(x)
|
39
|
-
return x, None
|
40
|
-
|
41
|
-
return brainstate.compile.scan(fun, x, None, length=10)[0]
|
42
|
-
|
43
|
-
m = ScanLinear()
|
44
|
-
|
45
|
-
assert m.linear.weight.value['weight'].shape == (3, 3)
|
46
|
-
assert m.linear.weight.value['bias'].shape == (3,)
|
47
|
-
|
48
|
-
y = m(jnp.ones((10, 3)))
|
49
|
-
assert y.shape == (10, 3)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import jax
|
17
|
+
import jax.numpy as jnp
|
18
|
+
from absl.testing import absltest
|
19
|
+
|
20
|
+
import brainstate
|
21
|
+
|
22
|
+
|
23
|
+
class TestRemat(absltest.TestCase):
|
24
|
+
def test_basic_remat(self):
|
25
|
+
module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
|
26
|
+
y = module(jnp.ones((1, 2)))
|
27
|
+
assert y.shape == (1, 3)
|
28
|
+
|
29
|
+
def test_remat_with_scan(self):
|
30
|
+
class ScanLinear(brainstate.nn.Module):
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
self.linear = brainstate.nn.Linear(3, 3)
|
34
|
+
|
35
|
+
def __call__(self, x: jax.Array):
|
36
|
+
@brainstate.compile.remat
|
37
|
+
def fun(x: jax.Array, _):
|
38
|
+
x = self.linear(x)
|
39
|
+
return x, None
|
40
|
+
|
41
|
+
return brainstate.compile.scan(fun, x, None, length=10)[0]
|
42
|
+
|
43
|
+
m = ScanLinear()
|
44
|
+
|
45
|
+
assert m.linear.weight.value['weight'].shape == (3, 3)
|
46
|
+
assert m.linear.weight.value['bias'].shape == (3,)
|
47
|
+
|
48
|
+
y = m(jnp.ones((10, 3)))
|
49
|
+
assert y.shape == (10, 3)
|