brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +6 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +22 -17
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/nn/_elementwise.py
CHANGED
@@ -61,7 +61,7 @@ class Threshold(ElementWiseBlock):
|
|
61
61
|
Examples::
|
62
62
|
|
63
63
|
>>> import brainstate.nn as nn
|
64
|
-
>>> import brainstate
|
64
|
+
>>> import brainstate
|
65
65
|
>>> m = nn.Threshold(0.1, 20)
|
66
66
|
>>> x = random.randn(2)
|
67
67
|
>>> output = m(x)
|
@@ -97,7 +97,7 @@ class ReLU(ElementWiseBlock):
|
|
97
97
|
Examples::
|
98
98
|
|
99
99
|
>>> import brainstate.nn as nn
|
100
|
-
>>> import brainstate as
|
100
|
+
>>> import brainstate as brainstate
|
101
101
|
>>> m = nn.ReLU()
|
102
102
|
>>> x = random.randn(2)
|
103
103
|
>>> output = m(x)
|
@@ -106,7 +106,7 @@ class ReLU(ElementWiseBlock):
|
|
106
106
|
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
107
107
|
|
108
108
|
>>> import brainstate.nn as nn
|
109
|
-
>>> import brainstate as
|
109
|
+
>>> import brainstate as brainstate
|
110
110
|
>>> m = nn.ReLU()
|
111
111
|
>>> x = random.randn(2).unsqueeze(0)
|
112
112
|
>>> output = jax.numpy.concat((m(x), m(-x)))
|
@@ -151,7 +151,7 @@ class RReLU(ElementWiseBlock):
|
|
151
151
|
Examples::
|
152
152
|
|
153
153
|
>>> import brainstate.nn as nn
|
154
|
-
>>> import brainstate as
|
154
|
+
>>> import brainstate as brainstate
|
155
155
|
>>> m = nn.RReLU(0.1, 0.3)
|
156
156
|
>>> x = random.randn(2)
|
157
157
|
>>> output = m(x)
|
@@ -205,7 +205,7 @@ class Hardtanh(ElementWiseBlock):
|
|
205
205
|
Examples::
|
206
206
|
|
207
207
|
>>> import brainstate.nn as nn
|
208
|
-
>>> import brainstate as
|
208
|
+
>>> import brainstate as brainstate
|
209
209
|
>>> m = nn.Hardtanh(-2, 2)
|
210
210
|
>>> x = random.randn(2)
|
211
211
|
>>> output = m(x)
|
@@ -244,7 +244,7 @@ class ReLU6(Hardtanh, ElementWiseBlock):
|
|
244
244
|
Examples::
|
245
245
|
|
246
246
|
>>> import brainstate.nn as nn
|
247
|
-
>>> import brainstate as
|
247
|
+
>>> import brainstate as brainstate
|
248
248
|
>>> m = nn.ReLU6()
|
249
249
|
>>> x = random.randn(2)
|
250
250
|
>>> output = m(x)
|
@@ -269,7 +269,7 @@ class Sigmoid(ElementWiseBlock):
|
|
269
269
|
Examples::
|
270
270
|
|
271
271
|
>>> import brainstate.nn as nn
|
272
|
-
>>> import brainstate as
|
272
|
+
>>> import brainstate as brainstate
|
273
273
|
>>> m = nn.Sigmoid()
|
274
274
|
>>> x = random.randn(2)
|
275
275
|
>>> output = m(x)
|
@@ -299,7 +299,7 @@ class Hardsigmoid(ElementWiseBlock):
|
|
299
299
|
Examples::
|
300
300
|
|
301
301
|
>>> import brainstate.nn as nn
|
302
|
-
>>> import brainstate as
|
302
|
+
>>> import brainstate as brainstate
|
303
303
|
>>> m = nn.Hardsigmoid()
|
304
304
|
>>> x = random.randn(2)
|
305
305
|
>>> output = m(x)
|
@@ -325,7 +325,7 @@ class Tanh(ElementWiseBlock):
|
|
325
325
|
Examples::
|
326
326
|
|
327
327
|
>>> import brainstate.nn as nn
|
328
|
-
>>> import brainstate as
|
328
|
+
>>> import brainstate as brainstate
|
329
329
|
>>> m = nn.Tanh()
|
330
330
|
>>> x = random.randn(2)
|
331
331
|
>>> output = m(x)
|
@@ -386,7 +386,7 @@ class Mish(ElementWiseBlock):
|
|
386
386
|
Examples::
|
387
387
|
|
388
388
|
>>> import brainstate.nn as nn
|
389
|
-
>>> import brainstate as
|
389
|
+
>>> import brainstate as brainstate
|
390
390
|
>>> m = nn.Mish()
|
391
391
|
>>> x = random.randn(2)
|
392
392
|
>>> output = m(x)
|
@@ -418,7 +418,7 @@ class Hardswish(ElementWiseBlock):
|
|
418
418
|
Examples::
|
419
419
|
|
420
420
|
>>> import brainstate.nn as nn
|
421
|
-
>>> import brainstate as
|
421
|
+
>>> import brainstate as brainstate
|
422
422
|
>>> m = nn.Hardswish()
|
423
423
|
>>> x = random.randn(2)
|
424
424
|
>>> output = m(x)
|
@@ -452,7 +452,7 @@ class ELU(ElementWiseBlock):
|
|
452
452
|
Examples::
|
453
453
|
|
454
454
|
>>> import brainstate.nn as nn
|
455
|
-
>>> import brainstate as
|
455
|
+
>>> import brainstate as brainstate
|
456
456
|
>>> m = nn.ELU()
|
457
457
|
>>> x = random.randn(2)
|
458
458
|
>>> output = m(x)
|
@@ -489,7 +489,7 @@ class CELU(ElementWiseBlock):
|
|
489
489
|
Examples::
|
490
490
|
|
491
491
|
>>> import brainstate.nn as nn
|
492
|
-
>>> import brainstate as
|
492
|
+
>>> import brainstate as brainstate
|
493
493
|
>>> m = nn.CELU()
|
494
494
|
>>> x = random.randn(2)
|
495
495
|
>>> output = m(x)
|
@@ -530,7 +530,7 @@ class SELU(ElementWiseBlock):
|
|
530
530
|
Examples::
|
531
531
|
|
532
532
|
>>> import brainstate.nn as nn
|
533
|
-
>>> import brainstate as
|
533
|
+
>>> import brainstate as brainstate
|
534
534
|
>>> m = nn.SELU()
|
535
535
|
>>> x = random.randn(2)
|
536
536
|
>>> output = m(x)
|
@@ -559,7 +559,7 @@ class GLU(ElementWiseBlock):
|
|
559
559
|
Examples::
|
560
560
|
|
561
561
|
>>> import brainstate.nn as nn
|
562
|
-
>>> import brainstate as
|
562
|
+
>>> import brainstate as brainstate
|
563
563
|
>>> m = nn.GLU()
|
564
564
|
>>> x = random.randn(4, 2)
|
565
565
|
>>> output = m(x)
|
@@ -600,7 +600,7 @@ class GELU(ElementWiseBlock):
|
|
600
600
|
Examples::
|
601
601
|
|
602
602
|
>>> import brainstate.nn as nn
|
603
|
-
>>> import brainstate as
|
603
|
+
>>> import brainstate as brainstate
|
604
604
|
>>> m = nn.GELU()
|
605
605
|
>>> x = random.randn(2)
|
606
606
|
>>> output = m(x)
|
@@ -642,7 +642,7 @@ class Hardshrink(ElementWiseBlock):
|
|
642
642
|
Examples::
|
643
643
|
|
644
644
|
>>> import brainstate.nn as nn
|
645
|
-
>>> import brainstate as
|
645
|
+
>>> import brainstate as brainstate
|
646
646
|
>>> m = nn.Hardshrink()
|
647
647
|
>>> x = random.randn(2)
|
648
648
|
>>> output = m(x)
|
@@ -689,7 +689,7 @@ class LeakyReLU(ElementWiseBlock):
|
|
689
689
|
Examples::
|
690
690
|
|
691
691
|
>>> import brainstate.nn as nn
|
692
|
-
>>> import brainstate as
|
692
|
+
>>> import brainstate as brainstate
|
693
693
|
>>> m = nn.LeakyReLU(0.1)
|
694
694
|
>>> x = random.randn(2)
|
695
695
|
>>> output = m(x)
|
@@ -721,7 +721,7 @@ class LogSigmoid(ElementWiseBlock):
|
|
721
721
|
Examples::
|
722
722
|
|
723
723
|
>>> import brainstate.nn as nn
|
724
|
-
>>> import brainstate as
|
724
|
+
>>> import brainstate as brainstate
|
725
725
|
>>> m = nn.LogSigmoid()
|
726
726
|
>>> x = random.randn(2)
|
727
727
|
>>> output = m(x)
|
@@ -749,7 +749,7 @@ class Softplus(ElementWiseBlock):
|
|
749
749
|
Examples::
|
750
750
|
|
751
751
|
>>> import brainstate.nn as nn
|
752
|
-
>>> import brainstate as
|
752
|
+
>>> import brainstate as brainstate
|
753
753
|
>>> m = nn.Softplus()
|
754
754
|
>>> x = random.randn(2)
|
755
755
|
>>> output = m(x)
|
@@ -781,7 +781,7 @@ class Softshrink(ElementWiseBlock):
|
|
781
781
|
Examples::
|
782
782
|
|
783
783
|
>>> import brainstate.nn as nn
|
784
|
-
>>> import brainstate as
|
784
|
+
>>> import brainstate as brainstate
|
785
785
|
>>> m = nn.Softshrink()
|
786
786
|
>>> x = random.randn(2)
|
787
787
|
>>> output = m(x)
|
@@ -843,9 +843,9 @@ class PReLU(ElementWiseBlock):
|
|
843
843
|
|
844
844
|
Examples::
|
845
845
|
|
846
|
-
>>> import brainstate as
|
847
|
-
>>> m =
|
848
|
-
>>> x =
|
846
|
+
>>> import brainstate as brainstate
|
847
|
+
>>> m = brainstate.nn.PReLU()
|
848
|
+
>>> x = brainstate.random.randn(2)
|
849
849
|
>>> output = m(x)
|
850
850
|
"""
|
851
851
|
__module__ = 'brainstate.nn'
|
@@ -876,7 +876,7 @@ class Softsign(ElementWiseBlock):
|
|
876
876
|
Examples::
|
877
877
|
|
878
878
|
>>> import brainstate.nn as nn
|
879
|
-
>>> import brainstate as
|
879
|
+
>>> import brainstate as brainstate
|
880
880
|
>>> m = nn.Softsign()
|
881
881
|
>>> x = random.randn(2)
|
882
882
|
>>> output = m(x)
|
@@ -900,7 +900,7 @@ class Tanhshrink(ElementWiseBlock):
|
|
900
900
|
Examples::
|
901
901
|
|
902
902
|
>>> import brainstate.nn as nn
|
903
|
-
>>> import brainstate as
|
903
|
+
>>> import brainstate as brainstate
|
904
904
|
>>> m = nn.Tanhshrink()
|
905
905
|
>>> x = random.randn(2)
|
906
906
|
>>> output = m(x)
|
@@ -937,7 +937,7 @@ class Softmin(ElementWiseBlock):
|
|
937
937
|
Examples::
|
938
938
|
|
939
939
|
>>> import brainstate.nn as nn
|
940
|
-
>>> import brainstate as
|
940
|
+
>>> import brainstate as brainstate
|
941
941
|
>>> m = nn.Softmin(dim=1)
|
942
942
|
>>> x = random.randn(2, 3)
|
943
943
|
>>> output = m(x)
|
@@ -990,7 +990,7 @@ class Softmax(ElementWiseBlock):
|
|
990
990
|
Examples::
|
991
991
|
|
992
992
|
>>> import brainstate.nn as nn
|
993
|
-
>>> import brainstate as
|
993
|
+
>>> import brainstate as brainstate
|
994
994
|
>>> m = nn.Softmax(dim=1)
|
995
995
|
>>> x = random.randn(2, 3)
|
996
996
|
>>> output = m(x)
|
@@ -1027,7 +1027,7 @@ class Softmax2d(ElementWiseBlock):
|
|
1027
1027
|
Examples::
|
1028
1028
|
|
1029
1029
|
>>> import brainstate.nn as nn
|
1030
|
-
>>> import brainstate as
|
1030
|
+
>>> import brainstate as brainstate
|
1031
1031
|
>>> m = nn.Softmax2d()
|
1032
1032
|
>>> # you softmax over the 2nd dimension
|
1033
1033
|
>>> x = random.randn(2, 3, 12, 13)
|
@@ -1062,7 +1062,7 @@ class LogSoftmax(ElementWiseBlock):
|
|
1062
1062
|
Examples::
|
1063
1063
|
|
1064
1064
|
>>> import brainstate.nn as nn
|
1065
|
-
>>> import brainstate as
|
1065
|
+
>>> import brainstate as brainstate
|
1066
1066
|
>>> m = nn.LogSoftmax(dim=1)
|
1067
1067
|
>>> x = random.randn(2, 3)
|
1068
1068
|
>>> output = m(x)
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -69,27 +69,24 @@ def exp_euler_step(
|
|
69
69
|
f'The input data type should be float64, float32, float16, or bfloat16 '
|
70
70
|
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
71
71
|
)
|
72
|
+
|
73
|
+
# drift
|
72
74
|
dt = environ.get('dt')
|
73
75
|
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
74
76
|
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
75
77
|
phi = u.math.exprel(dt * linear)
|
76
78
|
x_next = args[0] + dt * phi * derivative
|
77
79
|
|
80
|
+
# diffusion
|
78
81
|
if diffusion is not None:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
# "Drift unit is {drift}, diffusion unit is {diffusion}, ",
|
90
|
-
# drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
|
91
|
-
# )
|
92
|
-
|
93
|
-
# diffusion
|
94
|
-
x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
|
82
|
+
diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
|
83
|
+
if u.get_dim(x_next) != u.get_dim(diffusion_part):
|
84
|
+
drift_unit = u.get_unit(x_next)
|
85
|
+
time_unit = u.get_unit(dt)
|
86
|
+
raise ValueError(
|
87
|
+
f"Drift unit is {drift_unit}, "
|
88
|
+
f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
|
89
|
+
f"but we got {u.get_unit(diffusion_part)}."
|
90
|
+
)
|
91
|
+
x_next += diffusion_part
|
95
92
|
return x_next
|
brainstate/nn/_inputs.py
CHANGED
brainstate/nn/_linear.py
CHANGED
@@ -361,9 +361,9 @@ class LoRA(Module):
|
|
361
361
|
|
362
362
|
Example usage::
|
363
363
|
|
364
|
-
>>> import brainstate as
|
364
|
+
>>> import brainstate as brainstate
|
365
365
|
>>> import jax, jax.numpy as jnp
|
366
|
-
>>> layer =
|
366
|
+
>>> layer = brainstate.nn.LoRA(3, 2, 4)
|
367
367
|
>>> layer.weight.value
|
368
368
|
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
369
369
|
[ 0.2328382 , 0.38869813],
|
@@ -371,8 +371,8 @@ class LoRA(Module):
|
|
371
371
|
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
372
372
|
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
373
373
|
>>> # Wrap around existing layer
|
374
|
-
>>> linear =
|
375
|
-
>>> wrapper =
|
374
|
+
>>> linear = brainstate.nn.Linear(3, 4)
|
375
|
+
>>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
|
376
376
|
>>> assert wrapper.base_module == linear
|
377
377
|
>>> y = layer(jnp.ones((16, 3)))
|
378
378
|
>>> y.shape
|
brainstate/nn/_module.py
CHANGED
@@ -128,9 +128,9 @@ class Module(Node, ParamDesc):
|
|
128
128
|
Examples
|
129
129
|
--------
|
130
130
|
|
131
|
-
>>> import brainstate as
|
132
|
-
>>> x =
|
133
|
-
>>> l =
|
131
|
+
>>> import brainstate as brainstate
|
132
|
+
>>> x = brainstate.random.rand((10, 10))
|
133
|
+
>>> l = brainstate.nn.Dropout(0.5)
|
134
134
|
>>> y = x >> l
|
135
135
|
"""
|
136
136
|
return self.__call__(other)
|
@@ -230,7 +230,7 @@ class Module(Node, ParamDesc):
|
|
230
230
|
pass
|
231
231
|
|
232
232
|
def __pretty_repr_item__(self, name, value):
|
233
|
-
if name
|
233
|
+
if name.startswith('_'):
|
234
234
|
return None if value is None else (name[1:], value) # skip the first `_`
|
235
235
|
return name, value
|
236
236
|
|
@@ -266,14 +266,14 @@ class Sequential(Module):
|
|
266
266
|
--------
|
267
267
|
|
268
268
|
>>> import jax
|
269
|
-
>>> import brainstate as
|
269
|
+
>>> import brainstate as brainstate
|
270
270
|
>>> import brainstate.nn as nn
|
271
271
|
>>>
|
272
272
|
>>> # composing ANN models
|
273
273
|
>>> l = nn.Sequential(nn.Linear(100, 10),
|
274
274
|
>>> jax.nn.relu,
|
275
275
|
>>> nn.Linear(10, 2))
|
276
|
-
>>> l(
|
276
|
+
>>> l(brainstate.random.random((256, 100)))
|
277
277
|
|
278
278
|
Args:
|
279
279
|
modules_as_tuple: The children modules.
|
brainstate/nn/_module_test.py
CHANGED
@@ -77,7 +77,7 @@ class TestDelay(unittest.TestCase):
|
|
77
77
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
78
78
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
79
79
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
80
|
-
#
|
80
|
+
# brainstate.util.clear_buffer_memory()
|
81
81
|
|
82
82
|
def test_jit_erro(self):
|
83
83
|
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
brainstate/nn/_normalizations.py
CHANGED
@@ -488,9 +488,9 @@ class LayerNorm(Module):
|
|
488
488
|
|
489
489
|
Example usage::
|
490
490
|
|
491
|
-
>>> import brainstate as
|
492
|
-
>>> x =
|
493
|
-
>>> layer =
|
491
|
+
>>> import brainstate as brainstate
|
492
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
493
|
+
>>> layer = brainstate.nn.LayerNorm(x.shape)
|
494
494
|
>>> layer.states()
|
495
495
|
>>> y = layer(x)
|
496
496
|
|
@@ -616,9 +616,9 @@ class RMSNorm(Module):
|
|
616
616
|
|
617
617
|
Example usage::
|
618
618
|
|
619
|
-
>>> import brainstate as
|
620
|
-
>>> x =
|
621
|
-
>>> layer =
|
619
|
+
>>> import brainstate as brainstate
|
620
|
+
>>> x = brainstate.random.normal(size=(5, 6))
|
621
|
+
>>> layer = brainstate.nn.RMSNorm(num_features=6)
|
622
622
|
>>> layer.states()
|
623
623
|
>>> y = layer(x)
|
624
624
|
|
@@ -739,14 +739,14 @@ class GroupNorm(Module):
|
|
739
739
|
Example usage::
|
740
740
|
|
741
741
|
>>> import numpy as np
|
742
|
-
>>> import brainstate as
|
742
|
+
>>> import brainstate as brainstate
|
743
743
|
...
|
744
|
-
>>> x =
|
745
|
-
>>> layer =
|
744
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
745
|
+
>>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
|
746
746
|
>>> layer.states()
|
747
747
|
>>> y = layer(x)
|
748
|
-
>>> y =
|
749
|
-
>>> y2 =
|
748
|
+
>>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
|
749
|
+
>>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
750
750
|
>>> np.testing.assert_allclose(y, y2)
|
751
751
|
|
752
752
|
Attributes:
|
@@ -51,21 +51,21 @@ class Test_Normalization(parameterized.TestCase):
|
|
51
51
|
# normalized_shape=(10, [5, 10])
|
52
52
|
# )
|
53
53
|
# def test_LayerNorm(self, normalized_shape):
|
54
|
-
# net =
|
55
|
-
# input =
|
54
|
+
# net = brainstate.nn.LayerNorm(normalized_shape, )
|
55
|
+
# input = brainstate.random.randn(20, 5, 10)
|
56
56
|
# output = net(input)
|
57
57
|
#
|
58
58
|
# @parameterized.product(
|
59
59
|
# num_groups=[1, 2, 3, 6]
|
60
60
|
# )
|
61
61
|
# def test_GroupNorm(self, num_groups):
|
62
|
-
# input =
|
63
|
-
# net =
|
62
|
+
# input = brainstate.random.randn(20, 10, 10, 6)
|
63
|
+
# net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
|
64
64
|
# output = net(input)
|
65
65
|
#
|
66
66
|
# def test_InstanceNorm(self):
|
67
|
-
# input =
|
68
|
-
# net =
|
67
|
+
# input = brainstate.random.randn(20, 10, 10, 6)
|
68
|
+
# net = brainstate.nn.InstanceNorm(num_channels=6, )
|
69
69
|
# output = net(input)
|
70
70
|
|
71
71
|
|
brainstate/nn/_poolings.py
CHANGED
@@ -55,8 +55,8 @@ class Flatten(Module):
|
|
55
55
|
end_axis: last dim to flatten (default = -1).
|
56
56
|
|
57
57
|
Examples::
|
58
|
-
>>> import brainstate as
|
59
|
-
>>> inp =
|
58
|
+
>>> import brainstate as brainstate
|
59
|
+
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
60
60
|
>>> # With default parameters
|
61
61
|
>>> m = Flatten()
|
62
62
|
>>> output = m(inp)
|
@@ -334,10 +334,10 @@ class MaxPool1d(_MaxPool):
|
|
334
334
|
|
335
335
|
Examples::
|
336
336
|
|
337
|
-
>>> import brainstate as
|
337
|
+
>>> import brainstate as brainstate
|
338
338
|
>>> # pool of size=3, stride=2
|
339
339
|
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
340
|
-
>>> input =
|
340
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
341
341
|
>>> output = m(input)
|
342
342
|
>>> output.shape
|
343
343
|
(20, 24, 16)
|
@@ -418,12 +418,12 @@ class MaxPool2d(_MaxPool):
|
|
418
418
|
|
419
419
|
Examples::
|
420
420
|
|
421
|
-
>>> import brainstate as
|
421
|
+
>>> import brainstate as brainstate
|
422
422
|
>>> # pool of square window of size=3, stride=2
|
423
423
|
>>> m = MaxPool2d(3, stride=2)
|
424
424
|
>>> # pool of non-square window
|
425
425
|
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
426
|
-
>>> input =
|
426
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
427
427
|
>>> output = m(input)
|
428
428
|
>>> output.shape
|
429
429
|
(20, 24, 31, 16)
|
@@ -509,12 +509,12 @@ class MaxPool3d(_MaxPool):
|
|
509
509
|
|
510
510
|
Examples::
|
511
511
|
|
512
|
-
>>> import brainstate as
|
512
|
+
>>> import brainstate as brainstate
|
513
513
|
>>> # pool of square window of size=3, stride=2
|
514
514
|
>>> m = MaxPool3d(3, stride=2)
|
515
515
|
>>> # pool of non-square window
|
516
516
|
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
517
|
-
>>> input =
|
517
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
518
518
|
>>> output = m(input)
|
519
519
|
>>> output.shape
|
520
520
|
(20, 24, 43, 15, 16)
|
@@ -588,10 +588,10 @@ class AvgPool1d(_AvgPool):
|
|
588
588
|
|
589
589
|
Examples::
|
590
590
|
|
591
|
-
>>> import brainstate as
|
591
|
+
>>> import brainstate as brainstate
|
592
592
|
>>> # pool with window of size=3, stride=2
|
593
593
|
>>> m = AvgPool1d(3, stride=2)
|
594
|
-
>>> input =
|
594
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
595
595
|
>>> m(input).shape
|
596
596
|
(20, 24, 16)
|
597
597
|
|
@@ -665,12 +665,12 @@ class AvgPool2d(_AvgPool):
|
|
665
665
|
|
666
666
|
Examples::
|
667
667
|
|
668
|
-
>>> import brainstate as
|
668
|
+
>>> import brainstate as brainstate
|
669
669
|
>>> # pool of square window of size=3, stride=2
|
670
670
|
>>> m = AvgPool2d(3, stride=2)
|
671
671
|
>>> # pool of non-square window
|
672
672
|
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
673
|
-
>>> input =
|
673
|
+
>>> input = brainstate.random.randn(20, 50, 32, , 16)
|
674
674
|
>>> output = m(input)
|
675
675
|
>>> output.shape
|
676
676
|
(20, 24, 31, 16)
|
@@ -753,12 +753,12 @@ class AvgPool3d(_AvgPool):
|
|
753
753
|
|
754
754
|
Examples::
|
755
755
|
|
756
|
-
>>> import brainstate as
|
756
|
+
>>> import brainstate as brainstate
|
757
757
|
>>> # pool of square window of size=3, stride=2
|
758
758
|
>>> m = AvgPool3d(3, stride=2)
|
759
759
|
>>> # pool of non-square window
|
760
760
|
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
761
|
-
>>> input =
|
761
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
762
762
|
>>> output = m(input)
|
763
763
|
>>> output.shape
|
764
764
|
(20, 24, 43, 15, 16)
|
@@ -931,10 +931,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
|
|
931
931
|
|
932
932
|
Examples:
|
933
933
|
|
934
|
-
>>> import brainstate as
|
934
|
+
>>> import brainstate as brainstate
|
935
935
|
>>> # target output size of 5
|
936
936
|
>>> m = AdaptiveMaxPool1d(5)
|
937
|
-
>>> input =
|
937
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
938
938
|
>>> output = m(input)
|
939
939
|
>>> output.shape
|
940
940
|
(1, 5, 8)
|
@@ -979,22 +979,22 @@ class AdaptiveAvgPool2d(_AdaptivePool):
|
|
979
979
|
|
980
980
|
Examples:
|
981
981
|
|
982
|
-
>>> import brainstate as
|
982
|
+
>>> import brainstate as brainstate
|
983
983
|
>>> # target output size of 5x7
|
984
984
|
>>> m = AdaptiveMaxPool2d((5, 7))
|
985
|
-
>>> input =
|
985
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
986
986
|
>>> output = m(input)
|
987
987
|
>>> output.shape
|
988
988
|
(1, 5, 7, 64)
|
989
989
|
>>> # target output size of 7x7 (square)
|
990
990
|
>>> m = AdaptiveMaxPool2d(7)
|
991
|
-
>>> input =
|
991
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
992
992
|
>>> output = m(input)
|
993
993
|
>>> output.shape
|
994
994
|
(1, 7, 7, 64)
|
995
995
|
>>> # target output size of 10x7
|
996
996
|
>>> m = AdaptiveMaxPool2d((None, 7))
|
997
|
-
>>> input =
|
997
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
998
998
|
>>> output = m(input)
|
999
999
|
>>> output.shape
|
1000
1000
|
(1, 10, 7, 64)
|
@@ -1040,22 +1040,22 @@ class AdaptiveAvgPool3d(_AdaptivePool):
|
|
1040
1040
|
|
1041
1041
|
Examples:
|
1042
1042
|
|
1043
|
-
>>> import brainstate as
|
1043
|
+
>>> import brainstate as brainstate
|
1044
1044
|
>>> # target output size of 5x7x9
|
1045
1045
|
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
1046
|
-
>>> input =
|
1046
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1047
1047
|
>>> output = m(input)
|
1048
1048
|
>>> output.shape
|
1049
1049
|
(1, 5, 7, 9, 64)
|
1050
1050
|
>>> # target output size of 7x7x7 (cube)
|
1051
1051
|
>>> m = AdaptiveMaxPool3d(7)
|
1052
|
-
>>> input =
|
1052
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1053
1053
|
>>> output = m(input)
|
1054
1054
|
>>> output.shape
|
1055
1055
|
(1, 7, 7, 7, 64)
|
1056
1056
|
>>> # target output size of 7x9x8
|
1057
1057
|
>>> m = AdaptiveMaxPool3d((7, None, None))
|
1058
|
-
>>> input =
|
1058
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1059
1059
|
>>> output = m(input)
|
1060
1060
|
>>> output.shape
|
1061
1061
|
(1, 7, 9, 8, 64)
|
brainstate/nn/_synapse.py
CHANGED
@@ -123,9 +123,6 @@ class Expon(Synapse, AlignPost):
|
|
123
123
|
g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
|
124
124
|
self.g.value = self.sum_delta_inputs(g)
|
125
125
|
if x is not None: self.g.value += x
|
126
|
-
return self.update_return()
|
127
|
-
|
128
|
-
def update_return(self) -> PyTree:
|
129
126
|
return self.g.value
|
130
127
|
|
131
128
|
|
@@ -232,9 +229,6 @@ class DualExpon(Synapse, AlignPost):
|
|
232
229
|
if x is not None:
|
233
230
|
self.g_rise.value += x
|
234
231
|
self.g_decay.value += x
|
235
|
-
return self.update_return()
|
236
|
-
|
237
|
-
def update_return(self) -> PyTree:
|
238
232
|
return self.a * (self.g_decay.value - self.g_rise.value)
|
239
233
|
|
240
234
|
|
@@ -414,12 +408,8 @@ class AMPA(Synapse):
|
|
414
408
|
t = environ.get('t')
|
415
409
|
self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
|
416
410
|
TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
|
417
|
-
dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
|
411
|
+
dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
|
418
412
|
self.g.value = exp_euler_step(dg, self.g.value)
|
419
|
-
return self.update_return()
|
420
|
-
|
421
|
-
def update_return(self) -> PyTree:
|
422
|
-
"""Return the synaptic conductance value."""
|
423
413
|
return self.g.value
|
424
414
|
|
425
415
|
|
@@ -513,4 +503,3 @@ class GABAa(AMPA):
|
|
513
503
|
in_size=in_size,
|
514
504
|
g_initializer=g_initializer
|
515
505
|
)
|
516
|
-
|