brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax.numpy as jnp
19
+ import jax.stages
20
+
21
+ import brainstate as bc
22
+
23
+
24
+ class TestJIT(unittest.TestCase):
25
+ def test_inner_state_are_not_catched(self):
26
+ a = bc.State(bc.random.randn(10))
27
+
28
+ @bc.transform.jit
29
+ def fun1(inp):
30
+ a.value += inp
31
+
32
+ b = bc.State(bc.random.randn(1))
33
+
34
+ def inner_fun(x):
35
+ b.value += x
36
+
37
+ bc.transform.for_loop(inner_fun, bc.random.randn(100))
38
+
39
+ return a.value + b.value
40
+
41
+ print(fun1(1.))
42
+ key = fun1.stateful_fun.get_arg_cache_key(1.)
43
+ self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
44
+
45
+ x = bc.random.randn(10)
46
+ print(fun1(x))
47
+ key = fun1.stateful_fun.get_arg_cache_key(x)
48
+ self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
49
+
50
+ def test_jit_compile_sensitive_to_input_shape(self):
51
+ global_data = [0]
52
+
53
+ @bc.transform.jit
54
+ def fun1(inp):
55
+ global_data[0] += 1
56
+ return inp
57
+
58
+ print(fun1(1.))
59
+ self.assertTrue(global_data[0] == 1)
60
+
61
+ print(fun1(2.))
62
+ self.assertTrue(global_data[0] == 1)
63
+
64
+ print(fun1(bc.random.randn(10)))
65
+ self.assertTrue(global_data[0] == 2)
66
+
67
+ print(fun1(bc.random.randn(10, 10)))
68
+ self.assertTrue(global_data[0] == 3)
69
+
70
+ def test_jit_clear_cache(self):
71
+ a = bc.State(bc.random.randn(1))
72
+ compiling = []
73
+
74
+ @bc.transform.jit
75
+ def log2(x):
76
+ print('compiling')
77
+ compiling.append(1)
78
+ ln_x = jnp.log(x)
79
+ ln_2 = jnp.log(2.0) + a.value
80
+ return ln_x / ln_2
81
+
82
+ x = bc.random.randn(1)
83
+ print(log2(x)) # compiling
84
+ self.assertTrue(len(compiling) == 1)
85
+ print(log2(x)) # no compiling
86
+ self.assertTrue(len(compiling) == 1)
87
+
88
+ log2.clear_cache()
89
+ print(log2(x)) # compiling
90
+ self.assertTrue(len(compiling) == 2)
91
+
92
+ def test_jit_attribute_origin_fun(self):
93
+
94
+ def fun1(x):
95
+ return x
96
+
97
+ jitted_fun = bc.transform.jit(fun1)
98
+ self.assertTrue(jitted_fun.origin_fun is fun1)
99
+ self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
100
+ self.assertTrue(callable(jitted_fun.jitted_fun))
101
+ self.assertTrue(callable(jitted_fun.clear_cache))
102
+
@@ -0,0 +1,573 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
+ written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
+ include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
+
21
+
22
+ ``StatefulFunction``
23
+ --------------------
24
+
25
+ The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
+ JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
+ The class provides the following methods:
28
+
29
+ - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
+ - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
+ - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
+ - `get_states`: returns the states that are read and written by the function.
33
+ - `get_read_states`: returns the states that are read by the function.
34
+ - `get_write_states`: returns the states that are written by the function.
35
+ - `get_static_args`: returns the static arguments from the arguments.
36
+ - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
+ written by the function.
38
+ - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
+ - `get_out_shapes`: returns the output shapes of the function.
40
+ - `get_out_treedef`: returns the output tree of the function.
41
+
42
+ ``make_jaxpr``
43
+ --------------
44
+
45
+ The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
+ arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
+ `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
+ returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
+ and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
+ function.
51
+
52
+ """
53
+
54
+ from __future__ import annotations
55
+
56
+ import functools
57
+ import operator
58
+ from collections.abc import Hashable, Iterable, Sequence
59
+ from typing import Any, Callable, Tuple, Union, Dict, Optional
60
+
61
+ import jax
62
+ from jax._src import source_info_util
63
+ from jax.interpreters import partial_eval as pe
64
+ from jax.util import wraps
65
+ from jax.interpreters.xla import abstractify
66
+
67
+ from brainstate._state import State, StateTrace
68
+ from brainstate._utils import set_module_as
69
+
70
+ PyTree = Any
71
+
72
+ __all__ = [
73
+ "StatefulFunction",
74
+ "make_jaxpr",
75
+ ]
76
+
77
+
78
+ def _assign_state_values(states, state_vals) -> None:
79
+ """
80
+ Assign the state values to the states.
81
+
82
+ Args:
83
+ states: The states.
84
+ state_vals: The state values.
85
+ """
86
+ assert len(states) == len(state_vals), f'State length mismatch. {len(states)} != {len(state_vals)}.'
87
+ for st, val in zip(states, state_vals):
88
+ st.value = val
89
+
90
+
91
+ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
92
+ """Convert x to a tuple of indices."""
93
+ x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
94
+ try:
95
+ return (operator.index(x),)
96
+ except TypeError:
97
+ return tuple(jax.util.safe_map(operator.index, x))
98
+
99
+
100
+ def _new_arg(frame, trace, aval):
101
+ """
102
+ Transform a new argument to a tracer.
103
+
104
+ Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
105
+
106
+ Args:
107
+ frame: The frame.
108
+ trace: The trace.
109
+ aval: The abstract value.
110
+
111
+ Returns:
112
+ The tracer.
113
+ """
114
+ tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
115
+ frame.tracers.append(tracer)
116
+ frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
117
+ frame.invars.append(var)
118
+ return tracer
119
+
120
+
121
+ def wrapped_abstractify(x: Any) -> Any:
122
+ """
123
+ Abstractify the input.
124
+
125
+ Args:
126
+ x: The input.
127
+
128
+ Returns:
129
+ The abstractified input.
130
+ """
131
+ if isinstance(x, pe.DynamicJaxprTracer):
132
+ return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
133
+ return abstractify(x)
134
+
135
+
136
+ class StatefulFunction(object):
137
+ """
138
+ A wrapper class for a function that collects the states that are read and written by the function. The states are
139
+ collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
140
+ manage the states in the JAX program. The class provides a function called `states` that returns the states
141
+ that are read and written by the function. The class provides a function called `to_state_manager` that returns
142
+ a StateDictManager instance that contains the states that are read and written by the function. The class provides
143
+ a function called `__call__` that wraps the function and returns the states that are read and written by the
144
+ function and the output of the function.
145
+
146
+ Args:
147
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
148
+ arguments and return value should be arrays, scalars, or standard Python
149
+ containers (tuple/list/dict) thereof.
150
+ static_argnums: See the :py:func:`jax.jit` docstring.
151
+ axis_env: Optional, a sequence of pairs where the first element is an axis
152
+ name and the second element is a positive integer representing the size of
153
+ the mapped axis with that name. This parameter is useful when lowering
154
+ functions that involve parallel communication collectives, and it
155
+ specifies the axis name/size environment that would be set up by
156
+ applications of :py:func:`jax.pmap`.
157
+ abstracted_axes: Optional, a pytree with the same structure as the input
158
+ arguments to ``fun``. The leaves of the pytree can be either None or a
159
+ dict with axis names as keys and integers as values. If the leaf is None,
160
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
161
+ the corresponding axis is abstracted, and the dict specifies the axis name
162
+ and size. The abstracted axes are used to infer the input type of the
163
+ function. If None, then all axes are abstracted.
164
+ state_returns: Optional, a string or a tuple of strings. The default is
165
+ ``('read', 'write')``. The strings specify the categories of states to be
166
+ returned by the wrapped function. The categories are ``'read'`` and
167
+ ``'write'``. If the category is ``'read'``, then the wrapped function
168
+ returns the states that are read by the function. If the category is
169
+ ``'write'``, then the wrapped function returns the states that are written
170
+ by the function. If the category is ``'read'`` and ``'write'``, then the
171
+ wrapped function returns both the read and write states.
172
+
173
+ """
174
+ __module__ = "brainstate.transform"
175
+
176
+ def __init__(
177
+ self,
178
+ fun: Callable,
179
+ static_argnums: Union[int, Iterable[int]] = (),
180
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
181
+ abstracted_axes: Optional[Any] = None,
182
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
183
+ cache_type: Optional[str] = None,
184
+ ):
185
+ # explicit parameters
186
+ self.fun = fun
187
+ self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
188
+ self.axis_env = axis_env
189
+ self.abstracted_axes = abstracted_axes
190
+ self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
191
+ assert cache_type in [None, 'jit']
192
+ self.cache_type = cache_type
193
+
194
+ # implicit parameters
195
+ self._jaxpr: Dict[Any, jax.core.ClosedJaxpr] = dict()
196
+ self._out_shapes: Dict[Any, PyTree] = dict()
197
+ self._jaxpr_out_tree: Dict[Any, PyTree] = dict()
198
+ self._state_trace: Dict[Any, StateTrace] = dict()
199
+
200
+ def __repr__(self) -> str:
201
+ return (f"{self.__class__.__name__}({self.fun}, "
202
+ f"static_argnums={self.static_argnums}, "
203
+ f"axis_env={self.axis_env}, "
204
+ f"abstracted_axes={self.abstracted_axes}, "
205
+ f"state_returns={self.state_returns})")
206
+
207
+ def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
208
+ """
209
+ Read the JAX Jaxpr representation of the function.
210
+
211
+ Args:
212
+ cache_key: The hashable key.
213
+
214
+ Returns:
215
+ The JAX Jaxpr representation of the function.
216
+ """
217
+ if cache_key not in self._jaxpr:
218
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
219
+ return self._jaxpr[cache_key]
220
+
221
+ def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
222
+ """
223
+ Read the output shapes of the function.
224
+
225
+ Args:
226
+ cache_key: The hashable key.
227
+
228
+ Returns:
229
+ The output shapes of the function.
230
+ """
231
+ if cache_key not in self._out_shapes:
232
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
233
+ return self._out_shapes[cache_key]
234
+
235
+ def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
236
+ """
237
+ Read the output tree of the function.
238
+
239
+ Args:
240
+ cache_key: The hashable key.
241
+
242
+ Returns:
243
+ The output tree of the function.
244
+ """
245
+ if cache_key not in self._jaxpr_out_tree:
246
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
247
+ return self._jaxpr_out_tree[cache_key]
248
+
249
+ def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
250
+ """
251
+ Read the states that are read and written by the function.
252
+
253
+ Args:
254
+ cache_key: The hashable key.
255
+
256
+ Returns:
257
+ The states that are read and written by the function.
258
+ """
259
+ if cache_key not in self._state_trace:
260
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
261
+ return tuple(self._state_trace[cache_key].states)
262
+
263
+ def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
264
+ """
265
+ Read the states that are read by the function.
266
+
267
+ Args:
268
+ cache_key: The hashable key.
269
+
270
+ Returns:
271
+ The states that are read by the function.
272
+ """
273
+ _state_trace = self._state_trace[cache_key]
274
+ return tuple([st for st, ty in zip(_state_trace.states, _state_trace.types) if ty == 'read'])
275
+
276
+ def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
277
+ """
278
+ Read the states that are written by the function.
279
+
280
+ Args:
281
+ cache_key: The hashable key.
282
+
283
+ Returns:
284
+ The states that are written by the function.
285
+ """
286
+ state_trace = self._state_trace[cache_key]
287
+ return tuple([st for st, ty in zip(state_trace.states, state_trace.types) if ty == 'write'])
288
+
289
+ def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
290
+ """
291
+ Get the static arguments from the arguments.
292
+
293
+ Args:
294
+ *args: The arguments to the function.
295
+
296
+ Returns:
297
+ The static arguments.
298
+ """
299
+ if self.cache_type == 'jit':
300
+ static_args, dyn_args = [], []
301
+ for i, arg in enumerate(args):
302
+ if i in self.static_argnums:
303
+ static_args.append(arg)
304
+ else:
305
+ dyn_args.append(arg)
306
+ dyn_args = jax.tree.map(wrapped_abstractify, jax.tree.leaves(dyn_args))
307
+ dyn_kwargs = jax.tree.map(wrapped_abstractify, jax.tree.leaves(kwargs))
308
+ return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
309
+ elif self.cache_type is None:
310
+ num_arg = len(args)
311
+ return tuple(args[i] for i in self.static_argnums if i < num_arg)
312
+ else:
313
+ raise ValueError(f"Invalid cache type: {self.cache_type}")
314
+
315
+ def compile_and_get_states_by_static_args(self, *args, **kwargs) -> Tuple[State, ...]:
316
+ """
317
+ Get the states that are read and written by the function.
318
+
319
+ Args:
320
+ *args: The arguments to the function.
321
+ **kwargs: The keyword arguments to the function.
322
+
323
+ Returns:
324
+ The states that are read and written by the function.
325
+ """
326
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
327
+ if cache_key not in self._state_trace:
328
+ self.make_jaxpr(*args, **kwargs)
329
+ return self.get_states(cache_key)
330
+
331
+ def clear_cache(self) -> None:
332
+ """
333
+ Clear the compilation cache.
334
+ """
335
+ self._jaxpr.clear()
336
+ self._out_shapes.clear()
337
+ self._jaxpr_out_tree.clear()
338
+ self._state_trace.clear()
339
+
340
+ @staticmethod
341
+ def _init_trace_and_newarg() -> StateTrace:
342
+ # Should be within the calling of ``jax.make_jaxpr()``
343
+ state_trace: StateTrace = StateTrace()
344
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
345
+ frame = main.jaxpr_stack[-1]
346
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
347
+ state_trace.set_new_arg(functools.partial(_new_arg, frame, trace))
348
+ return state_trace
349
+
350
+ def _wrapped_fun_to_eval(self, cache_key, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
351
+ """
352
+ Wrap the function and return the states that are read and written by the function and the output of the function.
353
+
354
+ Args:
355
+ *args: The arguments to the function.
356
+ **kwargs: The keyword arguments to the function.
357
+
358
+ Returns:
359
+ A tuple of the states that are read and written by the function and the output of the function.
360
+ """
361
+ # state trace
362
+ _state_trace = self._init_trace_and_newarg()
363
+ self._state_trace[cache_key] = _state_trace
364
+ with _state_trace:
365
+ out = self.fun(*args, **kwargs)
366
+ state_values = _state_trace.collect_values('read', 'write')
367
+ _state_trace.recovery_original_values()
368
+
369
+ # return states is not allowed
370
+ # checking whether the states are returned
371
+ for leaf in jax.tree.leaves(out):
372
+ if isinstance(leaf, State):
373
+ leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
374
+ return out, state_values
375
+
376
+ def make_jaxpr(self, *args, **kwargs):
377
+ """Creates a function that produces its jaxpr given example args.
378
+
379
+ A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
380
+ argument ``return_shape`` is ``True``, then the returned function instead
381
+ returns a pair where the first element is the ``ClosedJaxpr``
382
+ representation of ``fun`` and the second element is a pytree representing
383
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
384
+
385
+ Args:
386
+ *args: The arguments to the function.
387
+ **kwargs: The keyword arguments to the function.
388
+ """
389
+
390
+ # static args
391
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
392
+
393
+ if cache_key not in self._state_trace:
394
+ try:
395
+ # jaxpr
396
+ jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
397
+ functools.partial(self._wrapped_fun_to_eval, cache_key),
398
+ static_argnums=self.static_argnums,
399
+ axis_env=self.axis_env,
400
+ return_shape=True,
401
+ abstracted_axes=self.abstracted_axes
402
+ )(*args, **kwargs)
403
+
404
+ # returns
405
+ self._jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
406
+ self._out_shapes[cache_key] = (out_shapes, state_shapes)
407
+ self._jaxpr[cache_key] = jaxpr
408
+ except Exception as e:
409
+ try:
410
+ self._state_trace.pop(cache_key)
411
+ except KeyError:
412
+ pass
413
+ raise e
414
+
415
+ return self
416
+
417
+ def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
418
+ """
419
+ Call the function at the JAX Jaxpr level.
420
+
421
+ Args:
422
+ state_vals: The state values.
423
+ *args: The arguments to the function.
424
+ **kwargs: The keyword arguments to the function.
425
+
426
+ Returns:
427
+ State values and the function output.
428
+ """
429
+ # state checking
430
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
431
+ states = self.get_states(cache_key)
432
+ assert len(state_vals) == len(states), 'State length mismatch.'
433
+
434
+ # parameters
435
+ args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
436
+ args = jax.tree.flatten((args, kwargs, state_vals))[0]
437
+
438
+ # calling the function
439
+ closed_jaxpr = self.get_jaxpr(cache_key)
440
+ out_treedef = self.get_out_treedef(cache_key)
441
+ jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
442
+
443
+ # output processing
444
+ out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
445
+ assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
446
+ return new_state_vals, out
447
+
448
+ def jaxpr_call_auto(self, *args, **kwargs) -> Any:
449
+ """
450
+ Call the function at the JAX Jaxpr level with automatic state management.
451
+
452
+ Args:
453
+ *args: The arguments to the function.
454
+ **kwargs: The keyword arguments to the function.
455
+
456
+ Returns:
457
+ The output of the function.
458
+ """
459
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
460
+ states = self.get_states(cache_key)
461
+ state_vals, out = self.jaxpr_call([st.value for st in states], *args, **kwargs)
462
+ for st, val in zip(states, state_vals):
463
+ st.value = val
464
+ return out
465
+
466
+
467
+ @set_module_as("brainstate.transform")
468
+ def make_jaxpr(
469
+ fun: Callable,
470
+ static_argnums: Union[int, Iterable[int]] = (),
471
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
472
+ return_shape: bool = False,
473
+ abstracted_axes: Optional[Any] = None,
474
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
475
+ ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
476
+ Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
477
+ """Creates a function that produces its jaxpr given example args.
478
+
479
+ Args:
480
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
481
+ arguments and return value should be arrays, scalars, or standard Python
482
+ containers (tuple/list/dict) thereof.
483
+ static_argnums: See the :py:func:`jax.jit` docstring.
484
+ axis_env: Optional, a sequence of pairs where the first element is an axis
485
+ name and the second element is a positive integer representing the size of
486
+ the mapped axis with that name. This parameter is useful when lowering
487
+ functions that involve parallel communication collectives, and it
488
+ specifies the axis name/size environment that would be set up by
489
+ applications of :py:func:`jax.pmap`.
490
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
491
+ wrapped function returns a pair where the first element is the XLA
492
+ computation and the second element is a pytree with the same structure as
493
+ the output of ``fun`` and where the leaves are objects with ``shape``,
494
+ ``dtype``, and ``named_shape`` attributes representing the corresponding
495
+ types of the output leaves.
496
+ abstracted_axes: Optional, a pytree with the same structure as the input
497
+ arguments to ``fun``. The leaves of the pytree can be either None or a
498
+ dict with axis names as keys and integers as values. If the leaf is None,
499
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
500
+ the corresponding axis is abstracted, and the dict specifies the axis name
501
+ and size. The abstracted axes are used to infer the input type of the
502
+ function. If None, then all axes are abstracted.
503
+ state_returns: Optional, a string or a tuple of strings. The default is
504
+ ``('read', 'write')``. The strings specify the categories of states to be
505
+ returned by the wrapped function. The categories are ``'read'`` and
506
+ ``'write'``. If the category is ``'read'``, then the wrapped function
507
+ returns the states that are read by the function. If the category is
508
+ ``'write'``, then the wrapped function returns the states that are written
509
+ by the function. If the category is ``'read'`` and ``'write'``, then the
510
+ wrapped function returns both the read and write states.
511
+
512
+
513
+ Returns:
514
+ A wrapped version of ``fun`` that when applied to example arguments returns
515
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
516
+ argument ``return_shape`` is ``True``, then the returned function instead
517
+ returns a pair where the first element is the ``ClosedJaxpr``
518
+ representation of ``fun`` and the second element is a pytree representing
519
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
520
+
521
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
522
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
523
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
524
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
525
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
526
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
527
+
528
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
529
+ instead give a few examples.
530
+
531
+ >>> import jax
532
+ >>> import brainstate as bst
533
+ >>>
534
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
535
+ >>> print(f(3.0))
536
+ -0.83602
537
+ >>> jaxpr, states = bst.transform.make_jaxpr(f)(3.0)
538
+ >>> jaxpr
539
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
540
+ >>> jaxpr, states = bst.transform.make_jaxpr(jax.grad(f))(3.0)
541
+ >>> jaxpr
542
+ { lambda ; a:f32[]. let
543
+ b:f32[] = cos a
544
+ c:f32[] = sin a
545
+ _:f32[] = sin b
546
+ d:f32[] = cos b
547
+ e:f32[] = mul 1.0 d
548
+ f:f32[] = neg e
549
+ g:f32[] = mul f c
550
+ in (g,) }
551
+ """
552
+
553
+ stateful_fun = StatefulFunction(fun, static_argnums, axis_env, abstracted_axes, state_returns)
554
+
555
+ @wraps(fun)
556
+ def make_jaxpr_f(*args, **kwargs):
557
+ stateful_fun.make_jaxpr(*args, **kwargs)
558
+ cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
559
+ if return_shape:
560
+ return (stateful_fun.get_jaxpr(cache_key),
561
+ stateful_fun.get_states(cache_key),
562
+ stateful_fun.get_out_shapes(cache_key)[0])
563
+ else:
564
+ return (stateful_fun.get_jaxpr(cache_key),
565
+ stateful_fun.get_states(cache_key))
566
+
567
+ # wrapped jaxpr builder function
568
+ make_jaxpr_f.__module__ = "brainstate.transform"
569
+ if hasattr(fun, "__qualname__"):
570
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
571
+ if hasattr(fun, "__name__"):
572
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
573
+ return make_jaxpr_f