blayers 0.1.0a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 George Berry
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,141 @@
1
+ Metadata-Version: 2.4
2
+ Name: blayers
3
+ Version: 0.1.0a1
4
+ Summary: Bayesian layers for NumPyro and Jax
5
+ Author-email: George Berry <george.e.berry@gmail.com>
6
+ Project-URL: Homepage, https://github.com/georgeberry/blayers
7
+ Project-URL: Documentation, https://georgeberry.github.io/blayers/
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: numpy
12
+ Requires-Dist: jax
13
+ Requires-Dist: numpyro
14
+ Provides-Extra: dev
15
+ Requires-Dist: optax; extra == "dev"
16
+ Requires-Dist: pytest; extra == "dev"
17
+ Requires-Dist: pytest-check; extra == "dev"
18
+ Requires-Dist: mypy; extra == "dev"
19
+ Requires-Dist: black; extra == "dev"
20
+ Requires-Dist: isort; extra == "dev"
21
+ Requires-Dist: autoflake; extra == "dev"
22
+ Requires-Dist: sphinx; extra == "dev"
23
+ Requires-Dist: furo; extra == "dev"
24
+ Requires-Dist: coverage; extra == "dev"
25
+ Requires-Dist: pytest-cov; extra == "dev"
26
+ Requires-Dist: pre-commit; extra == "dev"
27
+ Dynamic: license-file
28
+
29
+ [![codecov](https://codecov.io/gh/georgeberry/blayers/graph/badge.svg?token=ZDGT0C39QM)](https://codecov.io/gh/georgeberry/blayers) [![License](https://img.shields.io/github/license/georgeberry/blayers)](LICENSE) ![version](https://img.shields.io/badge/version-0.1.0a1-blue)
30
+
31
+ # BLayers
32
+
33
+ **NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
34
+
35
+ The missing layers package for Bayesian inference. Inspiration from Keras and
36
+ Tensorflow Probability, but made specifically for Numpyro + Jax.
37
+
38
+ Easily build Bayesian models from parts, abstract away the boilerplate, and
39
+ tweak priors as you wish.
40
+
41
+ Fit models either using Variational Inference (VI) or your sampling method of
42
+ choice. Use BLayer's ELBO implementation to do either batched VI or sampling
43
+ without having to rewrite models.
44
+
45
+ BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
46
+ code to build models of arbitrary complexity. It also gives you a recipe to
47
+ build more complex layers as you wish.
48
+
49
+ ## the starting point
50
+
51
+ The simplest non-trivial (and most important!) Bayesian regression model form is
52
+ the adaptive prior,
53
+
54
+ ```
55
+ lmbda ~ HalfNormal(1)
56
+ beta ~ Normal(0, lmbda)
57
+ y ~ Normal(beta * x, 1)
58
+ ```
59
+
60
+ BLayers takes this as its starting point and most fundamental building block,
61
+ providing the flexible `AdaptiveLayer`.
62
+
63
+ ```python
64
+ from blayers import AdaptiveLayer, gaussian_link_exp
65
+ def model(x, y):
66
+ mu = AdaptiveLayer()('mu', x)
67
+ return gaussian_link_exp(mu, y)
68
+ ```
69
+
70
+ For you purists out there, we also provide a `FixedPriorLayer` for standard
71
+ L1/L2 regression.
72
+
73
+ ```python
74
+ from blayers import FixedPriorLayer, gaussian_link_exp
75
+ def model(x, y):
76
+ mu = FixedPriorLayer()('mu', x)
77
+ return gaussian_link_exp(mu, y)
78
+ ```
79
+
80
+ ## additional layers
81
+
82
+ ### factorization machines
83
+
84
+ Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
85
+
86
+ To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
87
+
88
+ ```python
89
+ from blayers import FMLayer, gaussian_link_exp
90
+ def model(x, y):
91
+ mu = (
92
+ AdaptiveLayer('x', x) +
93
+ AdaptiveLayer('x2', x**2) +
94
+ FMLayer(low_rank_dim=3)('xx', x)
95
+ )
96
+ return gaussian_link_exp(mu, y)
97
+ ```
98
+
99
+ ### uv decomp
100
+
101
+ We also provide a standard UV deccomp for low rank interaction terms
102
+
103
+ ```python
104
+ from blayers import LowRankInteractionLayer, gaussian_link_exp
105
+ def model(x, z, y):
106
+ mu = (
107
+ AdaptiveLayer('x', x) +
108
+ AdaptiveLayer('z', z) +
109
+ LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
110
+ )
111
+ return gaussian_link_exp(mu, y)
112
+ ```
113
+
114
+ ## links
115
+
116
+ We provide link functions as a convenience to abstract away a bit more Numpyro
117
+ boilerplate.
118
+
119
+ We currently provide
120
+
121
+ * `gaussian_link_exp`
122
+
123
+ ## batched loss
124
+
125
+ The default Numpyro way to fit batched VI models is to use `plate`, which confuses
126
+ me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
127
+ you to use `plate` to batch in VI. Just drop your model in.
128
+
129
+ ```python
130
+ from blayers.infer import Batched_Trace_ELBO, svi_run_batched
131
+
132
+ svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
133
+
134
+ svi_result = svi_run_batched(
135
+ svi,
136
+ rng_key,
137
+ num_steps,
138
+ batch_size=1000,
139
+ **model_data,
140
+ )
141
+ ```
@@ -0,0 +1,113 @@
1
+ [![codecov](https://codecov.io/gh/georgeberry/blayers/graph/badge.svg?token=ZDGT0C39QM)](https://codecov.io/gh/georgeberry/blayers) [![License](https://img.shields.io/github/license/georgeberry/blayers)](LICENSE) ![version](https://img.shields.io/badge/version-0.1.0a1-blue)
2
+
3
+ # BLayers
4
+
5
+ **NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
6
+
7
+ The missing layers package for Bayesian inference. Inspiration from Keras and
8
+ Tensorflow Probability, but made specifically for Numpyro + Jax.
9
+
10
+ Easily build Bayesian models from parts, abstract away the boilerplate, and
11
+ tweak priors as you wish.
12
+
13
+ Fit models either using Variational Inference (VI) or your sampling method of
14
+ choice. Use BLayer's ELBO implementation to do either batched VI or sampling
15
+ without having to rewrite models.
16
+
17
+ BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
18
+ code to build models of arbitrary complexity. It also gives you a recipe to
19
+ build more complex layers as you wish.
20
+
21
+ ## the starting point
22
+
23
+ The simplest non-trivial (and most important!) Bayesian regression model form is
24
+ the adaptive prior,
25
+
26
+ ```
27
+ lmbda ~ HalfNormal(1)
28
+ beta ~ Normal(0, lmbda)
29
+ y ~ Normal(beta * x, 1)
30
+ ```
31
+
32
+ BLayers takes this as its starting point and most fundamental building block,
33
+ providing the flexible `AdaptiveLayer`.
34
+
35
+ ```python
36
+ from blayers import AdaptiveLayer, gaussian_link_exp
37
+ def model(x, y):
38
+ mu = AdaptiveLayer()('mu', x)
39
+ return gaussian_link_exp(mu, y)
40
+ ```
41
+
42
+ For you purists out there, we also provide a `FixedPriorLayer` for standard
43
+ L1/L2 regression.
44
+
45
+ ```python
46
+ from blayers import FixedPriorLayer, gaussian_link_exp
47
+ def model(x, y):
48
+ mu = FixedPriorLayer()('mu', x)
49
+ return gaussian_link_exp(mu, y)
50
+ ```
51
+
52
+ ## additional layers
53
+
54
+ ### factorization machines
55
+
56
+ Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
57
+
58
+ To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
59
+
60
+ ```python
61
+ from blayers import FMLayer, gaussian_link_exp
62
+ def model(x, y):
63
+ mu = (
64
+ AdaptiveLayer('x', x) +
65
+ AdaptiveLayer('x2', x**2) +
66
+ FMLayer(low_rank_dim=3)('xx', x)
67
+ )
68
+ return gaussian_link_exp(mu, y)
69
+ ```
70
+
71
+ ### uv decomp
72
+
73
+ We also provide a standard UV deccomp for low rank interaction terms
74
+
75
+ ```python
76
+ from blayers import LowRankInteractionLayer, gaussian_link_exp
77
+ def model(x, z, y):
78
+ mu = (
79
+ AdaptiveLayer('x', x) +
80
+ AdaptiveLayer('z', z) +
81
+ LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
82
+ )
83
+ return gaussian_link_exp(mu, y)
84
+ ```
85
+
86
+ ## links
87
+
88
+ We provide link functions as a convenience to abstract away a bit more Numpyro
89
+ boilerplate.
90
+
91
+ We currently provide
92
+
93
+ * `gaussian_link_exp`
94
+
95
+ ## batched loss
96
+
97
+ The default Numpyro way to fit batched VI models is to use `plate`, which confuses
98
+ me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
99
+ you to use `plate` to batch in VI. Just drop your model in.
100
+
101
+ ```python
102
+ from blayers.infer import Batched_Trace_ELBO, svi_run_batched
103
+
104
+ svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
105
+
106
+ svi_result = svi_run_batched(
107
+ svi,
108
+ rng_key,
109
+ num_steps,
110
+ batch_size=1000,
111
+ **model_data,
112
+ )
113
+ ```
@@ -0,0 +1,141 @@
1
+ Metadata-Version: 2.4
2
+ Name: blayers
3
+ Version: 0.1.0a1
4
+ Summary: Bayesian layers for NumPyro and Jax
5
+ Author-email: George Berry <george.e.berry@gmail.com>
6
+ Project-URL: Homepage, https://github.com/georgeberry/blayers
7
+ Project-URL: Documentation, https://georgeberry.github.io/blayers/
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: numpy
12
+ Requires-Dist: jax
13
+ Requires-Dist: numpyro
14
+ Provides-Extra: dev
15
+ Requires-Dist: optax; extra == "dev"
16
+ Requires-Dist: pytest; extra == "dev"
17
+ Requires-Dist: pytest-check; extra == "dev"
18
+ Requires-Dist: mypy; extra == "dev"
19
+ Requires-Dist: black; extra == "dev"
20
+ Requires-Dist: isort; extra == "dev"
21
+ Requires-Dist: autoflake; extra == "dev"
22
+ Requires-Dist: sphinx; extra == "dev"
23
+ Requires-Dist: furo; extra == "dev"
24
+ Requires-Dist: coverage; extra == "dev"
25
+ Requires-Dist: pytest-cov; extra == "dev"
26
+ Requires-Dist: pre-commit; extra == "dev"
27
+ Dynamic: license-file
28
+
29
+ [![codecov](https://codecov.io/gh/georgeberry/blayers/graph/badge.svg?token=ZDGT0C39QM)](https://codecov.io/gh/georgeberry/blayers) [![License](https://img.shields.io/github/license/georgeberry/blayers)](LICENSE) ![version](https://img.shields.io/badge/version-0.1.0a1-blue)
30
+
31
+ # BLayers
32
+
33
+ **NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
34
+
35
+ The missing layers package for Bayesian inference. Inspiration from Keras and
36
+ Tensorflow Probability, but made specifically for Numpyro + Jax.
37
+
38
+ Easily build Bayesian models from parts, abstract away the boilerplate, and
39
+ tweak priors as you wish.
40
+
41
+ Fit models either using Variational Inference (VI) or your sampling method of
42
+ choice. Use BLayer's ELBO implementation to do either batched VI or sampling
43
+ without having to rewrite models.
44
+
45
+ BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
46
+ code to build models of arbitrary complexity. It also gives you a recipe to
47
+ build more complex layers as you wish.
48
+
49
+ ## the starting point
50
+
51
+ The simplest non-trivial (and most important!) Bayesian regression model form is
52
+ the adaptive prior,
53
+
54
+ ```
55
+ lmbda ~ HalfNormal(1)
56
+ beta ~ Normal(0, lmbda)
57
+ y ~ Normal(beta * x, 1)
58
+ ```
59
+
60
+ BLayers takes this as its starting point and most fundamental building block,
61
+ providing the flexible `AdaptiveLayer`.
62
+
63
+ ```python
64
+ from blayers import AdaptiveLayer, gaussian_link_exp
65
+ def model(x, y):
66
+ mu = AdaptiveLayer()('mu', x)
67
+ return gaussian_link_exp(mu, y)
68
+ ```
69
+
70
+ For you purists out there, we also provide a `FixedPriorLayer` for standard
71
+ L1/L2 regression.
72
+
73
+ ```python
74
+ from blayers import FixedPriorLayer, gaussian_link_exp
75
+ def model(x, y):
76
+ mu = FixedPriorLayer()('mu', x)
77
+ return gaussian_link_exp(mu, y)
78
+ ```
79
+
80
+ ## additional layers
81
+
82
+ ### factorization machines
83
+
84
+ Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
85
+
86
+ To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
87
+
88
+ ```python
89
+ from blayers import FMLayer, gaussian_link_exp
90
+ def model(x, y):
91
+ mu = (
92
+ AdaptiveLayer('x', x) +
93
+ AdaptiveLayer('x2', x**2) +
94
+ FMLayer(low_rank_dim=3)('xx', x)
95
+ )
96
+ return gaussian_link_exp(mu, y)
97
+ ```
98
+
99
+ ### uv decomp
100
+
101
+ We also provide a standard UV deccomp for low rank interaction terms
102
+
103
+ ```python
104
+ from blayers import LowRankInteractionLayer, gaussian_link_exp
105
+ def model(x, z, y):
106
+ mu = (
107
+ AdaptiveLayer('x', x) +
108
+ AdaptiveLayer('z', z) +
109
+ LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
110
+ )
111
+ return gaussian_link_exp(mu, y)
112
+ ```
113
+
114
+ ## links
115
+
116
+ We provide link functions as a convenience to abstract away a bit more Numpyro
117
+ boilerplate.
118
+
119
+ We currently provide
120
+
121
+ * `gaussian_link_exp`
122
+
123
+ ## batched loss
124
+
125
+ The default Numpyro way to fit batched VI models is to use `plate`, which confuses
126
+ me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
127
+ you to use `plate` to batch in VI. Just drop your model in.
128
+
129
+ ```python
130
+ from blayers.infer import Batched_Trace_ELBO, svi_run_batched
131
+
132
+ svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
133
+
134
+ svi_result = svi_run_batched(
135
+ svi,
136
+ rng_key,
137
+ num_steps,
138
+ batch_size=1000,
139
+ **model_data,
140
+ )
141
+ ```
@@ -0,0 +1,9 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ blayers/blayers.egg-info/PKG-INFO
5
+ blayers/blayers.egg-info/SOURCES.txt
6
+ blayers/blayers.egg-info/dependency_links.txt
7
+ blayers/blayers.egg-info/requires.txt
8
+ blayers/blayers.egg-info/top_level.txt
9
+ tests/test_layers.py
@@ -0,0 +1,17 @@
1
+ numpy
2
+ jax
3
+ numpyro
4
+
5
+ [dev]
6
+ optax
7
+ pytest
8
+ pytest-check
9
+ mypy
10
+ black
11
+ isort
12
+ autoflake
13
+ sphinx
14
+ furo
15
+ coverage
16
+ pytest-cov
17
+ pre-commit
@@ -0,0 +1,57 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "blayers"
7
+ version = "0.1.0a1"
8
+ description = "Bayesian layers for NumPyro and Jax"
9
+ authors = [{ name = "George Berry", email = "george.e.berry@gmail.com" }]
10
+ readme = "README.md"
11
+ license = { file = "MIT" }
12
+ requires-python = ">=3.9"
13
+ dependencies = [
14
+ "numpy",
15
+ "jax",
16
+ "numpyro",
17
+ ]
18
+
19
+ [project.urls]
20
+ Homepage = "https://github.com/georgeberry/blayers"
21
+ Documentation = "https://georgeberry.github.io/blayers/"
22
+
23
+ [project.optional-dependencies]
24
+ dev = [
25
+ "optax",
26
+ "pytest",
27
+ "pytest-check",
28
+ "mypy",
29
+ "black",
30
+ "isort",
31
+ "autoflake",
32
+ "sphinx",
33
+ "furo",
34
+ "coverage",
35
+ "pytest-cov",
36
+ "pre-commit",
37
+ ]
38
+
39
+ [tool.setuptools.packages.find]
40
+ where = ["blayers"]
41
+
42
+ [tool.mypy]
43
+ strict = true
44
+ ignore_missing_imports = true
45
+ explicit_package_bases = true
46
+ disable_error_code = ["misc"]
47
+
48
+ [tool.isort]
49
+ profile = "black"
50
+ line_length = 80
51
+ multi_line_output = 3
52
+ include_trailing_comma = true
53
+ force_grid_wrap = 0
54
+
55
+ [tool.black]
56
+ line-length = 80
57
+ target-version = ["py311"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,388 @@
1
+ from typing import Any, Callable
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jax.random as random
6
+ import numpyro.distributions as dist
7
+ import optax
8
+ import pytest
9
+ import pytest_check
10
+ from _pytest.fixtures import SubRequest
11
+ from numpyro import sample
12
+ from numpyro.infer import SVI, Predictive, Trace_ELBO
13
+ from numpyro.infer.autoguide import AutoDiagonalNormal
14
+
15
+ from blayers import (
16
+ AdaptiveLayer,
17
+ EmbeddingLayer,
18
+ FixedPriorLayer,
19
+ FMLayer,
20
+ LowRankInteractionLayer,
21
+ )
22
+ from blayers.fit_tools import (
23
+ identity,
24
+ outer_product,
25
+ outer_product_upper_tril_no_diag,
26
+ rmse,
27
+ )
28
+ from blayers.infer import Batched_Trace_ELBO, svi_run_batched
29
+ from blayers.links import gaussian_link_exp
30
+
31
+ N_OBS = 10000
32
+ LOW_RANK_DIM = 3
33
+ EMB_DIM = 1
34
+ N_EMB_CATEGORIES = 10
35
+
36
+
37
+ # ---- Data Generating Processes --------------------------------------------- #
38
+
39
+
40
+ def dgp_simple(n_obs: int, k: int) -> dict[str, jax.Array]:
41
+ lambda1 = sample("lambda1", dist.HalfNormal(1.0))
42
+ beta = sample("beta", dist.Normal(0, lambda1).expand([k]))
43
+
44
+ x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
45
+
46
+ sigma = sample("sigma", dist.HalfNormal(1.0))
47
+ mu = jnp.dot(x1, beta)
48
+ y = sample("y", dist.Normal(mu, sigma))
49
+ return {
50
+ "x1": x1,
51
+ "y": y,
52
+ "beta": beta,
53
+ "lambda1": lambda1,
54
+ "sigma": sigma,
55
+ }
56
+
57
+
58
+ def dgp_fm(n_obs: int, k: int) -> dict[str, jax.Array]:
59
+ x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
60
+ lmbda = sample("lambda", dist.HalfNormal(1.0))
61
+ theta = sample("theta", dist.Normal(0.0, lmbda).expand([k, LOW_RANK_DIM]))
62
+
63
+ sigma = sample("sigma", dist.HalfNormal(1.0))
64
+ mu = FMLayer.matmul(theta, x1)
65
+ y = sample("y", dist.Normal(mu, sigma))
66
+ return {
67
+ "x1": x1,
68
+ "y": y,
69
+ "theta": theta,
70
+ "lambda": lmbda,
71
+ "sigma": sigma,
72
+ }
73
+
74
+
75
+ def dgp_emb(n_obs: int, k: int, n_categories: int) -> dict[str, jax.Array]:
76
+ lmbda = sample("lambda", dist.HalfNormal(1.0))
77
+ beta = sample("beta", dist.Normal(0, lmbda).expand([n_categories, k]))
78
+ x1 = sample(
79
+ "x1",
80
+ dist.Categorical(probs=jnp.ones(n_categories) / n_categories).expand(
81
+ [n_obs]
82
+ ),
83
+ )
84
+
85
+ sigma = sample("sigma", dist.HalfNormal(1.0))
86
+
87
+ mu = jnp.sum(beta[x1], axis=1)
88
+ y = sample("y", dist.Normal(mu, sigma))
89
+ return {
90
+ "x1": x1,
91
+ "y": y,
92
+ "beta": beta,
93
+ "lambda": lmbda,
94
+ "sigma": sigma,
95
+ }
96
+
97
+
98
+ def dgp_lowrank(n_obs: int, k: int) -> dict[str, jax.Array]:
99
+ offset = 1
100
+
101
+ x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
102
+ x2 = sample("x2", dist.Normal(0, 1).expand([n_obs, k + offset]))
103
+
104
+ lambda1 = sample("lambda1", dist.HalfNormal(1.0))
105
+ theta1_lowrank = sample(
106
+ "theta1", dist.Normal(0.0, lambda1).expand([k, LOW_RANK_DIM])
107
+ )
108
+
109
+ lambda2 = sample("lambda2", dist.HalfNormal(1.0))
110
+ theta2_lowrank = sample(
111
+ "theta2", dist.Normal(0.0, lambda1).expand([k + offset, LOW_RANK_DIM])
112
+ )
113
+
114
+ sigma = sample("sigma", dist.HalfNormal(1.0))
115
+ mu = LowRankInteractionLayer.matmul(
116
+ theta1=theta1_lowrank,
117
+ theta2=theta2_lowrank,
118
+ x=x1,
119
+ z=x2,
120
+ )
121
+
122
+ y = sample("y", dist.Normal(mu, sigma))
123
+ return {
124
+ "x1": x1,
125
+ "y": y,
126
+ "theta1": theta1_lowrank,
127
+ "theta2": theta2_lowrank,
128
+ "lambda1": lambda1,
129
+ "lambda2": lambda2,
130
+ "sigma": sigma,
131
+ }
132
+
133
+
134
+ def simulated_data(
135
+ dgp: Callable[..., Any],
136
+ **kwargs: Any,
137
+ ) -> dict[str, jax.Array]:
138
+ rng_key = random.PRNGKey(0)
139
+ predictive = Predictive(dgp, num_samples=1)
140
+ samples = predictive(
141
+ rng_key,
142
+ **kwargs,
143
+ )
144
+ res = {k: jnp.squeeze(v, axis=0) for k, v in samples.items()}
145
+ return res
146
+
147
+
148
+ @pytest.fixture
149
+ def simulated_data_simple() -> dict[str, jax.Array]:
150
+ return simulated_data(dgp_simple, n_obs=N_OBS, k=2)
151
+
152
+
153
+ @pytest.fixture
154
+ def simulated_data_fm() -> dict[str, jax.Array]:
155
+ return simulated_data(dgp_fm, n_obs=N_OBS, k=10)
156
+
157
+
158
+ @pytest.fixture
159
+ def simulated_data_emb() -> dict[str, jax.Array]:
160
+ return simulated_data(
161
+ dgp_emb,
162
+ n_obs=N_OBS,
163
+ k=EMB_DIM,
164
+ n_categories=N_EMB_CATEGORIES,
165
+ )
166
+
167
+
168
+ @pytest.fixture
169
+ def simulated_data_lowrank() -> dict[str, jax.Array]:
170
+ return simulated_data(
171
+ dgp_lowrank,
172
+ n_obs=N_OBS,
173
+ k=10,
174
+ )
175
+
176
+
177
+ # ---- Models ---------------------------------------------------------------- #
178
+
179
+
180
+ @pytest.fixture
181
+ def linear_regression_adaptive_model() -> (
182
+ tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
183
+ ):
184
+ def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
185
+ beta = AdaptiveLayer()("beta", x1)
186
+ return gaussian_link_exp(beta, y)
187
+
188
+ return model, [(["AdaptiveLayer_beta_beta"], identity)]
189
+
190
+
191
+ @pytest.fixture
192
+ def linear_regression_fixed_model() -> (
193
+ tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
194
+ ):
195
+ def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
196
+ beta = FixedPriorLayer()("beta", x1)
197
+ return gaussian_link_exp(beta, y)
198
+
199
+ return model, [(["FixedPriorLayer_beta_beta"], identity)]
200
+
201
+
202
+ @pytest.fixture
203
+ def fm_regression_model() -> (
204
+ tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
205
+ ):
206
+ def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
207
+ theta = FMLayer(low_rank_dim=LOW_RANK_DIM)("theta", x1)
208
+ return gaussian_link_exp(theta, y)
209
+
210
+ return (
211
+ model,
212
+ [
213
+ (["FMLayer_theta_theta"], outer_product_upper_tril_no_diag),
214
+ ],
215
+ )
216
+
217
+
218
+ @pytest.fixture
219
+ def emb_model() -> (
220
+ tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
221
+ ):
222
+ def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
223
+ beta = EmbeddingLayer()(
224
+ "beta",
225
+ x1,
226
+ n_categories=N_EMB_CATEGORIES,
227
+ embedding_dim=EMB_DIM,
228
+ )
229
+ return gaussian_link_exp(beta, y)
230
+
231
+ return (
232
+ model,
233
+ [
234
+ (["EmbeddingLayer_beta_beta"], identity),
235
+ ],
236
+ )
237
+
238
+
239
+ @pytest.fixture
240
+ def lowrank_model() -> (
241
+ tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
242
+ ):
243
+ def model(x1: jax.Array, x2: jax.Array, y: jax.Array | None = None) -> Any:
244
+ beta1 = LowRankInteractionLayer(low_rank_dim=LOW_RANK_DIM)(
245
+ "lowrank",
246
+ x1,
247
+ x2,
248
+ )
249
+ return gaussian_link_exp(beta1, y)
250
+
251
+ return (
252
+ model,
253
+ [
254
+ (
255
+ [
256
+ "LowRankInteractionLayer_lowrank_theta1",
257
+ "LowRankInteractionLayer_lowrank_theta2",
258
+ ],
259
+ outer_product,
260
+ ),
261
+ ],
262
+ )
263
+
264
+
265
+ # ---- Loss classes ---------------------------------------------------------- #
266
+
267
+
268
+ @pytest.fixture
269
+ def trace_elbo() -> Trace_ELBO:
270
+ return Trace_ELBO()
271
+
272
+
273
+ @pytest.fixture
274
+ def trace_elbo_batched() -> Batched_Trace_ELBO:
275
+ return Batched_Trace_ELBO(n_obs=N_OBS)
276
+
277
+
278
+ # ---- Dispatchers ----------------------------------------------------------- #
279
+
280
+ """
281
+ These are pytest helpers that let us cycle through fixtures. This setup is a
282
+ little wonky and I'm sure we could come up with something better in the long
283
+ run, but it works for now. Just make one with the name for the thing you want
284
+ to pass to the ultimate test function.
285
+ """
286
+
287
+
288
+ @pytest.fixture
289
+ def model(request: SubRequest) -> Any:
290
+ return request.getfixturevalue(request.param)
291
+
292
+
293
+ @pytest.fixture
294
+ def data(request: SubRequest) -> Any:
295
+ return request.getfixturevalue(request.param)
296
+
297
+
298
+ @pytest.fixture
299
+ def loss_instance(request: SubRequest) -> Any:
300
+ return request.getfixturevalue(request.param)
301
+
302
+
303
+ # ---- Test functions -------------------------------------------------------- #
304
+
305
+
306
+ @pytest.mark.parametrize(
307
+ "loss_instance",
308
+ [
309
+ "trace_elbo",
310
+ "trace_elbo_batched",
311
+ ],
312
+ indirect=True,
313
+ )
314
+ @pytest.mark.parametrize(
315
+ ("model", "data"),
316
+ [
317
+ ("linear_regression_adaptive_model", "simulated_data_simple"),
318
+ ("linear_regression_fixed_model", "simulated_data_simple"),
319
+ ("fm_regression_model", "simulated_data_fm"),
320
+ ("emb_model", "simulated_data_emb"),
321
+ ("lowrank_model", "simulated_data_lowrank"),
322
+ ],
323
+ indirect=True,
324
+ )
325
+ def test_models(
326
+ data: Any,
327
+ model: Any,
328
+ loss_instance: Any,
329
+ ) -> Any:
330
+ model_fn, coef_groups = model
331
+ model_data = {k: v for k, v in data.items() if k in ("y", "x1", "x2")}
332
+
333
+ guide = AutoDiagonalNormal(model_fn)
334
+
335
+ num_steps = 30000
336
+
337
+ schedule = optax.cosine_onecycle_schedule(
338
+ transition_steps=num_steps,
339
+ peak_value=5e-2,
340
+ pct_start=0.1,
341
+ div_factor=25,
342
+ )
343
+
344
+ svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
345
+
346
+ rng_key = random.PRNGKey(2)
347
+
348
+ if isinstance(loss_instance, Trace_ELBO):
349
+ svi_result = svi.run(
350
+ rng_key,
351
+ num_steps=num_steps,
352
+ **model_data,
353
+ )
354
+ if isinstance(loss_instance, Batched_Trace_ELBO):
355
+ svi_result = svi_run_batched(
356
+ svi,
357
+ rng_key,
358
+ num_steps,
359
+ batch_size=1000,
360
+ **model_data,
361
+ )
362
+ guide_predicitive = Predictive(
363
+ guide,
364
+ params=svi_result.params,
365
+ num_samples=1000,
366
+ )
367
+ guide_samples = guide_predicitive(
368
+ random.PRNGKey(1),
369
+ **{k: v for k, v in model_data.items() if k != "y"},
370
+ )
371
+ guide_means = {k: jnp.mean(v, axis=0) for k, v in guide_samples.items()}
372
+
373
+ for coef_list, coef_fn in coef_groups:
374
+ with pytest_check.check:
375
+ val = rmse(
376
+ coef_fn(*[guide_means[x] for x in coef_list]),
377
+ coef_fn(*[data[x.split("_")[2]] for x in coef_list]),
378
+ )
379
+ assert val < 0.1
380
+
381
+ with pytest_check.check:
382
+ assert (
383
+ rmse(
384
+ guide_means["sigma"],
385
+ data["sigma"],
386
+ )
387
+ < 0.03
388
+ )