brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
+ This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
+ computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
+ which has been moved into the ``compile`` module.
21
+
22
+ The wrapped gradient transformations here are made possible by using the following ideas:
23
+ 1. All the states to compute the gradients should be known before the transformation.
24
+ There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
+ 2. The states that have been written in the function should be collected and updated after the function call.
26
+ We record these states during the function call and updated them after the function call.
27
+
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import inspect
33
+ from functools import partial, wraps
34
+ from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
35
+
36
+ import jax
37
+ from jax import numpy as jnp
38
+ from jax._src.api import _vjp
39
+ from jax.api_util import argnums_partial
40
+ from jax.extend import linear_util
41
+
42
+ from brainstate._state import State, StateTraceStack
43
+ from brainstate._utils import set_module_as
44
+ from brainstate.typing import PyTree, Missing
45
+ from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
46
+
47
+ __all__ = [
48
+ 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
49
+ ]
50
+
51
+ A = TypeVar('A')
52
+ Gradient = PyTree
53
+ LossValue = PyTree
54
+ AuxData = PyTree
55
+
56
+
57
+ def _isgeneratorfunction(fun):
58
+ # re-implemented here because of https://bugs.python.org/issue33261
59
+ while inspect.ismethod(fun):
60
+ fun = fun.__func__
61
+ while isinstance(fun, partial):
62
+ fun = fun.func
63
+ return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
64
+
65
+
66
+ def _check_callable(fun):
67
+ # In Python 3.10+, the only thing stopping us from supporting staticmethods
68
+ # is that we can't take weak references to them, which the C++ JIT requires.
69
+ if isinstance(fun, staticmethod):
70
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
71
+ if not callable(fun):
72
+ raise TypeError(f"Expected a callable value, got {fun}")
73
+ if _isgeneratorfunction(fun):
74
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
75
+
76
+
77
+ def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
78
+ """
79
+ Compute the gradient of a vector with respect to the input.
80
+ """
81
+ _check_callable(func)
82
+
83
+ @wraps(func)
84
+ def grad_fun(*args, **kwargs):
85
+ f = linear_util.wrap_init(func, kwargs)
86
+ f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
87
+ if has_aux:
88
+ y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
89
+ else:
90
+ y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
91
+ leaves, tree = jax.tree.flatten(y)
92
+ tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
93
+ grads = vjp_fn(tangents)
94
+ if isinstance(argnums, int):
95
+ grads = grads[0]
96
+ if has_aux:
97
+ return (grads, y, aux) if return_value else (grads, aux)
98
+ else:
99
+ return (grads, y) if return_value else grads
100
+
101
+ return grad_fun
102
+
103
+
104
+ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
105
+ @wraps(fun)
106
+ def fun_wrapped(*args, **kwargs):
107
+ if has_aux:
108
+ y, aux = fun(*args, **kwargs)
109
+ if return_value:
110
+ return y, (y, aux)
111
+ else:
112
+ return y, aux
113
+ else:
114
+ y = fun(*args, **kwargs)
115
+ if return_value:
116
+ return y, y
117
+ else:
118
+ return y, None
119
+
120
+ transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True)
121
+
122
+ @wraps(fun)
123
+ def jacfun(*args, **kwargs):
124
+ jac, aux = transform(*args, **kwargs)
125
+ if return_value:
126
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
127
+ else:
128
+ return (jac, aux) if has_aux else jac
129
+
130
+ return jacfun
131
+
132
+
133
+ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
134
+ @wraps(fun)
135
+ def fun_wrapped(*args, **kwargs):
136
+ if has_aux:
137
+ y, aux = fun(*args, **kwargs)
138
+ if return_value:
139
+ return y, (y, aux)
140
+ else:
141
+ return y, aux
142
+ else:
143
+ y = fun(*args, **kwargs)
144
+ if return_value:
145
+ return y, y
146
+ else:
147
+ return y, None
148
+
149
+ transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True)
150
+
151
+ @wraps(fun)
152
+ def jacfun(*args, **kwargs):
153
+ jac, aux = transform(*args, **kwargs)
154
+ if return_value:
155
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
156
+ else:
157
+ return (jac, aux) if has_aux else jac
158
+
159
+ return jacfun
160
+
161
+
162
+ TransformFn = Callable
163
+
164
+
165
+ class GradientTransform(PrettyRepr):
166
+ """
167
+ Automatic Differentiation Transformations for the ``State`` system.
168
+ """
169
+ __module__ = "brainstate.augment"
170
+
171
+ def __init__(
172
+ self,
173
+ target: Callable,
174
+ transform: TransformFn,
175
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
176
+ argnums: Optional[Union[int, Sequence[int]]] = None,
177
+ return_value: bool = False,
178
+ has_aux: bool = False,
179
+ transform_params: Optional[Dict[str, Any]] = None,
180
+ ):
181
+ # gradient variables
182
+ if isinstance(grad_states, dict):
183
+ grad_states = {k: v for k, v in grad_states.items()}
184
+ self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
185
+ if any(not isinstance(v, State) for v in self._grad_states):
186
+ raise TypeError("All grad_states must be State instances.")
187
+
188
+ # parameters
189
+ if argnums is None and len(self._grad_states) == 0:
190
+ argnums = 0
191
+ if argnums is None:
192
+ assert len(self._grad_states) > 0
193
+ _argnums = 0
194
+ elif isinstance(argnums, int):
195
+ _argnums = (0, argnums + 1) if len(self._grad_states) > 0 else (argnums + 1)
196
+ else:
197
+ assert isinstance(argnums, (tuple, list))
198
+ _argnums = tuple(a + 1 for a in argnums)
199
+ if len(self._grad_states) > 0:
200
+ _argnums = (0,) + _argnums
201
+ self._nonvar_argnums = argnums
202
+ self._argnums = _argnums
203
+ self._return_value = return_value
204
+ self._has_aux = has_aux
205
+
206
+ # target
207
+ self.target = target
208
+
209
+ # transform
210
+ self._states_to_be_written: Tuple[State, ...] = None
211
+ _grad_setting = dict() if transform_params is None else transform_params
212
+ if self._has_aux:
213
+ self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
214
+ else:
215
+ self._transform = transform(self._fun_without_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
216
+
217
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
218
+ yield PrettyType(self.__class__.__name__)
219
+ yield PrettyAttr("target", self.target)
220
+ yield PrettyAttr("grad_states", self._grad_states)
221
+ yield PrettyAttr("grad_tree", self._grad_tree)
222
+ yield PrettyAttr("argnums", self._nonvar_argnums)
223
+ yield PrettyAttr("return_value", self._return_value)
224
+ yield PrettyAttr("has_aux", self._has_aux)
225
+ yield PrettyAttr("transform", self._transform)
226
+
227
+ def _call_target(self, *args, **kwargs):
228
+ if self._states_to_be_written is None:
229
+ with StateTraceStack() as stack:
230
+ output = self.target(*args, **kwargs)
231
+ # grad_ids = set([id(v) for v in self._grad_states])
232
+ # self._states_to_be_written = [st for st in stack.get_write_states() if id(st) not in grad_ids]
233
+ self._states_to_be_written = [st for st in stack.get_write_states()]
234
+ else:
235
+ output = self.target(*args, **kwargs)
236
+ return output
237
+
238
+ def _fun_with_aux(self, grad_values: tuple, *args, **kwargs):
239
+ for v, d in zip(self._grad_states, grad_values):
240
+ v.restore_value(d)
241
+ # Users should return the auxiliary data like::
242
+ # >>> # 1. example of return one data
243
+ # >>> return scalar_loss, data
244
+ # >>> # 2. example of return multiple data
245
+ # >>> return scalar_loss, (data1, data2, ...)
246
+ outs = self._call_target(*args, **kwargs)
247
+ # outputs: [0] is the value for gradient,
248
+ # [1] is other values for return
249
+ assert self._states_to_be_written is not None, "The states to be written should be collected."
250
+ return outs[0], (outs, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
251
+
252
+ def _fun_without_aux(self, grad_values: tuple, *args, **kwargs):
253
+ for v, d in zip(self._grad_states, grad_values):
254
+ v.restore_value(d)
255
+ # Users should return the scalar value like this::
256
+ # >>> return scalar_loss
257
+ out = self._call_target(*args, **kwargs)
258
+ assert self._states_to_be_written is not None, "The states to be written should be collected."
259
+ return out, (out, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
260
+
261
+ def _return(self, rets):
262
+ grads, (outputs, new_grad_vals, new_dyn_vals) = rets
263
+ for i, val in enumerate(new_grad_vals):
264
+ self._grad_states[i].restore_value(val)
265
+ for i, val in enumerate(new_dyn_vals):
266
+ self._states_to_be_written[i].value = val
267
+
268
+ # check returned grads
269
+ if len(self._grad_states) > 0:
270
+ if self._nonvar_argnums is None:
271
+ grads = self._grad_tree.unflatten(grads)
272
+ else:
273
+ var_grads = self._grad_tree.unflatten(grads[0])
274
+ arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
275
+ grads = (var_grads, arg_grads)
276
+
277
+ # check returned value
278
+ if self._return_value:
279
+ # check aux
280
+ if self._has_aux:
281
+ return grads, outputs[0], outputs[1]
282
+ else:
283
+ return grads, outputs
284
+ else:
285
+ # check aux
286
+ if self._has_aux:
287
+ return grads, outputs[1]
288
+ else:
289
+ return grads
290
+
291
+ def __call__(
292
+ self, *args, **kwargs
293
+ ) -> Gradient | Tuple[Gradient, LossValue] | Tuple[Gradient, AuxData] | Tuple[Gradient, LossValue, AuxData]:
294
+ rets = self._transform([v.value for v in self._grad_states], *args, **kwargs)
295
+ return self._return(rets)
296
+
297
+
298
+ _doc_of_return = '''
299
+
300
+ 1. When ``grad_states`` is None
301
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
302
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
303
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
304
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
305
+ 2. When ``grad_states`` is not None and ``argnums`` is None
306
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
307
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
308
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
309
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
310
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
311
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
312
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
313
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
314
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
315
+
316
+ '''
317
+
318
+
319
+ @set_module_as("brainstate.augment")
320
+ def grad(
321
+ fun: Callable = Missing(),
322
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
323
+ argnums: Optional[Union[int, Sequence[int]]] = None,
324
+ holomorphic: Optional[bool] = False,
325
+ allow_int: Optional[bool] = False,
326
+ reduce_axes: Optional[Sequence[str]] = (),
327
+ has_aux: Optional[bool] = None,
328
+ return_value: Optional[bool] = False,
329
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
330
+ """
331
+ Compute the gradient of a scalar-valued function with respect to its arguments.
332
+
333
+ %s
334
+
335
+ Args:
336
+ fun: callable. the scalar-valued function to be differentiated.
337
+ reduce_axes: (Sequence[str]) optional. Specifies the axes to reduce over when
338
+ differentiating with respect to array-valued arguments. The default, (),
339
+ means to differentiate each element of the output with respect to each
340
+ element of the argument. If the argument is an array, this argument controls
341
+ how many axes the output of grad has.
342
+ allow_int: (bool) optional. Whether to allow differentiating with respect to
343
+ integer valued inputs. The gradient of an integer input will have a trivial
344
+ vector-space dtype (float0). Default False.
345
+ holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
346
+ Default False.
347
+ grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
348
+ in fun to take their gradients.
349
+ fun: the scalar-valued function to be differentiated.
350
+ argnums: (int or tuple of ints) optional. Specifies which positional
351
+ argument(s) to differentiate with respect to.
352
+ has_aux: (bool) optional. Indicates whether fun returns a pair where the
353
+ first element is considered the output of the mathematical function to be
354
+ differentiated and the second element is auxiliary data. Default False.
355
+ return_value: (bool) optional. Indicates whether to return the value of the
356
+ function along with the gradient. Default False.
357
+
358
+ Returns:
359
+ A function which computes the gradient of fun. The function takes the same
360
+ arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
361
+ the function returns a pair where the first element is the gradient and the
362
+ second element is the auxiliary data. If `return_value` is True, the function
363
+ returns a pair where the first element is the gradient and the second element
364
+ is the value of the function.
365
+
366
+ """
367
+ if isinstance(fun, Missing):
368
+ def transform(fun) -> GradientTransform:
369
+ return GradientTransform(target=fun,
370
+ transform=jax.grad,
371
+ grad_states=grad_states,
372
+ argnums=argnums,
373
+ return_value=return_value,
374
+ has_aux=False if has_aux is None else has_aux,
375
+ transform_params=dict(holomorphic=holomorphic,
376
+ allow_int=allow_int,
377
+ reduce_axes=reduce_axes))
378
+
379
+ return transform
380
+
381
+ return GradientTransform(target=fun,
382
+ transform=jax.grad,
383
+ grad_states=grad_states,
384
+ argnums=argnums,
385
+ return_value=return_value,
386
+ has_aux=False if has_aux is None else has_aux,
387
+ transform_params=dict(holomorphic=holomorphic,
388
+ allow_int=allow_int,
389
+ reduce_axes=reduce_axes))
390
+
391
+
392
+ grad.__doc__ = grad.__doc__ % _doc_of_return
393
+
394
+
395
+ @set_module_as("brainstate.augment")
396
+ def vector_grad(
397
+ func: Callable = Missing(),
398
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
399
+ argnums: Optional[Union[int, Sequence[int]]] = None,
400
+ return_value: bool = False,
401
+ has_aux: Optional[bool] = None,
402
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
403
+ """Take vector-valued gradients for function ``func``.
404
+
405
+ Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
406
+ the returns in this function are different for different argument settings.
407
+
408
+ %s
409
+
410
+ Parameters
411
+ ----------
412
+ func: Callable
413
+ Function whose gradient is to be computed.
414
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
415
+ The variables in ``func`` to take their gradients.
416
+ has_aux: optional, bool
417
+ Indicates whether ``fun`` returns a pair where the
418
+ first element is considered the output of the mathematical function to be
419
+ differentiated and the second element is auxiliary data. Default False.
420
+ return_value : bool
421
+ Whether return the loss value.
422
+ argnums: Optional, integer or sequence of integers. Specifies which
423
+ positional argument(s) to differentiate with respect to (default ``0``).
424
+
425
+ Returns
426
+ -------
427
+ func : GradientTransform
428
+ The vector gradient function.
429
+ """
430
+
431
+ if isinstance(func, Missing):
432
+ def transform(fun) -> GradientTransform:
433
+ return GradientTransform(target=fun,
434
+ transform=functional_vector_grad,
435
+ grad_states=grad_states,
436
+ argnums=argnums,
437
+ return_value=return_value,
438
+ has_aux=False if has_aux is None else has_aux)
439
+
440
+ return transform
441
+
442
+ else:
443
+ return GradientTransform(target=func,
444
+ transform=functional_vector_grad,
445
+ grad_states=grad_states,
446
+ argnums=argnums,
447
+ return_value=return_value,
448
+ has_aux=False if has_aux is None else has_aux)
449
+
450
+
451
+ vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
452
+
453
+
454
+ @set_module_as("brainstate.augment")
455
+ def jacrev(
456
+ fun: Callable,
457
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
458
+ argnums: Optional[Union[int, Sequence[int]]] = None,
459
+ has_aux: Optional[bool] = None,
460
+ return_value: bool = False,
461
+ holomorphic: bool = False,
462
+ allow_int: bool = False,
463
+ ) -> GradientTransform:
464
+ """
465
+ Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
466
+
467
+ This function extends the JAX official ``jacrev`` to make automatic jacobian
468
+ computation on functions and class functions. Moreover, it supports returning
469
+ value ("return_value") and returning auxiliary data ("has_aux").
470
+
471
+ %s
472
+
473
+
474
+ Parameters
475
+ ----------
476
+ fun: Function whose Jacobian is to be computed.
477
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
478
+ The variables in ``func`` to take their gradients.
479
+ has_aux: optional, bool
480
+ Indicates whether ``fun`` returns a pair where the
481
+ first element is considered the output of the mathematical function to be
482
+ differentiated and the second element is auxiliary data. Default False.
483
+ return_value : bool
484
+ Whether return the loss value.
485
+ argnums: Optional, integer or sequence of integers.
486
+ Specifies which
487
+ positional argument(s) to differentiate with respect to (default ``0``).
488
+ holomorphic: Optional, bool.
489
+ Indicates whether ``fun`` is promised to be
490
+ holomorphic. Default False.
491
+ allow_int: Optional, bool.
492
+ Whether to allow differentiating with
493
+ respect to integer valued inputs. The gradient of an integer input will
494
+ have a trivial vector-space dtype (float0). Default False.
495
+
496
+ Returns
497
+ -------
498
+ fun: GradientTransform
499
+ The transformed object.
500
+ """
501
+ return GradientTransform(target=fun,
502
+ transform=_jacrev,
503
+ grad_states=grad_states,
504
+ argnums=argnums,
505
+ return_value=return_value,
506
+ has_aux=False if has_aux is None else has_aux,
507
+ transform_params=dict(holomorphic=holomorphic,
508
+ allow_int=allow_int))
509
+
510
+
511
+ jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
512
+
513
+ jacobian = jacrev
514
+
515
+
516
+ @set_module_as("brainstate.augment")
517
+ def jacfwd(
518
+ func: Callable,
519
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
520
+ argnums: Optional[Union[int, Sequence[int]]] = None,
521
+ has_aux: Optional[bool] = None,
522
+ return_value: bool = False,
523
+ holomorphic: bool = False,
524
+ ) -> GradientTransform:
525
+ """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
526
+
527
+ This function extends the JAX official ``jacfwd`` to make automatic jacobian
528
+ computation on functions and class functions. Moreover, it supports returning
529
+ value ("return_value") and returning auxiliary data ("has_aux").
530
+
531
+ %s
532
+
533
+ Parameters
534
+ ----------
535
+ func: Function whose Jacobian is to be computed.
536
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
537
+ The variables in ``func`` to take their gradients.
538
+ has_aux: optional, bool
539
+ Indicates whether ``fun`` returns a pair where the
540
+ first element is considered the output of the mathematical function to be
541
+ differentiated and the second element is auxiliary data. Default False.
542
+ return_value : bool
543
+ Whether return the loss value.
544
+ argnums: Optional, integer or sequence of integers. Specifies which
545
+ positional argument(s) to differentiate with respect to (default ``0``).
546
+ holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
547
+ holomorphic. Default False.
548
+
549
+ Returns
550
+ -------
551
+ obj: GradientTransform
552
+ The transformed object.
553
+ """
554
+
555
+ return GradientTransform(target=func,
556
+ transform=_jacfwd,
557
+ grad_states=grad_states,
558
+ argnums=argnums,
559
+ return_value=return_value,
560
+ has_aux=False if has_aux is None else has_aux,
561
+ transform_params=dict(holomorphic=holomorphic))
562
+
563
+
564
+ jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
565
+
566
+
567
+ @set_module_as("brainstate.augment")
568
+ def hessian(
569
+ func: Callable,
570
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
571
+ argnums: Optional[Union[int, Sequence[int]]] = None,
572
+ has_aux: bool = False,
573
+ return_value: bool = False,
574
+ holomorphic: bool = False,
575
+ ) -> GradientTransform:
576
+ """
577
+ Hessian of ``func`` as a dense array.
578
+
579
+ %s
580
+
581
+ Parameters
582
+ ----------
583
+ func : callable
584
+ Function whose Hessian is to be computed. Its arguments at positions
585
+ specified by ``argnums`` should be arrays, scalars, or standard Python
586
+ containers thereof. It should return arrays, scalars, or standard Python
587
+ containers thereof.
588
+ grad_states : optional, ArrayCollector, sequence of ArrayType
589
+ The variables required to compute their gradients.
590
+ argnums: Optional, integer or sequence of integers
591
+ Specifies which positional argument(s) to differentiate with respect to (default ``0``).
592
+ holomorphic : bool
593
+ Indicates whether ``fun`` is promised to be holomorphic. Default False.
594
+ return_value : bool
595
+ Whether return the hessian values.
596
+
597
+ Returns
598
+ -------
599
+ obj: ObjectTransform
600
+ The transformed object.
601
+ """
602
+ return GradientTransform(target=func,
603
+ transform=jax.hessian,
604
+ grad_states=grad_states,
605
+ argnums=argnums,
606
+ return_value=return_value,
607
+ has_aux=False if has_aux is None else has_aux,
608
+ transform_params=dict(holomorphic=holomorphic))
609
+
610
+
611
+ hessian.__doc__ = hessian.__doc__ % _doc_of_return