brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -42,20 +42,6 @@ class TestDynamics(unittest.TestCase):
42
42
  with self.assertRaises(ValueError):
43
43
  brainstate.nn.Dynamics(in_size="invalid")
44
44
 
45
- def test_input_handling(self):
46
- dyn = brainstate.nn.Dynamics(in_size=10)
47
- dyn.add_current_input("test_current", lambda: np.random.rand(10))
48
- dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
49
-
50
- self.assertIn("test_current", dyn.current_inputs)
51
- self.assertIn("test_delta", dyn.delta_inputs)
52
-
53
- def test_duplicate_input_key(self):
54
- dyn = brainstate.nn.Dynamics(in_size=10)
55
- dyn.add_current_input("test", lambda: np.random.rand(10))
56
- with self.assertRaises(ValueError):
57
- dyn.add_current_input("test", lambda: np.random.rand(10))
58
-
59
45
  def test_varshape(self):
60
46
  dyn = brainstate.nn.Dynamics(in_size=(2, 3))
61
47
  self.assertEqual(dyn.varshape, (2, 3))
@@ -161,8 +161,8 @@ class Embedding(Module):
161
161
 
162
162
  .. code-block:: python
163
163
 
164
- >>> import brainstate as bst
165
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3)
164
+ >>> import brainstate as brainstate
165
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3)
166
166
  >>> embedding.weight.value.shape
167
167
  (10, 3)
168
168
 
@@ -191,7 +191,7 @@ class Embedding(Module):
191
191
 
192
192
  .. code-block:: python
193
193
 
194
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
194
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
195
195
  >>> # The embedding at index 0 will remain zeros and not be updated during training
196
196
  >>> indices = jnp.array([0, 2, 0, 5])
197
197
  >>> output = embedding(indices)
@@ -202,7 +202,7 @@ class Embedding(Module):
202
202
 
203
203
  .. code-block:: python
204
204
 
205
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
205
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
206
206
  >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
207
207
 
208
208
  Load pretrained embeddings:
@@ -214,7 +214,7 @@ class Embedding(Module):
214
214
  >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
215
215
  ... [4.0, 5.0, 6.0],
216
216
  ... [7.0, 8.0, 9.0]])
217
- >>> embedding = bst.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
217
+ >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
218
218
  >>> embedding.weight.value.shape
219
219
  (3, 3)
220
220
  """
@@ -310,11 +310,11 @@ class Embedding(Module):
310
310
  .. code-block:: python
311
311
 
312
312
  >>> import jax.numpy as jnp
313
- >>> import brainstate as bst
313
+ >>> import brainstate as brainstate
314
314
  >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
315
315
  ... [4.0, 5.0, 6.0],
316
316
  ... [7.0, 8.0, 9.0]])
317
- >>> embedding = bst.nn.Embedding.from_pretrained(pretrained)
317
+ >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained)
318
318
  >>> embedding.weight.value.shape
319
319
  (3, 3)
320
320
  >>> indices = jnp.array([1])
@@ -112,11 +112,11 @@ def exp_euler_step(
112
112
 
113
113
  .. code-block:: python
114
114
 
115
- >>> import brainstate as bst
115
+ >>> import brainstate as brainstate
116
116
  >>> import jax.numpy as jnp
117
117
  >>>
118
118
  >>> # Set time step in environment
119
- >>> bst.environ.set(dt=0.01)
119
+ >>> brainstate.environ.set(dt=0.01)
120
120
  >>>
121
121
  >>> # Define drift function
122
122
  >>> def drift(x, t):
@@ -126,7 +126,7 @@ def exp_euler_step(
126
126
  >>> x0 = jnp.array(1.0)
127
127
  >>>
128
128
  >>> # Single integration step
129
- >>> x1 = bst.nn.exp_euler_step(drift, x0, None)
129
+ >>> x1 = brainstate.nn.exp_euler_step(drift, x0, None)
130
130
  >>> print(x1) # Should be close to exp(-0.01) ≈ 0.99
131
131
 
132
132
  **SDE Integration:**
@@ -135,11 +135,11 @@ def exp_euler_step(
135
135
 
136
136
  .. code-block:: python
137
137
 
138
- >>> import brainstate as bst
138
+ >>> import brainstate as brainstate
139
139
  >>> import jax.numpy as jnp
140
140
  >>>
141
141
  >>> # Set time step
142
- >>> bst.environ.set(dt=0.01)
142
+ >>> brainstate.environ.set(dt=0.01)
143
143
  >>>
144
144
  >>> # Define drift and diffusion
145
145
  >>> theta = 0.5
@@ -155,16 +155,16 @@ def exp_euler_step(
155
155
  >>> x0 = jnp.array(1.0)
156
156
  >>>
157
157
  >>> # Single SDE integration step
158
- >>> x1 = bst.nn.exp_euler_step(drift, diffusion, x0, None)
158
+ >>> x1 = brainstate.nn.exp_euler_step(drift, diffusion, x0, None)
159
159
 
160
160
  **Multi-dimensional system:**
161
161
 
162
162
  .. code-block:: python
163
163
 
164
- >>> import brainstate as bst
164
+ >>> import brainstate as brainstate
165
165
  >>> import jax.numpy as jnp
166
166
  >>>
167
- >>> bst.environ.set(dt=0.01)
167
+ >>> brainstate.environ.set(dt=0.01)
168
168
  >>>
169
169
  >>> # Coupled oscillator system
170
170
  >>> def drift(x, t):
@@ -172,7 +172,7 @@ def exp_euler_step(
172
172
  ... return jnp.array([-x1 + x2, -x2 - x1])
173
173
  >>>
174
174
  >>> x0 = jnp.array([1.0, 0.0])
175
- >>> x1 = bst.nn.exp_euler_step(drift, x0, None)
175
+ >>> x1 = brainstate.nn.exp_euler_step(drift, x0, None)
176
176
 
177
177
  See Also
178
178
  --------
brainstate/nn/_linear.py CHANGED
@@ -76,18 +76,18 @@ class Linear(Module):
76
76
  --------
77
77
  .. code-block:: python
78
78
 
79
- >>> import brainstate as bst
79
+ >>> import brainstate as brainstate
80
80
  >>> import jax.numpy as jnp
81
81
  >>>
82
82
  >>> # Create a linear layer
83
- >>> layer = bst.nn.Linear((10,), (5,))
83
+ >>> layer = brainstate.nn.Linear((10,), (5,))
84
84
  >>> x = jnp.ones((32, 10))
85
85
  >>> y = layer(x)
86
86
  >>> y.shape
87
87
  (32, 5)
88
88
  >>>
89
89
  >>> # Linear layer without bias
90
- >>> layer = bst.nn.Linear((10,), (5,), b_init=None)
90
+ >>> layer = brainstate.nn.Linear((10,), (5,), b_init=None)
91
91
  >>> y = layer(x)
92
92
  >>> y.shape
93
93
  (32, 5)
@@ -171,11 +171,11 @@ class SignedWLinear(Module):
171
171
  --------
172
172
  .. code-block:: python
173
173
 
174
- >>> import brainstate as bst
174
+ >>> import brainstate as brainstate
175
175
  >>> import jax.numpy as jnp
176
176
  >>>
177
177
  >>> # Create a signed weight linear layer with all positive weights
178
- >>> layer = bst.nn.SignedWLinear((10,), (5,))
178
+ >>> layer = brainstate.nn.SignedWLinear((10,), (5,))
179
179
  >>> x = jnp.ones((32, 10))
180
180
  >>> y = layer(x)
181
181
  >>> y.shape
@@ -183,7 +183,7 @@ class SignedWLinear(Module):
183
183
  >>>
184
184
  >>> # With custom sign matrix (e.g., inhibitory connections)
185
185
  >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
186
- >>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
186
+ >>> layer = brainstate.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
187
187
  >>> y = layer(x)
188
188
  >>> y.shape
189
189
  (32, 5)
@@ -274,18 +274,18 @@ class ScaledWSLinear(Module):
274
274
  --------
275
275
  .. code-block:: python
276
276
 
277
- >>> import brainstate as bst
277
+ >>> import brainstate as brainstate
278
278
  >>> import jax.numpy as jnp
279
279
  >>>
280
280
  >>> # Create a weight-standardized linear layer
281
- >>> layer = bst.nn.ScaledWSLinear((10,), (5,))
281
+ >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,))
282
282
  >>> x = jnp.ones((32, 10))
283
283
  >>> y = layer(x)
284
284
  >>> y.shape
285
285
  (32, 5)
286
286
  >>>
287
287
  >>> # Without learnable gain
288
- >>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
288
+ >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
289
289
  >>> y = layer(x)
290
290
  >>> y.shape
291
291
  (32, 5)
@@ -375,7 +375,7 @@ class SparseLinear(Module):
375
375
  --------
376
376
  .. code-block:: python
377
377
 
378
- >>> import brainstate as bst
378
+ >>> import brainstate as brainstate
379
379
  >>> import brainunit as u
380
380
  >>> import jax.numpy as jnp
381
381
  >>>
@@ -384,7 +384,7 @@ class SparseLinear(Module):
384
384
  >>> values = jnp.array([1.0, 2.0, 3.0])
385
385
  >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
386
386
  ... shape=(3, 3))
387
- >>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
387
+ >>> layer = brainstate.nn.SparseLinear(spar_mat, in_size=(3,))
388
388
  >>> x = jnp.ones((5, 3))
389
389
  >>> y = layer(x)
390
390
  >>> y.shape
@@ -468,18 +468,18 @@ class AllToAll(Module):
468
468
  --------
469
469
  .. code-block:: python
470
470
 
471
- >>> import brainstate as bst
471
+ >>> import brainstate as brainstate
472
472
  >>> import jax.numpy as jnp
473
473
  >>>
474
474
  >>> # All-to-all with self-connections
475
- >>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
475
+ >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=True)
476
476
  >>> x = jnp.ones((32, 10))
477
477
  >>> y = layer(x)
478
478
  >>> y.shape
479
479
  (32, 10)
480
480
  >>>
481
481
  >>> # All-to-all without self-connections (recurrent layer)
482
- >>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
482
+ >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=False)
483
483
  >>> y = layer(x)
484
484
  >>> y.shape
485
485
  (32, 10)
@@ -584,18 +584,18 @@ class OneToOne(Module):
584
584
  --------
585
585
  .. code-block:: python
586
586
 
587
- >>> import brainstate as bst
587
+ >>> import brainstate as brainstate
588
588
  >>> import jax.numpy as jnp
589
589
  >>>
590
590
  >>> # One-to-one connection
591
- >>> layer = bst.nn.OneToOne((10,))
591
+ >>> layer = brainstate.nn.OneToOne((10,))
592
592
  >>> x = jnp.ones((32, 10))
593
593
  >>> y = layer(x)
594
594
  >>> y.shape
595
595
  (32, 10)
596
596
  >>>
597
597
  >>> # With bias
598
- >>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
598
+ >>> layer = brainstate.nn.OneToOne((10,), b_init=brainstate.init.Constant(0.1))
599
599
  >>> y = layer(x)
600
600
  >>> y.shape
601
601
  (32, 10)
@@ -677,19 +677,19 @@ class LoRA(Module):
677
677
  --------
678
678
  .. code-block:: python
679
679
 
680
- >>> import brainstate as bst
680
+ >>> import brainstate as brainstate
681
681
  >>> import jax.numpy as jnp
682
682
  >>>
683
683
  >>> # Standalone LoRA layer
684
- >>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
684
+ >>> layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
685
685
  >>> x = jnp.ones((32, 10))
686
686
  >>> y = layer(x)
687
687
  >>> y.shape
688
688
  (32, 5)
689
689
  >>>
690
690
  >>> # Wrap around existing linear layer
691
- >>> base = bst.nn.Linear((10,), (5,))
692
- >>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
691
+ >>> base = brainstate.nn.Linear((10,), (5,))
692
+ >>> lora_layer = brainstate.nn.LoRA(in_features=10, lora_rank=2,
693
693
  ... out_features=5, base_module=base)
694
694
  >>> y = lora_layer(x)
695
695
  >>> y.shape
brainstate/nn/_module.py CHANGED
@@ -363,22 +363,29 @@ class Sequential(Module):
363
363
  self.out_size = in_size
364
364
 
365
365
  def _format_module(self, module, in_size):
366
- if isinstance(module, ParamDescriber):
367
- if in_size is None:
368
- raise ValueError(
369
- 'The input size should be specified. '
370
- f'Please set the in_size attribute of the previous module: \n'
371
- f'{self.layers[-1]}'
372
- )
373
- module = module(in_size=in_size)
374
- assert isinstance(module, Module), 'The module should be an instance of Module.'
375
- out_size = module.out_size
376
- elif isinstance(module, ElementWiseBlock):
377
- out_size = in_size
378
- elif isinstance(module, Module):
379
- out_size = module.out_size
380
- elif callable(module):
381
- out_size = in_size
382
- else:
383
- raise TypeError(f"Unsupported type {type(module)}. ")
366
+ try:
367
+ if isinstance(module, ParamDescriber):
368
+ if in_size is None:
369
+ raise ValueError(
370
+ 'The input size should be specified. '
371
+ f'Please set the in_size attribute of the previous module: \n'
372
+ f'{self.layers[-1]}'
373
+ )
374
+ module = module(in_size=in_size)
375
+ assert isinstance(module, Module), 'The module should be an instance of Module.'
376
+ out_size = module.out_size
377
+ elif isinstance(module, ElementWiseBlock):
378
+ out_size = in_size
379
+ elif isinstance(module, Module):
380
+ out_size = module.out_size
381
+ elif callable(module):
382
+ out_size = in_size
383
+ else:
384
+ raise TypeError(f"Unsupported type {type(module)}. ")
385
+ except Exception as e:
386
+ raise BrainStateError(
387
+ f'Failed to format the module: \n'
388
+ f'{module}\n'
389
+ f'with input size: {in_size}\n'
390
+ ) from e
384
391
  return module, out_size
@@ -77,16 +77,16 @@ def weight_standardization(
77
77
  --------
78
78
  .. code-block:: python
79
79
 
80
- >>> import brainstate as bst
80
+ >>> import brainstate as brainstate
81
81
  >>> import jax.numpy as jnp
82
82
  >>>
83
83
  >>> # Standardize a weight matrix
84
84
  >>> w = jnp.ones((3, 4))
85
- >>> w_std = bst.nn.weight_standardization(w)
85
+ >>> w_std = brainstate.nn.weight_standardization(w)
86
86
  >>>
87
87
  >>> # With custom gain
88
88
  >>> gain = jnp.ones((4,))
89
- >>> w_std = bst.nn.weight_standardization(w, gain=gain)
89
+ >>> w_std = brainstate.nn.weight_standardization(w, gain=gain)
90
90
  """
91
91
  w = u.maybe_custom_array(w)
92
92
  if out_axis < 0:
@@ -551,11 +551,11 @@ class BatchNorm0d(_BatchNorm):
551
551
  --------
552
552
  .. code-block:: python
553
553
 
554
- >>> import brainstate as bst
554
+ >>> import brainstate as brainstate
555
555
  >>> import jax.numpy as jnp
556
556
  >>>
557
557
  >>> # Create a BatchNorm0d layer
558
- >>> layer = bst.nn.BatchNorm0d(in_size=(10,))
558
+ >>> layer = brainstate.nn.BatchNorm0d(in_size=(10,))
559
559
  >>>
560
560
  >>> # Apply normalization to a batch of data
561
561
  >>> x = jnp.ones((32, 10)) # batch_size=32, features=10
@@ -623,11 +623,11 @@ class BatchNorm1d(_BatchNorm):
623
623
  --------
624
624
  .. code-block:: python
625
625
 
626
- >>> import brainstate as bst
626
+ >>> import brainstate as brainstate
627
627
  >>> import jax.numpy as jnp
628
628
  >>>
629
629
  >>> # Create a BatchNorm1d layer for sequence data
630
- >>> layer = bst.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
630
+ >>> layer = brainstate.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
631
631
  >>>
632
632
  >>> # Apply normalization
633
633
  >>> x = jnp.ones((8, 100, 64)) # batch_size=8
@@ -693,11 +693,11 @@ class BatchNorm2d(_BatchNorm):
693
693
  --------
694
694
  .. code-block:: python
695
695
 
696
- >>> import brainstate as bst
696
+ >>> import brainstate as brainstate
697
697
  >>> import jax.numpy as jnp
698
698
  >>>
699
699
  >>> # Create a BatchNorm2d layer for image data
700
- >>> layer = bst.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
700
+ >>> layer = brainstate.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
701
701
  >>>
702
702
  >>> # Apply normalization
703
703
  >>> x = jnp.ones((16, 28, 28, 3)) # batch_size=16
@@ -763,11 +763,11 @@ class BatchNorm3d(_BatchNorm):
763
763
  --------
764
764
  .. code-block:: python
765
765
 
766
- >>> import brainstate as bst
766
+ >>> import brainstate as brainstate
767
767
  >>> import jax.numpy as jnp
768
768
  >>>
769
769
  >>> # Create a BatchNorm3d layer for volumetric data
770
- >>> layer = bst.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
770
+ >>> layer = brainstate.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
771
771
  >>>
772
772
  >>> # Apply normalization
773
773
  >>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4
@@ -841,11 +841,11 @@ class LayerNorm(Module):
841
841
  --------
842
842
  .. code-block:: python
843
843
 
844
- >>> import brainstate as bst
844
+ >>> import brainstate as brainstate
845
845
  >>>
846
846
  >>> # Create a LayerNorm layer
847
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
848
- >>> layer = bst.nn.LayerNorm(x.shape)
847
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
848
+ >>> layer = brainstate.nn.LayerNorm(x.shape)
849
849
  >>>
850
850
  >>> # Apply normalization
851
851
  >>> y = layer(x)
@@ -853,8 +853,8 @@ class LayerNorm(Module):
853
853
  (3, 4, 5, 6)
854
854
  >>>
855
855
  >>> # Normalize only the last dimension
856
- >>> layer = bst.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
857
- >>> x = bst.random.normal((5, 10, 20))
856
+ >>> layer = brainstate.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
857
+ >>> x = brainstate.random.normal((5, 10, 20))
858
858
  >>> y = layer(x)
859
859
  """
860
860
 
@@ -1005,11 +1005,11 @@ class RMSNorm(Module):
1005
1005
  --------
1006
1006
  .. code-block:: python
1007
1007
 
1008
- >>> import brainstate as bst
1008
+ >>> import brainstate as brainstate
1009
1009
  >>>
1010
1010
  >>> # Create an RMSNorm layer
1011
- >>> x = bst.random.normal(size=(5, 6))
1012
- >>> layer = bst.nn.RMSNorm(in_size=(6,))
1011
+ >>> x = brainstate.random.normal(size=(5, 6))
1012
+ >>> layer = brainstate.nn.RMSNorm(in_size=(6,))
1013
1013
  >>>
1014
1014
  >>> # Apply normalization
1015
1015
  >>> y = layer(x)
@@ -1017,8 +1017,8 @@ class RMSNorm(Module):
1017
1017
  (5, 6)
1018
1018
  >>>
1019
1019
  >>> # Without scaling
1020
- >>> layer = bst.nn.RMSNorm(in_size=(10,), use_scale=False)
1021
- >>> x = bst.random.normal((3, 10))
1020
+ >>> layer = brainstate.nn.RMSNorm(in_size=(10,), use_scale=False)
1021
+ >>> x = brainstate.random.normal((3, 10))
1022
1022
  >>> y = layer(x)
1023
1023
  """
1024
1024
 
@@ -1179,20 +1179,20 @@ class GroupNorm(Module):
1179
1179
  .. code-block:: python
1180
1180
 
1181
1181
  >>> import numpy as np
1182
- >>> import brainstate as bst
1182
+ >>> import brainstate as brainstate
1183
1183
  >>>
1184
1184
  >>> # Create a GroupNorm layer with 3 groups
1185
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
1186
- >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
1185
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
1186
+ >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
1187
1187
  >>> y = layer(x)
1188
1188
  >>>
1189
1189
  >>> # GroupNorm with num_groups=1 is equivalent to LayerNorm
1190
- >>> y1 = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
1191
- >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
1190
+ >>> y1 = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
1191
+ >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
1192
1192
  >>> np.testing.assert_allclose(y1, y2, rtol=1e-5)
1193
1193
  >>>
1194
1194
  >>> # Specify group_size instead of num_groups
1195
- >>> layer = bst.nn.GroupNorm((12,), num_groups=None, group_size=4)
1195
+ >>> layer = brainstate.nn.GroupNorm((12,), num_groups=None, group_size=4)
1196
1196
  """
1197
1197
 
1198
1198
  def __init__(
@@ -259,12 +259,12 @@ References
259
259
 
260
260
  """
261
261
 
262
- from ._rand_funs import *
263
- from ._rand_funs import __all__ as __all_random__
264
- from ._rand_seed import *
265
- from ._rand_seed import __all__ as __all_seed__
266
- from ._rand_state import *
267
- from ._rand_state import __all__ as __all_state__
262
+ from ._fun import *
263
+ from ._fun import __all__ as __all_random__
264
+ from ._seed import *
265
+ from ._seed import __all__ as __all_seed__
266
+ from ._state import *
267
+ from ._state import __all__ as __all_state__
268
268
 
269
269
  __all__ = __all_random__ + __all_state__ + __all_seed__
270
270
  del __all_random__, __all_state__, __all_seed__
@@ -21,7 +21,7 @@ from typing import Optional
21
21
  import numpy as np
22
22
 
23
23
  from brainstate.typing import DTypeLike, Size, SeedOrKey
24
- from ._rand_state import RandomState, DEFAULT
24
+ from ._state import RandomState, DEFAULT
25
25
 
26
26
  __all__ = [
27
27
  # numpy compatibility
@@ -437,13 +437,11 @@ class TestRandom(unittest.TestCase):
437
437
  a = brainstate.random.hypergeometric(10, 10, 10, 20)
438
438
  self.assertTupleEqual(a.shape, (20,))
439
439
 
440
- @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
441
440
  def test_hypergeometric2(self):
442
441
  brainstate.random.seed()
443
442
  a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
444
443
  self.assertTupleEqual(a.shape, (2, 2))
445
444
 
446
- @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
447
445
  def test_hypergeometric3(self):
448
446
  brainstate.random.seed()
449
447
  a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))