brainstate 0.0.1.1.post20240804__py2.py3-none-any.whl → 0.0.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Brain Dynamics Programming
18
18
  """
19
19
 
20
- __version__ = "0.0.1.1"
20
+ __version__ = "0.0.2"
21
21
 
22
22
  from . import environ
23
23
  from . import functional
@@ -30,13 +30,14 @@ def uniform_for_unit(
30
30
  maxval: ArrayLike = 1.
31
31
  ) -> jax.Array | bu.Quantity:
32
32
  if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
33
- return bu.Quantity(jr.uniform(key, shape, dtype, minval.value, maxval.value), dim=minval.dim)
33
+ maxval = maxval.in_unit(minval.unit)
34
+ return bu.Quantity(jr.uniform(key, shape, dtype, minval.mantissa, maxval.mantissa), unit=minval.unit)
34
35
  elif isinstance(minval, bu.Quantity):
35
36
  assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
36
- minval = minval.value
37
+ minval = minval.mantissa
37
38
  elif isinstance(maxval, bu.Quantity):
38
39
  assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
39
- maxval = maxval.value
40
+ maxval = maxval.mantissa
40
41
  return jr.uniform(key, shape, dtype, minval, maxval)
41
42
 
42
43
 
@@ -47,5 +48,5 @@ def permutation_for_unit(
47
48
  independent: bool = False
48
49
  ) -> jax.Array | bu.Quantity:
49
50
  if isinstance(x, bu.Quantity):
50
- return bu.Quantity(jr.permutation(key, x.value, axis, independent=independent), dim=x.dim)
51
+ return bu.Quantity(jr.permutation(key, x.mantissa, axis, independent=independent), unit=x.unit)
51
52
  return jr.permutation(key, x, axis, independent=independent)
brainstate/random.py CHANGED
@@ -490,14 +490,13 @@ class RandomState(State):
490
490
  upper = bu.math.asarray(upper, dtype=dtype)
491
491
  loc = bu.math.asarray(loc, dtype=dtype)
492
492
  scale = bu.math.asarray(scale, dtype=dtype)
493
- bu.fail_for_dimension_mismatch(lower, upper)
494
- bu.fail_for_dimension_mismatch(lower, loc)
495
- bu.fail_for_dimension_mismatch(lower, scale)
496
- dim = lower.dim if isinstance(lower, bu.Quantity) else bu.DIMENSIONLESS
497
- lower = lower.value if isinstance(lower, bu.Quantity) else lower
498
- upper = upper.value if isinstance(upper, bu.Quantity) else upper
499
- loc = loc.value if isinstance(loc, bu.Quantity) else loc
500
- scale = scale.value if isinstance(scale, bu.Quantity) else scale
493
+ unit = bu.get_unit(lower)
494
+ lower, upper, loc, scale = (
495
+ lower.mantissa if isinstance(lower, bu.Quantity) else lower,
496
+ bu.Quantity(upper).in_unit(unit).mantissa,
497
+ bu.Quantity(loc).in_unit(unit).mantissa,
498
+ bu.Quantity(scale).in_unit(unit).mantissa
499
+ )
501
500
 
502
501
  jit_error(
503
502
  bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
@@ -535,10 +534,12 @@ class RandomState(State):
535
534
  out = out * scale * sqrt2 + loc
536
535
 
537
536
  # Clamp to ensure it's in the proper range
538
- out = jnp.clip(out,
539
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
541
- return out if dim == bu.DIMENSIONLESS else bu.Quantity(out, dim=dim)
537
+ out = jnp.clip(
538
+ out,
539
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
541
+ )
542
+ return out if unit.is_unitless else bu.Quantity(out, unit=unit)
542
543
 
543
544
  def _check_p(self, p):
544
545
  raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
@@ -555,30 +556,33 @@ class RandomState(State):
555
556
  r = jr.bernoulli(key, p=p, shape=_size2shape(size))
556
557
  return r
557
558
 
558
- def lognormal(self,
559
- mean=None,
560
- sigma=None,
561
- size: Optional[Size] = None,
562
- key: Optional[SeedOrKey] = None,
563
- dtype: DTypeLike = None):
559
+ def lognormal(
560
+ self,
561
+ mean=None,
562
+ sigma=None,
563
+ size: Optional[Size] = None,
564
+ key: Optional[SeedOrKey] = None,
565
+ dtype: DTypeLike = None
566
+ ):
564
567
  mean = _check_py_seq(mean)
565
568
  sigma = _check_py_seq(sigma)
566
569
  mean = bu.math.asarray(mean, dtype=dtype)
567
570
  sigma = bu.math.asarray(sigma, dtype=dtype)
568
- bu.fail_for_dimension_mismatch(mean, sigma)
569
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
570
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
571
- sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
571
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
572
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
573
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
572
574
 
573
575
  if size is None:
574
- size = jnp.broadcast_shapes(jnp.shape(mean),
575
- jnp.shape(sigma))
576
+ size = jnp.broadcast_shapes(
577
+ jnp.shape(mean),
578
+ jnp.shape(sigma)
579
+ )
576
580
  key = self.split_key() if key is None else _formalize_key(key)
577
581
  dtype = dtype or environ.dftype()
578
582
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
579
583
  samples = _loc_scale(mean, sigma, samples)
580
584
  samples = jnp.exp(samples)
581
- return samples if dim == bu.DIMENSIONLESS else bu.Quantity(samples, dim=dim)
585
+ return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
582
586
 
583
587
  def binomial(self,
584
588
  n,
@@ -678,10 +682,10 @@ class RandomState(State):
678
682
  cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
679
683
  if isinstance(mean, bu.Quantity):
680
684
  assert isinstance(cov, bu.Quantity)
681
- assert mean.dim ** 2 == cov.dim
682
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
683
- cov = cov.value if isinstance(cov, bu.Quantity) else cov
684
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
685
+ assert mean.unit ** 2 == cov.unit
686
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
687
+ cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
688
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
685
689
 
686
690
  key = self.split_key() if key is None else _formalize_key(key)
687
691
  if not jnp.ndim(mean) >= 1:
@@ -708,7 +712,7 @@ class RandomState(State):
708
712
  factor = jnp.linalg.cholesky(cov)
709
713
  normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
710
714
  r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
711
- return r if dim == bu.DIMENSIONLESS else bu.Quantity(r, dim=dim)
715
+ return r if unit.is_unitless else bu.Quantity(r, unit=unit)
712
716
 
713
717
  def rayleigh(self,
714
718
  scale=1.0,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.1.post20240804
3
+ Version: 0.0.2
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -1,14 +1,14 @@
1
- brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
1
+ brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
2
2
  brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
- brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
4
+ brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
5
5
  brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
9
9
  brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
10
  brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
- brainstate/random.py,sha256=rqwSsiUoeZwxhk9ot8NnOJA8iWMdZB0HaHOVuweJdZQ,188387
11
+ brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
14
  brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
59
59
  brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.1.1.post20240804.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.1.1.post20240804.dist-info/METADATA,sha256=RTuqQrR0-syn5SyxoyEbfbdAUpXBRxNMzpaqnVM2cqQ,3807
64
- brainstate-0.0.1.1.post20240804.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.1.1.post20240804.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.1.1.post20240804.dist-info/RECORD,,
62
+ brainstate-0.0.2.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.2.dist-info/METADATA,sha256=K6yiVOqGj3Qs_vKGgQmFXZtlu8cS4r7EZXl_iyCjwh0,3792
64
+ brainstate-0.0.2.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.2.dist-info/RECORD,,