brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,265 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import functools
19
- from collections.abc import Iterable, Sequence
20
- from typing import (Any, Callable, Union)
21
-
22
- import jax
23
- from jax._src import sharding_impls
24
- from jax.lib import xla_client as xc
25
-
26
- from brainstate._utils import set_module_as
27
- from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
28
-
29
- __all__ = ['jit']
30
-
31
-
32
- class JittedFunction(Callable):
33
- """
34
- A wrapped version of ``fun``, set up for just-in-time compilation.
35
- """
36
- origin_fun: Callable # the original function
37
- stateful_fun: StatefulFunction # the stateful function for extracting states
38
- jitted_fun: jax.stages.Wrapped # the jitted function
39
- clear_cache: Callable # clear the cache of the jitted function
40
-
41
- def __call__(self, *args, **kwargs):
42
- pass
43
-
44
-
45
- def _get_jitted_fun(
46
- fun: Callable,
47
- in_shardings,
48
- out_shardings,
49
- static_argnums,
50
- donate_argnums,
51
- donate_argnames,
52
- keep_unused,
53
- device,
54
- backend,
55
- inline,
56
- abstracted_axes,
57
- **kwargs
58
- ) -> JittedFunction:
59
- static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
60
- # TODO: add to cache stack for clear_cache
61
- fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
62
- jit_fun = jax.jit(fun.jaxpr_call,
63
- static_argnums=tuple(i + 1 for i in static_argnums),
64
- donate_argnums=donate_argnums,
65
- donate_argnames=donate_argnames,
66
- keep_unused=keep_unused,
67
- device=device,
68
- backend=backend,
69
- inline=inline,
70
- in_shardings=in_shardings,
71
- out_shardings=out_shardings,
72
- abstracted_axes=abstracted_axes,
73
- **kwargs)
74
-
75
- @functools.wraps(fun.fun)
76
- def jitted_fun(*args, **params):
77
- if jax.config.jax_disable_jit:
78
- return fun.fun(*args, **params)
79
- states = fun.compile_and_get_states_by_static_args(*args, **params)
80
- state_vals, outs = jit_fun([st.value for st in states], *args, **params)
81
- _assign_state_values(states, state_vals)
82
- return outs
83
-
84
- def clear_cache():
85
- # clear the cache of the stateful function
86
- fun.clear_cache()
87
- # clear the cache of the jitted function
88
- jit_fun.clear_cache()
89
-
90
- jitted_fun: JittedFunction
91
-
92
- # the original function
93
- jitted_fun.origin_fun = fun.fun
94
-
95
- # the stateful function for extracting states
96
- jitted_fun.stateful_fun = fun
97
-
98
- # the jitted function
99
- jitted_fun.jitted_fun = jit_fun
100
-
101
- # clear cache
102
- jitted_fun.clear_cache = clear_cache
103
-
104
- return jitted_fun
105
-
106
-
107
- @set_module_as('brainstate.transform')
108
- def jit(
109
- fun: Callable = None,
110
- in_shardings=sharding_impls.UNSPECIFIED,
111
- out_shardings=sharding_impls.UNSPECIFIED,
112
- static_argnums: int | Sequence[int] | None = None,
113
- donate_argnums: int | Sequence[int] | None = None,
114
- donate_argnames: str | Iterable[str] | None = None,
115
- keep_unused: bool = False,
116
- device: xc.Device | None = None,
117
- backend: str | None = None,
118
- inline: bool = False,
119
- abstracted_axes: Any | None = None,
120
- **kwargs
121
- ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
122
- """
123
- Sets up ``fun`` for just-in-time compilation with XLA.
124
-
125
- Does not support setting ``static_argnames`` as in ``jax.jit()``.
126
-
127
-
128
- Args:
129
- fun: Function to be jitted.
130
- in_shardings: Pytree of structure matching that of arguments to ``fun``,
131
- with all actual arguments replaced by resource assignment specifications.
132
- It is also valid to specify a pytree prefix (e.g. one value in place of a
133
- whole subtree), in which case the leaves get broadcast to all values in
134
- that subtree.
135
-
136
- The ``in_shardings`` argument is optional. JAX will infer the shardings
137
- from the input :py:class:`jax.Array`'s and defaults to replicating the input
138
- if the sharding cannot be inferred.
139
-
140
- The valid resource assignment specifications are:
141
- - :py:class:`XLACompatibleSharding`, which will decide how the value
142
- will be partitioned. With this, using a mesh context manager is not
143
- required.
144
- - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
145
- it wants.
146
- For in_shardings, JAX will mark is as replicated but this behavior
147
- can change in the future.
148
- For out_shardings, we will rely on the XLA GSPMD partitioner to
149
- determine the output shardings.
150
-
151
- The size of every dimension has to be a multiple of the total number of
152
- resources assigned to it. This is similar to pjit's in_shardings.
153
- out_shardings: Like ``in_shardings``, but specifies resource
154
- assignment for function outputs. This is similar to pjit's
155
- out_shardings.
156
-
157
- The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
158
- will use GSPMD's sharding propagation to figure out what the sharding of the
159
- output(s) should be.
160
- static_argnums: An optional int or collection of ints that specify which
161
- positional arguments to treat as static (compile-time constant).
162
- Operations that only depend on static arguments will be constant-folded in
163
- Python (during tracing), and so the corresponding argument values can be
164
- any Python object.
165
-
166
- Static arguments should be hashable, meaning both ``__hash__`` and
167
- ``__eq__`` are implemented, and immutable. Calling the jitted function
168
- with different values for these constants will trigger recompilation.
169
- Arguments that are not arrays or containers thereof must be marked as
170
- static.
171
-
172
- If neither ``static_argnums`` nor ``static_argnames`` is provided, no
173
- arguments are treated as static. If ``static_argnums`` is not provided but
174
- ``static_argnames`` is, or vice versa, JAX uses
175
- :code:`inspect.signature(fun)` to find any positional arguments that
176
- correspond to ``static_argnames``
177
- (or vice versa). If both ``static_argnums`` and ``static_argnames`` are
178
- provided, ``inspect.signature`` is not used, and only actual
179
- parameters listed in either ``static_argnums`` or ``static_argnames`` will
180
- be treated as static.
181
- donate_argnums: Specify which positional argument buffers are "donated" to
182
- the computation. It is safe to donate argument buffers if you no longer
183
- need them once the computation has finished. In some cases XLA can make
184
- use of donated buffers to reduce the amount of memory needed to perform a
185
- computation, for example recycling one of your input buffers to store a
186
- result. You should not reuse buffers that you donate to a computation, JAX
187
- will raise an error if you try to. By default, no argument buffers are
188
- donated.
189
-
190
- If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
191
- arguments are donated. If ``donate_argnums`` is not provided but
192
- ``donate_argnames`` is, or vice versa, JAX uses
193
- :code:`inspect.signature(fun)` to find any positional arguments that
194
- correspond to ``donate_argnames``
195
- (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
196
- provided, ``inspect.signature`` is not used, and only actual
197
- parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
198
- be donated.
199
-
200
- For more details on buffer donation see the
201
- `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
202
- donate_argnames: An optional string or collection of strings specifying
203
- which named arguments are donated to the computation. See the
204
- comment on ``donate_argnums`` for details. If not
205
- provided but ``donate_argnums`` is set, the default is based on calling
206
- ``inspect.signature(fun)`` to find corresponding named arguments.
207
- keep_unused: If `False` (the default), arguments that JAX determines to be
208
- unused by `fun` *may* be dropped from resulting compiled XLA executables.
209
- Such arguments will not be transferred to the device nor provided to the
210
- underlying executable. If `True`, unused arguments will not be pruned.
211
- device: This is an experimental feature and the API is likely to change.
212
- Optional, the Device the jitted function will run on. (Available devices
213
- can be retrieved via :py:func:`jax.devices`.) The default is inherited
214
- from XLA's DeviceAssignment logic and is usually to use
215
- ``jax.devices()[0]``.
216
- backend: This is an experimental feature and the API is likely to change.
217
- Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
218
- ``'tpu'``.
219
- inline: Specify whether this function should be inlined into enclosing
220
- jaxprs (rather than being represented as an application of the xla_call
221
- primitive with its own subjaxpr). Default False.
222
- abstracted_axes:
223
-
224
- Returns:
225
- A wrapped version of ``fun``, set up for just-in-time compilation.
226
- The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
227
- and has the following attributes and methods:
228
-
229
- - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
230
- - ``origin_fun(*args, **kwargs)``: the original function
231
- - ``jitted_fun(*args, **kwargs)``: the jitted function
232
- - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
233
-
234
- """
235
-
236
- if fun is None:
237
- def wrapper(fun_again: Callable) -> JittedFunction:
238
- return _get_jitted_fun(fun_again,
239
- in_shardings,
240
- out_shardings,
241
- static_argnums,
242
- donate_argnums,
243
- donate_argnames,
244
- keep_unused,
245
- device,
246
- backend,
247
- inline,
248
- abstracted_axes,
249
- **kwargs)
250
-
251
- return wrapper
252
-
253
- else:
254
- return _get_jitted_fun(fun,
255
- in_shardings,
256
- out_shardings,
257
- static_argnums,
258
- donate_argnums,
259
- donate_argnames,
260
- keep_unused,
261
- device,
262
- backend,
263
- inline,
264
- abstracted_axes,
265
- **kwargs)
@@ -1,118 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax.numpy as jnp
19
-
20
- import brainstate as bc
21
-
22
-
23
- class TestJIT(unittest.TestCase):
24
- def test_inner_state_are_not_catched(self):
25
- a = bc.State(bc.random.randn(10))
26
-
27
- @bc.transform.jit
28
- def fun1(inp):
29
- a.value += inp
30
-
31
- b = bc.State(bc.random.randn(1))
32
-
33
- def inner_fun(x):
34
- b.value += x
35
-
36
- bc.transform.for_loop(inner_fun, bc.random.randn(100))
37
-
38
- return a.value + b.value
39
-
40
- print(fun1(1.))
41
- key = fun1.stateful_fun.get_arg_cache_key(1.)
42
- self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
43
-
44
- x = bc.random.randn(10)
45
- print(fun1(x))
46
- key = fun1.stateful_fun.get_arg_cache_key(x)
47
- self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
48
-
49
- def test_kwargs(self):
50
- a = bc.State(bc.random.randn(10))
51
-
52
- @bc.transform.jit
53
- def fun1(inp):
54
- a.value += inp
55
-
56
- b = bc.State(bc.random.randn(1))
57
-
58
- def inner_fun(x):
59
- b.value += x
60
-
61
- bc.transform.for_loop(inner_fun, bc.random.randn(100))
62
-
63
- return a.value + b.value
64
-
65
- # test kwargs
66
- print(fun1(inp=bc.random.randn(10)))
67
-
68
- def test_jit_compile_sensitive_to_input_shape(self):
69
- global_data = [0]
70
-
71
- @bc.transform.jit
72
- def fun1(inp):
73
- global_data[0] += 1
74
- return inp
75
-
76
- print(fun1(1.))
77
- self.assertTrue(global_data[0] == 1)
78
-
79
- print(fun1(2.))
80
- self.assertTrue(global_data[0] == 1)
81
-
82
- print(fun1(bc.random.randn(10)))
83
- self.assertTrue(global_data[0] == 2)
84
-
85
- print(fun1(bc.random.randn(10, 10)))
86
- self.assertTrue(global_data[0] == 3)
87
-
88
- def test_jit_clear_cache(self):
89
- a = bc.State(bc.random.randn(1))
90
- compiling = []
91
-
92
- @bc.transform.jit
93
- def log2(x):
94
- print('compiling')
95
- compiling.append(1)
96
- ln_x = jnp.log(x)
97
- ln_2 = jnp.log(2.0) + a.value
98
- return ln_x / ln_2
99
-
100
- x = bc.random.randn(1)
101
- print(log2(x)) # compiling
102
- self.assertTrue(len(compiling) == 1)
103
- print(log2(x)) # no compiling
104
- self.assertTrue(len(compiling) == 1)
105
-
106
- log2.clear_cache()
107
- print(log2(x)) # compiling
108
- self.assertTrue(len(compiling) == 2)
109
-
110
- def test_jit_attribute_origin_fun(self):
111
- def fun1(x):
112
- return x
113
-
114
- jitted_fun = bc.transform.jit(fun1)
115
- self.assertTrue(jitted_fun.origin_fun is fun1)
116
- self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
117
- self.assertTrue(callable(jitted_fun.jitted_fun))
118
- self.assertTrue(callable(jitted_fun.clear_cache))