bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,149 +0,0 @@
1
- from typing import Any, Generic, Self, Tuple, TypeVar
2
-
3
- import equinox as eqx
4
- import jax.flatten_util as jfu
5
- import jax.numpy as jnp
6
- import jax.random as jr
7
- import jax.tree_util as jtu
8
- from jaxtyping import Array, Key, Scalar
9
-
10
- from bayinx.core import Flow, Model, Variational
11
-
12
- M = TypeVar('M', bound=Model)
13
- class NormalizingFlow(Variational, Generic[M]):
14
- """
15
- An ordered collection of diffeomorphisms that map a base distribution to a
16
- normalized approximation of a posterior distribution.
17
-
18
- # Attributes
19
- - `base`: A base variational distribution.
20
- - `flows`: An ordered collection of continuously parameterized diffeomorphisms.
21
- """
22
-
23
- flows: list[Flow]
24
- base: Variational
25
-
26
- def __init__(self, base: Variational, flows: list[Flow], model: M):
27
- """
28
- Constructs an unoptimized normalizing flow posterior approximation.
29
-
30
- # Parameters
31
- - `base`: The base variational distribution.
32
- - `flows`: A list of diffeomorphisms.
33
- - `model`: A probabilistic `Model` object.
34
- """
35
- # Partition model
36
- params, self._constraints = eqx.partition(model, model.filter_spec)
37
-
38
- # Flatten params component
39
- _, self._unflatten = jfu.ravel_pytree(params)
40
-
41
- self.base = base
42
- self.flows = flows
43
-
44
- @property
45
- @eqx.filter_jit
46
- def filter_spec(self):
47
- # Generate empty specification
48
- filter_spec = jtu.tree_map(lambda _: False, self)
49
-
50
- # Specify variational parameters based on each flow's filter spec.
51
- filter_spec = eqx.tree_at(
52
- lambda vari: vari.flows,
53
- filter_spec,
54
- replace=[flow.filter_spec for flow in self.flows],
55
- )
56
-
57
- return filter_spec
58
-
59
- @eqx.filter_jit
60
- def sample(self, n: int, key: Key = jr.PRNGKey(0)):
61
- """
62
- Sample from the variational distribution `n` times.
63
- """
64
- # Sample from the base distribution
65
- draws: Array = self.base.sample(n, key)
66
-
67
- # Apply forward transformations
68
- for map in self.flows:
69
- draws = map.forward(draws)
70
-
71
- return draws
72
-
73
- @eqx.filter_jit
74
- def eval(self, draws: Array) -> Array:
75
- # Evaluate base density
76
- variational_evals: Array = self.base.eval(draws)
77
-
78
- for map in self.flows:
79
- # Compute adjustment
80
- draws, laj = map.adjust_density(draws)
81
-
82
- # Adjust variational density
83
- variational_evals = variational_evals - laj
84
-
85
- return variational_evals
86
-
87
- @eqx.filter_jit
88
- def __eval(self, draws: Array, data=None) -> Tuple[Array, Array]:
89
- """
90
- Evaluate the posterior and variational densities together at the
91
- transformed `draws` to avoid extra compute.
92
-
93
- # Parameters
94
- - `draws`: Draws from the base variational distribution.
95
- - `data`: Any data required to evaluate the posterior density.
96
-
97
- # Returns
98
- The posterior and variational densities as JAX Arrays.
99
- """
100
- # Evaluate base density
101
- variational_evals: Array = self.base.eval(draws)
102
-
103
- for map in self.flows:
104
- # Compute adjustment
105
- draws, laj = map.adjust_density(draws)
106
-
107
- # Adjust variational density
108
- variational_evals = variational_evals - laj
109
-
110
- # Evaluate posterior at final variational draws
111
- posterior_evals = self.eval_model(draws, data)
112
-
113
- return posterior_evals, variational_evals
114
-
115
- @eqx.filter_jit
116
- def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
117
- dyn, static = eqx.partition(self, self.filter_spec)
118
-
119
- @eqx.filter_jit
120
- def elbo(dyn: Self, n: int, key: Key, data: Any = None):
121
- self = eqx.combine(dyn, static)
122
-
123
- # Sample draws from variational distribution
124
- draws: Array = self.base.sample(n, key)
125
-
126
- posterior_evals, variational_evals = self.__eval(draws, data)
127
- # Evaluate ELBO
128
- return jnp.mean(posterior_evals - variational_evals)
129
-
130
- return elbo(dyn, n, key, data)
131
-
132
- @eqx.filter_jit
133
- def elbo_grad(self, n: int, key: Key, data: Any = None) -> Self:
134
- dyn, static = eqx.partition(self, self.filter_spec)
135
-
136
- @eqx.filter_grad
137
- @eqx.filter_jit
138
- def elbo_grad(dyn: Self, n: int, key: Key, data: Any = None):
139
- self = eqx.combine(dyn, static)
140
-
141
- # Sample draws from variational distribution
142
- draws: Array = self.base.sample(n, key)
143
-
144
- posterior_evals, variational_evals = self.__eval(draws, data)
145
-
146
- # Evaluate ELBO
147
- return jnp.mean(posterior_evals - variational_evals)
148
-
149
- return elbo_grad(dyn, n, key, data)
bayinx/mhx/vi/standard.py DELETED
@@ -1,63 +0,0 @@
1
-
2
- import equinox as eqx
3
- import jax.numpy as jnp
4
- import jax.random as jr
5
- import jax.tree_util as jtu
6
- from jax.flatten_util import ravel_pytree
7
- from jaxtyping import Array, Key
8
-
9
- from bayinx.core._variational import M, Variational
10
- from bayinx.dists import normal
11
-
12
-
13
- class Standard(Variational[M]):
14
- """
15
- A standard normal approximation to a posterior distribution.
16
-
17
- # Attributes
18
- - `dim`: Dimension of the parameter space.
19
- """
20
- dim: int
21
-
22
- def __init__(self, model: M):
23
- """
24
- Constructs a standard normal approximation to a posterior distribution.
25
-
26
- # Parameters
27
- - `model`: A probabilistic `Model` object.
28
- """
29
- # Partition model
30
- params, self._constraints = eqx.partition(model, model.filter_spec)
31
-
32
- # Flatten params component
33
- params, self._unflatten = ravel_pytree(params)
34
-
35
- # Store dimension of parameter space
36
- self.dim = jnp.size(params)
37
-
38
- @eqx.filter_jit
39
- def sample(self, n: int, key: Key = jr.PRNGKey(0)) -> Array:
40
- # Sample variational draws
41
- draws: Array = jr.normal(key=key, shape=(n, self.dim))
42
-
43
- return draws
44
-
45
- @eqx.filter_jit
46
- def eval(self, draws: Array) -> Array:
47
- return normal.logprob(
48
- x=draws,
49
- mu=jnp.array(0.0),
50
- sigma=jnp.array(1.0),
51
- ).sum(axis=1, keepdims=True)
52
-
53
- @property
54
- def filter_spec(self):
55
- filter_spec = jtu.tree_map(lambda _: False, self)
56
-
57
- return filter_spec
58
-
59
- def elbo(self):
60
- return None
61
-
62
- def elbo_grad(self):
63
- return None
@@ -1,39 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: bayinx
3
- Version: 0.3.10
4
- Summary: Bayesian Inference with JAX
5
- License-File: LICENSE
6
- Requires-Python: >=3.12
7
- Requires-Dist: equinox>=0.11.12
8
- Requires-Dist: jax>=0.4.38
9
- Requires-Dist: jaxtyping>=0.2.36
10
- Requires-Dist: optax>=0.2.4
11
- Description-Content-Type: text/markdown
12
-
13
- # <ins>Bay</ins>esian <ins>In</ins>ference with JA<ins>X</ins>
14
-
15
- The endgoal of this project is to build a Bayesian inference library that is similar in feel to `Stan`(where you can define a probabilistic model with syntax that is equivalent to how you would write it out on a chalkboard) but allows for arbitrary models(e.g., ones with discrete parameters) and offers a suite of "machinery" to fit the model; this means I want to expand upon `Stan`'s existing toolbox of methods for estimation(point optimization, variational methods, MCMC) while keeping everything performant(hence using `JAX`).
16
-
17
- In the short-term, I'm going to focus on:
18
- 1) Implementing as much machinery as I feel is enough.
19
- 2) Figuring out how to design the `Model` superclass to have something like the `transformed pars {}` block but unifies transformations and constraints.
20
- 3) Figuring out how to design the library to automatically recognize what kind of machinery is amenable to a given probabilistic model.
21
-
22
- In the long-term, I'm going to focus on:
23
- 1) How to get `Stan`-like declarative syntax in Python with minimal syntactic overhead(to get as close as possible to statements like `X ~ Normal(mu, 1)`), while also allowing users to work with `target` directly when needed(same as `Stan` does).
24
- 2) How to make working with the posterior as easy as possible.
25
- - That's a vague goal but practically it means how to easily evaluate statements like $P(\theta \in [-1, 1] | \mathcal{D}, \mathcal{M})$, or set up contrasts and evaluate $P(\mu_1 - \mu_2 > 0 | \mathcal{D}, \mathcal{M})$, or simulate the posterior predictive to generate plots, etc.
26
-
27
- Although this is somewhat separate from the goals of the project, if this does pan out how I'm invisioning it I'd like an R formula-like syntax to shorten model construction in scenarios where the model is just a GLMM or similar(think `brms`).
28
-
29
- Additionally, when I get around to it I'd like the package documentation to also include theoretical and implementation details for all machinery implemented(with overthinking boxes because I do like that design from McElreath's book).
30
-
31
-
32
- # TODO
33
- - Find some way to discern between models with all floating-point parameters and weirder models with integer parameters. Useful for restricting variational methods like `MeanField` to `Model`s that only have floating-point parameters.
34
- - Look into adaptively tuning ADAM hyperparameters.
35
- - Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
36
- - Low-rank affine flow?
37
- - https://arxiv.org/pdf/1803.05649 implement sylvester flows.
38
- - Learn how to generate documentation lol.
39
- - Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
@@ -1,35 +0,0 @@
1
- bayinx/__init__.py,sha256=TM-aoRaPX6jSYtCM7Jv59TPV-H6bcDk1-VMttYP1KME,99
2
- bayinx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- bayinx/constraints/__init__.py,sha256=PiWXZKi7YdbTMKvw-OE5f-t87jJT893uAFrwWWBfOdg,64
4
- bayinx/constraints/lower.py,sha256=30y0l6PF-tbS9LR_tto9AvwmsvXq1ExU-v8DLrJD4g4,1446
5
- bayinx/core/__init__.py,sha256=bZvQITgW0DWuPKl3wCLKt6WHKogYKx8Zz36g8z9Aung,253
6
- bayinx/core/_constraint.py,sha256=Gx07ZT66VE2y-qZCmBDm3_y0wO4xQyslZW10Lec1_lM,761
7
- bayinx/core/_flow.py,sha256=3q4rKvATnbUpuj4ICUd4lIxu_3z7GRDuNujVhAku1X0,2342
8
- bayinx/core/_model.py,sha256=FJUyYVE9e2uTFamxtSMKY_VV2stiU2QF68Wl_7EAKEU,2895
9
- bayinx/core/_parameter.py,sha256=r20JedTW2lY0miNNh9y6LeIVAsGX1kP_rlGxphW_jZg,1080
10
- bayinx/core/_variational.py,sha256=X8o81b8vyU7vJxw8pZYH_nxc3u990tRUZgRhMNodNI4,5484
11
- bayinx/dists/__init__.py,sha256=9DdPea7HAnBOzaV_4gM5noPX8YCb_p06d8PJvGfFy3Y,118
12
- bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
13
- bayinx/dists/gamma2.py,sha256=MuFudL2UTfk8HgWVofNaR36JTmUpmtxvg1Mifu98MvM,1567
14
- bayinx/dists/normal.py,sha256=Yc2X8F7JoLYwprtK8bA2BPva1tAY7MEs3oSk5pMortI,3822
15
- bayinx/dists/posnormal.py,sha256=w9plA1EctXwXOiY0doc4ZndjnwptbEZBHHCGdc4gviY,7292
16
- bayinx/dists/uniform.py,sha256=7XgVvOrzINEFA6HJTYUOFwlWhEtrQQQ1aPJ_ZLOzLEc,2365
17
- bayinx/dists/censored/__init__.py,sha256=UVihMbQgAzCoOk_Zt5wrumPv5-acuTzV3TYMB-U1gOc,49
18
- bayinx/dists/censored/gamma2/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
19
- bayinx/dists/censored/gamma2/r.py,sha256=dKAOYstufwgDwibQZHrJxA1d2gawj-7K3IkaCRCzNTg,2446
20
- bayinx/dists/censored/posnormal/__init__.py,sha256=GO3jIF1En0ZxYF5JqvC0helLAL6yv8-LG6Ih2NOUYQc,33
21
- bayinx/dists/censored/posnormal/r.py,sha256=Ypi6w_t53pAzRVzjcStx2RhozkAlCDLnQmgKykhpQQ4,3426
22
- bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
- bayinx/mhx/vi/__init__.py,sha256=2woNB5oZxfs8pZCkOfzriGahRFLzkLdkTj8_keTN0I0,205
24
- bayinx/mhx/vi/meanfield.py,sha256=Z7kGQAyp5iB8rEdjbwAbVTFH4GwxlTKDZFbdJ-FN5Vs,3739
25
- bayinx/mhx/vi/normalizing_flow.py,sha256=8pLMDdZPIt5wlgbhHWSFY1ChSWM9pvSD2bQx3zgz1F8,4710
26
- bayinx/mhx/vi/standard.py,sha256=W-ZvigJkUpqVlREgiFm9io8ansT1XpZwq5AqSmdv--E,1578
27
- bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
28
- bayinx/mhx/vi/flows/fullaffine.py,sha256=11y_A0oO3bkKDSz-EQ6Sf4Ec2M7vHZxw94EdvADwVYQ,1954
29
- bayinx/mhx/vi/flows/planar.py,sha256=2I2WzIskl8MRpJkK13FQE3vSF-077qo8gRed_EL1Pn8,1920
30
- bayinx/mhx/vi/flows/radial.py,sha256=e0GfuO-CL8SVr3YnEfsxStpyKcJlFpzMyjMp3sa38hg,2503
31
- bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
32
- bayinx-0.3.10.dist-info/METADATA,sha256=y99n8rNP62ezyDKJSmBQvUQx52WKUEgyHtw3QnnVrvs,3080
33
- bayinx-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
- bayinx-0.3.10.dist-info/licenses/LICENSE,sha256=VMhLhj5hx6VAENZBaNfXrmsNl7ov9uRh0jZ6D3ltgv4,1070
35
- bayinx-0.3.10.dist-info/RECORD,,
File without changes