brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,778 +1,1025 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
- This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
- computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
- which has been moved into the ``compile`` module.
21
-
22
- The wrapped gradient transformations here are made possible by using the following ideas:
23
- 1. All the states to compute the gradients should be known before the transformation.
24
- There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
- 2. The states that have been written in the function should be collected and updated after the function call.
26
- We record these states during the function call and updated them after the function call.
27
-
28
- """
29
-
30
- from functools import wraps, partial
31
- from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
32
-
33
- import brainunit as u
34
- import jax
35
-
36
- from brainstate._state import State
37
- from brainstate._utils import set_module_as
38
- from brainstate.compile._make_jaxpr import StatefulFunction
39
- from brainstate.typing import PyTree, Missing
40
- from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
41
-
42
- __all__ = [
43
- 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
44
- ]
45
-
46
- A = TypeVar('A')
47
- Gradient = PyTree
48
- LossValue = PyTree
49
- AuxData = PyTree
50
-
51
-
52
- def _jacrev(
53
- fun,
54
- argnums=0,
55
- holomorphic=False,
56
- allow_int=False,
57
- has_aux=False,
58
- return_value=False,
59
- unit_aware=False,
60
- ):
61
- @wraps(fun)
62
- def fun_wrapped(*args, **kwargs):
63
- if has_aux:
64
- y, aux = fun(*args, **kwargs)
65
- if return_value:
66
- return y, (y, aux)
67
- else:
68
- return y, aux
69
- else:
70
- y = fun(*args, **kwargs)
71
- if return_value:
72
- return y, y
73
- else:
74
- return y, None
75
-
76
- if unit_aware:
77
- transform = u.autograd.jacrev(fun_wrapped,
78
- argnums=argnums,
79
- holomorphic=holomorphic,
80
- allow_int=allow_int,
81
- has_aux=True)
82
- else:
83
- transform = jax.jacrev(fun_wrapped,
84
- argnums=argnums,
85
- holomorphic=holomorphic,
86
- allow_int=allow_int,
87
- has_aux=True)
88
-
89
- @wraps(fun)
90
- def jacfun(*args, **kwargs):
91
- jac, aux = transform(*args, **kwargs)
92
- if return_value:
93
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
94
- else:
95
- return (jac, aux) if has_aux else jac
96
-
97
- return jacfun
98
-
99
-
100
- def _jacfwd(
101
- fun,
102
- argnums=0,
103
- holomorphic=False,
104
- has_aux=False,
105
- return_value=False,
106
- unit_aware=False,
107
- ):
108
- @wraps(fun)
109
- def fun_wrapped(*args, **kwargs):
110
- if has_aux:
111
- y, aux = fun(*args, **kwargs)
112
- if return_value:
113
- return y, (y, aux)
114
- else:
115
- return y, aux
116
- else:
117
- y = fun(*args, **kwargs)
118
- if return_value:
119
- return y, y
120
- else:
121
- return y, None
122
-
123
- if unit_aware:
124
- transform = u.autograd.jacfwd(fun_wrapped,
125
- argnums=argnums,
126
- holomorphic=holomorphic,
127
- has_aux=True)
128
- else:
129
- transform = jax.jacfwd(fun_wrapped,
130
- argnums=argnums,
131
- holomorphic=holomorphic,
132
- has_aux=True)
133
-
134
- @wraps(fun)
135
- def jacfun(*args, **kwargs):
136
- jac, aux = transform(*args, **kwargs)
137
- if return_value:
138
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
139
- else:
140
- return (jac, aux) if has_aux else jac
141
-
142
- return jacfun
143
-
144
-
145
- TransformFn = Callable
146
-
147
-
148
- class GradientTransform(PrettyRepr):
149
- """
150
- Automatic Differentiation Transformations for the ``State`` system.
151
-
152
- This class implements gradient transformations for functions that operate on State objects.
153
- It allows for flexible configuration of gradient computation with respect to specified states
154
- and function arguments.
155
-
156
- Attributes:
157
- target (Callable): The function to be transformed.
158
- stateful_target (StatefulFunction): A wrapper around the target function for state management.
159
- raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
160
- true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
161
- return_value (bool): Whether to return the function's value along with gradients.
162
- has_aux (bool): Whether the function returns auxiliary data.
163
- """
164
-
165
- __module__ = "brainstate.augment"
166
-
167
- def __init__(
168
- self,
169
- target: Callable,
170
- transform: TransformFn,
171
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
172
- argnums: Optional[Union[int, Sequence[int]]] = None,
173
- return_value: bool = False,
174
- has_aux: bool = False,
175
- transform_params: Optional[Dict[str, Any]] = None,
176
- check_states: bool = True,
177
- ):
178
- """
179
- Initialize a ``GradientTransform`` instance.
180
-
181
- Args:
182
- target (Callable): The function to be transformed.
183
- transform (TransformFn): The transformation function to apply.
184
- grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
185
- argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
186
- return_value (bool): Whether to return the function's value along with gradients.
187
- has_aux (bool): Whether the function returns auxiliary data.
188
- transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
189
-
190
- Raises:
191
- TypeError: If any grad_states are not State instances.
192
- """
193
- # gradient variables
194
- if isinstance(grad_states, dict):
195
- grad_states = {k: v for k, v in grad_states.items()}
196
- self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
197
- self._grad_state_ids = [id(v) for v in self._grad_states]
198
- self._grad_id_to_state = {id(v): v for v in self._grad_states}
199
- if any(not isinstance(v, State) for v in self._grad_states):
200
- raise TypeError("All grad_states must be State instances.")
201
- self.check_states = check_states
202
-
203
- # parameters
204
- if argnums is None and len(self._grad_states) == 0:
205
- argnums = 0
206
- if argnums is None:
207
- assert len(self._grad_states) > 0
208
- _argnums = 0
209
- elif isinstance(argnums, int):
210
- _argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
211
- else:
212
- assert isinstance(argnums, (tuple, list))
213
- _argnums = tuple(a + 2 for a in argnums)
214
- if len(self._grad_states) > 0:
215
- _argnums = (0,) + _argnums
216
- self.raw_argnums = argnums
217
- self.true_argnums = _argnums
218
- self.return_value = return_value
219
- self.has_aux = has_aux
220
-
221
- # target
222
- assert callable(target), "The target should be a callable object."
223
- self.target = target
224
- self.stateful_target = StatefulFunction(target, name='gradient')
225
-
226
- # transform
227
- grad_setting = dict() if transform_params is None else transform_params
228
- if self.has_aux:
229
- self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
230
- else:
231
- self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
232
-
233
- def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
234
- yield PrettyType(self.__class__.__name__)
235
- yield PrettyAttr("target", self.target)
236
- yield PrettyAttr("grad_states", self._grad_states)
237
- yield PrettyAttr("grad_tree", self._grad_tree)
238
- yield PrettyAttr("argnums", self.raw_argnums)
239
- yield PrettyAttr("return_value", self.return_value)
240
- yield PrettyAttr("has_aux", self.has_aux)
241
- yield PrettyAttr("transform", self._transform)
242
-
243
- def _split_state_vals(self, state_trace):
244
- """
245
- Split state values into gradient and non-gradient states.
246
-
247
- Args:
248
- state_trace: The state trace containing all states.
249
-
250
- Returns:
251
- Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
252
- """
253
- grad_vals = dict()
254
- other_vals = dict()
255
- all_ids = set(self._grad_state_ids)
256
- for st in state_trace.states:
257
- id_ = id(st)
258
- if id_ in all_ids:
259
- grad_vals[id_] = st.value
260
- all_ids.remove(id_)
261
- else:
262
- other_vals[id_] = st.value
263
- if len(all_ids):
264
- if self.check_states:
265
- err = f"Some states are not found in the state trace when performing gradient transformations.\n "
266
- for i, id_ in enumerate(all_ids):
267
- st = self._grad_id_to_state[id_]
268
- st.raise_error_with_source_info(ValueError(err + str(st)))
269
- else:
270
- id2state = {id(st): st for st in self._grad_states}
271
- for id_ in all_ids:
272
- grad_vals[id_] = id2state[id_].value
273
-
274
- return grad_vals, other_vals
275
-
276
- def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
277
- """
278
- Merge gradient and non-gradient state values back into a single list.
279
-
280
- Args:
281
- grad_vals (Dict): Dictionary of gradient state values.
282
- other_vals (Dict): Dictionary of non-gradient state values.
283
- state_trace: The state trace containing all states.
284
-
285
- Returns:
286
- List: A list of merged state values.
287
- """
288
- res = []
289
- for st in state_trace.states:
290
- id_ = id(st)
291
- if id_ in self._grad_state_ids:
292
- res.append(grad_vals[id_])
293
- else:
294
- res.append(other_vals[id_])
295
- return res
296
-
297
- def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
298
- """
299
- Call the target function with the given state values and arguments.
300
-
301
- Args:
302
- grad_vals (Dict): Dictionary of gradient state values.
303
- other_vals (Dict): Dictionary of non-gradient state values.
304
- *args: Positional arguments to pass to the target function.
305
- **kwargs: Keyword arguments to pass to the target function.
306
-
307
- Returns:
308
- Tuple: A tuple containing updated state values and the function output.
309
- """
310
- cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
311
- state_trace = self.stateful_target.get_state_trace(cache)
312
- state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
313
- state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
314
- return state_vals, out
315
-
316
- def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
317
- """
318
- Wrapper function for target functions that return auxiliary data.
319
-
320
- Args:
321
- grad_vals (Dict): Dictionary of gradient state values.
322
- other_vals (Dict): Dictionary of non-gradient state values.
323
- *args: Positional arguments to pass to the target function.
324
- **kwargs: Keyword arguments to pass to the target function.
325
-
326
- Returns:
327
- Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
328
- """
329
- # Users should return the auxiliary data like::
330
- # >>> # 1. example of return one data
331
- # >>> return scalar_loss, data
332
- # >>> # 2. example of return multiple data
333
- # >>> return scalar_loss, (data1, data2, ...)
334
- state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
335
- return outs[0], (outs, state_vals)
336
-
337
- def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
338
- """
339
- Wrapper function for target functions that do not return auxiliary data.
340
-
341
- Args:
342
- grad_vals (Dict): Dictionary of gradient state values.
343
- other_vals (Dict): Dictionary of non-gradient state values.
344
- *args: Positional arguments to pass to the target function.
345
- **kwargs: Keyword arguments to pass to the target function.
346
-
347
- Returns:
348
- Tuple: A tuple containing the output and a tuple of (output, updated state values).
349
- """
350
- state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
351
- return out, (out, state_vals)
352
-
353
- def _return(self, rets, state_trace):
354
- """
355
- Process and format the return values from the gradient computation.
356
-
357
- Args:
358
- rets: The raw results from the gradient computation.
359
- state_trace: The state trace containing all states.
360
-
361
- Returns:
362
- Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
363
- """
364
- # unpack the return values
365
- grads, (outputs, new_state_vals) = rets
366
-
367
- # assign new values to the states
368
- state_trace.assign_state_vals(new_state_vals)
369
-
370
- # check returned grads
371
- if len(self._grad_states) > 0:
372
- grads_of_states = grads if self.raw_argnums is None else grads[0]
373
- grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
374
- if self.raw_argnums is None:
375
- grads = self._grad_tree.unflatten(grads_of_states)
376
- else:
377
- var_grads = self._grad_tree.unflatten(grads_of_states)
378
- arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
379
- grads = (var_grads, arg_grads)
380
-
381
- # check returned value
382
- if self.return_value:
383
- # check aux
384
- if self.has_aux:
385
- return grads, outputs[0], outputs[1]
386
- else:
387
- return grads, outputs
388
- else:
389
- # check aux
390
- if self.has_aux:
391
- return grads, outputs[1]
392
- else:
393
- return grads
394
-
395
- def __call__(
396
- self, *args, **kwargs
397
- ) -> (
398
- Gradient |
399
- Tuple[Gradient, LossValue] |
400
- Tuple[Gradient, AuxData] |
401
- Tuple[Gradient, LossValue, AuxData]
402
- ):
403
- """
404
- Compute gradients by calling the transformed function.
405
-
406
- Args:
407
- *args: Positional arguments to pass to the target function.
408
- **kwargs: Keyword arguments to pass to the target function.
409
-
410
- Returns:
411
- Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
412
- """
413
-
414
- # TODO: support jax.disable_jit()
415
-
416
- # compute the model
417
- self.stateful_target.make_jaxpr(*args, **kwargs)
418
- cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
419
-
420
- # apply the gradient transformation
421
- state_trace = self.stateful_target.get_state_trace(cache)
422
- rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
423
-
424
- # analyze and return the results
425
- return self._return(rets, state_trace)
426
-
427
-
428
- _doc_of_return = '''
429
-
430
- 1. When ``grad_states`` is None
431
- - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
432
- - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
433
- - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
434
- - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
435
- 2. When ``grad_states`` is not None and ``argnums`` is None
436
- - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
437
- - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
438
- - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
439
- - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
440
- 3. When ``grad_states`` is not None and ``argnums`` is not None
441
- - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
442
- - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
443
- - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
444
- - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
445
-
446
- '''
447
-
448
-
449
- @set_module_as("brainstate.augment")
450
- def grad(
451
- fun: Callable = Missing(),
452
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
453
- argnums: Optional[Union[int, Sequence[int]]] = None,
454
- holomorphic: Optional[bool] = False,
455
- allow_int: Optional[bool] = False,
456
- has_aux: Optional[bool] = None,
457
- return_value: Optional[bool] = False,
458
- unit_aware: bool = False,
459
- check_states: bool = True,
460
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
461
- """
462
- Compute the gradient of a scalar-valued function with respect to its arguments.
463
-
464
- %s
465
-
466
- Args:
467
- fun: callable. the scalar-valued function to be differentiated.
468
- allow_int: (bool) optional. Whether to allow differentiating with respect to
469
- integer valued inputs. The gradient of an integer input will have a trivial
470
- vector-space dtype (float0). Default False.
471
- holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
472
- Default False.
473
- grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
474
- in fun to take their gradients.
475
- fun: the scalar-valued function to be differentiated.
476
- argnums: (int or tuple of ints) optional. Specifies which positional
477
- argument(s) to differentiate with respect to.
478
- has_aux: (bool) optional. Indicates whether fun returns a pair where the
479
- first element is considered the output of the mathematical function to be
480
- differentiated and the second element is auxiliary data. Default False.
481
- return_value: (bool) optional. Indicates whether to return the value of the
482
- function along with the gradient. Default False.
483
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
484
- mode. Default False.
485
-
486
- Returns:
487
- A function which computes the gradient of fun. The function takes the same
488
- arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
489
- the function returns a pair where the first element is the gradient and the
490
- second element is the auxiliary data. If `return_value` is True, the function
491
- returns a pair where the first element is the gradient and the second element
492
- is the value of the function.
493
-
494
- """
495
- if isinstance(fun, Missing):
496
- def transform(fun) -> GradientTransform:
497
- return GradientTransform(
498
- target=fun,
499
- transform=u.autograd.grad if unit_aware else jax.grad,
500
- grad_states=grad_states,
501
- argnums=argnums,
502
- return_value=return_value,
503
- has_aux=False if has_aux is None else has_aux,
504
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
505
- check_states=check_states
506
- )
507
-
508
- return transform
509
-
510
- return GradientTransform(
511
- target=fun,
512
- transform=u.autograd.grad if unit_aware else jax.grad,
513
- grad_states=grad_states,
514
- argnums=argnums,
515
- return_value=return_value,
516
- has_aux=False if has_aux is None else has_aux,
517
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
518
- check_states=check_states
519
- )
520
-
521
-
522
- grad.__doc__ = grad.__doc__ % _doc_of_return
523
-
524
-
525
- @set_module_as("brainstate.augment")
526
- def vector_grad(
527
- func: Callable = Missing(),
528
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
529
- argnums: Optional[Union[int, Sequence[int]]] = None,
530
- return_value: bool = False,
531
- has_aux: Optional[bool] = None,
532
- unit_aware: bool = False,
533
- check_states: bool = True,
534
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
535
- """Take vector-valued gradients for function ``func``.
536
-
537
- Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
538
- the returns in this function are different for different argument settings.
539
-
540
- %s
541
-
542
- Parameters
543
- ----------
544
- func: Callable
545
- Function whose gradient is to be computed.
546
- grad_states : optional, ArrayType, sequence of ArrayType, dict
547
- The variables in ``func`` to take their gradients.
548
- has_aux: optional, bool
549
- Indicates whether ``fun`` returns a pair where the
550
- first element is considered the output of the mathematical function to be
551
- differentiated and the second element is auxiliary data. Default False.
552
- return_value : bool
553
- Whether return the loss value.
554
- argnums: Optional, integer or sequence of integers. Specifies which
555
- positional argument(s) to differentiate with respect to (default ``0``).
556
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
557
- mode. Default False.
558
-
559
- Returns
560
- -------
561
- func : GradientTransform
562
- The vector gradient function.
563
- """
564
-
565
- if isinstance(func, Missing):
566
- def transform(fun) -> GradientTransform:
567
- return GradientTransform(
568
- target=fun,
569
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
570
- grad_states=grad_states,
571
- argnums=argnums,
572
- return_value=return_value,
573
- has_aux=False if has_aux is None else has_aux,
574
- check_states=check_states
575
- )
576
-
577
- return transform
578
-
579
- else:
580
- return GradientTransform(
581
- target=func,
582
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
583
- grad_states=grad_states,
584
- argnums=argnums,
585
- return_value=return_value,
586
- has_aux=False if has_aux is None else has_aux,
587
- check_states=check_states
588
- )
589
-
590
-
591
- vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
592
-
593
-
594
- @set_module_as("brainstate.augment")
595
- def jacrev(
596
- fun: Callable,
597
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
598
- argnums: Optional[Union[int, Sequence[int]]] = None,
599
- has_aux: Optional[bool] = None,
600
- return_value: bool = False,
601
- holomorphic: bool = False,
602
- allow_int: bool = False,
603
- unit_aware: bool = False,
604
- check_states: bool = True,
605
- ) -> GradientTransform:
606
- """
607
- Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
608
-
609
- This function extends the JAX official ``jacrev`` to make automatic jacobian
610
- computation on functions and class functions. Moreover, it supports returning
611
- value ("return_value") and returning auxiliary data ("has_aux").
612
-
613
- %s
614
-
615
-
616
- Parameters
617
- ----------
618
- fun: Callable
619
- Function whose Jacobian is to be computed.
620
- grad_states : optional, ArrayType, sequence of ArrayType, dict
621
- The variables in ``func`` to take their gradients.
622
- has_aux: optional, bool
623
- Indicates whether ``fun`` returns a pair where the
624
- first element is considered the output of the mathematical function to be
625
- differentiated and the second element is auxiliary data. Default False.
626
- return_value : bool
627
- Whether return the loss value.
628
- argnums: Optional, integer or sequence of integers.
629
- Specifies which
630
- positional argument(s) to differentiate with respect to (default ``0``).
631
- holomorphic: Optional, bool.
632
- Indicates whether ``fun`` is promised to be
633
- holomorphic. Default False.
634
- allow_int: Optional, bool.
635
- Whether to allow differentiating with
636
- respect to integer valued inputs. The gradient of an integer input will
637
- have a trivial vector-space dtype (float0). Default False.
638
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
639
- mode. Default False.
640
-
641
- Returns
642
- -------
643
- fun: GradientTransform
644
- The transformed object.
645
- """
646
- return GradientTransform(
647
- target=fun,
648
- transform=_jacrev,
649
- grad_states=grad_states,
650
- argnums=argnums,
651
- return_value=return_value,
652
- has_aux=False if has_aux is None else has_aux,
653
- transform_params=dict(holomorphic=holomorphic,
654
- allow_int=allow_int,
655
- unit_aware=unit_aware, ),
656
- check_states=check_states
657
- )
658
-
659
-
660
- jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
661
-
662
- jacobian = jacrev
663
-
664
-
665
- @set_module_as("brainstate.augment")
666
- def jacfwd(
667
- func: Callable,
668
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
669
- argnums: Optional[Union[int, Sequence[int]]] = None,
670
- has_aux: Optional[bool] = None,
671
- return_value: bool = False,
672
- holomorphic: bool = False,
673
- unit_aware: bool = False,
674
- check_states: bool = True,
675
- ) -> GradientTransform:
676
- """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
677
-
678
- This function extends the JAX official ``jacfwd`` to make automatic jacobian
679
- computation on functions and class functions. Moreover, it supports returning
680
- value ("return_value") and returning auxiliary data ("has_aux").
681
-
682
- %s
683
-
684
- Parameters
685
- ----------
686
- func: Function whose Jacobian is to be computed.
687
- grad_states : optional, ArrayType, sequence of ArrayType, dict
688
- The variables in ``func`` to take their gradients.
689
- has_aux: optional, bool
690
- Indicates whether ``fun`` returns a pair where the
691
- first element is considered the output of the mathematical function to be
692
- differentiated and the second element is auxiliary data. Default False.
693
- return_value : bool
694
- Whether return the loss value.
695
- argnums: Optional, integer or sequence of integers. Specifies which
696
- positional argument(s) to differentiate with respect to (default ``0``).
697
- holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
698
- holomorphic. Default False.
699
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
700
- mode. Default False.
701
-
702
- Returns
703
- -------
704
- obj: GradientTransform
705
- The transformed object.
706
- """
707
-
708
- return GradientTransform(
709
- target=func,
710
- transform=_jacfwd,
711
- grad_states=grad_states,
712
- argnums=argnums,
713
- return_value=return_value,
714
- has_aux=False if has_aux is None else has_aux,
715
- transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
716
- check_states=check_states
717
- )
718
-
719
-
720
- jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
721
-
722
-
723
- @set_module_as("brainstate.augment")
724
- def hessian(
725
- func: Callable,
726
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
727
- argnums: Optional[Union[int, Sequence[int]]] = None,
728
- return_value: bool = False,
729
- holomorphic: bool = False,
730
- has_aux: Optional[bool] = None,
731
- unit_aware: bool = False,
732
- check_states: bool = True,
733
- ) -> GradientTransform:
734
- """
735
- Hessian of ``func`` as a dense array.
736
-
737
- %s
738
-
739
- Parameters
740
- ----------
741
- func : callable
742
- Function whose Hessian is to be computed. Its arguments at positions
743
- specified by ``argnums`` should be arrays, scalars, or standard Python
744
- containers thereof. It should return arrays, scalars, or standard Python
745
- containers thereof.
746
- grad_states : optional, ArrayCollector, sequence of ArrayType
747
- The variables required to compute their gradients.
748
- argnums: Optional, integer or sequence of integers
749
- Specifies which positional argument(s) to differentiate with respect to (default ``0``).
750
- holomorphic : bool
751
- Indicates whether ``fun`` is promised to be holomorphic. Default False.
752
- return_value : bool
753
- Whether return the hessian values.
754
- has_aux: Optional, bool
755
- Indicates whether ``fun`` returns a pair where the first element is considered
756
- the output of the mathematical function to be differentiated and the second
757
- element is auxiliary data. Default False.
758
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
759
- mode. Default False.
760
-
761
- Returns
762
- -------
763
- obj: ObjectTransform
764
- The transformed object.
765
- """
766
- return GradientTransform(
767
- target=func,
768
- transform=u.autograd.hessian if unit_aware else jax.hessian,
769
- grad_states=grad_states,
770
- argnums=argnums,
771
- return_value=return_value,
772
- has_aux=False if has_aux is None else has_aux,
773
- transform_params=dict(holomorphic=holomorphic),
774
- check_states=check_states
775
- )
776
-
777
-
778
- hessian.__doc__ = hessian.__doc__ % _doc_of_return
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
+ This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
+ computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
+ which has been moved into the ``compile`` module.
21
+
22
+ The wrapped gradient transformations here are made possible by using the following ideas:
23
+ 1. All the states to compute the gradients should be known before the transformation.
24
+ There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
+ 2. The states that have been written in the function should be collected and updated after the function call.
26
+ We record these states during the function call and updated them after the function call.
27
+
28
+ """
29
+
30
+ from functools import wraps, partial
31
+ from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
32
+
33
+ import brainunit as u
34
+ import jax
35
+
36
+ from brainstate._state import State
37
+ from brainstate._utils import set_module_as
38
+ from brainstate.transform._make_jaxpr import StatefulFunction
39
+ from brainstate.typing import PyTree, Missing
40
+ from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
41
+
42
+ __all__ = [
43
+ 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
44
+ ]
45
+
46
+ A = TypeVar('A')
47
+ Gradient = PyTree
48
+ LossValue = PyTree
49
+ AuxData = PyTree
50
+
51
+
52
+ def _jacrev(
53
+ fun,
54
+ argnums=0,
55
+ holomorphic=False,
56
+ allow_int=False,
57
+ has_aux=False,
58
+ return_value=False,
59
+ unit_aware=False,
60
+ ):
61
+ @wraps(fun)
62
+ def fun_wrapped(*args, **kwargs):
63
+ if has_aux:
64
+ y, aux = fun(*args, **kwargs)
65
+ if return_value:
66
+ return y, (y, aux)
67
+ else:
68
+ return y, aux
69
+ else:
70
+ y = fun(*args, **kwargs)
71
+ if return_value:
72
+ return y, y
73
+ else:
74
+ return y, None
75
+
76
+ if unit_aware:
77
+ transform = u.autograd.jacrev(fun_wrapped,
78
+ argnums=argnums,
79
+ holomorphic=holomorphic,
80
+ allow_int=allow_int,
81
+ has_aux=True)
82
+ else:
83
+ transform = jax.jacrev(fun_wrapped,
84
+ argnums=argnums,
85
+ holomorphic=holomorphic,
86
+ allow_int=allow_int,
87
+ has_aux=True)
88
+
89
+ @wraps(fun)
90
+ def jacfun(*args, **kwargs):
91
+ jac, aux = transform(*args, **kwargs)
92
+ if return_value:
93
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
94
+ else:
95
+ return (jac, aux) if has_aux else jac
96
+
97
+ return jacfun
98
+
99
+
100
+ def _jacfwd(
101
+ fun,
102
+ argnums=0,
103
+ holomorphic=False,
104
+ has_aux=False,
105
+ return_value=False,
106
+ unit_aware=False,
107
+ ):
108
+ @wraps(fun)
109
+ def fun_wrapped(*args, **kwargs):
110
+ if has_aux:
111
+ y, aux = fun(*args, **kwargs)
112
+ if return_value:
113
+ return y, (y, aux)
114
+ else:
115
+ return y, aux
116
+ else:
117
+ y = fun(*args, **kwargs)
118
+ if return_value:
119
+ return y, y
120
+ else:
121
+ return y, None
122
+
123
+ if unit_aware:
124
+ transform = u.autograd.jacfwd(fun_wrapped,
125
+ argnums=argnums,
126
+ holomorphic=holomorphic,
127
+ has_aux=True)
128
+ else:
129
+ transform = jax.jacfwd(fun_wrapped,
130
+ argnums=argnums,
131
+ holomorphic=holomorphic,
132
+ has_aux=True)
133
+
134
+ @wraps(fun)
135
+ def jacfun(*args, **kwargs):
136
+ jac, aux = transform(*args, **kwargs)
137
+ if return_value:
138
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
139
+ else:
140
+ return (jac, aux) if has_aux else jac
141
+
142
+ return jacfun
143
+
144
+
145
+ TransformFn = Callable
146
+
147
+
148
+ class GradientTransform(PrettyRepr):
149
+ """
150
+ Automatic Differentiation Transformations for the ``State`` system.
151
+
152
+ This class implements gradient transformations for functions that operate on State objects.
153
+ It allows for flexible configuration of gradient computation with respect to specified states
154
+ and function arguments.
155
+
156
+ Parameters
157
+ ----------
158
+ target : callable
159
+ The function to be transformed.
160
+ transform : callable
161
+ The transformation function to apply.
162
+ grad_states : State, sequence of State, or dict of State, optional
163
+ States to compute gradients for.
164
+ argnums : int or sequence of int, optional
165
+ Indices of arguments to differentiate with respect to.
166
+ return_value : bool, default False
167
+ Whether to return the function's value along with gradients.
168
+ has_aux : bool, default False
169
+ Whether the function returns auxiliary data.
170
+ transform_params : dict, optional
171
+ Additional parameters for the transformation function.
172
+ check_states : bool, default True
173
+ Whether to check that all grad_states are found in the function.
174
+
175
+ Attributes
176
+ ----------
177
+ target : callable
178
+ The function to be transformed.
179
+ stateful_target : StatefulFunction
180
+ A wrapper around the target function for state management.
181
+ raw_argnums : int, sequence of int, or None
182
+ The original argnums specified by the user.
183
+ true_argnums : int or tuple of int
184
+ The adjusted argnums used internally.
185
+ return_value : bool
186
+ Whether to return the function's value along with gradients.
187
+ has_aux : bool
188
+ Whether the function returns auxiliary data.
189
+
190
+ Examples
191
+ --------
192
+ Basic gradient computation with states:
193
+
194
+ .. code-block:: python
195
+
196
+ >>> import brainstate
197
+ >>> import jax.numpy as jnp
198
+ >>>
199
+ >>> # Create states
200
+ >>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
201
+ >>> bias = brainstate.State(jnp.array([0.5, -0.5]))
202
+ >>>
203
+ >>> def loss_fn(x):
204
+ ... y = x @ weight.value + bias.value
205
+ ... return jnp.sum(y ** 2)
206
+ >>>
207
+ >>> # Create gradient transform
208
+ >>> grad_transform = brainstate.transform.GradientTransform(
209
+ ... target=loss_fn,
210
+ ... transform=jax.grad,
211
+ ... grad_states=[weight, bias]
212
+ ... )
213
+ >>>
214
+ >>> # Compute gradients
215
+ >>> x = jnp.array([1.0, 2.0])
216
+ >>> grads = grad_transform(x)
217
+
218
+ With function arguments and auxiliary data:
219
+
220
+ .. code-block:: python
221
+
222
+ >>> def loss_fn_with_aux(x, scale):
223
+ ... y = x @ weight.value + bias.value
224
+ ... loss = jnp.sum((y * scale) ** 2)
225
+ ... return loss, {"predictions": y, "scale": scale}
226
+ >>>
227
+ >>> grad_transform = brainstate.transform.GradientTransform(
228
+ ... target=loss_fn_with_aux,
229
+ ... transform=jax.grad,
230
+ ... grad_states=[weight, bias],
231
+ ... argnums=[0, 1], # gradient w.r.t x and scale
232
+ ... has_aux=True,
233
+ ... return_value=True
234
+ ... )
235
+ >>>
236
+ >>> grads, loss_value, aux_data = grad_transform(x, 2.0)
237
+ """
238
+
239
+ __module__ = "brainstate.transform"
240
+
241
+ def __init__(
242
+ self,
243
+ target: Callable,
244
+ transform: TransformFn,
245
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
246
+ argnums: Optional[Union[int, Sequence[int]]] = None,
247
+ return_value: bool = False,
248
+ has_aux: bool = False,
249
+ transform_params: Optional[Dict[str, Any]] = None,
250
+ check_states: bool = True,
251
+ ):
252
+ """
253
+ Initialize a ``GradientTransform`` instance.
254
+
255
+ Parameters
256
+ ----------
257
+ target : callable
258
+ The function to be transformed.
259
+ transform : callable
260
+ The transformation function to apply.
261
+ grad_states : State, sequence of State, or dict of State, optional
262
+ States to compute gradients for.
263
+ argnums : int or sequence of int, optional
264
+ Indices of arguments to differentiate with respect to.
265
+ return_value : bool, default False
266
+ Whether to return the function's value along with gradients.
267
+ has_aux : bool, default False
268
+ Whether the function returns auxiliary data.
269
+ transform_params : dict, optional
270
+ Additional parameters for the transformation function.
271
+ check_states : bool, default True
272
+ Whether to check that all grad_states are found in the function.
273
+
274
+ Raises
275
+ ------
276
+ TypeError
277
+ If any grad_states are not State instances.
278
+ """
279
+ # gradient variables
280
+ if isinstance(grad_states, dict):
281
+ grad_states = {k: v for k, v in grad_states.items()}
282
+ self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
283
+ self._grad_state_ids = [id(v) for v in self._grad_states]
284
+ self._grad_id_to_state = {id(v): v for v in self._grad_states}
285
+ if any(not isinstance(v, State) for v in self._grad_states):
286
+ raise TypeError("All grad_states must be State instances.")
287
+ self.check_states = check_states
288
+
289
+ # parameters
290
+ if argnums is None and len(self._grad_states) == 0:
291
+ argnums = 0
292
+ if argnums is None:
293
+ assert len(self._grad_states) > 0
294
+ _argnums = 0
295
+ elif isinstance(argnums, int):
296
+ _argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
297
+ else:
298
+ assert isinstance(argnums, (tuple, list))
299
+ _argnums = tuple(a + 2 for a in argnums)
300
+ if len(self._grad_states) > 0:
301
+ _argnums = (0,) + _argnums
302
+ self.raw_argnums = argnums
303
+ self.true_argnums = _argnums
304
+ self.return_value = return_value
305
+ self.has_aux = has_aux
306
+
307
+ # target
308
+ assert callable(target), "The target should be a callable object."
309
+ self.target = target
310
+ self.stateful_target = StatefulFunction(target, name='gradient', return_only_write=False)
311
+
312
+ # transform
313
+ grad_setting = dict() if transform_params is None else transform_params
314
+ if self.has_aux:
315
+ self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
316
+ else:
317
+ self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
318
+
319
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
320
+ yield PrettyType(self.__class__.__name__)
321
+ yield PrettyAttr("target", self.target)
322
+ yield PrettyAttr("grad_states", self._grad_states)
323
+ yield PrettyAttr("grad_tree", self._grad_tree)
324
+ yield PrettyAttr("argnums", self.raw_argnums)
325
+ yield PrettyAttr("return_value", self.return_value)
326
+ yield PrettyAttr("has_aux", self.has_aux)
327
+ yield PrettyAttr("transform", self._transform)
328
+
329
+ def _split_state_vals(self, state_trace):
330
+ """
331
+ Split state values into gradient and non-gradient states.
332
+
333
+ Args:
334
+ state_trace: The state trace containing all states.
335
+
336
+ Returns:
337
+ Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
338
+ """
339
+ grad_vals = dict()
340
+ other_vals = dict()
341
+ all_ids = set(self._grad_state_ids)
342
+ for st in state_trace.states:
343
+ id_ = id(st)
344
+ if id_ in all_ids:
345
+ grad_vals[id_] = st.value
346
+ all_ids.remove(id_)
347
+ else:
348
+ other_vals[id_] = st.value
349
+ if len(all_ids):
350
+ if self.check_states:
351
+ err = f"Some states are not found in the state trace when performing gradient transformations.\n "
352
+ for i, id_ in enumerate(all_ids):
353
+ st = self._grad_id_to_state[id_]
354
+ st.raise_error_with_source_info(ValueError(err + str(st)))
355
+ else:
356
+ id2state = {id(st): st for st in self._grad_states}
357
+ for id_ in all_ids:
358
+ grad_vals[id_] = id2state[id_].value
359
+
360
+ return grad_vals, other_vals
361
+
362
+ def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
363
+ """
364
+ Merge gradient and non-gradient state values back into a single list.
365
+
366
+ Args:
367
+ grad_vals (Dict): Dictionary of gradient state values.
368
+ other_vals (Dict): Dictionary of non-gradient state values.
369
+ state_trace: The state trace containing all states.
370
+
371
+ Returns:
372
+ List: A list of merged state values.
373
+ """
374
+ res = []
375
+ for st in state_trace.states:
376
+ id_ = id(st)
377
+ if id_ in self._grad_state_ids:
378
+ res.append(grad_vals[id_])
379
+ else:
380
+ res.append(other_vals[id_])
381
+ return res
382
+
383
+ def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
384
+ """
385
+ Call the target function with the given state values and arguments.
386
+
387
+ Args:
388
+ grad_vals (Dict): Dictionary of gradient state values.
389
+ other_vals (Dict): Dictionary of non-gradient state values.
390
+ *args: Positional arguments to pass to the target function.
391
+ **kwargs: Keyword arguments to pass to the target function.
392
+
393
+ Returns:
394
+ Tuple: A tuple containing updated state values and the function output.
395
+ """
396
+ state_trace = self.stateful_target.get_state_trace(*args, **kwargs, compile_if_miss=True)
397
+ state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
398
+ state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
399
+ return state_vals, out
400
+
401
+ def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
402
+ """
403
+ Wrapper function for target functions that return auxiliary data.
404
+
405
+ Args:
406
+ grad_vals (Dict): Dictionary of gradient state values.
407
+ other_vals (Dict): Dictionary of non-gradient state values.
408
+ *args: Positional arguments to pass to the target function.
409
+ **kwargs: Keyword arguments to pass to the target function.
410
+
411
+ Returns:
412
+ Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
413
+ """
414
+ # Users should return the auxiliary data like::
415
+ # >>> # 1. example of return one data
416
+ # >>> return scalar_loss, data
417
+ # >>> # 2. example of return multiple data
418
+ # >>> return scalar_loss, (data1, data2, ...)
419
+ state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
420
+ return outs[0], (outs, state_vals)
421
+
422
+ def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
423
+ """
424
+ Wrapper function for target functions that do not return auxiliary data.
425
+
426
+ Args:
427
+ grad_vals (Dict): Dictionary of gradient state values.
428
+ other_vals (Dict): Dictionary of non-gradient state values.
429
+ *args: Positional arguments to pass to the target function.
430
+ **kwargs: Keyword arguments to pass to the target function.
431
+
432
+ Returns:
433
+ Tuple: A tuple containing the output and a tuple of (output, updated state values).
434
+ """
435
+ state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
436
+ return out, (out, state_vals)
437
+
438
+ def _return(self, rets, state_trace):
439
+ """
440
+ Process and format the return values from the gradient computation.
441
+
442
+ Args:
443
+ rets: The raw results from the gradient computation.
444
+ state_trace: The state trace containing all states.
445
+
446
+ Returns:
447
+ Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
448
+ """
449
+ # unpack the return values
450
+ grads, (outputs, new_state_vals) = rets
451
+
452
+ # assign new values to the states
453
+ state_trace.assign_state_vals(new_state_vals)
454
+
455
+ # check returned grads
456
+ if len(self._grad_states) > 0:
457
+ grads_of_states = grads if self.raw_argnums is None else grads[0]
458
+ grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
459
+ if self.raw_argnums is None:
460
+ grads = self._grad_tree.unflatten(grads_of_states)
461
+ else:
462
+ var_grads = self._grad_tree.unflatten(grads_of_states)
463
+ arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
464
+ grads = (var_grads, arg_grads)
465
+
466
+ # check returned value
467
+ if self.return_value:
468
+ # check aux
469
+ if self.has_aux:
470
+ return grads, outputs[0], outputs[1]
471
+ else:
472
+ return grads, outputs
473
+ else:
474
+ # check aux
475
+ if self.has_aux:
476
+ return grads, outputs[1]
477
+ else:
478
+ return grads
479
+
480
+ def __call__(
481
+ self, *args, **kwargs
482
+ ) -> (
483
+ Gradient |
484
+ Tuple[Gradient, LossValue] |
485
+ Tuple[Gradient, AuxData] |
486
+ Tuple[Gradient, LossValue, AuxData]
487
+ ):
488
+ """
489
+ Compute gradients by calling the transformed function.
490
+
491
+ Parameters
492
+ ----------
493
+ *args
494
+ Positional arguments to pass to the target function.
495
+ **kwargs
496
+ Keyword arguments to pass to the target function.
497
+
498
+ Returns
499
+ -------
500
+ Gradient or tuple
501
+ The computed gradients, potentially including function value and/or auxiliary data.
502
+ The exact return structure depends on the settings of return_value and has_aux.
503
+ """
504
+
505
+ # TODO: support jax.disable_jit()
506
+
507
+ # compute the model
508
+ self.stateful_target.make_jaxpr(*args, **kwargs)
509
+ cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
510
+
511
+ # apply the gradient transformation
512
+ state_trace = self.stateful_target.get_state_trace_by_cache(cache)
513
+ rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
514
+
515
+ # analyze and return the results
516
+ return self._return(rets, state_trace)
517
+
518
+
519
+ @set_module_as("brainstate.transform")
520
+ def grad(
521
+ fun: Callable = Missing(),
522
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
523
+ argnums: Optional[Union[int, Sequence[int]]] = None,
524
+ holomorphic: Optional[bool] = False,
525
+ allow_int: Optional[bool] = False,
526
+ has_aux: Optional[bool] = None,
527
+ return_value: Optional[bool] = False,
528
+ unit_aware: bool = False,
529
+ check_states: bool = True,
530
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
531
+ """
532
+ Compute the gradient of a scalar-valued function with respect to its arguments.
533
+
534
+
535
+ 1. When ``grad_states`` is None
536
+
537
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
538
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
539
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
540
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
541
+ 2. When ``grad_states`` is not None and ``argnums`` is None
542
+
543
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
544
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
545
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
546
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
547
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
548
+
549
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
550
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
551
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
552
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
553
+
554
+
555
+ Parameters
556
+ ----------
557
+ fun : callable, optional
558
+ The scalar-valued function to be differentiated.
559
+ grad_states : State, sequence of State, or dict of State, optional
560
+ The variables in fun to take their gradients.
561
+ argnums : int or sequence of int, optional
562
+ Specifies which positional argument(s) to differentiate with respect to.
563
+ holomorphic : bool, default False
564
+ Whether fun is promised to be holomorphic.
565
+ allow_int : bool, default False
566
+ Whether to allow differentiating with respect to
567
+ integer valued inputs. The gradient of an integer input will have a trivial
568
+ vector-space dtype (float0).
569
+ has_aux : bool, optional
570
+ Indicates whether fun returns a pair where the
571
+ first element is considered the output of the mathematical function to be
572
+ differentiated and the second element is auxiliary data.
573
+ return_value : bool, default False
574
+ Indicates whether to return the value of the
575
+ function along with the gradient.
576
+ unit_aware : bool, default False
577
+ Whether to return the gradient in the unit-aware mode.
578
+ check_states : bool, default True
579
+ Whether to check that all grad_states are found in the function.
580
+
581
+ Returns
582
+ -------
583
+ GradientTransform or callable
584
+ A function which computes the gradient of fun. The function takes the same
585
+ arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
586
+ the function returns a pair where the first element is the gradient and the
587
+ second element is the auxiliary data. If `return_value` is True, the function
588
+ returns a pair where the first element is the gradient and the second element
589
+ is the value of the function.
590
+
591
+ Examples
592
+ --------
593
+ Basic gradient computation:
594
+
595
+ .. code-block:: python
596
+
597
+ >>> import brainstate
598
+ >>> import jax.numpy as jnp
599
+ >>>
600
+ >>> # Simple function gradient
601
+ >>> def f(x):
602
+ ... return jnp.sum(x ** 2)
603
+ >>>
604
+ >>> grad_f = brainstate.transform.grad(f)
605
+ >>> x = jnp.array([1.0, 2.0, 3.0])
606
+ >>> gradient = grad_f(x)
607
+
608
+ Gradient with respect to states:
609
+
610
+ .. code-block:: python
611
+
612
+ >>> # Create states
613
+ >>> weight = brainstate.State(jnp.array([1.0, 2.0]))
614
+ >>> bias = brainstate.State(jnp.array([0.5]))
615
+ >>>
616
+ >>> def loss_fn(x):
617
+ ... prediction = jnp.dot(x, weight.value) + bias.value
618
+ ... return prediction ** 2
619
+ >>>
620
+ >>> # Compute gradients with respect to states
621
+ >>> grad_fn = brainstate.transform.grad(loss_fn, grad_states=[weight, bias])
622
+ >>> x = jnp.array([1.0, 2.0])
623
+ >>> state_grads = grad_fn(x)
624
+
625
+ With auxiliary data and return value:
626
+
627
+ .. code-block:: python
628
+
629
+ >>> def loss_with_aux(x):
630
+ ... prediction = jnp.dot(x, weight.value) + bias.value
631
+ ... loss = prediction ** 2
632
+ ... return loss, {"prediction": prediction}
633
+ >>>
634
+ >>> grad_fn = brainstate.transform.grad(
635
+ ... loss_with_aux,
636
+ ... grad_states=[weight, bias],
637
+ ... has_aux=True,
638
+ ... return_value=True
639
+ ... )
640
+ >>> grads, loss_value, aux_data = grad_fn(x)
641
+ """
642
+ if isinstance(fun, Missing):
643
+ def transform(fun) -> GradientTransform:
644
+ return GradientTransform(
645
+ target=fun,
646
+ transform=u.autograd.grad if unit_aware else jax.grad,
647
+ grad_states=grad_states,
648
+ argnums=argnums,
649
+ return_value=return_value,
650
+ has_aux=False if has_aux is None else has_aux,
651
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
652
+ check_states=check_states
653
+ )
654
+
655
+ return transform
656
+
657
+ return GradientTransform(
658
+ target=fun,
659
+ transform=u.autograd.grad if unit_aware else jax.grad,
660
+ grad_states=grad_states,
661
+ argnums=argnums,
662
+ return_value=return_value,
663
+ has_aux=False if has_aux is None else has_aux,
664
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
665
+ check_states=check_states
666
+ )
667
+
668
+
669
+ @set_module_as("brainstate.transform")
670
+ def vector_grad(
671
+ func: Callable = Missing(),
672
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
673
+ argnums: Optional[Union[int, Sequence[int]]] = None,
674
+ return_value: bool = False,
675
+ has_aux: Optional[bool] = None,
676
+ unit_aware: bool = False,
677
+ check_states: bool = True,
678
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
679
+ """
680
+ Take vector-valued gradients for function ``func``.
681
+
682
+ Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
683
+ the returns in this function are different for different argument settings.
684
+
685
+
686
+ 1. When ``grad_states`` is None
687
+
688
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
689
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
690
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
691
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
692
+ 2. When ``grad_states`` is not None and ``argnums`` is None
693
+
694
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
695
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
696
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
697
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
698
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
699
+
700
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
701
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
702
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
703
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
704
+
705
+
706
+ Parameters
707
+ ----------
708
+ func : callable, optional
709
+ Function whose gradient is to be computed.
710
+ grad_states : State, sequence of State, or dict of State, optional
711
+ The variables in ``func`` to take their gradients.
712
+ argnums : int or sequence of int, optional
713
+ Specifies which positional argument(s) to differentiate with respect to.
714
+ return_value : bool, default False
715
+ Whether to return the loss value.
716
+ has_aux : bool, optional
717
+ Indicates whether ``fun`` returns a pair where the
718
+ first element is considered the output of the mathematical function to be
719
+ differentiated and the second element is auxiliary data.
720
+ unit_aware : bool, default False
721
+ Whether to return the gradient in the unit-aware mode.
722
+ check_states : bool, default True
723
+ Whether to check that all grad_states are found in the function.
724
+
725
+ Returns
726
+ -------
727
+ GradientTransform or callable
728
+ The vector gradient function.
729
+
730
+ Examples
731
+ --------
732
+ Basic vector gradient computation:
733
+
734
+ .. code-block:: python
735
+
736
+ >>> import brainstate
737
+ >>> import jax.numpy as jnp
738
+ >>>
739
+ >>> # Vector-valued function
740
+ >>> def f(x):
741
+ ... return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
742
+ >>>
743
+ >>> vector_grad_f = brainstate.transform.vector_grad(f)
744
+ >>> x = jnp.array([2.0, 3.0])
745
+ >>> gradients = vector_grad_f(x) # Shape: (3, 2)
746
+
747
+ With states:
748
+
749
+ .. code-block:: python
750
+
751
+ >>> params = brainstate.State(jnp.array([1.0, 2.0]))
752
+ >>>
753
+ >>> def model(x):
754
+ ... return jnp.array([
755
+ ... x * params.value[0],
756
+ ... x**2 * params.value[1]
757
+ ... ])
758
+ >>>
759
+ >>> vector_grad_fn = brainstate.transform.vector_grad(
760
+ ... model, grad_states=[params]
761
+ ... )
762
+ >>> x = 3.0
763
+ >>> param_grads = vector_grad_fn(x)
764
+ """
765
+
766
+ if isinstance(func, Missing):
767
+ def transform(fun) -> GradientTransform:
768
+ return GradientTransform(
769
+ target=fun,
770
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
771
+ grad_states=grad_states,
772
+ argnums=argnums,
773
+ return_value=return_value,
774
+ has_aux=False if has_aux is None else has_aux,
775
+ check_states=check_states
776
+ )
777
+
778
+ return transform
779
+
780
+ else:
781
+ return GradientTransform(
782
+ target=func,
783
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
784
+ grad_states=grad_states,
785
+ argnums=argnums,
786
+ return_value=return_value,
787
+ has_aux=False if has_aux is None else has_aux,
788
+ check_states=check_states
789
+ )
790
+
791
+
792
+ @set_module_as("brainstate.transform")
793
+ def jacrev(
794
+ fun: Callable,
795
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
796
+ argnums: Optional[Union[int, Sequence[int]]] = None,
797
+ has_aux: Optional[bool] = None,
798
+ return_value: bool = False,
799
+ holomorphic: bool = False,
800
+ allow_int: bool = False,
801
+ unit_aware: bool = False,
802
+ check_states: bool = True,
803
+ ) -> GradientTransform:
804
+ """
805
+ Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
806
+
807
+ This function extends the JAX official ``jacrev`` to make automatic jacobian
808
+ computation on functions and class functions. Moreover, it supports returning
809
+ value ("return_value") and returning auxiliary data ("has_aux").
810
+
811
+
812
+ 1. When ``grad_states`` is None
813
+
814
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
815
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
816
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
817
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
818
+ 2. When ``grad_states`` is not None and ``argnums`` is None
819
+
820
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
821
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
822
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
823
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
824
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
825
+
826
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
827
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
828
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
829
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
830
+
831
+
832
+
833
+ Parameters
834
+ ----------
835
+ fun: Callable
836
+ Function whose Jacobian is to be computed.
837
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
838
+ The variables in ``func`` to take their gradients.
839
+ has_aux: optional, bool
840
+ Indicates whether ``fun`` returns a pair where the
841
+ first element is considered the output of the mathematical function to be
842
+ differentiated and the second element is auxiliary data. Default False.
843
+ return_value : bool
844
+ Whether return the loss value.
845
+ argnums: Optional, integer or sequence of integers.
846
+ Specifies which
847
+ positional argument(s) to differentiate with respect to (default ``0``).
848
+ holomorphic: Optional, bool.
849
+ Indicates whether ``fun`` is promised to be
850
+ holomorphic. Default False.
851
+ allow_int: Optional, bool.
852
+ Whether to allow differentiating with
853
+ respect to integer valued inputs. The gradient of an integer input will
854
+ have a trivial vector-space dtype (float0). Default False.
855
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
856
+ mode. Default False.
857
+
858
+ Returns
859
+ -------
860
+ fun: GradientTransform
861
+ The transformed object.
862
+ """
863
+ return GradientTransform(
864
+ target=fun,
865
+ transform=_jacrev,
866
+ grad_states=grad_states,
867
+ argnums=argnums,
868
+ return_value=return_value,
869
+ has_aux=False if has_aux is None else has_aux,
870
+ transform_params=dict(holomorphic=holomorphic,
871
+ allow_int=allow_int,
872
+ unit_aware=unit_aware, ),
873
+ check_states=check_states
874
+ )
875
+
876
+
877
+ jacobian = jacrev
878
+
879
+
880
+ @set_module_as("brainstate.transform")
881
+ def jacfwd(
882
+ func: Callable,
883
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
884
+ argnums: Optional[Union[int, Sequence[int]]] = None,
885
+ has_aux: Optional[bool] = None,
886
+ return_value: bool = False,
887
+ holomorphic: bool = False,
888
+ unit_aware: bool = False,
889
+ check_states: bool = True,
890
+ ) -> GradientTransform:
891
+ """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
892
+
893
+ This function extends the JAX official ``jacfwd`` to make automatic jacobian
894
+ computation on functions and class functions. Moreover, it supports returning
895
+ value ("return_value") and returning auxiliary data ("has_aux").
896
+
897
+
898
+ 1. When ``grad_states`` is None
899
+
900
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
901
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
902
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
903
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
904
+ 2. When ``grad_states`` is not None and ``argnums`` is None
905
+
906
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
907
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
908
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
909
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
910
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
911
+
912
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
913
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
914
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
915
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
916
+
917
+
918
+ Parameters
919
+ ----------
920
+ func: Function whose Jacobian is to be computed.
921
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
922
+ The variables in ``func`` to take their gradients.
923
+ has_aux: optional, bool
924
+ Indicates whether ``fun`` returns a pair where the
925
+ first element is considered the output of the mathematical function to be
926
+ differentiated and the second element is auxiliary data. Default False.
927
+ return_value : bool
928
+ Whether return the loss value.
929
+ argnums: Optional, integer or sequence of integers. Specifies which
930
+ positional argument(s) to differentiate with respect to (default ``0``).
931
+ holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
932
+ holomorphic. Default False.
933
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
934
+ mode. Default False.
935
+
936
+ Returns
937
+ -------
938
+ obj: GradientTransform
939
+ The transformed object.
940
+ """
941
+
942
+ return GradientTransform(
943
+ target=func,
944
+ transform=_jacfwd,
945
+ grad_states=grad_states,
946
+ argnums=argnums,
947
+ return_value=return_value,
948
+ has_aux=False if has_aux is None else has_aux,
949
+ transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
950
+ check_states=check_states
951
+ )
952
+
953
+
954
+ @set_module_as("brainstate.transform")
955
+ def hessian(
956
+ func: Callable,
957
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
958
+ argnums: Optional[Union[int, Sequence[int]]] = None,
959
+ return_value: bool = False,
960
+ holomorphic: bool = False,
961
+ has_aux: Optional[bool] = None,
962
+ unit_aware: bool = False,
963
+ check_states: bool = True,
964
+ ) -> GradientTransform:
965
+ """
966
+ Hessian of ``func`` as a dense array.
967
+
968
+
969
+ 1. When ``grad_states`` is None
970
+
971
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
972
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
973
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
974
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
975
+ 2. When ``grad_states`` is not None and ``argnums`` is None
976
+
977
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
978
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
979
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
980
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
981
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
982
+
983
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
984
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
985
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
986
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
987
+
988
+
989
+ Parameters
990
+ ----------
991
+ func : callable
992
+ Function whose Hessian is to be computed. Its arguments at positions
993
+ specified by ``argnums`` should be arrays, scalars, or standard Python
994
+ containers thereof. It should return arrays, scalars, or standard Python
995
+ containers thereof.
996
+ grad_states : optional, ArrayCollector, sequence of ArrayType
997
+ The variables required to compute their gradients.
998
+ argnums: Optional, integer or sequence of integers
999
+ Specifies which positional argument(s) to differentiate with respect to (default ``0``).
1000
+ holomorphic : bool
1001
+ Indicates whether ``fun`` is promised to be holomorphic. Default False.
1002
+ return_value : bool
1003
+ Whether return the hessian values.
1004
+ has_aux: Optional, bool
1005
+ Indicates whether ``fun`` returns a pair where the first element is considered
1006
+ the output of the mathematical function to be differentiated and the second
1007
+ element is auxiliary data. Default False.
1008
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
1009
+ mode. Default False.
1010
+
1011
+ Returns
1012
+ -------
1013
+ obj: ObjectTransform
1014
+ The transformed object.
1015
+ """
1016
+ return GradientTransform(
1017
+ target=func,
1018
+ transform=u.autograd.hessian if unit_aware else jax.hessian,
1019
+ grad_states=grad_states,
1020
+ argnums=argnums,
1021
+ return_value=return_value,
1022
+ has_aux=False if has_aux is None else has_aux,
1023
+ transform_params=dict(holomorphic=holomorphic),
1024
+ check_states=check_states
1025
+ )