arviz 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +8 -3
- arviz/data/base.py +2 -2
- arviz/data/inference_data.py +57 -26
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +112 -4
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/__init__.py +8 -7
- arviz/plots/backends/bokeh/bpvplot.py +4 -3
- arviz/plots/backends/bokeh/densityplot.py +5 -1
- arviz/plots/backends/bokeh/dotplot.py +5 -2
- arviz/plots/backends/bokeh/essplot.py +4 -2
- arviz/plots/backends/bokeh/forestplot.py +11 -4
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/khatplot.py +4 -2
- arviz/plots/backends/bokeh/lmplot.py +28 -6
- arviz/plots/backends/bokeh/mcseplot.py +2 -2
- arviz/plots/backends/bokeh/pairplot.py +27 -52
- arviz/plots/backends/bokeh/ppcplot.py +2 -1
- arviz/plots/backends/bokeh/rankplot.py +2 -1
- arviz/plots/backends/bokeh/traceplot.py +2 -1
- arviz/plots/backends/bokeh/violinplot.py +2 -1
- arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/bfplot.py +9 -26
- arviz/plots/bpvplot.py +10 -1
- arviz/plots/hdiplot.py +5 -0
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/plots/plot_utils.py +5 -3
- arviz/preview.py +36 -5
- arviz/stats/__init__.py +1 -0
- arviz/stats/density_utils.py +1 -1
- arviz/stats/diagnostics.py +18 -14
- arviz/stats/stats.py +105 -7
- arviz/tests/base_tests/test_data.py +31 -11
- arviz/tests/base_tests/test_diagnostics.py +5 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
- arviz/tests/base_tests/test_stats.py +53 -1
- arviz/tests/external_tests/test_data_numpyro.py +130 -3
- arviz/utils.py +4 -0
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/RECORD +51 -51
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
arviz/plots/lmplot.py
CHANGED
|
@@ -300,20 +300,47 @@ def plot_lm(
|
|
|
300
300
|
# Filter out the required values to generate plotters
|
|
301
301
|
if y_model is not None:
|
|
302
302
|
if kind_model == "lines":
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
303
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
304
|
+
data = y_model.values
|
|
305
|
+
|
|
306
|
+
total_samples = data.shape[0] * data.shape[1]
|
|
307
|
+
data = data.reshape(total_samples, *data.shape[2:])
|
|
308
|
+
|
|
309
|
+
if pp_sample_ix is not None:
|
|
310
|
+
data = data[pp_sample_ix]
|
|
311
|
+
|
|
312
|
+
if plot_dim is not None:
|
|
313
|
+
# For plot_dim case, transpose to get dimension first
|
|
314
|
+
data = data.transpose(1, 0, 2)[..., 0]
|
|
315
|
+
|
|
316
|
+
# Create plotter tuple(s)
|
|
317
|
+
if plot_dim is not None:
|
|
318
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
319
|
+
else:
|
|
320
|
+
y_model = [(var_name, {}, {}, data)]
|
|
321
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
322
|
+
|
|
323
|
+
elif kind_model == "hdi":
|
|
324
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
325
|
+
data = y_model.values
|
|
326
|
+
|
|
327
|
+
if plot_dim is not None:
|
|
328
|
+
# First transpose to get plot_dim first
|
|
329
|
+
data = data.transpose(2, 0, 1, 3)
|
|
330
|
+
# For plot_dim case, we just want HDI for first dimension
|
|
331
|
+
data = data[..., 0]
|
|
332
|
+
|
|
333
|
+
# Reshape to (samples, points)
|
|
334
|
+
data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
|
|
335
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
data = data.reshape(-1, data.shape[-1])
|
|
339
|
+
y_model = [(var_name, {}, {}, data)]
|
|
340
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
341
|
+
|
|
342
|
+
if len(y_model) == 1:
|
|
343
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
317
344
|
|
|
318
345
|
rows, cols = default_grid(length_plotters)
|
|
319
346
|
|
arviz/plots/pairplot.py
CHANGED
|
@@ -196,9 +196,14 @@ def plot_pair(
|
|
|
196
196
|
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
|
|
197
197
|
)
|
|
198
198
|
)
|
|
199
|
-
flat_var_names = [
|
|
200
|
-
|
|
201
|
-
]
|
|
199
|
+
flat_var_names = []
|
|
200
|
+
flat_ref_slices = []
|
|
201
|
+
flat_var_labels = []
|
|
202
|
+
for var_name, sel, isel, _ in plotters:
|
|
203
|
+
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
|
|
204
|
+
flat_var_names.append(var_name)
|
|
205
|
+
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
|
|
206
|
+
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
|
|
202
207
|
|
|
203
208
|
divergent_data = None
|
|
204
209
|
diverging_mask = None
|
|
@@ -253,6 +258,8 @@ def plot_pair(
|
|
|
253
258
|
diverging_mask=diverging_mask,
|
|
254
259
|
divergences_kwargs=divergences_kwargs,
|
|
255
260
|
flat_var_names=flat_var_names,
|
|
261
|
+
flat_ref_slices=flat_ref_slices,
|
|
262
|
+
flat_var_labels=flat_var_labels,
|
|
256
263
|
backend_kwargs=backend_kwargs,
|
|
257
264
|
marginal_kwargs=marginal_kwargs,
|
|
258
265
|
show=show,
|
arviz/plots/plot_utils.py
CHANGED
|
@@ -482,16 +482,18 @@ def plot_point_interval(
|
|
|
482
482
|
if point_estimate:
|
|
483
483
|
point_value = calculate_point_estimate(point_estimate, values)
|
|
484
484
|
if rotated:
|
|
485
|
-
ax.
|
|
485
|
+
ax.scatter(
|
|
486
486
|
x=0,
|
|
487
487
|
y=point_value,
|
|
488
|
+
marker="circle",
|
|
488
489
|
size=markersize,
|
|
489
490
|
fill_color=markercolor,
|
|
490
491
|
)
|
|
491
492
|
else:
|
|
492
|
-
ax.
|
|
493
|
+
ax.scatter(
|
|
493
494
|
x=point_value,
|
|
494
495
|
y=0,
|
|
496
|
+
marker="circle",
|
|
495
497
|
size=markersize,
|
|
496
498
|
fill_color=markercolor,
|
|
497
499
|
)
|
|
@@ -534,7 +536,7 @@ def set_bokeh_circular_ticks_labels(ax, hist, labels):
|
|
|
534
536
|
)
|
|
535
537
|
|
|
536
538
|
radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
|
|
537
|
-
ax.
|
|
539
|
+
ax.scatter(0, 0, marker="circle", radius=radii_circles, fill_color=None, line_color="grey")
|
|
538
540
|
|
|
539
541
|
offset = np.max(hist * 1.05) * 0.15
|
|
540
542
|
ticks_labels_pos_1 = np.max(hist * 1.05)
|
arviz/preview.py
CHANGED
|
@@ -1,17 +1,48 @@
|
|
|
1
|
-
# pylint: disable=unused-import,unused-wildcard-import,wildcard-import
|
|
1
|
+
# pylint: disable=unused-import,unused-wildcard-import,wildcard-import,invalid-name
|
|
2
2
|
"""Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
_log = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
info = ""
|
|
3
8
|
|
|
4
9
|
try:
|
|
5
10
|
from arviz_base import *
|
|
11
|
+
|
|
12
|
+
status = "arviz_base available, exposing its functions as part of arviz.preview"
|
|
13
|
+
_log.info(status)
|
|
6
14
|
except ModuleNotFoundError:
|
|
7
|
-
|
|
15
|
+
status = "arviz_base not installed"
|
|
16
|
+
_log.info(status)
|
|
17
|
+
except ImportError:
|
|
18
|
+
status = "Unable to import arviz_base"
|
|
19
|
+
_log.info(status, exc_info=True)
|
|
20
|
+
|
|
21
|
+
info += status + "\n"
|
|
8
22
|
|
|
9
23
|
try:
|
|
10
|
-
import
|
|
24
|
+
from arviz_stats import *
|
|
25
|
+
|
|
26
|
+
status = "arviz_stats available, exposing its functions as part of arviz.preview"
|
|
27
|
+
_log.info(status)
|
|
11
28
|
except ModuleNotFoundError:
|
|
12
|
-
|
|
29
|
+
status = "arviz_stats not installed"
|
|
30
|
+
_log.info(status)
|
|
31
|
+
except ImportError:
|
|
32
|
+
status = "Unable to import arviz_stats"
|
|
33
|
+
_log.info(status, exc_info=True)
|
|
34
|
+
info += status + "\n"
|
|
13
35
|
|
|
14
36
|
try:
|
|
15
37
|
from arviz_plots import *
|
|
38
|
+
|
|
39
|
+
status = "arviz_plots available, exposing its functions as part of arviz.preview"
|
|
40
|
+
_log.info(status)
|
|
16
41
|
except ModuleNotFoundError:
|
|
17
|
-
|
|
42
|
+
status = "arviz_plots not installed"
|
|
43
|
+
_log.info(status)
|
|
44
|
+
except ImportError:
|
|
45
|
+
status = "Unable to import arviz_plots"
|
|
46
|
+
_log.info(status, exc_info=True)
|
|
47
|
+
|
|
48
|
+
info += status + "\n"
|
arviz/stats/__init__.py
CHANGED
arviz/stats/density_utils.py
CHANGED
|
@@ -635,7 +635,7 @@ def _kde_circular(
|
|
|
635
635
|
cumulative: bool, optional
|
|
636
636
|
Whether return the PDF or the cumulative PDF. Defaults to False.
|
|
637
637
|
grid_len: int, optional
|
|
638
|
-
The number of intervals used to bin the data
|
|
638
|
+
The number of intervals used to bin the data point i.e. the length of the grid used in the
|
|
639
639
|
estimation. Defaults to 512.
|
|
640
640
|
"""
|
|
641
641
|
# All values between -pi and pi
|
arviz/stats/diagnostics.py
CHANGED
|
@@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
|
|
|
744
744
|
ary = np.asarray(ary)
|
|
745
745
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
746
746
|
return np.nan
|
|
747
|
-
ary =
|
|
748
|
-
return
|
|
747
|
+
ary = (ary - ary.mean()) ** 2
|
|
748
|
+
return _ess(_split_chains(ary), relative=relative)
|
|
749
749
|
|
|
750
750
|
|
|
751
751
|
def _ess_quantile(ary, prob, relative=False):
|
|
@@ -838,13 +838,15 @@ def _mcse_sd(ary):
|
|
|
838
838
|
ary = np.asarray(ary)
|
|
839
839
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
840
840
|
return np.nan
|
|
841
|
-
|
|
841
|
+
sims_c2 = (ary - ary.mean()) ** 2
|
|
842
|
+
ess = _ess_mean(sims_c2)
|
|
843
|
+
evar = (sims_c2).mean()
|
|
844
|
+
varvar = ((sims_c2**2).mean() - evar**2) / ess
|
|
845
|
+
varsd = varvar / evar / 4
|
|
842
846
|
if _numba_flag:
|
|
843
|
-
|
|
847
|
+
mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
|
|
844
848
|
else:
|
|
845
|
-
|
|
846
|
-
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
|
|
847
|
-
mcse_sd_value = sd * fac_mcse_sd
|
|
849
|
+
mcse_sd_value = np.sqrt(varsd)
|
|
848
850
|
return mcse_sd_value
|
|
849
851
|
|
|
850
852
|
|
|
@@ -973,19 +975,21 @@ def _multichain_statistics(ary, focus="mean"):
|
|
|
973
975
|
# ess mean
|
|
974
976
|
ess_mean_value = _ess_mean(ary)
|
|
975
977
|
|
|
976
|
-
# ess sd
|
|
977
|
-
ess_sd_value = _ess_sd(ary)
|
|
978
|
-
|
|
979
978
|
# mcse_mean
|
|
980
|
-
|
|
981
|
-
|
|
979
|
+
sims_c2 = (ary - ary.mean()) ** 2
|
|
980
|
+
sims_c2_sum = sims_c2.sum()
|
|
981
|
+
var = sims_c2_sum / (sims_c2.size - 1)
|
|
982
|
+
mcse_mean_value = np.sqrt(var / ess_mean_value)
|
|
982
983
|
|
|
983
984
|
# ess bulk
|
|
984
985
|
ess_bulk_value = _ess(z_split)
|
|
985
986
|
|
|
986
987
|
# mcse_sd
|
|
987
|
-
|
|
988
|
-
|
|
988
|
+
evar = sims_c2_sum / sims_c2.size
|
|
989
|
+
ess_mean_sims = _ess_mean(sims_c2)
|
|
990
|
+
varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
|
|
991
|
+
varsd = varvar / evar / 4
|
|
992
|
+
mcse_sd_value = np.sqrt(varsd)
|
|
989
993
|
|
|
990
994
|
return (
|
|
991
995
|
mcse_mean_value,
|
arviz/stats/stats.py
CHANGED
|
@@ -27,6 +27,7 @@ from ..utils import Numba, _numba_var, _var_names, get_coords
|
|
|
27
27
|
from .density_utils import get_bins as _get_bins
|
|
28
28
|
from .density_utils import histogram as _histogram
|
|
29
29
|
from .density_utils import kde as _kde
|
|
30
|
+
from .density_utils import _kde_linear
|
|
30
31
|
from .diagnostics import _mc_error, _multichain_statistics, ess
|
|
31
32
|
from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
|
|
32
33
|
from .stats_utils import get_log_likelihood as _get_log_likelihood
|
|
@@ -41,6 +42,7 @@ from ..labels import BaseLabeller
|
|
|
41
42
|
|
|
42
43
|
__all__ = [
|
|
43
44
|
"apply_test_function",
|
|
45
|
+
"bayes_factor",
|
|
44
46
|
"compare",
|
|
45
47
|
"hdi",
|
|
46
48
|
"loo",
|
|
@@ -867,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
867
869
|
)
|
|
868
870
|
|
|
869
871
|
|
|
870
|
-
def psislw(log_weights, reff=1.0):
|
|
872
|
+
def psislw(log_weights, reff=1.0, normalize=True):
|
|
871
873
|
"""
|
|
872
874
|
Pareto smoothed importance sampling (PSIS).
|
|
873
875
|
|
|
@@ -885,11 +887,13 @@ def psislw(log_weights, reff=1.0):
|
|
|
885
887
|
Array of size (n_observations, n_samples)
|
|
886
888
|
reff : float, default 1
|
|
887
889
|
relative MCMC efficiency, ``ess / n``
|
|
890
|
+
normalize : bool, default True
|
|
891
|
+
return normalized log weights
|
|
888
892
|
|
|
889
893
|
Returns
|
|
890
894
|
-------
|
|
891
895
|
lw_out : DataArray or (..., N) ndarray
|
|
892
|
-
Smoothed, truncated and normalized log weights.
|
|
896
|
+
Smoothed, truncated and possibly normalized log weights.
|
|
893
897
|
kss : DataArray or (...) ndarray
|
|
894
898
|
Estimates of the shape parameter *k* of the generalized Pareto
|
|
895
899
|
distribution.
|
|
@@ -934,7 +938,12 @@ def psislw(log_weights, reff=1.0):
|
|
|
934
938
|
out = np.empty_like(log_weights), np.empty(shape)
|
|
935
939
|
|
|
936
940
|
# define kwargs
|
|
937
|
-
func_kwargs = {
|
|
941
|
+
func_kwargs = {
|
|
942
|
+
"cutoff_ind": cutoff_ind,
|
|
943
|
+
"cutoffmin": cutoffmin,
|
|
944
|
+
"out": out,
|
|
945
|
+
"normalize": normalize,
|
|
946
|
+
}
|
|
938
947
|
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
|
|
939
948
|
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
|
|
940
949
|
log_weights, pareto_shape = _wrap_xarray_ufunc(
|
|
@@ -951,7 +960,7 @@ def psislw(log_weights, reff=1.0):
|
|
|
951
960
|
return log_weights, pareto_shape
|
|
952
961
|
|
|
953
962
|
|
|
954
|
-
def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
963
|
+
def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
|
|
955
964
|
"""
|
|
956
965
|
Pareto smoothed importance sampling (PSIS) for a 1D vector.
|
|
957
966
|
|
|
@@ -961,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
961
970
|
Array of length n_observations
|
|
962
971
|
cutoff_ind: int
|
|
963
972
|
cutoffmin: float
|
|
964
|
-
|
|
973
|
+
normalize: bool
|
|
965
974
|
|
|
966
975
|
Returns
|
|
967
976
|
-------
|
|
@@ -973,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
973
982
|
x = np.asarray(log_weights)
|
|
974
983
|
|
|
975
984
|
# improve numerical accuracy
|
|
976
|
-
|
|
985
|
+
max_x = np.max(x)
|
|
986
|
+
x -= max_x
|
|
977
987
|
# sort the array
|
|
978
988
|
x_sort_ind = np.argsort(x)
|
|
979
989
|
# divide log weights into body and right tail
|
|
@@ -1005,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
1005
1015
|
x[tailinds[x_tail_si]] = smoothed_tail
|
|
1006
1016
|
# truncate smoothed values to the largest raw weight 0
|
|
1007
1017
|
x[x > 0] = 0
|
|
1018
|
+
|
|
1008
1019
|
# renormalize weights
|
|
1009
|
-
|
|
1020
|
+
if normalize:
|
|
1021
|
+
x -= _logsumexp(x)
|
|
1022
|
+
else:
|
|
1023
|
+
x += max_x
|
|
1010
1024
|
|
|
1011
1025
|
return x, k
|
|
1012
1026
|
|
|
@@ -2337,3 +2351,87 @@ def _cjs_dist(draws, weights):
|
|
|
2337
2351
|
bound = cdf_p_int + cdf_q_int
|
|
2338
2352
|
|
|
2339
2353
|
return np.sqrt((cjs_pq + cjs_qp) / bound)
|
|
2354
|
+
|
|
2355
|
+
|
|
2356
|
+
def bayes_factor(idata, var_name, ref_val=0, prior=None, return_ref_vals=False):
|
|
2357
|
+
r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
|
|
2358
|
+
|
|
2359
|
+
The Bayes factor is estimated by comparing a model (H1) against a model in which the
|
|
2360
|
+
parameter of interest has been restricted to be a point-null (H0). This computation
|
|
2361
|
+
assumes the models are nested and thus H0 is a special case of H1.
|
|
2362
|
+
|
|
2363
|
+
Notes
|
|
2364
|
+
-----
|
|
2365
|
+
The bayes Factor is approximated as the Savage-Dickey density ratio
|
|
2366
|
+
algorithm presented in [1]_.
|
|
2367
|
+
|
|
2368
|
+
Parameters
|
|
2369
|
+
----------
|
|
2370
|
+
idata : InferenceData
|
|
2371
|
+
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
2372
|
+
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
2373
|
+
var_name : str, optional
|
|
2374
|
+
Name of variable we want to test.
|
|
2375
|
+
ref_val : int, default 0
|
|
2376
|
+
Point-null for Bayes factor estimation.
|
|
2377
|
+
prior : numpy.array, optional
|
|
2378
|
+
In case we want to use different prior, for example for sensitivity analysis.
|
|
2379
|
+
return_ref_vals : bool, optional
|
|
2380
|
+
Whether to return the values of the prior and posterior at the reference value.
|
|
2381
|
+
Used by :func:`arviz.plot_bf` to display the distribution comparison.
|
|
2382
|
+
|
|
2383
|
+
|
|
2384
|
+
Returns
|
|
2385
|
+
-------
|
|
2386
|
+
dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
|
|
2387
|
+
|
|
2388
|
+
References
|
|
2389
|
+
----------
|
|
2390
|
+
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
|
|
2391
|
+
The case of computing Bayes factors for regression parameters.
|
|
2392
|
+
|
|
2393
|
+
Examples
|
|
2394
|
+
--------
|
|
2395
|
+
Moderate evidence indicating that the parameter "a" is different from zero.
|
|
2396
|
+
|
|
2397
|
+
.. ipython::
|
|
2398
|
+
|
|
2399
|
+
In [1]: import numpy as np
|
|
2400
|
+
...: import arviz as az
|
|
2401
|
+
...: idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
|
|
2402
|
+
...: prior={"a":np.random.normal(0, 1, 5000)})
|
|
2403
|
+
...: az.bayes_factor(idata, var_name="a", ref_val=0)
|
|
2404
|
+
|
|
2405
|
+
"""
|
|
2406
|
+
|
|
2407
|
+
posterior = extract(idata, var_names=var_name).values
|
|
2408
|
+
|
|
2409
|
+
if ref_val > posterior.max() or ref_val < posterior.min():
|
|
2410
|
+
_log.warning(
|
|
2411
|
+
"The reference value is outside of the posterior. "
|
|
2412
|
+
"This translate into infinite support for H1, which is most likely an overstatement."
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
if posterior.ndim > 1:
|
|
2416
|
+
_log.warning("Posterior distribution has {posterior.ndim} dimensions")
|
|
2417
|
+
|
|
2418
|
+
if prior is None:
|
|
2419
|
+
prior = extract(idata, var_names=var_name, group="prior").values
|
|
2420
|
+
|
|
2421
|
+
if posterior.dtype.kind == "f":
|
|
2422
|
+
posterior_grid, posterior_pdf, *_ = _kde_linear(posterior)
|
|
2423
|
+
prior_grid, prior_pdf, *_ = _kde_linear(prior)
|
|
2424
|
+
posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
|
|
2425
|
+
prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)
|
|
2426
|
+
|
|
2427
|
+
elif posterior.dtype.kind == "i":
|
|
2428
|
+
posterior_at_ref_val = (posterior == ref_val).mean()
|
|
2429
|
+
prior_at_ref_val = (prior == ref_val).mean()
|
|
2430
|
+
|
|
2431
|
+
bf_10 = prior_at_ref_val / posterior_at_ref_val
|
|
2432
|
+
bf = {"BF10": bf_10, "BF01": 1 / bf_10}
|
|
2433
|
+
|
|
2434
|
+
if return_ref_vals:
|
|
2435
|
+
return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
|
|
2436
|
+
else:
|
|
2437
|
+
return bf
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import importlib
|
|
5
5
|
import os
|
|
6
|
+
import warnings
|
|
6
7
|
from collections import namedtuple
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from html import escape
|
|
@@ -32,7 +33,13 @@ from ... import (
|
|
|
32
33
|
extract,
|
|
33
34
|
)
|
|
34
35
|
|
|
35
|
-
from ...data.base import
|
|
36
|
+
from ...data.base import (
|
|
37
|
+
dict_to_dataset,
|
|
38
|
+
generate_dims_coords,
|
|
39
|
+
infer_stan_dtypes,
|
|
40
|
+
make_attrs,
|
|
41
|
+
numpy_to_data_array,
|
|
42
|
+
)
|
|
36
43
|
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
|
|
37
44
|
from ..helpers import ( # pylint: disable=unused-import
|
|
38
45
|
chains,
|
|
@@ -231,6 +238,17 @@ def test_dims_coords_skip_event_dims(shape):
|
|
|
231
238
|
assert "z" not in coords
|
|
232
239
|
|
|
233
240
|
|
|
241
|
+
@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
|
|
242
|
+
def test_numpy_to_data_array_with_dims(dims):
|
|
243
|
+
da = numpy_to_data_array(
|
|
244
|
+
np.empty((4, 500, 7)),
|
|
245
|
+
var_name="a",
|
|
246
|
+
dims=dims,
|
|
247
|
+
default_dims=["chain", "draw"],
|
|
248
|
+
)
|
|
249
|
+
assert list(da.dims) == ["chain", "draw", "a_dim_0"]
|
|
250
|
+
|
|
251
|
+
|
|
234
252
|
def test_make_attrs():
|
|
235
253
|
extra_attrs = {"key": "Value"}
|
|
236
254
|
attrs = make_attrs(attrs=extra_attrs)
|
|
@@ -921,7 +939,7 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
921
939
|
data = np.random.normal(size=(4, 500, 8))
|
|
922
940
|
idata = data_random
|
|
923
941
|
with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
|
|
924
|
-
idata.add_groups({"new_group": idata.posterior})
|
|
942
|
+
idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
|
|
925
943
|
with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
|
|
926
944
|
idata.add_groups(constant_data={"a": data[..., 0], "b": data})
|
|
927
945
|
assert idata.new_group.equals(idata.posterior)
|
|
@@ -962,8 +980,8 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
962
980
|
with pytest.raises(ValueError, match="join must be either"):
|
|
963
981
|
idata.extend(idata2, join="outer")
|
|
964
982
|
idata2.add_groups(new_group=idata2.prior)
|
|
965
|
-
with pytest.warns(UserWarning):
|
|
966
|
-
idata.extend(idata2)
|
|
983
|
+
with pytest.warns(UserWarning, match="new_group"):
|
|
984
|
+
idata.extend(idata2, warn_on_custom_groups=True)
|
|
967
985
|
|
|
968
986
|
|
|
969
987
|
class TestNumpyToDataArray:
|
|
@@ -1156,11 +1174,17 @@ def test_bad_inference_data():
|
|
|
1156
1174
|
InferenceData(posterior=[1, 2, 3])
|
|
1157
1175
|
|
|
1158
1176
|
|
|
1159
|
-
|
|
1177
|
+
@pytest.mark.parametrize("warn", [True, False])
|
|
1178
|
+
def test_inference_data_other_groups(warn):
|
|
1160
1179
|
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
|
|
1161
1180
|
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
|
|
1162
|
-
|
|
1163
|
-
|
|
1181
|
+
if warn:
|
|
1182
|
+
with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
|
|
1183
|
+
idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
|
|
1184
|
+
else:
|
|
1185
|
+
with warnings.catch_warnings():
|
|
1186
|
+
warnings.simplefilter("error")
|
|
1187
|
+
idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
|
|
1164
1188
|
fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
|
|
1165
1189
|
assert not fails
|
|
1166
1190
|
|
|
@@ -1477,10 +1501,6 @@ class TestJSON:
|
|
|
1477
1501
|
assert not os.path.exists(filepath)
|
|
1478
1502
|
|
|
1479
1503
|
|
|
1480
|
-
@pytest.mark.skipif(
|
|
1481
|
-
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1482
|
-
reason="test requires xarray-datatree library",
|
|
1483
|
-
)
|
|
1484
1504
|
class TestDataTree:
|
|
1485
1505
|
def test_datatree(self):
|
|
1486
1506
|
idata = load_arviz_data("centered_eight")
|
|
@@ -120,10 +120,11 @@ class TestDiagnostics:
|
|
|
120
120
|
```
|
|
121
121
|
Reference file:
|
|
122
122
|
|
|
123
|
-
Created:
|
|
124
|
-
System: Ubuntu
|
|
125
|
-
R version 4.
|
|
126
|
-
posterior
|
|
123
|
+
Created: 2024-12-20
|
|
124
|
+
System: Ubuntu 24.04.1 LTS
|
|
125
|
+
R version 4.4.2 (2024-10-31)
|
|
126
|
+
posterior version from https://github.com/stan-dev/posterior/pull/388
|
|
127
|
+
(after release 1.6.0 but before the fixes in the PR were released).
|
|
127
128
|
"""
|
|
128
129
|
# download input files
|
|
129
130
|
here = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
|
|
|
8
8
|
from scipy.stats import norm # pylint: disable=wrong-import-position
|
|
9
9
|
|
|
10
10
|
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
|
|
11
|
+
from ...labels import MapLabeller # pylint: disable=wrong-import-position
|
|
11
12
|
from ...plots import ( # pylint: disable=wrong-import-position
|
|
12
13
|
plot_autocorr,
|
|
13
14
|
plot_bpv,
|
|
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
|
|
|
773
774
|
{"divergences": True, "var_names": ["theta", "mu"]},
|
|
774
775
|
{"kind": "kde", "var_names": ["theta"]},
|
|
775
776
|
{"kind": "hexbin", "var_names": ["theta"]},
|
|
776
|
-
{"kind": "hexbin", "var_names": ["theta"]},
|
|
777
777
|
{
|
|
778
778
|
"kind": "hexbin",
|
|
779
779
|
"var_names": ["theta"],
|
|
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
|
|
|
785
785
|
"reference_values": {"mu": 0, "tau": 0},
|
|
786
786
|
"reference_values_kwargs": {"line_color": "blue"},
|
|
787
787
|
},
|
|
788
|
+
{
|
|
789
|
+
"var_names": ["mu", "tau"],
|
|
790
|
+
"reference_values": {"mu": 0, "tau": 0},
|
|
791
|
+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
|
|
792
|
+
},
|
|
793
|
+
{
|
|
794
|
+
"var_names": ["theta"],
|
|
795
|
+
"reference_values": {"theta": [0.0] * 8},
|
|
796
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
797
|
+
},
|
|
798
|
+
{
|
|
799
|
+
"var_names": ["theta"],
|
|
800
|
+
"reference_values": {"theta": np.zeros(8)},
|
|
801
|
+
"labeller": MapLabeller({"theta": r"$\theta$"}),
|
|
802
|
+
},
|
|
788
803
|
],
|
|
789
804
|
)
|
|
790
805
|
def test_plot_pair(models, kwargs):
|
|
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
|
|
|
1201
1216
|
},
|
|
1202
1217
|
],
|
|
1203
1218
|
)
|
|
1204
|
-
def
|
|
1219
|
+
def test_plot_lm_1d(models, kwargs):
|
|
1205
1220
|
"""Test functionality for 1D data."""
|
|
1206
1221
|
idata = models.model_1
|
|
1207
1222
|
if "constant_data" not in idata.groups():
|
|
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
|
|
|
1228
1243
|
"""Test the plots when input data is list or ndarray."""
|
|
1229
1244
|
y = [1, 2, 3, 4, 5]
|
|
1230
1245
|
assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def generate_lm_1d_data():
|
|
1249
|
+
rng = np.random.default_rng()
|
|
1250
|
+
return from_dict(
|
|
1251
|
+
observed_data={"y": rng.normal(size=7)},
|
|
1252
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
|
|
1253
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
|
|
1254
|
+
dims={"y": ["dim1"]},
|
|
1255
|
+
coords={"dim1": range(7)},
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def generate_lm_2d_data():
|
|
1260
|
+
rng = np.random.default_rng()
|
|
1261
|
+
return from_dict(
|
|
1262
|
+
observed_data={"y": rng.normal(size=(5, 7))},
|
|
1263
|
+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
|
|
1264
|
+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
|
|
1265
|
+
dims={"y": ["dim1", "dim2"]},
|
|
1266
|
+
coords={"dim1": range(5), "dim2": range(7)},
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
@pytest.mark.parametrize("data", ("1d", "2d"))
|
|
1271
|
+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
|
|
1272
|
+
@pytest.mark.parametrize("use_y_model", (True, False))
|
|
1273
|
+
def test_plot_lm(data, kind, use_y_model):
|
|
1274
|
+
if data == "1d":
|
|
1275
|
+
idata = generate_lm_1d_data()
|
|
1276
|
+
else:
|
|
1277
|
+
idata = generate_lm_2d_data()
|
|
1278
|
+
|
|
1279
|
+
kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
|
|
1280
|
+
if data == "2d":
|
|
1281
|
+
kwargs["plot_dim"] = "dim1"
|
|
1282
|
+
if use_y_model:
|
|
1283
|
+
kwargs["y_model"] = "y_model"
|
|
1284
|
+
if kind == "lines":
|
|
1285
|
+
kwargs["num_samples"] = 50
|
|
1286
|
+
|
|
1287
|
+
ax = plot_lm(**kwargs)
|
|
1288
|
+
assert ax is not None
|