brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +6 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +22 -17
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/nn/metrics.py
CHANGED
@@ -61,12 +61,12 @@ class Average(Metric):
|
|
61
61
|
Example usage::
|
62
62
|
|
63
63
|
>>> import jax.numpy as jnp
|
64
|
-
>>> import brainstate as
|
64
|
+
>>> import brainstate as brainstate
|
65
65
|
|
66
66
|
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
67
67
|
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
68
68
|
|
69
|
-
>>> metrics =
|
69
|
+
>>> metrics = brainstate.nn.metrics.Average()
|
70
70
|
>>> metrics.compute()
|
71
71
|
Array(nan, dtype=float32)
|
72
72
|
>>> metrics.update(values=batch_loss)
|
@@ -223,7 +223,7 @@ class Accuracy(Average):
|
|
223
223
|
|
224
224
|
Example usage::
|
225
225
|
|
226
|
-
>>> import brainstate as
|
226
|
+
>>> import brainstate as brainstate
|
227
227
|
>>> import jax, jax.numpy as jnp
|
228
228
|
|
229
229
|
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
@@ -231,7 +231,7 @@ class Accuracy(Average):
|
|
231
231
|
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
232
232
|
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
233
233
|
|
234
|
-
>>> metrics =
|
234
|
+
>>> metrics = brainstate.nn.metrics.Accuracy()
|
235
235
|
>>> metrics.compute()
|
236
236
|
Array(nan, dtype=float32)
|
237
237
|
>>> metrics.update(logits=logits, labels=labels)
|
@@ -36,29 +36,29 @@ class OptaxOptimizer(Optimizer):
|
|
36
36
|
|
37
37
|
>>> import jax
|
38
38
|
>>> import jax.numpy as jnp
|
39
|
-
>>> import brainstate as
|
39
|
+
>>> import brainstate as brainstate
|
40
40
|
>>> import optax
|
41
41
|
...
|
42
|
-
>>> class Model(
|
42
|
+
>>> class Model(brainstate.nn.Module):
|
43
43
|
... def __init__(self):
|
44
44
|
... super().__init__()
|
45
|
-
... self.linear1 =
|
46
|
-
... self.linear2 =
|
45
|
+
... self.linear1 = brainstate.nn.Linear(2, 3)
|
46
|
+
... self.linear2 = brainstate.nn.Linear(3, 4)
|
47
47
|
... def __call__(self, x):
|
48
48
|
... return self.linear2(self.linear1(x))
|
49
49
|
...
|
50
|
-
>>> x =
|
50
|
+
>>> x = brainstate.random.randn(1, 2)
|
51
51
|
>>> y = jnp.ones((1, 4))
|
52
52
|
...
|
53
53
|
>>> model = Model()
|
54
54
|
>>> tx = optax.adam(1e-3)
|
55
|
-
>>> optimizer =
|
56
|
-
>>> optimizer.register_trainable_weights(model.states(
|
55
|
+
>>> optimizer = brainstate.optim.OptaxOptimizer(tx)
|
56
|
+
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
|
57
57
|
...
|
58
58
|
>>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
|
59
59
|
>>> loss_fn()
|
60
60
|
Array(1.7055722, dtype=float32)
|
61
|
-
>>> grads =
|
61
|
+
>>> grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
|
62
62
|
>>> optimizer.update(grads)
|
63
63
|
>>> loss_fn()
|
64
64
|
Array(1.6925814, dtype=float32)
|
brainstate/random/_rand_funs.py
CHANGED
@@ -78,8 +78,8 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
|
78
78
|
|
79
79
|
Examples
|
80
80
|
--------
|
81
|
-
>>> import brainstate as
|
82
|
-
>>>
|
81
|
+
>>> import brainstate as brainstate
|
82
|
+
>>> brainstate.random.rand(3,2)
|
83
83
|
array([[ 0.14022471, 0.96360618], #random
|
84
84
|
[ 0.37601032, 0.25528411], #random
|
85
85
|
[ 0.49313049, 0.94909878]]) #random
|
@@ -135,31 +135,31 @@ def randint(low, high=None, size: Optional[Size] = None,
|
|
135
135
|
|
136
136
|
Examples
|
137
137
|
--------
|
138
|
-
>>> import brainstate as
|
139
|
-
>>>
|
138
|
+
>>> import brainstate as brainstate
|
139
|
+
>>> brainstate.random.randint(2, size=10)
|
140
140
|
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
|
141
|
-
>>>
|
141
|
+
>>> brainstate.random.randint(1, size=10)
|
142
142
|
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
143
143
|
|
144
144
|
Generate a 2 x 4 array of ints between 0 and 4, inclusive:
|
145
145
|
|
146
|
-
>>>
|
146
|
+
>>> brainstate.random.randint(5, size=(2, 4))
|
147
147
|
array([[4, 0, 2, 1], # random
|
148
148
|
[3, 2, 2, 0]])
|
149
149
|
|
150
150
|
Generate a 1 x 3 array with 3 different upper bounds
|
151
151
|
|
152
|
-
>>>
|
152
|
+
>>> brainstate.random.randint(1, [3, 5, 10])
|
153
153
|
array([2, 2, 9]) # random
|
154
154
|
|
155
155
|
Generate a 1 by 3 array with 3 different lower bounds
|
156
156
|
|
157
|
-
>>>
|
157
|
+
>>> brainstate.random.randint([1, 5, 7], 10)
|
158
158
|
array([9, 8, 7]) # random
|
159
159
|
|
160
160
|
Generate a 2 by 4 array using broadcasting with dtype of uint8
|
161
161
|
|
162
|
-
>>>
|
162
|
+
>>> brainstate.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
|
163
163
|
array([[ 8, 6, 9, 7], # random
|
164
164
|
[ 1, 16, 9, 12]], dtype=uint8)
|
165
165
|
"""
|
@@ -219,12 +219,12 @@ def random_integers(low,
|
|
219
219
|
|
220
220
|
Examples
|
221
221
|
--------
|
222
|
-
>>> import brainstate as
|
223
|
-
>>>
|
222
|
+
>>> import brainstate as brainstate
|
223
|
+
>>> brainstate.random.random_integers(5)
|
224
224
|
4 # random
|
225
|
-
>>> type(
|
225
|
+
>>> type(brainstate.random.random_integers(5))
|
226
226
|
<class 'numpy.int64'>
|
227
|
-
>>>
|
227
|
+
>>> brainstate.random.random_integers(5, size=(3,2))
|
228
228
|
array([[5, 4], # random
|
229
229
|
[3, 3],
|
230
230
|
[4, 5]])
|
@@ -233,13 +233,13 @@ def random_integers(low,
|
|
233
233
|
numbers between 0 and 2.5, inclusive (*i.e.*, from the set
|
234
234
|
:math:`{0, 5/8, 10/8, 15/8, 20/8}`):
|
235
235
|
|
236
|
-
>>> 2.5 * (
|
236
|
+
>>> 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
|
237
237
|
array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
|
238
238
|
|
239
239
|
Roll two six sided dice 1000 times and sum the results:
|
240
240
|
|
241
|
-
>>> d1 =
|
242
|
-
>>> d2 =
|
241
|
+
>>> d1 = brainstate.random.random_integers(1, 6, 1000)
|
242
|
+
>>> d2 = brainstate.random.random_integers(1, 6, 1000)
|
243
243
|
>>> dsums = d1 + d2
|
244
244
|
|
245
245
|
Display results as a histogram:
|
@@ -301,13 +301,13 @@ def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
|
301
301
|
|
302
302
|
Examples
|
303
303
|
--------
|
304
|
-
>>> import brainstate as
|
305
|
-
>>>
|
304
|
+
>>> import brainstate as brainstate
|
305
|
+
>>> brainstate.random.randn()
|
306
306
|
2.1923875335537315 # random
|
307
307
|
|
308
308
|
Two-by-four array of samples from N(3, 6.25):
|
309
309
|
|
310
|
-
>>> 3 + 2.5 *
|
310
|
+
>>> 3 + 2.5 * brainstate.random.randn(2, 4)
|
311
311
|
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
|
312
312
|
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
|
313
313
|
"""
|
@@ -359,17 +359,17 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
|
|
359
359
|
|
360
360
|
Examples
|
361
361
|
--------
|
362
|
-
>>> import brainstate as
|
363
|
-
>>>
|
362
|
+
>>> import brainstate as brainstate
|
363
|
+
>>> brainstate.random.random_sample()
|
364
364
|
0.47108547995356098 # random
|
365
|
-
>>> type(
|
365
|
+
>>> type(brainstate.random.random_sample())
|
366
366
|
<class 'float'>
|
367
|
-
>>>
|
367
|
+
>>> brainstate.random.random_sample((5,))
|
368
368
|
array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
|
369
369
|
|
370
370
|
Three-by-two array of random numbers from [-5, 0):
|
371
371
|
|
372
|
-
>>> 5 *
|
372
|
+
>>> 5 * brainstate.random.random_sample((3, 2)) - 5
|
373
373
|
array([[-3.99149989, -0.52338984], # random
|
374
374
|
[-2.99091858, -0.79479508],
|
375
375
|
[-1.23204345, -1.75224494]])
|
@@ -450,34 +450,34 @@ def choice(a, size: Optional[Size] = None, replace=True, p=None,
|
|
450
450
|
--------
|
451
451
|
Generate a uniform random sample from np.arange(5) of size 3:
|
452
452
|
|
453
|
-
>>> import brainstate as
|
454
|
-
>>>
|
453
|
+
>>> import brainstate as brainstate
|
454
|
+
>>> brainstate.random.choice(5, 3)
|
455
455
|
array([0, 3, 4]) # random
|
456
456
|
>>> #This is equivalent to brainpy.math.random.randint(0,5,3)
|
457
457
|
|
458
458
|
Generate a non-uniform random sample from np.arange(5) of size 3:
|
459
459
|
|
460
|
-
>>>
|
460
|
+
>>> brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
|
461
461
|
array([3, 3, 0]) # random
|
462
462
|
|
463
463
|
Generate a uniform random sample from np.arange(5) of size 3 without
|
464
464
|
replacement:
|
465
465
|
|
466
|
-
>>>
|
466
|
+
>>> brainstate.random.choice(5, 3, replace=False)
|
467
467
|
array([3,1,0]) # random
|
468
468
|
>>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
|
469
469
|
|
470
470
|
Generate a non-uniform random sample from np.arange(5) of size
|
471
471
|
3 without replacement:
|
472
472
|
|
473
|
-
>>>
|
473
|
+
>>> brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
|
474
474
|
array([2, 3, 0]) # random
|
475
475
|
|
476
476
|
Any of the above can be repeated with an arbitrary array-like
|
477
477
|
instead of just integers. For instance:
|
478
478
|
|
479
479
|
>>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
|
480
|
-
>>>
|
480
|
+
>>> brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
|
481
481
|
array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
|
482
482
|
dtype='<U11')
|
483
483
|
"""
|
@@ -519,15 +519,15 @@ def permutation(x,
|
|
519
519
|
|
520
520
|
Examples
|
521
521
|
--------
|
522
|
-
>>> import brainstate as
|
523
|
-
>>>
|
522
|
+
>>> import brainstate as brainstate
|
523
|
+
>>> brainstate.random.permutation(10)
|
524
524
|
array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
|
525
525
|
|
526
|
-
>>>
|
526
|
+
>>> brainstate.random.permutation([1, 4, 9, 12, 15])
|
527
527
|
array([15, 1, 9, 4, 12]) # random
|
528
528
|
|
529
529
|
>>> arr = np.arange(9).reshape((3, 3))
|
530
|
-
>>>
|
530
|
+
>>> brainstate.random.permutation(arr)
|
531
531
|
array([[6, 7, 8], # random
|
532
532
|
[0, 1, 2],
|
533
533
|
[3, 4, 5]])
|
@@ -557,16 +557,16 @@ def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
|
|
557
557
|
|
558
558
|
Examples
|
559
559
|
--------
|
560
|
-
>>> import brainstate as
|
560
|
+
>>> import brainstate as brainstate
|
561
561
|
>>> arr = np.arange(10)
|
562
|
-
>>>
|
562
|
+
>>> brainstate.random.shuffle(arr)
|
563
563
|
>>> arr
|
564
564
|
[1 7 5 2 9 4 3 6 0 8] # random
|
565
565
|
|
566
566
|
Multi-dimensional arrays are only shuffled along the first axis:
|
567
567
|
|
568
568
|
>>> arr = np.arange(9).reshape((3, 3))
|
569
|
-
>>>
|
569
|
+
>>> brainstate.random.shuffle(arr)
|
570
570
|
>>> arr
|
571
571
|
array([[3, 4, 5], # random
|
572
572
|
[6, 7, 8],
|
@@ -562,6 +562,6 @@ class TestRandom(unittest.TestCase):
|
|
562
562
|
|
563
563
|
# class TestRandomKey(unittest.TestCase):
|
564
564
|
# def test_clear_memory(self):
|
565
|
-
#
|
566
|
-
# print(
|
567
|
-
# self.assertTrue(isinstance(
|
565
|
+
# brainstate.random.split_key()
|
566
|
+
# print(brainstate.random.DEFAULT.value)
|
567
|
+
# self.assertTrue(isinstance(brainstate.random.DEFAULT.value, np.ndarray))
|
brainstate/random/_rand_seed.py
CHANGED
@@ -183,16 +183,16 @@ def seed_context(seed_or_key: SeedOrKey):
|
|
183
183
|
|
184
184
|
Examples:
|
185
185
|
|
186
|
-
>>> import brainstate as
|
187
|
-
>>> print(
|
186
|
+
>>> import brainstate as brainstate
|
187
|
+
>>> print(brainstate.random.rand(2))
|
188
188
|
[0.57721865 0.9820676 ]
|
189
|
-
>>> print(
|
189
|
+
>>> print(brainstate.random.rand(2))
|
190
190
|
[0.8511752 0.95312667]
|
191
|
-
>>> with
|
192
|
-
... print(
|
191
|
+
>>> with brainstate.random.seed_context(42):
|
192
|
+
... print(brainstate.random.rand(2))
|
193
193
|
[0.95598125 0.4032725 ]
|
194
|
-
>>> with
|
195
|
-
... print(
|
194
|
+
>>> with brainstate.random.seed_context(42):
|
195
|
+
... print(brainstate.random.rand(2))
|
196
196
|
[0.95598125 0.4032725 ]
|
197
197
|
|
198
198
|
Args:
|
brainstate/random/_rand_state.py
CHANGED
@@ -384,7 +384,10 @@ class RandomState(State):
|
|
384
384
|
loc = _check_py_seq(loc)
|
385
385
|
scale = _check_py_seq(scale)
|
386
386
|
if size is None:
|
387
|
-
size = lax.broadcast_shapes(
|
387
|
+
size = lax.broadcast_shapes(
|
388
|
+
jnp.shape(loc) if loc is not None else (),
|
389
|
+
jnp.shape(scale) if scale is not None else ()
|
390
|
+
)
|
388
391
|
key = self.split_key() if key is None else _formalize_key(key)
|
389
392
|
dtype = dtype or environ.dftype()
|
390
393
|
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
@@ -399,7 +402,10 @@ class RandomState(State):
|
|
399
402
|
loc = _check_py_seq(loc)
|
400
403
|
scale = _check_py_seq(scale)
|
401
404
|
if size is None:
|
402
|
-
size = lax.broadcast_shapes(
|
405
|
+
size = lax.broadcast_shapes(
|
406
|
+
jnp.shape(scale) if scale is not None else (),
|
407
|
+
jnp.shape(loc) if loc is not None else ()
|
408
|
+
)
|
403
409
|
key = self.split_key() if key is None else _formalize_key(key)
|
404
410
|
dtype = dtype or environ.dftype()
|
405
411
|
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
@@ -456,7 +462,7 @@ class RandomState(State):
|
|
456
462
|
dtype: DTypeLike = None):
|
457
463
|
shape = _check_py_seq(shape)
|
458
464
|
if size is None:
|
459
|
-
size = jnp.shape(shape)
|
465
|
+
size = jnp.shape(shape) if shape is not None else ()
|
460
466
|
key = self.split_key() if key is None else _formalize_key(key)
|
461
467
|
dtype = dtype or environ.dftype()
|
462
468
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
@@ -477,7 +483,7 @@ class RandomState(State):
|
|
477
483
|
dtype: DTypeLike = None):
|
478
484
|
df = _check_py_seq(df)
|
479
485
|
if size is None:
|
480
|
-
size = jnp.shape(size)
|
486
|
+
size = jnp.shape(size) if size is not None else ()
|
481
487
|
key = self.split_key() if key is None else _formalize_key(key)
|
482
488
|
dtype = dtype or environ.dftype()
|
483
489
|
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
@@ -606,8 +612,8 @@ class RandomState(State):
|
|
606
612
|
|
607
613
|
if size is None:
|
608
614
|
size = jnp.broadcast_shapes(
|
609
|
-
jnp.shape(mean),
|
610
|
-
jnp.shape(sigma)
|
615
|
+
jnp.shape(mean) if mean is not None else (),
|
616
|
+
jnp.shape(sigma) if sigma is not None else ()
|
611
617
|
)
|
612
618
|
key = self.split_key() if key is None else _formalize_key(key)
|
613
619
|
dtype = dtype or environ.dftype()
|
@@ -822,7 +828,7 @@ class RandomState(State):
|
|
822
828
|
a = _check_py_seq(a)
|
823
829
|
scale = _check_py_seq(scale)
|
824
830
|
if size is None:
|
825
|
-
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
|
831
|
+
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
|
826
832
|
else:
|
827
833
|
if jnp.size(a) > 1:
|
828
834
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
brainstate/surrogate.py
CHANGED
@@ -106,11 +106,11 @@ class Surrogate(PrettyObject):
|
|
106
106
|
Examples
|
107
107
|
--------
|
108
108
|
|
109
|
-
>>> import brainstate as
|
109
|
+
>>> import brainstate as brainstate
|
110
110
|
>>> import brainstate.nn as nn
|
111
111
|
>>> import jax.numpy as jnp
|
112
112
|
|
113
|
-
>>> class MySurrogate(
|
113
|
+
>>> class MySurrogate(brainstate.surrogate.Surrogate):
|
114
114
|
... def __init__(self, alpha=1.):
|
115
115
|
... super().__init__()
|
116
116
|
... self.alpha = alpha
|
@@ -236,11 +236,11 @@ def sigmoid(
|
|
236
236
|
|
237
237
|
>>> import jax
|
238
238
|
>>> import brainstate.nn as nn
|
239
|
-
>>> import brainstate as
|
239
|
+
>>> import brainstate as brainstate
|
240
240
|
>>> import matplotlib.pyplot as plt
|
241
241
|
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
242
242
|
>>> for alpha in [1., 2., 4.]:
|
243
|
-
>>> grads =
|
243
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
|
244
244
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
245
245
|
>>> plt.legend()
|
246
246
|
>>> plt.show()
|
@@ -355,11 +355,11 @@ def piecewise_quadratic(
|
|
355
355
|
|
356
356
|
>>> import jax
|
357
357
|
>>> import brainstate.nn as nn
|
358
|
-
>>> import brainstate as
|
358
|
+
>>> import brainstate as brainstate
|
359
359
|
>>> import matplotlib.pyplot as plt
|
360
360
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
361
361
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
362
|
-
>>> grads =
|
362
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
|
363
363
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
364
364
|
>>> plt.legend()
|
365
365
|
>>> plt.show()
|
@@ -522,11 +522,11 @@ def piecewise_exp(
|
|
522
522
|
|
523
523
|
>>> import jax
|
524
524
|
>>> import brainstate.nn as nn
|
525
|
-
>>> import brainstate as
|
525
|
+
>>> import brainstate as brainstate
|
526
526
|
>>> import matplotlib.pyplot as plt
|
527
527
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
528
528
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
529
|
-
>>> grads =
|
529
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
|
530
530
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
531
531
|
>>> plt.legend()
|
532
532
|
>>> plt.show()
|
@@ -621,11 +621,11 @@ def soft_sign(
|
|
621
621
|
|
622
622
|
>>> import jax
|
623
623
|
>>> import brainstate.nn as nn
|
624
|
-
>>> import brainstate as
|
624
|
+
>>> import brainstate as brainstate
|
625
625
|
>>> import matplotlib.pyplot as plt
|
626
626
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
627
627
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
628
|
-
>>> grads =
|
628
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
|
629
629
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
630
630
|
>>> plt.legend()
|
631
631
|
>>> plt.show()
|
@@ -706,11 +706,11 @@ def arctan(
|
|
706
706
|
|
707
707
|
>>> import jax
|
708
708
|
>>> import brainstate.nn as nn
|
709
|
-
>>> import brainstate as
|
709
|
+
>>> import brainstate as brainstate
|
710
710
|
>>> import matplotlib.pyplot as plt
|
711
711
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
712
712
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
713
|
-
>>> grads =
|
713
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
|
714
714
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
715
715
|
>>> plt.legend()
|
716
716
|
>>> plt.show()
|
@@ -804,11 +804,11 @@ def nonzero_sign_log(
|
|
804
804
|
|
805
805
|
>>> import jax
|
806
806
|
>>> import brainstate.nn as nn
|
807
|
-
>>> import brainstate as
|
807
|
+
>>> import brainstate as brainstate
|
808
808
|
>>> import matplotlib.pyplot as plt
|
809
809
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
810
810
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
811
|
-
>>> grads =
|
811
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
812
812
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
813
813
|
>>> plt.legend()
|
814
814
|
>>> plt.show()
|
@@ -893,11 +893,11 @@ def erf(
|
|
893
893
|
|
894
894
|
>>> import jax
|
895
895
|
>>> import brainstate.nn as nn
|
896
|
-
>>> import brainstate as
|
896
|
+
>>> import brainstate as brainstate
|
897
897
|
>>> import matplotlib.pyplot as plt
|
898
898
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
899
899
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
900
|
-
>>> grads =
|
900
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
901
901
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
902
902
|
>>> plt.legend()
|
903
903
|
>>> plt.show()
|
@@ -1000,12 +1000,12 @@ def piecewise_leaky_relu(
|
|
1000
1000
|
|
1001
1001
|
>>> import jax
|
1002
1002
|
>>> import brainstate.nn as nn
|
1003
|
-
>>> import brainstate as
|
1003
|
+
>>> import brainstate as brainstate
|
1004
1004
|
>>> import matplotlib.pyplot as plt
|
1005
1005
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1006
1006
|
>>> for c in [0.01, 0.05, 0.1]:
|
1007
1007
|
>>> for w in [1., 2.]:
|
1008
|
-
>>> grads1 =
|
1008
|
+
>>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
1009
1009
|
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
1010
1010
|
>>> plt.legend()
|
1011
1011
|
>>> plt.show()
|
@@ -1113,12 +1113,12 @@ def squarewave_fourier_series(
|
|
1113
1113
|
|
1114
1114
|
>>> import jax
|
1115
1115
|
>>> import brainstate.nn as nn
|
1116
|
-
>>> import brainstate as
|
1116
|
+
>>> import brainstate as brainstate
|
1117
1117
|
>>> import matplotlib.pyplot as plt
|
1118
1118
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1119
1119
|
>>> for n in [2, 4, 8]:
|
1120
|
-
>>> f =
|
1121
|
-
>>> grads1 =
|
1120
|
+
>>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
|
1121
|
+
>>> grads1 = brainstate.augment.vector_grad(f)(xs)
|
1122
1122
|
>>> plt.plot(xs, grads1, label=f'n={n}')
|
1123
1123
|
>>> plt.legend()
|
1124
1124
|
>>> plt.show()
|
@@ -1214,12 +1214,12 @@ def s2nn(
|
|
1214
1214
|
|
1215
1215
|
>>> import jax
|
1216
1216
|
>>> import brainstate.nn as nn
|
1217
|
-
>>> import brainstate as
|
1217
|
+
>>> import brainstate as brainstate
|
1218
1218
|
>>> import matplotlib.pyplot as plt
|
1219
1219
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1220
|
-
>>> grads =
|
1220
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
|
1221
1221
|
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1222
|
-
>>> grads =
|
1222
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
|
1223
1223
|
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1224
1224
|
>>> plt.legend()
|
1225
1225
|
>>> plt.show()
|
@@ -1315,11 +1315,11 @@ def q_pseudo_spike(
|
|
1315
1315
|
|
1316
1316
|
>>> import jax
|
1317
1317
|
>>> import brainstate.nn as nn
|
1318
|
-
>>> import brainstate as
|
1318
|
+
>>> import brainstate as brainstate
|
1319
1319
|
>>> import matplotlib.pyplot as plt
|
1320
1320
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1321
1321
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
1322
|
-
>>> grads =
|
1322
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
|
1323
1323
|
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1324
1324
|
>>> plt.legend()
|
1325
1325
|
>>> plt.show()
|
@@ -1413,10 +1413,10 @@ def leaky_relu(
|
|
1413
1413
|
|
1414
1414
|
>>> import jax
|
1415
1415
|
>>> import brainstate.nn as nn
|
1416
|
-
>>> import brainstate as
|
1416
|
+
>>> import brainstate as brainstate
|
1417
1417
|
>>> import matplotlib.pyplot as plt
|
1418
1418
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1419
|
-
>>> grads =
|
1419
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1420
1420
|
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1421
1421
|
>>> plt.legend()
|
1422
1422
|
>>> plt.show()
|
@@ -1517,10 +1517,10 @@ def log_tailed_relu(
|
|
1517
1517
|
|
1518
1518
|
>>> import jax
|
1519
1519
|
>>> import brainstate.nn as nn
|
1520
|
-
>>> import brainstate as
|
1520
|
+
>>> import brainstate as brainstate
|
1521
1521
|
>>> import matplotlib.pyplot as plt
|
1522
1522
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1523
|
-
>>> grads =
|
1523
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1524
1524
|
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1525
1525
|
>>> plt.legend()
|
1526
1526
|
>>> plt.show()
|
@@ -1596,12 +1596,12 @@ def relu_grad(
|
|
1596
1596
|
|
1597
1597
|
>>> import jax
|
1598
1598
|
>>> import brainstate.nn as nn
|
1599
|
-
>>> import brainstate as
|
1599
|
+
>>> import brainstate as brainstate
|
1600
1600
|
>>> import matplotlib.pyplot as plt
|
1601
1601
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1602
1602
|
>>> for s in [0.5, 1.]:
|
1603
1603
|
>>> for w in [1, 2.]:
|
1604
|
-
>>> grads =
|
1604
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
|
1605
1605
|
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1606
1606
|
>>> plt.legend()
|
1607
1607
|
>>> plt.show()
|
@@ -1678,11 +1678,11 @@ def gaussian_grad(
|
|
1678
1678
|
|
1679
1679
|
>>> import jax
|
1680
1680
|
>>> import brainstate.nn as nn
|
1681
|
-
>>> import brainstate as
|
1681
|
+
>>> import brainstate as brainstate
|
1682
1682
|
>>> import matplotlib.pyplot as plt
|
1683
1683
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1684
1684
|
>>> for s in [0.5, 1., 2.]:
|
1685
|
-
>>> grads =
|
1685
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
|
1686
1686
|
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1687
1687
|
>>> plt.legend()
|
1688
1688
|
>>> plt.show()
|
@@ -1773,10 +1773,10 @@ def multi_gaussian_grad(
|
|
1773
1773
|
|
1774
1774
|
>>> import jax
|
1775
1775
|
>>> import brainstate.nn as nn
|
1776
|
-
>>> import brainstate as
|
1776
|
+
>>> import brainstate as brainstate
|
1777
1777
|
>>> import matplotlib.pyplot as plt
|
1778
1778
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1779
|
-
>>> grads =
|
1779
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
|
1780
1780
|
>>> plt.plot(xs, grads)
|
1781
1781
|
>>> plt.show()
|
1782
1782
|
|
@@ -1855,11 +1855,11 @@ def inv_square_grad(
|
|
1855
1855
|
|
1856
1856
|
>>> import jax
|
1857
1857
|
>>> import brainstate.nn as nn
|
1858
|
-
>>> import brainstate as
|
1858
|
+
>>> import brainstate as brainstate
|
1859
1859
|
>>> import matplotlib.pyplot as plt
|
1860
1860
|
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1861
1861
|
>>> for alpha in [1., 10., 100.]:
|
1862
|
-
>>> grads =
|
1862
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
|
1863
1863
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1864
1864
|
>>> plt.legend()
|
1865
1865
|
>>> plt.show()
|
@@ -1929,11 +1929,11 @@ def slayer_grad(
|
|
1929
1929
|
|
1930
1930
|
>>> import jax
|
1931
1931
|
>>> import brainstate.nn as nn
|
1932
|
-
>>> import brainstate as
|
1932
|
+
>>> import brainstate as brainstate
|
1933
1933
|
>>> import matplotlib.pyplot as plt
|
1934
1934
|
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1935
1935
|
>>> for alpha in [0.5, 1., 2., 4.]:
|
1936
|
-
>>> grads =
|
1936
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
|
1937
1937
|
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1938
1938
|
>>> plt.legend()
|
1939
1939
|
>>> plt.show()
|