brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from typing import Any, TypeVar, Callable, Sequence, Union
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
from brainstate.graph import graph_to_tree, tree_to_graph
|
24
|
+
from brainstate.random import DEFAULT, RandomState
|
25
|
+
from ._random import restore_rngs
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'eval_shape',
|
29
|
+
]
|
30
|
+
|
31
|
+
A = TypeVar('A')
|
32
|
+
|
33
|
+
|
34
|
+
def eval_shape(
|
35
|
+
fn: Callable[..., A],
|
36
|
+
*args: Any,
|
37
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
38
|
+
**kwargs: Any,
|
39
|
+
) -> A:
|
40
|
+
"""
|
41
|
+
Compute the shape/dtype of ``fn`` without any FLOPs.
|
42
|
+
|
43
|
+
Here's an example::
|
44
|
+
|
45
|
+
>>> import brainstate as bst
|
46
|
+
>>> class MLP:
|
47
|
+
... def __init__(self, n_in, n_mid, n_out):
|
48
|
+
... self.dense1 = bst.nn.Linear(n_in, n_mid)
|
49
|
+
... self.dense2 = bst.nn.Linear(n_mid, n_out)
|
50
|
+
|
51
|
+
>>> r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
|
52
|
+
>>> r
|
53
|
+
MLP(
|
54
|
+
dense1=Linear(
|
55
|
+
in_size=(1,),
|
56
|
+
out_size=(2,),
|
57
|
+
w_mask=None,
|
58
|
+
weight=ParamState(
|
59
|
+
value={'bias': ShapeDtypeStruct(shape=(2,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(1, 2), dtype=float32)}
|
60
|
+
)
|
61
|
+
),
|
62
|
+
dense2=Linear(
|
63
|
+
in_size=(2,),
|
64
|
+
out_size=(3,),
|
65
|
+
w_mask=None,
|
66
|
+
weight=ParamState(
|
67
|
+
value={'bias': ShapeDtypeStruct(shape=(3,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(2, 3), dtype=float32)}
|
68
|
+
)
|
69
|
+
)
|
70
|
+
)
|
71
|
+
|
72
|
+
Args:
|
73
|
+
fn: The function whose output shape should be evaluated.
|
74
|
+
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
75
|
+
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
76
|
+
those types. Since only the ``shape`` and ``dtype`` attributes are
|
77
|
+
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
78
|
+
that duck-types as ndarrays (note however that duck-typed objects cannot
|
79
|
+
be namedtuples because those are treated as standard Python containers).
|
80
|
+
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
81
|
+
Python containers (pytrees) of those types. As in ``args``, array values
|
82
|
+
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
83
|
+
rngs: a :class:`RandomState` or a sequence of :class:`RandomState` objects
|
84
|
+
representing the random number generators to use. If not provided, the
|
85
|
+
default random number generator will be used.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
|
89
|
+
|
90
|
+
|
91
|
+
"""
|
92
|
+
|
93
|
+
@functools.wraps(fn)
|
94
|
+
@restore_rngs(rngs=rngs)
|
95
|
+
def _eval_shape_fn(*args_, **kwargs_):
|
96
|
+
args_, kwargs_ = tree_to_graph((args_, kwargs_))
|
97
|
+
out = fn(*args_, **kwargs_)
|
98
|
+
return graph_to_tree(out)
|
99
|
+
|
100
|
+
args, kwargs = graph_to_tree((args, kwargs))
|
101
|
+
out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
102
|
+
return tree_to_graph(out)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
import brainstate as bst
|
22
|
+
|
23
|
+
|
24
|
+
class TestEvalShape(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
class MLP(bst.nn.Module):
|
27
|
+
def __init__(self, n_in, n_mid, n_out):
|
28
|
+
super().__init__()
|
29
|
+
self.dense1 = bst.nn.Linear(n_in, n_mid)
|
30
|
+
self.dense2 = bst.nn.Linear(n_mid, n_out)
|
31
|
+
|
32
|
+
def __call__(self, x):
|
33
|
+
x = self.dense1(x)
|
34
|
+
x = bst.functional.relu(x)
|
35
|
+
x = self.dense2(x)
|
36
|
+
return x
|
37
|
+
|
38
|
+
r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
|
39
|
+
print(r)
|
40
|
+
print(bst.random.DEFAULT)
|
@@ -0,0 +1,525 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
import functools
|
20
|
+
from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Mapping, Tuple, Union, Optional
|
21
|
+
|
22
|
+
import jax
|
23
|
+
|
24
|
+
from brainstate.graph import (NodeStates, graph_to_tree, tree_to_graph, update_context)
|
25
|
+
from brainstate.graph._graph_convert import clear_non_graph_nodes
|
26
|
+
from brainstate.random import DEFAULT, RandomState
|
27
|
+
from brainstate.typing import Missing, Filter
|
28
|
+
from brainstate.util import NestedDict
|
29
|
+
from ._random import restore_rngs
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'StateAxes',
|
33
|
+
'vmap',
|
34
|
+
'pmap',
|
35
|
+
]
|
36
|
+
|
37
|
+
AxisName = Hashable
|
38
|
+
F = TypeVar("F", bound=Callable)
|
39
|
+
Index = int
|
40
|
+
Carry = TypeVar("Carry")
|
41
|
+
|
42
|
+
|
43
|
+
class StateAxes:
|
44
|
+
"""
|
45
|
+
A class to represent the axes of a state.
|
46
|
+
|
47
|
+
This class is used to control how graph nodes like Modules are vectorized or
|
48
|
+
parallelized by specifying the axes to be applied to substates of the graph
|
49
|
+
node given a Filter.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
filter_axes: A mapping from filters to axes. The axes can be an index, a carry or None.
|
53
|
+
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
filter_axes: Union[Mapping[Filter, Index | Carry | None], Iterable[Tuple[Filter, Index | Carry | None]]],
|
59
|
+
):
|
60
|
+
iterable = filter_axes.items() if isinstance(filter_axes, Mapping) else filter_axes
|
61
|
+
self._filters = tuple(filter_ for filter_, _ in iterable)
|
62
|
+
self._axes = tuple(axis for _, axis in iterable)
|
63
|
+
|
64
|
+
@property
|
65
|
+
def filters(self) -> Tuple[Filter, ...]:
|
66
|
+
return self._filters
|
67
|
+
|
68
|
+
@property
|
69
|
+
def axes(self) -> Tuple[Index | Carry | None, ...]:
|
70
|
+
return self._axes
|
71
|
+
|
72
|
+
def __repr__(self):
|
73
|
+
return f'StateAxes({dict(self.items())})'
|
74
|
+
|
75
|
+
def items(self):
|
76
|
+
return zip(self.filters, self.axes)
|
77
|
+
|
78
|
+
def __eq__(self, other):
|
79
|
+
return isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes
|
80
|
+
|
81
|
+
def __hash__(self):
|
82
|
+
return hash((self.filters, self.axes))
|
83
|
+
|
84
|
+
|
85
|
+
def _map_split_fn(ctx, path, prefix, x):
|
86
|
+
if isinstance(prefix, StateAxes):
|
87
|
+
return NodeStates.from_split(*ctx.treefy_split(x, *prefix.filters), metadata=prefix)
|
88
|
+
return NodeStates.from_split(*ctx.treefy_split(x), metadata=prefix)
|
89
|
+
|
90
|
+
|
91
|
+
@dataclasses.dataclass(eq=False)
|
92
|
+
class MapFn:
|
93
|
+
f: Callable[..., Any]
|
94
|
+
in_axes: Any
|
95
|
+
out_axes: Any
|
96
|
+
ctxtag: str
|
97
|
+
|
98
|
+
def __post_init__(self):
|
99
|
+
functools.update_wrapper(self, self.f)
|
100
|
+
|
101
|
+
def __call__(self, *pure_args: Tuple[Any, ...]):
|
102
|
+
# pytree to graph
|
103
|
+
args = tree_to_graph(pure_args, ctxtag=self.ctxtag)
|
104
|
+
|
105
|
+
# call the function
|
106
|
+
out = self.f(*args)
|
107
|
+
|
108
|
+
# graph to pytree
|
109
|
+
args_out = clear_non_graph_nodes(args)
|
110
|
+
pure_args_out, pure_out = graph_to_tree(
|
111
|
+
(args_out, out),
|
112
|
+
prefix=(self.in_axes, self.out_axes),
|
113
|
+
split_fn=_map_split_fn,
|
114
|
+
ctxtag=self.ctxtag,
|
115
|
+
)
|
116
|
+
return pure_args_out, pure_out
|
117
|
+
|
118
|
+
|
119
|
+
def _map_transform(
|
120
|
+
ctxtag,
|
121
|
+
transform,
|
122
|
+
f: F,
|
123
|
+
*,
|
124
|
+
in_axes: Optional[int | Sequence[Any]] = 0,
|
125
|
+
out_axes: Any = 0,
|
126
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
127
|
+
**transform_kwargs,
|
128
|
+
):
|
129
|
+
# jax in axes
|
130
|
+
jax_in_axes = jax.tree.map(
|
131
|
+
lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
|
132
|
+
in_axes,
|
133
|
+
)
|
134
|
+
|
135
|
+
# jax out axes
|
136
|
+
jax_out_axes = jax.tree.map(
|
137
|
+
lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
|
138
|
+
out_axes,
|
139
|
+
)
|
140
|
+
|
141
|
+
# mapped function
|
142
|
+
mapped_fn = transform(
|
143
|
+
MapFn(f, in_axes, out_axes, ctxtag),
|
144
|
+
in_axes=jax_in_axes,
|
145
|
+
out_axes=(jax_in_axes, jax_out_axes),
|
146
|
+
**transform_kwargs
|
147
|
+
)
|
148
|
+
|
149
|
+
@functools.wraps(f)
|
150
|
+
@restore_rngs(rngs=rngs) # restore the random key of default random number generator
|
151
|
+
@update_context(ctxtag)
|
152
|
+
def map_wrapper(*args):
|
153
|
+
# graph to pytree
|
154
|
+
pure_args = graph_to_tree(args, prefix=in_axes, split_fn=_map_split_fn, ctxtag=ctxtag)
|
155
|
+
|
156
|
+
# vmap with pytree
|
157
|
+
pure_args_out, pure_out = mapped_fn(*pure_args)
|
158
|
+
|
159
|
+
# pytree to graph
|
160
|
+
_args_out, out = tree_to_graph((pure_args_out, pure_out), ctxtag=ctxtag)
|
161
|
+
return out
|
162
|
+
|
163
|
+
return map_wrapper # type: ignore
|
164
|
+
|
165
|
+
|
166
|
+
def vmap(
|
167
|
+
fn: F | Missing = Missing(),
|
168
|
+
*,
|
169
|
+
in_axes: int | None | Sequence[Any] = 0,
|
170
|
+
out_axes: Any = 0,
|
171
|
+
axis_name: AxisName | None = None,
|
172
|
+
axis_size: int | None = None,
|
173
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
174
|
+
# brainstate specific arguments
|
175
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
176
|
+
) -> F | Callable[[F], F]:
|
177
|
+
"""
|
178
|
+
Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
179
|
+
|
180
|
+
The transformation :func:`vmap` is designed to work with ``pygraph`` structure
|
181
|
+
defined in the ``brainstate`` library. It is used to vectorize functions by
|
182
|
+
pushing the mapped axis down into primitive operations.
|
183
|
+
|
184
|
+
More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
|
185
|
+
|
186
|
+
|
187
|
+
These are several example usage::
|
188
|
+
|
189
|
+
>>> import brainstate as bst
|
190
|
+
>>> import jax.numpy as jnp
|
191
|
+
|
192
|
+
>>> model = bst.nn.Linear(2, 3)
|
193
|
+
>>> x = jnp.ones((5, 2))
|
194
|
+
|
195
|
+
>>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
196
|
+
... def forward(model, x):
|
197
|
+
... return model(x)
|
198
|
+
|
199
|
+
>>> y = forward(model, x)
|
200
|
+
>>> print(y.shape)
|
201
|
+
(5, 3)
|
202
|
+
|
203
|
+
Another example with a more complex model::
|
204
|
+
|
205
|
+
>>> class LinearEnsemble(bst.nn.Module):
|
206
|
+
... def __init__(self, n: int):
|
207
|
+
... super().__init__()
|
208
|
+
... self.n = n
|
209
|
+
... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
|
210
|
+
|
211
|
+
>>> model = LinearEnsemble(5)
|
212
|
+
>>> x = jnp.ones((2,))
|
213
|
+
|
214
|
+
>>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
215
|
+
... def forward(model, x):
|
216
|
+
... return jnp.dot(x, model.w.value)
|
217
|
+
|
218
|
+
>>> y = forward(model, x)
|
219
|
+
>>> print(y.shape)
|
220
|
+
(5, 3)
|
221
|
+
|
222
|
+
To control how different types of states are vectorized, ``StateAxes``
|
223
|
+
can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
|
224
|
+
applied to each substate given a filter. The following example shows how to
|
225
|
+
share the parameters between the ensemble members which keeping different
|
226
|
+
batch statistics and dropout random state::
|
227
|
+
|
228
|
+
>>> class Foo(bst.nn.Module):
|
229
|
+
... def __init__(self):
|
230
|
+
... super().__init__()
|
231
|
+
... self.a = bst.ParamState(jnp.arange(4))
|
232
|
+
... self.b = bst.ShortTermState(jnp.arange(4))
|
233
|
+
|
234
|
+
>>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
|
235
|
+
>>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
|
236
|
+
... def mul(foo):
|
237
|
+
... return foo.a.value * foo.b.value
|
238
|
+
|
239
|
+
>>> model = Foo()
|
240
|
+
>>> y = mul(model)
|
241
|
+
>>> print(y.shape)
|
242
|
+
(4, 4)
|
243
|
+
|
244
|
+
Args:
|
245
|
+
fn: Function to be mapped over additional axes.
|
246
|
+
in_axes: An integer, None, or sequence of values specifying which input
|
247
|
+
array axes to map over.
|
248
|
+
|
249
|
+
If each positional argument to ``fun`` is an array, then ``in_axes`` can
|
250
|
+
be an integer, a None, or a tuple of integers and Nones with length equal
|
251
|
+
to the number of positional arguments to ``fun``. An integer or ``None``
|
252
|
+
indicates which array axis to map over for all arguments (with ``None``
|
253
|
+
indicating not to map any axis), and a tuple indicates which axis to map
|
254
|
+
for each corresponding positional argument. Axis integers must be in the
|
255
|
+
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
|
256
|
+
dimensions (axes) of the corresponding input array.
|
257
|
+
|
258
|
+
If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
|
259
|
+
must be a sequence with length equal to the number of positional arguments to
|
260
|
+
``fun``, and for each argument the corresponding element of ``in_axes`` can
|
261
|
+
be a container with a matching pytree structure specifying the mapping of its
|
262
|
+
container elements. In other words, ``in_axes`` must be a container tree prefix
|
263
|
+
of the positional argument tuple passed to ``fun``. See this link for more detail:
|
264
|
+
https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
|
265
|
+
|
266
|
+
Either ``axis_size`` must be provided explicitly, or at least one
|
267
|
+
positional argument must have ``in_axes`` not None. The sizes of the
|
268
|
+
mapped input axes for all mapped positional arguments must all be equal.
|
269
|
+
|
270
|
+
Arguments passed as keywords are always mapped over their leading axis
|
271
|
+
(i.e. axis index 0).
|
272
|
+
|
273
|
+
See below for examples.
|
274
|
+
|
275
|
+
out_axes: An integer, None, or (nested) standard Python container
|
276
|
+
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
277
|
+
in the output. All outputs with a mapped axis must have a non-None
|
278
|
+
``out_axes`` specification. Axis integers must be in the range ``[-ndim,
|
279
|
+
ndim)`` for each output array, where ``ndim`` is the number of dimensions
|
280
|
+
(axes) of the array returned by the :func:`vmap`-ed function, which is one
|
281
|
+
more than the number of dimensions (axes) of the corresponding array
|
282
|
+
returned by ``fun``.
|
283
|
+
axis_name: Optional, a hashable Python object used to identify the mapped
|
284
|
+
axis so that parallel collectives can be applied.
|
285
|
+
axis_size: Optional, an integer indicating the size of the axis to be
|
286
|
+
mapped. If not provided, the mapped axis size is inferred from arguments.
|
287
|
+
spmd_axis_name: Optional, a hashable Python object or tuple of hashable
|
288
|
+
Python objects used to identify the mapped axis so that parallel collectives
|
289
|
+
can be applied. This is used to specify multiple axes to be mapped over
|
290
|
+
in a nested :func:`vmap` call. The length of the tuple must match the
|
291
|
+
number of nested :func:`vmap` calls. The first element of the tuple
|
292
|
+
corresponds to the outermost :func:`vmap` call, the second element to
|
293
|
+
the next outermost, and so on. If the tuple is not provided, the
|
294
|
+
``axis_name`` is used for all nested :func:`vmap` calls.
|
295
|
+
rngs: Optional, a random number generator or sequence of random number
|
296
|
+
generators to be used in the mapped function. These random number
|
297
|
+
generators are restored their random key after the mapped function is
|
298
|
+
executed.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
Batched/vectorized version of ``fun`` with arguments that correspond to
|
302
|
+
those of ``fun``, but with extra array axes at positions indicated by
|
303
|
+
``in_axes``, and a return value that corresponds to that of ``fun``, but
|
304
|
+
with extra array axes at positions indicated by ``out_axes``.
|
305
|
+
|
306
|
+
"""
|
307
|
+
if isinstance(fn, Missing):
|
308
|
+
return functools.partial(
|
309
|
+
vmap,
|
310
|
+
in_axes=in_axes,
|
311
|
+
out_axes=out_axes,
|
312
|
+
axis_name=axis_name,
|
313
|
+
axis_size=axis_size,
|
314
|
+
spmd_axis_name=spmd_axis_name,
|
315
|
+
rngs=rngs,
|
316
|
+
) # type: ignore[return-value]
|
317
|
+
|
318
|
+
return _map_transform(
|
319
|
+
'vmap', # ctxtag
|
320
|
+
jax.vmap,
|
321
|
+
fn,
|
322
|
+
in_axes=in_axes,
|
323
|
+
out_axes=out_axes,
|
324
|
+
axis_name=axis_name,
|
325
|
+
axis_size=axis_size,
|
326
|
+
spmd_axis_name=spmd_axis_name,
|
327
|
+
rngs=rngs
|
328
|
+
)
|
329
|
+
|
330
|
+
|
331
|
+
def pmap(
|
332
|
+
fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
|
333
|
+
axis_name: Optional[AxisName] = None,
|
334
|
+
*,
|
335
|
+
in_axes: Any = 0,
|
336
|
+
out_axes: Any = 0,
|
337
|
+
static_broadcasted_argnums: int | Iterable[int] = (),
|
338
|
+
devices: Optional[Sequence[jax.Device]] = None, # noqa: F811
|
339
|
+
backend: Optional[str] = None,
|
340
|
+
axis_size: Optional[int] = None,
|
341
|
+
donate_argnums: int | Iterable[int] = (),
|
342
|
+
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
343
|
+
# brainstate specific arguments
|
344
|
+
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
345
|
+
) -> Callable[[F], F] | F:
|
346
|
+
"""
|
347
|
+
Parallel map with support for collective operations.
|
348
|
+
|
349
|
+
The purpose of :py:func:`pmap` is to express single-program multiple-data
|
350
|
+
(SPMD) programs. Applying :py:func:`pmap` to a function will compile the
|
351
|
+
function with XLA (similarly to :py:func:`jit`), then execute it in parallel
|
352
|
+
on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
|
353
|
+
is comparable to :py:func:`vmap` because both transformations map a function
|
354
|
+
over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
|
355
|
+
mapped axis down into primitive operations, :py:func:`pmap` instead replicates
|
356
|
+
the function and executes each replica on its own XLA device in parallel.
|
357
|
+
|
358
|
+
The mapped axis size must be less than or equal to the number of local XLA
|
359
|
+
devices available, as returned by :py:func:`jax.local_device_count()` (unless
|
360
|
+
``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
|
361
|
+
product of the mapped axis sizes must be less than or equal to the number of
|
362
|
+
XLA devices.
|
363
|
+
|
364
|
+
More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
|
365
|
+
|
366
|
+
If there are 4 XLA devices available, the following example will execute
|
367
|
+
the function in parallel on each device::
|
368
|
+
|
369
|
+
|
370
|
+
>>> import brainstate as bst
|
371
|
+
>>> import jax.numpy as jnp
|
372
|
+
|
373
|
+
>>> model = bst.nn.Linear(2, 3)
|
374
|
+
>>> x = jnp.ones((4, 2))
|
375
|
+
|
376
|
+
>>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
377
|
+
... def forward(model, x):
|
378
|
+
... return model(x)
|
379
|
+
|
380
|
+
>>> y = forward(model, x)
|
381
|
+
>>> print(y.shape)
|
382
|
+
(4, 3)
|
383
|
+
|
384
|
+
Another example with a more complex model::
|
385
|
+
|
386
|
+
>>> class LinearEnsemble(bst.nn.Module):
|
387
|
+
... def __init__(self, n: int):
|
388
|
+
... super().__init__()
|
389
|
+
... self.n = n
|
390
|
+
... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
|
391
|
+
|
392
|
+
>>> model = LinearEnsemble(4)
|
393
|
+
>>> x = jnp.ones((2,))
|
394
|
+
|
395
|
+
>>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
396
|
+
... def forward(model, x):
|
397
|
+
... return jnp.dot(x, model.w.value)
|
398
|
+
|
399
|
+
>>> y = forward(model, x)
|
400
|
+
>>> print(y.shape)
|
401
|
+
(4, 3)
|
402
|
+
|
403
|
+
To control how different types of states are vectorized, ``StateAxes``
|
404
|
+
can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
|
405
|
+
applied to each substate given a filter. The following example shows how to
|
406
|
+
share the parameters between the ensemble members which keeping different
|
407
|
+
batch statistics and dropout random state::
|
408
|
+
|
409
|
+
>>> class Foo(bst.nn.Module):
|
410
|
+
... def __init__(self):
|
411
|
+
... super().__init__()
|
412
|
+
... self.a = bst.ParamState(jnp.arange(4))
|
413
|
+
... self.b = bst.ShortTermState(jnp.arange(4))
|
414
|
+
|
415
|
+
>>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
|
416
|
+
>>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
|
417
|
+
... def mul(foo):
|
418
|
+
... return foo.a.value * foo.b.value
|
419
|
+
|
420
|
+
>>> model = Foo()
|
421
|
+
>>> y = mul(model)
|
422
|
+
>>> print(y.shape)
|
423
|
+
(4, 4)
|
424
|
+
|
425
|
+
|
426
|
+
Args:
|
427
|
+
fn: Function to be mapped over argument axes. Its arguments and return
|
428
|
+
value should be arrays, scalars, or (nested) standard Python containers
|
429
|
+
(tuple/list/dict) thereof. Positional arguments indicated by
|
430
|
+
``static_broadcasted_argnums`` can be anything at all, provided they are
|
431
|
+
hashable and have an equality operation defined.
|
432
|
+
axis_name: Optional, a hashable Python object used to identify the mapped
|
433
|
+
axis so that parallel collectives can be applied.
|
434
|
+
in_axes: A non-negative integer, None, or nested Python container thereof
|
435
|
+
that specifies which axes of positional arguments to map over. Arguments
|
436
|
+
passed as keywords are always mapped over their leading axis (i.e. axis
|
437
|
+
index 0). See :py:func:`vmap` for details.
|
438
|
+
out_axes: A non-negative integer, None, or nested Python container thereof
|
439
|
+
indicating where the mapped axis should appear in the output. All outputs
|
440
|
+
with a mapped axis must have a non-None ``out_axes`` specification
|
441
|
+
(see :py:func:`vmap`).
|
442
|
+
static_broadcasted_argnums: An int or collection of ints specifying which
|
443
|
+
positional arguments to treat as static (compile-time constant).
|
444
|
+
Operations that only depend on static arguments will be constant-folded.
|
445
|
+
Calling the pmapped function with different values for these constants
|
446
|
+
will trigger recompilation. If the pmapped function is called with fewer
|
447
|
+
positional arguments than indicated by ``static_broadcasted_argnums`` then
|
448
|
+
an error is raised. Each of the static arguments will be broadcasted to
|
449
|
+
all devices. Arguments that are not arrays or containers thereof must be
|
450
|
+
marked as static. Defaults to ().
|
451
|
+
|
452
|
+
Static arguments must be hashable, meaning both ``__hash__`` and
|
453
|
+
``__eq__`` are implemented, and should be immutable.
|
454
|
+
|
455
|
+
devices: This is an experimental feature and the API is likely to change.
|
456
|
+
Optional, a sequence of Devices to map over. (Available devices can be
|
457
|
+
retrieved via jax.devices()). Must be given identically for each process
|
458
|
+
in multi-process settings (and will therefore include devices across
|
459
|
+
processes). If specified, the size of the mapped axis must be equal to
|
460
|
+
the number of devices in the sequence local to the given process. Nested
|
461
|
+
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
|
462
|
+
:py:func:`pmap` are not yet supported.
|
463
|
+
backend: This is an experimental feature and the API is likely to change.
|
464
|
+
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
|
465
|
+
axis_size: Optional; the size of the mapped axis.
|
466
|
+
donate_argnums: Specify which positional argument buffers are "donated" to
|
467
|
+
the computation. It is safe to donate argument buffers if you no longer need
|
468
|
+
them once the computation has finished. In some cases XLA can make use of
|
469
|
+
donated buffers to reduce the amount of memory needed to perform a
|
470
|
+
computation, for example recycling one of your input buffers to store a
|
471
|
+
result. You should not reuse buffers that you donate to a computation, JAX
|
472
|
+
will raise an error if you try to.
|
473
|
+
Note that donate_argnums only work for positional arguments, and keyword
|
474
|
+
arguments will not be donated.
|
475
|
+
|
476
|
+
For more details on buffer donation see the
|
477
|
+
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
478
|
+
global_arg_shapes: Optional; a tuple of tuples of integers representing the
|
479
|
+
shapes of the global arguments. These are arguments that are not replicated
|
480
|
+
across devices, but are broadcasted to all devices. The tuple should have
|
481
|
+
the same length as the number of global arguments, and each inner tuple
|
482
|
+
should have the same length as the corresponding argument. The shapes of
|
483
|
+
the global arguments must be the same on all devices.
|
484
|
+
rngs: Optional, a random number generator or sequence of random number
|
485
|
+
generators to be used in the mapped function. These random number
|
486
|
+
generators are restored their random key after the mapped function is
|
487
|
+
executed.
|
488
|
+
|
489
|
+
Returns:
|
490
|
+
A parallelized version of ``fun`` with arguments that correspond to those of
|
491
|
+
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
|
492
|
+
with output that has an additional leading array axis (with the same size).
|
493
|
+
|
494
|
+
"""
|
495
|
+
|
496
|
+
if isinstance(fn, Missing):
|
497
|
+
return functools.partial(
|
498
|
+
pmap,
|
499
|
+
axis_name=axis_name,
|
500
|
+
in_axes=in_axes,
|
501
|
+
out_axes=out_axes,
|
502
|
+
static_broadcasted_argnums=static_broadcasted_argnums,
|
503
|
+
devices=devices,
|
504
|
+
backend=backend,
|
505
|
+
axis_size=axis_size,
|
506
|
+
donate_argnums=donate_argnums,
|
507
|
+
global_arg_shapes=global_arg_shapes,
|
508
|
+
rngs=rngs,
|
509
|
+
) # type: ignore[return-value]
|
510
|
+
|
511
|
+
return _map_transform(
|
512
|
+
'pmap', # ctxtag
|
513
|
+
jax.pmap,
|
514
|
+
fn,
|
515
|
+
in_axes=in_axes,
|
516
|
+
out_axes=out_axes,
|
517
|
+
axis_name=axis_name,
|
518
|
+
static_broadcasted_argnums=static_broadcasted_argnums,
|
519
|
+
devices=devices,
|
520
|
+
backend=backend,
|
521
|
+
axis_size=axis_size,
|
522
|
+
donate_argnums=donate_argnums,
|
523
|
+
global_arg_shapes=global_arg_shapes,
|
524
|
+
rngs=rngs,
|
525
|
+
)
|