arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
|
@@ -1,260 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name
|
|
2
|
-
import numpy as np
|
|
3
|
-
import packaging
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
from ...data.io_pyro import from_pyro # pylint: disable=wrong-import-position
|
|
7
|
-
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
|
|
8
|
-
chains,
|
|
9
|
-
check_multiple_attrs,
|
|
10
|
-
draws,
|
|
11
|
-
eight_schools_params,
|
|
12
|
-
importorskip,
|
|
13
|
-
load_cached_models,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
# Skip all tests if pyro or pytorch not installed
|
|
17
|
-
torch = importorskip("torch")
|
|
18
|
-
pyro = importorskip("pyro")
|
|
19
|
-
Predictive = pyro.infer.Predictive
|
|
20
|
-
dist = pyro.distributions
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class TestDataPyro:
|
|
24
|
-
@pytest.fixture(scope="class")
|
|
25
|
-
def data(self, eight_schools_params, draws, chains):
|
|
26
|
-
class Data:
|
|
27
|
-
obj = load_cached_models(eight_schools_params, draws, chains, "pyro")["pyro"]
|
|
28
|
-
|
|
29
|
-
return Data
|
|
30
|
-
|
|
31
|
-
@pytest.fixture(scope="class")
|
|
32
|
-
def predictions_params(self):
|
|
33
|
-
"""Predictions data for eight schools."""
|
|
34
|
-
return {
|
|
35
|
-
"J": 8,
|
|
36
|
-
"sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]),
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
@pytest.fixture(scope="class")
|
|
40
|
-
def predictions_data(self, data, predictions_params):
|
|
41
|
-
"""Generate predictions for predictions_params"""
|
|
42
|
-
posterior_samples = data.obj.get_samples()
|
|
43
|
-
model = data.obj.kernel.model
|
|
44
|
-
predictions = Predictive(model, posterior_samples)(
|
|
45
|
-
predictions_params["J"], torch.from_numpy(predictions_params["sigma"]).float()
|
|
46
|
-
)
|
|
47
|
-
return predictions
|
|
48
|
-
|
|
49
|
-
def get_inference_data(self, data, eight_schools_params, predictions_data):
|
|
50
|
-
posterior_samples = data.obj.get_samples()
|
|
51
|
-
model = data.obj.kernel.model
|
|
52
|
-
posterior_predictive = Predictive(model, posterior_samples)(
|
|
53
|
-
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
|
|
54
|
-
)
|
|
55
|
-
prior = Predictive(model, num_samples=500)(
|
|
56
|
-
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
|
|
57
|
-
)
|
|
58
|
-
predictions = predictions_data
|
|
59
|
-
return from_pyro(
|
|
60
|
-
posterior=data.obj,
|
|
61
|
-
prior=prior,
|
|
62
|
-
posterior_predictive=posterior_predictive,
|
|
63
|
-
predictions=predictions,
|
|
64
|
-
coords={
|
|
65
|
-
"school": np.arange(eight_schools_params["J"]),
|
|
66
|
-
"school_pred": np.arange(eight_schools_params["J"]),
|
|
67
|
-
},
|
|
68
|
-
dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]},
|
|
69
|
-
pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]},
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
def test_inference_data(self, data, eight_schools_params, predictions_data):
|
|
73
|
-
inference_data = self.get_inference_data(data, eight_schools_params, predictions_data)
|
|
74
|
-
test_dict = {
|
|
75
|
-
"posterior": ["mu", "tau", "eta"],
|
|
76
|
-
"sample_stats": ["diverging"],
|
|
77
|
-
"posterior_predictive": ["obs"],
|
|
78
|
-
"predictions": ["obs"],
|
|
79
|
-
"prior": ["mu", "tau", "eta"],
|
|
80
|
-
"prior_predictive": ["obs"],
|
|
81
|
-
}
|
|
82
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
83
|
-
assert not fails
|
|
84
|
-
|
|
85
|
-
# test dims
|
|
86
|
-
dims = inference_data.posterior_predictive.sizes["school"]
|
|
87
|
-
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
88
|
-
assert dims == 8
|
|
89
|
-
assert pred_dims == 8
|
|
90
|
-
|
|
91
|
-
@pytest.mark.skipif(
|
|
92
|
-
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
|
|
93
|
-
reason="requires pyro 1.0.0 or higher",
|
|
94
|
-
)
|
|
95
|
-
def test_inference_data_has_log_likelihood_and_observed_data(self, data):
|
|
96
|
-
idata = from_pyro(data.obj)
|
|
97
|
-
test_dict = {"log_likelihood": ["obs"], "observed_data": ["obs"]}
|
|
98
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
99
|
-
assert not fails
|
|
100
|
-
|
|
101
|
-
def test_inference_data_no_posterior(
|
|
102
|
-
self, data, eight_schools_params, predictions_data, predictions_params
|
|
103
|
-
):
|
|
104
|
-
posterior_samples = data.obj.get_samples()
|
|
105
|
-
model = data.obj.kernel.model
|
|
106
|
-
posterior_predictive = Predictive(model, posterior_samples)(
|
|
107
|
-
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
|
|
108
|
-
)
|
|
109
|
-
prior = Predictive(model, num_samples=500)(
|
|
110
|
-
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
|
|
111
|
-
)
|
|
112
|
-
predictions = predictions_data
|
|
113
|
-
constant_data = {"J": 8, "sigma": eight_schools_params["sigma"]}
|
|
114
|
-
predictions_constant_data = predictions_params
|
|
115
|
-
# only prior
|
|
116
|
-
inference_data = from_pyro(prior=prior)
|
|
117
|
-
test_dict = {"prior": ["mu", "tau", "eta"]}
|
|
118
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
119
|
-
assert not fails, f"only prior: {fails}"
|
|
120
|
-
# only posterior_predictive
|
|
121
|
-
inference_data = from_pyro(posterior_predictive=posterior_predictive)
|
|
122
|
-
test_dict = {"posterior_predictive": ["obs"]}
|
|
123
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
124
|
-
assert not fails, f"only posterior_predictive: {fails}"
|
|
125
|
-
# only predictions
|
|
126
|
-
inference_data = from_pyro(predictions=predictions)
|
|
127
|
-
test_dict = {"predictions": ["obs"]}
|
|
128
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
129
|
-
assert not fails, f"only predictions: {fails}"
|
|
130
|
-
# only constant_data
|
|
131
|
-
inference_data = from_pyro(constant_data=constant_data)
|
|
132
|
-
test_dict = {"constant_data": ["J", "sigma"]}
|
|
133
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
134
|
-
assert not fails, f"only constant_data: {fails}"
|
|
135
|
-
# only predictions_constant_data
|
|
136
|
-
inference_data = from_pyro(predictions_constant_data=predictions_constant_data)
|
|
137
|
-
test_dict = {"predictions_constant_data": ["J", "sigma"]}
|
|
138
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
139
|
-
assert not fails, f"only predictions_constant_data: {fails}"
|
|
140
|
-
# prior and posterior_predictive
|
|
141
|
-
idata = from_pyro(
|
|
142
|
-
prior=prior,
|
|
143
|
-
posterior_predictive=posterior_predictive,
|
|
144
|
-
coords={"school": np.arange(eight_schools_params["J"])},
|
|
145
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
146
|
-
)
|
|
147
|
-
test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
|
|
148
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
149
|
-
assert not fails, f"prior and posterior_predictive: {fails}"
|
|
150
|
-
|
|
151
|
-
def test_inference_data_only_posterior(self, data):
|
|
152
|
-
idata = from_pyro(data.obj)
|
|
153
|
-
test_dict = {"posterior": ["mu", "tau", "eta"], "sample_stats": ["diverging"]}
|
|
154
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
155
|
-
assert not fails
|
|
156
|
-
|
|
157
|
-
@pytest.mark.skipif(
|
|
158
|
-
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
|
|
159
|
-
reason="requires pyro 1.0.0 or higher",
|
|
160
|
-
)
|
|
161
|
-
def test_inference_data_only_posterior_has_log_likelihood(self, data):
|
|
162
|
-
idata = from_pyro(data.obj)
|
|
163
|
-
test_dict = {"log_likelihood": ["obs"]}
|
|
164
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
165
|
-
assert not fails
|
|
166
|
-
|
|
167
|
-
def test_multiple_observed_rv(self):
|
|
168
|
-
y1 = torch.randn(10)
|
|
169
|
-
y2 = torch.randn(10)
|
|
170
|
-
|
|
171
|
-
def model_example_multiple_obs(y1=None, y2=None):
|
|
172
|
-
x = pyro.sample("x", dist.Normal(1, 3))
|
|
173
|
-
pyro.sample("y1", dist.Normal(x, 1), obs=y1)
|
|
174
|
-
pyro.sample("y2", dist.Normal(x, 1), obs=y2)
|
|
175
|
-
|
|
176
|
-
nuts_kernel = pyro.infer.NUTS(model_example_multiple_obs)
|
|
177
|
-
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
|
|
178
|
-
mcmc.run(y1=y1, y2=y2)
|
|
179
|
-
inference_data = from_pyro(mcmc)
|
|
180
|
-
test_dict = {
|
|
181
|
-
"posterior": ["x"],
|
|
182
|
-
"sample_stats": ["diverging"],
|
|
183
|
-
"log_likelihood": ["y1", "y2"],
|
|
184
|
-
"observed_data": ["y1", "y2"],
|
|
185
|
-
}
|
|
186
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
187
|
-
assert not fails
|
|
188
|
-
assert not hasattr(inference_data.sample_stats, "log_likelihood")
|
|
189
|
-
|
|
190
|
-
def test_inference_data_constant_data(self):
|
|
191
|
-
x1 = 10
|
|
192
|
-
x2 = 12
|
|
193
|
-
y1 = torch.randn(10)
|
|
194
|
-
|
|
195
|
-
def model_constant_data(x, y1=None):
|
|
196
|
-
_x = pyro.sample("x", dist.Normal(1, 3))
|
|
197
|
-
pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)
|
|
198
|
-
|
|
199
|
-
nuts_kernel = pyro.infer.NUTS(model_constant_data)
|
|
200
|
-
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
|
|
201
|
-
mcmc.run(x=x1, y1=y1)
|
|
202
|
-
posterior = mcmc.get_samples()
|
|
203
|
-
posterior_predictive = Predictive(model_constant_data, posterior)(x1)
|
|
204
|
-
predictions = Predictive(model_constant_data, posterior)(x2)
|
|
205
|
-
inference_data = from_pyro(
|
|
206
|
-
mcmc,
|
|
207
|
-
posterior_predictive=posterior_predictive,
|
|
208
|
-
predictions=predictions,
|
|
209
|
-
constant_data={"x1": x1},
|
|
210
|
-
predictions_constant_data={"x2": x2},
|
|
211
|
-
)
|
|
212
|
-
test_dict = {
|
|
213
|
-
"posterior": ["x"],
|
|
214
|
-
"posterior_predictive": ["y1"],
|
|
215
|
-
"sample_stats": ["diverging"],
|
|
216
|
-
"log_likelihood": ["y1"],
|
|
217
|
-
"predictions": ["y1"],
|
|
218
|
-
"observed_data": ["y1"],
|
|
219
|
-
"constant_data": ["x1"],
|
|
220
|
-
"predictions_constant_data": ["x2"],
|
|
221
|
-
}
|
|
222
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
223
|
-
assert not fails
|
|
224
|
-
|
|
225
|
-
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
226
|
-
predictions = predictions_data
|
|
227
|
-
inference_data = from_pyro(predictions=predictions, num_chains=chains)
|
|
228
|
-
nchains = inference_data.predictions.sizes["chain"]
|
|
229
|
-
assert nchains == chains
|
|
230
|
-
|
|
231
|
-
@pytest.mark.parametrize("log_likelihood", [True, False])
|
|
232
|
-
def test_log_likelihood(self, log_likelihood):
|
|
233
|
-
"""Test behaviour when log likelihood cannot be retrieved.
|
|
234
|
-
|
|
235
|
-
If log_likelihood=True there is a warning to say log_likelihood group is skipped,
|
|
236
|
-
if log_likelihood=False there is no warning and log_likelihood is skipped.
|
|
237
|
-
"""
|
|
238
|
-
x = torch.randn((10, 2))
|
|
239
|
-
y = torch.randn(10)
|
|
240
|
-
|
|
241
|
-
def model_constant_data(x, y=None):
|
|
242
|
-
beta = pyro.sample("beta", dist.Normal(torch.ones(2), 3))
|
|
243
|
-
pyro.sample("y", dist.Normal(x.matmul(beta), 1), obs=y)
|
|
244
|
-
|
|
245
|
-
nuts_kernel = pyro.infer.NUTS(model_constant_data)
|
|
246
|
-
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
|
|
247
|
-
mcmc.run(x=x, y=y)
|
|
248
|
-
if log_likelihood:
|
|
249
|
-
with pytest.warns(UserWarning, match="Could not get vectorized trace"):
|
|
250
|
-
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
|
|
251
|
-
else:
|
|
252
|
-
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
|
|
253
|
-
test_dict = {
|
|
254
|
-
"posterior": ["beta"],
|
|
255
|
-
"sample_stats": ["diverging"],
|
|
256
|
-
"~log_likelihood": [""],
|
|
257
|
-
"observed_data": ["y"],
|
|
258
|
-
}
|
|
259
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
260
|
-
assert not fails
|
|
@@ -1,307 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-function-args
|
|
2
|
-
import importlib
|
|
3
|
-
from collections import OrderedDict
|
|
4
|
-
import os
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pytest
|
|
8
|
-
|
|
9
|
-
from ... import from_pystan
|
|
10
|
-
|
|
11
|
-
from ...data.io_pystan import get_draws, get_draws_stan3 # pylint: disable=unused-import
|
|
12
|
-
from ..helpers import ( # pylint: disable=unused-import
|
|
13
|
-
chains,
|
|
14
|
-
check_multiple_attrs,
|
|
15
|
-
draws,
|
|
16
|
-
eight_schools_params,
|
|
17
|
-
importorskip,
|
|
18
|
-
load_cached_models,
|
|
19
|
-
pystan_version,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
# Check if either pystan or pystan3 is installed
|
|
23
|
-
pystan_installed = (importlib.util.find_spec("pystan") is not None) or (
|
|
24
|
-
importlib.util.find_spec("stan") is not None
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@pytest.mark.skipif(
|
|
29
|
-
not (pystan_installed or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
30
|
-
reason="test requires pystan/pystan3 which is not installed",
|
|
31
|
-
)
|
|
32
|
-
class TestDataPyStan:
|
|
33
|
-
@pytest.fixture(scope="class")
|
|
34
|
-
def data(self, eight_schools_params, draws, chains):
|
|
35
|
-
class Data:
|
|
36
|
-
model, obj = load_cached_models(eight_schools_params, draws, chains, "pystan")["pystan"]
|
|
37
|
-
|
|
38
|
-
return Data
|
|
39
|
-
|
|
40
|
-
def get_inference_data(self, data, eight_schools_params):
|
|
41
|
-
"""vars as str."""
|
|
42
|
-
return from_pystan(
|
|
43
|
-
posterior=data.obj,
|
|
44
|
-
posterior_predictive="y_hat",
|
|
45
|
-
predictions="y_hat", # wrong, but fine for testing
|
|
46
|
-
prior=data.obj,
|
|
47
|
-
prior_predictive="y_hat",
|
|
48
|
-
observed_data="y",
|
|
49
|
-
constant_data="sigma",
|
|
50
|
-
predictions_constant_data="sigma", # wrong, but fine for testing
|
|
51
|
-
log_likelihood={"y": "log_lik"},
|
|
52
|
-
coords={"school": np.arange(eight_schools_params["J"])},
|
|
53
|
-
dims={
|
|
54
|
-
"theta": ["school"],
|
|
55
|
-
"y": ["school"],
|
|
56
|
-
"sigma": ["school"],
|
|
57
|
-
"y_hat": ["school"],
|
|
58
|
-
"eta": ["school"],
|
|
59
|
-
},
|
|
60
|
-
posterior_model=data.model,
|
|
61
|
-
prior_model=data.model,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
def get_inference_data2(self, data, eight_schools_params):
|
|
65
|
-
"""vars as lists."""
|
|
66
|
-
return from_pystan(
|
|
67
|
-
posterior=data.obj,
|
|
68
|
-
posterior_predictive=["y_hat"],
|
|
69
|
-
predictions=["y_hat"], # wrong, but fine for testing
|
|
70
|
-
prior=data.obj,
|
|
71
|
-
prior_predictive=["y_hat"],
|
|
72
|
-
observed_data=["y"],
|
|
73
|
-
log_likelihood="log_lik",
|
|
74
|
-
coords={
|
|
75
|
-
"school": np.arange(eight_schools_params["J"]),
|
|
76
|
-
"log_likelihood_dim": np.arange(eight_schools_params["J"]),
|
|
77
|
-
},
|
|
78
|
-
dims={
|
|
79
|
-
"theta": ["school"],
|
|
80
|
-
"y": ["school"],
|
|
81
|
-
"y_hat": ["school"],
|
|
82
|
-
"eta": ["school"],
|
|
83
|
-
"log_lik": ["log_likelihood_dim"],
|
|
84
|
-
},
|
|
85
|
-
posterior_model=data.model,
|
|
86
|
-
prior_model=data.model,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
def get_inference_data3(self, data, eight_schools_params):
|
|
90
|
-
"""multiple vars as lists."""
|
|
91
|
-
return from_pystan(
|
|
92
|
-
posterior=data.obj,
|
|
93
|
-
posterior_predictive=["y_hat", "log_lik"], # wrong, but fine for testing
|
|
94
|
-
predictions=["y_hat", "log_lik"], # wrong, but fine for testing
|
|
95
|
-
prior=data.obj,
|
|
96
|
-
prior_predictive=["y_hat", "log_lik"], # wrong, but fine for testing
|
|
97
|
-
constant_data=["sigma", "y"], # wrong, but fine for testing
|
|
98
|
-
predictions_constant_data=["sigma", "y"], # wrong, but fine for testing
|
|
99
|
-
coords={"school": np.arange(eight_schools_params["J"])},
|
|
100
|
-
dims={
|
|
101
|
-
"theta": ["school"],
|
|
102
|
-
"y": ["school"],
|
|
103
|
-
"sigma": ["school"],
|
|
104
|
-
"y_hat": ["school"],
|
|
105
|
-
"eta": ["school"],
|
|
106
|
-
},
|
|
107
|
-
posterior_model=data.model,
|
|
108
|
-
prior_model=data.model,
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
def get_inference_data4(self, data):
|
|
112
|
-
"""minimal input."""
|
|
113
|
-
return from_pystan(
|
|
114
|
-
posterior=data.obj,
|
|
115
|
-
posterior_predictive=None,
|
|
116
|
-
prior=data.obj,
|
|
117
|
-
prior_predictive=None,
|
|
118
|
-
coords=None,
|
|
119
|
-
dims=None,
|
|
120
|
-
posterior_model=data.model,
|
|
121
|
-
log_likelihood=[],
|
|
122
|
-
prior_model=data.model,
|
|
123
|
-
save_warmup=True,
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
def get_inference_data5(self, data):
|
|
127
|
-
"""minimal input."""
|
|
128
|
-
return from_pystan(
|
|
129
|
-
posterior=data.obj,
|
|
130
|
-
posterior_predictive=None,
|
|
131
|
-
prior=data.obj,
|
|
132
|
-
prior_predictive=None,
|
|
133
|
-
coords=None,
|
|
134
|
-
dims=None,
|
|
135
|
-
posterior_model=data.model,
|
|
136
|
-
log_likelihood=False,
|
|
137
|
-
prior_model=data.model,
|
|
138
|
-
save_warmup=True,
|
|
139
|
-
dtypes={"eta": int},
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
def test_sampler_stats(self, data, eight_schools_params):
|
|
143
|
-
inference_data = self.get_inference_data(data, eight_schools_params)
|
|
144
|
-
test_dict = {"sample_stats": ["diverging"]}
|
|
145
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
146
|
-
assert not fails
|
|
147
|
-
|
|
148
|
-
def test_inference_data(self, data, eight_schools_params):
|
|
149
|
-
inference_data1 = self.get_inference_data(data, eight_schools_params)
|
|
150
|
-
inference_data2 = self.get_inference_data2(data, eight_schools_params)
|
|
151
|
-
inference_data3 = self.get_inference_data3(data, eight_schools_params)
|
|
152
|
-
inference_data4 = self.get_inference_data4(data)
|
|
153
|
-
inference_data5 = self.get_inference_data5(data)
|
|
154
|
-
# inference_data 1
|
|
155
|
-
test_dict = {
|
|
156
|
-
"posterior": ["theta", "~log_lik"],
|
|
157
|
-
"posterior_predictive": ["y_hat"],
|
|
158
|
-
"predictions": ["y_hat"],
|
|
159
|
-
"observed_data": ["y"],
|
|
160
|
-
"constant_data": ["sigma"],
|
|
161
|
-
"predictions_constant_data": ["sigma"],
|
|
162
|
-
"sample_stats": ["diverging", "lp"],
|
|
163
|
-
"log_likelihood": ["y", "~log_lik"],
|
|
164
|
-
"prior": ["theta"],
|
|
165
|
-
}
|
|
166
|
-
fails = check_multiple_attrs(test_dict, inference_data1)
|
|
167
|
-
assert not fails
|
|
168
|
-
# inference_data 2
|
|
169
|
-
test_dict = {
|
|
170
|
-
"posterior_predictive": ["y_hat"],
|
|
171
|
-
"predictions": ["y_hat"],
|
|
172
|
-
"observed_data": ["y"],
|
|
173
|
-
"sample_stats_prior": ["diverging"],
|
|
174
|
-
"sample_stats": ["diverging", "lp"],
|
|
175
|
-
"log_likelihood": ["log_lik"],
|
|
176
|
-
"prior_predictive": ["y_hat"],
|
|
177
|
-
}
|
|
178
|
-
fails = check_multiple_attrs(test_dict, inference_data2)
|
|
179
|
-
assert not fails
|
|
180
|
-
assert any(
|
|
181
|
-
item in inference_data2.posterior.attrs for item in ["stan_code", "program_code"]
|
|
182
|
-
)
|
|
183
|
-
assert any(
|
|
184
|
-
item in inference_data2.sample_stats.attrs for item in ["stan_code", "program_code"]
|
|
185
|
-
)
|
|
186
|
-
# inference_data 3
|
|
187
|
-
test_dict = {
|
|
188
|
-
"posterior_predictive": ["y_hat", "log_lik"],
|
|
189
|
-
"predictions": ["y_hat", "log_lik"],
|
|
190
|
-
"constant_data": ["sigma", "y"],
|
|
191
|
-
"predictions_constant_data": ["sigma", "y"],
|
|
192
|
-
"sample_stats_prior": ["diverging"],
|
|
193
|
-
"sample_stats": ["diverging", "lp"],
|
|
194
|
-
"log_likelihood": ["log_lik"],
|
|
195
|
-
"prior_predictive": ["y_hat", "log_lik"],
|
|
196
|
-
}
|
|
197
|
-
fails = check_multiple_attrs(test_dict, inference_data3)
|
|
198
|
-
assert not fails
|
|
199
|
-
# inference_data 4
|
|
200
|
-
test_dict = {
|
|
201
|
-
"posterior": ["theta"],
|
|
202
|
-
"prior": ["theta"],
|
|
203
|
-
"sample_stats": ["diverging", "lp"],
|
|
204
|
-
"~log_likelihood": [""],
|
|
205
|
-
"warmup_posterior": ["theta"],
|
|
206
|
-
"warmup_sample_stats": ["diverging", "lp"],
|
|
207
|
-
}
|
|
208
|
-
fails = check_multiple_attrs(test_dict, inference_data4)
|
|
209
|
-
assert not fails
|
|
210
|
-
# inference_data 5
|
|
211
|
-
test_dict = {
|
|
212
|
-
"posterior": ["theta"],
|
|
213
|
-
"prior": ["theta"],
|
|
214
|
-
"sample_stats": ["diverging", "lp"],
|
|
215
|
-
"~log_likelihood": [""],
|
|
216
|
-
"warmup_posterior": ["theta"],
|
|
217
|
-
"warmup_sample_stats": ["diverging", "lp"],
|
|
218
|
-
}
|
|
219
|
-
fails = check_multiple_attrs(test_dict, inference_data5)
|
|
220
|
-
assert not fails
|
|
221
|
-
assert inference_data5.posterior.eta.dtype.kind == "i"
|
|
222
|
-
|
|
223
|
-
def test_invalid_fit(self, data):
|
|
224
|
-
if pystan_version() == 2:
|
|
225
|
-
model = data.model
|
|
226
|
-
model_data = {
|
|
227
|
-
"J": 8,
|
|
228
|
-
"y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
|
|
229
|
-
"sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
|
|
230
|
-
}
|
|
231
|
-
fit_test_grad = model.sampling(
|
|
232
|
-
data=model_data, test_grad=True, check_hmc_diagnostics=False
|
|
233
|
-
)
|
|
234
|
-
with pytest.raises(AttributeError):
|
|
235
|
-
_ = from_pystan(posterior=fit_test_grad)
|
|
236
|
-
fit = model.sampling(data=model_data, iter=100, chains=1, check_hmc_diagnostics=False)
|
|
237
|
-
del fit.sim["samples"]
|
|
238
|
-
with pytest.raises(AttributeError):
|
|
239
|
-
_ = from_pystan(posterior=fit)
|
|
240
|
-
|
|
241
|
-
def test_empty_parameter(self):
|
|
242
|
-
model_code = """
|
|
243
|
-
parameters {
|
|
244
|
-
real y;
|
|
245
|
-
vector[3] x;
|
|
246
|
-
vector[0] a;
|
|
247
|
-
vector[2] z;
|
|
248
|
-
}
|
|
249
|
-
model {
|
|
250
|
-
y ~ normal(0,1);
|
|
251
|
-
}
|
|
252
|
-
"""
|
|
253
|
-
if pystan_version() == 2:
|
|
254
|
-
from pystan import StanModel # pylint: disable=import-error
|
|
255
|
-
|
|
256
|
-
model = StanModel(model_code=model_code)
|
|
257
|
-
fit = model.sampling(iter=500, chains=2, check_hmc_diagnostics=False)
|
|
258
|
-
else:
|
|
259
|
-
import stan # pylint: disable=import-error
|
|
260
|
-
|
|
261
|
-
model = stan.build(model_code)
|
|
262
|
-
fit = model.sample(num_samples=500, num_chains=2)
|
|
263
|
-
|
|
264
|
-
posterior = from_pystan(posterior=fit)
|
|
265
|
-
test_dict = {"posterior": ["y", "x", "z", "~a"], "sample_stats": ["diverging"]}
|
|
266
|
-
fails = check_multiple_attrs(test_dict, posterior)
|
|
267
|
-
assert not fails
|
|
268
|
-
|
|
269
|
-
def test_get_draws(self, data):
|
|
270
|
-
fit = data.obj
|
|
271
|
-
if pystan_version() == 2:
|
|
272
|
-
draws, _ = get_draws(fit, variables=["theta", "theta"])
|
|
273
|
-
else:
|
|
274
|
-
draws, _ = get_draws_stan3(fit, variables=["theta", "theta"])
|
|
275
|
-
assert draws.get("theta") is not None
|
|
276
|
-
|
|
277
|
-
@pytest.mark.skipif(pystan_version() != 2, reason="PyStan 2.x required")
|
|
278
|
-
def test_index_order(self, data, eight_schools_params):
|
|
279
|
-
"""Test 0-indexed data."""
|
|
280
|
-
# Skip test if pystan not installed
|
|
281
|
-
pystan = importorskip("pystan") # pylint: disable=import-error
|
|
282
|
-
|
|
283
|
-
fit = data.model.sampling(data=eight_schools_params)
|
|
284
|
-
if pystan.__version__ >= "2.18":
|
|
285
|
-
# make 1-indexed to 0-indexed
|
|
286
|
-
for holder in fit.sim["samples"]:
|
|
287
|
-
new_chains = OrderedDict()
|
|
288
|
-
for i, (key, values) in enumerate(holder.chains.items()):
|
|
289
|
-
if "[" in key:
|
|
290
|
-
name, *shape = key.replace("]", "").split("[")
|
|
291
|
-
shape = [str(int(item) - 1) for items in shape for item in items.split(",")]
|
|
292
|
-
key = f"{name}[{','.join(shape)}]"
|
|
293
|
-
new_chains[key] = np.full_like(values, fill_value=float(i))
|
|
294
|
-
setattr(holder, "chains", new_chains)
|
|
295
|
-
fit.sim["fnames_oi"] = list(fit.sim["samples"][0].chains.keys())
|
|
296
|
-
idata = from_pystan(posterior=fit)
|
|
297
|
-
assert idata is not None
|
|
298
|
-
for j, fpar in enumerate(fit.sim["fnames_oi"]):
|
|
299
|
-
par, *shape = fpar.replace("]", "").split("[")
|
|
300
|
-
if par in {"lp__", "log_lik"}:
|
|
301
|
-
continue
|
|
302
|
-
assert hasattr(idata.posterior, par), (par, list(idata.posterior.data_vars))
|
|
303
|
-
if shape:
|
|
304
|
-
shape = [slice(None), slice(None)] + list(map(int, shape))
|
|
305
|
-
assert idata.posterior[par][tuple(shape)].values.mean() == float(j)
|
|
306
|
-
else:
|
|
307
|
-
assert idata.posterior[par].values.mean() == float(j)
|