brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +6 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +22 -17
  7. brainstate/augment/_mapping_test.py +162 -0
  8. brainstate/compile/_conditions.py +2 -2
  9. brainstate/compile/_make_jaxpr.py +59 -6
  10. brainstate/compile/_progress_bar.py +2 -2
  11. brainstate/environ.py +19 -19
  12. brainstate/functional/_activations_test.py +12 -12
  13. brainstate/graph/_graph_operation.py +69 -69
  14. brainstate/graph/_graph_operation_test.py +2 -2
  15. brainstate/mixin.py +0 -17
  16. brainstate/nn/_collective_ops.py +4 -4
  17. brainstate/nn/_common.py +7 -19
  18. brainstate/nn/_dropout_test.py +2 -2
  19. brainstate/nn/_dynamics.py +53 -35
  20. brainstate/nn/_elementwise.py +30 -30
  21. brainstate/nn/_exp_euler.py +13 -16
  22. brainstate/nn/_inputs.py +1 -1
  23. brainstate/nn/_linear.py +4 -4
  24. brainstate/nn/_module.py +6 -6
  25. brainstate/nn/_module_test.py +1 -1
  26. brainstate/nn/_normalizations.py +11 -11
  27. brainstate/nn/_normalizations_test.py +6 -6
  28. brainstate/nn/_poolings.py +24 -24
  29. brainstate/nn/_synapse.py +1 -12
  30. brainstate/nn/_utils.py +1 -1
  31. brainstate/nn/metrics.py +4 -4
  32. brainstate/optim/_optax_optimizer.py +8 -8
  33. brainstate/random/_rand_funs.py +37 -37
  34. brainstate/random/_rand_funs_test.py +3 -3
  35. brainstate/random/_rand_seed.py +7 -7
  36. brainstate/random/_rand_state.py +13 -7
  37. brainstate/surrogate.py +40 -40
  38. brainstate/util/pretty_pytree.py +10 -10
  39. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  40. brainstate/util/struct.py +7 -7
  41. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  42. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
  43. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  44. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ class Threshold(ElementWiseBlock):
61
61
  Examples::
62
62
 
63
63
  >>> import brainstate.nn as nn
64
- >>> import brainstate as bst
64
+ >>> import brainstate
65
65
  >>> m = nn.Threshold(0.1, 20)
66
66
  >>> x = random.randn(2)
67
67
  >>> output = m(x)
@@ -97,7 +97,7 @@ class ReLU(ElementWiseBlock):
97
97
  Examples::
98
98
 
99
99
  >>> import brainstate.nn as nn
100
- >>> import brainstate as bst
100
+ >>> import brainstate as brainstate
101
101
  >>> m = nn.ReLU()
102
102
  >>> x = random.randn(2)
103
103
  >>> output = m(x)
@@ -106,7 +106,7 @@ class ReLU(ElementWiseBlock):
106
106
  An implementation of CReLU - https://arxiv.org/abs/1603.05201
107
107
 
108
108
  >>> import brainstate.nn as nn
109
- >>> import brainstate as bst
109
+ >>> import brainstate as brainstate
110
110
  >>> m = nn.ReLU()
111
111
  >>> x = random.randn(2).unsqueeze(0)
112
112
  >>> output = jax.numpy.concat((m(x), m(-x)))
@@ -151,7 +151,7 @@ class RReLU(ElementWiseBlock):
151
151
  Examples::
152
152
 
153
153
  >>> import brainstate.nn as nn
154
- >>> import brainstate as bst
154
+ >>> import brainstate as brainstate
155
155
  >>> m = nn.RReLU(0.1, 0.3)
156
156
  >>> x = random.randn(2)
157
157
  >>> output = m(x)
@@ -205,7 +205,7 @@ class Hardtanh(ElementWiseBlock):
205
205
  Examples::
206
206
 
207
207
  >>> import brainstate.nn as nn
208
- >>> import brainstate as bst
208
+ >>> import brainstate as brainstate
209
209
  >>> m = nn.Hardtanh(-2, 2)
210
210
  >>> x = random.randn(2)
211
211
  >>> output = m(x)
@@ -244,7 +244,7 @@ class ReLU6(Hardtanh, ElementWiseBlock):
244
244
  Examples::
245
245
 
246
246
  >>> import brainstate.nn as nn
247
- >>> import brainstate as bst
247
+ >>> import brainstate as brainstate
248
248
  >>> m = nn.ReLU6()
249
249
  >>> x = random.randn(2)
250
250
  >>> output = m(x)
@@ -269,7 +269,7 @@ class Sigmoid(ElementWiseBlock):
269
269
  Examples::
270
270
 
271
271
  >>> import brainstate.nn as nn
272
- >>> import brainstate as bst
272
+ >>> import brainstate as brainstate
273
273
  >>> m = nn.Sigmoid()
274
274
  >>> x = random.randn(2)
275
275
  >>> output = m(x)
@@ -299,7 +299,7 @@ class Hardsigmoid(ElementWiseBlock):
299
299
  Examples::
300
300
 
301
301
  >>> import brainstate.nn as nn
302
- >>> import brainstate as bst
302
+ >>> import brainstate as brainstate
303
303
  >>> m = nn.Hardsigmoid()
304
304
  >>> x = random.randn(2)
305
305
  >>> output = m(x)
@@ -325,7 +325,7 @@ class Tanh(ElementWiseBlock):
325
325
  Examples::
326
326
 
327
327
  >>> import brainstate.nn as nn
328
- >>> import brainstate as bst
328
+ >>> import brainstate as brainstate
329
329
  >>> m = nn.Tanh()
330
330
  >>> x = random.randn(2)
331
331
  >>> output = m(x)
@@ -386,7 +386,7 @@ class Mish(ElementWiseBlock):
386
386
  Examples::
387
387
 
388
388
  >>> import brainstate.nn as nn
389
- >>> import brainstate as bst
389
+ >>> import brainstate as brainstate
390
390
  >>> m = nn.Mish()
391
391
  >>> x = random.randn(2)
392
392
  >>> output = m(x)
@@ -418,7 +418,7 @@ class Hardswish(ElementWiseBlock):
418
418
  Examples::
419
419
 
420
420
  >>> import brainstate.nn as nn
421
- >>> import brainstate as bst
421
+ >>> import brainstate as brainstate
422
422
  >>> m = nn.Hardswish()
423
423
  >>> x = random.randn(2)
424
424
  >>> output = m(x)
@@ -452,7 +452,7 @@ class ELU(ElementWiseBlock):
452
452
  Examples::
453
453
 
454
454
  >>> import brainstate.nn as nn
455
- >>> import brainstate as bst
455
+ >>> import brainstate as brainstate
456
456
  >>> m = nn.ELU()
457
457
  >>> x = random.randn(2)
458
458
  >>> output = m(x)
@@ -489,7 +489,7 @@ class CELU(ElementWiseBlock):
489
489
  Examples::
490
490
 
491
491
  >>> import brainstate.nn as nn
492
- >>> import brainstate as bst
492
+ >>> import brainstate as brainstate
493
493
  >>> m = nn.CELU()
494
494
  >>> x = random.randn(2)
495
495
  >>> output = m(x)
@@ -530,7 +530,7 @@ class SELU(ElementWiseBlock):
530
530
  Examples::
531
531
 
532
532
  >>> import brainstate.nn as nn
533
- >>> import brainstate as bst
533
+ >>> import brainstate as brainstate
534
534
  >>> m = nn.SELU()
535
535
  >>> x = random.randn(2)
536
536
  >>> output = m(x)
@@ -559,7 +559,7 @@ class GLU(ElementWiseBlock):
559
559
  Examples::
560
560
 
561
561
  >>> import brainstate.nn as nn
562
- >>> import brainstate as bst
562
+ >>> import brainstate as brainstate
563
563
  >>> m = nn.GLU()
564
564
  >>> x = random.randn(4, 2)
565
565
  >>> output = m(x)
@@ -600,7 +600,7 @@ class GELU(ElementWiseBlock):
600
600
  Examples::
601
601
 
602
602
  >>> import brainstate.nn as nn
603
- >>> import brainstate as bst
603
+ >>> import brainstate as brainstate
604
604
  >>> m = nn.GELU()
605
605
  >>> x = random.randn(2)
606
606
  >>> output = m(x)
@@ -642,7 +642,7 @@ class Hardshrink(ElementWiseBlock):
642
642
  Examples::
643
643
 
644
644
  >>> import brainstate.nn as nn
645
- >>> import brainstate as bst
645
+ >>> import brainstate as brainstate
646
646
  >>> m = nn.Hardshrink()
647
647
  >>> x = random.randn(2)
648
648
  >>> output = m(x)
@@ -689,7 +689,7 @@ class LeakyReLU(ElementWiseBlock):
689
689
  Examples::
690
690
 
691
691
  >>> import brainstate.nn as nn
692
- >>> import brainstate as bst
692
+ >>> import brainstate as brainstate
693
693
  >>> m = nn.LeakyReLU(0.1)
694
694
  >>> x = random.randn(2)
695
695
  >>> output = m(x)
@@ -721,7 +721,7 @@ class LogSigmoid(ElementWiseBlock):
721
721
  Examples::
722
722
 
723
723
  >>> import brainstate.nn as nn
724
- >>> import brainstate as bst
724
+ >>> import brainstate as brainstate
725
725
  >>> m = nn.LogSigmoid()
726
726
  >>> x = random.randn(2)
727
727
  >>> output = m(x)
@@ -749,7 +749,7 @@ class Softplus(ElementWiseBlock):
749
749
  Examples::
750
750
 
751
751
  >>> import brainstate.nn as nn
752
- >>> import brainstate as bst
752
+ >>> import brainstate as brainstate
753
753
  >>> m = nn.Softplus()
754
754
  >>> x = random.randn(2)
755
755
  >>> output = m(x)
@@ -781,7 +781,7 @@ class Softshrink(ElementWiseBlock):
781
781
  Examples::
782
782
 
783
783
  >>> import brainstate.nn as nn
784
- >>> import brainstate as bst
784
+ >>> import brainstate as brainstate
785
785
  >>> m = nn.Softshrink()
786
786
  >>> x = random.randn(2)
787
787
  >>> output = m(x)
@@ -843,9 +843,9 @@ class PReLU(ElementWiseBlock):
843
843
 
844
844
  Examples::
845
845
 
846
- >>> import brainstate as bst
847
- >>> m = bst.nn.PReLU()
848
- >>> x = bst.random.randn(2)
846
+ >>> import brainstate as brainstate
847
+ >>> m = brainstate.nn.PReLU()
848
+ >>> x = brainstate.random.randn(2)
849
849
  >>> output = m(x)
850
850
  """
851
851
  __module__ = 'brainstate.nn'
@@ -876,7 +876,7 @@ class Softsign(ElementWiseBlock):
876
876
  Examples::
877
877
 
878
878
  >>> import brainstate.nn as nn
879
- >>> import brainstate as bst
879
+ >>> import brainstate as brainstate
880
880
  >>> m = nn.Softsign()
881
881
  >>> x = random.randn(2)
882
882
  >>> output = m(x)
@@ -900,7 +900,7 @@ class Tanhshrink(ElementWiseBlock):
900
900
  Examples::
901
901
 
902
902
  >>> import brainstate.nn as nn
903
- >>> import brainstate as bst
903
+ >>> import brainstate as brainstate
904
904
  >>> m = nn.Tanhshrink()
905
905
  >>> x = random.randn(2)
906
906
  >>> output = m(x)
@@ -937,7 +937,7 @@ class Softmin(ElementWiseBlock):
937
937
  Examples::
938
938
 
939
939
  >>> import brainstate.nn as nn
940
- >>> import brainstate as bst
940
+ >>> import brainstate as brainstate
941
941
  >>> m = nn.Softmin(dim=1)
942
942
  >>> x = random.randn(2, 3)
943
943
  >>> output = m(x)
@@ -990,7 +990,7 @@ class Softmax(ElementWiseBlock):
990
990
  Examples::
991
991
 
992
992
  >>> import brainstate.nn as nn
993
- >>> import brainstate as bst
993
+ >>> import brainstate as brainstate
994
994
  >>> m = nn.Softmax(dim=1)
995
995
  >>> x = random.randn(2, 3)
996
996
  >>> output = m(x)
@@ -1027,7 +1027,7 @@ class Softmax2d(ElementWiseBlock):
1027
1027
  Examples::
1028
1028
 
1029
1029
  >>> import brainstate.nn as nn
1030
- >>> import brainstate as bst
1030
+ >>> import brainstate as brainstate
1031
1031
  >>> m = nn.Softmax2d()
1032
1032
  >>> # you softmax over the 2nd dimension
1033
1033
  >>> x = random.randn(2, 3, 12, 13)
@@ -1062,7 +1062,7 @@ class LogSoftmax(ElementWiseBlock):
1062
1062
  Examples::
1063
1063
 
1064
1064
  >>> import brainstate.nn as nn
1065
- >>> import brainstate as bst
1065
+ >>> import brainstate as brainstate
1066
1066
  >>> m = nn.LogSoftmax(dim=1)
1067
1067
  >>> x = random.randn(2, 3)
1068
1068
  >>> output = m(x)
@@ -69,27 +69,24 @@ def exp_euler_step(
69
69
  f'The input data type should be float64, float32, float16, or bfloat16 '
70
70
  f'when using Exponential Euler method. But we got {args[0].dtype}.'
71
71
  )
72
+
73
+ # drift
72
74
  dt = environ.get('dt')
73
75
  linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
74
76
  linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
75
77
  phi = u.math.exprel(dt * linear)
76
78
  x_next = args[0] + dt * phi * derivative
77
79
 
80
+ # diffusion
78
81
  if diffusion is not None:
79
- # unit checking
80
- diffusion = diffusion(*args, **kwargs)
81
- time_unit = u.get_unit(dt)
82
- drift_unit = u.get_unit(derivative)
83
- diffusion_unit = u.get_unit(diffusion)
84
- # if drift_unit.is_unitless:
85
- # assert diffusion_unit.is_unitless, 'The diffusion term should be unitless when the drift term is unitless.'
86
- # else:
87
- # u.fail_for_dimension_mismatch(
88
- # drift_unit, diffusion_unit * time_unit ** 0.5,
89
- # "Drift unit is {drift}, diffusion unit is {diffusion}, ",
90
- # drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
91
- # )
92
-
93
- # diffusion
94
- x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
82
+ diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
83
+ if u.get_dim(x_next) != u.get_dim(diffusion_part):
84
+ drift_unit = u.get_unit(x_next)
85
+ time_unit = u.get_unit(dt)
86
+ raise ValueError(
87
+ f"Drift unit is {drift_unit}, "
88
+ f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
89
+ f"but we got {u.get_unit(diffusion_part)}."
90
+ )
91
+ x_next += diffusion_part
95
92
  return x_next
brainstate/nn/_inputs.py CHANGED
@@ -547,7 +547,7 @@ def poisson_input(
547
547
  num_input,
548
548
  p,
549
549
  tar[indices].shape,
550
- # check_valid=False,
550
+ check_valid=False,
551
551
  dtype=tar.dtype
552
552
  ),
553
553
  tar_val,
brainstate/nn/_linear.py CHANGED
@@ -361,9 +361,9 @@ class LoRA(Module):
361
361
 
362
362
  Example usage::
363
363
 
364
- >>> import brainstate as bst
364
+ >>> import brainstate as brainstate
365
365
  >>> import jax, jax.numpy as jnp
366
- >>> layer = bst.nn.LoRA(3, 2, 4)
366
+ >>> layer = brainstate.nn.LoRA(3, 2, 4)
367
367
  >>> layer.weight.value
368
368
  {'lora_a': Array([[ 0.25141352, -0.09826107],
369
369
  [ 0.2328382 , 0.38869813],
@@ -371,8 +371,8 @@ class LoRA(Module):
371
371
  'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
372
372
  [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
373
373
  >>> # Wrap around existing layer
374
- >>> linear = bst.nn.Linear(3, 4)
375
- >>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
374
+ >>> linear = brainstate.nn.Linear(3, 4)
375
+ >>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
376
376
  >>> assert wrapper.base_module == linear
377
377
  >>> y = layer(jnp.ones((16, 3)))
378
378
  >>> y.shape
brainstate/nn/_module.py CHANGED
@@ -128,9 +128,9 @@ class Module(Node, ParamDesc):
128
128
  Examples
129
129
  --------
130
130
 
131
- >>> import brainstate as bst
132
- >>> x = bst.random.rand((10, 10))
133
- >>> l = bst.nn.Dropout(0.5)
131
+ >>> import brainstate as brainstate
132
+ >>> x = brainstate.random.rand((10, 10))
133
+ >>> l = brainstate.nn.Dropout(0.5)
134
134
  >>> y = x >> l
135
135
  """
136
136
  return self.__call__(other)
@@ -230,7 +230,7 @@ class Module(Node, ParamDesc):
230
230
  pass
231
231
 
232
232
  def __pretty_repr_item__(self, name, value):
233
- if name in ['_in_size', '_out_size', '_name']:
233
+ if name.startswith('_'):
234
234
  return None if value is None else (name[1:], value) # skip the first `_`
235
235
  return name, value
236
236
 
@@ -266,14 +266,14 @@ class Sequential(Module):
266
266
  --------
267
267
 
268
268
  >>> import jax
269
- >>> import brainstate as bst
269
+ >>> import brainstate as brainstate
270
270
  >>> import brainstate.nn as nn
271
271
  >>>
272
272
  >>> # composing ANN models
273
273
  >>> l = nn.Sequential(nn.Linear(100, 10),
274
274
  >>> jax.nn.relu,
275
275
  >>> nn.Linear(10, 2))
276
- >>> l(bst.random.random((256, 100)))
276
+ >>> l(brainstate.random.random((256, 100)))
277
277
 
278
278
  Args:
279
279
  modules_as_tuple: The children modules.
@@ -77,7 +77,7 @@ class TestDelay(unittest.TestCase):
77
77
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
78
78
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
79
79
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
80
- # bst.util.clear_buffer_memory()
80
+ # brainstate.util.clear_buffer_memory()
81
81
 
82
82
  def test_jit_erro(self):
83
83
  rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
@@ -488,9 +488,9 @@ class LayerNorm(Module):
488
488
 
489
489
  Example usage::
490
490
 
491
- >>> import brainstate as bst
492
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
493
- >>> layer = bst.nn.LayerNorm(x.shape)
491
+ >>> import brainstate as brainstate
492
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
493
+ >>> layer = brainstate.nn.LayerNorm(x.shape)
494
494
  >>> layer.states()
495
495
  >>> y = layer(x)
496
496
 
@@ -616,9 +616,9 @@ class RMSNorm(Module):
616
616
 
617
617
  Example usage::
618
618
 
619
- >>> import brainstate as bst
620
- >>> x = bst.random.normal(size=(5, 6))
621
- >>> layer = bst.nn.RMSNorm(num_features=6)
619
+ >>> import brainstate as brainstate
620
+ >>> x = brainstate.random.normal(size=(5, 6))
621
+ >>> layer = brainstate.nn.RMSNorm(num_features=6)
622
622
  >>> layer.states()
623
623
  >>> y = layer(x)
624
624
 
@@ -739,14 +739,14 @@ class GroupNorm(Module):
739
739
  Example usage::
740
740
 
741
741
  >>> import numpy as np
742
- >>> import brainstate as bst
742
+ >>> import brainstate as brainstate
743
743
  ...
744
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
745
- >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
744
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
745
+ >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
746
746
  >>> layer.states()
747
747
  >>> y = layer(x)
748
- >>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
749
- >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
748
+ >>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
749
+ >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
750
750
  >>> np.testing.assert_allclose(y, y2)
751
751
 
752
752
  Attributes:
@@ -51,21 +51,21 @@ class Test_Normalization(parameterized.TestCase):
51
51
  # normalized_shape=(10, [5, 10])
52
52
  # )
53
53
  # def test_LayerNorm(self, normalized_shape):
54
- # net = bst.nn.LayerNorm(normalized_shape, )
55
- # input = bst.random.randn(20, 5, 10)
54
+ # net = brainstate.nn.LayerNorm(normalized_shape, )
55
+ # input = brainstate.random.randn(20, 5, 10)
56
56
  # output = net(input)
57
57
  #
58
58
  # @parameterized.product(
59
59
  # num_groups=[1, 2, 3, 6]
60
60
  # )
61
61
  # def test_GroupNorm(self, num_groups):
62
- # input = bst.random.randn(20, 10, 10, 6)
63
- # net = bst.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
62
+ # input = brainstate.random.randn(20, 10, 10, 6)
63
+ # net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
64
64
  # output = net(input)
65
65
  #
66
66
  # def test_InstanceNorm(self):
67
- # input = bst.random.randn(20, 10, 10, 6)
68
- # net = bst.nn.InstanceNorm(num_channels=6, )
67
+ # input = brainstate.random.randn(20, 10, 10, 6)
68
+ # net = brainstate.nn.InstanceNorm(num_channels=6, )
69
69
  # output = net(input)
70
70
 
71
71
 
@@ -55,8 +55,8 @@ class Flatten(Module):
55
55
  end_axis: last dim to flatten (default = -1).
56
56
 
57
57
  Examples::
58
- >>> import brainstate as bst
59
- >>> inp = bst.random.randn(32, 1, 5, 5)
58
+ >>> import brainstate as brainstate
59
+ >>> inp = brainstate.random.randn(32, 1, 5, 5)
60
60
  >>> # With default parameters
61
61
  >>> m = Flatten()
62
62
  >>> output = m(inp)
@@ -334,10 +334,10 @@ class MaxPool1d(_MaxPool):
334
334
 
335
335
  Examples::
336
336
 
337
- >>> import brainstate as bst
337
+ >>> import brainstate as brainstate
338
338
  >>> # pool of size=3, stride=2
339
339
  >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
340
- >>> input = bst.random.randn(20, 50, 16)
340
+ >>> input = brainstate.random.randn(20, 50, 16)
341
341
  >>> output = m(input)
342
342
  >>> output.shape
343
343
  (20, 24, 16)
@@ -418,12 +418,12 @@ class MaxPool2d(_MaxPool):
418
418
 
419
419
  Examples::
420
420
 
421
- >>> import brainstate as bst
421
+ >>> import brainstate as brainstate
422
422
  >>> # pool of square window of size=3, stride=2
423
423
  >>> m = MaxPool2d(3, stride=2)
424
424
  >>> # pool of non-square window
425
425
  >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
426
- >>> input = bst.random.randn(20, 50, 32, 16)
426
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
427
427
  >>> output = m(input)
428
428
  >>> output.shape
429
429
  (20, 24, 31, 16)
@@ -509,12 +509,12 @@ class MaxPool3d(_MaxPool):
509
509
 
510
510
  Examples::
511
511
 
512
- >>> import brainstate as bst
512
+ >>> import brainstate as brainstate
513
513
  >>> # pool of square window of size=3, stride=2
514
514
  >>> m = MaxPool3d(3, stride=2)
515
515
  >>> # pool of non-square window
516
516
  >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
517
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
517
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
518
518
  >>> output = m(input)
519
519
  >>> output.shape
520
520
  (20, 24, 43, 15, 16)
@@ -588,10 +588,10 @@ class AvgPool1d(_AvgPool):
588
588
 
589
589
  Examples::
590
590
 
591
- >>> import brainstate as bst
591
+ >>> import brainstate as brainstate
592
592
  >>> # pool with window of size=3, stride=2
593
593
  >>> m = AvgPool1d(3, stride=2)
594
- >>> input = bst.random.randn(20, 50, 16)
594
+ >>> input = brainstate.random.randn(20, 50, 16)
595
595
  >>> m(input).shape
596
596
  (20, 24, 16)
597
597
 
@@ -665,12 +665,12 @@ class AvgPool2d(_AvgPool):
665
665
 
666
666
  Examples::
667
667
 
668
- >>> import brainstate as bst
668
+ >>> import brainstate as brainstate
669
669
  >>> # pool of square window of size=3, stride=2
670
670
  >>> m = AvgPool2d(3, stride=2)
671
671
  >>> # pool of non-square window
672
672
  >>> m = AvgPool2d((3, 2), stride=(2, 1))
673
- >>> input = bst.random.randn(20, 50, 32, , 16)
673
+ >>> input = brainstate.random.randn(20, 50, 32, , 16)
674
674
  >>> output = m(input)
675
675
  >>> output.shape
676
676
  (20, 24, 31, 16)
@@ -753,12 +753,12 @@ class AvgPool3d(_AvgPool):
753
753
 
754
754
  Examples::
755
755
 
756
- >>> import brainstate as bst
756
+ >>> import brainstate as brainstate
757
757
  >>> # pool of square window of size=3, stride=2
758
758
  >>> m = AvgPool3d(3, stride=2)
759
759
  >>> # pool of non-square window
760
760
  >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
761
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
761
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
762
762
  >>> output = m(input)
763
763
  >>> output.shape
764
764
  (20, 24, 43, 15, 16)
@@ -931,10 +931,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
931
931
 
932
932
  Examples:
933
933
 
934
- >>> import brainstate as bst
934
+ >>> import brainstate as brainstate
935
935
  >>> # target output size of 5
936
936
  >>> m = AdaptiveMaxPool1d(5)
937
- >>> input = bst.random.randn(1, 64, 8)
937
+ >>> input = brainstate.random.randn(1, 64, 8)
938
938
  >>> output = m(input)
939
939
  >>> output.shape
940
940
  (1, 5, 8)
@@ -979,22 +979,22 @@ class AdaptiveAvgPool2d(_AdaptivePool):
979
979
 
980
980
  Examples:
981
981
 
982
- >>> import brainstate as bst
982
+ >>> import brainstate as brainstate
983
983
  >>> # target output size of 5x7
984
984
  >>> m = AdaptiveMaxPool2d((5, 7))
985
- >>> input = bst.random.randn(1, 8, 9, 64)
985
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
986
986
  >>> output = m(input)
987
987
  >>> output.shape
988
988
  (1, 5, 7, 64)
989
989
  >>> # target output size of 7x7 (square)
990
990
  >>> m = AdaptiveMaxPool2d(7)
991
- >>> input = bst.random.randn(1, 10, 9, 64)
991
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
992
992
  >>> output = m(input)
993
993
  >>> output.shape
994
994
  (1, 7, 7, 64)
995
995
  >>> # target output size of 10x7
996
996
  >>> m = AdaptiveMaxPool2d((None, 7))
997
- >>> input = bst.random.randn(1, 10, 9, 64)
997
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
998
998
  >>> output = m(input)
999
999
  >>> output.shape
1000
1000
  (1, 10, 7, 64)
@@ -1040,22 +1040,22 @@ class AdaptiveAvgPool3d(_AdaptivePool):
1040
1040
 
1041
1041
  Examples:
1042
1042
 
1043
- >>> import brainstate as bst
1043
+ >>> import brainstate as brainstate
1044
1044
  >>> # target output size of 5x7x9
1045
1045
  >>> m = AdaptiveMaxPool3d((5, 7, 9))
1046
- >>> input = bst.random.randn(1, 8, 9, 10, 64)
1046
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1047
1047
  >>> output = m(input)
1048
1048
  >>> output.shape
1049
1049
  (1, 5, 7, 9, 64)
1050
1050
  >>> # target output size of 7x7x7 (cube)
1051
1051
  >>> m = AdaptiveMaxPool3d(7)
1052
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1052
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1053
1053
  >>> output = m(input)
1054
1054
  >>> output.shape
1055
1055
  (1, 7, 7, 7, 64)
1056
1056
  >>> # target output size of 7x9x8
1057
1057
  >>> m = AdaptiveMaxPool3d((7, None, None))
1058
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1058
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1059
1059
  >>> output = m(input)
1060
1060
  >>> output.shape
1061
1061
  (1, 7, 9, 8, 64)
brainstate/nn/_synapse.py CHANGED
@@ -123,9 +123,6 @@ class Expon(Synapse, AlignPost):
123
123
  g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
124
124
  self.g.value = self.sum_delta_inputs(g)
125
125
  if x is not None: self.g.value += x
126
- return self.update_return()
127
-
128
- def update_return(self) -> PyTree:
129
126
  return self.g.value
130
127
 
131
128
 
@@ -232,9 +229,6 @@ class DualExpon(Synapse, AlignPost):
232
229
  if x is not None:
233
230
  self.g_rise.value += x
234
231
  self.g_decay.value += x
235
- return self.update_return()
236
-
237
- def update_return(self) -> PyTree:
238
232
  return self.a * (self.g_decay.value - self.g_rise.value)
239
233
 
240
234
 
@@ -414,12 +408,8 @@ class AMPA(Synapse):
414
408
  t = environ.get('t')
415
409
  self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
416
410
  TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
417
- dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
411
+ dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
418
412
  self.g.value = exp_euler_step(dg, self.g.value)
419
- return self.update_return()
420
-
421
- def update_return(self) -> PyTree:
422
- """Return the synaptic conductance value."""
423
413
  return self.g.value
424
414
 
425
415
 
@@ -513,4 +503,3 @@ class GABAa(AMPA):
513
503
  in_size=in_size,
514
504
  g_initializer=g_initializer
515
505
  )
516
-
brainstate/nn/_utils.py CHANGED
@@ -62,7 +62,7 @@ def count_parameters(
62
62
 
63
63
  Parameters:
64
64
  -----------
65
- model : bst.nn.Module
65
+ model : brainstate.nn.Module
66
66
  The neural network model for which to count parameters.
67
67
 
68
68
  Returns: