bayinx 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from bayinx.core import Model as Model
2
2
  from bayinx.core import Parameter as Parameter
3
+ from bayinx.core.model import constrain as constrain
@@ -0,0 +1 @@
1
+ from bayinx.constraints.lower import Lower as Lower
bayinx/core/flow.py CHANGED
@@ -11,7 +11,7 @@ class Flow(eqx.Module):
11
11
  An abstract base class for a flow(of a normalizing flow).
12
12
 
13
13
  # Attributes
14
- - `pars`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
14
+ - `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
15
15
  - `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
16
16
  """
17
17
 
bayinx/core/model.py CHANGED
@@ -1,26 +1,29 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Dict, Generic, Tuple, TypeVar
2
+ from dataclasses import field, fields
3
+ from typing import Any, Self, Tuple
3
4
 
4
5
  import equinox as eqx
5
6
  import jax.numpy as jnp
6
7
  import jax.tree as jt
7
- from jaxtyping import PyTree, Scalar
8
+ from jaxtyping import Scalar
8
9
 
9
10
  from bayinx.core.constraint import Constraint
10
11
  from bayinx.core.parameter import Parameter
11
12
 
12
- T = TypeVar('T', bound=PyTree)
13
- class Model(eqx.Module, Generic[T]):
13
+
14
+ def constrain(constraint: Constraint):
15
+ """Defines constraint metadata."""
16
+ return field(metadata={'constraint': constraint})
17
+
18
+
19
+ class Model(eqx.Module):
14
20
  """
15
21
  An abstract base class used to define probabilistic models.
16
22
 
17
- # Attributes
18
- - `params`: A dictionary of parameters.
19
- - `constraints`: A dictionary of constraints.
20
- """
23
+ Annotate parameter attributes with `Parameter`.
21
24
 
22
- params: Dict[str, Parameter[T]]
23
- constraints: Dict[str, Constraint]
25
+ Include constraints by setting them equal to `constrain(Constraint)`.
26
+ """
24
27
 
25
28
  @abstractmethod
26
29
  def eval(self, data: Any) -> Scalar:
@@ -29,50 +32,71 @@ class Model(eqx.Module, Generic[T]):
29
32
  # Default filter specification
30
33
  @property
31
34
  @eqx.filter_jit
32
- def filter_spec(self):
35
+ def filter_spec(self) -> Self:
33
36
  """
34
37
  Generates a filter specification to subset relevant parameters for the model.
35
38
  """
36
39
  # Generate empty specification
37
- filter_spec = jt.map(lambda _: False, self)
40
+ filter_spec: Self = jt.map(lambda _: False, self)
41
+
42
+ for f in fields(self):
43
+ # Extract attribute from field
44
+ attr = getattr(self, f.name)
38
45
 
39
- # Specify relevant parameters
40
- filter_spec = eqx.tree_at(
41
- lambda model: model.params,
42
- filter_spec,
43
- replace={key: param.filter_spec for key, param in self.params.items()}
44
- )
46
+ # Check if attribute is a parameter
47
+ if isinstance(attr, Parameter):
48
+ # Update filter specification for parameter
49
+ filter_spec = eqx.tree_at(
50
+ lambda model: getattr(model, f.name),
51
+ filter_spec,
52
+ replace=attr.filter_spec
53
+ )
45
54
 
46
55
  return filter_spec
47
56
 
48
- # Add constrain method
57
+
49
58
  @eqx.filter_jit
50
- def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
59
+ def constrain_params(self) -> Tuple[Self, Scalar]:
51
60
  """
52
- Constrain `params` to the appropriate domain.
61
+ Constrain parameters to the appropriate domain.
53
62
 
54
63
  # Returns
55
- A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
64
+ A constrained `Model` object and the adjustment to the posterior.
56
65
  """
57
- t_params: Dict[str, Parameter[T]] = self.params
66
+ constrained: Self = self
58
67
  target: Scalar = jnp.array(0.0)
59
68
 
60
- for par, map in self.constraints.items():
61
- # Apply transformation
62
- t_params[par], ladj = map.constrain(t_params[par])
69
+ for f in fields(self):
70
+ # Extract attribute
71
+ attr = getattr(self, f.name)
72
+
73
+ # Check if constrained parameter
74
+ if isinstance(attr, Parameter) and 'constraint' in f.metadata:
75
+ param = attr
76
+ constraint = f.metadata['constraint']
77
+
78
+ # Apply constraint
79
+ param, laj = constraint.constrain(param)
80
+
81
+ # Update parameters for constrained model
82
+ constrained = eqx.tree_at(
83
+ lambda model: getattr(model, f.name),
84
+ constrained,
85
+ replace=param
86
+ )
87
+
88
+ # Adjust posterior density
89
+ target += laj
63
90
 
64
- # Adjust posterior density
65
- target -= ladj
91
+ return constrained, target
66
92
 
67
- return t_params, target
68
93
 
69
- # Add default transform method
70
94
  @eqx.filter_jit
71
- def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
95
+ def transform_params(self) -> Tuple[Self, Scalar]:
72
96
  """
73
- Apply a custom transformation to `params` if needed.
97
+ Apply a custom transformation to parameters if needed(defaults to constrained parameters).
74
98
 
75
99
  # Returns
76
- A dictionary of transformed JAX Arrays representing the transformed parameters.
100
+ A transformed `Model` object and the adjustment to the posterior.
77
101
  """
78
102
  return self.constrain_params()
bayinx/core/parameter.py CHANGED
@@ -21,6 +21,9 @@ class Parameter(eqx.Module, Generic[T]):
21
21
  # Insert parameter values
22
22
  self.vals = values
23
23
 
24
+ def __call__(self) -> T:
25
+ return self.vals
26
+
24
27
  # Default filter specification
25
28
  @property
26
29
  @eqx.filter_jit
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -1,12 +1,12 @@
1
- bayinx/__init__.py,sha256=htihTsJ54k-ljBLzt4ye8DR7ORwZhxv-bLMcEaDQeqY,86
1
+ bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
2
2
  bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
4
4
  bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
5
5
  bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
6
  bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
- bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
8
- bayinx/core/model.py,sha256=RKAtsLXc6xDnXWz5upmx6Vz6JOoorw4WTfxTA7B7Lmg,2294
9
- bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
7
+ bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
+ bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
9
+ bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
10
10
  bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
11
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
@@ -25,6 +25,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
25
25
  bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
26
26
  bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
27
27
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
28
- bayinx-0.3.1.dist-info/METADATA,sha256=Slp2nxR8HISCwCqIXY2El3GgqsO1v9_UbVeJq726w7k,3057
29
- bayinx-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
- bayinx-0.3.1.dist-info/RECORD,,
28
+ bayinx-0.3.3.dist-info/METADATA,sha256=5tZVPxDvYVajnoDrRIF-KskSMWBix7Zq4h_glyJa-_M,3057
29
+ bayinx-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
+ bayinx-0.3.3.dist-info/RECORD,,
File without changes