brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,210 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax.core
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
import brainstate as bst
|
24
|
+
|
25
|
+
|
26
|
+
class TestVmap(unittest.TestCase):
|
27
|
+
def test_vmap_return_keep_reference_return(self):
|
28
|
+
@bst.augment.vmap(in_axes=0, out_axes=0)
|
29
|
+
def create_model(key):
|
30
|
+
bst.random.set_key(key)
|
31
|
+
m1 = bst.nn.Linear(2, 3)
|
32
|
+
|
33
|
+
m2 = bst.nn.Linear(3, 4)
|
34
|
+
m2.a = m1
|
35
|
+
m3 = bst.nn.Linear(3, 5)
|
36
|
+
m3.a = m1
|
37
|
+
self.assertTrue(id(m2.a) == id(m3.a))
|
38
|
+
return m2, m3
|
39
|
+
|
40
|
+
m2, m3 = create_model(bst.random.split_key(10))
|
41
|
+
self.assertTrue(id(m2.a) == id(m3.a))
|
42
|
+
jax.core.concrete_or_error(None, bst.random.DEFAULT.value)
|
43
|
+
|
44
|
+
def test_vmap_return_keep_reference_pass_into_fun(self):
|
45
|
+
@bst.augment.vmap(in_axes=(None, None, 0), out_axes=0)
|
46
|
+
def run_model(m2, m3, x):
|
47
|
+
self.assertTrue(id(m2.a) == id(m3.a))
|
48
|
+
self.assertTrue(id(m2) != m2_id)
|
49
|
+
self.assertTrue(id(m3) != m3_id)
|
50
|
+
return m2(x), m3(x)
|
51
|
+
|
52
|
+
m1 = bst.nn.Linear(2, 3)
|
53
|
+
m2 = bst.nn.Linear(4, 3)
|
54
|
+
m2.a = m1
|
55
|
+
m3 = bst.nn.Linear(4, 5)
|
56
|
+
m3.a = m1
|
57
|
+
m3_id = id(m3)
|
58
|
+
m2_id = id(m2)
|
59
|
+
r1, r2 = run_model(m2, m3, jnp.ones((4, 3, 4)))
|
60
|
+
|
61
|
+
def test_vmap_set_key(self):
|
62
|
+
@bst.augment.vmap(in_axes=0, out_axes=0)
|
63
|
+
def create_model(key):
|
64
|
+
bst.random.set_key(key)
|
65
|
+
return bst.nn.Linear(2, 3)
|
66
|
+
|
67
|
+
model = create_model(bst.random.split_keys(10))
|
68
|
+
print(model.weight.value_call(jnp.shape))
|
69
|
+
model.weight.value_call(lambda x: jax.core.concrete_or_error(None, x))
|
70
|
+
bst.random.seed()
|
71
|
+
|
72
|
+
def test_vmap_input(self):
|
73
|
+
model = bst.nn.Linear(2, 3)
|
74
|
+
print(id(model), id(model.weight))
|
75
|
+
model_id = id(model)
|
76
|
+
weight_id = id(model.weight)
|
77
|
+
|
78
|
+
x = jnp.ones((5, 2))
|
79
|
+
|
80
|
+
@bst.augment.vmap
|
81
|
+
def forward(x):
|
82
|
+
self.assertTrue(id(model) == model_id)
|
83
|
+
self.assertTrue(id(model.weight) == weight_id)
|
84
|
+
return model(x)
|
85
|
+
|
86
|
+
y = forward(x)
|
87
|
+
self.assertTrue(y.shape == (5, 3))
|
88
|
+
print(y.shape)
|
89
|
+
print(model.weight.value_call(jnp.shape))
|
90
|
+
print(model.weight.value)
|
91
|
+
|
92
|
+
def test_vmap_model(self):
|
93
|
+
model = bst.nn.Linear(2, 3)
|
94
|
+
model_id = id(model)
|
95
|
+
weight_id = id(model.weight)
|
96
|
+
print(id(model), id(model.weight))
|
97
|
+
x = jnp.ones((5, 2))
|
98
|
+
|
99
|
+
@bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
100
|
+
def forward(model, x):
|
101
|
+
self.assertTrue(id(model) != model_id)
|
102
|
+
self.assertTrue(id(model.weight) != weight_id)
|
103
|
+
print(id(model), id(model.weight))
|
104
|
+
return model(x)
|
105
|
+
|
106
|
+
y = forward(model, x)
|
107
|
+
print(y.shape)
|
108
|
+
print(model.weight.value_call(jnp.shape))
|
109
|
+
print(model.weight.value)
|
110
|
+
|
111
|
+
def test_vmap1(self):
|
112
|
+
model = bst.nn.Linear(2, 3)
|
113
|
+
x = jnp.ones((5, 2))
|
114
|
+
|
115
|
+
@bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
116
|
+
def forward(model, x):
|
117
|
+
return model(x)
|
118
|
+
|
119
|
+
y = forward(model, x)
|
120
|
+
print(y.shape)
|
121
|
+
|
122
|
+
def test_vmap2(self):
|
123
|
+
class LinearEnsemble(bst.nn.Module):
|
124
|
+
def __init__(self, num):
|
125
|
+
super().__init__()
|
126
|
+
self.w = bst.ParamState(bst.random.random((num, 2, 3)))
|
127
|
+
|
128
|
+
model = LinearEnsemble(5)
|
129
|
+
x = jnp.ones((2,))
|
130
|
+
|
131
|
+
@bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
132
|
+
def forward(model, x):
|
133
|
+
return jnp.dot(x, model.w.value)
|
134
|
+
|
135
|
+
y = forward(model, x)
|
136
|
+
print(y.shape)
|
137
|
+
|
138
|
+
def test_vmap3(self):
|
139
|
+
class Foo(bst.nn.Module):
|
140
|
+
def __init__(self):
|
141
|
+
super().__init__()
|
142
|
+
self.a = bst.ParamState(jnp.arange(4))
|
143
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
144
|
+
|
145
|
+
state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
|
146
|
+
|
147
|
+
@bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
|
148
|
+
def mul(foo):
|
149
|
+
return foo.a.value * foo.b.value
|
150
|
+
|
151
|
+
foo = Foo()
|
152
|
+
y = mul(foo)
|
153
|
+
print(y.shape)
|
154
|
+
|
155
|
+
def test_vmap4(self):
|
156
|
+
class Foo(bst.nn.Module):
|
157
|
+
def __init__(self):
|
158
|
+
super().__init__()
|
159
|
+
self.a = bst.ParamState(jnp.arange(4))
|
160
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
161
|
+
|
162
|
+
def __call__(self):
|
163
|
+
self.b.value = self.a.value * self.b.value
|
164
|
+
|
165
|
+
@bst.augment.vmap
|
166
|
+
def mul(foo):
|
167
|
+
foo()
|
168
|
+
return foo
|
169
|
+
|
170
|
+
foo = Foo()
|
171
|
+
with bst.StateTraceStack() as trace:
|
172
|
+
m = mul(foo)
|
173
|
+
|
174
|
+
self.assertTrue(m is foo)
|
175
|
+
print(m.a.value, foo.a.value)
|
176
|
+
self.assertTrue(jnp.allclose(m.a.value, foo.a.value))
|
177
|
+
print(m.b.value, foo.b.value)
|
178
|
+
self.assertTrue(jnp.allclose(m.b.value, foo.b.value))
|
179
|
+
print(trace.get_write_states())
|
180
|
+
self.assertTrue(len(trace.get_write_states()) == 1)
|
181
|
+
print(trace.get_read_states())
|
182
|
+
self.assertTrue(len(trace.get_read_states()) == 2)
|
183
|
+
|
184
|
+
def test_vmap5(self):
|
185
|
+
class Foo(bst.nn.Module):
|
186
|
+
def __init__(self):
|
187
|
+
super().__init__()
|
188
|
+
self.a = bst.ParamState(jnp.arange(4))
|
189
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
190
|
+
|
191
|
+
def __call__(self):
|
192
|
+
self.b.value = self.a.value * self.b.value
|
193
|
+
|
194
|
+
@bst.augment.vmap
|
195
|
+
def mul(foo):
|
196
|
+
foo()
|
197
|
+
|
198
|
+
foo = Foo()
|
199
|
+
with bst.StateTraceStack() as trace:
|
200
|
+
mul(foo)
|
201
|
+
|
202
|
+
print(foo.a.value)
|
203
|
+
print(foo.b.value)
|
204
|
+
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
205
|
+
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
206
|
+
|
207
|
+
print(trace.get_write_states())
|
208
|
+
print(trace.get_read_states())
|
209
|
+
|
210
|
+
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from typing import Callable, Sequence, Union
|
20
|
+
|
21
|
+
from brainstate.random import DEFAULT, RandomState
|
22
|
+
from brainstate.typing import Missing
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'restore_rngs'
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
class RngRestore:
|
30
|
+
"""
|
31
|
+
Backup and restore the random state of a sequence of RandomState instances.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, rngs: Sequence[RandomState]):
|
35
|
+
self.rngs: Sequence[RandomState] = rngs
|
36
|
+
self.rng_keys = []
|
37
|
+
|
38
|
+
def backup(self):
|
39
|
+
"""
|
40
|
+
Backup the current random key of the RandomState instances.
|
41
|
+
"""
|
42
|
+
self.rng_keys = [rng.value for rng in self.rngs]
|
43
|
+
|
44
|
+
def restore(self):
|
45
|
+
"""
|
46
|
+
Restore the random key of the RandomState instances.
|
47
|
+
"""
|
48
|
+
for rng, key in zip(self.rngs, self.rng_keys):
|
49
|
+
rng.restore_value(key)
|
50
|
+
self.rng_keys = []
|
51
|
+
|
52
|
+
|
53
|
+
def _rng_backup(
|
54
|
+
fn: Callable,
|
55
|
+
rngs: Union[RandomState, Sequence[RandomState]]
|
56
|
+
) -> Callable:
|
57
|
+
rng_restorer = RngRestore(rngs)
|
58
|
+
|
59
|
+
@functools.wraps(fn)
|
60
|
+
def wrapper(*args, **kwargs):
|
61
|
+
# backup the random state
|
62
|
+
rng_restorer.backup()
|
63
|
+
# call the function
|
64
|
+
out = fn(*args, **kwargs)
|
65
|
+
# restore the random state
|
66
|
+
rng_restorer.restore()
|
67
|
+
return out
|
68
|
+
|
69
|
+
return wrapper
|
70
|
+
|
71
|
+
|
72
|
+
def restore_rngs(
|
73
|
+
fn: Callable = Missing(),
|
74
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
75
|
+
) -> Callable:
|
76
|
+
"""
|
77
|
+
Backup the current random state and restore it after the function call.
|
78
|
+
|
79
|
+
Parameters
|
80
|
+
----------
|
81
|
+
fn : Callable, optional
|
82
|
+
The function to be wrapped.
|
83
|
+
rngs : Union[RandomState, Sequence[RandomState]]
|
84
|
+
The random state to be backed up and restored. If not provided, the default RandomState instance will be used.
|
85
|
+
|
86
|
+
Returns
|
87
|
+
-------
|
88
|
+
Callable
|
89
|
+
The wrapped function.
|
90
|
+
"""
|
91
|
+
if isinstance(fn, Missing):
|
92
|
+
return functools.partial(restore_rngs, rngs=rngs)
|
93
|
+
|
94
|
+
if isinstance(rngs, RandomState):
|
95
|
+
rngs = [rngs]
|
96
|
+
assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
97
|
+
for rng in rngs:
|
98
|
+
assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
99
|
+
return _rng_backup(fn, rngs=rngs)
|
@@ -14,11 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""
|
17
|
-
This module contains the functions for the
|
17
|
+
This module contains the functions for the compilation of JAX code.
|
18
18
|
"""
|
19
19
|
|
20
|
-
from .
|
21
|
-
from .
|
20
|
+
from ._ad_checkpoint import *
|
21
|
+
from ._ad_checkpoint import __all__ as _ad_checkpoint_all
|
22
22
|
from ._conditions import *
|
23
23
|
from ._conditions import __all__ as _conditions_all
|
24
24
|
from ._error_if import *
|
@@ -26,20 +26,32 @@ from ._error_if import __all__ as _jit_error_all
|
|
26
26
|
from ._jit import *
|
27
27
|
from ._jit import __all__ as _jit_all
|
28
28
|
from ._loop_collect_return import *
|
29
|
-
from ._loop_collect_return import __all__ as
|
29
|
+
from ._loop_collect_return import __all__ as _loops_collection
|
30
30
|
from ._loop_no_collection import *
|
31
|
-
from ._loop_no_collection import __all__ as
|
31
|
+
from ._loop_no_collection import __all__ as _loops_no_collection
|
32
32
|
from ._make_jaxpr import *
|
33
33
|
from ._make_jaxpr import __all__ as _make_jaxpr_all
|
34
|
-
from ._mapping import *
|
35
|
-
from ._mapping import __all__ as _mapping_all
|
36
34
|
from ._progress_bar import *
|
37
35
|
from ._progress_bar import __all__ as _progress_bar_all
|
38
36
|
|
39
|
-
__all__ = (
|
40
|
-
|
41
|
-
|
37
|
+
__all__ = (
|
38
|
+
_jit_error_all
|
39
|
+
+ _conditions_all
|
40
|
+
+ _make_jaxpr_all
|
41
|
+
+ _jit_all
|
42
|
+
+ _progress_bar_all
|
43
|
+
+ _loops_collection
|
44
|
+
+ _loops_no_collection
|
45
|
+
+ _ad_checkpoint_all
|
46
|
+
)
|
42
47
|
|
43
|
-
del (
|
44
|
-
|
45
|
-
|
48
|
+
del (
|
49
|
+
_jit_error_all,
|
50
|
+
_conditions_all,
|
51
|
+
_loops_collection,
|
52
|
+
_make_jaxpr_all,
|
53
|
+
_jit_all,
|
54
|
+
_progress_bar_all,
|
55
|
+
_loops_no_collection,
|
56
|
+
_ad_checkpoint_all
|
57
|
+
)
|
@@ -0,0 +1,204 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from typing import Callable, Tuple, Union
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
from brainstate.typing import Missing
|
24
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
25
|
+
from ._util import write_back_state_values
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'checkpoint',
|
29
|
+
'remat'
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
def checkpoint(
|
34
|
+
fun: Callable = Missing(),
|
35
|
+
*,
|
36
|
+
prevent_cse: bool = True,
|
37
|
+
policy: Callable[..., bool] | None = None,
|
38
|
+
static_argnums: int | Tuple[int, ...] = (),
|
39
|
+
) -> Union[Callable, Callable[[Callable], Callable]]:
|
40
|
+
"""Make ``fun`` recompute internal linearization points when differentiated.
|
41
|
+
|
42
|
+
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
|
43
|
+
way to trade off computation time and memory cost in the context of automatic
|
44
|
+
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
|
45
|
+
and :func:`jax.vjp` but also with :func:`jax.linearize`.
|
46
|
+
|
47
|
+
When differentiating a function in reverse-mode, by default all the
|
48
|
+
linearization points (e.g. inputs to elementwise nonlinear primitive
|
49
|
+
operations) are stored when evaluating the forward pass so that they can be
|
50
|
+
reused on the backward pass. This evaluation strategy can lead to a high
|
51
|
+
memory cost, or even to poor performance on hardware accelerators where memory
|
52
|
+
access is much more expensive than FLOPs.
|
53
|
+
|
54
|
+
An alternative evaluation strategy is for some of the linearization points to
|
55
|
+
be recomputed (i.e. rematerialized) rather than stored. This approach can
|
56
|
+
reduce memory usage at the cost of increased computation.
|
57
|
+
|
58
|
+
This function decorator produces a new version of ``fun`` which follows
|
59
|
+
the rematerialization strategy rather than the default store-everything
|
60
|
+
strategy. That is, it returns a new version of ``fun`` which, when
|
61
|
+
differentiated, doesn't store any of its intermediate linearization points.
|
62
|
+
Instead, these linearization points are recomputed from the function's saved
|
63
|
+
inputs.
|
64
|
+
|
65
|
+
See the examples below.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
fun: Function for which the autodiff evaluation strategy is to be changed
|
69
|
+
from the default of storing all intermediate linearization points to
|
70
|
+
recomputing them. Its arguments and return value should be arrays,
|
71
|
+
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
72
|
+
prevent_cse: Optional, boolean keyword-only argument indicating whether to
|
73
|
+
prevent common subexpression elimination (CSE) optimizations in the HLO
|
74
|
+
generated from differentiation. This CSE prevention has costs because it
|
75
|
+
can foil other optimizations, and because it can incur high overheads on
|
76
|
+
some backends, especially GPU. The default is True because otherwise,
|
77
|
+
under a :func:`~jax.jit` or :func:`~jax.pmap`, CSE can defeat the purpose
|
78
|
+
of this decorator.
|
79
|
+
But in some settings, like when used inside a :func:`~jax.lax.scan`, this
|
80
|
+
CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can
|
81
|
+
be set to False.
|
82
|
+
static_argnums: Optional, int or sequence of ints, a keyword-only argument
|
83
|
+
indicating which argument values on which to specialize for tracing and
|
84
|
+
caching purposes. Specifying arguments as static can avoid
|
85
|
+
ConcretizationTypeErrors when tracing, but at the cost of more retracing
|
86
|
+
overheads. See the example below.
|
87
|
+
policy: Optional, callable keyword-only argument. It should be one of the
|
88
|
+
attributes of ``jax.checkpoint_policies``. The callable takes as input a
|
89
|
+
type-level specification of a first-order primitive application and
|
90
|
+
returns a boolean indicating whether the corresponding output value(s) can
|
91
|
+
be saved as residuals (or instead must be recomputed in the (co)tangent
|
92
|
+
computation if needed).
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
A function (callable) with the same input/output behavior as ``fun`` but
|
96
|
+
which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
|
97
|
+
:func:`jax.linearize`, recomputes rather than stores intermediate
|
98
|
+
linearization points, thus potentially saving memory at the cost of extra
|
99
|
+
computation.
|
100
|
+
|
101
|
+
Here is a simple example:
|
102
|
+
|
103
|
+
>>> import jax
|
104
|
+
>>> import jax.numpy as jnp
|
105
|
+
|
106
|
+
>>> @jax.checkpoint
|
107
|
+
... def g(x):
|
108
|
+
... y = jnp.sin(x)
|
109
|
+
... z = jnp.sin(y)
|
110
|
+
... return z
|
111
|
+
...
|
112
|
+
>>> jax.value_and_grad(g)(2.0)
|
113
|
+
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
|
114
|
+
|
115
|
+
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
116
|
+
decorator is present. When the decorator is not present, the values
|
117
|
+
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
|
118
|
+
pass and are stored for use in the backward pass, because they are needed
|
119
|
+
on the backward pass and depend only on the primal inputs. When using
|
120
|
+
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
|
121
|
+
and only the primal inputs (``2.0``) will be stored for the backward pass.
|
122
|
+
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
123
|
+
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
124
|
+
|
125
|
+
While :func:`jax.checkpoint` controls what values are stored from the
|
126
|
+
forward-pass to be used on the backward pass, the total amount of memory
|
127
|
+
required to evaluate a function or its VJP depends on many additional internal
|
128
|
+
details of that function. Those details include which numerical primitives are
|
129
|
+
used, how they're composed, where jit and control flow primitives like scan
|
130
|
+
are used, and other factors.
|
131
|
+
|
132
|
+
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
133
|
+
sophisticated autodiff rematerialization strategies. For example:
|
134
|
+
|
135
|
+
>>> def recursive_checkpoint(funs):
|
136
|
+
... if len(funs) == 1:
|
137
|
+
... return funs[0]
|
138
|
+
... elif len(funs) == 2:
|
139
|
+
... f1, f2 = funs
|
140
|
+
... return lambda x: f1(f2(x))
|
141
|
+
... else:
|
142
|
+
... f1 = recursive_checkpoint(funs[:len(funs)//2])
|
143
|
+
... f2 = recursive_checkpoint(funs[len(funs)//2:])
|
144
|
+
... return lambda x: f1(jax.checkpoint(f2)(x))
|
145
|
+
...
|
146
|
+
|
147
|
+
If ``fun`` involves Python control flow that depends on argument values,
|
148
|
+
it may be necessary to use the ``static_argnums`` parameter. For example,
|
149
|
+
consider a boolean flag argument::
|
150
|
+
|
151
|
+
from functools import partial
|
152
|
+
|
153
|
+
@partial(jax.checkpoint, static_argnums=(1,))
|
154
|
+
def foo(x, is_training):
|
155
|
+
if is_training:
|
156
|
+
...
|
157
|
+
else:
|
158
|
+
...
|
159
|
+
|
160
|
+
Here, the use of ``static_argnums`` allows the ``if`` statement's condition
|
161
|
+
to depends on the value of ``is_training``. The cost to using
|
162
|
+
``static_argnums`` is that it introduces re-tracing overheads across calls:
|
163
|
+
in the example, ``foo`` is re-traced every time it is called with a new value
|
164
|
+
of ``is_training``. In some situations, ``jax.ensure_compile_time_eval``
|
165
|
+
is needed as well::
|
166
|
+
|
167
|
+
@partial(jax.checkpoint, static_argnums=(1,))
|
168
|
+
def foo(x, y):
|
169
|
+
with jax.ensure_compile_time_eval():
|
170
|
+
y_pos = y > 0
|
171
|
+
if y_pos:
|
172
|
+
...
|
173
|
+
else:
|
174
|
+
...
|
175
|
+
|
176
|
+
As an alternative to using ``static_argnums`` (and
|
177
|
+
``jax.ensure_compile_time_eval``), it may be easier to compute some values
|
178
|
+
outside the :func:`jax.checkpoint`-decorated function and then close over them.
|
179
|
+
"""
|
180
|
+
if isinstance(fun, Missing):
|
181
|
+
return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
|
182
|
+
|
183
|
+
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
184
|
+
fun = StatefulFunction(fun, static_argnums=static_argnums)
|
185
|
+
checkpointed_fun = jax.checkpoint(fun.jaxpr_call,
|
186
|
+
prevent_cse=prevent_cse,
|
187
|
+
policy=policy,
|
188
|
+
static_argnums=tuple(i + 1 for i in static_argnums))
|
189
|
+
|
190
|
+
@functools.wraps(fun.fun)
|
191
|
+
def remat_fun(*args, **params):
|
192
|
+
# compile the function and get the state trace
|
193
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
194
|
+
read_state_vals = state_trace.get_read_state_values()
|
195
|
+
# call the checkpointed function
|
196
|
+
write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
|
197
|
+
# write the state values back to the states
|
198
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
199
|
+
return outs
|
200
|
+
|
201
|
+
return remat_fun
|
202
|
+
|
203
|
+
|
204
|
+
remat = checkpoint
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
from absl.testing import absltest
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestRemat(absltest.TestCase):
|
26
|
+
def test_basic_remat(self):
|
27
|
+
module = bst.compile.remat(bst.nn.Linear(2, 3))
|
28
|
+
y = module(jnp.ones((1, 2)))
|
29
|
+
assert y.shape == (1, 3)
|
30
|
+
|
31
|
+
def test_remat_with_scan(self):
|
32
|
+
class ScanLinear(bst.nn.Module):
|
33
|
+
def __init__(self):
|
34
|
+
super().__init__()
|
35
|
+
self.linear = bst.nn.Linear(3, 3)
|
36
|
+
|
37
|
+
def __call__(self, x: jax.Array):
|
38
|
+
@bst.compile.remat
|
39
|
+
def fun(x: jax.Array, _):
|
40
|
+
x = self.linear(x)
|
41
|
+
return x, None
|
42
|
+
|
43
|
+
return bst.compile.scan(fun, x, None, length=10)[0]
|
44
|
+
|
45
|
+
m = ScanLinear()
|
46
|
+
|
47
|
+
assert m.linear.weight.value['weight'].shape == (3, 3)
|
48
|
+
assert m.linear.weight.value['bias'].shape == (3,)
|
49
|
+
|
50
|
+
y = m(jnp.ones((10, 3)))
|
51
|
+
assert y.shape == (10, 3)
|