brainstate 0.0.1.1.post20240804__py2.py3-none-any.whl → 0.0.2.post20240814__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_random_for_unit.py +5 -4
- brainstate/init/_random_inits.py +9 -9
- brainstate/random.py +34 -30
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/METADATA +1 -1
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/RECORD +9 -9
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_random_for_unit.py
CHANGED
@@ -30,13 +30,14 @@ def uniform_for_unit(
|
|
30
30
|
maxval: ArrayLike = 1.
|
31
31
|
) -> jax.Array | bu.Quantity:
|
32
32
|
if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
|
33
|
-
|
33
|
+
maxval = maxval.in_unit(minval.unit)
|
34
|
+
return bu.Quantity(jr.uniform(key, shape, dtype, minval.mantissa, maxval.mantissa), unit=minval.unit)
|
34
35
|
elif isinstance(minval, bu.Quantity):
|
35
36
|
assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
|
36
|
-
minval = minval.
|
37
|
+
minval = minval.mantissa
|
37
38
|
elif isinstance(maxval, bu.Quantity):
|
38
39
|
assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
|
39
|
-
maxval = maxval.
|
40
|
+
maxval = maxval.mantissa
|
40
41
|
return jr.uniform(key, shape, dtype, minval, maxval)
|
41
42
|
|
42
43
|
|
@@ -47,5 +48,5 @@ def permutation_for_unit(
|
|
47
48
|
independent: bool = False
|
48
49
|
) -> jax.Array | bu.Quantity:
|
49
50
|
if isinstance(x, bu.Quantity):
|
50
|
-
return bu.Quantity(jr.permutation(key, x.
|
51
|
+
return bu.Quantity(jr.permutation(key, x.mantissa, axis, independent=independent), unit=x.unit)
|
51
52
|
return jr.permutation(key, x, axis, independent=independent)
|
brainstate/init/_random_inits.py
CHANGED
@@ -289,8 +289,8 @@ class VarianceScaling(Initializer):
|
|
289
289
|
denominator = (fan_in + fan_out) / 2
|
290
290
|
else:
|
291
291
|
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
292
|
-
scale = self.scale.
|
293
|
-
|
292
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
293
|
+
unit = bu.get_unit(self.scale)
|
294
294
|
variance = (scale / denominator).astype(self.dtype)
|
295
295
|
if self.distribution == "truncated_normal":
|
296
296
|
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
|
@@ -302,7 +302,7 @@ class VarianceScaling(Initializer):
|
|
302
302
|
jnp.sqrt(3 * variance).astype(self.dtype))
|
303
303
|
else:
|
304
304
|
raise ValueError("invalid distribution for variance scaling initializer")
|
305
|
-
return res if
|
305
|
+
return res if unit.is_unitless else bu.Quantity(res, unit=unit)
|
306
306
|
|
307
307
|
def __repr__(self):
|
308
308
|
name = self.__class__.__name__
|
@@ -445,8 +445,8 @@ class Orthogonal(Initializer):
|
|
445
445
|
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
446
446
|
norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
|
447
447
|
|
448
|
-
scale = self.scale.
|
449
|
-
|
448
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
449
|
+
unit = bu.get_unit(self.scale)
|
450
450
|
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
451
451
|
# Enforce Q is uniformly distributed
|
452
452
|
q_mat *= jnp.sign(jnp.diag(r_mat))
|
@@ -455,7 +455,7 @@ class Orthogonal(Initializer):
|
|
455
455
|
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
456
456
|
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
457
457
|
r = jnp.asarray(scale, dtype=self.dtype) * q_mat
|
458
|
-
return r if
|
458
|
+
return r if unit.is_unitless else bu.Quantity(r, unit=unit)
|
459
459
|
|
460
460
|
def __repr__(self):
|
461
461
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -480,8 +480,8 @@ class DeltaOrthogonal(Initializer):
|
|
480
480
|
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
481
481
|
if shape[-1] < shape[-2]:
|
482
482
|
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
483
|
-
scale = self.scale.
|
484
|
-
|
483
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
484
|
+
unit = bu.get_unit(self.scale)
|
485
485
|
ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
|
486
486
|
W = jnp.zeros(shape, dtype=self.dtype)
|
487
487
|
if len(shape) == 3:
|
@@ -493,7 +493,7 @@ class DeltaOrthogonal(Initializer):
|
|
493
493
|
else:
|
494
494
|
k1, k2, k3 = shape[:3]
|
495
495
|
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
496
|
-
return W if
|
496
|
+
return W if unit.is_unitless else bu.Quantity(W, unit=unit)
|
497
497
|
|
498
498
|
def __repr__(self):
|
499
499
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
brainstate/random.py
CHANGED
@@ -490,14 +490,13 @@ class RandomState(State):
|
|
490
490
|
upper = bu.math.asarray(upper, dtype=dtype)
|
491
491
|
loc = bu.math.asarray(loc, dtype=dtype)
|
492
492
|
scale = bu.math.asarray(scale, dtype=dtype)
|
493
|
-
bu.
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
scale = scale.value if isinstance(scale, bu.Quantity) else scale
|
493
|
+
unit = bu.get_unit(lower)
|
494
|
+
lower, upper, loc, scale = (
|
495
|
+
lower.mantissa if isinstance(lower, bu.Quantity) else lower,
|
496
|
+
bu.Quantity(upper).in_unit(unit).mantissa,
|
497
|
+
bu.Quantity(loc).in_unit(unit).mantissa,
|
498
|
+
bu.Quantity(scale).in_unit(unit).mantissa
|
499
|
+
)
|
501
500
|
|
502
501
|
jit_error(
|
503
502
|
bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
@@ -535,10 +534,12 @@ class RandomState(State):
|
|
535
534
|
out = out * scale * sqrt2 + loc
|
536
535
|
|
537
536
|
# Clamp to ensure it's in the proper range
|
538
|
-
out = jnp.clip(
|
539
|
-
|
540
|
-
|
541
|
-
|
537
|
+
out = jnp.clip(
|
538
|
+
out,
|
539
|
+
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
540
|
+
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
541
|
+
)
|
542
|
+
return out if unit.is_unitless else bu.Quantity(out, unit=unit)
|
542
543
|
|
543
544
|
def _check_p(self, p):
|
544
545
|
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
|
@@ -555,30 +556,33 @@ class RandomState(State):
|
|
555
556
|
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
556
557
|
return r
|
557
558
|
|
558
|
-
def lognormal(
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
559
|
+
def lognormal(
|
560
|
+
self,
|
561
|
+
mean=None,
|
562
|
+
sigma=None,
|
563
|
+
size: Optional[Size] = None,
|
564
|
+
key: Optional[SeedOrKey] = None,
|
565
|
+
dtype: DTypeLike = None
|
566
|
+
):
|
564
567
|
mean = _check_py_seq(mean)
|
565
568
|
sigma = _check_py_seq(sigma)
|
566
569
|
mean = bu.math.asarray(mean, dtype=dtype)
|
567
570
|
sigma = bu.math.asarray(sigma, dtype=dtype)
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
|
571
|
+
unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
|
572
|
+
mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
|
573
|
+
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
|
572
574
|
|
573
575
|
if size is None:
|
574
|
-
size = jnp.broadcast_shapes(
|
575
|
-
|
576
|
+
size = jnp.broadcast_shapes(
|
577
|
+
jnp.shape(mean),
|
578
|
+
jnp.shape(sigma)
|
579
|
+
)
|
576
580
|
key = self.split_key() if key is None else _formalize_key(key)
|
577
581
|
dtype = dtype or environ.dftype()
|
578
582
|
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
579
583
|
samples = _loc_scale(mean, sigma, samples)
|
580
584
|
samples = jnp.exp(samples)
|
581
|
-
return samples if
|
585
|
+
return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
|
582
586
|
|
583
587
|
def binomial(self,
|
584
588
|
n,
|
@@ -678,10 +682,10 @@ class RandomState(State):
|
|
678
682
|
cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
|
679
683
|
if isinstance(mean, bu.Quantity):
|
680
684
|
assert isinstance(cov, bu.Quantity)
|
681
|
-
assert mean.
|
682
|
-
mean = mean.
|
683
|
-
cov = cov.
|
684
|
-
|
685
|
+
assert mean.unit ** 2 == cov.unit
|
686
|
+
mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
|
687
|
+
cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
|
688
|
+
unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
|
685
689
|
|
686
690
|
key = self.split_key() if key is None else _formalize_key(key)
|
687
691
|
if not jnp.ndim(mean) >= 1:
|
@@ -708,7 +712,7 @@ class RandomState(State):
|
|
708
712
|
factor = jnp.linalg.cholesky(cov)
|
709
713
|
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
710
714
|
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
711
|
-
return r if
|
715
|
+
return r if unit.is_unitless else bu.Quantity(r, unit=unit)
|
712
716
|
|
713
717
|
def rayleigh(self,
|
714
718
|
scale=1.0,
|
{brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/RECORD
RENAMED
@@ -1,14 +1,14 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
|
2
2
|
brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
|
-
brainstate/_random_for_unit.py,sha256=
|
4
|
+
brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
|
5
5
|
brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
8
|
brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
|
9
9
|
brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
10
10
|
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
|
-
brainstate/random.py,sha256=
|
11
|
+
brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
13
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
14
14
|
brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
|
@@ -21,7 +21,7 @@ brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il
|
|
21
21
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
22
22
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
23
23
|
brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
|
24
|
-
brainstate/init/_random_inits.py,sha256=
|
24
|
+
brainstate/init/_random_inits.py,sha256=vNUVDdUOCXTx2i3i1enzxgg1USCzugYd56r0-2lBL-0,15919
|
25
25
|
brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
|
26
26
|
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
27
27
|
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
|
|
59
59
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.
|
63
|
-
brainstate-0.0.
|
64
|
-
brainstate-0.0.
|
65
|
-
brainstate-0.0.
|
66
|
-
brainstate-0.0.
|
62
|
+
brainstate-0.0.2.post20240814.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.2.post20240814.dist-info/METADATA,sha256=skMWlfxiGaJHxzQS7dY95V91umhXVjN_HuhGO0xHP1M,3805
|
64
|
+
brainstate-0.0.2.post20240814.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
+
brainstate-0.0.2.post20240814.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.2.post20240814.dist-info/RECORD,,
|
{brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|
{brainstate-0.0.1.1.post20240804.dist-info → brainstate-0.0.2.post20240814.dist-info}/top_level.txt
RENAMED
File without changes
|