arviz 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/base.py +20 -9
- arviz/data/converters.py +7 -3
- arviz/data/inference_data.py +28 -7
- arviz/plots/backends/__init__.py +8 -7
- arviz/plots/backends/bokeh/bpvplot.py +4 -3
- arviz/plots/backends/bokeh/densityplot.py +5 -1
- arviz/plots/backends/bokeh/dotplot.py +5 -2
- arviz/plots/backends/bokeh/essplot.py +4 -2
- arviz/plots/backends/bokeh/forestplot.py +11 -4
- arviz/plots/backends/bokeh/khatplot.py +4 -2
- arviz/plots/backends/bokeh/lmplot.py +9 -3
- arviz/plots/backends/bokeh/mcseplot.py +2 -2
- arviz/plots/backends/bokeh/pairplot.py +10 -5
- arviz/plots/backends/bokeh/ppcplot.py +2 -1
- arviz/plots/backends/bokeh/rankplot.py +2 -1
- arviz/plots/backends/bokeh/traceplot.py +2 -1
- arviz/plots/backends/bokeh/violinplot.py +2 -1
- arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- arviz/plots/bfplot.py +9 -26
- arviz/plots/bpvplot.py +10 -1
- arviz/plots/compareplot.py +4 -4
- arviz/plots/ecdfplot.py +16 -8
- arviz/plots/forestplot.py +2 -2
- arviz/plots/hdiplot.py +5 -0
- arviz/plots/kdeplot.py +9 -2
- arviz/plots/plot_utils.py +5 -3
- arviz/preview.py +36 -5
- arviz/stats/__init__.py +1 -0
- arviz/stats/diagnostics.py +18 -14
- arviz/stats/ecdf_utils.py +157 -2
- arviz/stats/stats.py +99 -7
- arviz/tests/base_tests/test_data.py +41 -7
- arviz/tests/base_tests/test_diagnostics.py +5 -4
- arviz/tests/base_tests/test_plots_matplotlib.py +32 -13
- arviz/tests/base_tests/test_stats.py +11 -0
- arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
- arviz/utils.py +4 -0
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/METADATA +22 -22
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/RECORD +43 -43
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/WHEEL +1 -1
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/LICENSE +0 -0
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/top_level.txt +0 -0
arviz/stats/stats.py
CHANGED
|
@@ -27,6 +27,7 @@ from ..utils import Numba, _numba_var, _var_names, get_coords
|
|
|
27
27
|
from .density_utils import get_bins as _get_bins
|
|
28
28
|
from .density_utils import histogram as _histogram
|
|
29
29
|
from .density_utils import kde as _kde
|
|
30
|
+
from .density_utils import _kde_linear
|
|
30
31
|
from .diagnostics import _mc_error, _multichain_statistics, ess
|
|
31
32
|
from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
|
|
32
33
|
from .stats_utils import get_log_likelihood as _get_log_likelihood
|
|
@@ -41,6 +42,7 @@ from ..labels import BaseLabeller
|
|
|
41
42
|
|
|
42
43
|
__all__ = [
|
|
43
44
|
"apply_test_function",
|
|
45
|
+
"bayes_factor",
|
|
44
46
|
"compare",
|
|
45
47
|
"hdi",
|
|
46
48
|
"loo",
|
|
@@ -711,16 +713,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
711
713
|
Returns
|
|
712
714
|
-------
|
|
713
715
|
ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
|
|
714
|
-
|
|
716
|
+
elpd_loo: approximated expected log pointwise predictive density (elpd)
|
|
715
717
|
se: standard error of the elpd
|
|
716
718
|
p_loo: effective number of parameters
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
719
|
+
n_samples: number of samples
|
|
720
|
+
n_data_points: number of data points
|
|
721
|
+
warning: bool
|
|
722
|
+
True if the estimated shape parameter of Pareto distribution is greater than
|
|
723
|
+
``good_k``.
|
|
724
|
+
loo_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
725
|
+
only if pointwise=True
|
|
722
726
|
pareto_k: array of Pareto shape values, only if pointwise True
|
|
723
727
|
scale: scale of the elpd
|
|
728
|
+
good_k: For a sample size S, the thresold is compute as min(1 - 1/log10(S), 0.7)
|
|
724
729
|
|
|
725
730
|
The returned object has a custom print method that overrides pd.Series method.
|
|
726
731
|
|
|
@@ -914,6 +919,7 @@ def psislw(log_weights, reff=1.0):
|
|
|
914
919
|
...: az.psislw(-log_likelihood, reff=0.8)
|
|
915
920
|
|
|
916
921
|
"""
|
|
922
|
+
log_weights = deepcopy(log_weights)
|
|
917
923
|
if hasattr(log_weights, "__sample__"):
|
|
918
924
|
n_samples = len(log_weights.__sample__)
|
|
919
925
|
shape = [
|
|
@@ -1580,7 +1586,9 @@ def waic(data, pointwise=None, var_name=None, scale=None, dask_kwargs=None):
|
|
|
1580
1586
|
elpd_waic: approximated expected log pointwise predictive density (elpd)
|
|
1581
1587
|
se: standard error of the elpd
|
|
1582
1588
|
p_waic: effective number parameters
|
|
1583
|
-
|
|
1589
|
+
n_samples: number of samples
|
|
1590
|
+
n_data_points: number of data points
|
|
1591
|
+
warning: bool
|
|
1584
1592
|
True if posterior variance of the log predictive densities exceeds 0.4
|
|
1585
1593
|
waic_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
1586
1594
|
only if pointwise=True
|
|
@@ -2331,3 +2339,87 @@ def _cjs_dist(draws, weights):
|
|
|
2331
2339
|
bound = cdf_p_int + cdf_q_int
|
|
2332
2340
|
|
|
2333
2341
|
return np.sqrt((cjs_pq + cjs_qp) / bound)
|
|
2342
|
+
|
|
2343
|
+
|
|
2344
|
+
def bayes_factor(idata, var_name, ref_val=0, prior=None, return_ref_vals=False):
|
|
2345
|
+
r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
|
|
2346
|
+
|
|
2347
|
+
The Bayes factor is estimated by comparing a model (H1) against a model in which the
|
|
2348
|
+
parameter of interest has been restricted to be a point-null (H0). This computation
|
|
2349
|
+
assumes the models are nested and thus H0 is a special case of H1.
|
|
2350
|
+
|
|
2351
|
+
Notes
|
|
2352
|
+
-----
|
|
2353
|
+
The bayes Factor is approximated as the Savage-Dickey density ratio
|
|
2354
|
+
algorithm presented in [1]_.
|
|
2355
|
+
|
|
2356
|
+
Parameters
|
|
2357
|
+
----------
|
|
2358
|
+
idata : InferenceData
|
|
2359
|
+
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
2360
|
+
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
2361
|
+
var_name : str, optional
|
|
2362
|
+
Name of variable we want to test.
|
|
2363
|
+
ref_val : int, default 0
|
|
2364
|
+
Point-null for Bayes factor estimation.
|
|
2365
|
+
prior : numpy.array, optional
|
|
2366
|
+
In case we want to use different prior, for example for sensitivity analysis.
|
|
2367
|
+
return_ref_vals : bool, optional
|
|
2368
|
+
Whether to return the values of the prior and posterior at the reference value.
|
|
2369
|
+
Used by :func:`arviz.plot_bf` to display the distribution comparison.
|
|
2370
|
+
|
|
2371
|
+
|
|
2372
|
+
Returns
|
|
2373
|
+
-------
|
|
2374
|
+
dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
|
|
2375
|
+
|
|
2376
|
+
References
|
|
2377
|
+
----------
|
|
2378
|
+
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
|
|
2379
|
+
The case of computing Bayes factors for regression parameters.
|
|
2380
|
+
|
|
2381
|
+
Examples
|
|
2382
|
+
--------
|
|
2383
|
+
Moderate evidence indicating that the parameter "a" is different from zero.
|
|
2384
|
+
|
|
2385
|
+
.. ipython::
|
|
2386
|
+
|
|
2387
|
+
In [1]: import numpy as np
|
|
2388
|
+
...: import arviz as az
|
|
2389
|
+
...: idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
|
|
2390
|
+
...: prior={"a":np.random.normal(0, 1, 5000)})
|
|
2391
|
+
...: az.bayes_factor(idata, var_name="a", ref_val=0)
|
|
2392
|
+
|
|
2393
|
+
"""
|
|
2394
|
+
|
|
2395
|
+
posterior = extract(idata, var_names=var_name).values
|
|
2396
|
+
|
|
2397
|
+
if ref_val > posterior.max() or ref_val < posterior.min():
|
|
2398
|
+
_log.warning(
|
|
2399
|
+
"The reference value is outside of the posterior. "
|
|
2400
|
+
"This translate into infinite support for H1, which is most likely an overstatement."
|
|
2401
|
+
)
|
|
2402
|
+
|
|
2403
|
+
if posterior.ndim > 1:
|
|
2404
|
+
_log.warning("Posterior distribution has {posterior.ndim} dimensions")
|
|
2405
|
+
|
|
2406
|
+
if prior is None:
|
|
2407
|
+
prior = extract(idata, var_names=var_name, group="prior").values
|
|
2408
|
+
|
|
2409
|
+
if posterior.dtype.kind == "f":
|
|
2410
|
+
posterior_grid, posterior_pdf, *_ = _kde_linear(posterior)
|
|
2411
|
+
prior_grid, prior_pdf, *_ = _kde_linear(prior)
|
|
2412
|
+
posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
|
|
2413
|
+
prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)
|
|
2414
|
+
|
|
2415
|
+
elif posterior.dtype.kind == "i":
|
|
2416
|
+
posterior_at_ref_val = (posterior == ref_val).mean()
|
|
2417
|
+
prior_at_ref_val = (prior == ref_val).mean()
|
|
2418
|
+
|
|
2419
|
+
bf_10 = prior_at_ref_val / posterior_at_ref_val
|
|
2420
|
+
bf = {"BF10": bf_10, "BF01": 1 / bf_10}
|
|
2421
|
+
|
|
2422
|
+
if return_ref_vals:
|
|
2423
|
+
return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
|
|
2424
|
+
else:
|
|
2425
|
+
return bf
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import importlib
|
|
5
5
|
import os
|
|
6
|
+
import warnings
|
|
6
7
|
from collections import namedtuple
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from html import escape
|
|
@@ -32,7 +33,13 @@ from ... import (
|
|
|
32
33
|
extract,
|
|
33
34
|
)
|
|
34
35
|
|
|
35
|
-
from ...data.base import
|
|
36
|
+
from ...data.base import (
|
|
37
|
+
dict_to_dataset,
|
|
38
|
+
generate_dims_coords,
|
|
39
|
+
infer_stan_dtypes,
|
|
40
|
+
make_attrs,
|
|
41
|
+
numpy_to_data_array,
|
|
42
|
+
)
|
|
36
43
|
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
|
|
37
44
|
from ..helpers import ( # pylint: disable=unused-import
|
|
38
45
|
chains,
|
|
@@ -44,6 +51,10 @@ from ..helpers import ( # pylint: disable=unused-import
|
|
|
44
51
|
models,
|
|
45
52
|
)
|
|
46
53
|
|
|
54
|
+
# Check if dm-tree is installed
|
|
55
|
+
dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
|
|
56
|
+
skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
57
|
+
|
|
47
58
|
|
|
48
59
|
@pytest.fixture(autouse=True)
|
|
49
60
|
def no_remote_data(monkeypatch, tmpdir):
|
|
@@ -227,6 +238,17 @@ def test_dims_coords_skip_event_dims(shape):
|
|
|
227
238
|
assert "z" not in coords
|
|
228
239
|
|
|
229
240
|
|
|
241
|
+
@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
|
|
242
|
+
def test_numpy_to_data_array_with_dims(dims):
|
|
243
|
+
da = numpy_to_data_array(
|
|
244
|
+
np.empty((4, 500, 7)),
|
|
245
|
+
var_name="a",
|
|
246
|
+
dims=dims,
|
|
247
|
+
default_dims=["chain", "draw"],
|
|
248
|
+
)
|
|
249
|
+
assert list(da.dims) == ["chain", "draw", "a_dim_0"]
|
|
250
|
+
|
|
251
|
+
|
|
230
252
|
def test_make_attrs():
|
|
231
253
|
extra_attrs = {"key": "Value"}
|
|
232
254
|
attrs = make_attrs(attrs=extra_attrs)
|
|
@@ -895,6 +917,11 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
895
917
|
assert escape(repr(idata)) in html
|
|
896
918
|
xr.set_options(display_style=display_style)
|
|
897
919
|
|
|
920
|
+
def test_setitem(self, data_random):
|
|
921
|
+
data_random["new_group"] = data_random.posterior
|
|
922
|
+
assert "new_group" in data_random.groups()
|
|
923
|
+
assert hasattr(data_random, "new_group")
|
|
924
|
+
|
|
898
925
|
def test_add_groups(self, data_random):
|
|
899
926
|
data = np.random.normal(size=(4, 500, 8))
|
|
900
927
|
idata = data_random
|
|
@@ -912,7 +939,7 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
912
939
|
data = np.random.normal(size=(4, 500, 8))
|
|
913
940
|
idata = data_random
|
|
914
941
|
with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
|
|
915
|
-
idata.add_groups({"new_group": idata.posterior})
|
|
942
|
+
idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
|
|
916
943
|
with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
|
|
917
944
|
idata.add_groups(constant_data={"a": data[..., 0], "b": data})
|
|
918
945
|
assert idata.new_group.equals(idata.posterior)
|
|
@@ -953,8 +980,8 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
953
980
|
with pytest.raises(ValueError, match="join must be either"):
|
|
954
981
|
idata.extend(idata2, join="outer")
|
|
955
982
|
idata2.add_groups(new_group=idata2.prior)
|
|
956
|
-
with pytest.warns(UserWarning):
|
|
957
|
-
idata.extend(idata2)
|
|
983
|
+
with pytest.warns(UserWarning, match="new_group"):
|
|
984
|
+
idata.extend(idata2, warn_on_custom_groups=True)
|
|
958
985
|
|
|
959
986
|
|
|
960
987
|
class TestNumpyToDataArray:
|
|
@@ -1076,6 +1103,7 @@ def test_dict_to_dataset():
|
|
|
1076
1103
|
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1077
1104
|
|
|
1078
1105
|
|
|
1106
|
+
@pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
|
|
1079
1107
|
def test_nested_dict_to_dataset():
|
|
1080
1108
|
datadict = {
|
|
1081
1109
|
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
@@ -1146,11 +1174,17 @@ def test_bad_inference_data():
|
|
|
1146
1174
|
InferenceData(posterior=[1, 2, 3])
|
|
1147
1175
|
|
|
1148
1176
|
|
|
1149
|
-
|
|
1177
|
+
@pytest.mark.parametrize("warn", [True, False])
|
|
1178
|
+
def test_inference_data_other_groups(warn):
|
|
1150
1179
|
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
|
|
1151
1180
|
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
|
|
1152
|
-
|
|
1153
|
-
|
|
1181
|
+
if warn:
|
|
1182
|
+
with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
|
|
1183
|
+
idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
|
|
1184
|
+
else:
|
|
1185
|
+
with warnings.catch_warnings():
|
|
1186
|
+
warnings.simplefilter("error")
|
|
1187
|
+
idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
|
|
1154
1188
|
fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
|
|
1155
1189
|
assert not fails
|
|
1156
1190
|
|
|
@@ -120,10 +120,11 @@ class TestDiagnostics:
|
|
|
120
120
|
```
|
|
121
121
|
Reference file:
|
|
122
122
|
|
|
123
|
-
Created:
|
|
124
|
-
System: Ubuntu
|
|
125
|
-
R version 4.
|
|
126
|
-
posterior
|
|
123
|
+
Created: 2024-12-20
|
|
124
|
+
System: Ubuntu 24.04.1 LTS
|
|
125
|
+
R version 4.4.2 (2024-10-31)
|
|
126
|
+
posterior version from https://github.com/stan-dev/posterior/pull/388
|
|
127
|
+
(after release 1.6.0 but before the fixes in the PR were released).
|
|
127
128
|
"""
|
|
128
129
|
# download input files
|
|
129
130
|
here = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -2,21 +2,22 @@
|
|
|
2
2
|
|
|
3
3
|
# pylint: disable=redefined-outer-name,too-many-lines
|
|
4
4
|
import os
|
|
5
|
+
import re
|
|
5
6
|
from copy import deepcopy
|
|
6
7
|
|
|
7
8
|
import matplotlib.pyplot as plt
|
|
8
9
|
import numpy as np
|
|
9
10
|
import pytest
|
|
11
|
+
import xarray as xr
|
|
10
12
|
from matplotlib import animation
|
|
11
13
|
from pandas import DataFrame
|
|
12
14
|
from scipy.stats import gaussian_kde, norm
|
|
13
|
-
import xarray as xr
|
|
14
15
|
|
|
15
16
|
from ...data import from_dict, load_arviz_data
|
|
16
17
|
from ...plots import (
|
|
17
18
|
plot_autocorr,
|
|
18
|
-
plot_bpv,
|
|
19
19
|
plot_bf,
|
|
20
|
+
plot_bpv,
|
|
20
21
|
plot_compare,
|
|
21
22
|
plot_density,
|
|
22
23
|
plot_dist,
|
|
@@ -43,20 +44,20 @@ from ...plots import (
|
|
|
43
44
|
plot_ts,
|
|
44
45
|
plot_violin,
|
|
45
46
|
)
|
|
47
|
+
from ...plots.dotplot import wilkinson_algorithm
|
|
48
|
+
from ...plots.plot_utils import plot_point_interval
|
|
46
49
|
from ...rcparams import rc_context, rcParams
|
|
47
50
|
from ...stats import compare, hdi, loo, waic
|
|
48
51
|
from ...stats.density_utils import kde as _kde
|
|
49
|
-
from ...utils import
|
|
50
|
-
from ...plots.plot_utils import plot_point_interval
|
|
51
|
-
from ...plots.dotplot import wilkinson_algorithm
|
|
52
|
+
from ...utils import BehaviourChangeWarning, _cov
|
|
52
53
|
from ..helpers import ( # pylint: disable=unused-import
|
|
54
|
+
RandomVariableTestClass,
|
|
53
55
|
create_model,
|
|
54
56
|
create_multidimensional_model,
|
|
55
57
|
does_not_warn,
|
|
56
58
|
eight_schools_params,
|
|
57
59
|
models,
|
|
58
60
|
multidim_models,
|
|
59
|
-
RandomVariableTestClass,
|
|
60
61
|
)
|
|
61
62
|
|
|
62
63
|
rcParams["data.load"] = "eager"
|
|
@@ -1236,6 +1237,23 @@ def test_plot_hdi_dataset_error(models):
|
|
|
1236
1237
|
plot_hdi(np.arange(8), hdi_data=hdi_data)
|
|
1237
1238
|
|
|
1238
1239
|
|
|
1240
|
+
def test_plot_hdi_string_error():
|
|
1241
|
+
"""Check x as type string raises an error."""
|
|
1242
|
+
x_data = ["a", "b", "c", "d"]
|
|
1243
|
+
y_data = np.random.normal(0, 5, (1, 200, len(x_data)))
|
|
1244
|
+
hdi_data = hdi(y_data)
|
|
1245
|
+
with pytest.raises(
|
|
1246
|
+
NotImplementedError,
|
|
1247
|
+
match=re.escape(
|
|
1248
|
+
(
|
|
1249
|
+
"The `arviz.plot_hdi()` function does not support categorical data. "
|
|
1250
|
+
"Consider using `arviz.plot_forest()`."
|
|
1251
|
+
)
|
|
1252
|
+
),
|
|
1253
|
+
):
|
|
1254
|
+
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
|
|
1255
|
+
|
|
1256
|
+
|
|
1239
1257
|
def test_plot_hdi_datetime_error():
|
|
1240
1258
|
"""Check x as datetime raises an error."""
|
|
1241
1259
|
x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)
|
|
@@ -1285,10 +1303,11 @@ def test_plot_ecdf_eval_points():
|
|
|
1285
1303
|
assert axes is not None
|
|
1286
1304
|
|
|
1287
1305
|
|
|
1288
|
-
@pytest.mark.parametrize("confidence_bands", [True, "pointwise", "simulated"])
|
|
1289
|
-
|
|
1306
|
+
@pytest.mark.parametrize("confidence_bands", [True, "pointwise", "optimized", "simulated"])
|
|
1307
|
+
@pytest.mark.parametrize("ndraws", [100, 10_000])
|
|
1308
|
+
def test_plot_ecdf_confidence_bands(confidence_bands, ndraws):
|
|
1290
1309
|
"""Check that all confidence_bands values correctly accepted"""
|
|
1291
|
-
data = np.random.randn(4,
|
|
1310
|
+
data = np.random.randn(4, ndraws // 4)
|
|
1292
1311
|
axes = plot_ecdf(data, confidence_bands=confidence_bands, cdf=norm(0, 1).cdf)
|
|
1293
1312
|
assert axes is not None
|
|
1294
1313
|
|
|
@@ -1326,6 +1345,8 @@ def test_plot_ecdf_error():
|
|
|
1326
1345
|
# contradictory confidence band types
|
|
1327
1346
|
with pytest.raises(ValueError):
|
|
1328
1347
|
plot_ecdf(data, cdf=dist.cdf, confidence_bands="simulated", pointwise=True)
|
|
1348
|
+
with pytest.raises(ValueError):
|
|
1349
|
+
plot_ecdf(data, cdf=dist.cdf, confidence_bands="optimized", pointwise=True)
|
|
1329
1350
|
plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, pointwise=True)
|
|
1330
1351
|
plot_ecdf(data, cdf=dist.cdf, confidence_bands="pointwise")
|
|
1331
1352
|
|
|
@@ -2079,7 +2100,5 @@ def test_plot_bf():
|
|
|
2079
2100
|
idata = from_dict(
|
|
2080
2101
|
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
2081
2102
|
)
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
2085
|
-
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
2103
|
+
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
|
|
2104
|
+
assert bf_plot is not None
|
|
@@ -18,6 +18,7 @@ from ...data import concat, convert_to_inference_data, from_dict, load_arviz_dat
|
|
|
18
18
|
from ...rcparams import rcParams
|
|
19
19
|
from ...stats import (
|
|
20
20
|
apply_test_function,
|
|
21
|
+
bayes_factor,
|
|
21
22
|
compare,
|
|
22
23
|
ess,
|
|
23
24
|
hdi,
|
|
@@ -871,3 +872,13 @@ def test_priorsens_coords(psens_data):
|
|
|
871
872
|
assert "mu" in result
|
|
872
873
|
assert "theta" in result
|
|
873
874
|
assert "school" in result.theta_t.dims
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
def test_bayes_factor():
|
|
878
|
+
idata = from_dict(
|
|
879
|
+
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
|
|
880
|
+
)
|
|
881
|
+
bf_dict0 = bayes_factor(idata, var_name="a", ref_val=0)
|
|
882
|
+
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
|
|
883
|
+
assert bf_dict0["BF10"] > bf_dict0["BF01"]
|
|
884
|
+
assert bf_dict1["BF10"] < bf_dict1["BF01"]
|
|
@@ -10,6 +10,13 @@ from ...stats.ecdf_utils import (
|
|
|
10
10
|
_get_pointwise_confidence_band,
|
|
11
11
|
)
|
|
12
12
|
|
|
13
|
+
try:
|
|
14
|
+
import numba # pylint: disable=unused-import
|
|
15
|
+
|
|
16
|
+
numba_options = [True, False]
|
|
17
|
+
except ImportError:
|
|
18
|
+
numba_options = [False]
|
|
19
|
+
|
|
13
20
|
|
|
14
21
|
def test_compute_ecdf():
|
|
15
22
|
"""Test compute_ecdf function."""
|
|
@@ -109,9 +116,15 @@ def test_get_pointwise_confidence_band(dist, prob, ndraws, num_trials=1_000, see
|
|
|
109
116
|
ids=["continuous", "continuous default rvs", "discrete"],
|
|
110
117
|
)
|
|
111
118
|
@pytest.mark.parametrize("ndraws", [10_000])
|
|
112
|
-
@pytest.mark.parametrize("method", ["pointwise", "simulated"])
|
|
113
|
-
|
|
119
|
+
@pytest.mark.parametrize("method", ["pointwise", "optimized", "simulated"])
|
|
120
|
+
@pytest.mark.parametrize("use_numba", numba_options)
|
|
121
|
+
def test_ecdf_confidence_band(
|
|
122
|
+
dist, rvs, prob, ndraws, method, use_numba, num_trials=1_000, seed=57
|
|
123
|
+
):
|
|
114
124
|
"""Test test_ecdf_confidence_band."""
|
|
125
|
+
if use_numba and method != "optimized":
|
|
126
|
+
pytest.skip("Numba only used in optimized method")
|
|
127
|
+
|
|
115
128
|
eval_points = np.linspace(*dist.interval(0.99), 10)
|
|
116
129
|
cdf_at_eval_points = dist.cdf(eval_points)
|
|
117
130
|
random_state = np.random.default_rng(seed)
|
arviz/utils.py
CHANGED
|
@@ -330,6 +330,7 @@ class Numba:
|
|
|
330
330
|
"""A class to toggle numba states."""
|
|
331
331
|
|
|
332
332
|
numba_flag = numba_check()
|
|
333
|
+
"""bool: Indicates whether Numba optimizations are enabled. Defaults to False."""
|
|
333
334
|
|
|
334
335
|
@classmethod
|
|
335
336
|
def disable_numba(cls):
|
|
@@ -732,7 +733,10 @@ class Dask:
|
|
|
732
733
|
"""
|
|
733
734
|
|
|
734
735
|
dask_flag = False
|
|
736
|
+
"""bool: Enables Dask parallelization when set to True. Defaults to False."""
|
|
735
737
|
dask_kwargs = None
|
|
738
|
+
"""dict: Additional keyword arguments for Dask configuration.
|
|
739
|
+
Defaults to an empty dictionary."""
|
|
736
740
|
|
|
737
741
|
@classmethod
|
|
738
742
|
def enable_dask(cls, dask_kwargs=None):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -21,30 +21,30 @@ Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
|
21
21
|
Requires-Python: >=3.10
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
|
-
Requires-Dist: setuptools
|
|
25
|
-
Requires-Dist: matplotlib
|
|
26
|
-
Requires-Dist: numpy
|
|
27
|
-
Requires-Dist: scipy
|
|
24
|
+
Requires-Dist: setuptools>=60.0.0
|
|
25
|
+
Requires-Dist: matplotlib>=3.5
|
|
26
|
+
Requires-Dist: numpy>=1.23.0
|
|
27
|
+
Requires-Dist: scipy>=1.9.0
|
|
28
28
|
Requires-Dist: packaging
|
|
29
|
-
Requires-Dist: pandas
|
|
30
|
-
Requires-Dist:
|
|
31
|
-
Requires-Dist:
|
|
32
|
-
Requires-Dist:
|
|
33
|
-
Requires-Dist:
|
|
34
|
-
Requires-Dist: xarray-einstats >=0.3
|
|
29
|
+
Requires-Dist: pandas>=1.5.0
|
|
30
|
+
Requires-Dist: xarray>=2022.6.0
|
|
31
|
+
Requires-Dist: h5netcdf>=1.0.2
|
|
32
|
+
Requires-Dist: typing-extensions>=4.1.0
|
|
33
|
+
Requires-Dist: xarray-einstats>=0.3
|
|
35
34
|
Provides-Extra: all
|
|
36
|
-
Requires-Dist: numba
|
|
37
|
-
Requires-Dist: netcdf4
|
|
38
|
-
Requires-Dist: bokeh
|
|
39
|
-
Requires-Dist: contourpy
|
|
40
|
-
Requires-Dist: ujson
|
|
41
|
-
Requires-Dist: dask[distributed]
|
|
42
|
-
Requires-Dist: zarr
|
|
43
|
-
Requires-Dist: xarray-datatree
|
|
35
|
+
Requires-Dist: numba; extra == "all"
|
|
36
|
+
Requires-Dist: netcdf4; extra == "all"
|
|
37
|
+
Requires-Dist: bokeh>=3; extra == "all"
|
|
38
|
+
Requires-Dist: contourpy; extra == "all"
|
|
39
|
+
Requires-Dist: ujson; extra == "all"
|
|
40
|
+
Requires-Dist: dask[distributed]; extra == "all"
|
|
41
|
+
Requires-Dist: zarr<3,>=2.5.0; extra == "all"
|
|
42
|
+
Requires-Dist: xarray-datatree; extra == "all"
|
|
43
|
+
Requires-Dist: dm-tree>=0.1.8; extra == "all"
|
|
44
44
|
Provides-Extra: preview
|
|
45
|
-
Requires-Dist: arviz-base[h5netcdf]
|
|
46
|
-
Requires-Dist: arviz-stats[xarray]
|
|
47
|
-
Requires-Dist: arviz-plots
|
|
45
|
+
Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
|
|
46
|
+
Requires-Dist: arviz-stats[xarray]; extra == "preview"
|
|
47
|
+
Requires-Dist: arviz-plots; extra == "preview"
|
|
48
48
|
|
|
49
49
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
|
|
50
50
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ_white.png#gh-dark-mode-only" width=200></img>
|