bayinx 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/core/flow.py CHANGED
@@ -5,10 +5,8 @@ import equinox as eqx
5
5
  import jax.tree_util as jtu
6
6
  from jaxtyping import Array, Float
7
7
 
8
- from bayinx.core.utils import __MyMeta
9
8
 
10
-
11
- class Flow(eqx.Module, metaclass=__MyMeta):
9
+ class Flow(eqx.Module):
12
10
  """
13
11
  A superclass used to define continuously parameterized diffeomorphisms for normalizing flows.
14
12
 
bayinx/core/model.py CHANGED
@@ -7,10 +7,9 @@ import jax.tree_util as jtu
7
7
  from jaxtyping import Array, Scalar
8
8
 
9
9
  from bayinx.core.constraints import Constraint
10
- from bayinx.core.utils import __MyMeta
11
10
 
12
11
 
13
- class Model(eqx.Module, metaclass=__MyMeta):
12
+ class Model(eqx.Module):
14
13
  """
15
14
  A superclass used to define probabilistic models.
16
15
 
bayinx/core/utils.py CHANGED
@@ -1,54 +1 @@
1
- from typing import Callable, Dict
2
1
 
3
- import equinox as eqx
4
- from jaxtyping import Array
5
-
6
-
7
- class __MyMeta(type(eqx.Module)):
8
- """
9
- Metaclass to ensure attribute types are respected.
10
- """
11
-
12
- def __call__(cls, *args, **kwargs):
13
- obj = super().__call__(*args, **kwargs)
14
-
15
- # Check parameters are a Dict of JAX Arrays
16
- if not isinstance(obj.params, Dict):
17
- raise ValueError(
18
- f"Model {cls.__name__} must initialize 'params' as a dictionary."
19
- )
20
-
21
- for key, value in obj.params.items():
22
- if not isinstance(value, Array):
23
- raise TypeError(f"Parameter '{key}' must be a JAX Array.")
24
-
25
- # Check constraints are a Dict of functions
26
- if not isinstance(obj.constraints, Dict):
27
- raise ValueError(
28
- f"Model {cls.__name__} must initialize 'constraints' as a dictionary."
29
- )
30
-
31
- for key, value in obj.constraints.items():
32
- if not isinstance(value, Callable):
33
- raise TypeError(f"Constraint for parameter '{key}' must be a function.")
34
-
35
- # Check that the constrain method returns a dict equivalent to `params`
36
- t_params: Dict[str, Array] = obj.constrain_pars()
37
-
38
- if not isinstance(t_params, Dict):
39
- raise ValueError(
40
- f"The 'constrain' method of {cls.__name__} must return a Dict of JAX Arrays."
41
- )
42
-
43
- for key, value in t_params.items():
44
- if not isinstance(value, Array):
45
- raise TypeError(f"Constrained parameter '{key}' must be a JAX Array.")
46
-
47
- if not value.shape == obj.params[key].shape:
48
- raise ValueError(
49
- f"Constrained parameter '{key}' must have same shape as unconstrained counterpart."
50
- )
51
-
52
- # Check transform_pars
53
-
54
- return obj
@@ -33,7 +33,22 @@ class FullAffine(Flow):
33
33
  "scale": jnp.zeros((dim, dim)),
34
34
  }
35
35
 
36
- self.constraints = {"scale": lambda m: jnp.tril(jnp.exp(m))}
36
+ self.constraints = {"scale": lambda m: jnp.tril(m)}
37
+
38
+ def transform_pars(self):
39
+ # Get constrained parameters
40
+ params = self.constrain_pars()
41
+
42
+ # Extract diagonal and apply exponential
43
+ diag: Array = jnp.exp(jnp.diag(params['scale']))
44
+
45
+ # Fill diagonal
46
+ params['scale'] = jnp.fill_diagonal(params['scale'], diag)
47
+
48
+
49
+ return params
50
+
51
+
37
52
 
38
53
  @eqx.filter_jit
39
54
  def forward(self, draws: Array) -> Array:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.24
3
+ Version: 0.2.26
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -2,9 +2,9 @@ bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
4
  bayinx/core/constraints.py,sha256=Y8FJX3CkgnLQ3HXuTPGuzvLtXlKs0B7z0-YymoHgdfg,1682
5
- bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
6
- bayinx/core/model.py,sha256=t7s5Yt4E3iC_MPvynJnk6lb4OLal7Gnk59tIZ6e5M4I,2203
7
- bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
5
+ bayinx/core/flow.py,sha256=9swS5wh7AsIZWgG_IhQS-upcPlr7G-juaP_5rsbX6G0,2363
6
+ bayinx/core/model.py,sha256=U1xBnAXnIvFJjWF-XIT8BYjP9PtoRZY_PwyhRwJf-HA,2144
7
+ bayinx/core/utils.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
8
  bayinx/core/variational.py,sha256=vUZ6u5CXCHfs6ZrA8PF4PHfmUXHTK2RJGHyZ3afFfsg,4820
9
9
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
@@ -19,10 +19,10 @@ bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,38
19
19
  bayinx/mhx/vi/normalizing_flow.py,sha256=nj7bpIoMJl6GTOXPxQCAsPArchbHd5vwwqMm3cLbMII,4791
20
20
  bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
21
21
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
22
- bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
22
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=TUcjXuDeFLHL9SYOLEU6kQSkEiyijztfBY2AAis7Pn0,2034
23
23
  bayinx/mhx/vi/flows/planar.py,sha256=u9ZVwEeOv4fMfwiORlseCz463atsWMuid_9crRg05Z8,1919
24
24
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
25
25
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
26
- bayinx-0.2.24.dist-info/METADATA,sha256=sR0C0Pk5vrAmdvAtB3faXZO-hIDpKzqLjnXcfMsikjw,3058
27
- bayinx-0.2.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- bayinx-0.2.24.dist-info/RECORD,,
26
+ bayinx-0.2.26.dist-info/METADATA,sha256=EReEQKQXgy71MtuqSgFRkn0YhVOUKT1hFECt12YEXxE,3058
27
+ bayinx-0.2.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ bayinx-0.2.26.dist-info/RECORD,,