bayinx 0.2.12__tar.gz → 0.2.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bayinx-0.2.12 → bayinx-0.2.18}/PKG-INFO +1 -1
- {bayinx-0.2.12 → bayinx-0.2.18}/pyproject.toml +1 -1
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/variational.py +1 -1
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/normal.py +9 -14
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/standard.py +1 -1
- {bayinx-0.2.12 → bayinx-0.2.18}/.github/workflows/release_and_publish.yml +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/.gitignore +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/README.md +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/flow.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/model.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/core/utils.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/bernoulli.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/binomial.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/gamma.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/dists/gamma2.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/fullaffine.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/planar.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/radial.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/flows/sylvester.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/meanfield.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/mhx/vi/normalizing_flow.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/src/bayinx/py.typed +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/tests/__init__.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/tests/test_variational.py +0 -0
- {bayinx-0.2.12 → bayinx-0.2.18}/uv.lock +0 -0
@@ -115,7 +115,7 @@ class Variational(eqx.Module):
|
|
115
115
|
|
116
116
|
# Initialize optimizer
|
117
117
|
optim: GradientTransformation = opx.chain(
|
118
|
-
opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
|
118
|
+
opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
|
119
119
|
)
|
120
120
|
opt_state: OptState = optim.init(dyn)
|
121
121
|
|
@@ -1,17 +1,12 @@
|
|
1
|
-
# MARK: Imports ----
|
2
1
|
import jax.lax as _lax
|
2
|
+
from jaxtyping import Array, ArrayLike, Float, Real
|
3
3
|
|
4
|
-
## Typing
|
5
|
-
from jaxtyping import Array, Real
|
6
|
-
|
7
|
-
# MARK: Constants
|
8
4
|
_PI = 3.141592653589793
|
9
5
|
|
10
6
|
|
11
|
-
# MARK: Functions ----
|
12
7
|
def prob(
|
13
|
-
x: Real[
|
14
|
-
) ->
|
8
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
|
9
|
+
) -> Float[Array, "..."]:
|
15
10
|
"""
|
16
11
|
The probability density function (PDF) for a Normal distribution.
|
17
12
|
|
@@ -30,8 +25,8 @@ def prob(
|
|
30
25
|
|
31
26
|
|
32
27
|
def logprob(
|
33
|
-
x: Real[
|
34
|
-
) ->
|
28
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
|
29
|
+
) -> Float[Array, "..."]:
|
35
30
|
"""
|
36
31
|
The log of the probability density function (log PDF) for a Normal distribution.
|
37
32
|
|
@@ -48,8 +43,8 @@ def logprob(
|
|
48
43
|
|
49
44
|
|
50
45
|
def uprob(
|
51
|
-
x: Real[
|
52
|
-
) ->
|
46
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
|
47
|
+
) -> Float[Array, "..."]:
|
53
48
|
"""
|
54
49
|
The unnormalized probability density function (uPDF) for a Normal distribution.
|
55
50
|
|
@@ -66,8 +61,8 @@ def uprob(
|
|
66
61
|
|
67
62
|
|
68
63
|
def ulogprob(
|
69
|
-
x: Real[
|
70
|
-
) ->
|
64
|
+
x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
|
65
|
+
) -> Float[Array, "..."]:
|
71
66
|
"""
|
72
67
|
The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
|
73
68
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|