bayinx 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/core/model.py
CHANGED
@@ -1,25 +1,25 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from typing import Any, Dict, Tuple
|
2
|
+
from typing import Any, Dict, Generic, Tuple, TypeVar
|
3
3
|
|
4
4
|
import equinox as eqx
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jax.tree as jt
|
7
|
-
from jaxtyping import
|
7
|
+
from jaxtyping import PyTree, Scalar
|
8
8
|
|
9
9
|
from bayinx.core.constraint import Constraint
|
10
10
|
from bayinx.core.parameter import Parameter
|
11
11
|
|
12
|
-
|
13
|
-
class Model(eqx.Module):
|
12
|
+
T = TypeVar('T', bound=PyTree)
|
13
|
+
class Model(eqx.Module, Generic[T]):
|
14
14
|
"""
|
15
15
|
An abstract base class used to define probabilistic models.
|
16
16
|
|
17
17
|
# Attributes
|
18
|
-
- `params`: A dictionary of
|
18
|
+
- `params`: A dictionary of parameters.
|
19
19
|
- `constraints`: A dictionary of constraints.
|
20
20
|
"""
|
21
21
|
|
22
|
-
params: Dict[str, Parameter]
|
22
|
+
params: Dict[str, Parameter[T]]
|
23
23
|
constraints: Dict[str, Constraint]
|
24
24
|
|
25
25
|
@abstractmethod
|
@@ -47,14 +47,14 @@ class Model(eqx.Module):
|
|
47
47
|
|
48
48
|
# Add constrain method
|
49
49
|
@eqx.filter_jit
|
50
|
-
def constrain_params(self) -> Tuple[Dict[str, Parameter], Scalar]:
|
50
|
+
def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
|
51
51
|
"""
|
52
52
|
Constrain `params` to the appropriate domain.
|
53
53
|
|
54
54
|
# Returns
|
55
55
|
A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
|
56
56
|
"""
|
57
|
-
t_params: Dict[str,
|
57
|
+
t_params: Dict[str, Parameter[T]] = self.params
|
58
58
|
target: Scalar = jnp.array(0.0)
|
59
59
|
|
60
60
|
for par, map in self.constraints.items():
|
@@ -67,7 +67,8 @@ class Model(eqx.Module):
|
|
67
67
|
return t_params, target
|
68
68
|
|
69
69
|
# Add default transform method
|
70
|
-
|
70
|
+
@eqx.filter_jit
|
71
|
+
def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
|
71
72
|
"""
|
72
73
|
Apply a custom transformation to `params` if needed.
|
73
74
|
|
bayinx/core/parameter.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import Self
|
1
|
+
from typing import Generic, Self, TypeVar
|
2
2
|
|
3
3
|
import equinox as eqx
|
4
4
|
import jax.tree as jt
|
5
|
-
from jaxtyping import
|
5
|
+
from jaxtyping import PyTree
|
6
6
|
|
7
|
-
|
8
|
-
class Parameter(eqx.Module):
|
7
|
+
T = TypeVar('T', bound=PyTree)
|
8
|
+
class Parameter(eqx.Module, Generic[T]):
|
9
9
|
"""
|
10
10
|
A container for a parameter of a `Model`.
|
11
11
|
|
@@ -14,10 +14,10 @@ class Parameter(eqx.Module):
|
|
14
14
|
# Attributes
|
15
15
|
- `vals`: The parameter's value(s).
|
16
16
|
"""
|
17
|
-
vals:
|
17
|
+
vals: T
|
18
18
|
|
19
19
|
|
20
|
-
def __init__(self, values:
|
20
|
+
def __init__(self, values: T):
|
21
21
|
# Insert parameter values
|
22
22
|
self.vals = values
|
23
23
|
|
@@ -5,8 +5,8 @@ bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1
|
|
5
5
|
bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
|
6
6
|
bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
|
7
7
|
bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
|
8
|
-
bayinx/core/model.py,sha256=
|
9
|
-
bayinx/core/parameter.py,sha256=
|
8
|
+
bayinx/core/model.py,sha256=RKAtsLXc6xDnXWz5upmx6Vz6JOoorw4WTfxTA7B7Lmg,2294
|
9
|
+
bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
|
10
10
|
bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
|
11
11
|
bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
|
@@ -25,6 +25,6 @@ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvAD
|
|
25
25
|
bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
|
26
26
|
bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
|
27
27
|
bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
|
28
|
-
bayinx-0.3.
|
29
|
-
bayinx-0.3.
|
30
|
-
bayinx-0.3.
|
28
|
+
bayinx-0.3.1.dist-info/METADATA,sha256=Slp2nxR8HISCwCqIXY2El3GgqsO1v9_UbVeJq726w7k,3057
|
29
|
+
bayinx-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
30
|
+
bayinx-0.3.1.dist-info/RECORD,,
|
File without changes
|