brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +6 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +22 -17
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -473,8 +473,8 @@ def flatten(
|
|
473
473
|
|
474
474
|
Example::
|
475
475
|
|
476
|
-
>>> import brainstate as
|
477
|
-
>>> node =
|
476
|
+
>>> import brainstate as brainstate
|
477
|
+
>>> node = brainstate.graph.Node()
|
478
478
|
>>> graph_def, state_mapping = flatten(node)
|
479
479
|
>>> print(graph_def)
|
480
480
|
>>> print(state_mapping)
|
@@ -709,15 +709,15 @@ def unflatten(
|
|
709
709
|
|
710
710
|
Example::
|
711
711
|
|
712
|
-
>>> import brainstate as
|
713
|
-
>>> class MyNode(
|
712
|
+
>>> import brainstate as brainstate
|
713
|
+
>>> class MyNode(brainstate.graph.Node):
|
714
714
|
... def __init__(self):
|
715
|
-
... self.a =
|
716
|
-
... self.b =
|
717
|
-
... self.c = [
|
718
|
-
... self.d = {'x':
|
715
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
716
|
+
... self.b = brainstate.nn.Linear(3, 4)
|
717
|
+
... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
|
718
|
+
... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
|
719
719
|
...
|
720
|
-
>>> graphdef, statetree =
|
720
|
+
>>> graphdef, statetree = brainstate.graph.flatten(MyNode())
|
721
721
|
>>> statetree
|
722
722
|
NestedDict({
|
723
723
|
'a': {
|
@@ -764,7 +764,7 @@ def unflatten(
|
|
764
764
|
}
|
765
765
|
}
|
766
766
|
})
|
767
|
-
>>> node =
|
767
|
+
>>> node = brainstate.graph.unflatten(graphdef, statetree)
|
768
768
|
>>> node
|
769
769
|
MyNode(
|
770
770
|
a=Linear(
|
@@ -942,21 +942,21 @@ def pop_states(
|
|
942
942
|
|
943
943
|
Example usage::
|
944
944
|
|
945
|
-
>>> import brainstate as
|
945
|
+
>>> import brainstate as brainstate
|
946
946
|
>>> import jax.numpy as jnp
|
947
947
|
|
948
|
-
>>> class Model(
|
948
|
+
>>> class Model(brainstate.nn.Module):
|
949
949
|
... def __init__(self):
|
950
950
|
... super().__init__()
|
951
|
-
... self.a =
|
952
|
-
... self.b =
|
951
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
952
|
+
... self.b = brainstate.nn.LIF([10, 2])
|
953
953
|
|
954
954
|
>>> model = Model()
|
955
|
-
>>> with
|
956
|
-
...
|
955
|
+
>>> with brainstate.catch_new_states('new'):
|
956
|
+
... brainstate.nn.init_all_states(model)
|
957
957
|
|
958
958
|
>>> assert len(model.states()) == 2
|
959
|
-
>>> model_states =
|
959
|
+
>>> model_states = brainstate.graph.pop_states(model, 'new')
|
960
960
|
>>> model_states
|
961
961
|
NestedDict({
|
962
962
|
'b': {
|
@@ -1046,16 +1046,16 @@ def treefy_split(
|
|
1046
1046
|
|
1047
1047
|
Example usage::
|
1048
1048
|
|
1049
|
-
>>> from joblib.testing import param >>> import brainstate as
|
1049
|
+
>>> from joblib.testing import param >>> import brainstate as brainstate
|
1050
1050
|
>>> import jax, jax.numpy as jnp
|
1051
1051
|
...
|
1052
|
-
>>> class Foo(
|
1052
|
+
>>> class Foo(brainstate.graph.Node):
|
1053
1053
|
... def __init__(self):
|
1054
|
-
... self.a =
|
1055
|
-
... self.b =
|
1054
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1055
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1056
1056
|
...
|
1057
1057
|
>>> node = Foo()
|
1058
|
-
>>> graphdef, params, others =
|
1058
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
1059
1059
|
...
|
1060
1060
|
>>> params
|
1061
1061
|
NestedDict({
|
@@ -1119,21 +1119,21 @@ def treefy_merge(
|
|
1119
1119
|
|
1120
1120
|
Example usage::
|
1121
1121
|
|
1122
|
-
>>> import brainstate as
|
1122
|
+
>>> import brainstate as brainstate
|
1123
1123
|
>>> import jax, jax.numpy as jnp
|
1124
1124
|
...
|
1125
|
-
>>> class Foo(
|
1125
|
+
>>> class Foo(brainstate.graph.Node):
|
1126
1126
|
... def __init__(self):
|
1127
|
-
... self.a =
|
1128
|
-
... self.b =
|
1127
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1128
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1129
1129
|
...
|
1130
1130
|
>>> node = Foo()
|
1131
|
-
>>> graphdef, params, others =
|
1131
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
1132
1132
|
...
|
1133
|
-
>>> new_node =
|
1133
|
+
>>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
1134
1134
|
>>> assert isinstance(new_node, Foo)
|
1135
|
-
>>> assert isinstance(new_node.b,
|
1136
|
-
>>> assert isinstance(new_node.a,
|
1135
|
+
>>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
1136
|
+
>>> assert isinstance(new_node.a, brainstate.nn.Linear)
|
1137
1137
|
|
1138
1138
|
:func:`split` and :func:`merge` are primarily used to interact directly with JAX
|
1139
1139
|
transformations, see
|
@@ -1302,22 +1302,22 @@ def treefy_states(
|
|
1302
1302
|
|
1303
1303
|
Example usage::
|
1304
1304
|
|
1305
|
-
>>> import brainstate as
|
1306
|
-
>>> class Model(
|
1305
|
+
>>> import brainstate as brainstate
|
1306
|
+
>>> class Model(brainstate.nn.Module):
|
1307
1307
|
... def __init__(self):
|
1308
1308
|
... super().__init__()
|
1309
|
-
... self.l1 =
|
1310
|
-
... self.l2 =
|
1309
|
+
... self.l1 = brainstate.nn.Linear(2, 3)
|
1310
|
+
... self.l2 = brainstate.nn.Linear(3, 4)
|
1311
1311
|
... def __call__(self, x):
|
1312
1312
|
... return self.l2(self.l1(x))
|
1313
1313
|
|
1314
1314
|
>>> model = Model()
|
1315
1315
|
>>> # get the learnable parameters from the batch norm and linear layer
|
1316
|
-
>>> params =
|
1316
|
+
>>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
|
1317
1317
|
>>> # get them separately
|
1318
|
-
>>> params, others =
|
1318
|
+
>>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
|
1319
1319
|
>>> # get them together
|
1320
|
-
>>> states =
|
1320
|
+
>>> states = brainstate.graph.treefy_states(model)
|
1321
1321
|
|
1322
1322
|
Args:
|
1323
1323
|
node: A graph node object.
|
@@ -1403,11 +1403,11 @@ def graphdef(node: Any, /) -> GraphDef[Any]:
|
|
1403
1403
|
|
1404
1404
|
Example usage::
|
1405
1405
|
|
1406
|
-
>>> import brainstate as
|
1406
|
+
>>> import brainstate as brainstate
|
1407
1407
|
|
1408
|
-
>>> model =
|
1409
|
-
>>> graphdef, _ =
|
1410
|
-
>>> assert graphdef ==
|
1408
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1409
|
+
>>> graphdef, _ = brainstate.graph.treefy_split(model)
|
1410
|
+
>>> assert graphdef == brainstate.graph.graphdef(model)
|
1411
1411
|
|
1412
1412
|
Args:
|
1413
1413
|
node: A graph node object.
|
@@ -1426,8 +1426,8 @@ def clone(node: Node) -> Node:
|
|
1426
1426
|
|
1427
1427
|
Example usage::
|
1428
1428
|
|
1429
|
-
>>> import brainstate as
|
1430
|
-
>>> model =
|
1429
|
+
>>> import brainstate as brainstate
|
1430
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1431
1431
|
>>> cloned_model = clone(model)
|
1432
1432
|
>>> model.weight.value['bias'] += 1
|
1433
1433
|
>>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
|
@@ -1456,15 +1456,15 @@ def call(
|
|
1456
1456
|
|
1457
1457
|
Example::
|
1458
1458
|
|
1459
|
-
>>> import brainstate as
|
1459
|
+
>>> import brainstate as brainstate
|
1460
1460
|
>>> import jax
|
1461
1461
|
>>> import jax.numpy as jnp
|
1462
1462
|
...
|
1463
|
-
>>> class StatefulLinear(
|
1463
|
+
>>> class StatefulLinear(brainstate.graph.Node):
|
1464
1464
|
... def __init__(self, din, dout):
|
1465
|
-
... self.w =
|
1466
|
-
... self.b =
|
1467
|
-
... self.count =
|
1465
|
+
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1466
|
+
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1467
|
+
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1468
1468
|
...
|
1469
1469
|
... def increment(self):
|
1470
1470
|
... self.count.value += 1
|
@@ -1474,18 +1474,18 @@ def call(
|
|
1474
1474
|
... return x @ self.w.value + self.b.value
|
1475
1475
|
...
|
1476
1476
|
>>> linear = StatefulLinear(3, 2)
|
1477
|
-
>>> linear_state =
|
1477
|
+
>>> linear_state = brainstate.graph.treefy_split(linear)
|
1478
1478
|
...
|
1479
1479
|
>>> @jax.jit
|
1480
1480
|
... def forward(x, linear_state):
|
1481
|
-
... y, linear_state =
|
1481
|
+
... y, linear_state = brainstate.graph.call(linear_state)(x)
|
1482
1482
|
... return y, linear_state
|
1483
1483
|
...
|
1484
1484
|
>>> x = jnp.ones((1, 3))
|
1485
1485
|
>>> y, linear_state = forward(x, linear_state)
|
1486
1486
|
>>> y, linear_state = forward(x, linear_state)
|
1487
1487
|
...
|
1488
|
-
>>> linear =
|
1488
|
+
>>> linear = brainstate.graph.treefy_merge(*linear_state)
|
1489
1489
|
>>> linear.count.value
|
1490
1490
|
Array(2, dtype=uint32)
|
1491
1491
|
|
@@ -1494,11 +1494,11 @@ def call(
|
|
1494
1494
|
is used to call the ``increment`` method of the ``StatefulLinear`` module
|
1495
1495
|
at the ``b`` key of a ``nodes`` dictionary.
|
1496
1496
|
|
1497
|
-
>>> class StatefulLinear(
|
1497
|
+
>>> class StatefulLinear(brainstate.graph.Node):
|
1498
1498
|
... def __init__(self, din, dout):
|
1499
|
-
... self.w =
|
1500
|
-
... self.b =
|
1501
|
-
... self.count =
|
1499
|
+
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1500
|
+
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1501
|
+
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1502
1502
|
...
|
1503
1503
|
... def increment(self):
|
1504
1504
|
... self.count.value += 1
|
@@ -1514,7 +1514,7 @@ def call(
|
|
1514
1514
|
...
|
1515
1515
|
>>> node_state = treefy_split(nodes)
|
1516
1516
|
>>> # use attribute access
|
1517
|
-
>>> _, node_state =
|
1517
|
+
>>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
|
1518
1518
|
...
|
1519
1519
|
>>> nodes = treefy_merge(*node_state)
|
1520
1520
|
>>> nodes['a'].count.value
|
@@ -1544,19 +1544,19 @@ def iter_leaf(
|
|
1544
1544
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1545
1545
|
|
1546
1546
|
Example::
|
1547
|
-
>>> import brainstate as
|
1547
|
+
>>> import brainstate as brainstate
|
1548
1548
|
>>> import jax.numpy as jnp
|
1549
1549
|
...
|
1550
|
-
>>> class Linear(
|
1550
|
+
>>> class Linear(brainstate.nn.Module):
|
1551
1551
|
... def __init__(self, din, dout):
|
1552
1552
|
... super().__init__()
|
1553
|
-
... self.weight =
|
1554
|
-
... self.bias =
|
1553
|
+
... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
1554
|
+
... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
1555
1555
|
... self.a = 1
|
1556
1556
|
...
|
1557
1557
|
>>> module = Linear(3, 4)
|
1558
1558
|
...
|
1559
|
-
>>> for path, value in
|
1559
|
+
>>> for path, value in brainstate.graph.iter_leaf([module, module]):
|
1560
1560
|
... print(path, type(value).__name__)
|
1561
1561
|
...
|
1562
1562
|
(0, 'a') int
|
@@ -1616,21 +1616,21 @@ def iter_node(
|
|
1616
1616
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1617
1617
|
|
1618
1618
|
Example::
|
1619
|
-
>>> import brainstate as
|
1619
|
+
>>> import brainstate as brainstate
|
1620
1620
|
>>> import jax.numpy as jnp
|
1621
1621
|
...
|
1622
|
-
>>> class Model(
|
1622
|
+
>>> class Model(brainstate.nn.Module):
|
1623
1623
|
... def __init__(self):
|
1624
1624
|
... super().__init__()
|
1625
|
-
... self.a =
|
1626
|
-
... self.b =
|
1627
|
-
... self.c = [
|
1628
|
-
... self.d = {'x':
|
1629
|
-
... self.b.a =
|
1625
|
+
... self.a = brainstate.nn.Linear(1, 2)
|
1626
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1627
|
+
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1628
|
+
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1629
|
+
... self.b.a = brainstate.nn.LIF(2)
|
1630
1630
|
...
|
1631
1631
|
>>> model = Model()
|
1632
1632
|
...
|
1633
|
-
>>> for path, node in
|
1633
|
+
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1634
1634
|
... print(path, node.__class__.__name__)
|
1635
1635
|
...
|
1636
1636
|
(0, 'a') Linear
|
@@ -443,7 +443,7 @@ class TestGraphOperation(unittest.TestCase):
|
|
443
443
|
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
444
444
|
# print(graphdef)
|
445
445
|
print(statetree)
|
446
|
-
# print(
|
446
|
+
# print(brainstate.graph.unflatten(graphdef, statetree))
|
447
447
|
|
448
448
|
def test_split(self):
|
449
449
|
class Foo(brainstate.graph.Node):
|
@@ -530,7 +530,7 @@ class TestGraphOperation(unittest.TestCase):
|
|
530
530
|
print(graph_def)
|
531
531
|
print(treefy_states)
|
532
532
|
|
533
|
-
# states =
|
533
|
+
# states = brainstate.graph.states(model)
|
534
534
|
# print(states)
|
535
535
|
# nest_states = states.to_nest()
|
536
536
|
# print(nest_states)
|
brainstate/mixin.py
CHANGED
@@ -27,7 +27,6 @@ __all__ = [
|
|
27
27
|
'ParamDescriber',
|
28
28
|
'AlignPost',
|
29
29
|
'BindCondData',
|
30
|
-
'UpdateReturn',
|
31
30
|
|
32
31
|
# types
|
33
32
|
'JointTypes',
|
@@ -171,22 +170,6 @@ def not_implemented(func):
|
|
171
170
|
return wrapper
|
172
171
|
|
173
172
|
|
174
|
-
|
175
|
-
class UpdateReturn(Mixin):
|
176
|
-
@not_implemented
|
177
|
-
def update_return(self) -> PyTree:
|
178
|
-
"""
|
179
|
-
The update function return of the model.
|
180
|
-
|
181
|
-
This function requires no parameters and must return a PyTree.
|
182
|
-
|
183
|
-
It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
|
184
|
-
initialize the output delay.
|
185
|
-
|
186
|
-
"""
|
187
|
-
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
188
|
-
|
189
|
-
|
190
173
|
class _MetaUnionType(type):
|
191
174
|
def __new__(cls, name, bases, dct):
|
192
175
|
if isinstance(bases, type):
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -56,10 +56,10 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
|
|
56
56
|
|
57
57
|
The lower the level, the earlier the function is called.
|
58
58
|
|
59
|
-
>>> import brainstate as
|
60
|
-
>>>
|
61
|
-
>>>
|
62
|
-
>>>
|
59
|
+
>>> import brainstate as brainstate
|
60
|
+
>>> brainstate.nn.call_order(0)
|
61
|
+
>>> brainstate.nn.call_order(-1)
|
62
|
+
>>> brainstate.nn.call_order(-2)
|
63
63
|
|
64
64
|
Parameters
|
65
65
|
----------
|
brainstate/nn/_common.py
CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
|
|
118
118
|
This class wraps a module and applies vectorized mapping to its execution,
|
119
119
|
allowing for efficient parallel processing across specified axes.
|
120
120
|
|
121
|
-
|
121
|
+
Args:
|
122
122
|
module (Module): The module to be vmapped.
|
123
|
-
in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
|
124
|
-
out_axes (Any): Specifies how to map over outputs.
|
125
|
-
vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
|
126
|
-
vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
|
127
|
-
axis_name (AxisName | None): Name of the axis being mapped over.
|
128
|
-
axis_size (int | None): Size of the axis being mapped over.
|
123
|
+
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
124
|
+
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
125
|
+
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
126
|
+
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
127
|
+
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
128
|
+
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
129
129
|
"""
|
130
130
|
|
131
131
|
def __init__(
|
@@ -138,18 +138,6 @@ class Vmap(Module):
|
|
138
138
|
axis_name: AxisName | None = None,
|
139
139
|
axis_size: int | None = None,
|
140
140
|
):
|
141
|
-
"""
|
142
|
-
Initialize the Vmap instance.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
module (Module): The module to be vmapped.
|
146
|
-
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
147
|
-
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
148
|
-
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
149
|
-
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
150
|
-
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
151
|
-
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
152
|
-
"""
|
153
141
|
super().__init__()
|
154
142
|
|
155
143
|
# parameters
|
brainstate/nn/_dropout_test.py
CHANGED
@@ -60,9 +60,9 @@ class TestDropout(unittest.TestCase):
|
|
60
60
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
61
61
|
|
62
62
|
# def test_Dropout1d(self):
|
63
|
-
# dropout_layer =
|
63
|
+
# dropout_layer = brainstate.nn.Dropout1d(prob=0.5)
|
64
64
|
# input_data = np.random.randn(2, 3, 4)
|
65
|
-
# with
|
65
|
+
# with brainstate.environ.context(fit=True):
|
66
66
|
# output_data = dropout_layer(input_data)
|
67
67
|
# self.assertEqual(input_data.shape, output_data.shape)
|
68
68
|
# self.assertTrue(np.any(output_data == 0))
|
brainstate/nn/_dynamics.py
CHANGED
@@ -42,7 +42,7 @@ import numpy as np
|
|
42
42
|
from brainstate import environ
|
43
43
|
from brainstate._state import State
|
44
44
|
from brainstate.graph import Node
|
45
|
-
from brainstate.mixin import ParamDescriber
|
45
|
+
from brainstate.mixin import ParamDescriber
|
46
46
|
from brainstate.typing import Size, ArrayLike, PyTree
|
47
47
|
from ._delay import StateWithDelay, Delay
|
48
48
|
from ._module import Module
|
@@ -101,7 +101,7 @@ class Projection(Module):
|
|
101
101
|
raise ValueError('Do not implement the update() function.')
|
102
102
|
|
103
103
|
|
104
|
-
class Dynamics(Module
|
104
|
+
class Dynamics(Module):
|
105
105
|
"""
|
106
106
|
Base class for implementing neural dynamics models in BrainState.
|
107
107
|
|
@@ -214,13 +214,13 @@ class Dynamics(Module, UpdateReturn):
|
|
214
214
|
# in-/out- size of neuron population
|
215
215
|
self.out_size = self.in_size
|
216
216
|
|
217
|
-
def __pretty_repr_item__(self, name, value):
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
217
|
+
# def __pretty_repr_item__(self, name, value):
|
218
|
+
# if name in [
|
219
|
+
# '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
|
220
|
+
# '_in_size', '_out_size', '_name', '_mode',
|
221
|
+
# ]:
|
222
|
+
# return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
223
|
+
# return super().__pretty_repr_item__(name, value)
|
224
224
|
|
225
225
|
@property
|
226
226
|
def varshape(self):
|
@@ -470,21 +470,30 @@ class Dynamics(Module, UpdateReturn):
|
|
470
470
|
if self._current_inputs is None:
|
471
471
|
return init
|
472
472
|
if label is None:
|
473
|
-
|
474
|
-
for key in tuple(self._current_inputs.keys()):
|
475
|
-
out = self._current_inputs[key]
|
476
|
-
init = init + (out(*args, **kwargs) if callable(out) else out)
|
477
|
-
if not callable(out):
|
478
|
-
self._current_inputs.pop(key)
|
473
|
+
filter_fn = lambda k: True
|
479
474
|
else:
|
480
|
-
# has label
|
481
475
|
label_repr = _input_label_start(label)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
476
|
+
filter_fn = lambda k: k.startswith(label_repr)
|
477
|
+
for key in tuple(self._current_inputs.keys()):
|
478
|
+
if filter_fn(key):
|
479
|
+
out = self._current_inputs[key]
|
480
|
+
if callable(out):
|
481
|
+
try:
|
482
|
+
init = init + out(*args, **kwargs)
|
483
|
+
except Exception as e:
|
484
|
+
raise ValueError(
|
485
|
+
f'Error in delta input value {key}: {out}\n'
|
486
|
+
f'Error: {e}'
|
487
|
+
) from e
|
488
|
+
else:
|
489
|
+
try:
|
490
|
+
init = init + out
|
491
|
+
except Exception as e:
|
492
|
+
raise ValueError(
|
493
|
+
f'Error in delta input value {key}: {out}\n'
|
494
|
+
f'Error: {e}'
|
495
|
+
) from e
|
496
|
+
self._current_inputs.pop(key)
|
488
497
|
return init
|
489
498
|
|
490
499
|
def sum_delta_inputs(
|
@@ -529,21 +538,30 @@ class Dynamics(Module, UpdateReturn):
|
|
529
538
|
if self._delta_inputs is None:
|
530
539
|
return init
|
531
540
|
if label is None:
|
532
|
-
|
533
|
-
for key in tuple(self._delta_inputs.keys()):
|
534
|
-
out = self._delta_inputs[key]
|
535
|
-
init = init + (out(*args, **kwargs) if callable(out) else out)
|
536
|
-
if not callable(out):
|
537
|
-
self._delta_inputs.pop(key)
|
541
|
+
filter_fn = lambda k: True
|
538
542
|
else:
|
539
|
-
# has label
|
540
543
|
label_repr = _input_label_start(label)
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
544
|
+
filter_fn = lambda k: k.startswith(label_repr)
|
545
|
+
for key in tuple(self._delta_inputs.keys()):
|
546
|
+
if filter_fn(key):
|
547
|
+
out = self._delta_inputs[key]
|
548
|
+
if callable(out):
|
549
|
+
try:
|
550
|
+
init = init + out(*args, **kwargs)
|
551
|
+
except Exception as e:
|
552
|
+
raise ValueError(
|
553
|
+
f'Error in delta input function {key}: {out}\n'
|
554
|
+
f'Error: {e}'
|
555
|
+
) from e
|
556
|
+
else:
|
557
|
+
try:
|
558
|
+
init = init + out
|
559
|
+
except Exception as e:
|
560
|
+
raise ValueError(
|
561
|
+
f'Error in delta input value {key}: {out}\n'
|
562
|
+
f'Error: {e}'
|
563
|
+
) from e
|
564
|
+
self._delta_inputs.pop(key)
|
547
565
|
return init
|
548
566
|
|
549
567
|
@property
|