brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,778 +1,778 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
- This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
- computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
- which has been moved into the ``compile`` module.
21
-
22
- The wrapped gradient transformations here are made possible by using the following ideas:
23
- 1. All the states to compute the gradients should be known before the transformation.
24
- There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
- 2. The states that have been written in the function should be collected and updated after the function call.
26
- We record these states during the function call and updated them after the function call.
27
-
28
- """
29
-
30
- from functools import wraps, partial
31
- from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
32
-
33
- import brainunit as u
34
- import jax
35
-
36
- from brainstate._state import State
37
- from brainstate._utils import set_module_as
38
- from brainstate.compile._make_jaxpr import StatefulFunction
39
- from brainstate.typing import PyTree, Missing
40
- from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
41
-
42
- __all__ = [
43
- 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
44
- ]
45
-
46
- A = TypeVar('A')
47
- Gradient = PyTree
48
- LossValue = PyTree
49
- AuxData = PyTree
50
-
51
-
52
- def _jacrev(
53
- fun,
54
- argnums=0,
55
- holomorphic=False,
56
- allow_int=False,
57
- has_aux=False,
58
- return_value=False,
59
- unit_aware=False,
60
- ):
61
- @wraps(fun)
62
- def fun_wrapped(*args, **kwargs):
63
- if has_aux:
64
- y, aux = fun(*args, **kwargs)
65
- if return_value:
66
- return y, (y, aux)
67
- else:
68
- return y, aux
69
- else:
70
- y = fun(*args, **kwargs)
71
- if return_value:
72
- return y, y
73
- else:
74
- return y, None
75
-
76
- if unit_aware:
77
- transform = u.autograd.jacrev(fun_wrapped,
78
- argnums=argnums,
79
- holomorphic=holomorphic,
80
- allow_int=allow_int,
81
- has_aux=True)
82
- else:
83
- transform = jax.jacrev(fun_wrapped,
84
- argnums=argnums,
85
- holomorphic=holomorphic,
86
- allow_int=allow_int,
87
- has_aux=True)
88
-
89
- @wraps(fun)
90
- def jacfun(*args, **kwargs):
91
- jac, aux = transform(*args, **kwargs)
92
- if return_value:
93
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
94
- else:
95
- return (jac, aux) if has_aux else jac
96
-
97
- return jacfun
98
-
99
-
100
- def _jacfwd(
101
- fun,
102
- argnums=0,
103
- holomorphic=False,
104
- has_aux=False,
105
- return_value=False,
106
- unit_aware=False,
107
- ):
108
- @wraps(fun)
109
- def fun_wrapped(*args, **kwargs):
110
- if has_aux:
111
- y, aux = fun(*args, **kwargs)
112
- if return_value:
113
- return y, (y, aux)
114
- else:
115
- return y, aux
116
- else:
117
- y = fun(*args, **kwargs)
118
- if return_value:
119
- return y, y
120
- else:
121
- return y, None
122
-
123
- if unit_aware:
124
- transform = u.autograd.jacfwd(fun_wrapped,
125
- argnums=argnums,
126
- holomorphic=holomorphic,
127
- has_aux=True)
128
- else:
129
- transform = jax.jacfwd(fun_wrapped,
130
- argnums=argnums,
131
- holomorphic=holomorphic,
132
- has_aux=True)
133
-
134
- @wraps(fun)
135
- def jacfun(*args, **kwargs):
136
- jac, aux = transform(*args, **kwargs)
137
- if return_value:
138
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
139
- else:
140
- return (jac, aux) if has_aux else jac
141
-
142
- return jacfun
143
-
144
-
145
- TransformFn = Callable
146
-
147
-
148
- class GradientTransform(PrettyRepr):
149
- """
150
- Automatic Differentiation Transformations for the ``State`` system.
151
-
152
- This class implements gradient transformations for functions that operate on State objects.
153
- It allows for flexible configuration of gradient computation with respect to specified states
154
- and function arguments.
155
-
156
- Attributes:
157
- target (Callable): The function to be transformed.
158
- stateful_target (StatefulFunction): A wrapper around the target function for state management.
159
- raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
160
- true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
161
- return_value (bool): Whether to return the function's value along with gradients.
162
- has_aux (bool): Whether the function returns auxiliary data.
163
- """
164
-
165
- __module__ = "brainstate.augment"
166
-
167
- def __init__(
168
- self,
169
- target: Callable,
170
- transform: TransformFn,
171
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
172
- argnums: Optional[Union[int, Sequence[int]]] = None,
173
- return_value: bool = False,
174
- has_aux: bool = False,
175
- transform_params: Optional[Dict[str, Any]] = None,
176
- check_states: bool = True,
177
- ):
178
- """
179
- Initialize a ``GradientTransform`` instance.
180
-
181
- Args:
182
- target (Callable): The function to be transformed.
183
- transform (TransformFn): The transformation function to apply.
184
- grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
185
- argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
186
- return_value (bool): Whether to return the function's value along with gradients.
187
- has_aux (bool): Whether the function returns auxiliary data.
188
- transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
189
-
190
- Raises:
191
- TypeError: If any grad_states are not State instances.
192
- """
193
- # gradient variables
194
- if isinstance(grad_states, dict):
195
- grad_states = {k: v for k, v in grad_states.items()}
196
- self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
197
- self._grad_state_ids = [id(v) for v in self._grad_states]
198
- self._grad_id_to_state = {id(v): v for v in self._grad_states}
199
- if any(not isinstance(v, State) for v in self._grad_states):
200
- raise TypeError("All grad_states must be State instances.")
201
- self.check_states = check_states
202
-
203
- # parameters
204
- if argnums is None and len(self._grad_states) == 0:
205
- argnums = 0
206
- if argnums is None:
207
- assert len(self._grad_states) > 0
208
- _argnums = 0
209
- elif isinstance(argnums, int):
210
- _argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
211
- else:
212
- assert isinstance(argnums, (tuple, list))
213
- _argnums = tuple(a + 2 for a in argnums)
214
- if len(self._grad_states) > 0:
215
- _argnums = (0,) + _argnums
216
- self.raw_argnums = argnums
217
- self.true_argnums = _argnums
218
- self.return_value = return_value
219
- self.has_aux = has_aux
220
-
221
- # target
222
- assert callable(target), "The target should be a callable object."
223
- self.target = target
224
- self.stateful_target = StatefulFunction(target, name='gradient')
225
-
226
- # transform
227
- grad_setting = dict() if transform_params is None else transform_params
228
- if self.has_aux:
229
- self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
230
- else:
231
- self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
232
-
233
- def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
234
- yield PrettyType(self.__class__.__name__)
235
- yield PrettyAttr("target", self.target)
236
- yield PrettyAttr("grad_states", self._grad_states)
237
- yield PrettyAttr("grad_tree", self._grad_tree)
238
- yield PrettyAttr("argnums", self.raw_argnums)
239
- yield PrettyAttr("return_value", self.return_value)
240
- yield PrettyAttr("has_aux", self.has_aux)
241
- yield PrettyAttr("transform", self._transform)
242
-
243
- def _split_state_vals(self, state_trace):
244
- """
245
- Split state values into gradient and non-gradient states.
246
-
247
- Args:
248
- state_trace: The state trace containing all states.
249
-
250
- Returns:
251
- Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
252
- """
253
- grad_vals = dict()
254
- other_vals = dict()
255
- all_ids = set(self._grad_state_ids)
256
- for st in state_trace.states:
257
- id_ = id(st)
258
- if id_ in all_ids:
259
- grad_vals[id_] = st.value
260
- all_ids.remove(id_)
261
- else:
262
- other_vals[id_] = st.value
263
- if len(all_ids):
264
- if self.check_states:
265
- err = f"Some states are not found in the state trace when performing gradient transformations.\n "
266
- for i, id_ in enumerate(all_ids):
267
- st = self._grad_id_to_state[id_]
268
- st.raise_error_with_source_info(ValueError(err + str(st)))
269
- else:
270
- id2state = {id(st): st for st in self._grad_states}
271
- for id_ in all_ids:
272
- grad_vals[id_] = id2state[id_].value
273
-
274
- return grad_vals, other_vals
275
-
276
- def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
277
- """
278
- Merge gradient and non-gradient state values back into a single list.
279
-
280
- Args:
281
- grad_vals (Dict): Dictionary of gradient state values.
282
- other_vals (Dict): Dictionary of non-gradient state values.
283
- state_trace: The state trace containing all states.
284
-
285
- Returns:
286
- List: A list of merged state values.
287
- """
288
- res = []
289
- for st in state_trace.states:
290
- id_ = id(st)
291
- if id_ in self._grad_state_ids:
292
- res.append(grad_vals[id_])
293
- else:
294
- res.append(other_vals[id_])
295
- return res
296
-
297
- def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
298
- """
299
- Call the target function with the given state values and arguments.
300
-
301
- Args:
302
- grad_vals (Dict): Dictionary of gradient state values.
303
- other_vals (Dict): Dictionary of non-gradient state values.
304
- *args: Positional arguments to pass to the target function.
305
- **kwargs: Keyword arguments to pass to the target function.
306
-
307
- Returns:
308
- Tuple: A tuple containing updated state values and the function output.
309
- """
310
- cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
311
- state_trace = self.stateful_target.get_state_trace(cache)
312
- state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
313
- state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
314
- return state_vals, out
315
-
316
- def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
317
- """
318
- Wrapper function for target functions that return auxiliary data.
319
-
320
- Args:
321
- grad_vals (Dict): Dictionary of gradient state values.
322
- other_vals (Dict): Dictionary of non-gradient state values.
323
- *args: Positional arguments to pass to the target function.
324
- **kwargs: Keyword arguments to pass to the target function.
325
-
326
- Returns:
327
- Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
328
- """
329
- # Users should return the auxiliary data like::
330
- # >>> # 1. example of return one data
331
- # >>> return scalar_loss, data
332
- # >>> # 2. example of return multiple data
333
- # >>> return scalar_loss, (data1, data2, ...)
334
- state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
335
- return outs[0], (outs, state_vals)
336
-
337
- def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
338
- """
339
- Wrapper function for target functions that do not return auxiliary data.
340
-
341
- Args:
342
- grad_vals (Dict): Dictionary of gradient state values.
343
- other_vals (Dict): Dictionary of non-gradient state values.
344
- *args: Positional arguments to pass to the target function.
345
- **kwargs: Keyword arguments to pass to the target function.
346
-
347
- Returns:
348
- Tuple: A tuple containing the output and a tuple of (output, updated state values).
349
- """
350
- state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
351
- return out, (out, state_vals)
352
-
353
- def _return(self, rets, state_trace):
354
- """
355
- Process and format the return values from the gradient computation.
356
-
357
- Args:
358
- rets: The raw results from the gradient computation.
359
- state_trace: The state trace containing all states.
360
-
361
- Returns:
362
- Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
363
- """
364
- # unpack the return values
365
- grads, (outputs, new_state_vals) = rets
366
-
367
- # assign new values to the states
368
- state_trace.assign_state_vals(new_state_vals)
369
-
370
- # check returned grads
371
- if len(self._grad_states) > 0:
372
- grads_of_states = grads if self.raw_argnums is None else grads[0]
373
- grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
374
- if self.raw_argnums is None:
375
- grads = self._grad_tree.unflatten(grads_of_states)
376
- else:
377
- var_grads = self._grad_tree.unflatten(grads_of_states)
378
- arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
379
- grads = (var_grads, arg_grads)
380
-
381
- # check returned value
382
- if self.return_value:
383
- # check aux
384
- if self.has_aux:
385
- return grads, outputs[0], outputs[1]
386
- else:
387
- return grads, outputs
388
- else:
389
- # check aux
390
- if self.has_aux:
391
- return grads, outputs[1]
392
- else:
393
- return grads
394
-
395
- def __call__(
396
- self, *args, **kwargs
397
- ) -> (
398
- Gradient |
399
- Tuple[Gradient, LossValue] |
400
- Tuple[Gradient, AuxData] |
401
- Tuple[Gradient, LossValue, AuxData]
402
- ):
403
- """
404
- Compute gradients by calling the transformed function.
405
-
406
- Args:
407
- *args: Positional arguments to pass to the target function.
408
- **kwargs: Keyword arguments to pass to the target function.
409
-
410
- Returns:
411
- Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
412
- """
413
-
414
- # TODO: support jax.disable_jit()
415
-
416
- # compute the model
417
- self.stateful_target.make_jaxpr(*args, **kwargs)
418
- cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
419
-
420
- # apply the gradient transformation
421
- state_trace = self.stateful_target.get_state_trace(cache)
422
- rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
423
-
424
- # analyze and return the results
425
- return self._return(rets, state_trace)
426
-
427
-
428
- _doc_of_return = '''
429
-
430
- 1. When ``grad_states`` is None
431
- - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
432
- - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
433
- - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
434
- - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
435
- 2. When ``grad_states`` is not None and ``argnums`` is None
436
- - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
437
- - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
438
- - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
439
- - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
440
- 3. When ``grad_states`` is not None and ``argnums`` is not None
441
- - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
442
- - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
443
- - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
444
- - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
445
-
446
- '''
447
-
448
-
449
- @set_module_as("brainstate.augment")
450
- def grad(
451
- fun: Callable = Missing(),
452
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
453
- argnums: Optional[Union[int, Sequence[int]]] = None,
454
- holomorphic: Optional[bool] = False,
455
- allow_int: Optional[bool] = False,
456
- has_aux: Optional[bool] = None,
457
- return_value: Optional[bool] = False,
458
- unit_aware: bool = False,
459
- check_states: bool = True,
460
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
461
- """
462
- Compute the gradient of a scalar-valued function with respect to its arguments.
463
-
464
- %s
465
-
466
- Args:
467
- fun: callable. the scalar-valued function to be differentiated.
468
- allow_int: (bool) optional. Whether to allow differentiating with respect to
469
- integer valued inputs. The gradient of an integer input will have a trivial
470
- vector-space dtype (float0). Default False.
471
- holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
472
- Default False.
473
- grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
474
- in fun to take their gradients.
475
- fun: the scalar-valued function to be differentiated.
476
- argnums: (int or tuple of ints) optional. Specifies which positional
477
- argument(s) to differentiate with respect to.
478
- has_aux: (bool) optional. Indicates whether fun returns a pair where the
479
- first element is considered the output of the mathematical function to be
480
- differentiated and the second element is auxiliary data. Default False.
481
- return_value: (bool) optional. Indicates whether to return the value of the
482
- function along with the gradient. Default False.
483
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
484
- mode. Default False.
485
-
486
- Returns:
487
- A function which computes the gradient of fun. The function takes the same
488
- arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
489
- the function returns a pair where the first element is the gradient and the
490
- second element is the auxiliary data. If `return_value` is True, the function
491
- returns a pair where the first element is the gradient and the second element
492
- is the value of the function.
493
-
494
- """
495
- if isinstance(fun, Missing):
496
- def transform(fun) -> GradientTransform:
497
- return GradientTransform(
498
- target=fun,
499
- transform=u.autograd.grad if unit_aware else jax.grad,
500
- grad_states=grad_states,
501
- argnums=argnums,
502
- return_value=return_value,
503
- has_aux=False if has_aux is None else has_aux,
504
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
505
- check_states=check_states
506
- )
507
-
508
- return transform
509
-
510
- return GradientTransform(
511
- target=fun,
512
- transform=u.autograd.grad if unit_aware else jax.grad,
513
- grad_states=grad_states,
514
- argnums=argnums,
515
- return_value=return_value,
516
- has_aux=False if has_aux is None else has_aux,
517
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
518
- check_states=check_states
519
- )
520
-
521
-
522
- grad.__doc__ = grad.__doc__ % _doc_of_return
523
-
524
-
525
- @set_module_as("brainstate.augment")
526
- def vector_grad(
527
- func: Callable = Missing(),
528
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
529
- argnums: Optional[Union[int, Sequence[int]]] = None,
530
- return_value: bool = False,
531
- has_aux: Optional[bool] = None,
532
- unit_aware: bool = False,
533
- check_states: bool = True,
534
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
535
- """Take vector-valued gradients for function ``func``.
536
-
537
- Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
538
- the returns in this function are different for different argument settings.
539
-
540
- %s
541
-
542
- Parameters
543
- ----------
544
- func: Callable
545
- Function whose gradient is to be computed.
546
- grad_states : optional, ArrayType, sequence of ArrayType, dict
547
- The variables in ``func`` to take their gradients.
548
- has_aux: optional, bool
549
- Indicates whether ``fun`` returns a pair where the
550
- first element is considered the output of the mathematical function to be
551
- differentiated and the second element is auxiliary data. Default False.
552
- return_value : bool
553
- Whether return the loss value.
554
- argnums: Optional, integer or sequence of integers. Specifies which
555
- positional argument(s) to differentiate with respect to (default ``0``).
556
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
557
- mode. Default False.
558
-
559
- Returns
560
- -------
561
- func : GradientTransform
562
- The vector gradient function.
563
- """
564
-
565
- if isinstance(func, Missing):
566
- def transform(fun) -> GradientTransform:
567
- return GradientTransform(
568
- target=fun,
569
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
570
- grad_states=grad_states,
571
- argnums=argnums,
572
- return_value=return_value,
573
- has_aux=False if has_aux is None else has_aux,
574
- check_states=check_states
575
- )
576
-
577
- return transform
578
-
579
- else:
580
- return GradientTransform(
581
- target=func,
582
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
583
- grad_states=grad_states,
584
- argnums=argnums,
585
- return_value=return_value,
586
- has_aux=False if has_aux is None else has_aux,
587
- check_states=check_states
588
- )
589
-
590
-
591
- vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
592
-
593
-
594
- @set_module_as("brainstate.augment")
595
- def jacrev(
596
- fun: Callable,
597
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
598
- argnums: Optional[Union[int, Sequence[int]]] = None,
599
- has_aux: Optional[bool] = None,
600
- return_value: bool = False,
601
- holomorphic: bool = False,
602
- allow_int: bool = False,
603
- unit_aware: bool = False,
604
- check_states: bool = True,
605
- ) -> GradientTransform:
606
- """
607
- Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
608
-
609
- This function extends the JAX official ``jacrev`` to make automatic jacobian
610
- computation on functions and class functions. Moreover, it supports returning
611
- value ("return_value") and returning auxiliary data ("has_aux").
612
-
613
- %s
614
-
615
-
616
- Parameters
617
- ----------
618
- fun: Callable
619
- Function whose Jacobian is to be computed.
620
- grad_states : optional, ArrayType, sequence of ArrayType, dict
621
- The variables in ``func`` to take their gradients.
622
- has_aux: optional, bool
623
- Indicates whether ``fun`` returns a pair where the
624
- first element is considered the output of the mathematical function to be
625
- differentiated and the second element is auxiliary data. Default False.
626
- return_value : bool
627
- Whether return the loss value.
628
- argnums: Optional, integer or sequence of integers.
629
- Specifies which
630
- positional argument(s) to differentiate with respect to (default ``0``).
631
- holomorphic: Optional, bool.
632
- Indicates whether ``fun`` is promised to be
633
- holomorphic. Default False.
634
- allow_int: Optional, bool.
635
- Whether to allow differentiating with
636
- respect to integer valued inputs. The gradient of an integer input will
637
- have a trivial vector-space dtype (float0). Default False.
638
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
639
- mode. Default False.
640
-
641
- Returns
642
- -------
643
- fun: GradientTransform
644
- The transformed object.
645
- """
646
- return GradientTransform(
647
- target=fun,
648
- transform=_jacrev,
649
- grad_states=grad_states,
650
- argnums=argnums,
651
- return_value=return_value,
652
- has_aux=False if has_aux is None else has_aux,
653
- transform_params=dict(holomorphic=holomorphic,
654
- allow_int=allow_int,
655
- unit_aware=unit_aware, ),
656
- check_states=check_states
657
- )
658
-
659
-
660
- jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
661
-
662
- jacobian = jacrev
663
-
664
-
665
- @set_module_as("brainstate.augment")
666
- def jacfwd(
667
- func: Callable,
668
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
669
- argnums: Optional[Union[int, Sequence[int]]] = None,
670
- has_aux: Optional[bool] = None,
671
- return_value: bool = False,
672
- holomorphic: bool = False,
673
- unit_aware: bool = False,
674
- check_states: bool = True,
675
- ) -> GradientTransform:
676
- """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
677
-
678
- This function extends the JAX official ``jacfwd`` to make automatic jacobian
679
- computation on functions and class functions. Moreover, it supports returning
680
- value ("return_value") and returning auxiliary data ("has_aux").
681
-
682
- %s
683
-
684
- Parameters
685
- ----------
686
- func: Function whose Jacobian is to be computed.
687
- grad_states : optional, ArrayType, sequence of ArrayType, dict
688
- The variables in ``func`` to take their gradients.
689
- has_aux: optional, bool
690
- Indicates whether ``fun`` returns a pair where the
691
- first element is considered the output of the mathematical function to be
692
- differentiated and the second element is auxiliary data. Default False.
693
- return_value : bool
694
- Whether return the loss value.
695
- argnums: Optional, integer or sequence of integers. Specifies which
696
- positional argument(s) to differentiate with respect to (default ``0``).
697
- holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
698
- holomorphic. Default False.
699
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
700
- mode. Default False.
701
-
702
- Returns
703
- -------
704
- obj: GradientTransform
705
- The transformed object.
706
- """
707
-
708
- return GradientTransform(
709
- target=func,
710
- transform=_jacfwd,
711
- grad_states=grad_states,
712
- argnums=argnums,
713
- return_value=return_value,
714
- has_aux=False if has_aux is None else has_aux,
715
- transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
716
- check_states=check_states
717
- )
718
-
719
-
720
- jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
721
-
722
-
723
- @set_module_as("brainstate.augment")
724
- def hessian(
725
- func: Callable,
726
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
727
- argnums: Optional[Union[int, Sequence[int]]] = None,
728
- return_value: bool = False,
729
- holomorphic: bool = False,
730
- has_aux: Optional[bool] = None,
731
- unit_aware: bool = False,
732
- check_states: bool = True,
733
- ) -> GradientTransform:
734
- """
735
- Hessian of ``func`` as a dense array.
736
-
737
- %s
738
-
739
- Parameters
740
- ----------
741
- func : callable
742
- Function whose Hessian is to be computed. Its arguments at positions
743
- specified by ``argnums`` should be arrays, scalars, or standard Python
744
- containers thereof. It should return arrays, scalars, or standard Python
745
- containers thereof.
746
- grad_states : optional, ArrayCollector, sequence of ArrayType
747
- The variables required to compute their gradients.
748
- argnums: Optional, integer or sequence of integers
749
- Specifies which positional argument(s) to differentiate with respect to (default ``0``).
750
- holomorphic : bool
751
- Indicates whether ``fun`` is promised to be holomorphic. Default False.
752
- return_value : bool
753
- Whether return the hessian values.
754
- has_aux: Optional, bool
755
- Indicates whether ``fun`` returns a pair where the first element is considered
756
- the output of the mathematical function to be differentiated and the second
757
- element is auxiliary data. Default False.
758
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
759
- mode. Default False.
760
-
761
- Returns
762
- -------
763
- obj: ObjectTransform
764
- The transformed object.
765
- """
766
- return GradientTransform(
767
- target=func,
768
- transform=u.autograd.hessian if unit_aware else jax.hessian,
769
- grad_states=grad_states,
770
- argnums=argnums,
771
- return_value=return_value,
772
- has_aux=False if has_aux is None else has_aux,
773
- transform_params=dict(holomorphic=holomorphic),
774
- check_states=check_states
775
- )
776
-
777
-
778
- hessian.__doc__ = hessian.__doc__ % _doc_of_return
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
18
+ This is because the gradient transformations are not using the Jaxpr, instead, most of them are
19
+ computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
20
+ which has been moved into the ``compile`` module.
21
+
22
+ The wrapped gradient transformations here are made possible by using the following ideas:
23
+ 1. All the states to compute the gradients should be known before the transformation.
24
+ There must be provided through the ``grad_states`` argument in any of the gradient transformations.
25
+ 2. The states that have been written in the function should be collected and updated after the function call.
26
+ We record these states during the function call and updated them after the function call.
27
+
28
+ """
29
+
30
+ from functools import wraps, partial
31
+ from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
32
+
33
+ import brainunit as u
34
+ import jax
35
+
36
+ from brainstate._state import State
37
+ from brainstate._utils import set_module_as
38
+ from brainstate.compile._make_jaxpr import StatefulFunction
39
+ from brainstate.typing import PyTree, Missing
40
+ from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
41
+
42
+ __all__ = [
43
+ 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
44
+ ]
45
+
46
+ A = TypeVar('A')
47
+ Gradient = PyTree
48
+ LossValue = PyTree
49
+ AuxData = PyTree
50
+
51
+
52
+ def _jacrev(
53
+ fun,
54
+ argnums=0,
55
+ holomorphic=False,
56
+ allow_int=False,
57
+ has_aux=False,
58
+ return_value=False,
59
+ unit_aware=False,
60
+ ):
61
+ @wraps(fun)
62
+ def fun_wrapped(*args, **kwargs):
63
+ if has_aux:
64
+ y, aux = fun(*args, **kwargs)
65
+ if return_value:
66
+ return y, (y, aux)
67
+ else:
68
+ return y, aux
69
+ else:
70
+ y = fun(*args, **kwargs)
71
+ if return_value:
72
+ return y, y
73
+ else:
74
+ return y, None
75
+
76
+ if unit_aware:
77
+ transform = u.autograd.jacrev(fun_wrapped,
78
+ argnums=argnums,
79
+ holomorphic=holomorphic,
80
+ allow_int=allow_int,
81
+ has_aux=True)
82
+ else:
83
+ transform = jax.jacrev(fun_wrapped,
84
+ argnums=argnums,
85
+ holomorphic=holomorphic,
86
+ allow_int=allow_int,
87
+ has_aux=True)
88
+
89
+ @wraps(fun)
90
+ def jacfun(*args, **kwargs):
91
+ jac, aux = transform(*args, **kwargs)
92
+ if return_value:
93
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
94
+ else:
95
+ return (jac, aux) if has_aux else jac
96
+
97
+ return jacfun
98
+
99
+
100
+ def _jacfwd(
101
+ fun,
102
+ argnums=0,
103
+ holomorphic=False,
104
+ has_aux=False,
105
+ return_value=False,
106
+ unit_aware=False,
107
+ ):
108
+ @wraps(fun)
109
+ def fun_wrapped(*args, **kwargs):
110
+ if has_aux:
111
+ y, aux = fun(*args, **kwargs)
112
+ if return_value:
113
+ return y, (y, aux)
114
+ else:
115
+ return y, aux
116
+ else:
117
+ y = fun(*args, **kwargs)
118
+ if return_value:
119
+ return y, y
120
+ else:
121
+ return y, None
122
+
123
+ if unit_aware:
124
+ transform = u.autograd.jacfwd(fun_wrapped,
125
+ argnums=argnums,
126
+ holomorphic=holomorphic,
127
+ has_aux=True)
128
+ else:
129
+ transform = jax.jacfwd(fun_wrapped,
130
+ argnums=argnums,
131
+ holomorphic=holomorphic,
132
+ has_aux=True)
133
+
134
+ @wraps(fun)
135
+ def jacfun(*args, **kwargs):
136
+ jac, aux = transform(*args, **kwargs)
137
+ if return_value:
138
+ return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
139
+ else:
140
+ return (jac, aux) if has_aux else jac
141
+
142
+ return jacfun
143
+
144
+
145
+ TransformFn = Callable
146
+
147
+
148
+ class GradientTransform(PrettyRepr):
149
+ """
150
+ Automatic Differentiation Transformations for the ``State`` system.
151
+
152
+ This class implements gradient transformations for functions that operate on State objects.
153
+ It allows for flexible configuration of gradient computation with respect to specified states
154
+ and function arguments.
155
+
156
+ Attributes:
157
+ target (Callable): The function to be transformed.
158
+ stateful_target (StatefulFunction): A wrapper around the target function for state management.
159
+ raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
160
+ true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
161
+ return_value (bool): Whether to return the function's value along with gradients.
162
+ has_aux (bool): Whether the function returns auxiliary data.
163
+ """
164
+
165
+ __module__ = "brainstate.augment"
166
+
167
+ def __init__(
168
+ self,
169
+ target: Callable,
170
+ transform: TransformFn,
171
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
172
+ argnums: Optional[Union[int, Sequence[int]]] = None,
173
+ return_value: bool = False,
174
+ has_aux: bool = False,
175
+ transform_params: Optional[Dict[str, Any]] = None,
176
+ check_states: bool = True,
177
+ ):
178
+ """
179
+ Initialize a ``GradientTransform`` instance.
180
+
181
+ Args:
182
+ target (Callable): The function to be transformed.
183
+ transform (TransformFn): The transformation function to apply.
184
+ grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
185
+ argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
186
+ return_value (bool): Whether to return the function's value along with gradients.
187
+ has_aux (bool): Whether the function returns auxiliary data.
188
+ transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
189
+
190
+ Raises:
191
+ TypeError: If any grad_states are not State instances.
192
+ """
193
+ # gradient variables
194
+ if isinstance(grad_states, dict):
195
+ grad_states = {k: v for k, v in grad_states.items()}
196
+ self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
197
+ self._grad_state_ids = [id(v) for v in self._grad_states]
198
+ self._grad_id_to_state = {id(v): v for v in self._grad_states}
199
+ if any(not isinstance(v, State) for v in self._grad_states):
200
+ raise TypeError("All grad_states must be State instances.")
201
+ self.check_states = check_states
202
+
203
+ # parameters
204
+ if argnums is None and len(self._grad_states) == 0:
205
+ argnums = 0
206
+ if argnums is None:
207
+ assert len(self._grad_states) > 0
208
+ _argnums = 0
209
+ elif isinstance(argnums, int):
210
+ _argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
211
+ else:
212
+ assert isinstance(argnums, (tuple, list))
213
+ _argnums = tuple(a + 2 for a in argnums)
214
+ if len(self._grad_states) > 0:
215
+ _argnums = (0,) + _argnums
216
+ self.raw_argnums = argnums
217
+ self.true_argnums = _argnums
218
+ self.return_value = return_value
219
+ self.has_aux = has_aux
220
+
221
+ # target
222
+ assert callable(target), "The target should be a callable object."
223
+ self.target = target
224
+ self.stateful_target = StatefulFunction(target, name='gradient')
225
+
226
+ # transform
227
+ grad_setting = dict() if transform_params is None else transform_params
228
+ if self.has_aux:
229
+ self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
230
+ else:
231
+ self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
232
+
233
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
234
+ yield PrettyType(self.__class__.__name__)
235
+ yield PrettyAttr("target", self.target)
236
+ yield PrettyAttr("grad_states", self._grad_states)
237
+ yield PrettyAttr("grad_tree", self._grad_tree)
238
+ yield PrettyAttr("argnums", self.raw_argnums)
239
+ yield PrettyAttr("return_value", self.return_value)
240
+ yield PrettyAttr("has_aux", self.has_aux)
241
+ yield PrettyAttr("transform", self._transform)
242
+
243
+ def _split_state_vals(self, state_trace):
244
+ """
245
+ Split state values into gradient and non-gradient states.
246
+
247
+ Args:
248
+ state_trace: The state trace containing all states.
249
+
250
+ Returns:
251
+ Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
252
+ """
253
+ grad_vals = dict()
254
+ other_vals = dict()
255
+ all_ids = set(self._grad_state_ids)
256
+ for st in state_trace.states:
257
+ id_ = id(st)
258
+ if id_ in all_ids:
259
+ grad_vals[id_] = st.value
260
+ all_ids.remove(id_)
261
+ else:
262
+ other_vals[id_] = st.value
263
+ if len(all_ids):
264
+ if self.check_states:
265
+ err = f"Some states are not found in the state trace when performing gradient transformations.\n "
266
+ for i, id_ in enumerate(all_ids):
267
+ st = self._grad_id_to_state[id_]
268
+ st.raise_error_with_source_info(ValueError(err + str(st)))
269
+ else:
270
+ id2state = {id(st): st for st in self._grad_states}
271
+ for id_ in all_ids:
272
+ grad_vals[id_] = id2state[id_].value
273
+
274
+ return grad_vals, other_vals
275
+
276
+ def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
277
+ """
278
+ Merge gradient and non-gradient state values back into a single list.
279
+
280
+ Args:
281
+ grad_vals (Dict): Dictionary of gradient state values.
282
+ other_vals (Dict): Dictionary of non-gradient state values.
283
+ state_trace: The state trace containing all states.
284
+
285
+ Returns:
286
+ List: A list of merged state values.
287
+ """
288
+ res = []
289
+ for st in state_trace.states:
290
+ id_ = id(st)
291
+ if id_ in self._grad_state_ids:
292
+ res.append(grad_vals[id_])
293
+ else:
294
+ res.append(other_vals[id_])
295
+ return res
296
+
297
+ def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
298
+ """
299
+ Call the target function with the given state values and arguments.
300
+
301
+ Args:
302
+ grad_vals (Dict): Dictionary of gradient state values.
303
+ other_vals (Dict): Dictionary of non-gradient state values.
304
+ *args: Positional arguments to pass to the target function.
305
+ **kwargs: Keyword arguments to pass to the target function.
306
+
307
+ Returns:
308
+ Tuple: A tuple containing updated state values and the function output.
309
+ """
310
+ cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
311
+ state_trace = self.stateful_target.get_state_trace(cache)
312
+ state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
313
+ state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
314
+ return state_vals, out
315
+
316
+ def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
317
+ """
318
+ Wrapper function for target functions that return auxiliary data.
319
+
320
+ Args:
321
+ grad_vals (Dict): Dictionary of gradient state values.
322
+ other_vals (Dict): Dictionary of non-gradient state values.
323
+ *args: Positional arguments to pass to the target function.
324
+ **kwargs: Keyword arguments to pass to the target function.
325
+
326
+ Returns:
327
+ Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
328
+ """
329
+ # Users should return the auxiliary data like::
330
+ # >>> # 1. example of return one data
331
+ # >>> return scalar_loss, data
332
+ # >>> # 2. example of return multiple data
333
+ # >>> return scalar_loss, (data1, data2, ...)
334
+ state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
335
+ return outs[0], (outs, state_vals)
336
+
337
+ def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
338
+ """
339
+ Wrapper function for target functions that do not return auxiliary data.
340
+
341
+ Args:
342
+ grad_vals (Dict): Dictionary of gradient state values.
343
+ other_vals (Dict): Dictionary of non-gradient state values.
344
+ *args: Positional arguments to pass to the target function.
345
+ **kwargs: Keyword arguments to pass to the target function.
346
+
347
+ Returns:
348
+ Tuple: A tuple containing the output and a tuple of (output, updated state values).
349
+ """
350
+ state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
351
+ return out, (out, state_vals)
352
+
353
+ def _return(self, rets, state_trace):
354
+ """
355
+ Process and format the return values from the gradient computation.
356
+
357
+ Args:
358
+ rets: The raw results from the gradient computation.
359
+ state_trace: The state trace containing all states.
360
+
361
+ Returns:
362
+ Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
363
+ """
364
+ # unpack the return values
365
+ grads, (outputs, new_state_vals) = rets
366
+
367
+ # assign new values to the states
368
+ state_trace.assign_state_vals(new_state_vals)
369
+
370
+ # check returned grads
371
+ if len(self._grad_states) > 0:
372
+ grads_of_states = grads if self.raw_argnums is None else grads[0]
373
+ grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
374
+ if self.raw_argnums is None:
375
+ grads = self._grad_tree.unflatten(grads_of_states)
376
+ else:
377
+ var_grads = self._grad_tree.unflatten(grads_of_states)
378
+ arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
379
+ grads = (var_grads, arg_grads)
380
+
381
+ # check returned value
382
+ if self.return_value:
383
+ # check aux
384
+ if self.has_aux:
385
+ return grads, outputs[0], outputs[1]
386
+ else:
387
+ return grads, outputs
388
+ else:
389
+ # check aux
390
+ if self.has_aux:
391
+ return grads, outputs[1]
392
+ else:
393
+ return grads
394
+
395
+ def __call__(
396
+ self, *args, **kwargs
397
+ ) -> (
398
+ Gradient |
399
+ Tuple[Gradient, LossValue] |
400
+ Tuple[Gradient, AuxData] |
401
+ Tuple[Gradient, LossValue, AuxData]
402
+ ):
403
+ """
404
+ Compute gradients by calling the transformed function.
405
+
406
+ Args:
407
+ *args: Positional arguments to pass to the target function.
408
+ **kwargs: Keyword arguments to pass to the target function.
409
+
410
+ Returns:
411
+ Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
412
+ """
413
+
414
+ # TODO: support jax.disable_jit()
415
+
416
+ # compute the model
417
+ self.stateful_target.make_jaxpr(*args, **kwargs)
418
+ cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
419
+
420
+ # apply the gradient transformation
421
+ state_trace = self.stateful_target.get_state_trace(cache)
422
+ rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
423
+
424
+ # analyze and return the results
425
+ return self._return(rets, state_trace)
426
+
427
+
428
+ _doc_of_return = '''
429
+
430
+ 1. When ``grad_states`` is None
431
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
432
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
433
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
434
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
435
+ 2. When ``grad_states`` is not None and ``argnums`` is None
436
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
437
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
438
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
439
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
440
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
441
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
442
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
443
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
444
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
445
+
446
+ '''
447
+
448
+
449
+ @set_module_as("brainstate.augment")
450
+ def grad(
451
+ fun: Callable = Missing(),
452
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
453
+ argnums: Optional[Union[int, Sequence[int]]] = None,
454
+ holomorphic: Optional[bool] = False,
455
+ allow_int: Optional[bool] = False,
456
+ has_aux: Optional[bool] = None,
457
+ return_value: Optional[bool] = False,
458
+ unit_aware: bool = False,
459
+ check_states: bool = True,
460
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
461
+ """
462
+ Compute the gradient of a scalar-valued function with respect to its arguments.
463
+
464
+ %s
465
+
466
+ Args:
467
+ fun: callable. the scalar-valued function to be differentiated.
468
+ allow_int: (bool) optional. Whether to allow differentiating with respect to
469
+ integer valued inputs. The gradient of an integer input will have a trivial
470
+ vector-space dtype (float0). Default False.
471
+ holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
472
+ Default False.
473
+ grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
474
+ in fun to take their gradients.
475
+ fun: the scalar-valued function to be differentiated.
476
+ argnums: (int or tuple of ints) optional. Specifies which positional
477
+ argument(s) to differentiate with respect to.
478
+ has_aux: (bool) optional. Indicates whether fun returns a pair where the
479
+ first element is considered the output of the mathematical function to be
480
+ differentiated and the second element is auxiliary data. Default False.
481
+ return_value: (bool) optional. Indicates whether to return the value of the
482
+ function along with the gradient. Default False.
483
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
484
+ mode. Default False.
485
+
486
+ Returns:
487
+ A function which computes the gradient of fun. The function takes the same
488
+ arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
489
+ the function returns a pair where the first element is the gradient and the
490
+ second element is the auxiliary data. If `return_value` is True, the function
491
+ returns a pair where the first element is the gradient and the second element
492
+ is the value of the function.
493
+
494
+ """
495
+ if isinstance(fun, Missing):
496
+ def transform(fun) -> GradientTransform:
497
+ return GradientTransform(
498
+ target=fun,
499
+ transform=u.autograd.grad if unit_aware else jax.grad,
500
+ grad_states=grad_states,
501
+ argnums=argnums,
502
+ return_value=return_value,
503
+ has_aux=False if has_aux is None else has_aux,
504
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
505
+ check_states=check_states
506
+ )
507
+
508
+ return transform
509
+
510
+ return GradientTransform(
511
+ target=fun,
512
+ transform=u.autograd.grad if unit_aware else jax.grad,
513
+ grad_states=grad_states,
514
+ argnums=argnums,
515
+ return_value=return_value,
516
+ has_aux=False if has_aux is None else has_aux,
517
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
518
+ check_states=check_states
519
+ )
520
+
521
+
522
+ grad.__doc__ = grad.__doc__ % _doc_of_return
523
+
524
+
525
+ @set_module_as("brainstate.augment")
526
+ def vector_grad(
527
+ func: Callable = Missing(),
528
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
529
+ argnums: Optional[Union[int, Sequence[int]]] = None,
530
+ return_value: bool = False,
531
+ has_aux: Optional[bool] = None,
532
+ unit_aware: bool = False,
533
+ check_states: bool = True,
534
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
535
+ """Take vector-valued gradients for function ``func``.
536
+
537
+ Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
538
+ the returns in this function are different for different argument settings.
539
+
540
+ %s
541
+
542
+ Parameters
543
+ ----------
544
+ func: Callable
545
+ Function whose gradient is to be computed.
546
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
547
+ The variables in ``func`` to take their gradients.
548
+ has_aux: optional, bool
549
+ Indicates whether ``fun`` returns a pair where the
550
+ first element is considered the output of the mathematical function to be
551
+ differentiated and the second element is auxiliary data. Default False.
552
+ return_value : bool
553
+ Whether return the loss value.
554
+ argnums: Optional, integer or sequence of integers. Specifies which
555
+ positional argument(s) to differentiate with respect to (default ``0``).
556
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
557
+ mode. Default False.
558
+
559
+ Returns
560
+ -------
561
+ func : GradientTransform
562
+ The vector gradient function.
563
+ """
564
+
565
+ if isinstance(func, Missing):
566
+ def transform(fun) -> GradientTransform:
567
+ return GradientTransform(
568
+ target=fun,
569
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
570
+ grad_states=grad_states,
571
+ argnums=argnums,
572
+ return_value=return_value,
573
+ has_aux=False if has_aux is None else has_aux,
574
+ check_states=check_states
575
+ )
576
+
577
+ return transform
578
+
579
+ else:
580
+ return GradientTransform(
581
+ target=func,
582
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
583
+ grad_states=grad_states,
584
+ argnums=argnums,
585
+ return_value=return_value,
586
+ has_aux=False if has_aux is None else has_aux,
587
+ check_states=check_states
588
+ )
589
+
590
+
591
+ vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
592
+
593
+
594
+ @set_module_as("brainstate.augment")
595
+ def jacrev(
596
+ fun: Callable,
597
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
598
+ argnums: Optional[Union[int, Sequence[int]]] = None,
599
+ has_aux: Optional[bool] = None,
600
+ return_value: bool = False,
601
+ holomorphic: bool = False,
602
+ allow_int: bool = False,
603
+ unit_aware: bool = False,
604
+ check_states: bool = True,
605
+ ) -> GradientTransform:
606
+ """
607
+ Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
608
+
609
+ This function extends the JAX official ``jacrev`` to make automatic jacobian
610
+ computation on functions and class functions. Moreover, it supports returning
611
+ value ("return_value") and returning auxiliary data ("has_aux").
612
+
613
+ %s
614
+
615
+
616
+ Parameters
617
+ ----------
618
+ fun: Callable
619
+ Function whose Jacobian is to be computed.
620
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
621
+ The variables in ``func`` to take their gradients.
622
+ has_aux: optional, bool
623
+ Indicates whether ``fun`` returns a pair where the
624
+ first element is considered the output of the mathematical function to be
625
+ differentiated and the second element is auxiliary data. Default False.
626
+ return_value : bool
627
+ Whether return the loss value.
628
+ argnums: Optional, integer or sequence of integers.
629
+ Specifies which
630
+ positional argument(s) to differentiate with respect to (default ``0``).
631
+ holomorphic: Optional, bool.
632
+ Indicates whether ``fun`` is promised to be
633
+ holomorphic. Default False.
634
+ allow_int: Optional, bool.
635
+ Whether to allow differentiating with
636
+ respect to integer valued inputs. The gradient of an integer input will
637
+ have a trivial vector-space dtype (float0). Default False.
638
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
639
+ mode. Default False.
640
+
641
+ Returns
642
+ -------
643
+ fun: GradientTransform
644
+ The transformed object.
645
+ """
646
+ return GradientTransform(
647
+ target=fun,
648
+ transform=_jacrev,
649
+ grad_states=grad_states,
650
+ argnums=argnums,
651
+ return_value=return_value,
652
+ has_aux=False if has_aux is None else has_aux,
653
+ transform_params=dict(holomorphic=holomorphic,
654
+ allow_int=allow_int,
655
+ unit_aware=unit_aware, ),
656
+ check_states=check_states
657
+ )
658
+
659
+
660
+ jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
661
+
662
+ jacobian = jacrev
663
+
664
+
665
+ @set_module_as("brainstate.augment")
666
+ def jacfwd(
667
+ func: Callable,
668
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
669
+ argnums: Optional[Union[int, Sequence[int]]] = None,
670
+ has_aux: Optional[bool] = None,
671
+ return_value: bool = False,
672
+ holomorphic: bool = False,
673
+ unit_aware: bool = False,
674
+ check_states: bool = True,
675
+ ) -> GradientTransform:
676
+ """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
677
+
678
+ This function extends the JAX official ``jacfwd`` to make automatic jacobian
679
+ computation on functions and class functions. Moreover, it supports returning
680
+ value ("return_value") and returning auxiliary data ("has_aux").
681
+
682
+ %s
683
+
684
+ Parameters
685
+ ----------
686
+ func: Function whose Jacobian is to be computed.
687
+ grad_states : optional, ArrayType, sequence of ArrayType, dict
688
+ The variables in ``func`` to take their gradients.
689
+ has_aux: optional, bool
690
+ Indicates whether ``fun`` returns a pair where the
691
+ first element is considered the output of the mathematical function to be
692
+ differentiated and the second element is auxiliary data. Default False.
693
+ return_value : bool
694
+ Whether return the loss value.
695
+ argnums: Optional, integer or sequence of integers. Specifies which
696
+ positional argument(s) to differentiate with respect to (default ``0``).
697
+ holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
698
+ holomorphic. Default False.
699
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
700
+ mode. Default False.
701
+
702
+ Returns
703
+ -------
704
+ obj: GradientTransform
705
+ The transformed object.
706
+ """
707
+
708
+ return GradientTransform(
709
+ target=func,
710
+ transform=_jacfwd,
711
+ grad_states=grad_states,
712
+ argnums=argnums,
713
+ return_value=return_value,
714
+ has_aux=False if has_aux is None else has_aux,
715
+ transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
716
+ check_states=check_states
717
+ )
718
+
719
+
720
+ jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
721
+
722
+
723
+ @set_module_as("brainstate.augment")
724
+ def hessian(
725
+ func: Callable,
726
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
727
+ argnums: Optional[Union[int, Sequence[int]]] = None,
728
+ return_value: bool = False,
729
+ holomorphic: bool = False,
730
+ has_aux: Optional[bool] = None,
731
+ unit_aware: bool = False,
732
+ check_states: bool = True,
733
+ ) -> GradientTransform:
734
+ """
735
+ Hessian of ``func`` as a dense array.
736
+
737
+ %s
738
+
739
+ Parameters
740
+ ----------
741
+ func : callable
742
+ Function whose Hessian is to be computed. Its arguments at positions
743
+ specified by ``argnums`` should be arrays, scalars, or standard Python
744
+ containers thereof. It should return arrays, scalars, or standard Python
745
+ containers thereof.
746
+ grad_states : optional, ArrayCollector, sequence of ArrayType
747
+ The variables required to compute their gradients.
748
+ argnums: Optional, integer or sequence of integers
749
+ Specifies which positional argument(s) to differentiate with respect to (default ``0``).
750
+ holomorphic : bool
751
+ Indicates whether ``fun`` is promised to be holomorphic. Default False.
752
+ return_value : bool
753
+ Whether return the hessian values.
754
+ has_aux: Optional, bool
755
+ Indicates whether ``fun`` returns a pair where the first element is considered
756
+ the output of the mathematical function to be differentiated and the second
757
+ element is auxiliary data. Default False.
758
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
759
+ mode. Default False.
760
+
761
+ Returns
762
+ -------
763
+ obj: ObjectTransform
764
+ The transformed object.
765
+ """
766
+ return GradientTransform(
767
+ target=func,
768
+ transform=u.autograd.hessian if unit_aware else jax.hessian,
769
+ grad_states=grad_states,
770
+ argnums=argnums,
771
+ return_value=return_value,
772
+ has_aux=False if has_aux is None else has_aux,
773
+ transform_params=dict(holomorphic=holomorphic),
774
+ check_states=check_states
775
+ )
776
+
777
+
778
+ hessian.__doc__ = hessian.__doc__ % _doc_of_return