bayinx 0.2.33__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bayinx/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from bayinx.core.model import Model as Model
1
+ from bayinx.core import Model as Model
2
+ from bayinx.core import Parameter as Parameter
@@ -1,9 +1,12 @@
1
1
  from typing import Tuple
2
2
 
3
+ import equinox as eqx
3
4
  import jax.numpy as jnp
4
- from jaxtyping import Array, ArrayLike, Scalar, ScalarLike
5
+ import jax.tree as jt
6
+ from jaxtyping import PyTree, Scalar, ScalarLike
5
7
 
6
8
  from bayinx.core.constraint import Constraint
9
+ from bayinx.core.parameter import Parameter
7
10
 
8
11
 
9
12
  class Lower(Constraint):
@@ -11,27 +14,38 @@ class Lower(Constraint):
11
14
  Enforces a lower bound on the parameter.
12
15
  """
13
16
 
14
- lb: ScalarLike
17
+ lb: Scalar
15
18
 
16
19
  def __init__(self, lb: ScalarLike):
17
- self.lb = lb
20
+ self.lb = jnp.array(lb)
18
21
 
19
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
22
+ @eqx.filter_jit
23
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
20
24
  """
21
- Applies the lower bound constraint and adjusts the posterior density.
25
+ Enforces a lower bound on the parameter and adjusts the posterior density.
22
26
 
23
27
  # Parameters
24
- - `x`: The unconstrained JAX Array-like input.
28
+ - `x`: The unconstrained `Parameter`.
25
29
 
26
30
  # Parameters
27
31
  A tuple containing:
28
- - The constrained JAX Array (x > self.lb).
29
- - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
32
+ - A modified `Parameter` with relevant leaves satisfying the constraint.
33
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
30
34
  """
31
- # Compute transformation adjustment
32
- laj: Scalar = jnp.sum(x)
35
+ # Extract relevant filter specification
36
+ filter_spec = x.filter_spec
37
+
38
+ # Extract relevant parameters(all Array)
39
+ dyn_params, static_params = eqx.partition(x, filter_spec)
40
+
41
+ # Compute density adjustment
42
+ laj: PyTree = jt.map(jnp.sum, dyn_params) # pyright: ignore
43
+ laj: Scalar = jt.reduce(lambda a,b: a + b, laj)
33
44
 
34
45
  # Compute transformation
35
- x = jnp.exp(x) + self.lb
46
+ dyn_params = jt.map(lambda v: jnp.exp(v) + self.lb, dyn_params)
47
+
48
+ # Combine into full parameter object
49
+ x = eqx.combine(dyn_params, static_params)
36
50
 
37
51
  return x, laj
bayinx/core/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from bayinx.core.flow import Flow as Flow
2
2
  from bayinx.core.model import Model as Model
3
+ from bayinx.core.parameter import Parameter as Parameter
3
4
  from bayinx.core.variational import Variational as Variational
bayinx/core/constraint.py CHANGED
@@ -2,7 +2,9 @@ from abc import abstractmethod
2
2
  from typing import Tuple
3
3
 
4
4
  import equinox as eqx
5
- from jaxtyping import Array, ArrayLike, Scalar
5
+ from jaxtyping import Scalar
6
+
7
+ from bayinx.core.parameter import Parameter
6
8
 
7
9
 
8
10
  class Constraint(eqx.Module):
@@ -11,16 +13,16 @@ class Constraint(eqx.Module):
11
13
  """
12
14
 
13
15
  @abstractmethod
14
- def constrain(self, x: ArrayLike) -> Tuple[Array, Scalar]:
16
+ def constrain(self, x: Parameter) -> Tuple[Parameter, Scalar]:
15
17
  """
16
- Applies the constraining transformation to an unconstrained input and computes the log-absolute-Jacobian of the transformation.
18
+ Applies the constraining transformation to a parameter and computes the log-absolute-Jacobian of the transformation.
17
19
 
18
20
  # Parameters
19
- - `x`: The unconstrained JAX Array-like input.
21
+ - `x`: The unconstrained `Parameter`.
20
22
 
21
23
  # Returns
22
24
  A tuple containing:
23
- - The constrained JAX Array.
24
- - A scalar JAX Array representing the log-absolute-Jacobian of the transformation.
25
+ - The constrained `Parameter`.
26
+ - A scalar Array representing the log-absolute-Jacobian of the transformation.
25
27
  """
26
28
  pass
bayinx/core/flow.py CHANGED
@@ -31,11 +31,13 @@ class Flow(eqx.Module):
31
31
  Computes the log-absolute-Jacobian at `draws` and applies the forward transformation.
32
32
 
33
33
  # Returns
34
- A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
34
+ A tuple of JAX Arrays containing the transformed draws and log-absolute-Jacobians.
35
35
  """
36
36
  pass
37
37
 
38
38
  # Default filter specification
39
+ @property
40
+ @eqx.filter_jit
39
41
  def filter_spec(self):
40
42
  """
41
43
  Generates a filter specification to subset relevant parameters for the flow.
@@ -53,7 +55,7 @@ class Flow(eqx.Module):
53
55
  return filter_spec
54
56
 
55
57
  @eqx.filter_jit
56
- def constrain_pars(self: Self):
58
+ def constrain_params(self: Self):
57
59
  """
58
60
  Constrain `params` to the appropriate domain.
59
61
 
@@ -68,11 +70,11 @@ class Flow(eqx.Module):
68
70
  return t_params
69
71
 
70
72
  @eqx.filter_jit
71
- def transform_pars(self: Self) -> Dict[str, Array]:
73
+ def transform_params(self: Self) -> Dict[str, Array]:
72
74
  """
73
75
  Apply a custom transformation to `params` if needed.
74
76
 
75
77
  # Returns
76
78
  A dictionary of transformed JAX Arrays representing the transformed parameters.
77
79
  """
78
- return self.constrain_pars()
80
+ return self.constrain_params()
bayinx/core/model.py CHANGED
@@ -1,24 +1,25 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Dict, Tuple
2
+ from typing import Any, Dict, Generic, Tuple, TypeVar
3
3
 
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
6
  import jax.tree as jt
7
- from jaxtyping import Array, PyTree, Scalar
7
+ from jaxtyping import PyTree, Scalar
8
8
 
9
9
  from bayinx.core.constraint import Constraint
10
+ from bayinx.core.parameter import Parameter
10
11
 
11
-
12
- class Model(eqx.Module):
12
+ T = TypeVar('T', bound=PyTree)
13
+ class Model(eqx.Module, Generic[T]):
13
14
  """
14
15
  An abstract base class used to define probabilistic models.
15
16
 
16
17
  # Attributes
17
- - `params`: A dictionary of JAX Arrays representing parameters of the model.
18
+ - `params`: A dictionary of parameters.
18
19
  - `constraints`: A dictionary of constraints.
19
20
  """
20
21
 
21
- params: Dict[str, PyTree]
22
+ params: Dict[str, Parameter[T]]
22
23
  constraints: Dict[str, Constraint]
23
24
 
24
25
  @abstractmethod
@@ -26,6 +27,8 @@ class Model(eqx.Module):
26
27
  pass
27
28
 
28
29
  # Default filter specification
30
+ @property
31
+ @eqx.filter_jit
29
32
  def filter_spec(self):
30
33
  """
31
34
  Generates a filter specification to subset relevant parameters for the model.
@@ -33,25 +36,25 @@ class Model(eqx.Module):
33
36
  # Generate empty specification
34
37
  filter_spec = jt.map(lambda _: False, self)
35
38
 
36
- # Specify JAX Array parameters
39
+ # Specify relevant parameters
37
40
  filter_spec = eqx.tree_at(
38
41
  lambda model: model.params,
39
42
  filter_spec,
40
- replace=jt.map(eqx.is_array, self.params),
43
+ replace={key: param.filter_spec for key, param in self.params.items()}
41
44
  )
42
45
 
43
46
  return filter_spec
44
47
 
45
48
  # Add constrain method
46
49
  @eqx.filter_jit
47
- def constrain_pars(self) -> Tuple[Dict[str, Array], Scalar]:
50
+ def constrain_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
48
51
  """
49
52
  Constrain `params` to the appropriate domain.
50
53
 
51
54
  # Returns
52
- A dictionary of transformed JAX Arrays representing the constrained parameters and the adjustment to the posterior density.
55
+ A dictionary of PyTrees representing the constrained parameters and the adjustment to the posterior density.
53
56
  """
54
- t_params: Dict[str, Array] = self.params
57
+ t_params: Dict[str, Parameter[T]] = self.params
55
58
  target: Scalar = jnp.array(0.0)
56
59
 
57
60
  for par, map in self.constraints.items():
@@ -64,11 +67,12 @@ class Model(eqx.Module):
64
67
  return t_params, target
65
68
 
66
69
  # Add default transform method
67
- def transform_pars(self) -> Tuple[Dict[str, Array], Scalar]:
70
+ @eqx.filter_jit
71
+ def transform_params(self) -> Tuple[Dict[str, Parameter[T]], Scalar]:
68
72
  """
69
73
  Apply a custom transformation to `params` if needed.
70
74
 
71
75
  # Returns
72
76
  A dictionary of transformed JAX Arrays representing the transformed parameters.
73
77
  """
74
- return self.constrain_pars()
78
+ return self.constrain_params()
@@ -0,0 +1,41 @@
1
+ from typing import Generic, Self, TypeVar
2
+
3
+ import equinox as eqx
4
+ import jax.tree as jt
5
+ from jaxtyping import PyTree
6
+
7
+ T = TypeVar('T', bound=PyTree)
8
+ class Parameter(eqx.Module, Generic[T]):
9
+ """
10
+ A container for a parameter of a `Model`.
11
+
12
+ Subclasses can be constructed for custom filter specifications(`filter_spec`).
13
+
14
+ # Attributes
15
+ - `vals`: The parameter's value(s).
16
+ """
17
+ vals: T
18
+
19
+
20
+ def __init__(self, values: T):
21
+ # Insert parameter values
22
+ self.vals = values
23
+
24
+ # Default filter specification
25
+ @property
26
+ @eqx.filter_jit
27
+ def filter_spec(self) -> Self:
28
+ """
29
+ Generates a filter specification to filter out static parameters.
30
+ """
31
+ # Generate empty specification
32
+ filter_spec = jt.map(lambda _: False, self)
33
+
34
+ # Specify Array leaves
35
+ filter_spec = eqx.tree_at(
36
+ lambda params: params.vals,
37
+ filter_spec,
38
+ replace=jt.map(eqx.is_array_like, self.vals),
39
+ )
40
+
41
+ return filter_spec
@@ -103,7 +103,7 @@ class Variational(eqx.Module):
103
103
  - `key`: A PRNG key.
104
104
  """
105
105
  # Partition variational
106
- dyn, static = eqx.partition(self, self.filter_spec())
106
+ dyn, static = eqx.partition(self, self.filter_spec)
107
107
 
108
108
  # Construct scheduler
109
109
  schedule: Schedule = opx.cosine_decay_schedule(
@@ -143,7 +143,7 @@ class Variational(eqx.Module):
143
143
 
144
144
  # Compute updates
145
145
  updates, opt_state = optim.update(
146
- updates, opt_state, eqx.filter(dyn, dyn.filter_spec())
146
+ updates, opt_state, eqx.filter(dyn, dyn.filter_spec)
147
147
  )
148
148
 
149
149
  # Update variational distribution
@@ -46,7 +46,7 @@ class FullAffine(Flow):
46
46
 
47
47
  @eqx.filter_jit
48
48
  def forward(self, draws: Array) -> Array:
49
- params = self.transform_pars()
49
+ params = self.transform_params()
50
50
 
51
51
  # Extract parameters
52
52
  shift: Array = params["shift"]
@@ -60,7 +60,7 @@ class FullAffine(Flow):
60
60
  @eqx.filter_jit
61
61
  @partial(jax.vmap, in_axes=(None, 0))
62
62
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
63
- params = self.transform_pars()
63
+ params = self.transform_params()
64
64
 
65
65
  # Extract parameters
66
66
  shift: Array = params["shift"]
@@ -39,7 +39,7 @@ class Planar(Flow):
39
39
  @eqx.filter_jit
40
40
  @partial(jax.vmap, in_axes=(None, 0))
41
41
  def forward(self, draws: Array) -> Array:
42
- params = self.transform_pars()
42
+ params = self.transform_params()
43
43
 
44
44
  # Extract parameters
45
45
  w: Array = params["w"]
@@ -54,7 +54,7 @@ class Planar(Flow):
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
56
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
57
- params = self.transform_pars()
57
+ params = self.transform_params()
58
58
 
59
59
  # Extract parameters
60
60
  w: Array = params["w"]
@@ -49,7 +49,7 @@ class Radial(Flow):
49
49
  # Returns
50
50
  The transformed samples.
51
51
  """
52
- params = self.transform_pars()
52
+ params = self.transform_params()
53
53
 
54
54
  # Extract parameters
55
55
  alpha = params["alpha"]
@@ -67,7 +67,7 @@ class Radial(Flow):
67
67
  @partial(jax.vmap, in_axes=(None, 0))
68
68
  @eqx.filter_jit
69
69
  def adjust_density(self, draws: Array) -> Tuple[Array, Scalar]:
70
- params = self.transform_pars()
70
+ params = self.transform_params()
71
71
 
72
72
  # Extract parameters
73
73
  alpha = params["alpha"]
@@ -29,7 +29,7 @@ class MeanField(Variational):
29
29
  - `model`: A probabilistic `Model` object.
30
30
  """
31
31
  # Partition model
32
- params, self._constraints = eqx.partition(model, model.filter_spec())
32
+ params, self._constraints = eqx.partition(model, model.filter_spec)
33
33
 
34
34
  # Flatten params component
35
35
  params, self._unflatten = ravel_pytree(params)
@@ -40,6 +40,22 @@ class MeanField(Variational):
40
40
  "log_std": jnp.zeros(params.size, dtype=params.dtype),
41
41
  }
42
42
 
43
+ @property
44
+ @eqx.filter_jit
45
+ def filter_spec(self):
46
+ # Generate empty specification
47
+ filter_spec = jtu.tree_map(lambda _: False, self)
48
+
49
+ # Specify variational parameters
50
+ filter_spec = eqx.tree_at(
51
+ lambda mf: mf.var_params,
52
+ filter_spec,
53
+ replace=True,
54
+ )
55
+
56
+ return filter_spec
57
+
58
+
43
59
  @eqx.filter_jit
44
60
  def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
45
61
  # Sample variational draws
@@ -59,23 +75,9 @@ class MeanField(Variational):
59
75
  sigma=jnp.exp(self.var_params["log_std"]),
60
76
  ).sum(axis=1)
61
77
 
62
- @eqx.filter_jit
63
- def filter_spec(self):
64
- # Generate empty specification
65
- filter_spec = jtu.tree_map(lambda _: False, self)
66
-
67
- # Specify variational parameters
68
- filter_spec = eqx.tree_at(
69
- lambda mf: mf.var_params,
70
- filter_spec,
71
- replace=True,
72
- )
73
-
74
- return filter_spec
75
-
76
78
  @eqx.filter_jit
77
79
  def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
78
- dyn, static = eqx.partition(self, self.filter_spec())
80
+ dyn, static = eqx.partition(self, self.filter_spec)
79
81
 
80
82
  @eqx.filter_jit
81
83
  def elbo(dyn: Self, n: int, key: Key, data: Any = None) -> Scalar:
@@ -97,7 +99,7 @@ class MeanField(Variational):
97
99
 
98
100
  @eqx.filter_jit
99
101
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
100
- dyn, static = eqx.partition(self, self.filter_spec())
102
+ dyn, static = eqx.partition(self, self.filter_spec)
101
103
 
102
104
  @eqx.filter_grad
103
105
  @eqx.filter_jit
@@ -33,7 +33,7 @@ class NormalizingFlow(Variational):
33
33
  - `model`: A probabilistic `Model` object.
34
34
  """
35
35
  # Partition model
36
- params, self._constraints = eqx.partition(model, eqx.is_array)
36
+ params, self._constraints = eqx.partition(model, model.filter_spec)
37
37
 
38
38
  # Flatten params component
39
39
  _, self._unflatten = jfu.ravel_pytree(params)
@@ -41,6 +41,8 @@ class NormalizingFlow(Variational):
41
41
  self.base = base
42
42
  self.flows = flows
43
43
 
44
+ @property
45
+ @eqx.filter_jit
44
46
  def filter_spec(self):
45
47
  # Generate empty specification
46
48
  filter_spec = jtu.tree_map(lambda _: False, self)
@@ -49,7 +51,7 @@ class NormalizingFlow(Variational):
49
51
  filter_spec = eqx.tree_at(
50
52
  lambda vari: vari.flows,
51
53
  filter_spec,
52
- replace=[flow.filter_spec() for flow in self.flows],
54
+ replace=[flow.filter_spec for flow in self.flows],
53
55
  )
54
56
 
55
57
  return filter_spec
@@ -112,7 +114,7 @@ class NormalizingFlow(Variational):
112
114
 
113
115
  @eqx.filter_jit
114
116
  def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
115
- dyn, static = eqx.partition(self, self.filter_spec())
117
+ dyn, static = eqx.partition(self, self.filter_spec)
116
118
 
117
119
  @eqx.filter_jit
118
120
  def elbo(dyn: Self, n: int, key: Key, data: Any = None):
@@ -129,7 +131,7 @@ class NormalizingFlow(Variational):
129
131
 
130
132
  @eqx.filter_jit
131
133
  def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
132
- dyn, static = eqx.partition(self, self.filter_spec())
134
+ dyn, static = eqx.partition(self, self.filter_spec)
133
135
 
134
136
  @eqx.filter_grad
135
137
  @eqx.filter_jit
bayinx/mhx/vi/standard.py CHANGED
@@ -19,7 +19,7 @@ class Standard(Variational):
19
19
  - `dim`: Dimension of the parameter space.
20
20
  """
21
21
 
22
- dim: int = eqx.field(static=True)
22
+ dim: int
23
23
  _unflatten: Callable[[Float[Array, "..."]], Model]
24
24
  _constraints: Model
25
25
 
@@ -31,7 +31,7 @@ class Standard(Variational):
31
31
  - `model`: A probabilistic `Model` object.
32
32
  """
33
33
  # Partition model
34
- params, self._constraints = eqx.partition(model, model.filter_spec())
34
+ params, self._constraints = eqx.partition(model, model.filter_spec)
35
35
 
36
36
  # Flatten params component
37
37
  params, self._unflatten = ravel_pytree(params)
@@ -54,7 +54,7 @@ class Standard(Variational):
54
54
  sigma=jnp.array(1.0),
55
55
  ).sum(axis=1, keepdims=True)
56
56
 
57
- @eqx.filter_jit
57
+ @property
58
58
  def filter_spec(self):
59
59
  filter_spec = jtu.tree_map(lambda _: False, self)
60
60
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.33
3
+ Version: 0.3.1
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -0,0 +1,30 @@
1
+ bayinx/__init__.py,sha256=htihTsJ54k-ljBLzt4ye8DR7ORwZhxv-bLMcEaDQeqY,86
2
+ bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ bayinx/constraints/lower.py,sha256=wkYnWjaAEGQeXKfBo_gY0pcK9ElJUMkzGdAmWI8ykCk,1488
5
+ bayinx/core/__init__.py,sha256=jSwEFdXqi-Bj_X8_H-YuaXp5ebEQpZTG2T18zpquzPo,207
6
+ bayinx/core/constraint.py,sha256=F6-TXQjzt-tcNm8bHkRcGEtyE9bZQf2RbAh_MKDuM20,760
7
+ bayinx/core/flow.py,sha256=lAPJdQnrIxC3JoowTp77Gvm0p0v_xQn8FMjFJWMnWbc,2340
8
+ bayinx/core/model.py,sha256=RKAtsLXc6xDnXWz5upmx6Vz6JOoorw4WTfxTA7B7Lmg,2294
9
+ bayinx/core/parameter.py,sha256=oxCCZcZ-DDBvfWzexfhQkSJPxNQnE1vYXtBhiEZG2bM,1025
10
+ bayinx/core/variational.py,sha256=lqENISRrKY8ODLtl0D-D7TAA2gD7HGh37BnROM7p5hI,4783
11
+ bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
+ bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
14
+ bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
15
+ bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
16
+ bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
18
+ bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
+ bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
20
+ bayinx/mhx/vi/meanfield.py,sha256=M4QrOuHaIMLTuQSD6JNF9vELnTm370tXV68JPB7B67M,3652
21
+ bayinx/mhx/vi/normalizing_flow.py,sha256=9c5ayMJ_Wgq6pUb1GYHIFIzw3Bf1AsIdOjcerLoYMrA,4655
22
+ bayinx/mhx/vi/standard.py,sha256=DfSV0r9oXzp9UM8OsZBpoJPRUhiDoAq_X2_2z_M83lA,1685
23
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
24
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
25
+ bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
26
+ bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
27
+ bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
28
+ bayinx-0.3.1.dist-info/METADATA,sha256=Slp2nxR8HISCwCqIXY2El3GgqsO1v9_UbVeJq726w7k,3057
29
+ bayinx-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
+ bayinx-0.3.1.dist-info/RECORD,,
@@ -1,29 +0,0 @@
1
- bayinx/__init__.py,sha256=l20JdkSsE_XGZlZFNEtySXf4NIlbjrao14vXPB-H6aQ,45
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- bayinx/constraints/lower.py,sha256=MAAsWpZhqu1ySMskQ0fPVkCzW6CVDCU67q2bkCyzbLc,936
5
- bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
6
- bayinx/core/constraint.py,sha256=60KzDILVLQTCY3jt4YEnNKJ5gnG8IHTv_nNqrdwt_3c,751
7
- bayinx/core/flow.py,sha256=A5Vw5t76LPasnMgghjw6ulBkIm5L2jBprusVt-tuwko,2296
8
- bayinx/core/model.py,sha256=7Gt7HkFLzSUbRY9PxTDp6CrXzmld25NL9aQo34ePeno,2135
9
- bayinx/core/variational.py,sha256=2stsYKZDri1rLP7mrz7X2GWehBXNESdlWtmF2N9CEas,4787
10
- bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
12
- bayinx/dists/gamma2.py,sha256=8XYaOtcYJCrr5q1yHWfZaMJmASpLOrfyhrH_J06ksj8,1333
13
- bayinx/dists/normal.py,sha256=mvm6EoAlORy-yivuhMcExYCZUo0vJzMKMOWH-9iQBZU,2634
14
- bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
15
- bayinx/dists/censored/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- bayinx/dists/censored/gamma2/r.py,sha256=3brRCKhE-74mRXyIyPcnyaWY2OJv8CZyUWPP9T1t09Y,2274
17
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
18
- bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
19
- bayinx/mhx/vi/meanfield.py,sha256=BobfTagVGA5R-dclv-E0jSA80KZg1X6GGjiw7XR61vE,3643
20
- bayinx/mhx/vi/normalizing_flow.py,sha256=DYhvTiu2Fr5x8KpWAMZVUaio7ctG2X2SMUO0l8zfZ5g,4622
21
- bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
22
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
23
- bayinx/mhx/vi/flows/fullaffine.py,sha256=Z_G2Cg90Asgvqel8buMx5kEVsiIxDDh8rc-L_nP9OCY,1950
24
- bayinx/mhx/vi/flows/planar.py,sha256=WVj-oxcRctuoRA6KJjU63ek1ZgKNG2vI-TLN0QqjtKA,1916
25
- bayinx/mhx/vi/flows/radial.py,sha256=Obj3SraliawIHmP14F9wRpWt34y3kscY--Izy24eCvM,2499
26
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
27
- bayinx-0.2.33.dist-info/METADATA,sha256=8d-BDtz7NrXSs5kJd-9Yr5zHTzEPtQvhgZGD-3VX7FI,3058
28
- bayinx-0.2.33.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
- bayinx-0.2.33.dist-info/RECORD,,