arviz 0.17.1__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +4 -2
- arviz/data/__init__.py +5 -2
- arviz/data/base.py +102 -11
- arviz/data/converters.py +5 -0
- arviz/data/datasets.py +1 -0
- arviz/data/example_data/data_remote.json +10 -3
- arviz/data/inference_data.py +20 -22
- arviz/data/io_cmdstan.py +5 -3
- arviz/data/io_datatree.py +1 -0
- arviz/data/io_dict.py +5 -3
- arviz/data/io_emcee.py +1 -0
- arviz/data/io_numpyro.py +2 -1
- arviz/data/io_pyjags.py +1 -0
- arviz/data/io_pyro.py +1 -0
- arviz/data/utils.py +1 -0
- arviz/plots/__init__.py +1 -0
- arviz/plots/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/bpvplot.py +1 -0
- arviz/plots/backends/bokeh/compareplot.py +1 -0
- arviz/plots/backends/bokeh/densityplot.py +1 -0
- arviz/plots/backends/bokeh/distplot.py +1 -0
- arviz/plots/backends/bokeh/dotplot.py +1 -0
- arviz/plots/backends/bokeh/ecdfplot.py +2 -2
- arviz/plots/backends/bokeh/elpdplot.py +1 -0
- arviz/plots/backends/bokeh/energyplot.py +1 -0
- arviz/plots/backends/bokeh/hdiplot.py +1 -0
- arviz/plots/backends/bokeh/kdeplot.py +3 -3
- arviz/plots/backends/bokeh/khatplot.py +9 -3
- arviz/plots/backends/bokeh/lmplot.py +1 -0
- arviz/plots/backends/bokeh/loopitplot.py +1 -0
- arviz/plots/backends/bokeh/mcseplot.py +1 -0
- arviz/plots/backends/bokeh/pairplot.py +3 -6
- arviz/plots/backends/bokeh/parallelplot.py +1 -0
- arviz/plots/backends/bokeh/posteriorplot.py +1 -0
- arviz/plots/backends/bokeh/ppcplot.py +1 -0
- arviz/plots/backends/bokeh/rankplot.py +1 -0
- arviz/plots/backends/bokeh/separationplot.py +1 -0
- arviz/plots/backends/bokeh/traceplot.py +1 -0
- arviz/plots/backends/bokeh/violinplot.py +1 -0
- arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
- arviz/plots/backends/matplotlib/bpvplot.py +1 -0
- arviz/plots/backends/matplotlib/compareplot.py +1 -0
- arviz/plots/backends/matplotlib/densityplot.py +1 -0
- arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
- arviz/plots/backends/matplotlib/distplot.py +1 -0
- arviz/plots/backends/matplotlib/dotplot.py +1 -0
- arviz/plots/backends/matplotlib/ecdfplot.py +2 -2
- arviz/plots/backends/matplotlib/elpdplot.py +1 -0
- arviz/plots/backends/matplotlib/energyplot.py +1 -0
- arviz/plots/backends/matplotlib/essplot.py +6 -5
- arviz/plots/backends/matplotlib/forestplot.py +1 -0
- arviz/plots/backends/matplotlib/hdiplot.py +1 -0
- arviz/plots/backends/matplotlib/kdeplot.py +5 -3
- arviz/plots/backends/matplotlib/khatplot.py +8 -3
- arviz/plots/backends/matplotlib/lmplot.py +1 -0
- arviz/plots/backends/matplotlib/loopitplot.py +1 -0
- arviz/plots/backends/matplotlib/mcseplot.py +11 -10
- arviz/plots/backends/matplotlib/pairplot.py +2 -1
- arviz/plots/backends/matplotlib/parallelplot.py +1 -0
- arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
- arviz/plots/backends/matplotlib/ppcplot.py +1 -0
- arviz/plots/backends/matplotlib/rankplot.py +1 -0
- arviz/plots/backends/matplotlib/separationplot.py +1 -0
- arviz/plots/backends/matplotlib/traceplot.py +2 -1
- arviz/plots/backends/matplotlib/tsplot.py +1 -0
- arviz/plots/backends/matplotlib/violinplot.py +2 -1
- arviz/plots/bpvplot.py +3 -2
- arviz/plots/compareplot.py +1 -0
- arviz/plots/densityplot.py +2 -1
- arviz/plots/distcomparisonplot.py +1 -0
- arviz/plots/dotplot.py +3 -2
- arviz/plots/ecdfplot.py +206 -89
- arviz/plots/elpdplot.py +1 -0
- arviz/plots/energyplot.py +1 -0
- arviz/plots/essplot.py +3 -2
- arviz/plots/forestplot.py +2 -1
- arviz/plots/hdiplot.py +3 -2
- arviz/plots/khatplot.py +24 -6
- arviz/plots/lmplot.py +1 -0
- arviz/plots/loopitplot.py +3 -2
- arviz/plots/mcseplot.py +4 -1
- arviz/plots/pairplot.py +1 -0
- arviz/plots/parallelplot.py +1 -0
- arviz/plots/plot_utils.py +3 -4
- arviz/plots/posteriorplot.py +2 -1
- arviz/plots/ppcplot.py +1 -0
- arviz/plots/rankplot.py +3 -2
- arviz/plots/separationplot.py +1 -0
- arviz/plots/traceplot.py +1 -0
- arviz/plots/tsplot.py +1 -0
- arviz/plots/violinplot.py +2 -1
- arviz/preview.py +17 -0
- arviz/rcparams.py +28 -2
- arviz/sel_utils.py +1 -0
- arviz/static/css/style.css +2 -1
- arviz/stats/density_utils.py +2 -1
- arviz/stats/diagnostics.py +15 -11
- arviz/stats/ecdf_utils.py +12 -8
- arviz/stats/stats.py +31 -16
- arviz/stats/stats_refitting.py +1 -0
- arviz/stats/stats_utils.py +13 -7
- arviz/tests/base_tests/test_data.py +15 -2
- arviz/tests/base_tests/test_data_zarr.py +0 -1
- arviz/tests/base_tests/test_diagnostics.py +1 -0
- arviz/tests/base_tests/test_diagnostics_numba.py +2 -6
- arviz/tests/base_tests/test_helpers.py +2 -2
- arviz/tests/base_tests/test_labels.py +1 -0
- arviz/tests/base_tests/test_plot_utils.py +5 -13
- arviz/tests/base_tests/test_plots_matplotlib.py +98 -7
- arviz/tests/base_tests/test_rcparams.py +12 -0
- arviz/tests/base_tests/test_stats.py +5 -5
- arviz/tests/base_tests/test_stats_numba.py +2 -7
- arviz/tests/base_tests/test_stats_utils.py +1 -0
- arviz/tests/base_tests/test_utils.py +3 -2
- arviz/tests/base_tests/test_utils_numba.py +2 -5
- arviz/tests/external_tests/test_data_pystan.py +5 -5
- arviz/tests/helpers.py +18 -10
- arviz/utils.py +4 -0
- arviz/wrappers/__init__.py +1 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/METADATA +13 -9
- arviz-0.19.0.dist-info/RECORD +183 -0
- arviz-0.17.1.dist-info/RECORD +0 -182
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/LICENSE +0 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/WHEEL +0 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/top_level.txt +0 -0
arviz/plots/rankplot.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Histograms of ranked posterior draws, plotted for each chain."""
|
|
2
|
+
|
|
2
3
|
from itertools import cycle
|
|
3
4
|
|
|
4
5
|
import matplotlib.pyplot as plt
|
|
@@ -45,8 +46,8 @@ def plot_rank(
|
|
|
45
46
|
indicates good mixing of the chains.
|
|
46
47
|
|
|
47
48
|
This plot was introduced by Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter,
|
|
48
|
-
Paul-Christian Burkner (
|
|
49
|
-
for assessing convergence of MCMC.
|
|
49
|
+
Paul-Christian Burkner (2021): Rank-normalization, folding, and localization:
|
|
50
|
+
An improved R-hat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
Parameters
|
arviz/plots/separationplot.py
CHANGED
arviz/plots/traceplot.py
CHANGED
arviz/plots/tsplot.py
CHANGED
arviz/plots/violinplot.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Plot posterior traces as violin plot."""
|
|
2
|
+
|
|
2
3
|
from ..data import convert_to_dataset
|
|
3
4
|
from ..labels import BaseLabeller
|
|
4
5
|
from ..sel_utils import xarray_var_iter
|
|
@@ -151,7 +152,7 @@ def plot_violin(
|
|
|
151
152
|
rows, cols = default_grid(len(plotters), grid=grid)
|
|
152
153
|
|
|
153
154
|
if hdi_prob is None:
|
|
154
|
-
hdi_prob = rcParams["stats.
|
|
155
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
155
156
|
elif not 1 >= hdi_prob > 0:
|
|
156
157
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
157
158
|
|
arviz/preview.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# pylint: disable=unused-import,unused-wildcard-import,wildcard-import
|
|
2
|
+
"""Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from arviz_base import *
|
|
6
|
+
except ModuleNotFoundError:
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import arviz_stats
|
|
11
|
+
except ModuleNotFoundError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from arviz_plots import *
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
pass
|
arviz/rcparams.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""ArviZ rcparams. Based on matplotlib's implementation."""
|
|
2
|
+
|
|
2
3
|
import locale
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
@@ -25,6 +26,8 @@ _log = logging.getLogger(__name__)
|
|
|
25
26
|
ScaleKeyword = Literal["log", "negative_log", "deviance"]
|
|
26
27
|
ICKeyword = Literal["loo", "waic"]
|
|
27
28
|
|
|
29
|
+
_identity = lambda x: x
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
|
|
30
33
|
"""Validate value is in accepted_values.
|
|
@@ -299,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
299
302
|
lambda x: x,
|
|
300
303
|
),
|
|
301
304
|
"plot.matplotlib.show": (False, _validate_boolean),
|
|
302
|
-
"stats.
|
|
305
|
+
"stats.ci_prob": (0.94, _validate_probability),
|
|
303
306
|
"stats.information_criterion": (
|
|
304
307
|
"loo",
|
|
305
308
|
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
|
|
@@ -317,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
317
320
|
),
|
|
318
321
|
}
|
|
319
322
|
|
|
323
|
+
# map from deprecated params to (version, new_param, fold2new, fnew2old)
|
|
324
|
+
deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
|
|
325
|
+
|
|
320
326
|
|
|
321
327
|
class RcParams(MutableMapping):
|
|
322
328
|
"""Class to contain ArviZ default parameters.
|
|
@@ -334,6 +340,15 @@ class RcParams(MutableMapping):
|
|
|
334
340
|
|
|
335
341
|
def __setitem__(self, key, val):
|
|
336
342
|
"""Add validation to __setitem__ function."""
|
|
343
|
+
if key in deprecated_map:
|
|
344
|
+
version, key_new, fold2new, _ = deprecated_map[key]
|
|
345
|
+
warnings.warn(
|
|
346
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
347
|
+
FutureWarning,
|
|
348
|
+
)
|
|
349
|
+
key = key_new
|
|
350
|
+
val = fold2new(val)
|
|
351
|
+
|
|
337
352
|
try:
|
|
338
353
|
try:
|
|
339
354
|
cval = self.validate[key](val)
|
|
@@ -348,7 +363,18 @@ class RcParams(MutableMapping):
|
|
|
348
363
|
|
|
349
364
|
def __getitem__(self, key):
|
|
350
365
|
"""Use underlying dict's getitem method."""
|
|
351
|
-
|
|
366
|
+
if key in deprecated_map:
|
|
367
|
+
version, key_new, _, fnew2old = deprecated_map[key]
|
|
368
|
+
warnings.warn(
|
|
369
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
370
|
+
FutureWarning,
|
|
371
|
+
)
|
|
372
|
+
if key not in self._underlying_storage:
|
|
373
|
+
key = key_new
|
|
374
|
+
else:
|
|
375
|
+
fnew2old = _identity
|
|
376
|
+
|
|
377
|
+
return fnew2old(self._underlying_storage[key])
|
|
352
378
|
|
|
353
379
|
def __delitem__(self, key):
|
|
354
380
|
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
|
arviz/sel_utils.py
CHANGED
arviz/static/css/style.css
CHANGED
arviz/stats/density_utils.py
CHANGED
|
@@ -5,7 +5,8 @@ import warnings
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from scipy.fftpack import fft
|
|
7
7
|
from scipy.optimize import brentq
|
|
8
|
-
from scipy.signal import convolve, convolve2d
|
|
8
|
+
from scipy.signal import convolve, convolve2d
|
|
9
|
+
from scipy.signal.windows import gaussian
|
|
9
10
|
from scipy.sparse import coo_matrix
|
|
10
11
|
from scipy.special import ive # pylint: disable=no-name-in-module
|
|
11
12
|
|
arviz/stats/diagnostics.py
CHANGED
|
@@ -135,10 +135,11 @@ def ess(
|
|
|
135
135
|
|
|
136
136
|
References
|
|
137
137
|
----------
|
|
138
|
-
* Vehtari et al. (
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
*
|
|
138
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
139
|
+
localization: An improved Rhat for assessing convergence of
|
|
140
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
141
|
+
* https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
|
|
142
|
+
* Gelman et al. BDA3 (2013) Formula 11.8
|
|
142
143
|
|
|
143
144
|
See Also
|
|
144
145
|
--------
|
|
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
246
247
|
Names of variables to include in the rhat report
|
|
247
248
|
method : str
|
|
248
249
|
Select R-hat method. Valid methods are:
|
|
249
|
-
- "rank" # recommended by Vehtari et al. (
|
|
250
|
+
- "rank" # recommended by Vehtari et al. (2021)
|
|
250
251
|
- "split"
|
|
251
252
|
- "folded"
|
|
252
253
|
- "z_scale"
|
|
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
269
270
|
-----
|
|
270
271
|
The diagnostic is computed by:
|
|
271
272
|
|
|
272
|
-
.. math:: \hat{R} = \frac{\hat{V}}{W}
|
|
273
|
+
.. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
|
|
273
274
|
|
|
274
275
|
where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
|
|
275
276
|
estimate for the pooled rank-traces. This is the potential scale reduction factor, which
|
|
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
277
278
|
greater than one indicate that one or more chains have not yet converged.
|
|
278
279
|
|
|
279
280
|
Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
|
|
280
|
-
Each chain is split in two and normalized with the z-transform following
|
|
281
|
+
Each chain is split in two and normalized with the z-transform following
|
|
282
|
+
Vehtari et al. (2021).
|
|
281
283
|
|
|
282
284
|
References
|
|
283
285
|
----------
|
|
284
|
-
* Vehtari et al. (
|
|
285
|
-
|
|
286
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
287
|
+
localization: An improved Rhat for assessing convergence of
|
|
288
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
289
|
+
* Gelman et al. BDA3 (2013)
|
|
286
290
|
* Brooks and Gelman (1998)
|
|
287
291
|
* Gelman and Rubin (1992)
|
|
288
292
|
|
|
@@ -836,7 +840,7 @@ def _mcse_sd(ary):
|
|
|
836
840
|
return np.nan
|
|
837
841
|
ess = _ess_sd(ary)
|
|
838
842
|
if _numba_flag:
|
|
839
|
-
sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
|
|
843
|
+
sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
|
|
840
844
|
else:
|
|
841
845
|
sd = np.std(ary, ddof=1)
|
|
842
846
|
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
|
|
@@ -904,7 +908,7 @@ def _mc_error(ary, batches=5, circular=False):
|
|
|
904
908
|
else:
|
|
905
909
|
std = stats.circstd(ary, high=np.pi, low=-np.pi)
|
|
906
910
|
elif _numba_flag:
|
|
907
|
-
std = float(_sqrt(svar(ary), np.zeros(1)))
|
|
911
|
+
std = float(_sqrt(svar(ary), np.zeros(1)).item())
|
|
908
912
|
else:
|
|
909
913
|
std = np.std(ary)
|
|
910
914
|
return std / np.sqrt(len(ary))
|
arviz/stats/ecdf_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Functions for evaluating ECDFs and their confidence bands."""
|
|
2
|
+
|
|
2
3
|
from typing import Any, Callable, Optional, Tuple
|
|
3
4
|
import warnings
|
|
4
5
|
|
|
@@ -24,6 +25,13 @@ def _get_ecdf_points(
|
|
|
24
25
|
return x, y
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
def _call_rvs(rvs, ndraws, random_state):
|
|
29
|
+
if random_state is None:
|
|
30
|
+
return rvs(ndraws)
|
|
31
|
+
else:
|
|
32
|
+
return rvs(ndraws, random_state=random_state)
|
|
33
|
+
|
|
34
|
+
|
|
27
35
|
def _simulate_ecdf(
|
|
28
36
|
ndraws: int,
|
|
29
37
|
eval_points: np.ndarray,
|
|
@@ -31,7 +39,7 @@ def _simulate_ecdf(
|
|
|
31
39
|
random_state: Optional[Any] = None,
|
|
32
40
|
) -> np.ndarray:
|
|
33
41
|
"""Simulate ECDF at the `eval_points` using the given random variable sampler"""
|
|
34
|
-
sample = rvs
|
|
42
|
+
sample = _call_rvs(rvs, ndraws, random_state)
|
|
35
43
|
sample.sort()
|
|
36
44
|
return compute_ecdf(sample, eval_points)
|
|
37
45
|
|
|
@@ -90,14 +98,10 @@ def ecdf_confidence_band(
|
|
|
90
98
|
A function that takes an integer `ndraws` and optionally the object passed to
|
|
91
99
|
`random_state` and returns an array of `ndraws` samples from the same distribution
|
|
92
100
|
as the original dataset. Required if `method` is "simulated" and variable is discrete.
|
|
93
|
-
num_trials : int, default
|
|
101
|
+
num_trials : int, default 500
|
|
94
102
|
The number of random ECDFs to generate for constructing simultaneous confidence bands
|
|
95
103
|
(if `method` is "simulated").
|
|
96
|
-
random_state :
|
|
97
|
-
`numpy.random.RandomState`}, optional
|
|
98
|
-
If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
|
|
99
|
-
``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
|
|
100
|
-
`Generator` instance, the instance is used.
|
|
104
|
+
random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
|
|
101
105
|
|
|
102
106
|
Returns
|
|
103
107
|
-------
|
|
@@ -131,7 +135,7 @@ def _simulate_simultaneous_ecdf_band_probability(
|
|
|
131
135
|
cdf_at_eval_points: np.ndarray,
|
|
132
136
|
prob: float = 0.95,
|
|
133
137
|
rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
|
|
134
|
-
num_trials: int =
|
|
138
|
+
num_trials: int = 500,
|
|
135
139
|
random_state: Optional[Any] = None,
|
|
136
140
|
) -> float:
|
|
137
141
|
"""Estimate probability for simultaneous confidence band using simulation.
|
arviz/stats/stats.py
CHANGED
|
@@ -270,12 +270,12 @@ def compare(
|
|
|
270
270
|
weights[i] = u_weights / np.sum(u_weights)
|
|
271
271
|
|
|
272
272
|
weights = weights.mean(axis=0)
|
|
273
|
-
ses = pd.Series(z_bs.std(axis=0), index=
|
|
273
|
+
ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
|
|
274
274
|
|
|
275
275
|
elif method.lower() == "pseudo-bma":
|
|
276
276
|
min_ic = ics.iloc[0][f"elpd_{ic}"]
|
|
277
277
|
z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
|
|
278
|
-
weights = z_rv / np.sum(z_rv)
|
|
278
|
+
weights = (z_rv / np.sum(z_rv)).to_numpy()
|
|
279
279
|
ses = ics["se"]
|
|
280
280
|
|
|
281
281
|
if np.any(weights):
|
|
@@ -471,7 +471,7 @@ def hdi(
|
|
|
471
471
|
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
472
472
|
hdi_prob: float, optional
|
|
473
473
|
Prob for which the highest density interval will be computed. Defaults to
|
|
474
|
-
``stats.
|
|
474
|
+
``stats.ci_prob`` rcParam.
|
|
475
475
|
circular: bool, optional
|
|
476
476
|
Whether to compute the hdi taking into account `x` is a circular variable
|
|
477
477
|
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
|
@@ -553,7 +553,7 @@ def hdi(
|
|
|
553
553
|
|
|
554
554
|
"""
|
|
555
555
|
if hdi_prob is None:
|
|
556
|
-
hdi_prob = rcParams["stats.
|
|
556
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
557
557
|
elif not 1 >= hdi_prob > 0:
|
|
558
558
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
559
559
|
|
|
@@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
715
715
|
se: standard error of the elpd
|
|
716
716
|
p_loo: effective number of parameters
|
|
717
717
|
shape_warn: bool
|
|
718
|
-
True if the estimated shape parameter of
|
|
719
|
-
|
|
718
|
+
True if the estimated shape parameter of Pareto distribution is greater than a thresold
|
|
719
|
+
value for one or more samples. For a sample size S, the thresold is compute as
|
|
720
|
+
min(1 - 1/log10(S), 0.7)
|
|
720
721
|
loo_i: array of pointwise predictive accuracy, only if pointwise True
|
|
721
722
|
pareto_k: array of Pareto shape values, only if pointwise True
|
|
722
723
|
scale: scale of the elpd
|
|
@@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
785
786
|
log_weights += log_likelihood
|
|
786
787
|
|
|
787
788
|
warn_mg = False
|
|
788
|
-
|
|
789
|
+
good_k = min(1 - 1 / np.log10(n_samples), 0.7)
|
|
790
|
+
|
|
791
|
+
if np.any(pareto_shape > good_k):
|
|
789
792
|
warnings.warn(
|
|
790
|
-
"Estimated shape parameter of Pareto distribution is greater than
|
|
791
|
-
"one or more samples. You should consider using a more robust model, this is
|
|
792
|
-
"importance sampling is less likely to work well if the marginal posterior
|
|
793
|
-
"LOO posterior are very different. This is more likely to happen with a
|
|
794
|
-
"model and highly influential observations."
|
|
793
|
+
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
|
|
794
|
+
"for one or more samples. You should consider using a more robust model, this is "
|
|
795
|
+
"because importance sampling is less likely to work well if the marginal posterior "
|
|
796
|
+
"and LOO posterior are very different. This is more likely to happen with a "
|
|
797
|
+
"non-robust model and highly influential observations."
|
|
795
798
|
)
|
|
796
799
|
warn_mg = True
|
|
797
800
|
|
|
@@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
816
819
|
|
|
817
820
|
if not pointwise:
|
|
818
821
|
return ELPDData(
|
|
819
|
-
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
|
|
820
|
-
index=[
|
|
822
|
+
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
|
|
823
|
+
index=[
|
|
824
|
+
"elpd_loo",
|
|
825
|
+
"se",
|
|
826
|
+
"p_loo",
|
|
827
|
+
"n_samples",
|
|
828
|
+
"n_data_points",
|
|
829
|
+
"warning",
|
|
830
|
+
"scale",
|
|
831
|
+
"good_k",
|
|
832
|
+
],
|
|
821
833
|
)
|
|
822
834
|
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
|
|
823
835
|
warnings.warn(
|
|
@@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
835
847
|
loo_lppd_i.rename("loo_i"),
|
|
836
848
|
pareto_shape,
|
|
837
849
|
scale,
|
|
850
|
+
good_k,
|
|
838
851
|
],
|
|
839
852
|
index=[
|
|
840
853
|
"elpd_loo",
|
|
@@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
846
859
|
"loo_i",
|
|
847
860
|
"pareto_k",
|
|
848
861
|
"scale",
|
|
862
|
+
"good_k",
|
|
849
863
|
],
|
|
850
864
|
)
|
|
851
865
|
|
|
@@ -879,7 +893,8 @@ def psislw(log_weights, reff=1.0):
|
|
|
879
893
|
|
|
880
894
|
References
|
|
881
895
|
----------
|
|
882
|
-
* Vehtari et al. (
|
|
896
|
+
* Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
|
|
897
|
+
Learning Research, 25(72):1-58.
|
|
883
898
|
|
|
884
899
|
See Also
|
|
885
900
|
--------
|
|
@@ -1322,7 +1337,7 @@ def summary(
|
|
|
1322
1337
|
if labeller is None:
|
|
1323
1338
|
labeller = BaseLabeller()
|
|
1324
1339
|
if hdi_prob is None:
|
|
1325
|
-
hdi_prob = rcParams["stats.
|
|
1340
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
1326
1341
|
elif not 1 >= hdi_prob > 0:
|
|
1327
1342
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
1328
1343
|
|
arviz/stats/stats_refitting.py
CHANGED
arviz/stats/stats_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Stats-utility functions for ArviZ."""
|
|
2
|
+
|
|
2
3
|
import warnings
|
|
3
4
|
from collections.abc import Sequence
|
|
4
5
|
from copy import copy as _copy
|
|
@@ -134,7 +135,10 @@ def make_ufunc(
|
|
|
134
135
|
raise TypeError(msg)
|
|
135
136
|
for idx in np.ndindex(out.shape[:n_dims_out]):
|
|
136
137
|
arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
|
|
137
|
-
|
|
138
|
+
out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
|
|
139
|
+
if n_dims_out is None:
|
|
140
|
+
out_idx = out_idx.item()
|
|
141
|
+
out[idx] = out_idx
|
|
138
142
|
return out
|
|
139
143
|
|
|
140
144
|
def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
|
|
@@ -450,10 +454,9 @@ POINTWISE_LOO_FMT = """------
|
|
|
450
454
|
|
|
451
455
|
Pareto k diagnostic values:
|
|
452
456
|
{{0:>{0}}} {{1:>6}}
|
|
453
|
-
(-Inf,
|
|
454
|
-
|
|
455
|
-
(
|
|
456
|
-
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
|
|
457
|
+
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
|
|
458
|
+
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
|
|
459
|
+
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
|
|
457
460
|
"""
|
|
458
461
|
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
|
|
459
462
|
|
|
@@ -484,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
|
484
487
|
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
485
488
|
|
|
486
489
|
if kind == "loo" and "pareto_k" in self:
|
|
487
|
-
bins = np.asarray([-np.inf,
|
|
490
|
+
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
|
|
488
491
|
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
489
492
|
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
490
493
|
extended = extended.format(
|
|
491
|
-
"Count",
|
|
494
|
+
"Count",
|
|
495
|
+
"Pct.",
|
|
496
|
+
*[*counts, *(counts / np.sum(counts) * 100)],
|
|
497
|
+
self.good_k,
|
|
492
498
|
)
|
|
493
499
|
base = "\n".join([base, extended])
|
|
494
500
|
return base
|
|
@@ -42,7 +42,6 @@ from ..helpers import ( # pylint: disable=unused-import
|
|
|
42
42
|
draws,
|
|
43
43
|
eight_schools_params,
|
|
44
44
|
models,
|
|
45
|
-
running_on_ci,
|
|
46
45
|
)
|
|
47
46
|
|
|
48
47
|
|
|
@@ -1077,6 +1076,20 @@ def test_dict_to_dataset():
|
|
|
1077
1076
|
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1078
1077
|
|
|
1079
1078
|
|
|
1079
|
+
def test_nested_dict_to_dataset():
|
|
1080
|
+
datadict = {
|
|
1081
|
+
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
1082
|
+
"d": np.random.randn(100),
|
|
1083
|
+
}
|
|
1084
|
+
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
|
|
1085
|
+
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
|
|
1086
|
+
assert set(dataset.coords) == {"chain", "draw", "c"}
|
|
1087
|
+
|
|
1088
|
+
assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
|
|
1089
|
+
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
|
|
1090
|
+
assert set(dataset.d.coords) == {"chain", "draw"}
|
|
1091
|
+
|
|
1092
|
+
|
|
1080
1093
|
def test_dict_to_dataset_event_dims_error():
|
|
1081
1094
|
datadict = {"a": np.random.randn(1, 100, 10)}
|
|
1082
1095
|
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
|
|
@@ -1455,7 +1468,7 @@ class TestJSON:
|
|
|
1455
1468
|
|
|
1456
1469
|
|
|
1457
1470
|
@pytest.mark.skipif(
|
|
1458
|
-
not (importlib.util.find_spec("datatree") or
|
|
1471
|
+
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1459
1472
|
reason="test requires xarray-datatree library",
|
|
1460
1473
|
)
|
|
1461
1474
|
class TestDataTree:
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
"""Test Diagnostic methods"""
|
|
2
|
-
import importlib
|
|
3
2
|
|
|
4
3
|
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
|
|
5
4
|
import numpy as np
|
|
@@ -10,13 +9,10 @@ from ...rcparams import rcParams
|
|
|
10
9
|
from ...stats import bfmi, mcse, rhat
|
|
11
10
|
from ...stats.diagnostics import _mc_error, ks_summary
|
|
12
11
|
from ...utils import Numba
|
|
13
|
-
from ..helpers import
|
|
12
|
+
from ..helpers import importorskip
|
|
14
13
|
from .test_diagnostics import data # pylint: disable=unused-import
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
(importlib.util.find_spec("numba") is None) and not running_on_ci(),
|
|
18
|
-
reason="test requires numba which is not installed",
|
|
19
|
-
)
|
|
15
|
+
importorskip("numba")
|
|
20
16
|
|
|
21
17
|
rcParams["data.load"] = "eager"
|
|
22
18
|
|
|
@@ -6,13 +6,13 @@ from ..helpers import importorskip
|
|
|
6
6
|
|
|
7
7
|
def test_importorskip_local(monkeypatch):
|
|
8
8
|
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
|
|
9
|
-
monkeypatch.delenv("
|
|
9
|
+
monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
|
|
10
10
|
with pytest.raises(Skipped):
|
|
11
11
|
importorskip("non-existent-function")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def test_importorskip_ci(monkeypatch):
|
|
15
15
|
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
|
|
16
|
-
monkeypatch.setenv("
|
|
16
|
+
monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
|
|
17
17
|
with pytest.raises(ModuleNotFoundError):
|
|
18
18
|
importorskip("non-existent-function")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# pylint: disable=redefined-outer-name
|
|
2
2
|
import importlib
|
|
3
|
+
import os
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pytest
|
|
@@ -20,10 +21,10 @@ from ...rcparams import rc_context
|
|
|
20
21
|
from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
|
|
21
22
|
from ...stats.density_utils import get_bins
|
|
22
23
|
from ...utils import get_coords
|
|
23
|
-
from ..helpers import running_on_ci
|
|
24
24
|
|
|
25
25
|
# Check if Bokeh is installed
|
|
26
26
|
bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
|
|
27
|
+
skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
@pytest.mark.parametrize(
|
|
@@ -212,10 +213,7 @@ def test_filter_plotter_list_warning():
|
|
|
212
213
|
assert len(plotters_filtered) == 5
|
|
213
214
|
|
|
214
215
|
|
|
215
|
-
@pytest.mark.skipif(
|
|
216
|
-
not (bokeh_installed or running_on_ci()),
|
|
217
|
-
reason="test requires bokeh which is not installed",
|
|
218
|
-
)
|
|
216
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
219
217
|
def test_bokeh_import():
|
|
220
218
|
"""Tests that correct method is returned on bokeh import"""
|
|
221
219
|
plot = get_plotting_function("plot_dist", "distplot", "bokeh")
|
|
@@ -290,10 +288,7 @@ def test_mpl_dealiase_sel_kwargs():
|
|
|
290
288
|
assert res["line_color"] == "red"
|
|
291
289
|
|
|
292
290
|
|
|
293
|
-
@pytest.mark.skipif(
|
|
294
|
-
not (bokeh_installed or running_on_ci()),
|
|
295
|
-
reason="test requires bokeh which is not installed",
|
|
296
|
-
)
|
|
291
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
297
292
|
def test_bokeh_dealiase_sel_kwargs():
|
|
298
293
|
"""Check bokeh dealiase_sel_kwargs behaviour.
|
|
299
294
|
|
|
@@ -315,10 +310,7 @@ def test_bokeh_dealiase_sel_kwargs():
|
|
|
315
310
|
assert res["line_color"] == "red"
|
|
316
311
|
|
|
317
312
|
|
|
318
|
-
@pytest.mark.skipif(
|
|
319
|
-
not (bokeh_installed or running_on_ci()),
|
|
320
|
-
reason="test requires bokeh which is not installed",
|
|
321
|
-
)
|
|
313
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
322
314
|
def test_set_bokeh_circular_ticks_labels():
|
|
323
315
|
"""Assert the axes returned after placing ticks and tick labels for circular plots."""
|
|
324
316
|
import bokeh.plotting as bkp
|