arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,925 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name, no-member
|
|
2
|
-
from copy import deepcopy
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pytest
|
|
6
|
-
from numpy.testing import (
|
|
7
|
-
assert_allclose,
|
|
8
|
-
assert_array_almost_equal,
|
|
9
|
-
assert_almost_equal,
|
|
10
|
-
assert_array_equal,
|
|
11
|
-
)
|
|
12
|
-
from scipy.special import logsumexp
|
|
13
|
-
from scipy.stats import linregress, norm, halfcauchy
|
|
14
|
-
from xarray import DataArray, Dataset
|
|
15
|
-
from xarray_einstats.stats import XrContinuousRV
|
|
16
|
-
|
|
17
|
-
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
|
|
18
|
-
from ...rcparams import rcParams
|
|
19
|
-
from ...stats import (
|
|
20
|
-
apply_test_function,
|
|
21
|
-
bayes_factor,
|
|
22
|
-
compare,
|
|
23
|
-
ess,
|
|
24
|
-
hdi,
|
|
25
|
-
loo,
|
|
26
|
-
loo_pit,
|
|
27
|
-
psens,
|
|
28
|
-
psislw,
|
|
29
|
-
r2_score,
|
|
30
|
-
summary,
|
|
31
|
-
waic,
|
|
32
|
-
weight_predictions,
|
|
33
|
-
_calculate_ics,
|
|
34
|
-
)
|
|
35
|
-
from ...stats.stats import _gpinv
|
|
36
|
-
from ...stats.stats_utils import get_log_likelihood
|
|
37
|
-
from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import
|
|
38
|
-
|
|
39
|
-
rcParams["data.load"] = "eager"
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@pytest.fixture(scope="session")
|
|
43
|
-
def centered_eight():
|
|
44
|
-
centered_eight = load_arviz_data("centered_eight")
|
|
45
|
-
return centered_eight
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@pytest.fixture(scope="session")
|
|
49
|
-
def non_centered_eight():
|
|
50
|
-
non_centered_eight = load_arviz_data("non_centered_eight")
|
|
51
|
-
return non_centered_eight
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@pytest.fixture(scope="module")
|
|
55
|
-
def multivariable_log_likelihood(centered_eight):
|
|
56
|
-
centered_eight = centered_eight.copy()
|
|
57
|
-
new_arr = DataArray(
|
|
58
|
-
np.zeros(centered_eight.log_likelihood["obs"].values.shape),
|
|
59
|
-
dims=["chain", "draw", "school"],
|
|
60
|
-
coords=centered_eight.log_likelihood.coords,
|
|
61
|
-
)
|
|
62
|
-
centered_eight.log_likelihood["decoy"] = new_arr
|
|
63
|
-
return centered_eight
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def test_hdp():
|
|
67
|
-
normal_sample = np.random.randn(5000000)
|
|
68
|
-
interval = hdi(normal_sample)
|
|
69
|
-
assert_array_almost_equal(interval, [-1.88, 1.88], 2)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def test_hdp_2darray():
|
|
73
|
-
normal_sample = np.random.randn(12000, 5)
|
|
74
|
-
msg = (
|
|
75
|
-
r"hdi currently interprets 2d data as \(draw, shape\) but this will "
|
|
76
|
-
r"change in a future release to \(chain, draw\) for coherence with other functions"
|
|
77
|
-
)
|
|
78
|
-
with pytest.warns(FutureWarning, match=msg):
|
|
79
|
-
result = hdi(normal_sample)
|
|
80
|
-
assert result.shape == (5, 2)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def test_hdi_multidimension():
|
|
84
|
-
normal_sample = np.random.randn(12000, 10, 3)
|
|
85
|
-
result = hdi(normal_sample)
|
|
86
|
-
assert result.shape == (3, 2)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def test_hdi_idata(centered_eight):
|
|
90
|
-
data = centered_eight.posterior
|
|
91
|
-
result = hdi(data)
|
|
92
|
-
assert isinstance(result, Dataset)
|
|
93
|
-
assert dict(result.sizes) == {"school": 8, "hdi": 2}
|
|
94
|
-
|
|
95
|
-
result = hdi(data, input_core_dims=[["chain"]])
|
|
96
|
-
assert isinstance(result, Dataset)
|
|
97
|
-
assert result.sizes == {"draw": 500, "hdi": 2, "school": 8}
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def test_hdi_idata_varnames(centered_eight):
|
|
101
|
-
data = centered_eight.posterior
|
|
102
|
-
result = hdi(data, var_names=["mu", "theta"])
|
|
103
|
-
assert isinstance(result, Dataset)
|
|
104
|
-
assert result.sizes == {"hdi": 2, "school": 8}
|
|
105
|
-
assert list(result.data_vars.keys()) == ["mu", "theta"]
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def test_hdi_idata_group(centered_eight):
|
|
109
|
-
result_posterior = hdi(centered_eight, group="posterior", var_names="mu")
|
|
110
|
-
result_prior = hdi(centered_eight, group="prior", var_names="mu")
|
|
111
|
-
assert result_prior.sizes == {"hdi": 2}
|
|
112
|
-
range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
|
|
113
|
-
range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
|
|
114
|
-
assert range_posterior < range_prior
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def test_hdi_coords(centered_eight):
|
|
118
|
-
data = centered_eight.posterior
|
|
119
|
-
result = hdi(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]])
|
|
120
|
-
assert_array_equal(result.coords["chain"], [0, 1, 3])
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def test_hdi_multimodal():
|
|
124
|
-
normal_sample = np.concatenate(
|
|
125
|
-
(np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
|
|
126
|
-
)
|
|
127
|
-
intervals = hdi(normal_sample, multimodal=True)
|
|
128
|
-
assert_array_almost_equal(intervals, [[-5.8, -2.2], [0.9, 3.1]], 1)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def test_hdi_multimodal_multivars():
|
|
132
|
-
size = 2500000
|
|
133
|
-
var1 = np.concatenate((np.random.normal(-4, 1, size), np.random.normal(2, 0.5, size)))
|
|
134
|
-
var2 = np.random.normal(8, 1, size * 2)
|
|
135
|
-
sample = Dataset(
|
|
136
|
-
{
|
|
137
|
-
"var1": (("chain", "draw"), var1[np.newaxis, :]),
|
|
138
|
-
"var2": (("chain", "draw"), var2[np.newaxis, :]),
|
|
139
|
-
},
|
|
140
|
-
coords={"chain": [0], "draw": np.arange(size * 2)},
|
|
141
|
-
)
|
|
142
|
-
intervals = hdi(sample, multimodal=True)
|
|
143
|
-
assert_array_almost_equal(intervals.var1, [[-5.8, -2.2], [0.9, 3.1]], 1)
|
|
144
|
-
assert_array_almost_equal(intervals.var2, [[6.1, 9.9], [np.nan, np.nan]], 1)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def test_hdi_circular():
|
|
148
|
-
normal_sample = np.random.vonmises(np.pi, 1, 5000000)
|
|
149
|
-
interval = hdi(normal_sample, circular=True)
|
|
150
|
-
assert_array_almost_equal(interval, [0.6, -0.6], 1)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def test_hdi_bad_ci():
|
|
154
|
-
normal_sample = np.random.randn(10)
|
|
155
|
-
with pytest.raises(ValueError):
|
|
156
|
-
hdi(normal_sample, hdi_prob=2)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def test_hdi_skipna():
|
|
160
|
-
normal_sample = np.random.randn(500)
|
|
161
|
-
interval = hdi(normal_sample[10:])
|
|
162
|
-
normal_sample[:10] = np.nan
|
|
163
|
-
interval_ = hdi(normal_sample, skipna=True)
|
|
164
|
-
assert_array_almost_equal(interval, interval_)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def test_r2_score():
|
|
168
|
-
x = np.linspace(0, 1, 100)
|
|
169
|
-
y = np.random.normal(x, 1)
|
|
170
|
-
y_pred = x + np.random.randn(300, 100)
|
|
171
|
-
res = linregress(x, y)
|
|
172
|
-
assert_allclose(res.rvalue**2, r2_score(y, y_pred).r2, 2)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
|
|
176
|
-
@pytest.mark.parametrize("multidim", [True, False])
|
|
177
|
-
def test_compare_same(centered_eight, multidim_models, method, multidim):
|
|
178
|
-
if multidim:
|
|
179
|
-
data_dict = {"first": multidim_models.model_1, "second": multidim_models.model_1}
|
|
180
|
-
else:
|
|
181
|
-
data_dict = {"first": centered_eight, "second": centered_eight}
|
|
182
|
-
|
|
183
|
-
weight = compare(data_dict, method=method)["weight"].to_numpy()
|
|
184
|
-
assert_allclose(weight[0], weight[1])
|
|
185
|
-
assert_allclose(np.sum(weight), 1.0)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
|
|
189
|
-
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
|
|
190
|
-
with pytest.raises(ValueError):
|
|
191
|
-
compare(model_dict, ic="Unknown", method="stacking")
|
|
192
|
-
with pytest.raises(ValueError):
|
|
193
|
-
compare(model_dict, ic="loo", method="Unknown")
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@pytest.mark.parametrize("ic", ["loo", "waic"])
|
|
197
|
-
@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
|
|
198
|
-
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
|
|
199
|
-
def test_compare_different(centered_eight, non_centered_eight, ic, method, scale):
|
|
200
|
-
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
|
|
201
|
-
weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"]
|
|
202
|
-
assert weight["non_centered"] > weight["centered"]
|
|
203
|
-
assert_allclose(np.sum(weight), 1.0)
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@pytest.mark.parametrize("ic", ["loo", "waic"])
|
|
207
|
-
@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
|
|
208
|
-
def test_compare_different_multidim(multidim_models, ic, method):
|
|
209
|
-
model_dict = {"model_1": multidim_models.model_1, "model_2": multidim_models.model_2}
|
|
210
|
-
weight = compare(model_dict, ic=ic, method=method)["weight"]
|
|
211
|
-
|
|
212
|
-
# this should hold because the same seed is always used
|
|
213
|
-
assert weight["model_1"] > weight["model_2"]
|
|
214
|
-
assert_allclose(np.sum(weight), 1.0)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
def test_compare_different_size(centered_eight, non_centered_eight):
|
|
218
|
-
centered_eight = deepcopy(centered_eight)
|
|
219
|
-
centered_eight.posterior = centered_eight.posterior.drop("Choate", "school")
|
|
220
|
-
centered_eight.log_likelihood = centered_eight.log_likelihood.drop("Choate", "school")
|
|
221
|
-
centered_eight.posterior_predictive = centered_eight.posterior_predictive.drop(
|
|
222
|
-
"Choate", "school"
|
|
223
|
-
)
|
|
224
|
-
centered_eight.prior = centered_eight.prior.drop("Choate", "school")
|
|
225
|
-
centered_eight.observed_data = centered_eight.observed_data.drop("Choate", "school")
|
|
226
|
-
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
|
|
227
|
-
with pytest.raises(ValueError):
|
|
228
|
-
compare(model_dict, ic="waic", method="stacking")
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
@pytest.mark.parametrize("ic", ["loo", "waic"])
|
|
232
|
-
def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic):
|
|
233
|
-
compare_dict = {
|
|
234
|
-
"centered_eight": centered_eight,
|
|
235
|
-
"non_centered_eight": non_centered_eight,
|
|
236
|
-
"problematic": multivariable_log_likelihood,
|
|
237
|
-
}
|
|
238
|
-
with pytest.raises(TypeError, match="several log likelihood arrays"):
|
|
239
|
-
get_log_likelihood(compare_dict["problematic"])
|
|
240
|
-
with pytest.raises(TypeError, match="error in ELPD computation"):
|
|
241
|
-
compare(compare_dict, ic=ic)
|
|
242
|
-
assert compare(compare_dict, ic=ic, var_name="obs") is not None
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
@pytest.mark.parametrize("ic", ["loo", "waic"])
|
|
246
|
-
def test_calculate_ics(centered_eight, non_centered_eight, ic):
|
|
247
|
-
ic_func = loo if ic == "loo" else waic
|
|
248
|
-
idata_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
|
|
249
|
-
elpddata_dict = {key: ic_func(value) for key, value in idata_dict.items()}
|
|
250
|
-
mixed_dict = {"centered": idata_dict["centered"], "non_centered": elpddata_dict["non_centered"]}
|
|
251
|
-
idata_out, _, _ = _calculate_ics(idata_dict, ic=ic)
|
|
252
|
-
elpddata_out, _, _ = _calculate_ics(elpddata_dict, ic=ic)
|
|
253
|
-
mixed_out, _, _ = _calculate_ics(mixed_dict, ic=ic)
|
|
254
|
-
for model in idata_dict:
|
|
255
|
-
ic_ = f"elpd_{ic}"
|
|
256
|
-
assert idata_out[model][ic_] == elpddata_out[model][ic_]
|
|
257
|
-
assert idata_out[model][ic_] == mixed_out[model][ic_]
|
|
258
|
-
assert idata_out[model][f"p_{ic}"] == elpddata_out[model][f"p_{ic}"]
|
|
259
|
-
assert idata_out[model][f"p_{ic}"] == mixed_out[model][f"p_{ic}"]
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
def test_calculate_ics_ic_error(centered_eight, non_centered_eight):
|
|
263
|
-
in_dict = {"centered": loo(centered_eight), "non_centered": waic(non_centered_eight)}
|
|
264
|
-
with pytest.raises(ValueError, match="found both loo and waic"):
|
|
265
|
-
_calculate_ics(in_dict)
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
def test_calculate_ics_ic_override(centered_eight, non_centered_eight):
|
|
269
|
-
in_dict = {"centered": centered_eight, "non_centered": waic(non_centered_eight)}
|
|
270
|
-
with pytest.warns(UserWarning, match="precomputed elpddata: waic"):
|
|
271
|
-
out_dict, _, ic = _calculate_ics(in_dict, ic="loo")
|
|
272
|
-
assert ic == "waic"
|
|
273
|
-
assert out_dict["centered"]["elpd_waic"] == waic(centered_eight)["elpd_waic"]
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def test_summary_ndarray():
|
|
277
|
-
array = np.random.randn(4, 100, 2)
|
|
278
|
-
summary_df = summary(array)
|
|
279
|
-
assert summary_df.shape
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
@pytest.mark.parametrize("var_names_expected", ((None, 10), ("mu", 1), (["mu", "tau"], 2)))
|
|
283
|
-
def test_summary_var_names(centered_eight, var_names_expected):
|
|
284
|
-
var_names, expected = var_names_expected
|
|
285
|
-
summary_df = summary(centered_eight, var_names=var_names)
|
|
286
|
-
assert len(summary_df.index) == expected
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
@pytest.mark.parametrize("missing_groups", (None, "posterior", "prior"))
|
|
290
|
-
def test_summary_groups(centered_eight, missing_groups):
|
|
291
|
-
if missing_groups == "posterior":
|
|
292
|
-
centered_eight = deepcopy(centered_eight)
|
|
293
|
-
del centered_eight.posterior
|
|
294
|
-
elif missing_groups == "prior":
|
|
295
|
-
centered_eight = deepcopy(centered_eight)
|
|
296
|
-
del centered_eight.posterior
|
|
297
|
-
del centered_eight.prior
|
|
298
|
-
if missing_groups == "prior":
|
|
299
|
-
with pytest.warns(UserWarning):
|
|
300
|
-
summary_df = summary(centered_eight)
|
|
301
|
-
else:
|
|
302
|
-
summary_df = summary(centered_eight)
|
|
303
|
-
assert summary_df.shape
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def test_summary_group_argument(centered_eight):
|
|
307
|
-
summary_df_posterior = summary(centered_eight, group="posterior")
|
|
308
|
-
summary_df_prior = summary(centered_eight, group="prior")
|
|
309
|
-
assert list(summary_df_posterior.index) != list(summary_df_prior.index)
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def test_summary_wrong_group(centered_eight):
|
|
313
|
-
with pytest.raises(TypeError, match=r"InferenceData does not contain group: InvalidGroup"):
|
|
314
|
-
summary(centered_eight, group="InvalidGroup")
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
METRICS_NAMES = [
|
|
318
|
-
"mean",
|
|
319
|
-
"sd",
|
|
320
|
-
"hdi_3%",
|
|
321
|
-
"hdi_97%",
|
|
322
|
-
"mcse_mean",
|
|
323
|
-
"mcse_sd",
|
|
324
|
-
"ess_bulk",
|
|
325
|
-
"ess_tail",
|
|
326
|
-
"r_hat",
|
|
327
|
-
"median",
|
|
328
|
-
"mad",
|
|
329
|
-
"eti_3%",
|
|
330
|
-
"eti_97%",
|
|
331
|
-
"mcse_median",
|
|
332
|
-
"ess_median",
|
|
333
|
-
"ess_tail",
|
|
334
|
-
"r_hat",
|
|
335
|
-
]
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
@pytest.mark.parametrize(
|
|
339
|
-
"params",
|
|
340
|
-
(
|
|
341
|
-
("mean", "all", METRICS_NAMES[:9]),
|
|
342
|
-
("mean", "stats", METRICS_NAMES[:4]),
|
|
343
|
-
("mean", "diagnostics", METRICS_NAMES[4:9]),
|
|
344
|
-
("median", "all", METRICS_NAMES[9:17]),
|
|
345
|
-
("median", "stats", METRICS_NAMES[9:13]),
|
|
346
|
-
("median", "diagnostics", METRICS_NAMES[13:17]),
|
|
347
|
-
),
|
|
348
|
-
)
|
|
349
|
-
def test_summary_focus_kind(centered_eight, params):
|
|
350
|
-
stat_focus, kind, metrics_names_ = params
|
|
351
|
-
summary_df = summary(centered_eight, stat_focus=stat_focus, kind=kind)
|
|
352
|
-
assert_array_equal(summary_df.columns, metrics_names_)
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
def test_summary_wrong_focus(centered_eight):
|
|
356
|
-
with pytest.raises(TypeError, match=r"Invalid format: 'WrongFocus'.*"):
|
|
357
|
-
summary(centered_eight, stat_focus="WrongFocus")
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
@pytest.mark.parametrize("fmt", ["wide", "long", "xarray"])
|
|
361
|
-
def test_summary_fmt(centered_eight, fmt):
|
|
362
|
-
assert summary(centered_eight, fmt=fmt) is not None
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
def test_summary_labels():
|
|
366
|
-
coords1 = list("abcd")
|
|
367
|
-
coords2 = np.arange(1, 6)
|
|
368
|
-
data = from_dict(
|
|
369
|
-
{"a": np.random.randn(4, 100, 4, 5)},
|
|
370
|
-
coords={"dim1": coords1, "dim2": coords2},
|
|
371
|
-
dims={"a": ["dim1", "dim2"]},
|
|
372
|
-
)
|
|
373
|
-
az_summary = summary(data, fmt="wide")
|
|
374
|
-
assert az_summary is not None
|
|
375
|
-
column_order = []
|
|
376
|
-
for coord1 in coords1:
|
|
377
|
-
for coord2 in coords2:
|
|
378
|
-
column_order.append(f"a[{coord1}, {coord2}]")
|
|
379
|
-
for col1, col2 in zip(list(az_summary.index), column_order):
|
|
380
|
-
assert col1 == col2
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
@pytest.mark.parametrize(
|
|
384
|
-
"stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}]
|
|
385
|
-
)
|
|
386
|
-
def test_summary_stat_func(centered_eight, stat_funcs):
|
|
387
|
-
arviz_summary = summary(centered_eight, stat_funcs=stat_funcs)
|
|
388
|
-
assert arviz_summary is not None
|
|
389
|
-
assert hasattr(arviz_summary, "var")
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
def test_summary_nan(centered_eight):
|
|
393
|
-
centered_eight = deepcopy(centered_eight)
|
|
394
|
-
centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan
|
|
395
|
-
summary_xarray = summary(centered_eight)
|
|
396
|
-
assert summary_xarray is not None
|
|
397
|
-
assert summary_xarray.loc["theta[Deerfield]"].isnull().all()
|
|
398
|
-
assert (
|
|
399
|
-
summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]]
|
|
400
|
-
.notnull()
|
|
401
|
-
.all()
|
|
402
|
-
.all()
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def test_summary_skip_nan(centered_eight):
|
|
407
|
-
centered_eight = deepcopy(centered_eight)
|
|
408
|
-
centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan
|
|
409
|
-
summary_xarray = summary(centered_eight)
|
|
410
|
-
theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull()
|
|
411
|
-
assert summary_xarray is not None
|
|
412
|
-
assert ~theta_1[:4].all()
|
|
413
|
-
assert theta_1[4:].all()
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
@pytest.mark.parametrize("fmt", [1, "bad_fmt"])
|
|
417
|
-
def test_summary_bad_fmt(centered_eight, fmt):
|
|
418
|
-
with pytest.raises(TypeError, match="Invalid format"):
|
|
419
|
-
summary(centered_eight, fmt=fmt)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
def test_summary_order_deprecation(centered_eight):
|
|
423
|
-
with pytest.warns(DeprecationWarning, match="order"):
|
|
424
|
-
summary(centered_eight, order="C")
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def test_summary_index_origin_deprecation(centered_eight):
|
|
428
|
-
with pytest.warns(DeprecationWarning, match="index_origin"):
|
|
429
|
-
summary(centered_eight, index_origin=1)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
|
|
433
|
-
@pytest.mark.parametrize("multidim", (True, False))
|
|
434
|
-
def test_waic(centered_eight, multidim_models, scale, multidim):
|
|
435
|
-
"""Test widely available information criterion calculation"""
|
|
436
|
-
if multidim:
|
|
437
|
-
assert waic(multidim_models.model_1, scale=scale) is not None
|
|
438
|
-
waic_pointwise = waic(multidim_models.model_1, pointwise=True, scale=scale)
|
|
439
|
-
else:
|
|
440
|
-
assert waic(centered_eight, scale=scale) is not None
|
|
441
|
-
waic_pointwise = waic(centered_eight, pointwise=True, scale=scale)
|
|
442
|
-
assert waic_pointwise is not None
|
|
443
|
-
assert "waic_i" in waic_pointwise
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
def test_waic_bad(centered_eight):
|
|
447
|
-
"""Test widely available information criterion calculation"""
|
|
448
|
-
centered_eight = deepcopy(centered_eight)
|
|
449
|
-
delattr(centered_eight, "log_likelihood")
|
|
450
|
-
with pytest.raises(TypeError):
|
|
451
|
-
waic(centered_eight)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
def test_waic_bad_scale(centered_eight):
|
|
455
|
-
"""Test widely available information criterion calculation with bad scale."""
|
|
456
|
-
with pytest.raises(TypeError):
|
|
457
|
-
waic(centered_eight, scale="bad_value")
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
def test_waic_warning(centered_eight):
|
|
461
|
-
centered_eight = deepcopy(centered_eight)
|
|
462
|
-
centered_eight.log_likelihood["obs"][:, :250, 1] = 10
|
|
463
|
-
with pytest.warns(UserWarning):
|
|
464
|
-
assert waic(centered_eight, pointwise=True) is not None
|
|
465
|
-
# this should throw a warning, but due to numerical issues it fails
|
|
466
|
-
centered_eight.log_likelihood["obs"][:, :, :] = 0
|
|
467
|
-
with pytest.warns(UserWarning):
|
|
468
|
-
assert waic(centered_eight, pointwise=True) is not None
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
|
|
472
|
-
def test_waic_print(centered_eight, scale):
|
|
473
|
-
waic_data = repr(waic(centered_eight, scale=scale))
|
|
474
|
-
waic_pointwise = repr(waic(centered_eight, scale=scale, pointwise=True))
|
|
475
|
-
assert waic_data is not None
|
|
476
|
-
assert waic_pointwise is not None
|
|
477
|
-
assert waic_data == waic_pointwise
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
|
|
481
|
-
@pytest.mark.parametrize("multidim", (True, False))
|
|
482
|
-
def test_loo(centered_eight, multidim_models, scale, multidim):
|
|
483
|
-
"""Test approximate leave one out criterion calculation"""
|
|
484
|
-
if multidim:
|
|
485
|
-
assert loo(multidim_models.model_1, scale=scale) is not None
|
|
486
|
-
loo_pointwise = loo(multidim_models.model_1, pointwise=True, scale=scale)
|
|
487
|
-
else:
|
|
488
|
-
assert loo(centered_eight, scale=scale) is not None
|
|
489
|
-
loo_pointwise = loo(centered_eight, pointwise=True, scale=scale)
|
|
490
|
-
assert loo_pointwise is not None
|
|
491
|
-
assert "loo_i" in loo_pointwise
|
|
492
|
-
assert "pareto_k" in loo_pointwise
|
|
493
|
-
assert "scale" in loo_pointwise
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
def test_loo_one_chain(centered_eight):
|
|
497
|
-
centered_eight = deepcopy(centered_eight)
|
|
498
|
-
centered_eight.posterior = centered_eight.posterior.drop([1, 2, 3], "chain")
|
|
499
|
-
centered_eight.sample_stats = centered_eight.sample_stats.drop([1, 2, 3], "chain")
|
|
500
|
-
assert loo(centered_eight) is not None
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
def test_loo_bad(centered_eight):
|
|
504
|
-
with pytest.raises(TypeError):
|
|
505
|
-
loo(np.random.randn(2, 10))
|
|
506
|
-
|
|
507
|
-
centered_eight = deepcopy(centered_eight)
|
|
508
|
-
delattr(centered_eight, "log_likelihood")
|
|
509
|
-
with pytest.raises(TypeError):
|
|
510
|
-
loo(centered_eight)
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
def test_loo_bad_scale(centered_eight):
|
|
514
|
-
"""Test loo with bad scale value."""
|
|
515
|
-
with pytest.raises(TypeError):
|
|
516
|
-
loo(centered_eight, scale="bad_scale")
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
def test_loo_bad_no_posterior_reff(centered_eight):
|
|
520
|
-
loo(centered_eight, reff=None)
|
|
521
|
-
centered_eight = deepcopy(centered_eight)
|
|
522
|
-
del centered_eight.posterior
|
|
523
|
-
with pytest.raises(TypeError):
|
|
524
|
-
loo(centered_eight, reff=None)
|
|
525
|
-
loo(centered_eight, reff=0.7)
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
def test_loo_warning(centered_eight):
|
|
529
|
-
centered_eight = deepcopy(centered_eight)
|
|
530
|
-
# make one of the khats infinity
|
|
531
|
-
centered_eight.log_likelihood["obs"][:, :, 1] = 10
|
|
532
|
-
with pytest.warns(UserWarning) as records:
|
|
533
|
-
assert loo(centered_eight, pointwise=True) is not None
|
|
534
|
-
assert any("Estimated shape parameter" in str(record.message) for record in records)
|
|
535
|
-
|
|
536
|
-
# make all of the khats infinity
|
|
537
|
-
centered_eight.log_likelihood["obs"][:, :, :] = 1
|
|
538
|
-
with pytest.warns(UserWarning) as records:
|
|
539
|
-
assert loo(centered_eight, pointwise=True) is not None
|
|
540
|
-
assert any("Estimated shape parameter" in str(record.message) for record in records)
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
|
|
544
|
-
def test_loo_print(centered_eight, scale):
|
|
545
|
-
loo_data = repr(loo(centered_eight, scale=scale, pointwise=False))
|
|
546
|
-
loo_pointwise = repr(loo(centered_eight, scale=scale, pointwise=True))
|
|
547
|
-
assert loo_data is not None
|
|
548
|
-
assert loo_pointwise is not None
|
|
549
|
-
assert len(loo_data) < len(loo_pointwise)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
def test_psislw(centered_eight):
|
|
553
|
-
pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"]
|
|
554
|
-
log_likelihood = get_log_likelihood(centered_eight)
|
|
555
|
-
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
|
|
556
|
-
assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1])
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
def test_psislw_smooths_for_low_k():
|
|
560
|
-
# check that log-weights are smoothed even when k < 1/3
|
|
561
|
-
# https://github.com/arviz-devs/arviz/issues/2010
|
|
562
|
-
rng = np.random.default_rng(44)
|
|
563
|
-
x = rng.normal(size=100)
|
|
564
|
-
x_smoothed, k = psislw(x.copy())
|
|
565
|
-
assert k < 1 / 3
|
|
566
|
-
assert not np.allclose(x - logsumexp(x), x_smoothed)
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
@pytest.mark.parametrize("probs", [True, False])
|
|
570
|
-
@pytest.mark.parametrize("kappa", [-1, -0.5, 1e-30, 0.5, 1])
|
|
571
|
-
@pytest.mark.parametrize("sigma", [0, 2])
|
|
572
|
-
def test_gpinv(probs, kappa, sigma):
|
|
573
|
-
if probs:
|
|
574
|
-
probs = np.array([0.1, 0.1, 0.1, 0.2, 0.3])
|
|
575
|
-
else:
|
|
576
|
-
probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3])
|
|
577
|
-
assert len(_gpinv(probs, kappa, sigma)) == len(probs)
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
@pytest.mark.parametrize("func", [loo, waic])
|
|
581
|
-
def test_multidimensional_log_likelihood(func):
|
|
582
|
-
llm = np.random.rand(4, 23, 15, 2)
|
|
583
|
-
ll1 = llm.reshape(4, 23, 15 * 2)
|
|
584
|
-
statsm = Dataset(dict(log_likelihood=DataArray(llm, dims=["chain", "draw", "a", "b"])))
|
|
585
|
-
|
|
586
|
-
stats1 = Dataset(dict(log_likelihood=DataArray(ll1, dims=["chain", "draw", "v"])))
|
|
587
|
-
|
|
588
|
-
post = Dataset(dict(mu=DataArray(np.random.rand(4, 23, 2), dims=["chain", "draw", "v"])))
|
|
589
|
-
|
|
590
|
-
dsm = convert_to_inference_data(statsm, group="sample_stats")
|
|
591
|
-
ds1 = convert_to_inference_data(stats1, group="sample_stats")
|
|
592
|
-
dsp = convert_to_inference_data(post, group="posterior")
|
|
593
|
-
|
|
594
|
-
dsm = concat(dsp, dsm)
|
|
595
|
-
ds1 = concat(dsp, ds1)
|
|
596
|
-
|
|
597
|
-
frm = func(dsm)
|
|
598
|
-
fr1 = func(ds1)
|
|
599
|
-
|
|
600
|
-
assert all(
|
|
601
|
-
fr1[key] == frm[key] for key in fr1.index if key not in {"loo_i", "waic_i", "pareto_k"}
|
|
602
|
-
)
|
|
603
|
-
assert_array_almost_equal(frm[:4], fr1[:4])
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
@pytest.mark.parametrize(
|
|
607
|
-
"args",
|
|
608
|
-
[
|
|
609
|
-
{"y": "obs"},
|
|
610
|
-
{"y": "obs", "y_hat": "obs"},
|
|
611
|
-
{"y": "arr", "y_hat": "obs"},
|
|
612
|
-
{"y": "obs", "y_hat": "arr"},
|
|
613
|
-
{"y": "arr", "y_hat": "arr"},
|
|
614
|
-
{"y": "obs", "y_hat": "obs", "log_weights": "arr"},
|
|
615
|
-
{"y": "arr", "y_hat": "obs", "log_weights": "arr"},
|
|
616
|
-
{"y": "obs", "y_hat": "arr", "log_weights": "arr"},
|
|
617
|
-
{"idata": False},
|
|
618
|
-
],
|
|
619
|
-
)
|
|
620
|
-
def test_loo_pit(centered_eight, args):
|
|
621
|
-
y = args.get("y", None)
|
|
622
|
-
y_hat = args.get("y_hat", None)
|
|
623
|
-
log_weights = args.get("log_weights", None)
|
|
624
|
-
y_arr = centered_eight.observed_data.obs
|
|
625
|
-
y_hat_arr = centered_eight.posterior_predictive.obs.stack(__sample__=("chain", "draw"))
|
|
626
|
-
log_like = get_log_likelihood(centered_eight).stack(__sample__=("chain", "draw"))
|
|
627
|
-
n_samples = len(log_like.__sample__)
|
|
628
|
-
ess_p = ess(centered_eight.posterior, method="mean")
|
|
629
|
-
reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
|
|
630
|
-
log_weights_arr = psislw(-log_like, reff=reff)[0]
|
|
631
|
-
|
|
632
|
-
if args.get("idata", True):
|
|
633
|
-
if y == "arr":
|
|
634
|
-
y = y_arr
|
|
635
|
-
if y_hat == "arr":
|
|
636
|
-
y_hat = y_hat_arr
|
|
637
|
-
if log_weights == "arr":
|
|
638
|
-
log_weights = log_weights_arr
|
|
639
|
-
loo_pit_data = loo_pit(idata=centered_eight, y=y, y_hat=y_hat, log_weights=log_weights)
|
|
640
|
-
else:
|
|
641
|
-
loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
|
|
642
|
-
assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
@pytest.mark.parametrize(
|
|
646
|
-
"args",
|
|
647
|
-
[
|
|
648
|
-
{"y": "y"},
|
|
649
|
-
{"y": "y", "y_hat": "y"},
|
|
650
|
-
{"y": "arr", "y_hat": "y"},
|
|
651
|
-
{"y": "y", "y_hat": "arr"},
|
|
652
|
-
{"y": "arr", "y_hat": "arr"},
|
|
653
|
-
{"y": "y", "y_hat": "y", "log_weights": "arr"},
|
|
654
|
-
{"y": "arr", "y_hat": "y", "log_weights": "arr"},
|
|
655
|
-
{"y": "y", "y_hat": "arr", "log_weights": "arr"},
|
|
656
|
-
{"idata": False},
|
|
657
|
-
],
|
|
658
|
-
)
|
|
659
|
-
def test_loo_pit_multidim(multidim_models, args):
|
|
660
|
-
y = args.get("y", None)
|
|
661
|
-
y_hat = args.get("y_hat", None)
|
|
662
|
-
log_weights = args.get("log_weights", None)
|
|
663
|
-
idata = multidim_models.model_1
|
|
664
|
-
y_arr = idata.observed_data.y
|
|
665
|
-
y_hat_arr = idata.posterior_predictive.y.stack(__sample__=("chain", "draw"))
|
|
666
|
-
log_like = get_log_likelihood(idata).stack(__sample__=("chain", "draw"))
|
|
667
|
-
n_samples = len(log_like.__sample__)
|
|
668
|
-
ess_p = ess(idata.posterior, method="mean")
|
|
669
|
-
reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
|
|
670
|
-
log_weights_arr = psislw(-log_like, reff=reff)[0]
|
|
671
|
-
|
|
672
|
-
if args.get("idata", True):
|
|
673
|
-
if y == "arr":
|
|
674
|
-
y = y_arr
|
|
675
|
-
if y_hat == "arr":
|
|
676
|
-
y_hat = y_hat_arr
|
|
677
|
-
if log_weights == "arr":
|
|
678
|
-
log_weights = log_weights_arr
|
|
679
|
-
loo_pit_data = loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
|
|
680
|
-
else:
|
|
681
|
-
loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
|
|
682
|
-
assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
def test_loo_pit_multi_lik():
|
|
686
|
-
rng = np.random.default_rng(0)
|
|
687
|
-
post_pred = rng.standard_normal(size=(4, 100, 10))
|
|
688
|
-
obs = np.quantile(post_pred, np.linspace(0, 1, 10))
|
|
689
|
-
obs[0] *= 0.9
|
|
690
|
-
obs[-1] *= 1.1
|
|
691
|
-
idata = from_dict(
|
|
692
|
-
posterior={"a": np.random.randn(4, 100)},
|
|
693
|
-
posterior_predictive={"y": post_pred},
|
|
694
|
-
observed_data={"y": obs},
|
|
695
|
-
log_likelihood={"y": -(post_pred**2), "decoy": np.zeros_like(post_pred)},
|
|
696
|
-
)
|
|
697
|
-
loo_pit_data = loo_pit(idata, y="y")
|
|
698
|
-
assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
@pytest.mark.parametrize("input_type", ["idataarray", "idatanone_ystr", "yarr_yhatnone"])
|
|
702
|
-
def test_loo_pit_bad_input(centered_eight, input_type):
|
|
703
|
-
"""Test incompatible input combinations."""
|
|
704
|
-
arr = np.random.random((8, 200))
|
|
705
|
-
if input_type == "idataarray":
|
|
706
|
-
with pytest.raises(ValueError, match=r"type InferenceData or None"):
|
|
707
|
-
loo_pit(idata=arr, y="obs")
|
|
708
|
-
elif input_type == "idatanone_ystr":
|
|
709
|
-
with pytest.raises(ValueError, match=r"all 3.+must be array or DataArray"):
|
|
710
|
-
loo_pit(idata=None, y="obs")
|
|
711
|
-
elif input_type == "yarr_yhatnone":
|
|
712
|
-
with pytest.raises(ValueError, match=r"y_hat.+None.+y.+str"):
|
|
713
|
-
loo_pit(idata=centered_eight, y=arr, y_hat=None)
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
@pytest.mark.parametrize("arg", ["y", "y_hat", "log_weights"])
|
|
717
|
-
def test_loo_pit_bad_input_type(centered_eight, arg):
|
|
718
|
-
"""Test wrong input type (not None, str not DataArray."""
|
|
719
|
-
kwargs = {"y": "obs", "y_hat": "obs", "log_weights": None}
|
|
720
|
-
kwargs[arg] = 2 # use int instead of array-like
|
|
721
|
-
with pytest.raises(ValueError, match=f"not {type(2)}"):
|
|
722
|
-
loo_pit(idata=centered_eight, **kwargs)
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
@pytest.mark.parametrize("incompatibility", ["y-y_hat1", "y-y_hat2", "y_hat-log_weights"])
|
|
726
|
-
def test_loo_pit_bad_input_shape(incompatibility):
|
|
727
|
-
"""Test shape incompatibilities."""
|
|
728
|
-
y = np.random.random(8)
|
|
729
|
-
y_hat = np.random.random((8, 200))
|
|
730
|
-
log_weights = np.random.random((8, 200))
|
|
731
|
-
if incompatibility == "y-y_hat1":
|
|
732
|
-
with pytest.raises(ValueError, match="1 more dimension"):
|
|
733
|
-
loo_pit(y=y, y_hat=y_hat[None, :], log_weights=log_weights)
|
|
734
|
-
elif incompatibility == "y-y_hat2":
|
|
735
|
-
with pytest.raises(ValueError, match="y has shape"):
|
|
736
|
-
loo_pit(y=y, y_hat=y_hat[1:3, :], log_weights=log_weights)
|
|
737
|
-
elif incompatibility == "y_hat-log_weights":
|
|
738
|
-
with pytest.raises(ValueError, match="must have the same shape"):
|
|
739
|
-
loo_pit(y=y, y_hat=y_hat[:, :100], log_weights=log_weights)
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
@pytest.mark.parametrize("pointwise", [True, False])
|
|
743
|
-
@pytest.mark.parametrize("inplace", [True, False])
|
|
744
|
-
@pytest.mark.parametrize(
|
|
745
|
-
"kwargs",
|
|
746
|
-
[
|
|
747
|
-
{},
|
|
748
|
-
{"group": "posterior_predictive", "var_names": {"posterior_predictive": "obs"}},
|
|
749
|
-
{"group": "observed_data", "var_names": {"both": "obs"}, "out_data_shape": "shape"},
|
|
750
|
-
{"var_names": {"both": "obs", "posterior": ["theta", "mu"]}},
|
|
751
|
-
{"group": "observed_data", "out_name_data": "T_name"},
|
|
752
|
-
],
|
|
753
|
-
)
|
|
754
|
-
def test_apply_test_function(centered_eight, pointwise, inplace, kwargs):
|
|
755
|
-
"""Test some usual call cases of apply_test_function"""
|
|
756
|
-
centered_eight = deepcopy(centered_eight)
|
|
757
|
-
group = kwargs.get("group", "both")
|
|
758
|
-
var_names = kwargs.get("var_names", None)
|
|
759
|
-
out_data_shape = kwargs.get("out_data_shape", None)
|
|
760
|
-
out_pp_shape = kwargs.get("out_pp_shape", None)
|
|
761
|
-
out_name_data = kwargs.get("out_name_data", "T")
|
|
762
|
-
if out_data_shape == "shape":
|
|
763
|
-
out_data_shape = (8,) if pointwise else ()
|
|
764
|
-
if out_pp_shape == "shape":
|
|
765
|
-
out_pp_shape = (4, 500, 8) if pointwise else (4, 500)
|
|
766
|
-
idata = deepcopy(centered_eight)
|
|
767
|
-
idata_out = apply_test_function(
|
|
768
|
-
idata,
|
|
769
|
-
lambda y, theta: np.mean(y),
|
|
770
|
-
group=group,
|
|
771
|
-
var_names=var_names,
|
|
772
|
-
pointwise=pointwise,
|
|
773
|
-
out_name_data=out_name_data,
|
|
774
|
-
out_data_shape=out_data_shape,
|
|
775
|
-
out_pp_shape=out_pp_shape,
|
|
776
|
-
)
|
|
777
|
-
if inplace:
|
|
778
|
-
assert idata is idata_out
|
|
779
|
-
|
|
780
|
-
if group == "both":
|
|
781
|
-
test_dict = {"observed_data": ["T"], "posterior_predictive": ["T"]}
|
|
782
|
-
else:
|
|
783
|
-
test_dict = {group: [kwargs.get("out_name_data", "T")]}
|
|
784
|
-
|
|
785
|
-
fails = check_multiple_attrs(test_dict, idata_out)
|
|
786
|
-
assert not fails
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
def test_apply_test_function_bad_group(centered_eight):
|
|
790
|
-
"""Test error when group is an invalid name."""
|
|
791
|
-
with pytest.raises(ValueError, match="Invalid group argument"):
|
|
792
|
-
apply_test_function(centered_eight, lambda y, theta: y, group="bad_group")
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
def test_apply_test_function_missing_group():
|
|
796
|
-
"""Test error when InferenceData object is missing a required group.
|
|
797
|
-
|
|
798
|
-
The function cannot work if group="both" but InferenceData object has no
|
|
799
|
-
posterior_predictive group.
|
|
800
|
-
"""
|
|
801
|
-
idata = from_dict(
|
|
802
|
-
posterior={"a": np.random.random((4, 500, 30))}, observed_data={"y": np.random.random(30)}
|
|
803
|
-
)
|
|
804
|
-
with pytest.raises(ValueError, match="must have posterior_predictive"):
|
|
805
|
-
apply_test_function(idata, lambda y, theta: np.mean, group="both")
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
def test_apply_test_function_should_overwrite_error(centered_eight):
|
|
809
|
-
"""Test error when overwrite=False but out_name is already a present variable."""
|
|
810
|
-
with pytest.raises(ValueError, match="Should overwrite"):
|
|
811
|
-
apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs")
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
def test_weight_predictions():
|
|
815
|
-
idata0 = from_dict(
|
|
816
|
-
posterior_predictive={"a": np.random.normal(-1, 1, 1000)}, observed_data={"a": [1]}
|
|
817
|
-
)
|
|
818
|
-
idata1 = from_dict(
|
|
819
|
-
posterior_predictive={"a": np.random.normal(1, 1, 1000)}, observed_data={"a": [1]}
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
new = weight_predictions([idata0, idata1])
|
|
823
|
-
assert (
|
|
824
|
-
idata1.posterior_predictive.mean()
|
|
825
|
-
> new.posterior_predictive.mean()
|
|
826
|
-
> idata0.posterior_predictive.mean()
|
|
827
|
-
)
|
|
828
|
-
assert "posterior_predictive" in new
|
|
829
|
-
assert "observed_data" in new
|
|
830
|
-
|
|
831
|
-
new = weight_predictions([idata0, idata1], weights=[0.5, 0.5])
|
|
832
|
-
assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
|
|
833
|
-
new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
|
|
834
|
-
assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
@pytest.fixture(scope="module")
|
|
838
|
-
def psens_data():
|
|
839
|
-
non_centered_eight = load_arviz_data("non_centered_eight")
|
|
840
|
-
post = non_centered_eight.posterior
|
|
841
|
-
log_prior = {
|
|
842
|
-
"mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
|
|
843
|
-
"tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
|
|
844
|
-
"theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
|
|
845
|
-
}
|
|
846
|
-
non_centered_eight.add_groups({"log_prior": log_prior})
|
|
847
|
-
return non_centered_eight
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
@pytest.mark.parametrize("component", ("prior", "likelihood"))
|
|
851
|
-
def test_priorsens_global(psens_data, component):
|
|
852
|
-
result = psens(psens_data, component=component)
|
|
853
|
-
assert "mu" in result
|
|
854
|
-
assert "theta" in result
|
|
855
|
-
assert "school" in result.theta_t.dims
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
def test_priorsens_var_names(psens_data):
|
|
859
|
-
result1 = psens(
|
|
860
|
-
psens_data, component="prior", component_var_names=["mu", "tau"], var_names=["mu", "tau"]
|
|
861
|
-
)
|
|
862
|
-
result2 = psens(psens_data, component="prior", var_names=["mu", "tau"])
|
|
863
|
-
for result in (result1, result2):
|
|
864
|
-
assert "theta" not in result
|
|
865
|
-
assert "mu" in result
|
|
866
|
-
assert "tau" in result
|
|
867
|
-
assert not np.isclose(result1.mu, result2.mu)
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
def test_priorsens_coords(psens_data):
|
|
871
|
-
result = psens(psens_data, component="likelihood", component_coords={"school": "Choate"})
|
|
872
|
-
assert "mu" in result
|
|
873
|
-
assert "theta" in result
|
|
874
|
-
assert "school" in result.theta_t.dims
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
def test_bayes_factor():
|
|
878
|
-
idata = from_dict(
|
|
879
|
-
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
880
|
-
)
|
|
881
|
-
bf_dict0 = bayes_factor(idata, var_name="a", ref_val=0)
|
|
882
|
-
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
|
|
883
|
-
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
884
|
-
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
def test_compare_sorting_consistency():
|
|
888
|
-
chains, draws = 4, 1000
|
|
889
|
-
|
|
890
|
-
# Model 1 - good fit
|
|
891
|
-
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
|
|
892
|
-
posterior1 = Dataset(
|
|
893
|
-
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
894
|
-
coords={"chain": range(chains), "draw": range(draws)},
|
|
895
|
-
)
|
|
896
|
-
log_like1 = Dataset(
|
|
897
|
-
{"y": (("chain", "draw"), log_lik1)},
|
|
898
|
-
coords={"chain": range(chains), "draw": range(draws)},
|
|
899
|
-
)
|
|
900
|
-
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
|
|
901
|
-
|
|
902
|
-
# Model 2 - poor fit (higher variance)
|
|
903
|
-
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
|
|
904
|
-
posterior2 = Dataset(
|
|
905
|
-
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
906
|
-
coords={"chain": range(chains), "draw": range(draws)},
|
|
907
|
-
)
|
|
908
|
-
log_like2 = Dataset(
|
|
909
|
-
{"y": (("chain", "draw"), log_lik2)},
|
|
910
|
-
coords={"chain": range(chains), "draw": range(draws)},
|
|
911
|
-
)
|
|
912
|
-
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
|
|
913
|
-
|
|
914
|
-
# Compare models in different orders
|
|
915
|
-
comp_dict1 = {"M1": data1, "M2": data2}
|
|
916
|
-
comp_dict2 = {"M2": data2, "M1": data1}
|
|
917
|
-
|
|
918
|
-
comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
|
|
919
|
-
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
|
|
920
|
-
|
|
921
|
-
assert comparison1.index.tolist() == comparison2.index.tolist()
|
|
922
|
-
|
|
923
|
-
se1 = comparison1["se"].values
|
|
924
|
-
se2 = comparison2["se"].values
|
|
925
|
-
np.testing.assert_array_almost_equal(se1, se2)
|