brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/random/_rand_state.py +14 -7
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/METADATA +1 -1
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/RECORD +6 -6
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt +0 -0
brainstate/random/_rand_state.py
CHANGED
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
+
from collections import namedtuple
|
19
|
+
from functools import partial
|
18
20
|
from operator import index
|
19
21
|
from typing import Optional
|
20
22
|
|
@@ -23,7 +25,8 @@ import jax
|
|
23
25
|
import jax.numpy as jnp
|
24
26
|
import jax.random as jr
|
25
27
|
import numpy as np
|
26
|
-
from jax import
|
28
|
+
from jax import jit, vmap
|
29
|
+
from jax import lax, core, dtypes
|
27
30
|
|
28
31
|
from brainstate import environ
|
29
32
|
from brainstate._state import State
|
@@ -75,6 +78,10 @@ class RandomState(State):
|
|
75
78
|
def clone(self):
|
76
79
|
return type(self)(self.split_key())
|
77
80
|
|
81
|
+
def set_key(self, key: SeedOrKey):
|
82
|
+
self.value = key
|
83
|
+
|
84
|
+
|
78
85
|
def seed(self, seed_or_key: Optional[SeedOrKey] = None):
|
79
86
|
"""Sets a new random seed.
|
80
87
|
|
@@ -492,7 +499,7 @@ class RandomState(State):
|
|
492
499
|
# Get upper and lower cdf values
|
493
500
|
sqrt2 = np.array(np.sqrt(2), dtype=dtype)
|
494
501
|
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
|
495
|
-
|
502
|
+
u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
|
496
503
|
|
497
504
|
# Uniformly fill tensor with values from [l, u], then translate to
|
498
505
|
# [2l-1, 2u-1].
|
@@ -500,7 +507,7 @@ class RandomState(State):
|
|
500
507
|
out = uniform_for_unit(
|
501
508
|
key, size, dtype,
|
502
509
|
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
503
|
-
maxval=lax.nextafter(2 *
|
510
|
+
maxval=lax.nextafter(2 * u_- 1, np.array(-np.inf, dtype=dtype))
|
504
511
|
)
|
505
512
|
|
506
513
|
# Use inverse cdf transform for normal distribution to get truncated
|
@@ -617,8 +624,8 @@ class RandomState(State):
|
|
617
624
|
size = jnp.shape(p)
|
618
625
|
key = self.split_key() if key is None else _formalize_key(key)
|
619
626
|
dtype = dtype or environ.dftype()
|
620
|
-
|
621
|
-
r = jnp.floor(jnp.log1p(-
|
627
|
+
u_ = uniform_for_unit(key, size, dtype=dtype)
|
628
|
+
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
|
622
629
|
return r
|
623
630
|
|
624
631
|
def _check_p2(self, p):
|
@@ -680,8 +687,8 @@ class RandomState(State):
|
|
680
687
|
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
|
681
688
|
|
682
689
|
if method == 'svd':
|
683
|
-
(
|
684
|
-
factor =
|
690
|
+
(u_, s, _) = jnp.linalg.svd(cov)
|
691
|
+
factor = u_ * jnp.sqrt(s[..., None, :])
|
685
692
|
elif method == 'eigh':
|
686
693
|
(w, v) = jnp.linalg.eigh(cov)
|
687
694
|
factor = v * jnp.sqrt(w[..., None, :])
|
@@ -61,7 +61,7 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
|
|
61
61
|
brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
|
62
62
|
brainstate/random/_rand_funs.py,sha256=mIoENR3iEVeVR-qCQ2UQVP0SEosWry4xhzhYr0UXPAI,132072
|
63
63
|
brainstate/random/_rand_seed.py,sha256=jRXP4zsQde1XyiOdx4aWpHXDvz5PAB1pogqQw0ywYnk,4633
|
64
|
-
brainstate/random/_rand_state.py,sha256=
|
64
|
+
brainstate/random/_rand_state.py,sha256=LDt3p7JCEidg7tBQ8GNUJ8aAIlj-PWUh21T5-yutmlI,53887
|
65
65
|
brainstate/random/_random_for_unit.py,sha256=Nm02GmMFiPx5LdfEDmbdK9hqLsebPdVd0bAQduGhASI,2017
|
66
66
|
brainstate/random/random_test.py,sha256=vw45gTn-39tu2WUuamR6rr0gb2h1eeClvcaJDj16Vz8,18402
|
67
67
|
brainstate/transform/__init__.py,sha256=hqef3a4sLQ_Oihuqs8E5IghSLr9o2bS7CWmwRL8jX6E,1887
|
@@ -80,8 +80,8 @@ brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2
|
|
80
80
|
brainstate/transform/_mapping.py,sha256=G9XUsD1xKLCprwwE0wv0gSXS0NYZ-ZIsv-PKKRlOoTA,3821
|
81
81
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
82
82
|
brainstate/transform/_unvmap.py,sha256=8Se_23QrwDdcJpFcUnnMgD6EP-4XylbhP9K5TDhW358,3311
|
83
|
-
brainstate-0.0.2.
|
84
|
-
brainstate-0.0.2.
|
85
|
-
brainstate-0.0.2.
|
86
|
-
brainstate-0.0.2.
|
87
|
-
brainstate-0.0.2.
|
83
|
+
brainstate-0.0.2.post20241010.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
84
|
+
brainstate-0.0.2.post20241010.dist-info/METADATA,sha256=yfpoOk0ZcT8Pd3_kC6AGm4CfGb6IrMbIgZKBgAxgBsQ,3311
|
85
|
+
brainstate-0.0.2.post20241010.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
86
|
+
brainstate-0.0.2.post20241010.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
87
|
+
brainstate-0.0.2.post20241010.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.0.2.post20241009.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt
RENAMED
File without changes
|