arviz 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +8 -3
- arviz/data/inference_data.py +37 -19
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +112 -4
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/lmplot.py +19 -3
- arviz/plots/backends/bokeh/pairplot.py +18 -48
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/stats/density_utils.py +1 -1
- arviz/stats/stats.py +19 -7
- arviz/tests/base_tests/test_data.py +0 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +77 -1
- arviz/tests/base_tests/test_stats.py +42 -1
- arviz/tests/external_tests/test_data_numpyro.py +130 -3
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/RECORD +28 -28
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +0 -0
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,8 @@ def plot_pair(
|
|
|
30
30
|
diverging_mask,
|
|
31
31
|
divergences_kwargs,
|
|
32
32
|
flat_var_names,
|
|
33
|
+
flat_ref_slices,
|
|
34
|
+
flat_var_labels,
|
|
33
35
|
backend_kwargs,
|
|
34
36
|
marginal_kwargs,
|
|
35
37
|
show,
|
|
@@ -77,24 +79,12 @@ def plot_pair(
|
|
|
77
79
|
kde_kwargs["contour_kwargs"].setdefault("colors", "k")
|
|
78
80
|
|
|
79
81
|
if reference_values:
|
|
80
|
-
|
|
81
|
-
label = []
|
|
82
|
-
for variable in list(reference_values.keys()):
|
|
83
|
-
if " " in variable:
|
|
84
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
85
|
-
else:
|
|
86
|
-
variable_copy = variable
|
|
87
|
-
|
|
88
|
-
label.append(variable_copy)
|
|
89
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
90
|
-
|
|
91
|
-
difference = set(flat_var_names).difference(set(label))
|
|
82
|
+
difference = set(flat_var_names).difference(set(reference_values.keys()))
|
|
92
83
|
|
|
93
84
|
if difference:
|
|
94
|
-
warn = [diff.replace("\n", " ", 1) for diff in difference]
|
|
95
85
|
warnings.warn(
|
|
96
86
|
"Argument reference_values does not include reference value for: {}".format(
|
|
97
|
-
", ".join(
|
|
87
|
+
", ".join(difference)
|
|
98
88
|
),
|
|
99
89
|
UserWarning,
|
|
100
90
|
)
|
|
@@ -211,12 +201,12 @@ def plot_pair(
|
|
|
211
201
|
|
|
212
202
|
if reference_values:
|
|
213
203
|
ax.plot(
|
|
214
|
-
|
|
215
|
-
|
|
204
|
+
np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
|
|
205
|
+
np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
|
|
216
206
|
**reference_values_kwargs,
|
|
217
207
|
)
|
|
218
|
-
ax.set_xlabel(f"{
|
|
219
|
-
ax.set_ylabel(f"{
|
|
208
|
+
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
|
|
209
|
+
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
|
|
220
210
|
ax.tick_params(labelsize=xt_labelsize)
|
|
221
211
|
|
|
222
212
|
else:
|
|
@@ -336,20 +326,22 @@ def plot_pair(
|
|
|
336
326
|
y_name = flat_var_names[j + not_marginals]
|
|
337
327
|
if (x_name not in difference) and (y_name not in difference):
|
|
338
328
|
ax[j, i].plot(
|
|
339
|
-
|
|
340
|
-
|
|
329
|
+
np.array(reference_values[x_name])[flat_ref_slices[i]],
|
|
330
|
+
np.array(reference_values[y_name])[
|
|
331
|
+
flat_ref_slices[j + not_marginals]
|
|
332
|
+
],
|
|
341
333
|
**reference_values_kwargs,
|
|
342
334
|
)
|
|
343
335
|
|
|
344
336
|
if j != vars_to_plot - 1:
|
|
345
337
|
plt.setp(ax[j, i].get_xticklabels(), visible=False)
|
|
346
338
|
else:
|
|
347
|
-
ax[j, i].set_xlabel(f"{
|
|
339
|
+
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
|
|
348
340
|
if i != 0:
|
|
349
341
|
plt.setp(ax[j, i].get_yticklabels(), visible=False)
|
|
350
342
|
else:
|
|
351
343
|
ax[j, i].set_ylabel(
|
|
352
|
-
f"{
|
|
344
|
+
f"{flat_var_labels[j + not_marginals]}",
|
|
353
345
|
fontsize=ax_labelsize,
|
|
354
346
|
wrap=True,
|
|
355
347
|
)
|
arviz/plots/kdeplot.py
CHANGED
|
@@ -255,6 +255,10 @@ def plot_kde(
|
|
|
255
255
|
"or plot_pair instead of plot_kde"
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
+
if backend is None:
|
|
259
|
+
backend = rcParams["plot.backend"]
|
|
260
|
+
backend = backend.lower()
|
|
261
|
+
|
|
258
262
|
if values2 is None:
|
|
259
263
|
if bw == "default":
|
|
260
264
|
bw = "taylor" if is_circular else "experimental"
|
|
@@ -346,10 +350,6 @@ def plot_kde(
|
|
|
346
350
|
**kwargs,
|
|
347
351
|
)
|
|
348
352
|
|
|
349
|
-
if backend is None:
|
|
350
|
-
backend = rcParams["plot.backend"]
|
|
351
|
-
backend = backend.lower()
|
|
352
|
-
|
|
353
353
|
# TODO: Add backend kwargs
|
|
354
354
|
plot = get_plotting_function("plot_kde", "kdeplot", backend)
|
|
355
355
|
ax = plot(**kde_plot_args)
|
arviz/plots/lmplot.py
CHANGED
|
@@ -300,20 +300,47 @@ def plot_lm(
|
|
|
300
300
|
# Filter out the required values to generate plotters
|
|
301
301
|
if y_model is not None:
|
|
302
302
|
if kind_model == "lines":
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
303
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
304
|
+
data = y_model.values
|
|
305
|
+
|
|
306
|
+
total_samples = data.shape[0] * data.shape[1]
|
|
307
|
+
data = data.reshape(total_samples, *data.shape[2:])
|
|
308
|
+
|
|
309
|
+
if pp_sample_ix is not None:
|
|
310
|
+
data = data[pp_sample_ix]
|
|
311
|
+
|
|
312
|
+
if plot_dim is not None:
|
|
313
|
+
# For plot_dim case, transpose to get dimension first
|
|
314
|
+
data = data.transpose(1, 0, 2)[..., 0]
|
|
315
|
+
|
|
316
|
+
# Create plotter tuple(s)
|
|
317
|
+
if plot_dim is not None:
|
|
318
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
319
|
+
else:
|
|
320
|
+
y_model = [(var_name, {}, {}, data)]
|
|
321
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
322
|
+
|
|
323
|
+
elif kind_model == "hdi":
|
|
324
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
325
|
+
data = y_model.values
|
|
326
|
+
|
|
327
|
+
if plot_dim is not None:
|
|
328
|
+
# First transpose to get plot_dim first
|
|
329
|
+
data = data.transpose(2, 0, 1, 3)
|
|
330
|
+
# For plot_dim case, we just want HDI for first dimension
|
|
331
|
+
data = data[..., 0]
|
|
332
|
+
|
|
333
|
+
# Reshape to (samples, points)
|
|
334
|
+
data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
|
|
335
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
data = data.reshape(-1, data.shape[-1])
|
|
339
|
+
y_model = [(var_name, {}, {}, data)]
|
|
340
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
341
|
+
|
|
342
|
+
if len(y_model) == 1:
|
|
343
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
317
344
|
|
|
318
345
|
rows, cols = default_grid(length_plotters)
|
|
319
346
|
|
arviz/plots/pairplot.py
CHANGED
|
@@ -196,9 +196,14 @@ def plot_pair(
|
|
|
196
196
|
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
|
|
197
197
|
)
|
|
198
198
|
)
|
|
199
|
-
flat_var_names = [
|
|
200
|
-
|
|
201
|
-
]
|
|
199
|
+
flat_var_names = []
|
|
200
|
+
flat_ref_slices = []
|
|
201
|
+
flat_var_labels = []
|
|
202
|
+
for var_name, sel, isel, _ in plotters:
|
|
203
|
+
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
|
|
204
|
+
flat_var_names.append(var_name)
|
|
205
|
+
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
|
|
206
|
+
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
|
|
202
207
|
|
|
203
208
|
divergent_data = None
|
|
204
209
|
diverging_mask = None
|
|
@@ -253,6 +258,8 @@ def plot_pair(
|
|
|
253
258
|
diverging_mask=diverging_mask,
|
|
254
259
|
divergences_kwargs=divergences_kwargs,
|
|
255
260
|
flat_var_names=flat_var_names,
|
|
261
|
+
flat_ref_slices=flat_ref_slices,
|
|
262
|
+
flat_var_labels=flat_var_labels,
|
|
256
263
|
backend_kwargs=backend_kwargs,
|
|
257
264
|
marginal_kwargs=marginal_kwargs,
|
|
258
265
|
show=show,
|
arviz/stats/density_utils.py
CHANGED
|
@@ -635,7 +635,7 @@ def _kde_circular(
|
|
|
635
635
|
cumulative: bool, optional
|
|
636
636
|
Whether return the PDF or the cumulative PDF. Defaults to False.
|
|
637
637
|
grid_len: int, optional
|
|
638
|
-
The number of intervals used to bin the data
|
|
638
|
+
The number of intervals used to bin the data point i.e. the length of the grid used in the
|
|
639
639
|
estimation. Defaults to 512.
|
|
640
640
|
"""
|
|
641
641
|
# All values between -pi and pi
|
arviz/stats/stats.py
CHANGED
|
@@ -869,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
869
869
|
)
|
|
870
870
|
|
|
871
871
|
|
|
872
|
-
def psislw(log_weights, reff=1.0):
|
|
872
|
+
def psislw(log_weights, reff=1.0, normalize=True):
|
|
873
873
|
"""
|
|
874
874
|
Pareto smoothed importance sampling (PSIS).
|
|
875
875
|
|
|
@@ -887,11 +887,13 @@ def psislw(log_weights, reff=1.0):
|
|
|
887
887
|
Array of size (n_observations, n_samples)
|
|
888
888
|
reff : float, default 1
|
|
889
889
|
relative MCMC efficiency, ``ess / n``
|
|
890
|
+
normalize : bool, default True
|
|
891
|
+
return normalized log weights
|
|
890
892
|
|
|
891
893
|
Returns
|
|
892
894
|
-------
|
|
893
895
|
lw_out : DataArray or (..., N) ndarray
|
|
894
|
-
Smoothed, truncated and normalized log weights.
|
|
896
|
+
Smoothed, truncated and possibly normalized log weights.
|
|
895
897
|
kss : DataArray or (...) ndarray
|
|
896
898
|
Estimates of the shape parameter *k* of the generalized Pareto
|
|
897
899
|
distribution.
|
|
@@ -936,7 +938,12 @@ def psislw(log_weights, reff=1.0):
|
|
|
936
938
|
out = np.empty_like(log_weights), np.empty(shape)
|
|
937
939
|
|
|
938
940
|
# define kwargs
|
|
939
|
-
func_kwargs = {
|
|
941
|
+
func_kwargs = {
|
|
942
|
+
"cutoff_ind": cutoff_ind,
|
|
943
|
+
"cutoffmin": cutoffmin,
|
|
944
|
+
"out": out,
|
|
945
|
+
"normalize": normalize,
|
|
946
|
+
}
|
|
940
947
|
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
|
|
941
948
|
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
|
|
942
949
|
log_weights, pareto_shape = _wrap_xarray_ufunc(
|
|
@@ -953,7 +960,7 @@ def psislw(log_weights, reff=1.0):
|
|
|
953
960
|
return log_weights, pareto_shape
|
|
954
961
|
|
|
955
962
|
|
|
956
|
-
def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
963
|
+
def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
|
|
957
964
|
"""
|
|
958
965
|
Pareto smoothed importance sampling (PSIS) for a 1D vector.
|
|
959
966
|
|
|
@@ -963,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
963
970
|
Array of length n_observations
|
|
964
971
|
cutoff_ind: int
|
|
965
972
|
cutoffmin: float
|
|
966
|
-
|
|
973
|
+
normalize: bool
|
|
967
974
|
|
|
968
975
|
Returns
|
|
969
976
|
-------
|
|
@@ -975,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
975
982
|
x = np.asarray(log_weights)
|
|
976
983
|
|
|
977
984
|
# improve numerical accuracy
|
|
978
|
-
|
|
985
|
+
max_x = np.max(x)
|
|
986
|
+
x -= max_x
|
|
979
987
|
# sort the array
|
|
980
988
|
x_sort_ind = np.argsort(x)
|
|
981
989
|
# divide log weights into body and right tail
|
|
@@ -1007,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
1007
1015
|
x[tailinds[x_tail_si]] = smoothed_tail
|
|
1008
1016
|
# truncate smoothed values to the largest raw weight 0
|
|
1009
1017
|
x[x > 0] = 0
|
|
1018
|
+
|
|
1010
1019
|
# renormalize weights
|
|
1011
|
-
|
|
1020
|
+
if normalize:
|
|
1021
|
+
x -= _logsumexp(x)
|
|
1022
|
+
else:
|
|
1023
|
+
x += max_x
|
|
1012
1024
|
|
|
1013
1025
|
return x, k
|
|
1014
1026
|
|
|
@@ -1501,10 +1501,6 @@ class TestJSON:
|
|
|
1501
1501
|
assert not os.path.exists(filepath)
|
|
1502
1502
|
|
|
1503
1503
|
|
|
1504
|
-
@pytest.mark.skipif(
|
|
1505
|
-
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1506
|
-
reason="test requires xarray-datatree library",
|
|
1507
|
-
)
|
|
1508
1504
|
class TestDataTree:
|
|
1509
1505
|
def test_datatree(self):
|
|
1510
1506
|
idata = load_arviz_data("centered_eight")
|
|
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
|
|
|
8
8
|
from scipy.stats import norm # pylint: disable=wrong-import-position
|
|
9
9
|
|
|
10
10
|
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
|
|
11
|
+
from ...labels import MapLabeller # pylint: disable=wrong-import-position
|
|
11
12
|
from ...plots import ( # pylint: disable=wrong-import-position
|
|
12
13
|
plot_autocorr,
|
|
13
14
|
plot_bpv,
|
|
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
|
|
|
773
774
|
{"divergences": True, "var_names": ["theta", "mu"]},
|
|
774
775
|
{"kind": "kde", "var_names": ["theta"]},
|
|
775
776
|
{"kind": "hexbin", "var_names": ["theta"]},
|
|
776
|
-
{"kind": "hexbin", "var_names": ["theta"]},
|
|
777
777
|
{
|
|
778
778
|
"kind": "hexbin",
|
|
779
779
|
"var_names": ["theta"],
|
|
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
|
|
|
785
785
|
"reference_values": {"mu": 0, "tau": 0},
|
|
786
786
|
"reference_values_kwargs": {"line_color": "blue"},
|
|
787
787
|
},
|
|
788
|
+
{
|
|
789
|
+
"var_names": ["mu", "tau"],
|
|
790
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
791
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
792
|
+
},
|
|
793
|
+
{
|
|
794
|
+
"var_names": ["theta"],
|
|
795
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
796
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
797
|
+
},
|
|
798
|
+
{
|
|
799
|
+
"var_names": ["theta"],
|
|
800
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
801
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
802
|
+
},
|
|
788
803
|
],
|
|
789
804
|
)
|
|
790
805
|
def test_plot_pair(models, kwargs):
|
|
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
|
|
|
1201
1216
|
},
|
|
1202
1217
|
],
|
|
1203
1218
|
)
|
|
1204
|
-
def
|
|
1219
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1205
1220
|
"""Test functionality for 1D data."""
|
|
1206
1221
|
idata = models.model_1
|
|
1207
1222
|
if "constant_data" not in idata.groups():
|
|
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
|
|
|
1228
1243
|
"""Test the plots when input data is list or ndarray."""
|
|
1229
1244
|
y = [1, 2, 3, 4, 5]
|
|
1230
1245
|
assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def generate_lm_1d_data():
|
|
1249
|
+
rng = np.random.default_rng()
|
|
1250
|
+
return from_dict(
|
|
1251
|
+
observed_data={"y": rng.normal(size=7)},
|
|
1252
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
1253
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
1254
|
+
dims={"y": ["dim1"]},
|
|
1255
|
+
coords={"dim1": range(7)},
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def generate_lm_2d_data():
|
|
1260
|
+
rng = np.random.default_rng()
|
|
1261
|
+
return from_dict(
|
|
1262
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
1263
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
1264
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
1265
|
+
dims={"y": ["dim1", "dim2"]},
|
|
1266
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
1271
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
1272
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
1273
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
1274
|
+
if data == "1d":
|
|
1275
|
+
idata = generate_lm_1d_data()
|
|
1276
|
+
else:
|
|
1277
|
+
idata = generate_lm_2d_data()
|
|
1278
|
+
|
|
1279
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
|
|
1280
|
+
if data == "2d":
|
|
1281
|
+
kwargs["plot_dim"] = "dim1"
|
|
1282
|
+
if use_y_model:
|
|
1283
|
+
kwargs["y_model"] = "y_model"
|
|
1284
|
+
if kind == "lines":
|
|
1285
|
+
kwargs["num_samples"] = 50
|
|
1286
|
+
|
|
1287
|
+
ax = plot_lm(**kwargs)
|
|
1288
|
+
assert ax is not None
|
|
@@ -14,6 +14,7 @@ from pandas import DataFrame
|
|
|
14
14
|
from scipy.stats import gaussian_kde, norm
|
|
15
15
|
|
|
16
16
|
from ...data import from_dict, load_arviz_data
|
|
17
|
+
from ...labels import MapLabeller
|
|
17
18
|
from ...plots import (
|
|
18
19
|
plot_autocorr,
|
|
19
20
|
plot_bf,
|
|
@@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
|
|
|
599
600
|
"reference_values": {"mu": 0, "tau": 0},
|
|
600
601
|
"reference_values_kwargs": {"c": "C0", "marker": "*"},
|
|
601
602
|
},
|
|
603
|
+
{
|
|
604
|
+
"var_names": ["mu", "tau"],
|
|
605
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
606
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
607
|
+
},
|
|
608
|
+
{
|
|
609
|
+
"var_names": ["theta"],
|
|
610
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
611
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
612
|
+
},
|
|
613
|
+
{
|
|
614
|
+
"var_names": ["theta"],
|
|
615
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
616
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
617
|
+
},
|
|
602
618
|
],
|
|
603
619
|
)
|
|
604
620
|
def test_plot_pair(models, kwargs):
|
|
@@ -1914,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
|
|
|
1914
1930
|
},
|
|
1915
1931
|
],
|
|
1916
1932
|
)
|
|
1917
|
-
def
|
|
1933
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1918
1934
|
"""Test functionality for 1D data."""
|
|
1919
1935
|
idata = models.model_1
|
|
1920
1936
|
if "constant_data" not in idata.groups():
|
|
@@ -2102,3 +2118,63 @@ def test_plot_bf():
|
|
|
2102
2118
|
)
|
|
2103
2119
|
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
|
|
2104
2120
|
assert bf_plot is not None
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
def generate_lm_1d_data():
|
|
2124
|
+
rng = np.random.default_rng()
|
|
2125
|
+
return from_dict(
|
|
2126
|
+
observed_data={"y": rng.normal(size=7)},
|
|
2127
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
2128
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
2129
|
+
dims={"y": ["dim1"]},
|
|
2130
|
+
coords={"dim1": range(7)},
|
|
2131
|
+
)
|
|
2132
|
+
|
|
2133
|
+
|
|
2134
|
+
def generate_lm_2d_data():
|
|
2135
|
+
rng = np.random.default_rng()
|
|
2136
|
+
return from_dict(
|
|
2137
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
2138
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
2139
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
2140
|
+
dims={"y": ["dim1", "dim2"]},
|
|
2141
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
2142
|
+
)
|
|
2143
|
+
|
|
2144
|
+
|
|
2145
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
2146
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
2147
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
2148
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
2149
|
+
if data == "1d":
|
|
2150
|
+
idata = generate_lm_1d_data()
|
|
2151
|
+
else:
|
|
2152
|
+
idata = generate_lm_2d_data()
|
|
2153
|
+
|
|
2154
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
|
|
2155
|
+
if data == "2d":
|
|
2156
|
+
kwargs["plot_dim"] = "dim1"
|
|
2157
|
+
if use_y_model:
|
|
2158
|
+
kwargs["y_model"] = "y_model"
|
|
2159
|
+
if kind == "lines":
|
|
2160
|
+
kwargs["num_samples"] = 50
|
|
2161
|
+
|
|
2162
|
+
ax = plot_lm(**kwargs)
|
|
2163
|
+
assert ax is not None
|
|
2164
|
+
|
|
2165
|
+
|
|
2166
|
+
@pytest.mark.parametrize(
|
|
2167
|
+
"coords, expected_vars",
|
|
2168
|
+
[
|
|
2169
|
+
({"school": ["Choate"]}, ["theta"]),
|
|
2170
|
+
({"school": ["Lawrenceville"]}, ["theta"]),
|
|
2171
|
+
({}, ["theta"]),
|
|
2172
|
+
],
|
|
2173
|
+
)
|
|
2174
|
+
def test_plot_autocorr_coords(coords, expected_vars):
|
|
2175
|
+
"""Test plot_autocorr with coords kwarg."""
|
|
2176
|
+
idata = load_arviz_data("centered_eight")
|
|
2177
|
+
|
|
2178
|
+
axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
|
|
2179
|
+
|
|
2180
|
+
assert axes is not None
|
|
@@ -14,7 +14,7 @@ from scipy.stats import linregress, norm, halfcauchy
|
|
|
14
14
|
from xarray import DataArray, Dataset
|
|
15
15
|
from xarray_einstats.stats import XrContinuousRV
|
|
16
16
|
|
|
17
|
-
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
|
|
17
|
+
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
|
|
18
18
|
from ...rcparams import rcParams
|
|
19
19
|
from ...stats import (
|
|
20
20
|
apply_test_function,
|
|
@@ -882,3 +882,44 @@ def test_bayes_factor():
|
|
|
882
882
|
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
|
|
883
883
|
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
884
884
|
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def test_compare_sorting_consistency():
|
|
888
|
+
chains, draws = 4, 1000
|
|
889
|
+
|
|
890
|
+
# Model 1 - good fit
|
|
891
|
+
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
|
|
892
|
+
posterior1 = Dataset(
|
|
893
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
894
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
895
|
+
)
|
|
896
|
+
log_like1 = Dataset(
|
|
897
|
+
{"y": (("chain", "draw"), log_lik1)},
|
|
898
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
899
|
+
)
|
|
900
|
+
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
|
|
901
|
+
|
|
902
|
+
# Model 2 - poor fit (higher variance)
|
|
903
|
+
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
|
|
904
|
+
posterior2 = Dataset(
|
|
905
|
+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
|
|
906
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
907
|
+
)
|
|
908
|
+
log_like2 = Dataset(
|
|
909
|
+
{"y": (("chain", "draw"), log_lik2)},
|
|
910
|
+
coords={"chain": range(chains), "draw": range(draws)},
|
|
911
|
+
)
|
|
912
|
+
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
|
|
913
|
+
|
|
914
|
+
# Compare models in different orders
|
|
915
|
+
comp_dict1 = {"M1": data1, "M2": data2}
|
|
916
|
+
comp_dict2 = {"M2": data2, "M1": data1}
|
|
917
|
+
|
|
918
|
+
comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
|
|
919
|
+
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
|
|
920
|
+
|
|
921
|
+
assert comparison1.index.tolist() == comparison2.index.tolist()
|
|
922
|
+
|
|
923
|
+
se1 = comparison1["se"].values
|
|
924
|
+
se2 = comparison2["se"].values
|
|
925
|
+
np.testing.assert_array_almost_equal(se1, se2)
|