brainstate 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +5 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +13 -8
  7. brainstate/compile/_conditions.py +2 -2
  8. brainstate/compile/_make_jaxpr.py +48 -6
  9. brainstate/compile/_progress_bar.py +2 -2
  10. brainstate/environ.py +19 -19
  11. brainstate/functional/_activations_test.py +12 -12
  12. brainstate/graph/_graph_operation.py +69 -69
  13. brainstate/graph/_graph_operation_test.py +2 -2
  14. brainstate/mixin.py +0 -17
  15. brainstate/nn/_collective_ops.py +4 -4
  16. brainstate/nn/_dropout_test.py +2 -2
  17. brainstate/nn/_dynamics.py +53 -35
  18. brainstate/nn/_elementwise.py +30 -30
  19. brainstate/nn/_linear.py +4 -4
  20. brainstate/nn/_module.py +6 -6
  21. brainstate/nn/_module_test.py +1 -1
  22. brainstate/nn/_normalizations.py +11 -11
  23. brainstate/nn/_normalizations_test.py +6 -6
  24. brainstate/nn/_poolings.py +24 -24
  25. brainstate/nn/_synapse.py +1 -12
  26. brainstate/nn/_utils.py +1 -1
  27. brainstate/nn/metrics.py +4 -4
  28. brainstate/optim/_optax_optimizer.py +8 -8
  29. brainstate/random/_rand_funs.py +37 -37
  30. brainstate/random/_rand_funs_test.py +3 -3
  31. brainstate/random/_rand_seed.py +7 -7
  32. brainstate/surrogate.py +40 -40
  33. brainstate/util/pretty_pytree.py +10 -10
  34. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  35. brainstate/util/struct.py +7 -7
  36. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/METADATA +12 -12
  37. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/RECORD +40 -40
  38. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/WHEEL +1 -1
  39. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/LICENSE +0 -0
  40. {brainstate-0.1.5.dist-info → brainstate-0.1.7.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,7 @@ import numpy as np
42
42
  from brainstate import environ
43
43
  from brainstate._state import State
44
44
  from brainstate.graph import Node
45
- from brainstate.mixin import ParamDescriber, UpdateReturn
45
+ from brainstate.mixin import ParamDescriber
46
46
  from brainstate.typing import Size, ArrayLike, PyTree
47
47
  from ._delay import StateWithDelay, Delay
48
48
  from ._module import Module
@@ -101,7 +101,7 @@ class Projection(Module):
101
101
  raise ValueError('Do not implement the update() function.')
102
102
 
103
103
 
104
- class Dynamics(Module, UpdateReturn):
104
+ class Dynamics(Module):
105
105
  """
106
106
  Base class for implementing neural dynamics models in BrainState.
107
107
 
@@ -214,13 +214,13 @@ class Dynamics(Module, UpdateReturn):
214
214
  # in-/out- size of neuron population
215
215
  self.out_size = self.in_size
216
216
 
217
- def __pretty_repr_item__(self, name, value):
218
- if name in [
219
- '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
220
- '_in_size', '_out_size', '_name', '_mode',
221
- ]:
222
- return (name, value) if value is None else (name[1:], value) # skip the first `_`
223
- return super().__pretty_repr_item__(name, value)
217
+ # def __pretty_repr_item__(self, name, value):
218
+ # if name in [
219
+ # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
220
+ # '_in_size', '_out_size', '_name', '_mode',
221
+ # ]:
222
+ # return (name, value) if value is None else (name[1:], value) # skip the first `_`
223
+ # return super().__pretty_repr_item__(name, value)
224
224
 
225
225
  @property
226
226
  def varshape(self):
@@ -470,21 +470,30 @@ class Dynamics(Module, UpdateReturn):
470
470
  if self._current_inputs is None:
471
471
  return init
472
472
  if label is None:
473
- # no label
474
- for key in tuple(self._current_inputs.keys()):
475
- out = self._current_inputs[key]
476
- init = init + (out(*args, **kwargs) if callable(out) else out)
477
- if not callable(out):
478
- self._current_inputs.pop(key)
473
+ filter_fn = lambda k: True
479
474
  else:
480
- # has label
481
475
  label_repr = _input_label_start(label)
482
- for key in tuple(self._current_inputs.keys()):
483
- if key.startswith(label_repr):
484
- out = self._current_inputs[key]
485
- init = init + (out(*args, **kwargs) if callable(out) else out)
486
- if not callable(out):
487
- self._current_inputs.pop(key)
476
+ filter_fn = lambda k: k.startswith(label_repr)
477
+ for key in tuple(self._current_inputs.keys()):
478
+ if filter_fn(key):
479
+ out = self._current_inputs[key]
480
+ if callable(out):
481
+ try:
482
+ init = init + out(*args, **kwargs)
483
+ except Exception as e:
484
+ raise ValueError(
485
+ f'Error in delta input value {key}: {out}\n'
486
+ f'Error: {e}'
487
+ ) from e
488
+ else:
489
+ try:
490
+ init = init + out
491
+ except Exception as e:
492
+ raise ValueError(
493
+ f'Error in delta input value {key}: {out}\n'
494
+ f'Error: {e}'
495
+ ) from e
496
+ self._current_inputs.pop(key)
488
497
  return init
489
498
 
490
499
  def sum_delta_inputs(
@@ -529,21 +538,30 @@ class Dynamics(Module, UpdateReturn):
529
538
  if self._delta_inputs is None:
530
539
  return init
531
540
  if label is None:
532
- # no label
533
- for key in tuple(self._delta_inputs.keys()):
534
- out = self._delta_inputs[key]
535
- init = init + (out(*args, **kwargs) if callable(out) else out)
536
- if not callable(out):
537
- self._delta_inputs.pop(key)
541
+ filter_fn = lambda k: True
538
542
  else:
539
- # has label
540
543
  label_repr = _input_label_start(label)
541
- for key in tuple(self._delta_inputs.keys()):
542
- if key.startswith(label_repr):
543
- out = self._delta_inputs[key]
544
- init = init + (out(*args, **kwargs) if callable(out) else out)
545
- if not callable(out):
546
- self._delta_inputs.pop(key)
544
+ filter_fn = lambda k: k.startswith(label_repr)
545
+ for key in tuple(self._delta_inputs.keys()):
546
+ if filter_fn(key):
547
+ out = self._delta_inputs[key]
548
+ if callable(out):
549
+ try:
550
+ init = init + out(*args, **kwargs)
551
+ except Exception as e:
552
+ raise ValueError(
553
+ f'Error in delta input function {key}: {out}\n'
554
+ f'Error: {e}'
555
+ ) from e
556
+ else:
557
+ try:
558
+ init = init + out
559
+ except Exception as e:
560
+ raise ValueError(
561
+ f'Error in delta input value {key}: {out}\n'
562
+ f'Error: {e}'
563
+ ) from e
564
+ self._delta_inputs.pop(key)
547
565
  return init
548
566
 
549
567
  @property
@@ -61,7 +61,7 @@ class Threshold(ElementWiseBlock):
61
61
  Examples::
62
62
 
63
63
  >>> import brainstate.nn as nn
64
- >>> import brainstate as bst
64
+ >>> import brainstate
65
65
  >>> m = nn.Threshold(0.1, 20)
66
66
  >>> x = random.randn(2)
67
67
  >>> output = m(x)
@@ -97,7 +97,7 @@ class ReLU(ElementWiseBlock):
97
97
  Examples::
98
98
 
99
99
  >>> import brainstate.nn as nn
100
- >>> import brainstate as bst
100
+ >>> import brainstate as brainstate
101
101
  >>> m = nn.ReLU()
102
102
  >>> x = random.randn(2)
103
103
  >>> output = m(x)
@@ -106,7 +106,7 @@ class ReLU(ElementWiseBlock):
106
106
  An implementation of CReLU - https://arxiv.org/abs/1603.05201
107
107
 
108
108
  >>> import brainstate.nn as nn
109
- >>> import brainstate as bst
109
+ >>> import brainstate as brainstate
110
110
  >>> m = nn.ReLU()
111
111
  >>> x = random.randn(2).unsqueeze(0)
112
112
  >>> output = jax.numpy.concat((m(x), m(-x)))
@@ -151,7 +151,7 @@ class RReLU(ElementWiseBlock):
151
151
  Examples::
152
152
 
153
153
  >>> import brainstate.nn as nn
154
- >>> import brainstate as bst
154
+ >>> import brainstate as brainstate
155
155
  >>> m = nn.RReLU(0.1, 0.3)
156
156
  >>> x = random.randn(2)
157
157
  >>> output = m(x)
@@ -205,7 +205,7 @@ class Hardtanh(ElementWiseBlock):
205
205
  Examples::
206
206
 
207
207
  >>> import brainstate.nn as nn
208
- >>> import brainstate as bst
208
+ >>> import brainstate as brainstate
209
209
  >>> m = nn.Hardtanh(-2, 2)
210
210
  >>> x = random.randn(2)
211
211
  >>> output = m(x)
@@ -244,7 +244,7 @@ class ReLU6(Hardtanh, ElementWiseBlock):
244
244
  Examples::
245
245
 
246
246
  >>> import brainstate.nn as nn
247
- >>> import brainstate as bst
247
+ >>> import brainstate as brainstate
248
248
  >>> m = nn.ReLU6()
249
249
  >>> x = random.randn(2)
250
250
  >>> output = m(x)
@@ -269,7 +269,7 @@ class Sigmoid(ElementWiseBlock):
269
269
  Examples::
270
270
 
271
271
  >>> import brainstate.nn as nn
272
- >>> import brainstate as bst
272
+ >>> import brainstate as brainstate
273
273
  >>> m = nn.Sigmoid()
274
274
  >>> x = random.randn(2)
275
275
  >>> output = m(x)
@@ -299,7 +299,7 @@ class Hardsigmoid(ElementWiseBlock):
299
299
  Examples::
300
300
 
301
301
  >>> import brainstate.nn as nn
302
- >>> import brainstate as bst
302
+ >>> import brainstate as brainstate
303
303
  >>> m = nn.Hardsigmoid()
304
304
  >>> x = random.randn(2)
305
305
  >>> output = m(x)
@@ -325,7 +325,7 @@ class Tanh(ElementWiseBlock):
325
325
  Examples::
326
326
 
327
327
  >>> import brainstate.nn as nn
328
- >>> import brainstate as bst
328
+ >>> import brainstate as brainstate
329
329
  >>> m = nn.Tanh()
330
330
  >>> x = random.randn(2)
331
331
  >>> output = m(x)
@@ -386,7 +386,7 @@ class Mish(ElementWiseBlock):
386
386
  Examples::
387
387
 
388
388
  >>> import brainstate.nn as nn
389
- >>> import brainstate as bst
389
+ >>> import brainstate as brainstate
390
390
  >>> m = nn.Mish()
391
391
  >>> x = random.randn(2)
392
392
  >>> output = m(x)
@@ -418,7 +418,7 @@ class Hardswish(ElementWiseBlock):
418
418
  Examples::
419
419
 
420
420
  >>> import brainstate.nn as nn
421
- >>> import brainstate as bst
421
+ >>> import brainstate as brainstate
422
422
  >>> m = nn.Hardswish()
423
423
  >>> x = random.randn(2)
424
424
  >>> output = m(x)
@@ -452,7 +452,7 @@ class ELU(ElementWiseBlock):
452
452
  Examples::
453
453
 
454
454
  >>> import brainstate.nn as nn
455
- >>> import brainstate as bst
455
+ >>> import brainstate as brainstate
456
456
  >>> m = nn.ELU()
457
457
  >>> x = random.randn(2)
458
458
  >>> output = m(x)
@@ -489,7 +489,7 @@ class CELU(ElementWiseBlock):
489
489
  Examples::
490
490
 
491
491
  >>> import brainstate.nn as nn
492
- >>> import brainstate as bst
492
+ >>> import brainstate as brainstate
493
493
  >>> m = nn.CELU()
494
494
  >>> x = random.randn(2)
495
495
  >>> output = m(x)
@@ -530,7 +530,7 @@ class SELU(ElementWiseBlock):
530
530
  Examples::
531
531
 
532
532
  >>> import brainstate.nn as nn
533
- >>> import brainstate as bst
533
+ >>> import brainstate as brainstate
534
534
  >>> m = nn.SELU()
535
535
  >>> x = random.randn(2)
536
536
  >>> output = m(x)
@@ -559,7 +559,7 @@ class GLU(ElementWiseBlock):
559
559
  Examples::
560
560
 
561
561
  >>> import brainstate.nn as nn
562
- >>> import brainstate as bst
562
+ >>> import brainstate as brainstate
563
563
  >>> m = nn.GLU()
564
564
  >>> x = random.randn(4, 2)
565
565
  >>> output = m(x)
@@ -600,7 +600,7 @@ class GELU(ElementWiseBlock):
600
600
  Examples::
601
601
 
602
602
  >>> import brainstate.nn as nn
603
- >>> import brainstate as bst
603
+ >>> import brainstate as brainstate
604
604
  >>> m = nn.GELU()
605
605
  >>> x = random.randn(2)
606
606
  >>> output = m(x)
@@ -642,7 +642,7 @@ class Hardshrink(ElementWiseBlock):
642
642
  Examples::
643
643
 
644
644
  >>> import brainstate.nn as nn
645
- >>> import brainstate as bst
645
+ >>> import brainstate as brainstate
646
646
  >>> m = nn.Hardshrink()
647
647
  >>> x = random.randn(2)
648
648
  >>> output = m(x)
@@ -689,7 +689,7 @@ class LeakyReLU(ElementWiseBlock):
689
689
  Examples::
690
690
 
691
691
  >>> import brainstate.nn as nn
692
- >>> import brainstate as bst
692
+ >>> import brainstate as brainstate
693
693
  >>> m = nn.LeakyReLU(0.1)
694
694
  >>> x = random.randn(2)
695
695
  >>> output = m(x)
@@ -721,7 +721,7 @@ class LogSigmoid(ElementWiseBlock):
721
721
  Examples::
722
722
 
723
723
  >>> import brainstate.nn as nn
724
- >>> import brainstate as bst
724
+ >>> import brainstate as brainstate
725
725
  >>> m = nn.LogSigmoid()
726
726
  >>> x = random.randn(2)
727
727
  >>> output = m(x)
@@ -749,7 +749,7 @@ class Softplus(ElementWiseBlock):
749
749
  Examples::
750
750
 
751
751
  >>> import brainstate.nn as nn
752
- >>> import brainstate as bst
752
+ >>> import brainstate as brainstate
753
753
  >>> m = nn.Softplus()
754
754
  >>> x = random.randn(2)
755
755
  >>> output = m(x)
@@ -781,7 +781,7 @@ class Softshrink(ElementWiseBlock):
781
781
  Examples::
782
782
 
783
783
  >>> import brainstate.nn as nn
784
- >>> import brainstate as bst
784
+ >>> import brainstate as brainstate
785
785
  >>> m = nn.Softshrink()
786
786
  >>> x = random.randn(2)
787
787
  >>> output = m(x)
@@ -843,9 +843,9 @@ class PReLU(ElementWiseBlock):
843
843
 
844
844
  Examples::
845
845
 
846
- >>> import brainstate as bst
847
- >>> m = bst.nn.PReLU()
848
- >>> x = bst.random.randn(2)
846
+ >>> import brainstate as brainstate
847
+ >>> m = brainstate.nn.PReLU()
848
+ >>> x = brainstate.random.randn(2)
849
849
  >>> output = m(x)
850
850
  """
851
851
  __module__ = 'brainstate.nn'
@@ -876,7 +876,7 @@ class Softsign(ElementWiseBlock):
876
876
  Examples::
877
877
 
878
878
  >>> import brainstate.nn as nn
879
- >>> import brainstate as bst
879
+ >>> import brainstate as brainstate
880
880
  >>> m = nn.Softsign()
881
881
  >>> x = random.randn(2)
882
882
  >>> output = m(x)
@@ -900,7 +900,7 @@ class Tanhshrink(ElementWiseBlock):
900
900
  Examples::
901
901
 
902
902
  >>> import brainstate.nn as nn
903
- >>> import brainstate as bst
903
+ >>> import brainstate as brainstate
904
904
  >>> m = nn.Tanhshrink()
905
905
  >>> x = random.randn(2)
906
906
  >>> output = m(x)
@@ -937,7 +937,7 @@ class Softmin(ElementWiseBlock):
937
937
  Examples::
938
938
 
939
939
  >>> import brainstate.nn as nn
940
- >>> import brainstate as bst
940
+ >>> import brainstate as brainstate
941
941
  >>> m = nn.Softmin(dim=1)
942
942
  >>> x = random.randn(2, 3)
943
943
  >>> output = m(x)
@@ -990,7 +990,7 @@ class Softmax(ElementWiseBlock):
990
990
  Examples::
991
991
 
992
992
  >>> import brainstate.nn as nn
993
- >>> import brainstate as bst
993
+ >>> import brainstate as brainstate
994
994
  >>> m = nn.Softmax(dim=1)
995
995
  >>> x = random.randn(2, 3)
996
996
  >>> output = m(x)
@@ -1027,7 +1027,7 @@ class Softmax2d(ElementWiseBlock):
1027
1027
  Examples::
1028
1028
 
1029
1029
  >>> import brainstate.nn as nn
1030
- >>> import brainstate as bst
1030
+ >>> import brainstate as brainstate
1031
1031
  >>> m = nn.Softmax2d()
1032
1032
  >>> # you softmax over the 2nd dimension
1033
1033
  >>> x = random.randn(2, 3, 12, 13)
@@ -1062,7 +1062,7 @@ class LogSoftmax(ElementWiseBlock):
1062
1062
  Examples::
1063
1063
 
1064
1064
  >>> import brainstate.nn as nn
1065
- >>> import brainstate as bst
1065
+ >>> import brainstate as brainstate
1066
1066
  >>> m = nn.LogSoftmax(dim=1)
1067
1067
  >>> x = random.randn(2, 3)
1068
1068
  >>> output = m(x)
brainstate/nn/_linear.py CHANGED
@@ -361,9 +361,9 @@ class LoRA(Module):
361
361
 
362
362
  Example usage::
363
363
 
364
- >>> import brainstate as bst
364
+ >>> import brainstate as brainstate
365
365
  >>> import jax, jax.numpy as jnp
366
- >>> layer = bst.nn.LoRA(3, 2, 4)
366
+ >>> layer = brainstate.nn.LoRA(3, 2, 4)
367
367
  >>> layer.weight.value
368
368
  {'lora_a': Array([[ 0.25141352, -0.09826107],
369
369
  [ 0.2328382 , 0.38869813],
@@ -371,8 +371,8 @@ class LoRA(Module):
371
371
  'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
372
372
  [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
373
373
  >>> # Wrap around existing layer
374
- >>> linear = bst.nn.Linear(3, 4)
375
- >>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
374
+ >>> linear = brainstate.nn.Linear(3, 4)
375
+ >>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
376
376
  >>> assert wrapper.base_module == linear
377
377
  >>> y = layer(jnp.ones((16, 3)))
378
378
  >>> y.shape
brainstate/nn/_module.py CHANGED
@@ -128,9 +128,9 @@ class Module(Node, ParamDesc):
128
128
  Examples
129
129
  --------
130
130
 
131
- >>> import brainstate as bst
132
- >>> x = bst.random.rand((10, 10))
133
- >>> l = bst.nn.Dropout(0.5)
131
+ >>> import brainstate as brainstate
132
+ >>> x = brainstate.random.rand((10, 10))
133
+ >>> l = brainstate.nn.Dropout(0.5)
134
134
  >>> y = x >> l
135
135
  """
136
136
  return self.__call__(other)
@@ -230,7 +230,7 @@ class Module(Node, ParamDesc):
230
230
  pass
231
231
 
232
232
  def __pretty_repr_item__(self, name, value):
233
- if name in ['_in_size', '_out_size', '_name']:
233
+ if name.startswith('_'):
234
234
  return None if value is None else (name[1:], value) # skip the first `_`
235
235
  return name, value
236
236
 
@@ -266,14 +266,14 @@ class Sequential(Module):
266
266
  --------
267
267
 
268
268
  >>> import jax
269
- >>> import brainstate as bst
269
+ >>> import brainstate as brainstate
270
270
  >>> import brainstate.nn as nn
271
271
  >>>
272
272
  >>> # composing ANN models
273
273
  >>> l = nn.Sequential(nn.Linear(100, 10),
274
274
  >>> jax.nn.relu,
275
275
  >>> nn.Linear(10, 2))
276
- >>> l(bst.random.random((256, 100)))
276
+ >>> l(brainstate.random.random((256, 100)))
277
277
 
278
278
  Args:
279
279
  modules_as_tuple: The children modules.
@@ -77,7 +77,7 @@ class TestDelay(unittest.TestCase):
77
77
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
78
78
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
79
79
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
80
- # bst.util.clear_buffer_memory()
80
+ # brainstate.util.clear_buffer_memory()
81
81
 
82
82
  def test_jit_erro(self):
83
83
  rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
@@ -488,9 +488,9 @@ class LayerNorm(Module):
488
488
 
489
489
  Example usage::
490
490
 
491
- >>> import brainstate as bst
492
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
493
- >>> layer = bst.nn.LayerNorm(x.shape)
491
+ >>> import brainstate as brainstate
492
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
493
+ >>> layer = brainstate.nn.LayerNorm(x.shape)
494
494
  >>> layer.states()
495
495
  >>> y = layer(x)
496
496
 
@@ -616,9 +616,9 @@ class RMSNorm(Module):
616
616
 
617
617
  Example usage::
618
618
 
619
- >>> import brainstate as bst
620
- >>> x = bst.random.normal(size=(5, 6))
621
- >>> layer = bst.nn.RMSNorm(num_features=6)
619
+ >>> import brainstate as brainstate
620
+ >>> x = brainstate.random.normal(size=(5, 6))
621
+ >>> layer = brainstate.nn.RMSNorm(num_features=6)
622
622
  >>> layer.states()
623
623
  >>> y = layer(x)
624
624
 
@@ -739,14 +739,14 @@ class GroupNorm(Module):
739
739
  Example usage::
740
740
 
741
741
  >>> import numpy as np
742
- >>> import brainstate as bst
742
+ >>> import brainstate as brainstate
743
743
  ...
744
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
745
- >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
744
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
745
+ >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
746
746
  >>> layer.states()
747
747
  >>> y = layer(x)
748
- >>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
749
- >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
748
+ >>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
749
+ >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
750
750
  >>> np.testing.assert_allclose(y, y2)
751
751
 
752
752
  Attributes:
@@ -51,21 +51,21 @@ class Test_Normalization(parameterized.TestCase):
51
51
  # normalized_shape=(10, [5, 10])
52
52
  # )
53
53
  # def test_LayerNorm(self, normalized_shape):
54
- # net = bst.nn.LayerNorm(normalized_shape, )
55
- # input = bst.random.randn(20, 5, 10)
54
+ # net = brainstate.nn.LayerNorm(normalized_shape, )
55
+ # input = brainstate.random.randn(20, 5, 10)
56
56
  # output = net(input)
57
57
  #
58
58
  # @parameterized.product(
59
59
  # num_groups=[1, 2, 3, 6]
60
60
  # )
61
61
  # def test_GroupNorm(self, num_groups):
62
- # input = bst.random.randn(20, 10, 10, 6)
63
- # net = bst.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
62
+ # input = brainstate.random.randn(20, 10, 10, 6)
63
+ # net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
64
64
  # output = net(input)
65
65
  #
66
66
  # def test_InstanceNorm(self):
67
- # input = bst.random.randn(20, 10, 10, 6)
68
- # net = bst.nn.InstanceNorm(num_channels=6, )
67
+ # input = brainstate.random.randn(20, 10, 10, 6)
68
+ # net = brainstate.nn.InstanceNorm(num_channels=6, )
69
69
  # output = net(input)
70
70
 
71
71