brainstate 0.1.5__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +5 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +13 -8
  7. brainstate/compile/_conditions.py +2 -2
  8. brainstate/compile/_make_jaxpr.py +59 -6
  9. brainstate/compile/_progress_bar.py +2 -2
  10. brainstate/environ.py +19 -19
  11. brainstate/functional/_activations_test.py +12 -12
  12. brainstate/graph/_graph_operation.py +69 -69
  13. brainstate/graph/_graph_operation_test.py +2 -2
  14. brainstate/mixin.py +0 -17
  15. brainstate/nn/_collective_ops.py +4 -4
  16. brainstate/nn/_dropout_test.py +2 -2
  17. brainstate/nn/_dynamics.py +53 -35
  18. brainstate/nn/_elementwise.py +30 -30
  19. brainstate/nn/_linear.py +4 -4
  20. brainstate/nn/_module.py +6 -6
  21. brainstate/nn/_module_test.py +1 -1
  22. brainstate/nn/_normalizations.py +11 -11
  23. brainstate/nn/_normalizations_test.py +6 -6
  24. brainstate/nn/_poolings.py +24 -24
  25. brainstate/nn/_synapse.py +1 -12
  26. brainstate/nn/_utils.py +1 -1
  27. brainstate/nn/metrics.py +4 -4
  28. brainstate/optim/_optax_optimizer.py +8 -8
  29. brainstate/random/_rand_funs.py +37 -37
  30. brainstate/random/_rand_funs_test.py +3 -3
  31. brainstate/random/_rand_seed.py +7 -7
  32. brainstate/surrogate.py +40 -40
  33. brainstate/util/pretty_pytree.py +10 -10
  34. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  35. brainstate/util/struct.py +7 -7
  36. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  37. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/RECORD +40 -40
  38. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  39. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  40. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py CHANGED
@@ -106,11 +106,11 @@ class Surrogate(PrettyObject):
106
106
  Examples
107
107
  --------
108
108
 
109
- >>> import brainstate as bst
109
+ >>> import brainstate as brainstate
110
110
  >>> import brainstate.nn as nn
111
111
  >>> import jax.numpy as jnp
112
112
 
113
- >>> class MySurrogate(bst.surrogate.Surrogate):
113
+ >>> class MySurrogate(brainstate.surrogate.Surrogate):
114
114
  ... def __init__(self, alpha=1.):
115
115
  ... super().__init__()
116
116
  ... self.alpha = alpha
@@ -236,11 +236,11 @@ def sigmoid(
236
236
 
237
237
  >>> import jax
238
238
  >>> import brainstate.nn as nn
239
- >>> import brainstate as bst
239
+ >>> import brainstate as brainstate
240
240
  >>> import matplotlib.pyplot as plt
241
241
  >>> xs = jax.numpy.linspace(-2, 2, 1000)
242
242
  >>> for alpha in [1., 2., 4.]:
243
- >>> grads = bst.augment.vector_grad(bst.surrogate.sigmoid)(xs, alpha)
243
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
244
244
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
245
245
  >>> plt.legend()
246
246
  >>> plt.show()
@@ -355,11 +355,11 @@ def piecewise_quadratic(
355
355
 
356
356
  >>> import jax
357
357
  >>> import brainstate.nn as nn
358
- >>> import brainstate as bst
358
+ >>> import brainstate as brainstate
359
359
  >>> import matplotlib.pyplot as plt
360
360
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
361
361
  >>> for alpha in [0.5, 1., 2., 4.]:
362
- >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_quadratic)(xs, alpha)
362
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
363
363
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
364
364
  >>> plt.legend()
365
365
  >>> plt.show()
@@ -522,11 +522,11 @@ def piecewise_exp(
522
522
 
523
523
  >>> import jax
524
524
  >>> import brainstate.nn as nn
525
- >>> import brainstate as bst
525
+ >>> import brainstate as brainstate
526
526
  >>> import matplotlib.pyplot as plt
527
527
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
528
528
  >>> for alpha in [0.5, 1., 2., 4.]:
529
- >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_exp)(xs, alpha)
529
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
530
530
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
531
531
  >>> plt.legend()
532
532
  >>> plt.show()
@@ -621,11 +621,11 @@ def soft_sign(
621
621
 
622
622
  >>> import jax
623
623
  >>> import brainstate.nn as nn
624
- >>> import brainstate as bst
624
+ >>> import brainstate as brainstate
625
625
  >>> import matplotlib.pyplot as plt
626
626
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
627
627
  >>> for alpha in [0.5, 1., 2., 4.]:
628
- >>> grads = bst.augment.vector_grad(bst.surrogate.soft_sign)(xs, alpha)
628
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
629
629
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
630
630
  >>> plt.legend()
631
631
  >>> plt.show()
@@ -706,11 +706,11 @@ def arctan(
706
706
 
707
707
  >>> import jax
708
708
  >>> import brainstate.nn as nn
709
- >>> import brainstate as bst
709
+ >>> import brainstate as brainstate
710
710
  >>> import matplotlib.pyplot as plt
711
711
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
712
712
  >>> for alpha in [0.5, 1., 2., 4.]:
713
- >>> grads = bst.augment.vector_grad(bst.surrogate.arctan)(xs, alpha)
713
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
714
714
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
715
715
  >>> plt.legend()
716
716
  >>> plt.show()
@@ -804,11 +804,11 @@ def nonzero_sign_log(
804
804
 
805
805
  >>> import jax
806
806
  >>> import brainstate.nn as nn
807
- >>> import brainstate as bst
807
+ >>> import brainstate as brainstate
808
808
  >>> import matplotlib.pyplot as plt
809
809
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
810
810
  >>> for alpha in [0.5, 1., 2., 4.]:
811
- >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
811
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
812
812
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
813
813
  >>> plt.legend()
814
814
  >>> plt.show()
@@ -893,11 +893,11 @@ def erf(
893
893
 
894
894
  >>> import jax
895
895
  >>> import brainstate.nn as nn
896
- >>> import brainstate as bst
896
+ >>> import brainstate as brainstate
897
897
  >>> import matplotlib.pyplot as plt
898
898
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
899
899
  >>> for alpha in [0.5, 1., 2., 4.]:
900
- >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
900
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
901
901
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
902
902
  >>> plt.legend()
903
903
  >>> plt.show()
@@ -1000,12 +1000,12 @@ def piecewise_leaky_relu(
1000
1000
 
1001
1001
  >>> import jax
1002
1002
  >>> import brainstate.nn as nn
1003
- >>> import brainstate as bst
1003
+ >>> import brainstate as brainstate
1004
1004
  >>> import matplotlib.pyplot as plt
1005
1005
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1006
1006
  >>> for c in [0.01, 0.05, 0.1]:
1007
1007
  >>> for w in [1., 2.]:
1008
- >>> grads1 = bst.augment.vector_grad(bst.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
1008
+ >>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
1009
1009
  >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
1010
1010
  >>> plt.legend()
1011
1011
  >>> plt.show()
@@ -1113,12 +1113,12 @@ def squarewave_fourier_series(
1113
1113
 
1114
1114
  >>> import jax
1115
1115
  >>> import brainstate.nn as nn
1116
- >>> import brainstate as bst
1116
+ >>> import brainstate as brainstate
1117
1117
  >>> import matplotlib.pyplot as plt
1118
1118
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1119
1119
  >>> for n in [2, 4, 8]:
1120
- >>> f = bst.surrogate.SquarewaveFourierSeries(n=n)
1121
- >>> grads1 = bst.augment.vector_grad(f)(xs)
1120
+ >>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
1121
+ >>> grads1 = brainstate.augment.vector_grad(f)(xs)
1122
1122
  >>> plt.plot(xs, grads1, label=f'n={n}')
1123
1123
  >>> plt.legend()
1124
1124
  >>> plt.show()
@@ -1214,12 +1214,12 @@ def s2nn(
1214
1214
 
1215
1215
  >>> import jax
1216
1216
  >>> import brainstate.nn as nn
1217
- >>> import brainstate as bst
1217
+ >>> import brainstate as brainstate
1218
1218
  >>> import matplotlib.pyplot as plt
1219
1219
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1220
- >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 4., 1.)
1220
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
1221
1221
  >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1222
- >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 8., 2.)
1222
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
1223
1223
  >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1224
1224
  >>> plt.legend()
1225
1225
  >>> plt.show()
@@ -1315,11 +1315,11 @@ def q_pseudo_spike(
1315
1315
 
1316
1316
  >>> import jax
1317
1317
  >>> import brainstate.nn as nn
1318
- >>> import brainstate as bst
1318
+ >>> import brainstate as brainstate
1319
1319
  >>> import matplotlib.pyplot as plt
1320
1320
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1321
1321
  >>> for alpha in [0.5, 1., 2., 4.]:
1322
- >>> grads = bst.augment.vector_grad(bst.surrogate.q_pseudo_spike)(xs, alpha)
1322
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
1323
1323
  >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1324
1324
  >>> plt.legend()
1325
1325
  >>> plt.show()
@@ -1413,10 +1413,10 @@ def leaky_relu(
1413
1413
 
1414
1414
  >>> import jax
1415
1415
  >>> import brainstate.nn as nn
1416
- >>> import brainstate as bst
1416
+ >>> import brainstate as brainstate
1417
1417
  >>> import matplotlib.pyplot as plt
1418
1418
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1419
- >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1419
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1420
1420
  >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1421
1421
  >>> plt.legend()
1422
1422
  >>> plt.show()
@@ -1517,10 +1517,10 @@ def log_tailed_relu(
1517
1517
 
1518
1518
  >>> import jax
1519
1519
  >>> import brainstate.nn as nn
1520
- >>> import brainstate as bst
1520
+ >>> import brainstate as brainstate
1521
1521
  >>> import matplotlib.pyplot as plt
1522
1522
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1523
- >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1523
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1524
1524
  >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1525
1525
  >>> plt.legend()
1526
1526
  >>> plt.show()
@@ -1596,12 +1596,12 @@ def relu_grad(
1596
1596
 
1597
1597
  >>> import jax
1598
1598
  >>> import brainstate.nn as nn
1599
- >>> import brainstate as bst
1599
+ >>> import brainstate as brainstate
1600
1600
  >>> import matplotlib.pyplot as plt
1601
1601
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1602
1602
  >>> for s in [0.5, 1.]:
1603
1603
  >>> for w in [1, 2.]:
1604
- >>> grads = bst.augment.vector_grad(bst.surrogate.relu_grad)(xs, s, w)
1604
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
1605
1605
  >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1606
1606
  >>> plt.legend()
1607
1607
  >>> plt.show()
@@ -1678,11 +1678,11 @@ def gaussian_grad(
1678
1678
 
1679
1679
  >>> import jax
1680
1680
  >>> import brainstate.nn as nn
1681
- >>> import brainstate as bst
1681
+ >>> import brainstate as brainstate
1682
1682
  >>> import matplotlib.pyplot as plt
1683
1683
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1684
1684
  >>> for s in [0.5, 1., 2.]:
1685
- >>> grads = bst.augment.vector_grad(bst.surrogate.gaussian_grad)(xs, s, 0.5)
1685
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
1686
1686
  >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1687
1687
  >>> plt.legend()
1688
1688
  >>> plt.show()
@@ -1773,10 +1773,10 @@ def multi_gaussian_grad(
1773
1773
 
1774
1774
  >>> import jax
1775
1775
  >>> import brainstate.nn as nn
1776
- >>> import brainstate as bst
1776
+ >>> import brainstate as brainstate
1777
1777
  >>> import matplotlib.pyplot as plt
1778
1778
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1779
- >>> grads = bst.augment.vector_grad(bst.surrogate.multi_gaussian_grad)(xs)
1779
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
1780
1780
  >>> plt.plot(xs, grads)
1781
1781
  >>> plt.show()
1782
1782
 
@@ -1855,11 +1855,11 @@ def inv_square_grad(
1855
1855
 
1856
1856
  >>> import jax
1857
1857
  >>> import brainstate.nn as nn
1858
- >>> import brainstate as bst
1858
+ >>> import brainstate as brainstate
1859
1859
  >>> import matplotlib.pyplot as plt
1860
1860
  >>> xs = jax.numpy.linspace(-1, 1, 1000)
1861
1861
  >>> for alpha in [1., 10., 100.]:
1862
- >>> grads = bst.augment.vector_grad(bst.surrogate.inv_square_grad)(xs, alpha)
1862
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
1863
1863
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1864
1864
  >>> plt.legend()
1865
1865
  >>> plt.show()
@@ -1929,11 +1929,11 @@ def slayer_grad(
1929
1929
 
1930
1930
  >>> import jax
1931
1931
  >>> import brainstate.nn as nn
1932
- >>> import brainstate as bst
1932
+ >>> import brainstate as brainstate
1933
1933
  >>> import matplotlib.pyplot as plt
1934
1934
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1935
1935
  >>> for alpha in [0.5, 1., 2., 4.]:
1936
- >>> grads = bst.augment.vector_grad(bst.surrogate.slayer_grad)(xs, alpha)
1936
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
1937
1937
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1938
1938
  >>> plt.legend()
1939
1939
  >>> plt.show()
@@ -373,19 +373,19 @@ class NestedDict(PrettyDict):
373
373
 
374
374
  Example usage::
375
375
 
376
- >>> import brainstate as bst
376
+ >>> import brainstate as brainstate
377
377
 
378
- >>> class Model(bst.nn.Module):
378
+ >>> class Model(brainstate.nn.Module):
379
379
  ... def __init__(self):
380
380
  ... super().__init__()
381
- ... self.batchnorm = bst.nn.BatchNorm1d([10, 3])
382
- ... self.linear = bst.nn.Linear(2, 3)
381
+ ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
382
+ ... self.linear = brainstate.nn.Linear(2, 3)
383
383
  ... def __call__(self, x):
384
384
  ... return self.linear(self.batchnorm(x))
385
385
 
386
386
  >>> model = Model()
387
- >>> state_map = bst.graph.treefy_states(model)
388
- >>> param, others = state_map.treefy_split(bst.ParamState, ...)
387
+ >>> state_map = brainstate.graph.treefy_states(model)
388
+ >>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
389
389
 
390
390
  Arguments:
391
391
  first: The first filter
@@ -495,14 +495,14 @@ class FlattedDict(PrettyDict):
495
495
 
496
496
  Example usage::
497
497
 
498
- >>> import brainstate as bst
498
+ >>> import brainstate as brainstate
499
499
  >>> import jax.numpy as jnp
500
500
  >>>
501
- >>> class Model(bst.nn.Module):
501
+ >>> class Model(brainstate.nn.Module):
502
502
  ... def __init__(self):
503
503
  ... super().__init__()
504
- ... self.batchnorm = bst.nn.BatchNorm1d([10, 3])
505
- ... self.linear = bst.nn.Linear(2, 3)
504
+ ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
505
+ ... self.linear = brainstate.nn.Linear(2, 3)
506
506
  ... def __call__(self, x):
507
507
  ... return self.linear(self.batchnorm(x))
508
508
  >>>
@@ -13,31 +13,30 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax
21
20
  from absl.testing import absltest
22
21
 
23
- import brainstate as bst
22
+ import brainstate
24
23
 
25
24
 
26
25
  class TestNestedMapping(absltest.TestCase):
27
26
  def test_create_state(self):
28
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
27
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
29
28
 
30
29
  assert state['a'].value == 1
31
30
  assert state['b']['c'].value == 2
32
31
 
33
32
  def test_get_attr(self):
34
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
33
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
35
34
 
36
35
  assert state.a.value == 1
37
36
  assert state.b['c'].value == 2
38
37
 
39
38
  def test_set_attr(self):
40
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
39
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
41
40
 
42
41
  state.a.value = 3
43
42
  state.b['c'].value = 4
@@ -46,36 +45,36 @@ class TestNestedMapping(absltest.TestCase):
46
45
  assert state['b']['c'].value == 4
47
46
 
48
47
  def test_set_attr_variables(self):
49
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
48
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
50
49
 
51
50
  state.a.value = 3
52
51
  state.b['c'].value = 4
53
52
 
54
- assert isinstance(state.a, bst.ParamState)
53
+ assert isinstance(state.a, brainstate.ParamState)
55
54
  assert state.a.value == 3
56
- assert isinstance(state.b['c'], bst.ParamState)
55
+ assert isinstance(state.b['c'], brainstate.ParamState)
57
56
  assert state.b['c'].value == 4
58
57
 
59
58
  def test_add_nested_attr(self):
60
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
61
- state.b['d'] = bst.ParamState(5)
59
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
60
+ state.b['d'] = brainstate.ParamState(5)
62
61
 
63
62
  assert state['b']['d'].value == 5
64
63
 
65
64
  def test_delete_nested_attr(self):
66
- state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
65
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
67
66
  del state['b']['c']
68
67
 
69
68
  assert 'c' not in state['b']
70
69
 
71
70
  def test_integer_access(self):
72
- class Foo(bst.nn.Module):
71
+ class Foo(brainstate.nn.Module):
73
72
  def __init__(self):
74
73
  super().__init__()
75
- self.layers = [bst.nn.Linear(1, 2), bst.nn.Linear(2, 3)]
74
+ self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
76
75
 
77
76
  module = Foo()
78
- state_refs = bst.graph.treefy_states(module)
77
+ state_refs = brainstate.graph.treefy_states(module)
79
78
 
80
79
  assert module.layers[0].weight.value['weight'].shape == (1, 2)
81
80
  assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
@@ -83,8 +82,8 @@ class TestNestedMapping(absltest.TestCase):
83
82
  assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
84
83
 
85
84
  def test_pure_dict(self):
86
- module = bst.nn.Linear(4, 5)
87
- state_map = bst.graph.treefy_states(module)
85
+ module = brainstate.nn.Linear(4, 5)
86
+ state_map = brainstate.graph.treefy_states(module)
88
87
  pure_dict = state_map.to_pure_dict()
89
88
  assert isinstance(pure_dict, dict)
90
89
  assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
@@ -93,27 +92,27 @@ class TestNestedMapping(absltest.TestCase):
93
92
 
94
93
  class TestSplit(unittest.TestCase):
95
94
  def test_split(self):
96
- class Model(bst.nn.Module):
95
+ class Model(brainstate.nn.Module):
97
96
  def __init__(self):
98
97
  super().__init__()
99
- self.batchnorm = bst.nn.BatchNorm1d([10, 3])
100
- self.linear = bst.nn.Linear([10, 3], [10, 4])
98
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
99
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
101
100
 
102
101
  def __call__(self, x):
103
102
  return self.linear(self.batchnorm(x))
104
103
 
105
- with bst.environ.context(fit=True):
104
+ with brainstate.environ.context(fit=True):
106
105
  model = Model()
107
- x = bst.random.randn(1, 10, 3)
106
+ x = brainstate.random.randn(1, 10, 3)
108
107
  y = model(x)
109
108
  self.assertEqual(y.shape, (1, 10, 4))
110
109
 
111
- state_map = bst.graph.treefy_states(model)
110
+ state_map = brainstate.graph.treefy_states(model)
112
111
 
113
112
  with self.assertRaises(ValueError):
114
- params, others = state_map.split(bst.ParamState)
113
+ params, others = state_map.split(brainstate.ParamState)
115
114
 
116
- params, others = state_map.split(bst.ParamState, ...)
115
+ params, others = state_map.split(brainstate.ParamState, ...)
117
116
  print()
118
117
  print(params)
119
118
  print(others)
@@ -124,37 +123,37 @@ class TestSplit(unittest.TestCase):
124
123
 
125
124
  class TestStateMap2(unittest.TestCase):
126
125
  def test1(self):
127
- class Model(bst.nn.Module):
126
+ class Model(brainstate.nn.Module):
128
127
  def __init__(self):
129
128
  super().__init__()
130
- self.batchnorm = bst.nn.BatchNorm1d([10, 3])
131
- self.linear = bst.nn.Linear([10, 3], [10, 4])
129
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
130
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
132
131
 
133
132
  def __call__(self, x):
134
133
  return self.linear(self.batchnorm(x))
135
134
 
136
- with bst.environ.context(fit=True):
135
+ with brainstate.environ.context(fit=True):
137
136
  model = Model()
138
- state_map = bst.graph.treefy_states(model).to_flat()
139
- state_map = bst.util.NestedDict(state_map)
137
+ state_map = brainstate.graph.treefy_states(model).to_flat()
138
+ state_map = brainstate.util.NestedDict(state_map)
140
139
 
141
140
 
142
141
  class TestFlattedMapping(unittest.TestCase):
143
142
  def test1(self):
144
- class Model(bst.nn.Module):
143
+ class Model(brainstate.nn.Module):
145
144
  def __init__(self):
146
145
  super().__init__()
147
- self.batchnorm = bst.nn.BatchNorm1d([10, 3])
148
- self.linear = bst.nn.Linear([10, 3], [10, 4])
146
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
147
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
149
148
 
150
149
  def __call__(self, x):
151
150
  return self.linear(self.batchnorm(x))
152
151
 
153
152
  model = Model()
154
153
  # print(model.states())
155
- # print(bst.graph.states(model))
156
- self.assertTrue(model.states() == bst.graph.states(model))
154
+ # print(brainstate.graph.states(model))
155
+ self.assertTrue(model.states() == brainstate.graph.states(model))
157
156
 
158
157
  print(model.nodes())
159
- # print(bst.graph.nodes(model))
160
- self.assertTrue(model.nodes() == bst.graph.nodes(model))
158
+ # print(brainstate.graph.nodes(model))
159
+ self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
brainstate/util/struct.py CHANGED
@@ -56,16 +56,16 @@ def dataclass(clz: T, **kwargs) -> T:
56
56
  The ``dataclass`` decorator makes it easy to define custom classes that can be
57
57
  passed safely to Jax. For example::
58
58
 
59
- >>> import brainstate as bst
59
+ >>> import brainstate as brainstate
60
60
  >>> import jax
61
61
  >>> from typing import Any, Callable
62
62
 
63
- >>> @bst.util.dataclass
63
+ >>> @brainstate.util.dataclass
64
64
  ... class Model:
65
65
  ... params: Any
66
66
  ... # use pytree_node=False to indicate an attribute should not be touched
67
67
  ... # by Jax transformations.
68
- ... apply_fn: Callable = bst.util.field(pytree_node=False)
68
+ ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
69
69
 
70
70
  ... def __apply__(self, *args):
71
71
  ... return self.apply_fn(*args)
@@ -97,7 +97,7 @@ def dataclass(clz: T, **kwargs) -> T:
97
97
  This way the simple constructor used by ``jax.tree_util`` is
98
98
  preserved. Consider the following example::
99
99
 
100
- >>> @bst.util.dataclass
100
+ >>> @brainstate.util.dataclass
101
101
  ... class DirectionAndScaleKernel:
102
102
  ... direction: jax.Array
103
103
  ... scale: jax.Array
@@ -189,15 +189,15 @@ class PyTreeNode:
189
189
 
190
190
  Example::
191
191
 
192
- >>> import brainstate as bst
192
+ >>> import brainstate as brainstate
193
193
  >>> import jax
194
194
  >>> from typing import Any, Callable
195
195
 
196
- >>> class Model(bst.util.PyTreeNode):
196
+ >>> class Model(brainstate.util.PyTreeNode):
197
197
  ... params: Any
198
198
  ... # use pytree_node=False to indicate an attribute should not be touched
199
199
  ... # by Jax transformations.
200
- ... apply_fn: Callable = bst.util.field(pytree_node=False)
200
+ ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
201
201
 
202
202
  ... def __apply__(self, *args):
203
203
  ... return self.apply_fn(*args)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -31,22 +31,22 @@ License-File: LICENSE
31
31
  Requires-Dist: jax
32
32
  Requires-Dist: jaxlib
33
33
  Requires-Dist: numpy
34
- Requires-Dist: brainunit (>=0.0.4)
34
+ Requires-Dist: brainunit>=0.0.4
35
35
  Requires-Dist: brainevent
36
36
  Provides-Extra: cpu
37
- Requires-Dist: jax[cpu] ; extra == 'cpu'
38
- Requires-Dist: brainunit[cpu] ; extra == 'cpu'
39
- Requires-Dist: brainevent[cpu] ; extra == 'cpu'
37
+ Requires-Dist: jax[cpu]; extra == "cpu"
38
+ Requires-Dist: brainunit[cpu]; extra == "cpu"
39
+ Requires-Dist: brainevent[cpu]; extra == "cpu"
40
40
  Provides-Extra: cuda12
41
- Requires-Dist: jax[cuda12] ; extra == 'cuda12'
42
- Requires-Dist: brainunit[cuda12] ; extra == 'cuda12'
43
- Requires-Dist: brainevent[cuda12] ; extra == 'cuda12'
41
+ Requires-Dist: jax[cuda12]; extra == "cuda12"
42
+ Requires-Dist: brainunit[cuda12]; extra == "cuda12"
43
+ Requires-Dist: brainevent[cuda12]; extra == "cuda12"
44
44
  Provides-Extra: testing
45
- Requires-Dist: pytest ; extra == 'testing'
45
+ Requires-Dist: pytest; extra == "testing"
46
46
  Provides-Extra: tpu
47
- Requires-Dist: jax[tpu] ; extra == 'tpu'
48
- Requires-Dist: brainunit[tpu] ; extra == 'tpu'
49
- Requires-Dist: brainevent[tpu] ; extra == 'tpu'
47
+ Requires-Dist: jax[tpu]; extra == "tpu"
48
+ Requires-Dist: brainunit[tpu]; extra == "tpu"
49
+ Requires-Dist: brainevent[tpu]; extra == "tpu"
50
50
 
51
51
 
52
52
  # A ``State``-based Transformation System for Program Compilation and Augmentation