brainstate 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +5 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +13 -8
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +48 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/METADATA +12 -12
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/RECORD +40 -40
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/WHEEL +1 -1
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/LICENSE +0 -0
- {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/top_level.txt +0 -0
brainstate/nn/_dynamics.py
CHANGED
@@ -42,7 +42,7 @@ import numpy as np
|
|
42
42
|
from brainstate import environ
|
43
43
|
from brainstate._state import State
|
44
44
|
from brainstate.graph import Node
|
45
|
-
from brainstate.mixin import ParamDescriber
|
45
|
+
from brainstate.mixin import ParamDescriber
|
46
46
|
from brainstate.typing import Size, ArrayLike, PyTree
|
47
47
|
from ._delay import StateWithDelay, Delay
|
48
48
|
from ._module import Module
|
@@ -101,7 +101,7 @@ class Projection(Module):
|
|
101
101
|
raise ValueError('Do not implement the update() function.')
|
102
102
|
|
103
103
|
|
104
|
-
class Dynamics(Module
|
104
|
+
class Dynamics(Module):
|
105
105
|
"""
|
106
106
|
Base class for implementing neural dynamics models in BrainState.
|
107
107
|
|
@@ -214,13 +214,13 @@ class Dynamics(Module, UpdateReturn):
|
|
214
214
|
# in-/out- size of neuron population
|
215
215
|
self.out_size = self.in_size
|
216
216
|
|
217
|
-
def __pretty_repr_item__(self, name, value):
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
217
|
+
# def __pretty_repr_item__(self, name, value):
|
218
|
+
# if name in [
|
219
|
+
# '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
|
220
|
+
# '_in_size', '_out_size', '_name', '_mode',
|
221
|
+
# ]:
|
222
|
+
# return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
223
|
+
# return super().__pretty_repr_item__(name, value)
|
224
224
|
|
225
225
|
@property
|
226
226
|
def varshape(self):
|
@@ -470,21 +470,30 @@ class Dynamics(Module, UpdateReturn):
|
|
470
470
|
if self._current_inputs is None:
|
471
471
|
return init
|
472
472
|
if label is None:
|
473
|
-
|
474
|
-
for key in tuple(self._current_inputs.keys()):
|
475
|
-
out = self._current_inputs[key]
|
476
|
-
init = init + (out(*args, **kwargs) if callable(out) else out)
|
477
|
-
if not callable(out):
|
478
|
-
self._current_inputs.pop(key)
|
473
|
+
filter_fn = lambda k: True
|
479
474
|
else:
|
480
|
-
# has label
|
481
475
|
label_repr = _input_label_start(label)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
476
|
+
filter_fn = lambda k: k.startswith(label_repr)
|
477
|
+
for key in tuple(self._current_inputs.keys()):
|
478
|
+
if filter_fn(key):
|
479
|
+
out = self._current_inputs[key]
|
480
|
+
if callable(out):
|
481
|
+
try:
|
482
|
+
init = init + out(*args, **kwargs)
|
483
|
+
except Exception as e:
|
484
|
+
raise ValueError(
|
485
|
+
f'Error in delta input value {key}: {out}\n'
|
486
|
+
f'Error: {e}'
|
487
|
+
) from e
|
488
|
+
else:
|
489
|
+
try:
|
490
|
+
init = init + out
|
491
|
+
except Exception as e:
|
492
|
+
raise ValueError(
|
493
|
+
f'Error in delta input value {key}: {out}\n'
|
494
|
+
f'Error: {e}'
|
495
|
+
) from e
|
496
|
+
self._current_inputs.pop(key)
|
488
497
|
return init
|
489
498
|
|
490
499
|
def sum_delta_inputs(
|
@@ -529,21 +538,30 @@ class Dynamics(Module, UpdateReturn):
|
|
529
538
|
if self._delta_inputs is None:
|
530
539
|
return init
|
531
540
|
if label is None:
|
532
|
-
|
533
|
-
for key in tuple(self._delta_inputs.keys()):
|
534
|
-
out = self._delta_inputs[key]
|
535
|
-
init = init + (out(*args, **kwargs) if callable(out) else out)
|
536
|
-
if not callable(out):
|
537
|
-
self._delta_inputs.pop(key)
|
541
|
+
filter_fn = lambda k: True
|
538
542
|
else:
|
539
|
-
# has label
|
540
543
|
label_repr = _input_label_start(label)
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
544
|
+
filter_fn = lambda k: k.startswith(label_repr)
|
545
|
+
for key in tuple(self._delta_inputs.keys()):
|
546
|
+
if filter_fn(key):
|
547
|
+
out = self._delta_inputs[key]
|
548
|
+
if callable(out):
|
549
|
+
try:
|
550
|
+
init = init + out(*args, **kwargs)
|
551
|
+
except Exception as e:
|
552
|
+
raise ValueError(
|
553
|
+
f'Error in delta input function {key}: {out}\n'
|
554
|
+
f'Error: {e}'
|
555
|
+
) from e
|
556
|
+
else:
|
557
|
+
try:
|
558
|
+
init = init + out
|
559
|
+
except Exception as e:
|
560
|
+
raise ValueError(
|
561
|
+
f'Error in delta input value {key}: {out}\n'
|
562
|
+
f'Error: {e}'
|
563
|
+
) from e
|
564
|
+
self._delta_inputs.pop(key)
|
547
565
|
return init
|
548
566
|
|
549
567
|
@property
|
brainstate/nn/_elementwise.py
CHANGED
@@ -61,7 +61,7 @@ class Threshold(ElementWiseBlock):
|
|
61
61
|
Examples::
|
62
62
|
|
63
63
|
>>> import brainstate.nn as nn
|
64
|
-
>>> import brainstate
|
64
|
+
>>> import brainstate
|
65
65
|
>>> m = nn.Threshold(0.1, 20)
|
66
66
|
>>> x = random.randn(2)
|
67
67
|
>>> output = m(x)
|
@@ -97,7 +97,7 @@ class ReLU(ElementWiseBlock):
|
|
97
97
|
Examples::
|
98
98
|
|
99
99
|
>>> import brainstate.nn as nn
|
100
|
-
>>> import brainstate as
|
100
|
+
>>> import brainstate as brainstate
|
101
101
|
>>> m = nn.ReLU()
|
102
102
|
>>> x = random.randn(2)
|
103
103
|
>>> output = m(x)
|
@@ -106,7 +106,7 @@ class ReLU(ElementWiseBlock):
|
|
106
106
|
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
107
107
|
|
108
108
|
>>> import brainstate.nn as nn
|
109
|
-
>>> import brainstate as
|
109
|
+
>>> import brainstate as brainstate
|
110
110
|
>>> m = nn.ReLU()
|
111
111
|
>>> x = random.randn(2).unsqueeze(0)
|
112
112
|
>>> output = jax.numpy.concat((m(x), m(-x)))
|
@@ -151,7 +151,7 @@ class RReLU(ElementWiseBlock):
|
|
151
151
|
Examples::
|
152
152
|
|
153
153
|
>>> import brainstate.nn as nn
|
154
|
-
>>> import brainstate as
|
154
|
+
>>> import brainstate as brainstate
|
155
155
|
>>> m = nn.RReLU(0.1, 0.3)
|
156
156
|
>>> x = random.randn(2)
|
157
157
|
>>> output = m(x)
|
@@ -205,7 +205,7 @@ class Hardtanh(ElementWiseBlock):
|
|
205
205
|
Examples::
|
206
206
|
|
207
207
|
>>> import brainstate.nn as nn
|
208
|
-
>>> import brainstate as
|
208
|
+
>>> import brainstate as brainstate
|
209
209
|
>>> m = nn.Hardtanh(-2, 2)
|
210
210
|
>>> x = random.randn(2)
|
211
211
|
>>> output = m(x)
|
@@ -244,7 +244,7 @@ class ReLU6(Hardtanh, ElementWiseBlock):
|
|
244
244
|
Examples::
|
245
245
|
|
246
246
|
>>> import brainstate.nn as nn
|
247
|
-
>>> import brainstate as
|
247
|
+
>>> import brainstate as brainstate
|
248
248
|
>>> m = nn.ReLU6()
|
249
249
|
>>> x = random.randn(2)
|
250
250
|
>>> output = m(x)
|
@@ -269,7 +269,7 @@ class Sigmoid(ElementWiseBlock):
|
|
269
269
|
Examples::
|
270
270
|
|
271
271
|
>>> import brainstate.nn as nn
|
272
|
-
>>> import brainstate as
|
272
|
+
>>> import brainstate as brainstate
|
273
273
|
>>> m = nn.Sigmoid()
|
274
274
|
>>> x = random.randn(2)
|
275
275
|
>>> output = m(x)
|
@@ -299,7 +299,7 @@ class Hardsigmoid(ElementWiseBlock):
|
|
299
299
|
Examples::
|
300
300
|
|
301
301
|
>>> import brainstate.nn as nn
|
302
|
-
>>> import brainstate as
|
302
|
+
>>> import brainstate as brainstate
|
303
303
|
>>> m = nn.Hardsigmoid()
|
304
304
|
>>> x = random.randn(2)
|
305
305
|
>>> output = m(x)
|
@@ -325,7 +325,7 @@ class Tanh(ElementWiseBlock):
|
|
325
325
|
Examples::
|
326
326
|
|
327
327
|
>>> import brainstate.nn as nn
|
328
|
-
>>> import brainstate as
|
328
|
+
>>> import brainstate as brainstate
|
329
329
|
>>> m = nn.Tanh()
|
330
330
|
>>> x = random.randn(2)
|
331
331
|
>>> output = m(x)
|
@@ -386,7 +386,7 @@ class Mish(ElementWiseBlock):
|
|
386
386
|
Examples::
|
387
387
|
|
388
388
|
>>> import brainstate.nn as nn
|
389
|
-
>>> import brainstate as
|
389
|
+
>>> import brainstate as brainstate
|
390
390
|
>>> m = nn.Mish()
|
391
391
|
>>> x = random.randn(2)
|
392
392
|
>>> output = m(x)
|
@@ -418,7 +418,7 @@ class Hardswish(ElementWiseBlock):
|
|
418
418
|
Examples::
|
419
419
|
|
420
420
|
>>> import brainstate.nn as nn
|
421
|
-
>>> import brainstate as
|
421
|
+
>>> import brainstate as brainstate
|
422
422
|
>>> m = nn.Hardswish()
|
423
423
|
>>> x = random.randn(2)
|
424
424
|
>>> output = m(x)
|
@@ -452,7 +452,7 @@ class ELU(ElementWiseBlock):
|
|
452
452
|
Examples::
|
453
453
|
|
454
454
|
>>> import brainstate.nn as nn
|
455
|
-
>>> import brainstate as
|
455
|
+
>>> import brainstate as brainstate
|
456
456
|
>>> m = nn.ELU()
|
457
457
|
>>> x = random.randn(2)
|
458
458
|
>>> output = m(x)
|
@@ -489,7 +489,7 @@ class CELU(ElementWiseBlock):
|
|
489
489
|
Examples::
|
490
490
|
|
491
491
|
>>> import brainstate.nn as nn
|
492
|
-
>>> import brainstate as
|
492
|
+
>>> import brainstate as brainstate
|
493
493
|
>>> m = nn.CELU()
|
494
494
|
>>> x = random.randn(2)
|
495
495
|
>>> output = m(x)
|
@@ -530,7 +530,7 @@ class SELU(ElementWiseBlock):
|
|
530
530
|
Examples::
|
531
531
|
|
532
532
|
>>> import brainstate.nn as nn
|
533
|
-
>>> import brainstate as
|
533
|
+
>>> import brainstate as brainstate
|
534
534
|
>>> m = nn.SELU()
|
535
535
|
>>> x = random.randn(2)
|
536
536
|
>>> output = m(x)
|
@@ -559,7 +559,7 @@ class GLU(ElementWiseBlock):
|
|
559
559
|
Examples::
|
560
560
|
|
561
561
|
>>> import brainstate.nn as nn
|
562
|
-
>>> import brainstate as
|
562
|
+
>>> import brainstate as brainstate
|
563
563
|
>>> m = nn.GLU()
|
564
564
|
>>> x = random.randn(4, 2)
|
565
565
|
>>> output = m(x)
|
@@ -600,7 +600,7 @@ class GELU(ElementWiseBlock):
|
|
600
600
|
Examples::
|
601
601
|
|
602
602
|
>>> import brainstate.nn as nn
|
603
|
-
>>> import brainstate as
|
603
|
+
>>> import brainstate as brainstate
|
604
604
|
>>> m = nn.GELU()
|
605
605
|
>>> x = random.randn(2)
|
606
606
|
>>> output = m(x)
|
@@ -642,7 +642,7 @@ class Hardshrink(ElementWiseBlock):
|
|
642
642
|
Examples::
|
643
643
|
|
644
644
|
>>> import brainstate.nn as nn
|
645
|
-
>>> import brainstate as
|
645
|
+
>>> import brainstate as brainstate
|
646
646
|
>>> m = nn.Hardshrink()
|
647
647
|
>>> x = random.randn(2)
|
648
648
|
>>> output = m(x)
|
@@ -689,7 +689,7 @@ class LeakyReLU(ElementWiseBlock):
|
|
689
689
|
Examples::
|
690
690
|
|
691
691
|
>>> import brainstate.nn as nn
|
692
|
-
>>> import brainstate as
|
692
|
+
>>> import brainstate as brainstate
|
693
693
|
>>> m = nn.LeakyReLU(0.1)
|
694
694
|
>>> x = random.randn(2)
|
695
695
|
>>> output = m(x)
|
@@ -721,7 +721,7 @@ class LogSigmoid(ElementWiseBlock):
|
|
721
721
|
Examples::
|
722
722
|
|
723
723
|
>>> import brainstate.nn as nn
|
724
|
-
>>> import brainstate as
|
724
|
+
>>> import brainstate as brainstate
|
725
725
|
>>> m = nn.LogSigmoid()
|
726
726
|
>>> x = random.randn(2)
|
727
727
|
>>> output = m(x)
|
@@ -749,7 +749,7 @@ class Softplus(ElementWiseBlock):
|
|
749
749
|
Examples::
|
750
750
|
|
751
751
|
>>> import brainstate.nn as nn
|
752
|
-
>>> import brainstate as
|
752
|
+
>>> import brainstate as brainstate
|
753
753
|
>>> m = nn.Softplus()
|
754
754
|
>>> x = random.randn(2)
|
755
755
|
>>> output = m(x)
|
@@ -781,7 +781,7 @@ class Softshrink(ElementWiseBlock):
|
|
781
781
|
Examples::
|
782
782
|
|
783
783
|
>>> import brainstate.nn as nn
|
784
|
-
>>> import brainstate as
|
784
|
+
>>> import brainstate as brainstate
|
785
785
|
>>> m = nn.Softshrink()
|
786
786
|
>>> x = random.randn(2)
|
787
787
|
>>> output = m(x)
|
@@ -843,9 +843,9 @@ class PReLU(ElementWiseBlock):
|
|
843
843
|
|
844
844
|
Examples::
|
845
845
|
|
846
|
-
>>> import brainstate as
|
847
|
-
>>> m =
|
848
|
-
>>> x =
|
846
|
+
>>> import brainstate as brainstate
|
847
|
+
>>> m = brainstate.nn.PReLU()
|
848
|
+
>>> x = brainstate.random.randn(2)
|
849
849
|
>>> output = m(x)
|
850
850
|
"""
|
851
851
|
__module__ = 'brainstate.nn'
|
@@ -876,7 +876,7 @@ class Softsign(ElementWiseBlock):
|
|
876
876
|
Examples::
|
877
877
|
|
878
878
|
>>> import brainstate.nn as nn
|
879
|
-
>>> import brainstate as
|
879
|
+
>>> import brainstate as brainstate
|
880
880
|
>>> m = nn.Softsign()
|
881
881
|
>>> x = random.randn(2)
|
882
882
|
>>> output = m(x)
|
@@ -900,7 +900,7 @@ class Tanhshrink(ElementWiseBlock):
|
|
900
900
|
Examples::
|
901
901
|
|
902
902
|
>>> import brainstate.nn as nn
|
903
|
-
>>> import brainstate as
|
903
|
+
>>> import brainstate as brainstate
|
904
904
|
>>> m = nn.Tanhshrink()
|
905
905
|
>>> x = random.randn(2)
|
906
906
|
>>> output = m(x)
|
@@ -937,7 +937,7 @@ class Softmin(ElementWiseBlock):
|
|
937
937
|
Examples::
|
938
938
|
|
939
939
|
>>> import brainstate.nn as nn
|
940
|
-
>>> import brainstate as
|
940
|
+
>>> import brainstate as brainstate
|
941
941
|
>>> m = nn.Softmin(dim=1)
|
942
942
|
>>> x = random.randn(2, 3)
|
943
943
|
>>> output = m(x)
|
@@ -990,7 +990,7 @@ class Softmax(ElementWiseBlock):
|
|
990
990
|
Examples::
|
991
991
|
|
992
992
|
>>> import brainstate.nn as nn
|
993
|
-
>>> import brainstate as
|
993
|
+
>>> import brainstate as brainstate
|
994
994
|
>>> m = nn.Softmax(dim=1)
|
995
995
|
>>> x = random.randn(2, 3)
|
996
996
|
>>> output = m(x)
|
@@ -1027,7 +1027,7 @@ class Softmax2d(ElementWiseBlock):
|
|
1027
1027
|
Examples::
|
1028
1028
|
|
1029
1029
|
>>> import brainstate.nn as nn
|
1030
|
-
>>> import brainstate as
|
1030
|
+
>>> import brainstate as brainstate
|
1031
1031
|
>>> m = nn.Softmax2d()
|
1032
1032
|
>>> # you softmax over the 2nd dimension
|
1033
1033
|
>>> x = random.randn(2, 3, 12, 13)
|
@@ -1062,7 +1062,7 @@ class LogSoftmax(ElementWiseBlock):
|
|
1062
1062
|
Examples::
|
1063
1063
|
|
1064
1064
|
>>> import brainstate.nn as nn
|
1065
|
-
>>> import brainstate as
|
1065
|
+
>>> import brainstate as brainstate
|
1066
1066
|
>>> m = nn.LogSoftmax(dim=1)
|
1067
1067
|
>>> x = random.randn(2, 3)
|
1068
1068
|
>>> output = m(x)
|
brainstate/nn/_linear.py
CHANGED
@@ -361,9 +361,9 @@ class LoRA(Module):
|
|
361
361
|
|
362
362
|
Example usage::
|
363
363
|
|
364
|
-
>>> import brainstate as
|
364
|
+
>>> import brainstate as brainstate
|
365
365
|
>>> import jax, jax.numpy as jnp
|
366
|
-
>>> layer =
|
366
|
+
>>> layer = brainstate.nn.LoRA(3, 2, 4)
|
367
367
|
>>> layer.weight.value
|
368
368
|
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
369
369
|
[ 0.2328382 , 0.38869813],
|
@@ -371,8 +371,8 @@ class LoRA(Module):
|
|
371
371
|
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
372
372
|
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
373
373
|
>>> # Wrap around existing layer
|
374
|
-
>>> linear =
|
375
|
-
>>> wrapper =
|
374
|
+
>>> linear = brainstate.nn.Linear(3, 4)
|
375
|
+
>>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
|
376
376
|
>>> assert wrapper.base_module == linear
|
377
377
|
>>> y = layer(jnp.ones((16, 3)))
|
378
378
|
>>> y.shape
|
brainstate/nn/_module.py
CHANGED
@@ -128,9 +128,9 @@ class Module(Node, ParamDesc):
|
|
128
128
|
Examples
|
129
129
|
--------
|
130
130
|
|
131
|
-
>>> import brainstate as
|
132
|
-
>>> x =
|
133
|
-
>>> l =
|
131
|
+
>>> import brainstate as brainstate
|
132
|
+
>>> x = brainstate.random.rand((10, 10))
|
133
|
+
>>> l = brainstate.nn.Dropout(0.5)
|
134
134
|
>>> y = x >> l
|
135
135
|
"""
|
136
136
|
return self.__call__(other)
|
@@ -230,7 +230,7 @@ class Module(Node, ParamDesc):
|
|
230
230
|
pass
|
231
231
|
|
232
232
|
def __pretty_repr_item__(self, name, value):
|
233
|
-
if name
|
233
|
+
if name.startswith('_'):
|
234
234
|
return None if value is None else (name[1:], value) # skip the first `_`
|
235
235
|
return name, value
|
236
236
|
|
@@ -266,14 +266,14 @@ class Sequential(Module):
|
|
266
266
|
--------
|
267
267
|
|
268
268
|
>>> import jax
|
269
|
-
>>> import brainstate as
|
269
|
+
>>> import brainstate as brainstate
|
270
270
|
>>> import brainstate.nn as nn
|
271
271
|
>>>
|
272
272
|
>>> # composing ANN models
|
273
273
|
>>> l = nn.Sequential(nn.Linear(100, 10),
|
274
274
|
>>> jax.nn.relu,
|
275
275
|
>>> nn.Linear(10, 2))
|
276
|
-
>>> l(
|
276
|
+
>>> l(brainstate.random.random((256, 100)))
|
277
277
|
|
278
278
|
Args:
|
279
279
|
modules_as_tuple: The children modules.
|
brainstate/nn/_module_test.py
CHANGED
@@ -77,7 +77,7 @@ class TestDelay(unittest.TestCase):
|
|
77
77
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
78
78
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
79
79
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
80
|
-
#
|
80
|
+
# brainstate.util.clear_buffer_memory()
|
81
81
|
|
82
82
|
def test_jit_erro(self):
|
83
83
|
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
brainstate/nn/_normalizations.py
CHANGED
@@ -488,9 +488,9 @@ class LayerNorm(Module):
|
|
488
488
|
|
489
489
|
Example usage::
|
490
490
|
|
491
|
-
>>> import brainstate as
|
492
|
-
>>> x =
|
493
|
-
>>> layer =
|
491
|
+
>>> import brainstate as brainstate
|
492
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
493
|
+
>>> layer = brainstate.nn.LayerNorm(x.shape)
|
494
494
|
>>> layer.states()
|
495
495
|
>>> y = layer(x)
|
496
496
|
|
@@ -616,9 +616,9 @@ class RMSNorm(Module):
|
|
616
616
|
|
617
617
|
Example usage::
|
618
618
|
|
619
|
-
>>> import brainstate as
|
620
|
-
>>> x =
|
621
|
-
>>> layer =
|
619
|
+
>>> import brainstate as brainstate
|
620
|
+
>>> x = brainstate.random.normal(size=(5, 6))
|
621
|
+
>>> layer = brainstate.nn.RMSNorm(num_features=6)
|
622
622
|
>>> layer.states()
|
623
623
|
>>> y = layer(x)
|
624
624
|
|
@@ -739,14 +739,14 @@ class GroupNorm(Module):
|
|
739
739
|
Example usage::
|
740
740
|
|
741
741
|
>>> import numpy as np
|
742
|
-
>>> import brainstate as
|
742
|
+
>>> import brainstate as brainstate
|
743
743
|
...
|
744
|
-
>>> x =
|
745
|
-
>>> layer =
|
744
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
745
|
+
>>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
|
746
746
|
>>> layer.states()
|
747
747
|
>>> y = layer(x)
|
748
|
-
>>> y =
|
749
|
-
>>> y2 =
|
748
|
+
>>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
|
749
|
+
>>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
750
750
|
>>> np.testing.assert_allclose(y, y2)
|
751
751
|
|
752
752
|
Attributes:
|
@@ -51,21 +51,21 @@ class Test_Normalization(parameterized.TestCase):
|
|
51
51
|
# normalized_shape=(10, [5, 10])
|
52
52
|
# )
|
53
53
|
# def test_LayerNorm(self, normalized_shape):
|
54
|
-
# net =
|
55
|
-
# input =
|
54
|
+
# net = brainstate.nn.LayerNorm(normalized_shape, )
|
55
|
+
# input = brainstate.random.randn(20, 5, 10)
|
56
56
|
# output = net(input)
|
57
57
|
#
|
58
58
|
# @parameterized.product(
|
59
59
|
# num_groups=[1, 2, 3, 6]
|
60
60
|
# )
|
61
61
|
# def test_GroupNorm(self, num_groups):
|
62
|
-
# input =
|
63
|
-
# net =
|
62
|
+
# input = brainstate.random.randn(20, 10, 10, 6)
|
63
|
+
# net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
|
64
64
|
# output = net(input)
|
65
65
|
#
|
66
66
|
# def test_InstanceNorm(self):
|
67
|
-
# input =
|
68
|
-
# net =
|
67
|
+
# input = brainstate.random.randn(20, 10, 10, 6)
|
68
|
+
# net = brainstate.nn.InstanceNorm(num_channels=6, )
|
69
69
|
# output = net(input)
|
70
70
|
|
71
71
|
|