brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +875 -93
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +183 -35
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +128 -10
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/_state_test.py CHANGED
@@ -14,10 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import unittest
18
-
19
- import jax
20
17
  import jax.numpy as jnp
18
+ import unittest
21
19
 
22
20
  import brainstate as bst
23
21
 
@@ -19,12 +19,12 @@ This module includes transformations for augmenting the functionalities of JAX c
19
19
 
20
20
  from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
21
21
  from ._eval_shape import abstract_init
22
- from ._mapping import vmap, pmap, map
22
+ from ._mapping import vmap, pmap, map, vmap_new_states
23
23
  from ._random import restore_rngs
24
24
 
25
25
  __all__ = [
26
26
  'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
27
27
  'abstract_init',
28
- 'vmap', 'pmap', 'map',
28
+ 'vmap', 'pmap', 'map', 'vmap_new_states',
29
29
  'restore_rngs',
30
30
  ]
@@ -29,14 +29,14 @@ The wrapped gradient transformations here are made possible by using the followi
29
29
 
30
30
  from __future__ import annotations
31
31
 
32
- from functools import wraps, partial
33
- from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
34
-
35
32
  import brainunit as u
36
33
  import jax
34
+ from functools import wraps, partial
35
+ from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
37
36
 
38
- from brainstate._state import State, StateTraceStack
37
+ from brainstate._state import State
39
38
  from brainstate._utils import set_module_as
39
+ from brainstate.compile._make_jaxpr import StatefulFunction
40
40
  from brainstate.typing import PyTree, Missing
41
41
  from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
42
42
 
@@ -149,7 +149,20 @@ TransformFn = Callable
149
149
  class GradientTransform(PrettyRepr):
150
150
  """
151
151
  Automatic Differentiation Transformations for the ``State`` system.
152
+
153
+ This class implements gradient transformations for functions that operate on State objects.
154
+ It allows for flexible configuration of gradient computation with respect to specified states
155
+ and function arguments.
156
+
157
+ Attributes:
158
+ target (Callable): The function to be transformed.
159
+ stateful_target (StatefulFunction): A wrapper around the target function for state management.
160
+ raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
161
+ true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
162
+ return_value (bool): Whether to return the function's value along with gradients.
163
+ has_aux (bool): Whether the function returns auxiliary data.
152
164
  """
165
+
153
166
  __module__ = "brainstate.augment"
154
167
 
155
168
  def __init__(
@@ -162,10 +175,26 @@ class GradientTransform(PrettyRepr):
162
175
  has_aux: bool = False,
163
176
  transform_params: Optional[Dict[str, Any]] = None,
164
177
  ):
178
+ """
179
+ Initialize a ``GradientTransform`` instance.
180
+
181
+ Args:
182
+ target (Callable): The function to be transformed.
183
+ transform (TransformFn): The transformation function to apply.
184
+ grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
185
+ argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
186
+ return_value (bool): Whether to return the function's value along with gradients.
187
+ has_aux (bool): Whether the function returns auxiliary data.
188
+ transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
189
+
190
+ Raises:
191
+ TypeError: If any grad_states are not State instances.
192
+ """
165
193
  # gradient variables
166
194
  if isinstance(grad_states, dict):
167
195
  grad_states = {k: v for k, v in grad_states.items()}
168
196
  self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
197
+ self._grad_state_ids = [id(v) for v in self._grad_states]
169
198
  if any(not isinstance(v, State) for v in self._grad_states):
170
199
  raise TypeError("All grad_states must be State instances.")
171
200
 
@@ -176,107 +205,209 @@ class GradientTransform(PrettyRepr):
176
205
  assert len(self._grad_states) > 0
177
206
  _argnums = 0
178
207
  elif isinstance(argnums, int):
179
- _argnums = (0, argnums + 1) if len(self._grad_states) > 0 else (argnums + 1)
208
+ _argnums = (0, argnums + 2) if len(self._grad_states) > 0 else (argnums + 2)
180
209
  else:
181
210
  assert isinstance(argnums, (tuple, list))
182
- _argnums = tuple(a + 1 for a in argnums)
211
+ _argnums = tuple(a + 2 for a in argnums)
183
212
  if len(self._grad_states) > 0:
184
213
  _argnums = (0,) + _argnums
185
- self._nonvar_argnums = argnums
186
- self._argnums = _argnums
187
- self._return_value = return_value
188
- self._has_aux = has_aux
214
+ self.raw_argnums = argnums
215
+ self.true_argnums = _argnums
216
+ self.return_value = return_value
217
+ self.has_aux = has_aux
189
218
 
190
219
  # target
220
+ assert callable(target), "The target should be a callable object."
191
221
  self.target = target
222
+ self.stateful_target = StatefulFunction(target, name='gradient')
192
223
 
193
224
  # transform
194
- self._states_to_be_written: Tuple[State, ...] = None
195
- _grad_setting = dict() if transform_params is None else transform_params
196
- if self._has_aux:
197
- self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
225
+ grad_setting = dict() if transform_params is None else transform_params
226
+ if self.has_aux:
227
+ self._transform = transform(self._fun_with_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
198
228
  else:
199
- self._transform = transform(self._fun_without_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
229
+ self._transform = transform(self._fun_without_aux, argnums=self.true_argnums, has_aux=True, **grad_setting)
200
230
 
201
231
  def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
202
232
  yield PrettyType(self.__class__.__name__)
203
233
  yield PrettyAttr("target", self.target)
204
234
  yield PrettyAttr("grad_states", self._grad_states)
205
235
  yield PrettyAttr("grad_tree", self._grad_tree)
206
- yield PrettyAttr("argnums", self._nonvar_argnums)
207
- yield PrettyAttr("return_value", self._return_value)
208
- yield PrettyAttr("has_aux", self._has_aux)
236
+ yield PrettyAttr("argnums", self.raw_argnums)
237
+ yield PrettyAttr("return_value", self.return_value)
238
+ yield PrettyAttr("has_aux", self.has_aux)
209
239
  yield PrettyAttr("transform", self._transform)
210
240
 
211
- def _call_target(self, *args, **kwargs):
212
- if self._states_to_be_written is None:
213
- with StateTraceStack() as stack:
214
- output = self.target(*args, **kwargs)
215
- # grad_ids = set([id(v) for v in self._grad_states])
216
- # self._states_to_be_written = [st for st in stack.get_write_states() if id(st) not in grad_ids]
217
- self._states_to_be_written = [st for st in stack.get_write_states()]
218
- else:
219
- output = self.target(*args, **kwargs)
220
- return output
221
-
222
- def _fun_with_aux(self, grad_values: tuple, *args, **kwargs):
223
- for v, d in zip(self._grad_states, grad_values):
224
- v.restore_value(d)
241
+ def _split_state_vals(self, state_trace):
242
+ """
243
+ Split state values into gradient and non-gradient states.
244
+
245
+ Args:
246
+ state_trace: The state trace containing all states.
247
+
248
+ Returns:
249
+ Tuple[Dict, Dict]: A tuple of dictionaries containing gradient and non-gradient state values.
250
+ """
251
+ grad_vals = dict()
252
+ other_vals = dict()
253
+ for st in state_trace.states:
254
+ id_ = id(st)
255
+ if id_ in self._grad_state_ids:
256
+ grad_vals[id_] = st.value
257
+ else:
258
+ other_vals[id_] = st.value
259
+ return grad_vals, other_vals
260
+
261
+ def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
262
+ """
263
+ Merge gradient and non-gradient state values back into a single list.
264
+
265
+ Args:
266
+ grad_vals (Dict): Dictionary of gradient state values.
267
+ other_vals (Dict): Dictionary of non-gradient state values.
268
+ state_trace: The state trace containing all states.
269
+
270
+ Returns:
271
+ List: A list of merged state values.
272
+ """
273
+ res = []
274
+ for st in state_trace.states:
275
+ id_ = id(st)
276
+ if id_ in self._grad_state_ids:
277
+ res.append(grad_vals[id_])
278
+ else:
279
+ res.append(other_vals[id_])
280
+ return res
281
+
282
+ def _call_target(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
283
+ """
284
+ Call the target function with the given state values and arguments.
285
+
286
+ Args:
287
+ grad_vals (Dict): Dictionary of gradient state values.
288
+ other_vals (Dict): Dictionary of non-gradient state values.
289
+ *args: Positional arguments to pass to the target function.
290
+ **kwargs: Keyword arguments to pass to the target function.
291
+
292
+ Returns:
293
+ Tuple: A tuple containing updated state values and the function output.
294
+ """
295
+ cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
296
+ state_trace = self.stateful_target.get_state_trace(cache)
297
+ state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
298
+ state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
299
+ return state_vals, out
300
+
301
+ def _fun_with_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
302
+ """
303
+ Wrapper function for target functions that return auxiliary data.
304
+
305
+ Args:
306
+ grad_vals (Dict): Dictionary of gradient state values.
307
+ other_vals (Dict): Dictionary of non-gradient state values.
308
+ *args: Positional arguments to pass to the target function.
309
+ **kwargs: Keyword arguments to pass to the target function.
310
+
311
+ Returns:
312
+ Tuple: A tuple containing the primary output and a tuple of (all outputs, updated state values).
313
+ """
225
314
  # Users should return the auxiliary data like::
226
315
  # >>> # 1. example of return one data
227
316
  # >>> return scalar_loss, data
228
317
  # >>> # 2. example of return multiple data
229
318
  # >>> return scalar_loss, (data1, data2, ...)
230
- outs = self._call_target(*args, **kwargs)
231
- # outputs: [0] is the value for gradient,
232
- # [1] is other values for return
233
- assert self._states_to_be_written is not None, "The states to be written should be collected."
234
- return outs[0], (outs, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
235
-
236
- def _fun_without_aux(self, grad_values: tuple, *args, **kwargs):
237
- for v, d in zip(self._grad_states, grad_values):
238
- v.restore_value(d)
239
- # Users should return the scalar value like this::
240
- # >>> return scalar_loss
241
- out = self._call_target(*args, **kwargs)
242
- assert self._states_to_be_written is not None, "The states to be written should be collected."
243
- return out, (out, [v.value for v in self._grad_states], [v.value for v in self._states_to_be_written])
244
-
245
- def _return(self, rets):
246
- grads, (outputs, new_grad_vals, new_dyn_vals) = rets
247
- for i, val in enumerate(new_grad_vals):
248
- self._grad_states[i].restore_value(val)
249
- for i, val in enumerate(new_dyn_vals):
250
- self._states_to_be_written[i].value = val
319
+ state_vals, outs = self._call_target(grad_vals, other_vals, *args, **kwargs)
320
+ return outs[0], (outs, state_vals)
321
+
322
+ def _fun_without_aux(self, grad_vals: Dict, other_vals: Dict, *args, **kwargs):
323
+ """
324
+ Wrapper function for target functions that do not return auxiliary data.
325
+
326
+ Args:
327
+ grad_vals (Dict): Dictionary of gradient state values.
328
+ other_vals (Dict): Dictionary of non-gradient state values.
329
+ *args: Positional arguments to pass to the target function.
330
+ **kwargs: Keyword arguments to pass to the target function.
331
+
332
+ Returns:
333
+ Tuple: A tuple containing the output and a tuple of (output, updated state values).
334
+ """
335
+ state_vals, out = self._call_target(grad_vals, other_vals, *args, **kwargs)
336
+ return out, (out, state_vals)
337
+
338
+ def _return(self, rets, state_trace):
339
+ """
340
+ Process and format the return values from the gradient computation.
341
+
342
+ Args:
343
+ rets: The raw results from the gradient computation.
344
+ state_trace: The state trace containing all states.
345
+
346
+ Returns:
347
+ Union[Gradient, Tuple]: The processed gradient results, potentially including function value and/or auxiliary data.
348
+ """
349
+ # unpack the return values
350
+ grads, (outputs, new_state_vals) = rets
351
+
352
+ # assign new values to the states
353
+ state_trace.assign_state_vals(new_state_vals)
251
354
 
252
355
  # check returned grads
253
356
  if len(self._grad_states) > 0:
254
- if self._nonvar_argnums is None:
255
- grads = self._grad_tree.unflatten(grads)
357
+ grads_of_states = grads if self.raw_argnums is None else grads[0]
358
+ grads_of_states = [grads_of_states[st_id] for st_id in self._grad_state_ids]
359
+ if self.raw_argnums is None:
360
+ grads = self._grad_tree.unflatten(grads_of_states)
256
361
  else:
257
- var_grads = self._grad_tree.unflatten(grads[0])
258
- arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
362
+ var_grads = self._grad_tree.unflatten(grads_of_states)
363
+ arg_grads = grads[1] if isinstance(self.raw_argnums, int) else grads[1:]
259
364
  grads = (var_grads, arg_grads)
260
365
 
261
366
  # check returned value
262
- if self._return_value:
367
+ if self.return_value:
263
368
  # check aux
264
- if self._has_aux:
369
+ if self.has_aux:
265
370
  return grads, outputs[0], outputs[1]
266
371
  else:
267
372
  return grads, outputs
268
373
  else:
269
374
  # check aux
270
- if self._has_aux:
375
+ if self.has_aux:
271
376
  return grads, outputs[1]
272
377
  else:
273
378
  return grads
274
379
 
275
380
  def __call__(
276
381
  self, *args, **kwargs
277
- ) -> Gradient | Tuple[Gradient, LossValue] | Tuple[Gradient, AuxData] | Tuple[Gradient, LossValue, AuxData]:
278
- rets = self._transform([v.value for v in self._grad_states], *args, **kwargs)
279
- return self._return(rets)
382
+ ) -> (
383
+ Gradient |
384
+ Tuple[Gradient, LossValue] |
385
+ Tuple[Gradient, AuxData] |
386
+ Tuple[Gradient, LossValue, AuxData]
387
+ ):
388
+ """
389
+ Compute gradients by calling the transformed function.
390
+
391
+ Args:
392
+ *args: Positional arguments to pass to the target function.
393
+ **kwargs: Keyword arguments to pass to the target function.
394
+
395
+ Returns:
396
+ Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
397
+ """
398
+
399
+ # TODO: support jax.disable_jit()
400
+
401
+ # compute the model
402
+ self.stateful_target.make_jaxpr(*args, **kwargs)
403
+ cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
404
+
405
+ # apply the gradient transformation
406
+ state_trace = self.stateful_target.get_state_trace(cache)
407
+ rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
408
+
409
+ # analyze and return the results
410
+ return self._return(rets, state_trace)
280
411
 
281
412
 
282
413
  _doc_of_return = '''
@@ -347,25 +478,27 @@ def grad(
347
478
  """
348
479
  if isinstance(fun, Missing):
349
480
  def transform(fun) -> GradientTransform:
350
- return GradientTransform(target=fun,
351
- transform=u.autograd.grad if unit_aware else jax.grad,
352
- grad_states=grad_states,
353
- argnums=argnums,
354
- return_value=return_value,
355
- has_aux=False if has_aux is None else has_aux,
356
- transform_params=dict(holomorphic=holomorphic,
357
- allow_int=allow_int))
481
+ return GradientTransform(
482
+ target=fun,
483
+ transform=u.autograd.grad if unit_aware else jax.grad,
484
+ grad_states=grad_states,
485
+ argnums=argnums,
486
+ return_value=return_value,
487
+ has_aux=False if has_aux is None else has_aux,
488
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
489
+ )
358
490
 
359
491
  return transform
360
492
 
361
- return GradientTransform(target=fun,
362
- transform=u.autograd.grad if unit_aware else jax.grad,
363
- grad_states=grad_states,
364
- argnums=argnums,
365
- return_value=return_value,
366
- has_aux=False if has_aux is None else has_aux,
367
- transform_params=dict(holomorphic=holomorphic,
368
- allow_int=allow_int))
493
+ return GradientTransform(
494
+ target=fun,
495
+ transform=u.autograd.grad if unit_aware else jax.grad,
496
+ grad_states=grad_states,
497
+ argnums=argnums,
498
+ return_value=return_value,
499
+ has_aux=False if has_aux is None else has_aux,
500
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
501
+ )
369
502
 
370
503
 
371
504
  grad.__doc__ = grad.__doc__ % _doc_of_return
@@ -412,22 +545,26 @@ def vector_grad(
412
545
 
413
546
  if isinstance(func, Missing):
414
547
  def transform(fun) -> GradientTransform:
415
- return GradientTransform(target=fun,
416
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
417
- grad_states=grad_states,
418
- argnums=argnums,
419
- return_value=return_value,
420
- has_aux=False if has_aux is None else has_aux)
548
+ return GradientTransform(
549
+ target=fun,
550
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
551
+ grad_states=grad_states,
552
+ argnums=argnums,
553
+ return_value=return_value,
554
+ has_aux=False if has_aux is None else has_aux
555
+ )
421
556
 
422
557
  return transform
423
558
 
424
559
  else:
425
- return GradientTransform(target=func,
426
- transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
427
- grad_states=grad_states,
428
- argnums=argnums,
429
- return_value=return_value,
430
- has_aux=False if has_aux is None else has_aux)
560
+ return GradientTransform(
561
+ target=func,
562
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
563
+ grad_states=grad_states,
564
+ argnums=argnums,
565
+ return_value=return_value,
566
+ has_aux=False if has_aux is None else has_aux
567
+ )
431
568
 
432
569
 
433
570
  vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
@@ -484,15 +621,17 @@ def jacrev(
484
621
  fun: GradientTransform
485
622
  The transformed object.
486
623
  """
487
- return GradientTransform(target=fun,
488
- transform=_jacrev,
489
- grad_states=grad_states,
490
- argnums=argnums,
491
- return_value=return_value,
492
- has_aux=False if has_aux is None else has_aux,
493
- transform_params=dict(holomorphic=holomorphic,
494
- allow_int=allow_int,
495
- unit_aware=unit_aware, ))
624
+ return GradientTransform(
625
+ target=fun,
626
+ transform=_jacrev,
627
+ grad_states=grad_states,
628
+ argnums=argnums,
629
+ return_value=return_value,
630
+ has_aux=False if has_aux is None else has_aux,
631
+ transform_params=dict(holomorphic=holomorphic,
632
+ allow_int=allow_int,
633
+ unit_aware=unit_aware, )
634
+ )
496
635
 
497
636
 
498
637
  jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
@@ -542,14 +681,15 @@ def jacfwd(
542
681
  The transformed object.
543
682
  """
544
683
 
545
- return GradientTransform(target=func,
546
- transform=_jacfwd,
547
- grad_states=grad_states,
548
- argnums=argnums,
549
- return_value=return_value,
550
- has_aux=False if has_aux is None else has_aux,
551
- transform_params=dict(holomorphic=holomorphic,
552
- unit_aware=unit_aware))
684
+ return GradientTransform(
685
+ target=func,
686
+ transform=_jacfwd,
687
+ grad_states=grad_states,
688
+ argnums=argnums,
689
+ return_value=return_value,
690
+ has_aux=False if has_aux is None else has_aux,
691
+ transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware)
692
+ )
553
693
 
554
694
 
555
695
  jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
@@ -597,13 +737,15 @@ def hessian(
597
737
  obj: ObjectTransform
598
738
  The transformed object.
599
739
  """
600
- return GradientTransform(target=func,
601
- transform=u.autograd.hessian if unit_aware else jax.hessian,
602
- grad_states=grad_states,
603
- argnums=argnums,
604
- return_value=return_value,
605
- has_aux=False if has_aux is None else has_aux,
606
- transform_params=dict(holomorphic=holomorphic))
740
+ return GradientTransform(
741
+ target=func,
742
+ transform=u.autograd.hessian if unit_aware else jax.hessian,
743
+ grad_states=grad_states,
744
+ argnums=argnums,
745
+ return_value=return_value,
746
+ has_aux=False if has_aux is None else has_aux,
747
+ transform_params=dict(holomorphic=holomorphic)
748
+ )
607
749
 
608
750
 
609
751
  hessian.__doc__ = hessian.__doc__ % _doc_of_return
@@ -16,13 +16,12 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- import unittest
20
- from pprint import pprint
21
-
22
19
  import brainunit as u
23
20
  import jax
24
21
  import jax.numpy as jnp
25
22
  import pytest
23
+ import unittest
24
+ from pprint import pprint
26
25
 
27
26
  import brainstate as bst
28
27
  from brainstate.augment._autograd import _jacfwd
@@ -16,12 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- from typing import Any, TypeVar, Callable, Sequence, Union
20
-
21
19
  import jax
20
+ from typing import Any, TypeVar, Callable, Sequence, Union
22
21
 
22
+ from brainstate import random
23
23
  from brainstate.graph import Node, flatten, unflatten
24
- from brainstate.random import DEFAULT, RandomState
25
24
  from ._random import restore_rngs
26
25
 
27
26
  __all__ = [
@@ -34,7 +33,7 @@ A = TypeVar('A')
34
33
  def abstract_init(
35
34
  fn: Callable[..., A],
36
35
  *args: Any,
37
- rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
36
+ rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
38
37
  **kwargs: Any,
39
38
  ) -> A:
40
39
  """