brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +6 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +22 -17
  7. brainstate/augment/_mapping_test.py +162 -0
  8. brainstate/compile/_conditions.py +2 -2
  9. brainstate/compile/_make_jaxpr.py +59 -6
  10. brainstate/compile/_progress_bar.py +2 -2
  11. brainstate/environ.py +19 -19
  12. brainstate/functional/_activations_test.py +12 -12
  13. brainstate/graph/_graph_operation.py +69 -69
  14. brainstate/graph/_graph_operation_test.py +2 -2
  15. brainstate/mixin.py +0 -17
  16. brainstate/nn/_collective_ops.py +4 -4
  17. brainstate/nn/_common.py +7 -19
  18. brainstate/nn/_dropout_test.py +2 -2
  19. brainstate/nn/_dynamics.py +53 -35
  20. brainstate/nn/_elementwise.py +30 -30
  21. brainstate/nn/_exp_euler.py +13 -16
  22. brainstate/nn/_inputs.py +1 -1
  23. brainstate/nn/_linear.py +4 -4
  24. brainstate/nn/_module.py +6 -6
  25. brainstate/nn/_module_test.py +1 -1
  26. brainstate/nn/_normalizations.py +11 -11
  27. brainstate/nn/_normalizations_test.py +6 -6
  28. brainstate/nn/_poolings.py +24 -24
  29. brainstate/nn/_synapse.py +1 -12
  30. brainstate/nn/_utils.py +1 -1
  31. brainstate/nn/metrics.py +4 -4
  32. brainstate/optim/_optax_optimizer.py +8 -8
  33. brainstate/random/_rand_funs.py +37 -37
  34. brainstate/random/_rand_funs_test.py +3 -3
  35. brainstate/random/_rand_seed.py +7 -7
  36. brainstate/random/_rand_state.py +13 -7
  37. brainstate/surrogate.py +40 -40
  38. brainstate/util/pretty_pytree.py +10 -10
  39. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  40. brainstate/util/struct.py +7 -7
  41. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  42. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
  43. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  44. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -473,8 +473,8 @@ def flatten(
473
473
 
474
474
  Example::
475
475
 
476
- >>> import brainstate as bst
477
- >>> node = bst.graph.Node()
476
+ >>> import brainstate as brainstate
477
+ >>> node = brainstate.graph.Node()
478
478
  >>> graph_def, state_mapping = flatten(node)
479
479
  >>> print(graph_def)
480
480
  >>> print(state_mapping)
@@ -709,15 +709,15 @@ def unflatten(
709
709
 
710
710
  Example::
711
711
 
712
- >>> import brainstate as bst
713
- >>> class MyNode(bst.graph.Node):
712
+ >>> import brainstate as brainstate
713
+ >>> class MyNode(brainstate.graph.Node):
714
714
  ... def __init__(self):
715
- ... self.a = bst.nn.Linear(2, 3)
716
- ... self.b = bst.nn.Linear(3, 4)
717
- ... self.c = [bst.nn.Linear(4, 5), bst.nn.Linear(5, 6)]
718
- ... self.d = {'x': bst.nn.Linear(6, 7), 'y': bst.nn.Linear(7, 8)}
715
+ ... self.a = brainstate.nn.Linear(2, 3)
716
+ ... self.b = brainstate.nn.Linear(3, 4)
717
+ ... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
718
+ ... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
719
719
  ...
720
- >>> graphdef, statetree = bst.graph.flatten(MyNode())
720
+ >>> graphdef, statetree = brainstate.graph.flatten(MyNode())
721
721
  >>> statetree
722
722
  NestedDict({
723
723
  'a': {
@@ -764,7 +764,7 @@ def unflatten(
764
764
  }
765
765
  }
766
766
  })
767
- >>> node = bst.graph.unflatten(graphdef, statetree)
767
+ >>> node = brainstate.graph.unflatten(graphdef, statetree)
768
768
  >>> node
769
769
  MyNode(
770
770
  a=Linear(
@@ -942,21 +942,21 @@ def pop_states(
942
942
 
943
943
  Example usage::
944
944
 
945
- >>> import brainstate as bst
945
+ >>> import brainstate as brainstate
946
946
  >>> import jax.numpy as jnp
947
947
 
948
- >>> class Model(bst.nn.Module):
948
+ >>> class Model(brainstate.nn.Module):
949
949
  ... def __init__(self):
950
950
  ... super().__init__()
951
- ... self.a = bst.nn.Linear(2, 3)
952
- ... self.b = bst.nn.LIF([10, 2])
951
+ ... self.a = brainstate.nn.Linear(2, 3)
952
+ ... self.b = brainstate.nn.LIF([10, 2])
953
953
 
954
954
  >>> model = Model()
955
- >>> with bst.catch_new_states('new'):
956
- ... bst.nn.init_all_states(model)
955
+ >>> with brainstate.catch_new_states('new'):
956
+ ... brainstate.nn.init_all_states(model)
957
957
 
958
958
  >>> assert len(model.states()) == 2
959
- >>> model_states = bst.graph.pop_states(model, 'new')
959
+ >>> model_states = brainstate.graph.pop_states(model, 'new')
960
960
  >>> model_states
961
961
  NestedDict({
962
962
  'b': {
@@ -1046,16 +1046,16 @@ def treefy_split(
1046
1046
 
1047
1047
  Example usage::
1048
1048
 
1049
- >>> from joblib.testing import param >>> import brainstate as bst
1049
+ >>> from joblib.testing import param >>> import brainstate as brainstate
1050
1050
  >>> import jax, jax.numpy as jnp
1051
1051
  ...
1052
- >>> class Foo(bst.graph.Node):
1052
+ >>> class Foo(brainstate.graph.Node):
1053
1053
  ... def __init__(self):
1054
- ... self.a = bst.nn.BatchNorm1d([10, 2])
1055
- ... self.b = bst.nn.Linear(2, 3)
1054
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1055
+ ... self.b = brainstate.nn.Linear(2, 3)
1056
1056
  ...
1057
1057
  >>> node = Foo()
1058
- >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1058
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1059
1059
  ...
1060
1060
  >>> params
1061
1061
  NestedDict({
@@ -1119,21 +1119,21 @@ def treefy_merge(
1119
1119
 
1120
1120
  Example usage::
1121
1121
 
1122
- >>> import brainstate as bst
1122
+ >>> import brainstate as brainstate
1123
1123
  >>> import jax, jax.numpy as jnp
1124
1124
  ...
1125
- >>> class Foo(bst.graph.Node):
1125
+ >>> class Foo(brainstate.graph.Node):
1126
1126
  ... def __init__(self):
1127
- ... self.a = bst.nn.BatchNorm1d([10, 2])
1128
- ... self.b = bst.nn.Linear(2, 3)
1127
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1128
+ ... self.b = brainstate.nn.Linear(2, 3)
1129
1129
  ...
1130
1130
  >>> node = Foo()
1131
- >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1131
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1132
1132
  ...
1133
- >>> new_node = bst.graph.treefy_merge(graphdef, params, others)
1133
+ >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1134
1134
  >>> assert isinstance(new_node, Foo)
1135
- >>> assert isinstance(new_node.b, bst.nn.BatchNorm1d)
1136
- >>> assert isinstance(new_node.a, bst.nn.Linear)
1135
+ >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1136
+ >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1137
1137
 
1138
1138
  :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1139
1139
  transformations, see
@@ -1302,22 +1302,22 @@ def treefy_states(
1302
1302
 
1303
1303
  Example usage::
1304
1304
 
1305
- >>> import brainstate as bst
1306
- >>> class Model(bst.nn.Module):
1305
+ >>> import brainstate as brainstate
1306
+ >>> class Model(brainstate.nn.Module):
1307
1307
  ... def __init__(self):
1308
1308
  ... super().__init__()
1309
- ... self.l1 = bst.nn.Linear(2, 3)
1310
- ... self.l2 = bst.nn.Linear(3, 4)
1309
+ ... self.l1 = brainstate.nn.Linear(2, 3)
1310
+ ... self.l2 = brainstate.nn.Linear(3, 4)
1311
1311
  ... def __call__(self, x):
1312
1312
  ... return self.l2(self.l1(x))
1313
1313
 
1314
1314
  >>> model = Model()
1315
1315
  >>> # get the learnable parameters from the batch norm and linear layer
1316
- >>> params = bst.graph.treefy_states(model, bst.ParamState)
1316
+ >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1317
1317
  >>> # get them separately
1318
- >>> params, others = bst.graph.treefy_states(model, bst.ParamState, bst.ShortTermState)
1318
+ >>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
1319
1319
  >>> # get them together
1320
- >>> states = bst.graph.treefy_states(model)
1320
+ >>> states = brainstate.graph.treefy_states(model)
1321
1321
 
1322
1322
  Args:
1323
1323
  node: A graph node object.
@@ -1403,11 +1403,11 @@ def graphdef(node: Any, /) -> GraphDef[Any]:
1403
1403
 
1404
1404
  Example usage::
1405
1405
 
1406
- >>> import brainstate as bst
1406
+ >>> import brainstate as brainstate
1407
1407
 
1408
- >>> model = bst.nn.Linear(2, 3)
1409
- >>> graphdef, _ = bst.graph.treefy_split(model)
1410
- >>> assert graphdef == bst.graph.graphdef(model)
1408
+ >>> model = brainstate.nn.Linear(2, 3)
1409
+ >>> graphdef, _ = brainstate.graph.treefy_split(model)
1410
+ >>> assert graphdef == brainstate.graph.graphdef(model)
1411
1411
 
1412
1412
  Args:
1413
1413
  node: A graph node object.
@@ -1426,8 +1426,8 @@ def clone(node: Node) -> Node:
1426
1426
 
1427
1427
  Example usage::
1428
1428
 
1429
- >>> import brainstate as bst
1430
- >>> model = bst.nn.Linear(2, 3)
1429
+ >>> import brainstate as brainstate
1430
+ >>> model = brainstate.nn.Linear(2, 3)
1431
1431
  >>> cloned_model = clone(model)
1432
1432
  >>> model.weight.value['bias'] += 1
1433
1433
  >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
@@ -1456,15 +1456,15 @@ def call(
1456
1456
 
1457
1457
  Example::
1458
1458
 
1459
- >>> import brainstate as bst
1459
+ >>> import brainstate as brainstate
1460
1460
  >>> import jax
1461
1461
  >>> import jax.numpy as jnp
1462
1462
  ...
1463
- >>> class StatefulLinear(bst.graph.Node):
1463
+ >>> class StatefulLinear(brainstate.graph.Node):
1464
1464
  ... def __init__(self, din, dout):
1465
- ... self.w = bst.ParamState(bst.random.rand(din, dout))
1466
- ... self.b = bst.ParamState(jnp.zeros((dout,)))
1467
- ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1465
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1466
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1467
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1468
1468
  ...
1469
1469
  ... def increment(self):
1470
1470
  ... self.count.value += 1
@@ -1474,18 +1474,18 @@ def call(
1474
1474
  ... return x @ self.w.value + self.b.value
1475
1475
  ...
1476
1476
  >>> linear = StatefulLinear(3, 2)
1477
- >>> linear_state = bst.graph.treefy_split(linear)
1477
+ >>> linear_state = brainstate.graph.treefy_split(linear)
1478
1478
  ...
1479
1479
  >>> @jax.jit
1480
1480
  ... def forward(x, linear_state):
1481
- ... y, linear_state = bst.graph.call(linear_state)(x)
1481
+ ... y, linear_state = brainstate.graph.call(linear_state)(x)
1482
1482
  ... return y, linear_state
1483
1483
  ...
1484
1484
  >>> x = jnp.ones((1, 3))
1485
1485
  >>> y, linear_state = forward(x, linear_state)
1486
1486
  >>> y, linear_state = forward(x, linear_state)
1487
1487
  ...
1488
- >>> linear = bst.graph.treefy_merge(*linear_state)
1488
+ >>> linear = brainstate.graph.treefy_merge(*linear_state)
1489
1489
  >>> linear.count.value
1490
1490
  Array(2, dtype=uint32)
1491
1491
 
@@ -1494,11 +1494,11 @@ def call(
1494
1494
  is used to call the ``increment`` method of the ``StatefulLinear`` module
1495
1495
  at the ``b`` key of a ``nodes`` dictionary.
1496
1496
 
1497
- >>> class StatefulLinear(bst.graph.Node):
1497
+ >>> class StatefulLinear(brainstate.graph.Node):
1498
1498
  ... def __init__(self, din, dout):
1499
- ... self.w = bst.ParamState(bst.random.rand(din, dout))
1500
- ... self.b = bst.ParamState(jnp.zeros((dout,)))
1501
- ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1499
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1500
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1501
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1502
1502
  ...
1503
1503
  ... def increment(self):
1504
1504
  ... self.count.value += 1
@@ -1514,7 +1514,7 @@ def call(
1514
1514
  ...
1515
1515
  >>> node_state = treefy_split(nodes)
1516
1516
  >>> # use attribute access
1517
- >>> _, node_state = bst.graph.call(node_state)['b'].increment()
1517
+ >>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
1518
1518
  ...
1519
1519
  >>> nodes = treefy_merge(*node_state)
1520
1520
  >>> nodes['a'].count.value
@@ -1544,19 +1544,19 @@ def iter_leaf(
1544
1544
  root. Repeated nodes are visited only once. Leaves include static values.
1545
1545
 
1546
1546
  Example::
1547
- >>> import brainstate as bst
1547
+ >>> import brainstate as brainstate
1548
1548
  >>> import jax.numpy as jnp
1549
1549
  ...
1550
- >>> class Linear(bst.nn.Module):
1550
+ >>> class Linear(brainstate.nn.Module):
1551
1551
  ... def __init__(self, din, dout):
1552
1552
  ... super().__init__()
1553
- ... self.weight = bst.ParamState(bst.random.randn(din, dout))
1554
- ... self.bias = bst.ParamState(bst.random.randn(dout))
1553
+ ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1554
+ ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1555
1555
  ... self.a = 1
1556
1556
  ...
1557
1557
  >>> module = Linear(3, 4)
1558
1558
  ...
1559
- >>> for path, value in bst.graph.iter_leaf([module, module]):
1559
+ >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1560
1560
  ... print(path, type(value).__name__)
1561
1561
  ...
1562
1562
  (0, 'a') int
@@ -1616,21 +1616,21 @@ def iter_node(
1616
1616
  root. Repeated nodes are visited only once. Leaves include static values.
1617
1617
 
1618
1618
  Example::
1619
- >>> import brainstate as bst
1619
+ >>> import brainstate as brainstate
1620
1620
  >>> import jax.numpy as jnp
1621
1621
  ...
1622
- >>> class Model(bst.nn.Module):
1622
+ >>> class Model(brainstate.nn.Module):
1623
1623
  ... def __init__(self):
1624
1624
  ... super().__init__()
1625
- ... self.a = bst.nn.Linear(1, 2)
1626
- ... self.b = bst.nn.Linear(2, 3)
1627
- ... self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
1628
- ... self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
1629
- ... self.b.a = bst.nn.LIF(2)
1625
+ ... self.a = brainstate.nn.Linear(1, 2)
1626
+ ... self.b = brainstate.nn.Linear(2, 3)
1627
+ ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1628
+ ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1629
+ ... self.b.a = brainstate.nn.LIF(2)
1630
1630
  ...
1631
1631
  >>> model = Model()
1632
1632
  ...
1633
- >>> for path, node in bst.graph.iter_node([model, model]):
1633
+ >>> for path, node in brainstate.graph.iter_node([model, model]):
1634
1634
  ... print(path, node.__class__.__name__)
1635
1635
  ...
1636
1636
  (0, 'a') Linear
@@ -443,7 +443,7 @@ class TestGraphOperation(unittest.TestCase):
443
443
  graphdef, statetree = brainstate.graph.flatten(MyNode())
444
444
  # print(graphdef)
445
445
  print(statetree)
446
- # print(bst.graph.unflatten(graphdef, statetree))
446
+ # print(brainstate.graph.unflatten(graphdef, statetree))
447
447
 
448
448
  def test_split(self):
449
449
  class Foo(brainstate.graph.Node):
@@ -530,7 +530,7 @@ class TestGraphOperation(unittest.TestCase):
530
530
  print(graph_def)
531
531
  print(treefy_states)
532
532
 
533
- # states = bst.graph.states(model)
533
+ # states = brainstate.graph.states(model)
534
534
  # print(states)
535
535
  # nest_states = states.to_nest()
536
536
  # print(nest_states)
brainstate/mixin.py CHANGED
@@ -27,7 +27,6 @@ __all__ = [
27
27
  'ParamDescriber',
28
28
  'AlignPost',
29
29
  'BindCondData',
30
- 'UpdateReturn',
31
30
 
32
31
  # types
33
32
  'JointTypes',
@@ -171,22 +170,6 @@ def not_implemented(func):
171
170
  return wrapper
172
171
 
173
172
 
174
-
175
- class UpdateReturn(Mixin):
176
- @not_implemented
177
- def update_return(self) -> PyTree:
178
- """
179
- The update function return of the model.
180
-
181
- This function requires no parameters and must return a PyTree.
182
-
183
- It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
184
- initialize the output delay.
185
-
186
- """
187
- raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
188
-
189
-
190
173
  class _MetaUnionType(type):
191
174
  def __new__(cls, name, bases, dct):
192
175
  if isinstance(bases, type):
@@ -56,10 +56,10 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
56
56
 
57
57
  The lower the level, the earlier the function is called.
58
58
 
59
- >>> import brainstate as bst
60
- >>> bst.nn.call_order(0)
61
- >>> bst.nn.call_order(-1)
62
- >>> bst.nn.call_order(-2)
59
+ >>> import brainstate as brainstate
60
+ >>> brainstate.nn.call_order(0)
61
+ >>> brainstate.nn.call_order(-1)
62
+ >>> brainstate.nn.call_order(-2)
63
63
 
64
64
  Parameters
65
65
  ----------
brainstate/nn/_common.py CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
118
118
  This class wraps a module and applies vectorized mapping to its execution,
119
119
  allowing for efficient parallel processing across specified axes.
120
120
 
121
- Attributes:
121
+ Args:
122
122
  module (Module): The module to be vmapped.
123
- in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
124
- out_axes (Any): Specifies how to map over outputs.
125
- vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
126
- vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
127
- axis_name (AxisName | None): Name of the axis being mapped over.
128
- axis_size (int | None): Size of the axis being mapped over.
123
+ in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
124
+ out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
125
+ vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
126
+ vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
127
+ axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
128
+ axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
129
129
  """
130
130
 
131
131
  def __init__(
@@ -138,18 +138,6 @@ class Vmap(Module):
138
138
  axis_name: AxisName | None = None,
139
139
  axis_size: int | None = None,
140
140
  ):
141
- """
142
- Initialize the Vmap instance.
143
-
144
- Args:
145
- module (Module): The module to be vmapped.
146
- in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
147
- out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
148
- vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
149
- vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
150
- axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
151
- axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
152
- """
153
141
  super().__init__()
154
142
 
155
143
  # parameters
@@ -60,9 +60,9 @@ class TestDropout(unittest.TestCase):
60
60
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
61
61
 
62
62
  # def test_Dropout1d(self):
63
- # dropout_layer = bst.nn.Dropout1d(prob=0.5)
63
+ # dropout_layer = brainstate.nn.Dropout1d(prob=0.5)
64
64
  # input_data = np.random.randn(2, 3, 4)
65
- # with bst.environ.context(fit=True):
65
+ # with brainstate.environ.context(fit=True):
66
66
  # output_data = dropout_layer(input_data)
67
67
  # self.assertEqual(input_data.shape, output_data.shape)
68
68
  # self.assertTrue(np.any(output_data == 0))
@@ -42,7 +42,7 @@ import numpy as np
42
42
  from brainstate import environ
43
43
  from brainstate._state import State
44
44
  from brainstate.graph import Node
45
- from brainstate.mixin import ParamDescriber, UpdateReturn
45
+ from brainstate.mixin import ParamDescriber
46
46
  from brainstate.typing import Size, ArrayLike, PyTree
47
47
  from ._delay import StateWithDelay, Delay
48
48
  from ._module import Module
@@ -101,7 +101,7 @@ class Projection(Module):
101
101
  raise ValueError('Do not implement the update() function.')
102
102
 
103
103
 
104
- class Dynamics(Module, UpdateReturn):
104
+ class Dynamics(Module):
105
105
  """
106
106
  Base class for implementing neural dynamics models in BrainState.
107
107
 
@@ -214,13 +214,13 @@ class Dynamics(Module, UpdateReturn):
214
214
  # in-/out- size of neuron population
215
215
  self.out_size = self.in_size
216
216
 
217
- def __pretty_repr_item__(self, name, value):
218
- if name in [
219
- '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
220
- '_in_size', '_out_size', '_name', '_mode',
221
- ]:
222
- return (name, value) if value is None else (name[1:], value) # skip the first `_`
223
- return super().__pretty_repr_item__(name, value)
217
+ # def __pretty_repr_item__(self, name, value):
218
+ # if name in [
219
+ # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
220
+ # '_in_size', '_out_size', '_name', '_mode',
221
+ # ]:
222
+ # return (name, value) if value is None else (name[1:], value) # skip the first `_`
223
+ # return super().__pretty_repr_item__(name, value)
224
224
 
225
225
  @property
226
226
  def varshape(self):
@@ -470,21 +470,30 @@ class Dynamics(Module, UpdateReturn):
470
470
  if self._current_inputs is None:
471
471
  return init
472
472
  if label is None:
473
- # no label
474
- for key in tuple(self._current_inputs.keys()):
475
- out = self._current_inputs[key]
476
- init = init + (out(*args, **kwargs) if callable(out) else out)
477
- if not callable(out):
478
- self._current_inputs.pop(key)
473
+ filter_fn = lambda k: True
479
474
  else:
480
- # has label
481
475
  label_repr = _input_label_start(label)
482
- for key in tuple(self._current_inputs.keys()):
483
- if key.startswith(label_repr):
484
- out = self._current_inputs[key]
485
- init = init + (out(*args, **kwargs) if callable(out) else out)
486
- if not callable(out):
487
- self._current_inputs.pop(key)
476
+ filter_fn = lambda k: k.startswith(label_repr)
477
+ for key in tuple(self._current_inputs.keys()):
478
+ if filter_fn(key):
479
+ out = self._current_inputs[key]
480
+ if callable(out):
481
+ try:
482
+ init = init + out(*args, **kwargs)
483
+ except Exception as e:
484
+ raise ValueError(
485
+ f'Error in delta input value {key}: {out}\n'
486
+ f'Error: {e}'
487
+ ) from e
488
+ else:
489
+ try:
490
+ init = init + out
491
+ except Exception as e:
492
+ raise ValueError(
493
+ f'Error in delta input value {key}: {out}\n'
494
+ f'Error: {e}'
495
+ ) from e
496
+ self._current_inputs.pop(key)
488
497
  return init
489
498
 
490
499
  def sum_delta_inputs(
@@ -529,21 +538,30 @@ class Dynamics(Module, UpdateReturn):
529
538
  if self._delta_inputs is None:
530
539
  return init
531
540
  if label is None:
532
- # no label
533
- for key in tuple(self._delta_inputs.keys()):
534
- out = self._delta_inputs[key]
535
- init = init + (out(*args, **kwargs) if callable(out) else out)
536
- if not callable(out):
537
- self._delta_inputs.pop(key)
541
+ filter_fn = lambda k: True
538
542
  else:
539
- # has label
540
543
  label_repr = _input_label_start(label)
541
- for key in tuple(self._delta_inputs.keys()):
542
- if key.startswith(label_repr):
543
- out = self._delta_inputs[key]
544
- init = init + (out(*args, **kwargs) if callable(out) else out)
545
- if not callable(out):
546
- self._delta_inputs.pop(key)
544
+ filter_fn = lambda k: k.startswith(label_repr)
545
+ for key in tuple(self._delta_inputs.keys()):
546
+ if filter_fn(key):
547
+ out = self._delta_inputs[key]
548
+ if callable(out):
549
+ try:
550
+ init = init + out(*args, **kwargs)
551
+ except Exception as e:
552
+ raise ValueError(
553
+ f'Error in delta input function {key}: {out}\n'
554
+ f'Error: {e}'
555
+ ) from e
556
+ else:
557
+ try:
558
+ init = init + out
559
+ except Exception as e:
560
+ raise ValueError(
561
+ f'Error in delta input value {key}: {out}\n'
562
+ f'Error: {e}'
563
+ ) from e
564
+ self._delta_inputs.pop(key)
547
565
  return init
548
566
 
549
567
  @property