arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,434 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-public-methods
|
|
2
|
-
from collections import namedtuple
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
from ...data.io_numpyro import from_numpyro # pylint: disable=wrong-import-position
|
|
7
|
-
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
|
|
8
|
-
chains,
|
|
9
|
-
check_multiple_attrs,
|
|
10
|
-
draws,
|
|
11
|
-
eight_schools_params,
|
|
12
|
-
importorskip,
|
|
13
|
-
load_cached_models,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
# Skip all tests if jax or numpyro not installed
|
|
17
|
-
jax = importorskip("jax")
|
|
18
|
-
PRNGKey = jax.random.PRNGKey
|
|
19
|
-
numpyro = importorskip("numpyro")
|
|
20
|
-
Predictive = numpyro.infer.Predictive
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class TestDataNumPyro:
|
|
24
|
-
@pytest.fixture(scope="class")
|
|
25
|
-
def data(self, eight_schools_params, draws, chains):
|
|
26
|
-
class Data:
|
|
27
|
-
obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"]
|
|
28
|
-
|
|
29
|
-
return Data
|
|
30
|
-
|
|
31
|
-
@pytest.fixture(scope="class")
|
|
32
|
-
def predictions_params(self):
|
|
33
|
-
"""Predictions data for eight schools."""
|
|
34
|
-
return {
|
|
35
|
-
"J": 8,
|
|
36
|
-
"sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]),
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
@pytest.fixture(scope="class")
|
|
40
|
-
def predictions_data(self, data, predictions_params):
|
|
41
|
-
"""Generate predictions for predictions_params"""
|
|
42
|
-
posterior_samples = data.obj.get_samples()
|
|
43
|
-
model = data.obj.sampler.model
|
|
44
|
-
predictions = Predictive(model, posterior_samples)(
|
|
45
|
-
PRNGKey(2), predictions_params["J"], predictions_params["sigma"]
|
|
46
|
-
)
|
|
47
|
-
return predictions
|
|
48
|
-
|
|
49
|
-
def get_inference_data(
|
|
50
|
-
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
|
|
51
|
-
):
|
|
52
|
-
posterior_samples = data.obj.get_samples()
|
|
53
|
-
model = data.obj.sampler.model
|
|
54
|
-
posterior_predictive = Predictive(model, posterior_samples)(
|
|
55
|
-
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
56
|
-
)
|
|
57
|
-
prior = Predictive(model, num_samples=500)(
|
|
58
|
-
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
59
|
-
)
|
|
60
|
-
dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
|
|
61
|
-
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
|
|
62
|
-
if infer_dims:
|
|
63
|
-
dims = None
|
|
64
|
-
pred_dims = None
|
|
65
|
-
|
|
66
|
-
predictions = predictions_data
|
|
67
|
-
return from_numpyro(
|
|
68
|
-
posterior=data.obj,
|
|
69
|
-
prior=prior,
|
|
70
|
-
posterior_predictive=posterior_predictive,
|
|
71
|
-
predictions=predictions,
|
|
72
|
-
coords={
|
|
73
|
-
"school": np.arange(eight_schools_params["J"]),
|
|
74
|
-
"school_pred": np.arange(predictions_params["J"]),
|
|
75
|
-
},
|
|
76
|
-
dims=dims,
|
|
77
|
-
pred_dims=pred_dims,
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
def test_inference_data_namedtuple(self, data):
|
|
81
|
-
samples = data.obj.get_samples()
|
|
82
|
-
Samples = namedtuple("Samples", samples)
|
|
83
|
-
data_namedtuple = Samples(**samples)
|
|
84
|
-
_old_fn = data.obj.get_samples
|
|
85
|
-
data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
|
|
86
|
-
inference_data = from_numpyro(
|
|
87
|
-
posterior=data.obj,
|
|
88
|
-
dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
|
|
89
|
-
)
|
|
90
|
-
assert isinstance(data.obj.get_samples(), Samples)
|
|
91
|
-
data.obj.get_samples = _old_fn
|
|
92
|
-
for key in samples:
|
|
93
|
-
assert key in inference_data.posterior
|
|
94
|
-
|
|
95
|
-
def test_inference_data(self, data, eight_schools_params, predictions_data, predictions_params):
|
|
96
|
-
inference_data = self.get_inference_data(
|
|
97
|
-
data, eight_schools_params, predictions_data, predictions_params
|
|
98
|
-
)
|
|
99
|
-
test_dict = {
|
|
100
|
-
"posterior": ["mu", "tau", "eta"],
|
|
101
|
-
"sample_stats": ["diverging"],
|
|
102
|
-
"log_likelihood": ["obs"],
|
|
103
|
-
"posterior_predictive": ["obs"],
|
|
104
|
-
"predictions": ["obs"],
|
|
105
|
-
"prior": ["mu", "tau", "eta"],
|
|
106
|
-
"prior_predictive": ["obs"],
|
|
107
|
-
"observed_data": ["obs"],
|
|
108
|
-
}
|
|
109
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
110
|
-
assert not fails
|
|
111
|
-
|
|
112
|
-
# test dims
|
|
113
|
-
dims = inference_data.posterior_predictive.sizes["school"]
|
|
114
|
-
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
115
|
-
assert dims == 8
|
|
116
|
-
assert pred_dims == 8
|
|
117
|
-
|
|
118
|
-
def test_inference_data_no_posterior(
|
|
119
|
-
self, data, eight_schools_params, predictions_data, predictions_params
|
|
120
|
-
):
|
|
121
|
-
posterior_samples = data.obj.get_samples()
|
|
122
|
-
model = data.obj.sampler.model
|
|
123
|
-
posterior_predictive = Predictive(model, posterior_samples)(
|
|
124
|
-
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
125
|
-
)
|
|
126
|
-
prior = Predictive(model, num_samples=500)(
|
|
127
|
-
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
128
|
-
)
|
|
129
|
-
predictions = predictions_data
|
|
130
|
-
constant_data = {"J": 8, "sigma": eight_schools_params["sigma"]}
|
|
131
|
-
predictions_constant_data = predictions_params
|
|
132
|
-
# only prior
|
|
133
|
-
inference_data = from_numpyro(prior=prior)
|
|
134
|
-
test_dict = {"prior": ["mu", "tau", "eta"]}
|
|
135
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
136
|
-
assert not fails, f"only prior: {fails}"
|
|
137
|
-
# only posterior_predictive
|
|
138
|
-
inference_data = from_numpyro(posterior_predictive=posterior_predictive)
|
|
139
|
-
test_dict = {"posterior_predictive": ["obs"]}
|
|
140
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
141
|
-
assert not fails, f"only posterior_predictive: {fails}"
|
|
142
|
-
# only predictions
|
|
143
|
-
inference_data = from_numpyro(predictions=predictions)
|
|
144
|
-
test_dict = {"predictions": ["obs"]}
|
|
145
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
146
|
-
assert not fails, f"only predictions: {fails}"
|
|
147
|
-
# only constant_data
|
|
148
|
-
inference_data = from_numpyro(constant_data=constant_data)
|
|
149
|
-
test_dict = {"constant_data": ["J", "sigma"]}
|
|
150
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
151
|
-
assert not fails, f"only constant_data: {fails}"
|
|
152
|
-
# only predictions_constant_data
|
|
153
|
-
inference_data = from_numpyro(predictions_constant_data=predictions_constant_data)
|
|
154
|
-
test_dict = {"predictions_constant_data": ["J", "sigma"]}
|
|
155
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
156
|
-
assert not fails, f"only predictions_constant_data: {fails}"
|
|
157
|
-
# prior and posterior_predictive
|
|
158
|
-
idata = from_numpyro(
|
|
159
|
-
prior=prior,
|
|
160
|
-
posterior_predictive=posterior_predictive,
|
|
161
|
-
coords={"school": np.arange(eight_schools_params["J"])},
|
|
162
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
163
|
-
)
|
|
164
|
-
test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
|
|
165
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
166
|
-
assert not fails, f"prior and posterior_predictive: {fails}"
|
|
167
|
-
|
|
168
|
-
def test_inference_data_only_posterior(self, data):
|
|
169
|
-
idata = from_numpyro(data.obj)
|
|
170
|
-
test_dict = {
|
|
171
|
-
"posterior": ["mu", "tau", "eta"],
|
|
172
|
-
"sample_stats": ["diverging"],
|
|
173
|
-
"log_likelihood": ["obs"],
|
|
174
|
-
}
|
|
175
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
176
|
-
assert not fails
|
|
177
|
-
|
|
178
|
-
def test_multiple_observed_rv(self):
|
|
179
|
-
import numpyro
|
|
180
|
-
import numpyro.distributions as dist
|
|
181
|
-
from numpyro.infer import MCMC, NUTS
|
|
182
|
-
|
|
183
|
-
y1 = np.random.randn(10)
|
|
184
|
-
y2 = np.random.randn(100)
|
|
185
|
-
|
|
186
|
-
def model_example_multiple_obs(y1=None, y2=None):
|
|
187
|
-
x = numpyro.sample("x", dist.Normal(1, 3))
|
|
188
|
-
numpyro.sample("y1", dist.Normal(x, 1), obs=y1)
|
|
189
|
-
numpyro.sample("y2", dist.Normal(x, 1), obs=y2)
|
|
190
|
-
|
|
191
|
-
nuts_kernel = NUTS(model_example_multiple_obs)
|
|
192
|
-
mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
|
|
193
|
-
mcmc.run(PRNGKey(0), y1=y1, y2=y2)
|
|
194
|
-
inference_data = from_numpyro(mcmc)
|
|
195
|
-
test_dict = {
|
|
196
|
-
"posterior": ["x"],
|
|
197
|
-
"sample_stats": ["diverging"],
|
|
198
|
-
"log_likelihood": ["y1", "y2"],
|
|
199
|
-
"observed_data": ["y1", "y2"],
|
|
200
|
-
}
|
|
201
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
202
|
-
# from ..stats import waic
|
|
203
|
-
# waic_results = waic(inference_data)
|
|
204
|
-
# print(waic_results)
|
|
205
|
-
# print(waic_results.keys())
|
|
206
|
-
# print(waic_results.waic, waic_results.waic_se)
|
|
207
|
-
assert not fails
|
|
208
|
-
assert not hasattr(inference_data.sample_stats, "log_likelihood")
|
|
209
|
-
|
|
210
|
-
def test_inference_data_constant_data(self):
|
|
211
|
-
import numpyro
|
|
212
|
-
import numpyro.distributions as dist
|
|
213
|
-
from numpyro.infer import MCMC, NUTS
|
|
214
|
-
|
|
215
|
-
x1 = 10
|
|
216
|
-
x2 = 12
|
|
217
|
-
y1 = np.random.randn(10)
|
|
218
|
-
|
|
219
|
-
def model_constant_data(x, y1=None):
|
|
220
|
-
_x = numpyro.sample("x", dist.Normal(1, 3))
|
|
221
|
-
numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)
|
|
222
|
-
|
|
223
|
-
nuts_kernel = NUTS(model_constant_data)
|
|
224
|
-
mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
|
|
225
|
-
mcmc.run(PRNGKey(0), x=x1, y1=y1)
|
|
226
|
-
posterior = mcmc.get_samples()
|
|
227
|
-
posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1)
|
|
228
|
-
predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2)
|
|
229
|
-
inference_data = from_numpyro(
|
|
230
|
-
mcmc,
|
|
231
|
-
posterior_predictive=posterior_predictive,
|
|
232
|
-
predictions=predictions,
|
|
233
|
-
constant_data={"x1": x1},
|
|
234
|
-
predictions_constant_data={"x2": x2},
|
|
235
|
-
)
|
|
236
|
-
test_dict = {
|
|
237
|
-
"posterior": ["x"],
|
|
238
|
-
"posterior_predictive": ["y1"],
|
|
239
|
-
"sample_stats": ["diverging"],
|
|
240
|
-
"log_likelihood": ["y1"],
|
|
241
|
-
"predictions": ["y1"],
|
|
242
|
-
"observed_data": ["y1"],
|
|
243
|
-
"constant_data": ["x1"],
|
|
244
|
-
"predictions_constant_data": ["x2"],
|
|
245
|
-
}
|
|
246
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
247
|
-
assert not fails
|
|
248
|
-
|
|
249
|
-
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
250
|
-
predictions = predictions_data
|
|
251
|
-
inference_data = from_numpyro(predictions=predictions, num_chains=chains)
|
|
252
|
-
nchains = inference_data.predictions.sizes["chain"]
|
|
253
|
-
assert nchains == chains
|
|
254
|
-
|
|
255
|
-
@pytest.mark.parametrize("nchains", [1, 2])
|
|
256
|
-
@pytest.mark.parametrize("thin", [1, 2, 3, 5, 10])
|
|
257
|
-
def test_mcmc_with_thinning(self, nchains, thin):
|
|
258
|
-
import numpyro
|
|
259
|
-
import numpyro.distributions as dist
|
|
260
|
-
from numpyro.infer import MCMC, NUTS
|
|
261
|
-
|
|
262
|
-
x = np.random.normal(10, 3, size=100)
|
|
263
|
-
|
|
264
|
-
def model(x):
|
|
265
|
-
numpyro.sample(
|
|
266
|
-
"x",
|
|
267
|
-
dist.Normal(
|
|
268
|
-
numpyro.sample("loc", dist.Uniform(0, 20)),
|
|
269
|
-
numpyro.sample("scale", dist.Uniform(0, 20)),
|
|
270
|
-
),
|
|
271
|
-
obs=x,
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
nuts_kernel = NUTS(model)
|
|
275
|
-
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=400, num_chains=nchains, thinning=thin)
|
|
276
|
-
mcmc.run(PRNGKey(0), x=x)
|
|
277
|
-
|
|
278
|
-
inference_data = from_numpyro(mcmc)
|
|
279
|
-
assert inference_data.posterior["loc"].shape == (nchains, 400 // thin)
|
|
280
|
-
|
|
281
|
-
def test_mcmc_improper_uniform(self):
|
|
282
|
-
import numpyro
|
|
283
|
-
import numpyro.distributions as dist
|
|
284
|
-
from numpyro.infer import MCMC, NUTS
|
|
285
|
-
|
|
286
|
-
def model():
|
|
287
|
-
x = numpyro.sample("x", dist.ImproperUniform(dist.constraints.positive, (), ()))
|
|
288
|
-
return numpyro.sample("y", dist.Normal(x, 1), obs=1.0)
|
|
289
|
-
|
|
290
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
291
|
-
mcmc.run(PRNGKey(0))
|
|
292
|
-
inference_data = from_numpyro(mcmc)
|
|
293
|
-
assert inference_data.observed_data
|
|
294
|
-
|
|
295
|
-
def test_mcmc_infer_dims(self):
|
|
296
|
-
import numpyro
|
|
297
|
-
import numpyro.distributions as dist
|
|
298
|
-
from numpyro.infer import MCMC, NUTS
|
|
299
|
-
|
|
300
|
-
def model():
|
|
301
|
-
# note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
|
|
302
|
-
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
|
|
303
|
-
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
304
|
-
|
|
305
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
306
|
-
mcmc.run(PRNGKey(0))
|
|
307
|
-
inference_data = from_numpyro(
|
|
308
|
-
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
309
|
-
)
|
|
310
|
-
assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
|
|
311
|
-
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
312
|
-
|
|
313
|
-
def test_mcmc_infer_unsorted_dims(self):
|
|
314
|
-
import numpyro
|
|
315
|
-
import numpyro.distributions as dist
|
|
316
|
-
from numpyro.infer import MCMC, NUTS
|
|
317
|
-
|
|
318
|
-
def model():
|
|
319
|
-
group1_plate = numpyro.plate("group1", 10, dim=-1)
|
|
320
|
-
group2_plate = numpyro.plate("group2", 5, dim=-2)
|
|
321
|
-
|
|
322
|
-
# the plate contexts are entered in a different order than the pre-defined dims
|
|
323
|
-
# we should make sure this still works because the trace has all of the info it needs
|
|
324
|
-
with group2_plate, group1_plate:
|
|
325
|
-
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
326
|
-
|
|
327
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
328
|
-
mcmc.run(PRNGKey(0))
|
|
329
|
-
inference_data = from_numpyro(
|
|
330
|
-
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
331
|
-
)
|
|
332
|
-
assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
|
|
333
|
-
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
334
|
-
|
|
335
|
-
def test_mcmc_infer_dims_no_coords(self):
|
|
336
|
-
import numpyro
|
|
337
|
-
import numpyro.distributions as dist
|
|
338
|
-
from numpyro.infer import MCMC, NUTS
|
|
339
|
-
|
|
340
|
-
def model():
|
|
341
|
-
with numpyro.plate("group", 5):
|
|
342
|
-
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
343
|
-
|
|
344
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
345
|
-
mcmc.run(PRNGKey(0))
|
|
346
|
-
inference_data = from_numpyro(mcmc)
|
|
347
|
-
assert inference_data.posterior.param.dims == ("chain", "draw", "group")
|
|
348
|
-
|
|
349
|
-
def test_mcmc_event_dims(self):
|
|
350
|
-
import numpyro
|
|
351
|
-
import numpyro.distributions as dist
|
|
352
|
-
from numpyro.infer import MCMC, NUTS
|
|
353
|
-
|
|
354
|
-
def model():
|
|
355
|
-
_ = numpyro.sample(
|
|
356
|
-
"gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
360
|
-
mcmc.run(PRNGKey(0))
|
|
361
|
-
inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
|
|
362
|
-
assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
|
|
363
|
-
assert "groups" in inference_data.posterior.gamma.coords
|
|
364
|
-
|
|
365
|
-
@pytest.mark.xfail
|
|
366
|
-
def test_mcmc_inferred_dims_univariate(self):
|
|
367
|
-
import numpyro
|
|
368
|
-
import numpyro.distributions as dist
|
|
369
|
-
from numpyro.infer import MCMC, NUTS
|
|
370
|
-
import jax.numpy as jnp
|
|
371
|
-
|
|
372
|
-
def model():
|
|
373
|
-
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
|
|
374
|
-
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
|
|
375
|
-
with numpyro.plate("obs_idx", 3):
|
|
376
|
-
# mu is plated by obs_idx, but isnt broadcasted to the plate shape
|
|
377
|
-
# the expected behavior is that this should cause a failure
|
|
378
|
-
mu = numpyro.deterministic("mu", alpha)
|
|
379
|
-
return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
|
|
380
|
-
|
|
381
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
382
|
-
mcmc.run(PRNGKey(0))
|
|
383
|
-
inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
|
|
384
|
-
assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
|
|
385
|
-
assert "obs_idx" in inference_data.posterior.mu.coords
|
|
386
|
-
|
|
387
|
-
def test_mcmc_extra_event_dims(self):
|
|
388
|
-
import numpyro
|
|
389
|
-
import numpyro.distributions as dist
|
|
390
|
-
from numpyro.infer import MCMC, NUTS
|
|
391
|
-
|
|
392
|
-
def model():
|
|
393
|
-
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
|
|
394
|
-
_ = numpyro.deterministic("gamma_plus1", gamma + 1)
|
|
395
|
-
|
|
396
|
-
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
397
|
-
mcmc.run(PRNGKey(0))
|
|
398
|
-
inference_data = from_numpyro(
|
|
399
|
-
mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
|
|
400
|
-
)
|
|
401
|
-
assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
|
|
402
|
-
assert "groups" in inference_data.posterior.gamma_plus1.coords
|
|
403
|
-
|
|
404
|
-
def test_mcmc_predictions_infer_dims(
|
|
405
|
-
self, data, eight_schools_params, predictions_data, predictions_params
|
|
406
|
-
):
|
|
407
|
-
inference_data = self.get_inference_data(
|
|
408
|
-
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
|
|
409
|
-
)
|
|
410
|
-
assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
|
|
411
|
-
assert "J" in inference_data.predictions.obs.coords
|
|
412
|
-
|
|
413
|
-
def test_potential_energy_sign_conversion(self):
|
|
414
|
-
"""Test that potential energy is converted to log probability (lp) with correct sign."""
|
|
415
|
-
import numpyro
|
|
416
|
-
import numpyro.distributions as dist
|
|
417
|
-
from numpyro.infer import MCMC, NUTS
|
|
418
|
-
|
|
419
|
-
num_samples = 10
|
|
420
|
-
|
|
421
|
-
def simple_model():
|
|
422
|
-
numpyro.sample("x", dist.Normal(0, 1))
|
|
423
|
-
|
|
424
|
-
nuts_kernel = NUTS(simple_model)
|
|
425
|
-
mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=5)
|
|
426
|
-
mcmc.run(PRNGKey(0), extra_fields=["potential_energy"])
|
|
427
|
-
|
|
428
|
-
# Get the raw extra fields from NumPyro
|
|
429
|
-
extra_fields = mcmc.get_extra_fields(group_by_chain=True)
|
|
430
|
-
# Convert to ArviZ InferenceData
|
|
431
|
-
inference_data = from_numpyro(mcmc)
|
|
432
|
-
arviz_lp = inference_data["sample_stats"]["lp"].values
|
|
433
|
-
|
|
434
|
-
np.testing.assert_array_equal(arviz_lp, -extra_fields["potential_energy"])
|
|
@@ -1,119 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name, unused-import
|
|
2
|
-
import sys
|
|
3
|
-
import typing as tp
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import pytest
|
|
7
|
-
|
|
8
|
-
from ... import InferenceData, from_pyjags, waic
|
|
9
|
-
from ...data.io_pyjags import (
|
|
10
|
-
_convert_arviz_dict_to_pyjags_dict,
|
|
11
|
-
_convert_pyjags_dict_to_arviz_dict,
|
|
12
|
-
_extract_arviz_dict_from_inference_data,
|
|
13
|
-
)
|
|
14
|
-
from ..helpers import check_multiple_attrs, eight_schools_params
|
|
15
|
-
|
|
16
|
-
pytest.skip("Uses deprecated numpy C-api", allow_module_level=True)
|
|
17
|
-
|
|
18
|
-
PYJAGS_POSTERIOR_DICT = {
|
|
19
|
-
"b": np.random.randn(3, 10, 3),
|
|
20
|
-
"int": np.random.randn(1, 10, 3),
|
|
21
|
-
"log_like": np.random.randn(1, 10, 3),
|
|
22
|
-
}
|
|
23
|
-
PYJAGS_PRIOR_DICT = {"b": np.random.randn(3, 10, 3), "int": np.random.randn(1, 10, 3)}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
PARAMETERS = ("mu", "tau", "theta_tilde")
|
|
27
|
-
VARIABLES = tuple(list(PARAMETERS) + ["log_like"])
|
|
28
|
-
|
|
29
|
-
NUMBER_OF_WARMUP_SAMPLES = 1000
|
|
30
|
-
NUMBER_OF_POST_WARMUP_SAMPLES = 5000
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def verify_equality_of_numpy_values_dictionaries(
|
|
34
|
-
dict_1: tp.Mapping[tp.Any, np.ndarray], dict_2: tp.Mapping[tp.Any, np.ndarray]
|
|
35
|
-
) -> bool:
|
|
36
|
-
if dict_1.keys() != dict_2.keys():
|
|
37
|
-
return False
|
|
38
|
-
|
|
39
|
-
for key in dict_1.keys():
|
|
40
|
-
if not np.all(dict_1[key] == dict_2[key]):
|
|
41
|
-
return False
|
|
42
|
-
|
|
43
|
-
return True
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class TestDataPyJAGSWithoutEstimation:
|
|
47
|
-
def test_convert_pyjags_samples_dictionary_to_arviz_samples_dictionary(self):
|
|
48
|
-
arviz_samples_dict_from_pyjags_samples_dict = _convert_pyjags_dict_to_arviz_dict(
|
|
49
|
-
PYJAGS_POSTERIOR_DICT
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
pyjags_dict_from_arviz_dict_from_pyjags_dict = _convert_arviz_dict_to_pyjags_dict(
|
|
53
|
-
arviz_samples_dict_from_pyjags_samples_dict
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
assert verify_equality_of_numpy_values_dictionaries(
|
|
57
|
-
PYJAGS_POSTERIOR_DICT,
|
|
58
|
-
pyjags_dict_from_arviz_dict_from_pyjags_dict,
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
def test_extract_samples_dictionary_from_arviz_inference_data(self):
|
|
62
|
-
arviz_samples_dict_from_pyjags_samples_dict = _convert_pyjags_dict_to_arviz_dict(
|
|
63
|
-
PYJAGS_POSTERIOR_DICT
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
arviz_inference_data_from_pyjags_samples_dict = from_pyjags(PYJAGS_POSTERIOR_DICT)
|
|
67
|
-
arviz_dict_from_idata_from_pyjags_dict = _extract_arviz_dict_from_inference_data(
|
|
68
|
-
arviz_inference_data_from_pyjags_samples_dict
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
assert verify_equality_of_numpy_values_dictionaries(
|
|
72
|
-
arviz_samples_dict_from_pyjags_samples_dict,
|
|
73
|
-
arviz_dict_from_idata_from_pyjags_dict,
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
def test_roundtrip_from_pyjags_via_arviz_to_pyjags(self):
|
|
77
|
-
arviz_inference_data_from_pyjags_samples_dict = from_pyjags(PYJAGS_POSTERIOR_DICT)
|
|
78
|
-
arviz_dict_from_idata_from_pyjags_dict = _extract_arviz_dict_from_inference_data(
|
|
79
|
-
arviz_inference_data_from_pyjags_samples_dict
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
pyjags_dict_from_arviz_idata = _convert_arviz_dict_to_pyjags_dict(
|
|
83
|
-
arviz_dict_from_idata_from_pyjags_dict
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
assert verify_equality_of_numpy_values_dictionaries(
|
|
87
|
-
PYJAGS_POSTERIOR_DICT, pyjags_dict_from_arviz_idata
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
@pytest.mark.parametrize("posterior", [None, PYJAGS_POSTERIOR_DICT])
|
|
91
|
-
@pytest.mark.parametrize("prior", [None, PYJAGS_PRIOR_DICT])
|
|
92
|
-
@pytest.mark.parametrize("save_warmup", [True, False])
|
|
93
|
-
@pytest.mark.parametrize("warmup_iterations", [0, 5])
|
|
94
|
-
def test_inference_data_attrs(self, posterior, prior, save_warmup, warmup_iterations: int):
|
|
95
|
-
arviz_inference_data_from_pyjags_samples_dict = from_pyjags(
|
|
96
|
-
posterior=posterior,
|
|
97
|
-
prior=prior,
|
|
98
|
-
log_likelihood={"y": "log_like"},
|
|
99
|
-
save_warmup=save_warmup,
|
|
100
|
-
warmup_iterations=warmup_iterations,
|
|
101
|
-
)
|
|
102
|
-
posterior_warmup_prefix = (
|
|
103
|
-
"" if save_warmup and warmup_iterations > 0 and posterior is not None else "~"
|
|
104
|
-
)
|
|
105
|
-
prior_warmup_prefix = (
|
|
106
|
-
"" if save_warmup and warmup_iterations > 0 and prior is not None else "~"
|
|
107
|
-
)
|
|
108
|
-
print(f'posterior_warmup_prefix="{posterior_warmup_prefix}"')
|
|
109
|
-
test_dict = {
|
|
110
|
-
f'{"~" if posterior is None else ""}posterior': ["b", "int"],
|
|
111
|
-
f'{"~" if prior is None else ""}prior': ["b", "int"],
|
|
112
|
-
f'{"~" if posterior is None else ""}log_likelihood': ["y"],
|
|
113
|
-
f"{posterior_warmup_prefix}warmup_posterior": ["b", "int"],
|
|
114
|
-
f"{prior_warmup_prefix}warmup_prior": ["b", "int"],
|
|
115
|
-
f"{posterior_warmup_prefix}warmup_log_likelihood": ["y"],
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
fails = check_multiple_attrs(test_dict, arviz_inference_data_from_pyjags_samples_dict)
|
|
119
|
-
assert not fails
|