bayinx 0.2.12__tar.gz → 0.2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {bayinx-0.2.12 → bayinx-0.2.13}/PKG-INFO +1 -1
  2. {bayinx-0.2.12 → bayinx-0.2.13}/pyproject.toml +1 -1
  3. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/core/variational.py +1 -1
  4. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/normal.py +9 -14
  5. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/normalizing_flow.py +3 -0
  6. {bayinx-0.2.12 → bayinx-0.2.13}/.github/workflows/release_and_publish.yml +0 -0
  7. {bayinx-0.2.12 → bayinx-0.2.13}/.gitignore +0 -0
  8. {bayinx-0.2.12 → bayinx-0.2.13}/README.md +0 -0
  9. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/__init__.py +0 -0
  10. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/core/__init__.py +0 -0
  11. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/core/flow.py +0 -0
  12. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/core/model.py +0 -0
  13. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/core/utils.py +0 -0
  14. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/__init__.py +0 -0
  15. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/bernoulli.py +0 -0
  16. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/binomial.py +0 -0
  17. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/gamma.py +0 -0
  18. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/dists/gamma2.py +0 -0
  19. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/__init__.py +0 -0
  20. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/__init__.py +0 -0
  21. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  22. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  23. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  24. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  25. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  26. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/meanfield.py +0 -0
  27. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/mhx/vi/standard.py +0 -0
  28. {bayinx-0.2.12 → bayinx-0.2.13}/src/bayinx/py.typed +0 -0
  29. {bayinx-0.2.12 → bayinx-0.2.13}/tests/__init__.py +0 -0
  30. {bayinx-0.2.12 → bayinx-0.2.13}/tests/test_variational.py +0 -0
  31. {bayinx-0.2.12 → bayinx-0.2.13}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.12"
3
+ version = "0.2.13"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -115,7 +115,7 @@ class Variational(eqx.Module):
115
115
 
116
116
  # Initialize optimizer
117
117
  optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
118
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
119
  )
120
120
  opt_state: OptState = optim.init(dyn)
121
121
 
@@ -1,17 +1,12 @@
1
- # MARK: Imports ----
2
1
  import jax.lax as _lax
2
+ from jaxtyping import Array, ArrayLike, Float, Real
3
3
 
4
- ## Typing
5
- from jaxtyping import Array, Real
6
-
7
- # MARK: Constants
8
4
  _PI = 3.141592653589793
9
5
 
10
6
 
11
- # MARK: Functions ----
12
7
  def prob(
13
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
- ) -> Real[Array, "..."]:
8
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
9
+ ) -> Float[Array, "..."]:
15
10
  """
16
11
  The probability density function (PDF) for a Normal distribution.
17
12
 
@@ -30,8 +25,8 @@ def prob(
30
25
 
31
26
 
32
27
  def logprob(
33
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
- ) -> Real[Array, "..."]:
28
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
29
+ ) -> Float[Array, "..."]:
35
30
  """
36
31
  The log of the probability density function (log PDF) for a Normal distribution.
37
32
 
@@ -48,8 +43,8 @@ def logprob(
48
43
 
49
44
 
50
45
  def uprob(
51
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
- ) -> Real[Array, "..."]:
46
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
47
+ ) -> Float[Array, "..."]:
53
48
  """
54
49
  The unnormalized probability density function (uPDF) for a Normal distribution.
55
50
 
@@ -66,8 +61,8 @@ def uprob(
66
61
 
67
62
 
68
63
  def ulogprob(
69
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
- ) -> Real[Array, "..."]:
64
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
65
+ ) -> Float[Array, "..."]:
71
66
  """
72
67
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
68
 
@@ -1,6 +1,8 @@
1
+ from functools import partial
1
2
  from typing import Any, Callable, Self, Tuple
2
3
 
3
4
  import equinox as eqx
5
+ import jax
4
6
  import jax.flatten_util as jfu
5
7
  import jax.numpy as jnp
6
8
  import jax.random as jr
@@ -59,6 +61,7 @@ class NormalizingFlow(Variational):
59
61
  return draws
60
62
 
61
63
  @eqx.filter_jit
64
+ @partial(jax.vmap, in_axes=(None, 0))
62
65
  def eval(self, draws: Array) -> Array:
63
66
  # Evaluate base density
64
67
  variational_evals: Array = self.base.eval(draws)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes