brainstate 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +5 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +13 -8
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +48 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/METADATA +12 -12
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/RECORD +40 -40
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/WHEEL +1 -1
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/LICENSE +0 -0
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py
CHANGED
@@ -106,11 +106,11 @@ class Surrogate(PrettyObject):
|
|
106
106
|
Examples
|
107
107
|
--------
|
108
108
|
|
109
|
-
>>> import brainstate as
|
109
|
+
>>> import brainstate as brainstate
|
110
110
|
>>> import brainstate.nn as nn
|
111
111
|
>>> import jax.numpy as jnp
|
112
112
|
|
113
|
-
>>> class MySurrogate(
|
113
|
+
>>> class MySurrogate(brainstate.surrogate.Surrogate):
|
114
114
|
... def __init__(self, alpha=1.):
|
115
115
|
... super().__init__()
|
116
116
|
... self.alpha = alpha
|
@@ -236,11 +236,11 @@ def sigmoid(
|
|
236
236
|
|
237
237
|
>>> import jax
|
238
238
|
>>> import brainstate.nn as nn
|
239
|
-
>>> import brainstate as
|
239
|
+
>>> import brainstate as brainstate
|
240
240
|
>>> import matplotlib.pyplot as plt
|
241
241
|
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
242
242
|
>>> for alpha in [1., 2., 4.]:
|
243
|
-
>>> grads =
|
243
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
|
244
244
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
245
245
|
>>> plt.legend()
|
246
246
|
>>> plt.show()
|
@@ -355,11 +355,11 @@ def piecewise_quadratic(
|
|
355
355
|
|
356
356
|
>>> import jax
|
357
357
|
>>> import brainstate.nn as nn
|
358
|
-
>>> import brainstate as
|
358
|
+
>>> import brainstate as brainstate
|
359
359
|
>>> import matplotlib.pyplot as plt
|
360
360
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
361
361
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
362
|
-
>>> grads =
|
362
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
|
363
363
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
364
364
|
>>> plt.legend()
|
365
365
|
>>> plt.show()
|
@@ -522,11 +522,11 @@ def piecewise_exp(
|
|
522
522
|
|
523
523
|
>>> import jax
|
524
524
|
>>> import brainstate.nn as nn
|
525
|
-
>>> import brainstate as
|
525
|
+
>>> import brainstate as brainstate
|
526
526
|
>>> import matplotlib.pyplot as plt
|
527
527
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
528
528
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
529
|
-
>>> grads =
|
529
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
|
530
530
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
531
531
|
>>> plt.legend()
|
532
532
|
>>> plt.show()
|
@@ -621,11 +621,11 @@ def soft_sign(
|
|
621
621
|
|
622
622
|
>>> import jax
|
623
623
|
>>> import brainstate.nn as nn
|
624
|
-
>>> import brainstate as
|
624
|
+
>>> import brainstate as brainstate
|
625
625
|
>>> import matplotlib.pyplot as plt
|
626
626
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
627
627
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
628
|
-
>>> grads =
|
628
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
|
629
629
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
630
630
|
>>> plt.legend()
|
631
631
|
>>> plt.show()
|
@@ -706,11 +706,11 @@ def arctan(
|
|
706
706
|
|
707
707
|
>>> import jax
|
708
708
|
>>> import brainstate.nn as nn
|
709
|
-
>>> import brainstate as
|
709
|
+
>>> import brainstate as brainstate
|
710
710
|
>>> import matplotlib.pyplot as plt
|
711
711
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
712
712
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
713
|
-
>>> grads =
|
713
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
|
714
714
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
715
715
|
>>> plt.legend()
|
716
716
|
>>> plt.show()
|
@@ -804,11 +804,11 @@ def nonzero_sign_log(
|
|
804
804
|
|
805
805
|
>>> import jax
|
806
806
|
>>> import brainstate.nn as nn
|
807
|
-
>>> import brainstate as
|
807
|
+
>>> import brainstate as brainstate
|
808
808
|
>>> import matplotlib.pyplot as plt
|
809
809
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
810
810
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
811
|
-
>>> grads =
|
811
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
812
812
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
813
813
|
>>> plt.legend()
|
814
814
|
>>> plt.show()
|
@@ -893,11 +893,11 @@ def erf(
|
|
893
893
|
|
894
894
|
>>> import jax
|
895
895
|
>>> import brainstate.nn as nn
|
896
|
-
>>> import brainstate as
|
896
|
+
>>> import brainstate as brainstate
|
897
897
|
>>> import matplotlib.pyplot as plt
|
898
898
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
899
899
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
900
|
-
>>> grads =
|
900
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
901
901
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
902
902
|
>>> plt.legend()
|
903
903
|
>>> plt.show()
|
@@ -1000,12 +1000,12 @@ def piecewise_leaky_relu(
|
|
1000
1000
|
|
1001
1001
|
>>> import jax
|
1002
1002
|
>>> import brainstate.nn as nn
|
1003
|
-
>>> import brainstate as
|
1003
|
+
>>> import brainstate as brainstate
|
1004
1004
|
>>> import matplotlib.pyplot as plt
|
1005
1005
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1006
1006
|
>>> for c in [0.01, 0.05, 0.1]:
|
1007
1007
|
>>> for w in [1., 2.]:
|
1008
|
-
>>> grads1 =
|
1008
|
+
>>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
1009
1009
|
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
1010
1010
|
>>> plt.legend()
|
1011
1011
|
>>> plt.show()
|
@@ -1113,12 +1113,12 @@ def squarewave_fourier_series(
|
|
1113
1113
|
|
1114
1114
|
>>> import jax
|
1115
1115
|
>>> import brainstate.nn as nn
|
1116
|
-
>>> import brainstate as
|
1116
|
+
>>> import brainstate as brainstate
|
1117
1117
|
>>> import matplotlib.pyplot as plt
|
1118
1118
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1119
1119
|
>>> for n in [2, 4, 8]:
|
1120
|
-
>>> f =
|
1121
|
-
>>> grads1 =
|
1120
|
+
>>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
|
1121
|
+
>>> grads1 = brainstate.augment.vector_grad(f)(xs)
|
1122
1122
|
>>> plt.plot(xs, grads1, label=f'n={n}')
|
1123
1123
|
>>> plt.legend()
|
1124
1124
|
>>> plt.show()
|
@@ -1214,12 +1214,12 @@ def s2nn(
|
|
1214
1214
|
|
1215
1215
|
>>> import jax
|
1216
1216
|
>>> import brainstate.nn as nn
|
1217
|
-
>>> import brainstate as
|
1217
|
+
>>> import brainstate as brainstate
|
1218
1218
|
>>> import matplotlib.pyplot as plt
|
1219
1219
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1220
|
-
>>> grads =
|
1220
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
|
1221
1221
|
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1222
|
-
>>> grads =
|
1222
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
|
1223
1223
|
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1224
1224
|
>>> plt.legend()
|
1225
1225
|
>>> plt.show()
|
@@ -1315,11 +1315,11 @@ def q_pseudo_spike(
|
|
1315
1315
|
|
1316
1316
|
>>> import jax
|
1317
1317
|
>>> import brainstate.nn as nn
|
1318
|
-
>>> import brainstate as
|
1318
|
+
>>> import brainstate as brainstate
|
1319
1319
|
>>> import matplotlib.pyplot as plt
|
1320
1320
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1321
1321
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
1322
|
-
>>> grads =
|
1322
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
|
1323
1323
|
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1324
1324
|
>>> plt.legend()
|
1325
1325
|
>>> plt.show()
|
@@ -1413,10 +1413,10 @@ def leaky_relu(
|
|
1413
1413
|
|
1414
1414
|
>>> import jax
|
1415
1415
|
>>> import brainstate.nn as nn
|
1416
|
-
>>> import brainstate as
|
1416
|
+
>>> import brainstate as brainstate
|
1417
1417
|
>>> import matplotlib.pyplot as plt
|
1418
1418
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1419
|
-
>>> grads =
|
1419
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1420
1420
|
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1421
1421
|
>>> plt.legend()
|
1422
1422
|
>>> plt.show()
|
@@ -1517,10 +1517,10 @@ def log_tailed_relu(
|
|
1517
1517
|
|
1518
1518
|
>>> import jax
|
1519
1519
|
>>> import brainstate.nn as nn
|
1520
|
-
>>> import brainstate as
|
1520
|
+
>>> import brainstate as brainstate
|
1521
1521
|
>>> import matplotlib.pyplot as plt
|
1522
1522
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1523
|
-
>>> grads =
|
1523
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1524
1524
|
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1525
1525
|
>>> plt.legend()
|
1526
1526
|
>>> plt.show()
|
@@ -1596,12 +1596,12 @@ def relu_grad(
|
|
1596
1596
|
|
1597
1597
|
>>> import jax
|
1598
1598
|
>>> import brainstate.nn as nn
|
1599
|
-
>>> import brainstate as
|
1599
|
+
>>> import brainstate as brainstate
|
1600
1600
|
>>> import matplotlib.pyplot as plt
|
1601
1601
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1602
1602
|
>>> for s in [0.5, 1.]:
|
1603
1603
|
>>> for w in [1, 2.]:
|
1604
|
-
>>> grads =
|
1604
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
|
1605
1605
|
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1606
1606
|
>>> plt.legend()
|
1607
1607
|
>>> plt.show()
|
@@ -1678,11 +1678,11 @@ def gaussian_grad(
|
|
1678
1678
|
|
1679
1679
|
>>> import jax
|
1680
1680
|
>>> import brainstate.nn as nn
|
1681
|
-
>>> import brainstate as
|
1681
|
+
>>> import brainstate as brainstate
|
1682
1682
|
>>> import matplotlib.pyplot as plt
|
1683
1683
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1684
1684
|
>>> for s in [0.5, 1., 2.]:
|
1685
|
-
>>> grads =
|
1685
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
|
1686
1686
|
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1687
1687
|
>>> plt.legend()
|
1688
1688
|
>>> plt.show()
|
@@ -1773,10 +1773,10 @@ def multi_gaussian_grad(
|
|
1773
1773
|
|
1774
1774
|
>>> import jax
|
1775
1775
|
>>> import brainstate.nn as nn
|
1776
|
-
>>> import brainstate as
|
1776
|
+
>>> import brainstate as brainstate
|
1777
1777
|
>>> import matplotlib.pyplot as plt
|
1778
1778
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1779
|
-
>>> grads =
|
1779
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
|
1780
1780
|
>>> plt.plot(xs, grads)
|
1781
1781
|
>>> plt.show()
|
1782
1782
|
|
@@ -1855,11 +1855,11 @@ def inv_square_grad(
|
|
1855
1855
|
|
1856
1856
|
>>> import jax
|
1857
1857
|
>>> import brainstate.nn as nn
|
1858
|
-
>>> import brainstate as
|
1858
|
+
>>> import brainstate as brainstate
|
1859
1859
|
>>> import matplotlib.pyplot as plt
|
1860
1860
|
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1861
1861
|
>>> for alpha in [1., 10., 100.]:
|
1862
|
-
>>> grads =
|
1862
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
|
1863
1863
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1864
1864
|
>>> plt.legend()
|
1865
1865
|
>>> plt.show()
|
@@ -1929,11 +1929,11 @@ def slayer_grad(
|
|
1929
1929
|
|
1930
1930
|
>>> import jax
|
1931
1931
|
>>> import brainstate.nn as nn
|
1932
|
-
>>> import brainstate as
|
1932
|
+
>>> import brainstate as brainstate
|
1933
1933
|
>>> import matplotlib.pyplot as plt
|
1934
1934
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1935
1935
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
1936
|
-
>>> grads =
|
1936
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
|
1937
1937
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1938
1938
|
>>> plt.legend()
|
1939
1939
|
>>> plt.show()
|
brainstate/util/pretty_pytree.py
CHANGED
@@ -373,19 +373,19 @@ class NestedDict(PrettyDict):
|
|
373
373
|
|
374
374
|
Example usage::
|
375
375
|
|
376
|
-
>>> import brainstate as
|
376
|
+
>>> import brainstate as brainstate
|
377
377
|
|
378
|
-
>>> class Model(
|
378
|
+
>>> class Model(brainstate.nn.Module):
|
379
379
|
... def __init__(self):
|
380
380
|
... super().__init__()
|
381
|
-
... self.batchnorm =
|
382
|
-
... self.linear =
|
381
|
+
... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
382
|
+
... self.linear = brainstate.nn.Linear(2, 3)
|
383
383
|
... def __call__(self, x):
|
384
384
|
... return self.linear(self.batchnorm(x))
|
385
385
|
|
386
386
|
>>> model = Model()
|
387
|
-
>>> state_map =
|
388
|
-
>>> param, others = state_map.treefy_split(
|
387
|
+
>>> state_map = brainstate.graph.treefy_states(model)
|
388
|
+
>>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
|
389
389
|
|
390
390
|
Arguments:
|
391
391
|
first: The first filter
|
@@ -495,14 +495,14 @@ class FlattedDict(PrettyDict):
|
|
495
495
|
|
496
496
|
Example usage::
|
497
497
|
|
498
|
-
>>> import brainstate as
|
498
|
+
>>> import brainstate as brainstate
|
499
499
|
>>> import jax.numpy as jnp
|
500
500
|
>>>
|
501
|
-
>>> class Model(
|
501
|
+
>>> class Model(brainstate.nn.Module):
|
502
502
|
... def __init__(self):
|
503
503
|
... super().__init__()
|
504
|
-
... self.batchnorm =
|
505
|
-
... self.linear =
|
504
|
+
... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
505
|
+
... self.linear = brainstate.nn.Linear(2, 3)
|
506
506
|
... def __call__(self, x):
|
507
507
|
... return self.linear(self.batchnorm(x))
|
508
508
|
>>>
|
@@ -13,31 +13,30 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import jax
|
21
20
|
from absl.testing import absltest
|
22
21
|
|
23
|
-
import brainstate
|
22
|
+
import brainstate
|
24
23
|
|
25
24
|
|
26
25
|
class TestNestedMapping(absltest.TestCase):
|
27
26
|
def test_create_state(self):
|
28
|
-
state =
|
27
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
29
28
|
|
30
29
|
assert state['a'].value == 1
|
31
30
|
assert state['b']['c'].value == 2
|
32
31
|
|
33
32
|
def test_get_attr(self):
|
34
|
-
state =
|
33
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
35
34
|
|
36
35
|
assert state.a.value == 1
|
37
36
|
assert state.b['c'].value == 2
|
38
37
|
|
39
38
|
def test_set_attr(self):
|
40
|
-
state =
|
39
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
41
40
|
|
42
41
|
state.a.value = 3
|
43
42
|
state.b['c'].value = 4
|
@@ -46,36 +45,36 @@ class TestNestedMapping(absltest.TestCase):
|
|
46
45
|
assert state['b']['c'].value == 4
|
47
46
|
|
48
47
|
def test_set_attr_variables(self):
|
49
|
-
state =
|
48
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
50
49
|
|
51
50
|
state.a.value = 3
|
52
51
|
state.b['c'].value = 4
|
53
52
|
|
54
|
-
assert isinstance(state.a,
|
53
|
+
assert isinstance(state.a, brainstate.ParamState)
|
55
54
|
assert state.a.value == 3
|
56
|
-
assert isinstance(state.b['c'],
|
55
|
+
assert isinstance(state.b['c'], brainstate.ParamState)
|
57
56
|
assert state.b['c'].value == 4
|
58
57
|
|
59
58
|
def test_add_nested_attr(self):
|
60
|
-
state =
|
61
|
-
state.b['d'] =
|
59
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
60
|
+
state.b['d'] = brainstate.ParamState(5)
|
62
61
|
|
63
62
|
assert state['b']['d'].value == 5
|
64
63
|
|
65
64
|
def test_delete_nested_attr(self):
|
66
|
-
state =
|
65
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
67
66
|
del state['b']['c']
|
68
67
|
|
69
68
|
assert 'c' not in state['b']
|
70
69
|
|
71
70
|
def test_integer_access(self):
|
72
|
-
class Foo(
|
71
|
+
class Foo(brainstate.nn.Module):
|
73
72
|
def __init__(self):
|
74
73
|
super().__init__()
|
75
|
-
self.layers = [
|
74
|
+
self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
|
76
75
|
|
77
76
|
module = Foo()
|
78
|
-
state_refs =
|
77
|
+
state_refs = brainstate.graph.treefy_states(module)
|
79
78
|
|
80
79
|
assert module.layers[0].weight.value['weight'].shape == (1, 2)
|
81
80
|
assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
|
@@ -83,8 +82,8 @@ class TestNestedMapping(absltest.TestCase):
|
|
83
82
|
assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
|
84
83
|
|
85
84
|
def test_pure_dict(self):
|
86
|
-
module =
|
87
|
-
state_map =
|
85
|
+
module = brainstate.nn.Linear(4, 5)
|
86
|
+
state_map = brainstate.graph.treefy_states(module)
|
88
87
|
pure_dict = state_map.to_pure_dict()
|
89
88
|
assert isinstance(pure_dict, dict)
|
90
89
|
assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
|
@@ -93,27 +92,27 @@ class TestNestedMapping(absltest.TestCase):
|
|
93
92
|
|
94
93
|
class TestSplit(unittest.TestCase):
|
95
94
|
def test_split(self):
|
96
|
-
class Model(
|
95
|
+
class Model(brainstate.nn.Module):
|
97
96
|
def __init__(self):
|
98
97
|
super().__init__()
|
99
|
-
self.batchnorm =
|
100
|
-
self.linear =
|
98
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
99
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
101
100
|
|
102
101
|
def __call__(self, x):
|
103
102
|
return self.linear(self.batchnorm(x))
|
104
103
|
|
105
|
-
with
|
104
|
+
with brainstate.environ.context(fit=True):
|
106
105
|
model = Model()
|
107
|
-
x =
|
106
|
+
x = brainstate.random.randn(1, 10, 3)
|
108
107
|
y = model(x)
|
109
108
|
self.assertEqual(y.shape, (1, 10, 4))
|
110
109
|
|
111
|
-
state_map =
|
110
|
+
state_map = brainstate.graph.treefy_states(model)
|
112
111
|
|
113
112
|
with self.assertRaises(ValueError):
|
114
|
-
params, others = state_map.split(
|
113
|
+
params, others = state_map.split(brainstate.ParamState)
|
115
114
|
|
116
|
-
params, others = state_map.split(
|
115
|
+
params, others = state_map.split(brainstate.ParamState, ...)
|
117
116
|
print()
|
118
117
|
print(params)
|
119
118
|
print(others)
|
@@ -124,37 +123,37 @@ class TestSplit(unittest.TestCase):
|
|
124
123
|
|
125
124
|
class TestStateMap2(unittest.TestCase):
|
126
125
|
def test1(self):
|
127
|
-
class Model(
|
126
|
+
class Model(brainstate.nn.Module):
|
128
127
|
def __init__(self):
|
129
128
|
super().__init__()
|
130
|
-
self.batchnorm =
|
131
|
-
self.linear =
|
129
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
130
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
132
131
|
|
133
132
|
def __call__(self, x):
|
134
133
|
return self.linear(self.batchnorm(x))
|
135
134
|
|
136
|
-
with
|
135
|
+
with brainstate.environ.context(fit=True):
|
137
136
|
model = Model()
|
138
|
-
state_map =
|
139
|
-
state_map =
|
137
|
+
state_map = brainstate.graph.treefy_states(model).to_flat()
|
138
|
+
state_map = brainstate.util.NestedDict(state_map)
|
140
139
|
|
141
140
|
|
142
141
|
class TestFlattedMapping(unittest.TestCase):
|
143
142
|
def test1(self):
|
144
|
-
class Model(
|
143
|
+
class Model(brainstate.nn.Module):
|
145
144
|
def __init__(self):
|
146
145
|
super().__init__()
|
147
|
-
self.batchnorm =
|
148
|
-
self.linear =
|
146
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
147
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
149
148
|
|
150
149
|
def __call__(self, x):
|
151
150
|
return self.linear(self.batchnorm(x))
|
152
151
|
|
153
152
|
model = Model()
|
154
153
|
# print(model.states())
|
155
|
-
# print(
|
156
|
-
self.assertTrue(model.states() ==
|
154
|
+
# print(brainstate.graph.states(model))
|
155
|
+
self.assertTrue(model.states() == brainstate.graph.states(model))
|
157
156
|
|
158
157
|
print(model.nodes())
|
159
|
-
# print(
|
160
|
-
self.assertTrue(model.nodes() ==
|
158
|
+
# print(brainstate.graph.nodes(model))
|
159
|
+
self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
|
brainstate/util/struct.py
CHANGED
@@ -56,16 +56,16 @@ def dataclass(clz: T, **kwargs) -> T:
|
|
56
56
|
The ``dataclass`` decorator makes it easy to define custom classes that can be
|
57
57
|
passed safely to Jax. For example::
|
58
58
|
|
59
|
-
>>> import brainstate as
|
59
|
+
>>> import brainstate as brainstate
|
60
60
|
>>> import jax
|
61
61
|
>>> from typing import Any, Callable
|
62
62
|
|
63
|
-
>>> @
|
63
|
+
>>> @brainstate.util.dataclass
|
64
64
|
... class Model:
|
65
65
|
... params: Any
|
66
66
|
... # use pytree_node=False to indicate an attribute should not be touched
|
67
67
|
... # by Jax transformations.
|
68
|
-
... apply_fn: Callable =
|
68
|
+
... apply_fn: Callable = brainstate.util.field(pytree_node=False)
|
69
69
|
|
70
70
|
... def __apply__(self, *args):
|
71
71
|
... return self.apply_fn(*args)
|
@@ -97,7 +97,7 @@ def dataclass(clz: T, **kwargs) -> T:
|
|
97
97
|
This way the simple constructor used by ``jax.tree_util`` is
|
98
98
|
preserved. Consider the following example::
|
99
99
|
|
100
|
-
>>> @
|
100
|
+
>>> @brainstate.util.dataclass
|
101
101
|
... class DirectionAndScaleKernel:
|
102
102
|
... direction: jax.Array
|
103
103
|
... scale: jax.Array
|
@@ -189,15 +189,15 @@ class PyTreeNode:
|
|
189
189
|
|
190
190
|
Example::
|
191
191
|
|
192
|
-
>>> import brainstate as
|
192
|
+
>>> import brainstate as brainstate
|
193
193
|
>>> import jax
|
194
194
|
>>> from typing import Any, Callable
|
195
195
|
|
196
|
-
>>> class Model(
|
196
|
+
>>> class Model(brainstate.util.PyTreeNode):
|
197
197
|
... params: Any
|
198
198
|
... # use pytree_node=False to indicate an attribute should not be touched
|
199
199
|
... # by Jax transformations.
|
200
|
-
... apply_fn: Callable =
|
200
|
+
... apply_fn: Callable = brainstate.util.field(pytree_node=False)
|
201
201
|
|
202
202
|
... def __apply__(self, *args):
|
203
203
|
... return self.apply_fn(*args)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.7
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -31,22 +31,22 @@ License-File: LICENSE
|
|
31
31
|
Requires-Dist: jax
|
32
32
|
Requires-Dist: jaxlib
|
33
33
|
Requires-Dist: numpy
|
34
|
-
Requires-Dist: brainunit
|
34
|
+
Requires-Dist: brainunit>=0.0.4
|
35
35
|
Requires-Dist: brainevent
|
36
36
|
Provides-Extra: cpu
|
37
|
-
Requires-Dist: jax[cpu]
|
38
|
-
Requires-Dist: brainunit[cpu]
|
39
|
-
Requires-Dist: brainevent[cpu]
|
37
|
+
Requires-Dist: jax[cpu]; extra == "cpu"
|
38
|
+
Requires-Dist: brainunit[cpu]; extra == "cpu"
|
39
|
+
Requires-Dist: brainevent[cpu]; extra == "cpu"
|
40
40
|
Provides-Extra: cuda12
|
41
|
-
Requires-Dist: jax[cuda12]
|
42
|
-
Requires-Dist: brainunit[cuda12]
|
43
|
-
Requires-Dist: brainevent[cuda12]
|
41
|
+
Requires-Dist: jax[cuda12]; extra == "cuda12"
|
42
|
+
Requires-Dist: brainunit[cuda12]; extra == "cuda12"
|
43
|
+
Requires-Dist: brainevent[cuda12]; extra == "cuda12"
|
44
44
|
Provides-Extra: testing
|
45
|
-
Requires-Dist: pytest
|
45
|
+
Requires-Dist: pytest; extra == "testing"
|
46
46
|
Provides-Extra: tpu
|
47
|
-
Requires-Dist: jax[tpu]
|
48
|
-
Requires-Dist: brainunit[tpu]
|
49
|
-
Requires-Dist: brainevent[tpu]
|
47
|
+
Requires-Dist: jax[tpu]; extra == "tpu"
|
48
|
+
Requires-Dist: brainunit[tpu]; extra == "tpu"
|
49
|
+
Requires-Dist: brainevent[tpu]; extra == "tpu"
|
50
50
|
|
51
51
|
|
52
52
|
# A ``State``-based Transformation System for Program Compilation and Augmentation
|