brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,756 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
+ written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
+ include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
+
21
+
22
+ ``StatefulFunction``
23
+ --------------------
24
+
25
+ The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
+ JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
+ The class provides the following methods:
28
+
29
+ - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
+ - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
+ - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
+ - `get_states`: returns the states that are read and written by the function.
33
+ - `get_read_states`: returns the states that are read by the function.
34
+ - `get_write_states`: returns the states that are written by the function.
35
+ - `get_static_args`: returns the static arguments from the arguments.
36
+ - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
+ written by the function.
38
+ - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
+ - `get_out_shapes`: returns the output shapes of the function.
40
+ - `get_out_treedef`: returns the output tree of the function.
41
+
42
+ ``make_jaxpr``
43
+ --------------
44
+
45
+ The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
+ arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
+ `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
+ returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
+ and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
+ function.
51
+
52
+ """
53
+
54
+ from __future__ import annotations
55
+
56
+ import functools
57
+ import inspect
58
+ import operator
59
+ from collections.abc import Hashable, Iterable, Sequence
60
+ from contextlib import ExitStack
61
+ from typing import Any, Callable, Tuple, Union, Dict, Optional
62
+
63
+ import jax
64
+ from jax._src import source_info_util
65
+ from jax._src.linear_util import annotate
66
+ from jax._src.traceback_util import api_boundary
67
+ from jax.api_util import shaped_abstractify
68
+ from jax.extend.linear_util import transformation_with_aux, wrap_init
69
+ from jax.interpreters import partial_eval as pe
70
+ from jax.util import wraps
71
+
72
+ from brainstate._state import State, StateTraceStack
73
+ from brainstate._utils import set_module_as
74
+ from brainstate.typing import PyTree
75
+ from brainstate.util._tracers import new_jax_trace
76
+
77
+ AxisName = Hashable
78
+
79
+ __all__ = [
80
+ "StatefulFunction",
81
+ "make_jaxpr",
82
+ ]
83
+
84
+
85
+ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
86
+ """Convert x to a tuple of indices."""
87
+ x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
88
+ try:
89
+ return (operator.index(x),)
90
+ except TypeError:
91
+ return tuple(jax.util.safe_map(operator.index, x))
92
+
93
+
94
+ def _new_arg_fn(frame, trace, aval):
95
+ """
96
+ Transform a new argument to a tracer.
97
+
98
+ Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
99
+
100
+ Args:
101
+ frame: The frame.
102
+ trace: The trace.
103
+ aval: The abstract value.
104
+
105
+ Returns:
106
+ The tracer.
107
+ """
108
+ tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
109
+ frame.tracers.append(tracer)
110
+ frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
111
+ frame.invars.append(var)
112
+ return tracer
113
+
114
+
115
+ def _init_state_trace() -> StateTraceStack:
116
+ # Should be within the calling of ``jax.make_jaxpr()``
117
+ frame, trace = new_jax_trace()
118
+ state_trace: StateTraceStack = StateTraceStack()
119
+ # Set the function to transform the new argument to a tracer
120
+ state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
121
+ return state_trace
122
+
123
+
124
+ # def wrapped_abstractify(x: Any) -> Any:
125
+ # """
126
+ # Abstractify the input.
127
+ #
128
+ # Args:
129
+ # x: The input.
130
+ #
131
+ # Returns:
132
+ # The abstractified input.
133
+ # """
134
+ # if isinstance(x, pe.DynamicJaxprTracer):
135
+ # return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
136
+ # return shaped_abstractify(x)
137
+
138
+
139
+ class StatefulFunction(object):
140
+ """
141
+ A wrapper class for a function that collects the states that are read and written by the function. The states are
142
+ collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
143
+ manage the states in the JAX program. The class provides a function called `states` that returns the states
144
+ that are read and written by the function. The class provides a function called `to_state_manager` that returns
145
+ a StateDictManager instance that contains the states that are read and written by the function. The class provides
146
+ a function called `__call__` that wraps the function and returns the states that are read and written by the
147
+ function and the output of the function.
148
+
149
+ Args:
150
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
151
+ arguments and return value should be arrays, scalars, or standard Python
152
+ containers (tuple/list/dict) thereof.
153
+ static_argnums: See the :py:func:`jax.jit` docstring.
154
+ axis_env: Optional, a sequence of pairs where the first element is an axis
155
+ name and the second element is a positive integer representing the size of
156
+ the mapped axis with that name. This parameter is useful when lowering
157
+ functions that involve parallel communication collectives, and it
158
+ specifies the axis name/size environment that would be set up by
159
+ applications of :py:func:`jax.pmap`.
160
+ abstracted_axes: Optional, a pytree with the same structure as the input
161
+ arguments to ``fun``. The leaves of the pytree can be either None or a
162
+ dict with axis names as keys and integers as values. If the leaf is None,
163
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
164
+ the corresponding axis is abstracted, and the dict specifies the axis name
165
+ and size. The abstracted axes are used to infer the input type of the
166
+ function. If None, then all axes are abstracted.
167
+ state_returns: Optional, a string or a tuple of strings. The default is
168
+ ``('read', 'write')``. The strings specify the categories of states to be
169
+ returned by the wrapped function. The categories are ``'read'`` and
170
+ ``'write'``. If the category is ``'read'``, then the wrapped function
171
+ returns the states that are read by the function. If the category is
172
+ ``'write'``, then the wrapped function returns the states that are written
173
+ by the function. If the category is ``'read'`` and ``'write'``, then the
174
+ wrapped function returns both the read and write states.
175
+
176
+ """
177
+ __module__ = "brainstate.compile"
178
+
179
+ def __init__(
180
+ self,
181
+ fun: Callable,
182
+ static_argnums: Union[int, Iterable[int]] = (),
183
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
184
+ abstracted_axes: Optional[Any] = None,
185
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
186
+ cache_type: Optional[str] = None,
187
+ ):
188
+ # explicit parameters
189
+ self.fun = fun
190
+ self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
191
+ self.axis_env = axis_env
192
+ self.abstracted_axes = abstracted_axes
193
+ self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
194
+ assert cache_type in [None, 'jit']
195
+
196
+ # implicit parameters
197
+ self.cache_type = cache_type
198
+ self._cached_jaxpr: Dict[Any, jax.core.ClosedJaxpr] = dict()
199
+ self._cached_out_shapes: Dict[Any, PyTree] = dict()
200
+ self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
201
+ self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
202
+
203
+ def __repr__(self) -> str:
204
+ return (f"{self.__class__.__name__}({self.fun}, "
205
+ f"static_argnums={self.static_argnums}, "
206
+ f"axis_env={self.axis_env}, "
207
+ f"abstracted_axes={self.abstracted_axes}, "
208
+ f"state_returns={self.state_returns})")
209
+
210
+ def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
211
+ """
212
+ Read the JAX Jaxpr representation of the function.
213
+
214
+ Args:
215
+ cache_key: The hashable key.
216
+
217
+ Returns:
218
+ The JAX Jaxpr representation of the function.
219
+ """
220
+ if cache_key not in self._cached_jaxpr:
221
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
222
+ return self._cached_jaxpr[cache_key]
223
+
224
+ def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
225
+ """
226
+ Read the output shapes of the function.
227
+
228
+ Args:
229
+ cache_key: The hashable key.
230
+
231
+ Returns:
232
+ The output shapes of the function.
233
+ """
234
+ if cache_key not in self._cached_out_shapes:
235
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
236
+ return self._cached_out_shapes[cache_key]
237
+
238
+ def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
239
+ """
240
+ Read the output tree of the function.
241
+
242
+ Args:
243
+ cache_key: The hashable key.
244
+
245
+ Returns:
246
+ The output tree of the function.
247
+ """
248
+ if cache_key not in self._cached_jaxpr_out_tree:
249
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
250
+ return self._cached_jaxpr_out_tree[cache_key]
251
+
252
+ def get_state_trace(self, cache_key: Hashable = ()) -> StateTraceStack:
253
+ """
254
+ Read the state trace of the function.
255
+
256
+ Args:
257
+ cache_key: The hashable key.
258
+
259
+ Returns:
260
+ The state trace of the function.
261
+ """
262
+ if cache_key not in self._cached_state_trace:
263
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
264
+ return self._cached_state_trace[cache_key]
265
+
266
+ def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
267
+ """
268
+ Read the states that are read and written by the function.
269
+
270
+ Args:
271
+ cache_key: The hashable key.
272
+
273
+ Returns:
274
+ The states that are read and written by the function.
275
+ """
276
+ return tuple(self.get_state_trace(cache_key).states)
277
+
278
+ def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
279
+ """
280
+ Read the states that are read by the function.
281
+
282
+ Args:
283
+ cache_key: The hashable key.
284
+
285
+ Returns:
286
+ The states that are read by the function.
287
+ """
288
+ return self.get_state_trace(cache_key).get_read_states()
289
+
290
+ def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
291
+ """
292
+ Read the states that are written by the function.
293
+
294
+ Args:
295
+ cache_key: The hashable key.
296
+
297
+ Returns:
298
+ The states that are written by the function.
299
+ """
300
+ return self.get_state_trace(cache_key).get_write_states()
301
+
302
+ def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
303
+ """
304
+ Get the static arguments from the arguments.
305
+
306
+ Args:
307
+ *args: The arguments to the function.
308
+
309
+ Returns:
310
+ The static arguments.
311
+ """
312
+ if self.cache_type == 'jit':
313
+ static_args, dyn_args = [], []
314
+ for i, arg in enumerate(args):
315
+ if i in self.static_argnums:
316
+ static_args.append(arg)
317
+ else:
318
+ dyn_args.append(arg)
319
+ dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
320
+ dyn_kwargs = jax.tree.map(shaped_abstractify, jax.tree.leaves(kwargs))
321
+ return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
322
+ elif self.cache_type is None:
323
+ num_arg = len(args)
324
+ return tuple(args[i] for i in self.static_argnums if i < num_arg)
325
+ else:
326
+ raise ValueError(f"Invalid cache type: {self.cache_type}")
327
+
328
+ def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
329
+ """
330
+ Compile the function, and get the states that are read and written by this function.
331
+
332
+ Args:
333
+ *args: The arguments to the function.
334
+ **kwargs: The keyword arguments to the function.
335
+
336
+ Returns:
337
+ The states that are read and written by the function.
338
+ """
339
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
340
+ if cache_key not in self._cached_state_trace:
341
+ self.make_jaxpr(*args, **kwargs)
342
+ return self.get_states(cache_key)
343
+
344
+ def compile_function_and_get_state_trace(
345
+ self, *args, return_only_write: bool = False, **kwargs
346
+ ) -> StateTraceStack:
347
+ """
348
+ Compile the function, and get the states that are read and written by this function.
349
+
350
+ Args:
351
+ *args: The arguments to the function.
352
+ **kwargs: The keyword arguments to the function.
353
+ return_only_write: If True, only return the states that are written by the function.
354
+
355
+ Returns:
356
+ The state trace stack.
357
+ """
358
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
359
+ if cache_key not in self._cached_state_trace:
360
+ self.make_jaxpr(*args, **kwargs, return_only_write=return_only_write)
361
+ return self.get_state_trace(cache_key)
362
+
363
+ def clear_cache(self) -> None:
364
+ """
365
+ Clear the compilation cache.
366
+ """
367
+ self._cached_jaxpr.clear()
368
+ self._cached_out_shapes.clear()
369
+ self._cached_jaxpr_out_tree.clear()
370
+ self._cached_state_trace.clear()
371
+
372
+ def _wrapped_fun_to_eval(
373
+ self, cache_key, *args, return_only_write: bool = False, **kwargs,
374
+ ) -> Tuple[Any, Tuple[State, ...]]:
375
+ """
376
+ Wrap the function and return the states that are read and written by the function and the output of the function.
377
+
378
+ Args:
379
+ *args: The arguments to the function.
380
+ **kwargs: The keyword arguments to the function.
381
+
382
+ Returns:
383
+ A tuple of the states that are read and written by the function and the output of the function.
384
+ """
385
+ # state trace
386
+ state_trace = _init_state_trace()
387
+ self._cached_state_trace[cache_key] = state_trace
388
+ with state_trace:
389
+ out = self.fun(*args, **kwargs)
390
+ state_values = state_trace.get_write_state_values(
391
+ True) if return_only_write else state_trace.get_state_values()
392
+ state_trace.recovery_original_values()
393
+
394
+ # State instance as functional returns is not allowed.
395
+ # Checking whether the states are returned.
396
+ for leaf in jax.tree.leaves(out):
397
+ if isinstance(leaf, State):
398
+ leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
399
+ return out, state_values
400
+
401
+ def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
402
+ """Creates a function that produces its jaxpr given example args.
403
+
404
+ A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
405
+ argument ``return_shape`` is ``True``, then the returned function instead
406
+ returns a pair where the first element is the ``ClosedJaxpr``
407
+ representation of ``fun`` and the second element is a pytree representing
408
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
409
+
410
+ Args:
411
+ *args: The arguments to the function.
412
+ **kwargs: The keyword arguments to the function.
413
+ """
414
+
415
+ # static args
416
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
417
+
418
+ if cache_key not in self._cached_state_trace:
419
+ try:
420
+ # jaxpr
421
+ jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
422
+ functools.partial(self._wrapped_fun_to_eval, cache_key, return_only_write=return_only_write),
423
+ static_argnums=self.static_argnums,
424
+ axis_env=self.axis_env,
425
+ return_shape=True,
426
+ abstracted_axes=self.abstracted_axes
427
+ )(*args, **kwargs)
428
+
429
+ # returns
430
+ self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
431
+ self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
432
+ self._cached_jaxpr[cache_key] = jaxpr
433
+ except Exception as e:
434
+ try:
435
+ self._cached_state_trace.pop(cache_key)
436
+ except KeyError:
437
+ pass
438
+ raise e
439
+
440
+ return self
441
+
442
+ def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
443
+ """
444
+ Call the function at the JAX Jaxpr level.
445
+
446
+ Args:
447
+ state_vals: The state values.
448
+ *args: The arguments to the function.
449
+ **kwargs: The keyword arguments to the function.
450
+
451
+ Returns:
452
+ State values and the function output.
453
+ """
454
+ # state checking
455
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
456
+ states: Sequence[State] = self.get_states(cache_key)
457
+ assert len(state_vals) == len(states), 'State length mismatch.'
458
+
459
+ # parameters
460
+ args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
461
+ args = jax.tree.flatten((args, kwargs, state_vals))[0]
462
+
463
+ # calling the function,
464
+ # note that this function always returns state values
465
+ # that both write and read by the function
466
+ closed_jaxpr = self.get_jaxpr(cache_key)
467
+ out_treedef = self.get_out_treedef(cache_key)
468
+ jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
469
+
470
+ # output processing
471
+ out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
472
+ assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
473
+ return new_state_vals, out
474
+
475
+ def jaxpr_call_auto(self, *args, **kwargs) -> Any:
476
+ """
477
+ Call the function at the JAX Jaxpr level with automatic state management.
478
+
479
+ Args:
480
+ *args: The arguments to the function.
481
+ **kwargs: The keyword arguments to the function.
482
+
483
+ Returns:
484
+ The output of the function.
485
+ """
486
+ state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
487
+ state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
488
+ for st, written, val in zip(state_trace.states, state_trace.been_writen, state_vals):
489
+ if written:
490
+ st.write_value(val)
491
+ else:
492
+ st.restore_value(val)
493
+ return out
494
+
495
+
496
+ @set_module_as("brainstate.compile")
497
+ def make_jaxpr(
498
+ fun: Callable,
499
+ static_argnums: Union[int, Iterable[int]] = (),
500
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
501
+ return_shape: bool = False,
502
+ abstracted_axes: Optional[Any] = None,
503
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
504
+ ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
505
+ Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
506
+ """
507
+ Creates a function that produces its jaxpr given example args.
508
+
509
+ Args:
510
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
511
+ arguments and return value should be arrays, scalars, or standard Python
512
+ containers (tuple/list/dict) thereof.
513
+ static_argnums: See the :py:func:`jax.jit` docstring.
514
+ axis_env: Optional, a sequence of pairs where the first element is an axis
515
+ name and the second element is a positive integer representing the size of
516
+ the mapped axis with that name. This parameter is useful when lowering
517
+ functions that involve parallel communication collectives, and it
518
+ specifies the axis name/size environment that would be set up by
519
+ applications of :py:func:`jax.pmap`.
520
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
521
+ wrapped function returns a pair where the first element is the XLA
522
+ computation and the second element is a pytree with the same structure as
523
+ the output of ``fun`` and where the leaves are objects with ``shape``,
524
+ ``dtype``, and ``named_shape`` attributes representing the corresponding
525
+ types of the output leaves.
526
+ abstracted_axes: Optional, a pytree with the same structure as the input
527
+ arguments to ``fun``. The leaves of the pytree can be either None or a
528
+ dict with axis names as keys and integers as values. If the leaf is None,
529
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
530
+ the corresponding axis is abstracted, and the dict specifies the axis name
531
+ and size. The abstracted axes are used to infer the input type of the
532
+ function. If None, then all axes are abstracted.
533
+ state_returns: Optional, a string or a tuple of strings. The default is
534
+ ``('read', 'write')``. The strings specify the categories of states to be
535
+ returned by the wrapped function. The categories are ``'read'`` and
536
+ ``'write'``. If the category is ``'read'``, then the wrapped function
537
+ returns the states that are read by the function. If the category is
538
+ ``'write'``, then the wrapped function returns the states that are written
539
+ by the function. If the category is ``'read'`` and ``'write'``, then the
540
+ wrapped function returns both the read and write states.
541
+
542
+
543
+ Returns:
544
+ A wrapped version of ``fun`` that when applied to example arguments returns
545
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
546
+ argument ``return_shape`` is ``True``, then the returned function instead
547
+ returns a pair where the first element is the ``ClosedJaxpr``
548
+ representation of ``fun`` and the second element is a pytree representing
549
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
550
+
551
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
552
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
553
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
554
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
555
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
556
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
557
+
558
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
559
+ instead give a few examples.
560
+
561
+ >>> import jax
562
+ >>> import brainstate as bst
563
+ >>>
564
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
565
+ >>> print(f(3.0))
566
+ -0.83602
567
+ >>> jaxpr, states = bst.compile.make_jaxpr(f)(3.0)
568
+ >>> jaxpr
569
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
570
+ >>> jaxpr, states = bst.compile.make_jaxpr(jax.grad(f))(3.0)
571
+ >>> jaxpr
572
+ { lambda ; a:f32[]. let
573
+ b:f32[] = cos a
574
+ c:f32[] = sin a
575
+ _:f32[] = sin b
576
+ d:f32[] = cos b
577
+ e:f32[] = mul 1.0 d
578
+ f:f32[] = neg e
579
+ g:f32[] = mul f c
580
+ in (g,) }
581
+ """
582
+
583
+ stateful_fun = StatefulFunction(fun, static_argnums, axis_env, abstracted_axes, state_returns)
584
+
585
+ @wraps(fun)
586
+ def make_jaxpr_f(*args, **kwargs):
587
+ stateful_fun.make_jaxpr(*args, **kwargs)
588
+ cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
589
+ if return_shape:
590
+ return (stateful_fun.get_jaxpr(cache_key),
591
+ stateful_fun.get_states(cache_key),
592
+ stateful_fun.get_out_shapes(cache_key)[0])
593
+ else:
594
+ return (stateful_fun.get_jaxpr(cache_key),
595
+ stateful_fun.get_states(cache_key))
596
+
597
+ # wrapped jaxpr builder function
598
+ make_jaxpr_f.__module__ = "brainstate.compile"
599
+ if hasattr(fun, "__qualname__"):
600
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
601
+ if hasattr(fun, "__name__"):
602
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
603
+ return make_jaxpr_f
604
+
605
+
606
+ def _check_callable(fun):
607
+ # In Python 3.10+, the only thing stopping us from supporting staticmethods
608
+ # is that we can't take weak references to them, which the C++ JIT requires.
609
+ if isinstance(fun, staticmethod):
610
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
611
+ if not callable(fun):
612
+ raise TypeError(f"Expected a callable value, got {fun}")
613
+ if inspect.isgeneratorfunction(fun):
614
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
615
+
616
+
617
+ def _broadcast_prefix(
618
+ prefix_tree: Any,
619
+ full_tree: Any,
620
+ is_leaf: Callable[[Any], bool] | None = None
621
+ ) -> list[Any]:
622
+ # If prefix_tree is not a tree prefix of full_tree, this code can raise a
623
+ # ValueError; use prefix_errors to find disagreements and raise more precise
624
+ # error messages.
625
+ result = []
626
+ num_leaves = lambda t: jax.tree.structure(t).num_leaves
627
+ add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
628
+ jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
629
+ return result
630
+
631
+
632
+ def _flat_axes_specs(
633
+ abstracted_axes, *args, **kwargs
634
+ ) -> list[pe.AbstractedAxesSpec]:
635
+ if kwargs:
636
+ raise NotImplementedError
637
+
638
+ def ax_leaf(l):
639
+ return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
640
+ isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
641
+
642
+ return _broadcast_prefix(abstracted_axes, args, ax_leaf)
643
+
644
+
645
+ @transformation_with_aux
646
+ def _flatten_fun(in_tree, *args_flat):
647
+ py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
648
+ ans = yield py_args, py_kwargs
649
+ yield jax.tree.flatten(ans)
650
+
651
+
652
+ def _make_jaxpr(
653
+ fun: Callable,
654
+ static_argnums: int | Iterable[int] = (),
655
+ axis_env: Sequence[tuple[AxisName, int]] | None = None,
656
+ return_shape: bool = False,
657
+ abstracted_axes: Any | None = None,
658
+ ) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
659
+ """Creates a function that produces its jaxpr given example args.
660
+
661
+ Args:
662
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
663
+ arguments and return value should be arrays, scalars, or standard Python
664
+ containers (tuple/list/dict) thereof.
665
+ static_argnums: See the :py:func:`jax.jit` docstring.
666
+ axis_env: Optional, a sequence of pairs where the first element is an axis
667
+ name and the second element is a positive integer representing the size of
668
+ the mapped axis with that name. This parameter is useful when lowering
669
+ functions that involve parallel communication collectives, and it
670
+ specifies the axis name/size environment that would be set up by
671
+ applications of :py:func:`jax.pmap`.
672
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
673
+ wrapped function returns a pair where the first element is the
674
+ ``ClosedJaxpr`` representation of ``fun`` and the second element is a
675
+ pytree with the same structure as the output of ``fun`` and where the
676
+ leaves are objects with ``shape``, ``dtype``, and ``named_shape``
677
+ attributes representing the corresponding types of the output leaves.
678
+
679
+ Returns:
680
+ A wrapped version of ``fun`` that when applied to example arguments returns
681
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
682
+ argument ``return_shape`` is ``True``, then the returned function instead
683
+ returns a pair where the first element is the ``ClosedJaxpr``
684
+ representation of ``fun`` and the second element is a pytree representing
685
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
686
+
687
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
688
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
689
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
690
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
691
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
692
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
693
+
694
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
695
+ instead give a few examples.
696
+
697
+ >>> import jax
698
+ >>>
699
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
700
+ >>> print(f(3.0))
701
+ -0.83602
702
+ >>> _make_jaxpr(f)(3.0)
703
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
704
+ >>> _make_jaxpr(jax.grad(f))(3.0)
705
+ { lambda ; a:f32[]. let
706
+ b:f32[] = cos a
707
+ c:f32[] = sin a
708
+ _:f32[] = sin b
709
+ d:f32[] = cos b
710
+ e:f32[] = mul 1.0 d
711
+ f:f32[] = neg e
712
+ g:f32[] = mul f c
713
+ in (g,) }
714
+ """
715
+ _check_callable(fun)
716
+ static_argnums = _ensure_index_tuple(static_argnums)
717
+
718
+ def _abstractify(args, kwargs):
719
+ flat_args, in_tree = jax.tree.flatten((args, kwargs))
720
+ if abstracted_axes is None:
721
+ return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
722
+ else:
723
+ axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
724
+ in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
725
+ in_avals, keep_inputs = jax.util.unzip2(in_type)
726
+ return in_avals, in_tree, keep_inputs
727
+
728
+ @wraps(fun)
729
+ @api_boundary
730
+ def make_jaxpr_f(*args, **kwargs):
731
+ f = wrap_init(fun)
732
+ if static_argnums:
733
+ dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
734
+ f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
735
+ in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
736
+ in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
737
+ f, out_tree = _flatten_fun(f, in_tree)
738
+ f = annotate(f, in_type)
739
+ debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
740
+ with ExitStack() as stack:
741
+ for axis_name, size in axis_env or []:
742
+ stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
743
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
744
+ closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
745
+ if return_shape:
746
+ out_avals, _ = jax.util.unzip2(out_type)
747
+ out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
748
+ return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
749
+ return closed_jaxpr
750
+
751
+ make_jaxpr_f.__module__ = "brainstate.compile"
752
+ if hasattr(fun, "__qualname__"):
753
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
754
+ if hasattr(fun, "__name__"):
755
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
756
+ return make_jaxpr_f