brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -55,6 +55,7 @@ import functools
55
55
  import inspect
56
56
  import operator
57
57
  import threading
58
+ import warnings
58
59
  from collections import OrderedDict, defaultdict
59
60
  from collections.abc import Hashable, Iterable, Sequence
60
61
  from collections.abc import MutableSet
@@ -66,7 +67,6 @@ import jax.numpy as jnp
66
67
  from jax._src import source_info_util
67
68
  from jax._src.linear_util import annotate
68
69
  from jax._src.traceback_util import api_boundary
69
- from jax._src.util import memoize
70
70
  from jax.api_util import shaped_abstractify
71
71
  from jax.extend.linear_util import transformation_with_aux
72
72
  from jax.interpreters import partial_eval as pe
@@ -75,6 +75,7 @@ from brainstate._compatible_import import (
75
75
  ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
76
76
  Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
77
77
  )
78
+ from brainstate._error import BatchAxisError
78
79
  from brainstate._state import State, StateTraceStack
79
80
  from brainstate._utils import set_module_as
80
81
  from brainstate.random import RandomState
@@ -1409,6 +1410,122 @@ def make_jaxpr(
1409
1410
 
1410
1411
 
1411
1412
  class StatefulMapping(StatefulFunction):
1413
+ """
1414
+ Vectorized wrapper that preserves BrainState state semantics during mapping.
1415
+
1416
+ ``StatefulMapping`` extends JAX mapping transforms (such as :func:`jax.vmap`
1417
+ and :func:`jax.pmap`) with awareness of :class:`~brainstate.State`
1418
+ instances. It tracks state reads and writes across the mapped axis,
1419
+ ensures deterministic random-number handling, and restores side effects
1420
+ after each batched execution. The helper is typically constructed by
1421
+ :func:`brainstate.transform.vmap` or :func:`brainstate.transform.pmap`, but
1422
+ it can also be instantiated directly for custom mapping primitives.
1423
+
1424
+ Parameters
1425
+ ----------
1426
+ fun : callable
1427
+ Stateless callable to be wrapped. The callable may close over
1428
+ :class:`~brainstate.State` objects that should be tracked during the
1429
+ mapping transform.
1430
+ in_axes : int, tuple of int, or None, default 0
1431
+ Alignment of the mapped axis per positional argument, following the
1432
+ semantics of :func:`jax.vmap`. Arguments mapped with ``None`` are treated
1433
+ as static.
1434
+ out_axes : int, tuple of int, or None, default 0
1435
+ Placement of the mapped axis in the return value, consistent with JAX
1436
+ mapping primitives.
1437
+ state_in_axes : dict[AxisName, Filter] or Filter, optional
1438
+ Specification of input states that participate in the mapped axis. A
1439
+ dictionary maps axis identifiers to :mod:`brainstate.util.filter`
1440
+ predicates; passing a single filter applies it to axis ``0``. Values are
1441
+ normalized via :func:`brainstate.util.filter.to_predicate`.
1442
+ state_out_axes : dict[AxisName, Filter] or Filter, optional
1443
+ Specification of state outputs to scatter back along the mapped axis.
1444
+ Uses the same semantics and normalization as ``state_in_axes``.
1445
+ unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
1446
+ Strategy for handling states written during the mapped call that are not
1447
+ captured by ``state_out_axes``.
1448
+ axis_size : int, optional
1449
+ Explicit size of the mapped axis. When omitted, the size is inferred
1450
+ from the mapped arguments.
1451
+ axis_name : hashable, optional
1452
+ Name for the mapped axis so that collective primitives can target it.
1453
+ name : str, optional
1454
+ Human-readable identifier for diagnostics and debugging.
1455
+ mapping_fn : callable, default ``jax.vmap``
1456
+ Mapping primitive that executes ``fun``. The callable must accept the
1457
+ ``in_axes`` and ``out_axes`` keyword arguments used by :func:`jax.vmap`.
1458
+
1459
+ Attributes
1460
+ ----------
1461
+ origin_fun : callable
1462
+ Original Python callable wrapped by the mapping helper.
1463
+ in_axes : int, tuple of int, or None
1464
+ Mapping specification for positional arguments.
1465
+ out_axes : int, tuple of int, or None
1466
+ Mapping specification for the return value.
1467
+ state_in_axes : dict[AxisName, Predicate]
1468
+ Normalized predicates describing which states to batch on input.
1469
+ state_out_axes : dict[AxisName, Predicate]
1470
+ Normalized predicates describing which states to batch on output.
1471
+ axis_size : int or None
1472
+ Size of the mapped axis, if explicitly provided.
1473
+ axis_name : hashable or None
1474
+ Axis identifier forwarded to collective primitives.
1475
+ mapping_fn : callable
1476
+ Mapping primitive responsible for executing ``fun``.
1477
+
1478
+ Raises
1479
+ ------
1480
+ TypeError
1481
+ If ``in_axes`` has an unsupported type.
1482
+ ValueError
1483
+ If batch dimensions are inconsistent or cannot be inferred.
1484
+ RuntimeError
1485
+ If tracing or executing the mapped function fails.
1486
+
1487
+ Notes
1488
+ -----
1489
+ Random states (for example :class:`~brainstate.RandomState`) encountered
1490
+ during execution are automatically split along the mapped axis and restored
1491
+ afterwards; this behaviour cannot be disabled. The wrapper caches inferred
1492
+ state placements, batch sizes, and trace stacks keyed by abstract argument
1493
+ signatures so repeated calls with the same structure avoid re-tracing.
1494
+
1495
+ Examples
1496
+ --------
1497
+ .. code-block:: python
1498
+
1499
+ >>> import brainstate
1500
+ >>> import jax.numpy as jnp
1501
+ >>> from brainstate.util.filter import OfType
1502
+ >>>
1503
+ >>> counter = brainstate.ShortTermState(jnp.array(0.0))
1504
+ >>>
1505
+ >>> def accumulate(x):
1506
+ ... counter.value = counter.value + x
1507
+ ... return counter.value
1508
+ >>>
1509
+ >>> batched_accumulate = brainstate.transform.StatefulMapping(
1510
+ ... accumulate,
1511
+ ... in_axes=0,
1512
+ ... out_axes=0,
1513
+ ... state_in_axes={0: OfType(brainstate.ShortTermState)},
1514
+ ... state_out_axes={0: OfType(brainstate.ShortTermState)},
1515
+ ... name="batched_accumulate",
1516
+ ... )
1517
+ >>>
1518
+ >>> xs = jnp.ones((3,))
1519
+ >>> batched_accumulate(xs)
1520
+ Array([1., 2., 3.], dtype=float32)
1521
+ >>> counter.value
1522
+ Array(3., dtype=float32)
1523
+
1524
+ See Also
1525
+ --------
1526
+ brainstate.transform.vmap : Convenience API returning a ``StatefulMapping``.
1527
+ brainstate.transform.pmap : Device-mapped variant aware of BrainState states.
1528
+ """
1412
1529
  __module__ = "brainstate.transform"
1413
1530
 
1414
1531
  def __init__(
@@ -1418,11 +1535,13 @@ class StatefulMapping(StatefulFunction):
1418
1535
  out_axes: Union[int, Tuple[int, ...], None] = 0,
1419
1536
  state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1420
1537
  state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1421
- # jit specific parameters
1538
+ unexpected_out_state_mapping: str = 'raise',
1539
+ # JIT specific parameters
1422
1540
  static_argnums: Union[int, Iterable[int]] = (),
1423
1541
  static_argnames: Union[str, Iterable[str]] = (),
1424
1542
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1425
1543
  abstracted_axes: Optional[Any] = None,
1544
+ return_only_write: bool = True,
1426
1545
  # mapping specific parameters
1427
1546
  axis_size: Optional[int] = None,
1428
1547
  axis_name: AxisName | None = None,
@@ -1430,16 +1549,18 @@ class StatefulMapping(StatefulFunction):
1430
1549
  # mapping function
1431
1550
  mapping_fn: Callable = jax.vmap,
1432
1551
  ):
1433
- self.origin_fun = fun
1434
1552
  super().__init__(
1435
- fun=self._wrapped_fun,
1553
+ fun=self.__wrapped_fun,
1436
1554
  static_argnums=static_argnums,
1437
1555
  static_argnames=static_argnames,
1438
1556
  axis_env=axis_env,
1439
1557
  abstracted_axes=abstracted_axes,
1558
+ return_only_write=return_only_write,
1440
1559
  name=name,
1441
- return_only_write=False,
1442
1560
  )
1561
+
1562
+ self.name = name
1563
+ self.origin_fun = fun
1443
1564
  self.in_axes = in_axes
1444
1565
  self.out_axes = out_axes
1445
1566
  if state_in_axes is None:
@@ -1459,6 +1580,7 @@ class StatefulMapping(StatefulFunction):
1459
1580
  self.axis_size = axis_size
1460
1581
  self.axis_name = axis_name
1461
1582
  self.mapping_fn = mapping_fn
1583
+ self.unexpected_out_state_mapping = unexpected_out_state_mapping
1462
1584
 
1463
1585
  # Cache for discovered state-to-axis mappings
1464
1586
  self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
@@ -1466,12 +1588,7 @@ class StatefulMapping(StatefulFunction):
1466
1588
  self._cached_map_state_trace = _BoundedCache(maxsize=128)
1467
1589
  self._cached_map_batch_size = _BoundedCache(maxsize=128)
1468
1590
 
1469
- def _infer_batch_size(self, args, in_axes):
1470
- if in_axes is None:
1471
- raise ValueError("Cannot infer batch size when in_axes is None")
1472
-
1473
- batch_sizes = []
1474
-
1591
+ def __infer_batch_size(self, args, in_axes):
1475
1592
  def get_batch_size_from_arg(arg_, axis_):
1476
1593
  if axis_ is None:
1477
1594
  return None
@@ -1490,6 +1607,7 @@ class StatefulMapping(StatefulFunction):
1490
1607
  sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
1491
1608
  return sizes[0] if sizes else None
1492
1609
 
1610
+ batch_sizes = []
1493
1611
  if isinstance(in_axes, int):
1494
1612
  # All args batched along the same axis
1495
1613
  for arg in args:
@@ -1506,6 +1624,8 @@ class StatefulMapping(StatefulFunction):
1506
1624
  size = get_batch_size_from_arg(arg, axis)
1507
1625
  if size is not None:
1508
1626
  batch_sizes.append(size)
1627
+ elif in_axes is None:
1628
+ pass
1509
1629
  else:
1510
1630
  raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
1511
1631
 
@@ -1523,17 +1643,14 @@ class StatefulMapping(StatefulFunction):
1523
1643
 
1524
1644
  return batch_sizes[0]
1525
1645
 
1526
- def __new_batch_arg(self, batch_size: int, dim_to_states: dict):
1527
- trace = jax.core.trace_ctx.trace
1528
- assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
1529
-
1646
+ def __new_batch_arg(self, trace, batch_size: int, dim_to_states: dict):
1530
1647
  def wrapper(x):
1531
1648
  if isinstance(x, RandomState):
1532
- idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current()))
1649
+ idx = lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current())
1533
1650
  dim_to_states['random'].append(x)
1534
- return to_elt(trace, idx, jnp.ones((batch_size,) + x._value.shape, x._value.dtype), 0)
1651
+ return to_elt(trace, idx, self._rand_value, 0)
1535
1652
  for dim, filter_ in self.state_in_axes.items():
1536
- idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current()))
1653
+ idx = lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current())
1537
1654
  if filter_(tuple(), x):
1538
1655
  dim_to_states[dim].append(x)
1539
1656
  return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
@@ -1541,41 +1658,81 @@ class StatefulMapping(StatefulFunction):
1541
1658
 
1542
1659
  return wrapper
1543
1660
 
1544
- def __eval(self, cache_key, *args, **kwargs):
1545
- def fn_to_eval(*new_args, **new_kwargs):
1546
- dim_to_in_states = defaultdict(list)
1547
- state_trace = StateTraceStack(name=self.name)
1548
- state_trace.set_new_arg(
1549
- self.__new_batch_arg(self._cached_map_batch_size.get(cache_key), dim_to_in_states)
1661
+ def __find_batch_dim(self, st):
1662
+ leaves = jax.tree.leaves(st._value)
1663
+ batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
1664
+ if len(batch_dims) != 1:
1665
+ raise ValueError(
1666
+ f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
1667
+ "All leaves must have the same batch dimension."
1550
1668
  )
1551
- self._cached_map_state_trace.set(cache_key, state_trace)
1552
-
1553
- # call functions
1554
- with state_trace:
1555
- out_ = self.origin_fun(*new_args, **new_kwargs)
1556
-
1557
- # cache
1558
- self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states)
1559
-
1560
- # vmapped state values
1561
- out_states = defaultdict(list)
1562
- out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
1563
- for st in state_trace.states:
1564
- if not isinstance(st, RandomState):
1565
- leaves = jax.tree.leaves(st._value)
1566
- batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
1567
- if len(batch_dims) != 1:
1568
- raise ValueError(
1569
- f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
1570
- "All leaves must have the same batch dimension."
1669
+ dim = batch_dims.pop()
1670
+ return dim
1671
+
1672
+ def __fn_to_eval(self, cache_key, *new_args, **new_kwargs):
1673
+ # state trace
1674
+ trace = jax.core.trace_ctx.trace
1675
+ assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
1676
+ dim_to_in_states = defaultdict(list)
1677
+ state_trace = StateTraceStack(name=self.name)
1678
+ state_trace.set_new_arg(
1679
+ self.__new_batch_arg(trace, self._cached_map_batch_size.get(cache_key), dim_to_in_states)
1680
+ )
1681
+ self._cached_map_state_trace.set(cache_key, state_trace)
1682
+
1683
+ # call functions
1684
+ with state_trace:
1685
+ out_ = self.origin_fun(*new_args, **new_kwargs)
1686
+
1687
+ # cache vmapped in states
1688
+ self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states.copy())
1689
+ mapped_in_states = set([id(v) for vv in dim_to_in_states.values() for v in vv])
1690
+
1691
+ # vmapped out states
1692
+ out_states = defaultdict(list)
1693
+ out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
1694
+ for st in state_trace.states:
1695
+ if isinstance(st, RandomState):
1696
+ continue
1697
+ find = False
1698
+ for dim, filter_ in self.state_out_axes.items():
1699
+ if filter_(tuple(), st):
1700
+ out_states[dim].append(st)
1701
+ find = True
1702
+ break
1703
+ if find:
1704
+ continue
1705
+ dim = self.__find_batch_dim(st)
1706
+ if dim is None or id(st) in mapped_in_states:
1707
+ out_states[dim].append(st)
1708
+ else:
1709
+ if self.unexpected_out_state_mapping == 'raise':
1710
+ st.raise_error_with_source_info(
1711
+ BatchAxisError(
1712
+ f'State\n {st} \n was not expected to be batched on output. '
1713
+ 'Please adjust state_out_axes or set unexpected_out_state_mapping to "warn" or "ignore".'
1571
1714
  )
1572
- batch_dim = batch_dims.pop()
1573
- out_states[batch_dim].append(st)
1574
- self._cached_map_dim_to_out_states.set(cache_key, out_states)
1715
+ )
1716
+ elif self.unexpected_out_state_mapping == 'warn':
1717
+ warnings.warn(
1718
+ f'State\n {st} \n was not expected to be batched on output. '
1719
+ f'Please adjust state_out_axes or set unexpected_out_state_mapping to "ignore".',
1720
+ UserWarning,
1721
+ )
1722
+ out_states[dim].append(st)
1723
+ elif self.unexpected_out_state_mapping == 'ignore':
1724
+ out_states[dim].append(st)
1725
+ else:
1726
+ raise ValueError(
1727
+ 'Invalid value for unexpected_out_state_mapping: '
1728
+ f'{self.unexpected_out_state_mapping}. Must be "raise", "warn", or "ignore".'
1729
+ )
1730
+ self._cached_map_dim_to_out_states.set(cache_key, out_states)
1575
1731
 
1732
+ def __eval(self, cache_key, *args, **kwargs):
1576
1733
  try:
1577
1734
  jax.vmap(
1578
- fn_to_eval,
1735
+ functools.partial(self.__fn_to_eval, cache_key),
1579
1736
  in_axes=self.in_axes,
1580
1737
  out_axes=self.out_axes,
1581
1738
  axis_name=self.axis_name,
@@ -1589,7 +1746,7 @@ class StatefulMapping(StatefulFunction):
1589
1746
  self._cached_map_dim_to_in_states.pop(cache_key, None)
1590
1747
  self._cached_map_dim_to_out_states.pop(cache_key, None)
1591
1748
  self._cached_map_batch_size.pop(cache_key, None)
1592
- raise RuntimeError(f"Failed to evaluate {self}") from e
1749
+ raise e
1593
1750
 
1594
1751
  def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
1595
1752
  in_states = self._cached_map_dim_to_in_states.get(cache_key)
@@ -1647,16 +1804,22 @@ class StatefulMapping(StatefulFunction):
1647
1804
  for st, val in zip(rand_states, rand_recover_vals):
1648
1805
  st.restore_value(val)
1649
1806
 
1650
- def _wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
1651
- batch_size = self._infer_batch_size(args, self.in_axes)
1807
+ def __wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
1808
+ if len(kwargs):
1809
+ raise NotImplementedError(
1810
+ 'StatefulMapping currently does not support keyword arguments.'
1811
+ )
1812
+
1813
+ batch_size = self.__infer_batch_size(args, self.in_axes)
1652
1814
  cache_key = self.get_arg_cache_key(*args, **kwargs)
1653
- self._cached_map_batch_size.set(cache_key, batch_size)
1654
1815
  if cache_key not in self._cached_map_state_trace:
1816
+ self._rand_value = RandomState._batch_keys(batch_size)
1817
+ self._cached_map_batch_size.set(cache_key, batch_size)
1655
1818
  self.__eval(cache_key, *args, **kwargs)
1656
1819
 
1657
1820
  def fn_to_map(origin_args, rand_st, *non_rand_st):
1658
1821
  self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
1659
- out = self.origin_fun(*origin_args[0], **origin_args[1])
1822
+ out = self.origin_fun(*origin_args)
1660
1823
  return out, *self.__get_out_state_vals(cache_key)[1]
1661
1824
 
1662
1825
  in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
@@ -1664,12 +1827,12 @@ class StatefulMapping(StatefulFunction):
1664
1827
  rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
1665
1828
  mapped_fn = self.mapping_fn(
1666
1829
  fn_to_map,
1667
- in_axes=(self.in_axes, 0) + in_axes,
1830
+ in_axes=(self.in_axes, 0 if len(rand_vals) else None) + in_axes,
1668
1831
  out_axes=(self.out_axes,) + out_axes,
1669
1832
  axis_size=self.axis_size,
1670
1833
  axis_name=self.axis_name,
1671
1834
  )
1672
- out_, *out_state_vals = mapped_fn((args, kwargs), rand_vals, *in_state_vals)
1835
+ out_, *out_state_vals = mapped_fn(args, rand_vals, *in_state_vals)
1673
1836
  self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
1674
1837
  return out_
1675
1838
 
@@ -1924,6 +2087,9 @@ def constant_fold_jaxpr(jaxpr: Jaxpr):
1924
2087
  return _partial_eval_jaxpr(jaxpr, {})
1925
2088
 
1926
2089
 
2090
+ _constant_fold_blacklist = {'broadcast_in_dim', 'broadcast'}
2091
+
2092
+
1927
2093
  def _partial_eval_jaxpr(jaxpr, env):
1928
2094
  env = env.copy()
1929
2095
  new_eqns = []
@@ -2008,9 +2174,3 @@ def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
2008
2174
  else:
2009
2175
  out = eqn.primitive.bind(*vals, **eqn.params)
2010
2176
  return out
2011
-
2012
-
2013
- _constant_fold_blacklist = {
2014
- 'broadcast_in_dim',
2015
- 'broadcast',
2016
- }
@@ -16,15 +16,18 @@
16
16
 
17
17
  import threading
18
18
  import unittest
19
+ import warnings
19
20
 
20
21
  import jax
21
22
  import jax.numpy as jnp
23
+ import jax.random as jr
22
24
  import pytest
23
25
 
24
26
  import brainstate
25
- from brainstate._error import BatchAxisError
26
27
  from brainstate._compatible_import import jaxpr_as_fun
28
+ from brainstate._error import BatchAxisError
27
29
  from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
30
+ from brainstate.util import filter as state_filter
28
31
 
29
32
 
30
33
  class TestMakeJaxpr(unittest.TestCase):
@@ -1508,3 +1511,124 @@ class TestStatefulFunctionRecompilation(unittest.TestCase):
1508
1511
  stats_after_recompile = sf.get_cache_stats()
1509
1512
  self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
1510
1513
 
1514
+
1515
+ class TestStatefulMapping(unittest.TestCase):
1516
+ def test_state_filters_and_caching(self):
1517
+ counter = brainstate.ShortTermState(jnp.zeros(3))
1518
+
1519
+ def accumulate(x):
1520
+ counter.value = counter.value + x
1521
+ return counter.value
1522
+
1523
+ mapper = brainstate.transform.StatefulMapping(
1524
+ accumulate,
1525
+ in_axes=0,
1526
+ out_axes=0,
1527
+ state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1528
+ state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1529
+ )
1530
+
1531
+ xs = jnp.asarray([1.0, 2.0, 3.0])
1532
+ result = mapper(xs)
1533
+ self.assertTrue(jnp.allclose(result, xs))
1534
+ self.assertTrue(jnp.allclose(counter.value, xs))
1535
+
1536
+ def test_random_state_restoration(self):
1537
+ rng_state = brainstate.random.RandomState(0)
1538
+
1539
+ def draw(_):
1540
+ key = rng_state.split_key()
1541
+ return jr.normal(key, ())
1542
+
1543
+ mapper = brainstate.transform.StatefulMapping(
1544
+ draw,
1545
+ in_axes=0,
1546
+ out_axes=0,
1547
+ )
1548
+
1549
+ xs = jnp.ones((4,))
1550
+ before = rng_state.value
1551
+ samples = mapper(xs)
1552
+ self.assertEqual(samples.shape, xs.shape)
1553
+ self.assertFalse(jnp.allclose(samples, jnp.repeat(samples[0], xs.shape[0])))
1554
+ self.assertTrue(jnp.array_equal(rng_state.value.shape, before.shape))
1555
+
1556
+ def test_inconsistent_batch_sizes_raise(self):
1557
+ tracker = brainstate.ShortTermState(jnp.array(0.0))
1558
+
1559
+ def combine(x, y):
1560
+ tracker.value = tracker.value + x + y
1561
+ return tracker.value
1562
+
1563
+ mapper = brainstate.transform.StatefulMapping(
1564
+ combine,
1565
+ in_axes=(0, 0),
1566
+ out_axes=0,
1567
+ state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1568
+ state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1569
+ )
1570
+
1571
+ with self.assertRaisesRegex(ValueError, "Inconsistent batch sizes"):
1572
+ mapper(jnp.ones((3,)), jnp.ones((4,)))
1573
+
1574
+ def test_unexpected_out_state_mapping_raise(self):
1575
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1576
+
1577
+ def mutate(x):
1578
+ leak.value = leak.value + x
1579
+ return x
1580
+
1581
+ mapper = brainstate.transform.StatefulMapping(
1582
+ mutate,
1583
+ in_axes=0,
1584
+ out_axes=0,
1585
+ state_in_axes={},
1586
+ state_out_axes={},
1587
+ unexpected_out_state_mapping='raise',
1588
+ )
1589
+
1590
+ with self.assertRaises(BatchAxisError):
1591
+ mapper(jnp.ones((2,)))
1592
+
1593
+ def test_unexpected_out_state_mapping_warn(self):
1594
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1595
+
1596
+ def mutate(x):
1597
+ leak.value = leak.value + x
1598
+ return x
1599
+
1600
+ mapper = brainstate.transform.StatefulMapping(
1601
+ mutate,
1602
+ in_axes=0,
1603
+ out_axes=0,
1604
+ state_in_axes={},
1605
+ state_out_axes={},
1606
+ unexpected_out_state_mapping='warn',
1607
+ )
1608
+
1609
+ with pytest.warns(UserWarning):
1610
+ mapper(jnp.ones((2,)))
1611
+ self.assertTrue(jnp.allclose(leak.value, 1.0))
1612
+
1613
+ def test_unexpected_out_state_mapping_ignore(self):
1614
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1615
+
1616
+ def mutate(x):
1617
+ leak.value = leak.value + x
1618
+ return x
1619
+
1620
+ mapper = brainstate.transform.StatefulMapping(
1621
+ mutate,
1622
+ in_axes=0,
1623
+ out_axes=0,
1624
+ state_in_axes={},
1625
+ state_out_axes={},
1626
+ unexpected_out_state_mapping='ignore',
1627
+ )
1628
+
1629
+ with warnings.catch_warnings(record=True) as caught:
1630
+ warnings.simplefilter('always')
1631
+ mapper(jnp.ones((2,)))
1632
+ self.assertEqual(len(caught), 0)
1633
+ self.assertTrue(jnp.allclose(leak.value, 1.0))
1634
+