arviz 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +8 -3
- arviz/data/base.py +2 -2
- arviz/data/inference_data.py +57 -26
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +112 -4
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/__init__.py +8 -7
- arviz/plots/backends/bokeh/bpvplot.py +4 -3
- arviz/plots/backends/bokeh/densityplot.py +5 -1
- arviz/plots/backends/bokeh/dotplot.py +5 -2
- arviz/plots/backends/bokeh/essplot.py +4 -2
- arviz/plots/backends/bokeh/forestplot.py +11 -4
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/khatplot.py +4 -2
- arviz/plots/backends/bokeh/lmplot.py +28 -6
- arviz/plots/backends/bokeh/mcseplot.py +2 -2
- arviz/plots/backends/bokeh/pairplot.py +27 -52
- arviz/plots/backends/bokeh/ppcplot.py +2 -1
- arviz/plots/backends/bokeh/rankplot.py +2 -1
- arviz/plots/backends/bokeh/traceplot.py +2 -1
- arviz/plots/backends/bokeh/violinplot.py +2 -1
- arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/bfplot.py +9 -26
- arviz/plots/bpvplot.py +10 -1
- arviz/plots/hdiplot.py +5 -0
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/plots/plot_utils.py +5 -3
- arviz/preview.py +36 -5
- arviz/stats/__init__.py +1 -0
- arviz/stats/density_utils.py +1 -1
- arviz/stats/diagnostics.py +18 -14
- arviz/stats/stats.py +105 -7
- arviz/tests/base_tests/test_data.py +31 -11
- arviz/tests/base_tests/test_diagnostics.py +5 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
- arviz/tests/base_tests/test_stats.py +53 -1
- arviz/tests/external_tests/test_data_numpyro.py +130 -3
- arviz/utils.py +4 -0
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/RECORD +51 -51
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -2,21 +2,23 @@
|
|
|
2
2
|
|
|
3
3
|
# pylint: disable=redefined-outer-name,too-many-lines
|
|
4
4
|
import os
|
|
5
|
+
import re
|
|
5
6
|
from copy import deepcopy
|
|
6
7
|
|
|
7
8
|
import matplotlib.pyplot as plt
|
|
8
9
|
import numpy as np
|
|
9
10
|
import pytest
|
|
11
|
+
import xarray as xr
|
|
10
12
|
from matplotlib import animation
|
|
11
13
|
from pandas import DataFrame
|
|
12
14
|
from scipy.stats import gaussian_kde, norm
|
|
13
|
-
import xarray as xr
|
|
14
15
|
|
|
15
16
|
from ...data import from_dict, load_arviz_data
|
|
17
|
+
from ...labels import MapLabeller
|
|
16
18
|
from ...plots import (
|
|
17
19
|
plot_autocorr,
|
|
18
|
-
plot_bpv,
|
|
19
20
|
plot_bf,
|
|
21
|
+
plot_bpv,
|
|
20
22
|
plot_compare,
|
|
21
23
|
plot_density,
|
|
22
24
|
plot_dist,
|
|
@@ -43,20 +45,20 @@ from ...plots import (
|
|
|
43
45
|
plot_ts,
|
|
44
46
|
plot_violin,
|
|
45
47
|
)
|
|
48
|
+
from ...plots.dotplot import wilkinson_algorithm
|
|
49
|
+
from ...plots.plot_utils import plot_point_interval
|
|
46
50
|
from ...rcparams import rc_context, rcParams
|
|
47
51
|
from ...stats import compare, hdi, loo, waic
|
|
48
52
|
from ...stats.density_utils import kde as _kde
|
|
49
|
-
from ...utils import
|
|
50
|
-
from ...plots.plot_utils import plot_point_interval
|
|
51
|
-
from ...plots.dotplot import wilkinson_algorithm
|
|
53
|
+
from ...utils import BehaviourChangeWarning, _cov
|
|
52
54
|
from ..helpers import ( # pylint: disable=unused-import
|
|
55
|
+
RandomVariableTestClass,
|
|
53
56
|
create_model,
|
|
54
57
|
create_multidimensional_model,
|
|
55
58
|
does_not_warn,
|
|
56
59
|
eight_schools_params,
|
|
57
60
|
models,
|
|
58
61
|
multidim_models,
|
|
59
|
-
RandomVariableTestClass,
|
|
60
62
|
)
|
|
61
63
|
|
|
62
64
|
rcParams["data.load"] = "eager"
|
|
@@ -598,6 +600,21 @@ def test_plot_kde_inference_data(models):
|
|
|
598
600
|
"reference_values": {"mu": 0, "tau": 0},
|
|
599
601
|
"reference_values_kwargs": {"c": "C0", "marker": "*"},
|
|
600
602
|
},
|
|
603
|
+
{
|
|
604
|
+
"var_names": ["mu", "tau"],
|
|
605
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
606
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
607
|
+
},
|
|
608
|
+
{
|
|
609
|
+
"var_names": ["theta"],
|
|
610
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
611
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
612
|
+
},
|
|
613
|
+
{
|
|
614
|
+
"var_names": ["theta"],
|
|
615
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
616
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
617
|
+
},
|
|
601
618
|
],
|
|
602
619
|
)
|
|
603
620
|
def test_plot_pair(models, kwargs):
|
|
@@ -1236,6 +1253,23 @@ def test_plot_hdi_dataset_error(models):
|
|
|
1236
1253
|
plot_hdi(np.arange(8), hdi_data=hdi_data)
|
|
1237
1254
|
|
|
1238
1255
|
|
|
1256
|
+
def test_plot_hdi_string_error():
|
|
1257
|
+
"""Check x as type string raises an error."""
|
|
1258
|
+
x_data = ["a", "b", "c", "d"]
|
|
1259
|
+
y_data = np.random.normal(0, 5, (1, 200, len(x_data)))
|
|
1260
|
+
hdi_data = hdi(y_data)
|
|
1261
|
+
with pytest.raises(
|
|
1262
|
+
NotImplementedError,
|
|
1263
|
+
match=re.escape(
|
|
1264
|
+
(
|
|
1265
|
+
"The `arviz.plot_hdi()` function does not support categorical data. "
|
|
1266
|
+
"Consider using `arviz.plot_forest()`."
|
|
1267
|
+
)
|
|
1268
|
+
),
|
|
1269
|
+
):
|
|
1270
|
+
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
|
|
1271
|
+
|
|
1272
|
+
|
|
1239
1273
|
def test_plot_hdi_datetime_error():
|
|
1240
1274
|
"""Check x as datetime raises an error."""
|
|
1241
1275
|
x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)
|
|
@@ -1896,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
|
|
|
1896
1930
|
},
|
|
1897
1931
|
],
|
|
1898
1932
|
)
|
|
1899
|
-
def
|
|
1933
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1900
1934
|
"""Test functionality for 1D data."""
|
|
1901
1935
|
idata = models.model_1
|
|
1902
1936
|
if "constant_data" not in idata.groups():
|
|
@@ -2082,7 +2116,65 @@ def test_plot_bf():
|
|
|
2082
2116
|
idata = from_dict(
|
|
2083
2117
|
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
2084
2118
|
)
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2119
|
+
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
|
|
2120
|
+
assert bf_plot is not None
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
def generate_lm_1d_data():
|
|
2124
|
+
rng = np.random.default_rng()
|
|
2125
|
+
return from_dict(
|
|
2126
|
+
observed_data={"y": rng.normal(size=7)},
|
|
2127
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
2128
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
2129
|
+
dims={"y": ["dim1"]},
|
|
2130
|
+
coords={"dim1": range(7)},
|
|
2131
|
+
)
|
|
2132
|
+
|
|
2133
|
+
|
|
2134
|
+
def generate_lm_2d_data():
|
|
2135
|
+
rng = np.random.default_rng()
|
|
2136
|
+
return from_dict(
|
|
2137
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
2138
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
2139
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
2140
|
+
dims={"y": ["dim1", "dim2"]},
|
|
2141
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
2142
|
+
)
|
|
2143
|
+
|
|
2144
|
+
|
|
2145
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
2146
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
2147
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
2148
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
2149
|
+
if data == "1d":
|
|
2150
|
+
idata = generate_lm_1d_data()
|
|
2151
|
+
else:
|
|
2152
|
+
idata = generate_lm_2d_data()
|
|
2153
|
+
|
|
2154
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
|
|
2155
|
+
if data == "2d":
|
|
2156
|
+
kwargs["plot_dim"] = "dim1"
|
|
2157
|
+
if use_y_model:
|
|
2158
|
+
kwargs["y_model"] = "y_model"
|
|
2159
|
+
if kind == "lines":
|
|
2160
|
+
kwargs["num_samples"] = 50
|
|
2161
|
+
|
|
2162
|
+
ax = plot_lm(**kwargs)
|
|
2163
|
+
assert ax is not None
|
|
2164
|
+
|
|
2165
|
+
|
|
2166
|
+
@pytest.mark.parametrize(
|
|
2167
|
+
"coords, expected_vars",
|
|
2168
|
+
[
|
|
2169
|
+
({"school": ["Choate"]}, ["theta"]),
|
|
2170
|
+
({"school": ["Lawrenceville"]}, ["theta"]),
|
|
2171
|
+
({}, ["theta"]),
|
|
2172
|
+
],
|
|
2173
|
+
)
|
|
2174
|
+
def test_plot_autocorr_coords(coords, expected_vars):
|
|
2175
|
+
"""Test plot_autocorr with coords kwarg."""
|
|
2176
|
+
idata = load_arviz_data("centered_eight")
|
|
2177
|
+
|
|
2178
|
+
axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
|
|
2179
|
+
|
|
2180
|
+
assert axes is not None
|
|
@@ -14,10 +14,11 @@ from scipy.stats import linregress, norm, halfcauchy
|
|
|
14
14
|
from xarray import DataArray, Dataset
|
|
15
15
|
from xarray_einstats.stats import XrContinuousRV
|
|
16
16
|
|
|
17
|
-
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
|
|
17
|
+
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
|
|
18
18
|
from ...rcparams import rcParams
|
|
19
19
|
from ...stats import (
|
|
20
20
|
apply_test_function,
|
|
21
|
+
bayes_factor,
|
|
21
22
|
compare,
|
|
22
23
|
ess,
|
|
23
24
|
hdi,
|
|
@@ -871,3 +872,54 @@ def test_priorsens_coords(psens_data):
|
|
|
871
872
|
assert "mu" in result
|
|
872
873
|
assert "theta" in result
|
|
873
874
|
assert "school" in result.theta_t.dims
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
def test_bayes_factor():
|
|
878
|
+
idata = from_dict(
|
|
879
|
+
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
880
|
+
)
|
|
881
|
+
bf_dict0 = bayes_factor(idata, var_name="a", ref_val=0)
|
|
882
|
+
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
|
|
883
|
+
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
884
|
+
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def test_compare_sorting_consistency():
|
|
888
|
+
chains, draws = 4, 1000
|
|
889
|
+
|
|
890
|
+
# Model 1 - good fit
|
|
891
|
+
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
|
|
892
|
+
posterior1 = Dataset(
|
|
893
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
894
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
895
|
+
)
|
|
896
|
+
log_like1 = Dataset(
|
|
897
|
+
{"y": (("chain", "draw"), log_lik1)},
|
|
898
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
899
|
+
)
|
|
900
|
+
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
|
|
901
|
+
|
|
902
|
+
# Model 2 - poor fit (higher variance)
|
|
903
|
+
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
|
|
904
|
+
posterior2 = Dataset(
|
|
905
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
906
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
907
|
+
)
|
|
908
|
+
log_like2 = Dataset(
|
|
909
|
+
{"y": (("chain", "draw"), log_lik2)},
|
|
910
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
911
|
+
)
|
|
912
|
+
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
|
|
913
|
+
|
|
914
|
+
# Compare models in different orders
|
|
915
|
+
comp_dict1 = {"M1": data1, "M2": data2}
|
|
916
|
+
comp_dict2 = {"M2": data2, "M1": data1}
|
|
917
|
+
|
|
918
|
+
comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
|
|
919
|
+
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
|
|
920
|
+
|
|
921
|
+
assert comparison1.index.tolist() == comparison2.index.tolist()
|
|
922
|
+
|
|
923
|
+
se1 = comparison1["se"].values
|
|
924
|
+
se2 = comparison2["se"].values
|
|
925
|
+
np.testing.assert_array_almost_equal(se1, se2)
|
|
@@ -46,7 +46,9 @@ class TestDataNumPyro:
|
|
|
46
46
|
)
|
|
47
47
|
return predictions
|
|
48
48
|
|
|
49
|
-
def get_inference_data(
|
|
49
|
+
def get_inference_data(
|
|
50
|
+
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
|
|
51
|
+
):
|
|
50
52
|
posterior_samples = data.obj.get_samples()
|
|
51
53
|
model = data.obj.sampler.model
|
|
52
54
|
posterior_predictive = Predictive(model, posterior_samples)(
|
|
@@ -55,6 +57,12 @@ class TestDataNumPyro:
|
|
|
55
57
|
prior = Predictive(model, num_samples=500)(
|
|
56
58
|
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
|
|
57
59
|
)
|
|
60
|
+
dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
|
|
61
|
+
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
|
|
62
|
+
if infer_dims:
|
|
63
|
+
dims = None
|
|
64
|
+
pred_dims = None
|
|
65
|
+
|
|
58
66
|
predictions = predictions_data
|
|
59
67
|
return from_numpyro(
|
|
60
68
|
posterior=data.obj,
|
|
@@ -65,8 +73,8 @@ class TestDataNumPyro:
|
|
|
65
73
|
"school": np.arange(eight_schools_params["J"]),
|
|
66
74
|
"school_pred": np.arange(predictions_params["J"]),
|
|
67
75
|
},
|
|
68
|
-
dims=
|
|
69
|
-
pred_dims=
|
|
76
|
+
dims=dims,
|
|
77
|
+
pred_dims=pred_dims,
|
|
70
78
|
)
|
|
71
79
|
|
|
72
80
|
def test_inference_data_namedtuple(self, data):
|
|
@@ -77,6 +85,7 @@ class TestDataNumPyro:
|
|
|
77
85
|
data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
|
|
78
86
|
inference_data = from_numpyro(
|
|
79
87
|
posterior=data.obj,
|
|
88
|
+
dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
|
|
80
89
|
)
|
|
81
90
|
assert isinstance(data.obj.get_samples(), Samples)
|
|
82
91
|
data.obj.get_samples = _old_fn
|
|
@@ -282,3 +291,121 @@ class TestDataNumPyro:
|
|
|
282
291
|
mcmc.run(PRNGKey(0))
|
|
283
292
|
inference_data = from_numpyro(mcmc)
|
|
284
293
|
assert inference_data.observed_data
|
|
294
|
+
|
|
295
|
+
def test_mcmc_infer_dims(self):
|
|
296
|
+
import numpyro
|
|
297
|
+
import numpyro.distributions as dist
|
|
298
|
+
from numpyro.infer import MCMC, NUTS
|
|
299
|
+
|
|
300
|
+
def model():
|
|
301
|
+
# note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
|
|
302
|
+
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
|
|
303
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
304
|
+
|
|
305
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
306
|
+
mcmc.run(PRNGKey(0))
|
|
307
|
+
inference_data = from_numpyro(
|
|
308
|
+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
309
|
+
)
|
|
310
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
|
|
311
|
+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
312
|
+
|
|
313
|
+
def test_mcmc_infer_unsorted_dims(self):
|
|
314
|
+
import numpyro
|
|
315
|
+
import numpyro.distributions as dist
|
|
316
|
+
from numpyro.infer import MCMC, NUTS
|
|
317
|
+
|
|
318
|
+
def model():
|
|
319
|
+
group1_plate = numpyro.plate("group1", 10, dim=-1)
|
|
320
|
+
group2_plate = numpyro.plate("group2", 5, dim=-2)
|
|
321
|
+
|
|
322
|
+
# the plate contexts are entered in a different order than the pre-defined dims
|
|
323
|
+
# we should make sure this still works because the trace has all of the info it needs
|
|
324
|
+
with group2_plate, group1_plate:
|
|
325
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
326
|
+
|
|
327
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
328
|
+
mcmc.run(PRNGKey(0))
|
|
329
|
+
inference_data = from_numpyro(
|
|
330
|
+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
|
|
331
|
+
)
|
|
332
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
|
|
333
|
+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
|
|
334
|
+
|
|
335
|
+
def test_mcmc_infer_dims_no_coords(self):
|
|
336
|
+
import numpyro
|
|
337
|
+
import numpyro.distributions as dist
|
|
338
|
+
from numpyro.infer import MCMC, NUTS
|
|
339
|
+
|
|
340
|
+
def model():
|
|
341
|
+
with numpyro.plate("group", 5):
|
|
342
|
+
_ = numpyro.sample("param", dist.Normal(0, 1))
|
|
343
|
+
|
|
344
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
345
|
+
mcmc.run(PRNGKey(0))
|
|
346
|
+
inference_data = from_numpyro(mcmc)
|
|
347
|
+
assert inference_data.posterior.param.dims == ("chain", "draw", "group")
|
|
348
|
+
|
|
349
|
+
def test_mcmc_event_dims(self):
|
|
350
|
+
import numpyro
|
|
351
|
+
import numpyro.distributions as dist
|
|
352
|
+
from numpyro.infer import MCMC, NUTS
|
|
353
|
+
|
|
354
|
+
def model():
|
|
355
|
+
_ = numpyro.sample(
|
|
356
|
+
"gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
360
|
+
mcmc.run(PRNGKey(0))
|
|
361
|
+
inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
|
|
362
|
+
assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
|
|
363
|
+
assert "groups" in inference_data.posterior.gamma.coords
|
|
364
|
+
|
|
365
|
+
@pytest.mark.xfail
|
|
366
|
+
def test_mcmc_inferred_dims_univariate(self):
|
|
367
|
+
import numpyro
|
|
368
|
+
import numpyro.distributions as dist
|
|
369
|
+
from numpyro.infer import MCMC, NUTS
|
|
370
|
+
import jax.numpy as jnp
|
|
371
|
+
|
|
372
|
+
def model():
|
|
373
|
+
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
|
|
374
|
+
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
|
|
375
|
+
with numpyro.plate("obs_idx", 3):
|
|
376
|
+
# mu is plated by obs_idx, but isnt broadcasted to the plate shape
|
|
377
|
+
# the expected behavior is that this should cause a failure
|
|
378
|
+
mu = numpyro.deterministic("mu", alpha)
|
|
379
|
+
return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
|
|
380
|
+
|
|
381
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
382
|
+
mcmc.run(PRNGKey(0))
|
|
383
|
+
inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
|
|
384
|
+
assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
|
|
385
|
+
assert "obs_idx" in inference_data.posterior.mu.coords
|
|
386
|
+
|
|
387
|
+
def test_mcmc_extra_event_dims(self):
|
|
388
|
+
import numpyro
|
|
389
|
+
import numpyro.distributions as dist
|
|
390
|
+
from numpyro.infer import MCMC, NUTS
|
|
391
|
+
|
|
392
|
+
def model():
|
|
393
|
+
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
|
|
394
|
+
_ = numpyro.deterministic("gamma_plus1", gamma + 1)
|
|
395
|
+
|
|
396
|
+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
|
|
397
|
+
mcmc.run(PRNGKey(0))
|
|
398
|
+
inference_data = from_numpyro(
|
|
399
|
+
mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
|
|
400
|
+
)
|
|
401
|
+
assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
|
|
402
|
+
assert "groups" in inference_data.posterior.gamma_plus1.coords
|
|
403
|
+
|
|
404
|
+
def test_mcmc_predictions_infer_dims(
|
|
405
|
+
self, data, eight_schools_params, predictions_data, predictions_params
|
|
406
|
+
):
|
|
407
|
+
inference_data = self.get_inference_data(
|
|
408
|
+
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
|
|
409
|
+
)
|
|
410
|
+
assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
|
|
411
|
+
assert "J" in inference_data.predictions.obs.coords
|
arviz/utils.py
CHANGED
|
@@ -330,6 +330,7 @@ class Numba:
|
|
|
330
330
|
"""A class to toggle numba states."""
|
|
331
331
|
|
|
332
332
|
numba_flag = numba_check()
|
|
333
|
+
"""bool: Indicates whether Numba optimizations are enabled. Defaults to False."""
|
|
333
334
|
|
|
334
335
|
@classmethod
|
|
335
336
|
def disable_numba(cls):
|
|
@@ -732,7 +733,10 @@ class Dask:
|
|
|
732
733
|
"""
|
|
733
734
|
|
|
734
735
|
dask_flag = False
|
|
736
|
+
"""bool: Enables Dask parallelization when set to True. Defaults to False."""
|
|
735
737
|
dask_kwargs = None
|
|
738
|
+
"""dict: Additional keyword arguments for Dask configuration.
|
|
739
|
+
Defaults to an empty dictionary."""
|
|
736
740
|
|
|
737
741
|
@classmethod
|
|
738
742
|
def enable_dask(cls, dask_kwargs=None):
|
arviz/wrappers/base.py
CHANGED
|
@@ -197,7 +197,7 @@ class SamplingWrapper:
|
|
|
197
197
|
"""Check that all methods listed are implemented.
|
|
198
198
|
|
|
199
199
|
Not all functions that require refitting need to have all the methods implemented in
|
|
200
|
-
order to work properly. This function
|
|
200
|
+
order to work properly. This function should be used before using the SamplingWrapper and
|
|
201
201
|
its subclasses to get informative error messages.
|
|
202
202
|
|
|
203
203
|
Parameters
|
arviz/wrappers/wrap_stan.py
CHANGED
|
@@ -44,7 +44,7 @@ class StanSamplingWrapper(SamplingWrapper):
|
|
|
44
44
|
excluded_observed_data : str
|
|
45
45
|
Variable name containing the pointwise log likelihood data of the excluded
|
|
46
46
|
data. As PyStan cannot call C++ functions and log_likelihood__i is already
|
|
47
|
-
calculated *during* the
|
|
47
|
+
calculated *during* the simulation, instead of the value on which to evaluate
|
|
48
48
|
the likelihood, ``log_likelihood__i`` expects a string so it can extract the
|
|
49
49
|
corresponding data from the InferenceData object.
|
|
50
50
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.22.0
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -22,12 +22,12 @@ Requires-Python: >=3.10
|
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: setuptools>=60.0.0
|
|
25
|
-
Requires-Dist: matplotlib>=3.
|
|
26
|
-
Requires-Dist: numpy>=1.
|
|
27
|
-
Requires-Dist: scipy>=1.
|
|
25
|
+
Requires-Dist: matplotlib>=3.8
|
|
26
|
+
Requires-Dist: numpy>=1.26.0
|
|
27
|
+
Requires-Dist: scipy>=1.11.0
|
|
28
28
|
Requires-Dist: packaging
|
|
29
|
-
Requires-Dist: pandas>=1.
|
|
30
|
-
Requires-Dist: xarray>=
|
|
29
|
+
Requires-Dist: pandas>=2.1.0
|
|
30
|
+
Requires-Dist: xarray>=2023.7.0
|
|
31
31
|
Requires-Dist: h5netcdf>=1.0.2
|
|
32
32
|
Requires-Dist: typing-extensions>=4.1.0
|
|
33
33
|
Requires-Dist: xarray-einstats>=0.3
|
|
@@ -39,7 +39,7 @@ Requires-Dist: contourpy; extra == "all"
|
|
|
39
39
|
Requires-Dist: ujson; extra == "all"
|
|
40
40
|
Requires-Dist: dask[distributed]; extra == "all"
|
|
41
41
|
Requires-Dist: zarr<3,>=2.5.0; extra == "all"
|
|
42
|
-
Requires-Dist: xarray
|
|
42
|
+
Requires-Dist: xarray>=2024.11.0; extra == "all"
|
|
43
43
|
Requires-Dist: dm-tree>=0.1.8; extra == "all"
|
|
44
44
|
Provides-Extra: preview
|
|
45
45
|
Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
|