brainstate 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +5 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +13 -8
  7. brainstate/compile/_conditions.py +2 -2
  8. brainstate/compile/_make_jaxpr.py +48 -6
  9. brainstate/compile/_progress_bar.py +2 -2
  10. brainstate/environ.py +19 -19
  11. brainstate/functional/_activations_test.py +12 -12
  12. brainstate/graph/_graph_operation.py +69 -69
  13. brainstate/graph/_graph_operation_test.py +2 -2
  14. brainstate/mixin.py +0 -17
  15. brainstate/nn/_collective_ops.py +4 -4
  16. brainstate/nn/_dropout_test.py +2 -2
  17. brainstate/nn/_dynamics.py +53 -35
  18. brainstate/nn/_elementwise.py +30 -30
  19. brainstate/nn/_linear.py +4 -4
  20. brainstate/nn/_module.py +6 -6
  21. brainstate/nn/_module_test.py +1 -1
  22. brainstate/nn/_normalizations.py +11 -11
  23. brainstate/nn/_normalizations_test.py +6 -6
  24. brainstate/nn/_poolings.py +24 -24
  25. brainstate/nn/_synapse.py +1 -12
  26. brainstate/nn/_utils.py +1 -1
  27. brainstate/nn/metrics.py +4 -4
  28. brainstate/optim/_optax_optimizer.py +8 -8
  29. brainstate/random/_rand_funs.py +37 -37
  30. brainstate/random/_rand_funs_test.py +3 -3
  31. brainstate/random/_rand_seed.py +7 -7
  32. brainstate/surrogate.py +40 -40
  33. brainstate/util/pretty_pytree.py +10 -10
  34. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  35. brainstate/util/struct.py +7 -7
  36. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/METADATA +12 -12
  37. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/RECORD +40 -40
  38. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/WHEEL +1 -1
  39. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/LICENSE +0 -0
  40. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.5"
20
+ __version__ = "0.1.7"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
brainstate/_state.py CHANGED
@@ -116,10 +116,10 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
116
116
 
117
117
  Example::
118
118
 
119
- >>> import brainstate as bst
119
+ >>> import brainstate as brainstate
120
120
  >>> import jax.numpy as jnp
121
- >>> state = bst.ShortTermState(jnp.zeros((2, 3)))
122
- >>> with bst.check_state_value_tree():
121
+ >>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
122
+ >>> with brainstate.check_state_value_tree():
123
123
  >>> # The line below will not raise an error.
124
124
  >>> state.value = jnp.zeros((2, 3))
125
125
  ...
@@ -163,10 +163,10 @@ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
163
163
  Example::
164
164
 
165
165
  >>> import jax
166
- >>> import brainstate as bst
166
+ >>> import brainstate as brainstate
167
167
  >>> import jax.numpy as jnp
168
168
  >>>
169
- >>> a = bst.ShortTermState(jnp.zeros((2, 3)))
169
+ >>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
170
170
  >>>
171
171
  >>> @jax.jit
172
172
  >>> def run_state(b):
@@ -173,6 +173,7 @@ class GradientTransform(PrettyRepr):
173
173
  return_value: bool = False,
174
174
  has_aux: bool = False,
175
175
  transform_params: Optional[Dict[str, Any]] = None,
176
+ check_states: bool = True,
176
177
  ):
177
178
  """
178
179
  Initialize a ``GradientTransform`` instance.
@@ -192,11 +193,12 @@ class GradientTransform(PrettyRepr):
192
193
  # gradient variables
193
194
  if isinstance(grad_states, dict):
194
195
  grad_states = {k: v for k, v in grad_states.items()}
195
- self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
196
+ self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
196
197
  self._grad_state_ids = [id(v) for v in self._grad_states]
197
198
  self._grad_id_to_state = {id(v): v for v in self._grad_states}
198
199
  if any(not isinstance(v, State) for v in self._grad_states):
199
200
  raise TypeError("All grad_states must be State instances.")
201
+ self.check_states = check_states
200
202
 
201
203
  # parameters
202
204
  if argnums is None and len(self._grad_states) == 0:
@@ -259,10 +261,15 @@ class GradientTransform(PrettyRepr):
259
261
  else:
260
262
  other_vals[id_] = st.value
261
263
  if len(all_ids):
262
- err = f"Some states are not found in the state trace when performing gradient transformations.\n "
263
- for i, id_ in enumerate(all_ids):
264
- st = self._grad_id_to_state[id_]
265
- st.raise_error_with_source_info(ValueError(err + str(st)))
264
+ if self.check_states:
265
+ err = f"Some states are not found in the state trace when performing gradient transformations.\n "
266
+ for i, id_ in enumerate(all_ids):
267
+ st = self._grad_id_to_state[id_]
268
+ st.raise_error_with_source_info(ValueError(err + str(st)))
269
+ else:
270
+ id2state = {id(st): st for st in self._grad_states}
271
+ for id_ in all_ids:
272
+ grad_vals[id_] = id2state[id_].value
266
273
 
267
274
  return grad_vals, other_vals
268
275
 
@@ -449,6 +456,7 @@ def grad(
449
456
  has_aux: Optional[bool] = None,
450
457
  return_value: Optional[bool] = False,
451
458
  unit_aware: bool = False,
459
+ check_states: bool = True,
452
460
  ) -> GradientTransform | Callable[[Callable], GradientTransform]:
453
461
  """
454
462
  Compute the gradient of a scalar-valued function with respect to its arguments.
@@ -493,7 +501,8 @@ def grad(
493
501
  argnums=argnums,
494
502
  return_value=return_value,
495
503
  has_aux=False if has_aux is None else has_aux,
496
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
504
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
505
+ check_states=check_states
497
506
  )
498
507
 
499
508
  return transform
@@ -505,7 +514,8 @@ def grad(
505
514
  argnums=argnums,
506
515
  return_value=return_value,
507
516
  has_aux=False if has_aux is None else has_aux,
508
- transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
517
+ transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
518
+ check_states=check_states
509
519
  )
510
520
 
511
521
 
@@ -520,6 +530,7 @@ def vector_grad(
520
530
  return_value: bool = False,
521
531
  has_aux: Optional[bool] = None,
522
532
  unit_aware: bool = False,
533
+ check_states: bool = True,
523
534
  ) -> GradientTransform | Callable[[Callable], GradientTransform]:
524
535
  """Take vector-valued gradients for function ``func``.
525
536
 
@@ -559,7 +570,8 @@ def vector_grad(
559
570
  grad_states=grad_states,
560
571
  argnums=argnums,
561
572
  return_value=return_value,
562
- has_aux=False if has_aux is None else has_aux
573
+ has_aux=False if has_aux is None else has_aux,
574
+ check_states=check_states
563
575
  )
564
576
 
565
577
  return transform
@@ -571,7 +583,8 @@ def vector_grad(
571
583
  grad_states=grad_states,
572
584
  argnums=argnums,
573
585
  return_value=return_value,
574
- has_aux=False if has_aux is None else has_aux
586
+ has_aux=False if has_aux is None else has_aux,
587
+ check_states=check_states
575
588
  )
576
589
 
577
590
 
@@ -588,6 +601,7 @@ def jacrev(
588
601
  holomorphic: bool = False,
589
602
  allow_int: bool = False,
590
603
  unit_aware: bool = False,
604
+ check_states: bool = True,
591
605
  ) -> GradientTransform:
592
606
  """
593
607
  Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
@@ -638,7 +652,8 @@ def jacrev(
638
652
  has_aux=False if has_aux is None else has_aux,
639
653
  transform_params=dict(holomorphic=holomorphic,
640
654
  allow_int=allow_int,
641
- unit_aware=unit_aware, )
655
+ unit_aware=unit_aware, ),
656
+ check_states=check_states
642
657
  )
643
658
 
644
659
 
@@ -656,6 +671,7 @@ def jacfwd(
656
671
  return_value: bool = False,
657
672
  holomorphic: bool = False,
658
673
  unit_aware: bool = False,
674
+ check_states: bool = True,
659
675
  ) -> GradientTransform:
660
676
  """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
661
677
 
@@ -696,7 +712,8 @@ def jacfwd(
696
712
  argnums=argnums,
697
713
  return_value=return_value,
698
714
  has_aux=False if has_aux is None else has_aux,
699
- transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware)
715
+ transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
716
+ check_states=check_states
700
717
  )
701
718
 
702
719
 
@@ -712,6 +729,7 @@ def hessian(
712
729
  holomorphic: bool = False,
713
730
  has_aux: Optional[bool] = None,
714
731
  unit_aware: bool = False,
732
+ check_states: bool = True,
715
733
  ) -> GradientTransform:
716
734
  """
717
735
  Hessian of ``func`` as a dense array.
@@ -752,7 +770,8 @@ def hessian(
752
770
  argnums=argnums,
753
771
  return_value=return_value,
754
772
  has_aux=False if has_aux is None else has_aux,
755
- transform_params=dict(holomorphic=holomorphic)
773
+ transform_params=dict(holomorphic=holomorphic),
774
+ check_states=check_states
756
775
  )
757
776
 
758
777
 
@@ -619,7 +619,7 @@ class TestClassFuncJacobian(unittest.TestCase):
619
619
  # _x = jnp.array([1., 2., 3.])
620
620
  # _y = jnp.array([10., 5.])
621
621
  #
622
- # class Test(bst.nn.Module):
622
+ # class Test(brainstate.nn.Module):
623
623
  # def __init__(self):
624
624
  # super(Test, self).__init__()
625
625
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -635,12 +635,12 @@ class TestClassFuncJacobian(unittest.TestCase):
635
635
  #
636
636
  # _jr = jax.jacfwd(f1)(_x, _y)
637
637
  # t = Test()
638
- # br = bst.augment.jacfwd(t, grad_states=t.x)()
638
+ # br = brainstate.augment.jacfwd(t, grad_states=t.x)()
639
639
  # self.assertTrue((br == _jr).all())
640
640
  #
641
641
  # _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
642
642
  # t = Test()
643
- # br = bst.augment.jacfwd(t, grad_states=[t.x, t.y])()
643
+ # br = brainstate.augment.jacfwd(t, grad_states=[t.x, t.y])()
644
644
  # self.assertTrue((br[0] == _jr[0]).all())
645
645
  # self.assertTrue((br[1] == _jr[1]).all())
646
646
  #
@@ -652,7 +652,7 @@ class TestClassFuncJacobian(unittest.TestCase):
652
652
  # _x = jnp.array([1., 2., 3.])
653
653
  # _y = jnp.array([10., 5.])
654
654
  #
655
- # class Test(bst.nn.Module):
655
+ # class Test(brainstate.nn.Module):
656
656
  # def __init__(self):
657
657
  # super(Test, self).__init__()
658
658
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -667,12 +667,12 @@ class TestClassFuncJacobian(unittest.TestCase):
667
667
  #
668
668
  # _jr = jax.jacrev(f1)(_x, _y)
669
669
  # t = Test()
670
- # br = bst.augment.jacrev(t, grad_states=t.x)(_y)
670
+ # br = brainstate.augment.jacrev(t, grad_states=t.x)(_y)
671
671
  # self.assertTrue((br == _jr).all())
672
672
  #
673
673
  # _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
674
674
  # t = Test()
675
- # var_grads, arg_grads = bst.augment.jacrev(t, grad_states=t.x, argnums=0)(_y)
675
+ # var_grads, arg_grads = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0)(_y)
676
676
  # print(var_grads, )
677
677
  # print(arg_grads, )
678
678
  # self.assertTrue((var_grads == _jr[0]).all())
@@ -686,7 +686,7 @@ class TestClassFuncJacobian(unittest.TestCase):
686
686
  # _x = jnp.array([1., 2., 3.])
687
687
  # _y = jnp.array([10., 5.])
688
688
  #
689
- # class Test(bst.nn.Module):
689
+ # class Test(brainstate.nn.Module):
690
690
  # def __init__(self):
691
691
  # super(Test, self).__init__()
692
692
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -701,12 +701,12 @@ class TestClassFuncJacobian(unittest.TestCase):
701
701
  #
702
702
  # _jr = jax.jacfwd(f1)(_x, _y)
703
703
  # t = Test()
704
- # br = bst.augment.jacfwd(t, grad_states=t.x)(_y)
704
+ # br = brainstate.augment.jacfwd(t, grad_states=t.x)(_y)
705
705
  # self.assertTrue((br == _jr).all())
706
706
  #
707
707
  # _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
708
708
  # t = Test()
709
- # var_grads, arg_grads = bst.augment.jacfwd(t, grad_states=t.x, argnums=0)(_y)
709
+ # var_grads, arg_grads = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0)(_y)
710
710
  # print(var_grads, )
711
711
  # print(arg_grads, )
712
712
  # self.assertTrue((var_grads == _jr[0]).all())
@@ -722,7 +722,7 @@ class TestClassFuncJacobian(unittest.TestCase):
722
722
  # _x = jnp.array([1., 2., 3.])
723
723
  # _y = jnp.array([10., 5.])
724
724
  #
725
- # class Test(bst.nn.Module):
725
+ # class Test(brainstate.nn.Module):
726
726
  # def __init__(self):
727
727
  # super(Test, self).__init__()
728
728
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -737,13 +737,13 @@ class TestClassFuncJacobian(unittest.TestCase):
737
737
  #
738
738
  # _jr = jax.jacrev(f1)(_x, _y)
739
739
  # t = Test()
740
- # br, _ = bst.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
740
+ # br, _ = brainstate.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
741
741
  # self.assertTrue((br == _jr).all())
742
742
  #
743
743
  # t = Test()
744
744
  # _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
745
745
  # _aux = t(_y)[1]
746
- # (var_grads, arg_grads), aux = bst.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
746
+ # (var_grads, arg_grads), aux = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
747
747
  # print(var_grads, )
748
748
  # print(arg_grads, )
749
749
  # self.assertTrue((var_grads == _jr[0]).all())
@@ -762,7 +762,7 @@ class TestClassFuncJacobian(unittest.TestCase):
762
762
  # _x = jnp.array([1., 2., 3.])
763
763
  # _y = jnp.array([10., 5.])
764
764
  #
765
- # class Test(bst.nn.Module):
765
+ # class Test(brainstate.nn.Module):
766
766
  # def __init__(self):
767
767
  # super(Test, self).__init__()
768
768
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -777,7 +777,7 @@ class TestClassFuncJacobian(unittest.TestCase):
777
777
  #
778
778
  # _jr = jax.jacfwd(f1)(_x, _y)
779
779
  # t = Test()
780
- # br, (c, d) = bst.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
780
+ # br, (c, d) = brainstate.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
781
781
  # # print(_jr)
782
782
  # # print(br)
783
783
  # a = (br == _jr)
@@ -786,7 +786,7 @@ class TestClassFuncJacobian(unittest.TestCase):
786
786
  # t = Test()
787
787
  # _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
788
788
  # _aux = t(_y)[1]
789
- # (var_grads, arg_grads), aux = bst.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
789
+ # (var_grads, arg_grads), aux = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
790
790
  # print(var_grads, )
791
791
  # print(arg_grads, )
792
792
  # self.assertTrue((var_grads == _jr[0]).all())
@@ -805,7 +805,7 @@ class TestClassFuncJacobian(unittest.TestCase):
805
805
  # _x = jnp.array([1., 2., 3.])
806
806
  # _y = jnp.array([10., 5.])
807
807
  #
808
- # class Test(bst.nn.Module):
808
+ # class Test(brainstate.nn.Module):
809
809
  # def __init__(self):
810
810
  # super(Test, self).__init__()
811
811
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -820,13 +820,13 @@ class TestClassFuncJacobian(unittest.TestCase):
820
820
  #
821
821
  # _jr = jax.jacrev(f1)(_x, _y)
822
822
  # t = Test()
823
- # br, _ = bst.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
823
+ # br, _ = brainstate.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
824
824
  # self.assertTrue((br == _jr).all())
825
825
  #
826
826
  # t = Test()
827
827
  # _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
828
828
  # _val, _aux = t(_y)
829
- # (var_grads, arg_grads), value, aux = bst.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
829
+ # (var_grads, arg_grads), value, aux = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
830
830
  # print(var_grads, )
831
831
  # print(arg_grads, )
832
832
  # self.assertTrue((var_grads == _jr[0]).all())
@@ -846,7 +846,7 @@ class TestClassFuncJacobian(unittest.TestCase):
846
846
  # _x = jnp.array([1., 2., 3.])
847
847
  # _y = jnp.array([10., 5.])
848
848
  #
849
- # class Test(bst.nn.Module):
849
+ # class Test(brainstate.nn.Module):
850
850
  # def __init__(self):
851
851
  # super(Test, self).__init__()
852
852
  # self.x = jnp.Variable(jnp.array([1., 2., 3.]))
@@ -861,13 +861,13 @@ class TestClassFuncJacobian(unittest.TestCase):
861
861
  #
862
862
  # _jr = jax.jacfwd(f1)(_x, _y)
863
863
  # t = Test()
864
- # br, _ = bst.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
864
+ # br, _ = brainstate.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
865
865
  # self.assertTrue((br == _jr).all())
866
866
  #
867
867
  # t = Test()
868
868
  # _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
869
869
  # _val, _aux = t(_y)
870
- # (var_grads, arg_grads), value, aux = bst.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
870
+ # (var_grads, arg_grads), value, aux = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
871
871
  # print(_val, )
872
872
  # print('_aux: ', _aux, 'aux: ', aux)
873
873
  # print(var_grads, )
@@ -884,7 +884,7 @@ class TestClassFuncJacobian(unittest.TestCase):
884
884
  # def test1(self):
885
885
  # f = lambda x: 3 * x ** 2
886
886
  # _x = jnp.ones(10)
887
- # pprint(bst.augment.vector_grad(f, argnums=0)(_x))
887
+ # pprint(brainstate.augment.vector_grad(f, argnums=0)(_x))
888
888
  #
889
889
  # def test2(self):
890
890
  # def f(x, y):
@@ -894,14 +894,14 @@ class TestClassFuncJacobian(unittest.TestCase):
894
894
  # _x = jnp.ones(5)
895
895
  # _y = jnp.ones(5)
896
896
  #
897
- # g = bst.augment.vector_grad(f, argnums=0)(_x, _y)
897
+ # g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
898
898
  # pprint(g)
899
899
  # self.assertTrue(jnp.array_equal(g, 2 * _x))
900
900
  #
901
- # g = bst.augment.vector_grad(f, argnums=(0,))(_x, _y)
901
+ # g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
902
902
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x))
903
903
  #
904
- # g = bst.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
904
+ # g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
905
905
  # pprint(g)
906
906
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x))
907
907
  # self.assertTrue(jnp.array_equal(g[1], 2 * _y))
@@ -915,14 +915,14 @@ class TestClassFuncJacobian(unittest.TestCase):
915
915
  # _x = jnp.ones(5)
916
916
  # _y = jnp.ones(5)
917
917
  #
918
- # g = bst.augment.vector_grad(f, argnums=0)(_x, _y)
918
+ # g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
919
919
  # # pprint(g)
920
920
  # self.assertTrue(jnp.array_equal(g, 2 * _x + 3 * _x ** 2))
921
921
  #
922
- # g = bst.augment.vector_grad(f, argnums=(0,))(_x, _y)
922
+ # g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
923
923
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x + 3 * _x ** 2))
924
924
  #
925
- # g = bst.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
925
+ # g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
926
926
  # # pprint(g)
927
927
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x + 3 * _x ** 2))
928
928
  # self.assertTrue(jnp.array_equal(g[1], 2 * _y + 3 * _y ** 2))
@@ -935,14 +935,14 @@ class TestClassFuncJacobian(unittest.TestCase):
935
935
  # _x = jnp.ones((5, 5))
936
936
  # _y = jnp.ones((5, 5))
937
937
  #
938
- # g = bst.augment.vector_grad(f, argnums=0)(_x, _y)
938
+ # g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
939
939
  # pprint(g)
940
940
  # self.assertTrue(jnp.array_equal(g, 2 * _x))
941
941
  #
942
- # g = bst.augment.vector_grad(f, argnums=(0,))(_x, _y)
942
+ # g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
943
943
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x))
944
944
  #
945
- # g = bst.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
945
+ # g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
946
946
  # pprint(g)
947
947
  # self.assertTrue(jnp.array_equal(g[0], 2 * _x))
948
948
  # self.assertTrue(jnp.array_equal(g[1], 2 * _y))
@@ -956,7 +956,7 @@ class TestClassFuncJacobian(unittest.TestCase):
956
956
  # _x = jnp.ones(5)
957
957
  # _y = jnp.ones(5)
958
958
  #
959
- # g, aux = bst.augment.vector_grad(f, has_aux=True)(_x, _y)
959
+ # g, aux = brainstate.augment.vector_grad(f, has_aux=True)(_x, _y)
960
960
  # pprint(g, )
961
961
  # pprint(aux)
962
962
  # self.assertTrue(jnp.array_equal(g, 2 * _x))
@@ -970,7 +970,7 @@ class TestClassFuncJacobian(unittest.TestCase):
970
970
  # _x = jnp.ones(5)
971
971
  # _y = jnp.ones(5)
972
972
  #
973
- # g, value = bst.augment.vector_grad(f, return_value=True)(_x, _y)
973
+ # g, value = brainstate.augment.vector_grad(f, return_value=True)(_x, _y)
974
974
  # pprint(g, )
975
975
  # pprint(value)
976
976
  # self.assertTrue(jnp.array_equal(g, 2 * _x))
@@ -985,7 +985,7 @@ class TestClassFuncJacobian(unittest.TestCase):
985
985
  # _x = jnp.ones(5)
986
986
  # _y = jnp.ones(5)
987
987
  #
988
- # g, value, aux = bst.augment.vector_grad(f, has_aux=True, return_value=True)(_x, _y)
988
+ # g, value, aux = brainstate.augment.vector_grad(f, has_aux=True, return_value=True)(_x, _y)
989
989
  # print('grad', g)
990
990
  # print('value', value)
991
991
  # print('aux', aux)
@@ -996,7 +996,7 @@ class TestClassFuncJacobian(unittest.TestCase):
996
996
  #
997
997
  # class TestClassFuncVectorGrad(unittest.TestCase):
998
998
  # def test1(self):
999
- # class Test(bst.nn.Module):
999
+ # class Test(brainstate.nn.Module):
1000
1000
  # def __init__(self):
1001
1001
  # super(Test, self).__init__()
1002
1002
  # self.x = jnp.Variable(jnp.ones(5))
@@ -1007,13 +1007,13 @@ class TestClassFuncJacobian(unittest.TestCase):
1007
1007
  #
1008
1008
  # t = Test()
1009
1009
  #
1010
- # g = bst.augment.vector_grad(t, grad_states=t.x)()
1010
+ # g = brainstate.augment.vector_grad(t, grad_states=t.x)()
1011
1011
  # self.assertTrue(jnp.array_equal(g, 2 * t.x))
1012
1012
  #
1013
- # g = bst.augment.vector_grad(t, grad_states=(t.x,))()
1013
+ # g = brainstate.augment.vector_grad(t, grad_states=(t.x,))()
1014
1014
  # self.assertTrue(jnp.array_equal(g[0], 2 * t.x))
1015
1015
  #
1016
- # g = bst.augment.vector_grad(t, grad_states=(t.x, t.y))()
1016
+ # g = brainstate.augment.vector_grad(t, grad_states=(t.x, t.y))()
1017
1017
  # self.assertTrue(jnp.array_equal(g[0], 2 * t.x))
1018
1018
  # self.assertTrue(jnp.array_equal(g[1], 2 * t.y))
1019
1019
  #
@@ -1025,20 +1025,20 @@ class TestClassFuncJacobian(unittest.TestCase):
1025
1025
  #
1026
1026
  # class TestDebug(parameterized.TestCase):
1027
1027
  # def test_debug1(self):
1028
- # a = bst.random.RandomState()
1028
+ # a = brainstate.random.RandomState()
1029
1029
  #
1030
1030
  # def f(b):
1031
1031
  # print(a.value)
1032
1032
  # return a + b + a.random()
1033
1033
  #
1034
- # f = bst.augment.vector_grad(f, argnums=0)
1034
+ # f = brainstate.augment.vector_grad(f, argnums=0)
1035
1035
  # f(1.)
1036
1036
  #
1037
1037
  # with jax.disable_jit():
1038
1038
  # f(1.)
1039
1039
  #
1040
1040
  # @parameterized.product(
1041
- # grad_fun=[bst.augment.grad, bst.augment.vector_grad]
1041
+ # grad_fun=[brainstate.augment.grad, brainstate.augment.vector_grad]
1042
1042
  # )
1043
1043
  # def test_print_info1(self, grad_fun):
1044
1044
  # file = tempfile.TemporaryFile(mode='w+')
@@ -1075,7 +1075,7 @@ class TestClassFuncJacobian(unittest.TestCase):
1075
1075
  # self.assertTrue(file.read().strip() == expect_res.strip())
1076
1076
  #
1077
1077
  # @parameterized.product(
1078
- # grad_fun=[bst.augment.grad, bst.augment.vector_grad]
1078
+ # grad_fun=[brainstate.augment.grad, brainstate.augment.vector_grad]
1079
1079
  # )
1080
1080
  # def test_print_info2(self, grad_fun):
1081
1081
  # file = tempfile.TemporaryFile(mode='w+')
@@ -1117,7 +1117,7 @@ class TestClassFuncJacobian(unittest.TestCase):
1117
1117
  # a = jnp.Variable(jnp.ones(2))
1118
1118
  # b = jnp.Variable(jnp.zeros(2))
1119
1119
  #
1120
- # @bst.augment.vector_grad(argnums=0)
1120
+ # @brainstate.augment.vector_grad(argnums=0)
1121
1121
  # def f1(c):
1122
1122
  # a.value += 1
1123
1123
  # b.value += 10
@@ -1149,7 +1149,7 @@ class TestClassFuncJacobian(unittest.TestCase):
1149
1149
  #
1150
1150
  #
1151
1151
  # def run_fun(d):
1152
- # @bst.augment.vector_grad(argnums=0)
1152
+ # @brainstate.augment.vector_grad(argnums=0)
1153
1153
  # def f1(c):
1154
1154
  # a.value += d
1155
1155
  # b.value += 10
@@ -1178,8 +1178,8 @@ class TestClassFuncJacobian(unittest.TestCase):
1178
1178
  # print('compiling f ...', file=file)
1179
1179
  # return a + b
1180
1180
  #
1181
- # grad1 = bst.augment.grad(f)(1., 2.) # call "f" twice, one for Variable finding, one for compiling
1182
- # grad2 = bst.augment.vector_grad(f)(1., 2.) # call "f" once for compiling
1181
+ # grad1 = brainstate.augment.grad(f)(1., 2.) # call "f" twice, one for Variable finding, one for compiling
1182
+ # grad2 = brainstate.augment.vector_grad(f)(1., 2.) # call "f" once for compiling
1183
1183
  #
1184
1184
  # file.seek(0)
1185
1185
  # print(file.read().strip())
@@ -40,13 +40,13 @@ def abstract_init(
40
40
 
41
41
  Here's an example::
42
42
 
43
- >>> import brainstate as bst
43
+ >>> import brainstate
44
44
  >>> class MLP:
45
45
  ... def __init__(self, n_in, n_mid, n_out):
46
- ... self.dense1 = bst.nn.Linear(n_in, n_mid)
47
- ... self.dense2 = bst.nn.Linear(n_mid, n_out)
46
+ ... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
47
+ ... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
48
48
 
49
- >>> r = bst.augment.abstract_init(lambda: MLP(1, 2, 3))
49
+ >>> r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
50
50
  >>> r
51
51
  MLP(
52
52
  dense1=Linear(
@@ -481,7 +481,7 @@ def _vmap_transform(
481
481
  return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
482
482
 
483
483
  @functools.wraps(f)
484
- def vmapped_fn(*args):
484
+ def vmapped_fn(*args, **kwargs):
485
485
  """
486
486
  Applies vectorized mapping (vmap) to the input function while managing state.
487
487
 
@@ -503,6 +503,11 @@ def _vmap_transform(
503
503
  data structures (e.g., axis_to_in_states, in_state_to_axis) which
504
504
  should be defined in the broader context.
505
505
  """
506
+ if len(kwargs):
507
+ raise NotImplementedError(
508
+ "Keyword arguments `f(**kwargs)` are not supported in brainstate.augment.vmap"
509
+ )
510
+
506
511
  # in states values
507
512
  in_state_map_vals = [
508
513
  [st.value for st in states]
@@ -615,16 +620,16 @@ def vmap(
615
620
 
616
621
  These are several example usage::
617
622
 
618
- >>> import brainstate as bst
623
+ >>> import brainstate as brainstate
619
624
  >>> import jax.numpy as jnp
620
625
 
621
- >>> class Model(bst.nn.Module):
626
+ >>> class Model(brainstate.nn.Module):
622
627
  >>> def __init__(self):
623
628
  >>> super().__init__()
624
629
  >>>
625
- >>> self.a = bst.ShortTermState(bst.random.randn(5))
626
- >>> self.b = bst.ShortTermState(bst.random.randn(5))
627
- >>> self.c = bst.State(bst.random.randn(1))
630
+ >>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
631
+ >>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
632
+ >>> self.c = brainstate.State(brainstate.random.randn(1))
628
633
 
629
634
  >>> def __call__(self, *args, **kwargs):
630
635
  >>> self.c.value = self.a.value * self.b.value
@@ -632,9 +637,9 @@ def vmap(
632
637
 
633
638
  >>> model = Model()
634
639
 
635
- >>> r = bst.augment.vmap(
640
+ >>> r = brainstate.augment.vmap(
636
641
  >>> model,
637
- >>> in_states=model.states(bst.ShortTermState),
642
+ >>> in_states=model.states(brainstate.ShortTermState),
638
643
  >>> out_states=model.c
639
644
  >>> )()
640
645
 
@@ -203,9 +203,9 @@ def ifelse(conditions, branches, *operands, check_cond: bool = True):
203
203
  Examples
204
204
  --------
205
205
 
206
- >>> import brainstate as bst
206
+ >>> import brainstate
207
207
  >>> def f(a):
208
- >>> return bst.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
208
+ >>> return brainstate.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
209
209
  >>> branches=[lambda: 1,
210
210
  >>> lambda: 2,
211
211
  >>> lambda: 3,