brainstate 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +5 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +13 -8
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +48 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/METADATA +12 -12
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/RECORD +40 -40
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/WHEEL +1 -1
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/LICENSE +0 -0
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/top_level.txt +0 -0
@@ -352,6 +352,13 @@ class StatefulFunction(PrettyObject):
|
|
352
352
|
cache_key = default_cache_key
|
353
353
|
return self.get_state_trace(cache_key).get_write_states()
|
354
354
|
|
355
|
+
def _check_input(self, x):
|
356
|
+
if isinstance(x, State):
|
357
|
+
raise ValueError(
|
358
|
+
'Inputs for brainstate transformations cannot be an instance of State. '
|
359
|
+
f'But we got {x}'
|
360
|
+
)
|
361
|
+
|
355
362
|
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
356
363
|
"""
|
357
364
|
Get the static arguments from the arguments.
|
@@ -370,22 +377,35 @@ class StatefulFunction(PrettyObject):
|
|
370
377
|
static_args.append(arg)
|
371
378
|
else:
|
372
379
|
dyn_args.append(arg)
|
373
|
-
dyn_args = jax.tree.map(shaped_abstractify,
|
380
|
+
dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
|
374
381
|
static_kwargs, dyn_kwargs = [], []
|
375
382
|
for k, v in kwargs.items():
|
376
383
|
if k in self.static_argnames:
|
377
384
|
static_kwargs.append((k, v))
|
378
385
|
else:
|
379
386
|
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
380
|
-
|
387
|
+
|
388
|
+
static_args = make_hashable(tuple(static_args))
|
389
|
+
dyn_args = make_hashable(tuple(dyn_args))
|
390
|
+
static_kwargs = make_hashable(static_kwargs)
|
391
|
+
dyn_kwargs = make_hashable(dyn_kwargs)
|
392
|
+
|
393
|
+
cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
|
381
394
|
elif self.cache_type is None:
|
382
395
|
num_arg = len(args)
|
383
396
|
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
384
397
|
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
385
|
-
|
398
|
+
|
399
|
+
# Make everything hashable
|
400
|
+
static_args = make_hashable(static_args)
|
401
|
+
static_kwargs = make_hashable(static_kwargs)
|
402
|
+
|
403
|
+
cache_key = (static_args, static_kwargs)
|
386
404
|
else:
|
387
405
|
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
388
406
|
|
407
|
+
return cache_key
|
408
|
+
|
389
409
|
def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
|
390
410
|
"""
|
391
411
|
Compile the function, and get the states that are read and written by this function.
|
@@ -480,6 +500,9 @@ class StatefulFunction(PrettyObject):
|
|
480
500
|
# static args
|
481
501
|
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
482
502
|
|
503
|
+
# check input types
|
504
|
+
jax.tree.map(self._check_input, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
505
|
+
|
483
506
|
if cache_key not in self._cached_state_trace:
|
484
507
|
try:
|
485
508
|
# jaxpr
|
@@ -637,15 +660,15 @@ def make_jaxpr(
|
|
637
660
|
instead give a few examples.
|
638
661
|
|
639
662
|
>>> import jax
|
640
|
-
>>> import brainstate as
|
663
|
+
>>> import brainstate as brainstate
|
641
664
|
>>>
|
642
665
|
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
643
666
|
>>> print(f(3.0))
|
644
667
|
-0.83602
|
645
|
-
>>> jaxpr, states =
|
668
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
|
646
669
|
>>> jaxpr
|
647
670
|
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
648
|
-
>>> jaxpr, states =
|
671
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
|
649
672
|
>>> jaxpr
|
650
673
|
{ lambda ; a:f32[]. let
|
651
674
|
b:f32[] = cos a
|
@@ -844,3 +867,22 @@ def _make_jaxpr(
|
|
844
867
|
if hasattr(fun, "__name__"):
|
845
868
|
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
846
869
|
return make_jaxpr_f
|
870
|
+
|
871
|
+
|
872
|
+
def make_hashable(obj):
|
873
|
+
"""Convert a pytree into a hashable representation."""
|
874
|
+
if isinstance(obj, (list, tuple)):
|
875
|
+
return tuple(make_hashable(item) for item in obj)
|
876
|
+
elif isinstance(obj, dict):
|
877
|
+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
878
|
+
elif isinstance(obj, set):
|
879
|
+
return frozenset(make_hashable(item) for item in obj)
|
880
|
+
else:
|
881
|
+
# # Use JAX's tree_util for any other pytree structures
|
882
|
+
# try:
|
883
|
+
# leaves, treedef = jax.tree_util.tree_flatten(obj)
|
884
|
+
# hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
|
885
|
+
# return (str(treedef), hashable_leaves)
|
886
|
+
# except:
|
887
|
+
# # Assume obj is already hashable
|
888
|
+
return obj
|
@@ -53,7 +53,7 @@ class ProgressBar(object):
|
|
53
53
|
|
54
54
|
.. code-block:: python
|
55
55
|
|
56
|
-
a =
|
56
|
+
a = brainstate.State(1.)
|
57
57
|
def loop_fn(x):
|
58
58
|
a.value = x.value + 1.
|
59
59
|
return jnp.sum(x ** 2)
|
@@ -61,7 +61,7 @@ class ProgressBar(object):
|
|
61
61
|
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
62
62
|
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
63
63
|
|
64
|
-
|
64
|
+
brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
|
65
65
|
|
66
66
|
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
67
67
|
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
brainstate/environ.py
CHANGED
@@ -76,9 +76,9 @@ def context(**kwargs):
|
|
76
76
|
|
77
77
|
For instance::
|
78
78
|
|
79
|
-
>>> import brainstate as
|
80
|
-
>>> with
|
81
|
-
... dt =
|
79
|
+
>>> import brainstate as brainstate
|
80
|
+
>>> with brainstate.environ.context(dt=0.1) as env:
|
81
|
+
... dt = brainstate.environ.get('dt')
|
82
82
|
... print(env)
|
83
83
|
|
84
84
|
"""
|
@@ -424,10 +424,10 @@ def dftype() -> DTypeLike:
|
|
424
424
|
|
425
425
|
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
426
426
|
|
427
|
-
>>> import brainstate as
|
427
|
+
>>> import brainstate as brainstate
|
428
428
|
>>> import numpy as np
|
429
|
-
>>> with
|
430
|
-
... a = np.zeros(1, dtype=
|
429
|
+
>>> with brainstate.environ.context(precision=32):
|
430
|
+
... a = np.zeros(1, dtype=brainstate.environ.dftype())
|
431
431
|
>>> print(a.dtype)
|
432
432
|
|
433
433
|
Returns
|
@@ -448,10 +448,10 @@ def ditype() -> DTypeLike:
|
|
448
448
|
|
449
449
|
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
450
450
|
|
451
|
-
>>> import brainstate as
|
451
|
+
>>> import brainstate as brainstate
|
452
452
|
>>> import numpy as np
|
453
|
-
>>> with
|
454
|
-
... a = np.zeros(1, dtype=
|
453
|
+
>>> with brainstate.environ.context(precision=32):
|
454
|
+
... a = np.zeros(1, dtype=brainstate.environ.ditype())
|
455
455
|
>>> print(a.dtype)
|
456
456
|
int32
|
457
457
|
|
@@ -474,10 +474,10 @@ def dutype() -> DTypeLike:
|
|
474
474
|
|
475
475
|
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
476
476
|
|
477
|
-
>>> import brainstate as
|
477
|
+
>>> import brainstate as brainstate
|
478
478
|
>>> import numpy as np
|
479
|
-
>>> with
|
480
|
-
... a = np.zeros(1, dtype=
|
479
|
+
>>> with brainstate.environ.context(precision=32):
|
480
|
+
... a = np.zeros(1, dtype=brainstate.environ.dutype())
|
481
481
|
>>> print(a.dtype)
|
482
482
|
uint32
|
483
483
|
|
@@ -499,10 +499,10 @@ def dctype() -> DTypeLike:
|
|
499
499
|
|
500
500
|
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
501
501
|
|
502
|
-
>>> import brainstate as
|
502
|
+
>>> import brainstate as brainstate
|
503
503
|
>>> import numpy as np
|
504
|
-
>>> with
|
505
|
-
... a = np.zeros(1, dtype=
|
504
|
+
>>> with brainstate.environ.context(precision=32):
|
505
|
+
... a = np.zeros(1, dtype=brainstate.environ.dctype())
|
506
506
|
>>> print(a.dtype)
|
507
507
|
complex64
|
508
508
|
|
@@ -529,19 +529,19 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
|
|
529
529
|
|
530
530
|
For example, you can register a default behavior for the key 'dt' by::
|
531
531
|
|
532
|
-
>>> import brainstate as
|
532
|
+
>>> import brainstate as brainstate
|
533
533
|
>>> def dt_behavior(dt):
|
534
534
|
... print(f'Set the default dt to {dt}.')
|
535
535
|
...
|
536
|
-
>>>
|
536
|
+
>>> brainstate.environ.register_default_behavior('dt', dt_behavior)
|
537
537
|
|
538
538
|
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
539
539
|
`dt_behavior` will be called with
|
540
540
|
`dt_behavior(0.1)`.
|
541
541
|
|
542
|
-
>>>
|
542
|
+
>>> brainstate.environ.set(dt=0.1)
|
543
543
|
Set the default dt to 0.1.
|
544
|
-
>>> with
|
544
|
+
>>> with brainstate.environ.context(dt=0.2):
|
545
545
|
... pass
|
546
546
|
Set the default dt to 0.2.
|
547
547
|
Set the default dt to 0.1.
|
@@ -70,39 +70,39 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
70
70
|
check_dtypes=False)
|
71
71
|
|
72
72
|
# def testSquareplusGrad(self):
|
73
|
-
# check_grads(
|
73
|
+
# check_grads(brainstate.functional.squareplus, (1e-8,), order=4,
|
74
74
|
# )
|
75
75
|
|
76
76
|
# def testSquareplusGradZero(self):
|
77
|
-
# check_grads(
|
77
|
+
# check_grads(brainstate.functional.squareplus, (0.,), order=1,
|
78
78
|
# )
|
79
79
|
|
80
80
|
# def testSquareplusGradNegInf(self):
|
81
|
-
# check_grads(
|
81
|
+
# check_grads(brainstate.functional.squareplus, (-float('inf'),), order=1,
|
82
82
|
# )
|
83
83
|
|
84
84
|
# def testSquareplusGradNan(self):
|
85
|
-
# check_grads(
|
85
|
+
# check_grads(brainstate.functional.squareplus, (float('nan'),), order=1,
|
86
86
|
# )
|
87
87
|
|
88
88
|
# @parameterized.parameters([float] + jtu.dtypes.floating)
|
89
89
|
# def testSquareplusZero(self, dtype):
|
90
|
-
# self.assertEqual(dtype(1),
|
90
|
+
# self.assertEqual(dtype(1), brainstate.functional.squareplus(dtype(0), dtype(4)))
|
91
91
|
#
|
92
92
|
# def testMishGrad(self):
|
93
|
-
# check_grads(
|
93
|
+
# check_grads(brainstate.functional.mish, (1e-8,), order=4,
|
94
94
|
# )
|
95
95
|
#
|
96
96
|
# def testMishGradZero(self):
|
97
|
-
# check_grads(
|
97
|
+
# check_grads(brainstate.functional.mish, (0.,), order=1,
|
98
98
|
# )
|
99
99
|
#
|
100
100
|
# def testMishGradNegInf(self):
|
101
|
-
# check_grads(
|
101
|
+
# check_grads(brainstate.functional.mish, (-float('inf'),), order=1,
|
102
102
|
# )
|
103
103
|
#
|
104
104
|
# def testMishGradNan(self):
|
105
|
-
# check_grads(
|
105
|
+
# check_grads(brainstate.functional.mish, (float('nan'),), order=1,
|
106
106
|
# )
|
107
107
|
|
108
108
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
@@ -137,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
137
137
|
self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
138
|
|
139
139
|
# def testSquareplusValue(self):
|
140
|
-
# val =
|
140
|
+
# val = brainstate.functional.squareplus(1e3)
|
141
141
|
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
142
|
|
143
143
|
def testMishValue(self):
|
@@ -177,7 +177,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
177
177
|
brainstate.functional.softplus,
|
178
178
|
brainstate.functional.sparse_plus,
|
179
179
|
brainstate.functional.sigmoid,
|
180
|
-
#
|
180
|
+
# brainstate.functional.squareplus,
|
181
181
|
brainstate.functional.mish)))
|
182
182
|
def testDtypeMatchesInput(self, dtype, fn):
|
183
183
|
x = jnp.zeros((), dtype=dtype)
|
@@ -306,7 +306,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
306
306
|
|
307
307
|
def testCustomJVPLeak2(self):
|
308
308
|
# https://github.com/google/jax/issues/8171
|
309
|
-
# The above test uses jax.
|
309
|
+
# The above test uses jax.brainstate.functional.sigmoid, as in the original #8171, but that
|
310
310
|
# function no longer actually has a custom_jvp! So we inline the old def.
|
311
311
|
|
312
312
|
@jax.custom_jvp
|
@@ -473,8 +473,8 @@ def flatten(
|
|
473
473
|
|
474
474
|
Example::
|
475
475
|
|
476
|
-
>>> import brainstate as
|
477
|
-
>>> node =
|
476
|
+
>>> import brainstate as brainstate
|
477
|
+
>>> node = brainstate.graph.Node()
|
478
478
|
>>> graph_def, state_mapping = flatten(node)
|
479
479
|
>>> print(graph_def)
|
480
480
|
>>> print(state_mapping)
|
@@ -709,15 +709,15 @@ def unflatten(
|
|
709
709
|
|
710
710
|
Example::
|
711
711
|
|
712
|
-
>>> import brainstate as
|
713
|
-
>>> class MyNode(
|
712
|
+
>>> import brainstate as brainstate
|
713
|
+
>>> class MyNode(brainstate.graph.Node):
|
714
714
|
... def __init__(self):
|
715
|
-
... self.a =
|
716
|
-
... self.b =
|
717
|
-
... self.c = [
|
718
|
-
... self.d = {'x':
|
715
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
716
|
+
... self.b = brainstate.nn.Linear(3, 4)
|
717
|
+
... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
|
718
|
+
... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
|
719
719
|
...
|
720
|
-
>>> graphdef, statetree =
|
720
|
+
>>> graphdef, statetree = brainstate.graph.flatten(MyNode())
|
721
721
|
>>> statetree
|
722
722
|
NestedDict({
|
723
723
|
'a': {
|
@@ -764,7 +764,7 @@ def unflatten(
|
|
764
764
|
}
|
765
765
|
}
|
766
766
|
})
|
767
|
-
>>> node =
|
767
|
+
>>> node = brainstate.graph.unflatten(graphdef, statetree)
|
768
768
|
>>> node
|
769
769
|
MyNode(
|
770
770
|
a=Linear(
|
@@ -942,21 +942,21 @@ def pop_states(
|
|
942
942
|
|
943
943
|
Example usage::
|
944
944
|
|
945
|
-
>>> import brainstate as
|
945
|
+
>>> import brainstate as brainstate
|
946
946
|
>>> import jax.numpy as jnp
|
947
947
|
|
948
|
-
>>> class Model(
|
948
|
+
>>> class Model(brainstate.nn.Module):
|
949
949
|
... def __init__(self):
|
950
950
|
... super().__init__()
|
951
|
-
... self.a =
|
952
|
-
... self.b =
|
951
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
952
|
+
... self.b = brainstate.nn.LIF([10, 2])
|
953
953
|
|
954
954
|
>>> model = Model()
|
955
|
-
>>> with
|
956
|
-
...
|
955
|
+
>>> with brainstate.catch_new_states('new'):
|
956
|
+
... brainstate.nn.init_all_states(model)
|
957
957
|
|
958
958
|
>>> assert len(model.states()) == 2
|
959
|
-
>>> model_states =
|
959
|
+
>>> model_states = brainstate.graph.pop_states(model, 'new')
|
960
960
|
>>> model_states
|
961
961
|
NestedDict({
|
962
962
|
'b': {
|
@@ -1046,16 +1046,16 @@ def treefy_split(
|
|
1046
1046
|
|
1047
1047
|
Example usage::
|
1048
1048
|
|
1049
|
-
>>> from joblib.testing import param >>> import brainstate as
|
1049
|
+
>>> from joblib.testing import param >>> import brainstate as brainstate
|
1050
1050
|
>>> import jax, jax.numpy as jnp
|
1051
1051
|
...
|
1052
|
-
>>> class Foo(
|
1052
|
+
>>> class Foo(brainstate.graph.Node):
|
1053
1053
|
... def __init__(self):
|
1054
|
-
... self.a =
|
1055
|
-
... self.b =
|
1054
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1055
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1056
1056
|
...
|
1057
1057
|
>>> node = Foo()
|
1058
|
-
>>> graphdef, params, others =
|
1058
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
1059
1059
|
...
|
1060
1060
|
>>> params
|
1061
1061
|
NestedDict({
|
@@ -1119,21 +1119,21 @@ def treefy_merge(
|
|
1119
1119
|
|
1120
1120
|
Example usage::
|
1121
1121
|
|
1122
|
-
>>> import brainstate as
|
1122
|
+
>>> import brainstate as brainstate
|
1123
1123
|
>>> import jax, jax.numpy as jnp
|
1124
1124
|
...
|
1125
|
-
>>> class Foo(
|
1125
|
+
>>> class Foo(brainstate.graph.Node):
|
1126
1126
|
... def __init__(self):
|
1127
|
-
... self.a =
|
1128
|
-
... self.b =
|
1127
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1128
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1129
1129
|
...
|
1130
1130
|
>>> node = Foo()
|
1131
|
-
>>> graphdef, params, others =
|
1131
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
1132
1132
|
...
|
1133
|
-
>>> new_node =
|
1133
|
+
>>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
1134
1134
|
>>> assert isinstance(new_node, Foo)
|
1135
|
-
>>> assert isinstance(new_node.b,
|
1136
|
-
>>> assert isinstance(new_node.a,
|
1135
|
+
>>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
1136
|
+
>>> assert isinstance(new_node.a, brainstate.nn.Linear)
|
1137
1137
|
|
1138
1138
|
:func:`split` and :func:`merge` are primarily used to interact directly with JAX
|
1139
1139
|
transformations, see
|
@@ -1302,22 +1302,22 @@ def treefy_states(
|
|
1302
1302
|
|
1303
1303
|
Example usage::
|
1304
1304
|
|
1305
|
-
>>> import brainstate as
|
1306
|
-
>>> class Model(
|
1305
|
+
>>> import brainstate as brainstate
|
1306
|
+
>>> class Model(brainstate.nn.Module):
|
1307
1307
|
... def __init__(self):
|
1308
1308
|
... super().__init__()
|
1309
|
-
... self.l1 =
|
1310
|
-
... self.l2 =
|
1309
|
+
... self.l1 = brainstate.nn.Linear(2, 3)
|
1310
|
+
... self.l2 = brainstate.nn.Linear(3, 4)
|
1311
1311
|
... def __call__(self, x):
|
1312
1312
|
... return self.l2(self.l1(x))
|
1313
1313
|
|
1314
1314
|
>>> model = Model()
|
1315
1315
|
>>> # get the learnable parameters from the batch norm and linear layer
|
1316
|
-
>>> params =
|
1316
|
+
>>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
|
1317
1317
|
>>> # get them separately
|
1318
|
-
>>> params, others =
|
1318
|
+
>>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
|
1319
1319
|
>>> # get them together
|
1320
|
-
>>> states =
|
1320
|
+
>>> states = brainstate.graph.treefy_states(model)
|
1321
1321
|
|
1322
1322
|
Args:
|
1323
1323
|
node: A graph node object.
|
@@ -1403,11 +1403,11 @@ def graphdef(node: Any, /) -> GraphDef[Any]:
|
|
1403
1403
|
|
1404
1404
|
Example usage::
|
1405
1405
|
|
1406
|
-
>>> import brainstate as
|
1406
|
+
>>> import brainstate as brainstate
|
1407
1407
|
|
1408
|
-
>>> model =
|
1409
|
-
>>> graphdef, _ =
|
1410
|
-
>>> assert graphdef ==
|
1408
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1409
|
+
>>> graphdef, _ = brainstate.graph.treefy_split(model)
|
1410
|
+
>>> assert graphdef == brainstate.graph.graphdef(model)
|
1411
1411
|
|
1412
1412
|
Args:
|
1413
1413
|
node: A graph node object.
|
@@ -1426,8 +1426,8 @@ def clone(node: Node) -> Node:
|
|
1426
1426
|
|
1427
1427
|
Example usage::
|
1428
1428
|
|
1429
|
-
>>> import brainstate as
|
1430
|
-
>>> model =
|
1429
|
+
>>> import brainstate as brainstate
|
1430
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1431
1431
|
>>> cloned_model = clone(model)
|
1432
1432
|
>>> model.weight.value['bias'] += 1
|
1433
1433
|
>>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
|
@@ -1456,15 +1456,15 @@ def call(
|
|
1456
1456
|
|
1457
1457
|
Example::
|
1458
1458
|
|
1459
|
-
>>> import brainstate as
|
1459
|
+
>>> import brainstate as brainstate
|
1460
1460
|
>>> import jax
|
1461
1461
|
>>> import jax.numpy as jnp
|
1462
1462
|
...
|
1463
|
-
>>> class StatefulLinear(
|
1463
|
+
>>> class StatefulLinear(brainstate.graph.Node):
|
1464
1464
|
... def __init__(self, din, dout):
|
1465
|
-
... self.w =
|
1466
|
-
... self.b =
|
1467
|
-
... self.count =
|
1465
|
+
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1466
|
+
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1467
|
+
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1468
1468
|
...
|
1469
1469
|
... def increment(self):
|
1470
1470
|
... self.count.value += 1
|
@@ -1474,18 +1474,18 @@ def call(
|
|
1474
1474
|
... return x @ self.w.value + self.b.value
|
1475
1475
|
...
|
1476
1476
|
>>> linear = StatefulLinear(3, 2)
|
1477
|
-
>>> linear_state =
|
1477
|
+
>>> linear_state = brainstate.graph.treefy_split(linear)
|
1478
1478
|
...
|
1479
1479
|
>>> @jax.jit
|
1480
1480
|
... def forward(x, linear_state):
|
1481
|
-
... y, linear_state =
|
1481
|
+
... y, linear_state = brainstate.graph.call(linear_state)(x)
|
1482
1482
|
... return y, linear_state
|
1483
1483
|
...
|
1484
1484
|
>>> x = jnp.ones((1, 3))
|
1485
1485
|
>>> y, linear_state = forward(x, linear_state)
|
1486
1486
|
>>> y, linear_state = forward(x, linear_state)
|
1487
1487
|
...
|
1488
|
-
>>> linear =
|
1488
|
+
>>> linear = brainstate.graph.treefy_merge(*linear_state)
|
1489
1489
|
>>> linear.count.value
|
1490
1490
|
Array(2, dtype=uint32)
|
1491
1491
|
|
@@ -1494,11 +1494,11 @@ def call(
|
|
1494
1494
|
is used to call the ``increment`` method of the ``StatefulLinear`` module
|
1495
1495
|
at the ``b`` key of a ``nodes`` dictionary.
|
1496
1496
|
|
1497
|
-
>>> class StatefulLinear(
|
1497
|
+
>>> class StatefulLinear(brainstate.graph.Node):
|
1498
1498
|
... def __init__(self, din, dout):
|
1499
|
-
... self.w =
|
1500
|
-
... self.b =
|
1501
|
-
... self.count =
|
1499
|
+
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1500
|
+
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1501
|
+
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1502
1502
|
...
|
1503
1503
|
... def increment(self):
|
1504
1504
|
... self.count.value += 1
|
@@ -1514,7 +1514,7 @@ def call(
|
|
1514
1514
|
...
|
1515
1515
|
>>> node_state = treefy_split(nodes)
|
1516
1516
|
>>> # use attribute access
|
1517
|
-
>>> _, node_state =
|
1517
|
+
>>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
|
1518
1518
|
...
|
1519
1519
|
>>> nodes = treefy_merge(*node_state)
|
1520
1520
|
>>> nodes['a'].count.value
|
@@ -1544,19 +1544,19 @@ def iter_leaf(
|
|
1544
1544
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1545
1545
|
|
1546
1546
|
Example::
|
1547
|
-
>>> import brainstate as
|
1547
|
+
>>> import brainstate as brainstate
|
1548
1548
|
>>> import jax.numpy as jnp
|
1549
1549
|
...
|
1550
|
-
>>> class Linear(
|
1550
|
+
>>> class Linear(brainstate.nn.Module):
|
1551
1551
|
... def __init__(self, din, dout):
|
1552
1552
|
... super().__init__()
|
1553
|
-
... self.weight =
|
1554
|
-
... self.bias =
|
1553
|
+
... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
1554
|
+
... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
1555
1555
|
... self.a = 1
|
1556
1556
|
...
|
1557
1557
|
>>> module = Linear(3, 4)
|
1558
1558
|
...
|
1559
|
-
>>> for path, value in
|
1559
|
+
>>> for path, value in brainstate.graph.iter_leaf([module, module]):
|
1560
1560
|
... print(path, type(value).__name__)
|
1561
1561
|
...
|
1562
1562
|
(0, 'a') int
|
@@ -1616,21 +1616,21 @@ def iter_node(
|
|
1616
1616
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1617
1617
|
|
1618
1618
|
Example::
|
1619
|
-
>>> import brainstate as
|
1619
|
+
>>> import brainstate as brainstate
|
1620
1620
|
>>> import jax.numpy as jnp
|
1621
1621
|
...
|
1622
|
-
>>> class Model(
|
1622
|
+
>>> class Model(brainstate.nn.Module):
|
1623
1623
|
... def __init__(self):
|
1624
1624
|
... super().__init__()
|
1625
|
-
... self.a =
|
1626
|
-
... self.b =
|
1627
|
-
... self.c = [
|
1628
|
-
... self.d = {'x':
|
1629
|
-
... self.b.a =
|
1625
|
+
... self.a = brainstate.nn.Linear(1, 2)
|
1626
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1627
|
+
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1628
|
+
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1629
|
+
... self.b.a = brainstate.nn.LIF(2)
|
1630
1630
|
...
|
1631
1631
|
>>> model = Model()
|
1632
1632
|
...
|
1633
|
-
>>> for path, node in
|
1633
|
+
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1634
1634
|
... print(path, node.__class__.__name__)
|
1635
1635
|
...
|
1636
1636
|
(0, 'a') Linear
|
@@ -443,7 +443,7 @@ class TestGraphOperation(unittest.TestCase):
|
|
443
443
|
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
444
444
|
# print(graphdef)
|
445
445
|
print(statetree)
|
446
|
-
# print(
|
446
|
+
# print(brainstate.graph.unflatten(graphdef, statetree))
|
447
447
|
|
448
448
|
def test_split(self):
|
449
449
|
class Foo(brainstate.graph.Node):
|
@@ -530,7 +530,7 @@ class TestGraphOperation(unittest.TestCase):
|
|
530
530
|
print(graph_def)
|
531
531
|
print(treefy_states)
|
532
532
|
|
533
|
-
# states =
|
533
|
+
# states = brainstate.graph.states(model)
|
534
534
|
# print(states)
|
535
535
|
# nest_states = states.to_nest()
|
536
536
|
# print(nest_states)
|
brainstate/mixin.py
CHANGED
@@ -27,7 +27,6 @@ __all__ = [
|
|
27
27
|
'ParamDescriber',
|
28
28
|
'AlignPost',
|
29
29
|
'BindCondData',
|
30
|
-
'UpdateReturn',
|
31
30
|
|
32
31
|
# types
|
33
32
|
'JointTypes',
|
@@ -171,22 +170,6 @@ def not_implemented(func):
|
|
171
170
|
return wrapper
|
172
171
|
|
173
172
|
|
174
|
-
|
175
|
-
class UpdateReturn(Mixin):
|
176
|
-
@not_implemented
|
177
|
-
def update_return(self) -> PyTree:
|
178
|
-
"""
|
179
|
-
The update function return of the model.
|
180
|
-
|
181
|
-
This function requires no parameters and must return a PyTree.
|
182
|
-
|
183
|
-
It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
|
184
|
-
initialize the output delay.
|
185
|
-
|
186
|
-
"""
|
187
|
-
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
188
|
-
|
189
|
-
|
190
173
|
class _MetaUnionType(type):
|
191
174
|
def __new__(cls, name, bases, dct):
|
192
175
|
if isinstance(bases, type):
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -56,10 +56,10 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
|
|
56
56
|
|
57
57
|
The lower the level, the earlier the function is called.
|
58
58
|
|
59
|
-
>>> import brainstate as
|
60
|
-
>>>
|
61
|
-
>>>
|
62
|
-
>>>
|
59
|
+
>>> import brainstate as brainstate
|
60
|
+
>>> brainstate.nn.call_order(0)
|
61
|
+
>>> brainstate.nn.call_order(-1)
|
62
|
+
>>> brainstate.nn.call_order(-2)
|
63
63
|
|
64
64
|
Parameters
|
65
65
|
----------
|
brainstate/nn/_dropout_test.py
CHANGED
@@ -60,9 +60,9 @@ class TestDropout(unittest.TestCase):
|
|
60
60
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
61
61
|
|
62
62
|
# def test_Dropout1d(self):
|
63
|
-
# dropout_layer =
|
63
|
+
# dropout_layer = brainstate.nn.Dropout1d(prob=0.5)
|
64
64
|
# input_data = np.random.randn(2, 3, 4)
|
65
|
-
# with
|
65
|
+
# with brainstate.environ.context(fit=True):
|
66
66
|
# output_data = dropout_layer(input_data)
|
67
67
|
# self.assertEqual(input_data.shape, output_data.shape)
|
68
68
|
# self.assertTrue(np.any(output_data == 0))
|