brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,8 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
+ from collections import namedtuple
19
+ from functools import partial
18
20
  from operator import index
19
21
  from typing import Optional
20
22
 
@@ -23,7 +25,8 @@ import jax
23
25
  import jax.numpy as jnp
24
26
  import jax.random as jr
25
27
  import numpy as np
26
- from jax import lax, core
28
+ from jax import jit, vmap
29
+ from jax import lax, core, dtypes
27
30
 
28
31
  from brainstate import environ
29
32
  from brainstate._state import State
@@ -75,6 +78,10 @@ class RandomState(State):
75
78
  def clone(self):
76
79
  return type(self)(self.split_key())
77
80
 
81
+ def set_key(self, key: SeedOrKey):
82
+ self.value = key
83
+
84
+
78
85
  def seed(self, seed_or_key: Optional[SeedOrKey] = None):
79
86
  """Sets a new random seed.
80
87
 
@@ -492,7 +499,7 @@ class RandomState(State):
492
499
  # Get upper and lower cdf values
493
500
  sqrt2 = np.array(np.sqrt(2), dtype=dtype)
494
501
  l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
495
- u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
502
+ u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
496
503
 
497
504
  # Uniformly fill tensor with values from [l, u], then translate to
498
505
  # [2l-1, 2u-1].
@@ -500,7 +507,7 @@ class RandomState(State):
500
507
  out = uniform_for_unit(
501
508
  key, size, dtype,
502
509
  minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
503
- maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))
510
+ maxval=lax.nextafter(2 * u_- 1, np.array(-np.inf, dtype=dtype))
504
511
  )
505
512
 
506
513
  # Use inverse cdf transform for normal distribution to get truncated
@@ -617,8 +624,8 @@ class RandomState(State):
617
624
  size = jnp.shape(p)
618
625
  key = self.split_key() if key is None else _formalize_key(key)
619
626
  dtype = dtype or environ.dftype()
620
- u = uniform_for_unit(key, size, dtype=dtype)
621
- r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
627
+ u_ = uniform_for_unit(key, size, dtype=dtype)
628
+ r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
622
629
  return r
623
630
 
624
631
  def _check_p2(self, p):
@@ -680,8 +687,8 @@ class RandomState(State):
680
687
  _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
681
688
 
682
689
  if method == 'svd':
683
- (u, s, _) = jnp.linalg.svd(cov)
684
- factor = u * jnp.sqrt(s[..., None, :])
690
+ (u_, s, _) = jnp.linalg.svd(cov)
691
+ factor = u_ * jnp.sqrt(s[..., None, :])
685
692
  elif method == 'eigh':
686
693
  (w, v) = jnp.linalg.eigh(cov)
687
694
  factor = v * jnp.sqrt(w[..., None, :])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.2.post20241009
3
+ Version: 0.0.2.post20241010
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -61,7 +61,7 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
61
61
  brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
62
62
  brainstate/random/_rand_funs.py,sha256=mIoENR3iEVeVR-qCQ2UQVP0SEosWry4xhzhYr0UXPAI,132072
63
63
  brainstate/random/_rand_seed.py,sha256=jRXP4zsQde1XyiOdx4aWpHXDvz5PAB1pogqQw0ywYnk,4633
64
- brainstate/random/_rand_state.py,sha256=D3-cvu-0mZZjnB2Wy4xmeqfm0oTvDHHPvV6SqXb7Of8,53716
64
+ brainstate/random/_rand_state.py,sha256=LDt3p7JCEidg7tBQ8GNUJ8aAIlj-PWUh21T5-yutmlI,53887
65
65
  brainstate/random/_random_for_unit.py,sha256=Nm02GmMFiPx5LdfEDmbdK9hqLsebPdVd0bAQduGhASI,2017
66
66
  brainstate/random/random_test.py,sha256=vw45gTn-39tu2WUuamR6rr0gb2h1eeClvcaJDj16Vz8,18402
67
67
  brainstate/transform/__init__.py,sha256=hqef3a4sLQ_Oihuqs8E5IghSLr9o2bS7CWmwRL8jX6E,1887
@@ -80,8 +80,8 @@ brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2
80
80
  brainstate/transform/_mapping.py,sha256=G9XUsD1xKLCprwwE0wv0gSXS0NYZ-ZIsv-PKKRlOoTA,3821
81
81
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
82
82
  brainstate/transform/_unvmap.py,sha256=8Se_23QrwDdcJpFcUnnMgD6EP-4XylbhP9K5TDhW358,3311
83
- brainstate-0.0.2.post20241009.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
84
- brainstate-0.0.2.post20241009.dist-info/METADATA,sha256=2RH74ehNlwwLBu-0oRYZdpmpEVVsGKZcrAv3BHXVPcg,3311
85
- brainstate-0.0.2.post20241009.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
86
- brainstate-0.0.2.post20241009.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
87
- brainstate-0.0.2.post20241009.dist-info/RECORD,,
83
+ brainstate-0.0.2.post20241010.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
84
+ brainstate-0.0.2.post20241010.dist-info/METADATA,sha256=yfpoOk0ZcT8Pd3_kC6AGm4CfGb6IrMbIgZKBgAxgBsQ,3311
85
+ brainstate-0.0.2.post20241010.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
86
+ brainstate-0.0.2.post20241010.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
87
+ brainstate-0.0.2.post20241010.dist-info/RECORD,,