brainstate 0.1.5__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +5 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +13 -8
  7. brainstate/compile/_conditions.py +2 -2
  8. brainstate/compile/_make_jaxpr.py +59 -6
  9. brainstate/compile/_progress_bar.py +2 -2
  10. brainstate/environ.py +19 -19
  11. brainstate/functional/_activations_test.py +12 -12
  12. brainstate/graph/_graph_operation.py +69 -69
  13. brainstate/graph/_graph_operation_test.py +2 -2
  14. brainstate/mixin.py +0 -17
  15. brainstate/nn/_collective_ops.py +4 -4
  16. brainstate/nn/_dropout_test.py +2 -2
  17. brainstate/nn/_dynamics.py +53 -35
  18. brainstate/nn/_elementwise.py +30 -30
  19. brainstate/nn/_linear.py +4 -4
  20. brainstate/nn/_module.py +6 -6
  21. brainstate/nn/_module_test.py +1 -1
  22. brainstate/nn/_normalizations.py +11 -11
  23. brainstate/nn/_normalizations_test.py +6 -6
  24. brainstate/nn/_poolings.py +24 -24
  25. brainstate/nn/_synapse.py +1 -12
  26. brainstate/nn/_utils.py +1 -1
  27. brainstate/nn/metrics.py +4 -4
  28. brainstate/optim/_optax_optimizer.py +8 -8
  29. brainstate/random/_rand_funs.py +37 -37
  30. brainstate/random/_rand_funs_test.py +3 -3
  31. brainstate/random/_rand_seed.py +7 -7
  32. brainstate/surrogate.py +40 -40
  33. brainstate/util/pretty_pytree.py +10 -10
  34. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  35. brainstate/util/struct.py +7 -7
  36. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  37. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/RECORD +40 -40
  38. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  39. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  40. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -352,6 +352,13 @@ class StatefulFunction(PrettyObject):
352
352
  cache_key = default_cache_key
353
353
  return self.get_state_trace(cache_key).get_write_states()
354
354
 
355
+ def _check_input(self, x):
356
+ if isinstance(x, State):
357
+ raise ValueError(
358
+ 'Inputs for brainstate transformations cannot be an instance of State. '
359
+ f'But we got {x}'
360
+ )
361
+
355
362
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
356
363
  """
357
364
  Get the static arguments from the arguments.
@@ -370,22 +377,35 @@ class StatefulFunction(PrettyObject):
370
377
  static_args.append(arg)
371
378
  else:
372
379
  dyn_args.append(arg)
373
- dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
380
+ dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
374
381
  static_kwargs, dyn_kwargs = [], []
375
382
  for k, v in kwargs.items():
376
383
  if k in self.static_argnames:
377
384
  static_kwargs.append((k, v))
378
385
  else:
379
386
  dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
380
- return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
387
+
388
+ static_args = make_hashable(tuple(static_args))
389
+ dyn_args = make_hashable(tuple(dyn_args))
390
+ static_kwargs = make_hashable(static_kwargs)
391
+ dyn_kwargs = make_hashable(dyn_kwargs)
392
+
393
+ cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
381
394
  elif self.cache_type is None:
382
395
  num_arg = len(args)
383
396
  static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
384
397
  static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
385
- return tuple([static_args, static_kwargs])
398
+
399
+ # Make everything hashable
400
+ static_args = make_hashable(static_args)
401
+ static_kwargs = make_hashable(static_kwargs)
402
+
403
+ cache_key = (static_args, static_kwargs)
386
404
  else:
387
405
  raise ValueError(f"Invalid cache type: {self.cache_type}")
388
406
 
407
+ return cache_key
408
+
389
409
  def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
390
410
  """
391
411
  Compile the function, and get the states that are read and written by this function.
@@ -480,6 +500,9 @@ class StatefulFunction(PrettyObject):
480
500
  # static args
481
501
  cache_key = self.get_arg_cache_key(*args, **kwargs)
482
502
 
503
+ # check input types
504
+ jax.tree.map(self._check_input, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
505
+
483
506
  if cache_key not in self._cached_state_trace:
484
507
  try:
485
508
  # jaxpr
@@ -637,15 +660,15 @@ def make_jaxpr(
637
660
  instead give a few examples.
638
661
 
639
662
  >>> import jax
640
- >>> import brainstate as bst
663
+ >>> import brainstate as brainstate
641
664
  >>>
642
665
  >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
643
666
  >>> print(f(3.0))
644
667
  -0.83602
645
- >>> jaxpr, states = bst.compile.make_jaxpr(f)(3.0)
668
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
646
669
  >>> jaxpr
647
670
  { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
648
- >>> jaxpr, states = bst.compile.make_jaxpr(jax.grad(f))(3.0)
671
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
649
672
  >>> jaxpr
650
673
  { lambda ; a:f32[]. let
651
674
  b:f32[] = cos a
@@ -844,3 +867,33 @@ def _make_jaxpr(
844
867
  if hasattr(fun, "__name__"):
845
868
  make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
846
869
  return make_jaxpr_f
870
+
871
+
872
+ def make_hashable(obj):
873
+ """Convert a pytree into a hashable representation."""
874
+ if isinstance(obj, (list, tuple)):
875
+ return tuple(make_hashable(item) for item in obj)
876
+ elif isinstance(obj, dict):
877
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
878
+ elif isinstance(obj, set):
879
+ return frozenset(make_hashable(item) for item in obj)
880
+ elif hasattr(obj, '__dict__'): # Handle custom objects
881
+ return (
882
+ obj.__class__.__name__,
883
+ tuple(
884
+ sorted(
885
+ (k, make_hashable(v))
886
+ for k, v in obj.__dict__.items()
887
+ if not k.startswith('_')
888
+ )
889
+ )
890
+ )
891
+ else:
892
+ # # Use JAX's tree_util for any other pytree structures
893
+ # try:
894
+ # leaves, treedef = jax.tree_util.tree_flatten(obj)
895
+ # hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
896
+ # return (str(treedef), hashable_leaves)
897
+ # except:
898
+ # # Assume obj is already hashable
899
+ return obj
@@ -53,7 +53,7 @@ class ProgressBar(object):
53
53
 
54
54
  .. code-block:: python
55
55
 
56
- a = bst.State(1.)
56
+ a = brainstate.State(1.)
57
57
  def loop_fn(x):
58
58
  a.value = x.value + 1.
59
59
  return jnp.sum(x ** 2)
@@ -61,7 +61,7 @@ class ProgressBar(object):
61
61
  pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
62
62
  lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
63
63
 
64
- bst.compile.for_loop(loop_fn, xs, pbar=pbar)
64
+ brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
65
65
 
66
66
  In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
67
67
  the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
brainstate/environ.py CHANGED
@@ -76,9 +76,9 @@ def context(**kwargs):
76
76
 
77
77
  For instance::
78
78
 
79
- >>> import brainstate as bst
80
- >>> with bst.environ.context(dt=0.1) as env:
81
- ... dt = bst.environ.get('dt')
79
+ >>> import brainstate as brainstate
80
+ >>> with brainstate.environ.context(dt=0.1) as env:
81
+ ... dt = brainstate.environ.get('dt')
82
82
  ... print(env)
83
83
 
84
84
  """
@@ -424,10 +424,10 @@ def dftype() -> DTypeLike:
424
424
 
425
425
  For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
426
 
427
- >>> import brainstate as bst
427
+ >>> import brainstate as brainstate
428
428
  >>> import numpy as np
429
- >>> with bst.environ.context(precision=32):
430
- ... a = np.zeros(1, dtype=bst.environ.dftype())
429
+ >>> with brainstate.environ.context(precision=32):
430
+ ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
431
  >>> print(a.dtype)
432
432
 
433
433
  Returns
@@ -448,10 +448,10 @@ def ditype() -> DTypeLike:
448
448
 
449
449
  For example, if the precision is set to 32, the default integer data type is ``np.int32``.
450
450
 
451
- >>> import brainstate as bst
451
+ >>> import brainstate as brainstate
452
452
  >>> import numpy as np
453
- >>> with bst.environ.context(precision=32):
454
- ... a = np.zeros(1, dtype=bst.environ.ditype())
453
+ >>> with brainstate.environ.context(precision=32):
454
+ ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
455
  >>> print(a.dtype)
456
456
  int32
457
457
 
@@ -474,10 +474,10 @@ def dutype() -> DTypeLike:
474
474
 
475
475
  For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
476
 
477
- >>> import brainstate as bst
477
+ >>> import brainstate as brainstate
478
478
  >>> import numpy as np
479
- >>> with bst.environ.context(precision=32):
480
- ... a = np.zeros(1, dtype=bst.environ.dutype())
479
+ >>> with brainstate.environ.context(precision=32):
480
+ ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
481
  >>> print(a.dtype)
482
482
  uint32
483
483
 
@@ -499,10 +499,10 @@ def dctype() -> DTypeLike:
499
499
 
500
500
  For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
501
 
502
- >>> import brainstate as bst
502
+ >>> import brainstate as brainstate
503
503
  >>> import numpy as np
504
- >>> with bst.environ.context(precision=32):
505
- ... a = np.zeros(1, dtype=bst.environ.dctype())
504
+ >>> with brainstate.environ.context(precision=32):
505
+ ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
506
  >>> print(a.dtype)
507
507
  complex64
508
508
 
@@ -529,19 +529,19 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
529
529
 
530
530
  For example, you can register a default behavior for the key 'dt' by::
531
531
 
532
- >>> import brainstate as bst
532
+ >>> import brainstate as brainstate
533
533
  >>> def dt_behavior(dt):
534
534
  ... print(f'Set the default dt to {dt}.')
535
535
  ...
536
- >>> bst.environ.register_default_behavior('dt', dt_behavior)
536
+ >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
537
537
 
538
538
  Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
539
  `dt_behavior` will be called with
540
540
  `dt_behavior(0.1)`.
541
541
 
542
- >>> bst.environ.set(dt=0.1)
542
+ >>> brainstate.environ.set(dt=0.1)
543
543
  Set the default dt to 0.1.
544
- >>> with bst.environ.context(dt=0.2):
544
+ >>> with brainstate.environ.context(dt=0.2):
545
545
  ... pass
546
546
  Set the default dt to 0.2.
547
547
  Set the default dt to 0.1.
@@ -70,39 +70,39 @@ class NNFunctionsTest(jtu.JaxTestCase):
70
70
  check_dtypes=False)
71
71
 
72
72
  # def testSquareplusGrad(self):
73
- # check_grads(bst.functional.squareplus, (1e-8,), order=4,
73
+ # check_grads(brainstate.functional.squareplus, (1e-8,), order=4,
74
74
  # )
75
75
 
76
76
  # def testSquareplusGradZero(self):
77
- # check_grads(bst.functional.squareplus, (0.,), order=1,
77
+ # check_grads(brainstate.functional.squareplus, (0.,), order=1,
78
78
  # )
79
79
 
80
80
  # def testSquareplusGradNegInf(self):
81
- # check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
81
+ # check_grads(brainstate.functional.squareplus, (-float('inf'),), order=1,
82
82
  # )
83
83
 
84
84
  # def testSquareplusGradNan(self):
85
- # check_grads(bst.functional.squareplus, (float('nan'),), order=1,
85
+ # check_grads(brainstate.functional.squareplus, (float('nan'),), order=1,
86
86
  # )
87
87
 
88
88
  # @parameterized.parameters([float] + jtu.dtypes.floating)
89
89
  # def testSquareplusZero(self, dtype):
90
- # self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
90
+ # self.assertEqual(dtype(1), brainstate.functional.squareplus(dtype(0), dtype(4)))
91
91
  #
92
92
  # def testMishGrad(self):
93
- # check_grads(bst.functional.mish, (1e-8,), order=4,
93
+ # check_grads(brainstate.functional.mish, (1e-8,), order=4,
94
94
  # )
95
95
  #
96
96
  # def testMishGradZero(self):
97
- # check_grads(bst.functional.mish, (0.,), order=1,
97
+ # check_grads(brainstate.functional.mish, (0.,), order=1,
98
98
  # )
99
99
  #
100
100
  # def testMishGradNegInf(self):
101
- # check_grads(bst.functional.mish, (-float('inf'),), order=1,
101
+ # check_grads(brainstate.functional.mish, (-float('inf'),), order=1,
102
102
  # )
103
103
  #
104
104
  # def testMishGradNan(self):
105
- # check_grads(bst.functional.mish, (float('nan'),), order=1,
105
+ # check_grads(brainstate.functional.mish, (float('nan'),), order=1,
106
106
  # )
107
107
 
108
108
  @parameterized.parameters([float] + jtu.dtypes.floating)
@@ -137,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
137
137
  self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
138
138
 
139
139
  # def testSquareplusValue(self):
140
- # val = bst.functional.squareplus(1e3)
140
+ # val = brainstate.functional.squareplus(1e3)
141
141
  # self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
142
 
143
143
  def testMishValue(self):
@@ -177,7 +177,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
177
177
  brainstate.functional.softplus,
178
178
  brainstate.functional.sparse_plus,
179
179
  brainstate.functional.sigmoid,
180
- # bst.functional.squareplus,
180
+ # brainstate.functional.squareplus,
181
181
  brainstate.functional.mish)))
182
182
  def testDtypeMatchesInput(self, dtype, fn):
183
183
  x = jnp.zeros((), dtype=dtype)
@@ -306,7 +306,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
306
306
 
307
307
  def testCustomJVPLeak2(self):
308
308
  # https://github.com/google/jax/issues/8171
309
- # The above test uses jax.bst.functional.sigmoid, as in the original #8171, but that
309
+ # The above test uses jax.brainstate.functional.sigmoid, as in the original #8171, but that
310
310
  # function no longer actually has a custom_jvp! So we inline the old def.
311
311
 
312
312
  @jax.custom_jvp
@@ -473,8 +473,8 @@ def flatten(
473
473
 
474
474
  Example::
475
475
 
476
- >>> import brainstate as bst
477
- >>> node = bst.graph.Node()
476
+ >>> import brainstate as brainstate
477
+ >>> node = brainstate.graph.Node()
478
478
  >>> graph_def, state_mapping = flatten(node)
479
479
  >>> print(graph_def)
480
480
  >>> print(state_mapping)
@@ -709,15 +709,15 @@ def unflatten(
709
709
 
710
710
  Example::
711
711
 
712
- >>> import brainstate as bst
713
- >>> class MyNode(bst.graph.Node):
712
+ >>> import brainstate as brainstate
713
+ >>> class MyNode(brainstate.graph.Node):
714
714
  ... def __init__(self):
715
- ... self.a = bst.nn.Linear(2, 3)
716
- ... self.b = bst.nn.Linear(3, 4)
717
- ... self.c = [bst.nn.Linear(4, 5), bst.nn.Linear(5, 6)]
718
- ... self.d = {'x': bst.nn.Linear(6, 7), 'y': bst.nn.Linear(7, 8)}
715
+ ... self.a = brainstate.nn.Linear(2, 3)
716
+ ... self.b = brainstate.nn.Linear(3, 4)
717
+ ... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
718
+ ... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
719
719
  ...
720
- >>> graphdef, statetree = bst.graph.flatten(MyNode())
720
+ >>> graphdef, statetree = brainstate.graph.flatten(MyNode())
721
721
  >>> statetree
722
722
  NestedDict({
723
723
  'a': {
@@ -764,7 +764,7 @@ def unflatten(
764
764
  }
765
765
  }
766
766
  })
767
- >>> node = bst.graph.unflatten(graphdef, statetree)
767
+ >>> node = brainstate.graph.unflatten(graphdef, statetree)
768
768
  >>> node
769
769
  MyNode(
770
770
  a=Linear(
@@ -942,21 +942,21 @@ def pop_states(
942
942
 
943
943
  Example usage::
944
944
 
945
- >>> import brainstate as bst
945
+ >>> import brainstate as brainstate
946
946
  >>> import jax.numpy as jnp
947
947
 
948
- >>> class Model(bst.nn.Module):
948
+ >>> class Model(brainstate.nn.Module):
949
949
  ... def __init__(self):
950
950
  ... super().__init__()
951
- ... self.a = bst.nn.Linear(2, 3)
952
- ... self.b = bst.nn.LIF([10, 2])
951
+ ... self.a = brainstate.nn.Linear(2, 3)
952
+ ... self.b = brainstate.nn.LIF([10, 2])
953
953
 
954
954
  >>> model = Model()
955
- >>> with bst.catch_new_states('new'):
956
- ... bst.nn.init_all_states(model)
955
+ >>> with brainstate.catch_new_states('new'):
956
+ ... brainstate.nn.init_all_states(model)
957
957
 
958
958
  >>> assert len(model.states()) == 2
959
- >>> model_states = bst.graph.pop_states(model, 'new')
959
+ >>> model_states = brainstate.graph.pop_states(model, 'new')
960
960
  >>> model_states
961
961
  NestedDict({
962
962
  'b': {
@@ -1046,16 +1046,16 @@ def treefy_split(
1046
1046
 
1047
1047
  Example usage::
1048
1048
 
1049
- >>> from joblib.testing import param >>> import brainstate as bst
1049
+ >>> from joblib.testing import param >>> import brainstate as brainstate
1050
1050
  >>> import jax, jax.numpy as jnp
1051
1051
  ...
1052
- >>> class Foo(bst.graph.Node):
1052
+ >>> class Foo(brainstate.graph.Node):
1053
1053
  ... def __init__(self):
1054
- ... self.a = bst.nn.BatchNorm1d([10, 2])
1055
- ... self.b = bst.nn.Linear(2, 3)
1054
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1055
+ ... self.b = brainstate.nn.Linear(2, 3)
1056
1056
  ...
1057
1057
  >>> node = Foo()
1058
- >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1058
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1059
1059
  ...
1060
1060
  >>> params
1061
1061
  NestedDict({
@@ -1119,21 +1119,21 @@ def treefy_merge(
1119
1119
 
1120
1120
  Example usage::
1121
1121
 
1122
- >>> import brainstate as bst
1122
+ >>> import brainstate as brainstate
1123
1123
  >>> import jax, jax.numpy as jnp
1124
1124
  ...
1125
- >>> class Foo(bst.graph.Node):
1125
+ >>> class Foo(brainstate.graph.Node):
1126
1126
  ... def __init__(self):
1127
- ... self.a = bst.nn.BatchNorm1d([10, 2])
1128
- ... self.b = bst.nn.Linear(2, 3)
1127
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1128
+ ... self.b = brainstate.nn.Linear(2, 3)
1129
1129
  ...
1130
1130
  >>> node = Foo()
1131
- >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1131
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1132
1132
  ...
1133
- >>> new_node = bst.graph.treefy_merge(graphdef, params, others)
1133
+ >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1134
1134
  >>> assert isinstance(new_node, Foo)
1135
- >>> assert isinstance(new_node.b, bst.nn.BatchNorm1d)
1136
- >>> assert isinstance(new_node.a, bst.nn.Linear)
1135
+ >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1136
+ >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1137
1137
 
1138
1138
  :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1139
1139
  transformations, see
@@ -1302,22 +1302,22 @@ def treefy_states(
1302
1302
 
1303
1303
  Example usage::
1304
1304
 
1305
- >>> import brainstate as bst
1306
- >>> class Model(bst.nn.Module):
1305
+ >>> import brainstate as brainstate
1306
+ >>> class Model(brainstate.nn.Module):
1307
1307
  ... def __init__(self):
1308
1308
  ... super().__init__()
1309
- ... self.l1 = bst.nn.Linear(2, 3)
1310
- ... self.l2 = bst.nn.Linear(3, 4)
1309
+ ... self.l1 = brainstate.nn.Linear(2, 3)
1310
+ ... self.l2 = brainstate.nn.Linear(3, 4)
1311
1311
  ... def __call__(self, x):
1312
1312
  ... return self.l2(self.l1(x))
1313
1313
 
1314
1314
  >>> model = Model()
1315
1315
  >>> # get the learnable parameters from the batch norm and linear layer
1316
- >>> params = bst.graph.treefy_states(model, bst.ParamState)
1316
+ >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1317
1317
  >>> # get them separately
1318
- >>> params, others = bst.graph.treefy_states(model, bst.ParamState, bst.ShortTermState)
1318
+ >>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
1319
1319
  >>> # get them together
1320
- >>> states = bst.graph.treefy_states(model)
1320
+ >>> states = brainstate.graph.treefy_states(model)
1321
1321
 
1322
1322
  Args:
1323
1323
  node: A graph node object.
@@ -1403,11 +1403,11 @@ def graphdef(node: Any, /) -> GraphDef[Any]:
1403
1403
 
1404
1404
  Example usage::
1405
1405
 
1406
- >>> import brainstate as bst
1406
+ >>> import brainstate as brainstate
1407
1407
 
1408
- >>> model = bst.nn.Linear(2, 3)
1409
- >>> graphdef, _ = bst.graph.treefy_split(model)
1410
- >>> assert graphdef == bst.graph.graphdef(model)
1408
+ >>> model = brainstate.nn.Linear(2, 3)
1409
+ >>> graphdef, _ = brainstate.graph.treefy_split(model)
1410
+ >>> assert graphdef == brainstate.graph.graphdef(model)
1411
1411
 
1412
1412
  Args:
1413
1413
  node: A graph node object.
@@ -1426,8 +1426,8 @@ def clone(node: Node) -> Node:
1426
1426
 
1427
1427
  Example usage::
1428
1428
 
1429
- >>> import brainstate as bst
1430
- >>> model = bst.nn.Linear(2, 3)
1429
+ >>> import brainstate as brainstate
1430
+ >>> model = brainstate.nn.Linear(2, 3)
1431
1431
  >>> cloned_model = clone(model)
1432
1432
  >>> model.weight.value['bias'] += 1
1433
1433
  >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
@@ -1456,15 +1456,15 @@ def call(
1456
1456
 
1457
1457
  Example::
1458
1458
 
1459
- >>> import brainstate as bst
1459
+ >>> import brainstate as brainstate
1460
1460
  >>> import jax
1461
1461
  >>> import jax.numpy as jnp
1462
1462
  ...
1463
- >>> class StatefulLinear(bst.graph.Node):
1463
+ >>> class StatefulLinear(brainstate.graph.Node):
1464
1464
  ... def __init__(self, din, dout):
1465
- ... self.w = bst.ParamState(bst.random.rand(din, dout))
1466
- ... self.b = bst.ParamState(jnp.zeros((dout,)))
1467
- ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1465
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1466
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1467
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1468
1468
  ...
1469
1469
  ... def increment(self):
1470
1470
  ... self.count.value += 1
@@ -1474,18 +1474,18 @@ def call(
1474
1474
  ... return x @ self.w.value + self.b.value
1475
1475
  ...
1476
1476
  >>> linear = StatefulLinear(3, 2)
1477
- >>> linear_state = bst.graph.treefy_split(linear)
1477
+ >>> linear_state = brainstate.graph.treefy_split(linear)
1478
1478
  ...
1479
1479
  >>> @jax.jit
1480
1480
  ... def forward(x, linear_state):
1481
- ... y, linear_state = bst.graph.call(linear_state)(x)
1481
+ ... y, linear_state = brainstate.graph.call(linear_state)(x)
1482
1482
  ... return y, linear_state
1483
1483
  ...
1484
1484
  >>> x = jnp.ones((1, 3))
1485
1485
  >>> y, linear_state = forward(x, linear_state)
1486
1486
  >>> y, linear_state = forward(x, linear_state)
1487
1487
  ...
1488
- >>> linear = bst.graph.treefy_merge(*linear_state)
1488
+ >>> linear = brainstate.graph.treefy_merge(*linear_state)
1489
1489
  >>> linear.count.value
1490
1490
  Array(2, dtype=uint32)
1491
1491
 
@@ -1494,11 +1494,11 @@ def call(
1494
1494
  is used to call the ``increment`` method of the ``StatefulLinear`` module
1495
1495
  at the ``b`` key of a ``nodes`` dictionary.
1496
1496
 
1497
- >>> class StatefulLinear(bst.graph.Node):
1497
+ >>> class StatefulLinear(brainstate.graph.Node):
1498
1498
  ... def __init__(self, din, dout):
1499
- ... self.w = bst.ParamState(bst.random.rand(din, dout))
1500
- ... self.b = bst.ParamState(jnp.zeros((dout,)))
1501
- ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1499
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1500
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1501
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1502
1502
  ...
1503
1503
  ... def increment(self):
1504
1504
  ... self.count.value += 1
@@ -1514,7 +1514,7 @@ def call(
1514
1514
  ...
1515
1515
  >>> node_state = treefy_split(nodes)
1516
1516
  >>> # use attribute access
1517
- >>> _, node_state = bst.graph.call(node_state)['b'].increment()
1517
+ >>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
1518
1518
  ...
1519
1519
  >>> nodes = treefy_merge(*node_state)
1520
1520
  >>> nodes['a'].count.value
@@ -1544,19 +1544,19 @@ def iter_leaf(
1544
1544
  root. Repeated nodes are visited only once. Leaves include static values.
1545
1545
 
1546
1546
  Example::
1547
- >>> import brainstate as bst
1547
+ >>> import brainstate as brainstate
1548
1548
  >>> import jax.numpy as jnp
1549
1549
  ...
1550
- >>> class Linear(bst.nn.Module):
1550
+ >>> class Linear(brainstate.nn.Module):
1551
1551
  ... def __init__(self, din, dout):
1552
1552
  ... super().__init__()
1553
- ... self.weight = bst.ParamState(bst.random.randn(din, dout))
1554
- ... self.bias = bst.ParamState(bst.random.randn(dout))
1553
+ ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1554
+ ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1555
1555
  ... self.a = 1
1556
1556
  ...
1557
1557
  >>> module = Linear(3, 4)
1558
1558
  ...
1559
- >>> for path, value in bst.graph.iter_leaf([module, module]):
1559
+ >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1560
1560
  ... print(path, type(value).__name__)
1561
1561
  ...
1562
1562
  (0, 'a') int
@@ -1616,21 +1616,21 @@ def iter_node(
1616
1616
  root. Repeated nodes are visited only once. Leaves include static values.
1617
1617
 
1618
1618
  Example::
1619
- >>> import brainstate as bst
1619
+ >>> import brainstate as brainstate
1620
1620
  >>> import jax.numpy as jnp
1621
1621
  ...
1622
- >>> class Model(bst.nn.Module):
1622
+ >>> class Model(brainstate.nn.Module):
1623
1623
  ... def __init__(self):
1624
1624
  ... super().__init__()
1625
- ... self.a = bst.nn.Linear(1, 2)
1626
- ... self.b = bst.nn.Linear(2, 3)
1627
- ... self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
1628
- ... self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
1629
- ... self.b.a = bst.nn.LIF(2)
1625
+ ... self.a = brainstate.nn.Linear(1, 2)
1626
+ ... self.b = brainstate.nn.Linear(2, 3)
1627
+ ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1628
+ ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1629
+ ... self.b.a = brainstate.nn.LIF(2)
1630
1630
  ...
1631
1631
  >>> model = Model()
1632
1632
  ...
1633
- >>> for path, node in bst.graph.iter_node([model, model]):
1633
+ >>> for path, node in brainstate.graph.iter_node([model, model]):
1634
1634
  ... print(path, node.__class__.__name__)
1635
1635
  ...
1636
1636
  (0, 'a') Linear
@@ -443,7 +443,7 @@ class TestGraphOperation(unittest.TestCase):
443
443
  graphdef, statetree = brainstate.graph.flatten(MyNode())
444
444
  # print(graphdef)
445
445
  print(statetree)
446
- # print(bst.graph.unflatten(graphdef, statetree))
446
+ # print(brainstate.graph.unflatten(graphdef, statetree))
447
447
 
448
448
  def test_split(self):
449
449
  class Foo(brainstate.graph.Node):
@@ -530,7 +530,7 @@ class TestGraphOperation(unittest.TestCase):
530
530
  print(graph_def)
531
531
  print(treefy_states)
532
532
 
533
- # states = bst.graph.states(model)
533
+ # states = brainstate.graph.states(model)
534
534
  # print(states)
535
535
  # nest_states = states.to_nest()
536
536
  # print(nest_states)
brainstate/mixin.py CHANGED
@@ -27,7 +27,6 @@ __all__ = [
27
27
  'ParamDescriber',
28
28
  'AlignPost',
29
29
  'BindCondData',
30
- 'UpdateReturn',
31
30
 
32
31
  # types
33
32
  'JointTypes',
@@ -171,22 +170,6 @@ def not_implemented(func):
171
170
  return wrapper
172
171
 
173
172
 
174
-
175
- class UpdateReturn(Mixin):
176
- @not_implemented
177
- def update_return(self) -> PyTree:
178
- """
179
- The update function return of the model.
180
-
181
- This function requires no parameters and must return a PyTree.
182
-
183
- It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
184
- initialize the output delay.
185
-
186
- """
187
- raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
188
-
189
-
190
173
  class _MetaUnionType(type):
191
174
  def __new__(cls, name, bases, dct):
192
175
  if isinstance(bases, type):
@@ -56,10 +56,10 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
56
56
 
57
57
  The lower the level, the earlier the function is called.
58
58
 
59
- >>> import brainstate as bst
60
- >>> bst.nn.call_order(0)
61
- >>> bst.nn.call_order(-1)
62
- >>> bst.nn.call_order(-2)
59
+ >>> import brainstate as brainstate
60
+ >>> brainstate.nn.call_order(0)
61
+ >>> brainstate.nn.call_order(-1)
62
+ >>> brainstate.nn.call_order(-2)
63
63
 
64
64
  Parameters
65
65
  ----------