brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,314 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from collections.abc import Iterable, Sequence
20
+ from typing import (Any, Callable, Union)
21
+
22
+ import jax
23
+ from jax._src import sharding_impls
24
+ from jax.lib import xla_client as xc
25
+
26
+ from brainstate._utils import set_module_as
27
+ from brainstate.typing import Missing
28
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
29
+ from ._util import write_back_state_values
30
+
31
+ __all__ = ['jit']
32
+
33
+
34
+ class JittedFunction(Callable):
35
+ """
36
+ A wrapped version of ``fun``, set up for just-in-time compilation.
37
+ """
38
+ origin_fun: Callable # the original function
39
+ stateful_fun: StatefulFunction # the stateful function for extracting states
40
+ jitted_fun: jax.stages.Wrapped # the jitted function
41
+ clear_cache: Callable # clear the cache of the jitted function
42
+ eval_shape: Callable # evaluate the shape of the jitted function
43
+ lower: Callable # lower the jitted function
44
+ trace: Callable # trace the jitted
45
+
46
+ def __call__(self, *args, **kwargs):
47
+ pass
48
+
49
+
50
+ def _get_jitted_fun(
51
+ fun: Callable,
52
+ in_shardings,
53
+ out_shardings,
54
+ static_argnums,
55
+ donate_argnums,
56
+ donate_argnames,
57
+ keep_unused,
58
+ device,
59
+ backend,
60
+ inline,
61
+ abstracted_axes,
62
+ **kwargs
63
+ ) -> JittedFunction:
64
+ static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
65
+ # TODO: add to cache stack for clear_cache
66
+ fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
67
+ jit_fun = jax.jit(fun.jaxpr_call,
68
+ static_argnums=tuple(i + 1 for i in static_argnums),
69
+ donate_argnums=donate_argnums,
70
+ donate_argnames=donate_argnames,
71
+ keep_unused=keep_unused,
72
+ device=device,
73
+ backend=backend,
74
+ inline=inline,
75
+ in_shardings=in_shardings,
76
+ out_shardings=out_shardings,
77
+ abstracted_axes=abstracted_axes,
78
+ **kwargs)
79
+
80
+ @functools.wraps(fun.fun)
81
+ def jitted_fun(*args, **params):
82
+ if jax.config.jax_disable_jit:
83
+ return fun.fun(*args, **params)
84
+
85
+ # compile the function and get the state trace
86
+ with jax.ensure_compile_time_eval():
87
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
88
+ read_state_vals = state_trace.get_read_state_values(True)
89
+ # call the jitted function
90
+ write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
+ # write the state values back to the states
92
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
93
+ return outs
94
+
95
+ def clear_cache():
96
+ """
97
+ Clear the cache of the jitted function.
98
+ """
99
+ # clear the cache of the stateful function
100
+ fun.clear_cache()
101
+ # clear the cache of the jitted function
102
+ jit_fun.clear_cache()
103
+
104
+ def eval_shape():
105
+ raise NotImplementedError
106
+
107
+ def lower():
108
+ """Lower this function explicitly for the given arguments.
109
+
110
+ A lowered function is staged out of Python and translated to a
111
+ compiler's input language, possibly in a backend-dependent
112
+ manner. It is ready for compilation but not yet compiled.
113
+
114
+ Returns:
115
+ A ``Lowered`` instance representing the lowering.
116
+ """
117
+ raise NotImplementedError
118
+
119
+ def trace():
120
+ """Trace this function explicitly for the given arguments.
121
+
122
+ A traced function is staged out of Python and translated to a jaxpr. It is
123
+ ready for lowering but not yet lowered.
124
+
125
+ Returns:
126
+ A ``Traced`` instance representing the tracing.
127
+ """
128
+ raise NotImplementedError
129
+
130
+ jitted_fun: JittedFunction
131
+
132
+ # the original function
133
+ jitted_fun.origin_fun = fun.fun
134
+
135
+ # the stateful function for extracting states
136
+ jitted_fun.stateful_fun = fun
137
+
138
+ # the jitted function
139
+ jitted_fun.jitted_fun = jit_fun
140
+
141
+ # clear cache
142
+ jitted_fun.clear_cache = clear_cache
143
+
144
+ # evaluate the shape of the jitted function
145
+ jitted_fun.eval_shape = eval_shape
146
+
147
+ # lower the jitted function
148
+ jitted_fun.lower = lower
149
+
150
+ # trace the jitted
151
+ jitted_fun.trace = trace
152
+
153
+ return jitted_fun
154
+
155
+
156
+ @set_module_as('brainstate.compile')
157
+ def jit(
158
+ fun: Callable | Missing = Missing(),
159
+ in_shardings=sharding_impls.UNSPECIFIED,
160
+ out_shardings=sharding_impls.UNSPECIFIED,
161
+ static_argnums: int | Sequence[int] | None = None,
162
+ donate_argnums: int | Sequence[int] | None = None,
163
+ donate_argnames: str | Iterable[str] | None = None,
164
+ keep_unused: bool = False,
165
+ device: xc.Device | None = None,
166
+ backend: str | None = None,
167
+ inline: bool = False,
168
+ abstracted_axes: Any | None = None,
169
+ **kwargs
170
+ ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
171
+ """
172
+ Sets up ``fun`` for just-in-time compilation with XLA.
173
+
174
+ Does not support setting ``static_argnames`` as in ``jax.jit()``.
175
+
176
+
177
+ Args:
178
+ fun: Function to be jitted.
179
+ in_shardings: Pytree of structure matching that of arguments to ``fun``,
180
+ with all actual arguments replaced by resource assignment specifications.
181
+ It is also valid to specify a pytree prefix (e.g. one value in place of a
182
+ whole subtree), in which case the leaves get broadcast to all values in
183
+ that subtree.
184
+
185
+ The ``in_shardings`` argument is optional. JAX will infer the shardings
186
+ from the input :py:class:`jax.Array`'s and defaults to replicating the input
187
+ if the sharding cannot be inferred.
188
+
189
+ The valid resource assignment specifications are:
190
+ - :py:class:`XLACompatibleSharding`, which will decide how the value
191
+ will be partitioned. With this, using a mesh context manager is not
192
+ required.
193
+ - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
194
+ it wants.
195
+ For in_shardings, JAX will mark is as replicated but this behavior
196
+ can change in the future.
197
+ For out_shardings, we will rely on the XLA GSPMD partitioner to
198
+ determine the output shardings.
199
+
200
+ The size of every dimension has to be a multiple of the total number of
201
+ resources assigned to it. This is similar to pjit's in_shardings.
202
+ out_shardings: Like ``in_shardings``, but specifies resource
203
+ assignment for function outputs. This is similar to pjit's
204
+ out_shardings.
205
+
206
+ The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
207
+ will use GSPMD's sharding propagation to figure out what the sharding of the
208
+ output(s) should be.
209
+ static_argnums: An optional int or collection of ints that specify which
210
+ positional arguments to treat as static (compile-time constant).
211
+ Operations that only depend on static arguments will be constant-folded in
212
+ Python (during tracing), and so the corresponding argument values can be
213
+ any Python object.
214
+
215
+ Static arguments should be hashable, meaning both ``__hash__`` and
216
+ ``__eq__`` are implemented, and immutable. Calling the jitted function
217
+ with different values for these constants will trigger recompilation.
218
+ Arguments that are not arrays or containers thereof must be marked as
219
+ static.
220
+
221
+ If neither ``static_argnums`` nor ``static_argnames`` is provided, no
222
+ arguments are treated as static. If ``static_argnums`` is not provided but
223
+ ``static_argnames`` is, or vice versa, JAX uses
224
+ :code:`inspect.signature(fun)` to find any positional arguments that
225
+ correspond to ``static_argnames``
226
+ (or vice versa). If both ``static_argnums`` and ``static_argnames`` are
227
+ provided, ``inspect.signature`` is not used, and only actual
228
+ parameters listed in either ``static_argnums`` or ``static_argnames`` will
229
+ be treated as static.
230
+ donate_argnums: Specify which positional argument buffers are "donated" to
231
+ the computation. It is safe to donate argument buffers if you no longer
232
+ need them once the computation has finished. In some cases XLA can make
233
+ use of donated buffers to reduce the amount of memory needed to perform a
234
+ computation, for example recycling one of your input buffers to store a
235
+ result. You should not reuse buffers that you donate to a computation, JAX
236
+ will raise an error if you try to. By default, no argument buffers are
237
+ donated.
238
+
239
+ If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
240
+ arguments are donated. If ``donate_argnums`` is not provided but
241
+ ``donate_argnames`` is, or vice versa, JAX uses
242
+ :code:`inspect.signature(fun)` to find any positional arguments that
243
+ correspond to ``donate_argnames``
244
+ (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
245
+ provided, ``inspect.signature`` is not used, and only actual
246
+ parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
247
+ be donated.
248
+
249
+ For more details on buffer donation see the
250
+ `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
251
+ donate_argnames: An optional string or collection of strings specifying
252
+ which named arguments are donated to the computation. See the
253
+ comment on ``donate_argnums`` for details. If not
254
+ provided but ``donate_argnums`` is set, the default is based on calling
255
+ ``inspect.signature(fun)`` to find corresponding named arguments.
256
+ keep_unused: If `False` (the default), arguments that JAX determines to be
257
+ unused by `fun` *may* be dropped from resulting compiled XLA executables.
258
+ Such arguments will not be transferred to the device nor provided to the
259
+ underlying executable. If `True`, unused arguments will not be pruned.
260
+ device: This is an experimental feature and the API is likely to change.
261
+ Optional, the Device the jitted function will run on. (Available devices
262
+ can be retrieved via :py:func:`jax.devices`.) The default is inherited
263
+ from XLA's DeviceAssignment logic and is usually to use
264
+ ``jax.devices()[0]``.
265
+ backend: This is an experimental feature and the API is likely to change.
266
+ Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
267
+ ``'tpu'``.
268
+ inline: Specify whether this function should be inlined into enclosing
269
+ jaxprs (rather than being represented as an application of the xla_call
270
+ primitive with its own subjaxpr). Default False.
271
+ abstracted_axes:
272
+
273
+ Returns:
274
+ A wrapped version of ``fun``, set up for just-in-time compilation.
275
+ The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
276
+ and has the following attributes and methods:
277
+
278
+ - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
279
+ - ``origin_fun(*args, **kwargs)``: the original function
280
+ - ``jitted_fun(*args, **kwargs)``: the jitted function
281
+ - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
282
+
283
+ """
284
+
285
+ if isinstance(fun, Missing):
286
+ def wrapper(fun_again: Callable) -> JittedFunction:
287
+ return _get_jitted_fun(fun_again,
288
+ in_shardings=in_shardings,
289
+ out_shardings=out_shardings,
290
+ static_argnums=static_argnums,
291
+ donate_argnums=donate_argnums,
292
+ donate_argnames=donate_argnames,
293
+ keep_unused=keep_unused,
294
+ device=device,
295
+ backend=backend,
296
+ inline=inline,
297
+ abstracted_axes=abstracted_axes,
298
+ **kwargs)
299
+
300
+ return wrapper
301
+
302
+ else:
303
+ return _get_jitted_fun(fun,
304
+ in_shardings,
305
+ out_shardings,
306
+ static_argnums,
307
+ donate_argnums,
308
+ donate_argnames,
309
+ keep_unused,
310
+ device,
311
+ backend,
312
+ inline,
313
+ abstracted_axes,
314
+ **kwargs)
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestJIT(unittest.TestCase):
26
+ def test_inner_state_are_not_catched(self):
27
+ a = bst.State(bst.random.randn(10))
28
+
29
+ @bst.compile.jit
30
+ def fun1(inp):
31
+ a.value += inp
32
+
33
+ b = bst.State(bst.random.randn(1))
34
+
35
+ def inner_fun(x):
36
+ b.value += x
37
+
38
+ bst.compile.for_loop(inner_fun, bst.random.randn(100))
39
+
40
+ return a.value + b.value
41
+
42
+ print(fun1(1.))
43
+ key = fun1.stateful_fun.get_arg_cache_key(1.)
44
+ self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
45
+
46
+ x = bst.random.randn(10)
47
+ print(fun1(x))
48
+ key = fun1.stateful_fun.get_arg_cache_key(x)
49
+ self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
50
+
51
+ def test_kwargs(self):
52
+ a = bst.State(bst.random.randn(10))
53
+
54
+ @bst.compile.jit
55
+ def fun1(inp):
56
+ a.value += inp
57
+
58
+ b = bst.State(bst.random.randn(1))
59
+
60
+ def inner_fun(x):
61
+ b.value += x
62
+
63
+ bst.compile.for_loop(inner_fun, bst.random.randn(100))
64
+
65
+ return a.value + b.value
66
+
67
+ # test kwargs
68
+ print(fun1(inp=bst.random.randn(10)))
69
+
70
+ def test_jit_compile_sensitive_to_input_shape(self):
71
+ global_data = [0]
72
+
73
+ @bst.compile.jit
74
+ def fun1(inp):
75
+ global_data[0] += 1
76
+ return inp
77
+
78
+ print(fun1(1.))
79
+ self.assertTrue(global_data[0] == 1)
80
+
81
+ print(fun1(2.))
82
+ self.assertTrue(global_data[0] == 1)
83
+
84
+ print(fun1(bst.random.randn(10)))
85
+ self.assertTrue(global_data[0] == 2)
86
+
87
+ print(fun1(bst.random.randn(10, 10)))
88
+ self.assertTrue(global_data[0] == 3)
89
+
90
+ def test_jit_clear_cache(self):
91
+ a = bst.State(bst.random.randn(1))
92
+ compiling = []
93
+
94
+ @bst.compile.jit
95
+ def log2(x):
96
+ print('compiling')
97
+ compiling.append(1)
98
+ ln_x = jnp.log(x)
99
+ ln_2 = jnp.log(2.0) + a.value
100
+ return ln_x / ln_2
101
+
102
+ x = bst.random.randn(1)
103
+ print(log2(x)) # compiling
104
+ self.assertTrue(len(compiling) == 1)
105
+ print(log2(x)) # no compiling
106
+ self.assertTrue(len(compiling) == 1)
107
+
108
+ log2.clear_cache()
109
+ print(log2(x)) # compiling
110
+ self.assertTrue(len(compiling) == 2)
111
+
112
+ def test_jit_attribute_origin_fun(self):
113
+ def fun1(x):
114
+ return x
115
+
116
+ jitted_fun = bst.compile.jit(fun1)
117
+ self.assertTrue(jitted_fun.origin_fun is fun1)
118
+ self.assertTrue(isinstance(jitted_fun.stateful_fun, bst.compile.StatefulFunction))
119
+ self.assertTrue(callable(jitted_fun.jitted_fun))
120
+ self.assertTrue(callable(jitted_fun.clear_cache))
121
+
122
+ def test_clear_cache(self):
123
+ a = bst.State(bst.random.randn(1))
124
+
125
+ @bst.compile.jit
126
+ def f_jit(x, y):
127
+ print('Compiling')
128
+ a.value = jnp.sin(x) + jnp.cos(y)
129
+
130
+ f_jit(0.5, 1.0)
131
+ f_jit.clear_cache()
132
+ f_jit(0.5, 1.0)
133
+
134
+ def test_cache(self):
135
+ @bst.compile.jit
136
+ @bst.compile.jit
137
+ @bst.compile.jit
138
+ def f(a):
139
+ print('Compiling')
140
+ print(a)
141
+ return a + 1
142
+
143
+ print(f(1.))