arviz 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +49 -4
- arviz/data/converters.py +11 -0
- arviz/data/inference_data.py +46 -24
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +116 -5
- arviz/data/io_pyjags.py +1 -1
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/lmplot.py +19 -3
- arviz/plots/backends/bokeh/pairplot.py +18 -48
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/bpvplot.py +1 -1
- arviz/plots/dotplot.py +2 -0
- arviz/plots/forestplot.py +16 -4
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/plots/ppcplot.py +1 -1
- arviz/preview.py +31 -21
- arviz/rcparams.py +2 -2
- arviz/stats/density_utils.py +1 -1
- arviz/stats/stats.py +31 -34
- arviz/tests/base_tests/test_data.py +25 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
- arviz/tests/base_tests/test_stats.py +42 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
- arviz/tests/external_tests/test_data_numpyro.py +154 -4
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/METADATA +20 -9
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/RECORD +37 -37
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/WHEEL +1 -1
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info/licenses}/LICENSE +0 -0
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
|
|
|
8
8
|
from scipy.stats import norm # pylint: disable=wrong-import-position
|
|
9
9
|
|
|
10
10
|
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
|
|
11
|
+
from ...labels import MapLabeller # pylint: disable=wrong-import-position
|
|
11
12
|
from ...plots import ( # pylint: disable=wrong-import-position
|
|
12
13
|
plot_autocorr,
|
|
13
14
|
plot_bpv,
|
|
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
|
|
|
773
774
|
{"divergences": True, "var_names": ["theta", "mu"]},
|
|
774
775
|
{"kind": "kde", "var_names": ["theta"]},
|
|
775
776
|
{"kind": "hexbin", "var_names": ["theta"]},
|
|
776
|
-
{"kind": "hexbin", "var_names": ["theta"]},
|
|
777
777
|
{
|
|
778
778
|
"kind": "hexbin",
|
|
779
779
|
"var_names": ["theta"],
|
|
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
|
|
|
785
785
|
"reference_values": {"mu": 0, "tau": 0},
|
|
786
786
|
"reference_values_kwargs": {"line_color": "blue"},
|
|
787
787
|
},
|
|
788
|
+
{
|
|
789
|
+
"var_names": ["mu", "tau"],
|
|
790
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
791
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
792
|
+
},
|
|
793
|
+
{
|
|
794
|
+
"var_names": ["theta"],
|
|
795
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
796
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
797
|
+
},
|
|
798
|
+
{
|
|
799
|
+
"var_names": ["theta"],
|
|
800
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
801
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
802
|
+
},
|
|
788
803
|
],
|
|
789
804
|
)
|
|
790
805
|
def test_plot_pair(models, kwargs):
|
|
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
|
|
|
1201
1216
|
},
|
|
1202
1217
|
],
|
|
1203
1218
|
)
|
|
1204
|
-
def
|
|
1219
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1205
1220
|
"""Test functionality for 1D data."""
|
|
1206
1221
|
idata = models.model_1
|
|
1207
1222
|
if "constant_data" not in idata.groups():
|
|
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
|
|
|
1228
1243
|
"""Test the plots when input data is list or ndarray."""
|
|
1229
1244
|
y = [1, 2, 3, 4, 5]
|
|
1230
1245
|
assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def generate_lm_1d_data():
|
|
1249
|
+
rng = np.random.default_rng()
|
|
1250
|
+
return from_dict(
|
|
1251
|
+
observed_data={"y": rng.normal(size=7)},
|
|
1252
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
1253
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
1254
|
+
dims={"y": ["dim1"]},
|
|
1255
|
+
coords={"dim1": range(7)},
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def generate_lm_2d_data():
|
|
1260
|
+
rng = np.random.default_rng()
|
|
1261
|
+
return from_dict(
|
|
1262
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
1263
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
1264
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
1265
|
+
dims={"y": ["dim1", "dim2"]},
|
|
1266
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
1271
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
1272
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
1273
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
1274
|
+
if data == "1d":
|
|
1275
|
+
idata = generate_lm_1d_data()
|
|
1276
|
+
else:
|
|
1277
|
+
idata = generate_lm_2d_data()
|
|
1278
|
+
|
|
1279
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
|
|
1280
|
+
if data == "2d":
|
|
1281
|
+
kwargs["plot_dim"] = "dim1"
|
|
1282
|
+
if use_y_model:
|
|
1283
|
+
kwargs["y_model"] = "y_model"
|
|
1284
|
+
if kind == "lines":
|
|
1285
|
+
kwargs["num_samples"] = 50
|
|
1286
|
+
|
|
1287
|
+
ax = plot_lm(**kwargs)
|
|
1288
|
+
assert ax is not None
|
|
@@ -14,6 +14,7 @@ from pandas import DataFrame
|
|
|
14
14
|
from scipy.stats import gaussian_kde, norm
|
|
15
15
|
|
|
16
16
|
from ...data import from_dict, load_arviz_data
|
|
17
|
+
from ...labels import MapLabeller
|
|
17
18
|
from ...plots import (
|
|
18
19
|
plot_autocorr,
|
|
19
20
|
plot_bf,
|
|
@@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
|
|
|
599
600
|
"reference_values": {"mu": 0, "tau": 0},
|
|
600
601
|
"reference_values_kwargs": {"c": "C0", "marker": "*"},
|
|
601
602
|
},
|
|
603
|
+
{
|
|
604
|
+
"var_names": ["mu", "tau"],
|
|
605
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
606
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
607
|
+
},
|
|
608
|
+
{
|
|
609
|
+
"var_names": ["theta"],
|
|
610
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
611
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
612
|
+
},
|
|
613
|
+
{
|
|
614
|
+
"var_names": ["theta"],
|
|
615
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
616
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
617
|
+
},
|
|
602
618
|
],
|
|
603
619
|
)
|
|
604
620
|
def test_plot_pair(models, kwargs):
|
|
@@ -1914,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
|
|
|
1914
1930
|
},
|
|
1915
1931
|
],
|
|
1916
1932
|
)
|
|
1917
|
-
def
|
|
1933
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1918
1934
|
"""Test functionality for 1D data."""
|
|
1919
1935
|
idata = models.model_1
|
|
1920
1936
|
if "constant_data" not in idata.groups():
|
|
@@ -2102,3 +2118,80 @@ def test_plot_bf():
|
|
|
2102
2118
|
)
|
|
2103
2119
|
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
|
|
2104
2120
|
assert bf_plot is not None
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
def generate_lm_1d_data():
|
|
2124
|
+
rng = np.random.default_rng()
|
|
2125
|
+
return from_dict(
|
|
2126
|
+
observed_data={"y": rng.normal(size=7)},
|
|
2127
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
2128
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
2129
|
+
dims={"y": ["dim1"]},
|
|
2130
|
+
coords={"dim1": range(7)},
|
|
2131
|
+
)
|
|
2132
|
+
|
|
2133
|
+
|
|
2134
|
+
def generate_lm_2d_data():
|
|
2135
|
+
rng = np.random.default_rng()
|
|
2136
|
+
return from_dict(
|
|
2137
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
2138
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
2139
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
2140
|
+
dims={"y": ["dim1", "dim2"]},
|
|
2141
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
2142
|
+
)
|
|
2143
|
+
|
|
2144
|
+
|
|
2145
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
2146
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
2147
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
2148
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
2149
|
+
if data == "1d":
|
|
2150
|
+
idata = generate_lm_1d_data()
|
|
2151
|
+
else:
|
|
2152
|
+
idata = generate_lm_2d_data()
|
|
2153
|
+
|
|
2154
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
|
|
2155
|
+
if data == "2d":
|
|
2156
|
+
kwargs["plot_dim"] = "dim1"
|
|
2157
|
+
if use_y_model:
|
|
2158
|
+
kwargs["y_model"] = "y_model"
|
|
2159
|
+
if kind == "lines":
|
|
2160
|
+
kwargs["num_samples"] = 50
|
|
2161
|
+
|
|
2162
|
+
ax = plot_lm(**kwargs)
|
|
2163
|
+
assert ax is not None
|
|
2164
|
+
|
|
2165
|
+
|
|
2166
|
+
@pytest.mark.parametrize(
|
|
2167
|
+
"coords, expected_vars",
|
|
2168
|
+
[
|
|
2169
|
+
({"school": ["Choate"]}, ["theta"]),
|
|
2170
|
+
({"school": ["Lawrenceville"]}, ["theta"]),
|
|
2171
|
+
({}, ["theta"]),
|
|
2172
|
+
],
|
|
2173
|
+
)
|
|
2174
|
+
def test_plot_autocorr_coords(coords, expected_vars):
|
|
2175
|
+
"""Test plot_autocorr with coords kwarg."""
|
|
2176
|
+
idata = load_arviz_data("centered_eight")
|
|
2177
|
+
|
|
2178
|
+
axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
|
|
2179
|
+
assert axes is not None
|
|
2180
|
+
|
|
2181
|
+
|
|
2182
|
+
def test_plot_forest_with_transform():
|
|
2183
|
+
"""Test if plot_forest runs successfully with a transform dictionary."""
|
|
2184
|
+
data = xr.Dataset(
|
|
2185
|
+
{
|
|
2186
|
+
"var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])),
|
|
2187
|
+
"var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])),
|
|
2188
|
+
},
|
|
2189
|
+
coords={"chain": [0, 1], "draw": [0, 1, 2]},
|
|
2190
|
+
)
|
|
2191
|
+
transform_dict = {
|
|
2192
|
+
"var1": lambda x: x + 1,
|
|
2193
|
+
"var2": lambda x: x * 2,
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
axes = plot_forest(data, transform=transform_dict, show=False)
|
|
2197
|
+
assert axes is not None
|
|
@@ -14,7 +14,7 @@ from scipy.stats import linregress, norm, halfcauchy
|
|
|
14
14
|
from xarray import DataArray, Dataset
|
|
15
15
|
from xarray_einstats.stats import XrContinuousRV
|
|
16
16
|
|
|
17
|
-
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
|
|
17
|
+
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
|
|
18
18
|
from ...rcparams import rcParams
|
|
19
19
|
from ...stats import (
|
|
20
20
|
apply_test_function,
|
|
@@ -882,3 +882,44 @@ def test_bayes_factor():
|
|
|
882
882
|
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
|
|
883
883
|
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
884
884
|
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def test_compare_sorting_consistency():
|
|
888
|
+
chains, draws = 4, 1000
|
|
889
|
+
|
|
890
|
+
# Model 1 - good fit
|
|
891
|
+
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
|
|
892
|
+
posterior1 = Dataset(
|
|
893
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
894
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
895
|
+
)
|
|
896
|
+
log_like1 = Dataset(
|
|
897
|
+
{"y": (("chain", "draw"), log_lik1)},
|
|
898
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
899
|
+
)
|
|
900
|
+
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
|
|
901
|
+
|
|
902
|
+
# Model 2 - poor fit (higher variance)
|
|
903
|
+
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
|
|
904
|
+
posterior2 = Dataset(
|
|
905
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
906
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
907
|
+
)
|
|
908
|
+
log_like2 = Dataset(
|
|
909
|
+
{"y": (("chain", "draw"), log_lik2)},
|
|
910
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
911
|
+
)
|
|
912
|
+
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
|
|
913
|
+
|
|
914
|
+
# Compare models in different orders
|
|
915
|
+
comp_dict1 = {"M1": data1, "M2": data2}
|
|
916
|
+
comp_dict2 = {"M2": data2, "M1": data1}
|
|
917
|
+
|
|
918
|
+
comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
|
|
919
|
+
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
|
|
920
|
+
|
|
921
|
+
assert comparison1.index.tolist() == comparison2.index.tolist()
|
|
922
|
+
|
|
923
|
+
se1 = comparison1["se"].values
|
|
924
|
+
se2 = comparison2["se"].values
|
|
925
|
+
np.testing.assert_array_almost_equal(se1, se2)
|
|
@@ -13,9 +13,9 @@ from ...stats.ecdf_utils import (
|
|
|
13
13
|
try:
|
|
14
14
|
import numba # pylint: disable=unused-import
|
|
15
15
|
|
|
16
|
-
numba_options = [True, False]
|
|
16
|
+
numba_options = [True, False] # pylint: disable=invalid-name
|
|
17
17
|
except ImportError:
|
|
18
|
-
numba_options = [False]
|
|
18
|
+
numba_options = [False] # pylint: disable=invalid-name
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def test_compute_ecdf():
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name
|
|
1
|
+
# pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-public-methods
|
|
2
2
|
from collections import namedtuple
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pytest
|
|
@@ -46,7 +46,9 @@ class TestDataNumPyro:
|
|
|
46
46
|
)
|
|
47
47
|
return predictions
|
|
48
48
|
|
|
49
|
-
def get_inference_data(
|
|
49
|
+
def get_inference_data(
|
|
50
|
+
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
|
|
51
|
+
):
|
|
50
52
|
posterior_samples = data.obj.get_samples()
|
|
51
53
|
model = data.obj.sampler.model
|
|
52
54
|
posterior_predictive = Predictive(model, posterior_samples)(
|
|
@@ -55,6 +57,12 @@ class TestDataNumPyro:
|
|
|
55
57
|
prior = Predictive(model, num_samples=500)(
|
|
56
58
|
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
57
59
|
)
|
|
60
|
+
dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
|
|
61
|
+
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
|
|
62
|
+
if infer_dims:
|
|
63
|
+
dims = None
|
|
64
|
+
pred_dims = None
|
|
65
|
+
|
|
58
66
|
predictions = predictions_data
|
|
59
67
|
return from_numpyro(
|
|
60
68
|
posterior=data.obj,
|
|
@@ -65,8 +73,8 @@ class TestDataNumPyro:
|
|
|
65
73
|
"school": np.arange(eight_schools_params["J"]),
|
|
66
74
|
"school_pred": np.arange(predictions_params["J"]),
|
|
67
75
|
},
|
|
68
|
-
dims=
|
|
69
|
-
pred_dims=
|
|
76
|
+
dims=dims,
|
|
77
|
+
pred_dims=pred_dims,
|
|
70
78
|
)
|
|
71
79
|
|
|
72
80
|
def test_inference_data_namedtuple(self, data):
|
|
@@ -77,6 +85,7 @@ class TestDataNumPyro:
|
|
|
77
85
|
data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
|
|
78
86
|
inference_data = from_numpyro(
|
|
79
87
|
posterior=data.obj,
|
|
88
|
+
dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
|
|
80
89
|
)
|
|
81
90
|
assert isinstance(data.obj.get_samples(), Samples)
|
|
82
91
|
data.obj.get_samples = _old_fn
|
|
@@ -282,3 +291,144 @@ class TestDataNumPyro:
|
|
|
282
291
|
mcmc.run(PRNGKey(0))
|
|
283
292
|
inference_data = from_numpyro(mcmc)
|
|
284
293
|
assert inference_data.observed_data
|
|
294
|
+
|
|
295
|
+
def test_mcmc_infer_dims(self):
|
|
296
|
+
import numpyro
|
|
297
|
+
import numpyro.distributions as dist
|
|
298
|
+
from numpyro.infer import MCMC, NUTS
|
|
299
|
+
|
|
300
|
+
def model():
|
|
301
|
+
# note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
|
|
302
|
+
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
|
|
303
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
304
|
+
|
|
305
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
306
|
+
mcmc.run(PRNGKey(0))
|
|
307
|
+
inference_data = from_numpyro(
|
|
308
|
+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
309
|
+
)
|
|
310
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
|
|
311
|
+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
312
|
+
|
|
313
|
+
def test_mcmc_infer_unsorted_dims(self):
|
|
314
|
+
import numpyro
|
|
315
|
+
import numpyro.distributions as dist
|
|
316
|
+
from numpyro.infer import MCMC, NUTS
|
|
317
|
+
|
|
318
|
+
def model():
|
|
319
|
+
group1_plate = numpyro.plate("group1", 10, dim=-1)
|
|
320
|
+
group2_plate = numpyro.plate("group2", 5, dim=-2)
|
|
321
|
+
|
|
322
|
+
# the plate contexts are entered in a different order than the pre-defined dims
|
|
323
|
+
# we should make sure this still works because the trace has all of the info it needs
|
|
324
|
+
with group2_plate, group1_plate:
|
|
325
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
326
|
+
|
|
327
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
328
|
+
mcmc.run(PRNGKey(0))
|
|
329
|
+
inference_data = from_numpyro(
|
|
330
|
+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
331
|
+
)
|
|
332
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
|
|
333
|
+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
334
|
+
|
|
335
|
+
def test_mcmc_infer_dims_no_coords(self):
|
|
336
|
+
import numpyro
|
|
337
|
+
import numpyro.distributions as dist
|
|
338
|
+
from numpyro.infer import MCMC, NUTS
|
|
339
|
+
|
|
340
|
+
def model():
|
|
341
|
+
with numpyro.plate("group", 5):
|
|
342
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
343
|
+
|
|
344
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
345
|
+
mcmc.run(PRNGKey(0))
|
|
346
|
+
inference_data = from_numpyro(mcmc)
|
|
347
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group")
|
|
348
|
+
|
|
349
|
+
def test_mcmc_event_dims(self):
|
|
350
|
+
import numpyro
|
|
351
|
+
import numpyro.distributions as dist
|
|
352
|
+
from numpyro.infer import MCMC, NUTS
|
|
353
|
+
|
|
354
|
+
def model():
|
|
355
|
+
_ = numpyro.sample(
|
|
356
|
+
"gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
360
|
+
mcmc.run(PRNGKey(0))
|
|
361
|
+
inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
|
|
362
|
+
assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
|
|
363
|
+
assert "groups" in inference_data.posterior.gamma.coords
|
|
364
|
+
|
|
365
|
+
@pytest.mark.xfail
|
|
366
|
+
def test_mcmc_inferred_dims_univariate(self):
|
|
367
|
+
import numpyro
|
|
368
|
+
import numpyro.distributions as dist
|
|
369
|
+
from numpyro.infer import MCMC, NUTS
|
|
370
|
+
import jax.numpy as jnp
|
|
371
|
+
|
|
372
|
+
def model():
|
|
373
|
+
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
|
|
374
|
+
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
|
|
375
|
+
with numpyro.plate("obs_idx", 3):
|
|
376
|
+
# mu is plated by obs_idx, but isnt broadcasted to the plate shape
|
|
377
|
+
# the expected behavior is that this should cause a failure
|
|
378
|
+
mu = numpyro.deterministic("mu", alpha)
|
|
379
|
+
return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
|
|
380
|
+
|
|
381
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
382
|
+
mcmc.run(PRNGKey(0))
|
|
383
|
+
inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
|
|
384
|
+
assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
|
|
385
|
+
assert "obs_idx" in inference_data.posterior.mu.coords
|
|
386
|
+
|
|
387
|
+
def test_mcmc_extra_event_dims(self):
|
|
388
|
+
import numpyro
|
|
389
|
+
import numpyro.distributions as dist
|
|
390
|
+
from numpyro.infer import MCMC, NUTS
|
|
391
|
+
|
|
392
|
+
def model():
|
|
393
|
+
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
|
|
394
|
+
_ = numpyro.deterministic("gamma_plus1", gamma + 1)
|
|
395
|
+
|
|
396
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
397
|
+
mcmc.run(PRNGKey(0))
|
|
398
|
+
inference_data = from_numpyro(
|
|
399
|
+
mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
|
|
400
|
+
)
|
|
401
|
+
assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
|
|
402
|
+
assert "groups" in inference_data.posterior.gamma_plus1.coords
|
|
403
|
+
|
|
404
|
+
def test_mcmc_predictions_infer_dims(
|
|
405
|
+
self, data, eight_schools_params, predictions_data, predictions_params
|
|
406
|
+
):
|
|
407
|
+
inference_data = self.get_inference_data(
|
|
408
|
+
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
|
|
409
|
+
)
|
|
410
|
+
assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
|
|
411
|
+
assert "J" in inference_data.predictions.obs.coords
|
|
412
|
+
|
|
413
|
+
def test_potential_energy_sign_conversion(self):
|
|
414
|
+
"""Test that potential energy is converted to log probability (lp) with correct sign."""
|
|
415
|
+
import numpyro
|
|
416
|
+
import numpyro.distributions as dist
|
|
417
|
+
from numpyro.infer import MCMC, NUTS
|
|
418
|
+
|
|
419
|
+
num_samples = 10
|
|
420
|
+
|
|
421
|
+
def simple_model():
|
|
422
|
+
numpyro.sample("x", dist.Normal(0, 1))
|
|
423
|
+
|
|
424
|
+
nuts_kernel = NUTS(simple_model)
|
|
425
|
+
mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=5)
|
|
426
|
+
mcmc.run(PRNGKey(0), extra_fields=["potential_energy"])
|
|
427
|
+
|
|
428
|
+
# Get the raw extra fields from NumPyro
|
|
429
|
+
extra_fields = mcmc.get_extra_fields(group_by_chain=True)
|
|
430
|
+
# Convert to ArviZ InferenceData
|
|
431
|
+
inference_data = from_numpyro(mcmc)
|
|
432
|
+
arviz_lp = inference_data["sample_stats"]["lp"].values
|
|
433
|
+
|
|
434
|
+
np.testing.assert_array_equal(arviz_lp, -extra_fields["potential_energy"])
|
arviz/wrappers/base.py
CHANGED
|
@@ -197,7 +197,7 @@ class SamplingWrapper:
|
|
|
197
197
|
"""Check that all methods listed are implemented.
|
|
198
198
|
|
|
199
199
|
Not all functions that require refitting need to have all the methods implemented in
|
|
200
|
-
order to work properly. This function
|
|
200
|
+
order to work properly. This function should be used before using the SamplingWrapper and
|
|
201
201
|
its subclasses to get informative error messages.
|
|
202
202
|
|
|
203
203
|
Parameters
|
arviz/wrappers/wrap_stan.py
CHANGED
|
@@ -44,7 +44,7 @@ class StanSamplingWrapper(SamplingWrapper):
|
|
|
44
44
|
excluded_observed_data : str
|
|
45
45
|
Variable name containing the pointwise log likelihood data of the excluded
|
|
46
46
|
data. As PyStan cannot call C++ functions and log_likelihood__i is already
|
|
47
|
-
calculated *during* the
|
|
47
|
+
calculated *during* the simulation, instead of the value on which to evaluate
|
|
48
48
|
the likelihood, ``log_likelihood__i`` expects a string so it can extract the
|
|
49
49
|
corresponding data from the InferenceData object.
|
|
50
50
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.23.0
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -22,14 +22,14 @@ Requires-Python: >=3.10
|
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: setuptools>=60.0.0
|
|
25
|
-
Requires-Dist: matplotlib>=3.
|
|
26
|
-
Requires-Dist: numpy>=1.
|
|
27
|
-
Requires-Dist: scipy>=1.
|
|
25
|
+
Requires-Dist: matplotlib>=3.8
|
|
26
|
+
Requires-Dist: numpy>=1.26.0
|
|
27
|
+
Requires-Dist: scipy>=1.11.0
|
|
28
28
|
Requires-Dist: packaging
|
|
29
|
-
Requires-Dist: pandas>=1.
|
|
30
|
-
Requires-Dist: xarray>=
|
|
29
|
+
Requires-Dist: pandas>=2.1.0
|
|
30
|
+
Requires-Dist: xarray>=2023.7.0
|
|
31
31
|
Requires-Dist: h5netcdf>=1.0.2
|
|
32
|
-
Requires-Dist:
|
|
32
|
+
Requires-Dist: typing_extensions>=4.1.0
|
|
33
33
|
Requires-Dist: xarray-einstats>=0.3
|
|
34
34
|
Provides-Extra: all
|
|
35
35
|
Requires-Dist: numba; extra == "all"
|
|
@@ -39,12 +39,23 @@ Requires-Dist: contourpy; extra == "all"
|
|
|
39
39
|
Requires-Dist: ujson; extra == "all"
|
|
40
40
|
Requires-Dist: dask[distributed]; extra == "all"
|
|
41
41
|
Requires-Dist: zarr<3,>=2.5.0; extra == "all"
|
|
42
|
-
Requires-Dist: xarray
|
|
42
|
+
Requires-Dist: xarray>=2024.11.0; extra == "all"
|
|
43
43
|
Requires-Dist: dm-tree>=0.1.8; extra == "all"
|
|
44
44
|
Provides-Extra: preview
|
|
45
45
|
Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
|
|
46
46
|
Requires-Dist: arviz-stats[xarray]; extra == "preview"
|
|
47
47
|
Requires-Dist: arviz-plots; extra == "preview"
|
|
48
|
+
Dynamic: author
|
|
49
|
+
Dynamic: classifier
|
|
50
|
+
Dynamic: description
|
|
51
|
+
Dynamic: description-content-type
|
|
52
|
+
Dynamic: home-page
|
|
53
|
+
Dynamic: license
|
|
54
|
+
Dynamic: license-file
|
|
55
|
+
Dynamic: provides-extra
|
|
56
|
+
Dynamic: requires-dist
|
|
57
|
+
Dynamic: requires-python
|
|
58
|
+
Dynamic: summary
|
|
48
59
|
|
|
49
60
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
|
|
50
61
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ_white.png#gh-dark-mode-only" width=200></img>
|