brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +191 -48
- brainstate/_module_test.py +95 -21
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -2
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +164 -3
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
- brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- brainstate-0.0.1.dist-info/RECORD +0 -79
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/nn/_rate_rnns.py
CHANGED
@@ -90,6 +90,9 @@ class ValinaRNNCell(RNNCell):
|
|
90
90
|
def init_state(self, batch_size: int = None, **kwargs):
|
91
91
|
self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))
|
92
92
|
|
93
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
94
|
+
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
95
|
+
|
93
96
|
def update(self, x):
|
94
97
|
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
95
98
|
h = self.W(xh)
|
@@ -147,6 +150,9 @@ class GRUCell(RNNCell):
|
|
147
150
|
def init_state(self, batch_size: int = None, **kwargs):
|
148
151
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
149
152
|
|
153
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
154
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
155
|
+
|
150
156
|
def update(self, x):
|
151
157
|
old_h = self.h.value
|
152
158
|
xh = jnp.concatenate([x, old_h], axis=-1)
|
@@ -224,6 +230,9 @@ class MGUCell(RNNCell):
|
|
224
230
|
def init_state(self, batch_size: int = None, **kwargs):
|
225
231
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
226
232
|
|
233
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
234
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
235
|
+
|
227
236
|
def update(self, x):
|
228
237
|
old_h = self.h.value
|
229
238
|
xh = jnp.concatenate([x, old_h], axis=-1)
|
@@ -327,6 +336,10 @@ class LSTMCell(RNNCell):
|
|
327
336
|
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
328
337
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
329
338
|
|
339
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
340
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
341
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
342
|
+
|
330
343
|
def update(self, x):
|
331
344
|
h, c = self.h.value, self.c.value
|
332
345
|
xh = jnp.concat([x, h], axis=-1)
|
@@ -379,6 +392,10 @@ class URLSTMCell(RNNCell):
|
|
379
392
|
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
380
393
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
381
394
|
|
395
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
396
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
397
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
398
|
+
|
382
399
|
def update(self, x: ArrayLike) -> ArrayLike:
|
383
400
|
h, c = self.h.value, self.c.value
|
384
401
|
xh = jnp.concat([x, h], axis=-1)
|
brainstate/nn/_readout.py
CHANGED
@@ -66,6 +66,9 @@ class LeakyRateReadout(DnnLayer):
|
|
66
66
|
def init_state(self, batch_size=None, **kwargs):
|
67
67
|
self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))
|
68
68
|
|
69
|
+
def reset_state(self, batch_size=None, **kwargs):
|
70
|
+
self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
|
71
|
+
|
69
72
|
def update(self, x):
|
70
73
|
r = self.decay * self.r.value + x @ self.weight.value
|
71
74
|
self.r.value = r
|
@@ -109,6 +112,9 @@ class LeakySpikeReadout(Neuron):
|
|
109
112
|
def init_state(self, batch_size, **kwargs):
|
110
113
|
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
111
114
|
|
115
|
+
def reset_state(self, batch_size, **kwargs):
|
116
|
+
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
117
|
+
|
112
118
|
@property
|
113
119
|
def spike(self):
|
114
120
|
return self.get_spike(self.V.value)
|
brainstate/optim/__init__.py
CHANGED
@@ -34,3 +34,16 @@ class TestMultiStepLR(unittest.TestCase):
|
|
34
34
|
self.assertTrue(jnp.allclose(r, 0.001))
|
35
35
|
else:
|
36
36
|
self.assertTrue(jnp.allclose(r, 0.0001))
|
37
|
+
|
38
|
+
def test2(self):
|
39
|
+
lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
|
40
|
+
for i in range(40):
|
41
|
+
r = lr(i)
|
42
|
+
if i < 10:
|
43
|
+
self.assertEqual(r, 0.1)
|
44
|
+
elif i < 20:
|
45
|
+
self.assertTrue(jnp.allclose(r, 0.01))
|
46
|
+
elif i < 30:
|
47
|
+
self.assertTrue(jnp.allclose(r, 0.001))
|
48
|
+
else:
|
49
|
+
self.assertTrue(jnp.allclose(r, 0.0001))
|
@@ -18,11 +18,12 @@
|
|
18
18
|
import functools
|
19
19
|
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
20
20
|
|
21
|
+
import brainunit as bu
|
21
22
|
import jax
|
22
23
|
import jax.numpy as jnp
|
23
24
|
|
24
25
|
from ._lr_scheduler import make_schedule, LearningRateScheduler
|
25
|
-
from .. import environ
|
26
|
+
from .. import environ
|
26
27
|
from .._module import Module
|
27
28
|
from .._state import State, LongTermState, StateDictManager, visible_state_dict
|
28
29
|
|
@@ -282,7 +283,7 @@ class Momentum(_WeightDecayOptimizer):
|
|
282
283
|
for k, v in train_states.items():
|
283
284
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
284
285
|
self.weight_states.add_unique_elem(k, v)
|
285
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
286
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
286
287
|
|
287
288
|
def update(self, grads: dict):
|
288
289
|
lr = self.lr()
|
@@ -349,7 +350,7 @@ class MomentumNesterov(_WeightDecayOptimizer):
|
|
349
350
|
for k, v in train_states.items():
|
350
351
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
351
352
|
self.weight_states.add_unique_elem(k, v)
|
352
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
353
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
353
354
|
|
354
355
|
def update(self, grads: dict):
|
355
356
|
lr = self.lr()
|
@@ -417,7 +418,7 @@ class Adagrad(_WeightDecayOptimizer):
|
|
417
418
|
for k, v in train_states.items():
|
418
419
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
419
420
|
self.weight_states.add_unique_elem(k, v)
|
420
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
421
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
421
422
|
|
422
423
|
def update(self, grads: dict):
|
423
424
|
lr = self.lr()
|
@@ -500,8 +501,8 @@ class Adadelta(_WeightDecayOptimizer):
|
|
500
501
|
for k, v in train_states.items():
|
501
502
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
502
503
|
self.weight_states.add_unique_elem(k, v)
|
503
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
504
|
-
self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
|
504
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
505
|
+
self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
505
506
|
|
506
507
|
def update(self, grads: dict):
|
507
508
|
weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
|
@@ -574,7 +575,7 @@ class RMSProp(_WeightDecayOptimizer):
|
|
574
575
|
for k, v in train_states.items():
|
575
576
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
576
577
|
self.weight_states.add_unique_elem(k, v)
|
577
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
578
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
578
579
|
|
579
580
|
def update(self, grads: dict):
|
580
581
|
lr = self.lr()
|
@@ -647,8 +648,8 @@ class Adam(_WeightDecayOptimizer):
|
|
647
648
|
for k, v in train_states.items():
|
648
649
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
649
650
|
self.weight_states.add_unique_elem(k, v)
|
650
|
-
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
|
651
|
-
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
|
651
|
+
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
652
|
+
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
652
653
|
|
653
654
|
def update(self, grads: dict):
|
654
655
|
lr = self.lr()
|
@@ -730,7 +731,7 @@ class LARS(_WeightDecayOptimizer):
|
|
730
731
|
for k, v in train_states.items():
|
731
732
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
732
733
|
self.weight_states.add_unique_elem(k, v)
|
733
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
734
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
734
735
|
|
735
736
|
def update(self, grads: dict):
|
736
737
|
lr = self.lr()
|
@@ -835,10 +836,10 @@ class Adan(_WeightDecayOptimizer):
|
|
835
836
|
for k, v in train_states.items():
|
836
837
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
837
838
|
self.weight_states.add_unique_elem(k, v)
|
838
|
-
self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
|
839
|
-
self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
|
840
|
-
self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
|
841
|
-
self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
|
839
|
+
self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
840
|
+
self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
841
|
+
self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
842
|
+
self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
842
843
|
|
843
844
|
def update(self, grads: dict):
|
844
845
|
lr = self.lr()
|
@@ -989,10 +990,10 @@ class AdamW(_WeightDecayOptimizer):
|
|
989
990
|
for k, v in train_states.items():
|
990
991
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
991
992
|
self.weight_states.add_unique_elem(k, v)
|
992
|
-
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
|
993
|
-
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
|
993
|
+
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
994
|
+
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
994
995
|
if self.amsgrad:
|
995
|
-
self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
|
996
|
+
self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
996
997
|
|
997
998
|
def update(self, grads: dict):
|
998
999
|
lr_old = self.lr()
|
brainstate/transform/__init__.py
CHANGED
@@ -17,10 +17,10 @@
|
|
17
17
|
This module contains the functions for the transformation of the brain data.
|
18
18
|
"""
|
19
19
|
|
20
|
-
from ._control import *
|
21
|
-
from ._control import __all__ as _controls_all
|
22
20
|
from ._autograd import *
|
23
21
|
from ._autograd import __all__ as _gradients_all
|
22
|
+
from ._control import *
|
23
|
+
from ._control import __all__ as _controls_all
|
24
24
|
from ._jit import *
|
25
25
|
from ._jit import __all__ as _jit_all
|
26
26
|
from ._jit_error import *
|
@@ -33,4 +33,3 @@ from ._progress_bar import __all__ as _progress_bar_all
|
|
33
33
|
__all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all
|
34
34
|
|
35
35
|
del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all
|
36
|
-
|
@@ -25,8 +25,8 @@ from jax._src.api import _vjp
|
|
25
25
|
from jax.api_util import argnums_partial
|
26
26
|
from jax.extend import linear_util
|
27
27
|
|
28
|
-
from brainstate._utils import set_module_as
|
29
28
|
from brainstate._state import State, StateTrace, StateDictManager
|
29
|
+
from brainstate._utils import set_module_as
|
30
30
|
|
31
31
|
__all__ = [
|
32
32
|
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
@@ -537,7 +537,6 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
537
537
|
|
538
538
|
def test_jacrev_return_aux1(self):
|
539
539
|
with bc.environ.context(precision=64):
|
540
|
-
|
541
540
|
def f1(x, y):
|
542
541
|
a = 4 * x[1] ** 2 - 2 * x[2]
|
543
542
|
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
|
@@ -564,7 +563,6 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
564
563
|
assert (vec == _r).all()
|
565
564
|
|
566
565
|
|
567
|
-
|
568
566
|
class TestClassFuncJacobian(unittest.TestCase):
|
569
567
|
def test_jacrev1(self):
|
570
568
|
def f1(x, y):
|
brainstate/transform/_jit.py
CHANGED
@@ -23,8 +23,8 @@ import jax
|
|
23
23
|
from jax._src import sharding_impls
|
24
24
|
from jax.lib import xla_client as xc
|
25
25
|
|
26
|
-
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
27
26
|
from brainstate._utils import set_module_as
|
27
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
28
28
|
|
29
29
|
__all__ = ['jit']
|
30
30
|
|
@@ -33,10 +33,13 @@ class JittedFunction(Callable):
|
|
33
33
|
"""
|
34
34
|
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
35
|
"""
|
36
|
-
origin_fun: Callable
|
36
|
+
origin_fun: Callable # the original function
|
37
37
|
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
38
|
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
-
clear_cache: Callable
|
39
|
+
clear_cache: Callable # clear the cache of the jitted function
|
40
|
+
|
41
|
+
def __call__(self, *args, **kwargs):
|
42
|
+
pass
|
40
43
|
|
41
44
|
|
42
45
|
def _get_jitted_fun(
|
@@ -85,12 +88,16 @@ def _get_jitted_fun(
|
|
85
88
|
jit_fun.clear_cache()
|
86
89
|
|
87
90
|
jitted_fun: JittedFunction
|
91
|
+
|
88
92
|
# the original function
|
89
93
|
jitted_fun.origin_fun = fun.fun
|
94
|
+
|
90
95
|
# the stateful function for extracting states
|
91
96
|
jitted_fun.stateful_fun = fun
|
97
|
+
|
92
98
|
# the jitted function
|
93
99
|
jitted_fun.jitted_fun = jit_fun
|
100
|
+
|
94
101
|
# clear cache
|
95
102
|
jitted_fun.clear_cache = clear_cache
|
96
103
|
|
@@ -99,18 +106,18 @@ def _get_jitted_fun(
|
|
99
106
|
|
100
107
|
@set_module_as('brainstate.transform')
|
101
108
|
def jit(
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
109
|
+
fun: Callable = None,
|
110
|
+
in_shardings=sharding_impls.UNSPECIFIED,
|
111
|
+
out_shardings=sharding_impls.UNSPECIFIED,
|
112
|
+
static_argnums: int | Sequence[int] | None = None,
|
113
|
+
donate_argnums: int | Sequence[int] | None = None,
|
114
|
+
donate_argnames: str | Iterable[str] | None = None,
|
115
|
+
keep_unused: bool = False,
|
116
|
+
device: xc.Device | None = None,
|
117
|
+
backend: str | None = None,
|
118
|
+
inline: bool = False,
|
119
|
+
abstracted_axes: Any | None = None,
|
120
|
+
**kwargs
|
114
121
|
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
115
122
|
"""
|
116
123
|
Sets up ``fun`` for just-in-time compilation with XLA.
|
@@ -228,12 +235,31 @@ def jit(
|
|
228
235
|
|
229
236
|
if fun is None:
|
230
237
|
def wrapper(fun_again: Callable) -> JittedFunction:
|
231
|
-
return _get_jitted_fun(fun_again,
|
232
|
-
|
233
|
-
|
238
|
+
return _get_jitted_fun(fun_again,
|
239
|
+
in_shardings,
|
240
|
+
out_shardings,
|
241
|
+
static_argnums,
|
242
|
+
donate_argnums,
|
243
|
+
donate_argnames,
|
244
|
+
keep_unused,
|
245
|
+
device,
|
246
|
+
backend,
|
247
|
+
inline,
|
248
|
+
abstracted_axes,
|
249
|
+
**kwargs)
|
250
|
+
|
234
251
|
return wrapper
|
235
252
|
|
236
253
|
else:
|
237
|
-
return _get_jitted_fun(fun,
|
238
|
-
|
239
|
-
|
254
|
+
return _get_jitted_fun(fun,
|
255
|
+
in_shardings,
|
256
|
+
out_shardings,
|
257
|
+
static_argnums,
|
258
|
+
donate_argnums,
|
259
|
+
donate_argnames,
|
260
|
+
keep_unused,
|
261
|
+
device,
|
262
|
+
backend,
|
263
|
+
inline,
|
264
|
+
abstracted_axes,
|
265
|
+
**kwargs)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
import jax.numpy as jnp
|
19
|
-
import jax.stages
|
20
19
|
|
21
20
|
import brainstate as bc
|
22
21
|
|
@@ -90,7 +89,6 @@ class TestJIT(unittest.TestCase):
|
|
90
89
|
self.assertTrue(len(compiling) == 2)
|
91
90
|
|
92
91
|
def test_jit_attribute_origin_fun(self):
|
93
|
-
|
94
92
|
def fun1(x):
|
95
93
|
return x
|
96
94
|
|
@@ -99,4 +97,3 @@ class TestJIT(unittest.TestCase):
|
|
99
97
|
self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
|
100
98
|
self.assertTrue(callable(jitted_fun.jitted_fun))
|
101
99
|
self.assertTrue(callable(jitted_fun.clear_cache))
|
102
|
-
|
@@ -54,20 +54,26 @@ function.
|
|
54
54
|
from __future__ import annotations
|
55
55
|
|
56
56
|
import functools
|
57
|
+
import inspect
|
57
58
|
import operator
|
58
59
|
from collections.abc import Hashable, Iterable, Sequence
|
60
|
+
from contextlib import ExitStack
|
59
61
|
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
60
62
|
|
61
63
|
import jax
|
62
64
|
from jax._src import source_info_util
|
65
|
+
from jax._src.linear_util import annotate
|
66
|
+
from jax._src.traceback_util import api_boundary
|
67
|
+
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
63
68
|
from jax.interpreters import partial_eval as pe
|
64
|
-
from jax.util import wraps
|
65
69
|
from jax.interpreters.xla import abstractify
|
70
|
+
from jax.util import wraps
|
66
71
|
|
67
72
|
from brainstate._state import State, StateTrace
|
68
73
|
from brainstate._utils import set_module_as
|
69
74
|
|
70
75
|
PyTree = Any
|
76
|
+
AxisName = Hashable
|
71
77
|
|
72
78
|
__all__ = [
|
73
79
|
"StatefulFunction",
|
@@ -393,7 +399,8 @@ class StatefulFunction(object):
|
|
393
399
|
if cache_key not in self._state_trace:
|
394
400
|
try:
|
395
401
|
# jaxpr
|
396
|
-
jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
|
402
|
+
# jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
|
403
|
+
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
397
404
|
functools.partial(self._wrapped_fun_to_eval, cache_key),
|
398
405
|
static_argnums=self.static_argnums,
|
399
406
|
axis_env=self.axis_env,
|
@@ -474,7 +481,8 @@ def make_jaxpr(
|
|
474
481
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
475
482
|
) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
|
476
483
|
Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
|
477
|
-
"""
|
484
|
+
"""
|
485
|
+
Creates a function that produces its jaxpr given example args.
|
478
486
|
|
479
487
|
Args:
|
480
488
|
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
@@ -571,3 +579,156 @@ def make_jaxpr(
|
|
571
579
|
if hasattr(fun, "__name__"):
|
572
580
|
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
573
581
|
return make_jaxpr_f
|
582
|
+
|
583
|
+
|
584
|
+
def _check_callable(fun):
|
585
|
+
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
586
|
+
# is that we can't take weak references to them, which the C++ JIT requires.
|
587
|
+
if isinstance(fun, staticmethod):
|
588
|
+
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
589
|
+
if not callable(fun):
|
590
|
+
raise TypeError(f"Expected a callable value, got {fun}")
|
591
|
+
if inspect.isgeneratorfunction(fun):
|
592
|
+
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
593
|
+
|
594
|
+
|
595
|
+
def _broadcast_prefix(
|
596
|
+
prefix_tree: Any,
|
597
|
+
full_tree: Any,
|
598
|
+
is_leaf: Callable[[Any], bool] | None = None
|
599
|
+
) -> list[Any]:
|
600
|
+
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
601
|
+
# ValueError; use prefix_errors to find disagreements and raise more precise
|
602
|
+
# error messages.
|
603
|
+
result = []
|
604
|
+
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
605
|
+
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
606
|
+
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
607
|
+
return result
|
608
|
+
|
609
|
+
|
610
|
+
def _flat_axes_specs(
|
611
|
+
abstracted_axes, *args, **kwargs
|
612
|
+
) -> list[pe.AbstractedAxesSpec]:
|
613
|
+
if kwargs:
|
614
|
+
raise NotImplementedError
|
615
|
+
|
616
|
+
def ax_leaf(l):
|
617
|
+
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
618
|
+
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
619
|
+
|
620
|
+
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
621
|
+
|
622
|
+
|
623
|
+
@transformation_with_aux
|
624
|
+
def _flatten_fun(in_tree, *args_flat):
|
625
|
+
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
626
|
+
ans = yield py_args, py_kwargs
|
627
|
+
yield jax.tree.flatten(ans)
|
628
|
+
|
629
|
+
|
630
|
+
def _make_jaxpr(
|
631
|
+
fun: Callable,
|
632
|
+
static_argnums: int | Iterable[int] = (),
|
633
|
+
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
634
|
+
return_shape: bool = False,
|
635
|
+
abstracted_axes: Any | None = None,
|
636
|
+
) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
|
637
|
+
"""Creates a function that produces its jaxpr given example args.
|
638
|
+
|
639
|
+
Args:
|
640
|
+
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
641
|
+
arguments and return value should be arrays, scalars, or standard Python
|
642
|
+
containers (tuple/list/dict) thereof.
|
643
|
+
static_argnums: See the :py:func:`jax.jit` docstring.
|
644
|
+
axis_env: Optional, a sequence of pairs where the first element is an axis
|
645
|
+
name and the second element is a positive integer representing the size of
|
646
|
+
the mapped axis with that name. This parameter is useful when lowering
|
647
|
+
functions that involve parallel communication collectives, and it
|
648
|
+
specifies the axis name/size environment that would be set up by
|
649
|
+
applications of :py:func:`jax.pmap`.
|
650
|
+
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
651
|
+
wrapped function returns a pair where the first element is the
|
652
|
+
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
653
|
+
pytree with the same structure as the output of ``fun`` and where the
|
654
|
+
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
655
|
+
attributes representing the corresponding types of the output leaves.
|
656
|
+
|
657
|
+
Returns:
|
658
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
659
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
660
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
661
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
662
|
+
representation of ``fun`` and the second element is a pytree representing
|
663
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
664
|
+
|
665
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
666
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
667
|
+
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
668
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
669
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
670
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
671
|
+
|
672
|
+
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
673
|
+
instead give a few examples.
|
674
|
+
|
675
|
+
>>> import jax
|
676
|
+
>>>
|
677
|
+
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
678
|
+
>>> print(f(3.0))
|
679
|
+
-0.83602
|
680
|
+
>>> _make_jaxpr(f)(3.0)
|
681
|
+
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
682
|
+
>>> _make_jaxpr(jax.grad(f))(3.0)
|
683
|
+
{ lambda ; a:f32[]. let
|
684
|
+
b:f32[] = cos a
|
685
|
+
c:f32[] = sin a
|
686
|
+
_:f32[] = sin b
|
687
|
+
d:f32[] = cos b
|
688
|
+
e:f32[] = mul 1.0 d
|
689
|
+
f:f32[] = neg e
|
690
|
+
g:f32[] = mul f c
|
691
|
+
in (g,) }
|
692
|
+
"""
|
693
|
+
_check_callable(fun)
|
694
|
+
static_argnums = _ensure_index_tuple(static_argnums)
|
695
|
+
|
696
|
+
def _abstractify(args, kwargs):
|
697
|
+
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
698
|
+
if abstracted_axes is None:
|
699
|
+
return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
700
|
+
else:
|
701
|
+
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
702
|
+
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
703
|
+
in_avals, keep_inputs = jax.util.unzip2(in_type)
|
704
|
+
return in_avals, in_tree, keep_inputs
|
705
|
+
|
706
|
+
@wraps(fun)
|
707
|
+
@api_boundary
|
708
|
+
def make_jaxpr_f(*args, **kwargs):
|
709
|
+
f = wrap_init(fun)
|
710
|
+
if static_argnums:
|
711
|
+
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
712
|
+
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
713
|
+
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
714
|
+
in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
|
715
|
+
f, out_tree = _flatten_fun(f, in_tree)
|
716
|
+
f = annotate(f, in_type)
|
717
|
+
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
718
|
+
with ExitStack() as stack:
|
719
|
+
for axis_name, size in axis_env or []:
|
720
|
+
stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
|
721
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
722
|
+
closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
|
723
|
+
if return_shape:
|
724
|
+
out_avals, _ = jax.util.unzip2(out_type)
|
725
|
+
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
726
|
+
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
727
|
+
return closed_jaxpr
|
728
|
+
|
729
|
+
make_jaxpr_f.__module__ = "brainstate.transform"
|
730
|
+
if hasattr(fun, "__qualname__"):
|
731
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
732
|
+
if hasattr(fun, "__name__"):
|
733
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
734
|
+
return make_jaxpr_f
|
@@ -14,13 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
|
+
|
17
18
|
import copy
|
18
19
|
from typing import Optional
|
19
20
|
|
20
21
|
import jax
|
21
22
|
|
22
|
-
from brainstate import environ
|
23
|
-
|
24
23
|
try:
|
25
24
|
from tqdm.auto import tqdm
|
26
25
|
except (ImportError, ModuleNotFoundError):
|
@@ -95,7 +94,6 @@ class ProgressBarRunner(object):
|
|
95
94
|
self.tqdm_bars[0].close()
|
96
95
|
|
97
96
|
def __call__(self, iter_num, *args, **kwargs):
|
98
|
-
|
99
97
|
_ = jax.lax.cond(
|
100
98
|
iter_num == 0,
|
101
99
|
lambda: jax.debug.callback(self._define_tqdm),
|
brainstate/util.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.1
|
3
|
+
Version: 0.0.1.post20240622
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
|
-
Home-page: https://github.com/brainpy/
|
5
|
+
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BrainPy Team
|
7
7
|
Author-email: BrainPy Team <chao.brain@qq.com>
|
8
8
|
License: Apache-2.0 license
|
@@ -31,21 +31,11 @@ License-File: LICENSE
|
|
31
31
|
Requires-Dist: jax
|
32
32
|
Requires-Dist: jaxlib
|
33
33
|
Requires-Dist: numpy
|
34
|
+
Requires-Dist: brainunit
|
34
35
|
Provides-Extra: cpu
|
35
36
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
36
|
-
Requires-Dist: brainpylib ; extra == 'cpu'
|
37
|
-
Provides-Extra: cpu_mini
|
38
|
-
Requires-Dist: jaxlib ; extra == 'cpu_mini'
|
39
|
-
Provides-Extra: cuda11
|
40
|
-
Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11'
|
41
|
-
Requires-Dist: brainpylib ; extra == 'cuda11'
|
42
|
-
Provides-Extra: cuda11_mini
|
43
|
-
Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11_mini'
|
44
37
|
Provides-Extra: cuda12
|
45
38
|
Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12'
|
46
|
-
Requires-Dist: brainpylib ; extra == 'cuda12'
|
47
|
-
Provides-Extra: cuda12_mini
|
48
|
-
Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12_mini'
|
49
39
|
Provides-Extra: testing
|
50
40
|
Requires-Dist: pytest ; extra == 'testing'
|
51
41
|
Provides-Extra: tpu
|
@@ -90,12 +80,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
|
|
90
80
|
|
91
81
|
## See also the BDP ecosystem
|
92
82
|
|
93
|
-
- [``
|
83
|
+
- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.
|
94
84
|
|
95
|
-
- [``
|
85
|
+
- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.
|
96
86
|
|
97
|
-
- [``
|
87
|
+
- [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators.
|
98
88
|
|
99
|
-
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological
|
89
|
+
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.
|
90
|
+
|
91
|
+
- [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis.
|
100
92
|
|
101
93
|
|