brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +6 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +22 -17
  7. brainstate/augment/_mapping_test.py +162 -0
  8. brainstate/compile/_conditions.py +2 -2
  9. brainstate/compile/_make_jaxpr.py +59 -6
  10. brainstate/compile/_progress_bar.py +2 -2
  11. brainstate/environ.py +19 -19
  12. brainstate/functional/_activations_test.py +12 -12
  13. brainstate/graph/_graph_operation.py +69 -69
  14. brainstate/graph/_graph_operation_test.py +2 -2
  15. brainstate/mixin.py +0 -17
  16. brainstate/nn/_collective_ops.py +4 -4
  17. brainstate/nn/_common.py +7 -19
  18. brainstate/nn/_dropout_test.py +2 -2
  19. brainstate/nn/_dynamics.py +53 -35
  20. brainstate/nn/_elementwise.py +30 -30
  21. brainstate/nn/_exp_euler.py +13 -16
  22. brainstate/nn/_inputs.py +1 -1
  23. brainstate/nn/_linear.py +4 -4
  24. brainstate/nn/_module.py +6 -6
  25. brainstate/nn/_module_test.py +1 -1
  26. brainstate/nn/_normalizations.py +11 -11
  27. brainstate/nn/_normalizations_test.py +6 -6
  28. brainstate/nn/_poolings.py +24 -24
  29. brainstate/nn/_synapse.py +1 -12
  30. brainstate/nn/_utils.py +1 -1
  31. brainstate/nn/metrics.py +4 -4
  32. brainstate/optim/_optax_optimizer.py +8 -8
  33. brainstate/random/_rand_funs.py +37 -37
  34. brainstate/random/_rand_funs_test.py +3 -3
  35. brainstate/random/_rand_seed.py +7 -7
  36. brainstate/random/_rand_state.py +13 -7
  37. brainstate/surrogate.py +40 -40
  38. brainstate/util/pretty_pytree.py +10 -10
  39. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  40. brainstate/util/struct.py +7 -7
  41. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  42. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
  43. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  44. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,8 @@ import unittest
19
19
  import jax
20
20
  import jax.numpy as jnp
21
21
  import numpy as np
22
+ from jax import vmap
23
+ from jax.lax import psum, pmean, pmax
22
24
 
23
25
  import brainstate
24
26
  import brainstate.augment
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
433
435
  foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
434
436
 
435
437
  faulty_init()
438
+
439
+
440
+ class TestAxisName:
441
+ def test1(self):
442
+ def compute_stats_with_axis_name(x):
443
+ """Compute statistics using named axis operations"""
444
+ # Sum across the named axis 'batch'
445
+ total_sum = psum(x, axis_name='batch')
446
+
447
+ # Mean across the named axis 'batch'
448
+ mean_val = pmean(x, axis_name='batch')
449
+
450
+ # Max across the named axis 'batch'
451
+ max_val = pmax(x, axis_name='batch')
452
+
453
+ return {
454
+ 'sum': total_sum,
455
+ 'mean': mean_val,
456
+ 'max': max_val,
457
+ 'original': x
458
+ }
459
+
460
+ batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
461
+ print("Input batch data:", batch_data)
462
+
463
+ # vmap with axis name 'batch'
464
+ vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
465
+ result_jax = vectorized_stats_jax(batch_data)
466
+
467
+ # vmap with axis name 'batch'
468
+ vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
469
+ result = vectorized_stats(batch_data)
470
+
471
+ # vmap with axis name 'batch'
472
+ vectorized_stats_v2 = brainstate.transform.jit(
473
+ brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
474
+ )
475
+ result_v2 = vectorized_stats_v2(batch_data)
476
+
477
+ for key in result_jax.keys():
478
+ print(f" {key}: {result_jax[key]}")
479
+ assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
480
+ assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
481
+
482
+ def test_nested_vmap(self):
483
+ def nested_computation(x):
484
+ """Computation with multiple named axes"""
485
+ # Sum over 'inner' axis, then mean over 'outer' axis
486
+ inner_sum = psum(x, axis_name='inner')
487
+ outer_mean = pmean(inner_sum, axis_name='outer')
488
+ return outer_mean
489
+
490
+ # Create 2D batch data
491
+ data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
492
+ print("Input 2D data shape:", data_2d.shape)
493
+ print("Input 2D data:\n", data_2d)
494
+
495
+ # Nested vmap: first over inner dimension, then outer dimension
496
+ inner_vmap = vmap(nested_computation, axis_name='inner')
497
+ nested_vmap = vmap(inner_vmap, axis_name='outer')
498
+
499
+ result_2d = nested_vmap(data_2d)
500
+ print("Result after nested vmap:", result_2d)
501
+
502
+ inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
503
+ nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
504
+ result_2d_bst = nested_vmap_bst(data_2d)
505
+ print("Result after nested vmap:", result_2d_bst)
506
+
507
+ assert jnp.allclose(result_2d, result_2d_bst)
508
+
509
+ def _gradient_averaging_simulation_bst(self):
510
+ def loss_function(params, x, y):
511
+ """Simple quadratic loss"""
512
+ pred = params * x
513
+ return (pred - y) ** 2
514
+
515
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
516
+ """Compute gradients and average them across the batch"""
517
+ # Compute per-sample gradients
518
+ grad_fn = jax.grad(loss_function, argnums=0)
519
+ per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
520
+
521
+ # Average gradients across batch using named axis
522
+ def average_grads(grads):
523
+ return pmean(grads, axis_name='batch')
524
+
525
+ # Apply averaging with named axis
526
+ averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
527
+ return averaged_grads
528
+
529
+ # Example data
530
+ params = 2.0
531
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
532
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
533
+
534
+ print("Parameters:", params)
535
+ print("Batch X:", batch_x)
536
+ print("Batch Y:", batch_y)
537
+
538
+ # Compute individual gradients first
539
+ grad_fn = jax.grad(loss_function, argnums=0)
540
+ individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
541
+ print("Individual gradients:", individual_grads)
542
+
543
+ # Now compute averaged gradients using axis names
544
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
545
+ print("Averaged gradients:", averaged_grads)
546
+
547
+ return individual_grads, averaged_grads
548
+
549
+ def _gradient_averaging_simulation_jax(self):
550
+ def loss_function(params, x, y):
551
+ """Simple quadratic loss"""
552
+ pred = params * x
553
+ return (pred - y) ** 2
554
+
555
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
556
+ """Compute gradients and average them across the batch"""
557
+ # Compute per-sample gradients
558
+ grad_fn = jax.grad(loss_function, argnums=0)
559
+ per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
560
+
561
+ # Average gradients across batch using named axis
562
+ def average_grads(grads):
563
+ return pmean(grads, axis_name='batch')
564
+
565
+ # Apply averaging with named axis
566
+ averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
567
+ return averaged_grads
568
+
569
+ # Example data
570
+ params = 2.0
571
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
572
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
573
+
574
+ print("Parameters:", params)
575
+ print("Batch X:", batch_x)
576
+ print("Batch Y:", batch_y)
577
+
578
+ # Compute individual gradients first
579
+ grad_fn = jax.grad(loss_function, argnums=0)
580
+ individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
581
+ print("Individual gradients:", individual_grads)
582
+
583
+ # Now compute averaged gradients using axis names
584
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
585
+ print("Averaged gradients:", averaged_grads)
586
+
587
+ return individual_grads, averaged_grads
588
+
589
+ def test_gradient_averaging_simulation(self):
590
+ individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
591
+ individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
592
+ assert jnp.allclose(individual_grads, individual_grads_jax)
593
+ assert jnp.allclose(averaged_grads, averaged_grads_jax)
594
+
595
+
596
+
597
+
@@ -203,9 +203,9 @@ def ifelse(conditions, branches, *operands, check_cond: bool = True):
203
203
  Examples
204
204
  --------
205
205
 
206
- >>> import brainstate as bst
206
+ >>> import brainstate
207
207
  >>> def f(a):
208
- >>> return bst.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
208
+ >>> return brainstate.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
209
209
  >>> branches=[lambda: 1,
210
210
  >>> lambda: 2,
211
211
  >>> lambda: 3,
@@ -352,6 +352,13 @@ class StatefulFunction(PrettyObject):
352
352
  cache_key = default_cache_key
353
353
  return self.get_state_trace(cache_key).get_write_states()
354
354
 
355
+ def _check_input(self, x):
356
+ if isinstance(x, State):
357
+ raise ValueError(
358
+ 'Inputs for brainstate transformations cannot be an instance of State. '
359
+ f'But we got {x}'
360
+ )
361
+
355
362
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
356
363
  """
357
364
  Get the static arguments from the arguments.
@@ -370,22 +377,35 @@ class StatefulFunction(PrettyObject):
370
377
  static_args.append(arg)
371
378
  else:
372
379
  dyn_args.append(arg)
373
- dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
380
+ dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
374
381
  static_kwargs, dyn_kwargs = [], []
375
382
  for k, v in kwargs.items():
376
383
  if k in self.static_argnames:
377
384
  static_kwargs.append((k, v))
378
385
  else:
379
386
  dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
380
- return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
387
+
388
+ static_args = make_hashable(tuple(static_args))
389
+ dyn_args = make_hashable(tuple(dyn_args))
390
+ static_kwargs = make_hashable(static_kwargs)
391
+ dyn_kwargs = make_hashable(dyn_kwargs)
392
+
393
+ cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
381
394
  elif self.cache_type is None:
382
395
  num_arg = len(args)
383
396
  static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
384
397
  static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
385
- return tuple([static_args, static_kwargs])
398
+
399
+ # Make everything hashable
400
+ static_args = make_hashable(static_args)
401
+ static_kwargs = make_hashable(static_kwargs)
402
+
403
+ cache_key = (static_args, static_kwargs)
386
404
  else:
387
405
  raise ValueError(f"Invalid cache type: {self.cache_type}")
388
406
 
407
+ return cache_key
408
+
389
409
  def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
390
410
  """
391
411
  Compile the function, and get the states that are read and written by this function.
@@ -480,6 +500,9 @@ class StatefulFunction(PrettyObject):
480
500
  # static args
481
501
  cache_key = self.get_arg_cache_key(*args, **kwargs)
482
502
 
503
+ # check input types
504
+ jax.tree.map(self._check_input, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
505
+
483
506
  if cache_key not in self._cached_state_trace:
484
507
  try:
485
508
  # jaxpr
@@ -637,15 +660,15 @@ def make_jaxpr(
637
660
  instead give a few examples.
638
661
 
639
662
  >>> import jax
640
- >>> import brainstate as bst
663
+ >>> import brainstate as brainstate
641
664
  >>>
642
665
  >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
643
666
  >>> print(f(3.0))
644
667
  -0.83602
645
- >>> jaxpr, states = bst.compile.make_jaxpr(f)(3.0)
668
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
646
669
  >>> jaxpr
647
670
  { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
648
- >>> jaxpr, states = bst.compile.make_jaxpr(jax.grad(f))(3.0)
671
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
649
672
  >>> jaxpr
650
673
  { lambda ; a:f32[]. let
651
674
  b:f32[] = cos a
@@ -844,3 +867,33 @@ def _make_jaxpr(
844
867
  if hasattr(fun, "__name__"):
845
868
  make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
846
869
  return make_jaxpr_f
870
+
871
+
872
+ def make_hashable(obj):
873
+ """Convert a pytree into a hashable representation."""
874
+ if isinstance(obj, (list, tuple)):
875
+ return tuple(make_hashable(item) for item in obj)
876
+ elif isinstance(obj, dict):
877
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
878
+ elif isinstance(obj, set):
879
+ return frozenset(make_hashable(item) for item in obj)
880
+ elif hasattr(obj, '__dict__'): # Handle custom objects
881
+ return (
882
+ obj.__class__.__name__,
883
+ tuple(
884
+ sorted(
885
+ (k, make_hashable(v))
886
+ for k, v in obj.__dict__.items()
887
+ if not k.startswith('_')
888
+ )
889
+ )
890
+ )
891
+ else:
892
+ # # Use JAX's tree_util for any other pytree structures
893
+ # try:
894
+ # leaves, treedef = jax.tree_util.tree_flatten(obj)
895
+ # hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
896
+ # return (str(treedef), hashable_leaves)
897
+ # except:
898
+ # # Assume obj is already hashable
899
+ return obj
@@ -53,7 +53,7 @@ class ProgressBar(object):
53
53
 
54
54
  .. code-block:: python
55
55
 
56
- a = bst.State(1.)
56
+ a = brainstate.State(1.)
57
57
  def loop_fn(x):
58
58
  a.value = x.value + 1.
59
59
  return jnp.sum(x ** 2)
@@ -61,7 +61,7 @@ class ProgressBar(object):
61
61
  pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
62
62
  lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
63
63
 
64
- bst.compile.for_loop(loop_fn, xs, pbar=pbar)
64
+ brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
65
65
 
66
66
  In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
67
67
  the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
brainstate/environ.py CHANGED
@@ -76,9 +76,9 @@ def context(**kwargs):
76
76
 
77
77
  For instance::
78
78
 
79
- >>> import brainstate as bst
80
- >>> with bst.environ.context(dt=0.1) as env:
81
- ... dt = bst.environ.get('dt')
79
+ >>> import brainstate as brainstate
80
+ >>> with brainstate.environ.context(dt=0.1) as env:
81
+ ... dt = brainstate.environ.get('dt')
82
82
  ... print(env)
83
83
 
84
84
  """
@@ -424,10 +424,10 @@ def dftype() -> DTypeLike:
424
424
 
425
425
  For example, if the precision is set to 32, the default floating data type is ``np.float32``.
426
426
 
427
- >>> import brainstate as bst
427
+ >>> import brainstate as brainstate
428
428
  >>> import numpy as np
429
- >>> with bst.environ.context(precision=32):
430
- ... a = np.zeros(1, dtype=bst.environ.dftype())
429
+ >>> with brainstate.environ.context(precision=32):
430
+ ... a = np.zeros(1, dtype=brainstate.environ.dftype())
431
431
  >>> print(a.dtype)
432
432
 
433
433
  Returns
@@ -448,10 +448,10 @@ def ditype() -> DTypeLike:
448
448
 
449
449
  For example, if the precision is set to 32, the default integer data type is ``np.int32``.
450
450
 
451
- >>> import brainstate as bst
451
+ >>> import brainstate as brainstate
452
452
  >>> import numpy as np
453
- >>> with bst.environ.context(precision=32):
454
- ... a = np.zeros(1, dtype=bst.environ.ditype())
453
+ >>> with brainstate.environ.context(precision=32):
454
+ ... a = np.zeros(1, dtype=brainstate.environ.ditype())
455
455
  >>> print(a.dtype)
456
456
  int32
457
457
 
@@ -474,10 +474,10 @@ def dutype() -> DTypeLike:
474
474
 
475
475
  For example, if the precision is set to 32, the default unsigned integer data type is ``np.uint32``.
476
476
 
477
- >>> import brainstate as bst
477
+ >>> import brainstate as brainstate
478
478
  >>> import numpy as np
479
- >>> with bst.environ.context(precision=32):
480
- ... a = np.zeros(1, dtype=bst.environ.dutype())
479
+ >>> with brainstate.environ.context(precision=32):
480
+ ... a = np.zeros(1, dtype=brainstate.environ.dutype())
481
481
  >>> print(a.dtype)
482
482
  uint32
483
483
 
@@ -499,10 +499,10 @@ def dctype() -> DTypeLike:
499
499
 
500
500
  For example, if the precision is set to 32, the default complex data type is ``np.complex64``.
501
501
 
502
- >>> import brainstate as bst
502
+ >>> import brainstate as brainstate
503
503
  >>> import numpy as np
504
- >>> with bst.environ.context(precision=32):
505
- ... a = np.zeros(1, dtype=bst.environ.dctype())
504
+ >>> with brainstate.environ.context(precision=32):
505
+ ... a = np.zeros(1, dtype=brainstate.environ.dctype())
506
506
  >>> print(a.dtype)
507
507
  complex64
508
508
 
@@ -529,19 +529,19 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo
529
529
 
530
530
  For example, you can register a default behavior for the key 'dt' by::
531
531
 
532
- >>> import brainstate as bst
532
+ >>> import brainstate as brainstate
533
533
  >>> def dt_behavior(dt):
534
534
  ... print(f'Set the default dt to {dt}.')
535
535
  ...
536
- >>> bst.environ.register_default_behavior('dt', dt_behavior)
536
+ >>> brainstate.environ.register_default_behavior('dt', dt_behavior)
537
537
 
538
538
  Then, when you set the default dt by `brainstate.environ.set(dt=0.1)`, the behavior
539
539
  `dt_behavior` will be called with
540
540
  `dt_behavior(0.1)`.
541
541
 
542
- >>> bst.environ.set(dt=0.1)
542
+ >>> brainstate.environ.set(dt=0.1)
543
543
  Set the default dt to 0.1.
544
- >>> with bst.environ.context(dt=0.2):
544
+ >>> with brainstate.environ.context(dt=0.2):
545
545
  ... pass
546
546
  Set the default dt to 0.2.
547
547
  Set the default dt to 0.1.
@@ -70,39 +70,39 @@ class NNFunctionsTest(jtu.JaxTestCase):
70
70
  check_dtypes=False)
71
71
 
72
72
  # def testSquareplusGrad(self):
73
- # check_grads(bst.functional.squareplus, (1e-8,), order=4,
73
+ # check_grads(brainstate.functional.squareplus, (1e-8,), order=4,
74
74
  # )
75
75
 
76
76
  # def testSquareplusGradZero(self):
77
- # check_grads(bst.functional.squareplus, (0.,), order=1,
77
+ # check_grads(brainstate.functional.squareplus, (0.,), order=1,
78
78
  # )
79
79
 
80
80
  # def testSquareplusGradNegInf(self):
81
- # check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
81
+ # check_grads(brainstate.functional.squareplus, (-float('inf'),), order=1,
82
82
  # )
83
83
 
84
84
  # def testSquareplusGradNan(self):
85
- # check_grads(bst.functional.squareplus, (float('nan'),), order=1,
85
+ # check_grads(brainstate.functional.squareplus, (float('nan'),), order=1,
86
86
  # )
87
87
 
88
88
  # @parameterized.parameters([float] + jtu.dtypes.floating)
89
89
  # def testSquareplusZero(self, dtype):
90
- # self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
90
+ # self.assertEqual(dtype(1), brainstate.functional.squareplus(dtype(0), dtype(4)))
91
91
  #
92
92
  # def testMishGrad(self):
93
- # check_grads(bst.functional.mish, (1e-8,), order=4,
93
+ # check_grads(brainstate.functional.mish, (1e-8,), order=4,
94
94
  # )
95
95
  #
96
96
  # def testMishGradZero(self):
97
- # check_grads(bst.functional.mish, (0.,), order=1,
97
+ # check_grads(brainstate.functional.mish, (0.,), order=1,
98
98
  # )
99
99
  #
100
100
  # def testMishGradNegInf(self):
101
- # check_grads(bst.functional.mish, (-float('inf'),), order=1,
101
+ # check_grads(brainstate.functional.mish, (-float('inf'),), order=1,
102
102
  # )
103
103
  #
104
104
  # def testMishGradNan(self):
105
- # check_grads(bst.functional.mish, (float('nan'),), order=1,
105
+ # check_grads(brainstate.functional.mish, (float('nan'),), order=1,
106
106
  # )
107
107
 
108
108
  @parameterized.parameters([float] + jtu.dtypes.floating)
@@ -137,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
137
137
  self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
138
138
 
139
139
  # def testSquareplusValue(self):
140
- # val = bst.functional.squareplus(1e3)
140
+ # val = brainstate.functional.squareplus(1e3)
141
141
  # self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
142
 
143
143
  def testMishValue(self):
@@ -177,7 +177,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
177
177
  brainstate.functional.softplus,
178
178
  brainstate.functional.sparse_plus,
179
179
  brainstate.functional.sigmoid,
180
- # bst.functional.squareplus,
180
+ # brainstate.functional.squareplus,
181
181
  brainstate.functional.mish)))
182
182
  def testDtypeMatchesInput(self, dtype, fn):
183
183
  x = jnp.zeros((), dtype=dtype)
@@ -306,7 +306,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
306
306
 
307
307
  def testCustomJVPLeak2(self):
308
308
  # https://github.com/google/jax/issues/8171
309
- # The above test uses jax.bst.functional.sigmoid, as in the original #8171, but that
309
+ # The above test uses jax.brainstate.functional.sigmoid, as in the original #8171, but that
310
310
  # function no longer actually has a custom_jvp! So we inline the old def.
311
311
 
312
312
  @jax.custom_jvp