brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,608 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
+ This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
+ computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
+ which has been moved into the ``compile`` module.
21
+
22
+ The wrapped gradient transformations here are made possible by using the following ideas:
23
+ 1. All the states to compute the gradients should be known before the transformation.
24
+ There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
+ 2. The states that have been written in the function should be collected and updated after the function call.
26
+ We record these states during the function call and updated them after the function call.
27
+
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import inspect
33
+ from functools import partial, wraps
34
+ from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
35
+
36
+ import jax
37
+ from jax import numpy as jnp
38
+ from jax._src.api import _vjp
39
+ from jax.api_util import argnums_partial
40
+ from jax.extend import linear_util
41
+
42
+ from brainstate._state import State, StateTraceStack
43
+ from brainstate._utils import set_module_as
44
+ from brainstate.typing import PyTree, Missing
45
+ from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
46
+
47
+ __all__ = [
48
+ 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
49
+ ]
50
+
51
+ A = TypeVar('A')
52
+ Gradient = PyTree
53
+ LossValue = PyTree
54
+ AuxData = PyTree
55
+
56
+
57
+ def _isgeneratorfunction(fun):
58
+ # re-implemented here because of https://bugs.python.org/issue33261
59
+ while inspect.ismethod(fun):
60
+ fun = fun.__func__
61
+ while isinstance(fun, partial):
62
+ fun = fun.func
63
+ return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
64
+
65
+
66
+ def _check_callable(fun):
67
+ # In Python 3.10+, the only thing stopping us from supporting staticmethods
68
+ # is that we can't take weak references to them, which the C++ JIT requires.
69
+ if isinstance(fun, staticmethod):
70
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
71
+ if not callable(fun):
72
+ raise TypeError(f"Expected a callable value, got {fun}")
73
+ if _isgeneratorfunction(fun):
74
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
75
+
76
+
77
+ def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
78
+ """
79
+ Compute the gradient of a vector with respect to the input.
80
+ """
81
+ _check_callable(func)
82
+
83
+ @wraps(func)
84
+ def grad_fun(*args, **kwargs):
85
+ f = linear_util.wrap_init(func, kwargs)
86
+ f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
87
+ if has_aux:
88
+ y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
89
+ else:
90
+ y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
91
+ leaves, tree = jax.tree.flatten(y)
92
+ tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
93
+ grads = vjp_fn(tangents)
94
+ if isinstance(argnums, int):
95
+ grads = grads[0]
96
+ if has_aux:
97
+ return (grads, y, aux) if return_value else (grads, aux)
98
+ else:
99
+ return (grads, y) if return_value else grads
100
+
101
+ return grad_fun
102
+
103
+
104
+ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
105
+ @wraps(fun)
106
+ def fun_wrapped(*args, **kwargs):
107
+ if has_aux:
108
+ y, aux = fun(*args, **kwargs)
109
+ if return_value:
110
+ return y, (y, aux)
111
+ else:
112
+ return y, aux
113
+ else:
114
+ y = fun(*args, **kwargs)
115
+ if return_value:
116
+ return y, y
117
+ else:
118
+ return y, None
119
+
120
+ transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True)
121
+
122
+ @wraps(fun)
123
+ def jacfun(*args, **kwargs):
124
+ jac, aux = transform(*args, **kwargs)
125
+ if return_value:
126
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
127
+ else:
128
+ return (jac, aux) if has_aux else jac
129
+
130
+ return jacfun
131
+
132
+
133
+ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
134
+ @wraps(fun)
135
+ def fun_wrapped(*args, **kwargs):
136
+ if has_aux:
137
+ y, aux = fun(*args, **kwargs)
138
+ if return_value:
139
+ return y, (y, aux)
140
+ else:
141
+ return y, aux
142
+ else:
143
+ y = fun(*args, **kwargs)
144
+ if return_value:
145
+ return y, y
146
+ else:
147
+ return y, None
148
+
149
+ transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True)
150
+
151
+ @wraps(fun)
152
+ def jacfun(*args, **kwargs):
153
+ jac, aux = transform(*args, **kwargs)
154
+ if return_value:
155
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
156
+ else:
157
+ return (jac, aux) if has_aux else jac
158
+
159
+ return jacfun
160
+
161
+
162
+ class GradientTransform(PrettyRepr):
163
+ """
164
+ Automatic Differentiation Transformations for the ``State`` system.
165
+ """
166
+ __module__ = "brainstate.augment"
167
+
168
+ def __init__(
169
+ self,
170
+ target: Callable,
171
+ transform: Callable,
172
+ grad_states: Any,
173
+ argnums: Optional[Union[int, Sequence[int]]],
174
+ return_value: bool,
175
+ has_aux: bool,
176
+ transform_params: Optional[Dict[str, Any]] = None,
177
+ ):
178
+ # gradient variables
179
+ if isinstance(grad_states, dict):
180
+ grad_states = {k: v for k, v in grad_states.items()}
181
+ self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
182
+ if any(not isinstance(v, State) for v in self._grad_states):
183
+ raise TypeError("All grad_states must be State instances.")
184
+
185
+ # parameters
186
+ if argnums is None and len(self._grad_states) == 0:
187
+ argnums = 0
188
+ if argnums is None:
189
+ assert len(self._grad_states) > 0
190
+ _argnums = 0
191
+ elif isinstance(argnums, int):
192
+ _argnums = (0, argnums + 1) if len(self._grad_states) > 0 else (argnums + 1)
193
+ else:
194
+ assert isinstance(argnums, (tuple, list))
195
+ _argnums = tuple(a + 1 for a in argnums)
196
+ if len(self._grad_states) > 0:
197
+ _argnums = (0,) + _argnums
198
+ self._nonvar_argnums = argnums
199
+ self._argnums = _argnums
200
+ self._return_value = return_value
201
+ self._has_aux = has_aux
202
+
203
+ # target
204
+ self.target = target
205
+
206
+ # transform
207
+ self._states_to_be_written: Tuple[State, ...] = None
208
+ _grad_setting = dict() if transform_params is None else transform_params
209
+ if self._has_aux:
210
+ self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
211
+ else:
212
+ self._transform = transform(self._fun_without_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
213
+
214
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
215
+ yield PrettyType(self.__class__.__name__)
216
+ yield PrettyAttr("target", self.target)
217
+ yield PrettyAttr("grad_states", self._grad_states)
218
+ yield PrettyAttr("grad_tree", self._grad_tree)
219
+ yield PrettyAttr("argnums", self._nonvar_argnums)
220
+ yield PrettyAttr("return_value", self._return_value)
221
+ yield PrettyAttr("has_aux", self._has_aux)
222
+ yield PrettyAttr("transform", self._transform)
223
+
224
+ def _call_target(self, *args, **kwargs):
225
+ if self._states_to_be_written is None:
226
+ with StateTraceStack() as stack:
227
+ output = self.target(*args, **kwargs)
228
+ # grad_ids = set([id(v) for v in self._grad_states])
229
+ # self._states_to_be_written = [st for st in stack.get_write_states() if id(st) not in grad_ids]
230
+ self._states_to_be_written = [st for st in stack.get_write_states()]
231
+ else:
232
+ output = self.target(*args, **kwargs)
233
+ return output
234
+
235
+ def _fun_with_aux(self, grad_values: tuple, *args, **kwargs):
236
+ for v, d in zip(self._grad_states, grad_values):
237
+ v.restore_value(d)
238
+ # Users should return the auxiliary data like::
239
+ # >>> # 1. example of return one data
240
+ # >>> return scalar_loss, data
241
+ # >>> # 2. example of return multiple data
242
+ # >>> return scalar_loss, (data1, data2, ...)
243
+ outs = self._call_target(*args, **kwargs)
244
+ # outputs: [0] is the value for gradient,
245
+ # [1] is other values for return
246
+ assert self._states_to_be_written is not None, "The states to be written should be collected."
247
+ return outs[0], (outs, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
248
+
249
+ def _fun_without_aux(self, grad_values: tuple, *args, **kwargs):
250
+ for v, d in zip(self._grad_states, grad_values):
251
+ v.restore_value(d)
252
+ # Users should return the scalar value like this::
253
+ # >>> return scalar_loss
254
+ out = self._call_target(*args, **kwargs)
255
+ assert self._states_to_be_written is not None, "The states to be written should be collected."
256
+ return out, (out, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
257
+
258
+ def _return(self, rets):
259
+ grads, (outputs, new_grad_vals, new_dyn_vals) = rets
260
+ for i, val in enumerate(new_grad_vals):
261
+ self._grad_states[i].restore_value(val)
262
+ for i, val in enumerate(new_dyn_vals):
263
+ self._states_to_be_written[i].value = val
264
+
265
+ # check returned grads
266
+ if len(self._grad_states) > 0:
267
+ if self._nonvar_argnums is None:
268
+ grads = self._grad_tree.unflatten(grads)
269
+ else:
270
+ var_grads = self._grad_tree.unflatten(grads[0])
271
+ arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
272
+ grads = (var_grads, arg_grads)
273
+
274
+ # check returned value
275
+ if self._return_value:
276
+ # check aux
277
+ if self._has_aux:
278
+ return grads, outputs[0], outputs[1]
279
+ else:
280
+ return grads, outputs
281
+ else:
282
+ # check aux
283
+ if self._has_aux:
284
+ return grads, outputs[1]
285
+ else:
286
+ return grads
287
+
288
+ def __call__(
289
+ self, *args, **kwargs
290
+ ) -> Gradient | Tuple[Gradient, LossValue] | Tuple[Gradient, AuxData] | Tuple[Gradient, LossValue, AuxData]:
291
+ rets = self._transform([v.value for v in self._grad_states], *args, **kwargs)
292
+ return self._return(rets)
293
+
294
+
295
+ _doc_of_return = '''
296
+
297
+ 1. When ``grad_states`` is None
298
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
299
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
300
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
301
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
302
+ 2. When ``grad_states`` is not None and ``argnums`` is None
303
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
304
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
305
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
306
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
307
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
308
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
309
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
310
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
311
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
312
+
313
+ '''
314
+
315
+
316
+ @set_module_as("brainstate.augment")
317
+ def grad(
318
+ fun: Callable = Missing(),
319
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
320
+ argnums: Optional[Union[int, Sequence[int]]] = None,
321
+ holomorphic: Optional[bool] = False,
322
+ allow_int: Optional[bool] = False,
323
+ reduce_axes: Optional[Sequence[str]] = (),
324
+ has_aux: Optional[bool] = None,
325
+ return_value: Optional[bool] = False,
326
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
327
+ """
328
+ Compute the gradient of a scalar-valued function with respect to its arguments.
329
+
330
+ %s
331
+
332
+ Args:
333
+ fun: callable. the scalar-valued function to be differentiated.
334
+ reduce_axes: (Sequence[str]) optional. Specifies the axes to reduce over when
335
+ differentiating with respect to array-valued arguments. The default, (),
336
+ means to differentiate each element of the output with respect to each
337
+ element of the argument. If the argument is an array, this argument controls
338
+ how many axes the output of grad has.
339
+ allow_int: (bool) optional. Whether to allow differentiating with respect to
340
+ integer valued inputs. The gradient of an integer input will have a trivial
341
+ vector-space dtype (float0). Default False.
342
+ holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
343
+ Default False.
344
+ grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
345
+ in fun to take their gradients.
346
+ fun: the scalar-valued function to be differentiated.
347
+ argnums: (int or tuple of ints) optional. Specifies which positional
348
+ argument(s) to differentiate with respect to.
349
+ has_aux: (bool) optional. Indicates whether fun returns a pair where the
350
+ first element is considered the output of the mathematical function to be
351
+ differentiated and the second element is auxiliary data. Default False.
352
+ return_value: (bool) optional. Indicates whether to return the value of the
353
+ function along with the gradient. Default False.
354
+
355
+ Returns:
356
+ A function which computes the gradient of fun. The function takes the same
357
+ arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
358
+ the function returns a pair where the first element is the gradient and the
359
+ second element is the auxiliary data. If `return_value` is True, the function
360
+ returns a pair where the first element is the gradient and the second element
361
+ is the value of the function.
362
+
363
+ """
364
+ if isinstance(fun, Missing):
365
+ def transform(fun) -> GradientTransform:
366
+ return GradientTransform(target=fun,
367
+ transform=jax.grad,
368
+ grad_states=grad_states,
369
+ argnums=argnums,
370
+ return_value=return_value,
371
+ has_aux=False if has_aux is None else has_aux,
372
+ transform_params=dict(holomorphic=holomorphic,
373
+ allow_int=allow_int,
374
+ reduce_axes=reduce_axes))
375
+
376
+ return transform
377
+
378
+ return GradientTransform(target=fun,
379
+ transform=jax.grad,
380
+ grad_states=grad_states,
381
+ argnums=argnums,
382
+ return_value=return_value,
383
+ has_aux=False if has_aux is None else has_aux,
384
+ transform_params=dict(holomorphic=holomorphic,
385
+ allow_int=allow_int,
386
+ reduce_axes=reduce_axes))
387
+
388
+
389
+ grad.__doc__ = grad.__doc__ % _doc_of_return
390
+
391
+
392
+ @set_module_as("brainstate.augment")
393
+ def vector_grad(
394
+ func: Callable = Missing(),
395
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
396
+ argnums: Optional[Union[int, Sequence[int]]] = None,
397
+ return_value: bool = False,
398
+ has_aux: Optional[bool] = None,
399
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
400
+ """Take vector-valued gradients for function ``func``.
401
+
402
+ Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
403
+ the returns in this function are different for different argument settings.
404
+
405
+ %s
406
+
407
+ Parameters
408
+ ----------
409
+ func: Callable
410
+ Function whose gradient is to be computed.
411
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
412
+ The variables in ``func`` to take their gradients.
413
+ has_aux: optional, bool
414
+ Indicates whether ``fun`` returns a pair where the
415
+ first element is considered the output of the mathematical function to be
416
+ differentiated and the second element is auxiliary data. Default False.
417
+ return_value : bool
418
+ Whether return the loss value.
419
+ argnums: Optional, integer or sequence of integers. Specifies which
420
+ positional argument(s) to differentiate with respect to (default ``0``).
421
+
422
+ Returns
423
+ -------
424
+ func : GradientTransform
425
+ The vector gradient function.
426
+ """
427
+
428
+ if isinstance(func, Missing):
429
+ def transform(fun) -> GradientTransform:
430
+ return GradientTransform(target=fun,
431
+ transform=functional_vector_grad,
432
+ grad_states=grad_states,
433
+ argnums=argnums,
434
+ return_value=return_value,
435
+ has_aux=False if has_aux is None else has_aux)
436
+
437
+ return transform
438
+
439
+ else:
440
+ return GradientTransform(target=func,
441
+ transform=functional_vector_grad,
442
+ grad_states=grad_states,
443
+ argnums=argnums,
444
+ return_value=return_value,
445
+ has_aux=False if has_aux is None else has_aux)
446
+
447
+
448
+ vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
449
+
450
+
451
+ @set_module_as("brainstate.augment")
452
+ def jacrev(
453
+ fun: Callable,
454
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
455
+ argnums: Optional[Union[int, Sequence[int]]] = None,
456
+ has_aux: Optional[bool] = None,
457
+ return_value: bool = False,
458
+ holomorphic: bool = False,
459
+ allow_int: bool = False,
460
+ ) -> GradientTransform:
461
+ """
462
+ Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
463
+
464
+ This function extends the JAX official ``jacrev`` to make automatic jacobian
465
+ computation on functions and class functions. Moreover, it supports returning
466
+ value ("return_value") and returning auxiliary data ("has_aux").
467
+
468
+ %s
469
+
470
+
471
+ Parameters
472
+ ----------
473
+ fun: Function whose Jacobian is to be computed.
474
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
475
+ The variables in ``func`` to take their gradients.
476
+ has_aux: optional, bool
477
+ Indicates whether ``fun`` returns a pair where the
478
+ first element is considered the output of the mathematical function to be
479
+ differentiated and the second element is auxiliary data. Default False.
480
+ return_value : bool
481
+ Whether return the loss value.
482
+ argnums: Optional, integer or sequence of integers.
483
+ Specifies which
484
+ positional argument(s) to differentiate with respect to (default ``0``).
485
+ holomorphic: Optional, bool.
486
+ Indicates whether ``fun`` is promised to be
487
+ holomorphic. Default False.
488
+ allow_int: Optional, bool.
489
+ Whether to allow differentiating with
490
+ respect to integer valued inputs. The gradient of an integer input will
491
+ have a trivial vector-space dtype (float0). Default False.
492
+
493
+ Returns
494
+ -------
495
+ fun: GradientTransform
496
+ The transformed object.
497
+ """
498
+ return GradientTransform(target=fun,
499
+ transform=_jacrev,
500
+ grad_states=grad_states,
501
+ argnums=argnums,
502
+ return_value=return_value,
503
+ has_aux=False if has_aux is None else has_aux,
504
+ transform_params=dict(holomorphic=holomorphic,
505
+ allow_int=allow_int))
506
+
507
+
508
+ jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
509
+
510
+ jacobian = jacrev
511
+
512
+
513
+ @set_module_as("brainstate.augment")
514
+ def jacfwd(
515
+ func: Callable,
516
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
517
+ argnums: Optional[Union[int, Sequence[int]]] = None,
518
+ has_aux: Optional[bool] = None,
519
+ return_value: bool = False,
520
+ holomorphic: bool = False,
521
+ ) -> GradientTransform:
522
+ """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
523
+
524
+ This function extends the JAX official ``jacfwd`` to make automatic jacobian
525
+ computation on functions and class functions. Moreover, it supports returning
526
+ value ("return_value") and returning auxiliary data ("has_aux").
527
+
528
+ %s
529
+
530
+ Parameters
531
+ ----------
532
+ func: Function whose Jacobian is to be computed.
533
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
534
+ The variables in ``func`` to take their gradients.
535
+ has_aux: optional, bool
536
+ Indicates whether ``fun`` returns a pair where the
537
+ first element is considered the output of the mathematical function to be
538
+ differentiated and the second element is auxiliary data. Default False.
539
+ return_value : bool
540
+ Whether return the loss value.
541
+ argnums: Optional, integer or sequence of integers. Specifies which
542
+ positional argument(s) to differentiate with respect to (default ``0``).
543
+ holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
544
+ holomorphic. Default False.
545
+
546
+ Returns
547
+ -------
548
+ obj: GradientTransform
549
+ The transformed object.
550
+ """
551
+
552
+ return GradientTransform(target=func,
553
+ transform=_jacfwd,
554
+ grad_states=grad_states,
555
+ argnums=argnums,
556
+ return_value=return_value,
557
+ has_aux=False if has_aux is None else has_aux,
558
+ transform_params=dict(holomorphic=holomorphic))
559
+
560
+
561
+ jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
562
+
563
+
564
+ @set_module_as("brainstate.augment")
565
+ def hessian(
566
+ func: Callable,
567
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
568
+ argnums: Optional[Union[int, Sequence[int]]] = None,
569
+ has_aux: bool = False,
570
+ return_value: bool = False,
571
+ holomorphic: bool = False,
572
+ ) -> GradientTransform:
573
+ """
574
+ Hessian of ``func`` as a dense array.
575
+
576
+ %s
577
+
578
+ Parameters
579
+ ----------
580
+ func : callable
581
+ Function whose Hessian is to be computed. Its arguments at positions
582
+ specified by ``argnums`` should be arrays, scalars, or standard Python
583
+ containers thereof. It should return arrays, scalars, or standard Python
584
+ containers thereof.
585
+ grad_states : optional, ArrayCollector, sequence of ArrayType
586
+ The variables required to compute their gradients.
587
+ argnums: Optional, integer or sequence of integers
588
+ Specifies which positional argument(s) to differentiate with respect to (default ``0``).
589
+ holomorphic : bool
590
+ Indicates whether ``fun`` is promised to be holomorphic. Default False.
591
+ return_value : bool
592
+ Whether return the hessian values.
593
+
594
+ Returns
595
+ -------
596
+ obj: ObjectTransform
597
+ The transformed object.
598
+ """
599
+ return GradientTransform(target=func,
600
+ transform=jax.hessian,
601
+ grad_states=grad_states,
602
+ argnums=argnums,
603
+ return_value=return_value,
604
+ has_aux=False if has_aux is None else has_aux,
605
+ transform_params=dict(holomorphic=holomorphic))
606
+
607
+
608
+ hessian.__doc__ = hessian.__doc__ % _doc_of_return