brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_dynamics_test.py
CHANGED
@@ -42,20 +42,6 @@ class TestDynamics(unittest.TestCase):
|
|
42
42
|
with self.assertRaises(ValueError):
|
43
43
|
brainstate.nn.Dynamics(in_size="invalid")
|
44
44
|
|
45
|
-
def test_input_handling(self):
|
46
|
-
dyn = brainstate.nn.Dynamics(in_size=10)
|
47
|
-
dyn.add_current_input("test_current", lambda: np.random.rand(10))
|
48
|
-
dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
|
49
|
-
|
50
|
-
self.assertIn("test_current", dyn.current_inputs)
|
51
|
-
self.assertIn("test_delta", dyn.delta_inputs)
|
52
|
-
|
53
|
-
def test_duplicate_input_key(self):
|
54
|
-
dyn = brainstate.nn.Dynamics(in_size=10)
|
55
|
-
dyn.add_current_input("test", lambda: np.random.rand(10))
|
56
|
-
with self.assertRaises(ValueError):
|
57
|
-
dyn.add_current_input("test", lambda: np.random.rand(10))
|
58
|
-
|
59
45
|
def test_varshape(self):
|
60
46
|
dyn = brainstate.nn.Dynamics(in_size=(2, 3))
|
61
47
|
self.assertEqual(dyn.varshape, (2, 3))
|
brainstate/nn/_embedding.py
CHANGED
@@ -161,8 +161,8 @@ class Embedding(Module):
|
|
161
161
|
|
162
162
|
.. code-block:: python
|
163
163
|
|
164
|
-
>>> import brainstate as
|
165
|
-
>>> embedding =
|
164
|
+
>>> import brainstate as brainstate
|
165
|
+
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3)
|
166
166
|
>>> embedding.weight.value.shape
|
167
167
|
(10, 3)
|
168
168
|
|
@@ -191,7 +191,7 @@ class Embedding(Module):
|
|
191
191
|
|
192
192
|
.. code-block:: python
|
193
193
|
|
194
|
-
>>> embedding =
|
194
|
+
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
|
195
195
|
>>> # The embedding at index 0 will remain zeros and not be updated during training
|
196
196
|
>>> indices = jnp.array([0, 2, 0, 5])
|
197
197
|
>>> output = embedding(indices)
|
@@ -202,7 +202,7 @@ class Embedding(Module):
|
|
202
202
|
|
203
203
|
.. code-block:: python
|
204
204
|
|
205
|
-
>>> embedding =
|
205
|
+
>>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
|
206
206
|
>>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
|
207
207
|
|
208
208
|
Load pretrained embeddings:
|
@@ -214,7 +214,7 @@ class Embedding(Module):
|
|
214
214
|
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
|
215
215
|
... [4.0, 5.0, 6.0],
|
216
216
|
... [7.0, 8.0, 9.0]])
|
217
|
-
>>> embedding =
|
217
|
+
>>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
|
218
218
|
>>> embedding.weight.value.shape
|
219
219
|
(3, 3)
|
220
220
|
"""
|
@@ -310,11 +310,11 @@ class Embedding(Module):
|
|
310
310
|
.. code-block:: python
|
311
311
|
|
312
312
|
>>> import jax.numpy as jnp
|
313
|
-
>>> import brainstate as
|
313
|
+
>>> import brainstate as brainstate
|
314
314
|
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
|
315
315
|
... [4.0, 5.0, 6.0],
|
316
316
|
... [7.0, 8.0, 9.0]])
|
317
|
-
>>> embedding =
|
317
|
+
>>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained)
|
318
318
|
>>> embedding.weight.value.shape
|
319
319
|
(3, 3)
|
320
320
|
>>> indices = jnp.array([1])
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -112,11 +112,11 @@ def exp_euler_step(
|
|
112
112
|
|
113
113
|
.. code-block:: python
|
114
114
|
|
115
|
-
>>> import brainstate as
|
115
|
+
>>> import brainstate as brainstate
|
116
116
|
>>> import jax.numpy as jnp
|
117
117
|
>>>
|
118
118
|
>>> # Set time step in environment
|
119
|
-
>>>
|
119
|
+
>>> brainstate.environ.set(dt=0.01)
|
120
120
|
>>>
|
121
121
|
>>> # Define drift function
|
122
122
|
>>> def drift(x, t):
|
@@ -126,7 +126,7 @@ def exp_euler_step(
|
|
126
126
|
>>> x0 = jnp.array(1.0)
|
127
127
|
>>>
|
128
128
|
>>> # Single integration step
|
129
|
-
>>> x1 =
|
129
|
+
>>> x1 = brainstate.nn.exp_euler_step(drift, x0, None)
|
130
130
|
>>> print(x1) # Should be close to exp(-0.01) ≈ 0.99
|
131
131
|
|
132
132
|
**SDE Integration:**
|
@@ -135,11 +135,11 @@ def exp_euler_step(
|
|
135
135
|
|
136
136
|
.. code-block:: python
|
137
137
|
|
138
|
-
>>> import brainstate as
|
138
|
+
>>> import brainstate as brainstate
|
139
139
|
>>> import jax.numpy as jnp
|
140
140
|
>>>
|
141
141
|
>>> # Set time step
|
142
|
-
>>>
|
142
|
+
>>> brainstate.environ.set(dt=0.01)
|
143
143
|
>>>
|
144
144
|
>>> # Define drift and diffusion
|
145
145
|
>>> theta = 0.5
|
@@ -155,16 +155,16 @@ def exp_euler_step(
|
|
155
155
|
>>> x0 = jnp.array(1.0)
|
156
156
|
>>>
|
157
157
|
>>> # Single SDE integration step
|
158
|
-
>>> x1 =
|
158
|
+
>>> x1 = brainstate.nn.exp_euler_step(drift, diffusion, x0, None)
|
159
159
|
|
160
160
|
**Multi-dimensional system:**
|
161
161
|
|
162
162
|
.. code-block:: python
|
163
163
|
|
164
|
-
>>> import brainstate as
|
164
|
+
>>> import brainstate as brainstate
|
165
165
|
>>> import jax.numpy as jnp
|
166
166
|
>>>
|
167
|
-
>>>
|
167
|
+
>>> brainstate.environ.set(dt=0.01)
|
168
168
|
>>>
|
169
169
|
>>> # Coupled oscillator system
|
170
170
|
>>> def drift(x, t):
|
@@ -172,7 +172,7 @@ def exp_euler_step(
|
|
172
172
|
... return jnp.array([-x1 + x2, -x2 - x1])
|
173
173
|
>>>
|
174
174
|
>>> x0 = jnp.array([1.0, 0.0])
|
175
|
-
>>> x1 =
|
175
|
+
>>> x1 = brainstate.nn.exp_euler_step(drift, x0, None)
|
176
176
|
|
177
177
|
See Also
|
178
178
|
--------
|
brainstate/nn/_linear.py
CHANGED
@@ -76,18 +76,18 @@ class Linear(Module):
|
|
76
76
|
--------
|
77
77
|
.. code-block:: python
|
78
78
|
|
79
|
-
>>> import brainstate as
|
79
|
+
>>> import brainstate as brainstate
|
80
80
|
>>> import jax.numpy as jnp
|
81
81
|
>>>
|
82
82
|
>>> # Create a linear layer
|
83
|
-
>>> layer =
|
83
|
+
>>> layer = brainstate.nn.Linear((10,), (5,))
|
84
84
|
>>> x = jnp.ones((32, 10))
|
85
85
|
>>> y = layer(x)
|
86
86
|
>>> y.shape
|
87
87
|
(32, 5)
|
88
88
|
>>>
|
89
89
|
>>> # Linear layer without bias
|
90
|
-
>>> layer =
|
90
|
+
>>> layer = brainstate.nn.Linear((10,), (5,), b_init=None)
|
91
91
|
>>> y = layer(x)
|
92
92
|
>>> y.shape
|
93
93
|
(32, 5)
|
@@ -171,11 +171,11 @@ class SignedWLinear(Module):
|
|
171
171
|
--------
|
172
172
|
.. code-block:: python
|
173
173
|
|
174
|
-
>>> import brainstate as
|
174
|
+
>>> import brainstate as brainstate
|
175
175
|
>>> import jax.numpy as jnp
|
176
176
|
>>>
|
177
177
|
>>> # Create a signed weight linear layer with all positive weights
|
178
|
-
>>> layer =
|
178
|
+
>>> layer = brainstate.nn.SignedWLinear((10,), (5,))
|
179
179
|
>>> x = jnp.ones((32, 10))
|
180
180
|
>>> y = layer(x)
|
181
181
|
>>> y.shape
|
@@ -183,7 +183,7 @@ class SignedWLinear(Module):
|
|
183
183
|
>>>
|
184
184
|
>>> # With custom sign matrix (e.g., inhibitory connections)
|
185
185
|
>>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
|
186
|
-
>>> layer =
|
186
|
+
>>> layer = brainstate.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
|
187
187
|
>>> y = layer(x)
|
188
188
|
>>> y.shape
|
189
189
|
(32, 5)
|
@@ -274,18 +274,18 @@ class ScaledWSLinear(Module):
|
|
274
274
|
--------
|
275
275
|
.. code-block:: python
|
276
276
|
|
277
|
-
>>> import brainstate as
|
277
|
+
>>> import brainstate as brainstate
|
278
278
|
>>> import jax.numpy as jnp
|
279
279
|
>>>
|
280
280
|
>>> # Create a weight-standardized linear layer
|
281
|
-
>>> layer =
|
281
|
+
>>> layer = brainstate.nn.ScaledWSLinear((10,), (5,))
|
282
282
|
>>> x = jnp.ones((32, 10))
|
283
283
|
>>> y = layer(x)
|
284
284
|
>>> y.shape
|
285
285
|
(32, 5)
|
286
286
|
>>>
|
287
287
|
>>> # Without learnable gain
|
288
|
-
>>> layer =
|
288
|
+
>>> layer = brainstate.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
|
289
289
|
>>> y = layer(x)
|
290
290
|
>>> y.shape
|
291
291
|
(32, 5)
|
@@ -375,7 +375,7 @@ class SparseLinear(Module):
|
|
375
375
|
--------
|
376
376
|
.. code-block:: python
|
377
377
|
|
378
|
-
>>> import brainstate as
|
378
|
+
>>> import brainstate as brainstate
|
379
379
|
>>> import brainunit as u
|
380
380
|
>>> import jax.numpy as jnp
|
381
381
|
>>>
|
@@ -384,7 +384,7 @@ class SparseLinear(Module):
|
|
384
384
|
>>> values = jnp.array([1.0, 2.0, 3.0])
|
385
385
|
>>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
|
386
386
|
... shape=(3, 3))
|
387
|
-
>>> layer =
|
387
|
+
>>> layer = brainstate.nn.SparseLinear(spar_mat, in_size=(3,))
|
388
388
|
>>> x = jnp.ones((5, 3))
|
389
389
|
>>> y = layer(x)
|
390
390
|
>>> y.shape
|
@@ -468,18 +468,18 @@ class AllToAll(Module):
|
|
468
468
|
--------
|
469
469
|
.. code-block:: python
|
470
470
|
|
471
|
-
>>> import brainstate as
|
471
|
+
>>> import brainstate as brainstate
|
472
472
|
>>> import jax.numpy as jnp
|
473
473
|
>>>
|
474
474
|
>>> # All-to-all with self-connections
|
475
|
-
>>> layer =
|
475
|
+
>>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=True)
|
476
476
|
>>> x = jnp.ones((32, 10))
|
477
477
|
>>> y = layer(x)
|
478
478
|
>>> y.shape
|
479
479
|
(32, 10)
|
480
480
|
>>>
|
481
481
|
>>> # All-to-all without self-connections (recurrent layer)
|
482
|
-
>>> layer =
|
482
|
+
>>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=False)
|
483
483
|
>>> y = layer(x)
|
484
484
|
>>> y.shape
|
485
485
|
(32, 10)
|
@@ -584,18 +584,18 @@ class OneToOne(Module):
|
|
584
584
|
--------
|
585
585
|
.. code-block:: python
|
586
586
|
|
587
|
-
>>> import brainstate as
|
587
|
+
>>> import brainstate as brainstate
|
588
588
|
>>> import jax.numpy as jnp
|
589
589
|
>>>
|
590
590
|
>>> # One-to-one connection
|
591
|
-
>>> layer =
|
591
|
+
>>> layer = brainstate.nn.OneToOne((10,))
|
592
592
|
>>> x = jnp.ones((32, 10))
|
593
593
|
>>> y = layer(x)
|
594
594
|
>>> y.shape
|
595
595
|
(32, 10)
|
596
596
|
>>>
|
597
597
|
>>> # With bias
|
598
|
-
>>> layer =
|
598
|
+
>>> layer = brainstate.nn.OneToOne((10,), b_init=brainstate.init.Constant(0.1))
|
599
599
|
>>> y = layer(x)
|
600
600
|
>>> y.shape
|
601
601
|
(32, 10)
|
@@ -677,19 +677,19 @@ class LoRA(Module):
|
|
677
677
|
--------
|
678
678
|
.. code-block:: python
|
679
679
|
|
680
|
-
>>> import brainstate as
|
680
|
+
>>> import brainstate as brainstate
|
681
681
|
>>> import jax.numpy as jnp
|
682
682
|
>>>
|
683
683
|
>>> # Standalone LoRA layer
|
684
|
-
>>> layer =
|
684
|
+
>>> layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
|
685
685
|
>>> x = jnp.ones((32, 10))
|
686
686
|
>>> y = layer(x)
|
687
687
|
>>> y.shape
|
688
688
|
(32, 5)
|
689
689
|
>>>
|
690
690
|
>>> # Wrap around existing linear layer
|
691
|
-
>>> base =
|
692
|
-
>>> lora_layer =
|
691
|
+
>>> base = brainstate.nn.Linear((10,), (5,))
|
692
|
+
>>> lora_layer = brainstate.nn.LoRA(in_features=10, lora_rank=2,
|
693
693
|
... out_features=5, base_module=base)
|
694
694
|
>>> y = lora_layer(x)
|
695
695
|
>>> y.shape
|
brainstate/nn/_module.py
CHANGED
@@ -363,22 +363,29 @@ class Sequential(Module):
|
|
363
363
|
self.out_size = in_size
|
364
364
|
|
365
365
|
def _format_module(self, module, in_size):
|
366
|
-
|
367
|
-
if
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
366
|
+
try:
|
367
|
+
if isinstance(module, ParamDescriber):
|
368
|
+
if in_size is None:
|
369
|
+
raise ValueError(
|
370
|
+
'The input size should be specified. '
|
371
|
+
f'Please set the in_size attribute of the previous module: \n'
|
372
|
+
f'{self.layers[-1]}'
|
373
|
+
)
|
374
|
+
module = module(in_size=in_size)
|
375
|
+
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
376
|
+
out_size = module.out_size
|
377
|
+
elif isinstance(module, ElementWiseBlock):
|
378
|
+
out_size = in_size
|
379
|
+
elif isinstance(module, Module):
|
380
|
+
out_size = module.out_size
|
381
|
+
elif callable(module):
|
382
|
+
out_size = in_size
|
383
|
+
else:
|
384
|
+
raise TypeError(f"Unsupported type {type(module)}. ")
|
385
|
+
except Exception as e:
|
386
|
+
raise BrainStateError(
|
387
|
+
f'Failed to format the module: \n'
|
388
|
+
f'{module}\n'
|
389
|
+
f'with input size: {in_size}\n'
|
390
|
+
) from e
|
384
391
|
return module, out_size
|
brainstate/nn/_normalizations.py
CHANGED
@@ -77,16 +77,16 @@ def weight_standardization(
|
|
77
77
|
--------
|
78
78
|
.. code-block:: python
|
79
79
|
|
80
|
-
>>> import brainstate as
|
80
|
+
>>> import brainstate as brainstate
|
81
81
|
>>> import jax.numpy as jnp
|
82
82
|
>>>
|
83
83
|
>>> # Standardize a weight matrix
|
84
84
|
>>> w = jnp.ones((3, 4))
|
85
|
-
>>> w_std =
|
85
|
+
>>> w_std = brainstate.nn.weight_standardization(w)
|
86
86
|
>>>
|
87
87
|
>>> # With custom gain
|
88
88
|
>>> gain = jnp.ones((4,))
|
89
|
-
>>> w_std =
|
89
|
+
>>> w_std = brainstate.nn.weight_standardization(w, gain=gain)
|
90
90
|
"""
|
91
91
|
w = u.maybe_custom_array(w)
|
92
92
|
if out_axis < 0:
|
@@ -551,11 +551,11 @@ class BatchNorm0d(_BatchNorm):
|
|
551
551
|
--------
|
552
552
|
.. code-block:: python
|
553
553
|
|
554
|
-
>>> import brainstate as
|
554
|
+
>>> import brainstate as brainstate
|
555
555
|
>>> import jax.numpy as jnp
|
556
556
|
>>>
|
557
557
|
>>> # Create a BatchNorm0d layer
|
558
|
-
>>> layer =
|
558
|
+
>>> layer = brainstate.nn.BatchNorm0d(in_size=(10,))
|
559
559
|
>>>
|
560
560
|
>>> # Apply normalization to a batch of data
|
561
561
|
>>> x = jnp.ones((32, 10)) # batch_size=32, features=10
|
@@ -623,11 +623,11 @@ class BatchNorm1d(_BatchNorm):
|
|
623
623
|
--------
|
624
624
|
.. code-block:: python
|
625
625
|
|
626
|
-
>>> import brainstate as
|
626
|
+
>>> import brainstate as brainstate
|
627
627
|
>>> import jax.numpy as jnp
|
628
628
|
>>>
|
629
629
|
>>> # Create a BatchNorm1d layer for sequence data
|
630
|
-
>>> layer =
|
630
|
+
>>> layer = brainstate.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
|
631
631
|
>>>
|
632
632
|
>>> # Apply normalization
|
633
633
|
>>> x = jnp.ones((8, 100, 64)) # batch_size=8
|
@@ -693,11 +693,11 @@ class BatchNorm2d(_BatchNorm):
|
|
693
693
|
--------
|
694
694
|
.. code-block:: python
|
695
695
|
|
696
|
-
>>> import brainstate as
|
696
|
+
>>> import brainstate as brainstate
|
697
697
|
>>> import jax.numpy as jnp
|
698
698
|
>>>
|
699
699
|
>>> # Create a BatchNorm2d layer for image data
|
700
|
-
>>> layer =
|
700
|
+
>>> layer = brainstate.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
|
701
701
|
>>>
|
702
702
|
>>> # Apply normalization
|
703
703
|
>>> x = jnp.ones((16, 28, 28, 3)) # batch_size=16
|
@@ -763,11 +763,11 @@ class BatchNorm3d(_BatchNorm):
|
|
763
763
|
--------
|
764
764
|
.. code-block:: python
|
765
765
|
|
766
|
-
>>> import brainstate as
|
766
|
+
>>> import brainstate as brainstate
|
767
767
|
>>> import jax.numpy as jnp
|
768
768
|
>>>
|
769
769
|
>>> # Create a BatchNorm3d layer for volumetric data
|
770
|
-
>>> layer =
|
770
|
+
>>> layer = brainstate.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
|
771
771
|
>>>
|
772
772
|
>>> # Apply normalization
|
773
773
|
>>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4
|
@@ -841,11 +841,11 @@ class LayerNorm(Module):
|
|
841
841
|
--------
|
842
842
|
.. code-block:: python
|
843
843
|
|
844
|
-
>>> import brainstate as
|
844
|
+
>>> import brainstate as brainstate
|
845
845
|
>>>
|
846
846
|
>>> # Create a LayerNorm layer
|
847
|
-
>>> x =
|
848
|
-
>>> layer =
|
847
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
848
|
+
>>> layer = brainstate.nn.LayerNorm(x.shape)
|
849
849
|
>>>
|
850
850
|
>>> # Apply normalization
|
851
851
|
>>> y = layer(x)
|
@@ -853,8 +853,8 @@ class LayerNorm(Module):
|
|
853
853
|
(3, 4, 5, 6)
|
854
854
|
>>>
|
855
855
|
>>> # Normalize only the last dimension
|
856
|
-
>>> layer =
|
857
|
-
>>> x =
|
856
|
+
>>> layer = brainstate.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
|
857
|
+
>>> x = brainstate.random.normal((5, 10, 20))
|
858
858
|
>>> y = layer(x)
|
859
859
|
"""
|
860
860
|
|
@@ -1005,11 +1005,11 @@ class RMSNorm(Module):
|
|
1005
1005
|
--------
|
1006
1006
|
.. code-block:: python
|
1007
1007
|
|
1008
|
-
>>> import brainstate as
|
1008
|
+
>>> import brainstate as brainstate
|
1009
1009
|
>>>
|
1010
1010
|
>>> # Create an RMSNorm layer
|
1011
|
-
>>> x =
|
1012
|
-
>>> layer =
|
1011
|
+
>>> x = brainstate.random.normal(size=(5, 6))
|
1012
|
+
>>> layer = brainstate.nn.RMSNorm(in_size=(6,))
|
1013
1013
|
>>>
|
1014
1014
|
>>> # Apply normalization
|
1015
1015
|
>>> y = layer(x)
|
@@ -1017,8 +1017,8 @@ class RMSNorm(Module):
|
|
1017
1017
|
(5, 6)
|
1018
1018
|
>>>
|
1019
1019
|
>>> # Without scaling
|
1020
|
-
>>> layer =
|
1021
|
-
>>> x =
|
1020
|
+
>>> layer = brainstate.nn.RMSNorm(in_size=(10,), use_scale=False)
|
1021
|
+
>>> x = brainstate.random.normal((3, 10))
|
1022
1022
|
>>> y = layer(x)
|
1023
1023
|
"""
|
1024
1024
|
|
@@ -1179,20 +1179,20 @@ class GroupNorm(Module):
|
|
1179
1179
|
.. code-block:: python
|
1180
1180
|
|
1181
1181
|
>>> import numpy as np
|
1182
|
-
>>> import brainstate as
|
1182
|
+
>>> import brainstate as brainstate
|
1183
1183
|
>>>
|
1184
1184
|
>>> # Create a GroupNorm layer with 3 groups
|
1185
|
-
>>> x =
|
1186
|
-
>>> layer =
|
1185
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
1186
|
+
>>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
|
1187
1187
|
>>> y = layer(x)
|
1188
1188
|
>>>
|
1189
1189
|
>>> # GroupNorm with num_groups=1 is equivalent to LayerNorm
|
1190
|
-
>>> y1 =
|
1191
|
-
>>> y2 =
|
1190
|
+
>>> y1 = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
|
1191
|
+
>>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
1192
1192
|
>>> np.testing.assert_allclose(y1, y2, rtol=1e-5)
|
1193
1193
|
>>>
|
1194
1194
|
>>> # Specify group_size instead of num_groups
|
1195
|
-
>>> layer =
|
1195
|
+
>>> layer = brainstate.nn.GroupNorm((12,), num_groups=None, group_size=4)
|
1196
1196
|
"""
|
1197
1197
|
|
1198
1198
|
def __init__(
|
brainstate/random/__init__.py
CHANGED
@@ -259,12 +259,12 @@ References
|
|
259
259
|
|
260
260
|
"""
|
261
261
|
|
262
|
-
from .
|
263
|
-
from .
|
264
|
-
from .
|
265
|
-
from .
|
266
|
-
from .
|
267
|
-
from .
|
262
|
+
from ._fun import *
|
263
|
+
from ._fun import __all__ as __all_random__
|
264
|
+
from ._seed import *
|
265
|
+
from ._seed import __all__ as __all_seed__
|
266
|
+
from ._state import *
|
267
|
+
from ._state import __all__ as __all_state__
|
268
268
|
|
269
269
|
__all__ = __all_random__ + __all_state__ + __all_seed__
|
270
270
|
del __all_random__, __all_state__, __all_seed__
|
@@ -437,13 +437,11 @@ class TestRandom(unittest.TestCase):
|
|
437
437
|
a = brainstate.random.hypergeometric(10, 10, 10, 20)
|
438
438
|
self.assertTupleEqual(a.shape, (20,))
|
439
439
|
|
440
|
-
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
441
440
|
def test_hypergeometric2(self):
|
442
441
|
brainstate.random.seed()
|
443
442
|
a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
|
444
443
|
self.assertTupleEqual(a.shape, (2, 2))
|
445
444
|
|
446
|
-
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
447
445
|
def test_hypergeometric3(self):
|
448
446
|
brainstate.random.seed()
|
449
447
|
a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
|