bayinx 0.2.12__tar.gz → 0.2.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {bayinx-0.2.12 → bayinx-0.2.18}/PKG-INFO +1 -1
  2. {bayinx-0.2.12 → bayinx-0.2.18}/pyproject.toml +1 -1
  3. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/variational.py +1 -1
  4. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/normal.py +9 -14
  5. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/standard.py +1 -1
  6. {bayinx-0.2.12 → bayinx-0.2.18}/.github/workflows/release_and_publish.yml +0 -0
  7. {bayinx-0.2.12 → bayinx-0.2.18}/.gitignore +0 -0
  8. {bayinx-0.2.12 → bayinx-0.2.18}/README.md +0 -0
  9. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/__init__.py +0 -0
  10. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/__init__.py +0 -0
  11. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/flow.py +0 -0
  12. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/model.py +0 -0
  13. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/utils.py +0 -0
  14. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/__init__.py +0 -0
  15. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/bernoulli.py +0 -0
  16. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/binomial.py +0 -0
  17. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/gamma.py +0 -0
  18. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/gamma2.py +0 -0
  19. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/__init__.py +0 -0
  20. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/__init__.py +0 -0
  21. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
  22. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
  23. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/planar.py +0 -0
  24. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/radial.py +0 -0
  25. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
  26. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/meanfield.py +0 -0
  27. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
  28. {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/py.typed +0 -0
  29. {bayinx-0.2.12 → bayinx-0.2.18}/tests/__init__.py +0 -0
  30. {bayinx-0.2.12 → bayinx-0.2.18}/tests/test_variational.py +0 -0
  31. {bayinx-0.2.12 → bayinx-0.2.18}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.12
3
+ Version: 0.2.18
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "bayinx"
3
- version = "0.2.12"
3
+ version = "0.2.18"
4
4
  description = "Bayesian Inference with JAX"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -115,7 +115,7 @@ class Variational(eqx.Module):
115
115
 
116
116
  # Initialize optimizer
117
117
  optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
118
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
119
  )
120
120
  opt_state: OptState = optim.init(dyn)
121
121
 
@@ -1,17 +1,12 @@
1
- # MARK: Imports ----
2
1
  import jax.lax as _lax
2
+ from jaxtyping import Array, ArrayLike, Float, Real
3
3
 
4
- ## Typing
5
- from jaxtyping import Array, Real
6
-
7
- # MARK: Constants
8
4
  _PI = 3.141592653589793
9
5
 
10
6
 
11
- # MARK: Functions ----
12
7
  def prob(
13
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
- ) -> Real[Array, "..."]:
8
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
9
+ ) -> Float[Array, "..."]:
15
10
  """
16
11
  The probability density function (PDF) for a Normal distribution.
17
12
 
@@ -30,8 +25,8 @@ def prob(
30
25
 
31
26
 
32
27
  def logprob(
33
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
- ) -> Real[Array, "..."]:
28
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
29
+ ) -> Float[Array, "..."]:
35
30
  """
36
31
  The log of the probability density function (log PDF) for a Normal distribution.
37
32
 
@@ -48,8 +43,8 @@ def logprob(
48
43
 
49
44
 
50
45
  def uprob(
51
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
- ) -> Real[Array, "..."]:
46
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
47
+ ) -> Float[Array, "..."]:
53
48
  """
54
49
  The unnormalized probability density function (uPDF) for a Normal distribution.
55
50
 
@@ -66,8 +61,8 @@ def uprob(
66
61
 
67
62
 
68
63
  def ulogprob(
69
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
- ) -> Real[Array, "..."]:
64
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
65
+ ) -> Float[Array, "..."]:
71
66
  """
72
67
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
68
 
@@ -52,7 +52,7 @@ class Standard(Variational):
52
52
  x=draws,
53
53
  mu=jnp.array(0.0),
54
54
  sigma=jnp.array(1.0),
55
- ).sum(axis=1)
55
+ ).sum(axis=1, keepdims=True)
56
56
 
57
57
  @eqx.filter_jit
58
58
  def filter_spec(self):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes