arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,2197 +0,0 @@
|
|
|
1
|
-
"""Tests use the default backend."""
|
|
2
|
-
|
|
3
|
-
# pylint: disable=redefined-outer-name,too-many-lines
|
|
4
|
-
import os
|
|
5
|
-
import re
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
|
-
import numpy as np
|
|
10
|
-
import pytest
|
|
11
|
-
import xarray as xr
|
|
12
|
-
from matplotlib import animation
|
|
13
|
-
from pandas import DataFrame
|
|
14
|
-
from scipy.stats import gaussian_kde, norm
|
|
15
|
-
|
|
16
|
-
from ...data import from_dict, load_arviz_data
|
|
17
|
-
from ...labels import MapLabeller
|
|
18
|
-
from ...plots import (
|
|
19
|
-
plot_autocorr,
|
|
20
|
-
plot_bf,
|
|
21
|
-
plot_bpv,
|
|
22
|
-
plot_compare,
|
|
23
|
-
plot_density,
|
|
24
|
-
plot_dist,
|
|
25
|
-
plot_dist_comparison,
|
|
26
|
-
plot_dot,
|
|
27
|
-
plot_ecdf,
|
|
28
|
-
plot_elpd,
|
|
29
|
-
plot_energy,
|
|
30
|
-
plot_ess,
|
|
31
|
-
plot_forest,
|
|
32
|
-
plot_hdi,
|
|
33
|
-
plot_kde,
|
|
34
|
-
plot_khat,
|
|
35
|
-
plot_lm,
|
|
36
|
-
plot_loo_pit,
|
|
37
|
-
plot_mcse,
|
|
38
|
-
plot_pair,
|
|
39
|
-
plot_parallel,
|
|
40
|
-
plot_posterior,
|
|
41
|
-
plot_ppc,
|
|
42
|
-
plot_rank,
|
|
43
|
-
plot_separation,
|
|
44
|
-
plot_trace,
|
|
45
|
-
plot_ts,
|
|
46
|
-
plot_violin,
|
|
47
|
-
)
|
|
48
|
-
from ...plots.dotplot import wilkinson_algorithm
|
|
49
|
-
from ...plots.plot_utils import plot_point_interval
|
|
50
|
-
from ...rcparams import rc_context, rcParams
|
|
51
|
-
from ...stats import compare, hdi, loo, waic
|
|
52
|
-
from ...stats.density_utils import kde as _kde
|
|
53
|
-
from ...utils import BehaviourChangeWarning, _cov
|
|
54
|
-
from ..helpers import ( # pylint: disable=unused-import
|
|
55
|
-
RandomVariableTestClass,
|
|
56
|
-
create_model,
|
|
57
|
-
create_multidimensional_model,
|
|
58
|
-
does_not_warn,
|
|
59
|
-
eight_schools_params,
|
|
60
|
-
models,
|
|
61
|
-
multidim_models,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
rcParams["data.load"] = "eager"
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@pytest.fixture(scope="function", autouse=True)
|
|
68
|
-
def clean_plots(request, save_figs):
|
|
69
|
-
"""Close plots after each test, optionally save if --save is specified during test invocation"""
|
|
70
|
-
|
|
71
|
-
def fin():
|
|
72
|
-
if save_figs is not None:
|
|
73
|
-
plt.savefig(f"{os.path.join(save_figs, request.node.name)}.png")
|
|
74
|
-
plt.close("all")
|
|
75
|
-
|
|
76
|
-
request.addfinalizer(fin)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@pytest.fixture(scope="module")
|
|
80
|
-
def data(eight_schools_params):
|
|
81
|
-
data = eight_schools_params
|
|
82
|
-
return data
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@pytest.fixture(scope="module")
|
|
86
|
-
def df_trace():
|
|
87
|
-
return DataFrame({"a": np.random.poisson(2.3, 100)})
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@pytest.fixture(scope="module")
|
|
91
|
-
def discrete_model():
|
|
92
|
-
"""Simple fixture for random discrete model"""
|
|
93
|
-
return {"x": np.random.randint(10, size=100), "y": np.random.randint(10, size=100)}
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@pytest.fixture(scope="module")
|
|
97
|
-
def discrete_multidim_model():
|
|
98
|
-
"""Simple fixture for random discrete model"""
|
|
99
|
-
idata = from_dict(
|
|
100
|
-
{"x": np.random.randint(10, size=(2, 50, 3)), "y": np.random.randint(10, size=(2, 50))},
|
|
101
|
-
dims={"x": ["school"]},
|
|
102
|
-
)
|
|
103
|
-
return idata
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
@pytest.fixture(scope="module")
|
|
107
|
-
def continuous_model():
|
|
108
|
-
"""Simple fixture for random continuous model"""
|
|
109
|
-
return {"x": np.random.beta(2, 5, size=100), "y": np.random.beta(2, 5, size=100)}
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@pytest.fixture(scope="function")
|
|
113
|
-
def fig_ax():
|
|
114
|
-
fig, ax = plt.subplots(1, 1)
|
|
115
|
-
return fig, ax
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
@pytest.fixture(scope="module")
|
|
119
|
-
def data_random():
|
|
120
|
-
return np.random.randint(1, 100, size=20)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
@pytest.fixture(scope="module")
|
|
124
|
-
def data_list():
|
|
125
|
-
return list(range(11, 31))
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
@pytest.mark.parametrize(
|
|
129
|
-
"kwargs",
|
|
130
|
-
[
|
|
131
|
-
{"point_estimate": "mean"},
|
|
132
|
-
{"point_estimate": "median"},
|
|
133
|
-
{"hdi_prob": 0.94},
|
|
134
|
-
{"hdi_prob": 1},
|
|
135
|
-
{"outline": True},
|
|
136
|
-
{"colors": ["g", "b", "r", "y"]},
|
|
137
|
-
{"colors": "k"},
|
|
138
|
-
{"hdi_markers": ["v"]},
|
|
139
|
-
{"shade": 1},
|
|
140
|
-
{"transform": lambda x: x + 1},
|
|
141
|
-
{"ax": plt.subplots(6, 3)[1]},
|
|
142
|
-
],
|
|
143
|
-
)
|
|
144
|
-
def test_plot_density_float(models, kwargs):
|
|
145
|
-
obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
|
|
146
|
-
axes = plot_density(obj, **kwargs)
|
|
147
|
-
assert axes.shape == (6, 3)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def test_plot_density_discrete(discrete_model):
|
|
151
|
-
axes = plot_density(discrete_model, shade=0.9)
|
|
152
|
-
assert axes.size == 2
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
def test_plot_density_no_subset():
|
|
156
|
-
"""Test plot_density works when variables are not subset of one another (#1093)."""
|
|
157
|
-
model_ab = from_dict(
|
|
158
|
-
{
|
|
159
|
-
"a": np.random.normal(size=200),
|
|
160
|
-
"b": np.random.normal(size=200),
|
|
161
|
-
}
|
|
162
|
-
)
|
|
163
|
-
model_bc = from_dict(
|
|
164
|
-
{
|
|
165
|
-
"b": np.random.normal(size=200),
|
|
166
|
-
"c": np.random.normal(size=200),
|
|
167
|
-
}
|
|
168
|
-
)
|
|
169
|
-
axes = plot_density([model_ab, model_bc])
|
|
170
|
-
assert axes.size == 3
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def test_plot_density_nonstring_varnames():
|
|
174
|
-
"""Test plot_density works when variables are not strings."""
|
|
175
|
-
rv1 = RandomVariableTestClass("a")
|
|
176
|
-
rv2 = RandomVariableTestClass("b")
|
|
177
|
-
rv3 = RandomVariableTestClass("c")
|
|
178
|
-
model_ab = from_dict(
|
|
179
|
-
{
|
|
180
|
-
rv1: np.random.normal(size=200),
|
|
181
|
-
rv2: np.random.normal(size=200),
|
|
182
|
-
}
|
|
183
|
-
)
|
|
184
|
-
model_bc = from_dict(
|
|
185
|
-
{
|
|
186
|
-
rv2: np.random.normal(size=200),
|
|
187
|
-
rv3: np.random.normal(size=200),
|
|
188
|
-
}
|
|
189
|
-
)
|
|
190
|
-
axes = plot_density([model_ab, model_bc])
|
|
191
|
-
assert axes.size == 3
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
def test_plot_density_bad_kwargs(models):
|
|
195
|
-
obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
|
|
196
|
-
with pytest.raises(ValueError):
|
|
197
|
-
plot_density(obj, point_estimate="bad_value")
|
|
198
|
-
|
|
199
|
-
with pytest.raises(ValueError):
|
|
200
|
-
plot_density(obj, data_labels=[f"bad_value_{i}" for i in range(len(obj) + 10)])
|
|
201
|
-
|
|
202
|
-
with pytest.raises(ValueError):
|
|
203
|
-
plot_density(obj, hdi_prob=2)
|
|
204
|
-
|
|
205
|
-
with pytest.raises(ValueError):
|
|
206
|
-
plot_density(obj, filter_vars="bad_value")
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def test_plot_density_discrete_combinedims(discrete_model):
|
|
210
|
-
axes = plot_density(discrete_model, combine_dims={"school"}, shade=0.9)
|
|
211
|
-
assert axes.size == 2
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@pytest.mark.parametrize(
|
|
215
|
-
"kwargs",
|
|
216
|
-
[
|
|
217
|
-
{},
|
|
218
|
-
{"y_hat_line": True},
|
|
219
|
-
{"expected_events": True},
|
|
220
|
-
{"y_hat_line_kwargs": {"linestyle": "dotted"}},
|
|
221
|
-
{"exp_events_kwargs": {"marker": "o"}},
|
|
222
|
-
],
|
|
223
|
-
)
|
|
224
|
-
def test_plot_separation(kwargs):
|
|
225
|
-
idata = load_arviz_data("classification10d")
|
|
226
|
-
ax = plot_separation(idata=idata, y="outcome", **kwargs)
|
|
227
|
-
assert ax
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
@pytest.mark.parametrize(
|
|
231
|
-
"kwargs",
|
|
232
|
-
[
|
|
233
|
-
{},
|
|
234
|
-
{"var_names": "mu"},
|
|
235
|
-
{"var_names": ["mu", "tau"]},
|
|
236
|
-
{"combined": True},
|
|
237
|
-
{"compact": True},
|
|
238
|
-
{"combined": True, "compact": True, "legend": True},
|
|
239
|
-
{"divergences": "top", "legend": True},
|
|
240
|
-
{"divergences": False},
|
|
241
|
-
{"kind": "rank_vlines"},
|
|
242
|
-
{"kind": "rank_bars"},
|
|
243
|
-
{"lines": [("mu", {}, [1, 2])]},
|
|
244
|
-
{"lines": [("mu", {}, 8)]},
|
|
245
|
-
{"circ_var_names": ["mu"]},
|
|
246
|
-
{"circ_var_names": ["mu"], "circ_var_units": "degrees"},
|
|
247
|
-
],
|
|
248
|
-
)
|
|
249
|
-
def test_plot_trace(models, kwargs):
|
|
250
|
-
axes = plot_trace(models.model_1, **kwargs)
|
|
251
|
-
assert axes.shape
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
@pytest.mark.parametrize(
|
|
255
|
-
"compact",
|
|
256
|
-
[True, False],
|
|
257
|
-
)
|
|
258
|
-
@pytest.mark.parametrize(
|
|
259
|
-
"combined",
|
|
260
|
-
[True, False],
|
|
261
|
-
)
|
|
262
|
-
def test_plot_trace_legend(compact, combined):
|
|
263
|
-
idata = load_arviz_data("rugby")
|
|
264
|
-
axes = plot_trace(
|
|
265
|
-
idata, var_names=["home", "atts_star"], compact=compact, combined=combined, legend=True
|
|
266
|
-
)
|
|
267
|
-
assert axes[0, 1].get_legend()
|
|
268
|
-
compact_legend = axes[1, 0].get_legend()
|
|
269
|
-
if compact:
|
|
270
|
-
assert axes.shape == (2, 2)
|
|
271
|
-
assert compact_legend
|
|
272
|
-
else:
|
|
273
|
-
assert axes.shape == (7, 2)
|
|
274
|
-
assert not compact_legend
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def test_plot_trace_discrete(discrete_model):
|
|
278
|
-
axes = plot_trace(discrete_model)
|
|
279
|
-
assert axes.shape
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def test_plot_trace_max_subplots_warning(models):
|
|
283
|
-
with pytest.warns(UserWarning):
|
|
284
|
-
with rc_context(rc={"plot.max_subplots": 6}):
|
|
285
|
-
axes = plot_trace(models.model_1)
|
|
286
|
-
assert axes.shape == (3, 2)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def test_plot_dist_comparison_warning(models):
|
|
290
|
-
with pytest.warns(UserWarning):
|
|
291
|
-
with rc_context(rc={"plot.max_subplots": 6}):
|
|
292
|
-
axes = plot_dist_comparison(models.model_1)
|
|
293
|
-
assert axes.shape == (2, 3)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
|
|
297
|
-
def test_plot_trace_invalid_varname_warning(models, kwargs):
|
|
298
|
-
with pytest.warns(UserWarning, match="valid var.+should be provided"):
|
|
299
|
-
axes = plot_trace(models.model_1, **kwargs)
|
|
300
|
-
assert axes.shape
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
def test_plot_trace_diverging_correctly_transposed():
|
|
304
|
-
idata = load_arviz_data("centered_eight")
|
|
305
|
-
idata.sample_stats["diverging"] = idata.sample_stats.diverging.T
|
|
306
|
-
plot_trace(idata, divergences="bottom")
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
@pytest.mark.parametrize(
|
|
310
|
-
"bad_kwargs", [{"var_names": ["mu", "tau"], "lines": [("mu", {}, ["hey"])]}]
|
|
311
|
-
)
|
|
312
|
-
def test_plot_trace_bad_lines_value(models, bad_kwargs):
|
|
313
|
-
with pytest.raises(ValueError, match="line-positions should be numeric"):
|
|
314
|
-
plot_trace(models.model_1, **bad_kwargs)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
@pytest.mark.parametrize("prop", ["chain_prop", "compact_prop"])
|
|
318
|
-
def test_plot_trace_futurewarning(models, prop):
|
|
319
|
-
with pytest.warns(FutureWarning, match=f"{prop} as a tuple.+deprecated"):
|
|
320
|
-
ax = plot_trace(models.model_1, **{prop: ("ls", ("-", "--"))})
|
|
321
|
-
assert ax.shape
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
@pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
|
|
325
|
-
@pytest.mark.parametrize(
|
|
326
|
-
"args_expected",
|
|
327
|
-
[
|
|
328
|
-
({}, 1),
|
|
329
|
-
({"var_names": "mu", "transform": lambda x: x + 1}, 1),
|
|
330
|
-
({"var_names": "mu", "rope": (-1, 1), "combine_dims": {"school"}}, 1),
|
|
331
|
-
({"r_hat": True, "quartiles": False}, 2),
|
|
332
|
-
({"var_names": ["mu"], "colors": "C0", "ess": True, "combined": True}, 2),
|
|
333
|
-
(
|
|
334
|
-
{
|
|
335
|
-
"kind": "ridgeplot",
|
|
336
|
-
"ridgeplot_truncate": False,
|
|
337
|
-
"ridgeplot_quantiles": [0.25, 0.5, 0.75],
|
|
338
|
-
},
|
|
339
|
-
1,
|
|
340
|
-
),
|
|
341
|
-
({"kind": "ridgeplot", "r_hat": True, "ess": True}, 3),
|
|
342
|
-
({"kind": "ridgeplot", "r_hat": True, "ess": True}, 3),
|
|
343
|
-
({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3),
|
|
344
|
-
(
|
|
345
|
-
{
|
|
346
|
-
"var_names": ["mu", "theta"],
|
|
347
|
-
"rope": {
|
|
348
|
-
"mu": [{"rope": (-0.1, 0.1)}],
|
|
349
|
-
"theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
|
|
350
|
-
},
|
|
351
|
-
},
|
|
352
|
-
1,
|
|
353
|
-
),
|
|
354
|
-
],
|
|
355
|
-
)
|
|
356
|
-
def test_plot_forest(models, model_fits, args_expected):
|
|
357
|
-
obj = [getattr(models, model_fit) for model_fit in model_fits]
|
|
358
|
-
args, expected = args_expected
|
|
359
|
-
axes = plot_forest(obj, **args)
|
|
360
|
-
assert axes.size == expected
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
def test_plot_forest_rope_exception():
|
|
364
|
-
with pytest.raises(ValueError) as err:
|
|
365
|
-
plot_forest({"x": [1]}, rope="not_correct_format")
|
|
366
|
-
assert "Argument `rope` must be None, a dictionary like" in str(err.value)
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
def test_plot_forest_single_value():
|
|
370
|
-
axes = plot_forest({"x": [1]})
|
|
371
|
-
assert axes.shape
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
def test_plot_forest_ridge_discrete(discrete_model):
|
|
375
|
-
axes = plot_forest(discrete_model, kind="ridgeplot")
|
|
376
|
-
assert axes.shape
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
@pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
|
|
380
|
-
def test_plot_forest_bad(models, model_fits):
|
|
381
|
-
obj = [getattr(models, model_fit) for model_fit in model_fits]
|
|
382
|
-
with pytest.raises(TypeError):
|
|
383
|
-
plot_forest(obj, kind="bad_kind")
|
|
384
|
-
|
|
385
|
-
with pytest.raises(ValueError):
|
|
386
|
-
plot_forest(obj, model_names=[f"model_name_{i}" for i in range(len(obj) + 10)])
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
@pytest.mark.parametrize("kind", ["kde", "hist"])
|
|
390
|
-
def test_plot_energy(models, kind):
|
|
391
|
-
assert plot_energy(models.model_1, kind=kind)
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
def test_plot_energy_bad(models):
|
|
395
|
-
with pytest.raises(ValueError):
|
|
396
|
-
plot_energy(models.model_1, kind="bad_kind")
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def test_plot_energy_correctly_transposed():
|
|
400
|
-
idata = load_arviz_data("centered_eight")
|
|
401
|
-
idata.sample_stats["energy"] = idata.sample_stats.energy.T
|
|
402
|
-
ax = plot_energy(idata)
|
|
403
|
-
# legend has one entry for each KDE and 1 BFMI for each chain
|
|
404
|
-
assert len(ax.legend_.texts) == 2 + len(idata.sample_stats.chain)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
def test_plot_parallel_raises_valueerror(df_trace): # pylint: disable=invalid-name
|
|
408
|
-
with pytest.raises(ValueError):
|
|
409
|
-
plot_parallel(df_trace)
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
@pytest.mark.parametrize("norm_method", [None, "normal", "minmax", "rank"])
|
|
413
|
-
def test_plot_parallel(models, norm_method):
|
|
414
|
-
assert plot_parallel(models.model_1, var_names=["mu", "tau"], norm_method=norm_method)
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
@pytest.mark.parametrize("var_names", [None, "mu", ["mu", "tau"]])
|
|
418
|
-
def test_plot_parallel_exception(models, var_names):
|
|
419
|
-
"""Ensure that correct exception is raised when one variable is passed."""
|
|
420
|
-
with pytest.raises(ValueError):
|
|
421
|
-
assert plot_parallel(models.model_1, var_names=var_names, norm_method="foo")
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
@pytest.mark.parametrize(
|
|
425
|
-
"kwargs",
|
|
426
|
-
[
|
|
427
|
-
{"plot_kwargs": {"linestyle": "-"}},
|
|
428
|
-
{"contour": True, "fill_last": False},
|
|
429
|
-
{
|
|
430
|
-
"contour": True,
|
|
431
|
-
"contourf_kwargs": {"cmap": "plasma"},
|
|
432
|
-
"contour_kwargs": {"linewidths": 1},
|
|
433
|
-
},
|
|
434
|
-
{"contour": False},
|
|
435
|
-
{"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}},
|
|
436
|
-
{"is_circular": False},
|
|
437
|
-
{"is_circular": True},
|
|
438
|
-
{"is_circular": "radians"},
|
|
439
|
-
{"is_circular": "degrees"},
|
|
440
|
-
{"adaptive": True},
|
|
441
|
-
{"hdi_probs": [0.3, 0.9, 0.6]},
|
|
442
|
-
{"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
|
|
443
|
-
{"hdi_probs": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
|
|
444
|
-
],
|
|
445
|
-
)
|
|
446
|
-
def test_plot_kde(continuous_model, kwargs):
|
|
447
|
-
axes = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
|
|
448
|
-
axes1 = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
|
|
449
|
-
assert axes
|
|
450
|
-
assert axes is axes1
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
@pytest.mark.parametrize(
|
|
454
|
-
"kwargs",
|
|
455
|
-
[
|
|
456
|
-
{"hdi_probs": [1, 2, 3]},
|
|
457
|
-
{"hdi_probs": [-0.3, 0.6, 0.9]},
|
|
458
|
-
{"hdi_probs": [0, 0.3, 0.6]},
|
|
459
|
-
{"hdi_probs": [0.3, 0.6, 1]},
|
|
460
|
-
],
|
|
461
|
-
)
|
|
462
|
-
def test_plot_kde_hdi_probs_bad(continuous_model, kwargs):
|
|
463
|
-
"""Ensure invalid hdi probabilities are rejected."""
|
|
464
|
-
with pytest.raises(ValueError):
|
|
465
|
-
plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
@pytest.mark.parametrize(
|
|
469
|
-
"kwargs",
|
|
470
|
-
[
|
|
471
|
-
{"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"levels": [0, 0.5, 1]}},
|
|
472
|
-
{"hdi_probs": [0.3, 0.6, 0.9], "contour_kwargs": {"levels": [0, 0.5, 1]}},
|
|
473
|
-
],
|
|
474
|
-
)
|
|
475
|
-
def test_plot_kde_hdi_probs_warning(continuous_model, kwargs):
|
|
476
|
-
"""Ensure warning is raised when too many keywords are specified."""
|
|
477
|
-
with pytest.warns(UserWarning):
|
|
478
|
-
axes = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
|
|
479
|
-
assert axes
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
@pytest.mark.parametrize("shape", [(8,), (8, 8), (8, 8, 8)])
|
|
483
|
-
def test_cov(shape):
|
|
484
|
-
x = np.random.randn(*shape)
|
|
485
|
-
if x.ndim <= 2:
|
|
486
|
-
assert np.allclose(_cov(x), np.cov(x))
|
|
487
|
-
else:
|
|
488
|
-
with pytest.raises(ValueError):
|
|
489
|
-
_cov(x)
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
@pytest.mark.parametrize(
|
|
493
|
-
"kwargs",
|
|
494
|
-
[
|
|
495
|
-
{"cumulative": True},
|
|
496
|
-
{"cumulative": True, "plot_kwargs": {"linestyle": "--"}},
|
|
497
|
-
{"rug": True},
|
|
498
|
-
{"rug": True, "rug_kwargs": {"alpha": 0.2}, "rotated": True},
|
|
499
|
-
],
|
|
500
|
-
)
|
|
501
|
-
def test_plot_kde_cumulative(continuous_model, kwargs):
|
|
502
|
-
axes = plot_kde(continuous_model["x"], quantiles=[0.25, 0.5, 0.75], **kwargs)
|
|
503
|
-
assert axes
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
@pytest.mark.parametrize(
|
|
507
|
-
"kwargs",
|
|
508
|
-
[
|
|
509
|
-
{"kind": "hist"},
|
|
510
|
-
{"kind": "kde"},
|
|
511
|
-
{"is_circular": False},
|
|
512
|
-
{"is_circular": False, "kind": "hist"},
|
|
513
|
-
{"is_circular": True},
|
|
514
|
-
{"is_circular": True, "kind": "hist"},
|
|
515
|
-
{"is_circular": "radians"},
|
|
516
|
-
{"is_circular": "radians", "kind": "hist"},
|
|
517
|
-
{"is_circular": "degrees"},
|
|
518
|
-
{"is_circular": "degrees", "kind": "hist"},
|
|
519
|
-
],
|
|
520
|
-
)
|
|
521
|
-
def test_plot_dist(continuous_model, kwargs):
|
|
522
|
-
axes = plot_dist(continuous_model["x"], **kwargs)
|
|
523
|
-
axes1 = plot_dist(continuous_model["x"], **kwargs)
|
|
524
|
-
assert axes
|
|
525
|
-
assert axes is axes1
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
def test_plot_dist_hist(data_random):
|
|
529
|
-
axes = plot_dist(data_random, hist_kwargs=dict(bins=30))
|
|
530
|
-
assert axes
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
def test_list_conversion(data_list):
|
|
534
|
-
axes = plot_dist(data_list, hist_kwargs=dict(bins=30))
|
|
535
|
-
assert axes
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
@pytest.mark.parametrize(
|
|
539
|
-
"kwargs",
|
|
540
|
-
[
|
|
541
|
-
{"plot_kwargs": {"linestyle": "-"}},
|
|
542
|
-
{"contour": True, "fill_last": False},
|
|
543
|
-
{"contour": False},
|
|
544
|
-
],
|
|
545
|
-
)
|
|
546
|
-
def test_plot_dist_2d_kde(continuous_model, kwargs):
|
|
547
|
-
axes = plot_dist(continuous_model["x"], continuous_model["y"], **kwargs)
|
|
548
|
-
assert axes
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
@pytest.mark.parametrize(
|
|
552
|
-
"kwargs", [{"plot_kwargs": {"linestyle": "-"}}, {"cumulative": True}, {"rug": True}]
|
|
553
|
-
)
|
|
554
|
-
def test_plot_kde_quantiles(continuous_model, kwargs):
|
|
555
|
-
axes = plot_kde(continuous_model["x"], quantiles=[0.05, 0.5, 0.95], **kwargs)
|
|
556
|
-
assert axes
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
def test_plot_kde_inference_data(models):
|
|
560
|
-
"""
|
|
561
|
-
Ensure that an exception is raised when plot_kde
|
|
562
|
-
is used with an inference data or Xarray dataset object.
|
|
563
|
-
"""
|
|
564
|
-
with pytest.raises(ValueError, match="Inference Data"):
|
|
565
|
-
plot_kde(models.model_1)
|
|
566
|
-
with pytest.raises(ValueError, match="Xarray"):
|
|
567
|
-
plot_kde(models.model_1.posterior)
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
@pytest.mark.slow
|
|
571
|
-
@pytest.mark.parametrize(
|
|
572
|
-
"kwargs",
|
|
573
|
-
[
|
|
574
|
-
{
|
|
575
|
-
"var_names": "theta",
|
|
576
|
-
"divergences": True,
|
|
577
|
-
"coords": {"school": [0, 1]},
|
|
578
|
-
"scatter_kwargs": {"marker": "x", "c": "C0"},
|
|
579
|
-
"divergences_kwargs": {"marker": "*", "c": "C0"},
|
|
580
|
-
},
|
|
581
|
-
{
|
|
582
|
-
"divergences": True,
|
|
583
|
-
"scatter_kwargs": {"marker": "x", "c": "C0"},
|
|
584
|
-
"divergences_kwargs": {"marker": "*", "c": "C0"},
|
|
585
|
-
"var_names": ["theta", "mu"],
|
|
586
|
-
},
|
|
587
|
-
{"kind": "kde", "var_names": ["theta"]},
|
|
588
|
-
{"kind": "hexbin", "colorbar": False, "var_names": ["theta"]},
|
|
589
|
-
{"kind": "hexbin", "colorbar": True, "var_names": ["theta"]},
|
|
590
|
-
{
|
|
591
|
-
"kind": "hexbin",
|
|
592
|
-
"var_names": ["theta"],
|
|
593
|
-
"coords": {"school": [0, 1]},
|
|
594
|
-
"colorbar": True,
|
|
595
|
-
"hexbin_kwargs": {"cmap": "viridis"},
|
|
596
|
-
"textsize": 20,
|
|
597
|
-
},
|
|
598
|
-
{
|
|
599
|
-
"point_estimate": "mean",
|
|
600
|
-
"reference_values": {"mu": 0, "tau": 0},
|
|
601
|
-
"reference_values_kwargs": {"c": "C0", "marker": "*"},
|
|
602
|
-
},
|
|
603
|
-
{
|
|
604
|
-
"var_names": ["mu", "tau"],
|
|
605
|
-
"reference_values": {"mu": 0, "tau": 0},
|
|
606
|
-
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
607
|
-
},
|
|
608
|
-
{
|
|
609
|
-
"var_names": ["theta"],
|
|
610
|
-
"reference_values": {"theta": [0.0] * 8},
|
|
611
|
-
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
612
|
-
},
|
|
613
|
-
{
|
|
614
|
-
"var_names": ["theta"],
|
|
615
|
-
"reference_values": {"theta": np.zeros(8)},
|
|
616
|
-
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
617
|
-
},
|
|
618
|
-
],
|
|
619
|
-
)
|
|
620
|
-
def test_plot_pair(models, kwargs):
|
|
621
|
-
ax = plot_pair(models.model_1, **kwargs)
|
|
622
|
-
assert np.all(ax)
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
@pytest.mark.parametrize(
|
|
626
|
-
"kwargs", [{"kind": "scatter"}, {"kind": "kde"}, {"kind": "hexbin", "colorbar": True}]
|
|
627
|
-
)
|
|
628
|
-
def test_plot_pair_2var(discrete_model, fig_ax, kwargs):
|
|
629
|
-
_, ax = fig_ax
|
|
630
|
-
ax = plot_pair(discrete_model, ax=ax, **kwargs)
|
|
631
|
-
assert ax
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
def test_plot_pair_bad(models):
|
|
635
|
-
with pytest.raises(ValueError):
|
|
636
|
-
plot_pair(models.model_1, kind="bad_kind")
|
|
637
|
-
with pytest.raises(Exception):
|
|
638
|
-
plot_pair(models.model_1, var_names=["mu"])
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
@pytest.mark.parametrize("has_sample_stats", [True, False])
|
|
642
|
-
def test_plot_pair_divergences_warning(has_sample_stats):
|
|
643
|
-
data = load_arviz_data("centered_eight")
|
|
644
|
-
if has_sample_stats:
|
|
645
|
-
# sample_stats present, diverging field missing
|
|
646
|
-
data.sample_stats = data.sample_stats.rename({"diverging": "diverging_missing"})
|
|
647
|
-
else:
|
|
648
|
-
# sample_stats missing
|
|
649
|
-
data = data.posterior # pylint: disable=no-member
|
|
650
|
-
with pytest.warns(UserWarning):
|
|
651
|
-
ax = plot_pair(data, divergences=True)
|
|
652
|
-
assert np.all(ax)
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
@pytest.mark.parametrize(
|
|
656
|
-
"kwargs", [{}, {"marginals": True}, {"marginals": True, "var_names": ["mu", "tau"]}]
|
|
657
|
-
)
|
|
658
|
-
def test_plot_pair_overlaid(models, kwargs):
|
|
659
|
-
ax = plot_pair(models.model_1, **kwargs)
|
|
660
|
-
ax2 = plot_pair(models.model_2, ax=ax, **kwargs)
|
|
661
|
-
assert ax is ax2
|
|
662
|
-
assert ax.shape
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
@pytest.mark.parametrize("marginals", [True, False])
|
|
666
|
-
def test_plot_pair_combinedims(models, marginals):
|
|
667
|
-
ax = plot_pair(
|
|
668
|
-
models.model_1, var_names=["eta", "theta"], combine_dims={"school"}, marginals=marginals
|
|
669
|
-
)
|
|
670
|
-
if marginals:
|
|
671
|
-
assert ax.shape == (2, 2)
|
|
672
|
-
else:
|
|
673
|
-
assert not isinstance(ax, np.ndarray)
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
@pytest.mark.parametrize("marginals", [True, False])
|
|
677
|
-
@pytest.mark.parametrize("max_subplots", [True, False])
|
|
678
|
-
def test_plot_pair_shapes(marginals, max_subplots):
|
|
679
|
-
rng = np.random.default_rng()
|
|
680
|
-
idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
|
|
681
|
-
if max_subplots:
|
|
682
|
-
with rc_context({"plot.max_subplots": 6}):
|
|
683
|
-
with pytest.warns(UserWarning, match="3x3 grid"):
|
|
684
|
-
ax = plot_pair(idata, marginals=marginals)
|
|
685
|
-
else:
|
|
686
|
-
ax = plot_pair(idata, marginals=marginals)
|
|
687
|
-
side = 3 if max_subplots else (4 + marginals)
|
|
688
|
-
assert ax.shape == (side, side)
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
@pytest.mark.parametrize("sharex", ["col", None])
|
|
692
|
-
@pytest.mark.parametrize("sharey", ["row", None])
|
|
693
|
-
@pytest.mark.parametrize("marginals", [True, False])
|
|
694
|
-
def test_plot_pair_shared(sharex, sharey, marginals):
|
|
695
|
-
# Generate fake data and plot
|
|
696
|
-
rng = np.random.default_rng()
|
|
697
|
-
idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
|
|
698
|
-
numvars = 5 - (not marginals)
|
|
699
|
-
if sharex is None and sharey is None:
|
|
700
|
-
ax = plot_pair(idata, marginals=marginals)
|
|
701
|
-
else:
|
|
702
|
-
backend_kwargs = {}
|
|
703
|
-
if sharex is not None:
|
|
704
|
-
backend_kwargs["sharex"] = sharex
|
|
705
|
-
if sharey is not None:
|
|
706
|
-
backend_kwargs["sharey"] = sharey
|
|
707
|
-
with pytest.warns(UserWarning):
|
|
708
|
-
ax = plot_pair(idata, marginals=marginals, backend_kwargs=backend_kwargs)
|
|
709
|
-
|
|
710
|
-
# Check x axes shared correctly
|
|
711
|
-
for i in range(numvars):
|
|
712
|
-
num_shared_x = numvars - i
|
|
713
|
-
assert len(ax[-1, i].get_shared_x_axes().get_siblings(ax[-1, i])) == num_shared_x
|
|
714
|
-
|
|
715
|
-
# Check y axes shared correctly
|
|
716
|
-
for j in range(numvars):
|
|
717
|
-
if marginals:
|
|
718
|
-
num_shared_y = j
|
|
719
|
-
|
|
720
|
-
# Check diagonal has unshared axis
|
|
721
|
-
assert len(ax[j, j].get_shared_y_axes().get_siblings(ax[j, j])) == 1
|
|
722
|
-
|
|
723
|
-
if j == 0:
|
|
724
|
-
continue
|
|
725
|
-
else:
|
|
726
|
-
num_shared_y = j + 1
|
|
727
|
-
assert len(ax[j, 0].get_shared_y_axes().get_siblings(ax[j, 0])) == num_shared_y
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
731
|
-
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
|
|
732
|
-
@pytest.mark.parametrize("animated", [False, True])
|
|
733
|
-
@pytest.mark.parametrize("observed", [True, False])
|
|
734
|
-
@pytest.mark.parametrize("observed_rug", [False, True])
|
|
735
|
-
def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
|
|
736
|
-
if animation and not animation.writers.is_available("ffmpeg"):
|
|
737
|
-
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
|
|
738
|
-
animation_kwargs = {"blit": False}
|
|
739
|
-
axes = plot_ppc(
|
|
740
|
-
models.model_1,
|
|
741
|
-
kind=kind,
|
|
742
|
-
alpha=alpha,
|
|
743
|
-
observed=observed,
|
|
744
|
-
observed_rug=observed_rug,
|
|
745
|
-
animated=animated,
|
|
746
|
-
animation_kwargs=animation_kwargs,
|
|
747
|
-
random_seed=3,
|
|
748
|
-
)
|
|
749
|
-
if animated:
|
|
750
|
-
assert axes[0]
|
|
751
|
-
assert axes[1]
|
|
752
|
-
assert axes
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
def test_plot_ppc_transposed():
|
|
756
|
-
idata = load_arviz_data("rugby")
|
|
757
|
-
idata.map(
|
|
758
|
-
lambda ds: ds.assign(points=xr.concat((ds.home_points, ds.away_points), "field")),
|
|
759
|
-
groups="observed_vars",
|
|
760
|
-
inplace=True,
|
|
761
|
-
)
|
|
762
|
-
assert idata.posterior_predictive.points.dims == ("field", "chain", "draw", "match")
|
|
763
|
-
ax = plot_ppc(
|
|
764
|
-
idata,
|
|
765
|
-
kind="scatter",
|
|
766
|
-
var_names="points",
|
|
767
|
-
flatten=["field"],
|
|
768
|
-
coords={"match": ["Wales Italy"]},
|
|
769
|
-
random_seed=3,
|
|
770
|
-
num_pp_samples=8,
|
|
771
|
-
)
|
|
772
|
-
x, y = ax.get_lines()[2].get_data()
|
|
773
|
-
assert not np.isclose(y[0], 0)
|
|
774
|
-
assert np.all(np.array([47, 44, 15, 11]) == x)
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
778
|
-
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
|
|
779
|
-
@pytest.mark.parametrize("animated", [False, True])
|
|
780
|
-
def test_plot_ppc_multichain(kind, jitter, animated):
|
|
781
|
-
if animation and not animation.writers.is_available("ffmpeg"):
|
|
782
|
-
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
|
|
783
|
-
data = from_dict(
|
|
784
|
-
posterior_predictive={
|
|
785
|
-
"x": np.random.randn(4, 100, 30),
|
|
786
|
-
"y_hat": np.random.randn(4, 100, 3, 10),
|
|
787
|
-
},
|
|
788
|
-
observed_data={"x": np.random.randn(30), "y": np.random.randn(3, 10)},
|
|
789
|
-
)
|
|
790
|
-
animation_kwargs = {"blit": False}
|
|
791
|
-
axes = plot_ppc(
|
|
792
|
-
data,
|
|
793
|
-
kind=kind,
|
|
794
|
-
data_pairs={"y": "y_hat"},
|
|
795
|
-
jitter=jitter,
|
|
796
|
-
animated=animated,
|
|
797
|
-
animation_kwargs=animation_kwargs,
|
|
798
|
-
random_seed=3,
|
|
799
|
-
)
|
|
800
|
-
if animated:
|
|
801
|
-
assert np.all(axes[0])
|
|
802
|
-
assert np.all(axes[1])
|
|
803
|
-
else:
|
|
804
|
-
assert np.all(axes)
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
808
|
-
@pytest.mark.parametrize("animated", [False, True])
|
|
809
|
-
def test_plot_ppc_discrete(kind, animated):
|
|
810
|
-
if animation and not animation.writers.is_available("ffmpeg"):
|
|
811
|
-
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
|
|
812
|
-
data = from_dict(
|
|
813
|
-
observed_data={"obs": np.random.randint(1, 100, 15)},
|
|
814
|
-
posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
animation_kwargs = {"blit": False}
|
|
818
|
-
axes = plot_ppc(data, kind=kind, animated=animated, animation_kwargs=animation_kwargs)
|
|
819
|
-
if animated:
|
|
820
|
-
assert np.all(axes[0])
|
|
821
|
-
assert np.all(axes[1])
|
|
822
|
-
assert axes
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
@pytest.mark.skipif(
|
|
826
|
-
not animation.writers.is_available("ffmpeg"),
|
|
827
|
-
reason="matplotlib animations within ArviZ require ffmpeg",
|
|
828
|
-
)
|
|
829
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
830
|
-
def test_plot_ppc_save_animation(models, kind):
|
|
831
|
-
animation_kwargs = {"blit": False}
|
|
832
|
-
axes, anim = plot_ppc(
|
|
833
|
-
models.model_1,
|
|
834
|
-
kind=kind,
|
|
835
|
-
animated=True,
|
|
836
|
-
animation_kwargs=animation_kwargs,
|
|
837
|
-
num_pp_samples=5,
|
|
838
|
-
random_seed=3,
|
|
839
|
-
)
|
|
840
|
-
assert axes
|
|
841
|
-
assert anim
|
|
842
|
-
animations_folder = "../saved_animations"
|
|
843
|
-
os.makedirs(animations_folder, exist_ok=True)
|
|
844
|
-
path = os.path.join(animations_folder, f"ppc_{kind}_animation.mp4")
|
|
845
|
-
anim.save(path)
|
|
846
|
-
assert os.path.exists(path)
|
|
847
|
-
assert os.path.getsize(path)
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
@pytest.mark.skipif(
|
|
851
|
-
not animation.writers.is_available("ffmpeg"),
|
|
852
|
-
reason="matplotlib animations within ArviZ require ffmpeg",
|
|
853
|
-
)
|
|
854
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
855
|
-
def test_plot_ppc_discrete_save_animation(kind):
|
|
856
|
-
data = from_dict(
|
|
857
|
-
observed_data={"obs": np.random.randint(1, 100, 15)},
|
|
858
|
-
posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
|
|
859
|
-
)
|
|
860
|
-
animation_kwargs = {"blit": False}
|
|
861
|
-
axes, anim = plot_ppc(
|
|
862
|
-
data,
|
|
863
|
-
kind=kind,
|
|
864
|
-
animated=True,
|
|
865
|
-
animation_kwargs=animation_kwargs,
|
|
866
|
-
num_pp_samples=5,
|
|
867
|
-
random_seed=3,
|
|
868
|
-
)
|
|
869
|
-
assert axes
|
|
870
|
-
assert anim
|
|
871
|
-
animations_folder = "../saved_animations"
|
|
872
|
-
os.makedirs(animations_folder, exist_ok=True)
|
|
873
|
-
path = os.path.join(animations_folder, f"ppc_discrete_{kind}_animation.mp4")
|
|
874
|
-
anim.save(path)
|
|
875
|
-
assert os.path.exists(path)
|
|
876
|
-
assert os.path.getsize(path)
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
@pytest.mark.skipif(
|
|
880
|
-
not animation.writers.is_available("ffmpeg"),
|
|
881
|
-
reason="matplotlib animations within ArviZ require ffmpeg",
|
|
882
|
-
)
|
|
883
|
-
@pytest.mark.parametrize("system", ["Windows", "Darwin"])
|
|
884
|
-
def test_non_linux_blit(models, monkeypatch, system, caplog):
|
|
885
|
-
import platform
|
|
886
|
-
|
|
887
|
-
def mock_system():
|
|
888
|
-
return system
|
|
889
|
-
|
|
890
|
-
monkeypatch.setattr(platform, "system", mock_system)
|
|
891
|
-
|
|
892
|
-
animation_kwargs = {"blit": True}
|
|
893
|
-
axes, anim = plot_ppc(
|
|
894
|
-
models.model_1,
|
|
895
|
-
kind="kde",
|
|
896
|
-
animated=True,
|
|
897
|
-
animation_kwargs=animation_kwargs,
|
|
898
|
-
num_pp_samples=5,
|
|
899
|
-
random_seed=3,
|
|
900
|
-
)
|
|
901
|
-
records = caplog.records
|
|
902
|
-
assert len(records) == 1
|
|
903
|
-
assert records[0].levelname == "WARNING"
|
|
904
|
-
assert axes
|
|
905
|
-
assert anim
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
@pytest.mark.parametrize(
|
|
909
|
-
"kwargs",
|
|
910
|
-
[
|
|
911
|
-
{"flatten": []},
|
|
912
|
-
{"flatten": [], "coords": {"obs_dim": [1, 2, 3]}},
|
|
913
|
-
{"flatten": ["obs_dim"], "coords": {"obs_dim": [1, 2, 3]}},
|
|
914
|
-
],
|
|
915
|
-
)
|
|
916
|
-
def test_plot_ppc_grid(models, kwargs):
|
|
917
|
-
axes = plot_ppc(models.model_1, kind="scatter", **kwargs)
|
|
918
|
-
if not kwargs.get("flatten") and not kwargs.get("coords"):
|
|
919
|
-
assert axes.size == 8
|
|
920
|
-
elif not kwargs.get("flatten"):
|
|
921
|
-
assert axes.size == 3
|
|
922
|
-
else:
|
|
923
|
-
assert not isinstance(axes, np.ndarray)
|
|
924
|
-
assert np.ravel(axes).size == 1
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
928
|
-
def test_plot_ppc_bad(models, kind):
|
|
929
|
-
data = from_dict(posterior={"mu": np.random.randn()})
|
|
930
|
-
with pytest.raises(TypeError):
|
|
931
|
-
plot_ppc(data, kind=kind)
|
|
932
|
-
with pytest.raises(TypeError):
|
|
933
|
-
plot_ppc(models.model_1, kind="bad_val")
|
|
934
|
-
with pytest.raises(TypeError):
|
|
935
|
-
plot_ppc(models.model_1, num_pp_samples="bad_val")
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
939
|
-
def test_plot_ppc_ax(models, kind, fig_ax):
|
|
940
|
-
"""Test ax argument of plot_ppc."""
|
|
941
|
-
_, ax = fig_ax
|
|
942
|
-
axes = plot_ppc(models.model_1, kind=kind, ax=ax)
|
|
943
|
-
assert np.asarray(axes).item(0) is ax
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
@pytest.mark.skipif(
|
|
947
|
-
not animation.writers.is_available("ffmpeg"),
|
|
948
|
-
reason="matplotlib animations within ArviZ require ffmpeg",
|
|
949
|
-
)
|
|
950
|
-
def test_plot_ppc_bad_ax(models, fig_ax):
|
|
951
|
-
_, ax = fig_ax
|
|
952
|
-
_, ax2 = plt.subplots(1, 2)
|
|
953
|
-
with pytest.raises(ValueError, match="same figure"):
|
|
954
|
-
plot_ppc(
|
|
955
|
-
models.model_1, ax=[ax, *ax2], flatten=[], coords={"obs_dim": [1, 2, 3]}, animated=True
|
|
956
|
-
)
|
|
957
|
-
with pytest.raises(ValueError, match="2 axes"):
|
|
958
|
-
plot_ppc(models.model_1, ax=ax2)
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
def test_plot_legend(models):
|
|
962
|
-
axes = plot_ppc(models.model_1)
|
|
963
|
-
legend_texts = axes.get_legend().get_texts()
|
|
964
|
-
result = [i.get_text() for i in legend_texts]
|
|
965
|
-
expected = ["Posterior predictive", "Observed", "Posterior predictive mean"]
|
|
966
|
-
assert result == expected
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
|
|
970
|
-
@pytest.mark.parametrize("side", ["both", "left", "right"])
|
|
971
|
-
@pytest.mark.parametrize("rug", [True])
|
|
972
|
-
def test_plot_violin(models, var_names, side, rug):
|
|
973
|
-
axes = plot_violin(models.model_1, var_names=var_names, side=side, rug=rug)
|
|
974
|
-
assert axes.shape
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
def test_plot_violin_ax(models):
|
|
978
|
-
_, ax = plt.subplots(1)
|
|
979
|
-
axes = plot_violin(models.model_1, var_names="mu", ax=ax)
|
|
980
|
-
assert axes.shape
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
def test_plot_violin_layout(models):
|
|
984
|
-
axes = plot_violin(models.model_1, var_names=["mu", "tau"], sharey=False)
|
|
985
|
-
assert axes.shape
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
def test_plot_violin_discrete(discrete_model):
|
|
989
|
-
axes = plot_violin(discrete_model)
|
|
990
|
-
assert axes.shape
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
|
|
994
|
-
def test_plot_violin_combinedims(models, var_names):
|
|
995
|
-
axes = plot_violin(models.model_1, var_names=var_names, combine_dims={"school"})
|
|
996
|
-
assert axes.shape
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
def test_plot_violin_ax_combinedims(models):
|
|
1000
|
-
_, ax = plt.subplots(1)
|
|
1001
|
-
axes = plot_violin(models.model_1, var_names="mu", combine_dims={"school"}, ax=ax)
|
|
1002
|
-
assert axes.shape
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
def test_plot_violin_layout_combinedims(models):
|
|
1006
|
-
axes = plot_violin(
|
|
1007
|
-
models.model_1, var_names=["mu", "tau"], combine_dims={"school"}, sharey=False
|
|
1008
|
-
)
|
|
1009
|
-
assert axes.shape
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
def test_plot_violin_discrete_combinedims(discrete_model):
|
|
1013
|
-
axes = plot_violin(discrete_model, combine_dims={"school"})
|
|
1014
|
-
assert axes.shape
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
def test_plot_autocorr_short_chain():
|
|
1018
|
-
"""Check that logic for small chain defaulting doesn't cause exception"""
|
|
1019
|
-
chain = np.arange(10)
|
|
1020
|
-
axes = plot_autocorr(chain)
|
|
1021
|
-
assert axes
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
def test_plot_autocorr_uncombined(models):
|
|
1025
|
-
axes = plot_autocorr(models.model_1, combined=False)
|
|
1026
|
-
assert axes.size
|
|
1027
|
-
max_subplots = (
|
|
1028
|
-
np.inf if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
|
|
1029
|
-
)
|
|
1030
|
-
assert axes.size == min(72, max_subplots)
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
def test_plot_autocorr_combined(models):
|
|
1034
|
-
axes = plot_autocorr(models.model_1, combined=True)
|
|
1035
|
-
assert axes.size == 18
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu"], ["mu", "tau"]))
|
|
1039
|
-
def test_plot_autocorr_var_names(models, var_names):
|
|
1040
|
-
axes = plot_autocorr(models.model_1, var_names=var_names, combined=True)
|
|
1041
|
-
if (isinstance(var_names, list) and len(var_names) == 1) or isinstance(var_names, str):
|
|
1042
|
-
assert not isinstance(axes, np.ndarray)
|
|
1043
|
-
else:
|
|
1044
|
-
assert axes.shape
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
@pytest.mark.parametrize(
|
|
1048
|
-
"kwargs",
|
|
1049
|
-
[
|
|
1050
|
-
{},
|
|
1051
|
-
{"var_names": "mu"},
|
|
1052
|
-
{"var_names": ("mu", "tau"), "coords": {"school": [0, 1]}},
|
|
1053
|
-
{"var_names": "mu", "ref_line": True},
|
|
1054
|
-
{
|
|
1055
|
-
"var_names": "mu",
|
|
1056
|
-
"ref_line_kwargs": {"lw": 2, "color": "C2"},
|
|
1057
|
-
"bar_kwargs": {"width": 0.7},
|
|
1058
|
-
},
|
|
1059
|
-
{"var_names": "mu", "ref_line": False},
|
|
1060
|
-
{"var_names": "mu", "kind": "vlines"},
|
|
1061
|
-
{
|
|
1062
|
-
"var_names": "mu",
|
|
1063
|
-
"kind": "vlines",
|
|
1064
|
-
"vlines_kwargs": {"lw": 0},
|
|
1065
|
-
"marker_vlines_kwargs": {"lw": 3},
|
|
1066
|
-
},
|
|
1067
|
-
],
|
|
1068
|
-
)
|
|
1069
|
-
def test_plot_rank(models, kwargs):
|
|
1070
|
-
axes = plot_rank(models.model_1, **kwargs)
|
|
1071
|
-
var_names = kwargs.get("var_names", [])
|
|
1072
|
-
if isinstance(var_names, str):
|
|
1073
|
-
assert not isinstance(axes, np.ndarray)
|
|
1074
|
-
else:
|
|
1075
|
-
assert axes.shape
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
@pytest.mark.parametrize(
|
|
1079
|
-
"kwargs",
|
|
1080
|
-
[
|
|
1081
|
-
{},
|
|
1082
|
-
{"var_names": "mu"},
|
|
1083
|
-
{"var_names": ("mu", "tau")},
|
|
1084
|
-
{"rope": (-2, 2)},
|
|
1085
|
-
{"rope": {"mu": [{"rope": (-2, 2)}], "theta": [{"school": "Choate", "rope": (2, 4)}]}},
|
|
1086
|
-
{"point_estimate": "mode"},
|
|
1087
|
-
{"point_estimate": "median"},
|
|
1088
|
-
{"hdi_prob": "hide", "label": ""},
|
|
1089
|
-
{"point_estimate": None},
|
|
1090
|
-
{"ref_val": 0},
|
|
1091
|
-
{"ref_val": None},
|
|
1092
|
-
{"ref_val": {"mu": [{"ref_val": 1}]}},
|
|
1093
|
-
{"bins": None, "kind": "hist"},
|
|
1094
|
-
{
|
|
1095
|
-
"ref_val": {
|
|
1096
|
-
"theta": [
|
|
1097
|
-
# {"school": ["Choate", "Deerfield"], "ref_val": -1}, this is not working
|
|
1098
|
-
{"school": "Lawrenceville", "ref_val": 3}
|
|
1099
|
-
]
|
|
1100
|
-
}
|
|
1101
|
-
},
|
|
1102
|
-
],
|
|
1103
|
-
)
|
|
1104
|
-
def test_plot_posterior(models, kwargs):
|
|
1105
|
-
axes = plot_posterior(models.model_1, **kwargs)
|
|
1106
|
-
if isinstance(kwargs.get("var_names"), str):
|
|
1107
|
-
assert not isinstance(axes, np.ndarray)
|
|
1108
|
-
else:
|
|
1109
|
-
assert axes.shape
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
def test_plot_posterior_boolean():
|
|
1113
|
-
data = np.random.choice(a=[False, True], size=(4, 100))
|
|
1114
|
-
axes = plot_posterior(data)
|
|
1115
|
-
assert axes
|
|
1116
|
-
plt.draw()
|
|
1117
|
-
labels = [label.get_text() for label in axes.get_xticklabels()]
|
|
1118
|
-
assert all(item in labels for item in ("True", "False"))
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
@pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
|
|
1122
|
-
def test_plot_posterior_discrete(discrete_model, kwargs):
|
|
1123
|
-
axes = plot_posterior(discrete_model, **kwargs)
|
|
1124
|
-
assert axes.shape
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
def test_plot_posterior_bad_type():
|
|
1128
|
-
with pytest.raises(TypeError):
|
|
1129
|
-
plot_posterior(np.array(["a", "b", "c"]))
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
def test_plot_posterior_bad(models):
|
|
1133
|
-
with pytest.raises(ValueError):
|
|
1134
|
-
plot_posterior(models.model_1, rope="bad_value")
|
|
1135
|
-
with pytest.raises(ValueError):
|
|
1136
|
-
plot_posterior(models.model_1, ref_val="bad_value")
|
|
1137
|
-
with pytest.raises(ValueError):
|
|
1138
|
-
plot_posterior(models.model_1, point_estimate="bad_value")
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
@pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
|
|
1142
|
-
def test_plot_posterior_point_estimates(models, point_estimate):
|
|
1143
|
-
axes = plot_posterior(models.model_1, var_names=("mu", "tau"), point_estimate=point_estimate)
|
|
1144
|
-
assert axes.size == 2
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
def test_plot_posterior_skipna():
|
|
1148
|
-
sample = np.linspace(0, 1)
|
|
1149
|
-
sample[:10] = np.nan
|
|
1150
|
-
plot_posterior({"a": sample}, skipna=True)
|
|
1151
|
-
with pytest.raises(ValueError):
|
|
1152
|
-
plot_posterior({"a": sample}, skipna=False)
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "theta"]}])
|
|
1156
|
-
def test_plot_posterior_combinedims(models, kwargs):
|
|
1157
|
-
axes = plot_posterior(models.model_1, combine_dims={"school"}, **kwargs)
|
|
1158
|
-
if isinstance(kwargs.get("var_names"), str):
|
|
1159
|
-
assert not isinstance(axes, np.ndarray)
|
|
1160
|
-
else:
|
|
1161
|
-
assert axes.shape
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
@pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
|
|
1165
|
-
def test_plot_posterior_discrete_combinedims(discrete_multidim_model, kwargs):
|
|
1166
|
-
axes = plot_posterior(discrete_multidim_model, combine_dims={"school"}, **kwargs)
|
|
1167
|
-
assert axes.size == 2
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
@pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
|
|
1171
|
-
def test_plot_posterior_point_estimates_combinedims(models, point_estimate):
|
|
1172
|
-
axes = plot_posterior(
|
|
1173
|
-
models.model_1,
|
|
1174
|
-
var_names=("mu", "tau"),
|
|
1175
|
-
combine_dims={"school"},
|
|
1176
|
-
point_estimate=point_estimate,
|
|
1177
|
-
)
|
|
1178
|
-
assert axes.size == 2
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
def test_plot_posterior_skipna_combinedims():
|
|
1182
|
-
idata = load_arviz_data("centered_eight")
|
|
1183
|
-
idata.posterior["theta"].loc[dict(school="Deerfield")] = np.nan
|
|
1184
|
-
with pytest.raises(ValueError):
|
|
1185
|
-
plot_posterior(idata, var_names="theta", combine_dims={"school"}, skipna=False)
|
|
1186
|
-
ax = plot_posterior(idata, var_names="theta", combine_dims={"school"}, skipna=True)
|
|
1187
|
-
assert not isinstance(ax, np.ndarray)
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
@pytest.mark.parametrize(
|
|
1191
|
-
"kwargs", [{"insample_dev": True}, {"plot_standard_error": False}, {"plot_ic_diff": False}]
|
|
1192
|
-
)
|
|
1193
|
-
def test_plot_compare(models, kwargs):
|
|
1194
|
-
model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
|
|
1195
|
-
|
|
1196
|
-
axes = plot_compare(model_compare, **kwargs)
|
|
1197
|
-
assert axes
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
def test_plot_compare_no_ic(models):
|
|
1201
|
-
"""Check exception is raised if model_compare doesn't contain a valid information criterion"""
|
|
1202
|
-
model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
|
|
1203
|
-
|
|
1204
|
-
# Drop column needed for plotting
|
|
1205
|
-
model_compare = model_compare.drop("elpd_loo", axis=1)
|
|
1206
|
-
with pytest.raises(ValueError) as err:
|
|
1207
|
-
plot_compare(model_compare)
|
|
1208
|
-
|
|
1209
|
-
assert "comp_df must contain one of the following" in str(err.value)
|
|
1210
|
-
assert "['elpd_loo', 'elpd_waic']" in str(err.value)
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
@pytest.mark.parametrize(
|
|
1214
|
-
"kwargs",
|
|
1215
|
-
[
|
|
1216
|
-
{"color": "0.5", "circular": True},
|
|
1217
|
-
{"hdi_data": True, "fill_kwargs": {"alpha": 0}},
|
|
1218
|
-
{"plot_kwargs": {"alpha": 0}},
|
|
1219
|
-
{"smooth_kwargs": {"window_length": 33, "polyorder": 5, "mode": "mirror"}},
|
|
1220
|
-
{"hdi_data": True, "smooth": False},
|
|
1221
|
-
],
|
|
1222
|
-
)
|
|
1223
|
-
def test_plot_hdi(models, data, kwargs):
|
|
1224
|
-
hdi_data = kwargs.pop("hdi_data", None)
|
|
1225
|
-
if hdi_data:
|
|
1226
|
-
hdi_data = hdi(models.model_1.posterior["theta"])
|
|
1227
|
-
ax = plot_hdi(data["y"], hdi_data=hdi_data, **kwargs)
|
|
1228
|
-
else:
|
|
1229
|
-
ax = plot_hdi(data["y"], models.model_1.posterior["theta"], **kwargs)
|
|
1230
|
-
assert ax
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
def test_plot_hdi_warning():
|
|
1234
|
-
"""Check using both y and hdi_data sends a warning."""
|
|
1235
|
-
x_data = np.random.normal(0, 1, 100)
|
|
1236
|
-
y_data = np.random.normal(2 + x_data * 0.5, 0.5, (1, 200, 100))
|
|
1237
|
-
hdi_data = hdi(y_data)
|
|
1238
|
-
with pytest.warns(UserWarning, match="Both y and hdi_data"):
|
|
1239
|
-
ax = plot_hdi(x_data, y=y_data, hdi_data=hdi_data)
|
|
1240
|
-
assert ax
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
def test_plot_hdi_missing_arg_error():
|
|
1244
|
-
"""Check that both y and hdi_data missing raises an error."""
|
|
1245
|
-
with pytest.raises(ValueError, match="One of {y, hdi_data"):
|
|
1246
|
-
plot_hdi(np.arange(20))
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
def test_plot_hdi_dataset_error(models):
|
|
1250
|
-
"""Check hdi_data as multiple variable Dataset raises an error."""
|
|
1251
|
-
hdi_data = hdi(models.model_1)
|
|
1252
|
-
with pytest.raises(ValueError, match="Only single variable Dataset"):
|
|
1253
|
-
plot_hdi(np.arange(8), hdi_data=hdi_data)
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
def test_plot_hdi_string_error():
|
|
1257
|
-
"""Check x as type string raises an error."""
|
|
1258
|
-
x_data = ["a", "b", "c", "d"]
|
|
1259
|
-
y_data = np.random.normal(0, 5, (1, 200, len(x_data)))
|
|
1260
|
-
hdi_data = hdi(y_data)
|
|
1261
|
-
with pytest.raises(
|
|
1262
|
-
NotImplementedError,
|
|
1263
|
-
match=re.escape(
|
|
1264
|
-
(
|
|
1265
|
-
"The `arviz.plot_hdi()` function does not support categorical data. "
|
|
1266
|
-
"Consider using `arviz.plot_forest()`."
|
|
1267
|
-
)
|
|
1268
|
-
),
|
|
1269
|
-
):
|
|
1270
|
-
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
def test_plot_hdi_datetime_error():
|
|
1274
|
-
"""Check x as datetime raises an error."""
|
|
1275
|
-
x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)
|
|
1276
|
-
y_data = np.random.normal(0, 5, (1, 200, x_data.shape[0]))
|
|
1277
|
-
hdi_data = hdi(y_data)
|
|
1278
|
-
with pytest.raises(TypeError, match="Cannot deal with x as type datetime."):
|
|
1279
|
-
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
@pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)])
|
|
1283
|
-
def test_kde_scipy(limits):
|
|
1284
|
-
"""
|
|
1285
|
-
Evaluates if sum of density is the same for our implementation
|
|
1286
|
-
and the implementation in scipy
|
|
1287
|
-
"""
|
|
1288
|
-
data = np.random.normal(0, 1, 10000)
|
|
1289
|
-
grid, density_own = _kde(data, custom_lims=limits)
|
|
1290
|
-
density_sp = gaussian_kde(data).evaluate(grid)
|
|
1291
|
-
np.testing.assert_almost_equal(density_own.sum(), density_sp.sum(), 1)
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
@pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)])
|
|
1295
|
-
def test_kde_cumulative(limits):
|
|
1296
|
-
"""
|
|
1297
|
-
Evaluates if last value of cumulative density is 1
|
|
1298
|
-
"""
|
|
1299
|
-
data = np.random.normal(0, 1, 1000)
|
|
1300
|
-
density = _kde(data, custom_lims=limits, cumulative=True)[1]
|
|
1301
|
-
np.testing.assert_almost_equal(round(density[-1], 3), 1)
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
def test_plot_ecdf_basic():
|
|
1305
|
-
data = np.random.randn(4, 1000)
|
|
1306
|
-
axes = plot_ecdf(data)
|
|
1307
|
-
assert axes is not None
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
def test_plot_ecdf_eval_points():
|
|
1311
|
-
"""Check that BehaviourChangeWarning is raised if eval_points is not specified."""
|
|
1312
|
-
data = np.random.randn(4, 1000)
|
|
1313
|
-
eval_points = np.linspace(-3, 3, 100)
|
|
1314
|
-
with pytest.warns(BehaviourChangeWarning):
|
|
1315
|
-
axes = plot_ecdf(data)
|
|
1316
|
-
assert axes is not None
|
|
1317
|
-
with does_not_warn(BehaviourChangeWarning):
|
|
1318
|
-
axes = plot_ecdf(data, eval_points=eval_points)
|
|
1319
|
-
assert axes is not None
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
@pytest.mark.parametrize("confidence_bands", [True, "pointwise", "optimized", "simulated"])
|
|
1323
|
-
@pytest.mark.parametrize("ndraws", [100, 10_000])
|
|
1324
|
-
def test_plot_ecdf_confidence_bands(confidence_bands, ndraws):
|
|
1325
|
-
"""Check that all confidence_bands values correctly accepted"""
|
|
1326
|
-
data = np.random.randn(4, ndraws // 4)
|
|
1327
|
-
axes = plot_ecdf(data, confidence_bands=confidence_bands, cdf=norm(0, 1).cdf)
|
|
1328
|
-
assert axes is not None
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
def test_plot_ecdf_values2():
|
|
1332
|
-
data = np.random.randn(4, 1000)
|
|
1333
|
-
data2 = np.random.randn(4, 1000)
|
|
1334
|
-
axes = plot_ecdf(data, data2)
|
|
1335
|
-
assert axes is not None
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
def test_plot_ecdf_cdf():
|
|
1339
|
-
data = np.random.randn(4, 1000)
|
|
1340
|
-
cdf = norm(0, 1).cdf
|
|
1341
|
-
axes = plot_ecdf(data, cdf=cdf)
|
|
1342
|
-
assert axes is not None
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
def test_plot_ecdf_error():
|
|
1346
|
-
"""Check that all error conditions are correctly raised."""
|
|
1347
|
-
dist = norm(0, 1)
|
|
1348
|
-
data = dist.rvs(1000)
|
|
1349
|
-
|
|
1350
|
-
# cdf not specified
|
|
1351
|
-
with pytest.raises(ValueError):
|
|
1352
|
-
plot_ecdf(data, confidence_bands=True)
|
|
1353
|
-
plot_ecdf(data, confidence_bands=True, cdf=dist.cdf)
|
|
1354
|
-
with pytest.raises(ValueError):
|
|
1355
|
-
plot_ecdf(data, difference=True)
|
|
1356
|
-
plot_ecdf(data, difference=True, cdf=dist.cdf)
|
|
1357
|
-
with pytest.raises(ValueError):
|
|
1358
|
-
plot_ecdf(data, pit=True)
|
|
1359
|
-
plot_ecdf(data, pit=True, cdf=dist.cdf)
|
|
1360
|
-
|
|
1361
|
-
# contradictory confidence band types
|
|
1362
|
-
with pytest.raises(ValueError):
|
|
1363
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands="simulated", pointwise=True)
|
|
1364
|
-
with pytest.raises(ValueError):
|
|
1365
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands="optimized", pointwise=True)
|
|
1366
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, pointwise=True)
|
|
1367
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands="pointwise")
|
|
1368
|
-
|
|
1369
|
-
# contradictory band probabilities
|
|
1370
|
-
with pytest.raises(ValueError):
|
|
1371
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, ci_prob=0.9, fpr=0.1)
|
|
1372
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, ci_prob=0.9)
|
|
1373
|
-
plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, fpr=0.1)
|
|
1374
|
-
|
|
1375
|
-
# contradictory reference
|
|
1376
|
-
data2 = dist.rvs(200)
|
|
1377
|
-
with pytest.raises(ValueError):
|
|
1378
|
-
plot_ecdf(data, data2, cdf=dist.cdf, difference=True)
|
|
1379
|
-
plot_ecdf(data, data2, difference=True)
|
|
1380
|
-
plot_ecdf(data, cdf=dist.cdf, difference=True)
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
def test_plot_ecdf_deprecations():
|
|
1384
|
-
"""Check that deprecations are raised correctly."""
|
|
1385
|
-
dist = norm(0, 1)
|
|
1386
|
-
data = dist.rvs(1000)
|
|
1387
|
-
# base case, no deprecations
|
|
1388
|
-
with does_not_warn(FutureWarning):
|
|
1389
|
-
axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True)
|
|
1390
|
-
assert axes is not None
|
|
1391
|
-
|
|
1392
|
-
# values2 is deprecated
|
|
1393
|
-
data2 = dist.rvs(200)
|
|
1394
|
-
with pytest.warns(FutureWarning):
|
|
1395
|
-
axes = plot_ecdf(data, values2=data2, difference=True)
|
|
1396
|
-
|
|
1397
|
-
# pit is deprecated
|
|
1398
|
-
with pytest.warns(FutureWarning):
|
|
1399
|
-
axes = plot_ecdf(data, cdf=dist.cdf, pit=True)
|
|
1400
|
-
assert axes is not None
|
|
1401
|
-
|
|
1402
|
-
# fpr is deprecated
|
|
1403
|
-
with does_not_warn(FutureWarning):
|
|
1404
|
-
axes = plot_ecdf(data, cdf=dist.cdf, ci_prob=0.9)
|
|
1405
|
-
with pytest.warns(FutureWarning):
|
|
1406
|
-
axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, fpr=0.1)
|
|
1407
|
-
assert axes is not None
|
|
1408
|
-
|
|
1409
|
-
# pointwise is deprecated
|
|
1410
|
-
with does_not_warn(FutureWarning):
|
|
1411
|
-
axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands="pointwise")
|
|
1412
|
-
with pytest.warns(FutureWarning):
|
|
1413
|
-
axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, pointwise=True)
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
@pytest.mark.parametrize(
|
|
1417
|
-
"kwargs",
|
|
1418
|
-
[
|
|
1419
|
-
{},
|
|
1420
|
-
{"ic": "loo"},
|
|
1421
|
-
{"xlabels": True, "scale": "log"},
|
|
1422
|
-
{"color": "obs_dim", "xlabels": True},
|
|
1423
|
-
{"color": "obs_dim", "legend": True},
|
|
1424
|
-
{"ic": "loo", "color": "blue", "coords": {"obs_dim": slice(2, 5)}},
|
|
1425
|
-
{"color": np.random.uniform(size=8), "threshold": 0.1},
|
|
1426
|
-
{"threshold": 2},
|
|
1427
|
-
],
|
|
1428
|
-
)
|
|
1429
|
-
@pytest.mark.parametrize("add_model", [False, True])
|
|
1430
|
-
@pytest.mark.parametrize("use_elpddata", [False, True])
|
|
1431
|
-
def test_plot_elpd(models, add_model, use_elpddata, kwargs):
|
|
1432
|
-
model_dict = {"Model 1": models.model_1, "Model 2": models.model_2}
|
|
1433
|
-
if add_model:
|
|
1434
|
-
model_dict["Model 3"] = create_model(seed=12)
|
|
1435
|
-
|
|
1436
|
-
if use_elpddata:
|
|
1437
|
-
ic = kwargs.get("ic", "waic")
|
|
1438
|
-
scale = kwargs.get("scale", "deviance")
|
|
1439
|
-
if ic == "waic":
|
|
1440
|
-
model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
|
|
1441
|
-
else:
|
|
1442
|
-
model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
|
|
1443
|
-
|
|
1444
|
-
axes = plot_elpd(model_dict, **kwargs)
|
|
1445
|
-
assert np.all(axes)
|
|
1446
|
-
if add_model:
|
|
1447
|
-
assert axes.shape[0] == axes.shape[1]
|
|
1448
|
-
assert axes.shape[0] == len(model_dict) - 1
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
@pytest.mark.parametrize(
|
|
1452
|
-
"kwargs",
|
|
1453
|
-
[
|
|
1454
|
-
{},
|
|
1455
|
-
{"ic": "loo"},
|
|
1456
|
-
{"xlabels": True, "scale": "log"},
|
|
1457
|
-
{"color": "dim1", "xlabels": True},
|
|
1458
|
-
{"color": "dim2", "legend": True},
|
|
1459
|
-
{"ic": "loo", "color": "blue", "coords": {"dim2": slice(2, 4)}},
|
|
1460
|
-
{"color": np.random.uniform(size=35), "threshold": 0.1},
|
|
1461
|
-
],
|
|
1462
|
-
)
|
|
1463
|
-
@pytest.mark.parametrize("add_model", [False, True])
|
|
1464
|
-
@pytest.mark.parametrize("use_elpddata", [False, True])
|
|
1465
|
-
def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs):
|
|
1466
|
-
model_dict = {"Model 1": multidim_models.model_1, "Model 2": multidim_models.model_2}
|
|
1467
|
-
if add_model:
|
|
1468
|
-
model_dict["Model 3"] = create_multidimensional_model(seed=12)
|
|
1469
|
-
|
|
1470
|
-
if use_elpddata:
|
|
1471
|
-
ic = kwargs.get("ic", "waic")
|
|
1472
|
-
scale = kwargs.get("scale", "deviance")
|
|
1473
|
-
if ic == "waic":
|
|
1474
|
-
model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
|
|
1475
|
-
else:
|
|
1476
|
-
model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
|
|
1477
|
-
|
|
1478
|
-
axes = plot_elpd(model_dict, **kwargs)
|
|
1479
|
-
assert np.all(axes)
|
|
1480
|
-
if add_model:
|
|
1481
|
-
assert axes.shape[0] == axes.shape[1]
|
|
1482
|
-
assert axes.shape[0] == len(model_dict) - 1
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
def test_plot_elpd_bad_ic(models):
|
|
1486
|
-
model_dict = {
|
|
1487
|
-
"Model 1": waic(models.model_1, pointwise=True),
|
|
1488
|
-
"Model 2": loo(models.model_2, pointwise=True),
|
|
1489
|
-
}
|
|
1490
|
-
with pytest.raises(ValueError):
|
|
1491
|
-
plot_elpd(model_dict, ic="bad_ic")
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
def test_plot_elpd_ic_error(models):
|
|
1495
|
-
model_dict = {
|
|
1496
|
-
"Model 1": waic(models.model_1, pointwise=True),
|
|
1497
|
-
"Model 2": loo(models.model_2, pointwise=True),
|
|
1498
|
-
}
|
|
1499
|
-
with pytest.raises(ValueError):
|
|
1500
|
-
plot_elpd(model_dict)
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
def test_plot_elpd_scale_error(models):
|
|
1504
|
-
model_dict = {
|
|
1505
|
-
"Model 1": waic(models.model_1, pointwise=True, scale="log"),
|
|
1506
|
-
"Model 2": waic(models.model_2, pointwise=True, scale="deviance"),
|
|
1507
|
-
}
|
|
1508
|
-
with pytest.raises(ValueError):
|
|
1509
|
-
plot_elpd(model_dict)
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
def test_plot_elpd_one_model(models):
|
|
1513
|
-
model_dict = {"Model 1": models.model_1}
|
|
1514
|
-
with pytest.raises(Exception):
|
|
1515
|
-
plot_elpd(model_dict)
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
@pytest.mark.parametrize(
|
|
1519
|
-
"kwargs",
|
|
1520
|
-
[
|
|
1521
|
-
{},
|
|
1522
|
-
{"xlabels": True},
|
|
1523
|
-
{"color": "obs_dim", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
|
|
1524
|
-
{"color": "obs_dim", "legend": True, "hover_label": True},
|
|
1525
|
-
{"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
|
|
1526
|
-
{"color": np.random.uniform(size=8), "show_bins": True},
|
|
1527
|
-
{
|
|
1528
|
-
"color": np.random.uniform(size=(8, 3)),
|
|
1529
|
-
"show_bins": True,
|
|
1530
|
-
"show_hlines": True,
|
|
1531
|
-
"threshold": 1,
|
|
1532
|
-
},
|
|
1533
|
-
],
|
|
1534
|
-
)
|
|
1535
|
-
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
|
|
1536
|
-
def test_plot_khat(models, input_type, kwargs):
|
|
1537
|
-
khats_data = loo(models.model_1, pointwise=True)
|
|
1538
|
-
|
|
1539
|
-
if input_type == "data_array":
|
|
1540
|
-
khats_data = khats_data.pareto_k
|
|
1541
|
-
elif input_type == "array":
|
|
1542
|
-
khats_data = khats_data.pareto_k.values
|
|
1543
|
-
if "color" in kwargs and isinstance(kwargs["color"], str) and kwargs["color"] == "obs_dim":
|
|
1544
|
-
kwargs["color"] = None
|
|
1545
|
-
|
|
1546
|
-
axes = plot_khat(khats_data, **kwargs)
|
|
1547
|
-
assert axes
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
@pytest.mark.parametrize(
|
|
1551
|
-
"kwargs",
|
|
1552
|
-
[
|
|
1553
|
-
{},
|
|
1554
|
-
{"xlabels": True},
|
|
1555
|
-
{"color": "dim1", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
|
|
1556
|
-
{"color": "dim2", "legend": True, "hover_label": True},
|
|
1557
|
-
{"color": "blue", "coords": {"dim2": slice(2, 4)}},
|
|
1558
|
-
{"color": np.random.uniform(size=35), "show_bins": True},
|
|
1559
|
-
{
|
|
1560
|
-
"color": np.random.uniform(size=(35, 3)),
|
|
1561
|
-
"show_bins": True,
|
|
1562
|
-
"show_hlines": True,
|
|
1563
|
-
"threshold": 1,
|
|
1564
|
-
},
|
|
1565
|
-
],
|
|
1566
|
-
)
|
|
1567
|
-
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
|
|
1568
|
-
def test_plot_khat_multidim(multidim_models, input_type, kwargs):
|
|
1569
|
-
khats_data = loo(multidim_models.model_1, pointwise=True)
|
|
1570
|
-
|
|
1571
|
-
if input_type == "data_array":
|
|
1572
|
-
khats_data = khats_data.pareto_k
|
|
1573
|
-
elif input_type == "array":
|
|
1574
|
-
khats_data = khats_data.pareto_k.values
|
|
1575
|
-
if (
|
|
1576
|
-
"color" in kwargs
|
|
1577
|
-
and isinstance(kwargs["color"], str)
|
|
1578
|
-
and kwargs["color"] in ("dim1", "dim2")
|
|
1579
|
-
):
|
|
1580
|
-
kwargs["color"] = None
|
|
1581
|
-
|
|
1582
|
-
axes = plot_khat(khats_data, **kwargs)
|
|
1583
|
-
assert axes
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
def test_plot_khat_threshold():
|
|
1587
|
-
khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
|
|
1588
|
-
axes = plot_khat(khats, threshold=1)
|
|
1589
|
-
assert axes
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
def test_plot_khat_bad_input(models):
|
|
1593
|
-
with pytest.raises(ValueError):
|
|
1594
|
-
plot_khat(models.model_1.sample_stats)
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
@pytest.mark.parametrize(
|
|
1598
|
-
"kwargs",
|
|
1599
|
-
[
|
|
1600
|
-
{},
|
|
1601
|
-
{"var_names": ["theta"], "relative": True, "color": "r"},
|
|
1602
|
-
{"coords": {"school": slice(4)}, "n_points": 10},
|
|
1603
|
-
{"min_ess": 600, "hline_kwargs": {"color": "r"}},
|
|
1604
|
-
],
|
|
1605
|
-
)
|
|
1606
|
-
@pytest.mark.parametrize("kind", ["local", "quantile", "evolution"])
|
|
1607
|
-
def test_plot_ess(models, kind, kwargs):
|
|
1608
|
-
"""Test plot_ess arguments common to all kind of plots."""
|
|
1609
|
-
idata = models.model_1
|
|
1610
|
-
ax = plot_ess(idata, kind=kind, **kwargs)
|
|
1611
|
-
assert np.all(ax)
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
@pytest.mark.parametrize(
|
|
1615
|
-
"kwargs",
|
|
1616
|
-
[
|
|
1617
|
-
{"rug": True},
|
|
1618
|
-
{"rug": True, "rug_kind": "max_depth", "rug_kwargs": {"color": "c"}},
|
|
1619
|
-
{"extra_methods": True},
|
|
1620
|
-
{"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
|
|
1621
|
-
{"extra_methods": True, "rug": True},
|
|
1622
|
-
],
|
|
1623
|
-
)
|
|
1624
|
-
@pytest.mark.parametrize("kind", ["local", "quantile"])
|
|
1625
|
-
def test_plot_ess_local_quantile(models, kind, kwargs):
|
|
1626
|
-
"""Test specific arguments in kinds local and quantile of plot_ess."""
|
|
1627
|
-
idata = models.model_1
|
|
1628
|
-
ax = plot_ess(idata, kind=kind, **kwargs)
|
|
1629
|
-
assert np.all(ax)
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
def test_plot_ess_evolution(models):
|
|
1633
|
-
"""Test specific arguments in evolution kind of plot_ess."""
|
|
1634
|
-
idata = models.model_1
|
|
1635
|
-
ax = plot_ess(idata, kind="evolution", extra_kwargs={"linestyle": "--"}, color="b")
|
|
1636
|
-
assert np.all(ax)
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
def test_plot_ess_bad_kind(models):
|
|
1640
|
-
"""Test error when plot_ess receives an invalid kind."""
|
|
1641
|
-
idata = models.model_1
|
|
1642
|
-
with pytest.raises(ValueError, match="Invalid kind"):
|
|
1643
|
-
plot_ess(idata, kind="bad kind")
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
@pytest.mark.parametrize("dim", ["chain", "draw"])
|
|
1647
|
-
def test_plot_ess_bad_coords(models, dim):
|
|
1648
|
-
"""Test error when chain or dim are used as coords to select a data subset."""
|
|
1649
|
-
idata = models.model_1
|
|
1650
|
-
with pytest.raises(ValueError, match="invalid coordinates"):
|
|
1651
|
-
plot_ess(idata, coords={dim: slice(3)})
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
def test_plot_ess_no_sample_stats(models):
|
|
1655
|
-
"""Test error when rug=True but sample_stats group is not present."""
|
|
1656
|
-
idata = models.model_1
|
|
1657
|
-
with pytest.raises(ValueError, match="must contain sample_stats"):
|
|
1658
|
-
plot_ess(idata.posterior, rug=True)
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
def test_plot_ess_no_divergences(models):
|
|
1662
|
-
"""Test error when rug=True, but the variable defined by rug_kind is missing."""
|
|
1663
|
-
idata = deepcopy(models.model_1)
|
|
1664
|
-
idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
|
|
1665
|
-
with pytest.raises(ValueError, match="not contain diverging"):
|
|
1666
|
-
plot_ess(idata, rug=True)
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
@pytest.mark.parametrize(
|
|
1670
|
-
"kwargs",
|
|
1671
|
-
[
|
|
1672
|
-
{},
|
|
1673
|
-
{"n_unif": 50, "legend": False},
|
|
1674
|
-
{"use_hdi": True, "color": "gray"},
|
|
1675
|
-
{"use_hdi": True, "hdi_prob": 0.68},
|
|
1676
|
-
{"use_hdi": True, "hdi_kwargs": {"fill": 0.1}},
|
|
1677
|
-
{"ecdf": True},
|
|
1678
|
-
{"ecdf": True, "ecdf_fill": False, "plot_unif_kwargs": {"ls": "--"}},
|
|
1679
|
-
{"ecdf": True, "hdi_prob": 0.97, "fill_kwargs": {"hatch": "/"}},
|
|
1680
|
-
],
|
|
1681
|
-
)
|
|
1682
|
-
def test_plot_loo_pit(models, kwargs):
|
|
1683
|
-
axes = plot_loo_pit(idata=models.model_1, y="y", **kwargs)
|
|
1684
|
-
assert axes
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
def test_plot_loo_pit_incompatible_args(models):
|
|
1688
|
-
"""Test error when both ecdf and use_hdi are True."""
|
|
1689
|
-
with pytest.raises(ValueError, match="incompatible"):
|
|
1690
|
-
plot_loo_pit(idata=models.model_1, y="y", ecdf=True, use_hdi=True)
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
@pytest.mark.parametrize(
|
|
1694
|
-
"kwargs",
|
|
1695
|
-
[
|
|
1696
|
-
{},
|
|
1697
|
-
{"var_names": ["theta"], "color": "r"},
|
|
1698
|
-
{"rug": True, "rug_kwargs": {"color": "r"}},
|
|
1699
|
-
{"errorbar": True, "rug": True, "rug_kind": "max_depth"},
|
|
1700
|
-
{"errorbar": True, "coords": {"school": slice(4)}, "n_points": 10},
|
|
1701
|
-
{"extra_methods": True, "rug": True},
|
|
1702
|
-
{"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
|
|
1703
|
-
],
|
|
1704
|
-
)
|
|
1705
|
-
def test_plot_mcse(models, kwargs):
|
|
1706
|
-
idata = models.model_1
|
|
1707
|
-
ax = plot_mcse(idata, **kwargs)
|
|
1708
|
-
assert np.all(ax)
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
@pytest.mark.parametrize("dim", ["chain", "draw"])
|
|
1712
|
-
def test_plot_mcse_bad_coords(models, dim):
|
|
1713
|
-
"""Test error when chain or dim are used as coords to select a data subset."""
|
|
1714
|
-
idata = models.model_1
|
|
1715
|
-
with pytest.raises(ValueError, match="invalid coordinates"):
|
|
1716
|
-
plot_mcse(idata, coords={dim: slice(3)})
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
def test_plot_mcse_no_sample_stats(models):
|
|
1720
|
-
"""Test error when rug=True but sample_stats group is not present."""
|
|
1721
|
-
idata = models.model_1
|
|
1722
|
-
with pytest.raises(ValueError, match="must contain sample_stats"):
|
|
1723
|
-
plot_mcse(idata.posterior, rug=True)
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
def test_plot_mcse_no_divergences(models):
|
|
1727
|
-
"""Test error when rug=True, but the variable defined by rug_kind is missing."""
|
|
1728
|
-
idata = deepcopy(models.model_1)
|
|
1729
|
-
idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
|
|
1730
|
-
with pytest.raises(ValueError, match="not contain diverging"):
|
|
1731
|
-
plot_mcse(idata, rug=True)
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
@pytest.mark.parametrize(
|
|
1735
|
-
"kwargs",
|
|
1736
|
-
[
|
|
1737
|
-
{},
|
|
1738
|
-
{"var_names": ["theta"]},
|
|
1739
|
-
{"var_names": ["theta"], "coords": {"school": [0, 1]}},
|
|
1740
|
-
{"var_names": ["eta"], "posterior_kwargs": {"rug": True, "rug_kwargs": {"color": "r"}}},
|
|
1741
|
-
{"var_names": ["mu"], "prior_kwargs": {"fill_kwargs": {"alpha": 0.5}}},
|
|
1742
|
-
{
|
|
1743
|
-
"var_names": ["tau"],
|
|
1744
|
-
"prior_kwargs": {"plot_kwargs": {"color": "r"}},
|
|
1745
|
-
"posterior_kwargs": {"plot_kwargs": {"color": "b"}},
|
|
1746
|
-
},
|
|
1747
|
-
{"var_names": ["y"], "kind": "observed"},
|
|
1748
|
-
],
|
|
1749
|
-
)
|
|
1750
|
-
def test_plot_dist_comparison(models, kwargs):
|
|
1751
|
-
idata = models.model_1
|
|
1752
|
-
ax = plot_dist_comparison(idata, **kwargs)
|
|
1753
|
-
assert np.all(ax)
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
def test_plot_dist_comparison_different_vars():
|
|
1757
|
-
data = from_dict(
|
|
1758
|
-
posterior={
|
|
1759
|
-
"x": np.random.randn(4, 100, 30),
|
|
1760
|
-
},
|
|
1761
|
-
prior={"x_hat": np.random.randn(4, 100, 30)},
|
|
1762
|
-
)
|
|
1763
|
-
with pytest.raises(KeyError):
|
|
1764
|
-
plot_dist_comparison(data, var_names="x")
|
|
1765
|
-
ax = plot_dist_comparison(data, var_names=[["x_hat"], ["x"]])
|
|
1766
|
-
assert np.all(ax)
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
def test_plot_dist_comparison_combinedims(models):
|
|
1770
|
-
idata = models.model_1
|
|
1771
|
-
ax = plot_dist_comparison(idata, combine_dims={"school"})
|
|
1772
|
-
assert np.all(ax)
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
def test_plot_dist_comparison_different_vars_combinedims():
|
|
1776
|
-
data = from_dict(
|
|
1777
|
-
posterior={
|
|
1778
|
-
"x": np.random.randn(4, 100, 30),
|
|
1779
|
-
},
|
|
1780
|
-
prior={"x_hat": np.random.randn(4, 100, 30)},
|
|
1781
|
-
dims={"x": ["3rd_dim"], "x_hat": ["3rd_dim"]},
|
|
1782
|
-
)
|
|
1783
|
-
with pytest.raises(KeyError):
|
|
1784
|
-
plot_dist_comparison(data, var_names="x", combine_dims={"3rd_dim"})
|
|
1785
|
-
ax = plot_dist_comparison(data, var_names=[["x_hat"], ["x"]], combine_dims={"3rd_dim"})
|
|
1786
|
-
assert np.all(ax)
|
|
1787
|
-
assert ax.size == 3
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
@pytest.mark.parametrize(
|
|
1791
|
-
"kwargs",
|
|
1792
|
-
[
|
|
1793
|
-
{},
|
|
1794
|
-
{"reference": "analytical"},
|
|
1795
|
-
{"kind": "p_value"},
|
|
1796
|
-
{"kind": "t_stat", "t_stat": "std"},
|
|
1797
|
-
{"kind": "t_stat", "t_stat": 0.5, "bpv": True},
|
|
1798
|
-
],
|
|
1799
|
-
)
|
|
1800
|
-
def test_plot_bpv(models, kwargs):
|
|
1801
|
-
axes = plot_bpv(models.model_1, **kwargs)
|
|
1802
|
-
assert not isinstance(axes, np.ndarray)
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
def test_plot_bpv_discrete():
|
|
1806
|
-
fake_obs = {"a": np.random.poisson(2.5, 100)}
|
|
1807
|
-
fake_pp = {"a": np.random.poisson(2.5, (1, 10, 100))}
|
|
1808
|
-
fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs)
|
|
1809
|
-
axes = plot_bpv(fake_model)
|
|
1810
|
-
assert not isinstance(axes, np.ndarray)
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
@pytest.mark.parametrize(
|
|
1814
|
-
"kwargs",
|
|
1815
|
-
[
|
|
1816
|
-
{},
|
|
1817
|
-
{
|
|
1818
|
-
"binwidth": 0.5,
|
|
1819
|
-
"stackratio": 2,
|
|
1820
|
-
"nquantiles": 20,
|
|
1821
|
-
},
|
|
1822
|
-
{"point_interval": True},
|
|
1823
|
-
{
|
|
1824
|
-
"point_interval": True,
|
|
1825
|
-
"dotsize": 1.2,
|
|
1826
|
-
"point_estimate": "median",
|
|
1827
|
-
"plot_kwargs": {"color": "grey"},
|
|
1828
|
-
},
|
|
1829
|
-
{
|
|
1830
|
-
"point_interval": True,
|
|
1831
|
-
"plot_kwargs": {"color": "grey"},
|
|
1832
|
-
"nquantiles": 100,
|
|
1833
|
-
"hdi_prob": 0.95,
|
|
1834
|
-
"intervalcolor": "green",
|
|
1835
|
-
},
|
|
1836
|
-
{
|
|
1837
|
-
"point_interval": True,
|
|
1838
|
-
"plot_kwargs": {"color": "grey"},
|
|
1839
|
-
"quartiles": False,
|
|
1840
|
-
"linewidth": 2,
|
|
1841
|
-
},
|
|
1842
|
-
],
|
|
1843
|
-
)
|
|
1844
|
-
def test_plot_dot(continuous_model, kwargs):
|
|
1845
|
-
data = continuous_model["x"]
|
|
1846
|
-
ax = plot_dot(data, **kwargs)
|
|
1847
|
-
assert ax
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
@pytest.mark.parametrize(
|
|
1851
|
-
"kwargs",
|
|
1852
|
-
[
|
|
1853
|
-
{"rotated": True},
|
|
1854
|
-
{
|
|
1855
|
-
"point_interval": True,
|
|
1856
|
-
"rotated": True,
|
|
1857
|
-
"dotcolor": "grey",
|
|
1858
|
-
"binwidth": 0.5,
|
|
1859
|
-
},
|
|
1860
|
-
{
|
|
1861
|
-
"rotated": True,
|
|
1862
|
-
"point_interval": True,
|
|
1863
|
-
"plot_kwargs": {"color": "grey"},
|
|
1864
|
-
"nquantiles": 100,
|
|
1865
|
-
"dotsize": 0.8,
|
|
1866
|
-
"hdi_prob": 0.95,
|
|
1867
|
-
"intervalcolor": "green",
|
|
1868
|
-
},
|
|
1869
|
-
],
|
|
1870
|
-
)
|
|
1871
|
-
def test_plot_dot_rotated(continuous_model, kwargs):
|
|
1872
|
-
data = continuous_model["x"]
|
|
1873
|
-
ax = plot_dot(data, **kwargs)
|
|
1874
|
-
assert ax
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
@pytest.mark.parametrize(
|
|
1878
|
-
"kwargs",
|
|
1879
|
-
[
|
|
1880
|
-
{
|
|
1881
|
-
"point_estimate": "mean",
|
|
1882
|
-
"hdi_prob": 0.95,
|
|
1883
|
-
"quartiles": False,
|
|
1884
|
-
"linewidth": 2,
|
|
1885
|
-
"markersize": 2,
|
|
1886
|
-
"markercolor": "red",
|
|
1887
|
-
"marker": "o",
|
|
1888
|
-
"rotated": False,
|
|
1889
|
-
"intervalcolor": "green",
|
|
1890
|
-
},
|
|
1891
|
-
],
|
|
1892
|
-
)
|
|
1893
|
-
def test_plot_point_interval(continuous_model, kwargs):
|
|
1894
|
-
_, ax = plt.subplots()
|
|
1895
|
-
data = continuous_model["x"]
|
|
1896
|
-
values = np.sort(data)
|
|
1897
|
-
ax = plot_point_interval(ax, values, **kwargs)
|
|
1898
|
-
assert ax
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
def test_wilkinson_algorithm(continuous_model):
|
|
1902
|
-
data = continuous_model["x"]
|
|
1903
|
-
values = np.sort(data)
|
|
1904
|
-
_, stack_counts = wilkinson_algorithm(values, 0.5)
|
|
1905
|
-
assert np.sum(stack_counts) == len(values)
|
|
1906
|
-
stack_locs, stack_counts = wilkinson_algorithm([0.0, 1.0, 1.8, 3.0, 5.0], 1.0)
|
|
1907
|
-
assert stack_locs == [0.0, 1.4, 3.0, 5.0]
|
|
1908
|
-
assert stack_counts == [1, 2, 1, 1]
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
@pytest.mark.parametrize(
|
|
1912
|
-
"kwargs",
|
|
1913
|
-
[
|
|
1914
|
-
{},
|
|
1915
|
-
{"y_hat": "bad_name"},
|
|
1916
|
-
{"x": "x1"},
|
|
1917
|
-
{"x": ("x1", "x2")},
|
|
1918
|
-
{
|
|
1919
|
-
"x": ("x1", "x2"),
|
|
1920
|
-
"y_kwargs": {"color": "blue", "marker": "^"},
|
|
1921
|
-
"y_hat_plot_kwargs": {"color": "cyan"},
|
|
1922
|
-
},
|
|
1923
|
-
{"x": ("x1", "x2"), "y_model_plot_kwargs": {"color": "red"}},
|
|
1924
|
-
{
|
|
1925
|
-
"x": ("x1", "x2"),
|
|
1926
|
-
"kind_pp": "hdi",
|
|
1927
|
-
"kind_model": "hdi",
|
|
1928
|
-
"y_model_fill_kwargs": {"color": "red"},
|
|
1929
|
-
"y_hat_fill_kwargs": {"color": "cyan"},
|
|
1930
|
-
},
|
|
1931
|
-
],
|
|
1932
|
-
)
|
|
1933
|
-
def test_plot_lm_1d(models, kwargs):
|
|
1934
|
-
"""Test functionality for 1D data."""
|
|
1935
|
-
idata = models.model_1
|
|
1936
|
-
if "constant_data" not in idata.groups():
|
|
1937
|
-
y = idata.observed_data["y"]
|
|
1938
|
-
x1data = y.coords[y.dims[0]]
|
|
1939
|
-
idata.add_groups({"constant_data": {"_": x1data}})
|
|
1940
|
-
idata.constant_data["x1"] = x1data
|
|
1941
|
-
idata.constant_data["x2"] = x1data
|
|
1942
|
-
|
|
1943
|
-
axes = plot_lm(idata=idata, y="y", y_model="eta", xjitter=True, **kwargs)
|
|
1944
|
-
assert np.all(axes)
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
def test_plot_lm_multidim(multidim_models):
|
|
1948
|
-
"""Test functionality for multidimentional data."""
|
|
1949
|
-
idata = multidim_models.model_1
|
|
1950
|
-
axes = plot_lm(
|
|
1951
|
-
idata=idata,
|
|
1952
|
-
x=idata.observed_data["y"].coords["dim1"].values,
|
|
1953
|
-
y="y",
|
|
1954
|
-
xjitter=True,
|
|
1955
|
-
plot_dim="dim1",
|
|
1956
|
-
show=False,
|
|
1957
|
-
figsize=(4, 16),
|
|
1958
|
-
)
|
|
1959
|
-
assert np.all(axes)
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
@pytest.mark.parametrize(
|
|
1963
|
-
"val_err_kwargs",
|
|
1964
|
-
[
|
|
1965
|
-
{},
|
|
1966
|
-
{"kind_pp": "bad_kind"},
|
|
1967
|
-
{"kind_model": "bad_kind"},
|
|
1968
|
-
],
|
|
1969
|
-
)
|
|
1970
|
-
def test_plot_lm_valueerror(multidim_models, val_err_kwargs):
|
|
1971
|
-
"""Test error plot_dim gets no value for multidim data and wrong value in kind_... args."""
|
|
1972
|
-
idata2 = multidim_models.model_1
|
|
1973
|
-
with pytest.raises(ValueError):
|
|
1974
|
-
plot_lm(idata=idata2, y="y", **val_err_kwargs)
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
@pytest.mark.parametrize(
|
|
1978
|
-
"warn_kwargs",
|
|
1979
|
-
[
|
|
1980
|
-
{"y_hat": "bad_name"},
|
|
1981
|
-
{"y_model": "bad_name"},
|
|
1982
|
-
],
|
|
1983
|
-
)
|
|
1984
|
-
def test_plot_lm_warning(models, warn_kwargs):
|
|
1985
|
-
"""Test Warning when needed groups or variables are not there in idata."""
|
|
1986
|
-
idata1 = models.model_1
|
|
1987
|
-
with pytest.warns(UserWarning):
|
|
1988
|
-
plot_lm(
|
|
1989
|
-
idata=from_dict(observed_data={"y": idata1.observed_data["y"].values}),
|
|
1990
|
-
y="y",
|
|
1991
|
-
**warn_kwargs,
|
|
1992
|
-
)
|
|
1993
|
-
with pytest.warns(UserWarning):
|
|
1994
|
-
plot_lm(idata=idata1, y="y", **warn_kwargs)
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
def test_plot_lm_typeerror(models):
|
|
1998
|
-
"""Test error when invalid value passed to num_samples."""
|
|
1999
|
-
idata1 = models.model_1
|
|
2000
|
-
with pytest.raises(TypeError):
|
|
2001
|
-
plot_lm(idata=idata1, y="y", num_samples=-1)
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
def test_plot_lm_list():
|
|
2005
|
-
"""Test the plots when input data is list or ndarray."""
|
|
2006
|
-
y = [1, 2, 3, 4, 5]
|
|
2007
|
-
assert plot_lm(y=y, x=np.arange(len(y)), show=False)
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
@pytest.mark.parametrize(
|
|
2011
|
-
"kwargs",
|
|
2012
|
-
[
|
|
2013
|
-
{},
|
|
2014
|
-
{"y_hat": "bad_name"},
|
|
2015
|
-
{"x": "x"},
|
|
2016
|
-
{"x": ("x", "x")},
|
|
2017
|
-
{"y_holdout": "z"},
|
|
2018
|
-
{"y_holdout": "z", "x_holdout": "x_pred"},
|
|
2019
|
-
{"x": ("x", "x"), "y_holdout": "z", "x_holdout": ("x_pred", "x_pred")},
|
|
2020
|
-
{"y_forecasts": "z"},
|
|
2021
|
-
{"y_holdout": "z", "y_forecasts": "bad_name"},
|
|
2022
|
-
],
|
|
2023
|
-
)
|
|
2024
|
-
def test_plot_ts(kwargs):
|
|
2025
|
-
"""Test timeseries plots basic functionality."""
|
|
2026
|
-
nchains = 4
|
|
2027
|
-
ndraws = 500
|
|
2028
|
-
obs_data = {
|
|
2029
|
-
"y": 2 * np.arange(1, 9) + 3,
|
|
2030
|
-
"z": 2 * np.arange(8, 12) + 3,
|
|
2031
|
-
}
|
|
2032
|
-
|
|
2033
|
-
posterior_predictive = {
|
|
2034
|
-
"y": np.random.normal(
|
|
2035
|
-
(obs_data["y"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["y"]))
|
|
2036
|
-
),
|
|
2037
|
-
"z": np.random.normal(
|
|
2038
|
-
(obs_data["z"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["z"]))
|
|
2039
|
-
),
|
|
2040
|
-
}
|
|
2041
|
-
|
|
2042
|
-
const_data = {"x": np.arange(1, 9), "x_pred": np.arange(8, 12)}
|
|
2043
|
-
|
|
2044
|
-
idata = from_dict(
|
|
2045
|
-
observed_data=obs_data,
|
|
2046
|
-
posterior_predictive=posterior_predictive,
|
|
2047
|
-
constant_data=const_data,
|
|
2048
|
-
coords={"obs_dim": np.arange(1, 9), "pred_dim": np.arange(8, 12)},
|
|
2049
|
-
dims={"y": ["obs_dim"], "z": ["pred_dim"]},
|
|
2050
|
-
)
|
|
2051
|
-
|
|
2052
|
-
ax = plot_ts(idata=idata, y="y", **kwargs)
|
|
2053
|
-
assert np.all(ax)
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
@pytest.mark.parametrize(
|
|
2057
|
-
"kwargs",
|
|
2058
|
-
[
|
|
2059
|
-
{},
|
|
2060
|
-
{
|
|
2061
|
-
"y_holdout": "z",
|
|
2062
|
-
"holdout_dim": "holdout_dim1",
|
|
2063
|
-
"x": ("x", "x"),
|
|
2064
|
-
"x_holdout": ("x_pred", "x_pred"),
|
|
2065
|
-
},
|
|
2066
|
-
{"y_forecasts": "z", "holdout_dim": "holdout_dim1"},
|
|
2067
|
-
],
|
|
2068
|
-
)
|
|
2069
|
-
def test_plot_ts_multidim(kwargs):
|
|
2070
|
-
"""Test timeseries plots multidim functionality."""
|
|
2071
|
-
nchains = 4
|
|
2072
|
-
ndraws = 500
|
|
2073
|
-
ndim1 = 5
|
|
2074
|
-
ndim2 = 7
|
|
2075
|
-
data = {
|
|
2076
|
-
"y": np.random.normal(size=(ndim1, ndim2)),
|
|
2077
|
-
"z": np.random.normal(size=(ndim1, ndim2)),
|
|
2078
|
-
}
|
|
2079
|
-
|
|
2080
|
-
posterior_predictive = {
|
|
2081
|
-
"y": np.random.randn(nchains, ndraws, ndim1, ndim2),
|
|
2082
|
-
"z": np.random.randn(nchains, ndraws, ndim1, ndim2),
|
|
2083
|
-
}
|
|
2084
|
-
|
|
2085
|
-
const_data = {"x": np.arange(1, 6), "x_pred": np.arange(5, 10)}
|
|
2086
|
-
|
|
2087
|
-
idata = from_dict(
|
|
2088
|
-
observed_data=data,
|
|
2089
|
-
posterior_predictive=posterior_predictive,
|
|
2090
|
-
constant_data=const_data,
|
|
2091
|
-
dims={
|
|
2092
|
-
"y": ["dim1", "dim2"],
|
|
2093
|
-
"z": ["holdout_dim1", "holdout_dim2"],
|
|
2094
|
-
},
|
|
2095
|
-
coords={
|
|
2096
|
-
"dim1": range(ndim1),
|
|
2097
|
-
"dim2": range(ndim2),
|
|
2098
|
-
"holdout_dim1": range(ndim1 - 1, ndim1 + 4),
|
|
2099
|
-
"holdout_dim2": range(ndim2 - 1, ndim2 + 6),
|
|
2100
|
-
},
|
|
2101
|
-
)
|
|
2102
|
-
|
|
2103
|
-
ax = plot_ts(idata=idata, y="y", plot_dim="dim1", **kwargs)
|
|
2104
|
-
assert np.all(ax)
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
@pytest.mark.parametrize("val_err_kwargs", [{}, {"plot_dim": "dim1", "y_holdout": "y"}])
|
|
2108
|
-
def test_plot_ts_valueerror(multidim_models, val_err_kwargs):
|
|
2109
|
-
"""Test error plot_dim gets no value for multidim data and wrong value in kind_... args."""
|
|
2110
|
-
idata2 = multidim_models.model_1
|
|
2111
|
-
with pytest.raises(ValueError):
|
|
2112
|
-
plot_ts(idata=idata2, y="y", **val_err_kwargs)
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
def test_plot_bf():
|
|
2116
|
-
idata = from_dict(
|
|
2117
|
-
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
2118
|
-
)
|
|
2119
|
-
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
|
|
2120
|
-
assert bf_plot is not None
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
def generate_lm_1d_data():
|
|
2124
|
-
rng = np.random.default_rng()
|
|
2125
|
-
return from_dict(
|
|
2126
|
-
observed_data={"y": rng.normal(size=7)},
|
|
2127
|
-
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
2128
|
-
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
2129
|
-
dims={"y": ["dim1"]},
|
|
2130
|
-
coords={"dim1": range(7)},
|
|
2131
|
-
)
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
def generate_lm_2d_data():
|
|
2135
|
-
rng = np.random.default_rng()
|
|
2136
|
-
return from_dict(
|
|
2137
|
-
observed_data={"y": rng.normal(size=(5, 7))},
|
|
2138
|
-
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
2139
|
-
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
2140
|
-
dims={"y": ["dim1", "dim2"]},
|
|
2141
|
-
coords={"dim1": range(5), "dim2": range(7)},
|
|
2142
|
-
)
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
2146
|
-
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
2147
|
-
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
2148
|
-
def test_plot_lm(data, kind, use_y_model):
|
|
2149
|
-
if data == "1d":
|
|
2150
|
-
idata = generate_lm_1d_data()
|
|
2151
|
-
else:
|
|
2152
|
-
idata = generate_lm_2d_data()
|
|
2153
|
-
|
|
2154
|
-
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
|
|
2155
|
-
if data == "2d":
|
|
2156
|
-
kwargs["plot_dim"] = "dim1"
|
|
2157
|
-
if use_y_model:
|
|
2158
|
-
kwargs["y_model"] = "y_model"
|
|
2159
|
-
if kind == "lines":
|
|
2160
|
-
kwargs["num_samples"] = 50
|
|
2161
|
-
|
|
2162
|
-
ax = plot_lm(**kwargs)
|
|
2163
|
-
assert ax is not None
|
|
2164
|
-
|
|
2165
|
-
|
|
2166
|
-
@pytest.mark.parametrize(
|
|
2167
|
-
"coords, expected_vars",
|
|
2168
|
-
[
|
|
2169
|
-
({"school": ["Choate"]}, ["theta"]),
|
|
2170
|
-
({"school": ["Lawrenceville"]}, ["theta"]),
|
|
2171
|
-
({}, ["theta"]),
|
|
2172
|
-
],
|
|
2173
|
-
)
|
|
2174
|
-
def test_plot_autocorr_coords(coords, expected_vars):
|
|
2175
|
-
"""Test plot_autocorr with coords kwarg."""
|
|
2176
|
-
idata = load_arviz_data("centered_eight")
|
|
2177
|
-
|
|
2178
|
-
axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
|
|
2179
|
-
assert axes is not None
|
|
2180
|
-
|
|
2181
|
-
|
|
2182
|
-
def test_plot_forest_with_transform():
|
|
2183
|
-
"""Test if plot_forest runs successfully with a transform dictionary."""
|
|
2184
|
-
data = xr.Dataset(
|
|
2185
|
-
{
|
|
2186
|
-
"var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])),
|
|
2187
|
-
"var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])),
|
|
2188
|
-
},
|
|
2189
|
-
coords={"chain": [0, 1], "draw": [0, 1, 2]},
|
|
2190
|
-
)
|
|
2191
|
-
transform_dict = {
|
|
2192
|
-
"var1": lambda x: x + 1,
|
|
2193
|
-
"var2": lambda x: x * 2,
|
|
2194
|
-
}
|
|
2195
|
-
|
|
2196
|
-
axes = plot_forest(data, transform=transform_dict, show=False)
|
|
2197
|
-
assert axes is not None
|