brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +6 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +22 -17
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_state.py
CHANGED
@@ -50,6 +50,7 @@ __all__ = [
|
|
50
50
|
'LongTermState',
|
51
51
|
'HiddenState',
|
52
52
|
'ParamState',
|
53
|
+
'BatchState',
|
53
54
|
'TreefyState',
|
54
55
|
'FakeState',
|
55
56
|
|
@@ -115,10 +116,10 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
|
115
116
|
|
116
117
|
Example::
|
117
118
|
|
118
|
-
>>> import brainstate as
|
119
|
+
>>> import brainstate as brainstate
|
119
120
|
>>> import jax.numpy as jnp
|
120
|
-
>>> state =
|
121
|
-
>>> with
|
121
|
+
>>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
122
|
+
>>> with brainstate.check_state_value_tree():
|
122
123
|
>>> # The line below will not raise an error.
|
123
124
|
>>> state.value = jnp.zeros((2, 3))
|
124
125
|
...
|
@@ -162,10 +163,10 @@ def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
|
|
162
163
|
Example::
|
163
164
|
|
164
165
|
>>> import jax
|
165
|
-
>>> import brainstate as
|
166
|
+
>>> import brainstate as brainstate
|
166
167
|
>>> import jax.numpy as jnp
|
167
168
|
>>>
|
168
|
-
>>> a =
|
169
|
+
>>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
169
170
|
>>>
|
170
171
|
>>> @jax.jit
|
171
172
|
>>> def run_state(b):
|
brainstate/augment/_autograd.py
CHANGED
@@ -173,6 +173,7 @@ class GradientTransform(PrettyRepr):
|
|
173
173
|
return_value: bool = False,
|
174
174
|
has_aux: bool = False,
|
175
175
|
transform_params: Optional[Dict[str, Any]] = None,
|
176
|
+
check_states: bool = True,
|
176
177
|
):
|
177
178
|
"""
|
178
179
|
Initialize a ``GradientTransform`` instance.
|
@@ -192,11 +193,12 @@ class GradientTransform(PrettyRepr):
|
|
192
193
|
# gradient variables
|
193
194
|
if isinstance(grad_states, dict):
|
194
195
|
grad_states = {k: v for k, v in grad_states.items()}
|
195
|
-
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
|
196
|
+
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states, is_leaf=lambda x: isinstance(x, State))
|
196
197
|
self._grad_state_ids = [id(v) for v in self._grad_states]
|
197
198
|
self._grad_id_to_state = {id(v): v for v in self._grad_states}
|
198
199
|
if any(not isinstance(v, State) for v in self._grad_states):
|
199
200
|
raise TypeError("All grad_states must be State instances.")
|
201
|
+
self.check_states = check_states
|
200
202
|
|
201
203
|
# parameters
|
202
204
|
if argnums is None and len(self._grad_states) == 0:
|
@@ -259,10 +261,15 @@ class GradientTransform(PrettyRepr):
|
|
259
261
|
else:
|
260
262
|
other_vals[id_] = st.value
|
261
263
|
if len(all_ids):
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
264
|
+
if self.check_states:
|
265
|
+
err = f"Some states are not found in the state trace when performing gradient transformations.\n "
|
266
|
+
for i, id_ in enumerate(all_ids):
|
267
|
+
st = self._grad_id_to_state[id_]
|
268
|
+
st.raise_error_with_source_info(ValueError(err + str(st)))
|
269
|
+
else:
|
270
|
+
id2state = {id(st): st for st in self._grad_states}
|
271
|
+
for id_ in all_ids:
|
272
|
+
grad_vals[id_] = id2state[id_].value
|
266
273
|
|
267
274
|
return grad_vals, other_vals
|
268
275
|
|
@@ -449,6 +456,7 @@ def grad(
|
|
449
456
|
has_aux: Optional[bool] = None,
|
450
457
|
return_value: Optional[bool] = False,
|
451
458
|
unit_aware: bool = False,
|
459
|
+
check_states: bool = True,
|
452
460
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
453
461
|
"""
|
454
462
|
Compute the gradient of a scalar-valued function with respect to its arguments.
|
@@ -493,7 +501,8 @@ def grad(
|
|
493
501
|
argnums=argnums,
|
494
502
|
return_value=return_value,
|
495
503
|
has_aux=False if has_aux is None else has_aux,
|
496
|
-
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
|
504
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
505
|
+
check_states=check_states
|
497
506
|
)
|
498
507
|
|
499
508
|
return transform
|
@@ -505,7 +514,8 @@ def grad(
|
|
505
514
|
argnums=argnums,
|
506
515
|
return_value=return_value,
|
507
516
|
has_aux=False if has_aux is None else has_aux,
|
508
|
-
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int)
|
517
|
+
transform_params=dict(holomorphic=holomorphic, allow_int=allow_int),
|
518
|
+
check_states=check_states
|
509
519
|
)
|
510
520
|
|
511
521
|
|
@@ -520,6 +530,7 @@ def vector_grad(
|
|
520
530
|
return_value: bool = False,
|
521
531
|
has_aux: Optional[bool] = None,
|
522
532
|
unit_aware: bool = False,
|
533
|
+
check_states: bool = True,
|
523
534
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
524
535
|
"""Take vector-valued gradients for function ``func``.
|
525
536
|
|
@@ -559,7 +570,8 @@ def vector_grad(
|
|
559
570
|
grad_states=grad_states,
|
560
571
|
argnums=argnums,
|
561
572
|
return_value=return_value,
|
562
|
-
has_aux=False if has_aux is None else has_aux
|
573
|
+
has_aux=False if has_aux is None else has_aux,
|
574
|
+
check_states=check_states
|
563
575
|
)
|
564
576
|
|
565
577
|
return transform
|
@@ -571,7 +583,8 @@ def vector_grad(
|
|
571
583
|
grad_states=grad_states,
|
572
584
|
argnums=argnums,
|
573
585
|
return_value=return_value,
|
574
|
-
has_aux=False if has_aux is None else has_aux
|
586
|
+
has_aux=False if has_aux is None else has_aux,
|
587
|
+
check_states=check_states
|
575
588
|
)
|
576
589
|
|
577
590
|
|
@@ -588,6 +601,7 @@ def jacrev(
|
|
588
601
|
holomorphic: bool = False,
|
589
602
|
allow_int: bool = False,
|
590
603
|
unit_aware: bool = False,
|
604
|
+
check_states: bool = True,
|
591
605
|
) -> GradientTransform:
|
592
606
|
"""
|
593
607
|
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
@@ -638,7 +652,8 @@ def jacrev(
|
|
638
652
|
has_aux=False if has_aux is None else has_aux,
|
639
653
|
transform_params=dict(holomorphic=holomorphic,
|
640
654
|
allow_int=allow_int,
|
641
|
-
unit_aware=unit_aware, )
|
655
|
+
unit_aware=unit_aware, ),
|
656
|
+
check_states=check_states
|
642
657
|
)
|
643
658
|
|
644
659
|
|
@@ -656,6 +671,7 @@ def jacfwd(
|
|
656
671
|
return_value: bool = False,
|
657
672
|
holomorphic: bool = False,
|
658
673
|
unit_aware: bool = False,
|
674
|
+
check_states: bool = True,
|
659
675
|
) -> GradientTransform:
|
660
676
|
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
661
677
|
|
@@ -696,7 +712,8 @@ def jacfwd(
|
|
696
712
|
argnums=argnums,
|
697
713
|
return_value=return_value,
|
698
714
|
has_aux=False if has_aux is None else has_aux,
|
699
|
-
transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware)
|
715
|
+
transform_params=dict(holomorphic=holomorphic, unit_aware=unit_aware),
|
716
|
+
check_states=check_states
|
700
717
|
)
|
701
718
|
|
702
719
|
|
@@ -712,6 +729,7 @@ def hessian(
|
|
712
729
|
holomorphic: bool = False,
|
713
730
|
has_aux: Optional[bool] = None,
|
714
731
|
unit_aware: bool = False,
|
732
|
+
check_states: bool = True,
|
715
733
|
) -> GradientTransform:
|
716
734
|
"""
|
717
735
|
Hessian of ``func`` as a dense array.
|
@@ -752,7 +770,8 @@ def hessian(
|
|
752
770
|
argnums=argnums,
|
753
771
|
return_value=return_value,
|
754
772
|
has_aux=False if has_aux is None else has_aux,
|
755
|
-
transform_params=dict(holomorphic=holomorphic)
|
773
|
+
transform_params=dict(holomorphic=holomorphic),
|
774
|
+
check_states=check_states
|
756
775
|
)
|
757
776
|
|
758
777
|
|
@@ -619,7 +619,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
619
619
|
# _x = jnp.array([1., 2., 3.])
|
620
620
|
# _y = jnp.array([10., 5.])
|
621
621
|
#
|
622
|
-
# class Test(
|
622
|
+
# class Test(brainstate.nn.Module):
|
623
623
|
# def __init__(self):
|
624
624
|
# super(Test, self).__init__()
|
625
625
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -635,12 +635,12 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
635
635
|
#
|
636
636
|
# _jr = jax.jacfwd(f1)(_x, _y)
|
637
637
|
# t = Test()
|
638
|
-
# br =
|
638
|
+
# br = brainstate.augment.jacfwd(t, grad_states=t.x)()
|
639
639
|
# self.assertTrue((br == _jr).all())
|
640
640
|
#
|
641
641
|
# _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
|
642
642
|
# t = Test()
|
643
|
-
# br =
|
643
|
+
# br = brainstate.augment.jacfwd(t, grad_states=[t.x, t.y])()
|
644
644
|
# self.assertTrue((br[0] == _jr[0]).all())
|
645
645
|
# self.assertTrue((br[1] == _jr[1]).all())
|
646
646
|
#
|
@@ -652,7 +652,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
652
652
|
# _x = jnp.array([1., 2., 3.])
|
653
653
|
# _y = jnp.array([10., 5.])
|
654
654
|
#
|
655
|
-
# class Test(
|
655
|
+
# class Test(brainstate.nn.Module):
|
656
656
|
# def __init__(self):
|
657
657
|
# super(Test, self).__init__()
|
658
658
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -667,12 +667,12 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
667
667
|
#
|
668
668
|
# _jr = jax.jacrev(f1)(_x, _y)
|
669
669
|
# t = Test()
|
670
|
-
# br =
|
670
|
+
# br = brainstate.augment.jacrev(t, grad_states=t.x)(_y)
|
671
671
|
# self.assertTrue((br == _jr).all())
|
672
672
|
#
|
673
673
|
# _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
|
674
674
|
# t = Test()
|
675
|
-
# var_grads, arg_grads =
|
675
|
+
# var_grads, arg_grads = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0)(_y)
|
676
676
|
# print(var_grads, )
|
677
677
|
# print(arg_grads, )
|
678
678
|
# self.assertTrue((var_grads == _jr[0]).all())
|
@@ -686,7 +686,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
686
686
|
# _x = jnp.array([1., 2., 3.])
|
687
687
|
# _y = jnp.array([10., 5.])
|
688
688
|
#
|
689
|
-
# class Test(
|
689
|
+
# class Test(brainstate.nn.Module):
|
690
690
|
# def __init__(self):
|
691
691
|
# super(Test, self).__init__()
|
692
692
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -701,12 +701,12 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
701
701
|
#
|
702
702
|
# _jr = jax.jacfwd(f1)(_x, _y)
|
703
703
|
# t = Test()
|
704
|
-
# br =
|
704
|
+
# br = brainstate.augment.jacfwd(t, grad_states=t.x)(_y)
|
705
705
|
# self.assertTrue((br == _jr).all())
|
706
706
|
#
|
707
707
|
# _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
|
708
708
|
# t = Test()
|
709
|
-
# var_grads, arg_grads =
|
709
|
+
# var_grads, arg_grads = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0)(_y)
|
710
710
|
# print(var_grads, )
|
711
711
|
# print(arg_grads, )
|
712
712
|
# self.assertTrue((var_grads == _jr[0]).all())
|
@@ -722,7 +722,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
722
722
|
# _x = jnp.array([1., 2., 3.])
|
723
723
|
# _y = jnp.array([10., 5.])
|
724
724
|
#
|
725
|
-
# class Test(
|
725
|
+
# class Test(brainstate.nn.Module):
|
726
726
|
# def __init__(self):
|
727
727
|
# super(Test, self).__init__()
|
728
728
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -737,13 +737,13 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
737
737
|
#
|
738
738
|
# _jr = jax.jacrev(f1)(_x, _y)
|
739
739
|
# t = Test()
|
740
|
-
# br, _ =
|
740
|
+
# br, _ = brainstate.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
|
741
741
|
# self.assertTrue((br == _jr).all())
|
742
742
|
#
|
743
743
|
# t = Test()
|
744
744
|
# _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
|
745
745
|
# _aux = t(_y)[1]
|
746
|
-
# (var_grads, arg_grads), aux =
|
746
|
+
# (var_grads, arg_grads), aux = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
|
747
747
|
# print(var_grads, )
|
748
748
|
# print(arg_grads, )
|
749
749
|
# self.assertTrue((var_grads == _jr[0]).all())
|
@@ -762,7 +762,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
762
762
|
# _x = jnp.array([1., 2., 3.])
|
763
763
|
# _y = jnp.array([10., 5.])
|
764
764
|
#
|
765
|
-
# class Test(
|
765
|
+
# class Test(brainstate.nn.Module):
|
766
766
|
# def __init__(self):
|
767
767
|
# super(Test, self).__init__()
|
768
768
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -777,7 +777,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
777
777
|
#
|
778
778
|
# _jr = jax.jacfwd(f1)(_x, _y)
|
779
779
|
# t = Test()
|
780
|
-
# br, (c, d) =
|
780
|
+
# br, (c, d) = brainstate.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
|
781
781
|
# # print(_jr)
|
782
782
|
# # print(br)
|
783
783
|
# a = (br == _jr)
|
@@ -786,7 +786,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
786
786
|
# t = Test()
|
787
787
|
# _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
|
788
788
|
# _aux = t(_y)[1]
|
789
|
-
# (var_grads, arg_grads), aux =
|
789
|
+
# (var_grads, arg_grads), aux = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True)(_y)
|
790
790
|
# print(var_grads, )
|
791
791
|
# print(arg_grads, )
|
792
792
|
# self.assertTrue((var_grads == _jr[0]).all())
|
@@ -805,7 +805,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
805
805
|
# _x = jnp.array([1., 2., 3.])
|
806
806
|
# _y = jnp.array([10., 5.])
|
807
807
|
#
|
808
|
-
# class Test(
|
808
|
+
# class Test(brainstate.nn.Module):
|
809
809
|
# def __init__(self):
|
810
810
|
# super(Test, self).__init__()
|
811
811
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -820,13 +820,13 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
820
820
|
#
|
821
821
|
# _jr = jax.jacrev(f1)(_x, _y)
|
822
822
|
# t = Test()
|
823
|
-
# br, _ =
|
823
|
+
# br, _ = brainstate.augment.jacrev(t, grad_states=t.x, has_aux=True)(_y)
|
824
824
|
# self.assertTrue((br == _jr).all())
|
825
825
|
#
|
826
826
|
# t = Test()
|
827
827
|
# _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
|
828
828
|
# _val, _aux = t(_y)
|
829
|
-
# (var_grads, arg_grads), value, aux =
|
829
|
+
# (var_grads, arg_grads), value, aux = brainstate.augment.jacrev(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
|
830
830
|
# print(var_grads, )
|
831
831
|
# print(arg_grads, )
|
832
832
|
# self.assertTrue((var_grads == _jr[0]).all())
|
@@ -846,7 +846,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
846
846
|
# _x = jnp.array([1., 2., 3.])
|
847
847
|
# _y = jnp.array([10., 5.])
|
848
848
|
#
|
849
|
-
# class Test(
|
849
|
+
# class Test(brainstate.nn.Module):
|
850
850
|
# def __init__(self):
|
851
851
|
# super(Test, self).__init__()
|
852
852
|
# self.x = jnp.Variable(jnp.array([1., 2., 3.]))
|
@@ -861,13 +861,13 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
861
861
|
#
|
862
862
|
# _jr = jax.jacfwd(f1)(_x, _y)
|
863
863
|
# t = Test()
|
864
|
-
# br, _ =
|
864
|
+
# br, _ = brainstate.augment.jacfwd(t, grad_states=t.x, has_aux=True)(_y)
|
865
865
|
# self.assertTrue((br == _jr).all())
|
866
866
|
#
|
867
867
|
# t = Test()
|
868
868
|
# _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
|
869
869
|
# _val, _aux = t(_y)
|
870
|
-
# (var_grads, arg_grads), value, aux =
|
870
|
+
# (var_grads, arg_grads), value, aux = brainstate.augment.jacfwd(t, grad_states=t.x, argnums=0, has_aux=True, return_value=True)(_y)
|
871
871
|
# print(_val, )
|
872
872
|
# print('_aux: ', _aux, 'aux: ', aux)
|
873
873
|
# print(var_grads, )
|
@@ -884,7 +884,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
884
884
|
# def test1(self):
|
885
885
|
# f = lambda x: 3 * x ** 2
|
886
886
|
# _x = jnp.ones(10)
|
887
|
-
# pprint(
|
887
|
+
# pprint(brainstate.augment.vector_grad(f, argnums=0)(_x))
|
888
888
|
#
|
889
889
|
# def test2(self):
|
890
890
|
# def f(x, y):
|
@@ -894,14 +894,14 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
894
894
|
# _x = jnp.ones(5)
|
895
895
|
# _y = jnp.ones(5)
|
896
896
|
#
|
897
|
-
# g =
|
897
|
+
# g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
|
898
898
|
# pprint(g)
|
899
899
|
# self.assertTrue(jnp.array_equal(g, 2 * _x))
|
900
900
|
#
|
901
|
-
# g =
|
901
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
|
902
902
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x))
|
903
903
|
#
|
904
|
-
# g =
|
904
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
|
905
905
|
# pprint(g)
|
906
906
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x))
|
907
907
|
# self.assertTrue(jnp.array_equal(g[1], 2 * _y))
|
@@ -915,14 +915,14 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
915
915
|
# _x = jnp.ones(5)
|
916
916
|
# _y = jnp.ones(5)
|
917
917
|
#
|
918
|
-
# g =
|
918
|
+
# g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
|
919
919
|
# # pprint(g)
|
920
920
|
# self.assertTrue(jnp.array_equal(g, 2 * _x + 3 * _x ** 2))
|
921
921
|
#
|
922
|
-
# g =
|
922
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
|
923
923
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x + 3 * _x ** 2))
|
924
924
|
#
|
925
|
-
# g =
|
925
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
|
926
926
|
# # pprint(g)
|
927
927
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x + 3 * _x ** 2))
|
928
928
|
# self.assertTrue(jnp.array_equal(g[1], 2 * _y + 3 * _y ** 2))
|
@@ -935,14 +935,14 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
935
935
|
# _x = jnp.ones((5, 5))
|
936
936
|
# _y = jnp.ones((5, 5))
|
937
937
|
#
|
938
|
-
# g =
|
938
|
+
# g = brainstate.augment.vector_grad(f, argnums=0)(_x, _y)
|
939
939
|
# pprint(g)
|
940
940
|
# self.assertTrue(jnp.array_equal(g, 2 * _x))
|
941
941
|
#
|
942
|
-
# g =
|
942
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0,))(_x, _y)
|
943
943
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x))
|
944
944
|
#
|
945
|
-
# g =
|
945
|
+
# g = brainstate.augment.vector_grad(f, argnums=(0, 1))(_x, _y)
|
946
946
|
# pprint(g)
|
947
947
|
# self.assertTrue(jnp.array_equal(g[0], 2 * _x))
|
948
948
|
# self.assertTrue(jnp.array_equal(g[1], 2 * _y))
|
@@ -956,7 +956,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
956
956
|
# _x = jnp.ones(5)
|
957
957
|
# _y = jnp.ones(5)
|
958
958
|
#
|
959
|
-
# g, aux =
|
959
|
+
# g, aux = brainstate.augment.vector_grad(f, has_aux=True)(_x, _y)
|
960
960
|
# pprint(g, )
|
961
961
|
# pprint(aux)
|
962
962
|
# self.assertTrue(jnp.array_equal(g, 2 * _x))
|
@@ -970,7 +970,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
970
970
|
# _x = jnp.ones(5)
|
971
971
|
# _y = jnp.ones(5)
|
972
972
|
#
|
973
|
-
# g, value =
|
973
|
+
# g, value = brainstate.augment.vector_grad(f, return_value=True)(_x, _y)
|
974
974
|
# pprint(g, )
|
975
975
|
# pprint(value)
|
976
976
|
# self.assertTrue(jnp.array_equal(g, 2 * _x))
|
@@ -985,7 +985,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
985
985
|
# _x = jnp.ones(5)
|
986
986
|
# _y = jnp.ones(5)
|
987
987
|
#
|
988
|
-
# g, value, aux =
|
988
|
+
# g, value, aux = brainstate.augment.vector_grad(f, has_aux=True, return_value=True)(_x, _y)
|
989
989
|
# print('grad', g)
|
990
990
|
# print('value', value)
|
991
991
|
# print('aux', aux)
|
@@ -996,7 +996,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
996
996
|
#
|
997
997
|
# class TestClassFuncVectorGrad(unittest.TestCase):
|
998
998
|
# def test1(self):
|
999
|
-
# class Test(
|
999
|
+
# class Test(brainstate.nn.Module):
|
1000
1000
|
# def __init__(self):
|
1001
1001
|
# super(Test, self).__init__()
|
1002
1002
|
# self.x = jnp.Variable(jnp.ones(5))
|
@@ -1007,13 +1007,13 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1007
1007
|
#
|
1008
1008
|
# t = Test()
|
1009
1009
|
#
|
1010
|
-
# g =
|
1010
|
+
# g = brainstate.augment.vector_grad(t, grad_states=t.x)()
|
1011
1011
|
# self.assertTrue(jnp.array_equal(g, 2 * t.x))
|
1012
1012
|
#
|
1013
|
-
# g =
|
1013
|
+
# g = brainstate.augment.vector_grad(t, grad_states=(t.x,))()
|
1014
1014
|
# self.assertTrue(jnp.array_equal(g[0], 2 * t.x))
|
1015
1015
|
#
|
1016
|
-
# g =
|
1016
|
+
# g = brainstate.augment.vector_grad(t, grad_states=(t.x, t.y))()
|
1017
1017
|
# self.assertTrue(jnp.array_equal(g[0], 2 * t.x))
|
1018
1018
|
# self.assertTrue(jnp.array_equal(g[1], 2 * t.y))
|
1019
1019
|
#
|
@@ -1025,20 +1025,20 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1025
1025
|
#
|
1026
1026
|
# class TestDebug(parameterized.TestCase):
|
1027
1027
|
# def test_debug1(self):
|
1028
|
-
# a =
|
1028
|
+
# a = brainstate.random.RandomState()
|
1029
1029
|
#
|
1030
1030
|
# def f(b):
|
1031
1031
|
# print(a.value)
|
1032
1032
|
# return a + b + a.random()
|
1033
1033
|
#
|
1034
|
-
# f =
|
1034
|
+
# f = brainstate.augment.vector_grad(f, argnums=0)
|
1035
1035
|
# f(1.)
|
1036
1036
|
#
|
1037
1037
|
# with jax.disable_jit():
|
1038
1038
|
# f(1.)
|
1039
1039
|
#
|
1040
1040
|
# @parameterized.product(
|
1041
|
-
# grad_fun=[
|
1041
|
+
# grad_fun=[brainstate.augment.grad, brainstate.augment.vector_grad]
|
1042
1042
|
# )
|
1043
1043
|
# def test_print_info1(self, grad_fun):
|
1044
1044
|
# file = tempfile.TemporaryFile(mode='w+')
|
@@ -1075,7 +1075,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1075
1075
|
# self.assertTrue(file.read().strip() == expect_res.strip())
|
1076
1076
|
#
|
1077
1077
|
# @parameterized.product(
|
1078
|
-
# grad_fun=[
|
1078
|
+
# grad_fun=[brainstate.augment.grad, brainstate.augment.vector_grad]
|
1079
1079
|
# )
|
1080
1080
|
# def test_print_info2(self, grad_fun):
|
1081
1081
|
# file = tempfile.TemporaryFile(mode='w+')
|
@@ -1117,7 +1117,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1117
1117
|
# a = jnp.Variable(jnp.ones(2))
|
1118
1118
|
# b = jnp.Variable(jnp.zeros(2))
|
1119
1119
|
#
|
1120
|
-
# @
|
1120
|
+
# @brainstate.augment.vector_grad(argnums=0)
|
1121
1121
|
# def f1(c):
|
1122
1122
|
# a.value += 1
|
1123
1123
|
# b.value += 10
|
@@ -1149,7 +1149,7 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1149
1149
|
#
|
1150
1150
|
#
|
1151
1151
|
# def run_fun(d):
|
1152
|
-
# @
|
1152
|
+
# @brainstate.augment.vector_grad(argnums=0)
|
1153
1153
|
# def f1(c):
|
1154
1154
|
# a.value += d
|
1155
1155
|
# b.value += 10
|
@@ -1178,8 +1178,8 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1178
1178
|
# print('compiling f ...', file=file)
|
1179
1179
|
# return a + b
|
1180
1180
|
#
|
1181
|
-
# grad1 =
|
1182
|
-
# grad2 =
|
1181
|
+
# grad1 = brainstate.augment.grad(f)(1., 2.) # call "f" twice, one for Variable finding, one for compiling
|
1182
|
+
# grad2 = brainstate.augment.vector_grad(f)(1., 2.) # call "f" once for compiling
|
1183
1183
|
#
|
1184
1184
|
# file.seek(0)
|
1185
1185
|
# print(file.read().strip())
|
@@ -40,13 +40,13 @@ def abstract_init(
|
|
40
40
|
|
41
41
|
Here's an example::
|
42
42
|
|
43
|
-
>>> import brainstate
|
43
|
+
>>> import brainstate
|
44
44
|
>>> class MLP:
|
45
45
|
... def __init__(self, n_in, n_mid, n_out):
|
46
|
-
... self.dense1 =
|
47
|
-
... self.dense2 =
|
46
|
+
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
47
|
+
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
48
48
|
|
49
|
-
>>> r =
|
49
|
+
>>> r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
50
50
|
>>> r
|
51
51
|
MLP(
|
52
52
|
dense1=Linear(
|
brainstate/augment/_mapping.py
CHANGED
@@ -185,10 +185,10 @@ def _compile_stateful_function(
|
|
185
185
|
if isinstance(in_axes, int):
|
186
186
|
args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
|
187
187
|
elif isinstance(in_axes, tuple):
|
188
|
-
args = tuple(
|
189
|
-
|
190
|
-
|
191
|
-
)
|
188
|
+
args = tuple([
|
189
|
+
arg if in_axis is None else _remove_axis(arg, in_axis)
|
190
|
+
for arg, in_axis in zip(args, in_axes)
|
191
|
+
])
|
192
192
|
stateful_fn.make_jaxpr(state_vals, args)
|
193
193
|
return stateful_fn.get_arg_cache_key(state_vals, args)
|
194
194
|
|
@@ -383,10 +383,7 @@ def _vmap_transform(
|
|
383
383
|
stateful_fn.axis_env = axis_env
|
384
384
|
|
385
385
|
# stateful function
|
386
|
-
stateful_fn = StatefulFunction(
|
387
|
-
_vmap_fn_for_compilation,
|
388
|
-
name='vmap',
|
389
|
-
)
|
386
|
+
stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
|
390
387
|
|
391
388
|
@functools.wraps(f)
|
392
389
|
def new_fn_for_vmap(
|
@@ -460,7 +457,10 @@ def _vmap_transform(
|
|
460
457
|
# analyze vmapping axis error
|
461
458
|
for state in state_trace.get_write_states():
|
462
459
|
leaves = jax.tree.leaves(state.value)
|
463
|
-
if
|
460
|
+
if (
|
461
|
+
any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
|
462
|
+
and state not in out_state_to_axis
|
463
|
+
):
|
464
464
|
if isinstance(state, RandomState) and state in rng_sets:
|
465
465
|
continue
|
466
466
|
state.raise_error_with_source_info(
|
@@ -481,7 +481,7 @@ def _vmap_transform(
|
|
481
481
|
return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
|
482
482
|
|
483
483
|
@functools.wraps(f)
|
484
|
-
def vmapped_fn(*args):
|
484
|
+
def vmapped_fn(*args, **kwargs):
|
485
485
|
"""
|
486
486
|
Applies vectorized mapping (vmap) to the input function while managing state.
|
487
487
|
|
@@ -503,6 +503,11 @@ def _vmap_transform(
|
|
503
503
|
data structures (e.g., axis_to_in_states, in_state_to_axis) which
|
504
504
|
should be defined in the broader context.
|
505
505
|
"""
|
506
|
+
if len(kwargs):
|
507
|
+
raise NotImplementedError(
|
508
|
+
"Keyword arguments `f(**kwargs)` are not supported in brainstate.augment.vmap"
|
509
|
+
)
|
510
|
+
|
506
511
|
# in states values
|
507
512
|
in_state_map_vals = [
|
508
513
|
[st.value for st in states]
|
@@ -615,16 +620,16 @@ def vmap(
|
|
615
620
|
|
616
621
|
These are several example usage::
|
617
622
|
|
618
|
-
>>> import brainstate as
|
623
|
+
>>> import brainstate as brainstate
|
619
624
|
>>> import jax.numpy as jnp
|
620
625
|
|
621
|
-
>>> class Model(
|
626
|
+
>>> class Model(brainstate.nn.Module):
|
622
627
|
>>> def __init__(self):
|
623
628
|
>>> super().__init__()
|
624
629
|
>>>
|
625
|
-
>>> self.a =
|
626
|
-
>>> self.b =
|
627
|
-
>>> self.c =
|
630
|
+
>>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
631
|
+
>>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
632
|
+
>>> self.c = brainstate.State(brainstate.random.randn(1))
|
628
633
|
|
629
634
|
>>> def __call__(self, *args, **kwargs):
|
630
635
|
>>> self.c.value = self.a.value * self.b.value
|
@@ -632,9 +637,9 @@ def vmap(
|
|
632
637
|
|
633
638
|
>>> model = Model()
|
634
639
|
|
635
|
-
>>> r =
|
640
|
+
>>> r = brainstate.augment.vmap(
|
636
641
|
>>> model,
|
637
|
-
>>> in_states=model.states(
|
642
|
+
>>> in_states=model.states(brainstate.ShortTermState),
|
638
643
|
>>> out_states=model.c
|
639
644
|
>>> )()
|
640
645
|
|