bayinx 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -115,7 +115,7 @@ class Variational(eqx.Module):
115
115
 
116
116
  # Initialize optimizer
117
117
  optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
118
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
119
  )
120
120
  opt_state: OptState = optim.init(dyn)
121
121
 
bayinx/dists/normal.py CHANGED
@@ -1,17 +1,12 @@
1
- # MARK: Imports ----
2
1
  import jax.lax as _lax
2
+ from jaxtyping import Array, ArrayLike, Float, Real
3
3
 
4
- ## Typing
5
- from jaxtyping import Array, Real
6
-
7
- # MARK: Constants
8
4
  _PI = 3.141592653589793
9
5
 
10
6
 
11
- # MARK: Functions ----
12
7
  def prob(
13
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
- ) -> Real[Array, "..."]:
8
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
9
+ ) -> Float[Array, "..."]:
15
10
  """
16
11
  The probability density function (PDF) for a Normal distribution.
17
12
 
@@ -30,8 +25,8 @@ def prob(
30
25
 
31
26
 
32
27
  def logprob(
33
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
- ) -> Real[Array, "..."]:
28
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
29
+ ) -> Float[Array, "..."]:
35
30
  """
36
31
  The log of the probability density function (log PDF) for a Normal distribution.
37
32
 
@@ -48,8 +43,8 @@ def logprob(
48
43
 
49
44
 
50
45
  def uprob(
51
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
- ) -> Real[Array, "..."]:
46
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
47
+ ) -> Float[Array, "..."]:
53
48
  """
54
49
  The unnormalized probability density function (uPDF) for a Normal distribution.
55
50
 
@@ -66,8 +61,8 @@ def uprob(
66
61
 
67
62
 
68
63
  def ulogprob(
69
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
- ) -> Real[Array, "..."]:
64
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
65
+ ) -> Float[Array, "..."]:
71
66
  """
72
67
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
68
 
@@ -3,7 +3,6 @@ from typing import Callable, Dict, Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax
6
- import jax.nn as jnn
7
6
  import jax.numpy as jnp
8
7
  import jax.random as jr
9
8
  from jaxtyping import Array, Float, Scalar
@@ -31,25 +30,12 @@ class Planar(Flow):
31
30
  - `dim`: The dimension of the parameter space.
32
31
  """
33
32
  self.params = {
34
- "u": jr.normal(key, (dim,)),
35
- "w": jr.normal(key, (dim,)),
36
- "b": jr.normal(key, (1,)),
33
+ "u": jnp.ones(dim),
34
+ "w": jnp.ones(dim),
35
+ "b": jnp.zeros(1),
37
36
  }
38
37
  self.constraints = {}
39
38
 
40
- def transform_pars(self):
41
- params = self.params
42
-
43
- u = params['u']
44
- w = params['w']
45
- b = params['b']
46
-
47
- m = jnn.softplus(w.dot(u)) - 1.0
48
-
49
- u = u + (m - w.dot(u)) * w / (w**2).sum()
50
-
51
- return {'u': u, 'w': w, 'b': b}
52
-
53
39
  @eqx.filter_jit
54
40
  @partial(jax.vmap, in_axes=(None, 0))
55
41
  def forward(self, draws: Array) -> Array:
@@ -1,6 +1,8 @@
1
+ from functools import partial
1
2
  from typing import Any, Callable, Self, Tuple
2
3
 
3
4
  import equinox as eqx
5
+ import jax
4
6
  import jax.flatten_util as jfu
5
7
  import jax.numpy as jnp
6
8
  import jax.random as jr
@@ -59,6 +61,7 @@ class NormalizingFlow(Variational):
59
61
  return draws
60
62
 
61
63
  @eqx.filter_jit
64
+ @partial(jax.vmap, in_axes=(None, 0))
62
65
  def eval(self, draws: Array) -> Array:
63
66
  # Evaluate base density
64
67
  variational_evals: Array = self.base.eval(draws)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -4,23 +4,23 @@ bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
4
  bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
5
5
  bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
6
6
  bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
7
- bayinx/core/variational.py,sha256=3CsDyQkq1XgV2ZBLzGrm5XgUFoJBnT6glHDgxHNcbTc,5250
7
+ bayinx/core/variational.py,sha256=k9wWn7Tnj3eET-qK1pZtzDyPZVvQTRUexJUBVSdGXOA,5251
8
8
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
10
  bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
13
+ bayinx/dists/normal.py,sha256=OOKg46y5hHFP76ydbRjEXaDkgefZcj9sd0XAl7yokww,2587
14
14
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
16
16
  bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
17
- bayinx/mhx/vi/normalizing_flow.py,sha256=XBUWYZpm_Ipi6X9oTnGhqIs3ARY-5QFiuxM7uAWFRps,4790
17
+ bayinx/mhx/vi/normalizing_flow.py,sha256=O9U40Z7ANAh4Weqs7jaNHqmG5UkdlNooUP8Vx1u7hwg,4873
18
18
  bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
19
19
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
20
20
  bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
21
- bayinx/mhx/vi/flows/planar.py,sha256=u4heNqxpmfXsACsE3RH8XkBd04Emd8G674M3sSd7CxM,2232
21
+ bayinx/mhx/vi/flows/planar.py,sha256=qmtWpIBXRct2seI78pkmtF0X7cASUBELqmZmf2QS5Gs,1918
22
22
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
23
23
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
24
- bayinx-0.2.11.dist-info/METADATA,sha256=TkIGXb5baSFeP5K4s1bynQ3QqNsR-DKePlAOV_vV0Nc,3058
25
- bayinx-0.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- bayinx-0.2.11.dist-info/RECORD,,
24
+ bayinx-0.2.13.dist-info/METADATA,sha256=3Do4je5M1N3-DKmo0tG1a0ULG91ECcCQe-B60E-EuPA,3058
25
+ bayinx-0.2.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ bayinx-0.2.13.dist-info/RECORD,,