bayinx 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayinx/__init__.py +1 -0
- bayinx/core/flow.py +1 -1
- bayinx/core/model.py +57 -33
- bayinx/core/parameter.py +3 -0
- bayinx/dists/censored/gamma2/__init__.py +1 -0
- {bayinx-0.3.2.dist-info → bayinx-0.3.4.dist-info}/METADATA +1 -1
- {bayinx-0.3.2.dist-info → bayinx-0.3.4.dist-info}/RECORD +8 -7
- {bayinx-0.3.2.dist-info → bayinx-0.3.4.dist-info}/WHEEL +0 -0
bayinx/__init__.py
CHANGED
bayinx/core/flow.py
CHANGED
@@ -11,7 +11,7 @@ class Flow(eqx.Module):
|
|
11
11
|
An abstract base class for a flow(of a normalizing flow).
|
12
12
|
|
13
13
|
# Attributes
|
14
|
-
- `
|
14
|
+
- `params`: A dictionary of JAX Arrays representing parameters of the diffeomorphism.
|
15
15
|
- `constraints`: A dictionary of simple functions that constrain their corresponding parameter.
|
16
16
|
"""
|
17
17
|
|
bayinx/core/model.py
CHANGED
@@ -1,26 +1,29 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from
|
2
|
+
from dataclasses import field, fields
|
3
|
+
from typing import Any, Self, Tuple
|
3
4
|
|
4
5
|
import equinox as eqx
|
5
6
|
import jax.numpy as jnp
|
6
7
|
import jax.tree as jt
|
7
|
-
from jaxtyping import
|
8
|
+
from jaxtyping import Scalar
|
8
9
|
|
9
10
|
from bayinx.core.constraint import Constraint
|
10
11
|
from bayinx.core.parameter import Parameter
|
11
12
|
|
12
|
-
|
13
|
-
|
13
|
+
|
14
|
+
def constrain(constraint: Constraint):
|
15
|
+
"""Defines constraint metadata."""
|
16
|
+
return field(metadata={'constraint': constraint})
|
17
|
+
|
18
|
+
|
19
|
+
class Model(eqx.Module):
|
14
20
|
"""
|
15
21
|
An abstract base class used to define probabilistic models.
|
16
22
|
|
17
|
-
|
18
|
-
- `params`: A dictionary of parameters.
|
19
|
-
- `constraints`: A dictionary of constraints.
|
20
|
-
"""
|
23
|
+
Annotate parameter attributes with `Parameter`.
|
21
24
|
|
22
|
-
|
23
|
-
|
25
|
+
Include constraints by setting them equal to `constrain(Constraint)`.
|
26
|
+
"""
|
24
27
|
|
25
28
|
@abstractmethod
|
26
29
|
def eval(self, data: Any) -> Scalar:
|
@@ -29,50 +32,71 @@ class Model(eqx.Module, Generic[P]):
|
|
29
32
|
# Default filter specification
|
30
33
|
@property
|
31
34
|
@eqx.filter_jit
|
32
|
-
def filter_spec(self):
|
35
|
+
def filter_spec(self) -> Self:
|
33
36
|
"""
|
34
37
|
Generates a filter specification to subset relevant parameters for the model.
|
35
38
|
"""
|
36
39
|
# Generate empty specification
|
37
|
-
filter_spec = jt.map(lambda _: False, self)
|
40
|
+
filter_spec: Self = jt.map(lambda _: False, self)
|
41
|
+
|
42
|
+
for f in fields(self):
|
43
|
+
# Extract attribute from field
|
44
|
+
attr = getattr(self, f.name)
|
38
45
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
46
|
+
# Check if attribute is a parameter
|
47
|
+
if isinstance(attr, Parameter):
|
48
|
+
# Update filter specification for parameter
|
49
|
+
filter_spec = eqx.tree_at(
|
50
|
+
lambda model: getattr(model, f.name),
|
51
|
+
filter_spec,
|
52
|
+
replace=attr.filter_spec
|
53
|
+
)
|
45
54
|
|
46
55
|
return filter_spec
|
47
56
|
|
48
|
-
|
57
|
+
|
49
58
|
@eqx.filter_jit
|
50
|
-
def constrain_params(self) -> Tuple[
|
59
|
+
def constrain_params(self) -> Tuple[Self, Scalar]:
|
51
60
|
"""
|
52
|
-
Constrain
|
61
|
+
Constrain parameters to the appropriate domain.
|
53
62
|
|
54
63
|
# Returns
|
55
|
-
A
|
64
|
+
A constrained `Model` object and the adjustment to the posterior.
|
56
65
|
"""
|
57
|
-
|
66
|
+
constrained: Self = self
|
58
67
|
target: Scalar = jnp.array(0.0)
|
59
68
|
|
60
|
-
for
|
61
|
-
#
|
62
|
-
|
69
|
+
for f in fields(self):
|
70
|
+
# Extract attribute
|
71
|
+
attr = getattr(self, f.name)
|
72
|
+
|
73
|
+
# Check if constrained parameter
|
74
|
+
if isinstance(attr, Parameter) and 'constraint' in f.metadata:
|
75
|
+
param = attr
|
76
|
+
constraint = f.metadata['constraint']
|
77
|
+
|
78
|
+
# Apply constraint
|
79
|
+
param, laj = constraint.constrain(param)
|
80
|
+
|
81
|
+
# Update parameters for constrained model
|
82
|
+
constrained = eqx.tree_at(
|
83
|
+
lambda model: getattr(model, f.name),
|
84
|
+
constrained,
|
85
|
+
replace=param
|
86
|
+
)
|
87
|
+
|
88
|
+
# Adjust posterior density
|
89
|
+
target += laj
|
63
90
|
|
64
|
-
|
65
|
-
target -= ladj
|
91
|
+
return constrained, target
|
66
92
|
|
67
|
-
return t_params, target
|
68
93
|
|
69
|
-
# Add default transform method
|
70
94
|
@eqx.filter_jit
|
71
|
-
def transform_params(self) -> Tuple[
|
95
|
+
def transform_params(self) -> Tuple[Self, Scalar]:
|
72
96
|
"""
|
73
|
-
Apply a custom transformation to
|
97
|
+
Apply a custom transformation to parameters if needed(defaults to constrained parameters).
|
74
98
|
|
75
99
|
# Returns
|
76
|
-
A
|
100
|
+
A transformed `Model` object and the adjustment to the posterior.
|
77
101
|
"""
|
78
102
|
return self.constrain_params()
|
bayinx/core/parameter.py
CHANGED
@@ -0,0 +1 @@
|
|
1
|
+
from . import r as r
|
@@ -1,12 +1,12 @@
|
|
1
|
-
bayinx/__init__.py,sha256=
|
1
|
+
bayinx/__init__.py,sha256=5fb_tGeEVnrNt6IQqu7gZaJskBJHqjcg08JRPrY2ANo,139
|
2
2
|
bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
bayinx/constraints/__init__.py,sha256=PSxvcuSox2JL61AG1iag2PTNKPcid_DbOQzHpYdj5RE,52
|
4
4
|
bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
|
5
5
|
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
6
|
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
|
-
bayinx/core/flow.py,sha256=
|
8
|
-
bayinx/core/model.py,sha256=
|
9
|
-
bayinx/core/parameter.py,sha256=
|
7
|
+
bayinx/core/flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
|
8
|
+
bayinx/core/model.py,sha256=1vQPVjE0ebCdW7mLuabgQcCTi95o8n8CC6GuzJdNL1s,2956
|
9
|
+
bayinx/core/parameter.py,sha256=eECqvfMNWSU8_CkGYaAfOCneMMQGZI21kF0mErsh2Rc,1080
|
10
10
|
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
11
|
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
@@ -14,6 +14,7 @@ bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
|
|
14
14
|
bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
|
15
15
|
bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
|
16
16
|
bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
bayinx/dists/censored/gamma2/__init__.py,sha256=2EaQcgCXEwaRoHChVlD02ZMfgiwQAqey6uLPov1lcwE,21
|
17
18
|
bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
|
18
19
|
bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
19
20
|
bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
|
@@ -25,6 +26,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
25
26
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
26
27
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
27
28
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
28
|
-
bayinx-0.3.
|
29
|
-
bayinx-0.3.
|
30
|
-
bayinx-0.3.
|
29
|
+
bayinx-0.3.4.dist-info/METADATA,sha256=EpVIXPifXNloZfCCWNuNaVhWO_dMEujN3V_kVZz2Q6Y,3057
|
30
|
+
bayinx-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
31
|
+
bayinx-0.3.4.dist-info/RECORD,,
|
File without changes
|