brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -24,7 +24,6 @@ from brainstate._compatible_import import Device
|
|
24
24
|
from brainstate._utils import set_module_as
|
25
25
|
from brainstate.typing import Missing
|
26
26
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
27
|
-
from ._util import write_back_state_values
|
28
27
|
|
29
28
|
__all__ = ['jit']
|
30
29
|
|
@@ -33,6 +32,8 @@ class JittedFunction(Callable):
|
|
33
32
|
"""
|
34
33
|
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
34
|
"""
|
35
|
+
__module__ = 'brainstate.transform'
|
36
|
+
|
36
37
|
origin_fun: Callable # the original function
|
37
38
|
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
39
|
jitted_fun: jax.stages.Wrapped # the jitted function
|
@@ -67,8 +68,8 @@ def _get_jitted_fun(
|
|
67
68
|
static_argnums=static_argnums,
|
68
69
|
static_argnames=static_argnames,
|
69
70
|
abstracted_axes=abstracted_axes,
|
70
|
-
|
71
|
-
|
71
|
+
name='jit',
|
72
|
+
return_only_write=True
|
72
73
|
)
|
73
74
|
jit_fun = jax.jit(
|
74
75
|
fun.jaxpr_call,
|
@@ -92,14 +93,14 @@ def _get_jitted_fun(
|
|
92
93
|
return fun.fun(*args, **params)
|
93
94
|
|
94
95
|
# compile the function and get the state trace
|
95
|
-
state_trace = fun.
|
96
|
+
state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
|
96
97
|
read_state_vals = state_trace.get_read_state_values(True)
|
97
98
|
|
98
99
|
# call the jitted function
|
99
100
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
100
101
|
|
101
102
|
# write the state values back to the states
|
102
|
-
|
103
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
103
104
|
return outs
|
104
105
|
|
105
106
|
def clear_cache():
|
@@ -139,7 +140,7 @@ def _get_jitted_fun(
|
|
139
140
|
A ``Lowered`` instance representing the lowering.
|
140
141
|
"""
|
141
142
|
# compile the function and get the state trace
|
142
|
-
state_trace = fun.
|
143
|
+
state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
|
143
144
|
read_state_vals = state_trace.get_read_state_values(replace_writen=True)
|
144
145
|
write_state_vals = state_trace.get_write_state_values(replace_read=True)
|
145
146
|
|
@@ -147,7 +148,7 @@ def _get_jitted_fun(
|
|
147
148
|
ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
|
148
149
|
|
149
150
|
# write the state values back to the states
|
150
|
-
|
151
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
151
152
|
return ret
|
152
153
|
|
153
154
|
jitted_fun: JittedFunction
|
@@ -176,7 +177,7 @@ def _get_jitted_fun(
|
|
176
177
|
return jitted_fun
|
177
178
|
|
178
179
|
|
179
|
-
@set_module_as('brainstate.
|
180
|
+
@set_module_as('brainstate.transform')
|
180
181
|
def jit(
|
181
182
|
fun: Callable | Missing = Missing(),
|
182
183
|
in_shardings=sharding_impls.UNSPECIFIED,
|
@@ -195,9 +196,12 @@ def jit(
|
|
195
196
|
"""
|
196
197
|
Sets up ``fun`` for just-in-time compilation with XLA.
|
197
198
|
|
198
|
-
|
199
|
-
|
200
|
-
|
199
|
+
Parameters
|
200
|
+
----------
|
201
|
+
fun : callable or Missing, optional
|
202
|
+
Function to be jitted.
|
203
|
+
in_shardings : pytree, optional
|
204
|
+
Pytree of structure matching that of arguments to ``fun``,
|
201
205
|
with all actual arguments replaced by resource assignment specifications.
|
202
206
|
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
203
207
|
whole subtree), in which case the leaves get broadcast to all values in
|
@@ -208,26 +212,29 @@ def jit(
|
|
208
212
|
if the sharding cannot be inferred.
|
209
213
|
|
210
214
|
The valid resource assignment specifications are:
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
215
|
+
|
216
|
+
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
217
|
+
will be partitioned. With this, using a mesh context manager is not
|
218
|
+
required.
|
219
|
+
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
220
|
+
it wants.
|
221
|
+
For in_shardings, JAX will mark is as replicated but this behavior
|
222
|
+
can change in the future.
|
223
|
+
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
224
|
+
determine the output shardings.
|
220
225
|
|
221
226
|
The size of every dimension has to be a multiple of the total number of
|
222
227
|
resources assigned to it. This is similar to pjit's in_shardings.
|
223
|
-
|
228
|
+
out_shardings : pytree, optional
|
229
|
+
Like ``in_shardings``, but specifies resource
|
224
230
|
assignment for function outputs. This is similar to pjit's
|
225
231
|
out_shardings.
|
226
232
|
|
227
233
|
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
228
234
|
will use GSPMD's sharding propagation to figure out what the sharding of the
|
229
235
|
output(s) should be.
|
230
|
-
|
236
|
+
static_argnums : int or sequence of int, optional
|
237
|
+
An optional int or collection of ints that specify which
|
231
238
|
positional arguments to treat as static (compile-time constant).
|
232
239
|
Operations that only depend on static arguments will be constant-folded in
|
233
240
|
Python (during tracing), and so the corresponding argument values can be
|
@@ -248,12 +255,8 @@ def jit(
|
|
248
255
|
provided, ``inspect.signature`` is not used, and only actual
|
249
256
|
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
250
257
|
be treated as static.
|
251
|
-
|
252
|
-
which
|
253
|
-
Operations that only depend on static arguments will be constant-folded in
|
254
|
-
Python (during tracing), and so the corresponding argument values can be
|
255
|
-
any Python object.
|
256
|
-
donate_argnums: Specify which positional argument buffers are "donated" to
|
258
|
+
donate_argnums : int or sequence of int, optional
|
259
|
+
Specify which positional argument buffers are "donated" to
|
257
260
|
the computation. It is safe to donate argument buffers if you no longer
|
258
261
|
need them once the computation has finished. In some cases XLA can make
|
259
262
|
use of donated buffers to reduce the amount of memory needed to perform a
|
@@ -274,38 +277,88 @@ def jit(
|
|
274
277
|
|
275
278
|
For more details on buffer donation see the
|
276
279
|
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
277
|
-
|
280
|
+
static_argnames : str or sequence of str, optional
|
281
|
+
An optional string or collection of strings specifying
|
282
|
+
which named arguments are treated as static (compile-time constant).
|
283
|
+
Operations that only depend on static arguments will be constant-folded in
|
284
|
+
Python (during tracing), and so the corresponding argument values can be
|
285
|
+
any Python object.
|
286
|
+
donate_argnames : str or iterable of str, optional
|
287
|
+
An optional string or collection of strings specifying
|
278
288
|
which named arguments are donated to the computation. See the
|
279
289
|
comment on ``donate_argnums`` for details. If not
|
280
290
|
provided but ``donate_argnums`` is set, the default is based on calling
|
281
291
|
``inspect.signature(fun)`` to find corresponding named arguments.
|
282
|
-
|
292
|
+
keep_unused : bool, default False
|
293
|
+
If `False` (the default), arguments that JAX determines to be
|
283
294
|
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
284
295
|
Such arguments will not be transferred to the device nor provided to the
|
285
296
|
underlying executable. If `True`, unused arguments will not be pruned.
|
286
|
-
|
297
|
+
device : Device, optional
|
298
|
+
This is an experimental feature and the API is likely to change.
|
287
299
|
Optional, the Device the jitted function will run on. (Available devices
|
288
300
|
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
289
301
|
from XLA's DeviceAssignment logic and is usually to use
|
290
302
|
``jax.devices()[0]``.
|
291
|
-
|
303
|
+
backend : str, optional
|
304
|
+
This is an experimental feature and the API is likely to change.
|
292
305
|
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
293
306
|
``'tpu'``.
|
294
|
-
|
307
|
+
inline : bool, default False
|
308
|
+
Specify whether this function should be inlined into enclosing
|
295
309
|
jaxprs (rather than being represented as an application of the xla_call
|
296
310
|
primitive with its own subjaxpr). Default False.
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
311
|
+
abstracted_axes : Any, optional
|
312
|
+
Abstracted axes specification.
|
313
|
+
**kwargs
|
314
|
+
Additional keyword arguments passed to the underlying JAX jit function.
|
315
|
+
|
316
|
+
Returns
|
317
|
+
-------
|
318
|
+
JittedFunction or callable
|
319
|
+
A wrapped version of ``fun``, set up for just-in-time compilation.
|
320
|
+
The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
|
321
|
+
and has the following attributes and methods:
|
322
|
+
|
323
|
+
- ``stateful_fun`` : the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
|
324
|
+
- ``origin_fun(*args, **kwargs)`` : the original function
|
325
|
+
- ``jitted_fun(*args, **kwargs)`` : the jitted function
|
326
|
+
- ``clear_cache(*args, **kwargs)`` : clear the cache of the jitted function
|
327
|
+
|
328
|
+
Examples
|
329
|
+
--------
|
330
|
+
Basic usage with a simple function:
|
331
|
+
|
332
|
+
.. code-block:: python
|
333
|
+
|
334
|
+
>>> import brainstate
|
335
|
+
>>> import jax.numpy as jnp
|
336
|
+
>>>
|
337
|
+
>>> @brainstate.transform.jit
|
338
|
+
... def f(x):
|
339
|
+
... return x ** 2
|
340
|
+
>>>
|
341
|
+
>>> result = f(jnp.array([1, 2, 3]))
|
342
|
+
|
343
|
+
Using static arguments:
|
344
|
+
|
345
|
+
.. code-block:: python
|
346
|
+
|
347
|
+
>>> @brainstate.transform.jit(static_argnums=(1,))
|
348
|
+
... def g(x, n):
|
349
|
+
... return x ** n
|
350
|
+
>>>
|
351
|
+
>>> result = g(jnp.array([1, 2, 3]), 2)
|
352
|
+
|
353
|
+
Manual jitting:
|
354
|
+
|
355
|
+
.. code-block:: python
|
356
|
+
|
357
|
+
>>> def h(x):
|
358
|
+
... return x * 2
|
359
|
+
>>>
|
360
|
+
>>> jitted_h = brainstate.transform.jit(h)
|
361
|
+
>>> result = jitted_h(jnp.array([1, 2, 3]))
|
309
362
|
"""
|
310
363
|
|
311
364
|
if isinstance(fun, Missing):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -41,12 +41,12 @@ class TestJIT(unittest.TestCase):
|
|
41
41
|
|
42
42
|
print(fun1(1.))
|
43
43
|
key = fun1.stateful_fun.get_arg_cache_key(1.)
|
44
|
-
self.assertTrue(len(fun1.stateful_fun.
|
44
|
+
self.assertTrue(len(fun1.stateful_fun.get_states_by_cache(key)) == 2)
|
45
45
|
|
46
46
|
x = bst.random.randn(10)
|
47
47
|
print(fun1(x))
|
48
48
|
key = fun1.stateful_fun.get_arg_cache_key(x)
|
49
|
-
self.assertTrue(len(fun1.stateful_fun.
|
49
|
+
self.assertTrue(len(fun1.stateful_fun.get_states_by_cache(key)) == 2)
|
50
50
|
|
51
51
|
def test_kwargs(self):
|
52
52
|
a = bst.State(bst.random.randn(10))
|