brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/_state_test.py
CHANGED
brainstate/augment/__init__.py
CHANGED
@@ -19,12 +19,12 @@ This module includes transformations for augmenting the functionalities of JAX c
|
|
19
19
|
|
20
20
|
from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
|
21
21
|
from ._eval_shape import abstract_init
|
22
|
-
from ._mapping import vmap, pmap, map
|
22
|
+
from ._mapping import vmap, pmap, map, vmap_new_states
|
23
23
|
from ._random import restore_rngs
|
24
24
|
|
25
25
|
__all__ = [
|
26
26
|
'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
|
27
27
|
'abstract_init',
|
28
|
-
'vmap', 'pmap', 'map',
|
28
|
+
'vmap', 'pmap', 'map', 'vmap_new_states',
|
29
29
|
'restore_rngs',
|
30
30
|
]
|
brainstate/augment/_autograd.py
CHANGED
@@ -29,14 +29,14 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
29
29
|
|
30
30
|
from __future__ import annotations
|
31
31
|
|
32
|
-
from functools import wraps, partial
|
33
|
-
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
34
|
-
|
35
32
|
import brainunit as u
|
36
33
|
import jax
|
34
|
+
from functools import wraps, partial
|
35
|
+
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
37
36
|
|
38
|
-
from brainstate._state import State
|
37
|
+
from brainstate._state import State
|
39
38
|
from brainstate._utils import set_module_as
|
39
|
+
from brainstate.compile._make_jaxpr import StatefulFunction
|
40
40
|
from brainstate.typing import PyTree, Missing
|
41
41
|
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
42
42
|
|
@@ -149,7 +149,20 @@ TransformFn = Callable
|
|
149
149
|
class GradientTransform(PrettyRepr):
|
150
150
|
"""
|
151
151
|
Automatic Differentiation Transformations for the ``State`` system.
|
152
|
+
|
153
|
+
This class implements gradient transformations for functions that operate on State objects.
|
154
|
+
It allows for flexible configuration of gradient computation with respect to specified states
|
155
|
+
and function arguments.
|
156
|
+
|
157
|
+
Attributes:
|
158
|
+
target (Callable): The function to be transformed.
|
159
|
+
stateful_target (StatefulFunction): A wrapper around the target function for state management.
|
160
|
+
raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
|
161
|
+
true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
|
162
|
+
return_value (bool): Whether to return the function's value along with gradients.
|
163
|
+
has_aux (bool): Whether the function returns auxiliary data.
|
152
164
|
"""
|
165
|
+
|
153
166
|
__module__ = "brainstate.augment"
|
154
167
|
|
155
168
|
def __init__(
|
@@ -162,10 +175,26 @@ class GradientTransform(PrettyRepr):
|
|
162
175
|
has_aux: bool = False,
|
163
176
|
transform_params: Optional[Dict[str, Any]] = None,
|
164
177
|
):
|
178
|
+
"""
|
179
|
+
Initialize a ``GradientTransform`` instance.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
target (Callable): The function to be transformed.
|
183
|
+
transform (TransformFn): The transformation function to apply.
|
184
|
+
grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
|
185
|
+
argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
|
186
|
+
return_value (bool): Whether to return the function's value along with gradients.
|
187
|
+
has_aux (bool): Whether the function returns auxiliary data.
|
188
|
+
transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
|
189
|
+
|
190
|
+
Raises:
|
191
|
+
TypeError: If any grad_states are not State instances.
|
192
|
+
"""
|
165
193
|
# gradient variables
|
166
194
|
if isinstance(grad_states, dict):
|
167
195
|
grad_states = {k: v for k, v in grad_states.items()}
|
168
196
|
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
|
197
|
+
self._grad_state_ids = [id(v) for v in self._grad_states]
|
169
198
|
if any(not isinstance(v, State) for v in self._grad_states):
|
170
199
|
raise TypeError("All grad_states must be State instances.")
|
171
200
|
|
@@ -176,107 +205,209 @@ class GradientTransform(PrettyRepr):
|
|
176
205
|
assert len(self._grad_states) > 0
|
177
206
|
_argnums = 0
|
178
207
|
elif isinstance(argnums, int):
|
179
|
-
_argnums = (0, argnums +
|
208
|
+
_argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
|
180
209
|
else:
|
181
210
|
assert isinstance(argnums, (tuple, list))
|
182
|
-
_argnums = tuple(a +
|
211
|
+
_argnums = tuple(a + 2 for a in argnums)
|
183
212
|
if len(self._grad_states) > 0:
|
184
213
|
_argnums = (0,) + _argnums
|
185
|
-
self.
|
186
|
-
self.
|
187
|
-
self.
|
188
|
-
self.
|
214
|
+
self.raw_argnums = argnums
|
215
|
+
self.true_argnums = _argnums
|
216
|
+
self.return_value = return_value
|
217
|
+
self.has_aux = has_aux
|
189
218
|
|
190
219
|
# target
|
220
|
+
assert callable(target), "The target should be a callable object."
|
191
221
|
self.target = target
|
222
|
+
self.stateful_target = StatefulFunction(target, name='gradient')
|
192
223
|
|
193
224
|
# transform
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
|
225
|
+
grad_setting = dict() if transform_params is None else transform_params
|
226
|
+
if self.has_aux:
|
227
|
+
self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
198
228
|
else:
|
199
|
-
self._transform = transform(self._fun_without_aux, argnums=self.
|
229
|
+
self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
|
200
230
|
|
201
231
|
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
202
232
|
yield PrettyType(self.__class__.__name__)
|
203
233
|
yield PrettyAttr("target", self.target)
|
204
234
|
yield PrettyAttr("grad_states", self._grad_states)
|
205
235
|
yield PrettyAttr("grad_tree", self._grad_tree)
|
206
|
-
yield PrettyAttr("argnums", self.
|
207
|
-
yield PrettyAttr("return_value", self.
|
208
|
-
yield PrettyAttr("has_aux", self.
|
236
|
+
yield PrettyAttr("argnums", self.raw_argnums)
|
237
|
+
yield PrettyAttr("return_value", self.return_value)
|
238
|
+
yield PrettyAttr("has_aux", self.has_aux)
|
209
239
|
yield PrettyAttr("transform", self._transform)
|
210
240
|
|
211
|
-
def
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
for
|
224
|
-
|
241
|
+
def _split_state_vals(self, state_trace):
|
242
|
+
"""
|
243
|
+
Split state values into gradient and non-gradient states.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
state_trace: The state trace containing all states.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
|
250
|
+
"""
|
251
|
+
grad_vals = dict()
|
252
|
+
other_vals = dict()
|
253
|
+
for st in state_trace.states:
|
254
|
+
id_ = id(st)
|
255
|
+
if id_ in self._grad_state_ids:
|
256
|
+
grad_vals[id_] = st.value
|
257
|
+
else:
|
258
|
+
other_vals[id_] = st.value
|
259
|
+
return grad_vals, other_vals
|
260
|
+
|
261
|
+
def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
|
262
|
+
"""
|
263
|
+
Merge gradient and non-gradient state values back into a single list.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
267
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
268
|
+
state_trace: The state trace containing all states.
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
List: A list of merged state values.
|
272
|
+
"""
|
273
|
+
res = []
|
274
|
+
for st in state_trace.states:
|
275
|
+
id_ = id(st)
|
276
|
+
if id_ in self._grad_state_ids:
|
277
|
+
res.append(grad_vals[id_])
|
278
|
+
else:
|
279
|
+
res.append(other_vals[id_])
|
280
|
+
return res
|
281
|
+
|
282
|
+
def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
283
|
+
"""
|
284
|
+
Call the target function with the given state values and arguments.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
288
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
289
|
+
*args: Positional arguments to pass to the target function.
|
290
|
+
**kwargs: Keyword arguments to pass to the target function.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
Tuple: A tuple containing updated state values and the function output.
|
294
|
+
"""
|
295
|
+
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
296
|
+
state_trace = self.stateful_target.get_state_trace(cache)
|
297
|
+
state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
|
298
|
+
state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
|
299
|
+
return state_vals, out
|
300
|
+
|
301
|
+
def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
302
|
+
"""
|
303
|
+
Wrapper function for target functions that return auxiliary data.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
307
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
308
|
+
*args: Positional arguments to pass to the target function.
|
309
|
+
**kwargs: Keyword arguments to pass to the target function.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
|
313
|
+
"""
|
225
314
|
# Users should return the auxiliary data like::
|
226
315
|
# >>> # 1. example of return one data
|
227
316
|
# >>> return scalar_loss, data
|
228
317
|
# >>> # 2. example of return multiple data
|
229
318
|
# >>> return scalar_loss, (data1, data2, ...)
|
230
|
-
outs = self._call_target(*args, **kwargs)
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
319
|
+
state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
320
|
+
return outs[0], (outs, state_vals)
|
321
|
+
|
322
|
+
def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
|
323
|
+
"""
|
324
|
+
Wrapper function for target functions that do not return auxiliary data.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
grad_vals (Dict): Dictionary of gradient state values.
|
328
|
+
other_vals (Dict): Dictionary of non-gradient state values.
|
329
|
+
*args: Positional arguments to pass to the target function.
|
330
|
+
**kwargs: Keyword arguments to pass to the target function.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
Tuple: A tuple containing the output and a tuple of (output, updated state values).
|
334
|
+
"""
|
335
|
+
state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
|
336
|
+
return out, (out, state_vals)
|
337
|
+
|
338
|
+
def _return(self, rets, state_trace):
|
339
|
+
"""
|
340
|
+
Process and format the return values from the gradient computation.
|
341
|
+
|
342
|
+
Args:
|
343
|
+
rets: The raw results from the gradient computation.
|
344
|
+
state_trace: The state trace containing all states.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
|
348
|
+
"""
|
349
|
+
# unpack the return values
|
350
|
+
grads, (outputs, new_state_vals) = rets
|
351
|
+
|
352
|
+
# assign new values to the states
|
353
|
+
state_trace.assign_state_vals(new_state_vals)
|
251
354
|
|
252
355
|
# check returned grads
|
253
356
|
if len(self._grad_states) > 0:
|
254
|
-
if self.
|
255
|
-
|
357
|
+
grads_of_states = grads if self.raw_argnums is None else grads[0]
|
358
|
+
grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
|
359
|
+
if self.raw_argnums is None:
|
360
|
+
grads = self._grad_tree.unflatten(grads_of_states)
|
256
361
|
else:
|
257
|
-
var_grads = self._grad_tree.unflatten(
|
258
|
-
arg_grads = grads[1] if isinstance(self.
|
362
|
+
var_grads = self._grad_tree.unflatten(grads_of_states)
|
363
|
+
arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
|
259
364
|
grads = (var_grads, arg_grads)
|
260
365
|
|
261
366
|
# check returned value
|
262
|
-
if self.
|
367
|
+
if self.return_value:
|
263
368
|
# check aux
|
264
|
-
if self.
|
369
|
+
if self.has_aux:
|
265
370
|
return grads, outputs[0], outputs[1]
|
266
371
|
else:
|
267
372
|
return grads, outputs
|
268
373
|
else:
|
269
374
|
# check aux
|
270
|
-
if self.
|
375
|
+
if self.has_aux:
|
271
376
|
return grads, outputs[1]
|
272
377
|
else:
|
273
378
|
return grads
|
274
379
|
|
275
380
|
def __call__(
|
276
381
|
self, *args, **kwargs
|
277
|
-
) ->
|
278
|
-
|
279
|
-
|
382
|
+
) -> (
|
383
|
+
Gradient |
|
384
|
+
Tuple[Gradient, LossValue] |
|
385
|
+
Tuple[Gradient, AuxData] |
|
386
|
+
Tuple[Gradient, LossValue, AuxData]
|
387
|
+
):
|
388
|
+
"""
|
389
|
+
Compute gradients by calling the transformed function.
|
390
|
+
|
391
|
+
Args:
|
392
|
+
*args: Positional arguments to pass to the target function.
|
393
|
+
**kwargs: Keyword arguments to pass to the target function.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
|
397
|
+
"""
|
398
|
+
|
399
|
+
# TODO: support jax.disable_jit()
|
400
|
+
|
401
|
+
# compute the model
|
402
|
+
self.stateful_target.make_jaxpr(*args, **kwargs)
|
403
|
+
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
404
|
+
|
405
|
+
# apply the gradient transformation
|
406
|
+
state_trace = self.stateful_target.get_state_trace(cache)
|
407
|
+
rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
|
408
|
+
|
409
|
+
# analyze and return the results
|
410
|
+
return self._return(rets, state_trace)
|
280
411
|
|
281
412
|
|
282
413
|
_doc_of_return = '''
|
@@ -347,25 +478,27 @@ def grad(
|
|
347
478
|
"""
|
348
479
|
if isinstance(fun, Missing):
|
349
480
|
def transform(fun) -> GradientTransform:
|
350
|
-
return GradientTransform(
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
481
|
+
return GradientTransform(
|
482
|
+
target=fun,
|
483
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
484
|
+
grad_states=grad_states,
|
485
|
+
argnums=argnums,
|
486
|
+
return_value=return_value,
|
487
|
+
has_aux=False if has_aux is None else has_aux,
|
488
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
|
489
|
+
)
|
358
490
|
|
359
491
|
return transform
|
360
492
|
|
361
|
-
return GradientTransform(
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
493
|
+
return GradientTransform(
|
494
|
+
target=fun,
|
495
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
496
|
+
grad_states=grad_states,
|
497
|
+
argnums=argnums,
|
498
|
+
return_value=return_value,
|
499
|
+
has_aux=False if has_aux is None else has_aux,
|
500
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
|
501
|
+
)
|
369
502
|
|
370
503
|
|
371
504
|
grad.__doc__ = grad.__doc__ % _doc_of_return
|
@@ -412,22 +545,26 @@ def vector_grad(
|
|
412
545
|
|
413
546
|
if isinstance(func, Missing):
|
414
547
|
def transform(fun) -> GradientTransform:
|
415
|
-
return GradientTransform(
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
548
|
+
return GradientTransform(
|
549
|
+
target=fun,
|
550
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
551
|
+
grad_states=grad_states,
|
552
|
+
argnums=argnums,
|
553
|
+
return_value=return_value,
|
554
|
+
has_aux=False if has_aux is None else has_aux
|
555
|
+
)
|
421
556
|
|
422
557
|
return transform
|
423
558
|
|
424
559
|
else:
|
425
|
-
return GradientTransform(
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
560
|
+
return GradientTransform(
|
561
|
+
target=func,
|
562
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
563
|
+
grad_states=grad_states,
|
564
|
+
argnums=argnums,
|
565
|
+
return_value=return_value,
|
566
|
+
has_aux=False if has_aux is None else has_aux
|
567
|
+
)
|
431
568
|
|
432
569
|
|
433
570
|
vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
|
@@ -484,15 +621,17 @@ def jacrev(
|
|
484
621
|
fun: GradientTransform
|
485
622
|
The transformed object.
|
486
623
|
"""
|
487
|
-
return GradientTransform(
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
624
|
+
return GradientTransform(
|
625
|
+
target=fun,
|
626
|
+
transform=_jacrev,
|
627
|
+
grad_states=grad_states,
|
628
|
+
argnums=argnums,
|
629
|
+
return_value=return_value,
|
630
|
+
has_aux=False if has_aux is None else has_aux,
|
631
|
+
transform_params=dict(holomorphic=holomorphic,
|
632
|
+
allow_int=allow_int,
|
633
|
+
unit_aware=unit_aware, )
|
634
|
+
)
|
496
635
|
|
497
636
|
|
498
637
|
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
@@ -542,14 +681,15 @@ def jacfwd(
|
|
542
681
|
The transformed object.
|
543
682
|
"""
|
544
683
|
|
545
|
-
return GradientTransform(
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
684
|
+
return GradientTransform(
|
685
|
+
target=func,
|
686
|
+
transform=_jacfwd,
|
687
|
+
grad_states=grad_states,
|
688
|
+
argnums=argnums,
|
689
|
+
return_value=return_value,
|
690
|
+
has_aux=False if has_aux is None else has_aux,
|
691
|
+
transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware)
|
692
|
+
)
|
553
693
|
|
554
694
|
|
555
695
|
jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
|
@@ -597,13 +737,15 @@ def hessian(
|
|
597
737
|
obj: ObjectTransform
|
598
738
|
The transformed object.
|
599
739
|
"""
|
600
|
-
return GradientTransform(
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
740
|
+
return GradientTransform(
|
741
|
+
target=func,
|
742
|
+
transform=u.autograd.hessian if unit_aware else jax.hessian,
|
743
|
+
grad_states=grad_states,
|
744
|
+
argnums=argnums,
|
745
|
+
return_value=return_value,
|
746
|
+
has_aux=False if has_aux is None else has_aux,
|
747
|
+
transform_params=dict(holomorphic=holomorphic)
|
748
|
+
)
|
607
749
|
|
608
750
|
|
609
751
|
hessian.__doc__ = hessian.__doc__ % _doc_of_return
|
@@ -16,13 +16,12 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import unittest
|
20
|
-
from pprint import pprint
|
21
|
-
|
22
19
|
import brainunit as u
|
23
20
|
import jax
|
24
21
|
import jax.numpy as jnp
|
25
22
|
import pytest
|
23
|
+
import unittest
|
24
|
+
from pprint import pprint
|
26
25
|
|
27
26
|
import brainstate as bst
|
28
27
|
from brainstate.augment._autograd import _jacfwd
|
@@ -16,12 +16,11 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
from typing import Any, TypeVar, Callable, Sequence, Union
|
20
|
-
|
21
19
|
import jax
|
20
|
+
from typing import Any, TypeVar, Callable, Sequence, Union
|
22
21
|
|
22
|
+
from brainstate import random
|
23
23
|
from brainstate.graph import Node, flatten, unflatten
|
24
|
-
from brainstate.random import DEFAULT, RandomState
|
25
24
|
from ._random import restore_rngs
|
26
25
|
|
27
26
|
__all__ = [
|
@@ -34,7 +33,7 @@ A = TypeVar('A')
|
|
34
33
|
def abstract_init(
|
35
34
|
fn: Callable[..., A],
|
36
35
|
*args: Any,
|
37
|
-
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
36
|
+
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
38
37
|
**kwargs: Any,
|
39
38
|
) -> A:
|
40
39
|
"""
|