brainstate 0.0.1.post20240623__py2.py3-none-any.whl → 0.0.1.1.post20240708__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_module.py +6 -11
- brainstate/_module_test.py +1 -1
- brainstate/_random_for_unit.py +51 -0
- brainstate/_state.py +12 -6
- brainstate/init/_generic.py +97 -32
- brainstate/init/_random_inits.py +17 -7
- brainstate/init/_regular_inits.py +8 -7
- brainstate/mixin.py +3 -3
- brainstate/mixin_test.py +9 -9
- brainstate/nn/_projection/_align_post.py +11 -11
- brainstate/nn/_projection/_align_pre.py +3 -3
- brainstate/random.py +66 -36
- brainstate/transform/_jit_error.py +71 -49
- brainstate/transform/_jit_error_test.py +55 -0
- brainstate/transform/_make_jaxpr.py +10 -5
- brainstate/typing.py +2 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.1.post20240708.dist-info}/METADATA +1 -1
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.1.post20240708.dist-info}/RECORD +22 -20
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.1.post20240708.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.1.post20240708.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.1.post20240708.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from brainstate._module import (register_delay_of_target,
|
|
22
22
|
ReceiveInputProj,
|
23
23
|
ExtendedUpdateWithBA)
|
24
24
|
from brainstate._utils import set_module_as
|
25
|
-
from brainstate.mixin import (Mode,
|
25
|
+
from brainstate.mixin import (Mode, JointTypes, DelayedInitializer, BindCondData, AlignPost, UpdateReturn)
|
26
26
|
from ._utils import is_instance
|
27
27
|
|
28
28
|
__all__ = [
|
@@ -39,7 +39,7 @@ def align_post_add_bef_update(
|
|
39
39
|
out_label: str,
|
40
40
|
syn_desc,
|
41
41
|
out_desc,
|
42
|
-
post:
|
42
|
+
post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
|
43
43
|
proj_name: str
|
44
44
|
):
|
45
45
|
# synapse and output initialization
|
@@ -60,7 +60,7 @@ class _AlignPost(Module):
|
|
60
60
|
def __init__(
|
61
61
|
self,
|
62
62
|
syn: Module,
|
63
|
-
out:
|
63
|
+
out: JointTypes[Dynamics, BindCondData]
|
64
64
|
):
|
65
65
|
super().__init__()
|
66
66
|
self.syn = syn
|
@@ -140,7 +140,7 @@ class HalfProjAlignPostMg(Projection):
|
|
140
140
|
comm: Module,
|
141
141
|
syn: DelayedInitializer[AlignPost],
|
142
142
|
out: DelayedInitializer[BindCondData],
|
143
|
-
post:
|
143
|
+
post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
|
144
144
|
out_label: Optional[str] = None,
|
145
145
|
name: Optional[str] = None,
|
146
146
|
mode: Optional[Mode] = None,
|
@@ -150,7 +150,7 @@ class HalfProjAlignPostMg(Projection):
|
|
150
150
|
# synaptic models
|
151
151
|
is_instance(syn, DelayedInitializer[AlignPost])
|
152
152
|
is_instance(out, DelayedInitializer[BindCondData])
|
153
|
-
is_instance(post,
|
153
|
+
is_instance(post, JointTypes[ReceiveInputProj, ExtendedUpdateWithBA])
|
154
154
|
|
155
155
|
# synapse and output initialization
|
156
156
|
syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
|
@@ -257,12 +257,12 @@ class FullProjAlignPostMg(Projection):
|
|
257
257
|
|
258
258
|
def __init__(
|
259
259
|
self,
|
260
|
-
pre:
|
260
|
+
pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
|
261
261
|
delay: Union[None, int, float],
|
262
262
|
comm: Module,
|
263
263
|
syn: DelayedInitializer[AlignPost],
|
264
264
|
out: DelayedInitializer[BindCondData],
|
265
|
-
post:
|
265
|
+
post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
|
266
266
|
out_label: Optional[str] = None,
|
267
267
|
name: Optional[str] = None,
|
268
268
|
mode: Optional[Mode] = None,
|
@@ -270,11 +270,11 @@ class FullProjAlignPostMg(Projection):
|
|
270
270
|
super().__init__(name=name, mode=mode)
|
271
271
|
|
272
272
|
# synaptic models
|
273
|
-
is_instance(pre,
|
273
|
+
is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
|
274
274
|
is_instance(comm, Module)
|
275
275
|
is_instance(syn, DelayedInitializer[AlignPost])
|
276
276
|
is_instance(out, DelayedInitializer[BindCondData])
|
277
|
-
is_instance(post,
|
277
|
+
is_instance(post, JointTypes[ReceiveInputProj, ExtendedUpdateWithBA])
|
278
278
|
self.comm = comm
|
279
279
|
|
280
280
|
# delay initialization
|
@@ -494,8 +494,8 @@ class FullProjAlignPost(Projection):
|
|
494
494
|
# synaptic models
|
495
495
|
is_instance(pre, UpdateReturn)
|
496
496
|
is_instance(comm, Module)
|
497
|
-
is_instance(syn,
|
498
|
-
is_instance(out,
|
497
|
+
is_instance(syn, JointTypes[Dynamics, AlignPost])
|
498
|
+
is_instance(out, JointTypes[Dynamics, BindCondData])
|
499
499
|
is_instance(post, ReceiveInputProj)
|
500
500
|
self.comm = comm
|
501
501
|
self.syn = syn
|
@@ -20,7 +20,7 @@ from brainstate._module import (Module, DelayAccess, Projection,
|
|
20
20
|
ExtendedUpdateWithBA, ReceiveInputProj,
|
21
21
|
register_delay_of_target)
|
22
22
|
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode,
|
23
|
+
from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, JointTypes)
|
24
24
|
from ._utils import is_instance
|
25
25
|
|
26
26
|
__all__ = [
|
@@ -284,7 +284,7 @@ class FullProjAlignPreDSMg(Projection):
|
|
284
284
|
|
285
285
|
def __init__(
|
286
286
|
self,
|
287
|
-
pre:
|
287
|
+
pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
|
288
288
|
delay: Union[None, int, float],
|
289
289
|
syn: DelayedInitializer[UpdateReturn],
|
290
290
|
comm: Module,
|
@@ -297,7 +297,7 @@ class FullProjAlignPreDSMg(Projection):
|
|
297
297
|
super().__init__(name=name, mode=mode)
|
298
298
|
|
299
299
|
# synaptic models
|
300
|
-
is_instance(pre,
|
300
|
+
is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
|
301
301
|
is_instance(syn, DelayedInitializer[Module])
|
302
302
|
is_instance(comm, Module)
|
303
303
|
is_instance(out, BindCondData)
|
brainstate/random.py
CHANGED
@@ -22,11 +22,16 @@ from functools import partial
|
|
22
22
|
from operator import index
|
23
23
|
from typing import Optional
|
24
24
|
|
25
|
+
import brainunit as bu
|
25
26
|
import jax
|
27
|
+
import jax.numpy as jnp
|
28
|
+
import jax.random as jr
|
26
29
|
import numpy as np
|
27
|
-
from jax import
|
30
|
+
from jax import jit, vmap
|
31
|
+
from jax import lax, core, dtypes
|
28
32
|
|
29
33
|
from brainstate import environ
|
34
|
+
from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
30
35
|
from ._state import State
|
31
36
|
from .transform._jit_error import jit_error
|
32
37
|
from .typing import DTypeLike, Size, SeedOrKey
|
@@ -144,7 +149,7 @@ class RandomState(State):
|
|
144
149
|
def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
145
150
|
key = self.split_key() if key is None else _formalize_key(key)
|
146
151
|
dtype = dtype or environ.dftype()
|
147
|
-
r =
|
152
|
+
r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
|
148
153
|
return r
|
149
154
|
|
150
155
|
def randint(
|
@@ -207,7 +212,7 @@ class RandomState(State):
|
|
207
212
|
dtype: DTypeLike = None):
|
208
213
|
dtype = dtype or environ.dftype()
|
209
214
|
key = self.split_key() if key is None else _formalize_key(key)
|
210
|
-
r =
|
215
|
+
r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
|
211
216
|
return r
|
212
217
|
|
213
218
|
def random_sample(self,
|
@@ -250,7 +255,7 @@ class RandomState(State):
|
|
250
255
|
key: Optional[SeedOrKey] = None):
|
251
256
|
x = _check_py_seq(x)
|
252
257
|
key = self.split_key() if key is None else _formalize_key(key)
|
253
|
-
r =
|
258
|
+
r = permutation_for_unit(key, x, axis=axis, independent=independent)
|
254
259
|
return r
|
255
260
|
|
256
261
|
def shuffle(self,
|
@@ -258,7 +263,7 @@ class RandomState(State):
|
|
258
263
|
axis=0,
|
259
264
|
key: Optional[SeedOrKey] = None):
|
260
265
|
key = self.split_key() if key is None else _formalize_key(key)
|
261
|
-
x =
|
266
|
+
x = permutation_for_unit(key, x, axis=axis)
|
262
267
|
return x
|
263
268
|
|
264
269
|
def beta(self,
|
@@ -458,7 +463,7 @@ class RandomState(State):
|
|
458
463
|
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
|
459
464
|
key = self.split_key() if key is None else _formalize_key(key)
|
460
465
|
dtype = dtype or environ.dftype()
|
461
|
-
r =
|
466
|
+
r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
|
462
467
|
return r
|
463
468
|
|
464
469
|
def __norm_cdf(self, x, sqrt2, dtype):
|
@@ -481,36 +486,46 @@ class RandomState(State):
|
|
481
486
|
scale = _check_py_seq(scale)
|
482
487
|
dtype = dtype or environ.dftype()
|
483
488
|
|
484
|
-
lower =
|
485
|
-
upper =
|
486
|
-
loc =
|
487
|
-
scale =
|
489
|
+
lower = bu.math.asarray(lower, dtype=dtype)
|
490
|
+
upper = bu.math.asarray(upper, dtype=dtype)
|
491
|
+
loc = bu.math.asarray(loc, dtype=dtype)
|
492
|
+
scale = bu.math.asarray(scale, dtype=dtype)
|
493
|
+
bu.fail_for_dimension_mismatch(lower, upper)
|
494
|
+
bu.fail_for_dimension_mismatch(lower, loc)
|
495
|
+
bu.fail_for_dimension_mismatch(lower, scale)
|
496
|
+
dim = lower.dim if isinstance(lower, bu.Quantity) else bu.DIMENSIONLESS
|
497
|
+
lower = lower.value if isinstance(lower, bu.Quantity) else lower
|
498
|
+
upper = upper.value if isinstance(upper, bu.Quantity) else upper
|
499
|
+
loc = loc.value if isinstance(loc, bu.Quantity) else loc
|
500
|
+
scale = scale.value if isinstance(scale, bu.Quantity) else scale
|
488
501
|
|
489
502
|
jit_error(
|
490
|
-
|
503
|
+
bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
491
504
|
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
492
505
|
"The distribution of values may be incorrect."
|
493
506
|
)
|
494
507
|
|
495
508
|
if size is None:
|
496
|
-
size =
|
497
|
-
|
498
|
-
|
499
|
-
|
509
|
+
size = bu.math.broadcast_shapes(jnp.shape(lower),
|
510
|
+
jnp.shape(upper),
|
511
|
+
jnp.shape(loc),
|
512
|
+
jnp.shape(scale))
|
500
513
|
|
501
514
|
# Values are generated by using a truncated uniform distribution and
|
502
515
|
# then using the inverse CDF for the normal distribution.
|
503
516
|
# Get upper and lower cdf values
|
504
|
-
sqrt2 = np.array(np.sqrt(2), dtype)
|
517
|
+
sqrt2 = np.array(np.sqrt(2), dtype=dtype)
|
505
518
|
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
|
506
519
|
u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
|
507
520
|
|
508
521
|
# Uniformly fill tensor with values from [l, u], then translate to
|
509
522
|
# [2l-1, 2u-1].
|
510
523
|
key = self.split_key() if key is None else _formalize_key(key)
|
511
|
-
out =
|
512
|
-
|
513
|
-
|
524
|
+
out = uniform_for_unit(
|
525
|
+
key, size, dtype,
|
526
|
+
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
527
|
+
maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))
|
528
|
+
)
|
514
529
|
|
515
530
|
# Use inverse cdf transform for normal distribution to get truncated
|
516
531
|
# standard normal
|
@@ -523,7 +538,7 @@ class RandomState(State):
|
|
523
538
|
out = jnp.clip(out,
|
524
539
|
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
525
540
|
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
|
526
|
-
return out
|
541
|
+
return out if dim == bu.DIMENSIONLESS else bu.Quantity(out, dim=dim)
|
527
542
|
|
528
543
|
def _check_p(self, p):
|
529
544
|
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
|
@@ -548,6 +563,13 @@ class RandomState(State):
|
|
548
563
|
dtype: DTypeLike = None):
|
549
564
|
mean = _check_py_seq(mean)
|
550
565
|
sigma = _check_py_seq(sigma)
|
566
|
+
mean = bu.math.asarray(mean, dtype=dtype)
|
567
|
+
sigma = bu.math.asarray(sigma, dtype=dtype)
|
568
|
+
bu.fail_for_dimension_mismatch(mean, sigma)
|
569
|
+
dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
|
570
|
+
mean = mean.value if isinstance(mean, bu.Quantity) else mean
|
571
|
+
sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
|
572
|
+
|
551
573
|
if size is None:
|
552
574
|
size = jnp.broadcast_shapes(jnp.shape(mean),
|
553
575
|
jnp.shape(sigma))
|
@@ -556,7 +578,7 @@ class RandomState(State):
|
|
556
578
|
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
557
579
|
samples = _loc_scale(mean, sigma, samples)
|
558
580
|
samples = jnp.exp(samples)
|
559
|
-
return samples
|
581
|
+
return samples if dim == bu.DIMENSIONLESS else bu.Quantity(samples, dim=dim)
|
560
582
|
|
561
583
|
def binomial(self,
|
562
584
|
n,
|
@@ -614,7 +636,7 @@ class RandomState(State):
|
|
614
636
|
size = jnp.shape(p)
|
615
637
|
key = self.split_key() if key is None else _formalize_key(key)
|
616
638
|
dtype = dtype or environ.dftype()
|
617
|
-
u =
|
639
|
+
u = uniform_for_unit(key, size, dtype=dtype)
|
618
640
|
r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
|
619
641
|
return r
|
620
642
|
|
@@ -640,20 +662,28 @@ class RandomState(State):
|
|
640
662
|
dtype = dtype or environ.ditype()
|
641
663
|
return jnp.asarray(r, dtype=dtype)
|
642
664
|
|
643
|
-
def multivariate_normal(
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
665
|
+
def multivariate_normal(
|
666
|
+
self,
|
667
|
+
mean,
|
668
|
+
cov,
|
669
|
+
size: Optional[Size] = None,
|
670
|
+
method: str = 'cholesky',
|
671
|
+
key: Optional[SeedOrKey] = None,
|
672
|
+
dtype: DTypeLike = None
|
673
|
+
):
|
650
674
|
if method not in {'svd', 'eigh', 'cholesky'}:
|
651
675
|
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
652
676
|
dtype = dtype or environ.dftype()
|
653
|
-
mean =
|
654
|
-
cov =
|
655
|
-
|
677
|
+
mean = bu.math.asarray(_check_py_seq(mean), dtype=dtype)
|
678
|
+
cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
|
679
|
+
if isinstance(mean, bu.Quantity):
|
680
|
+
assert isinstance(cov, bu.Quantity)
|
681
|
+
assert mean.dim ** 2 == cov.dim
|
682
|
+
mean = mean.value if isinstance(mean, bu.Quantity) else mean
|
683
|
+
cov = cov.value if isinstance(cov, bu.Quantity) else cov
|
684
|
+
dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
|
656
685
|
|
686
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
657
687
|
if not jnp.ndim(mean) >= 1:
|
658
688
|
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
|
659
689
|
if not jnp.ndim(cov) >= 2:
|
@@ -678,7 +708,7 @@ class RandomState(State):
|
|
678
708
|
factor = jnp.linalg.cholesky(cov)
|
679
709
|
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
680
710
|
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
681
|
-
return r
|
711
|
+
return r if dim == bu.DIMENSIONLESS else bu.Quantity(r, dim=dim)
|
682
712
|
|
683
713
|
def rayleigh(self,
|
684
714
|
scale=1.0,
|
@@ -690,7 +720,7 @@ class RandomState(State):
|
|
690
720
|
size = jnp.shape(scale)
|
691
721
|
key = self.split_key() if key is None else _formalize_key(key)
|
692
722
|
dtype = dtype or environ.dftype()
|
693
|
-
x = jnp.sqrt(-2. * jnp.log(
|
723
|
+
x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
|
694
724
|
r = x * scale
|
695
725
|
return r
|
696
726
|
|
@@ -734,7 +764,7 @@ class RandomState(State):
|
|
734
764
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
735
765
|
size = _size2shape(size)
|
736
766
|
dtype = dtype or environ.dftype()
|
737
|
-
random_uniform =
|
767
|
+
random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
|
738
768
|
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
739
769
|
return r
|
740
770
|
|
@@ -754,7 +784,7 @@ class RandomState(State):
|
|
754
784
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
755
785
|
size = _size2shape(size)
|
756
786
|
dtype = dtype or environ.dftype()
|
757
|
-
random_uniform =
|
787
|
+
random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
|
758
788
|
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
759
789
|
if scale is not None:
|
760
790
|
r /= scale
|
@@ -15,14 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
|
18
|
+
import functools
|
19
|
+
from functools import partial
|
19
20
|
from typing import Callable, Union
|
20
21
|
|
21
22
|
import jax
|
22
23
|
from jax import numpy as jnp
|
23
24
|
from jax.core import Primitive, ShapedArray
|
24
|
-
from jax.interpreters import batching, mlir
|
25
|
-
from jax.lax import cond
|
25
|
+
from jax.interpreters import batching, mlir
|
26
26
|
|
27
27
|
from brainstate._utils import set_module_as
|
28
28
|
|
@@ -32,16 +32,39 @@ __all__ = [
|
|
32
32
|
|
33
33
|
|
34
34
|
@set_module_as('brainstate.transform')
|
35
|
-
def remove_vmap(x, op='any'):
|
35
|
+
def remove_vmap(x, op: str = 'any'):
|
36
36
|
if op == 'any':
|
37
37
|
return _any_without_vmap(x)
|
38
38
|
elif op == 'all':
|
39
39
|
return _all_without_vmap(x)
|
40
|
+
elif op == 'none':
|
41
|
+
return _without_vmap(x)
|
40
42
|
else:
|
41
43
|
raise ValueError(f'Do not support type: {op}')
|
42
44
|
|
43
45
|
|
44
|
-
|
46
|
+
def _without_vmap(x):
|
47
|
+
return _no_vmap_prim.bind(x)
|
48
|
+
|
49
|
+
|
50
|
+
def _without_vmap_imp(x):
|
51
|
+
return x
|
52
|
+
|
53
|
+
|
54
|
+
def _without_vmap_abs(x):
|
55
|
+
return x
|
56
|
+
|
57
|
+
|
58
|
+
def _without_vmap_batch(x, batch_axes):
|
59
|
+
(x,) = x
|
60
|
+
return _without_vmap(x), batching.not_mapped
|
61
|
+
|
62
|
+
|
63
|
+
_no_vmap_prim = Primitive('no_vmap')
|
64
|
+
_no_vmap_prim.def_impl(_without_vmap_imp)
|
65
|
+
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
|
66
|
+
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
|
67
|
+
mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))
|
45
68
|
|
46
69
|
|
47
70
|
def _any_without_vmap(x):
|
@@ -61,16 +84,12 @@ def _any_without_vmap_batch(x, batch_axes):
|
|
61
84
|
return _any_without_vmap(x), batching.not_mapped
|
62
85
|
|
63
86
|
|
87
|
+
_any_no_vmap_prim = Primitive('any_no_vmap')
|
64
88
|
_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
|
65
89
|
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
|
66
90
|
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
|
67
|
-
if hasattr(xla, "lower_fun"):
|
68
|
-
xla.register_translation(_any_no_vmap_prim,
|
69
|
-
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
|
70
91
|
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
|
71
92
|
|
72
|
-
_all_no_vmap_prim = Primitive('all_no_vmap')
|
73
|
-
|
74
93
|
|
75
94
|
def _all_without_vmap(x):
|
76
95
|
return _all_no_vmap_prim.bind(x)
|
@@ -89,50 +108,54 @@ def _all_without_vmap_batch(x, batch_axes):
|
|
89
108
|
return _all_without_vmap(x), batching.not_mapped
|
90
109
|
|
91
110
|
|
111
|
+
_all_no_vmap_prim = Primitive('all_no_vmap')
|
92
112
|
_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
|
93
113
|
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
|
94
114
|
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
|
95
|
-
if hasattr(xla, "lower_fun"):
|
96
|
-
xla.register_translation(_all_no_vmap_prim,
|
97
|
-
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
|
98
115
|
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
|
99
116
|
|
100
117
|
|
101
|
-
def _err_jit_true_branch(err_fun,
|
102
|
-
jax.debug.callback(err_fun,
|
103
|
-
return
|
118
|
+
def _err_jit_true_branch(err_fun, args, kwargs):
|
119
|
+
jax.debug.callback(err_fun, *args, **kwargs)
|
104
120
|
|
105
121
|
|
106
|
-
def _err_jit_false_branch(
|
107
|
-
|
122
|
+
def _err_jit_false_branch(args, kwargs):
|
123
|
+
pass
|
108
124
|
|
109
125
|
|
110
|
-
def
|
111
|
-
|
112
|
-
|
113
|
-
|
126
|
+
def _error_msg(msg, *arg, **kwargs):
|
127
|
+
if len(arg):
|
128
|
+
msg = msg % arg
|
129
|
+
if len(kwargs):
|
130
|
+
msg = msg.format(**kwargs)
|
131
|
+
raise ValueError(msg)
|
114
132
|
|
115
|
-
cond(pred,
|
116
|
-
partial(_err_jit_true_branch, true_err_fun),
|
117
|
-
_err_jit_false_branch,
|
118
|
-
err_arg)
|
119
133
|
|
134
|
+
@set_module_as('brainstate.transform')
|
135
|
+
def jit_error(
|
136
|
+
pred,
|
137
|
+
err_fun: Union[Callable, str],
|
138
|
+
*err_args,
|
139
|
+
**err_kwargs,
|
140
|
+
):
|
141
|
+
"""
|
142
|
+
Check errors in a jit function.
|
120
143
|
|
121
|
-
|
122
|
-
|
123
|
-
raise ValueError(msg)
|
124
|
-
else:
|
125
|
-
raise ValueError(msg.format(arg))
|
144
|
+
Examples
|
145
|
+
--------
|
126
146
|
|
147
|
+
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
127
148
|
|
128
|
-
|
129
|
-
|
130
|
-
|
149
|
+
>>> def error(x):
|
150
|
+
>>> raise ValueError(f'error {x}')
|
151
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
152
|
+
>>> jit_error(x.sum() < 5., error, x)
|
153
|
+
|
154
|
+
Or, it can be a simple string message.
|
131
155
|
|
132
|
-
>>> def error(arg):
|
133
|
-
>>> raise ValueError(f'error {arg}')
|
134
156
|
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
135
|
-
>>> jit_error(x.sum() < 5.,
|
157
|
+
>>> jit_error(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
158
|
+
|
136
159
|
|
137
160
|
Parameters
|
138
161
|
----------
|
@@ -140,19 +163,18 @@ def jit_error(pred, err_fun: Union[Callable, str], err_arg=None, scope: str = 'a
|
|
140
163
|
The boolean prediction.
|
141
164
|
err_fun: callable
|
142
165
|
The error function, which raise errors.
|
143
|
-
|
166
|
+
err_args:
|
144
167
|
The arguments which passed into `err_f`.
|
145
|
-
|
146
|
-
The
|
168
|
+
err_kwargs:
|
169
|
+
The keywords which passed into `err_f`.
|
147
170
|
"""
|
148
171
|
if isinstance(err_fun, str):
|
149
172
|
err_fun = partial(_error_msg, err_fun)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
_cond(err_fun, pred, err_arg)
|
173
|
+
|
174
|
+
jax.lax.cond(
|
175
|
+
remove_vmap(pred, op='any'),
|
176
|
+
partial(_err_jit_true_branch, err_fun),
|
177
|
+
_err_jit_false_branch,
|
178
|
+
jax.tree.map(functools.partial(remove_vmap, op='none'), err_args),
|
179
|
+
jax.tree.map(functools.partial(remove_vmap, op='none'), err_kwargs),
|
180
|
+
)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jaxlib.xla_extension
|
20
|
+
import jax.numpy as jnp
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestJitError(unittest.TestCase):
|
26
|
+
def test1(self):
|
27
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
28
|
+
bst.transform.jit_error(True, 'error')
|
29
|
+
|
30
|
+
def err_f(x):
|
31
|
+
raise ValueError(f'error: {x}')
|
32
|
+
|
33
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
34
|
+
bst.transform.jit_error(True, err_f, 1.)
|
35
|
+
|
36
|
+
def test_vmap(self):
|
37
|
+
|
38
|
+
def f(x):
|
39
|
+
bst.transform.jit_error(x, 'error: {x}', x=x)
|
40
|
+
|
41
|
+
jax.vmap(f)(jnp.array([False, False, False]))
|
42
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
43
|
+
jax.vmap(f)(jnp.array([True, False, False]))
|
44
|
+
|
45
|
+
def test_vmap_vmap(self):
|
46
|
+
|
47
|
+
def f(x):
|
48
|
+
bst.transform.jit_error(x, 'error: {x}', x=x)
|
49
|
+
|
50
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
51
|
+
[False, False, False]]))
|
52
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
53
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
54
|
+
[True, False, False]]))
|
55
|
+
|
@@ -369,11 +369,11 @@ class StatefulFunction(object):
|
|
369
369
|
self._state_trace[cache_key] = _state_trace
|
370
370
|
with _state_trace:
|
371
371
|
out = self.fun(*args, **kwargs)
|
372
|
-
state_values = _state_trace.collect_values('read', 'write')
|
372
|
+
state_values = _state_trace.collect_values('read', 'write', check_val_tree=True)
|
373
373
|
_state_trace.recovery_original_values()
|
374
374
|
|
375
|
-
#
|
376
|
-
#
|
375
|
+
# State instance as functional returns is not allowed.
|
376
|
+
# Checking whether the states are returned.
|
377
377
|
for leaf in jax.tree.leaves(out):
|
378
378
|
if isinstance(leaf, State):
|
379
379
|
leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
|
@@ -399,7 +399,6 @@ class StatefulFunction(object):
|
|
399
399
|
if cache_key not in self._state_trace:
|
400
400
|
try:
|
401
401
|
# jaxpr
|
402
|
-
# jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
|
403
402
|
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
404
403
|
functools.partial(self._wrapped_fun_to_eval, cache_key),
|
405
404
|
static_argnums=self.static_argnums,
|
@@ -435,8 +434,11 @@ class StatefulFunction(object):
|
|
435
434
|
"""
|
436
435
|
# state checking
|
437
436
|
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
438
|
-
states = self.get_states(cache_key)
|
437
|
+
states: Sequence[State] = self.get_states(cache_key)
|
439
438
|
assert len(state_vals) == len(states), 'State length mismatch.'
|
439
|
+
# # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
|
440
|
+
# for val, st in zip(state_vals, states): # check state's value tree structure
|
441
|
+
# st._check_value_tree(val)
|
440
442
|
|
441
443
|
# parameters
|
442
444
|
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
@@ -450,6 +452,9 @@ class StatefulFunction(object):
|
|
450
452
|
# output processing
|
451
453
|
out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
|
452
454
|
assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
|
455
|
+
# # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
|
456
|
+
# for val, st in zip(new_state_vals, states): # check state's value tree structure
|
457
|
+
# st._check_value_tree(val)
|
453
458
|
return new_state_vals, out
|
454
459
|
|
455
460
|
def jaxpr_call_auto(self, *args, **kwargs) -> Any:
|
brainstate/typing.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
from typing import Any, Sequence, Protocol, Union
|
18
18
|
|
19
|
+
import brainunit as bu
|
19
20
|
import jax
|
20
21
|
import numpy as np
|
21
22
|
|
@@ -43,6 +44,7 @@ ArrayLike = Union[
|
|
43
44
|
np.ndarray, # NumPy array type
|
44
45
|
np.bool_, np.number, # NumPy scalar types
|
45
46
|
bool, int, float, complex, # Python scalar types
|
47
|
+
bu.Quantity, # quantity
|
46
48
|
]
|
47
49
|
|
48
50
|
# --- Dtype --- #
|