brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/augment/_autograd.py
CHANGED
@@ -1,778 +1,778 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
|
18
|
-
This is because the gradient transformations are not using the Jaxpr, instead, most of them are
|
19
|
-
computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
|
20
|
-
which has been moved into the ``compile`` module.
|
21
|
-
|
22
|
-
The wrapped gradient transformations here are made possible by using the following ideas:
|
23
|
-
1. All the states to compute the gradients should be known before the transformation.
|
24
|
-
There must be provided through the ``grad_states`` argument in any of the gradient transformations.
|
25
|
-
2. The states that have been written in the function should be collected and updated after the function call.
|
26
|
-
We record these states during the function call and updated them after the function call.
|
27
|
-
|
28
|
-
"""
|
29
|
-
|
30
|
-
from functools import wraps, partial
|
31
|
-
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
32
|
-
|
33
|
-
import brainunit as u
|
34
|
-
import jax
|
35
|
-
|
36
|
-
from brainstate._state import State
|
37
|
-
from brainstate._utils import set_module_as
|
38
|
-
from brainstate.compile._make_jaxpr import StatefulFunction
|
39
|
-
from brainstate.typing import PyTree, Missing
|
40
|
-
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
41
|
-
|
42
|
-
__all__ = [
|
43
|
-
'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
44
|
-
]
|
45
|
-
|
46
|
-
A = TypeVar('A')
|
47
|
-
Gradient = PyTree
|
48
|
-
LossValue = PyTree
|
49
|
-
AuxData = PyTree
|
50
|
-
|
51
|
-
|
52
|
-
def _jacrev(
|
53
|
-
fun,
|
54
|
-
argnums=0,
|
55
|
-
holomorphic=False,
|
56
|
-
allow_int=False,
|
57
|
-
has_aux=False,
|
58
|
-
return_value=False,
|
59
|
-
unit_aware=False,
|
60
|
-
):
|
61
|
-
@wraps(fun)
|
62
|
-
def fun_wrapped(*args, **kwargs):
|
63
|
-
if has_aux:
|
64
|
-
y, aux = fun(*args, **kwargs)
|
65
|
-
if return_value:
|
66
|
-
return y, (y, aux)
|
67
|
-
else:
|
68
|
-
return y, aux
|
69
|
-
else:
|
70
|
-
y = fun(*args, **kwargs)
|
71
|
-
if return_value:
|
72
|
-
return y, y
|
73
|
-
else:
|
74
|
-
return y, None
|
75
|
-
|
76
|
-
if unit_aware:
|
77
|
-
transform = u.autograd.jacrev(fun_wrapped,
|
78
|
-
argnums=argnums,
|
79
|
-
holomorphic=holomorphic,
|
80
|
-
allow_int=allow_int,
|
81
|
-
has_aux=True)
|
82
|
-
else:
|
83
|
-
transform = jax.jacrev(fun_wrapped,
|
84
|
-
argnums=argnums,
|
85
|
-
holomorphic=holomorphic,
|
86
|
-
allow_int=allow_int,
|
87
|
-
has_aux=True)
|
88
|
-
|
89
|
-
@wraps(fun)
|
90
|
-
def jacfun(*args, **kwargs):
|
91
|
-
jac, aux = transform(*args, **kwargs)
|
92
|
-
if return_value:
|
93
|
-
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
94
|
-
else:
|
95
|
-
return (jac, aux) if has_aux else jac
|
96
|
-
|
97
|
-
return jacfun
|
98
|
-
|
99
|
-
|
100
|
-
def _jacfwd(
|
101
|
-
fun,
|
102
|
-
argnums=0,
|
103
|
-
holomorphic=False,
|
104
|
-
has_aux=False,
|
105
|
-
return_value=False,
|
106
|
-
unit_aware=False,
|
107
|
-
):
|
108
|
-
@wraps(fun)
|
109
|
-
def fun_wrapped(*args, **kwargs):
|
110
|
-
if has_aux:
|
111
|
-
y, aux = fun(*args, **kwargs)
|
112
|
-
if return_value:
|
113
|
-
return y, (y, aux)
|
114
|
-
else:
|
115
|
-
return y, aux
|
116
|
-
else:
|
117
|
-
y = fun(*args, **kwargs)
|
118
|
-
if return_value:
|
119
|
-
return y, y
|
120
|
-
else:
|
121
|
-
return y, None
|
122
|
-
|
123
|
-
if unit_aware:
|
124
|
-
transform = u.autograd.jacfwd(fun_wrapped,
|
125
|
-
argnums=argnums,
|
126
|
-
holomorphic=holomorphic,
|
127
|
-
has_aux=True)
|
128
|
-
else:
|
129
|
-
transform = jax.jacfwd(fun_wrapped,
|
130
|
-
argnums=argnums,
|
131
|
-
holomorphic=holomorphic,
|
132
|
-
has_aux=True)
|
133
|
-
|
134
|
-
@wraps(fun)
|
135
|
-
def jacfun(*args, **kwargs):
|
136
|
-
jac, aux = transform(*args, **kwargs)
|
137
|
-
if return_value:
|
138
|
-
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
139
|
-
else:
|
140
|
-
return (jac, aux) if has_aux else jac
|
141
|
-
|
142
|
-
return jacfun
|
143
|
-
|
144
|
-
|
145
|
-
TransformFn = Callable
|
146
|
-
|
147
|
-
|
148
|
-
class GradientTransform(PrettyRepr):
|
149
|
-
"""
|
150
|
-
Automatic Differentiation Transformations for the ``State`` system.
|
151
|
-
|
152
|
-
This class implements gradient transformations for functions that operate on State objects.
|
153
|
-
It allows for flexible configuration of gradient computation with respect to specified states
|
154
|
-
and function arguments.
|
155
|
-
|
156
|
-
Attributes:
|
157
|
-
target (Callable): The function to be transformed.
|
158
|
-
stateful_target (StatefulFunction): A wrapper around the target function for state management.
|
159
|
-
raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
|
160
|
-
true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
|
161
|
-
return_value (bool): Whether to return the function's value along with gradients.
|
162
|
-
has_aux (bool): Whether the function returns auxiliary data.
|
163
|
-
"""
|
164
|
-
|
165
|
-
__module__ = "brainstate.augment"
|
166
|
-
|
167
|
-
def __init__(
|
168
|
-
self,
|
169
|
-
target: Callable,
|
170
|
-
transform: TransformFn,
|
171
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
172
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
173
|
-
return_value: bool = False,
|
174
|
-
has_aux: bool = False,
|
175
|
-
transform_params: Optional[Dict[str, Any]] = None,
|
176
|
-
check_states: bool = True,
|
177
|
-
):
|
178
|
-
"""
|
179
|
-
Initialize a ``GradientTransform`` instance.
|
180
|
-
|
181
|
-
Args:
|
182
|
-
target (Callable): The function to be transformed.
|
183
|
-
transform (TransformFn): The transformation function to apply.
|
184
|
-
grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
|
185
|
-
argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
|
186
|
-
return_value (bool): Whether to return the function's value along with gradients.
|
187
|
-
has_aux (bool): Whether the function returns auxiliary data.
|
188
|
-
transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
|
189
|
-
|
190
|
-
Raises:
|
191
|
-
TypeError: If any grad_states are not State instances.
|
192
|
-
"""
|
193
|
-
# gradient variables
|
194
|
-
if isinstance(grad_states, dict):
|
195
|
-
grad_states = {k: v for k, v in grad_states.items()}
|
196
|
-
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
|
197
|
-
self._grad_state_ids = [id(v) for v in self._grad_states]
|
198
|
-
self._grad_id_to_state = {id(v): v for v in self._grad_states}
|
199
|
-
if any(not isinstance(v, State) for v in self._grad_states):
|
200
|
-
raise TypeError("All grad_states must be State instances.")
|
201
|
-
self.check_states = check_states
|
202
|
-
|
203
|
-
# parameters
|
204
|
-
if argnums is None and len(self._grad_states) == 0:
|
205
|
-
argnums = 0
|
206
|
-
if argnums is None:
|
207
|
-
assert len(self._grad_states) > 0
|
208
|
-
_argnums = 0
|
209
|
-
elif isinstance(argnums, int):
|
210
|
-
_argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
|
211
|
-
else:
|
212
|
-
assert isinstance(argnums, (tuple, list))
|
213
|
-
_argnums = tuple(a + 2 for a in argnums)
|
214
|
-
if len(self._grad_states) > 0:
|
215
|
-
_argnums = (0,) + _argnums
|
216
|
-
self.raw_argnums = argnums
|
217
|
-
self.true_argnums = _argnums
|
218
|
-
self.return_value = return_value
|
219
|
-
self.has_aux = has_aux
|
220
|
-
|
221
|
-
# target
|
222
|
-
assert callable(target), "The target should be a callable object."
|
223
|
-
self.target = target
|
224
|
-
self.stateful_target = StatefulFunction(target, name='gradient')
|
225
|
-
|
226
|
-
# transform
|
227
|
-
grad_setting = dict() if transform_params is None else transform_params
|
228
|
-
if self.has_aux:
|
229
|
-
self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
230
|
-
else:
|
231
|
-
self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
232
|
-
|
233
|
-
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
234
|
-
yield PrettyType(self.__class__.__name__)
|
235
|
-
yield PrettyAttr("target", self.target)
|
236
|
-
yield PrettyAttr("grad_states", self._grad_states)
|
237
|
-
yield PrettyAttr("grad_tree", self._grad_tree)
|
238
|
-
yield PrettyAttr("argnums", self.raw_argnums)
|
239
|
-
yield PrettyAttr("return_value", self.return_value)
|
240
|
-
yield PrettyAttr("has_aux", self.has_aux)
|
241
|
-
yield PrettyAttr("transform", self._transform)
|
242
|
-
|
243
|
-
def _split_state_vals(self, state_trace):
|
244
|
-
"""
|
245
|
-
Split state values into gradient and non-gradient states.
|
246
|
-
|
247
|
-
Args:
|
248
|
-
state_trace: The state trace containing all states.
|
249
|
-
|
250
|
-
Returns:
|
251
|
-
Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
|
252
|
-
"""
|
253
|
-
grad_vals = dict()
|
254
|
-
other_vals = dict()
|
255
|
-
all_ids = set(self._grad_state_ids)
|
256
|
-
for st in state_trace.states:
|
257
|
-
id_ = id(st)
|
258
|
-
if id_ in all_ids:
|
259
|
-
grad_vals[id_] = st.value
|
260
|
-
all_ids.remove(id_)
|
261
|
-
else:
|
262
|
-
other_vals[id_] = st.value
|
263
|
-
if len(all_ids):
|
264
|
-
if self.check_states:
|
265
|
-
err = f"Some states are not found in the state trace when performing gradient transformations.\n "
|
266
|
-
for i, id_ in enumerate(all_ids):
|
267
|
-
st = self._grad_id_to_state[id_]
|
268
|
-
st.raise_error_with_source_info(ValueError(err + str(st)))
|
269
|
-
else:
|
270
|
-
id2state = {id(st): st for st in self._grad_states}
|
271
|
-
for id_ in all_ids:
|
272
|
-
grad_vals[id_] = id2state[id_].value
|
273
|
-
|
274
|
-
return grad_vals, other_vals
|
275
|
-
|
276
|
-
def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
|
277
|
-
"""
|
278
|
-
Merge gradient and non-gradient state values back into a single list.
|
279
|
-
|
280
|
-
Args:
|
281
|
-
grad_vals (Dict): Dictionary of gradient state values.
|
282
|
-
other_vals (Dict): Dictionary of non-gradient state values.
|
283
|
-
state_trace: The state trace containing all states.
|
284
|
-
|
285
|
-
Returns:
|
286
|
-
List: A list of merged state values.
|
287
|
-
"""
|
288
|
-
res = []
|
289
|
-
for st in state_trace.states:
|
290
|
-
id_ = id(st)
|
291
|
-
if id_ in self._grad_state_ids:
|
292
|
-
res.append(grad_vals[id_])
|
293
|
-
else:
|
294
|
-
res.append(other_vals[id_])
|
295
|
-
return res
|
296
|
-
|
297
|
-
def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
298
|
-
"""
|
299
|
-
Call the target function with the given state values and arguments.
|
300
|
-
|
301
|
-
Args:
|
302
|
-
grad_vals (Dict): Dictionary of gradient state values.
|
303
|
-
other_vals (Dict): Dictionary of non-gradient state values.
|
304
|
-
*args: Positional arguments to pass to the target function.
|
305
|
-
**kwargs: Keyword arguments to pass to the target function.
|
306
|
-
|
307
|
-
Returns:
|
308
|
-
Tuple: A tuple containing updated state values and the function output.
|
309
|
-
"""
|
310
|
-
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
311
|
-
state_trace = self.stateful_target.get_state_trace(cache)
|
312
|
-
state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
|
313
|
-
state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
|
314
|
-
return state_vals, out
|
315
|
-
|
316
|
-
def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
317
|
-
"""
|
318
|
-
Wrapper function for target functions that return auxiliary data.
|
319
|
-
|
320
|
-
Args:
|
321
|
-
grad_vals (Dict): Dictionary of gradient state values.
|
322
|
-
other_vals (Dict): Dictionary of non-gradient state values.
|
323
|
-
*args: Positional arguments to pass to the target function.
|
324
|
-
**kwargs: Keyword arguments to pass to the target function.
|
325
|
-
|
326
|
-
Returns:
|
327
|
-
Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
|
328
|
-
"""
|
329
|
-
# Users should return the auxiliary data like::
|
330
|
-
# >>> # 1. example of return one data
|
331
|
-
# >>> return scalar_loss, data
|
332
|
-
# >>> # 2. example of return multiple data
|
333
|
-
# >>> return scalar_loss, (data1, data2, ...)
|
334
|
-
state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
335
|
-
return outs[0], (outs, state_vals)
|
336
|
-
|
337
|
-
def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
338
|
-
"""
|
339
|
-
Wrapper function for target functions that do not return auxiliary data.
|
340
|
-
|
341
|
-
Args:
|
342
|
-
grad_vals (Dict): Dictionary of gradient state values.
|
343
|
-
other_vals (Dict): Dictionary of non-gradient state values.
|
344
|
-
*args: Positional arguments to pass to the target function.
|
345
|
-
**kwargs: Keyword arguments to pass to the target function.
|
346
|
-
|
347
|
-
Returns:
|
348
|
-
Tuple: A tuple containing the output and a tuple of (output, updated state values).
|
349
|
-
"""
|
350
|
-
state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
351
|
-
return out, (out, state_vals)
|
352
|
-
|
353
|
-
def _return(self, rets, state_trace):
|
354
|
-
"""
|
355
|
-
Process and format the return values from the gradient computation.
|
356
|
-
|
357
|
-
Args:
|
358
|
-
rets: The raw results from the gradient computation.
|
359
|
-
state_trace: The state trace containing all states.
|
360
|
-
|
361
|
-
Returns:
|
362
|
-
Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
|
363
|
-
"""
|
364
|
-
# unpack the return values
|
365
|
-
grads, (outputs, new_state_vals) = rets
|
366
|
-
|
367
|
-
# assign new values to the states
|
368
|
-
state_trace.assign_state_vals(new_state_vals)
|
369
|
-
|
370
|
-
# check returned grads
|
371
|
-
if len(self._grad_states) > 0:
|
372
|
-
grads_of_states = grads if self.raw_argnums is None else grads[0]
|
373
|
-
grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
|
374
|
-
if self.raw_argnums is None:
|
375
|
-
grads = self._grad_tree.unflatten(grads_of_states)
|
376
|
-
else:
|
377
|
-
var_grads = self._grad_tree.unflatten(grads_of_states)
|
378
|
-
arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
|
379
|
-
grads = (var_grads, arg_grads)
|
380
|
-
|
381
|
-
# check returned value
|
382
|
-
if self.return_value:
|
383
|
-
# check aux
|
384
|
-
if self.has_aux:
|
385
|
-
return grads, outputs[0], outputs[1]
|
386
|
-
else:
|
387
|
-
return grads, outputs
|
388
|
-
else:
|
389
|
-
# check aux
|
390
|
-
if self.has_aux:
|
391
|
-
return grads, outputs[1]
|
392
|
-
else:
|
393
|
-
return grads
|
394
|
-
|
395
|
-
def __call__(
|
396
|
-
self, *args, **kwargs
|
397
|
-
) -> (
|
398
|
-
Gradient |
|
399
|
-
Tuple[Gradient, LossValue] |
|
400
|
-
Tuple[Gradient, AuxData] |
|
401
|
-
Tuple[Gradient, LossValue, AuxData]
|
402
|
-
):
|
403
|
-
"""
|
404
|
-
Compute gradients by calling the transformed function.
|
405
|
-
|
406
|
-
Args:
|
407
|
-
*args: Positional arguments to pass to the target function.
|
408
|
-
**kwargs: Keyword arguments to pass to the target function.
|
409
|
-
|
410
|
-
Returns:
|
411
|
-
Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
|
412
|
-
"""
|
413
|
-
|
414
|
-
# TODO: support jax.disable_jit()
|
415
|
-
|
416
|
-
# compute the model
|
417
|
-
self.stateful_target.make_jaxpr(*args, **kwargs)
|
418
|
-
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
419
|
-
|
420
|
-
# apply the gradient transformation
|
421
|
-
state_trace = self.stateful_target.get_state_trace(cache)
|
422
|
-
rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
|
423
|
-
|
424
|
-
# analyze and return the results
|
425
|
-
return self._return(rets, state_trace)
|
426
|
-
|
427
|
-
|
428
|
-
_doc_of_return = '''
|
429
|
-
|
430
|
-
1. When ``grad_states`` is None
|
431
|
-
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
432
|
-
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
433
|
-
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
434
|
-
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
435
|
-
2. When ``grad_states`` is not None and ``argnums`` is None
|
436
|
-
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
437
|
-
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
438
|
-
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
439
|
-
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
440
|
-
3. When ``grad_states`` is not None and ``argnums`` is not None
|
441
|
-
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
442
|
-
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
443
|
-
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
444
|
-
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
445
|
-
|
446
|
-
'''
|
447
|
-
|
448
|
-
|
449
|
-
@set_module_as("brainstate.augment")
|
450
|
-
def grad(
|
451
|
-
fun: Callable = Missing(),
|
452
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
453
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
454
|
-
holomorphic: Optional[bool] = False,
|
455
|
-
allow_int: Optional[bool] = False,
|
456
|
-
has_aux: Optional[bool] = None,
|
457
|
-
return_value: Optional[bool] = False,
|
458
|
-
unit_aware: bool = False,
|
459
|
-
check_states: bool = True,
|
460
|
-
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
461
|
-
"""
|
462
|
-
Compute the gradient of a scalar-valued function with respect to its arguments.
|
463
|
-
|
464
|
-
%s
|
465
|
-
|
466
|
-
Args:
|
467
|
-
fun: callable. the scalar-valued function to be differentiated.
|
468
|
-
allow_int: (bool) optional. Whether to allow differentiating with respect to
|
469
|
-
integer valued inputs. The gradient of an integer input will have a trivial
|
470
|
-
vector-space dtype (float0). Default False.
|
471
|
-
holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
|
472
|
-
Default False.
|
473
|
-
grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
|
474
|
-
in fun to take their gradients.
|
475
|
-
fun: the scalar-valued function to be differentiated.
|
476
|
-
argnums: (int or tuple of ints) optional. Specifies which positional
|
477
|
-
argument(s) to differentiate with respect to.
|
478
|
-
has_aux: (bool) optional. Indicates whether fun returns a pair where the
|
479
|
-
first element is considered the output of the mathematical function to be
|
480
|
-
differentiated and the second element is auxiliary data. Default False.
|
481
|
-
return_value: (bool) optional. Indicates whether to return the value of the
|
482
|
-
function along with the gradient. Default False.
|
483
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
484
|
-
mode. Default False.
|
485
|
-
|
486
|
-
Returns:
|
487
|
-
A function which computes the gradient of fun. The function takes the same
|
488
|
-
arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
|
489
|
-
the function returns a pair where the first element is the gradient and the
|
490
|
-
second element is the auxiliary data. If `return_value` is True, the function
|
491
|
-
returns a pair where the first element is the gradient and the second element
|
492
|
-
is the value of the function.
|
493
|
-
|
494
|
-
"""
|
495
|
-
if isinstance(fun, Missing):
|
496
|
-
def transform(fun) -> GradientTransform:
|
497
|
-
return GradientTransform(
|
498
|
-
target=fun,
|
499
|
-
transform=u.autograd.grad if unit_aware else jax.grad,
|
500
|
-
grad_states=grad_states,
|
501
|
-
argnums=argnums,
|
502
|
-
return_value=return_value,
|
503
|
-
has_aux=False if has_aux is None else has_aux,
|
504
|
-
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
505
|
-
check_states=check_states
|
506
|
-
)
|
507
|
-
|
508
|
-
return transform
|
509
|
-
|
510
|
-
return GradientTransform(
|
511
|
-
target=fun,
|
512
|
-
transform=u.autograd.grad if unit_aware else jax.grad,
|
513
|
-
grad_states=grad_states,
|
514
|
-
argnums=argnums,
|
515
|
-
return_value=return_value,
|
516
|
-
has_aux=False if has_aux is None else has_aux,
|
517
|
-
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
518
|
-
check_states=check_states
|
519
|
-
)
|
520
|
-
|
521
|
-
|
522
|
-
grad.__doc__ = grad.__doc__ % _doc_of_return
|
523
|
-
|
524
|
-
|
525
|
-
@set_module_as("brainstate.augment")
|
526
|
-
def vector_grad(
|
527
|
-
func: Callable = Missing(),
|
528
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
529
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
530
|
-
return_value: bool = False,
|
531
|
-
has_aux: Optional[bool] = None,
|
532
|
-
unit_aware: bool = False,
|
533
|
-
check_states: bool = True,
|
534
|
-
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
535
|
-
"""Take vector-valued gradients for function ``func``.
|
536
|
-
|
537
|
-
Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
|
538
|
-
the returns in this function are different for different argument settings.
|
539
|
-
|
540
|
-
%s
|
541
|
-
|
542
|
-
Parameters
|
543
|
-
----------
|
544
|
-
func: Callable
|
545
|
-
Function whose gradient is to be computed.
|
546
|
-
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
547
|
-
The variables in ``func`` to take their gradients.
|
548
|
-
has_aux: optional, bool
|
549
|
-
Indicates whether ``fun`` returns a pair where the
|
550
|
-
first element is considered the output of the mathematical function to be
|
551
|
-
differentiated and the second element is auxiliary data. Default False.
|
552
|
-
return_value : bool
|
553
|
-
Whether return the loss value.
|
554
|
-
argnums: Optional, integer or sequence of integers. Specifies which
|
555
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
556
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
557
|
-
mode. Default False.
|
558
|
-
|
559
|
-
Returns
|
560
|
-
-------
|
561
|
-
func : GradientTransform
|
562
|
-
The vector gradient function.
|
563
|
-
"""
|
564
|
-
|
565
|
-
if isinstance(func, Missing):
|
566
|
-
def transform(fun) -> GradientTransform:
|
567
|
-
return GradientTransform(
|
568
|
-
target=fun,
|
569
|
-
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
570
|
-
grad_states=grad_states,
|
571
|
-
argnums=argnums,
|
572
|
-
return_value=return_value,
|
573
|
-
has_aux=False if has_aux is None else has_aux,
|
574
|
-
check_states=check_states
|
575
|
-
)
|
576
|
-
|
577
|
-
return transform
|
578
|
-
|
579
|
-
else:
|
580
|
-
return GradientTransform(
|
581
|
-
target=func,
|
582
|
-
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
583
|
-
grad_states=grad_states,
|
584
|
-
argnums=argnums,
|
585
|
-
return_value=return_value,
|
586
|
-
has_aux=False if has_aux is None else has_aux,
|
587
|
-
check_states=check_states
|
588
|
-
)
|
589
|
-
|
590
|
-
|
591
|
-
vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
|
592
|
-
|
593
|
-
|
594
|
-
@set_module_as("brainstate.augment")
|
595
|
-
def jacrev(
|
596
|
-
fun: Callable,
|
597
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
598
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
599
|
-
has_aux: Optional[bool] = None,
|
600
|
-
return_value: bool = False,
|
601
|
-
holomorphic: bool = False,
|
602
|
-
allow_int: bool = False,
|
603
|
-
unit_aware: bool = False,
|
604
|
-
check_states: bool = True,
|
605
|
-
) -> GradientTransform:
|
606
|
-
"""
|
607
|
-
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
608
|
-
|
609
|
-
This function extends the JAX official ``jacrev`` to make automatic jacobian
|
610
|
-
computation on functions and class functions. Moreover, it supports returning
|
611
|
-
value ("return_value") and returning auxiliary data ("has_aux").
|
612
|
-
|
613
|
-
%s
|
614
|
-
|
615
|
-
|
616
|
-
Parameters
|
617
|
-
----------
|
618
|
-
fun: Callable
|
619
|
-
Function whose Jacobian is to be computed.
|
620
|
-
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
621
|
-
The variables in ``func`` to take their gradients.
|
622
|
-
has_aux: optional, bool
|
623
|
-
Indicates whether ``fun`` returns a pair where the
|
624
|
-
first element is considered the output of the mathematical function to be
|
625
|
-
differentiated and the second element is auxiliary data. Default False.
|
626
|
-
return_value : bool
|
627
|
-
Whether return the loss value.
|
628
|
-
argnums: Optional, integer or sequence of integers.
|
629
|
-
Specifies which
|
630
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
631
|
-
holomorphic: Optional, bool.
|
632
|
-
Indicates whether ``fun`` is promised to be
|
633
|
-
holomorphic. Default False.
|
634
|
-
allow_int: Optional, bool.
|
635
|
-
Whether to allow differentiating with
|
636
|
-
respect to integer valued inputs. The gradient of an integer input will
|
637
|
-
have a trivial vector-space dtype (float0). Default False.
|
638
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
639
|
-
mode. Default False.
|
640
|
-
|
641
|
-
Returns
|
642
|
-
-------
|
643
|
-
fun: GradientTransform
|
644
|
-
The transformed object.
|
645
|
-
"""
|
646
|
-
return GradientTransform(
|
647
|
-
target=fun,
|
648
|
-
transform=_jacrev,
|
649
|
-
grad_states=grad_states,
|
650
|
-
argnums=argnums,
|
651
|
-
return_value=return_value,
|
652
|
-
has_aux=False if has_aux is None else has_aux,
|
653
|
-
transform_params=dict(holomorphic=holomorphic,
|
654
|
-
allow_int=allow_int,
|
655
|
-
unit_aware=unit_aware, ),
|
656
|
-
check_states=check_states
|
657
|
-
)
|
658
|
-
|
659
|
-
|
660
|
-
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
661
|
-
|
662
|
-
jacobian = jacrev
|
663
|
-
|
664
|
-
|
665
|
-
@set_module_as("brainstate.augment")
|
666
|
-
def jacfwd(
|
667
|
-
func: Callable,
|
668
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
669
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
670
|
-
has_aux: Optional[bool] = None,
|
671
|
-
return_value: bool = False,
|
672
|
-
holomorphic: bool = False,
|
673
|
-
unit_aware: bool = False,
|
674
|
-
check_states: bool = True,
|
675
|
-
) -> GradientTransform:
|
676
|
-
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
677
|
-
|
678
|
-
This function extends the JAX official ``jacfwd`` to make automatic jacobian
|
679
|
-
computation on functions and class functions. Moreover, it supports returning
|
680
|
-
value ("return_value") and returning auxiliary data ("has_aux").
|
681
|
-
|
682
|
-
%s
|
683
|
-
|
684
|
-
Parameters
|
685
|
-
----------
|
686
|
-
func: Function whose Jacobian is to be computed.
|
687
|
-
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
688
|
-
The variables in ``func`` to take their gradients.
|
689
|
-
has_aux: optional, bool
|
690
|
-
Indicates whether ``fun`` returns a pair where the
|
691
|
-
first element is considered the output of the mathematical function to be
|
692
|
-
differentiated and the second element is auxiliary data. Default False.
|
693
|
-
return_value : bool
|
694
|
-
Whether return the loss value.
|
695
|
-
argnums: Optional, integer or sequence of integers. Specifies which
|
696
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
697
|
-
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
698
|
-
holomorphic. Default False.
|
699
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
700
|
-
mode. Default False.
|
701
|
-
|
702
|
-
Returns
|
703
|
-
-------
|
704
|
-
obj: GradientTransform
|
705
|
-
The transformed object.
|
706
|
-
"""
|
707
|
-
|
708
|
-
return GradientTransform(
|
709
|
-
target=func,
|
710
|
-
transform=_jacfwd,
|
711
|
-
grad_states=grad_states,
|
712
|
-
argnums=argnums,
|
713
|
-
return_value=return_value,
|
714
|
-
has_aux=False if has_aux is None else has_aux,
|
715
|
-
transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
|
716
|
-
check_states=check_states
|
717
|
-
)
|
718
|
-
|
719
|
-
|
720
|
-
jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
|
721
|
-
|
722
|
-
|
723
|
-
@set_module_as("brainstate.augment")
|
724
|
-
def hessian(
|
725
|
-
func: Callable,
|
726
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
727
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
728
|
-
return_value: bool = False,
|
729
|
-
holomorphic: bool = False,
|
730
|
-
has_aux: Optional[bool] = None,
|
731
|
-
unit_aware: bool = False,
|
732
|
-
check_states: bool = True,
|
733
|
-
) -> GradientTransform:
|
734
|
-
"""
|
735
|
-
Hessian of ``func`` as a dense array.
|
736
|
-
|
737
|
-
%s
|
738
|
-
|
739
|
-
Parameters
|
740
|
-
----------
|
741
|
-
func : callable
|
742
|
-
Function whose Hessian is to be computed. Its arguments at positions
|
743
|
-
specified by ``argnums`` should be arrays, scalars, or standard Python
|
744
|
-
containers thereof. It should return arrays, scalars, or standard Python
|
745
|
-
containers thereof.
|
746
|
-
grad_states : optional, ArrayCollector, sequence of ArrayType
|
747
|
-
The variables required to compute their gradients.
|
748
|
-
argnums: Optional, integer or sequence of integers
|
749
|
-
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
|
750
|
-
holomorphic : bool
|
751
|
-
Indicates whether ``fun`` is promised to be holomorphic. Default False.
|
752
|
-
return_value : bool
|
753
|
-
Whether return the hessian values.
|
754
|
-
has_aux: Optional, bool
|
755
|
-
Indicates whether ``fun`` returns a pair where the first element is considered
|
756
|
-
the output of the mathematical function to be differentiated and the second
|
757
|
-
element is auxiliary data. Default False.
|
758
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
759
|
-
mode. Default False.
|
760
|
-
|
761
|
-
Returns
|
762
|
-
-------
|
763
|
-
obj: ObjectTransform
|
764
|
-
The transformed object.
|
765
|
-
"""
|
766
|
-
return GradientTransform(
|
767
|
-
target=func,
|
768
|
-
transform=u.autograd.hessian if unit_aware else jax.hessian,
|
769
|
-
grad_states=grad_states,
|
770
|
-
argnums=argnums,
|
771
|
-
return_value=return_value,
|
772
|
-
has_aux=False if has_aux is None else has_aux,
|
773
|
-
transform_params=dict(holomorphic=holomorphic),
|
774
|
-
check_states=check_states
|
775
|
-
)
|
776
|
-
|
777
|
-
|
778
|
-
hessian.__doc__ = hessian.__doc__ % _doc_of_return
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Gradient transformations are relatively simple compared to ``vmap`` or ``pmap`` augmentations.
|
18
|
+
This is because the gradient transformations are not using the Jaxpr, instead, most of them are
|
19
|
+
computed in the Python level. However, there is an exception, the ``checkpoint`` transformation,
|
20
|
+
which has been moved into the ``compile`` module.
|
21
|
+
|
22
|
+
The wrapped gradient transformations here are made possible by using the following ideas:
|
23
|
+
1. All the states to compute the gradients should be known before the transformation.
|
24
|
+
There must be provided through the ``grad_states`` argument in any of the gradient transformations.
|
25
|
+
2. The states that have been written in the function should be collected and updated after the function call.
|
26
|
+
We record these states during the function call and updated them after the function call.
|
27
|
+
|
28
|
+
"""
|
29
|
+
|
30
|
+
from functools import wraps, partial
|
31
|
+
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
32
|
+
|
33
|
+
import brainunit as u
|
34
|
+
import jax
|
35
|
+
|
36
|
+
from brainstate._state import State
|
37
|
+
from brainstate._utils import set_module_as
|
38
|
+
from brainstate.compile._make_jaxpr import StatefulFunction
|
39
|
+
from brainstate.typing import PyTree, Missing
|
40
|
+
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
41
|
+
|
42
|
+
__all__ = [
|
43
|
+
'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
44
|
+
]
|
45
|
+
|
46
|
+
A = TypeVar('A')
|
47
|
+
Gradient = PyTree
|
48
|
+
LossValue = PyTree
|
49
|
+
AuxData = PyTree
|
50
|
+
|
51
|
+
|
52
|
+
def _jacrev(
|
53
|
+
fun,
|
54
|
+
argnums=0,
|
55
|
+
holomorphic=False,
|
56
|
+
allow_int=False,
|
57
|
+
has_aux=False,
|
58
|
+
return_value=False,
|
59
|
+
unit_aware=False,
|
60
|
+
):
|
61
|
+
@wraps(fun)
|
62
|
+
def fun_wrapped(*args, **kwargs):
|
63
|
+
if has_aux:
|
64
|
+
y, aux = fun(*args, **kwargs)
|
65
|
+
if return_value:
|
66
|
+
return y, (y, aux)
|
67
|
+
else:
|
68
|
+
return y, aux
|
69
|
+
else:
|
70
|
+
y = fun(*args, **kwargs)
|
71
|
+
if return_value:
|
72
|
+
return y, y
|
73
|
+
else:
|
74
|
+
return y, None
|
75
|
+
|
76
|
+
if unit_aware:
|
77
|
+
transform = u.autograd.jacrev(fun_wrapped,
|
78
|
+
argnums=argnums,
|
79
|
+
holomorphic=holomorphic,
|
80
|
+
allow_int=allow_int,
|
81
|
+
has_aux=True)
|
82
|
+
else:
|
83
|
+
transform = jax.jacrev(fun_wrapped,
|
84
|
+
argnums=argnums,
|
85
|
+
holomorphic=holomorphic,
|
86
|
+
allow_int=allow_int,
|
87
|
+
has_aux=True)
|
88
|
+
|
89
|
+
@wraps(fun)
|
90
|
+
def jacfun(*args, **kwargs):
|
91
|
+
jac, aux = transform(*args, **kwargs)
|
92
|
+
if return_value:
|
93
|
+
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
94
|
+
else:
|
95
|
+
return (jac, aux) if has_aux else jac
|
96
|
+
|
97
|
+
return jacfun
|
98
|
+
|
99
|
+
|
100
|
+
def _jacfwd(
|
101
|
+
fun,
|
102
|
+
argnums=0,
|
103
|
+
holomorphic=False,
|
104
|
+
has_aux=False,
|
105
|
+
return_value=False,
|
106
|
+
unit_aware=False,
|
107
|
+
):
|
108
|
+
@wraps(fun)
|
109
|
+
def fun_wrapped(*args, **kwargs):
|
110
|
+
if has_aux:
|
111
|
+
y, aux = fun(*args, **kwargs)
|
112
|
+
if return_value:
|
113
|
+
return y, (y, aux)
|
114
|
+
else:
|
115
|
+
return y, aux
|
116
|
+
else:
|
117
|
+
y = fun(*args, **kwargs)
|
118
|
+
if return_value:
|
119
|
+
return y, y
|
120
|
+
else:
|
121
|
+
return y, None
|
122
|
+
|
123
|
+
if unit_aware:
|
124
|
+
transform = u.autograd.jacfwd(fun_wrapped,
|
125
|
+
argnums=argnums,
|
126
|
+
holomorphic=holomorphic,
|
127
|
+
has_aux=True)
|
128
|
+
else:
|
129
|
+
transform = jax.jacfwd(fun_wrapped,
|
130
|
+
argnums=argnums,
|
131
|
+
holomorphic=holomorphic,
|
132
|
+
has_aux=True)
|
133
|
+
|
134
|
+
@wraps(fun)
|
135
|
+
def jacfun(*args, **kwargs):
|
136
|
+
jac, aux = transform(*args, **kwargs)
|
137
|
+
if return_value:
|
138
|
+
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
139
|
+
else:
|
140
|
+
return (jac, aux) if has_aux else jac
|
141
|
+
|
142
|
+
return jacfun
|
143
|
+
|
144
|
+
|
145
|
+
TransformFn = Callable
|
146
|
+
|
147
|
+
|
148
|
+
class GradientTransform(PrettyRepr):
|
149
|
+
"""
|
150
|
+
Automatic Differentiation Transformations for the ``State`` system.
|
151
|
+
|
152
|
+
This class implements gradient transformations for functions that operate on State objects.
|
153
|
+
It allows for flexible configuration of gradient computation with respect to specified states
|
154
|
+
and function arguments.
|
155
|
+
|
156
|
+
Attributes:
|
157
|
+
target (Callable): The function to be transformed.
|
158
|
+
stateful_target (StatefulFunction): A wrapper around the target function for state management.
|
159
|
+
raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
|
160
|
+
true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
|
161
|
+
return_value (bool): Whether to return the function's value along with gradients.
|
162
|
+
has_aux (bool): Whether the function returns auxiliary data.
|
163
|
+
"""
|
164
|
+
|
165
|
+
__module__ = "brainstate.augment"
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
target: Callable,
|
170
|
+
transform: TransformFn,
|
171
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
172
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
173
|
+
return_value: bool = False,
|
174
|
+
has_aux: bool = False,
|
175
|
+
transform_params: Optional[Dict[str, Any]] = None,
|
176
|
+
check_states: bool = True,
|
177
|
+
):
|
178
|
+
"""
|
179
|
+
Initialize a ``GradientTransform`` instance.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
target (Callable): The function to be transformed.
|
183
|
+
transform (TransformFn): The transformation function to apply.
|
184
|
+
grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
|
185
|
+
argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
|
186
|
+
return_value (bool): Whether to return the function's value along with gradients.
|
187
|
+
has_aux (bool): Whether the function returns auxiliary data.
|
188
|
+
transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
|
189
|
+
|
190
|
+
Raises:
|
191
|
+
TypeError: If any grad_states are not State instances.
|
192
|
+
"""
|
193
|
+
# gradient variables
|
194
|
+
if isinstance(grad_states, dict):
|
195
|
+
grad_states = {k: v for k, v in grad_states.items()}
|
196
|
+
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
|
197
|
+
self._grad_state_ids = [id(v) for v in self._grad_states]
|
198
|
+
self._grad_id_to_state = {id(v): v for v in self._grad_states}
|
199
|
+
if any(not isinstance(v, State) for v in self._grad_states):
|
200
|
+
raise TypeError("All grad_states must be State instances.")
|
201
|
+
self.check_states = check_states
|
202
|
+
|
203
|
+
# parameters
|
204
|
+
if argnums is None and len(self._grad_states) == 0:
|
205
|
+
argnums = 0
|
206
|
+
if argnums is None:
|
207
|
+
assert len(self._grad_states) > 0
|
208
|
+
_argnums = 0
|
209
|
+
elif isinstance(argnums, int):
|
210
|
+
_argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
|
211
|
+
else:
|
212
|
+
assert isinstance(argnums, (tuple, list))
|
213
|
+
_argnums = tuple(a + 2 for a in argnums)
|
214
|
+
if len(self._grad_states) > 0:
|
215
|
+
_argnums = (0,) + _argnums
|
216
|
+
self.raw_argnums = argnums
|
217
|
+
self.true_argnums = _argnums
|
218
|
+
self.return_value = return_value
|
219
|
+
self.has_aux = has_aux
|
220
|
+
|
221
|
+
# target
|
222
|
+
assert callable(target), "The target should be a callable object."
|
223
|
+
self.target = target
|
224
|
+
self.stateful_target = StatefulFunction(target, name='gradient')
|
225
|
+
|
226
|
+
# transform
|
227
|
+
grad_setting = dict() if transform_params is None else transform_params
|
228
|
+
if self.has_aux:
|
229
|
+
self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
230
|
+
else:
|
231
|
+
self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
232
|
+
|
233
|
+
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
234
|
+
yield PrettyType(self.__class__.__name__)
|
235
|
+
yield PrettyAttr("target", self.target)
|
236
|
+
yield PrettyAttr("grad_states", self._grad_states)
|
237
|
+
yield PrettyAttr("grad_tree", self._grad_tree)
|
238
|
+
yield PrettyAttr("argnums", self.raw_argnums)
|
239
|
+
yield PrettyAttr("return_value", self.return_value)
|
240
|
+
yield PrettyAttr("has_aux", self.has_aux)
|
241
|
+
yield PrettyAttr("transform", self._transform)
|
242
|
+
|
243
|
+
def _split_state_vals(self, state_trace):
|
244
|
+
"""
|
245
|
+
Split state values into gradient and non-gradient states.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
state_trace: The state trace containing all states.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
|
252
|
+
"""
|
253
|
+
grad_vals = dict()
|
254
|
+
other_vals = dict()
|
255
|
+
all_ids = set(self._grad_state_ids)
|
256
|
+
for st in state_trace.states:
|
257
|
+
id_ = id(st)
|
258
|
+
if id_ in all_ids:
|
259
|
+
grad_vals[id_] = st.value
|
260
|
+
all_ids.remove(id_)
|
261
|
+
else:
|
262
|
+
other_vals[id_] = st.value
|
263
|
+
if len(all_ids):
|
264
|
+
if self.check_states:
|
265
|
+
err = f"Some states are not found in the state trace when performing gradient transformations.\n "
|
266
|
+
for i, id_ in enumerate(all_ids):
|
267
|
+
st = self._grad_id_to_state[id_]
|
268
|
+
st.raise_error_with_source_info(ValueError(err + str(st)))
|
269
|
+
else:
|
270
|
+
id2state = {id(st): st for st in self._grad_states}
|
271
|
+
for id_ in all_ids:
|
272
|
+
grad_vals[id_] = id2state[id_].value
|
273
|
+
|
274
|
+
return grad_vals, other_vals
|
275
|
+
|
276
|
+
def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
|
277
|
+
"""
|
278
|
+
Merge gradient and non-gradient state values back into a single list.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
282
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
283
|
+
state_trace: The state trace containing all states.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
List: A list of merged state values.
|
287
|
+
"""
|
288
|
+
res = []
|
289
|
+
for st in state_trace.states:
|
290
|
+
id_ = id(st)
|
291
|
+
if id_ in self._grad_state_ids:
|
292
|
+
res.append(grad_vals[id_])
|
293
|
+
else:
|
294
|
+
res.append(other_vals[id_])
|
295
|
+
return res
|
296
|
+
|
297
|
+
def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
298
|
+
"""
|
299
|
+
Call the target function with the given state values and arguments.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
303
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
304
|
+
*args: Positional arguments to pass to the target function.
|
305
|
+
**kwargs: Keyword arguments to pass to the target function.
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
Tuple: A tuple containing updated state values and the function output.
|
309
|
+
"""
|
310
|
+
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
311
|
+
state_trace = self.stateful_target.get_state_trace(cache)
|
312
|
+
state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
|
313
|
+
state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
|
314
|
+
return state_vals, out
|
315
|
+
|
316
|
+
def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
317
|
+
"""
|
318
|
+
Wrapper function for target functions that return auxiliary data.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
322
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
323
|
+
*args: Positional arguments to pass to the target function.
|
324
|
+
**kwargs: Keyword arguments to pass to the target function.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
|
328
|
+
"""
|
329
|
+
# Users should return the auxiliary data like::
|
330
|
+
# >>> # 1. example of return one data
|
331
|
+
# >>> return scalar_loss, data
|
332
|
+
# >>> # 2. example of return multiple data
|
333
|
+
# >>> return scalar_loss, (data1, data2, ...)
|
334
|
+
state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
335
|
+
return outs[0], (outs, state_vals)
|
336
|
+
|
337
|
+
def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
338
|
+
"""
|
339
|
+
Wrapper function for target functions that do not return auxiliary data.
|
340
|
+
|
341
|
+
Args:
|
342
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
343
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
344
|
+
*args: Positional arguments to pass to the target function.
|
345
|
+
**kwargs: Keyword arguments to pass to the target function.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
Tuple: A tuple containing the output and a tuple of (output, updated state values).
|
349
|
+
"""
|
350
|
+
state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
351
|
+
return out, (out, state_vals)
|
352
|
+
|
353
|
+
def _return(self, rets, state_trace):
|
354
|
+
"""
|
355
|
+
Process and format the return values from the gradient computation.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
rets: The raw results from the gradient computation.
|
359
|
+
state_trace: The state trace containing all states.
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
|
363
|
+
"""
|
364
|
+
# unpack the return values
|
365
|
+
grads, (outputs, new_state_vals) = rets
|
366
|
+
|
367
|
+
# assign new values to the states
|
368
|
+
state_trace.assign_state_vals(new_state_vals)
|
369
|
+
|
370
|
+
# check returned grads
|
371
|
+
if len(self._grad_states) > 0:
|
372
|
+
grads_of_states = grads if self.raw_argnums is None else grads[0]
|
373
|
+
grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
|
374
|
+
if self.raw_argnums is None:
|
375
|
+
grads = self._grad_tree.unflatten(grads_of_states)
|
376
|
+
else:
|
377
|
+
var_grads = self._grad_tree.unflatten(grads_of_states)
|
378
|
+
arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
|
379
|
+
grads = (var_grads, arg_grads)
|
380
|
+
|
381
|
+
# check returned value
|
382
|
+
if self.return_value:
|
383
|
+
# check aux
|
384
|
+
if self.has_aux:
|
385
|
+
return grads, outputs[0], outputs[1]
|
386
|
+
else:
|
387
|
+
return grads, outputs
|
388
|
+
else:
|
389
|
+
# check aux
|
390
|
+
if self.has_aux:
|
391
|
+
return grads, outputs[1]
|
392
|
+
else:
|
393
|
+
return grads
|
394
|
+
|
395
|
+
def __call__(
|
396
|
+
self, *args, **kwargs
|
397
|
+
) -> (
|
398
|
+
Gradient |
|
399
|
+
Tuple[Gradient, LossValue] |
|
400
|
+
Tuple[Gradient, AuxData] |
|
401
|
+
Tuple[Gradient, LossValue, AuxData]
|
402
|
+
):
|
403
|
+
"""
|
404
|
+
Compute gradients by calling the transformed function.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
*args: Positional arguments to pass to the target function.
|
408
|
+
**kwargs: Keyword arguments to pass to the target function.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
|
412
|
+
"""
|
413
|
+
|
414
|
+
# TODO: support jax.disable_jit()
|
415
|
+
|
416
|
+
# compute the model
|
417
|
+
self.stateful_target.make_jaxpr(*args, **kwargs)
|
418
|
+
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
419
|
+
|
420
|
+
# apply the gradient transformation
|
421
|
+
state_trace = self.stateful_target.get_state_trace(cache)
|
422
|
+
rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
|
423
|
+
|
424
|
+
# analyze and return the results
|
425
|
+
return self._return(rets, state_trace)
|
426
|
+
|
427
|
+
|
428
|
+
_doc_of_return = '''
|
429
|
+
|
430
|
+
1. When ``grad_states`` is None
|
431
|
+
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
432
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
433
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
434
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
435
|
+
2. When ``grad_states`` is not None and ``argnums`` is None
|
436
|
+
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
437
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
438
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
439
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
440
|
+
3. When ``grad_states`` is not None and ``argnums`` is not None
|
441
|
+
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
442
|
+
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
443
|
+
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
444
|
+
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
445
|
+
|
446
|
+
'''
|
447
|
+
|
448
|
+
|
449
|
+
@set_module_as("brainstate.augment")
|
450
|
+
def grad(
|
451
|
+
fun: Callable = Missing(),
|
452
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
453
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
454
|
+
holomorphic: Optional[bool] = False,
|
455
|
+
allow_int: Optional[bool] = False,
|
456
|
+
has_aux: Optional[bool] = None,
|
457
|
+
return_value: Optional[bool] = False,
|
458
|
+
unit_aware: bool = False,
|
459
|
+
check_states: bool = True,
|
460
|
+
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
461
|
+
"""
|
462
|
+
Compute the gradient of a scalar-valued function with respect to its arguments.
|
463
|
+
|
464
|
+
%s
|
465
|
+
|
466
|
+
Args:
|
467
|
+
fun: callable. the scalar-valued function to be differentiated.
|
468
|
+
allow_int: (bool) optional. Whether to allow differentiating with respect to
|
469
|
+
integer valued inputs. The gradient of an integer input will have a trivial
|
470
|
+
vector-space dtype (float0). Default False.
|
471
|
+
holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
|
472
|
+
Default False.
|
473
|
+
grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
|
474
|
+
in fun to take their gradients.
|
475
|
+
fun: the scalar-valued function to be differentiated.
|
476
|
+
argnums: (int or tuple of ints) optional. Specifies which positional
|
477
|
+
argument(s) to differentiate with respect to.
|
478
|
+
has_aux: (bool) optional. Indicates whether fun returns a pair where the
|
479
|
+
first element is considered the output of the mathematical function to be
|
480
|
+
differentiated and the second element is auxiliary data. Default False.
|
481
|
+
return_value: (bool) optional. Indicates whether to return the value of the
|
482
|
+
function along with the gradient. Default False.
|
483
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
484
|
+
mode. Default False.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
A function which computes the gradient of fun. The function takes the same
|
488
|
+
arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
|
489
|
+
the function returns a pair where the first element is the gradient and the
|
490
|
+
second element is the auxiliary data. If `return_value` is True, the function
|
491
|
+
returns a pair where the first element is the gradient and the second element
|
492
|
+
is the value of the function.
|
493
|
+
|
494
|
+
"""
|
495
|
+
if isinstance(fun, Missing):
|
496
|
+
def transform(fun) -> GradientTransform:
|
497
|
+
return GradientTransform(
|
498
|
+
target=fun,
|
499
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
500
|
+
grad_states=grad_states,
|
501
|
+
argnums=argnums,
|
502
|
+
return_value=return_value,
|
503
|
+
has_aux=False if has_aux is None else has_aux,
|
504
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
505
|
+
check_states=check_states
|
506
|
+
)
|
507
|
+
|
508
|
+
return transform
|
509
|
+
|
510
|
+
return GradientTransform(
|
511
|
+
target=fun,
|
512
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
513
|
+
grad_states=grad_states,
|
514
|
+
argnums=argnums,
|
515
|
+
return_value=return_value,
|
516
|
+
has_aux=False if has_aux is None else has_aux,
|
517
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
518
|
+
check_states=check_states
|
519
|
+
)
|
520
|
+
|
521
|
+
|
522
|
+
grad.__doc__ = grad.__doc__ % _doc_of_return
|
523
|
+
|
524
|
+
|
525
|
+
@set_module_as("brainstate.augment")
|
526
|
+
def vector_grad(
|
527
|
+
func: Callable = Missing(),
|
528
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
529
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
530
|
+
return_value: bool = False,
|
531
|
+
has_aux: Optional[bool] = None,
|
532
|
+
unit_aware: bool = False,
|
533
|
+
check_states: bool = True,
|
534
|
+
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
535
|
+
"""Take vector-valued gradients for function ``func``.
|
536
|
+
|
537
|
+
Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
|
538
|
+
the returns in this function are different for different argument settings.
|
539
|
+
|
540
|
+
%s
|
541
|
+
|
542
|
+
Parameters
|
543
|
+
----------
|
544
|
+
func: Callable
|
545
|
+
Function whose gradient is to be computed.
|
546
|
+
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
547
|
+
The variables in ``func`` to take their gradients.
|
548
|
+
has_aux: optional, bool
|
549
|
+
Indicates whether ``fun`` returns a pair where the
|
550
|
+
first element is considered the output of the mathematical function to be
|
551
|
+
differentiated and the second element is auxiliary data. Default False.
|
552
|
+
return_value : bool
|
553
|
+
Whether return the loss value.
|
554
|
+
argnums: Optional, integer or sequence of integers. Specifies which
|
555
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
556
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
557
|
+
mode. Default False.
|
558
|
+
|
559
|
+
Returns
|
560
|
+
-------
|
561
|
+
func : GradientTransform
|
562
|
+
The vector gradient function.
|
563
|
+
"""
|
564
|
+
|
565
|
+
if isinstance(func, Missing):
|
566
|
+
def transform(fun) -> GradientTransform:
|
567
|
+
return GradientTransform(
|
568
|
+
target=fun,
|
569
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
570
|
+
grad_states=grad_states,
|
571
|
+
argnums=argnums,
|
572
|
+
return_value=return_value,
|
573
|
+
has_aux=False if has_aux is None else has_aux,
|
574
|
+
check_states=check_states
|
575
|
+
)
|
576
|
+
|
577
|
+
return transform
|
578
|
+
|
579
|
+
else:
|
580
|
+
return GradientTransform(
|
581
|
+
target=func,
|
582
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
583
|
+
grad_states=grad_states,
|
584
|
+
argnums=argnums,
|
585
|
+
return_value=return_value,
|
586
|
+
has_aux=False if has_aux is None else has_aux,
|
587
|
+
check_states=check_states
|
588
|
+
)
|
589
|
+
|
590
|
+
|
591
|
+
vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
|
592
|
+
|
593
|
+
|
594
|
+
@set_module_as("brainstate.augment")
|
595
|
+
def jacrev(
|
596
|
+
fun: Callable,
|
597
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
598
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
599
|
+
has_aux: Optional[bool] = None,
|
600
|
+
return_value: bool = False,
|
601
|
+
holomorphic: bool = False,
|
602
|
+
allow_int: bool = False,
|
603
|
+
unit_aware: bool = False,
|
604
|
+
check_states: bool = True,
|
605
|
+
) -> GradientTransform:
|
606
|
+
"""
|
607
|
+
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
608
|
+
|
609
|
+
This function extends the JAX official ``jacrev`` to make automatic jacobian
|
610
|
+
computation on functions and class functions. Moreover, it supports returning
|
611
|
+
value ("return_value") and returning auxiliary data ("has_aux").
|
612
|
+
|
613
|
+
%s
|
614
|
+
|
615
|
+
|
616
|
+
Parameters
|
617
|
+
----------
|
618
|
+
fun: Callable
|
619
|
+
Function whose Jacobian is to be computed.
|
620
|
+
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
621
|
+
The variables in ``func`` to take their gradients.
|
622
|
+
has_aux: optional, bool
|
623
|
+
Indicates whether ``fun`` returns a pair where the
|
624
|
+
first element is considered the output of the mathematical function to be
|
625
|
+
differentiated and the second element is auxiliary data. Default False.
|
626
|
+
return_value : bool
|
627
|
+
Whether return the loss value.
|
628
|
+
argnums: Optional, integer or sequence of integers.
|
629
|
+
Specifies which
|
630
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
631
|
+
holomorphic: Optional, bool.
|
632
|
+
Indicates whether ``fun`` is promised to be
|
633
|
+
holomorphic. Default False.
|
634
|
+
allow_int: Optional, bool.
|
635
|
+
Whether to allow differentiating with
|
636
|
+
respect to integer valued inputs. The gradient of an integer input will
|
637
|
+
have a trivial vector-space dtype (float0). Default False.
|
638
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
639
|
+
mode. Default False.
|
640
|
+
|
641
|
+
Returns
|
642
|
+
-------
|
643
|
+
fun: GradientTransform
|
644
|
+
The transformed object.
|
645
|
+
"""
|
646
|
+
return GradientTransform(
|
647
|
+
target=fun,
|
648
|
+
transform=_jacrev,
|
649
|
+
grad_states=grad_states,
|
650
|
+
argnums=argnums,
|
651
|
+
return_value=return_value,
|
652
|
+
has_aux=False if has_aux is None else has_aux,
|
653
|
+
transform_params=dict(holomorphic=holomorphic,
|
654
|
+
allow_int=allow_int,
|
655
|
+
unit_aware=unit_aware, ),
|
656
|
+
check_states=check_states
|
657
|
+
)
|
658
|
+
|
659
|
+
|
660
|
+
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
661
|
+
|
662
|
+
jacobian = jacrev
|
663
|
+
|
664
|
+
|
665
|
+
@set_module_as("brainstate.augment")
|
666
|
+
def jacfwd(
|
667
|
+
func: Callable,
|
668
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
669
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
670
|
+
has_aux: Optional[bool] = None,
|
671
|
+
return_value: bool = False,
|
672
|
+
holomorphic: bool = False,
|
673
|
+
unit_aware: bool = False,
|
674
|
+
check_states: bool = True,
|
675
|
+
) -> GradientTransform:
|
676
|
+
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
677
|
+
|
678
|
+
This function extends the JAX official ``jacfwd`` to make automatic jacobian
|
679
|
+
computation on functions and class functions. Moreover, it supports returning
|
680
|
+
value ("return_value") and returning auxiliary data ("has_aux").
|
681
|
+
|
682
|
+
%s
|
683
|
+
|
684
|
+
Parameters
|
685
|
+
----------
|
686
|
+
func: Function whose Jacobian is to be computed.
|
687
|
+
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
688
|
+
The variables in ``func`` to take their gradients.
|
689
|
+
has_aux: optional, bool
|
690
|
+
Indicates whether ``fun`` returns a pair where the
|
691
|
+
first element is considered the output of the mathematical function to be
|
692
|
+
differentiated and the second element is auxiliary data. Default False.
|
693
|
+
return_value : bool
|
694
|
+
Whether return the loss value.
|
695
|
+
argnums: Optional, integer or sequence of integers. Specifies which
|
696
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
697
|
+
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
698
|
+
holomorphic. Default False.
|
699
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
700
|
+
mode. Default False.
|
701
|
+
|
702
|
+
Returns
|
703
|
+
-------
|
704
|
+
obj: GradientTransform
|
705
|
+
The transformed object.
|
706
|
+
"""
|
707
|
+
|
708
|
+
return GradientTransform(
|
709
|
+
target=func,
|
710
|
+
transform=_jacfwd,
|
711
|
+
grad_states=grad_states,
|
712
|
+
argnums=argnums,
|
713
|
+
return_value=return_value,
|
714
|
+
has_aux=False if has_aux is None else has_aux,
|
715
|
+
transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
|
716
|
+
check_states=check_states
|
717
|
+
)
|
718
|
+
|
719
|
+
|
720
|
+
jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
|
721
|
+
|
722
|
+
|
723
|
+
@set_module_as("brainstate.augment")
|
724
|
+
def hessian(
|
725
|
+
func: Callable,
|
726
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
727
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
728
|
+
return_value: bool = False,
|
729
|
+
holomorphic: bool = False,
|
730
|
+
has_aux: Optional[bool] = None,
|
731
|
+
unit_aware: bool = False,
|
732
|
+
check_states: bool = True,
|
733
|
+
) -> GradientTransform:
|
734
|
+
"""
|
735
|
+
Hessian of ``func`` as a dense array.
|
736
|
+
|
737
|
+
%s
|
738
|
+
|
739
|
+
Parameters
|
740
|
+
----------
|
741
|
+
func : callable
|
742
|
+
Function whose Hessian is to be computed. Its arguments at positions
|
743
|
+
specified by ``argnums`` should be arrays, scalars, or standard Python
|
744
|
+
containers thereof. It should return arrays, scalars, or standard Python
|
745
|
+
containers thereof.
|
746
|
+
grad_states : optional, ArrayCollector, sequence of ArrayType
|
747
|
+
The variables required to compute their gradients.
|
748
|
+
argnums: Optional, integer or sequence of integers
|
749
|
+
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
|
750
|
+
holomorphic : bool
|
751
|
+
Indicates whether ``fun`` is promised to be holomorphic. Default False.
|
752
|
+
return_value : bool
|
753
|
+
Whether return the hessian values.
|
754
|
+
has_aux: Optional, bool
|
755
|
+
Indicates whether ``fun`` returns a pair where the first element is considered
|
756
|
+
the output of the mathematical function to be differentiated and the second
|
757
|
+
element is auxiliary data. Default False.
|
758
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
759
|
+
mode. Default False.
|
760
|
+
|
761
|
+
Returns
|
762
|
+
-------
|
763
|
+
obj: ObjectTransform
|
764
|
+
The transformed object.
|
765
|
+
"""
|
766
|
+
return GradientTransform(
|
767
|
+
target=func,
|
768
|
+
transform=u.autograd.hessian if unit_aware else jax.hessian,
|
769
|
+
grad_states=grad_states,
|
770
|
+
argnums=argnums,
|
771
|
+
return_value=return_value,
|
772
|
+
has_aux=False if has_aux is None else has_aux,
|
773
|
+
transform_params=dict(holomorphic=holomorphic),
|
774
|
+
check_states=check_states
|
775
|
+
)
|
776
|
+
|
777
|
+
|
778
|
+
hessian.__doc__ = hessian.__doc__ % _doc_of_return
|