brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,346 +1,346 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from collections.abc import Iterable, Sequence
18
- from typing import (Any, Callable, Union)
19
-
20
- import jax
21
- from jax._src import sharding_impls
22
-
23
- from brainstate._compatible_import import Device
24
- from brainstate._utils import set_module_as
25
- from brainstate.typing import Missing
26
- from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
27
- from ._util import write_back_state_values
28
-
29
- __all__ = ['jit']
30
-
31
-
32
- class JittedFunction(Callable):
33
- """
34
- A wrapped version of ``fun``, set up for just-in-time compilation.
35
- """
36
- origin_fun: Callable # the original function
37
- stateful_fun: StatefulFunction # the stateful function for extracting states
38
- jitted_fun: jax.stages.Wrapped # the jitted function
39
- clear_cache: Callable # clear the cache of the jitted function
40
- eval_shape: Callable # evaluate the shape of the jitted function
41
- compile: Callable # lower the jitted function
42
- trace: Callable # trace the jitted
43
-
44
- def __call__(self, *args, **kwargs):
45
- pass
46
-
47
-
48
- def _get_jitted_fun(
49
- fun: Callable,
50
- in_shardings,
51
- out_shardings,
52
- static_argnums,
53
- donate_argnums,
54
- static_argnames,
55
- donate_argnames,
56
- keep_unused,
57
- device,
58
- backend,
59
- inline,
60
- abstracted_axes,
61
- **kwargs
62
- ) -> JittedFunction:
63
- static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
64
- donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
65
- fun = StatefulFunction(
66
- fun,
67
- static_argnums=static_argnums,
68
- static_argnames=static_argnames,
69
- abstracted_axes=abstracted_axes,
70
- cache_type='jit',
71
- name='jit'
72
- )
73
- jit_fun = jax.jit(
74
- fun.jaxpr_call,
75
- static_argnums=tuple(i + 1 for i in static_argnums),
76
- static_argnames=static_argnames,
77
- donate_argnums=tuple(i + 1 for i in donate_argnums),
78
- donate_argnames=donate_argnames,
79
- keep_unused=keep_unused,
80
- device=device,
81
- backend=backend,
82
- inline=inline,
83
- in_shardings=in_shardings,
84
- out_shardings=out_shardings,
85
- abstracted_axes=abstracted_axes,
86
- **kwargs
87
- )
88
-
89
- @functools.wraps(fun.fun)
90
- def jitted_fun(*args, **params):
91
- if jax.config.jax_disable_jit:
92
- return fun.fun(*args, **params)
93
-
94
- # compile the function and get the state trace
95
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
96
- read_state_vals = state_trace.get_read_state_values(True)
97
-
98
- # call the jitted function
99
- write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
100
-
101
- # write the state values back to the states
102
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
103
- return outs
104
-
105
- def clear_cache():
106
- """
107
- Clear the cache of the jitted function.
108
- """
109
- # clear the cache of the stateful function
110
- fun.clear_cache()
111
- try:
112
- # clear the cache of the jitted function
113
- jit_fun.clear_cache()
114
- except AttributeError:
115
- pass
116
-
117
- def eval_shape():
118
- raise NotImplementedError
119
-
120
- def trace():
121
- """Trace this function explicitly for the given arguments.
122
-
123
- A traced function is staged out of Python and translated to a jaxpr. It is
124
- ready for lowering but not yet lowered.
125
-
126
- Returns:
127
- A ``Traced`` instance representing the tracing.
128
- """
129
- raise NotImplementedError
130
-
131
- def compile(*args, **params):
132
- """Lower this function explicitly for the given arguments.
133
-
134
- A lowered function is staged out of Python and translated to a
135
- compiler's input language, possibly in a backend-dependent
136
- manner. It is ready for compilation but not yet compiled.
137
-
138
- Returns:
139
- A ``Lowered`` instance representing the lowering.
140
- """
141
- # compile the function and get the state trace
142
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
143
- read_state_vals = state_trace.get_read_state_values(replace_writen=True)
144
- write_state_vals = state_trace.get_write_state_values(replace_read=True)
145
-
146
- # compile the model
147
- ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
148
-
149
- # write the state values back to the states
150
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
151
- return ret
152
-
153
- jitted_fun: JittedFunction
154
-
155
- # the original function
156
- jitted_fun.origin_fun = fun.fun
157
-
158
- # the stateful function for extracting states
159
- jitted_fun.stateful_fun = fun
160
-
161
- # the jitted function
162
- jitted_fun.jitted_fun = jit_fun
163
-
164
- # clear cache
165
- jitted_fun.clear_cache = clear_cache
166
-
167
- # evaluate the shape of the jitted function
168
- jitted_fun.eval_shape = eval_shape
169
-
170
- # compile the jitted function
171
- jitted_fun.compile = compile
172
-
173
- # trace the jitted function
174
- jitted_fun.trace = trace
175
-
176
- return jitted_fun
177
-
178
-
179
- @set_module_as('brainstate.compile')
180
- def jit(
181
- fun: Callable | Missing = Missing(),
182
- in_shardings=sharding_impls.UNSPECIFIED,
183
- out_shardings=sharding_impls.UNSPECIFIED,
184
- static_argnums: int | Sequence[int] | None = None,
185
- donate_argnums: int | Sequence[int] | None = None,
186
- static_argnames: str | Sequence[str] | None = None,
187
- donate_argnames: str | Iterable[str] | None = None,
188
- keep_unused: bool = False,
189
- device: Device | None = None,
190
- backend: str | None = None,
191
- inline: bool = False,
192
- abstracted_axes: Any | None = None,
193
- **kwargs
194
- ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
195
- """
196
- Sets up ``fun`` for just-in-time compilation with XLA.
197
-
198
- Args:
199
- fun: Function to be jitted.
200
- in_shardings: Pytree of structure matching that of arguments to ``fun``,
201
- with all actual arguments replaced by resource assignment specifications.
202
- It is also valid to specify a pytree prefix (e.g. one value in place of a
203
- whole subtree), in which case the leaves get broadcast to all values in
204
- that subtree.
205
-
206
- The ``in_shardings`` argument is optional. JAX will infer the shardings
207
- from the input :py:class:`jax.Array`'s and defaults to replicating the input
208
- if the sharding cannot be inferred.
209
-
210
- The valid resource assignment specifications are:
211
- - :py:class:`XLACompatibleSharding`, which will decide how the value
212
- will be partitioned. With this, using a mesh context manager is not
213
- required.
214
- - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
215
- it wants.
216
- For in_shardings, JAX will mark is as replicated but this behavior
217
- can change in the future.
218
- For out_shardings, we will rely on the XLA GSPMD partitioner to
219
- determine the output shardings.
220
-
221
- The size of every dimension has to be a multiple of the total number of
222
- resources assigned to it. This is similar to pjit's in_shardings.
223
- out_shardings: Like ``in_shardings``, but specifies resource
224
- assignment for function outputs. This is similar to pjit's
225
- out_shardings.
226
-
227
- The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
228
- will use GSPMD's sharding propagation to figure out what the sharding of the
229
- output(s) should be.
230
- static_argnums: An optional int or collection of ints that specify which
231
- positional arguments to treat as static (compile-time constant).
232
- Operations that only depend on static arguments will be constant-folded in
233
- Python (during tracing), and so the corresponding argument values can be
234
- any Python object.
235
-
236
- Static arguments should be hashable, meaning both ``__hash__`` and
237
- ``__eq__`` are implemented, and immutable. Calling the jitted function
238
- with different values for these constants will trigger recompilation.
239
- Arguments that are not arrays or containers thereof must be marked as
240
- static.
241
-
242
- If neither ``static_argnums`` nor ``static_argnames`` is provided, no
243
- arguments are treated as static. If ``static_argnums`` is not provided but
244
- ``static_argnames`` is, or vice versa, JAX uses
245
- :code:`inspect.signature(fun)` to find any positional arguments that
246
- correspond to ``static_argnames``
247
- (or vice versa). If both ``static_argnums`` and ``static_argnames`` are
248
- provided, ``inspect.signature`` is not used, and only actual
249
- parameters listed in either ``static_argnums`` or ``static_argnames`` will
250
- be treated as static.
251
- static_argnames: An optional string or collection of strings specifying
252
- which named arguments are treated as static (compile-time constant).
253
- Operations that only depend on static arguments will be constant-folded in
254
- Python (during tracing), and so the corresponding argument values can be
255
- any Python object.
256
- donate_argnums: Specify which positional argument buffers are "donated" to
257
- the computation. It is safe to donate argument buffers if you no longer
258
- need them once the computation has finished. In some cases XLA can make
259
- use of donated buffers to reduce the amount of memory needed to perform a
260
- computation, for example recycling one of your input buffers to store a
261
- result. You should not reuse buffers that you donate to a computation, JAX
262
- will raise an error if you try to. By default, no argument buffers are
263
- donated.
264
-
265
- If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
266
- arguments are donated. If ``donate_argnums`` is not provided but
267
- ``donate_argnames`` is, or vice versa, JAX uses
268
- :code:`inspect.signature(fun)` to find any positional arguments that
269
- correspond to ``donate_argnames``
270
- (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
271
- provided, ``inspect.signature`` is not used, and only actual
272
- parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
273
- be donated.
274
-
275
- For more details on buffer donation see the
276
- `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
277
- donate_argnames: An optional string or collection of strings specifying
278
- which named arguments are donated to the computation. See the
279
- comment on ``donate_argnums`` for details. If not
280
- provided but ``donate_argnums`` is set, the default is based on calling
281
- ``inspect.signature(fun)`` to find corresponding named arguments.
282
- keep_unused: If `False` (the default), arguments that JAX determines to be
283
- unused by `fun` *may* be dropped from resulting compiled XLA executables.
284
- Such arguments will not be transferred to the device nor provided to the
285
- underlying executable. If `True`, unused arguments will not be pruned.
286
- device: This is an experimental feature and the API is likely to change.
287
- Optional, the Device the jitted function will run on. (Available devices
288
- can be retrieved via :py:func:`jax.devices`.) The default is inherited
289
- from XLA's DeviceAssignment logic and is usually to use
290
- ``jax.devices()[0]``.
291
- backend: This is an experimental feature and the API is likely to change.
292
- Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
293
- ``'tpu'``.
294
- inline: Specify whether this function should be inlined into enclosing
295
- jaxprs (rather than being represented as an application of the xla_call
296
- primitive with its own subjaxpr). Default False.
297
- abstracted_axes:
298
-
299
- Returns:
300
- A wrapped version of ``fun``, set up for just-in-time compilation.
301
- The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
302
- and has the following attributes and methods:
303
-
304
- - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
305
- - ``origin_fun(*args, **kwargs)``: the original function
306
- - ``jitted_fun(*args, **kwargs)``: the jitted function
307
- - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
308
-
309
- """
310
-
311
- if isinstance(fun, Missing):
312
- def wrapper(fun_again: Callable) -> JittedFunction:
313
- return _get_jitted_fun(
314
- fun_again,
315
- in_shardings=in_shardings,
316
- out_shardings=out_shardings,
317
- static_argnums=static_argnums,
318
- donate_argnums=donate_argnums,
319
- static_argnames=static_argnames,
320
- donate_argnames=donate_argnames,
321
- keep_unused=keep_unused,
322
- device=device,
323
- backend=backend,
324
- inline=inline,
325
- abstracted_axes=abstracted_axes,
326
- **kwargs
327
- )
328
-
329
- return wrapper
330
-
331
- else:
332
- return _get_jitted_fun(
333
- fun,
334
- in_shardings,
335
- out_shardings,
336
- static_argnums,
337
- donate_argnums,
338
- static_argnames,
339
- donate_argnames,
340
- keep_unused,
341
- device,
342
- backend,
343
- inline,
344
- abstracted_axes,
345
- **kwargs
346
- )
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from collections.abc import Iterable, Sequence
18
+ from typing import (Any, Callable, Union)
19
+
20
+ import jax
21
+ from jax._src import sharding_impls
22
+
23
+ from brainstate._compatible_import import Device
24
+ from brainstate._utils import set_module_as
25
+ from brainstate.typing import Missing
26
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
27
+ from ._util import write_back_state_values
28
+
29
+ __all__ = ['jit']
30
+
31
+
32
+ class JittedFunction(Callable):
33
+ """
34
+ A wrapped version of ``fun``, set up for just-in-time compilation.
35
+ """
36
+ origin_fun: Callable # the original function
37
+ stateful_fun: StatefulFunction # the stateful function for extracting states
38
+ jitted_fun: jax.stages.Wrapped # the jitted function
39
+ clear_cache: Callable # clear the cache of the jitted function
40
+ eval_shape: Callable # evaluate the shape of the jitted function
41
+ compile: Callable # lower the jitted function
42
+ trace: Callable # trace the jitted
43
+
44
+ def __call__(self, *args, **kwargs):
45
+ pass
46
+
47
+
48
+ def _get_jitted_fun(
49
+ fun: Callable,
50
+ in_shardings,
51
+ out_shardings,
52
+ static_argnums,
53
+ donate_argnums,
54
+ static_argnames,
55
+ donate_argnames,
56
+ keep_unused,
57
+ device,
58
+ backend,
59
+ inline,
60
+ abstracted_axes,
61
+ **kwargs
62
+ ) -> JittedFunction:
63
+ static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
64
+ donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
65
+ fun = StatefulFunction(
66
+ fun,
67
+ static_argnums=static_argnums,
68
+ static_argnames=static_argnames,
69
+ abstracted_axes=abstracted_axes,
70
+ cache_type='jit',
71
+ name='jit'
72
+ )
73
+ jit_fun = jax.jit(
74
+ fun.jaxpr_call,
75
+ static_argnums=tuple(i + 1 for i in static_argnums),
76
+ static_argnames=static_argnames,
77
+ donate_argnums=tuple(i + 1 for i in donate_argnums),
78
+ donate_argnames=donate_argnames,
79
+ keep_unused=keep_unused,
80
+ device=device,
81
+ backend=backend,
82
+ inline=inline,
83
+ in_shardings=in_shardings,
84
+ out_shardings=out_shardings,
85
+ abstracted_axes=abstracted_axes,
86
+ **kwargs
87
+ )
88
+
89
+ @functools.wraps(fun.fun)
90
+ def jitted_fun(*args, **params):
91
+ if jax.config.jax_disable_jit:
92
+ return fun.fun(*args, **params)
93
+
94
+ # compile the function and get the state trace
95
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
96
+ read_state_vals = state_trace.get_read_state_values(True)
97
+
98
+ # call the jitted function
99
+ write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
100
+
101
+ # write the state values back to the states
102
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
103
+ return outs
104
+
105
+ def clear_cache():
106
+ """
107
+ Clear the cache of the jitted function.
108
+ """
109
+ # clear the cache of the stateful function
110
+ fun.clear_cache()
111
+ try:
112
+ # clear the cache of the jitted function
113
+ jit_fun.clear_cache()
114
+ except AttributeError:
115
+ pass
116
+
117
+ def eval_shape():
118
+ raise NotImplementedError
119
+
120
+ def trace():
121
+ """Trace this function explicitly for the given arguments.
122
+
123
+ A traced function is staged out of Python and translated to a jaxpr. It is
124
+ ready for lowering but not yet lowered.
125
+
126
+ Returns:
127
+ A ``Traced`` instance representing the tracing.
128
+ """
129
+ raise NotImplementedError
130
+
131
+ def compile(*args, **params):
132
+ """Lower this function explicitly for the given arguments.
133
+
134
+ A lowered function is staged out of Python and translated to a
135
+ compiler's input language, possibly in a backend-dependent
136
+ manner. It is ready for compilation but not yet compiled.
137
+
138
+ Returns:
139
+ A ``Lowered`` instance representing the lowering.
140
+ """
141
+ # compile the function and get the state trace
142
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
143
+ read_state_vals = state_trace.get_read_state_values(replace_writen=True)
144
+ write_state_vals = state_trace.get_write_state_values(replace_read=True)
145
+
146
+ # compile the model
147
+ ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
148
+
149
+ # write the state values back to the states
150
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
151
+ return ret
152
+
153
+ jitted_fun: JittedFunction
154
+
155
+ # the original function
156
+ jitted_fun.origin_fun = fun.fun
157
+
158
+ # the stateful function for extracting states
159
+ jitted_fun.stateful_fun = fun
160
+
161
+ # the jitted function
162
+ jitted_fun.jitted_fun = jit_fun
163
+
164
+ # clear cache
165
+ jitted_fun.clear_cache = clear_cache
166
+
167
+ # evaluate the shape of the jitted function
168
+ jitted_fun.eval_shape = eval_shape
169
+
170
+ # compile the jitted function
171
+ jitted_fun.compile = compile
172
+
173
+ # trace the jitted function
174
+ jitted_fun.trace = trace
175
+
176
+ return jitted_fun
177
+
178
+
179
+ @set_module_as('brainstate.compile')
180
+ def jit(
181
+ fun: Callable | Missing = Missing(),
182
+ in_shardings=sharding_impls.UNSPECIFIED,
183
+ out_shardings=sharding_impls.UNSPECIFIED,
184
+ static_argnums: int | Sequence[int] | None = None,
185
+ donate_argnums: int | Sequence[int] | None = None,
186
+ static_argnames: str | Sequence[str] | None = None,
187
+ donate_argnames: str | Iterable[str] | None = None,
188
+ keep_unused: bool = False,
189
+ device: Device | None = None,
190
+ backend: str | None = None,
191
+ inline: bool = False,
192
+ abstracted_axes: Any | None = None,
193
+ **kwargs
194
+ ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
195
+ """
196
+ Sets up ``fun`` for just-in-time compilation with XLA.
197
+
198
+ Args:
199
+ fun: Function to be jitted.
200
+ in_shardings: Pytree of structure matching that of arguments to ``fun``,
201
+ with all actual arguments replaced by resource assignment specifications.
202
+ It is also valid to specify a pytree prefix (e.g. one value in place of a
203
+ whole subtree), in which case the leaves get broadcast to all values in
204
+ that subtree.
205
+
206
+ The ``in_shardings`` argument is optional. JAX will infer the shardings
207
+ from the input :py:class:`jax.Array`'s and defaults to replicating the input
208
+ if the sharding cannot be inferred.
209
+
210
+ The valid resource assignment specifications are:
211
+ - :py:class:`XLACompatibleSharding`, which will decide how the value
212
+ will be partitioned. With this, using a mesh context manager is not
213
+ required.
214
+ - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
215
+ it wants.
216
+ For in_shardings, JAX will mark is as replicated but this behavior
217
+ can change in the future.
218
+ For out_shardings, we will rely on the XLA GSPMD partitioner to
219
+ determine the output shardings.
220
+
221
+ The size of every dimension has to be a multiple of the total number of
222
+ resources assigned to it. This is similar to pjit's in_shardings.
223
+ out_shardings: Like ``in_shardings``, but specifies resource
224
+ assignment for function outputs. This is similar to pjit's
225
+ out_shardings.
226
+
227
+ The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
228
+ will use GSPMD's sharding propagation to figure out what the sharding of the
229
+ output(s) should be.
230
+ static_argnums: An optional int or collection of ints that specify which
231
+ positional arguments to treat as static (compile-time constant).
232
+ Operations that only depend on static arguments will be constant-folded in
233
+ Python (during tracing), and so the corresponding argument values can be
234
+ any Python object.
235
+
236
+ Static arguments should be hashable, meaning both ``__hash__`` and
237
+ ``__eq__`` are implemented, and immutable. Calling the jitted function
238
+ with different values for these constants will trigger recompilation.
239
+ Arguments that are not arrays or containers thereof must be marked as
240
+ static.
241
+
242
+ If neither ``static_argnums`` nor ``static_argnames`` is provided, no
243
+ arguments are treated as static. If ``static_argnums`` is not provided but
244
+ ``static_argnames`` is, or vice versa, JAX uses
245
+ :code:`inspect.signature(fun)` to find any positional arguments that
246
+ correspond to ``static_argnames``
247
+ (or vice versa). If both ``static_argnums`` and ``static_argnames`` are
248
+ provided, ``inspect.signature`` is not used, and only actual
249
+ parameters listed in either ``static_argnums`` or ``static_argnames`` will
250
+ be treated as static.
251
+ static_argnames: An optional string or collection of strings specifying
252
+ which named arguments are treated as static (compile-time constant).
253
+ Operations that only depend on static arguments will be constant-folded in
254
+ Python (during tracing), and so the corresponding argument values can be
255
+ any Python object.
256
+ donate_argnums: Specify which positional argument buffers are "donated" to
257
+ the computation. It is safe to donate argument buffers if you no longer
258
+ need them once the computation has finished. In some cases XLA can make
259
+ use of donated buffers to reduce the amount of memory needed to perform a
260
+ computation, for example recycling one of your input buffers to store a
261
+ result. You should not reuse buffers that you donate to a computation, JAX
262
+ will raise an error if you try to. By default, no argument buffers are
263
+ donated.
264
+
265
+ If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
266
+ arguments are donated. If ``donate_argnums`` is not provided but
267
+ ``donate_argnames`` is, or vice versa, JAX uses
268
+ :code:`inspect.signature(fun)` to find any positional arguments that
269
+ correspond to ``donate_argnames``
270
+ (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
271
+ provided, ``inspect.signature`` is not used, and only actual
272
+ parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
273
+ be donated.
274
+
275
+ For more details on buffer donation see the
276
+ `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
277
+ donate_argnames: An optional string or collection of strings specifying
278
+ which named arguments are donated to the computation. See the
279
+ comment on ``donate_argnums`` for details. If not
280
+ provided but ``donate_argnums`` is set, the default is based on calling
281
+ ``inspect.signature(fun)`` to find corresponding named arguments.
282
+ keep_unused: If `False` (the default), arguments that JAX determines to be
283
+ unused by `fun` *may* be dropped from resulting compiled XLA executables.
284
+ Such arguments will not be transferred to the device nor provided to the
285
+ underlying executable. If `True`, unused arguments will not be pruned.
286
+ device: This is an experimental feature and the API is likely to change.
287
+ Optional, the Device the jitted function will run on. (Available devices
288
+ can be retrieved via :py:func:`jax.devices`.) The default is inherited
289
+ from XLA's DeviceAssignment logic and is usually to use
290
+ ``jax.devices()[0]``.
291
+ backend: This is an experimental feature and the API is likely to change.
292
+ Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
293
+ ``'tpu'``.
294
+ inline: Specify whether this function should be inlined into enclosing
295
+ jaxprs (rather than being represented as an application of the xla_call
296
+ primitive with its own subjaxpr). Default False.
297
+ abstracted_axes:
298
+
299
+ Returns:
300
+ A wrapped version of ``fun``, set up for just-in-time compilation.
301
+ The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
302
+ and has the following attributes and methods:
303
+
304
+ - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
305
+ - ``origin_fun(*args, **kwargs)``: the original function
306
+ - ``jitted_fun(*args, **kwargs)``: the jitted function
307
+ - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
308
+
309
+ """
310
+
311
+ if isinstance(fun, Missing):
312
+ def wrapper(fun_again: Callable) -> JittedFunction:
313
+ return _get_jitted_fun(
314
+ fun_again,
315
+ in_shardings=in_shardings,
316
+ out_shardings=out_shardings,
317
+ static_argnums=static_argnums,
318
+ donate_argnums=donate_argnums,
319
+ static_argnames=static_argnames,
320
+ donate_argnames=donate_argnames,
321
+ keep_unused=keep_unused,
322
+ device=device,
323
+ backend=backend,
324
+ inline=inline,
325
+ abstracted_axes=abstracted_axes,
326
+ **kwargs
327
+ )
328
+
329
+ return wrapper
330
+
331
+ else:
332
+ return _get_jitted_fun(
333
+ fun,
334
+ in_shardings,
335
+ out_shardings,
336
+ static_argnums,
337
+ donate_argnums,
338
+ static_argnames,
339
+ donate_argnames,
340
+ keep_unused,
341
+ device,
342
+ backend,
343
+ inline,
344
+ abstracted_axes,
345
+ **kwargs
346
+ )