bayinx 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
6
6
  from bayinx.core.constraint import Constraint
7
7
 
8
8
 
9
- class LowerBound(Constraint):
9
+ class Lower(Constraint):
10
10
  """
11
11
  Enforces a lower bound on the parameter.
12
12
  """
@@ -18,7 +18,7 @@ class LowerBound(Constraint):
18
18
 
19
19
  def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
20
20
  """
21
- Applies the lower bound constraint and computes the laj.
21
+ Applies the lower bound constraint and adjusts the posterior density.
22
22
 
23
23
  # Parameters
24
24
  - `x`: The unconstrained JAX Array-like input.
@@ -26,12 +26,12 @@ class LowerBound(Constraint):
26
26
  # Parameters
27
27
  A tuple containing:
28
28
  - The constrained JAX Array (x > self.lb).
29
- - A scalar JAX Array representing the laj of the transformation.
29
+ - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
30
30
  """
31
31
  # Compute transformation adjustment
32
- ladj: Scalar = jnp.sum(x)
32
+ laj: Scalar = jnp.sum(x)
33
33
 
34
34
  # Compute transformation
35
35
  x = jnp.exp(x) + self.lb
36
36
 
37
- return x, ladj
37
+ return x, laj
bayinx/core/constraint.py CHANGED
@@ -13,7 +13,7 @@ class Constraint(eqx.Module):
13
13
  @abstractmethod
14
14
  def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
15
15
  """
16
- Applies the constraining transformation to an unconstrained input and computes the log-absolute-jacobian of the transformation.
16
+ Applies the constraining transformation to an unconstrained input and computes the log-absolute-Jacobian of the transformation.
17
17
 
18
18
  # Parameters
19
19
  - `x`: The unconstrained JAX Array-like input.
@@ -21,6 +21,6 @@ class Constraint(eqx.Module):
21
21
  # Returns
22
22
  A tuple containing:
23
23
  - The constrained JAX Array.
24
- - A scalar JAX Array representing the laj of the transformation.
24
+ - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
25
25
  """
26
26
  pass
@@ -40,7 +40,7 @@ class FullAffine(Flow):
40
40
  diag: Array = jnp.exp(jnp.diag(scale))
41
41
 
42
42
  # Return matrix with modified diagonal
43
- return jnp.fill_diagonal(scale, diag, inplace=False)
43
+ return jnp.fill_diagonal(jnp.tril(scale), diag, inplace=False)
44
44
 
45
45
  self.constraints = {"scale": constrain_scale}
46
46
 
@@ -75,7 +75,7 @@ class NormalizingFlow(Variational):
75
75
 
76
76
  for map in self.flows:
77
77
  # Compute adjustment
78
- laj, draws = map.adjust_density(draws)
78
+ draws, laj = map.adjust_density(draws)
79
79
 
80
80
  # Adjust variational density
81
81
  variational_evals = variational_evals - laj
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.30
3
+ Version: 0.2.32
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,9 +1,9 @@
1
1
  bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- bayinx/constraints/lower.py,sha256=O37qJ6ojRKbKGJlnQ7Vv7P2VGARcnKrOifwyUNBHma8,912
4
+ bayinx/constraints/lower.py,sha256=MAAsWpZhqu1ySMskQ0fPVkCzW6CVDCU67q2bkCyzbLc,936
5
5
  bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
6
- bayinx/core/constraint.py,sha256=T1QP3WV9RU1nPW-9H7loc6DyBzElHVyDp1c4mmZpnmI,733
6
+ bayinx/core/constraint.py,sha256=60KzDILVLQTCY3jt4YEnNKJ5gnG8IHTv_nNqrdwt_3c,751
7
7
  bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
8
8
  bayinx/core/model.py,sha256=vfEnqBpHE2MtuJPIDgKvVYIv5n53E2e-KAAXEtqEy0c,2126
9
9
  bayinx/core/variational.py,sha256=2stsYKZDri1rLP7mrz7X2GWehBXNESdlWtmF2N9CEas,4787
@@ -17,13 +17,13 @@ bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1
17
17
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
18
18
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
19
19
  bayinx/mhx/vi/meanfield.py,sha256=BobfTagVGA5R-dclv-E0jSA80KZg1X6GGjiw7XR61vE,3643
20
- bayinx/mhx/vi/normalizing_flow.py,sha256=FvxDtqGRtaEeeF-bXCYnIEAvOOXVHKUK0oCTF9ma02Y,4622
20
+ bayinx/mhx/vi/normalizing_flow.py,sha256=DYhvTiu2Fr5x8KpWAMZVUaio7ctG2X2SMUO0l8zfZ5g,4622
21
21
  bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
22
22
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
23
- bayinx/mhx/vi/flows/fullaffine.py,sha256=s-fxgzv84BEqNMnmLt6vtvwJqBzixCC2OwWXLz3IK-w,1940
23
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=Z_G2Cg90Asgvqel8buMx5kEVsiIxDDh8rc-L_nP9OCY,1950
24
24
  bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
25
25
  bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
26
26
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
27
- bayinx-0.2.30.dist-info/METADATA,sha256=VVDldLMllMzTkx_tphb6k2n4u-PSwBbmFOGN-PG1BWc,3058
28
- bayinx-0.2.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
- bayinx-0.2.30.dist-info/RECORD,,
27
+ bayinx-0.2.32.dist-info/METADATA,sha256=C6xrmvyGJ573nlEObTgIb_CWiPDOY0zAtQj75qbgCc4,3058
28
+ bayinx-0.2.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
+ bayinx-0.2.32.dist-info/RECORD,,