brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/transform/_jit.py
DELETED
@@ -1,265 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import functools
|
19
|
-
from collections.abc import Iterable, Sequence
|
20
|
-
from typing import (Any, Callable, Union)
|
21
|
-
|
22
|
-
import jax
|
23
|
-
from jax._src import sharding_impls
|
24
|
-
from jax.lib import xla_client as xc
|
25
|
-
|
26
|
-
from brainstate._utils import set_module_as
|
27
|
-
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
28
|
-
|
29
|
-
__all__ = ['jit']
|
30
|
-
|
31
|
-
|
32
|
-
class JittedFunction(Callable):
|
33
|
-
"""
|
34
|
-
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
|
-
"""
|
36
|
-
origin_fun: Callable # the original function
|
37
|
-
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
|
-
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
-
clear_cache: Callable # clear the cache of the jitted function
|
40
|
-
|
41
|
-
def __call__(self, *args, **kwargs):
|
42
|
-
pass
|
43
|
-
|
44
|
-
|
45
|
-
def _get_jitted_fun(
|
46
|
-
fun: Callable,
|
47
|
-
in_shardings,
|
48
|
-
out_shardings,
|
49
|
-
static_argnums,
|
50
|
-
donate_argnums,
|
51
|
-
donate_argnames,
|
52
|
-
keep_unused,
|
53
|
-
device,
|
54
|
-
backend,
|
55
|
-
inline,
|
56
|
-
abstracted_axes,
|
57
|
-
**kwargs
|
58
|
-
) -> JittedFunction:
|
59
|
-
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
60
|
-
# TODO: add to cache stack for clear_cache
|
61
|
-
fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
|
62
|
-
jit_fun = jax.jit(fun.jaxpr_call,
|
63
|
-
static_argnums=tuple(i + 1 for i in static_argnums),
|
64
|
-
donate_argnums=donate_argnums,
|
65
|
-
donate_argnames=donate_argnames,
|
66
|
-
keep_unused=keep_unused,
|
67
|
-
device=device,
|
68
|
-
backend=backend,
|
69
|
-
inline=inline,
|
70
|
-
in_shardings=in_shardings,
|
71
|
-
out_shardings=out_shardings,
|
72
|
-
abstracted_axes=abstracted_axes,
|
73
|
-
**kwargs)
|
74
|
-
|
75
|
-
@functools.wraps(fun.fun)
|
76
|
-
def jitted_fun(*args, **params):
|
77
|
-
if jax.config.jax_disable_jit:
|
78
|
-
return fun.fun(*args, **params)
|
79
|
-
states = fun.compile_and_get_states_by_static_args(*args, **params)
|
80
|
-
state_vals, outs = jit_fun([st.value for st in states], *args, **params)
|
81
|
-
_assign_state_values(states, state_vals)
|
82
|
-
return outs
|
83
|
-
|
84
|
-
def clear_cache():
|
85
|
-
# clear the cache of the stateful function
|
86
|
-
fun.clear_cache()
|
87
|
-
# clear the cache of the jitted function
|
88
|
-
jit_fun.clear_cache()
|
89
|
-
|
90
|
-
jitted_fun: JittedFunction
|
91
|
-
|
92
|
-
# the original function
|
93
|
-
jitted_fun.origin_fun = fun.fun
|
94
|
-
|
95
|
-
# the stateful function for extracting states
|
96
|
-
jitted_fun.stateful_fun = fun
|
97
|
-
|
98
|
-
# the jitted function
|
99
|
-
jitted_fun.jitted_fun = jit_fun
|
100
|
-
|
101
|
-
# clear cache
|
102
|
-
jitted_fun.clear_cache = clear_cache
|
103
|
-
|
104
|
-
return jitted_fun
|
105
|
-
|
106
|
-
|
107
|
-
@set_module_as('brainstate.transform')
|
108
|
-
def jit(
|
109
|
-
fun: Callable = None,
|
110
|
-
in_shardings=sharding_impls.UNSPECIFIED,
|
111
|
-
out_shardings=sharding_impls.UNSPECIFIED,
|
112
|
-
static_argnums: int | Sequence[int] | None = None,
|
113
|
-
donate_argnums: int | Sequence[int] | None = None,
|
114
|
-
donate_argnames: str | Iterable[str] | None = None,
|
115
|
-
keep_unused: bool = False,
|
116
|
-
device: xc.Device | None = None,
|
117
|
-
backend: str | None = None,
|
118
|
-
inline: bool = False,
|
119
|
-
abstracted_axes: Any | None = None,
|
120
|
-
**kwargs
|
121
|
-
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
122
|
-
"""
|
123
|
-
Sets up ``fun`` for just-in-time compilation with XLA.
|
124
|
-
|
125
|
-
Does not support setting ``static_argnames`` as in ``jax.jit()``.
|
126
|
-
|
127
|
-
|
128
|
-
Args:
|
129
|
-
fun: Function to be jitted.
|
130
|
-
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
131
|
-
with all actual arguments replaced by resource assignment specifications.
|
132
|
-
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
133
|
-
whole subtree), in which case the leaves get broadcast to all values in
|
134
|
-
that subtree.
|
135
|
-
|
136
|
-
The ``in_shardings`` argument is optional. JAX will infer the shardings
|
137
|
-
from the input :py:class:`jax.Array`'s and defaults to replicating the input
|
138
|
-
if the sharding cannot be inferred.
|
139
|
-
|
140
|
-
The valid resource assignment specifications are:
|
141
|
-
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
142
|
-
will be partitioned. With this, using a mesh context manager is not
|
143
|
-
required.
|
144
|
-
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
145
|
-
it wants.
|
146
|
-
For in_shardings, JAX will mark is as replicated but this behavior
|
147
|
-
can change in the future.
|
148
|
-
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
149
|
-
determine the output shardings.
|
150
|
-
|
151
|
-
The size of every dimension has to be a multiple of the total number of
|
152
|
-
resources assigned to it. This is similar to pjit's in_shardings.
|
153
|
-
out_shardings: Like ``in_shardings``, but specifies resource
|
154
|
-
assignment for function outputs. This is similar to pjit's
|
155
|
-
out_shardings.
|
156
|
-
|
157
|
-
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
158
|
-
will use GSPMD's sharding propagation to figure out what the sharding of the
|
159
|
-
output(s) should be.
|
160
|
-
static_argnums: An optional int or collection of ints that specify which
|
161
|
-
positional arguments to treat as static (compile-time constant).
|
162
|
-
Operations that only depend on static arguments will be constant-folded in
|
163
|
-
Python (during tracing), and so the corresponding argument values can be
|
164
|
-
any Python object.
|
165
|
-
|
166
|
-
Static arguments should be hashable, meaning both ``__hash__`` and
|
167
|
-
``__eq__`` are implemented, and immutable. Calling the jitted function
|
168
|
-
with different values for these constants will trigger recompilation.
|
169
|
-
Arguments that are not arrays or containers thereof must be marked as
|
170
|
-
static.
|
171
|
-
|
172
|
-
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
|
173
|
-
arguments are treated as static. If ``static_argnums`` is not provided but
|
174
|
-
``static_argnames`` is, or vice versa, JAX uses
|
175
|
-
:code:`inspect.signature(fun)` to find any positional arguments that
|
176
|
-
correspond to ``static_argnames``
|
177
|
-
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
|
178
|
-
provided, ``inspect.signature`` is not used, and only actual
|
179
|
-
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
180
|
-
be treated as static.
|
181
|
-
donate_argnums: Specify which positional argument buffers are "donated" to
|
182
|
-
the computation. It is safe to donate argument buffers if you no longer
|
183
|
-
need them once the computation has finished. In some cases XLA can make
|
184
|
-
use of donated buffers to reduce the amount of memory needed to perform a
|
185
|
-
computation, for example recycling one of your input buffers to store a
|
186
|
-
result. You should not reuse buffers that you donate to a computation, JAX
|
187
|
-
will raise an error if you try to. By default, no argument buffers are
|
188
|
-
donated.
|
189
|
-
|
190
|
-
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
191
|
-
arguments are donated. If ``donate_argnums`` is not provided but
|
192
|
-
``donate_argnames`` is, or vice versa, JAX uses
|
193
|
-
:code:`inspect.signature(fun)` to find any positional arguments that
|
194
|
-
correspond to ``donate_argnames``
|
195
|
-
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
196
|
-
provided, ``inspect.signature`` is not used, and only actual
|
197
|
-
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
198
|
-
be donated.
|
199
|
-
|
200
|
-
For more details on buffer donation see the
|
201
|
-
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
202
|
-
donate_argnames: An optional string or collection of strings specifying
|
203
|
-
which named arguments are donated to the computation. See the
|
204
|
-
comment on ``donate_argnums`` for details. If not
|
205
|
-
provided but ``donate_argnums`` is set, the default is based on calling
|
206
|
-
``inspect.signature(fun)`` to find corresponding named arguments.
|
207
|
-
keep_unused: If `False` (the default), arguments that JAX determines to be
|
208
|
-
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
209
|
-
Such arguments will not be transferred to the device nor provided to the
|
210
|
-
underlying executable. If `True`, unused arguments will not be pruned.
|
211
|
-
device: This is an experimental feature and the API is likely to change.
|
212
|
-
Optional, the Device the jitted function will run on. (Available devices
|
213
|
-
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
214
|
-
from XLA's DeviceAssignment logic and is usually to use
|
215
|
-
``jax.devices()[0]``.
|
216
|
-
backend: This is an experimental feature and the API is likely to change.
|
217
|
-
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
218
|
-
``'tpu'``.
|
219
|
-
inline: Specify whether this function should be inlined into enclosing
|
220
|
-
jaxprs (rather than being represented as an application of the xla_call
|
221
|
-
primitive with its own subjaxpr). Default False.
|
222
|
-
abstracted_axes:
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
A wrapped version of ``fun``, set up for just-in-time compilation.
|
226
|
-
The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
|
227
|
-
and has the following attributes and methods:
|
228
|
-
|
229
|
-
- ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
|
230
|
-
- ``origin_fun(*args, **kwargs)``: the original function
|
231
|
-
- ``jitted_fun(*args, **kwargs)``: the jitted function
|
232
|
-
- ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
|
233
|
-
|
234
|
-
"""
|
235
|
-
|
236
|
-
if fun is None:
|
237
|
-
def wrapper(fun_again: Callable) -> JittedFunction:
|
238
|
-
return _get_jitted_fun(fun_again,
|
239
|
-
in_shardings,
|
240
|
-
out_shardings,
|
241
|
-
static_argnums,
|
242
|
-
donate_argnums,
|
243
|
-
donate_argnames,
|
244
|
-
keep_unused,
|
245
|
-
device,
|
246
|
-
backend,
|
247
|
-
inline,
|
248
|
-
abstracted_axes,
|
249
|
-
**kwargs)
|
250
|
-
|
251
|
-
return wrapper
|
252
|
-
|
253
|
-
else:
|
254
|
-
return _get_jitted_fun(fun,
|
255
|
-
in_shardings,
|
256
|
-
out_shardings,
|
257
|
-
static_argnums,
|
258
|
-
donate_argnums,
|
259
|
-
donate_argnames,
|
260
|
-
keep_unused,
|
261
|
-
device,
|
262
|
-
backend,
|
263
|
-
inline,
|
264
|
-
abstracted_axes,
|
265
|
-
**kwargs)
|
@@ -1,118 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax.numpy as jnp
|
19
|
-
|
20
|
-
import brainstate as bc
|
21
|
-
|
22
|
-
|
23
|
-
class TestJIT(unittest.TestCase):
|
24
|
-
def test_inner_state_are_not_catched(self):
|
25
|
-
a = bc.State(bc.random.randn(10))
|
26
|
-
|
27
|
-
@bc.transform.jit
|
28
|
-
def fun1(inp):
|
29
|
-
a.value += inp
|
30
|
-
|
31
|
-
b = bc.State(bc.random.randn(1))
|
32
|
-
|
33
|
-
def inner_fun(x):
|
34
|
-
b.value += x
|
35
|
-
|
36
|
-
bc.transform.for_loop(inner_fun, bc.random.randn(100))
|
37
|
-
|
38
|
-
return a.value + b.value
|
39
|
-
|
40
|
-
print(fun1(1.))
|
41
|
-
key = fun1.stateful_fun.get_arg_cache_key(1.)
|
42
|
-
self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
|
43
|
-
|
44
|
-
x = bc.random.randn(10)
|
45
|
-
print(fun1(x))
|
46
|
-
key = fun1.stateful_fun.get_arg_cache_key(x)
|
47
|
-
self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
|
48
|
-
|
49
|
-
def test_kwargs(self):
|
50
|
-
a = bc.State(bc.random.randn(10))
|
51
|
-
|
52
|
-
@bc.transform.jit
|
53
|
-
def fun1(inp):
|
54
|
-
a.value += inp
|
55
|
-
|
56
|
-
b = bc.State(bc.random.randn(1))
|
57
|
-
|
58
|
-
def inner_fun(x):
|
59
|
-
b.value += x
|
60
|
-
|
61
|
-
bc.transform.for_loop(inner_fun, bc.random.randn(100))
|
62
|
-
|
63
|
-
return a.value + b.value
|
64
|
-
|
65
|
-
# test kwargs
|
66
|
-
print(fun1(inp=bc.random.randn(10)))
|
67
|
-
|
68
|
-
def test_jit_compile_sensitive_to_input_shape(self):
|
69
|
-
global_data = [0]
|
70
|
-
|
71
|
-
@bc.transform.jit
|
72
|
-
def fun1(inp):
|
73
|
-
global_data[0] += 1
|
74
|
-
return inp
|
75
|
-
|
76
|
-
print(fun1(1.))
|
77
|
-
self.assertTrue(global_data[0] == 1)
|
78
|
-
|
79
|
-
print(fun1(2.))
|
80
|
-
self.assertTrue(global_data[0] == 1)
|
81
|
-
|
82
|
-
print(fun1(bc.random.randn(10)))
|
83
|
-
self.assertTrue(global_data[0] == 2)
|
84
|
-
|
85
|
-
print(fun1(bc.random.randn(10, 10)))
|
86
|
-
self.assertTrue(global_data[0] == 3)
|
87
|
-
|
88
|
-
def test_jit_clear_cache(self):
|
89
|
-
a = bc.State(bc.random.randn(1))
|
90
|
-
compiling = []
|
91
|
-
|
92
|
-
@bc.transform.jit
|
93
|
-
def log2(x):
|
94
|
-
print('compiling')
|
95
|
-
compiling.append(1)
|
96
|
-
ln_x = jnp.log(x)
|
97
|
-
ln_2 = jnp.log(2.0) + a.value
|
98
|
-
return ln_x / ln_2
|
99
|
-
|
100
|
-
x = bc.random.randn(1)
|
101
|
-
print(log2(x)) # compiling
|
102
|
-
self.assertTrue(len(compiling) == 1)
|
103
|
-
print(log2(x)) # no compiling
|
104
|
-
self.assertTrue(len(compiling) == 1)
|
105
|
-
|
106
|
-
log2.clear_cache()
|
107
|
-
print(log2(x)) # compiling
|
108
|
-
self.assertTrue(len(compiling) == 2)
|
109
|
-
|
110
|
-
def test_jit_attribute_origin_fun(self):
|
111
|
-
def fun1(x):
|
112
|
-
return x
|
113
|
-
|
114
|
-
jitted_fun = bc.transform.jit(fun1)
|
115
|
-
self.assertTrue(jitted_fun.origin_fun is fun1)
|
116
|
-
self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
|
117
|
-
self.assertTrue(callable(jitted_fun.jitted_fun))
|
118
|
-
self.assertTrue(callable(jitted_fun.clear_cache))
|