brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +6 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +22 -17
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,8 @@ import unittest
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
21
|
import numpy as np
|
22
|
+
from jax import vmap
|
23
|
+
from jax.lax import psum, pmean, pmax
|
22
24
|
|
23
25
|
import brainstate
|
24
26
|
import brainstate.augment
|
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
433
435
|
foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
|
434
436
|
|
435
437
|
faulty_init()
|
438
|
+
|
439
|
+
|
440
|
+
class TestAxisName:
|
441
|
+
def test1(self):
|
442
|
+
def compute_stats_with_axis_name(x):
|
443
|
+
"""Compute statistics using named axis operations"""
|
444
|
+
# Sum across the named axis 'batch'
|
445
|
+
total_sum = psum(x, axis_name='batch')
|
446
|
+
|
447
|
+
# Mean across the named axis 'batch'
|
448
|
+
mean_val = pmean(x, axis_name='batch')
|
449
|
+
|
450
|
+
# Max across the named axis 'batch'
|
451
|
+
max_val = pmax(x, axis_name='batch')
|
452
|
+
|
453
|
+
return {
|
454
|
+
'sum': total_sum,
|
455
|
+
'mean': mean_val,
|
456
|
+
'max': max_val,
|
457
|
+
'original': x
|
458
|
+
}
|
459
|
+
|
460
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
461
|
+
print("Input batch data:", batch_data)
|
462
|
+
|
463
|
+
# vmap with axis name 'batch'
|
464
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
465
|
+
result_jax = vectorized_stats_jax(batch_data)
|
466
|
+
|
467
|
+
# vmap with axis name 'batch'
|
468
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
469
|
+
result = vectorized_stats(batch_data)
|
470
|
+
|
471
|
+
# vmap with axis name 'batch'
|
472
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
473
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
474
|
+
)
|
475
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
476
|
+
|
477
|
+
for key in result_jax.keys():
|
478
|
+
print(f" {key}: {result_jax[key]}")
|
479
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
480
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
481
|
+
|
482
|
+
def test_nested_vmap(self):
|
483
|
+
def nested_computation(x):
|
484
|
+
"""Computation with multiple named axes"""
|
485
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
486
|
+
inner_sum = psum(x, axis_name='inner')
|
487
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
488
|
+
return outer_mean
|
489
|
+
|
490
|
+
# Create 2D batch data
|
491
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
492
|
+
print("Input 2D data shape:", data_2d.shape)
|
493
|
+
print("Input 2D data:\n", data_2d)
|
494
|
+
|
495
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
496
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
497
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
498
|
+
|
499
|
+
result_2d = nested_vmap(data_2d)
|
500
|
+
print("Result after nested vmap:", result_2d)
|
501
|
+
|
502
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
503
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
504
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
505
|
+
print("Result after nested vmap:", result_2d_bst)
|
506
|
+
|
507
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
508
|
+
|
509
|
+
def _gradient_averaging_simulation_bst(self):
|
510
|
+
def loss_function(params, x, y):
|
511
|
+
"""Simple quadratic loss"""
|
512
|
+
pred = params * x
|
513
|
+
return (pred - y) ** 2
|
514
|
+
|
515
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
516
|
+
"""Compute gradients and average them across the batch"""
|
517
|
+
# Compute per-sample gradients
|
518
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
519
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
520
|
+
|
521
|
+
# Average gradients across batch using named axis
|
522
|
+
def average_grads(grads):
|
523
|
+
return pmean(grads, axis_name='batch')
|
524
|
+
|
525
|
+
# Apply averaging with named axis
|
526
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
527
|
+
return averaged_grads
|
528
|
+
|
529
|
+
# Example data
|
530
|
+
params = 2.0
|
531
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
532
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
533
|
+
|
534
|
+
print("Parameters:", params)
|
535
|
+
print("Batch X:", batch_x)
|
536
|
+
print("Batch Y:", batch_y)
|
537
|
+
|
538
|
+
# Compute individual gradients first
|
539
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
540
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
541
|
+
print("Individual gradients:", individual_grads)
|
542
|
+
|
543
|
+
# Now compute averaged gradients using axis names
|
544
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
545
|
+
print("Averaged gradients:", averaged_grads)
|
546
|
+
|
547
|
+
return individual_grads, averaged_grads
|
548
|
+
|
549
|
+
def _gradient_averaging_simulation_jax(self):
|
550
|
+
def loss_function(params, x, y):
|
551
|
+
"""Simple quadratic loss"""
|
552
|
+
pred = params * x
|
553
|
+
return (pred - y) ** 2
|
554
|
+
|
555
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
556
|
+
"""Compute gradients and average them across the batch"""
|
557
|
+
# Compute per-sample gradients
|
558
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
559
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
560
|
+
|
561
|
+
# Average gradients across batch using named axis
|
562
|
+
def average_grads(grads):
|
563
|
+
return pmean(grads, axis_name='batch')
|
564
|
+
|
565
|
+
# Apply averaging with named axis
|
566
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
567
|
+
return averaged_grads
|
568
|
+
|
569
|
+
# Example data
|
570
|
+
params = 2.0
|
571
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
572
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
573
|
+
|
574
|
+
print("Parameters:", params)
|
575
|
+
print("Batch X:", batch_x)
|
576
|
+
print("Batch Y:", batch_y)
|
577
|
+
|
578
|
+
# Compute individual gradients first
|
579
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
580
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
581
|
+
print("Individual gradients:", individual_grads)
|
582
|
+
|
583
|
+
# Now compute averaged gradients using axis names
|
584
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
585
|
+
print("Averaged gradients:", averaged_grads)
|
586
|
+
|
587
|
+
return individual_grads, averaged_grads
|
588
|
+
|
589
|
+
def test_gradient_averaging_simulation(self):
|
590
|
+
individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
|
591
|
+
individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
|
592
|
+
assert jnp.allclose(individual_grads, individual_grads_jax)
|
593
|
+
assert jnp.allclose(averaged_grads, averaged_grads_jax)
|
594
|
+
|
595
|
+
|
596
|
+
|
597
|
+
|
@@ -203,9 +203,9 @@ def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
|
203
203
|
Examples
|
204
204
|
--------
|
205
205
|
|
206
|
-
>>> import brainstate
|
206
|
+
>>> import brainstate
|
207
207
|
>>> def f(a):
|
208
|
-
>>> return
|
208
|
+
>>> return brainstate.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
|
209
209
|
>>> branches=[lambda: 1,
|
210
210
|
>>> lambda: 2,
|
211
211
|
>>> lambda: 3,
|
@@ -352,6 +352,13 @@ class StatefulFunction(PrettyObject):
|
|
352
352
|
cache_key = default_cache_key
|
353
353
|
return self.get_state_trace(cache_key).get_write_states()
|
354
354
|
|
355
|
+
def _check_input(self, x):
|
356
|
+
if isinstance(x, State):
|
357
|
+
raise ValueError(
|
358
|
+
'Inputs for brainstate transformations cannot be an instance of State. '
|
359
|
+
f'But we got {x}'
|
360
|
+
)
|
361
|
+
|
355
362
|
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
356
363
|
"""
|
357
364
|
Get the static arguments from the arguments.
|
@@ -370,22 +377,35 @@ class StatefulFunction(PrettyObject):
|
|
370
377
|
static_args.append(arg)
|
371
378
|
else:
|
372
379
|
dyn_args.append(arg)
|
373
|
-
dyn_args = jax.tree.map(shaped_abstractify,
|
380
|
+
dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
|
374
381
|
static_kwargs, dyn_kwargs = [], []
|
375
382
|
for k, v in kwargs.items():
|
376
383
|
if k in self.static_argnames:
|
377
384
|
static_kwargs.append((k, v))
|
378
385
|
else:
|
379
386
|
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
380
|
-
|
387
|
+
|
388
|
+
static_args = make_hashable(tuple(static_args))
|
389
|
+
dyn_args = make_hashable(tuple(dyn_args))
|
390
|
+
static_kwargs = make_hashable(static_kwargs)
|
391
|
+
dyn_kwargs = make_hashable(dyn_kwargs)
|
392
|
+
|
393
|
+
cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
|
381
394
|
elif self.cache_type is None:
|
382
395
|
num_arg = len(args)
|
383
396
|
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
384
397
|
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
385
|
-
|
398
|
+
|
399
|
+
# Make everything hashable
|
400
|
+
static_args = make_hashable(static_args)
|
401
|
+
static_kwargs = make_hashable(static_kwargs)
|
402
|
+
|
403
|
+
cache_key = (static_args, static_kwargs)
|
386
404
|
else:
|
387
405
|
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
388
406
|
|
407
|
+
return cache_key
|
408
|
+
|
389
409
|
def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
|
390
410
|
"""
|
391
411
|
Compile the function, and get the states that are read and written by this function.
|
@@ -480,6 +500,9 @@ class StatefulFunction(PrettyObject):
|
|
480
500
|
# static args
|
481
501
|
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
482
502
|
|
503
|
+
# check input types
|
504
|
+
jax.tree.map(self._check_input, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
505
|
+
|
483
506
|
if cache_key not in self._cached_state_trace:
|
484
507
|
try:
|
485
508
|
# jaxpr
|
@@ -637,15 +660,15 @@ def make_jaxpr(
|
|
637
660
|
instead give a few examples.
|
638
661
|
|
639
662
|
>>> import jax
|
640
|
-
>>> import brainstate as
|
663
|
+
>>> import brainstate as brainstate
|
641
664
|
>>>
|
642
665
|
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
643
666
|
>>> print(f(3.0))
|
644
667
|
-0.83602
|
645
|
-
>>> jaxpr, states =
|
668
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
|
646
669
|
>>> jaxpr
|
647
670
|
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
648
|
-
>>> jaxpr, states =
|
671
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
|
649
672
|
>>> jaxpr
|
650
673
|
{ lambda ; a:f32[]. let
|
651
674
|
b:f32[] = cos a
|
@@ -844,3 +867,33 @@ def _make_jaxpr(
|
|
844
867
|
if hasattr(fun, "__name__"):
|
845
868
|
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
846
869
|
return make_jaxpr_f
|
870
|
+
|
871
|
+
|
872
|
+
def make_hashable(obj):
|
873
|
+
"""Convert a pytree into a hashable representation."""
|
874
|
+
if isinstance(obj, (list, tuple)):
|
875
|
+
return tuple(make_hashable(item) for item in obj)
|
876
|
+
elif isinstance(obj, dict):
|
877
|
+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
878
|
+
elif isinstance(obj, set):
|
879
|
+
return frozenset(make_hashable(item) for item in obj)
|
880
|
+
elif hasattr(obj, '__dict__'): # Handle custom objects
|
881
|
+
return (
|
882
|
+
obj.__class__.__name__,
|
883
|
+
tuple(
|
884
|
+
sorted(
|
885
|
+
(k, make_hashable(v))
|
886
|
+
for k, v in obj.__dict__.items()
|
887
|
+
if not k.startswith('_')
|
888
|
+
)
|
889
|
+
)
|
890
|
+
)
|
891
|
+
else:
|
892
|
+
# # Use JAX's tree_util for any other pytree structures
|
893
|
+
# try:
|
894
|
+
# leaves, treedef = jax.tree_util.tree_flatten(obj)
|
895
|
+
# hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
|
896
|
+
# return (str(treedef), hashable_leaves)
|
897
|
+
# except:
|
898
|
+
# # Assume obj is already hashable
|
899
|
+
return obj
|
@@ -53,7 +53,7 @@ class ProgressBar(object):
|
|
53
53
|
|
54
54
|
.. code-block:: python
|
55
55
|
|
56
|
-
a =
|
56
|
+
a = brainstate.State(1.)
|
57
57
|
def loop_fn(x):
|
58
58
|
a.value = x.value + 1.
|
59
59
|
return jnp.sum(x ** 2)
|
@@ -61,7 +61,7 @@ class ProgressBar(object):
|
|
61
61
|
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
62
62
|
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
63
63
|
|
64
|
-
|
64
|
+
brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
|
65
65
|
|
66
66
|
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
67
67
|
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
brainstate/environ.py
CHANGED
@@ -76,9 +76,9 @@ def context(**kwargs):
|
|
76
76
|
|
77
77
|
For instance::
|
78
78
|
|
79
|
-
>>> import brainstate as
|
80
|
-
>>> with
|
81
|
-
... dt =
|
79
|
+
>>> import brainstate as brainstate
|
80
|
+
>>> with brainstate.environ.context(dt=0.1) as env:
|
81
|
+
... dt = brainstate.environ.get('dt')
|
82
82
|
... print(env)
|
83
83
|
|
84
84
|
"""
|
@@ -424,10 +424,10 @@ def dftype() -> DTypeLike:
|
|
424
424
|
|
425
425
|
For example, if the precision is set to 32, the default floating data type is ``np.float32``.
|
426
426
|
|
427
|
-
>>> import brainstate as
|
427
|
+
>>> import brainstate as brainstate
|
428
428
|
>>> import numpy as np
|
429
|
-
>>> with
|
430
|
-
... a = np.zeros(1, dtype=
|
429
|
+
>>> with brainstate.environ.context(precision=32):
|
430
|
+
... a = np.zeros(1, dtype=brainstate.environ.dftype())
|
431
431
|
>>> print(a.dtype)
|
432
432
|
|
433
433
|
Returns
|
@@ -448,10 +448,10 @@ def ditype() -> DTypeLike:
|
|
448
448
|
|
449
449
|
For example, if the precision is set to 32, the default integer data type is ``np.int32``.
|
450
450
|
|
451
|
-
>>> import brainstate as
|
451
|
+
>>> import brainstate as brainstate
|
452
452
|
>>> import numpy as np
|
453
|
-
>>> with
|
454
|
-
... a = np.zeros(1, dtype=
|
453
|
+
>>> with brainstate.environ.context(precision=32):
|
454
|
+
... a = np.zeros(1, dtype=brainstate.environ.ditype())
|
455
455
|
>>> print(a.dtype)
|
456
456
|
int32
|
457
457
|
|
@@ -474,10 +474,10 @@ def dutype() -> DTypeLike:
|
|
474
474
|
|
475
475
|
For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
|
476
476
|
|
477
|
-
>>> import brainstate as
|
477
|
+
>>> import brainstate as brainstate
|
478
478
|
>>> import numpy as np
|
479
|
-
>>> with
|
480
|
-
... a = np.zeros(1, dtype=
|
479
|
+
>>> with brainstate.environ.context(precision=32):
|
480
|
+
... a = np.zeros(1, dtype=brainstate.environ.dutype())
|
481
481
|
>>> print(a.dtype)
|
482
482
|
uint32
|
483
483
|
|
@@ -499,10 +499,10 @@ def dctype() -> DTypeLike:
|
|
499
499
|
|
500
500
|
For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
|
501
501
|
|
502
|
-
>>> import brainstate as
|
502
|
+
>>> import brainstate as brainstate
|
503
503
|
>>> import numpy as np
|
504
|
-
>>> with
|
505
|
-
... a = np.zeros(1, dtype=
|
504
|
+
>>> with brainstate.environ.context(precision=32):
|
505
|
+
... a = np.zeros(1, dtype=brainstate.environ.dctype())
|
506
506
|
>>> print(a.dtype)
|
507
507
|
complex64
|
508
508
|
|
@@ -529,19 +529,19 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
|
|
529
529
|
|
530
530
|
For example, you can register a default behavior for the key 'dt' by::
|
531
531
|
|
532
|
-
>>> import brainstate as
|
532
|
+
>>> import brainstate as brainstate
|
533
533
|
>>> def dt_behavior(dt):
|
534
534
|
... print(f'Set the default dt to {dt}.')
|
535
535
|
...
|
536
|
-
>>>
|
536
|
+
>>> brainstate.environ.register_default_behavior('dt', dt_behavior)
|
537
537
|
|
538
538
|
Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
|
539
539
|
`dt_behavior` will be called with
|
540
540
|
`dt_behavior(0.1)`.
|
541
541
|
|
542
|
-
>>>
|
542
|
+
>>> brainstate.environ.set(dt=0.1)
|
543
543
|
Set the default dt to 0.1.
|
544
|
-
>>> with
|
544
|
+
>>> with brainstate.environ.context(dt=0.2):
|
545
545
|
... pass
|
546
546
|
Set the default dt to 0.2.
|
547
547
|
Set the default dt to 0.1.
|
@@ -70,39 +70,39 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
70
70
|
check_dtypes=False)
|
71
71
|
|
72
72
|
# def testSquareplusGrad(self):
|
73
|
-
# check_grads(
|
73
|
+
# check_grads(brainstate.functional.squareplus, (1e-8,), order=4,
|
74
74
|
# )
|
75
75
|
|
76
76
|
# def testSquareplusGradZero(self):
|
77
|
-
# check_grads(
|
77
|
+
# check_grads(brainstate.functional.squareplus, (0.,), order=1,
|
78
78
|
# )
|
79
79
|
|
80
80
|
# def testSquareplusGradNegInf(self):
|
81
|
-
# check_grads(
|
81
|
+
# check_grads(brainstate.functional.squareplus, (-float('inf'),), order=1,
|
82
82
|
# )
|
83
83
|
|
84
84
|
# def testSquareplusGradNan(self):
|
85
|
-
# check_grads(
|
85
|
+
# check_grads(brainstate.functional.squareplus, (float('nan'),), order=1,
|
86
86
|
# )
|
87
87
|
|
88
88
|
# @parameterized.parameters([float] + jtu.dtypes.floating)
|
89
89
|
# def testSquareplusZero(self, dtype):
|
90
|
-
# self.assertEqual(dtype(1),
|
90
|
+
# self.assertEqual(dtype(1), brainstate.functional.squareplus(dtype(0), dtype(4)))
|
91
91
|
#
|
92
92
|
# def testMishGrad(self):
|
93
|
-
# check_grads(
|
93
|
+
# check_grads(brainstate.functional.mish, (1e-8,), order=4,
|
94
94
|
# )
|
95
95
|
#
|
96
96
|
# def testMishGradZero(self):
|
97
|
-
# check_grads(
|
97
|
+
# check_grads(brainstate.functional.mish, (0.,), order=1,
|
98
98
|
# )
|
99
99
|
#
|
100
100
|
# def testMishGradNegInf(self):
|
101
|
-
# check_grads(
|
101
|
+
# check_grads(brainstate.functional.mish, (-float('inf'),), order=1,
|
102
102
|
# )
|
103
103
|
#
|
104
104
|
# def testMishGradNan(self):
|
105
|
-
# check_grads(
|
105
|
+
# check_grads(brainstate.functional.mish, (float('nan'),), order=1,
|
106
106
|
# )
|
107
107
|
|
108
108
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
@@ -137,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
137
137
|
self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
138
|
|
139
139
|
# def testSquareplusValue(self):
|
140
|
-
# val =
|
140
|
+
# val = brainstate.functional.squareplus(1e3)
|
141
141
|
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
142
|
|
143
143
|
def testMishValue(self):
|
@@ -177,7 +177,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
177
177
|
brainstate.functional.softplus,
|
178
178
|
brainstate.functional.sparse_plus,
|
179
179
|
brainstate.functional.sigmoid,
|
180
|
-
#
|
180
|
+
# brainstate.functional.squareplus,
|
181
181
|
brainstate.functional.mish)))
|
182
182
|
def testDtypeMatchesInput(self, dtype, fn):
|
183
183
|
x = jnp.zeros((), dtype=dtype)
|
@@ -306,7 +306,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
306
306
|
|
307
307
|
def testCustomJVPLeak2(self):
|
308
308
|
# https://github.com/google/jax/issues/8171
|
309
|
-
# The above test uses jax.
|
309
|
+
# The above test uses jax.brainstate.functional.sigmoid, as in the original #8171, but that
|
310
310
|
# function no longer actually has a custom_jvp! So we inline the old def.
|
311
311
|
|
312
312
|
@jax.custom_jvp
|