brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +588 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
  127. brainstate-0.1.10.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,888 +1,888 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
- written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
- include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
-
21
-
22
- ``StatefulFunction``
23
- --------------------
24
-
25
- The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
- JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
- The class provides the following methods:
28
-
29
- - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
- - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
- - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
- - `get_states`: returns the states that are read and written by the function.
33
- - `get_read_states`: returns the states that are read by the function.
34
- - `get_write_states`: returns the states that are written by the function.
35
- - `get_static_args`: returns the static arguments from the arguments.
36
- - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
- written by the function.
38
- - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
- - `get_out_shapes`: returns the output shapes of the function.
40
- - `get_out_treedef`: returns the output tree of the function.
41
-
42
- ``make_jaxpr``
43
- --------------
44
-
45
- The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
- arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
- `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
- returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
- and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
- function.
51
-
52
- """
53
-
54
- import functools
55
- import inspect
56
- import operator
57
- from collections.abc import Hashable, Iterable, Sequence
58
- from contextlib import ExitStack
59
- from typing import Any, Callable, Tuple, Union, Dict, Optional
60
-
61
- import jax
62
- from jax._src import source_info_util
63
- from jax._src.linear_util import annotate
64
- from jax._src.traceback_util import api_boundary
65
- from jax.api_util import shaped_abstractify
66
- from jax.extend.linear_util import transformation_with_aux
67
- from jax.interpreters import partial_eval as pe
68
-
69
- from brainstate._compatible_import import (
70
- ClosedJaxpr,
71
- extend_axis_env_nd,
72
- safe_map,
73
- safe_zip,
74
- unzip2,
75
- wraps,
76
- wrap_init,
77
- )
78
- from brainstate._state import State, StateTraceStack
79
- from brainstate._utils import set_module_as
80
- from brainstate.typing import PyTree
81
- from brainstate.util import PrettyObject
82
-
83
- AxisName = Hashable
84
-
85
- __all__ = [
86
- "StatefulFunction",
87
- "make_jaxpr",
88
- ]
89
-
90
-
91
- def _ensure_str(x: str) -> str:
92
- if not isinstance(x, str):
93
- raise TypeError(f"argument is not a string: {x}")
94
- return x
95
-
96
-
97
- def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
98
- """Convert x to a tuple of indices."""
99
- x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
100
- try:
101
- return (operator.index(x),)
102
- except TypeError:
103
- return tuple(safe_map(operator.index, x))
104
-
105
-
106
- def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
107
- """Convert x to a tuple of strings."""
108
- if isinstance(x, str):
109
- return (x,)
110
- else:
111
- return tuple(safe_map(_ensure_str, x))
112
-
113
-
114
- def _jax_v04_new_arg_fn(frame, trace, aval):
115
- """
116
- Transform a new argument to a tracer.
117
-
118
- Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
119
-
120
- Args:
121
- frame: The frame.
122
- trace: The trace.
123
- aval: The abstract value.
124
-
125
- Returns:
126
- The tracer.
127
- """
128
- tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
129
- frame.tracers.append(tracer)
130
- frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
131
- frame.invars.append(var)
132
- return tracer
133
-
134
-
135
- def _jax_v04_new_jax_trace():
136
- main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
137
- frame = main.jaxpr_stack[-1]
138
- trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
139
- return frame, trace
140
-
141
-
142
- def _jax_v04_new_arg():
143
- # Should be within the calling of ``jax.make_jaxpr()``
144
- frame, trace = _jax_v04_new_jax_trace()
145
- # Set the function to transform the new argument to a tracer
146
- fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
147
- return fn
148
-
149
-
150
- def _jax_new_version_new_arg():
151
- trace = jax.core.trace_ctx.trace
152
-
153
- def wrapper(x):
154
- if jax.__version_info__ < (0, 6, 1):
155
- return trace.new_arg(shaped_abstractify(x))
156
- else:
157
- return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
158
-
159
- return wrapper
160
-
161
-
162
- def _init_state_trace_stack(name) -> StateTraceStack:
163
- state_trace: StateTraceStack = StateTraceStack(name=name)
164
-
165
- if jax.__version_info__ < (0, 4, 36):
166
- state_trace.set_new_arg(_jax_v04_new_arg())
167
- else:
168
- state_trace.set_new_arg(_jax_new_version_new_arg())
169
- return state_trace
170
-
171
-
172
- default_cache_key = ((), ())
173
-
174
-
175
- class StatefulFunction(PrettyObject):
176
- """
177
- A wrapper class for a function that collects the states that are read and written by the function. The states are
178
- collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
179
- manage the states in the JAX program. The class provides a function called `states` that returns the states
180
- that are read and written by the function. The class provides a function called `to_state_manager` that returns
181
- a StateDictManager instance that contains the states that are read and written by the function. The class provides
182
- a function called `__call__` that wraps the function and returns the states that are read and written by the
183
- function and the output of the function.
184
-
185
- Args:
186
- fun: The function whose ``jaxpr`` is to be computed. Its positional
187
- arguments and return value should be arrays, scalars, or standard Python
188
- containers (tuple/list/dict) thereof.
189
- static_argnums: See the :py:func:`jax.jit` docstring.
190
- static_argnames: See the :py:func:`jax.jit` docstring.
191
- axis_env: Optional, a sequence of pairs where the first element is an axis
192
- name and the second element is a positive integer representing the size of
193
- the mapped axis with that name. This parameter is useful when lowering
194
- functions that involve parallel communication collectives, and it
195
- specifies the axis name/size environment that would be set up by
196
- applications of :py:func:`jax.pmap`.
197
- abstracted_axes: Optional, a pytree with the same structure as the input
198
- arguments to ``fun``. The leaves of the pytree can be either None or a
199
- dict with axis names as keys and integers as values. If the leaf is None,
200
- then the corresponding axis is not abstracted. If the leaf is a dict, then
201
- the corresponding axis is abstracted, and the dict specifies the axis name
202
- and size. The abstracted axes are used to infer the input type of the
203
- function. If None, then all axes are abstracted.
204
- state_returns: Optional, a string or a tuple of strings. The default is
205
- ``('read', 'write')``. The strings specify the categories of states to be
206
- returned by the wrapped function. The categories are ``'read'`` and
207
- ``'write'``. If the category is ``'read'``, then the wrapped function
208
- returns the states that are read by the function. If the category is
209
- ``'write'``, then the wrapped function returns the states that are written
210
- by the function. If the category is ``'read'`` and ``'write'``, then the
211
- wrapped function returns both the read and write states.
212
-
213
- """
214
- __module__ = "brainstate.compile"
215
-
216
- def __init__(
217
- self,
218
- fun: Callable,
219
- static_argnums: Union[int, Iterable[int]] = (),
220
- static_argnames: Union[str, Iterable[str]] = (),
221
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
222
- abstracted_axes: Optional[Any] = None,
223
- state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
224
- cache_type: Optional[str] = None,
225
- name: Optional[str] = None,
226
- ):
227
- # explicit parameters
228
- self.fun = fun
229
- self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
230
- self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
231
- self.axis_env = axis_env
232
- self.abstracted_axes = abstracted_axes
233
- self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
234
- assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
235
- self.name = name
236
-
237
- # implicit parameters
238
- self.cache_type = cache_type
239
- self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
240
- self._cached_out_shapes: Dict[Any, PyTree] = dict()
241
- self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
242
- self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
243
-
244
- def __pretty_repr_item__(self, k, v):
245
- if k.startswith('_'):
246
- return None
247
- return k, v
248
-
249
- def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
250
- """
251
- Read the JAX Jaxpr representation of the function.
252
-
253
- Args:
254
- cache_key: The hashable key.
255
-
256
- Returns:
257
- The JAX Jaxpr representation of the function.
258
- """
259
- if cache_key is None:
260
- cache_key = default_cache_key
261
- if cache_key not in self._cached_jaxpr:
262
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
263
- return self._cached_jaxpr[cache_key]
264
-
265
- def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
266
- """
267
- Read the output shapes of the function.
268
-
269
- Args:
270
- cache_key: The hashable key.
271
-
272
- Returns:
273
- The output shapes of the function.
274
- """
275
- if cache_key is None:
276
- cache_key = default_cache_key
277
- if cache_key not in self._cached_out_shapes:
278
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
279
- return self._cached_out_shapes[cache_key]
280
-
281
- def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
282
- """
283
- Read the output tree of the function.
284
-
285
- Args:
286
- cache_key: The hashable key.
287
-
288
- Returns:
289
- The output tree of the function.
290
- """
291
- if cache_key is None:
292
- cache_key = default_cache_key
293
- if cache_key not in self._cached_jaxpr_out_tree:
294
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
295
- return self._cached_jaxpr_out_tree[cache_key]
296
-
297
- def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
298
- """
299
- Read the state trace of the function.
300
-
301
- Args:
302
- cache_key: The hashable key.
303
-
304
- Returns:
305
- The state trace of the function.
306
- """
307
- if cache_key is None:
308
- cache_key = default_cache_key
309
- if cache_key not in self._cached_state_trace:
310
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
311
- return self._cached_state_trace[cache_key]
312
-
313
- def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
314
- """
315
- Read the states that are read and written by the function.
316
-
317
- Args:
318
- cache_key: The hashable key.
319
-
320
- Returns:
321
- The states that are read and written by the function.
322
- """
323
- if cache_key is None:
324
- cache_key = default_cache_key
325
- return tuple(self.get_state_trace(cache_key).states)
326
-
327
- def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
328
- """
329
- Read the states that are read by the function.
330
-
331
- Args:
332
- cache_key: The hashable key.
333
-
334
- Returns:
335
- The states that are read by the function.
336
- """
337
- if cache_key is None:
338
- cache_key = default_cache_key
339
- return self.get_state_trace(cache_key).get_read_states()
340
-
341
- def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
342
- """
343
- Read the states that are written by the function.
344
-
345
- Args:
346
- cache_key: The hashable key.
347
-
348
- Returns:
349
- The states that are written by the function.
350
- """
351
- if cache_key is None:
352
- cache_key = default_cache_key
353
- return self.get_state_trace(cache_key).get_write_states()
354
-
355
- def _check_input_ouput(self, x):
356
- if isinstance(x, State):
357
- x.raise_error_with_source_info(
358
- ValueError(
359
- 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
360
- f'But we got {x}'
361
- )
362
- )
363
-
364
- def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
365
- """
366
- Get the static arguments from the arguments.
367
-
368
- Args:
369
- *args: The arguments to the function.
370
- **kwargs: The keyword arguments to the function.
371
-
372
- Returns:
373
- The static arguments and keyword arguments as a tuple.
374
- """
375
- if self.cache_type == 'jit':
376
- static_args, dyn_args = [], []
377
- for i, arg in enumerate(args):
378
- if i in self.static_argnums:
379
- static_args.append(arg)
380
- else:
381
- dyn_args.append(arg)
382
- dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
383
- static_kwargs, dyn_kwargs = [], []
384
- for k, v in kwargs.items():
385
- if k in self.static_argnames:
386
- static_kwargs.append((k, v))
387
- else:
388
- dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
389
-
390
- static_args = make_hashable(tuple(static_args))
391
- dyn_args = make_hashable(tuple(dyn_args))
392
- static_kwargs = make_hashable(static_kwargs)
393
- dyn_kwargs = make_hashable(dyn_kwargs)
394
-
395
- cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
396
- elif self.cache_type is None:
397
- num_arg = len(args)
398
- static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
399
- static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
400
-
401
- # Make everything hashable
402
- static_args = make_hashable(static_args)
403
- static_kwargs = make_hashable(static_kwargs)
404
-
405
- cache_key = (static_args, static_kwargs)
406
- else:
407
- raise ValueError(f"Invalid cache type: {self.cache_type}")
408
-
409
- return cache_key
410
-
411
- def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
412
- """
413
- Compile the function, and get the states that are read and written by this function.
414
-
415
- Args:
416
- *args: The arguments to the function.
417
- **kwargs: The keyword arguments to the function.
418
-
419
- Returns:
420
- The states that are read and written by the function.
421
- """
422
- cache_key = self.get_arg_cache_key(*args, **kwargs)
423
- if cache_key not in self._cached_state_trace:
424
- self.make_jaxpr(*args, **kwargs)
425
- return self.get_states(cache_key)
426
-
427
- def compile_function_and_get_state_trace(
428
- self, *args, return_only_write: bool = False, **kwargs
429
- ) -> StateTraceStack:
430
- """
431
- Compile the function, and get the states that are read and written by this function.
432
-
433
- Args:
434
- *args: The arguments to the function.
435
- **kwargs: The keyword arguments to the function.
436
- return_only_write: If True, only return the states that are written by the function.
437
-
438
- Returns:
439
- The state trace stack.
440
- """
441
- cache_key = self.get_arg_cache_key(*args, **kwargs)
442
- if cache_key not in self._cached_state_trace:
443
- self.make_jaxpr(*args, **kwargs, return_only_write=return_only_write)
444
- return self.get_state_trace(cache_key)
445
-
446
- def clear_cache(self) -> None:
447
- """
448
- Clear the compilation cache.
449
- """
450
- self._cached_jaxpr.clear()
451
- self._cached_out_shapes.clear()
452
- self._cached_jaxpr_out_tree.clear()
453
- self._cached_state_trace.clear()
454
-
455
- def _wrapped_fun_to_eval(
456
- self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
457
- ) -> Tuple[Any, Tuple[State, ...]]:
458
- """
459
- Wrap the function and return the states that are read and written by the function and the output of the function.
460
-
461
- Args:
462
- *args: The arguments to the function.
463
- **kwargs: The keyword arguments to the function.
464
-
465
- Returns:
466
- A tuple of the states that are read and written by the function and the output of the function.
467
- """
468
- # state trace
469
- state_trace = _init_state_trace_stack(self.name)
470
- self._cached_state_trace[cache_key] = state_trace
471
- with state_trace:
472
- out = self.fun(*args, **dyn_kwargs, **static_kwargs)
473
- state_values = (
474
- state_trace.get_write_state_values(True)
475
- if return_only_write else
476
- state_trace.get_state_values()
477
- )
478
- state_trace.recovery_original_values()
479
-
480
- # State instance as functional returns is not allowed.
481
- # Checking whether the states are returned.
482
- jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
483
- return out, state_values
484
-
485
- def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
486
- """Creates a function that produces its jaxpr given example args.
487
-
488
- A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
489
- argument ``return_shape`` is ``True``, then the returned function instead
490
- returns a pair where the first element is the ``ClosedJaxpr``
491
- representation of ``fun`` and the second element is a pytree representing
492
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
493
-
494
- Args:
495
- *args: The arguments to the function.
496
- **kwargs: The keyword arguments to the function.
497
- return_only_write: If True, only return the states that are written by the function.
498
- """
499
-
500
- # static args
501
- cache_key = self.get_arg_cache_key(*args, **kwargs)
502
-
503
- # check input types
504
- jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
505
-
506
- if cache_key not in self._cached_state_trace:
507
- try:
508
- # jaxpr
509
- static_kwargs, dyn_kwargs = {}, {}
510
- for k, v in kwargs.items():
511
- if k in self.static_argnames:
512
- static_kwargs[k] = v
513
- else:
514
- dyn_kwargs[k] = v
515
- jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
516
- functools.partial(
517
- self._wrapped_fun_to_eval,
518
- cache_key,
519
- static_kwargs,
520
- return_only_write=return_only_write
521
- ),
522
- static_argnums=self.static_argnums,
523
- axis_env=self.axis_env,
524
- return_shape=True,
525
- abstracted_axes=self.abstracted_axes
526
- )(*args, **dyn_kwargs)
527
- # returns
528
- self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
529
- self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
530
- self._cached_jaxpr[cache_key] = jaxpr
531
-
532
- except Exception as e:
533
- try:
534
- self._cached_state_trace.pop(cache_key)
535
- except KeyError:
536
- pass
537
- raise e
538
-
539
- return self
540
-
541
- def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
542
- """
543
- Call the function at the JAX Jaxpr level.
544
-
545
- Args:
546
- state_vals: The state values.
547
- *args: The arguments to the function.
548
- **kwargs: The keyword arguments to the function.
549
-
550
- Returns:
551
- State values and the function output.
552
- """
553
- # state checking
554
- cache_key = self.get_arg_cache_key(*args, **kwargs)
555
- states: Sequence[State] = self.get_states(cache_key)
556
- assert len(state_vals) == len(states), 'State length mismatch.'
557
-
558
- # parameters
559
- kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
560
- args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
561
- args = jax.tree.flatten((args, kwargs, state_vals))[0]
562
-
563
- # calling the function,
564
- # note that this function always returns state values
565
- # that both write and read by the function
566
- closed_jaxpr = self.get_jaxpr(cache_key)
567
- out_treedef = self.get_out_treedef(cache_key)
568
- jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
569
-
570
- # output processing
571
- out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
572
- assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
573
- return new_state_vals, out
574
-
575
- def jaxpr_call_auto(self, *args, **kwargs) -> Any:
576
- """
577
- Call the function at the JAX Jaxpr level with automatic state management.
578
-
579
- Args:
580
- *args: The arguments to the function.
581
- **kwargs: The keyword arguments to the function.
582
-
583
- Returns:
584
- The output of the function.
585
- """
586
- state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
587
- state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
588
- state_trace.assign_state_vals(state_vals)
589
- return out
590
-
591
-
592
- @set_module_as("brainstate.compile")
593
- def make_jaxpr(
594
- fun: Callable,
595
- static_argnums: Union[int, Iterable[int]] = (),
596
- static_argnames: Union[str, Iterable[str]] = (),
597
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
598
- return_shape: bool = False,
599
- abstracted_axes: Optional[Any] = None,
600
- state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
601
- ) -> Callable[
602
- ...,
603
- (Tuple[ClosedJaxpr, Tuple[State, ...]] |
604
- Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
605
- ]:
606
- """
607
- Creates a function that produces its jaxpr given example args.
608
-
609
- Args:
610
- fun: The function whose ``jaxpr`` is to be computed. Its positional
611
- arguments and return value should be arrays, scalars, or standard Python
612
- containers (tuple/list/dict) thereof.
613
- static_argnums: See the :py:func:`jax.jit` docstring.
614
- static_argnames: See the :py:func:`jax.jit` docstring.
615
- axis_env: Optional, a sequence of pairs where the first element is an axis
616
- name and the second element is a positive integer representing the size of
617
- the mapped axis with that name. This parameter is useful when lowering
618
- functions that involve parallel communication collectives, and it
619
- specifies the axis name/size environment that would be set up by
620
- applications of :py:func:`jax.pmap`.
621
- return_shape: Optional boolean, defaults to ``False``. If ``True``, the
622
- wrapped function returns a pair where the first element is the XLA
623
- computation and the second element is a pytree with the same structure as
624
- the output of ``fun`` and where the leaves are objects with ``shape``,
625
- ``dtype``, and ``named_shape`` attributes representing the corresponding
626
- types of the output leaves.
627
- abstracted_axes: Optional, a pytree with the same structure as the input
628
- arguments to ``fun``. The leaves of the pytree can be either None or a
629
- dict with axis names as keys and integers as values. If the leaf is None,
630
- then the corresponding axis is not abstracted. If the leaf is a dict, then
631
- the corresponding axis is abstracted, and the dict specifies the axis name
632
- and size. The abstracted axes are used to infer the input type of the
633
- function. If None, then all axes are abstracted.
634
- state_returns: Optional, a string or a tuple of strings. The default is
635
- ``('read', 'write')``. The strings specify the categories of states to be
636
- returned by the wrapped function. The categories are ``'read'`` and
637
- ``'write'``. If the category is ``'read'``, then the wrapped function
638
- returns the states that are read by the function. If the category is
639
- ``'write'``, then the wrapped function returns the states that are written
640
- by the function. If the category is ``'read'`` and ``'write'``, then the
641
- wrapped function returns both the read and write states.
642
-
643
-
644
- Returns:
645
- A wrapped version of ``fun`` that when applied to example arguments returns
646
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
647
- argument ``return_shape`` is ``True``, then the returned function instead
648
- returns a pair where the first element is the ``ClosedJaxpr``
649
- representation of ``fun`` and the second element is a pytree representing
650
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
651
-
652
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
653
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
654
- with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
655
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
656
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
657
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
658
-
659
- We do not describe the semantics of the ``jaxpr`` language in detail here, but
660
- instead give a few examples.
661
-
662
- >>> import jax
663
- >>> import brainstate as brainstate
664
- >>>
665
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
666
- >>> print(f(3.0))
667
- -0.83602
668
- >>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
669
- >>> jaxpr
670
- { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
671
- >>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
672
- >>> jaxpr
673
- { lambda ; a:f32[]. let
674
- b:f32[] = cos a
675
- c:f32[] = sin a
676
- _:f32[] = sin b
677
- d:f32[] = cos b
678
- e:f32[] = mul 1.0 d
679
- f:f32[] = neg e
680
- g:f32[] = mul f c
681
- in (g,) }
682
- """
683
-
684
- stateful_fun = StatefulFunction(
685
- fun,
686
- static_argnums=static_argnums,
687
- static_argnames=static_argnames,
688
- axis_env=axis_env,
689
- abstracted_axes=abstracted_axes,
690
- state_returns=state_returns,
691
- name='make_jaxpr'
692
- )
693
-
694
- @wraps(fun)
695
- def make_jaxpr_f(*args, **kwargs):
696
- stateful_fun.make_jaxpr(*args, **kwargs)
697
- cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
698
- if return_shape:
699
- return (stateful_fun.get_jaxpr(cache_key),
700
- stateful_fun.get_states(cache_key),
701
- stateful_fun.get_out_shapes(cache_key)[0])
702
- else:
703
- return (stateful_fun.get_jaxpr(cache_key),
704
- stateful_fun.get_states(cache_key))
705
-
706
- # wrapped jaxpr builder function
707
- make_jaxpr_f.__module__ = "brainstate.compile"
708
- if hasattr(fun, "__qualname__"):
709
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
710
- if hasattr(fun, "__name__"):
711
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
712
- return make_jaxpr_f
713
-
714
-
715
- def _check_callable(fun):
716
- # In Python 3.10+, the only thing stopping us from supporting staticmethods
717
- # is that we can't take weak references to them, which the C++ JIT requires.
718
- if isinstance(fun, staticmethod):
719
- raise TypeError(f"staticmethod arguments are not supported, got {fun}")
720
- if not callable(fun):
721
- raise TypeError(f"Expected a callable value, got {fun}")
722
- if inspect.isgeneratorfunction(fun):
723
- raise TypeError(f"Expected a function, got a generator function: {fun}")
724
-
725
-
726
- def _broadcast_prefix(
727
- prefix_tree: Any,
728
- full_tree: Any,
729
- is_leaf: Callable[[Any], bool] | None = None
730
- ) -> list[Any]:
731
- # If prefix_tree is not a tree prefix of full_tree, this code can raise a
732
- # ValueError; use prefix_errors to find disagreements and raise more precise
733
- # error messages.
734
- result = []
735
- num_leaves = lambda t: jax.tree.structure(t).num_leaves
736
- add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
737
- jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
738
- return result
739
-
740
-
741
- def _flat_axes_specs(
742
- abstracted_axes, *args, **kwargs
743
- ) -> list[pe.AbstractedAxesSpec]:
744
- if kwargs:
745
- raise NotImplementedError
746
-
747
- def ax_leaf(l):
748
- return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
749
- isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
750
-
751
- return _broadcast_prefix(abstracted_axes, args, ax_leaf)
752
-
753
-
754
- @transformation_with_aux
755
- def _flatten_fun(in_tree, *args_flat):
756
- py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
757
- ans = yield py_args, py_kwargs
758
- yield jax.tree.flatten(ans)
759
-
760
-
761
- def _make_jaxpr(
762
- fun: Callable,
763
- static_argnums: int | Iterable[int] = (),
764
- axis_env: Sequence[tuple[AxisName, int]] | None = None,
765
- return_shape: bool = False,
766
- abstracted_axes: Any | None = None,
767
- ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
768
- """Creates a function that produces its jaxpr given example args.
769
-
770
- Args:
771
- fun: The function whose ``jaxpr`` is to be computed. Its positional
772
- arguments and return value should be arrays, scalars, or standard Python
773
- containers (tuple/list/dict) thereof.
774
- static_argnums: See the :py:func:`jax.jit` docstring.
775
- axis_env: Optional, a sequence of pairs where the first element is an axis
776
- name and the second element is a positive integer representing the size of
777
- the mapped axis with that name. This parameter is useful when lowering
778
- functions that involve parallel communication collectives, and it
779
- specifies the axis name/size environment that would be set up by
780
- applications of :py:func:`jax.pmap`.
781
- return_shape: Optional boolean, defaults to ``False``. If ``True``, the
782
- wrapped function returns a pair where the first element is the
783
- ``ClosedJaxpr`` representation of ``fun`` and the second element is a
784
- pytree with the same structure as the output of ``fun`` and where the
785
- leaves are objects with ``shape``, ``dtype``, and ``named_shape``
786
- attributes representing the corresponding types of the output leaves.
787
-
788
- Returns:
789
- A wrapped version of ``fun`` that when applied to example arguments returns
790
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
791
- argument ``return_shape`` is ``True``, then the returned function instead
792
- returns a pair where the first element is the ``ClosedJaxpr``
793
- representation of ``fun`` and the second element is a pytree representing
794
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
795
-
796
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
797
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
798
- with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
799
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
800
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
801
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
802
-
803
- We do not describe the semantics of the ``jaxpr`` language in detail here, but
804
- instead give a few examples.
805
-
806
- >>> import jax
807
- >>>
808
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
809
- >>> print(f(3.0))
810
- -0.83602
811
- >>> _make_jaxpr(f)(3.0)
812
- { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
813
- >>> _make_jaxpr(jax.grad(f))(3.0)
814
- { lambda ; a:f32[]. let
815
- b:f32[] = cos a
816
- c:f32[] = sin a
817
- _:f32[] = sin b
818
- d:f32[] = cos b
819
- e:f32[] = mul 1.0 d
820
- f:f32[] = neg e
821
- g:f32[] = mul f c
822
- in (g,) }
823
- """
824
- _check_callable(fun)
825
- static_argnums = _ensure_index_tuple(static_argnums)
826
-
827
- def _abstractify(args, kwargs):
828
- flat_args, in_tree = jax.tree.flatten((args, kwargs))
829
- if abstracted_axes is None:
830
- return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
831
- else:
832
- axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
833
- in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
834
- in_avals, keep_inputs = unzip2(in_type)
835
- return in_avals, in_tree, keep_inputs
836
-
837
- @wraps(fun)
838
- @api_boundary
839
- def make_jaxpr_f(*args, **kwargs):
840
- f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
841
- if static_argnums:
842
- dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
843
- f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
844
- in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
845
- in_type = tuple(safe_zip(in_avals, keep_inputs))
846
- f, out_tree = _flatten_fun(f, in_tree)
847
- f = annotate(f, in_type)
848
- if jax.__version_info__ < (0, 5, 0):
849
- debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
850
- with ExitStack() as stack:
851
- if axis_env is not None:
852
- stack.enter_context(extend_axis_env_nd(axis_env))
853
- if jax.__version_info__ < (0, 5, 0):
854
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
855
- else:
856
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
857
- closed_jaxpr = ClosedJaxpr(jaxpr, consts)
858
- if return_shape:
859
- out_avals, _ = unzip2(out_type)
860
- out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
861
- return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
862
- return closed_jaxpr
863
-
864
- make_jaxpr_f.__module__ = "brainstate.compile"
865
- if hasattr(fun, "__qualname__"):
866
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
867
- if hasattr(fun, "__name__"):
868
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
869
- return make_jaxpr_f
870
-
871
-
872
- def make_hashable(obj):
873
- """Convert a pytree into a hashable representation."""
874
- if isinstance(obj, (list, tuple)):
875
- return tuple(make_hashable(item) for item in obj)
876
- elif isinstance(obj, dict):
877
- return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
878
- elif isinstance(obj, set):
879
- return frozenset(make_hashable(item) for item in obj)
880
- else:
881
- # # Use JAX's tree_util for any other pytree structures
882
- # try:
883
- # leaves, treedef = jax.tree_util.tree_flatten(obj)
884
- # hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
885
- # return (str(treedef), hashable_leaves)
886
- # except:
887
- # # Assume obj is already hashable
888
- return obj
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
+ written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
+ include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
+
21
+
22
+ ``StatefulFunction``
23
+ --------------------
24
+
25
+ The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
+ JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
+ The class provides the following methods:
28
+
29
+ - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
+ - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
+ - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
+ - `get_states`: returns the states that are read and written by the function.
33
+ - `get_read_states`: returns the states that are read by the function.
34
+ - `get_write_states`: returns the states that are written by the function.
35
+ - `get_static_args`: returns the static arguments from the arguments.
36
+ - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
+ written by the function.
38
+ - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
+ - `get_out_shapes`: returns the output shapes of the function.
40
+ - `get_out_treedef`: returns the output tree of the function.
41
+
42
+ ``make_jaxpr``
43
+ --------------
44
+
45
+ The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
+ arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
+ `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
+ returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
+ and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
+ function.
51
+
52
+ """
53
+
54
+ import functools
55
+ import inspect
56
+ import operator
57
+ from collections.abc import Hashable, Iterable, Sequence
58
+ from contextlib import ExitStack
59
+ from typing import Any, Callable, Tuple, Union, Dict, Optional
60
+
61
+ import jax
62
+ from jax._src import source_info_util
63
+ from jax._src.linear_util import annotate
64
+ from jax._src.traceback_util import api_boundary
65
+ from jax.api_util import shaped_abstractify
66
+ from jax.extend.linear_util import transformation_with_aux
67
+ from jax.interpreters import partial_eval as pe
68
+
69
+ from brainstate._compatible_import import (
70
+ ClosedJaxpr,
71
+ extend_axis_env_nd,
72
+ safe_map,
73
+ safe_zip,
74
+ unzip2,
75
+ wraps,
76
+ wrap_init,
77
+ )
78
+ from brainstate._state import State, StateTraceStack
79
+ from brainstate._utils import set_module_as
80
+ from brainstate.typing import PyTree
81
+ from brainstate.util import PrettyObject
82
+
83
+ AxisName = Hashable
84
+
85
+ __all__ = [
86
+ "StatefulFunction",
87
+ "make_jaxpr",
88
+ ]
89
+
90
+
91
+ def _ensure_str(x: str) -> str:
92
+ if not isinstance(x, str):
93
+ raise TypeError(f"argument is not a string: {x}")
94
+ return x
95
+
96
+
97
+ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
98
+ """Convert x to a tuple of indices."""
99
+ x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
100
+ try:
101
+ return (operator.index(x),)
102
+ except TypeError:
103
+ return tuple(safe_map(operator.index, x))
104
+
105
+
106
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
107
+ """Convert x to a tuple of strings."""
108
+ if isinstance(x, str):
109
+ return (x,)
110
+ else:
111
+ return tuple(safe_map(_ensure_str, x))
112
+
113
+
114
+ def _jax_v04_new_arg_fn(frame, trace, aval):
115
+ """
116
+ Transform a new argument to a tracer.
117
+
118
+ Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
119
+
120
+ Args:
121
+ frame: The frame.
122
+ trace: The trace.
123
+ aval: The abstract value.
124
+
125
+ Returns:
126
+ The tracer.
127
+ """
128
+ tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
129
+ frame.tracers.append(tracer)
130
+ frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
131
+ frame.invars.append(var)
132
+ return tracer
133
+
134
+
135
+ def _jax_v04_new_jax_trace():
136
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
137
+ frame = main.jaxpr_stack[-1]
138
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
139
+ return frame, trace
140
+
141
+
142
+ def _jax_v04_new_arg():
143
+ # Should be within the calling of ``jax.make_jaxpr()``
144
+ frame, trace = _jax_v04_new_jax_trace()
145
+ # Set the function to transform the new argument to a tracer
146
+ fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
147
+ return fn
148
+
149
+
150
+ def _jax_new_version_new_arg():
151
+ trace = jax.core.trace_ctx.trace
152
+
153
+ def wrapper(x):
154
+ if jax.__version_info__ < (0, 6, 1):
155
+ return trace.new_arg(shaped_abstractify(x))
156
+ else:
157
+ return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
158
+
159
+ return wrapper
160
+
161
+
162
+ def _init_state_trace_stack(name) -> StateTraceStack:
163
+ state_trace: StateTraceStack = StateTraceStack(name=name)
164
+
165
+ if jax.__version_info__ < (0, 4, 36):
166
+ state_trace.set_new_arg(_jax_v04_new_arg())
167
+ else:
168
+ state_trace.set_new_arg(_jax_new_version_new_arg())
169
+ return state_trace
170
+
171
+
172
+ default_cache_key = ((), ())
173
+
174
+
175
+ class StatefulFunction(PrettyObject):
176
+ """
177
+ A wrapper class for a function that collects the states that are read and written by the function. The states are
178
+ collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
179
+ manage the states in the JAX program. The class provides a function called `states` that returns the states
180
+ that are read and written by the function. The class provides a function called `to_state_manager` that returns
181
+ a StateDictManager instance that contains the states that are read and written by the function. The class provides
182
+ a function called `__call__` that wraps the function and returns the states that are read and written by the
183
+ function and the output of the function.
184
+
185
+ Args:
186
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
187
+ arguments and return value should be arrays, scalars, or standard Python
188
+ containers (tuple/list/dict) thereof.
189
+ static_argnums: See the :py:func:`jax.jit` docstring.
190
+ static_argnames: See the :py:func:`jax.jit` docstring.
191
+ axis_env: Optional, a sequence of pairs where the first element is an axis
192
+ name and the second element is a positive integer representing the size of
193
+ the mapped axis with that name. This parameter is useful when lowering
194
+ functions that involve parallel communication collectives, and it
195
+ specifies the axis name/size environment that would be set up by
196
+ applications of :py:func:`jax.pmap`.
197
+ abstracted_axes: Optional, a pytree with the same structure as the input
198
+ arguments to ``fun``. The leaves of the pytree can be either None or a
199
+ dict with axis names as keys and integers as values. If the leaf is None,
200
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
201
+ the corresponding axis is abstracted, and the dict specifies the axis name
202
+ and size. The abstracted axes are used to infer the input type of the
203
+ function. If None, then all axes are abstracted.
204
+ state_returns: Optional, a string or a tuple of strings. The default is
205
+ ``('read', 'write')``. The strings specify the categories of states to be
206
+ returned by the wrapped function. The categories are ``'read'`` and
207
+ ``'write'``. If the category is ``'read'``, then the wrapped function
208
+ returns the states that are read by the function. If the category is
209
+ ``'write'``, then the wrapped function returns the states that are written
210
+ by the function. If the category is ``'read'`` and ``'write'``, then the
211
+ wrapped function returns both the read and write states.
212
+
213
+ """
214
+ __module__ = "brainstate.compile"
215
+
216
+ def __init__(
217
+ self,
218
+ fun: Callable,
219
+ static_argnums: Union[int, Iterable[int]] = (),
220
+ static_argnames: Union[str, Iterable[str]] = (),
221
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
222
+ abstracted_axes: Optional[Any] = None,
223
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
224
+ cache_type: Optional[str] = None,
225
+ name: Optional[str] = None,
226
+ ):
227
+ # explicit parameters
228
+ self.fun = fun
229
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
230
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
231
+ self.axis_env = axis_env
232
+ self.abstracted_axes = abstracted_axes
233
+ self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
234
+ assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
235
+ self.name = name
236
+
237
+ # implicit parameters
238
+ self.cache_type = cache_type
239
+ self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
240
+ self._cached_out_shapes: Dict[Any, PyTree] = dict()
241
+ self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
242
+ self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
243
+
244
+ def __pretty_repr_item__(self, k, v):
245
+ if k.startswith('_'):
246
+ return None
247
+ return k, v
248
+
249
+ def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
250
+ """
251
+ Read the JAX Jaxpr representation of the function.
252
+
253
+ Args:
254
+ cache_key: The hashable key.
255
+
256
+ Returns:
257
+ The JAX Jaxpr representation of the function.
258
+ """
259
+ if cache_key is None:
260
+ cache_key = default_cache_key
261
+ if cache_key not in self._cached_jaxpr:
262
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
263
+ return self._cached_jaxpr[cache_key]
264
+
265
+ def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
266
+ """
267
+ Read the output shapes of the function.
268
+
269
+ Args:
270
+ cache_key: The hashable key.
271
+
272
+ Returns:
273
+ The output shapes of the function.
274
+ """
275
+ if cache_key is None:
276
+ cache_key = default_cache_key
277
+ if cache_key not in self._cached_out_shapes:
278
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
279
+ return self._cached_out_shapes[cache_key]
280
+
281
+ def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
282
+ """
283
+ Read the output tree of the function.
284
+
285
+ Args:
286
+ cache_key: The hashable key.
287
+
288
+ Returns:
289
+ The output tree of the function.
290
+ """
291
+ if cache_key is None:
292
+ cache_key = default_cache_key
293
+ if cache_key not in self._cached_jaxpr_out_tree:
294
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
295
+ return self._cached_jaxpr_out_tree[cache_key]
296
+
297
+ def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
298
+ """
299
+ Read the state trace of the function.
300
+
301
+ Args:
302
+ cache_key: The hashable key.
303
+
304
+ Returns:
305
+ The state trace of the function.
306
+ """
307
+ if cache_key is None:
308
+ cache_key = default_cache_key
309
+ if cache_key not in self._cached_state_trace:
310
+ raise ValueError(f"the function is not called with the static arguments: {cache_key}")
311
+ return self._cached_state_trace[cache_key]
312
+
313
+ def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
314
+ """
315
+ Read the states that are read and written by the function.
316
+
317
+ Args:
318
+ cache_key: The hashable key.
319
+
320
+ Returns:
321
+ The states that are read and written by the function.
322
+ """
323
+ if cache_key is None:
324
+ cache_key = default_cache_key
325
+ return tuple(self.get_state_trace(cache_key).states)
326
+
327
+ def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
328
+ """
329
+ Read the states that are read by the function.
330
+
331
+ Args:
332
+ cache_key: The hashable key.
333
+
334
+ Returns:
335
+ The states that are read by the function.
336
+ """
337
+ if cache_key is None:
338
+ cache_key = default_cache_key
339
+ return self.get_state_trace(cache_key).get_read_states()
340
+
341
+ def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
342
+ """
343
+ Read the states that are written by the function.
344
+
345
+ Args:
346
+ cache_key: The hashable key.
347
+
348
+ Returns:
349
+ The states that are written by the function.
350
+ """
351
+ if cache_key is None:
352
+ cache_key = default_cache_key
353
+ return self.get_state_trace(cache_key).get_write_states()
354
+
355
+ def _check_input_ouput(self, x):
356
+ if isinstance(x, State):
357
+ x.raise_error_with_source_info(
358
+ ValueError(
359
+ 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
360
+ f'But we got {x}'
361
+ )
362
+ )
363
+
364
+ def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
365
+ """
366
+ Get the static arguments from the arguments.
367
+
368
+ Args:
369
+ *args: The arguments to the function.
370
+ **kwargs: The keyword arguments to the function.
371
+
372
+ Returns:
373
+ The static arguments and keyword arguments as a tuple.
374
+ """
375
+ if self.cache_type == 'jit':
376
+ static_args, dyn_args = [], []
377
+ for i, arg in enumerate(args):
378
+ if i in self.static_argnums:
379
+ static_args.append(arg)
380
+ else:
381
+ dyn_args.append(arg)
382
+ dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
383
+ static_kwargs, dyn_kwargs = [], []
384
+ for k, v in kwargs.items():
385
+ if k in self.static_argnames:
386
+ static_kwargs.append((k, v))
387
+ else:
388
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
389
+
390
+ static_args = make_hashable(tuple(static_args))
391
+ dyn_args = make_hashable(tuple(dyn_args))
392
+ static_kwargs = make_hashable(static_kwargs)
393
+ dyn_kwargs = make_hashable(dyn_kwargs)
394
+
395
+ cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
396
+ elif self.cache_type is None:
397
+ num_arg = len(args)
398
+ static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
399
+ static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
400
+
401
+ # Make everything hashable
402
+ static_args = make_hashable(static_args)
403
+ static_kwargs = make_hashable(static_kwargs)
404
+
405
+ cache_key = (static_args, static_kwargs)
406
+ else:
407
+ raise ValueError(f"Invalid cache type: {self.cache_type}")
408
+
409
+ return cache_key
410
+
411
+ def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
412
+ """
413
+ Compile the function, and get the states that are read and written by this function.
414
+
415
+ Args:
416
+ *args: The arguments to the function.
417
+ **kwargs: The keyword arguments to the function.
418
+
419
+ Returns:
420
+ The states that are read and written by the function.
421
+ """
422
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
423
+ if cache_key not in self._cached_state_trace:
424
+ self.make_jaxpr(*args, **kwargs)
425
+ return self.get_states(cache_key)
426
+
427
+ def compile_function_and_get_state_trace(
428
+ self, *args, return_only_write: bool = False, **kwargs
429
+ ) -> StateTraceStack:
430
+ """
431
+ Compile the function, and get the states that are read and written by this function.
432
+
433
+ Args:
434
+ *args: The arguments to the function.
435
+ **kwargs: The keyword arguments to the function.
436
+ return_only_write: If True, only return the states that are written by the function.
437
+
438
+ Returns:
439
+ The state trace stack.
440
+ """
441
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
442
+ if cache_key not in self._cached_state_trace:
443
+ self.make_jaxpr(*args, **kwargs, return_only_write=return_only_write)
444
+ return self.get_state_trace(cache_key)
445
+
446
+ def clear_cache(self) -> None:
447
+ """
448
+ Clear the compilation cache.
449
+ """
450
+ self._cached_jaxpr.clear()
451
+ self._cached_out_shapes.clear()
452
+ self._cached_jaxpr_out_tree.clear()
453
+ self._cached_state_trace.clear()
454
+
455
+ def _wrapped_fun_to_eval(
456
+ self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
457
+ ) -> Tuple[Any, Tuple[State, ...]]:
458
+ """
459
+ Wrap the function and return the states that are read and written by the function and the output of the function.
460
+
461
+ Args:
462
+ *args: The arguments to the function.
463
+ **kwargs: The keyword arguments to the function.
464
+
465
+ Returns:
466
+ A tuple of the states that are read and written by the function and the output of the function.
467
+ """
468
+ # state trace
469
+ state_trace = _init_state_trace_stack(self.name)
470
+ self._cached_state_trace[cache_key] = state_trace
471
+ with state_trace:
472
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
473
+ state_values = (
474
+ state_trace.get_write_state_values(True)
475
+ if return_only_write else
476
+ state_trace.get_state_values()
477
+ )
478
+ state_trace.recovery_original_values()
479
+
480
+ # State instance as functional returns is not allowed.
481
+ # Checking whether the states are returned.
482
+ jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
483
+ return out, state_values
484
+
485
+ def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
486
+ """Creates a function that produces its jaxpr given example args.
487
+
488
+ A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
489
+ argument ``return_shape`` is ``True``, then the returned function instead
490
+ returns a pair where the first element is the ``ClosedJaxpr``
491
+ representation of ``fun`` and the second element is a pytree representing
492
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
493
+
494
+ Args:
495
+ *args: The arguments to the function.
496
+ **kwargs: The keyword arguments to the function.
497
+ return_only_write: If True, only return the states that are written by the function.
498
+ """
499
+
500
+ # static args
501
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
502
+
503
+ # check input types
504
+ jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
505
+
506
+ if cache_key not in self._cached_state_trace:
507
+ try:
508
+ # jaxpr
509
+ static_kwargs, dyn_kwargs = {}, {}
510
+ for k, v in kwargs.items():
511
+ if k in self.static_argnames:
512
+ static_kwargs[k] = v
513
+ else:
514
+ dyn_kwargs[k] = v
515
+ jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
516
+ functools.partial(
517
+ self._wrapped_fun_to_eval,
518
+ cache_key,
519
+ static_kwargs,
520
+ return_only_write=return_only_write
521
+ ),
522
+ static_argnums=self.static_argnums,
523
+ axis_env=self.axis_env,
524
+ return_shape=True,
525
+ abstracted_axes=self.abstracted_axes
526
+ )(*args, **dyn_kwargs)
527
+ # returns
528
+ self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
529
+ self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
530
+ self._cached_jaxpr[cache_key] = jaxpr
531
+
532
+ except Exception as e:
533
+ try:
534
+ self._cached_state_trace.pop(cache_key)
535
+ except KeyError:
536
+ pass
537
+ raise e
538
+
539
+ return self
540
+
541
+ def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
542
+ """
543
+ Call the function at the JAX Jaxpr level.
544
+
545
+ Args:
546
+ state_vals: The state values.
547
+ *args: The arguments to the function.
548
+ **kwargs: The keyword arguments to the function.
549
+
550
+ Returns:
551
+ State values and the function output.
552
+ """
553
+ # state checking
554
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
555
+ states: Sequence[State] = self.get_states(cache_key)
556
+ assert len(state_vals) == len(states), 'State length mismatch.'
557
+
558
+ # parameters
559
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
560
+ args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
561
+ args = jax.tree.flatten((args, kwargs, state_vals))[0]
562
+
563
+ # calling the function,
564
+ # note that this function always returns state values
565
+ # that both write and read by the function
566
+ closed_jaxpr = self.get_jaxpr(cache_key)
567
+ out_treedef = self.get_out_treedef(cache_key)
568
+ jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
569
+
570
+ # output processing
571
+ out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
572
+ assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
573
+ return new_state_vals, out
574
+
575
+ def jaxpr_call_auto(self, *args, **kwargs) -> Any:
576
+ """
577
+ Call the function at the JAX Jaxpr level with automatic state management.
578
+
579
+ Args:
580
+ *args: The arguments to the function.
581
+ **kwargs: The keyword arguments to the function.
582
+
583
+ Returns:
584
+ The output of the function.
585
+ """
586
+ state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
587
+ state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
588
+ state_trace.assign_state_vals(state_vals)
589
+ return out
590
+
591
+
592
+ @set_module_as("brainstate.compile")
593
+ def make_jaxpr(
594
+ fun: Callable,
595
+ static_argnums: Union[int, Iterable[int]] = (),
596
+ static_argnames: Union[str, Iterable[str]] = (),
597
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
598
+ return_shape: bool = False,
599
+ abstracted_axes: Optional[Any] = None,
600
+ state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
601
+ ) -> Callable[
602
+ ...,
603
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
604
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
605
+ ]:
606
+ """
607
+ Creates a function that produces its jaxpr given example args.
608
+
609
+ Args:
610
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
611
+ arguments and return value should be arrays, scalars, or standard Python
612
+ containers (tuple/list/dict) thereof.
613
+ static_argnums: See the :py:func:`jax.jit` docstring.
614
+ static_argnames: See the :py:func:`jax.jit` docstring.
615
+ axis_env: Optional, a sequence of pairs where the first element is an axis
616
+ name and the second element is a positive integer representing the size of
617
+ the mapped axis with that name. This parameter is useful when lowering
618
+ functions that involve parallel communication collectives, and it
619
+ specifies the axis name/size environment that would be set up by
620
+ applications of :py:func:`jax.pmap`.
621
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
622
+ wrapped function returns a pair where the first element is the XLA
623
+ computation and the second element is a pytree with the same structure as
624
+ the output of ``fun`` and where the leaves are objects with ``shape``,
625
+ ``dtype``, and ``named_shape`` attributes representing the corresponding
626
+ types of the output leaves.
627
+ abstracted_axes: Optional, a pytree with the same structure as the input
628
+ arguments to ``fun``. The leaves of the pytree can be either None or a
629
+ dict with axis names as keys and integers as values. If the leaf is None,
630
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
631
+ the corresponding axis is abstracted, and the dict specifies the axis name
632
+ and size. The abstracted axes are used to infer the input type of the
633
+ function. If None, then all axes are abstracted.
634
+ state_returns: Optional, a string or a tuple of strings. The default is
635
+ ``('read', 'write')``. The strings specify the categories of states to be
636
+ returned by the wrapped function. The categories are ``'read'`` and
637
+ ``'write'``. If the category is ``'read'``, then the wrapped function
638
+ returns the states that are read by the function. If the category is
639
+ ``'write'``, then the wrapped function returns the states that are written
640
+ by the function. If the category is ``'read'`` and ``'write'``, then the
641
+ wrapped function returns both the read and write states.
642
+
643
+
644
+ Returns:
645
+ A wrapped version of ``fun`` that when applied to example arguments returns
646
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
647
+ argument ``return_shape`` is ``True``, then the returned function instead
648
+ returns a pair where the first element is the ``ClosedJaxpr``
649
+ representation of ``fun`` and the second element is a pytree representing
650
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
651
+
652
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
653
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
654
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
655
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
656
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
657
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
658
+
659
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
660
+ instead give a few examples.
661
+
662
+ >>> import jax
663
+ >>> import brainstate as brainstate
664
+ >>>
665
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
666
+ >>> print(f(3.0))
667
+ -0.83602
668
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
669
+ >>> jaxpr
670
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
671
+ >>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
672
+ >>> jaxpr
673
+ { lambda ; a:f32[]. let
674
+ b:f32[] = cos a
675
+ c:f32[] = sin a
676
+ _:f32[] = sin b
677
+ d:f32[] = cos b
678
+ e:f32[] = mul 1.0 d
679
+ f:f32[] = neg e
680
+ g:f32[] = mul f c
681
+ in (g,) }
682
+ """
683
+
684
+ stateful_fun = StatefulFunction(
685
+ fun,
686
+ static_argnums=static_argnums,
687
+ static_argnames=static_argnames,
688
+ axis_env=axis_env,
689
+ abstracted_axes=abstracted_axes,
690
+ state_returns=state_returns,
691
+ name='make_jaxpr'
692
+ )
693
+
694
+ @wraps(fun)
695
+ def make_jaxpr_f(*args, **kwargs):
696
+ stateful_fun.make_jaxpr(*args, **kwargs)
697
+ cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
698
+ if return_shape:
699
+ return (stateful_fun.get_jaxpr(cache_key),
700
+ stateful_fun.get_states(cache_key),
701
+ stateful_fun.get_out_shapes(cache_key)[0])
702
+ else:
703
+ return (stateful_fun.get_jaxpr(cache_key),
704
+ stateful_fun.get_states(cache_key))
705
+
706
+ # wrapped jaxpr builder function
707
+ make_jaxpr_f.__module__ = "brainstate.compile"
708
+ if hasattr(fun, "__qualname__"):
709
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
710
+ if hasattr(fun, "__name__"):
711
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
712
+ return make_jaxpr_f
713
+
714
+
715
+ def _check_callable(fun):
716
+ # In Python 3.10+, the only thing stopping us from supporting staticmethods
717
+ # is that we can't take weak references to them, which the C++ JIT requires.
718
+ if isinstance(fun, staticmethod):
719
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
720
+ if not callable(fun):
721
+ raise TypeError(f"Expected a callable value, got {fun}")
722
+ if inspect.isgeneratorfunction(fun):
723
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
724
+
725
+
726
+ def _broadcast_prefix(
727
+ prefix_tree: Any,
728
+ full_tree: Any,
729
+ is_leaf: Callable[[Any], bool] | None = None
730
+ ) -> list[Any]:
731
+ # If prefix_tree is not a tree prefix of full_tree, this code can raise a
732
+ # ValueError; use prefix_errors to find disagreements and raise more precise
733
+ # error messages.
734
+ result = []
735
+ num_leaves = lambda t: jax.tree.structure(t).num_leaves
736
+ add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
737
+ jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
738
+ return result
739
+
740
+
741
+ def _flat_axes_specs(
742
+ abstracted_axes, *args, **kwargs
743
+ ) -> list[pe.AbstractedAxesSpec]:
744
+ if kwargs:
745
+ raise NotImplementedError
746
+
747
+ def ax_leaf(l):
748
+ return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
749
+ isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
750
+
751
+ return _broadcast_prefix(abstracted_axes, args, ax_leaf)
752
+
753
+
754
+ @transformation_with_aux
755
+ def _flatten_fun(in_tree, *args_flat):
756
+ py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
757
+ ans = yield py_args, py_kwargs
758
+ yield jax.tree.flatten(ans)
759
+
760
+
761
+ def _make_jaxpr(
762
+ fun: Callable,
763
+ static_argnums: int | Iterable[int] = (),
764
+ axis_env: Sequence[tuple[AxisName, int]] | None = None,
765
+ return_shape: bool = False,
766
+ abstracted_axes: Any | None = None,
767
+ ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
768
+ """Creates a function that produces its jaxpr given example args.
769
+
770
+ Args:
771
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
772
+ arguments and return value should be arrays, scalars, or standard Python
773
+ containers (tuple/list/dict) thereof.
774
+ static_argnums: See the :py:func:`jax.jit` docstring.
775
+ axis_env: Optional, a sequence of pairs where the first element is an axis
776
+ name and the second element is a positive integer representing the size of
777
+ the mapped axis with that name. This parameter is useful when lowering
778
+ functions that involve parallel communication collectives, and it
779
+ specifies the axis name/size environment that would be set up by
780
+ applications of :py:func:`jax.pmap`.
781
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
782
+ wrapped function returns a pair where the first element is the
783
+ ``ClosedJaxpr`` representation of ``fun`` and the second element is a
784
+ pytree with the same structure as the output of ``fun`` and where the
785
+ leaves are objects with ``shape``, ``dtype``, and ``named_shape``
786
+ attributes representing the corresponding types of the output leaves.
787
+
788
+ Returns:
789
+ A wrapped version of ``fun`` that when applied to example arguments returns
790
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
791
+ argument ``return_shape`` is ``True``, then the returned function instead
792
+ returns a pair where the first element is the ``ClosedJaxpr``
793
+ representation of ``fun`` and the second element is a pytree representing
794
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
795
+
796
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
797
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
798
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
799
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
800
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
801
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
802
+
803
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
804
+ instead give a few examples.
805
+
806
+ >>> import jax
807
+ >>>
808
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
809
+ >>> print(f(3.0))
810
+ -0.83602
811
+ >>> _make_jaxpr(f)(3.0)
812
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
813
+ >>> _make_jaxpr(jax.grad(f))(3.0)
814
+ { lambda ; a:f32[]. let
815
+ b:f32[] = cos a
816
+ c:f32[] = sin a
817
+ _:f32[] = sin b
818
+ d:f32[] = cos b
819
+ e:f32[] = mul 1.0 d
820
+ f:f32[] = neg e
821
+ g:f32[] = mul f c
822
+ in (g,) }
823
+ """
824
+ _check_callable(fun)
825
+ static_argnums = _ensure_index_tuple(static_argnums)
826
+
827
+ def _abstractify(args, kwargs):
828
+ flat_args, in_tree = jax.tree.flatten((args, kwargs))
829
+ if abstracted_axes is None:
830
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
831
+ else:
832
+ axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
833
+ in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
834
+ in_avals, keep_inputs = unzip2(in_type)
835
+ return in_avals, in_tree, keep_inputs
836
+
837
+ @wraps(fun)
838
+ @api_boundary
839
+ def make_jaxpr_f(*args, **kwargs):
840
+ f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
841
+ if static_argnums:
842
+ dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
843
+ f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
844
+ in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
845
+ in_type = tuple(safe_zip(in_avals, keep_inputs))
846
+ f, out_tree = _flatten_fun(f, in_tree)
847
+ f = annotate(f, in_type)
848
+ if jax.__version_info__ < (0, 5, 0):
849
+ debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
850
+ with ExitStack() as stack:
851
+ if axis_env is not None:
852
+ stack.enter_context(extend_axis_env_nd(axis_env))
853
+ if jax.__version_info__ < (0, 5, 0):
854
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
855
+ else:
856
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
857
+ closed_jaxpr = ClosedJaxpr(jaxpr, consts)
858
+ if return_shape:
859
+ out_avals, _ = unzip2(out_type)
860
+ out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
861
+ return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
862
+ return closed_jaxpr
863
+
864
+ make_jaxpr_f.__module__ = "brainstate.compile"
865
+ if hasattr(fun, "__qualname__"):
866
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
867
+ if hasattr(fun, "__name__"):
868
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
869
+ return make_jaxpr_f
870
+
871
+
872
+ def make_hashable(obj):
873
+ """Convert a pytree into a hashable representation."""
874
+ if isinstance(obj, (list, tuple)):
875
+ return tuple(make_hashable(item) for item in obj)
876
+ elif isinstance(obj, dict):
877
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
878
+ elif isinstance(obj, set):
879
+ return frozenset(make_hashable(item) for item in obj)
880
+ else:
881
+ # # Use JAX's tree_util for any other pytree structures
882
+ # try:
883
+ # leaves, treedef = jax.tree_util.tree_flatten(obj)
884
+ # hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
885
+ # return (str(treedef), hashable_leaves)
886
+ # except:
887
+ # # Assume obj is already hashable
888
+ return obj