brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from typing import Any, TypeVar, Callable, Sequence, Union
20
+
21
+ import jax
22
+
23
+ from brainstate.graph import graph_to_tree, tree_to_graph
24
+ from brainstate.random import DEFAULT, RandomState
25
+ from ._random import restore_rngs
26
+
27
+ __all__ = [
28
+ 'eval_shape',
29
+ ]
30
+
31
+ A = TypeVar('A')
32
+
33
+
34
+ def eval_shape(
35
+ fn: Callable[..., A],
36
+ *args: Any,
37
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
38
+ **kwargs: Any,
39
+ ) -> A:
40
+ """
41
+ Compute the shape/dtype of ``fn`` without any FLOPs.
42
+
43
+ Here's an example::
44
+
45
+ >>> import brainstate as bst
46
+ >>> class MLP:
47
+ ... def __init__(self, n_in, n_mid, n_out):
48
+ ... self.dense1 = bst.nn.Linear(n_in, n_mid)
49
+ ... self.dense2 = bst.nn.Linear(n_mid, n_out)
50
+
51
+ >>> r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
52
+ >>> r
53
+ MLP(
54
+ dense1=Linear(
55
+ in_size=(1,),
56
+ out_size=(2,),
57
+ w_mask=None,
58
+ weight=ParamState(
59
+ value={'bias': ShapeDtypeStruct(shape=(2,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(1, 2), dtype=float32)}
60
+ )
61
+ ),
62
+ dense2=Linear(
63
+ in_size=(2,),
64
+ out_size=(3,),
65
+ w_mask=None,
66
+ weight=ParamState(
67
+ value={'bias': ShapeDtypeStruct(shape=(3,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(2, 3), dtype=float32)}
68
+ )
69
+ )
70
+ )
71
+
72
+ Args:
73
+ fn: The function whose output shape should be evaluated.
74
+ *args: a positional argument tuple of arrays, scalars, or (nested) standard
75
+ Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
76
+ those types. Since only the ``shape`` and ``dtype`` attributes are
77
+ accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
78
+ that duck-types as ndarrays (note however that duck-typed objects cannot
79
+ be namedtuples because those are treated as standard Python containers).
80
+ **kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
81
+ Python containers (pytrees) of those types. As in ``args``, array values
82
+ need only be duck-typed to have ``shape`` and ``dtype`` attributes.
83
+ rngs: a :class:`RandomState` or a sequence of :class:`RandomState` objects
84
+ representing the random number generators to use. If not provided, the
85
+ default random number generator will be used.
86
+
87
+ Returns:
88
+ out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
89
+
90
+
91
+ """
92
+
93
+ @functools.wraps(fn)
94
+ @restore_rngs(rngs=rngs)
95
+ def _eval_shape_fn(*args_, **kwargs_):
96
+ args_, kwargs_ = tree_to_graph((args_, kwargs_))
97
+ out = fn(*args_, **kwargs_)
98
+ return graph_to_tree(out)
99
+
100
+ args, kwargs = graph_to_tree((args, kwargs))
101
+ out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
102
+ return tree_to_graph(out)
@@ -0,0 +1,40 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ import unittest
20
+
21
+ import brainstate as bst
22
+
23
+
24
+ class TestEvalShape(unittest.TestCase):
25
+ def test1(self):
26
+ class MLP(bst.nn.Module):
27
+ def __init__(self, n_in, n_mid, n_out):
28
+ super().__init__()
29
+ self.dense1 = bst.nn.Linear(n_in, n_mid)
30
+ self.dense2 = bst.nn.Linear(n_mid, n_out)
31
+
32
+ def __call__(self, x):
33
+ x = self.dense1(x)
34
+ x = bst.functional.relu(x)
35
+ x = self.dense2(x)
36
+ return x
37
+
38
+ r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
39
+ print(r)
40
+ print(bst.random.DEFAULT)
@@ -0,0 +1,525 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import dataclasses
19
+ import functools
20
+ from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Mapping, Tuple, Union, Optional
21
+
22
+ import jax
23
+
24
+ from brainstate.graph import (NodeStates, graph_to_tree, tree_to_graph, update_context)
25
+ from brainstate.graph._graph_convert import clear_non_graph_nodes
26
+ from brainstate.random import DEFAULT, RandomState
27
+ from brainstate.typing import Missing, Filter
28
+ from brainstate.util import NestedDict
29
+ from ._random import restore_rngs
30
+
31
+ __all__ = [
32
+ 'StateAxes',
33
+ 'vmap',
34
+ 'pmap',
35
+ ]
36
+
37
+ AxisName = Hashable
38
+ F = TypeVar("F", bound=Callable)
39
+ Index = int
40
+ Carry = TypeVar("Carry")
41
+
42
+
43
+ class StateAxes:
44
+ """
45
+ A class to represent the axes of a state.
46
+
47
+ This class is used to control how graph nodes like Modules are vectorized or
48
+ parallelized by specifying the axes to be applied to substates of the graph
49
+ node given a Filter.
50
+
51
+ Args:
52
+ filter_axes: A mapping from filters to axes. The axes can be an index, a carry or None.
53
+
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ filter_axes: Union[Mapping[Filter, Index | Carry | None], Iterable[Tuple[Filter, Index | Carry | None]]],
59
+ ):
60
+ iterable = filter_axes.items() if isinstance(filter_axes, Mapping) else filter_axes
61
+ self._filters = tuple(filter_ for filter_, _ in iterable)
62
+ self._axes = tuple(axis for _, axis in iterable)
63
+
64
+ @property
65
+ def filters(self) -> Tuple[Filter, ...]:
66
+ return self._filters
67
+
68
+ @property
69
+ def axes(self) -> Tuple[Index | Carry | None, ...]:
70
+ return self._axes
71
+
72
+ def __repr__(self):
73
+ return f'StateAxes({dict(self.items())})'
74
+
75
+ def items(self):
76
+ return zip(self.filters, self.axes)
77
+
78
+ def __eq__(self, other):
79
+ return isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes
80
+
81
+ def __hash__(self):
82
+ return hash((self.filters, self.axes))
83
+
84
+
85
+ def _map_split_fn(ctx, path, prefix, x):
86
+ if isinstance(prefix, StateAxes):
87
+ return NodeStates.from_split(*ctx.treefy_split(x, *prefix.filters), metadata=prefix)
88
+ return NodeStates.from_split(*ctx.treefy_split(x), metadata=prefix)
89
+
90
+
91
+ @dataclasses.dataclass(eq=False)
92
+ class MapFn:
93
+ f: Callable[..., Any]
94
+ in_axes: Any
95
+ out_axes: Any
96
+ ctxtag: str
97
+
98
+ def __post_init__(self):
99
+ functools.update_wrapper(self, self.f)
100
+
101
+ def __call__(self, *pure_args: Tuple[Any, ...]):
102
+ # pytree to graph
103
+ args = tree_to_graph(pure_args, ctxtag=self.ctxtag)
104
+
105
+ # call the function
106
+ out = self.f(*args)
107
+
108
+ # graph to pytree
109
+ args_out = clear_non_graph_nodes(args)
110
+ pure_args_out, pure_out = graph_to_tree(
111
+ (args_out, out),
112
+ prefix=(self.in_axes, self.out_axes),
113
+ split_fn=_map_split_fn,
114
+ ctxtag=self.ctxtag,
115
+ )
116
+ return pure_args_out, pure_out
117
+
118
+
119
+ def _map_transform(
120
+ ctxtag,
121
+ transform,
122
+ f: F,
123
+ *,
124
+ in_axes: Optional[int | Sequence[Any]] = 0,
125
+ out_axes: Any = 0,
126
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
127
+ **transform_kwargs,
128
+ ):
129
+ # jax in axes
130
+ jax_in_axes = jax.tree.map(
131
+ lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
132
+ in_axes,
133
+ )
134
+
135
+ # jax out axes
136
+ jax_out_axes = jax.tree.map(
137
+ lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
138
+ out_axes,
139
+ )
140
+
141
+ # mapped function
142
+ mapped_fn = transform(
143
+ MapFn(f, in_axes, out_axes, ctxtag),
144
+ in_axes=jax_in_axes,
145
+ out_axes=(jax_in_axes, jax_out_axes),
146
+ **transform_kwargs
147
+ )
148
+
149
+ @functools.wraps(f)
150
+ @restore_rngs(rngs=rngs) # restore the random key of default random number generator
151
+ @update_context(ctxtag)
152
+ def map_wrapper(*args):
153
+ # graph to pytree
154
+ pure_args = graph_to_tree(args, prefix=in_axes, split_fn=_map_split_fn, ctxtag=ctxtag)
155
+
156
+ # vmap with pytree
157
+ pure_args_out, pure_out = mapped_fn(*pure_args)
158
+
159
+ # pytree to graph
160
+ _args_out, out = tree_to_graph((pure_args_out, pure_out), ctxtag=ctxtag)
161
+ return out
162
+
163
+ return map_wrapper # type: ignore
164
+
165
+
166
+ def vmap(
167
+ fn: F | Missing = Missing(),
168
+ *,
169
+ in_axes: int | None | Sequence[Any] = 0,
170
+ out_axes: Any = 0,
171
+ axis_name: AxisName | None = None,
172
+ axis_size: int | None = None,
173
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
174
+ # brainstate specific arguments
175
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
176
+ ) -> F | Callable[[F], F]:
177
+ """
178
+ Vectorizing map. Creates a function which maps ``fun`` over argument axes.
179
+
180
+ The transformation :func:`vmap` is designed to work with ``pygraph`` structure
181
+ defined in the ``brainstate`` library. It is used to vectorize functions by
182
+ pushing the mapped axis down into primitive operations.
183
+
184
+ More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
185
+
186
+
187
+ These are several example usage::
188
+
189
+ >>> import brainstate as bst
190
+ >>> import jax.numpy as jnp
191
+
192
+ >>> model = bst.nn.Linear(2, 3)
193
+ >>> x = jnp.ones((5, 2))
194
+
195
+ >>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
196
+ ... def forward(model, x):
197
+ ... return model(x)
198
+
199
+ >>> y = forward(model, x)
200
+ >>> print(y.shape)
201
+ (5, 3)
202
+
203
+ Another example with a more complex model::
204
+
205
+ >>> class LinearEnsemble(bst.nn.Module):
206
+ ... def __init__(self, n: int):
207
+ ... super().__init__()
208
+ ... self.n = n
209
+ ... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
210
+
211
+ >>> model = LinearEnsemble(5)
212
+ >>> x = jnp.ones((2,))
213
+
214
+ >>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
215
+ ... def forward(model, x):
216
+ ... return jnp.dot(x, model.w.value)
217
+
218
+ >>> y = forward(model, x)
219
+ >>> print(y.shape)
220
+ (5, 3)
221
+
222
+ To control how different types of states are vectorized, ``StateAxes``
223
+ can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
224
+ applied to each substate given a filter. The following example shows how to
225
+ share the parameters between the ensemble members which keeping different
226
+ batch statistics and dropout random state::
227
+
228
+ >>> class Foo(bst.nn.Module):
229
+ ... def __init__(self):
230
+ ... super().__init__()
231
+ ... self.a = bst.ParamState(jnp.arange(4))
232
+ ... self.b = bst.ShortTermState(jnp.arange(4))
233
+
234
+ >>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
235
+ >>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
236
+ ... def mul(foo):
237
+ ... return foo.a.value * foo.b.value
238
+
239
+ >>> model = Foo()
240
+ >>> y = mul(model)
241
+ >>> print(y.shape)
242
+ (4, 4)
243
+
244
+ Args:
245
+ fn: Function to be mapped over additional axes.
246
+ in_axes: An integer, None, or sequence of values specifying which input
247
+ array axes to map over.
248
+
249
+ If each positional argument to ``fun`` is an array, then ``in_axes`` can
250
+ be an integer, a None, or a tuple of integers and Nones with length equal
251
+ to the number of positional arguments to ``fun``. An integer or ``None``
252
+ indicates which array axis to map over for all arguments (with ``None``
253
+ indicating not to map any axis), and a tuple indicates which axis to map
254
+ for each corresponding positional argument. Axis integers must be in the
255
+ range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
256
+ dimensions (axes) of the corresponding input array.
257
+
258
+ If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
259
+ must be a sequence with length equal to the number of positional arguments to
260
+ ``fun``, and for each argument the corresponding element of ``in_axes`` can
261
+ be a container with a matching pytree structure specifying the mapping of its
262
+ container elements. In other words, ``in_axes`` must be a container tree prefix
263
+ of the positional argument tuple passed to ``fun``. See this link for more detail:
264
+ https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
265
+
266
+ Either ``axis_size`` must be provided explicitly, or at least one
267
+ positional argument must have ``in_axes`` not None. The sizes of the
268
+ mapped input axes for all mapped positional arguments must all be equal.
269
+
270
+ Arguments passed as keywords are always mapped over their leading axis
271
+ (i.e. axis index 0).
272
+
273
+ See below for examples.
274
+
275
+ out_axes: An integer, None, or (nested) standard Python container
276
+ (tuple/list/dict) thereof indicating where the mapped axis should appear
277
+ in the output. All outputs with a mapped axis must have a non-None
278
+ ``out_axes`` specification. Axis integers must be in the range ``[-ndim,
279
+ ndim)`` for each output array, where ``ndim`` is the number of dimensions
280
+ (axes) of the array returned by the :func:`vmap`-ed function, which is one
281
+ more than the number of dimensions (axes) of the corresponding array
282
+ returned by ``fun``.
283
+ axis_name: Optional, a hashable Python object used to identify the mapped
284
+ axis so that parallel collectives can be applied.
285
+ axis_size: Optional, an integer indicating the size of the axis to be
286
+ mapped. If not provided, the mapped axis size is inferred from arguments.
287
+ spmd_axis_name: Optional, a hashable Python object or tuple of hashable
288
+ Python objects used to identify the mapped axis so that parallel collectives
289
+ can be applied. This is used to specify multiple axes to be mapped over
290
+ in a nested :func:`vmap` call. The length of the tuple must match the
291
+ number of nested :func:`vmap` calls. The first element of the tuple
292
+ corresponds to the outermost :func:`vmap` call, the second element to
293
+ the next outermost, and so on. If the tuple is not provided, the
294
+ ``axis_name`` is used for all nested :func:`vmap` calls.
295
+ rngs: Optional, a random number generator or sequence of random number
296
+ generators to be used in the mapped function. These random number
297
+ generators are restored their random key after the mapped function is
298
+ executed.
299
+
300
+ Returns:
301
+ Batched/vectorized version of ``fun`` with arguments that correspond to
302
+ those of ``fun``, but with extra array axes at positions indicated by
303
+ ``in_axes``, and a return value that corresponds to that of ``fun``, but
304
+ with extra array axes at positions indicated by ``out_axes``.
305
+
306
+ """
307
+ if isinstance(fn, Missing):
308
+ return functools.partial(
309
+ vmap,
310
+ in_axes=in_axes,
311
+ out_axes=out_axes,
312
+ axis_name=axis_name,
313
+ axis_size=axis_size,
314
+ spmd_axis_name=spmd_axis_name,
315
+ rngs=rngs,
316
+ ) # type: ignore[return-value]
317
+
318
+ return _map_transform(
319
+ 'vmap', # ctxtag
320
+ jax.vmap,
321
+ fn,
322
+ in_axes=in_axes,
323
+ out_axes=out_axes,
324
+ axis_name=axis_name,
325
+ axis_size=axis_size,
326
+ spmd_axis_name=spmd_axis_name,
327
+ rngs=rngs
328
+ )
329
+
330
+
331
+ def pmap(
332
+ fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
333
+ axis_name: Optional[AxisName] = None,
334
+ *,
335
+ in_axes: Any = 0,
336
+ out_axes: Any = 0,
337
+ static_broadcasted_argnums: int | Iterable[int] = (),
338
+ devices: Optional[Sequence[jax.Device]] = None, # noqa: F811
339
+ backend: Optional[str] = None,
340
+ axis_size: Optional[int] = None,
341
+ donate_argnums: int | Iterable[int] = (),
342
+ global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
343
+ # brainstate specific arguments
344
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
345
+ ) -> Callable[[F], F] | F:
346
+ """
347
+ Parallel map with support for collective operations.
348
+
349
+ The purpose of :py:func:`pmap` is to express single-program multiple-data
350
+ (SPMD) programs. Applying :py:func:`pmap` to a function will compile the
351
+ function with XLA (similarly to :py:func:`jit`), then execute it in parallel
352
+ on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
353
+ is comparable to :py:func:`vmap` because both transformations map a function
354
+ over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
355
+ mapped axis down into primitive operations, :py:func:`pmap` instead replicates
356
+ the function and executes each replica on its own XLA device in parallel.
357
+
358
+ The mapped axis size must be less than or equal to the number of local XLA
359
+ devices available, as returned by :py:func:`jax.local_device_count()` (unless
360
+ ``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
361
+ product of the mapped axis sizes must be less than or equal to the number of
362
+ XLA devices.
363
+
364
+ More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
365
+
366
+ If there are 4 XLA devices available, the following example will execute
367
+ the function in parallel on each device::
368
+
369
+
370
+ >>> import brainstate as bst
371
+ >>> import jax.numpy as jnp
372
+
373
+ >>> model = bst.nn.Linear(2, 3)
374
+ >>> x = jnp.ones((4, 2))
375
+
376
+ >>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
377
+ ... def forward(model, x):
378
+ ... return model(x)
379
+
380
+ >>> y = forward(model, x)
381
+ >>> print(y.shape)
382
+ (4, 3)
383
+
384
+ Another example with a more complex model::
385
+
386
+ >>> class LinearEnsemble(bst.nn.Module):
387
+ ... def __init__(self, n: int):
388
+ ... super().__init__()
389
+ ... self.n = n
390
+ ... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
391
+
392
+ >>> model = LinearEnsemble(4)
393
+ >>> x = jnp.ones((2,))
394
+
395
+ >>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
396
+ ... def forward(model, x):
397
+ ... return jnp.dot(x, model.w.value)
398
+
399
+ >>> y = forward(model, x)
400
+ >>> print(y.shape)
401
+ (4, 3)
402
+
403
+ To control how different types of states are vectorized, ``StateAxes``
404
+ can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
405
+ applied to each substate given a filter. The following example shows how to
406
+ share the parameters between the ensemble members which keeping different
407
+ batch statistics and dropout random state::
408
+
409
+ >>> class Foo(bst.nn.Module):
410
+ ... def __init__(self):
411
+ ... super().__init__()
412
+ ... self.a = bst.ParamState(jnp.arange(4))
413
+ ... self.b = bst.ShortTermState(jnp.arange(4))
414
+
415
+ >>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
416
+ >>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
417
+ ... def mul(foo):
418
+ ... return foo.a.value * foo.b.value
419
+
420
+ >>> model = Foo()
421
+ >>> y = mul(model)
422
+ >>> print(y.shape)
423
+ (4, 4)
424
+
425
+
426
+ Args:
427
+ fn: Function to be mapped over argument axes. Its arguments and return
428
+ value should be arrays, scalars, or (nested) standard Python containers
429
+ (tuple/list/dict) thereof. Positional arguments indicated by
430
+ ``static_broadcasted_argnums`` can be anything at all, provided they are
431
+ hashable and have an equality operation defined.
432
+ axis_name: Optional, a hashable Python object used to identify the mapped
433
+ axis so that parallel collectives can be applied.
434
+ in_axes: A non-negative integer, None, or nested Python container thereof
435
+ that specifies which axes of positional arguments to map over. Arguments
436
+ passed as keywords are always mapped over their leading axis (i.e. axis
437
+ index 0). See :py:func:`vmap` for details.
438
+ out_axes: A non-negative integer, None, or nested Python container thereof
439
+ indicating where the mapped axis should appear in the output. All outputs
440
+ with a mapped axis must have a non-None ``out_axes`` specification
441
+ (see :py:func:`vmap`).
442
+ static_broadcasted_argnums: An int or collection of ints specifying which
443
+ positional arguments to treat as static (compile-time constant).
444
+ Operations that only depend on static arguments will be constant-folded.
445
+ Calling the pmapped function with different values for these constants
446
+ will trigger recompilation. If the pmapped function is called with fewer
447
+ positional arguments than indicated by ``static_broadcasted_argnums`` then
448
+ an error is raised. Each of the static arguments will be broadcasted to
449
+ all devices. Arguments that are not arrays or containers thereof must be
450
+ marked as static. Defaults to ().
451
+
452
+ Static arguments must be hashable, meaning both ``__hash__`` and
453
+ ``__eq__`` are implemented, and should be immutable.
454
+
455
+ devices: This is an experimental feature and the API is likely to change.
456
+ Optional, a sequence of Devices to map over. (Available devices can be
457
+ retrieved via jax.devices()). Must be given identically for each process
458
+ in multi-process settings (and will therefore include devices across
459
+ processes). If specified, the size of the mapped axis must be equal to
460
+ the number of devices in the sequence local to the given process. Nested
461
+ :py:func:`pmap` s with ``devices`` specified in either the inner or outer
462
+ :py:func:`pmap` are not yet supported.
463
+ backend: This is an experimental feature and the API is likely to change.
464
+ Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
465
+ axis_size: Optional; the size of the mapped axis.
466
+ donate_argnums: Specify which positional argument buffers are "donated" to
467
+ the computation. It is safe to donate argument buffers if you no longer need
468
+ them once the computation has finished. In some cases XLA can make use of
469
+ donated buffers to reduce the amount of memory needed to perform a
470
+ computation, for example recycling one of your input buffers to store a
471
+ result. You should not reuse buffers that you donate to a computation, JAX
472
+ will raise an error if you try to.
473
+ Note that donate_argnums only work for positional arguments, and keyword
474
+ arguments will not be donated.
475
+
476
+ For more details on buffer donation see the
477
+ `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
478
+ global_arg_shapes: Optional; a tuple of tuples of integers representing the
479
+ shapes of the global arguments. These are arguments that are not replicated
480
+ across devices, but are broadcasted to all devices. The tuple should have
481
+ the same length as the number of global arguments, and each inner tuple
482
+ should have the same length as the corresponding argument. The shapes of
483
+ the global arguments must be the same on all devices.
484
+ rngs: Optional, a random number generator or sequence of random number
485
+ generators to be used in the mapped function. These random number
486
+ generators are restored their random key after the mapped function is
487
+ executed.
488
+
489
+ Returns:
490
+ A parallelized version of ``fun`` with arguments that correspond to those of
491
+ ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
492
+ with output that has an additional leading array axis (with the same size).
493
+
494
+ """
495
+
496
+ if isinstance(fn, Missing):
497
+ return functools.partial(
498
+ pmap,
499
+ axis_name=axis_name,
500
+ in_axes=in_axes,
501
+ out_axes=out_axes,
502
+ static_broadcasted_argnums=static_broadcasted_argnums,
503
+ devices=devices,
504
+ backend=backend,
505
+ axis_size=axis_size,
506
+ donate_argnums=donate_argnums,
507
+ global_arg_shapes=global_arg_shapes,
508
+ rngs=rngs,
509
+ ) # type: ignore[return-value]
510
+
511
+ return _map_transform(
512
+ 'pmap', # ctxtag
513
+ jax.pmap,
514
+ fn,
515
+ in_axes=in_axes,
516
+ out_axes=out_axes,
517
+ axis_name=axis_name,
518
+ static_broadcasted_argnums=static_broadcasted_argnums,
519
+ devices=devices,
520
+ backend=backend,
521
+ axis_size=axis_size,
522
+ donate_argnums=donate_argnums,
523
+ global_arg_shapes=global_arg_shapes,
524
+ rngs=rngs,
525
+ )