brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.core
21
+ import jax.numpy as jnp
22
+
23
+ import brainstate as bst
24
+
25
+
26
+ class TestVmap(unittest.TestCase):
27
+ def test_vmap_return_keep_reference_return(self):
28
+ @bst.augment.vmap(in_axes=0, out_axes=0)
29
+ def create_model(key):
30
+ bst.random.set_key(key)
31
+ m1 = bst.nn.Linear(2, 3)
32
+
33
+ m2 = bst.nn.Linear(3, 4)
34
+ m2.a = m1
35
+ m3 = bst.nn.Linear(3, 5)
36
+ m3.a = m1
37
+ self.assertTrue(id(m2.a) == id(m3.a))
38
+ return m2, m3
39
+
40
+ m2, m3 = create_model(bst.random.split_key(10))
41
+ self.assertTrue(id(m2.a) == id(m3.a))
42
+ jax.core.concrete_or_error(None, bst.random.DEFAULT.value)
43
+
44
+ def test_vmap_return_keep_reference_pass_into_fun(self):
45
+ @bst.augment.vmap(in_axes=(None, None, 0), out_axes=0)
46
+ def run_model(m2, m3, x):
47
+ self.assertTrue(id(m2.a) == id(m3.a))
48
+ self.assertTrue(id(m2) != m2_id)
49
+ self.assertTrue(id(m3) != m3_id)
50
+ return m2(x), m3(x)
51
+
52
+ m1 = bst.nn.Linear(2, 3)
53
+ m2 = bst.nn.Linear(4, 3)
54
+ m2.a = m1
55
+ m3 = bst.nn.Linear(4, 5)
56
+ m3.a = m1
57
+ m3_id = id(m3)
58
+ m2_id = id(m2)
59
+ r1, r2 = run_model(m2, m3, jnp.ones((4, 3, 4)))
60
+
61
+ def test_vmap_set_key(self):
62
+ @bst.augment.vmap(in_axes=0, out_axes=0)
63
+ def create_model(key):
64
+ bst.random.set_key(key)
65
+ return bst.nn.Linear(2, 3)
66
+
67
+ model = create_model(bst.random.split_keys(10))
68
+ print(model.weight.value_call(jnp.shape))
69
+ model.weight.value_call(lambda x: jax.core.concrete_or_error(None, x))
70
+ bst.random.seed()
71
+
72
+ def test_vmap_input(self):
73
+ model = bst.nn.Linear(2, 3)
74
+ print(id(model), id(model.weight))
75
+ model_id = id(model)
76
+ weight_id = id(model.weight)
77
+
78
+ x = jnp.ones((5, 2))
79
+
80
+ @bst.augment.vmap
81
+ def forward(x):
82
+ self.assertTrue(id(model) == model_id)
83
+ self.assertTrue(id(model.weight) == weight_id)
84
+ return model(x)
85
+
86
+ y = forward(x)
87
+ self.assertTrue(y.shape == (5, 3))
88
+ print(y.shape)
89
+ print(model.weight.value_call(jnp.shape))
90
+ print(model.weight.value)
91
+
92
+ def test_vmap_model(self):
93
+ model = bst.nn.Linear(2, 3)
94
+ model_id = id(model)
95
+ weight_id = id(model.weight)
96
+ print(id(model), id(model.weight))
97
+ x = jnp.ones((5, 2))
98
+
99
+ @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
100
+ def forward(model, x):
101
+ self.assertTrue(id(model) != model_id)
102
+ self.assertTrue(id(model.weight) != weight_id)
103
+ print(id(model), id(model.weight))
104
+ return model(x)
105
+
106
+ y = forward(model, x)
107
+ print(y.shape)
108
+ print(model.weight.value_call(jnp.shape))
109
+ print(model.weight.value)
110
+
111
+ def test_vmap1(self):
112
+ model = bst.nn.Linear(2, 3)
113
+ x = jnp.ones((5, 2))
114
+
115
+ @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
116
+ def forward(model, x):
117
+ return model(x)
118
+
119
+ y = forward(model, x)
120
+ print(y.shape)
121
+
122
+ def test_vmap2(self):
123
+ class LinearEnsemble(bst.nn.Module):
124
+ def __init__(self, num):
125
+ super().__init__()
126
+ self.w = bst.ParamState(bst.random.random((num, 2, 3)))
127
+
128
+ model = LinearEnsemble(5)
129
+ x = jnp.ones((2,))
130
+
131
+ @bst.augment.vmap(in_axes=(0, None), out_axes=0)
132
+ def forward(model, x):
133
+ return jnp.dot(x, model.w.value)
134
+
135
+ y = forward(model, x)
136
+ print(y.shape)
137
+
138
+ def test_vmap3(self):
139
+ class Foo(bst.nn.Module):
140
+ def __init__(self):
141
+ super().__init__()
142
+ self.a = bst.ParamState(jnp.arange(4))
143
+ self.b = bst.ShortTermState(jnp.arange(4))
144
+
145
+ state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
146
+
147
+ @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
148
+ def mul(foo):
149
+ return foo.a.value * foo.b.value
150
+
151
+ foo = Foo()
152
+ y = mul(foo)
153
+ print(y.shape)
154
+
155
+ def test_vmap4(self):
156
+ class Foo(bst.nn.Module):
157
+ def __init__(self):
158
+ super().__init__()
159
+ self.a = bst.ParamState(jnp.arange(4))
160
+ self.b = bst.ShortTermState(jnp.arange(4))
161
+
162
+ def __call__(self):
163
+ self.b.value = self.a.value * self.b.value
164
+
165
+ @bst.augment.vmap
166
+ def mul(foo):
167
+ foo()
168
+ return foo
169
+
170
+ foo = Foo()
171
+ with bst.StateTraceStack() as trace:
172
+ m = mul(foo)
173
+
174
+ self.assertTrue(m is foo)
175
+ print(m.a.value, foo.a.value)
176
+ self.assertTrue(jnp.allclose(m.a.value, foo.a.value))
177
+ print(m.b.value, foo.b.value)
178
+ self.assertTrue(jnp.allclose(m.b.value, foo.b.value))
179
+ print(trace.get_write_states())
180
+ self.assertTrue(len(trace.get_write_states()) == 1)
181
+ print(trace.get_read_states())
182
+ self.assertTrue(len(trace.get_read_states()) == 2)
183
+
184
+ def test_vmap5(self):
185
+ class Foo(bst.nn.Module):
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.a = bst.ParamState(jnp.arange(4))
189
+ self.b = bst.ShortTermState(jnp.arange(4))
190
+
191
+ def __call__(self):
192
+ self.b.value = self.a.value * self.b.value
193
+
194
+ @bst.augment.vmap
195
+ def mul(foo):
196
+ foo()
197
+
198
+ foo = Foo()
199
+ with bst.StateTraceStack() as trace:
200
+ mul(foo)
201
+
202
+ print(foo.a.value)
203
+ print(foo.b.value)
204
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
205
+ self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
206
+
207
+ print(trace.get_write_states())
208
+ print(trace.get_read_states())
209
+
210
+
@@ -0,0 +1,99 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from typing import Callable, Sequence, Union
20
+
21
+ from brainstate.random import DEFAULT, RandomState
22
+ from brainstate.typing import Missing
23
+
24
+ __all__ = [
25
+ 'restore_rngs'
26
+ ]
27
+
28
+
29
+ class RngRestore:
30
+ """
31
+ Backup and restore the random state of a sequence of RandomState instances.
32
+ """
33
+
34
+ def __init__(self, rngs: Sequence[RandomState]):
35
+ self.rngs: Sequence[RandomState] = rngs
36
+ self.rng_keys = []
37
+
38
+ def backup(self):
39
+ """
40
+ Backup the current random key of the RandomState instances.
41
+ """
42
+ self.rng_keys = [rng.value for rng in self.rngs]
43
+
44
+ def restore(self):
45
+ """
46
+ Restore the random key of the RandomState instances.
47
+ """
48
+ for rng, key in zip(self.rngs, self.rng_keys):
49
+ rng.restore_value(key)
50
+ self.rng_keys = []
51
+
52
+
53
+ def _rng_backup(
54
+ fn: Callable,
55
+ rngs: Union[RandomState, Sequence[RandomState]]
56
+ ) -> Callable:
57
+ rng_restorer = RngRestore(rngs)
58
+
59
+ @functools.wraps(fn)
60
+ def wrapper(*args, **kwargs):
61
+ # backup the random state
62
+ rng_restorer.backup()
63
+ # call the function
64
+ out = fn(*args, **kwargs)
65
+ # restore the random state
66
+ rng_restorer.restore()
67
+ return out
68
+
69
+ return wrapper
70
+
71
+
72
+ def restore_rngs(
73
+ fn: Callable = Missing(),
74
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
75
+ ) -> Callable:
76
+ """
77
+ Backup the current random state and restore it after the function call.
78
+
79
+ Parameters
80
+ ----------
81
+ fn : Callable, optional
82
+ The function to be wrapped.
83
+ rngs : Union[RandomState, Sequence[RandomState]]
84
+ The random state to be backed up and restored. If not provided, the default RandomState instance will be used.
85
+
86
+ Returns
87
+ -------
88
+ Callable
89
+ The wrapped function.
90
+ """
91
+ if isinstance(fn, Missing):
92
+ return functools.partial(restore_rngs, rngs=rngs)
93
+
94
+ if isinstance(rngs, RandomState):
95
+ rngs = [rngs]
96
+ assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
97
+ for rng in rngs:
98
+ assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
99
+ return _rng_backup(fn, rngs=rngs)
@@ -14,11 +14,11 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """
17
- This module contains the functions for the transformation of the brain data.
17
+ This module contains the functions for the compilation of JAX code.
18
18
  """
19
19
 
20
- from ._autograd import *
21
- from ._autograd import __all__ as _gradients_all
20
+ from ._ad_checkpoint import *
21
+ from ._ad_checkpoint import __all__ as _ad_checkpoint_all
22
22
  from ._conditions import *
23
23
  from ._conditions import __all__ as _conditions_all
24
24
  from ._error_if import *
@@ -26,20 +26,32 @@ from ._error_if import __all__ as _jit_error_all
26
26
  from ._jit import *
27
27
  from ._jit import __all__ as _jit_all
28
28
  from ._loop_collect_return import *
29
- from ._loop_collect_return import __all__ as _loops_all
29
+ from ._loop_collect_return import __all__ as _loops_collection
30
30
  from ._loop_no_collection import *
31
- from ._loop_no_collection import __all__ as _loops_no_collection_all
31
+ from ._loop_no_collection import __all__ as _loops_no_collection
32
32
  from ._make_jaxpr import *
33
33
  from ._make_jaxpr import __all__ as _make_jaxpr_all
34
- from ._mapping import *
35
- from ._mapping import __all__ as _mapping_all
36
34
  from ._progress_bar import *
37
35
  from ._progress_bar import __all__ as _progress_bar_all
38
36
 
39
- __all__ = (_gradients_all + _jit_error_all + _conditions_all + _loops_all +
40
- _make_jaxpr_all + _jit_all + _progress_bar_all + _loops_no_collection_all +
41
- _mapping_all)
37
+ __all__ = (
38
+ _jit_error_all
39
+ + _conditions_all
40
+ + _make_jaxpr_all
41
+ + _jit_all
42
+ + _progress_bar_all
43
+ + _loops_collection
44
+ + _loops_no_collection
45
+ + _ad_checkpoint_all
46
+ )
42
47
 
43
- del (_gradients_all, _jit_error_all, _conditions_all, _loops_all,
44
- _make_jaxpr_all, _jit_all, _progress_bar_all, _loops_no_collection_all,
45
- _mapping_all)
48
+ del (
49
+ _jit_error_all,
50
+ _conditions_all,
51
+ _loops_collection,
52
+ _make_jaxpr_all,
53
+ _jit_all,
54
+ _progress_bar_all,
55
+ _loops_no_collection,
56
+ _ad_checkpoint_all
57
+ )
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from typing import Callable, Tuple, Union
20
+
21
+ import jax
22
+
23
+ from brainstate.typing import Missing
24
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
25
+ from ._util import write_back_state_values
26
+
27
+ __all__ = [
28
+ 'checkpoint',
29
+ 'remat'
30
+ ]
31
+
32
+
33
+ def checkpoint(
34
+ fun: Callable = Missing(),
35
+ *,
36
+ prevent_cse: bool = True,
37
+ policy: Callable[..., bool] | None = None,
38
+ static_argnums: int | Tuple[int, ...] = (),
39
+ ) -> Union[Callable, Callable[[Callable], Callable]]:
40
+ """Make ``fun`` recompute internal linearization points when differentiated.
41
+
42
+ The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
43
+ way to trade off computation time and memory cost in the context of automatic
44
+ differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
45
+ and :func:`jax.vjp` but also with :func:`jax.linearize`.
46
+
47
+ When differentiating a function in reverse-mode, by default all the
48
+ linearization points (e.g. inputs to elementwise nonlinear primitive
49
+ operations) are stored when evaluating the forward pass so that they can be
50
+ reused on the backward pass. This evaluation strategy can lead to a high
51
+ memory cost, or even to poor performance on hardware accelerators where memory
52
+ access is much more expensive than FLOPs.
53
+
54
+ An alternative evaluation strategy is for some of the linearization points to
55
+ be recomputed (i.e. rematerialized) rather than stored. This approach can
56
+ reduce memory usage at the cost of increased computation.
57
+
58
+ This function decorator produces a new version of ``fun`` which follows
59
+ the rematerialization strategy rather than the default store-everything
60
+ strategy. That is, it returns a new version of ``fun`` which, when
61
+ differentiated, doesn't store any of its intermediate linearization points.
62
+ Instead, these linearization points are recomputed from the function's saved
63
+ inputs.
64
+
65
+ See the examples below.
66
+
67
+ Args:
68
+ fun: Function for which the autodiff evaluation strategy is to be changed
69
+ from the default of storing all intermediate linearization points to
70
+ recomputing them. Its arguments and return value should be arrays,
71
+ scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
72
+ prevent_cse: Optional, boolean keyword-only argument indicating whether to
73
+ prevent common subexpression elimination (CSE) optimizations in the HLO
74
+ generated from differentiation. This CSE prevention has costs because it
75
+ can foil other optimizations, and because it can incur high overheads on
76
+ some backends, especially GPU. The default is True because otherwise,
77
+ under a :func:`~jax.jit` or :func:`~jax.pmap`, CSE can defeat the purpose
78
+ of this decorator.
79
+ But in some settings, like when used inside a :func:`~jax.lax.scan`, this
80
+ CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can
81
+ be set to False.
82
+ static_argnums: Optional, int or sequence of ints, a keyword-only argument
83
+ indicating which argument values on which to specialize for tracing and
84
+ caching purposes. Specifying arguments as static can avoid
85
+ ConcretizationTypeErrors when tracing, but at the cost of more retracing
86
+ overheads. See the example below.
87
+ policy: Optional, callable keyword-only argument. It should be one of the
88
+ attributes of ``jax.checkpoint_policies``. The callable takes as input a
89
+ type-level specification of a first-order primitive application and
90
+ returns a boolean indicating whether the corresponding output value(s) can
91
+ be saved as residuals (or instead must be recomputed in the (co)tangent
92
+ computation if needed).
93
+
94
+ Returns:
95
+ A function (callable) with the same input/output behavior as ``fun`` but
96
+ which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
97
+ :func:`jax.linearize`, recomputes rather than stores intermediate
98
+ linearization points, thus potentially saving memory at the cost of extra
99
+ computation.
100
+
101
+ Here is a simple example:
102
+
103
+ >>> import jax
104
+ >>> import jax.numpy as jnp
105
+
106
+ >>> @jax.checkpoint
107
+ ... def g(x):
108
+ ... y = jnp.sin(x)
109
+ ... z = jnp.sin(y)
110
+ ... return z
111
+ ...
112
+ >>> jax.value_and_grad(g)(2.0)
113
+ (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
114
+
115
+ Here, the same value is produced whether or not the :func:`jax.checkpoint`
116
+ decorator is present. When the decorator is not present, the values
117
+ ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
118
+ pass and are stored for use in the backward pass, because they are needed
119
+ on the backward pass and depend only on the primal inputs. When using
120
+ :func:`jax.checkpoint`, the forward pass will compute only the primal outputs
121
+ and only the primal inputs (``2.0``) will be stored for the backward pass.
122
+ At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
123
+ ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
124
+
125
+ While :func:`jax.checkpoint` controls what values are stored from the
126
+ forward-pass to be used on the backward pass, the total amount of memory
127
+ required to evaluate a function or its VJP depends on many additional internal
128
+ details of that function. Those details include which numerical primitives are
129
+ used, how they're composed, where jit and control flow primitives like scan
130
+ are used, and other factors.
131
+
132
+ The :func:`jax.checkpoint` decorator can be applied recursively to express
133
+ sophisticated autodiff rematerialization strategies. For example:
134
+
135
+ >>> def recursive_checkpoint(funs):
136
+ ... if len(funs) == 1:
137
+ ... return funs[0]
138
+ ... elif len(funs) == 2:
139
+ ... f1, f2 = funs
140
+ ... return lambda x: f1(f2(x))
141
+ ... else:
142
+ ... f1 = recursive_checkpoint(funs[:len(funs)//2])
143
+ ... f2 = recursive_checkpoint(funs[len(funs)//2:])
144
+ ... return lambda x: f1(jax.checkpoint(f2)(x))
145
+ ...
146
+
147
+ If ``fun`` involves Python control flow that depends on argument values,
148
+ it may be necessary to use the ``static_argnums`` parameter. For example,
149
+ consider a boolean flag argument::
150
+
151
+ from functools import partial
152
+
153
+ @partial(jax.checkpoint, static_argnums=(1,))
154
+ def foo(x, is_training):
155
+ if is_training:
156
+ ...
157
+ else:
158
+ ...
159
+
160
+ Here, the use of ``static_argnums`` allows the ``if`` statement's condition
161
+ to depends on the value of ``is_training``. The cost to using
162
+ ``static_argnums`` is that it introduces re-tracing overheads across calls:
163
+ in the example, ``foo`` is re-traced every time it is called with a new value
164
+ of ``is_training``. In some situations, ``jax.ensure_compile_time_eval``
165
+ is needed as well::
166
+
167
+ @partial(jax.checkpoint, static_argnums=(1,))
168
+ def foo(x, y):
169
+ with jax.ensure_compile_time_eval():
170
+ y_pos = y > 0
171
+ if y_pos:
172
+ ...
173
+ else:
174
+ ...
175
+
176
+ As an alternative to using ``static_argnums`` (and
177
+ ``jax.ensure_compile_time_eval``), it may be easier to compute some values
178
+ outside the :func:`jax.checkpoint`-decorated function and then close over them.
179
+ """
180
+ if isinstance(fun, Missing):
181
+ return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
182
+
183
+ static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
184
+ fun = StatefulFunction(fun, static_argnums=static_argnums)
185
+ checkpointed_fun = jax.checkpoint(fun.jaxpr_call,
186
+ prevent_cse=prevent_cse,
187
+ policy=policy,
188
+ static_argnums=tuple(i + 1 for i in static_argnums))
189
+
190
+ @functools.wraps(fun.fun)
191
+ def remat_fun(*args, **params):
192
+ # compile the function and get the state trace
193
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
194
+ read_state_vals = state_trace.get_read_state_values()
195
+ # call the checkpointed function
196
+ write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
197
+ # write the state values back to the states
198
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
199
+ return outs
200
+
201
+ return remat_fun
202
+
203
+
204
+ remat = checkpoint
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from absl.testing import absltest
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestRemat(absltest.TestCase):
26
+ def test_basic_remat(self):
27
+ module = bst.compile.remat(bst.nn.Linear(2, 3))
28
+ y = module(jnp.ones((1, 2)))
29
+ assert y.shape == (1, 3)
30
+
31
+ def test_remat_with_scan(self):
32
+ class ScanLinear(bst.nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.linear = bst.nn.Linear(3, 3)
36
+
37
+ def __call__(self, x: jax.Array):
38
+ @bst.compile.remat
39
+ def fun(x: jax.Array, _):
40
+ x = self.linear(x)
41
+ return x, None
42
+
43
+ return bst.compile.scan(fun, x, None, length=10)[0]
44
+
45
+ m = ScanLinear()
46
+
47
+ assert m.linear.weight.value['weight'].shape == (3, 3)
48
+ assert m.linear.weight.value['bias'].shape == (3,)
49
+
50
+ y = m(jnp.ones((10, 3)))
51
+ assert y.shape == (10, 3)