blayers 0.1.0a1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blayers-0.1.0a1/LICENSE +21 -0
- blayers-0.1.0a1/PKG-INFO +141 -0
- blayers-0.1.0a1/README.md +113 -0
- blayers-0.1.0a1/blayers/blayers.egg-info/PKG-INFO +141 -0
- blayers-0.1.0a1/blayers/blayers.egg-info/SOURCES.txt +9 -0
- blayers-0.1.0a1/blayers/blayers.egg-info/dependency_links.txt +1 -0
- blayers-0.1.0a1/blayers/blayers.egg-info/requires.txt +17 -0
- blayers-0.1.0a1/blayers/blayers.egg-info/top_level.txt +1 -0
- blayers-0.1.0a1/pyproject.toml +57 -0
- blayers-0.1.0a1/setup.cfg +4 -0
- blayers-0.1.0a1/tests/test_layers.py +388 -0
blayers-0.1.0a1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 George Berry
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
blayers-0.1.0a1/PKG-INFO
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: blayers
|
|
3
|
+
Version: 0.1.0a1
|
|
4
|
+
Summary: Bayesian layers for NumPyro and Jax
|
|
5
|
+
Author-email: George Berry <george.e.berry@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/georgeberry/blayers
|
|
7
|
+
Project-URL: Documentation, https://georgeberry.github.io/blayers/
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: numpy
|
|
12
|
+
Requires-Dist: jax
|
|
13
|
+
Requires-Dist: numpyro
|
|
14
|
+
Provides-Extra: dev
|
|
15
|
+
Requires-Dist: optax; extra == "dev"
|
|
16
|
+
Requires-Dist: pytest; extra == "dev"
|
|
17
|
+
Requires-Dist: pytest-check; extra == "dev"
|
|
18
|
+
Requires-Dist: mypy; extra == "dev"
|
|
19
|
+
Requires-Dist: black; extra == "dev"
|
|
20
|
+
Requires-Dist: isort; extra == "dev"
|
|
21
|
+
Requires-Dist: autoflake; extra == "dev"
|
|
22
|
+
Requires-Dist: sphinx; extra == "dev"
|
|
23
|
+
Requires-Dist: furo; extra == "dev"
|
|
24
|
+
Requires-Dist: coverage; extra == "dev"
|
|
25
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
26
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
[](https://codecov.io/gh/georgeberry/blayers) [](LICENSE) 
|
|
30
|
+
|
|
31
|
+
# BLayers
|
|
32
|
+
|
|
33
|
+
**NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
|
|
34
|
+
|
|
35
|
+
The missing layers package for Bayesian inference. Inspiration from Keras and
|
|
36
|
+
Tensorflow Probability, but made specifically for Numpyro + Jax.
|
|
37
|
+
|
|
38
|
+
Easily build Bayesian models from parts, abstract away the boilerplate, and
|
|
39
|
+
tweak priors as you wish.
|
|
40
|
+
|
|
41
|
+
Fit models either using Variational Inference (VI) or your sampling method of
|
|
42
|
+
choice. Use BLayer's ELBO implementation to do either batched VI or sampling
|
|
43
|
+
without having to rewrite models.
|
|
44
|
+
|
|
45
|
+
BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
|
|
46
|
+
code to build models of arbitrary complexity. It also gives you a recipe to
|
|
47
|
+
build more complex layers as you wish.
|
|
48
|
+
|
|
49
|
+
## the starting point
|
|
50
|
+
|
|
51
|
+
The simplest non-trivial (and most important!) Bayesian regression model form is
|
|
52
|
+
the adaptive prior,
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
lmbda ~ HalfNormal(1)
|
|
56
|
+
beta ~ Normal(0, lmbda)
|
|
57
|
+
y ~ Normal(beta * x, 1)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
BLayers takes this as its starting point and most fundamental building block,
|
|
61
|
+
providing the flexible `AdaptiveLayer`.
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from blayers import AdaptiveLayer, gaussian_link_exp
|
|
65
|
+
def model(x, y):
|
|
66
|
+
mu = AdaptiveLayer()('mu', x)
|
|
67
|
+
return gaussian_link_exp(mu, y)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
For you purists out there, we also provide a `FixedPriorLayer` for standard
|
|
71
|
+
L1/L2 regression.
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from blayers import FixedPriorLayer, gaussian_link_exp
|
|
75
|
+
def model(x, y):
|
|
76
|
+
mu = FixedPriorLayer()('mu', x)
|
|
77
|
+
return gaussian_link_exp(mu, y)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## additional layers
|
|
81
|
+
|
|
82
|
+
### factorization machines
|
|
83
|
+
|
|
84
|
+
Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
|
|
85
|
+
|
|
86
|
+
To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
from blayers import FMLayer, gaussian_link_exp
|
|
90
|
+
def model(x, y):
|
|
91
|
+
mu = (
|
|
92
|
+
AdaptiveLayer('x', x) +
|
|
93
|
+
AdaptiveLayer('x2', x**2) +
|
|
94
|
+
FMLayer(low_rank_dim=3)('xx', x)
|
|
95
|
+
)
|
|
96
|
+
return gaussian_link_exp(mu, y)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### uv decomp
|
|
100
|
+
|
|
101
|
+
We also provide a standard UV deccomp for low rank interaction terms
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
from blayers import LowRankInteractionLayer, gaussian_link_exp
|
|
105
|
+
def model(x, z, y):
|
|
106
|
+
mu = (
|
|
107
|
+
AdaptiveLayer('x', x) +
|
|
108
|
+
AdaptiveLayer('z', z) +
|
|
109
|
+
LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
|
|
110
|
+
)
|
|
111
|
+
return gaussian_link_exp(mu, y)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## links
|
|
115
|
+
|
|
116
|
+
We provide link functions as a convenience to abstract away a bit more Numpyro
|
|
117
|
+
boilerplate.
|
|
118
|
+
|
|
119
|
+
We currently provide
|
|
120
|
+
|
|
121
|
+
* `gaussian_link_exp`
|
|
122
|
+
|
|
123
|
+
## batched loss
|
|
124
|
+
|
|
125
|
+
The default Numpyro way to fit batched VI models is to use `plate`, which confuses
|
|
126
|
+
me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
|
|
127
|
+
you to use `plate` to batch in VI. Just drop your model in.
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
from blayers.infer import Batched_Trace_ELBO, svi_run_batched
|
|
131
|
+
|
|
132
|
+
svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
|
|
133
|
+
|
|
134
|
+
svi_result = svi_run_batched(
|
|
135
|
+
svi,
|
|
136
|
+
rng_key,
|
|
137
|
+
num_steps,
|
|
138
|
+
batch_size=1000,
|
|
139
|
+
**model_data,
|
|
140
|
+
)
|
|
141
|
+
```
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
[](https://codecov.io/gh/georgeberry/blayers) [](LICENSE) 
|
|
2
|
+
|
|
3
|
+
# BLayers
|
|
4
|
+
|
|
5
|
+
**NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
|
|
6
|
+
|
|
7
|
+
The missing layers package for Bayesian inference. Inspiration from Keras and
|
|
8
|
+
Tensorflow Probability, but made specifically for Numpyro + Jax.
|
|
9
|
+
|
|
10
|
+
Easily build Bayesian models from parts, abstract away the boilerplate, and
|
|
11
|
+
tweak priors as you wish.
|
|
12
|
+
|
|
13
|
+
Fit models either using Variational Inference (VI) or your sampling method of
|
|
14
|
+
choice. Use BLayer's ELBO implementation to do either batched VI or sampling
|
|
15
|
+
without having to rewrite models.
|
|
16
|
+
|
|
17
|
+
BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
|
|
18
|
+
code to build models of arbitrary complexity. It also gives you a recipe to
|
|
19
|
+
build more complex layers as you wish.
|
|
20
|
+
|
|
21
|
+
## the starting point
|
|
22
|
+
|
|
23
|
+
The simplest non-trivial (and most important!) Bayesian regression model form is
|
|
24
|
+
the adaptive prior,
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
lmbda ~ HalfNormal(1)
|
|
28
|
+
beta ~ Normal(0, lmbda)
|
|
29
|
+
y ~ Normal(beta * x, 1)
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
BLayers takes this as its starting point and most fundamental building block,
|
|
33
|
+
providing the flexible `AdaptiveLayer`.
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
from blayers import AdaptiveLayer, gaussian_link_exp
|
|
37
|
+
def model(x, y):
|
|
38
|
+
mu = AdaptiveLayer()('mu', x)
|
|
39
|
+
return gaussian_link_exp(mu, y)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
For you purists out there, we also provide a `FixedPriorLayer` for standard
|
|
43
|
+
L1/L2 regression.
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from blayers import FixedPriorLayer, gaussian_link_exp
|
|
47
|
+
def model(x, y):
|
|
48
|
+
mu = FixedPriorLayer()('mu', x)
|
|
49
|
+
return gaussian_link_exp(mu, y)
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## additional layers
|
|
53
|
+
|
|
54
|
+
### factorization machines
|
|
55
|
+
|
|
56
|
+
Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
|
|
57
|
+
|
|
58
|
+
To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
from blayers import FMLayer, gaussian_link_exp
|
|
62
|
+
def model(x, y):
|
|
63
|
+
mu = (
|
|
64
|
+
AdaptiveLayer('x', x) +
|
|
65
|
+
AdaptiveLayer('x2', x**2) +
|
|
66
|
+
FMLayer(low_rank_dim=3)('xx', x)
|
|
67
|
+
)
|
|
68
|
+
return gaussian_link_exp(mu, y)
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
### uv decomp
|
|
72
|
+
|
|
73
|
+
We also provide a standard UV deccomp for low rank interaction terms
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
from blayers import LowRankInteractionLayer, gaussian_link_exp
|
|
77
|
+
def model(x, z, y):
|
|
78
|
+
mu = (
|
|
79
|
+
AdaptiveLayer('x', x) +
|
|
80
|
+
AdaptiveLayer('z', z) +
|
|
81
|
+
LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
|
|
82
|
+
)
|
|
83
|
+
return gaussian_link_exp(mu, y)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## links
|
|
87
|
+
|
|
88
|
+
We provide link functions as a convenience to abstract away a bit more Numpyro
|
|
89
|
+
boilerplate.
|
|
90
|
+
|
|
91
|
+
We currently provide
|
|
92
|
+
|
|
93
|
+
* `gaussian_link_exp`
|
|
94
|
+
|
|
95
|
+
## batched loss
|
|
96
|
+
|
|
97
|
+
The default Numpyro way to fit batched VI models is to use `plate`, which confuses
|
|
98
|
+
me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
|
|
99
|
+
you to use `plate` to batch in VI. Just drop your model in.
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from blayers.infer import Batched_Trace_ELBO, svi_run_batched
|
|
103
|
+
|
|
104
|
+
svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
|
|
105
|
+
|
|
106
|
+
svi_result = svi_run_batched(
|
|
107
|
+
svi,
|
|
108
|
+
rng_key,
|
|
109
|
+
num_steps,
|
|
110
|
+
batch_size=1000,
|
|
111
|
+
**model_data,
|
|
112
|
+
)
|
|
113
|
+
```
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: blayers
|
|
3
|
+
Version: 0.1.0a1
|
|
4
|
+
Summary: Bayesian layers for NumPyro and Jax
|
|
5
|
+
Author-email: George Berry <george.e.berry@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/georgeberry/blayers
|
|
7
|
+
Project-URL: Documentation, https://georgeberry.github.io/blayers/
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: numpy
|
|
12
|
+
Requires-Dist: jax
|
|
13
|
+
Requires-Dist: numpyro
|
|
14
|
+
Provides-Extra: dev
|
|
15
|
+
Requires-Dist: optax; extra == "dev"
|
|
16
|
+
Requires-Dist: pytest; extra == "dev"
|
|
17
|
+
Requires-Dist: pytest-check; extra == "dev"
|
|
18
|
+
Requires-Dist: mypy; extra == "dev"
|
|
19
|
+
Requires-Dist: black; extra == "dev"
|
|
20
|
+
Requires-Dist: isort; extra == "dev"
|
|
21
|
+
Requires-Dist: autoflake; extra == "dev"
|
|
22
|
+
Requires-Dist: sphinx; extra == "dev"
|
|
23
|
+
Requires-Dist: furo; extra == "dev"
|
|
24
|
+
Requires-Dist: coverage; extra == "dev"
|
|
25
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
26
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
[](https://codecov.io/gh/georgeberry/blayers) [](LICENSE) 
|
|
30
|
+
|
|
31
|
+
# BLayers
|
|
32
|
+
|
|
33
|
+
**NOTE: BLayers is in alpha. Expect changes. Feedback welcome.**
|
|
34
|
+
|
|
35
|
+
The missing layers package for Bayesian inference. Inspiration from Keras and
|
|
36
|
+
Tensorflow Probability, but made specifically for Numpyro + Jax.
|
|
37
|
+
|
|
38
|
+
Easily build Bayesian models from parts, abstract away the boilerplate, and
|
|
39
|
+
tweak priors as you wish.
|
|
40
|
+
|
|
41
|
+
Fit models either using Variational Inference (VI) or your sampling method of
|
|
42
|
+
choice. Use BLayer's ELBO implementation to do either batched VI or sampling
|
|
43
|
+
without having to rewrite models.
|
|
44
|
+
|
|
45
|
+
BLayers helps you write pure Numpyro, so you can integrate it with any Numpyro
|
|
46
|
+
code to build models of arbitrary complexity. It also gives you a recipe to
|
|
47
|
+
build more complex layers as you wish.
|
|
48
|
+
|
|
49
|
+
## the starting point
|
|
50
|
+
|
|
51
|
+
The simplest non-trivial (and most important!) Bayesian regression model form is
|
|
52
|
+
the adaptive prior,
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
lmbda ~ HalfNormal(1)
|
|
56
|
+
beta ~ Normal(0, lmbda)
|
|
57
|
+
y ~ Normal(beta * x, 1)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
BLayers takes this as its starting point and most fundamental building block,
|
|
61
|
+
providing the flexible `AdaptiveLayer`.
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from blayers import AdaptiveLayer, gaussian_link_exp
|
|
65
|
+
def model(x, y):
|
|
66
|
+
mu = AdaptiveLayer()('mu', x)
|
|
67
|
+
return gaussian_link_exp(mu, y)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
For you purists out there, we also provide a `FixedPriorLayer` for standard
|
|
71
|
+
L1/L2 regression.
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from blayers import FixedPriorLayer, gaussian_link_exp
|
|
75
|
+
def model(x, y):
|
|
76
|
+
mu = FixedPriorLayer()('mu', x)
|
|
77
|
+
return gaussian_link_exp(mu, y)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## additional layers
|
|
81
|
+
|
|
82
|
+
### factorization machines
|
|
83
|
+
|
|
84
|
+
Developed in [Rendle 2010](https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf) and [Rendle 2011](https://www.ismll.uni-hildesheim.de/pub/pdfs/FreudenthalerRendle_BayesianFactorizationMachines.pdf), FMs provide a low-rank approximation to the `x`-by-`x` interaction matrix. For those familiar with R syntax, it is an approximation to `y ~ x:x`, excluding the x^2 terms.
|
|
85
|
+
|
|
86
|
+
To fit the equivalent of an r model like `y ~ x*x` (all main effects, x^2 terms, and one-way interaction effects), you'd do
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
from blayers import FMLayer, gaussian_link_exp
|
|
90
|
+
def model(x, y):
|
|
91
|
+
mu = (
|
|
92
|
+
AdaptiveLayer('x', x) +
|
|
93
|
+
AdaptiveLayer('x2', x**2) +
|
|
94
|
+
FMLayer(low_rank_dim=3)('xx', x)
|
|
95
|
+
)
|
|
96
|
+
return gaussian_link_exp(mu, y)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### uv decomp
|
|
100
|
+
|
|
101
|
+
We also provide a standard UV deccomp for low rank interaction terms
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
from blayers import LowRankInteractionLayer, gaussian_link_exp
|
|
105
|
+
def model(x, z, y):
|
|
106
|
+
mu = (
|
|
107
|
+
AdaptiveLayer('x', x) +
|
|
108
|
+
AdaptiveLayer('z', z) +
|
|
109
|
+
LowRankInteractionLayer(low_rank_dim=3)('xz', x, z)
|
|
110
|
+
)
|
|
111
|
+
return gaussian_link_exp(mu, y)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## links
|
|
115
|
+
|
|
116
|
+
We provide link functions as a convenience to abstract away a bit more Numpyro
|
|
117
|
+
boilerplate.
|
|
118
|
+
|
|
119
|
+
We currently provide
|
|
120
|
+
|
|
121
|
+
* `gaussian_link_exp`
|
|
122
|
+
|
|
123
|
+
## batched loss
|
|
124
|
+
|
|
125
|
+
The default Numpyro way to fit batched VI models is to use `plate`, which confuses
|
|
126
|
+
me a lot. Instead, BLayers provides `Batched_Trace_ELBO` which does not require
|
|
127
|
+
you to use `plate` to batch in VI. Just drop your model in.
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
from blayers.infer import Batched_Trace_ELBO, svi_run_batched
|
|
131
|
+
|
|
132
|
+
svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
|
|
133
|
+
|
|
134
|
+
svi_result = svi_run_batched(
|
|
135
|
+
svi,
|
|
136
|
+
rng_key,
|
|
137
|
+
num_steps,
|
|
138
|
+
batch_size=1000,
|
|
139
|
+
**model_data,
|
|
140
|
+
)
|
|
141
|
+
```
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "blayers"
|
|
7
|
+
version = "0.1.0a1"
|
|
8
|
+
description = "Bayesian layers for NumPyro and Jax"
|
|
9
|
+
authors = [{ name = "George Berry", email = "george.e.berry@gmail.com" }]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
license = { file = "MIT" }
|
|
12
|
+
requires-python = ">=3.9"
|
|
13
|
+
dependencies = [
|
|
14
|
+
"numpy",
|
|
15
|
+
"jax",
|
|
16
|
+
"numpyro",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
[project.urls]
|
|
20
|
+
Homepage = "https://github.com/georgeberry/blayers"
|
|
21
|
+
Documentation = "https://georgeberry.github.io/blayers/"
|
|
22
|
+
|
|
23
|
+
[project.optional-dependencies]
|
|
24
|
+
dev = [
|
|
25
|
+
"optax",
|
|
26
|
+
"pytest",
|
|
27
|
+
"pytest-check",
|
|
28
|
+
"mypy",
|
|
29
|
+
"black",
|
|
30
|
+
"isort",
|
|
31
|
+
"autoflake",
|
|
32
|
+
"sphinx",
|
|
33
|
+
"furo",
|
|
34
|
+
"coverage",
|
|
35
|
+
"pytest-cov",
|
|
36
|
+
"pre-commit",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
[tool.setuptools.packages.find]
|
|
40
|
+
where = ["blayers"]
|
|
41
|
+
|
|
42
|
+
[tool.mypy]
|
|
43
|
+
strict = true
|
|
44
|
+
ignore_missing_imports = true
|
|
45
|
+
explicit_package_bases = true
|
|
46
|
+
disable_error_code = ["misc"]
|
|
47
|
+
|
|
48
|
+
[tool.isort]
|
|
49
|
+
profile = "black"
|
|
50
|
+
line_length = 80
|
|
51
|
+
multi_line_output = 3
|
|
52
|
+
include_trailing_comma = true
|
|
53
|
+
force_grid_wrap = 0
|
|
54
|
+
|
|
55
|
+
[tool.black]
|
|
56
|
+
line-length = 80
|
|
57
|
+
target-version = ["py311"]
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
from typing import Any, Callable
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import jax.random as random
|
|
6
|
+
import numpyro.distributions as dist
|
|
7
|
+
import optax
|
|
8
|
+
import pytest
|
|
9
|
+
import pytest_check
|
|
10
|
+
from _pytest.fixtures import SubRequest
|
|
11
|
+
from numpyro import sample
|
|
12
|
+
from numpyro.infer import SVI, Predictive, Trace_ELBO
|
|
13
|
+
from numpyro.infer.autoguide import AutoDiagonalNormal
|
|
14
|
+
|
|
15
|
+
from blayers import (
|
|
16
|
+
AdaptiveLayer,
|
|
17
|
+
EmbeddingLayer,
|
|
18
|
+
FixedPriorLayer,
|
|
19
|
+
FMLayer,
|
|
20
|
+
LowRankInteractionLayer,
|
|
21
|
+
)
|
|
22
|
+
from blayers.fit_tools import (
|
|
23
|
+
identity,
|
|
24
|
+
outer_product,
|
|
25
|
+
outer_product_upper_tril_no_diag,
|
|
26
|
+
rmse,
|
|
27
|
+
)
|
|
28
|
+
from blayers.infer import Batched_Trace_ELBO, svi_run_batched
|
|
29
|
+
from blayers.links import gaussian_link_exp
|
|
30
|
+
|
|
31
|
+
N_OBS = 10000
|
|
32
|
+
LOW_RANK_DIM = 3
|
|
33
|
+
EMB_DIM = 1
|
|
34
|
+
N_EMB_CATEGORIES = 10
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---- Data Generating Processes --------------------------------------------- #
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def dgp_simple(n_obs: int, k: int) -> dict[str, jax.Array]:
|
|
41
|
+
lambda1 = sample("lambda1", dist.HalfNormal(1.0))
|
|
42
|
+
beta = sample("beta", dist.Normal(0, lambda1).expand([k]))
|
|
43
|
+
|
|
44
|
+
x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
|
|
45
|
+
|
|
46
|
+
sigma = sample("sigma", dist.HalfNormal(1.0))
|
|
47
|
+
mu = jnp.dot(x1, beta)
|
|
48
|
+
y = sample("y", dist.Normal(mu, sigma))
|
|
49
|
+
return {
|
|
50
|
+
"x1": x1,
|
|
51
|
+
"y": y,
|
|
52
|
+
"beta": beta,
|
|
53
|
+
"lambda1": lambda1,
|
|
54
|
+
"sigma": sigma,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def dgp_fm(n_obs: int, k: int) -> dict[str, jax.Array]:
|
|
59
|
+
x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
|
|
60
|
+
lmbda = sample("lambda", dist.HalfNormal(1.0))
|
|
61
|
+
theta = sample("theta", dist.Normal(0.0, lmbda).expand([k, LOW_RANK_DIM]))
|
|
62
|
+
|
|
63
|
+
sigma = sample("sigma", dist.HalfNormal(1.0))
|
|
64
|
+
mu = FMLayer.matmul(theta, x1)
|
|
65
|
+
y = sample("y", dist.Normal(mu, sigma))
|
|
66
|
+
return {
|
|
67
|
+
"x1": x1,
|
|
68
|
+
"y": y,
|
|
69
|
+
"theta": theta,
|
|
70
|
+
"lambda": lmbda,
|
|
71
|
+
"sigma": sigma,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def dgp_emb(n_obs: int, k: int, n_categories: int) -> dict[str, jax.Array]:
|
|
76
|
+
lmbda = sample("lambda", dist.HalfNormal(1.0))
|
|
77
|
+
beta = sample("beta", dist.Normal(0, lmbda).expand([n_categories, k]))
|
|
78
|
+
x1 = sample(
|
|
79
|
+
"x1",
|
|
80
|
+
dist.Categorical(probs=jnp.ones(n_categories) / n_categories).expand(
|
|
81
|
+
[n_obs]
|
|
82
|
+
),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
sigma = sample("sigma", dist.HalfNormal(1.0))
|
|
86
|
+
|
|
87
|
+
mu = jnp.sum(beta[x1], axis=1)
|
|
88
|
+
y = sample("y", dist.Normal(mu, sigma))
|
|
89
|
+
return {
|
|
90
|
+
"x1": x1,
|
|
91
|
+
"y": y,
|
|
92
|
+
"beta": beta,
|
|
93
|
+
"lambda": lmbda,
|
|
94
|
+
"sigma": sigma,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def dgp_lowrank(n_obs: int, k: int) -> dict[str, jax.Array]:
|
|
99
|
+
offset = 1
|
|
100
|
+
|
|
101
|
+
x1 = sample("x1", dist.Normal(0, 1).expand([n_obs, k]))
|
|
102
|
+
x2 = sample("x2", dist.Normal(0, 1).expand([n_obs, k + offset]))
|
|
103
|
+
|
|
104
|
+
lambda1 = sample("lambda1", dist.HalfNormal(1.0))
|
|
105
|
+
theta1_lowrank = sample(
|
|
106
|
+
"theta1", dist.Normal(0.0, lambda1).expand([k, LOW_RANK_DIM])
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
lambda2 = sample("lambda2", dist.HalfNormal(1.0))
|
|
110
|
+
theta2_lowrank = sample(
|
|
111
|
+
"theta2", dist.Normal(0.0, lambda1).expand([k + offset, LOW_RANK_DIM])
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
sigma = sample("sigma", dist.HalfNormal(1.0))
|
|
115
|
+
mu = LowRankInteractionLayer.matmul(
|
|
116
|
+
theta1=theta1_lowrank,
|
|
117
|
+
theta2=theta2_lowrank,
|
|
118
|
+
x=x1,
|
|
119
|
+
z=x2,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
y = sample("y", dist.Normal(mu, sigma))
|
|
123
|
+
return {
|
|
124
|
+
"x1": x1,
|
|
125
|
+
"y": y,
|
|
126
|
+
"theta1": theta1_lowrank,
|
|
127
|
+
"theta2": theta2_lowrank,
|
|
128
|
+
"lambda1": lambda1,
|
|
129
|
+
"lambda2": lambda2,
|
|
130
|
+
"sigma": sigma,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def simulated_data(
|
|
135
|
+
dgp: Callable[..., Any],
|
|
136
|
+
**kwargs: Any,
|
|
137
|
+
) -> dict[str, jax.Array]:
|
|
138
|
+
rng_key = random.PRNGKey(0)
|
|
139
|
+
predictive = Predictive(dgp, num_samples=1)
|
|
140
|
+
samples = predictive(
|
|
141
|
+
rng_key,
|
|
142
|
+
**kwargs,
|
|
143
|
+
)
|
|
144
|
+
res = {k: jnp.squeeze(v, axis=0) for k, v in samples.items()}
|
|
145
|
+
return res
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@pytest.fixture
|
|
149
|
+
def simulated_data_simple() -> dict[str, jax.Array]:
|
|
150
|
+
return simulated_data(dgp_simple, n_obs=N_OBS, k=2)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.fixture
|
|
154
|
+
def simulated_data_fm() -> dict[str, jax.Array]:
|
|
155
|
+
return simulated_data(dgp_fm, n_obs=N_OBS, k=10)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@pytest.fixture
|
|
159
|
+
def simulated_data_emb() -> dict[str, jax.Array]:
|
|
160
|
+
return simulated_data(
|
|
161
|
+
dgp_emb,
|
|
162
|
+
n_obs=N_OBS,
|
|
163
|
+
k=EMB_DIM,
|
|
164
|
+
n_categories=N_EMB_CATEGORIES,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@pytest.fixture
|
|
169
|
+
def simulated_data_lowrank() -> dict[str, jax.Array]:
|
|
170
|
+
return simulated_data(
|
|
171
|
+
dgp_lowrank,
|
|
172
|
+
n_obs=N_OBS,
|
|
173
|
+
k=10,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# ---- Models ---------------------------------------------------------------- #
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@pytest.fixture
|
|
181
|
+
def linear_regression_adaptive_model() -> (
|
|
182
|
+
tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
|
|
183
|
+
):
|
|
184
|
+
def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
|
|
185
|
+
beta = AdaptiveLayer()("beta", x1)
|
|
186
|
+
return gaussian_link_exp(beta, y)
|
|
187
|
+
|
|
188
|
+
return model, [(["AdaptiveLayer_beta_beta"], identity)]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@pytest.fixture
|
|
192
|
+
def linear_regression_fixed_model() -> (
|
|
193
|
+
tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
|
|
194
|
+
):
|
|
195
|
+
def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
|
|
196
|
+
beta = FixedPriorLayer()("beta", x1)
|
|
197
|
+
return gaussian_link_exp(beta, y)
|
|
198
|
+
|
|
199
|
+
return model, [(["FixedPriorLayer_beta_beta"], identity)]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@pytest.fixture
|
|
203
|
+
def fm_regression_model() -> (
|
|
204
|
+
tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
|
|
205
|
+
):
|
|
206
|
+
def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
|
|
207
|
+
theta = FMLayer(low_rank_dim=LOW_RANK_DIM)("theta", x1)
|
|
208
|
+
return gaussian_link_exp(theta, y)
|
|
209
|
+
|
|
210
|
+
return (
|
|
211
|
+
model,
|
|
212
|
+
[
|
|
213
|
+
(["FMLayer_theta_theta"], outer_product_upper_tril_no_diag),
|
|
214
|
+
],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.fixture
|
|
219
|
+
def emb_model() -> (
|
|
220
|
+
tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
|
|
221
|
+
):
|
|
222
|
+
def model(x1: jax.Array, y: jax.Array | None = None) -> Any:
|
|
223
|
+
beta = EmbeddingLayer()(
|
|
224
|
+
"beta",
|
|
225
|
+
x1,
|
|
226
|
+
n_categories=N_EMB_CATEGORIES,
|
|
227
|
+
embedding_dim=EMB_DIM,
|
|
228
|
+
)
|
|
229
|
+
return gaussian_link_exp(beta, y)
|
|
230
|
+
|
|
231
|
+
return (
|
|
232
|
+
model,
|
|
233
|
+
[
|
|
234
|
+
(["EmbeddingLayer_beta_beta"], identity),
|
|
235
|
+
],
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@pytest.fixture
|
|
240
|
+
def lowrank_model() -> (
|
|
241
|
+
tuple[Callable[..., Any], list[tuple[list[str], Callable[..., jax.Array]]]]
|
|
242
|
+
):
|
|
243
|
+
def model(x1: jax.Array, x2: jax.Array, y: jax.Array | None = None) -> Any:
|
|
244
|
+
beta1 = LowRankInteractionLayer(low_rank_dim=LOW_RANK_DIM)(
|
|
245
|
+
"lowrank",
|
|
246
|
+
x1,
|
|
247
|
+
x2,
|
|
248
|
+
)
|
|
249
|
+
return gaussian_link_exp(beta1, y)
|
|
250
|
+
|
|
251
|
+
return (
|
|
252
|
+
model,
|
|
253
|
+
[
|
|
254
|
+
(
|
|
255
|
+
[
|
|
256
|
+
"LowRankInteractionLayer_lowrank_theta1",
|
|
257
|
+
"LowRankInteractionLayer_lowrank_theta2",
|
|
258
|
+
],
|
|
259
|
+
outer_product,
|
|
260
|
+
),
|
|
261
|
+
],
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# ---- Loss classes ---------------------------------------------------------- #
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@pytest.fixture
|
|
269
|
+
def trace_elbo() -> Trace_ELBO:
|
|
270
|
+
return Trace_ELBO()
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@pytest.fixture
|
|
274
|
+
def trace_elbo_batched() -> Batched_Trace_ELBO:
|
|
275
|
+
return Batched_Trace_ELBO(n_obs=N_OBS)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# ---- Dispatchers ----------------------------------------------------------- #
|
|
279
|
+
|
|
280
|
+
"""
|
|
281
|
+
These are pytest helpers that let us cycle through fixtures. This setup is a
|
|
282
|
+
little wonky and I'm sure we could come up with something better in the long
|
|
283
|
+
run, but it works for now. Just make one with the name for the thing you want
|
|
284
|
+
to pass to the ultimate test function.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@pytest.fixture
|
|
289
|
+
def model(request: SubRequest) -> Any:
|
|
290
|
+
return request.getfixturevalue(request.param)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@pytest.fixture
|
|
294
|
+
def data(request: SubRequest) -> Any:
|
|
295
|
+
return request.getfixturevalue(request.param)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@pytest.fixture
|
|
299
|
+
def loss_instance(request: SubRequest) -> Any:
|
|
300
|
+
return request.getfixturevalue(request.param)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# ---- Test functions -------------------------------------------------------- #
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@pytest.mark.parametrize(
|
|
307
|
+
"loss_instance",
|
|
308
|
+
[
|
|
309
|
+
"trace_elbo",
|
|
310
|
+
"trace_elbo_batched",
|
|
311
|
+
],
|
|
312
|
+
indirect=True,
|
|
313
|
+
)
|
|
314
|
+
@pytest.mark.parametrize(
|
|
315
|
+
("model", "data"),
|
|
316
|
+
[
|
|
317
|
+
("linear_regression_adaptive_model", "simulated_data_simple"),
|
|
318
|
+
("linear_regression_fixed_model", "simulated_data_simple"),
|
|
319
|
+
("fm_regression_model", "simulated_data_fm"),
|
|
320
|
+
("emb_model", "simulated_data_emb"),
|
|
321
|
+
("lowrank_model", "simulated_data_lowrank"),
|
|
322
|
+
],
|
|
323
|
+
indirect=True,
|
|
324
|
+
)
|
|
325
|
+
def test_models(
|
|
326
|
+
data: Any,
|
|
327
|
+
model: Any,
|
|
328
|
+
loss_instance: Any,
|
|
329
|
+
) -> Any:
|
|
330
|
+
model_fn, coef_groups = model
|
|
331
|
+
model_data = {k: v for k, v in data.items() if k in ("y", "x1", "x2")}
|
|
332
|
+
|
|
333
|
+
guide = AutoDiagonalNormal(model_fn)
|
|
334
|
+
|
|
335
|
+
num_steps = 30000
|
|
336
|
+
|
|
337
|
+
schedule = optax.cosine_onecycle_schedule(
|
|
338
|
+
transition_steps=num_steps,
|
|
339
|
+
peak_value=5e-2,
|
|
340
|
+
pct_start=0.1,
|
|
341
|
+
div_factor=25,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
|
|
345
|
+
|
|
346
|
+
rng_key = random.PRNGKey(2)
|
|
347
|
+
|
|
348
|
+
if isinstance(loss_instance, Trace_ELBO):
|
|
349
|
+
svi_result = svi.run(
|
|
350
|
+
rng_key,
|
|
351
|
+
num_steps=num_steps,
|
|
352
|
+
**model_data,
|
|
353
|
+
)
|
|
354
|
+
if isinstance(loss_instance, Batched_Trace_ELBO):
|
|
355
|
+
svi_result = svi_run_batched(
|
|
356
|
+
svi,
|
|
357
|
+
rng_key,
|
|
358
|
+
num_steps,
|
|
359
|
+
batch_size=1000,
|
|
360
|
+
**model_data,
|
|
361
|
+
)
|
|
362
|
+
guide_predicitive = Predictive(
|
|
363
|
+
guide,
|
|
364
|
+
params=svi_result.params,
|
|
365
|
+
num_samples=1000,
|
|
366
|
+
)
|
|
367
|
+
guide_samples = guide_predicitive(
|
|
368
|
+
random.PRNGKey(1),
|
|
369
|
+
**{k: v for k, v in model_data.items() if k != "y"},
|
|
370
|
+
)
|
|
371
|
+
guide_means = {k: jnp.mean(v, axis=0) for k, v in guide_samples.items()}
|
|
372
|
+
|
|
373
|
+
for coef_list, coef_fn in coef_groups:
|
|
374
|
+
with pytest_check.check:
|
|
375
|
+
val = rmse(
|
|
376
|
+
coef_fn(*[guide_means[x] for x in coef_list]),
|
|
377
|
+
coef_fn(*[data[x.split("_")[2]] for x in coef_list]),
|
|
378
|
+
)
|
|
379
|
+
assert val < 0.1
|
|
380
|
+
|
|
381
|
+
with pytest_check.check:
|
|
382
|
+
assert (
|
|
383
|
+
rmse(
|
|
384
|
+
guide_means["sigma"],
|
|
385
|
+
data["sigma"],
|
|
386
|
+
)
|
|
387
|
+
< 0.03
|
|
388
|
+
)
|