brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -55,6 +55,7 @@ import functools
|
|
55
55
|
import inspect
|
56
56
|
import operator
|
57
57
|
import threading
|
58
|
+
import warnings
|
58
59
|
from collections import OrderedDict, defaultdict
|
59
60
|
from collections.abc import Hashable, Iterable, Sequence
|
60
61
|
from collections.abc import MutableSet
|
@@ -66,7 +67,6 @@ import jax.numpy as jnp
|
|
66
67
|
from jax._src import source_info_util
|
67
68
|
from jax._src.linear_util import annotate
|
68
69
|
from jax._src.traceback_util import api_boundary
|
69
|
-
from jax._src.util import memoize
|
70
70
|
from jax.api_util import shaped_abstractify
|
71
71
|
from jax.extend.linear_util import transformation_with_aux
|
72
72
|
from jax.interpreters import partial_eval as pe
|
@@ -75,6 +75,7 @@ from brainstate._compatible_import import (
|
|
75
75
|
ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
|
76
76
|
Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
|
77
77
|
)
|
78
|
+
from brainstate._error import BatchAxisError
|
78
79
|
from brainstate._state import State, StateTraceStack
|
79
80
|
from brainstate._utils import set_module_as
|
80
81
|
from brainstate.random import RandomState
|
@@ -1409,6 +1410,122 @@ def make_jaxpr(
|
|
1409
1410
|
|
1410
1411
|
|
1411
1412
|
class StatefulMapping(StatefulFunction):
|
1413
|
+
"""
|
1414
|
+
Vectorized wrapper that preserves BrainState state semantics during mapping.
|
1415
|
+
|
1416
|
+
``StatefulMapping`` extends JAX mapping transforms (such as :func:`jax.vmap`
|
1417
|
+
and :func:`jax.pmap`) with awareness of :class:`~brainstate.State`
|
1418
|
+
instances. It tracks state reads and writes across the mapped axis,
|
1419
|
+
ensures deterministic random-number handling, and restores side effects
|
1420
|
+
after each batched execution. The helper is typically constructed by
|
1421
|
+
:func:`brainstate.transform.vmap` or :func:`brainstate.transform.pmap`, but
|
1422
|
+
it can also be instantiated directly for custom mapping primitives.
|
1423
|
+
|
1424
|
+
Parameters
|
1425
|
+
----------
|
1426
|
+
fun : callable
|
1427
|
+
Stateless callable to be wrapped. The callable may close over
|
1428
|
+
:class:`~brainstate.State` objects that should be tracked during the
|
1429
|
+
mapping transform.
|
1430
|
+
in_axes : int, tuple of int, or None, default 0
|
1431
|
+
Alignment of the mapped axis per positional argument, following the
|
1432
|
+
semantics of :func:`jax.vmap`. Arguments mapped with ``None`` are treated
|
1433
|
+
as static.
|
1434
|
+
out_axes : int, tuple of int, or None, default 0
|
1435
|
+
Placement of the mapped axis in the return value, consistent with JAX
|
1436
|
+
mapping primitives.
|
1437
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
1438
|
+
Specification of input states that participate in the mapped axis. A
|
1439
|
+
dictionary maps axis identifiers to :mod:`brainstate.util.filter`
|
1440
|
+
predicates; passing a single filter applies it to axis ``0``. Values are
|
1441
|
+
normalized via :func:`brainstate.util.filter.to_predicate`.
|
1442
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
1443
|
+
Specification of state outputs to scatter back along the mapped axis.
|
1444
|
+
Uses the same semantics and normalization as ``state_in_axes``.
|
1445
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
1446
|
+
Strategy for handling states written during the mapped call that are not
|
1447
|
+
captured by ``state_out_axes``.
|
1448
|
+
axis_size : int, optional
|
1449
|
+
Explicit size of the mapped axis. When omitted, the size is inferred
|
1450
|
+
from the mapped arguments.
|
1451
|
+
axis_name : hashable, optional
|
1452
|
+
Name for the mapped axis so that collective primitives can target it.
|
1453
|
+
name : str, optional
|
1454
|
+
Human-readable identifier for diagnostics and debugging.
|
1455
|
+
mapping_fn : callable, default ``jax.vmap``
|
1456
|
+
Mapping primitive that executes ``fun``. The callable must accept the
|
1457
|
+
``in_axes`` and ``out_axes`` keyword arguments used by :func:`jax.vmap`.
|
1458
|
+
|
1459
|
+
Attributes
|
1460
|
+
----------
|
1461
|
+
origin_fun : callable
|
1462
|
+
Original Python callable wrapped by the mapping helper.
|
1463
|
+
in_axes : int, tuple of int, or None
|
1464
|
+
Mapping specification for positional arguments.
|
1465
|
+
out_axes : int, tuple of int, or None
|
1466
|
+
Mapping specification for the return value.
|
1467
|
+
state_in_axes : dict[AxisName, Predicate]
|
1468
|
+
Normalized predicates describing which states to batch on input.
|
1469
|
+
state_out_axes : dict[AxisName, Predicate]
|
1470
|
+
Normalized predicates describing which states to batch on output.
|
1471
|
+
axis_size : int or None
|
1472
|
+
Size of the mapped axis, if explicitly provided.
|
1473
|
+
axis_name : hashable or None
|
1474
|
+
Axis identifier forwarded to collective primitives.
|
1475
|
+
mapping_fn : callable
|
1476
|
+
Mapping primitive responsible for executing ``fun``.
|
1477
|
+
|
1478
|
+
Raises
|
1479
|
+
------
|
1480
|
+
TypeError
|
1481
|
+
If ``in_axes`` has an unsupported type.
|
1482
|
+
ValueError
|
1483
|
+
If batch dimensions are inconsistent or cannot be inferred.
|
1484
|
+
RuntimeError
|
1485
|
+
If tracing or executing the mapped function fails.
|
1486
|
+
|
1487
|
+
Notes
|
1488
|
+
-----
|
1489
|
+
Random states (for example :class:`~brainstate.RandomState`) encountered
|
1490
|
+
during execution are automatically split along the mapped axis and restored
|
1491
|
+
afterwards; this behaviour cannot be disabled. The wrapper caches inferred
|
1492
|
+
state placements, batch sizes, and trace stacks keyed by abstract argument
|
1493
|
+
signatures so repeated calls with the same structure avoid re-tracing.
|
1494
|
+
|
1495
|
+
Examples
|
1496
|
+
--------
|
1497
|
+
.. code-block:: python
|
1498
|
+
|
1499
|
+
>>> import brainstate
|
1500
|
+
>>> import jax.numpy as jnp
|
1501
|
+
>>> from brainstate.util.filter import OfType
|
1502
|
+
>>>
|
1503
|
+
>>> counter = brainstate.ShortTermState(jnp.array(0.0))
|
1504
|
+
>>>
|
1505
|
+
>>> def accumulate(x):
|
1506
|
+
... counter.value = counter.value + x
|
1507
|
+
... return counter.value
|
1508
|
+
>>>
|
1509
|
+
>>> batched_accumulate = brainstate.transform.StatefulMapping(
|
1510
|
+
... accumulate,
|
1511
|
+
... in_axes=0,
|
1512
|
+
... out_axes=0,
|
1513
|
+
... state_in_axes={0: OfType(brainstate.ShortTermState)},
|
1514
|
+
... state_out_axes={0: OfType(brainstate.ShortTermState)},
|
1515
|
+
... name="batched_accumulate",
|
1516
|
+
... )
|
1517
|
+
>>>
|
1518
|
+
>>> xs = jnp.ones((3,))
|
1519
|
+
>>> batched_accumulate(xs)
|
1520
|
+
Array([1., 2., 3.], dtype=float32)
|
1521
|
+
>>> counter.value
|
1522
|
+
Array(3., dtype=float32)
|
1523
|
+
|
1524
|
+
See Also
|
1525
|
+
--------
|
1526
|
+
brainstate.transform.vmap : Convenience API returning a ``StatefulMapping``.
|
1527
|
+
brainstate.transform.pmap : Device-mapped variant aware of BrainState states.
|
1528
|
+
"""
|
1412
1529
|
__module__ = "brainstate.transform"
|
1413
1530
|
|
1414
1531
|
def __init__(
|
@@ -1418,11 +1535,13 @@ class StatefulMapping(StatefulFunction):
|
|
1418
1535
|
out_axes: Union[int, Tuple[int, ...], None] = 0,
|
1419
1536
|
state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
|
1420
1537
|
state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
|
1421
|
-
|
1538
|
+
unexpected_out_state_mapping: str = 'raise',
|
1539
|
+
# JIT specific parameters
|
1422
1540
|
static_argnums: Union[int, Iterable[int]] = (),
|
1423
1541
|
static_argnames: Union[str, Iterable[str]] = (),
|
1424
1542
|
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
1425
1543
|
abstracted_axes: Optional[Any] = None,
|
1544
|
+
return_only_write: bool = True,
|
1426
1545
|
# mapping specific parameters
|
1427
1546
|
axis_size: Optional[int] = None,
|
1428
1547
|
axis_name: AxisName | None = None,
|
@@ -1430,16 +1549,18 @@ class StatefulMapping(StatefulFunction):
|
|
1430
1549
|
# mapping function
|
1431
1550
|
mapping_fn: Callable = jax.vmap,
|
1432
1551
|
):
|
1433
|
-
self.origin_fun = fun
|
1434
1552
|
super().__init__(
|
1435
|
-
fun=self.
|
1553
|
+
fun=self.__wrapped_fun,
|
1436
1554
|
static_argnums=static_argnums,
|
1437
1555
|
static_argnames=static_argnames,
|
1438
1556
|
axis_env=axis_env,
|
1439
1557
|
abstracted_axes=abstracted_axes,
|
1558
|
+
return_only_write=return_only_write,
|
1440
1559
|
name=name,
|
1441
|
-
return_only_write=False,
|
1442
1560
|
)
|
1561
|
+
|
1562
|
+
self.name = name
|
1563
|
+
self.origin_fun = fun
|
1443
1564
|
self.in_axes = in_axes
|
1444
1565
|
self.out_axes = out_axes
|
1445
1566
|
if state_in_axes is None:
|
@@ -1459,6 +1580,7 @@ class StatefulMapping(StatefulFunction):
|
|
1459
1580
|
self.axis_size = axis_size
|
1460
1581
|
self.axis_name = axis_name
|
1461
1582
|
self.mapping_fn = mapping_fn
|
1583
|
+
self.unexpected_out_state_mapping = unexpected_out_state_mapping
|
1462
1584
|
|
1463
1585
|
# Cache for discovered state-to-axis mappings
|
1464
1586
|
self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
|
@@ -1466,12 +1588,7 @@ class StatefulMapping(StatefulFunction):
|
|
1466
1588
|
self._cached_map_state_trace = _BoundedCache(maxsize=128)
|
1467
1589
|
self._cached_map_batch_size = _BoundedCache(maxsize=128)
|
1468
1590
|
|
1469
|
-
def
|
1470
|
-
if in_axes is None:
|
1471
|
-
raise ValueError("Cannot infer batch size when in_axes is None")
|
1472
|
-
|
1473
|
-
batch_sizes = []
|
1474
|
-
|
1591
|
+
def __infer_batch_size(self, args, in_axes):
|
1475
1592
|
def get_batch_size_from_arg(arg_, axis_):
|
1476
1593
|
if axis_ is None:
|
1477
1594
|
return None
|
@@ -1490,6 +1607,7 @@ class StatefulMapping(StatefulFunction):
|
|
1490
1607
|
sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
|
1491
1608
|
return sizes[0] if sizes else None
|
1492
1609
|
|
1610
|
+
batch_sizes = []
|
1493
1611
|
if isinstance(in_axes, int):
|
1494
1612
|
# All args batched along the same axis
|
1495
1613
|
for arg in args:
|
@@ -1506,6 +1624,8 @@ class StatefulMapping(StatefulFunction):
|
|
1506
1624
|
size = get_batch_size_from_arg(arg, axis)
|
1507
1625
|
if size is not None:
|
1508
1626
|
batch_sizes.append(size)
|
1627
|
+
elif in_axes is None:
|
1628
|
+
pass
|
1509
1629
|
else:
|
1510
1630
|
raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
|
1511
1631
|
|
@@ -1523,17 +1643,14 @@ class StatefulMapping(StatefulFunction):
|
|
1523
1643
|
|
1524
1644
|
return batch_sizes[0]
|
1525
1645
|
|
1526
|
-
def __new_batch_arg(self, batch_size: int, dim_to_states: dict):
|
1527
|
-
trace = jax.core.trace_ctx.trace
|
1528
|
-
assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
|
1529
|
-
|
1646
|
+
def __new_batch_arg(self, trace, batch_size: int, dim_to_states: dict):
|
1530
1647
|
def wrapper(x):
|
1531
1648
|
if isinstance(x, RandomState):
|
1532
|
-
idx =
|
1649
|
+
idx = lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current())
|
1533
1650
|
dim_to_states['random'].append(x)
|
1534
|
-
return to_elt(trace, idx,
|
1651
|
+
return to_elt(trace, idx, self._rand_value, 0)
|
1535
1652
|
for dim, filter_ in self.state_in_axes.items():
|
1536
|
-
idx =
|
1653
|
+
idx = lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current())
|
1537
1654
|
if filter_(tuple(), x):
|
1538
1655
|
dim_to_states[dim].append(x)
|
1539
1656
|
return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
|
@@ -1541,41 +1658,81 @@ class StatefulMapping(StatefulFunction):
|
|
1541
1658
|
|
1542
1659
|
return wrapper
|
1543
1660
|
|
1544
|
-
def
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1661
|
+
def __find_batch_dim(self, st):
|
1662
|
+
leaves = jax.tree.leaves(st._value)
|
1663
|
+
batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
|
1664
|
+
if len(batch_dims) != 1:
|
1665
|
+
raise ValueError(
|
1666
|
+
f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
|
1667
|
+
"All leaves must have the same batch dimension."
|
1550
1668
|
)
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1669
|
+
dim = batch_dims.pop()
|
1670
|
+
return dim
|
1671
|
+
|
1672
|
+
def __fn_to_eval(self, cache_key, *new_args, **new_kwargs):
|
1673
|
+
# state trace
|
1674
|
+
trace = jax.core.trace_ctx.trace
|
1675
|
+
assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
|
1676
|
+
dim_to_in_states = defaultdict(list)
|
1677
|
+
state_trace = StateTraceStack(name=self.name)
|
1678
|
+
state_trace.set_new_arg(
|
1679
|
+
self.__new_batch_arg(trace, self._cached_map_batch_size.get(cache_key), dim_to_in_states)
|
1680
|
+
)
|
1681
|
+
self._cached_map_state_trace.set(cache_key, state_trace)
|
1682
|
+
|
1683
|
+
# call functions
|
1684
|
+
with state_trace:
|
1685
|
+
out_ = self.origin_fun(*new_args, **new_kwargs)
|
1686
|
+
|
1687
|
+
# cache vmapped in states
|
1688
|
+
self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states.copy())
|
1689
|
+
mapped_in_states = set([id(v) for vv in dim_to_in_states.values() for v in vv])
|
1690
|
+
|
1691
|
+
# vmapped out states
|
1692
|
+
out_states = defaultdict(list)
|
1693
|
+
out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
|
1694
|
+
for st in state_trace.states:
|
1695
|
+
if isinstance(st, RandomState):
|
1696
|
+
continue
|
1697
|
+
find = False
|
1698
|
+
for dim, filter_ in self.state_out_axes.items():
|
1699
|
+
if filter_(tuple(), st):
|
1700
|
+
out_states[dim].append(st)
|
1701
|
+
find = True
|
1702
|
+
break
|
1703
|
+
if find:
|
1704
|
+
continue
|
1705
|
+
dim = self.__find_batch_dim(st)
|
1706
|
+
if dim is None or id(st) in mapped_in_states:
|
1707
|
+
out_states[dim].append(st)
|
1708
|
+
else:
|
1709
|
+
if self.unexpected_out_state_mapping == 'raise':
|
1710
|
+
st.raise_error_with_source_info(
|
1711
|
+
BatchAxisError(
|
1712
|
+
f'State\n {st} \n was not expected to be batched on output. '
|
1713
|
+
'Please adjust state_out_axes or set unexpected_out_state_mapping to "warn" or "ignore".'
|
1571
1714
|
)
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1715
|
+
)
|
1716
|
+
elif self.unexpected_out_state_mapping == 'warn':
|
1717
|
+
warnings.warn(
|
1718
|
+
f'State\n {st} \n was not expected to be batched on output. '
|
1719
|
+
f'Please adjust state_out_axes or set unexpected_out_state_mapping to "ignore".',
|
1720
|
+
UserWarning,
|
1721
|
+
)
|
1722
|
+
out_states[dim].append(st)
|
1723
|
+
elif self.unexpected_out_state_mapping == 'ignore':
|
1724
|
+
out_states[dim].append(st)
|
1725
|
+
else:
|
1726
|
+
raise ValueError(
|
1727
|
+
'Invalid value for unexpected_out_state_mapping: '
|
1728
|
+
f'{self.unexpected_out_state_mapping}. Must be "raise", "warn", or "ignore".'
|
1729
|
+
)
|
1730
|
+
self._cached_map_dim_to_out_states.set(cache_key, out_states)
|
1575
1731
|
|
1732
|
+
def __eval(self, cache_key, *args, **kwargs):
|
1576
1733
|
try:
|
1577
1734
|
jax.vmap(
|
1578
|
-
|
1735
|
+
functools.partial(self.__fn_to_eval, cache_key),
|
1579
1736
|
in_axes=self.in_axes,
|
1580
1737
|
out_axes=self.out_axes,
|
1581
1738
|
axis_name=self.axis_name,
|
@@ -1589,7 +1746,7 @@ class StatefulMapping(StatefulFunction):
|
|
1589
1746
|
self._cached_map_dim_to_in_states.pop(cache_key, None)
|
1590
1747
|
self._cached_map_dim_to_out_states.pop(cache_key, None)
|
1591
1748
|
self._cached_map_batch_size.pop(cache_key, None)
|
1592
|
-
raise
|
1749
|
+
raise e
|
1593
1750
|
|
1594
1751
|
def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
|
1595
1752
|
in_states = self._cached_map_dim_to_in_states.get(cache_key)
|
@@ -1647,16 +1804,22 @@ class StatefulMapping(StatefulFunction):
|
|
1647
1804
|
for st, val in zip(rand_states, rand_recover_vals):
|
1648
1805
|
st.restore_value(val)
|
1649
1806
|
|
1650
|
-
def
|
1651
|
-
|
1807
|
+
def __wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
|
1808
|
+
if len(kwargs):
|
1809
|
+
raise NotImplementedError(
|
1810
|
+
'StatefulMapping currently does not support keyword arguments.'
|
1811
|
+
)
|
1812
|
+
|
1813
|
+
batch_size = self.__infer_batch_size(args, self.in_axes)
|
1652
1814
|
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
1653
|
-
self._cached_map_batch_size.set(cache_key, batch_size)
|
1654
1815
|
if cache_key not in self._cached_map_state_trace:
|
1816
|
+
self._rand_value = RandomState._batch_keys(batch_size)
|
1817
|
+
self._cached_map_batch_size.set(cache_key, batch_size)
|
1655
1818
|
self.__eval(cache_key, *args, **kwargs)
|
1656
1819
|
|
1657
1820
|
def fn_to_map(origin_args, rand_st, *non_rand_st):
|
1658
1821
|
self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
|
1659
|
-
out = self.origin_fun(*origin_args
|
1822
|
+
out = self.origin_fun(*origin_args)
|
1660
1823
|
return out, *self.__get_out_state_vals(cache_key)[1]
|
1661
1824
|
|
1662
1825
|
in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
|
@@ -1664,12 +1827,12 @@ class StatefulMapping(StatefulFunction):
|
|
1664
1827
|
rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
|
1665
1828
|
mapped_fn = self.mapping_fn(
|
1666
1829
|
fn_to_map,
|
1667
|
-
in_axes=(self.in_axes, 0) + in_axes,
|
1830
|
+
in_axes=(self.in_axes, 0 if len(rand_vals) else None) + in_axes,
|
1668
1831
|
out_axes=(self.out_axes,) + out_axes,
|
1669
1832
|
axis_size=self.axis_size,
|
1670
1833
|
axis_name=self.axis_name,
|
1671
1834
|
)
|
1672
|
-
out_, *out_state_vals = mapped_fn(
|
1835
|
+
out_, *out_state_vals = mapped_fn(args, rand_vals, *in_state_vals)
|
1673
1836
|
self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
|
1674
1837
|
return out_
|
1675
1838
|
|
@@ -1924,6 +2087,9 @@ def constant_fold_jaxpr(jaxpr: Jaxpr):
|
|
1924
2087
|
return _partial_eval_jaxpr(jaxpr, {})
|
1925
2088
|
|
1926
2089
|
|
2090
|
+
_constant_fold_blacklist = {'broadcast_in_dim', 'broadcast'}
|
2091
|
+
|
2092
|
+
|
1927
2093
|
def _partial_eval_jaxpr(jaxpr, env):
|
1928
2094
|
env = env.copy()
|
1929
2095
|
new_eqns = []
|
@@ -2008,9 +2174,3 @@ def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
|
|
2008
2174
|
else:
|
2009
2175
|
out = eqn.primitive.bind(*vals, **eqn.params)
|
2010
2176
|
return out
|
2011
|
-
|
2012
|
-
|
2013
|
-
_constant_fold_blacklist = {
|
2014
|
-
'broadcast_in_dim',
|
2015
|
-
'broadcast',
|
2016
|
-
}
|
@@ -16,15 +16,18 @@
|
|
16
16
|
|
17
17
|
import threading
|
18
18
|
import unittest
|
19
|
+
import warnings
|
19
20
|
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
23
|
+
import jax.random as jr
|
22
24
|
import pytest
|
23
25
|
|
24
26
|
import brainstate
|
25
|
-
from brainstate._error import BatchAxisError
|
26
27
|
from brainstate._compatible_import import jaxpr_as_fun
|
28
|
+
from brainstate._error import BatchAxisError
|
27
29
|
from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
|
30
|
+
from brainstate.util import filter as state_filter
|
28
31
|
|
29
32
|
|
30
33
|
class TestMakeJaxpr(unittest.TestCase):
|
@@ -1508,3 +1511,124 @@ class TestStatefulFunctionRecompilation(unittest.TestCase):
|
|
1508
1511
|
stats_after_recompile = sf.get_cache_stats()
|
1509
1512
|
self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
|
1510
1513
|
|
1514
|
+
|
1515
|
+
class TestStatefulMapping(unittest.TestCase):
|
1516
|
+
def test_state_filters_and_caching(self):
|
1517
|
+
counter = brainstate.ShortTermState(jnp.zeros(3))
|
1518
|
+
|
1519
|
+
def accumulate(x):
|
1520
|
+
counter.value = counter.value + x
|
1521
|
+
return counter.value
|
1522
|
+
|
1523
|
+
mapper = brainstate.transform.StatefulMapping(
|
1524
|
+
accumulate,
|
1525
|
+
in_axes=0,
|
1526
|
+
out_axes=0,
|
1527
|
+
state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1528
|
+
state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1529
|
+
)
|
1530
|
+
|
1531
|
+
xs = jnp.asarray([1.0, 2.0, 3.0])
|
1532
|
+
result = mapper(xs)
|
1533
|
+
self.assertTrue(jnp.allclose(result, xs))
|
1534
|
+
self.assertTrue(jnp.allclose(counter.value, xs))
|
1535
|
+
|
1536
|
+
def test_random_state_restoration(self):
|
1537
|
+
rng_state = brainstate.random.RandomState(0)
|
1538
|
+
|
1539
|
+
def draw(_):
|
1540
|
+
key = rng_state.split_key()
|
1541
|
+
return jr.normal(key, ())
|
1542
|
+
|
1543
|
+
mapper = brainstate.transform.StatefulMapping(
|
1544
|
+
draw,
|
1545
|
+
in_axes=0,
|
1546
|
+
out_axes=0,
|
1547
|
+
)
|
1548
|
+
|
1549
|
+
xs = jnp.ones((4,))
|
1550
|
+
before = rng_state.value
|
1551
|
+
samples = mapper(xs)
|
1552
|
+
self.assertEqual(samples.shape, xs.shape)
|
1553
|
+
self.assertFalse(jnp.allclose(samples, jnp.repeat(samples[0], xs.shape[0])))
|
1554
|
+
self.assertTrue(jnp.array_equal(rng_state.value.shape, before.shape))
|
1555
|
+
|
1556
|
+
def test_inconsistent_batch_sizes_raise(self):
|
1557
|
+
tracker = brainstate.ShortTermState(jnp.array(0.0))
|
1558
|
+
|
1559
|
+
def combine(x, y):
|
1560
|
+
tracker.value = tracker.value + x + y
|
1561
|
+
return tracker.value
|
1562
|
+
|
1563
|
+
mapper = brainstate.transform.StatefulMapping(
|
1564
|
+
combine,
|
1565
|
+
in_axes=(0, 0),
|
1566
|
+
out_axes=0,
|
1567
|
+
state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1568
|
+
state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1569
|
+
)
|
1570
|
+
|
1571
|
+
with self.assertRaisesRegex(ValueError, "Inconsistent batch sizes"):
|
1572
|
+
mapper(jnp.ones((3,)), jnp.ones((4,)))
|
1573
|
+
|
1574
|
+
def test_unexpected_out_state_mapping_raise(self):
|
1575
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1576
|
+
|
1577
|
+
def mutate(x):
|
1578
|
+
leak.value = leak.value + x
|
1579
|
+
return x
|
1580
|
+
|
1581
|
+
mapper = brainstate.transform.StatefulMapping(
|
1582
|
+
mutate,
|
1583
|
+
in_axes=0,
|
1584
|
+
out_axes=0,
|
1585
|
+
state_in_axes={},
|
1586
|
+
state_out_axes={},
|
1587
|
+
unexpected_out_state_mapping='raise',
|
1588
|
+
)
|
1589
|
+
|
1590
|
+
with self.assertRaises(BatchAxisError):
|
1591
|
+
mapper(jnp.ones((2,)))
|
1592
|
+
|
1593
|
+
def test_unexpected_out_state_mapping_warn(self):
|
1594
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1595
|
+
|
1596
|
+
def mutate(x):
|
1597
|
+
leak.value = leak.value + x
|
1598
|
+
return x
|
1599
|
+
|
1600
|
+
mapper = brainstate.transform.StatefulMapping(
|
1601
|
+
mutate,
|
1602
|
+
in_axes=0,
|
1603
|
+
out_axes=0,
|
1604
|
+
state_in_axes={},
|
1605
|
+
state_out_axes={},
|
1606
|
+
unexpected_out_state_mapping='warn',
|
1607
|
+
)
|
1608
|
+
|
1609
|
+
with pytest.warns(UserWarning):
|
1610
|
+
mapper(jnp.ones((2,)))
|
1611
|
+
self.assertTrue(jnp.allclose(leak.value, 1.0))
|
1612
|
+
|
1613
|
+
def test_unexpected_out_state_mapping_ignore(self):
|
1614
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1615
|
+
|
1616
|
+
def mutate(x):
|
1617
|
+
leak.value = leak.value + x
|
1618
|
+
return x
|
1619
|
+
|
1620
|
+
mapper = brainstate.transform.StatefulMapping(
|
1621
|
+
mutate,
|
1622
|
+
in_axes=0,
|
1623
|
+
out_axes=0,
|
1624
|
+
state_in_axes={},
|
1625
|
+
state_out_axes={},
|
1626
|
+
unexpected_out_state_mapping='ignore',
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
with warnings.catch_warnings(record=True) as caught:
|
1630
|
+
warnings.simplefilter('always')
|
1631
|
+
mapper(jnp.ones((2,)))
|
1632
|
+
self.assertEqual(len(caught), 0)
|
1633
|
+
self.assertTrue(jnp.allclose(leak.value, 1.0))
|
1634
|
+
|