arviz 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +3 -2
- arviz/data/__init__.py +5 -2
- arviz/data/base.py +102 -11
- arviz/data/converters.py +5 -0
- arviz/data/datasets.py +1 -0
- arviz/data/example_data/data_remote.json +10 -3
- arviz/data/inference_data.py +26 -25
- arviz/data/io_cmdstan.py +1 -3
- arviz/data/io_datatree.py +1 -0
- arviz/data/io_dict.py +5 -3
- arviz/data/io_emcee.py +1 -0
- arviz/data/io_numpyro.py +1 -0
- arviz/data/io_pyjags.py +1 -0
- arviz/data/io_pyro.py +1 -0
- arviz/data/io_pystan.py +1 -2
- arviz/data/utils.py +1 -0
- arviz/plots/__init__.py +1 -0
- arviz/plots/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/bpvplot.py +8 -2
- arviz/plots/backends/bokeh/compareplot.py +8 -4
- arviz/plots/backends/bokeh/densityplot.py +1 -0
- arviz/plots/backends/bokeh/distplot.py +1 -0
- arviz/plots/backends/bokeh/dotplot.py +1 -0
- arviz/plots/backends/bokeh/ecdfplot.py +1 -0
- arviz/plots/backends/bokeh/elpdplot.py +1 -0
- arviz/plots/backends/bokeh/energyplot.py +1 -0
- arviz/plots/backends/bokeh/forestplot.py +2 -4
- arviz/plots/backends/bokeh/hdiplot.py +1 -0
- arviz/plots/backends/bokeh/kdeplot.py +3 -3
- arviz/plots/backends/bokeh/khatplot.py +1 -0
- arviz/plots/backends/bokeh/lmplot.py +1 -0
- arviz/plots/backends/bokeh/loopitplot.py +1 -0
- arviz/plots/backends/bokeh/mcseplot.py +1 -0
- arviz/plots/backends/bokeh/pairplot.py +1 -0
- arviz/plots/backends/bokeh/parallelplot.py +1 -0
- arviz/plots/backends/bokeh/posteriorplot.py +1 -0
- arviz/plots/backends/bokeh/ppcplot.py +1 -0
- arviz/plots/backends/bokeh/rankplot.py +1 -0
- arviz/plots/backends/bokeh/separationplot.py +1 -0
- arviz/plots/backends/bokeh/traceplot.py +1 -0
- arviz/plots/backends/bokeh/violinplot.py +1 -0
- arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
- arviz/plots/backends/matplotlib/bpvplot.py +1 -0
- arviz/plots/backends/matplotlib/compareplot.py +2 -1
- arviz/plots/backends/matplotlib/densityplot.py +1 -0
- arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
- arviz/plots/backends/matplotlib/distplot.py +1 -0
- arviz/plots/backends/matplotlib/dotplot.py +1 -0
- arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
- arviz/plots/backends/matplotlib/elpdplot.py +1 -0
- arviz/plots/backends/matplotlib/energyplot.py +1 -0
- arviz/plots/backends/matplotlib/essplot.py +6 -5
- arviz/plots/backends/matplotlib/forestplot.py +3 -4
- arviz/plots/backends/matplotlib/hdiplot.py +1 -0
- arviz/plots/backends/matplotlib/kdeplot.py +5 -3
- arviz/plots/backends/matplotlib/khatplot.py +1 -0
- arviz/plots/backends/matplotlib/lmplot.py +1 -0
- arviz/plots/backends/matplotlib/loopitplot.py +1 -0
- arviz/plots/backends/matplotlib/mcseplot.py +11 -10
- arviz/plots/backends/matplotlib/pairplot.py +2 -1
- arviz/plots/backends/matplotlib/parallelplot.py +1 -0
- arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
- arviz/plots/backends/matplotlib/ppcplot.py +1 -0
- arviz/plots/backends/matplotlib/rankplot.py +1 -0
- arviz/plots/backends/matplotlib/separationplot.py +1 -0
- arviz/plots/backends/matplotlib/traceplot.py +2 -1
- arviz/plots/backends/matplotlib/tsplot.py +1 -0
- arviz/plots/backends/matplotlib/violinplot.py +2 -1
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +3 -2
- arviz/plots/compareplot.py +3 -2
- arviz/plots/densityplot.py +1 -0
- arviz/plots/distcomparisonplot.py +1 -0
- arviz/plots/dotplot.py +1 -0
- arviz/plots/ecdfplot.py +38 -112
- arviz/plots/elpdplot.py +2 -1
- arviz/plots/energyplot.py +1 -0
- arviz/plots/essplot.py +3 -2
- arviz/plots/forestplot.py +1 -0
- arviz/plots/hdiplot.py +1 -0
- arviz/plots/khatplot.py +1 -0
- arviz/plots/lmplot.py +1 -0
- arviz/plots/loopitplot.py +1 -0
- arviz/plots/mcseplot.py +1 -0
- arviz/plots/pairplot.py +2 -1
- arviz/plots/parallelplot.py +1 -0
- arviz/plots/plot_utils.py +1 -0
- arviz/plots/posteriorplot.py +1 -0
- arviz/plots/ppcplot.py +11 -5
- arviz/plots/rankplot.py +1 -0
- arviz/plots/separationplot.py +1 -0
- arviz/plots/traceplot.py +1 -0
- arviz/plots/tsplot.py +1 -0
- arviz/plots/violinplot.py +1 -0
- arviz/rcparams.py +1 -0
- arviz/sel_utils.py +1 -0
- arviz/static/css/style.css +2 -1
- arviz/stats/density_utils.py +4 -3
- arviz/stats/diagnostics.py +4 -4
- arviz/stats/ecdf_utils.py +166 -0
- arviz/stats/stats.py +16 -32
- arviz/stats/stats_refitting.py +1 -0
- arviz/stats/stats_utils.py +6 -2
- arviz/tests/base_tests/test_data.py +18 -4
- arviz/tests/base_tests/test_diagnostics.py +1 -0
- arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
- arviz/tests/base_tests/test_labels.py +1 -0
- arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
- arviz/tests/base_tests/test_stats.py +4 -4
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +4 -3
- arviz/tests/base_tests/test_utils.py +3 -2
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +1 -1
- arviz/wrappers/__init__.py +1 -0
- {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/METADATA +10 -9
- arviz-0.18.0.dist-info/RECORD +182 -0
- {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/WHEEL +1 -1
- arviz-0.17.0.dist-info/RECORD +0 -180
- {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
- {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
arviz/stats/stats.py
CHANGED
|
@@ -146,6 +146,7 @@ def compare(
|
|
|
146
146
|
Compare the centered and non centered models of the eight school problem:
|
|
147
147
|
|
|
148
148
|
.. ipython::
|
|
149
|
+
:okwarning:
|
|
149
150
|
|
|
150
151
|
In [1]: import arviz as az
|
|
151
152
|
...: data1 = az.load_arviz_data("non_centered_eight")
|
|
@@ -157,6 +158,7 @@ def compare(
|
|
|
157
158
|
weights using the stacking method.
|
|
158
159
|
|
|
159
160
|
.. ipython::
|
|
161
|
+
:okwarning:
|
|
160
162
|
|
|
161
163
|
In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
|
|
162
164
|
|
|
@@ -180,37 +182,19 @@ def compare(
|
|
|
180
182
|
except Exception as e:
|
|
181
183
|
raise e.__class__("Encountered error in ELPD computation of compare.") from e
|
|
182
184
|
names = list(ics_dict.keys())
|
|
183
|
-
if ic
|
|
185
|
+
if ic in {"loo", "waic"}:
|
|
184
186
|
df_comp = pd.DataFrame(
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
"
|
|
188
|
-
"
|
|
189
|
-
"
|
|
190
|
-
"
|
|
191
|
-
"
|
|
192
|
-
"
|
|
193
|
-
"
|
|
194
|
-
"
|
|
195
|
-
|
|
196
|
-
],
|
|
197
|
-
dtype=np.float_,
|
|
198
|
-
)
|
|
199
|
-
elif ic == "waic":
|
|
200
|
-
df_comp = pd.DataFrame(
|
|
201
|
-
index=names,
|
|
202
|
-
columns=[
|
|
203
|
-
"rank",
|
|
204
|
-
"elpd_waic",
|
|
205
|
-
"p_waic",
|
|
206
|
-
"elpd_diff",
|
|
207
|
-
"weight",
|
|
208
|
-
"se",
|
|
209
|
-
"dse",
|
|
210
|
-
"warning",
|
|
211
|
-
"scale",
|
|
212
|
-
],
|
|
213
|
-
dtype=np.float_,
|
|
187
|
+
{
|
|
188
|
+
"rank": pd.Series(index=names, dtype="int"),
|
|
189
|
+
f"elpd_{ic}": pd.Series(index=names, dtype="float"),
|
|
190
|
+
f"p_{ic}": pd.Series(index=names, dtype="float"),
|
|
191
|
+
"elpd_diff": pd.Series(index=names, dtype="float"),
|
|
192
|
+
"weight": pd.Series(index=names, dtype="float"),
|
|
193
|
+
"se": pd.Series(index=names, dtype="float"),
|
|
194
|
+
"dse": pd.Series(index=names, dtype="float"),
|
|
195
|
+
"warning": pd.Series(index=names, dtype="boolean"),
|
|
196
|
+
"scale": pd.Series(index=names, dtype="str"),
|
|
197
|
+
}
|
|
214
198
|
)
|
|
215
199
|
else:
|
|
216
200
|
raise NotImplementedError(f"The information criterion {ic} is not supported.")
|
|
@@ -632,7 +616,7 @@ def _hdi(ary, hdi_prob, circular, skipna):
|
|
|
632
616
|
ary = np.sort(ary)
|
|
633
617
|
interval_idx_inc = int(np.floor(hdi_prob * n))
|
|
634
618
|
n_intervals = n - interval_idx_inc
|
|
635
|
-
interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.
|
|
619
|
+
interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
|
|
636
620
|
|
|
637
621
|
if len(interval_width) == 0:
|
|
638
622
|
raise ValueError("Too few elements for interval calculation. ")
|
|
@@ -2096,7 +2080,7 @@ def weight_predictions(idatas, weights=None):
|
|
|
2096
2080
|
weights /= weights.sum()
|
|
2097
2081
|
|
|
2098
2082
|
len_idatas = [
|
|
2099
|
-
idata.posterior_predictive.
|
|
2083
|
+
idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
|
|
2100
2084
|
for idata in idatas
|
|
2101
2085
|
]
|
|
2102
2086
|
|
arviz/stats/stats_refitting.py
CHANGED
arviz/stats/stats_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Stats-utility functions for ArviZ."""
|
|
2
|
+
|
|
2
3
|
import warnings
|
|
3
4
|
from collections.abc import Sequence
|
|
4
5
|
from copy import copy as _copy
|
|
@@ -134,7 +135,10 @@ def make_ufunc(
|
|
|
134
135
|
raise TypeError(msg)
|
|
135
136
|
for idx in np.ndindex(out.shape[:n_dims_out]):
|
|
136
137
|
arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
|
|
137
|
-
|
|
138
|
+
out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
|
|
139
|
+
if n_dims_out is None:
|
|
140
|
+
out_idx = out_idx.item()
|
|
141
|
+
out[idx] = out_idx
|
|
138
142
|
return out
|
|
139
143
|
|
|
140
144
|
def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
|
|
@@ -484,7 +488,7 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
|
484
488
|
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
485
489
|
|
|
486
490
|
if kind == "loo" and "pareto_k" in self:
|
|
487
|
-
bins = np.asarray([-np.
|
|
491
|
+
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
|
|
488
492
|
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
489
493
|
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
490
494
|
extended = extended.format(
|
|
@@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
|
|
|
1077
1077
|
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1078
1078
|
|
|
1079
1079
|
|
|
1080
|
+
def test_nested_dict_to_dataset():
|
|
1081
|
+
datadict = {
|
|
1082
|
+
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
1083
|
+
"d": np.random.randn(100),
|
|
1084
|
+
}
|
|
1085
|
+
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
|
|
1086
|
+
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
|
|
1087
|
+
assert set(dataset.coords) == {"chain", "draw", "c"}
|
|
1088
|
+
|
|
1089
|
+
assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
|
|
1090
|
+
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
|
|
1091
|
+
assert set(dataset.d.coords) == {"chain", "draw"}
|
|
1092
|
+
|
|
1093
|
+
|
|
1080
1094
|
def test_dict_to_dataset_event_dims_error():
|
|
1081
1095
|
datadict = {"a": np.random.randn(1, 100, 10)}
|
|
1082
1096
|
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
|
|
@@ -1241,7 +1255,7 @@ class TestDataDict:
|
|
|
1241
1255
|
self.check_var_names_coords_dims(inference_data.prior_predictive)
|
|
1242
1256
|
self.check_var_names_coords_dims(inference_data.sample_stats_prior)
|
|
1243
1257
|
|
|
1244
|
-
pred_dims = inference_data.predictions.
|
|
1258
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
1245
1259
|
assert pred_dims == 8
|
|
1246
1260
|
|
|
1247
1261
|
def test_inference_data_warmup(self, data, eight_schools_params):
|
|
@@ -1586,8 +1600,8 @@ class TestExtractDataset:
|
|
|
1586
1600
|
idata = load_arviz_data("centered_eight")
|
|
1587
1601
|
post = extract(idata, combined=False)
|
|
1588
1602
|
assert "sample" not in post.dims
|
|
1589
|
-
assert post.
|
|
1590
|
-
assert post.
|
|
1603
|
+
assert post.sizes["chain"] == 4
|
|
1604
|
+
assert post.sizes["draw"] == 500
|
|
1591
1605
|
|
|
1592
1606
|
def test_var_name_group(self):
|
|
1593
1607
|
idata = load_arviz_data("centered_eight")
|
|
@@ -1607,5 +1621,5 @@ class TestExtractDataset:
|
|
|
1607
1621
|
def test_subset_samples(self):
|
|
1608
1622
|
idata = load_arviz_data("centered_eight")
|
|
1609
1623
|
post = extract(idata, num_samples=10)
|
|
1610
|
-
assert post.
|
|
1624
|
+
assert post.sizes["sample"] == 10
|
|
1611
1625
|
assert post.attrs == idata.posterior.attrs
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Tests use the default backend."""
|
|
2
|
+
|
|
2
3
|
# pylint: disable=redefined-outer-name,too-many-lines
|
|
3
4
|
import os
|
|
4
5
|
from copy import deepcopy
|
|
@@ -54,7 +55,7 @@ from ..helpers import ( # pylint: disable=unused-import
|
|
|
54
55
|
eight_schools_params,
|
|
55
56
|
models,
|
|
56
57
|
multidim_models,
|
|
57
|
-
|
|
58
|
+
RandomVariableTestClass,
|
|
58
59
|
)
|
|
59
60
|
|
|
60
61
|
rcParams["data.load"] = "eager"
|
|
@@ -168,9 +169,9 @@ def test_plot_density_no_subset():
|
|
|
168
169
|
|
|
169
170
|
def test_plot_density_nonstring_varnames():
|
|
170
171
|
"""Test plot_density works when variables are not strings."""
|
|
171
|
-
rv1 =
|
|
172
|
-
rv2 =
|
|
173
|
-
rv3 =
|
|
172
|
+
rv1 = RandomVariableTestClass("a")
|
|
173
|
+
rv2 = RandomVariableTestClass("b")
|
|
174
|
+
rv3 = RandomVariableTestClass("c")
|
|
174
175
|
model_ab = from_dict(
|
|
175
176
|
{
|
|
176
177
|
rv1: np.random.normal(size=200),
|
|
@@ -752,7 +753,7 @@ def test_plot_ppc_transposed():
|
|
|
752
753
|
)
|
|
753
754
|
x, y = ax.get_lines()[2].get_data()
|
|
754
755
|
assert not np.isclose(y[0], 0)
|
|
755
|
-
assert np.all(np.array([
|
|
756
|
+
assert np.all(np.array([47, 44, 15, 11]) == x)
|
|
756
757
|
|
|
757
758
|
|
|
758
759
|
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
@@ -89,25 +89,25 @@ def test_hdi_idata(centered_eight):
|
|
|
89
89
|
data = centered_eight.posterior
|
|
90
90
|
result = hdi(data)
|
|
91
91
|
assert isinstance(result, Dataset)
|
|
92
|
-
assert dict(result.
|
|
92
|
+
assert dict(result.sizes) == {"school": 8, "hdi": 2}
|
|
93
93
|
|
|
94
94
|
result = hdi(data, input_core_dims=[["chain"]])
|
|
95
95
|
assert isinstance(result, Dataset)
|
|
96
|
-
assert result.
|
|
96
|
+
assert result.sizes == {"draw": 500, "hdi": 2, "school": 8}
|
|
97
97
|
|
|
98
98
|
|
|
99
99
|
def test_hdi_idata_varnames(centered_eight):
|
|
100
100
|
data = centered_eight.posterior
|
|
101
101
|
result = hdi(data, var_names=["mu", "theta"])
|
|
102
102
|
assert isinstance(result, Dataset)
|
|
103
|
-
assert result.
|
|
103
|
+
assert result.sizes == {"hdi": 2, "school": 8}
|
|
104
104
|
assert list(result.data_vars.keys()) == ["mu", "theta"]
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
def test_hdi_idata_group(centered_eight):
|
|
108
108
|
result_posterior = hdi(centered_eight, group="posterior", var_names="mu")
|
|
109
109
|
result_prior = hdi(centered_eight, group="prior", var_names="mu")
|
|
110
|
-
assert result_prior.
|
|
110
|
+
assert result_prior.sizes == {"hdi": 2}
|
|
111
111
|
range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
|
|
112
112
|
range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
|
|
113
113
|
assert range_posterior < range_prior
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import scipy.stats
|
|
5
|
+
from ...stats.ecdf_utils import (
|
|
6
|
+
compute_ecdf,
|
|
7
|
+
ecdf_confidence_band,
|
|
8
|
+
_get_ecdf_points,
|
|
9
|
+
_simulate_ecdf,
|
|
10
|
+
_get_pointwise_confidence_band,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_compute_ecdf():
|
|
15
|
+
"""Test compute_ecdf function."""
|
|
16
|
+
sample = np.array([1, 2, 3, 3, 4, 5])
|
|
17
|
+
eval_points = np.arange(0, 7, 0.1)
|
|
18
|
+
ecdf_expected = (sample[:, None] <= eval_points).mean(axis=0)
|
|
19
|
+
assert np.allclose(compute_ecdf(sample, eval_points), ecdf_expected)
|
|
20
|
+
assert np.allclose(compute_ecdf(sample / 2 + 10, eval_points / 2 + 10), ecdf_expected)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.parametrize("difference", [True, False])
|
|
24
|
+
def test_get_ecdf_points(difference):
|
|
25
|
+
"""Test _get_ecdf_points."""
|
|
26
|
+
# if first point already outside support, no need to insert it
|
|
27
|
+
sample = np.array([1, 2, 3, 3, 4, 5, 5])
|
|
28
|
+
eval_points = np.arange(-1, 7, 0.1)
|
|
29
|
+
x, y = _get_ecdf_points(sample, eval_points, difference)
|
|
30
|
+
assert np.array_equal(x, eval_points)
|
|
31
|
+
assert np.array_equal(y, compute_ecdf(sample, eval_points))
|
|
32
|
+
|
|
33
|
+
# if first point is inside support, insert it if not in difference mode
|
|
34
|
+
eval_points = np.arange(1, 6, 0.1)
|
|
35
|
+
x, y = _get_ecdf_points(sample, eval_points, difference)
|
|
36
|
+
assert len(x) == len(eval_points) + 1 - difference
|
|
37
|
+
assert len(y) == len(eval_points) + 1 - difference
|
|
38
|
+
|
|
39
|
+
# if not in difference mode, first point should be (eval_points[0], 0)
|
|
40
|
+
if not difference:
|
|
41
|
+
assert x[0] == eval_points[0]
|
|
42
|
+
assert y[0] == 0
|
|
43
|
+
assert np.allclose(x[1:], eval_points)
|
|
44
|
+
assert np.allclose(y[1:], compute_ecdf(sample, eval_points))
|
|
45
|
+
assert x[-1] == eval_points[-1]
|
|
46
|
+
assert y[-1] == 1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.parametrize(
|
|
50
|
+
"dist", [scipy.stats.norm(3, 10), scipy.stats.binom(10, 0.5)], ids=["continuous", "discrete"]
|
|
51
|
+
)
|
|
52
|
+
@pytest.mark.parametrize("seed", [32, 87])
|
|
53
|
+
def test_simulate_ecdf(dist, seed):
|
|
54
|
+
"""Test _simulate_ecdf."""
|
|
55
|
+
ndraws = 1000
|
|
56
|
+
eval_points = np.arange(0, 1, 0.1)
|
|
57
|
+
|
|
58
|
+
rvs = dist.rvs
|
|
59
|
+
|
|
60
|
+
random_state = np.random.default_rng(seed)
|
|
61
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, rvs, random_state=random_state)
|
|
62
|
+
random_state = np.random.default_rng(seed)
|
|
63
|
+
ecdf_expected = compute_ecdf(np.sort(rvs(ndraws, random_state=random_state)), eval_points)
|
|
64
|
+
|
|
65
|
+
assert np.allclose(ecdf, ecdf_expected)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.mark.parametrize("prob", [0.8, 0.9])
|
|
69
|
+
@pytest.mark.parametrize(
|
|
70
|
+
"dist", [scipy.stats.norm(3, 10), scipy.stats.poisson(100)], ids=["continuous", "discrete"]
|
|
71
|
+
)
|
|
72
|
+
@pytest.mark.parametrize("ndraws", [10_000])
|
|
73
|
+
def test_get_pointwise_confidence_band(dist, prob, ndraws, num_trials=1_000, seed=57):
|
|
74
|
+
"""Test _get_pointwise_confidence_band."""
|
|
75
|
+
eval_points = np.linspace(*dist.interval(0.99), 10)
|
|
76
|
+
cdf_at_eval_points = dist.cdf(eval_points)
|
|
77
|
+
|
|
78
|
+
ecdf_lower, ecdf_upper = _get_pointwise_confidence_band(prob, ndraws, cdf_at_eval_points)
|
|
79
|
+
|
|
80
|
+
# check basic properties
|
|
81
|
+
assert np.all(ecdf_lower >= 0)
|
|
82
|
+
assert np.all(ecdf_upper <= 1)
|
|
83
|
+
assert np.all(ecdf_lower <= ecdf_upper)
|
|
84
|
+
|
|
85
|
+
# use simulation to estimate lower and upper bounds on pointwise probability
|
|
86
|
+
in_interval = []
|
|
87
|
+
random_state = np.random.default_rng(seed)
|
|
88
|
+
for _ in range(num_trials):
|
|
89
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
|
|
90
|
+
in_interval.append((ecdf_lower <= ecdf) & (ecdf < ecdf_upper))
|
|
91
|
+
asymptotic_dist = scipy.stats.norm(
|
|
92
|
+
np.mean(in_interval, axis=0), scipy.stats.sem(in_interval, axis=0)
|
|
93
|
+
)
|
|
94
|
+
prob_lower, prob_upper = asymptotic_dist.interval(0.999)
|
|
95
|
+
|
|
96
|
+
# check target probability within all bounds
|
|
97
|
+
assert np.all(prob_lower <= prob)
|
|
98
|
+
assert np.all(prob <= prob_upper)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.mark.parametrize("prob", [0.8, 0.9])
|
|
102
|
+
@pytest.mark.parametrize(
|
|
103
|
+
"dist, rvs",
|
|
104
|
+
[
|
|
105
|
+
(scipy.stats.norm(3, 10), scipy.stats.norm(3, 10).rvs),
|
|
106
|
+
(scipy.stats.norm(3, 10), None),
|
|
107
|
+
(scipy.stats.poisson(100), scipy.stats.poisson(100).rvs),
|
|
108
|
+
],
|
|
109
|
+
ids=["continuous", "continuous default rvs", "discrete"],
|
|
110
|
+
)
|
|
111
|
+
@pytest.mark.parametrize("ndraws", [10_000])
|
|
112
|
+
@pytest.mark.parametrize("method", ["pointwise", "simulated"])
|
|
113
|
+
def test_ecdf_confidence_band(dist, rvs, prob, ndraws, method, num_trials=1_000, seed=57):
|
|
114
|
+
"""Test test_ecdf_confidence_band."""
|
|
115
|
+
eval_points = np.linspace(*dist.interval(0.99), 10)
|
|
116
|
+
cdf_at_eval_points = dist.cdf(eval_points)
|
|
117
|
+
random_state = np.random.default_rng(seed)
|
|
118
|
+
|
|
119
|
+
ecdf_lower, ecdf_upper = ecdf_confidence_band(
|
|
120
|
+
ndraws,
|
|
121
|
+
eval_points,
|
|
122
|
+
cdf_at_eval_points,
|
|
123
|
+
prob=prob,
|
|
124
|
+
rvs=rvs,
|
|
125
|
+
random_state=random_state,
|
|
126
|
+
method=method,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if method == "pointwise":
|
|
130
|
+
# these values tested elsewhere, we just make sure they're the same
|
|
131
|
+
ecdf_lower_pointwise, ecdf_upper_pointwise = _get_pointwise_confidence_band(
|
|
132
|
+
prob, ndraws, cdf_at_eval_points
|
|
133
|
+
)
|
|
134
|
+
assert np.array_equal(ecdf_lower, ecdf_lower_pointwise)
|
|
135
|
+
assert np.array_equal(ecdf_upper, ecdf_upper_pointwise)
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
# check basic properties
|
|
139
|
+
assert np.all(ecdf_lower >= 0)
|
|
140
|
+
assert np.all(ecdf_upper <= 1)
|
|
141
|
+
assert np.all(ecdf_lower <= ecdf_upper)
|
|
142
|
+
|
|
143
|
+
# use simulation to estimate lower and upper bounds on simultaneous probability
|
|
144
|
+
in_envelope = []
|
|
145
|
+
random_state = np.random.default_rng(seed)
|
|
146
|
+
for _ in range(num_trials):
|
|
147
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
|
|
148
|
+
in_envelope.append(np.all(ecdf_lower <= ecdf) & np.all(ecdf < ecdf_upper))
|
|
149
|
+
asymptotic_dist = scipy.stats.norm(np.mean(in_envelope), scipy.stats.sem(in_envelope))
|
|
150
|
+
prob_lower, prob_upper = asymptotic_dist.interval(0.999)
|
|
151
|
+
|
|
152
|
+
# check target probability within bounds
|
|
153
|
+
assert prob_lower <= prob <= prob_upper
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Tests for stats_utils."""
|
|
2
|
+
|
|
2
3
|
# pylint: disable=no-member
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pytest
|
|
@@ -344,9 +345,9 @@ def test_variance_bad_data():
|
|
|
344
345
|
|
|
345
346
|
def test_histogram():
|
|
346
347
|
school = load_arviz_data("non_centered_eight").posterior["mu"].values
|
|
347
|
-
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.
|
|
348
|
-
k_dens_np, *_ = np.histogram(school, bins=[-np.
|
|
349
|
-
k_count_np, *_ = np.histogram(school, bins=[-np.
|
|
348
|
+
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
|
|
349
|
+
k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
|
|
350
|
+
k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
|
|
350
351
|
assert np.allclose(k_count_az, k_count_np)
|
|
351
352
|
assert np.allclose(k_dens_az, k_dens_np)
|
|
352
353
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Tests for arviz.utils."""
|
|
2
|
+
|
|
2
3
|
# pylint: disable=redefined-outer-name, no-member
|
|
3
4
|
from unittest.mock import Mock
|
|
4
5
|
|
|
@@ -17,7 +18,7 @@ from ...utils import (
|
|
|
17
18
|
one_de,
|
|
18
19
|
two_de,
|
|
19
20
|
)
|
|
20
|
-
from ..helpers import
|
|
21
|
+
from ..helpers import RandomVariableTestClass
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
@pytest.fixture(scope="session")
|
|
@@ -123,7 +124,7 @@ def test_var_names_filter(var_args):
|
|
|
123
124
|
|
|
124
125
|
def test_nonstring_var_names():
|
|
125
126
|
"""Check that non-string variables are preserved"""
|
|
126
|
-
mu =
|
|
127
|
+
mu = RandomVariableTestClass("mu")
|
|
127
128
|
samples = np.random.randn(10)
|
|
128
129
|
data = dict_to_dataset({mu: samples})
|
|
129
130
|
assert _var_names([mu], data) == [mu]
|
|
@@ -101,8 +101,8 @@ class TestDataNumPyro:
|
|
|
101
101
|
assert not fails
|
|
102
102
|
|
|
103
103
|
# test dims
|
|
104
|
-
dims = inference_data.posterior_predictive.
|
|
105
|
-
pred_dims = inference_data.predictions.
|
|
104
|
+
dims = inference_data.posterior_predictive.sizes["school"]
|
|
105
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
106
106
|
assert dims == 8
|
|
107
107
|
assert pred_dims == 8
|
|
108
108
|
|
|
@@ -240,7 +240,7 @@ class TestDataNumPyro:
|
|
|
240
240
|
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
241
241
|
predictions = predictions_data
|
|
242
242
|
inference_data = from_numpyro(predictions=predictions, num_chains=chains)
|
|
243
|
-
nchains = inference_data.predictions.
|
|
243
|
+
nchains = inference_data.predictions.sizes["chain"]
|
|
244
244
|
assert nchains == chains
|
|
245
245
|
|
|
246
246
|
@pytest.mark.parametrize("nchains", [1, 2])
|
|
@@ -83,8 +83,8 @@ class TestDataPyro:
|
|
|
83
83
|
assert not fails
|
|
84
84
|
|
|
85
85
|
# test dims
|
|
86
|
-
dims = inference_data.posterior_predictive.
|
|
87
|
-
pred_dims = inference_data.predictions.
|
|
86
|
+
dims = inference_data.posterior_predictive.sizes["school"]
|
|
87
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
88
88
|
assert dims == 8
|
|
89
89
|
assert pred_dims == 8
|
|
90
90
|
|
|
@@ -225,7 +225,7 @@ class TestDataPyro:
|
|
|
225
225
|
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
226
226
|
predictions = predictions_data
|
|
227
227
|
inference_data = from_pyro(predictions=predictions, num_chains=chains)
|
|
228
|
-
nchains = inference_data.predictions.
|
|
228
|
+
nchains = inference_data.predictions.sizes["chain"]
|
|
229
229
|
assert nchains == chains
|
|
230
230
|
|
|
231
231
|
@pytest.mark.parametrize("log_likelihood", [True, False])
|
arviz/tests/helpers.py
CHANGED
arviz/wrappers/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.18.0
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -12,22 +12,23 @@ Classifier: Intended Audience :: Education
|
|
|
12
12
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
13
|
Classifier: Programming Language :: Python
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.10
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
18
|
Classifier: Topic :: Scientific/Engineering
|
|
19
19
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
20
20
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
21
|
-
Requires-Python: >=3.
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: setuptools >=60.0.0
|
|
25
25
|
Requires-Dist: matplotlib >=3.5
|
|
26
|
-
Requires-Dist: numpy <2.0,>=1.
|
|
27
|
-
Requires-Dist: scipy >=1.
|
|
26
|
+
Requires-Dist: numpy <2.0,>=1.23.0
|
|
27
|
+
Requires-Dist: scipy >=1.9.0
|
|
28
28
|
Requires-Dist: packaging
|
|
29
|
-
Requires-Dist: pandas >=1.
|
|
30
|
-
Requires-Dist:
|
|
29
|
+
Requires-Dist: pandas >=1.5.0
|
|
30
|
+
Requires-Dist: dm-tree >=0.1.8
|
|
31
|
+
Requires-Dist: xarray >=2022.6.0
|
|
31
32
|
Requires-Dist: h5netcdf >=1.0.2
|
|
32
33
|
Requires-Dist: typing-extensions >=4.1.0
|
|
33
34
|
Requires-Dist: xarray-einstats >=0.3
|
|
@@ -52,8 +53,7 @@ Requires-Dist: xarray-datatree ; extra == 'all'
|
|
|
52
53
|
[](https://doi.org/10.21105/joss.01143) [](https://doi.org/10.5281/zenodo.2540945)
|
|
53
54
|
[](https://numfocus.org)
|
|
54
55
|
|
|
55
|
-
ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models.
|
|
56
|
-
Includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
|
|
56
|
+
ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models. It includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
|
|
57
57
|
|
|
58
58
|
### ArviZ in other languages
|
|
59
59
|
ArviZ also has a Julia wrapper available [ArviZ.jl](https://julia.arviz.org/).
|
|
@@ -202,6 +202,7 @@ python setup.py install
|
|
|
202
202
|
|
|
203
203
|
<a href="https://python.arviz.org/en/latest/examples/index.html">And more...</a>
|
|
204
204
|
</div>
|
|
205
|
+
|
|
205
206
|
## Dependencies
|
|
206
207
|
|
|
207
208
|
ArviZ is tested on Python 3.10, 3.11 and 3.12, and depends on NumPy, SciPy, xarray, and Matplotlib.
|