brainstate 0.0.1.1.post20240804__py2.py3-none-any.whl → 0.0.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_random_for_unit.py +5 -4
- brainstate/random.py +34 -30
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.dist-info}/METADATA +1 -1
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.dist-info}/RECORD +8 -8
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_random_for_unit.py
CHANGED
@@ -30,13 +30,14 @@ def uniform_for_unit(
|
|
30
30
|
maxval: ArrayLike = 1.
|
31
31
|
) -> jax.Array | bu.Quantity:
|
32
32
|
if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
|
33
|
-
|
33
|
+
maxval = maxval.in_unit(minval.unit)
|
34
|
+
return bu.Quantity(jr.uniform(key, shape, dtype, minval.mantissa, maxval.mantissa), unit=minval.unit)
|
34
35
|
elif isinstance(minval, bu.Quantity):
|
35
36
|
assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
|
36
|
-
minval = minval.
|
37
|
+
minval = minval.mantissa
|
37
38
|
elif isinstance(maxval, bu.Quantity):
|
38
39
|
assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
|
39
|
-
maxval = maxval.
|
40
|
+
maxval = maxval.mantissa
|
40
41
|
return jr.uniform(key, shape, dtype, minval, maxval)
|
41
42
|
|
42
43
|
|
@@ -47,5 +48,5 @@ def permutation_for_unit(
|
|
47
48
|
independent: bool = False
|
48
49
|
) -> jax.Array | bu.Quantity:
|
49
50
|
if isinstance(x, bu.Quantity):
|
50
|
-
return bu.Quantity(jr.permutation(key, x.
|
51
|
+
return bu.Quantity(jr.permutation(key, x.mantissa, axis, independent=independent), unit=x.unit)
|
51
52
|
return jr.permutation(key, x, axis, independent=independent)
|
brainstate/random.py
CHANGED
@@ -490,14 +490,13 @@ class RandomState(State):
|
|
490
490
|
upper = bu.math.asarray(upper, dtype=dtype)
|
491
491
|
loc = bu.math.asarray(loc, dtype=dtype)
|
492
492
|
scale = bu.math.asarray(scale, dtype=dtype)
|
493
|
-
bu.
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
scale = scale.value if isinstance(scale, bu.Quantity) else scale
|
493
|
+
unit = bu.get_unit(lower)
|
494
|
+
lower, upper, loc, scale = (
|
495
|
+
lower.mantissa if isinstance(lower, bu.Quantity) else lower,
|
496
|
+
bu.Quantity(upper).in_unit(unit).mantissa,
|
497
|
+
bu.Quantity(loc).in_unit(unit).mantissa,
|
498
|
+
bu.Quantity(scale).in_unit(unit).mantissa
|
499
|
+
)
|
501
500
|
|
502
501
|
jit_error(
|
503
502
|
bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
@@ -535,10 +534,12 @@ class RandomState(State):
|
|
535
534
|
out = out * scale * sqrt2 + loc
|
536
535
|
|
537
536
|
# Clamp to ensure it's in the proper range
|
538
|
-
out = jnp.clip(
|
539
|
-
|
540
|
-
|
541
|
-
|
537
|
+
out = jnp.clip(
|
538
|
+
out,
|
539
|
+
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
540
|
+
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
541
|
+
)
|
542
|
+
return out if unit.is_unitless else bu.Quantity(out, unit=unit)
|
542
543
|
|
543
544
|
def _check_p(self, p):
|
544
545
|
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
|
@@ -555,30 +556,33 @@ class RandomState(State):
|
|
555
556
|
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
556
557
|
return r
|
557
558
|
|
558
|
-
def lognormal(
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
559
|
+
def lognormal(
|
560
|
+
self,
|
561
|
+
mean=None,
|
562
|
+
sigma=None,
|
563
|
+
size: Optional[Size] = None,
|
564
|
+
key: Optional[SeedOrKey] = None,
|
565
|
+
dtype: DTypeLike = None
|
566
|
+
):
|
564
567
|
mean = _check_py_seq(mean)
|
565
568
|
sigma = _check_py_seq(sigma)
|
566
569
|
mean = bu.math.asarray(mean, dtype=dtype)
|
567
570
|
sigma = bu.math.asarray(sigma, dtype=dtype)
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
|
571
|
+
unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
|
572
|
+
mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
|
573
|
+
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
|
572
574
|
|
573
575
|
if size is None:
|
574
|
-
size = jnp.broadcast_shapes(
|
575
|
-
|
576
|
+
size = jnp.broadcast_shapes(
|
577
|
+
jnp.shape(mean),
|
578
|
+
jnp.shape(sigma)
|
579
|
+
)
|
576
580
|
key = self.split_key() if key is None else _formalize_key(key)
|
577
581
|
dtype = dtype or environ.dftype()
|
578
582
|
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
579
583
|
samples = _loc_scale(mean, sigma, samples)
|
580
584
|
samples = jnp.exp(samples)
|
581
|
-
return samples if
|
585
|
+
return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
|
582
586
|
|
583
587
|
def binomial(self,
|
584
588
|
n,
|
@@ -678,10 +682,10 @@ class RandomState(State):
|
|
678
682
|
cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
|
679
683
|
if isinstance(mean, bu.Quantity):
|
680
684
|
assert isinstance(cov, bu.Quantity)
|
681
|
-
assert mean.
|
682
|
-
mean = mean.
|
683
|
-
cov = cov.
|
684
|
-
|
685
|
+
assert mean.unit ** 2 == cov.unit
|
686
|
+
mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
|
687
|
+
cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
|
688
|
+
unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
|
685
689
|
|
686
690
|
key = self.split_key() if key is None else _formalize_key(key)
|
687
691
|
if not jnp.ndim(mean) >= 1:
|
@@ -708,7 +712,7 @@ class RandomState(State):
|
|
708
712
|
factor = jnp.linalg.cholesky(cov)
|
709
713
|
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
710
714
|
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
711
|
-
return r if
|
715
|
+
return r if unit.is_unitless else bu.Quantity(r, unit=unit)
|
712
716
|
|
713
717
|
def rayleigh(self,
|
714
718
|
scale=1.0,
|
@@ -1,14 +1,14 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
|
2
2
|
brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
|
-
brainstate/_random_for_unit.py,sha256=
|
4
|
+
brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
|
5
5
|
brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
8
|
brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
|
9
9
|
brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
10
10
|
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
|
-
brainstate/random.py,sha256=
|
11
|
+
brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
13
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
14
14
|
brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
|
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
|
|
59
59
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.
|
63
|
-
brainstate-0.0.
|
64
|
-
brainstate-0.0.
|
65
|
-
brainstate-0.0.
|
66
|
-
brainstate-0.0.
|
62
|
+
brainstate-0.0.2.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.2.dist-info/METADATA,sha256=K6yiVOqGj3Qs_vKGgQmFXZtlu8cS4r7EZXl_iyCjwh0,3792
|
64
|
+
brainstate-0.0.2.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
+
brainstate-0.0.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|