brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,739 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
- written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
- include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
-
21
-
22
- ``StatefulFunction``
23
- --------------------
24
-
25
- The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
- JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
- The class provides the following methods:
28
-
29
- - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
- - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
- - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
- - `get_states`: returns the states that are read and written by the function.
33
- - `get_read_states`: returns the states that are read by the function.
34
- - `get_write_states`: returns the states that are written by the function.
35
- - `get_static_args`: returns the static arguments from the arguments.
36
- - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
- written by the function.
38
- - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
- - `get_out_shapes`: returns the output shapes of the function.
40
- - `get_out_treedef`: returns the output tree of the function.
41
-
42
- ``make_jaxpr``
43
- --------------
44
-
45
- The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
- arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
- `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
- returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
- and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
- function.
51
-
52
- """
53
-
54
- from __future__ import annotations
55
-
56
- import functools
57
- import inspect
58
- import operator
59
- from collections.abc import Hashable, Iterable, Sequence
60
- from contextlib import ExitStack
61
- from typing import Any, Callable, Tuple, Union, Dict, Optional
62
-
63
- import jax
64
- from jax._src import source_info_util
65
- from jax._src.linear_util import annotate
66
- from jax._src.traceback_util import api_boundary
67
- from jax.extend.linear_util import transformation_with_aux, wrap_init
68
- from jax.interpreters import partial_eval as pe
69
- from jax.interpreters.xla import abstractify
70
- from jax.util import wraps
71
-
72
- from brainstate._state import State, StateTrace
73
- from brainstate._utils import set_module_as
74
- from brainstate.typing import PyTree
75
-
76
- AxisName = Hashable
77
-
78
- __all__ = [
79
- "StatefulFunction",
80
- "make_jaxpr",
81
- ]
82
-
83
-
84
- def _assign_state_values(states, state_vals) -> None:
85
- """
86
- Assign the state values to the states.
87
-
88
- Args:
89
- states: The states.
90
- state_vals: The state values.
91
- """
92
- assert len(states) == len(state_vals), f'State length mismatch. {len(states)} != {len(state_vals)}.'
93
- for st, val in zip(states, state_vals):
94
- st.value = val
95
-
96
-
97
- def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
98
- """Convert x to a tuple of indices."""
99
- x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
100
- try:
101
- return (operator.index(x),)
102
- except TypeError:
103
- return tuple(jax.util.safe_map(operator.index, x))
104
-
105
-
106
- def _new_arg(frame, trace, aval):
107
- """
108
- Transform a new argument to a tracer.
109
-
110
- Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
111
-
112
- Args:
113
- frame: The frame.
114
- trace: The trace.
115
- aval: The abstract value.
116
-
117
- Returns:
118
- The tracer.
119
- """
120
- tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
121
- frame.tracers.append(tracer)
122
- frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
123
- frame.invars.append(var)
124
- return tracer
125
-
126
-
127
- def wrapped_abstractify(x: Any) -> Any:
128
- """
129
- Abstractify the input.
130
-
131
- Args:
132
- x: The input.
133
-
134
- Returns:
135
- The abstractified input.
136
- """
137
- if isinstance(x, pe.DynamicJaxprTracer):
138
- return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
139
- return abstractify(x)
140
-
141
-
142
- class StatefulFunction(object):
143
- """
144
- A wrapper class for a function that collects the states that are read and written by the function. The states are
145
- collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
146
- manage the states in the JAX program. The class provides a function called `states` that returns the states
147
- that are read and written by the function. The class provides a function called `to_state_manager` that returns
148
- a StateDictManager instance that contains the states that are read and written by the function. The class provides
149
- a function called `__call__` that wraps the function and returns the states that are read and written by the
150
- function and the output of the function.
151
-
152
- Args:
153
- fun: The function whose ``jaxpr`` is to be computed. Its positional
154
- arguments and return value should be arrays, scalars, or standard Python
155
- containers (tuple/list/dict) thereof.
156
- static_argnums: See the :py:func:`jax.jit` docstring.
157
- axis_env: Optional, a sequence of pairs where the first element is an axis
158
- name and the second element is a positive integer representing the size of
159
- the mapped axis with that name. This parameter is useful when lowering
160
- functions that involve parallel communication collectives, and it
161
- specifies the axis name/size environment that would be set up by
162
- applications of :py:func:`jax.pmap`.
163
- abstracted_axes: Optional, a pytree with the same structure as the input
164
- arguments to ``fun``. The leaves of the pytree can be either None or a
165
- dict with axis names as keys and integers as values. If the leaf is None,
166
- then the corresponding axis is not abstracted. If the leaf is a dict, then
167
- the corresponding axis is abstracted, and the dict specifies the axis name
168
- and size. The abstracted axes are used to infer the input type of the
169
- function. If None, then all axes are abstracted.
170
- state_returns: Optional, a string or a tuple of strings. The default is
171
- ``('read', 'write')``. The strings specify the categories of states to be
172
- returned by the wrapped function. The categories are ``'read'`` and
173
- ``'write'``. If the category is ``'read'``, then the wrapped function
174
- returns the states that are read by the function. If the category is
175
- ``'write'``, then the wrapped function returns the states that are written
176
- by the function. If the category is ``'read'`` and ``'write'``, then the
177
- wrapped function returns both the read and write states.
178
-
179
- """
180
- __module__ = "brainstate.transform"
181
-
182
- def __init__(
183
- self,
184
- fun: Callable,
185
- static_argnums: Union[int, Iterable[int]] = (),
186
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
187
- abstracted_axes: Optional[Any] = None,
188
- state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
189
- cache_type: Optional[str] = None,
190
- ):
191
- # explicit parameters
192
- self.fun = fun
193
- self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
194
- self.axis_env = axis_env
195
- self.abstracted_axes = abstracted_axes
196
- self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
197
- assert cache_type in [None, 'jit']
198
- self.cache_type = cache_type
199
-
200
- # implicit parameters
201
- self._jaxpr: Dict[Any, jax.core.ClosedJaxpr] = dict()
202
- self._out_shapes: Dict[Any, PyTree] = dict()
203
- self._jaxpr_out_tree: Dict[Any, PyTree] = dict()
204
- self._state_trace: Dict[Any, StateTrace] = dict()
205
-
206
- def __repr__(self) -> str:
207
- return (f"{self.__class__.__name__}({self.fun}, "
208
- f"static_argnums={self.static_argnums}, "
209
- f"axis_env={self.axis_env}, "
210
- f"abstracted_axes={self.abstracted_axes}, "
211
- f"state_returns={self.state_returns})")
212
-
213
- def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
214
- """
215
- Read the JAX Jaxpr representation of the function.
216
-
217
- Args:
218
- cache_key: The hashable key.
219
-
220
- Returns:
221
- The JAX Jaxpr representation of the function.
222
- """
223
- if cache_key not in self._jaxpr:
224
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
225
- return self._jaxpr[cache_key]
226
-
227
- def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
228
- """
229
- Read the output shapes of the function.
230
-
231
- Args:
232
- cache_key: The hashable key.
233
-
234
- Returns:
235
- The output shapes of the function.
236
- """
237
- if cache_key not in self._out_shapes:
238
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
239
- return self._out_shapes[cache_key]
240
-
241
- def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
242
- """
243
- Read the output tree of the function.
244
-
245
- Args:
246
- cache_key: The hashable key.
247
-
248
- Returns:
249
- The output tree of the function.
250
- """
251
- if cache_key not in self._jaxpr_out_tree:
252
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
253
- return self._jaxpr_out_tree[cache_key]
254
-
255
- def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
256
- """
257
- Read the states that are read and written by the function.
258
-
259
- Args:
260
- cache_key: The hashable key.
261
-
262
- Returns:
263
- The states that are read and written by the function.
264
- """
265
- if cache_key not in self._state_trace:
266
- raise ValueError(f"the function is not called with the static arguments: {cache_key}")
267
- return tuple(self._state_trace[cache_key].states)
268
-
269
- def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
270
- """
271
- Read the states that are read by the function.
272
-
273
- Args:
274
- cache_key: The hashable key.
275
-
276
- Returns:
277
- The states that are read by the function.
278
- """
279
- _state_trace = self._state_trace[cache_key]
280
- return tuple([st for st, ty in zip(_state_trace.states, _state_trace.types) if ty == 'read'])
281
-
282
- def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
283
- """
284
- Read the states that are written by the function.
285
-
286
- Args:
287
- cache_key: The hashable key.
288
-
289
- Returns:
290
- The states that are written by the function.
291
- """
292
- state_trace = self._state_trace[cache_key]
293
- return tuple([st for st, ty in zip(state_trace.states, state_trace.types) if ty == 'write'])
294
-
295
- def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
296
- """
297
- Get the static arguments from the arguments.
298
-
299
- Args:
300
- *args: The arguments to the function.
301
-
302
- Returns:
303
- The static arguments.
304
- """
305
- if self.cache_type == 'jit':
306
- static_args, dyn_args = [], []
307
- for i, arg in enumerate(args):
308
- if i in self.static_argnums:
309
- static_args.append(arg)
310
- else:
311
- dyn_args.append(arg)
312
- dyn_args = jax.tree.map(wrapped_abstractify, jax.tree.leaves(dyn_args))
313
- dyn_kwargs = jax.tree.map(wrapped_abstractify, jax.tree.leaves(kwargs))
314
- return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
315
- elif self.cache_type is None:
316
- num_arg = len(args)
317
- return tuple(args[i] for i in self.static_argnums if i < num_arg)
318
- else:
319
- raise ValueError(f"Invalid cache type: {self.cache_type}")
320
-
321
- def compile_and_get_states_by_static_args(self, *args, **kwargs) -> Tuple[State, ...]:
322
- """
323
- Get the states that are read and written by the function.
324
-
325
- Args:
326
- *args: The arguments to the function.
327
- **kwargs: The keyword arguments to the function.
328
-
329
- Returns:
330
- The states that are read and written by the function.
331
- """
332
- cache_key = self.get_arg_cache_key(*args, **kwargs)
333
- if cache_key not in self._state_trace:
334
- self.make_jaxpr(*args, **kwargs)
335
- return self.get_states(cache_key)
336
-
337
- def clear_cache(self) -> None:
338
- """
339
- Clear the compilation cache.
340
- """
341
- self._jaxpr.clear()
342
- self._out_shapes.clear()
343
- self._jaxpr_out_tree.clear()
344
- self._state_trace.clear()
345
-
346
- @staticmethod
347
- def _init_trace_and_newarg() -> StateTrace:
348
- # Should be within the calling of ``jax.make_jaxpr()``
349
- state_trace: StateTrace = StateTrace()
350
- main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
351
- frame = main.jaxpr_stack[-1]
352
- trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
353
- state_trace.set_new_arg(functools.partial(_new_arg, frame, trace))
354
- return state_trace
355
-
356
- def _wrapped_fun_to_eval(self, cache_key, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
357
- """
358
- Wrap the function and return the states that are read and written by the function and the output of the function.
359
-
360
- Args:
361
- *args: The arguments to the function.
362
- **kwargs: The keyword arguments to the function.
363
-
364
- Returns:
365
- A tuple of the states that are read and written by the function and the output of the function.
366
- """
367
- # state trace
368
- _state_trace = self._init_trace_and_newarg()
369
- self._state_trace[cache_key] = _state_trace
370
- with _state_trace:
371
- out = self.fun(*args, **kwargs)
372
- state_values = _state_trace.collect_values('read', 'write', check_val_tree=True)
373
- _state_trace.recovery_original_values()
374
-
375
- # State instance as functional returns is not allowed.
376
- # Checking whether the states are returned.
377
- for leaf in jax.tree.leaves(out):
378
- if isinstance(leaf, State):
379
- leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
380
- return out, state_values
381
-
382
- def make_jaxpr(self, *args, **kwargs):
383
- """Creates a function that produces its jaxpr given example args.
384
-
385
- A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
386
- argument ``return_shape`` is ``True``, then the returned function instead
387
- returns a pair where the first element is the ``ClosedJaxpr``
388
- representation of ``fun`` and the second element is a pytree representing
389
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
390
-
391
- Args:
392
- *args: The arguments to the function.
393
- **kwargs: The keyword arguments to the function.
394
- """
395
-
396
- # static args
397
- cache_key = self.get_arg_cache_key(*args, **kwargs)
398
-
399
- if cache_key not in self._state_trace:
400
- try:
401
- # jaxpr
402
- jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
403
- functools.partial(self._wrapped_fun_to_eval, cache_key),
404
- static_argnums=self.static_argnums,
405
- axis_env=self.axis_env,
406
- return_shape=True,
407
- abstracted_axes=self.abstracted_axes
408
- )(*args, **kwargs)
409
-
410
- # returns
411
- self._jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
412
- self._out_shapes[cache_key] = (out_shapes, state_shapes)
413
- self._jaxpr[cache_key] = jaxpr
414
- except Exception as e:
415
- try:
416
- self._state_trace.pop(cache_key)
417
- except KeyError:
418
- pass
419
- raise e
420
-
421
- return self
422
-
423
- def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
424
- """
425
- Call the function at the JAX Jaxpr level.
426
-
427
- Args:
428
- state_vals: The state values.
429
- *args: The arguments to the function.
430
- **kwargs: The keyword arguments to the function.
431
-
432
- Returns:
433
- State values and the function output.
434
- """
435
- # state checking
436
- cache_key = self.get_arg_cache_key(*args, **kwargs)
437
- states: Sequence[State] = self.get_states(cache_key)
438
- assert len(state_vals) == len(states), 'State length mismatch.'
439
- # # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
440
- # for val, st in zip(state_vals, states): # check state's value tree structure
441
- # st._check_value_tree(val)
442
-
443
- # parameters
444
- args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
445
- args = jax.tree.flatten((args, kwargs, state_vals))[0]
446
-
447
- # calling the function
448
- closed_jaxpr = self.get_jaxpr(cache_key)
449
- out_treedef = self.get_out_treedef(cache_key)
450
- jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
451
-
452
- # output processing
453
- out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
454
- assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
455
- # # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
456
- # for val, st in zip(new_state_vals, states): # check state's value tree structure
457
- # st._check_value_tree(val)
458
- return new_state_vals, out
459
-
460
- def jaxpr_call_auto(self, *args, **kwargs) -> Any:
461
- """
462
- Call the function at the JAX Jaxpr level with automatic state management.
463
-
464
- Args:
465
- *args: The arguments to the function.
466
- **kwargs: The keyword arguments to the function.
467
-
468
- Returns:
469
- The output of the function.
470
- """
471
- cache_key = self.get_arg_cache_key(*args, **kwargs)
472
- states = self.get_states(cache_key)
473
- state_vals, out = self.jaxpr_call([st.value for st in states], *args, **kwargs)
474
- for st, val in zip(states, state_vals):
475
- st.value = val
476
- return out
477
-
478
-
479
- @set_module_as("brainstate.transform")
480
- def make_jaxpr(
481
- fun: Callable,
482
- static_argnums: Union[int, Iterable[int]] = (),
483
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
484
- return_shape: bool = False,
485
- abstracted_axes: Optional[Any] = None,
486
- state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
487
- ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
488
- Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
489
- """
490
- Creates a function that produces its jaxpr given example args.
491
-
492
- Args:
493
- fun: The function whose ``jaxpr`` is to be computed. Its positional
494
- arguments and return value should be arrays, scalars, or standard Python
495
- containers (tuple/list/dict) thereof.
496
- static_argnums: See the :py:func:`jax.jit` docstring.
497
- axis_env: Optional, a sequence of pairs where the first element is an axis
498
- name and the second element is a positive integer representing the size of
499
- the mapped axis with that name. This parameter is useful when lowering
500
- functions that involve parallel communication collectives, and it
501
- specifies the axis name/size environment that would be set up by
502
- applications of :py:func:`jax.pmap`.
503
- return_shape: Optional boolean, defaults to ``False``. If ``True``, the
504
- wrapped function returns a pair where the first element is the XLA
505
- computation and the second element is a pytree with the same structure as
506
- the output of ``fun`` and where the leaves are objects with ``shape``,
507
- ``dtype``, and ``named_shape`` attributes representing the corresponding
508
- types of the output leaves.
509
- abstracted_axes: Optional, a pytree with the same structure as the input
510
- arguments to ``fun``. The leaves of the pytree can be either None or a
511
- dict with axis names as keys and integers as values. If the leaf is None,
512
- then the corresponding axis is not abstracted. If the leaf is a dict, then
513
- the corresponding axis is abstracted, and the dict specifies the axis name
514
- and size. The abstracted axes are used to infer the input type of the
515
- function. If None, then all axes are abstracted.
516
- state_returns: Optional, a string or a tuple of strings. The default is
517
- ``('read', 'write')``. The strings specify the categories of states to be
518
- returned by the wrapped function. The categories are ``'read'`` and
519
- ``'write'``. If the category is ``'read'``, then the wrapped function
520
- returns the states that are read by the function. If the category is
521
- ``'write'``, then the wrapped function returns the states that are written
522
- by the function. If the category is ``'read'`` and ``'write'``, then the
523
- wrapped function returns both the read and write states.
524
-
525
-
526
- Returns:
527
- A wrapped version of ``fun`` that when applied to example arguments returns
528
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
529
- argument ``return_shape`` is ``True``, then the returned function instead
530
- returns a pair where the first element is the ``ClosedJaxpr``
531
- representation of ``fun`` and the second element is a pytree representing
532
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
533
-
534
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
535
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
536
- with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
537
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
538
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
539
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
540
-
541
- We do not describe the semantics of the ``jaxpr`` language in detail here, but
542
- instead give a few examples.
543
-
544
- >>> import jax
545
- >>> import brainstate as bst
546
- >>>
547
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
548
- >>> print(f(3.0))
549
- -0.83602
550
- >>> jaxpr, states = bst.transform.make_jaxpr(f)(3.0)
551
- >>> jaxpr
552
- { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
553
- >>> jaxpr, states = bst.transform.make_jaxpr(jax.grad(f))(3.0)
554
- >>> jaxpr
555
- { lambda ; a:f32[]. let
556
- b:f32[] = cos a
557
- c:f32[] = sin a
558
- _:f32[] = sin b
559
- d:f32[] = cos b
560
- e:f32[] = mul 1.0 d
561
- f:f32[] = neg e
562
- g:f32[] = mul f c
563
- in (g,) }
564
- """
565
-
566
- stateful_fun = StatefulFunction(fun, static_argnums, axis_env, abstracted_axes, state_returns)
567
-
568
- @wraps(fun)
569
- def make_jaxpr_f(*args, **kwargs):
570
- stateful_fun.make_jaxpr(*args, **kwargs)
571
- cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
572
- if return_shape:
573
- return (stateful_fun.get_jaxpr(cache_key),
574
- stateful_fun.get_states(cache_key),
575
- stateful_fun.get_out_shapes(cache_key)[0])
576
- else:
577
- return (stateful_fun.get_jaxpr(cache_key),
578
- stateful_fun.get_states(cache_key))
579
-
580
- # wrapped jaxpr builder function
581
- make_jaxpr_f.__module__ = "brainstate.transform"
582
- if hasattr(fun, "__qualname__"):
583
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
584
- if hasattr(fun, "__name__"):
585
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
586
- return make_jaxpr_f
587
-
588
-
589
- def _check_callable(fun):
590
- # In Python 3.10+, the only thing stopping us from supporting staticmethods
591
- # is that we can't take weak references to them, which the C++ JIT requires.
592
- if isinstance(fun, staticmethod):
593
- raise TypeError(f"staticmethod arguments are not supported, got {fun}")
594
- if not callable(fun):
595
- raise TypeError(f"Expected a callable value, got {fun}")
596
- if inspect.isgeneratorfunction(fun):
597
- raise TypeError(f"Expected a function, got a generator function: {fun}")
598
-
599
-
600
- def _broadcast_prefix(
601
- prefix_tree: Any,
602
- full_tree: Any,
603
- is_leaf: Callable[[Any], bool] | None = None
604
- ) -> list[Any]:
605
- # If prefix_tree is not a tree prefix of full_tree, this code can raise a
606
- # ValueError; use prefix_errors to find disagreements and raise more precise
607
- # error messages.
608
- result = []
609
- num_leaves = lambda t: jax.tree.structure(t).num_leaves
610
- add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
611
- jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
612
- return result
613
-
614
-
615
- def _flat_axes_specs(
616
- abstracted_axes, *args, **kwargs
617
- ) -> list[pe.AbstractedAxesSpec]:
618
- if kwargs:
619
- raise NotImplementedError
620
-
621
- def ax_leaf(l):
622
- return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
623
- isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
624
-
625
- return _broadcast_prefix(abstracted_axes, args, ax_leaf)
626
-
627
-
628
- @transformation_with_aux
629
- def _flatten_fun(in_tree, *args_flat):
630
- py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
631
- ans = yield py_args, py_kwargs
632
- yield jax.tree.flatten(ans)
633
-
634
-
635
- def _make_jaxpr(
636
- fun: Callable,
637
- static_argnums: int | Iterable[int] = (),
638
- axis_env: Sequence[tuple[AxisName, int]] | None = None,
639
- return_shape: bool = False,
640
- abstracted_axes: Any | None = None,
641
- ) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
642
- """Creates a function that produces its jaxpr given example args.
643
-
644
- Args:
645
- fun: The function whose ``jaxpr`` is to be computed. Its positional
646
- arguments and return value should be arrays, scalars, or standard Python
647
- containers (tuple/list/dict) thereof.
648
- static_argnums: See the :py:func:`jax.jit` docstring.
649
- axis_env: Optional, a sequence of pairs where the first element is an axis
650
- name and the second element is a positive integer representing the size of
651
- the mapped axis with that name. This parameter is useful when lowering
652
- functions that involve parallel communication collectives, and it
653
- specifies the axis name/size environment that would be set up by
654
- applications of :py:func:`jax.pmap`.
655
- return_shape: Optional boolean, defaults to ``False``. If ``True``, the
656
- wrapped function returns a pair where the first element is the
657
- ``ClosedJaxpr`` representation of ``fun`` and the second element is a
658
- pytree with the same structure as the output of ``fun`` and where the
659
- leaves are objects with ``shape``, ``dtype``, and ``named_shape``
660
- attributes representing the corresponding types of the output leaves.
661
-
662
- Returns:
663
- A wrapped version of ``fun`` that when applied to example arguments returns
664
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
665
- argument ``return_shape`` is ``True``, then the returned function instead
666
- returns a pair where the first element is the ``ClosedJaxpr``
667
- representation of ``fun`` and the second element is a pytree representing
668
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
669
-
670
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
671
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
672
- with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
673
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
674
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
675
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
676
-
677
- We do not describe the semantics of the ``jaxpr`` language in detail here, but
678
- instead give a few examples.
679
-
680
- >>> import jax
681
- >>>
682
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
683
- >>> print(f(3.0))
684
- -0.83602
685
- >>> _make_jaxpr(f)(3.0)
686
- { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
687
- >>> _make_jaxpr(jax.grad(f))(3.0)
688
- { lambda ; a:f32[]. let
689
- b:f32[] = cos a
690
- c:f32[] = sin a
691
- _:f32[] = sin b
692
- d:f32[] = cos b
693
- e:f32[] = mul 1.0 d
694
- f:f32[] = neg e
695
- g:f32[] = mul f c
696
- in (g,) }
697
- """
698
- _check_callable(fun)
699
- static_argnums = _ensure_index_tuple(static_argnums)
700
-
701
- def _abstractify(args, kwargs):
702
- flat_args, in_tree = jax.tree.flatten((args, kwargs))
703
- if abstracted_axes is None:
704
- return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
705
- else:
706
- axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
707
- in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
708
- in_avals, keep_inputs = jax.util.unzip2(in_type)
709
- return in_avals, in_tree, keep_inputs
710
-
711
- @wraps(fun)
712
- @api_boundary
713
- def make_jaxpr_f(*args, **kwargs):
714
- f = wrap_init(fun)
715
- if static_argnums:
716
- dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
717
- f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
718
- in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
719
- in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
720
- f, out_tree = _flatten_fun(f, in_tree)
721
- f = annotate(f, in_type)
722
- debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
723
- with ExitStack() as stack:
724
- for axis_name, size in axis_env or []:
725
- stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
726
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
727
- closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
728
- if return_shape:
729
- out_avals, _ = jax.util.unzip2(out_type)
730
- out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
731
- return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
732
- return closed_jaxpr
733
-
734
- make_jaxpr_f.__module__ = "brainstate.transform"
735
- if hasattr(fun, "__qualname__"):
736
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
737
- if hasattr(fun, "__name__"):
738
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
739
- return make_jaxpr_f