brainstate 0.0.1.1.post20240804__py2.py3-none-any.whl → 0.0.2.post20240814__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Brain Dynamics Programming
18
18
  """
19
19
 
20
- __version__ = "0.0.1.1"
20
+ __version__ = "0.0.2"
21
21
 
22
22
  from . import environ
23
23
  from . import functional
@@ -30,13 +30,14 @@ def uniform_for_unit(
30
30
  maxval: ArrayLike = 1.
31
31
  ) -> jax.Array | bu.Quantity:
32
32
  if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
33
- return bu.Quantity(jr.uniform(key, shape, dtype, minval.value, maxval.value), dim=minval.dim)
33
+ maxval = maxval.in_unit(minval.unit)
34
+ return bu.Quantity(jr.uniform(key, shape, dtype, minval.mantissa, maxval.mantissa), unit=minval.unit)
34
35
  elif isinstance(minval, bu.Quantity):
35
36
  assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
36
- minval = minval.value
37
+ minval = minval.mantissa
37
38
  elif isinstance(maxval, bu.Quantity):
38
39
  assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
39
- maxval = maxval.value
40
+ maxval = maxval.mantissa
40
41
  return jr.uniform(key, shape, dtype, minval, maxval)
41
42
 
42
43
 
@@ -47,5 +48,5 @@ def permutation_for_unit(
47
48
  independent: bool = False
48
49
  ) -> jax.Array | bu.Quantity:
49
50
  if isinstance(x, bu.Quantity):
50
- return bu.Quantity(jr.permutation(key, x.value, axis, independent=independent), dim=x.dim)
51
+ return bu.Quantity(jr.permutation(key, x.mantissa, axis, independent=independent), unit=x.unit)
51
52
  return jr.permutation(key, x, axis, independent=independent)
@@ -289,8 +289,8 @@ class VarianceScaling(Initializer):
289
289
  denominator = (fan_in + fan_out) / 2
290
290
  else:
291
291
  raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
292
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
293
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
292
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
293
+ unit = bu.get_unit(self.scale)
294
294
  variance = (scale / denominator).astype(self.dtype)
295
295
  if self.distribution == "truncated_normal":
296
296
  stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
@@ -302,7 +302,7 @@ class VarianceScaling(Initializer):
302
302
  jnp.sqrt(3 * variance).astype(self.dtype))
303
303
  else:
304
304
  raise ValueError("invalid distribution for variance scaling initializer")
305
- return res if dim == bu.DIMENSIONLESS else res * dim
305
+ return res if unit.is_unitless else bu.Quantity(res, unit=unit)
306
306
 
307
307
  def __repr__(self):
308
308
  name = self.__class__.__name__
@@ -445,8 +445,8 @@ class Orthogonal(Initializer):
445
445
  matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
446
446
  norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
447
447
 
448
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
449
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
448
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
449
+ unit = bu.get_unit(self.scale)
450
450
  q_mat, r_mat = jnp.linalg.qr(norm_dst)
451
451
  # Enforce Q is uniformly distributed
452
452
  q_mat *= jnp.sign(jnp.diag(r_mat))
@@ -455,7 +455,7 @@ class Orthogonal(Initializer):
455
455
  q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
456
456
  q_mat = jnp.moveaxis(q_mat, 0, self.axis)
457
457
  r = jnp.asarray(scale, dtype=self.dtype) * q_mat
458
- return r if dim == bu.DIMENSIONLESS else r * dim
458
+ return r if unit.is_unitless else bu.Quantity(r, unit=unit)
459
459
 
460
460
  def __repr__(self):
461
461
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -480,8 +480,8 @@ class DeltaOrthogonal(Initializer):
480
480
  raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
481
481
  if shape[-1] < shape[-2]:
482
482
  raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
483
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
484
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
483
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
484
+ unit = bu.get_unit(self.scale)
485
485
  ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
486
486
  W = jnp.zeros(shape, dtype=self.dtype)
487
487
  if len(shape) == 3:
@@ -493,7 +493,7 @@ class DeltaOrthogonal(Initializer):
493
493
  else:
494
494
  k1, k2, k3 = shape[:3]
495
495
  W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
496
- return W if dim == bu.DIMENSIONLESS else W * dim
496
+ return W if unit.is_unitless else bu.Quantity(W, unit=unit)
497
497
 
498
498
  def __repr__(self):
499
499
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
brainstate/random.py CHANGED
@@ -490,14 +490,13 @@ class RandomState(State):
490
490
  upper = bu.math.asarray(upper, dtype=dtype)
491
491
  loc = bu.math.asarray(loc, dtype=dtype)
492
492
  scale = bu.math.asarray(scale, dtype=dtype)
493
- bu.fail_for_dimension_mismatch(lower, upper)
494
- bu.fail_for_dimension_mismatch(lower, loc)
495
- bu.fail_for_dimension_mismatch(lower, scale)
496
- dim = lower.dim if isinstance(lower, bu.Quantity) else bu.DIMENSIONLESS
497
- lower = lower.value if isinstance(lower, bu.Quantity) else lower
498
- upper = upper.value if isinstance(upper, bu.Quantity) else upper
499
- loc = loc.value if isinstance(loc, bu.Quantity) else loc
500
- scale = scale.value if isinstance(scale, bu.Quantity) else scale
493
+ unit = bu.get_unit(lower)
494
+ lower, upper, loc, scale = (
495
+ lower.mantissa if isinstance(lower, bu.Quantity) else lower,
496
+ bu.Quantity(upper).in_unit(unit).mantissa,
497
+ bu.Quantity(loc).in_unit(unit).mantissa,
498
+ bu.Quantity(scale).in_unit(unit).mantissa
499
+ )
501
500
 
502
501
  jit_error(
503
502
  bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
@@ -535,10 +534,12 @@ class RandomState(State):
535
534
  out = out * scale * sqrt2 + loc
536
535
 
537
536
  # Clamp to ensure it's in the proper range
538
- out = jnp.clip(out,
539
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
541
- return out if dim == bu.DIMENSIONLESS else bu.Quantity(out, dim=dim)
537
+ out = jnp.clip(
538
+ out,
539
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
541
+ )
542
+ return out if unit.is_unitless else bu.Quantity(out, unit=unit)
542
543
 
543
544
  def _check_p(self, p):
544
545
  raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
@@ -555,30 +556,33 @@ class RandomState(State):
555
556
  r = jr.bernoulli(key, p=p, shape=_size2shape(size))
556
557
  return r
557
558
 
558
- def lognormal(self,
559
- mean=None,
560
- sigma=None,
561
- size: Optional[Size] = None,
562
- key: Optional[SeedOrKey] = None,
563
- dtype: DTypeLike = None):
559
+ def lognormal(
560
+ self,
561
+ mean=None,
562
+ sigma=None,
563
+ size: Optional[Size] = None,
564
+ key: Optional[SeedOrKey] = None,
565
+ dtype: DTypeLike = None
566
+ ):
564
567
  mean = _check_py_seq(mean)
565
568
  sigma = _check_py_seq(sigma)
566
569
  mean = bu.math.asarray(mean, dtype=dtype)
567
570
  sigma = bu.math.asarray(sigma, dtype=dtype)
568
- bu.fail_for_dimension_mismatch(mean, sigma)
569
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
570
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
571
- sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
571
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
572
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
573
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
572
574
 
573
575
  if size is None:
574
- size = jnp.broadcast_shapes(jnp.shape(mean),
575
- jnp.shape(sigma))
576
+ size = jnp.broadcast_shapes(
577
+ jnp.shape(mean),
578
+ jnp.shape(sigma)
579
+ )
576
580
  key = self.split_key() if key is None else _formalize_key(key)
577
581
  dtype = dtype or environ.dftype()
578
582
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
579
583
  samples = _loc_scale(mean, sigma, samples)
580
584
  samples = jnp.exp(samples)
581
- return samples if dim == bu.DIMENSIONLESS else bu.Quantity(samples, dim=dim)
585
+ return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
582
586
 
583
587
  def binomial(self,
584
588
  n,
@@ -678,10 +682,10 @@ class RandomState(State):
678
682
  cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
679
683
  if isinstance(mean, bu.Quantity):
680
684
  assert isinstance(cov, bu.Quantity)
681
- assert mean.dim ** 2 == cov.dim
682
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
683
- cov = cov.value if isinstance(cov, bu.Quantity) else cov
684
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
685
+ assert mean.unit ** 2 == cov.unit
686
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
687
+ cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
688
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
685
689
 
686
690
  key = self.split_key() if key is None else _formalize_key(key)
687
691
  if not jnp.ndim(mean) >= 1:
@@ -708,7 +712,7 @@ class RandomState(State):
708
712
  factor = jnp.linalg.cholesky(cov)
709
713
  normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
710
714
  r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
711
- return r if dim == bu.DIMENSIONLESS else bu.Quantity(r, dim=dim)
715
+ return r if unit.is_unitless else bu.Quantity(r, unit=unit)
712
716
 
713
717
  def rayleigh(self,
714
718
  scale=1.0,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.1.post20240804
3
+ Version: 0.0.2.post20240814
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -1,14 +1,14 @@
1
- brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
1
+ brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
2
2
  brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
- brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
4
+ brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
5
5
  brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
9
9
  brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
10
  brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
- brainstate/random.py,sha256=rqwSsiUoeZwxhk9ot8NnOJA8iWMdZB0HaHOVuweJdZQ,188387
11
+ brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
14
  brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
@@ -21,7 +21,7 @@ brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il
21
21
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
22
22
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
23
23
  brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
24
- brainstate/init/_random_inits.py,sha256=LsfvKSX4wsR7Kh5jgKgdyXTCEEa5Nn_iYcp_2GgLQKY,16030
24
+ brainstate/init/_random_inits.py,sha256=vNUVDdUOCXTx2i3i1enzxgg1USCzugYd56r0-2lBL-0,15919
25
25
  brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
26
26
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
27
27
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
59
59
  brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.1.1.post20240804.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.1.1.post20240804.dist-info/METADATA,sha256=RTuqQrR0-syn5SyxoyEbfbdAUpXBRxNMzpaqnVM2cqQ,3807
64
- brainstate-0.0.1.1.post20240804.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.1.1.post20240804.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.1.1.post20240804.dist-info/RECORD,,
62
+ brainstate-0.0.2.post20240814.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.2.post20240814.dist-info/METADATA,sha256=skMWlfxiGaJHxzQS7dY95V91umhXVjN_HuhGO0xHP1M,3805
64
+ brainstate-0.0.2.post20240814.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.2.post20240814.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.2.post20240814.dist-info/RECORD,,