arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/stats/stats.py
DELETED
|
@@ -1,2422 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines
|
|
2
|
-
"""Statistical functions in ArviZ."""
|
|
3
|
-
|
|
4
|
-
import warnings
|
|
5
|
-
from copy import deepcopy
|
|
6
|
-
from typing import List, Optional, Tuple, Union, Mapping, cast, Callable
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import pandas as pd
|
|
10
|
-
import scipy.stats as st
|
|
11
|
-
from xarray_einstats import stats
|
|
12
|
-
import xarray as xr
|
|
13
|
-
from scipy.optimize import minimize, LinearConstraint, Bounds
|
|
14
|
-
from typing_extensions import Literal
|
|
15
|
-
|
|
16
|
-
NO_GET_ARGS: bool = False # pylint: disable=invalid-name
|
|
17
|
-
try:
|
|
18
|
-
from typing_extensions import get_args
|
|
19
|
-
except ImportError:
|
|
20
|
-
NO_GET_ARGS = True # pylint: disable=invalid-name
|
|
21
|
-
|
|
22
|
-
from .. import _log
|
|
23
|
-
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
|
|
24
|
-
from ..rcparams import rcParams, ScaleKeyword, ICKeyword
|
|
25
|
-
from ..utils import Numba, _numba_var, _var_names, get_coords
|
|
26
|
-
from .density_utils import get_bins as _get_bins
|
|
27
|
-
from .density_utils import histogram as _histogram
|
|
28
|
-
from .density_utils import kde as _kde
|
|
29
|
-
from .density_utils import _kde_linear
|
|
30
|
-
from .diagnostics import _mc_error, _multichain_statistics, ess
|
|
31
|
-
from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
|
|
32
|
-
from .stats_utils import get_log_likelihood as _get_log_likelihood
|
|
33
|
-
from .stats_utils import get_log_prior as _get_log_prior
|
|
34
|
-
from .stats_utils import logsumexp as _logsumexp
|
|
35
|
-
from .stats_utils import make_ufunc as _make_ufunc
|
|
36
|
-
from .stats_utils import stats_variance_2d as svar
|
|
37
|
-
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
|
|
38
|
-
from ..sel_utils import xarray_var_iter
|
|
39
|
-
from ..labels import BaseLabeller
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
__all__ = [
|
|
43
|
-
"apply_test_function",
|
|
44
|
-
"bayes_factor",
|
|
45
|
-
"compare",
|
|
46
|
-
"hdi",
|
|
47
|
-
"loo",
|
|
48
|
-
"loo_pit",
|
|
49
|
-
"psislw",
|
|
50
|
-
"r2_samples",
|
|
51
|
-
"r2_score",
|
|
52
|
-
"summary",
|
|
53
|
-
"waic",
|
|
54
|
-
"weight_predictions",
|
|
55
|
-
"_calculate_ics",
|
|
56
|
-
"psens",
|
|
57
|
-
]
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def compare(
|
|
61
|
-
compare_dict: Mapping[str, InferenceData],
|
|
62
|
-
ic: Optional[ICKeyword] = None,
|
|
63
|
-
method: Literal["stacking", "BB-pseudo-BMA", "pseudo-BMA"] = "stacking",
|
|
64
|
-
b_samples: int = 1000,
|
|
65
|
-
alpha: float = 1,
|
|
66
|
-
seed=None,
|
|
67
|
-
scale: Optional[ScaleKeyword] = None,
|
|
68
|
-
var_name: Optional[str] = None,
|
|
69
|
-
):
|
|
70
|
-
r"""Compare models based on their expected log pointwise predictive density (ELPD).
|
|
71
|
-
|
|
72
|
-
The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
|
|
73
|
-
cross-validation (LOO) or using the widely applicable information criterion (WAIC).
|
|
74
|
-
We recommend loo. Read more theory here - in a paper by some of the
|
|
75
|
-
leading authorities on model comparison dx.doi.org/10.1111/1467-9868.00353
|
|
76
|
-
|
|
77
|
-
Parameters
|
|
78
|
-
----------
|
|
79
|
-
compare_dict: dict of {str: InferenceData or ELPDData}
|
|
80
|
-
A dictionary of model names and :class:`arviz.InferenceData` or ``ELPDData``.
|
|
81
|
-
ic: str, optional
|
|
82
|
-
Method to estimate the ELPD, available options are "loo" or "waic". Defaults to
|
|
83
|
-
``rcParams["stats.information_criterion"]``.
|
|
84
|
-
method: str, optional
|
|
85
|
-
Method used to estimate the weights for each model. Available options are:
|
|
86
|
-
|
|
87
|
-
- 'stacking' : stacking of predictive distributions.
|
|
88
|
-
- 'BB-pseudo-BMA' : pseudo-Bayesian Model averaging using Akaike-type
|
|
89
|
-
weighting. The weights are stabilized using the Bayesian bootstrap.
|
|
90
|
-
- 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
|
|
91
|
-
weighting, without Bootstrap stabilization (not recommended).
|
|
92
|
-
|
|
93
|
-
For more information read https://arxiv.org/abs/1704.02030
|
|
94
|
-
b_samples: int, optional default = 1000
|
|
95
|
-
Number of samples taken by the Bayesian bootstrap estimation.
|
|
96
|
-
Only useful when method = 'BB-pseudo-BMA'.
|
|
97
|
-
Defaults to ``rcParams["stats.ic_compare_method"]``.
|
|
98
|
-
alpha: float, optional
|
|
99
|
-
The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
|
|
100
|
-
useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
|
|
101
|
-
on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
|
|
102
|
-
seed: int or np.random.RandomState instance, optional
|
|
103
|
-
If int or RandomState, use it for seeding Bayesian bootstrap. Only
|
|
104
|
-
useful when method = 'BB-pseudo-BMA'. Default None the global
|
|
105
|
-
:mod:`numpy.random` state is used.
|
|
106
|
-
scale: str, optional
|
|
107
|
-
Output scale for IC. Available options are:
|
|
108
|
-
|
|
109
|
-
- `log` : (default) log-score (after Vehtari et al. (2017))
|
|
110
|
-
- `negative_log` : -1 * (log-score)
|
|
111
|
-
- `deviance` : -2 * (log-score)
|
|
112
|
-
|
|
113
|
-
A higher log-score (or a lower deviance) indicates a model with better predictive
|
|
114
|
-
accuracy.
|
|
115
|
-
var_name: str, optional
|
|
116
|
-
If there is more than a single observed variable in the ``InferenceData``, which
|
|
117
|
-
should be used as the basis for comparison.
|
|
118
|
-
|
|
119
|
-
Returns
|
|
120
|
-
-------
|
|
121
|
-
A DataFrame, ordered from best to worst model (measured by the ELPD).
|
|
122
|
-
The index reflects the key with which the models are passed to this function. The columns are:
|
|
123
|
-
rank: The rank-order of the models. 0 is the best.
|
|
124
|
-
elpd: ELPD estimated either using (PSIS-LOO-CV `elpd_loo` or WAIC `elpd_waic`).
|
|
125
|
-
Higher ELPD indicates higher out-of-sample predictive fit ("better" model).
|
|
126
|
-
If `scale` is `deviance` or `negative_log` smaller values indicates
|
|
127
|
-
higher out-of-sample predictive fit ("better" model).
|
|
128
|
-
pIC: Estimated effective number of parameters.
|
|
129
|
-
elpd_diff: The difference in ELPD between two models.
|
|
130
|
-
If more than two models are compared, the difference is computed relative to the
|
|
131
|
-
top-ranked model, that always has a elpd_diff of 0.
|
|
132
|
-
weight: Relative weight for each model.
|
|
133
|
-
This can be loosely interpreted as the probability of each model (among the compared model)
|
|
134
|
-
given the data. By default the uncertainty in the weights estimation is considered using
|
|
135
|
-
Bayesian bootstrap.
|
|
136
|
-
SE: Standard error of the ELPD estimate.
|
|
137
|
-
If method = BB-pseudo-BMA these values are estimated using Bayesian bootstrap.
|
|
138
|
-
dSE: Standard error of the difference in ELPD between each model and the top-ranked model.
|
|
139
|
-
It's always 0 for the top-ranked model.
|
|
140
|
-
warning: A value of 1 indicates that the computation of the ELPD may not be reliable.
|
|
141
|
-
This could be indication of WAIC/LOO starting to fail see
|
|
142
|
-
http://arxiv.org/abs/1507.04544 for details.
|
|
143
|
-
scale: Scale used for the ELPD.
|
|
144
|
-
|
|
145
|
-
Examples
|
|
146
|
-
--------
|
|
147
|
-
Compare the centered and non centered models of the eight school problem:
|
|
148
|
-
|
|
149
|
-
.. ipython::
|
|
150
|
-
:okwarning:
|
|
151
|
-
|
|
152
|
-
In [1]: import arviz as az
|
|
153
|
-
...: data1 = az.load_arviz_data("non_centered_eight")
|
|
154
|
-
...: data2 = az.load_arviz_data("centered_eight")
|
|
155
|
-
...: compare_dict = {"non centered": data1, "centered": data2}
|
|
156
|
-
...: az.compare(compare_dict)
|
|
157
|
-
|
|
158
|
-
Compare the models using PSIS-LOO-CV, returning the ELPD in log scale and calculating the
|
|
159
|
-
weights using the stacking method.
|
|
160
|
-
|
|
161
|
-
.. ipython::
|
|
162
|
-
:okwarning:
|
|
163
|
-
|
|
164
|
-
In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
|
|
165
|
-
|
|
166
|
-
See Also
|
|
167
|
-
--------
|
|
168
|
-
loo :
|
|
169
|
-
Compute the ELPD using the Pareto smoothed importance sampling Leave-one-out
|
|
170
|
-
cross-validation method.
|
|
171
|
-
waic : Compute the ELPD using the widely applicable information criterion.
|
|
172
|
-
plot_compare : Summary plot for model comparison.
|
|
173
|
-
|
|
174
|
-
References
|
|
175
|
-
----------
|
|
176
|
-
.. [1] Vehtari, A., Gelman, A. & Gabry, J. Practical Bayesian model evaluation using
|
|
177
|
-
leave-one-out cross-validation and WAIC. Stat Comput 27, 1413–1432 (2017)
|
|
178
|
-
see https://doi.org/10.1007/s11222-016-9696-4
|
|
179
|
-
|
|
180
|
-
"""
|
|
181
|
-
try:
|
|
182
|
-
(ics_dict, scale, ic) = _calculate_ics(compare_dict, scale=scale, ic=ic, var_name=var_name)
|
|
183
|
-
except Exception as e:
|
|
184
|
-
raise e.__class__("Encountered error in ELPD computation of compare.") from e
|
|
185
|
-
names = list(ics_dict.keys())
|
|
186
|
-
if ic in {"loo", "waic"}:
|
|
187
|
-
df_comp = pd.DataFrame(
|
|
188
|
-
{
|
|
189
|
-
"rank": pd.Series(index=names, dtype="int"),
|
|
190
|
-
f"elpd_{ic}": pd.Series(index=names, dtype="float"),
|
|
191
|
-
f"p_{ic}": pd.Series(index=names, dtype="float"),
|
|
192
|
-
"elpd_diff": pd.Series(index=names, dtype="float"),
|
|
193
|
-
"weight": pd.Series(index=names, dtype="float"),
|
|
194
|
-
"se": pd.Series(index=names, dtype="float"),
|
|
195
|
-
"dse": pd.Series(index=names, dtype="float"),
|
|
196
|
-
"warning": pd.Series(index=names, dtype="boolean"),
|
|
197
|
-
"scale": pd.Series(index=names, dtype="str"),
|
|
198
|
-
}
|
|
199
|
-
)
|
|
200
|
-
else:
|
|
201
|
-
raise NotImplementedError(f"The information criterion {ic} is not supported.")
|
|
202
|
-
|
|
203
|
-
if scale == "log":
|
|
204
|
-
scale_value = 1
|
|
205
|
-
ascending = False
|
|
206
|
-
else:
|
|
207
|
-
if scale == "negative_log":
|
|
208
|
-
scale_value = -1
|
|
209
|
-
else:
|
|
210
|
-
scale_value = -2
|
|
211
|
-
ascending = True
|
|
212
|
-
|
|
213
|
-
method = rcParams["stats.ic_compare_method"] if method is None else method
|
|
214
|
-
if method.lower() not in ["stacking", "bb-pseudo-bma", "pseudo-bma"]:
|
|
215
|
-
raise ValueError(f"The method {method}, to compute weights, is not supported.")
|
|
216
|
-
|
|
217
|
-
p_ic = f"p_{ic}"
|
|
218
|
-
ic_i = f"{ic}_i"
|
|
219
|
-
|
|
220
|
-
ics = pd.DataFrame.from_dict(ics_dict, orient="index")
|
|
221
|
-
ics.sort_values(by=f"elpd_{ic}", inplace=True, ascending=ascending)
|
|
222
|
-
ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
|
|
223
|
-
|
|
224
|
-
if method.lower() == "stacking":
|
|
225
|
-
rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
|
|
226
|
-
exp_ic_i = np.exp(ic_i_val / scale_value)
|
|
227
|
-
|
|
228
|
-
def log_score(weights):
|
|
229
|
-
return -np.sum(np.log(exp_ic_i @ weights))
|
|
230
|
-
|
|
231
|
-
def gradient(weights):
|
|
232
|
-
denominator = exp_ic_i @ weights
|
|
233
|
-
return -np.sum(exp_ic_i / denominator[:, np.newaxis], axis=0)
|
|
234
|
-
|
|
235
|
-
theta = np.full(cols, 1.0 / cols)
|
|
236
|
-
bounds = Bounds(lb=np.zeros(cols), ub=np.ones(cols))
|
|
237
|
-
constraints = LinearConstraint(np.ones(cols), lb=1.0, ub=1.0)
|
|
238
|
-
|
|
239
|
-
minimize_result = minimize(
|
|
240
|
-
fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
weights = minimize_result["x"]
|
|
244
|
-
ses = ics["se"]
|
|
245
|
-
|
|
246
|
-
elif method.lower() == "bb-pseudo-bma":
|
|
247
|
-
rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
|
|
248
|
-
ic_i_val = ic_i_val * rows
|
|
249
|
-
|
|
250
|
-
b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)
|
|
251
|
-
weights = np.zeros((b_samples, cols))
|
|
252
|
-
z_bs = np.zeros_like(weights)
|
|
253
|
-
for i in range(b_samples):
|
|
254
|
-
z_b = np.dot(b_weighting[i], ic_i_val)
|
|
255
|
-
u_weights = np.exp((z_b - np.max(z_b)) / scale_value)
|
|
256
|
-
z_bs[i] = z_b # pylint: disable=unsupported-assignment-operation
|
|
257
|
-
weights[i] = u_weights / np.sum(u_weights)
|
|
258
|
-
|
|
259
|
-
weights = weights.mean(axis=0)
|
|
260
|
-
ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
|
|
261
|
-
|
|
262
|
-
elif method.lower() == "pseudo-bma":
|
|
263
|
-
min_ic = ics.iloc[0][f"elpd_{ic}"]
|
|
264
|
-
z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
|
|
265
|
-
weights = (z_rv / np.sum(z_rv)).to_numpy()
|
|
266
|
-
ses = ics["se"]
|
|
267
|
-
|
|
268
|
-
if np.any(weights):
|
|
269
|
-
min_ic_i_val = ics[ic_i].iloc[0]
|
|
270
|
-
for idx, val in enumerate(ics.index):
|
|
271
|
-
res = ics.loc[val]
|
|
272
|
-
if scale_value < 0:
|
|
273
|
-
diff = res[ic_i] - min_ic_i_val
|
|
274
|
-
else:
|
|
275
|
-
diff = min_ic_i_val - res[ic_i]
|
|
276
|
-
d_ic = np.sum(diff)
|
|
277
|
-
d_std_err = np.sqrt(len(diff) * np.var(diff))
|
|
278
|
-
std_err = ses.loc[val]
|
|
279
|
-
weight = weights[idx]
|
|
280
|
-
df_comp.loc[val] = (
|
|
281
|
-
idx,
|
|
282
|
-
res[f"elpd_{ic}"],
|
|
283
|
-
res[p_ic],
|
|
284
|
-
d_ic,
|
|
285
|
-
weight,
|
|
286
|
-
std_err,
|
|
287
|
-
d_std_err,
|
|
288
|
-
res["warning"],
|
|
289
|
-
res["scale"],
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
df_comp["rank"] = df_comp["rank"].astype(int)
|
|
293
|
-
df_comp["warning"] = df_comp["warning"].astype(bool)
|
|
294
|
-
return df_comp.sort_values(by=f"elpd_{ic}", ascending=ascending)
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
def _ic_matrix(ics, ic_i):
|
|
298
|
-
"""Store the previously computed pointwise predictive accuracy values (ics) in a 2D matrix."""
|
|
299
|
-
cols, _ = ics.shape
|
|
300
|
-
rows = len(ics[ic_i].iloc[0])
|
|
301
|
-
ic_i_val = np.zeros((rows, cols))
|
|
302
|
-
|
|
303
|
-
for idx, val in enumerate(ics.index):
|
|
304
|
-
ic = ics.loc[val][ic_i]
|
|
305
|
-
|
|
306
|
-
if len(ic) != rows:
|
|
307
|
-
raise ValueError("The number of observations should be the same across all models")
|
|
308
|
-
|
|
309
|
-
ic_i_val[:, idx] = ic
|
|
310
|
-
|
|
311
|
-
return rows, cols, ic_i_val
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
def _calculate_ics(
|
|
315
|
-
compare_dict,
|
|
316
|
-
scale: Optional[ScaleKeyword] = None,
|
|
317
|
-
ic: Optional[ICKeyword] = None,
|
|
318
|
-
var_name: Optional[str] = None,
|
|
319
|
-
):
|
|
320
|
-
"""Calculate LOO or WAIC only if necessary.
|
|
321
|
-
|
|
322
|
-
It always calls the ic function with ``pointwise=True``.
|
|
323
|
-
|
|
324
|
-
Parameters
|
|
325
|
-
----------
|
|
326
|
-
compare_dict : dict of {str : InferenceData or ELPDData}
|
|
327
|
-
A dictionary of model names and InferenceData or ELPDData objects
|
|
328
|
-
scale : str, optional
|
|
329
|
-
Output scale for IC. Available options are:
|
|
330
|
-
|
|
331
|
-
- `log` : (default) log-score (after Vehtari et al. (2017))
|
|
332
|
-
- `negative_log` : -1 * (log-score)
|
|
333
|
-
- `deviance` : -2 * (log-score)
|
|
334
|
-
|
|
335
|
-
A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
|
|
336
|
-
ic : str, optional
|
|
337
|
-
Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models.
|
|
338
|
-
Defaults to ``rcParams["stats.information_criterion"]``.
|
|
339
|
-
var_name : str, optional
|
|
340
|
-
Name of the variable storing pointwise log likelihood values in ``log_likelihood`` group.
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
Returns
|
|
344
|
-
-------
|
|
345
|
-
compare_dict : dict of ELPDData
|
|
346
|
-
scale : str
|
|
347
|
-
ic : str
|
|
348
|
-
|
|
349
|
-
"""
|
|
350
|
-
precomputed_elpds = {
|
|
351
|
-
name: elpd_data
|
|
352
|
-
for name, elpd_data in compare_dict.items()
|
|
353
|
-
if isinstance(elpd_data, ELPDData)
|
|
354
|
-
}
|
|
355
|
-
precomputed_ic = None
|
|
356
|
-
precomputed_scale = None
|
|
357
|
-
if precomputed_elpds:
|
|
358
|
-
_, arbitrary_elpd = precomputed_elpds.popitem()
|
|
359
|
-
precomputed_ic = arbitrary_elpd.index[0].split("_")[1]
|
|
360
|
-
precomputed_scale = arbitrary_elpd["scale"]
|
|
361
|
-
raise_non_pointwise = f"{precomputed_ic}_i" not in arbitrary_elpd
|
|
362
|
-
if any(
|
|
363
|
-
elpd_data.index[0].split("_")[1] != precomputed_ic
|
|
364
|
-
for elpd_data in precomputed_elpds.values()
|
|
365
|
-
):
|
|
366
|
-
raise ValueError(
|
|
367
|
-
"All information criteria to be compared must be the same "
|
|
368
|
-
"but found both loo and waic."
|
|
369
|
-
)
|
|
370
|
-
if any(elpd_data["scale"] != precomputed_scale for elpd_data in precomputed_elpds.values()):
|
|
371
|
-
raise ValueError("All information criteria to be compared must use the same scale")
|
|
372
|
-
if (
|
|
373
|
-
any(f"{precomputed_ic}_i" not in elpd_data for elpd_data in precomputed_elpds.values())
|
|
374
|
-
or raise_non_pointwise
|
|
375
|
-
):
|
|
376
|
-
raise ValueError("Not all provided ELPDData have been calculated with pointwise=True")
|
|
377
|
-
if ic is not None and ic.lower() != precomputed_ic:
|
|
378
|
-
warnings.warn(
|
|
379
|
-
"Provided ic argument is incompatible with precomputed elpd data. "
|
|
380
|
-
f"Using ic from precomputed elpddata: {precomputed_ic}"
|
|
381
|
-
)
|
|
382
|
-
ic = precomputed_ic
|
|
383
|
-
if scale is not None and scale.lower() != precomputed_scale:
|
|
384
|
-
warnings.warn(
|
|
385
|
-
"Provided scale argument is incompatible with precomputed elpd data. "
|
|
386
|
-
f"Using scale from precomputed elpddata: {precomputed_scale}"
|
|
387
|
-
)
|
|
388
|
-
scale = precomputed_scale
|
|
389
|
-
|
|
390
|
-
if ic is None and precomputed_ic is None:
|
|
391
|
-
ic = cast(ICKeyword, rcParams["stats.information_criterion"])
|
|
392
|
-
elif ic is None:
|
|
393
|
-
ic = precomputed_ic
|
|
394
|
-
else:
|
|
395
|
-
ic = cast(ICKeyword, ic.lower())
|
|
396
|
-
allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword)
|
|
397
|
-
if ic not in allowable:
|
|
398
|
-
raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}")
|
|
399
|
-
|
|
400
|
-
if scale is None and precomputed_scale is None:
|
|
401
|
-
scale = cast(ScaleKeyword, rcParams["stats.ic_scale"])
|
|
402
|
-
elif scale is None:
|
|
403
|
-
scale = precomputed_scale
|
|
404
|
-
else:
|
|
405
|
-
scale = cast(ScaleKeyword, scale.lower())
|
|
406
|
-
allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword)
|
|
407
|
-
if scale not in allowable:
|
|
408
|
-
raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}")
|
|
409
|
-
|
|
410
|
-
if ic == "loo":
|
|
411
|
-
ic_func: Callable = loo
|
|
412
|
-
elif ic == "waic":
|
|
413
|
-
ic_func = waic
|
|
414
|
-
else:
|
|
415
|
-
raise NotImplementedError(f"The information criterion {ic} is not supported.")
|
|
416
|
-
|
|
417
|
-
compare_dict = deepcopy(compare_dict)
|
|
418
|
-
for name, dataset in compare_dict.items():
|
|
419
|
-
if not isinstance(dataset, ELPDData):
|
|
420
|
-
try:
|
|
421
|
-
compare_dict[name] = ic_func(
|
|
422
|
-
convert_to_inference_data(dataset),
|
|
423
|
-
pointwise=True,
|
|
424
|
-
scale=scale,
|
|
425
|
-
var_name=var_name,
|
|
426
|
-
)
|
|
427
|
-
except Exception as e:
|
|
428
|
-
raise e.__class__(
|
|
429
|
-
f"Encountered error trying to compute {ic} from model {name}."
|
|
430
|
-
) from e
|
|
431
|
-
return (compare_dict, scale, ic)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def hdi(
|
|
435
|
-
ary,
|
|
436
|
-
hdi_prob=None,
|
|
437
|
-
circular=False,
|
|
438
|
-
multimodal=False,
|
|
439
|
-
skipna=False,
|
|
440
|
-
group="posterior",
|
|
441
|
-
var_names=None,
|
|
442
|
-
filter_vars=None,
|
|
443
|
-
coords=None,
|
|
444
|
-
max_modes=10,
|
|
445
|
-
dask_kwargs=None,
|
|
446
|
-
**kwargs,
|
|
447
|
-
):
|
|
448
|
-
"""
|
|
449
|
-
Calculate highest density interval (HDI) of array for given probability.
|
|
450
|
-
|
|
451
|
-
The HDI is the minimum width Bayesian credible interval (BCI).
|
|
452
|
-
|
|
453
|
-
Parameters
|
|
454
|
-
----------
|
|
455
|
-
ary: obj
|
|
456
|
-
object containing posterior samples.
|
|
457
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
458
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
459
|
-
hdi_prob: float, optional
|
|
460
|
-
Prob for which the highest density interval will be computed. Defaults to
|
|
461
|
-
``stats.ci_prob`` rcParam.
|
|
462
|
-
circular: bool, optional
|
|
463
|
-
Whether to compute the hdi taking into account `x` is a circular variable
|
|
464
|
-
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
|
465
|
-
Only works if multimodal is False.
|
|
466
|
-
multimodal: bool, optional
|
|
467
|
-
If true it may compute more than one hdi if the distribution is multimodal and the
|
|
468
|
-
modes are well separated.
|
|
469
|
-
skipna: bool, optional
|
|
470
|
-
If true ignores nan values when computing the hdi. Defaults to false.
|
|
471
|
-
group: str, optional
|
|
472
|
-
Specifies which InferenceData group should be used to calculate hdi.
|
|
473
|
-
Defaults to 'posterior'
|
|
474
|
-
var_names: list, optional
|
|
475
|
-
Names of variables to include in the hdi report. Prefix the variables by ``~``
|
|
476
|
-
when you want to exclude them from the report: `["~beta"]` instead of `["beta"]`
|
|
477
|
-
(see :func:`arviz.summary` for more details).
|
|
478
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
479
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
480
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
481
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
482
|
-
``pandas.filter``.
|
|
483
|
-
coords: mapping, optional
|
|
484
|
-
Specifies the subset over to calculate hdi.
|
|
485
|
-
max_modes: int, optional
|
|
486
|
-
Specifies the maximum number of modes for multimodal case.
|
|
487
|
-
dask_kwargs : dict, optional
|
|
488
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
489
|
-
kwargs: dict, optional
|
|
490
|
-
Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
491
|
-
|
|
492
|
-
Returns
|
|
493
|
-
-------
|
|
494
|
-
np.ndarray or xarray.Dataset, depending upon input
|
|
495
|
-
lower(s) and upper(s) values of the interval(s).
|
|
496
|
-
|
|
497
|
-
See Also
|
|
498
|
-
--------
|
|
499
|
-
plot_hdi : Plot highest density intervals for regression data.
|
|
500
|
-
xarray.Dataset.quantile : Calculate quantiles of array for given probabilities.
|
|
501
|
-
|
|
502
|
-
Examples
|
|
503
|
-
--------
|
|
504
|
-
Calculate the HDI of a Normal random variable:
|
|
505
|
-
|
|
506
|
-
.. ipython::
|
|
507
|
-
|
|
508
|
-
In [1]: import arviz as az
|
|
509
|
-
...: import numpy as np
|
|
510
|
-
...: data = np.random.normal(size=2000)
|
|
511
|
-
...: az.hdi(data, hdi_prob=.68)
|
|
512
|
-
|
|
513
|
-
Calculate the HDI of a dataset:
|
|
514
|
-
|
|
515
|
-
.. ipython::
|
|
516
|
-
|
|
517
|
-
In [1]: import arviz as az
|
|
518
|
-
...: data = az.load_arviz_data('centered_eight')
|
|
519
|
-
...: az.hdi(data)
|
|
520
|
-
|
|
521
|
-
We can also calculate the HDI of some of the variables of dataset:
|
|
522
|
-
|
|
523
|
-
.. ipython::
|
|
524
|
-
|
|
525
|
-
In [1]: az.hdi(data, var_names=["mu", "theta"])
|
|
526
|
-
|
|
527
|
-
By default, ``hdi`` is calculated over the ``chain`` and ``draw`` dimensions. We can use the
|
|
528
|
-
``input_core_dims`` argument of :func:`~arviz.wrap_xarray_ufunc` to change this. In this example
|
|
529
|
-
we calculate the HDI also over the ``school`` dimension:
|
|
530
|
-
|
|
531
|
-
.. ipython::
|
|
532
|
-
|
|
533
|
-
In [1]: az.hdi(data, var_names="theta", input_core_dims = [["chain","draw", "school"]])
|
|
534
|
-
|
|
535
|
-
We can also calculate the hdi over a particular selection:
|
|
536
|
-
|
|
537
|
-
.. ipython::
|
|
538
|
-
|
|
539
|
-
In [1]: az.hdi(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])
|
|
540
|
-
|
|
541
|
-
"""
|
|
542
|
-
if hdi_prob is None:
|
|
543
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
544
|
-
elif not 1 >= hdi_prob > 0:
|
|
545
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
546
|
-
|
|
547
|
-
func_kwargs = {
|
|
548
|
-
"hdi_prob": hdi_prob,
|
|
549
|
-
"skipna": skipna,
|
|
550
|
-
"out_shape": (max_modes, 2) if multimodal else (2,),
|
|
551
|
-
}
|
|
552
|
-
kwargs.setdefault("output_core_dims", [["mode", "hdi"] if multimodal else ["hdi"]])
|
|
553
|
-
if not multimodal:
|
|
554
|
-
func_kwargs["circular"] = circular
|
|
555
|
-
else:
|
|
556
|
-
func_kwargs["max_modes"] = max_modes
|
|
557
|
-
|
|
558
|
-
func = _hdi_multimodal if multimodal else _hdi
|
|
559
|
-
|
|
560
|
-
isarray = isinstance(ary, np.ndarray)
|
|
561
|
-
if isarray and ary.ndim <= 1:
|
|
562
|
-
func_kwargs.pop("out_shape")
|
|
563
|
-
hdi_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg
|
|
564
|
-
return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data
|
|
565
|
-
|
|
566
|
-
if isarray and ary.ndim == 2:
|
|
567
|
-
warnings.warn(
|
|
568
|
-
"hdi currently interprets 2d data as (draw, shape) but this will change in "
|
|
569
|
-
"a future release to (chain, draw) for coherence with other functions",
|
|
570
|
-
FutureWarning,
|
|
571
|
-
stacklevel=2,
|
|
572
|
-
)
|
|
573
|
-
ary = np.expand_dims(ary, 0)
|
|
574
|
-
|
|
575
|
-
ary = convert_to_dataset(ary, group=group)
|
|
576
|
-
if coords is not None:
|
|
577
|
-
ary = get_coords(ary, coords)
|
|
578
|
-
var_names = _var_names(var_names, ary, filter_vars)
|
|
579
|
-
ary = ary[var_names] if var_names else ary
|
|
580
|
-
|
|
581
|
-
hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
|
|
582
|
-
hdi_data = _wrap_xarray_ufunc(
|
|
583
|
-
func, ary, func_kwargs=func_kwargs, dask_kwargs=dask_kwargs, **kwargs
|
|
584
|
-
).assign_coords({"hdi": hdi_coord})
|
|
585
|
-
hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
|
|
586
|
-
return hdi_data.x.values if isarray else hdi_data
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
def _hdi(ary, hdi_prob, circular, skipna):
|
|
590
|
-
"""Compute hpi over the flattened array."""
|
|
591
|
-
ary = ary.flatten()
|
|
592
|
-
if skipna:
|
|
593
|
-
nans = np.isnan(ary)
|
|
594
|
-
if not nans.all():
|
|
595
|
-
ary = ary[~nans]
|
|
596
|
-
n = len(ary)
|
|
597
|
-
|
|
598
|
-
if circular:
|
|
599
|
-
mean = st.circmean(ary, high=np.pi, low=-np.pi)
|
|
600
|
-
ary = ary - mean
|
|
601
|
-
ary = np.arctan2(np.sin(ary), np.cos(ary))
|
|
602
|
-
|
|
603
|
-
ary = np.sort(ary)
|
|
604
|
-
interval_idx_inc = int(np.floor(hdi_prob * n))
|
|
605
|
-
n_intervals = n - interval_idx_inc
|
|
606
|
-
interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
|
|
607
|
-
|
|
608
|
-
if len(interval_width) == 0:
|
|
609
|
-
raise ValueError("Too few elements for interval calculation. ")
|
|
610
|
-
|
|
611
|
-
min_idx = np.argmin(interval_width)
|
|
612
|
-
hdi_min = ary[min_idx]
|
|
613
|
-
hdi_max = ary[min_idx + interval_idx_inc]
|
|
614
|
-
|
|
615
|
-
if circular:
|
|
616
|
-
hdi_min = hdi_min + mean
|
|
617
|
-
hdi_max = hdi_max + mean
|
|
618
|
-
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
|
|
619
|
-
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
|
|
620
|
-
|
|
621
|
-
hdi_interval = np.array([hdi_min, hdi_max])
|
|
622
|
-
|
|
623
|
-
return hdi_interval
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
|
|
627
|
-
"""Compute HDI if the distribution is multimodal."""
|
|
628
|
-
ary = ary.flatten()
|
|
629
|
-
if skipna:
|
|
630
|
-
ary = ary[~np.isnan(ary)]
|
|
631
|
-
|
|
632
|
-
if ary.dtype.kind == "f":
|
|
633
|
-
bins, density = _kde(ary)
|
|
634
|
-
lower, upper = bins[0], bins[-1]
|
|
635
|
-
range_x = upper - lower
|
|
636
|
-
dx = range_x / len(density)
|
|
637
|
-
else:
|
|
638
|
-
bins = _get_bins(ary)
|
|
639
|
-
_, density, _ = _histogram(ary, bins=bins)
|
|
640
|
-
dx = np.diff(bins)[0]
|
|
641
|
-
|
|
642
|
-
density *= dx
|
|
643
|
-
|
|
644
|
-
idx = np.argsort(-density)
|
|
645
|
-
intervals = bins[idx][density[idx].cumsum() <= hdi_prob]
|
|
646
|
-
intervals.sort()
|
|
647
|
-
|
|
648
|
-
intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
|
|
649
|
-
|
|
650
|
-
hdi_intervals = np.full((max_modes, 2), np.nan)
|
|
651
|
-
for i, interval in enumerate(intervals_splitted):
|
|
652
|
-
if i == max_modes:
|
|
653
|
-
warnings.warn(
|
|
654
|
-
f"found more modes than {max_modes}, returning only the first {max_modes} modes"
|
|
655
|
-
)
|
|
656
|
-
break
|
|
657
|
-
if interval.size == 0:
|
|
658
|
-
hdi_intervals[i] = np.asarray([bins[0], bins[0]])
|
|
659
|
-
else:
|
|
660
|
-
hdi_intervals[i] = np.asarray([interval[0], interval[-1]])
|
|
661
|
-
|
|
662
|
-
return np.array(hdi_intervals)
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
666
|
-
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
|
|
667
|
-
|
|
668
|
-
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
|
|
669
|
-
importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Also calculates LOO's
|
|
670
|
-
standard error and the effective number of parameters. Read more theory here
|
|
671
|
-
https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1507.02646
|
|
672
|
-
|
|
673
|
-
Parameters
|
|
674
|
-
----------
|
|
675
|
-
data: obj
|
|
676
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
677
|
-
Refer to documentation of
|
|
678
|
-
:func:`arviz.convert_to_dataset` for details.
|
|
679
|
-
pointwise: bool, optional
|
|
680
|
-
If True the pointwise predictive accuracy will be returned. Defaults to
|
|
681
|
-
``stats.ic_pointwise`` rcParam.
|
|
682
|
-
var_name : str, optional
|
|
683
|
-
The name of the variable in log_likelihood groups storing the pointwise log
|
|
684
|
-
likelihood data to use for loo computation.
|
|
685
|
-
reff: float, optional
|
|
686
|
-
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
|
|
687
|
-
of actual samples. Computed from trace by default.
|
|
688
|
-
scale: str
|
|
689
|
-
Output scale for loo. Available options are:
|
|
690
|
-
|
|
691
|
-
- ``log`` : (default) log-score
|
|
692
|
-
- ``negative_log`` : -1 * log-score
|
|
693
|
-
- ``deviance`` : -2 * log-score
|
|
694
|
-
|
|
695
|
-
A higher log-score (or a lower deviance or negative log_score) indicates a model with
|
|
696
|
-
better predictive accuracy.
|
|
697
|
-
|
|
698
|
-
Returns
|
|
699
|
-
-------
|
|
700
|
-
ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
|
|
701
|
-
elpd_loo: approximated expected log pointwise predictive density (elpd)
|
|
702
|
-
se: standard error of the elpd
|
|
703
|
-
p_loo: effective number of parameters
|
|
704
|
-
n_samples: number of samples
|
|
705
|
-
n_data_points: number of data points
|
|
706
|
-
warning: bool
|
|
707
|
-
True if the estimated shape parameter of Pareto distribution is greater than
|
|
708
|
-
``good_k``.
|
|
709
|
-
loo_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
710
|
-
only if pointwise=True
|
|
711
|
-
pareto_k: array of Pareto shape values, only if pointwise True
|
|
712
|
-
scale: scale of the elpd
|
|
713
|
-
good_k: For a sample size S, the thresold is compute as min(1 - 1/log10(S), 0.7)
|
|
714
|
-
|
|
715
|
-
The returned object has a custom print method that overrides pd.Series method.
|
|
716
|
-
|
|
717
|
-
See Also
|
|
718
|
-
--------
|
|
719
|
-
compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
|
|
720
|
-
waic : Compute the widely applicable information criterion.
|
|
721
|
-
plot_compare : Summary plot for model comparison.
|
|
722
|
-
plot_elpd : Plot pointwise elpd differences between two or more models.
|
|
723
|
-
plot_khat : Plot Pareto tail indices for diagnosing convergence.
|
|
724
|
-
|
|
725
|
-
Examples
|
|
726
|
-
--------
|
|
727
|
-
Calculate LOO of a model:
|
|
728
|
-
|
|
729
|
-
.. ipython::
|
|
730
|
-
|
|
731
|
-
In [1]: import arviz as az
|
|
732
|
-
...: data = az.load_arviz_data("centered_eight")
|
|
733
|
-
...: az.loo(data)
|
|
734
|
-
|
|
735
|
-
Calculate LOO of a model and return the pointwise values:
|
|
736
|
-
|
|
737
|
-
.. ipython::
|
|
738
|
-
|
|
739
|
-
In [2]: data_loo = az.loo(data, pointwise=True)
|
|
740
|
-
...: data_loo.loo_i
|
|
741
|
-
"""
|
|
742
|
-
inference_data = convert_to_inference_data(data)
|
|
743
|
-
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
|
|
744
|
-
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
|
|
745
|
-
|
|
746
|
-
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
|
|
747
|
-
shape = log_likelihood.shape
|
|
748
|
-
n_samples = shape[-1]
|
|
749
|
-
n_data_points = np.prod(shape[:-1])
|
|
750
|
-
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
|
|
751
|
-
|
|
752
|
-
if scale == "deviance":
|
|
753
|
-
scale_value = -2
|
|
754
|
-
elif scale == "log":
|
|
755
|
-
scale_value = 1
|
|
756
|
-
elif scale == "negative_log":
|
|
757
|
-
scale_value = -1
|
|
758
|
-
else:
|
|
759
|
-
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
|
|
760
|
-
|
|
761
|
-
if reff is None:
|
|
762
|
-
if not hasattr(inference_data, "posterior"):
|
|
763
|
-
raise TypeError("Must be able to extract a posterior group from data.")
|
|
764
|
-
posterior = inference_data.posterior
|
|
765
|
-
n_chains = len(posterior.chain)
|
|
766
|
-
if n_chains == 1:
|
|
767
|
-
reff = 1.0
|
|
768
|
-
else:
|
|
769
|
-
ess_p = ess(posterior, method="mean")
|
|
770
|
-
# this mean is over all data variables
|
|
771
|
-
reff = (
|
|
772
|
-
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
|
|
773
|
-
)
|
|
774
|
-
|
|
775
|
-
log_weights, pareto_shape = psislw(-log_likelihood, reff)
|
|
776
|
-
log_weights += log_likelihood
|
|
777
|
-
|
|
778
|
-
warn_mg = False
|
|
779
|
-
good_k = min(1 - 1 / np.log10(n_samples), 0.7)
|
|
780
|
-
|
|
781
|
-
if np.any(pareto_shape > good_k):
|
|
782
|
-
warnings.warn(
|
|
783
|
-
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
|
|
784
|
-
"for one or more samples. You should consider using a more robust model, this is "
|
|
785
|
-
"because importance sampling is less likely to work well if the marginal posterior "
|
|
786
|
-
"and LOO posterior are very different. This is more likely to happen with a "
|
|
787
|
-
"non-robust model and highly influential observations."
|
|
788
|
-
)
|
|
789
|
-
warn_mg = True
|
|
790
|
-
|
|
791
|
-
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
|
792
|
-
kwargs = {"input_core_dims": [["__sample__"]]}
|
|
793
|
-
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
|
|
794
|
-
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
|
|
795
|
-
)
|
|
796
|
-
loo_lppd = loo_lppd_i.values.sum()
|
|
797
|
-
loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
|
|
798
|
-
|
|
799
|
-
lppd = np.sum(
|
|
800
|
-
_wrap_xarray_ufunc(
|
|
801
|
-
_logsumexp,
|
|
802
|
-
log_likelihood,
|
|
803
|
-
func_kwargs={"b_inv": n_samples},
|
|
804
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
805
|
-
**kwargs,
|
|
806
|
-
).values
|
|
807
|
-
)
|
|
808
|
-
p_loo = lppd - loo_lppd / scale_value
|
|
809
|
-
|
|
810
|
-
if not pointwise:
|
|
811
|
-
return ELPDData(
|
|
812
|
-
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
|
|
813
|
-
index=[
|
|
814
|
-
"elpd_loo",
|
|
815
|
-
"se",
|
|
816
|
-
"p_loo",
|
|
817
|
-
"n_samples",
|
|
818
|
-
"n_data_points",
|
|
819
|
-
"warning",
|
|
820
|
-
"scale",
|
|
821
|
-
"good_k",
|
|
822
|
-
],
|
|
823
|
-
)
|
|
824
|
-
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
|
|
825
|
-
warnings.warn(
|
|
826
|
-
"The point-wise LOO is the same with the sum LOO, please double check "
|
|
827
|
-
"the Observed RV in your model to make sure it returns element-wise logp."
|
|
828
|
-
)
|
|
829
|
-
return ELPDData(
|
|
830
|
-
data=[
|
|
831
|
-
loo_lppd,
|
|
832
|
-
loo_lppd_se,
|
|
833
|
-
p_loo,
|
|
834
|
-
n_samples,
|
|
835
|
-
n_data_points,
|
|
836
|
-
warn_mg,
|
|
837
|
-
loo_lppd_i.rename("loo_i"),
|
|
838
|
-
pareto_shape,
|
|
839
|
-
scale,
|
|
840
|
-
good_k,
|
|
841
|
-
],
|
|
842
|
-
index=[
|
|
843
|
-
"elpd_loo",
|
|
844
|
-
"se",
|
|
845
|
-
"p_loo",
|
|
846
|
-
"n_samples",
|
|
847
|
-
"n_data_points",
|
|
848
|
-
"warning",
|
|
849
|
-
"loo_i",
|
|
850
|
-
"pareto_k",
|
|
851
|
-
"scale",
|
|
852
|
-
"good_k",
|
|
853
|
-
],
|
|
854
|
-
)
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
def psislw(log_weights, reff=1.0, normalize=True):
|
|
858
|
-
"""
|
|
859
|
-
Pareto smoothed importance sampling (PSIS).
|
|
860
|
-
|
|
861
|
-
Notes
|
|
862
|
-
-----
|
|
863
|
-
If the ``log_weights`` input is an :class:`~xarray.DataArray` with a dimension
|
|
864
|
-
named ``__sample__`` (recommended) ``psislw`` will interpret this dimension as samples,
|
|
865
|
-
and all other dimensions as dimensions of the observed data, looping over them to
|
|
866
|
-
calculate the psislw of each observation. If no ``__sample__`` dimension is present or
|
|
867
|
-
the input is a numpy array, the last dimension will be interpreted as ``__sample__``.
|
|
868
|
-
|
|
869
|
-
Parameters
|
|
870
|
-
----------
|
|
871
|
-
log_weights : DataArray or (..., N) array-like
|
|
872
|
-
Array of size (n_observations, n_samples)
|
|
873
|
-
reff : float, default 1
|
|
874
|
-
relative MCMC efficiency, ``ess / n``
|
|
875
|
-
normalize : bool, default True
|
|
876
|
-
return normalized log weights
|
|
877
|
-
|
|
878
|
-
Returns
|
|
879
|
-
-------
|
|
880
|
-
lw_out : DataArray or (..., N) ndarray
|
|
881
|
-
Smoothed, truncated and possibly normalized log weights.
|
|
882
|
-
kss : DataArray or (...) ndarray
|
|
883
|
-
Estimates of the shape parameter *k* of the generalized Pareto
|
|
884
|
-
distribution.
|
|
885
|
-
|
|
886
|
-
References
|
|
887
|
-
----------
|
|
888
|
-
* Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
|
|
889
|
-
Learning Research, 25(72):1-58.
|
|
890
|
-
|
|
891
|
-
See Also
|
|
892
|
-
--------
|
|
893
|
-
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
|
|
894
|
-
|
|
895
|
-
Examples
|
|
896
|
-
--------
|
|
897
|
-
Get Pareto smoothed importance sampling (PSIS) log weights:
|
|
898
|
-
|
|
899
|
-
.. ipython::
|
|
900
|
-
|
|
901
|
-
In [1]: import arviz as az
|
|
902
|
-
...: data = az.load_arviz_data("non_centered_eight")
|
|
903
|
-
...: log_likelihood = data.log_likelihood["obs"].stack(
|
|
904
|
-
...: __sample__=["chain", "draw"]
|
|
905
|
-
...: )
|
|
906
|
-
...: az.psislw(-log_likelihood, reff=0.8)
|
|
907
|
-
|
|
908
|
-
"""
|
|
909
|
-
log_weights = deepcopy(log_weights)
|
|
910
|
-
if hasattr(log_weights, "__sample__"):
|
|
911
|
-
n_samples = len(log_weights.__sample__)
|
|
912
|
-
shape = [
|
|
913
|
-
size for size, dim in zip(log_weights.shape, log_weights.dims) if dim != "__sample__"
|
|
914
|
-
]
|
|
915
|
-
else:
|
|
916
|
-
n_samples = log_weights.shape[-1]
|
|
917
|
-
shape = log_weights.shape[:-1]
|
|
918
|
-
# precalculate constants
|
|
919
|
-
cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1
|
|
920
|
-
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return
|
|
921
|
-
|
|
922
|
-
# create output array with proper dimensions
|
|
923
|
-
out = np.empty_like(log_weights), np.empty(shape)
|
|
924
|
-
|
|
925
|
-
# define kwargs
|
|
926
|
-
func_kwargs = {
|
|
927
|
-
"cutoff_ind": cutoff_ind,
|
|
928
|
-
"cutoffmin": cutoffmin,
|
|
929
|
-
"out": out,
|
|
930
|
-
"normalize": normalize,
|
|
931
|
-
}
|
|
932
|
-
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
|
|
933
|
-
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
|
|
934
|
-
log_weights, pareto_shape = _wrap_xarray_ufunc(
|
|
935
|
-
_psislw,
|
|
936
|
-
log_weights,
|
|
937
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
938
|
-
func_kwargs=func_kwargs,
|
|
939
|
-
**kwargs,
|
|
940
|
-
)
|
|
941
|
-
if isinstance(log_weights, xr.DataArray):
|
|
942
|
-
log_weights = log_weights.rename("log_weights")
|
|
943
|
-
if isinstance(pareto_shape, xr.DataArray):
|
|
944
|
-
pareto_shape = pareto_shape.rename("pareto_shape")
|
|
945
|
-
return log_weights, pareto_shape
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
|
|
949
|
-
"""
|
|
950
|
-
Pareto smoothed importance sampling (PSIS) for a 1D vector.
|
|
951
|
-
|
|
952
|
-
Parameters
|
|
953
|
-
----------
|
|
954
|
-
log_weights: array
|
|
955
|
-
Array of length n_observations
|
|
956
|
-
cutoff_ind: int
|
|
957
|
-
cutoffmin: float
|
|
958
|
-
normalize: bool
|
|
959
|
-
|
|
960
|
-
Returns
|
|
961
|
-
-------
|
|
962
|
-
lw_out: array
|
|
963
|
-
Smoothed log weights
|
|
964
|
-
kss: float
|
|
965
|
-
Pareto tail index
|
|
966
|
-
"""
|
|
967
|
-
x = np.asarray(log_weights)
|
|
968
|
-
|
|
969
|
-
# improve numerical accuracy
|
|
970
|
-
max_x = np.max(x)
|
|
971
|
-
x -= max_x
|
|
972
|
-
# sort the array
|
|
973
|
-
x_sort_ind = np.argsort(x)
|
|
974
|
-
# divide log weights into body and right tail
|
|
975
|
-
xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)
|
|
976
|
-
|
|
977
|
-
expxcutoff = np.exp(xcutoff)
|
|
978
|
-
(tailinds,) = np.where(x > xcutoff) # pylint: disable=unbalanced-tuple-unpacking
|
|
979
|
-
x_tail = x[tailinds]
|
|
980
|
-
tail_len = len(x_tail)
|
|
981
|
-
if tail_len <= 4:
|
|
982
|
-
# not enough tail samples for gpdfit
|
|
983
|
-
k = np.inf
|
|
984
|
-
else:
|
|
985
|
-
# order of tail samples
|
|
986
|
-
x_tail_si = np.argsort(x_tail)
|
|
987
|
-
# fit generalized Pareto distribution to the right tail samples
|
|
988
|
-
x_tail = np.exp(x_tail) - expxcutoff
|
|
989
|
-
k, sigma = _gpdfit(x_tail[x_tail_si])
|
|
990
|
-
|
|
991
|
-
if np.isfinite(k):
|
|
992
|
-
# no smoothing if GPD fit failed
|
|
993
|
-
# compute ordered statistic for the fit
|
|
994
|
-
sti = np.arange(0.5, tail_len) / tail_len
|
|
995
|
-
smoothed_tail = _gpinv(sti, k, sigma)
|
|
996
|
-
smoothed_tail = np.log( # pylint: disable=assignment-from-no-return
|
|
997
|
-
smoothed_tail + expxcutoff
|
|
998
|
-
)
|
|
999
|
-
# place the smoothed tail into the output array
|
|
1000
|
-
x[tailinds[x_tail_si]] = smoothed_tail
|
|
1001
|
-
# truncate smoothed values to the largest raw weight 0
|
|
1002
|
-
x[x > 0] = 0
|
|
1003
|
-
|
|
1004
|
-
# renormalize weights
|
|
1005
|
-
if normalize:
|
|
1006
|
-
x -= _logsumexp(x)
|
|
1007
|
-
else:
|
|
1008
|
-
x += max_x
|
|
1009
|
-
|
|
1010
|
-
return x, k
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
def _gpdfit(ary):
|
|
1014
|
-
"""Estimate the parameters for the Generalized Pareto Distribution (GPD).
|
|
1015
|
-
|
|
1016
|
-
Empirical Bayes estimate for the parameters of the generalized Pareto
|
|
1017
|
-
distribution given the data.
|
|
1018
|
-
|
|
1019
|
-
Parameters
|
|
1020
|
-
----------
|
|
1021
|
-
ary: array
|
|
1022
|
-
sorted 1D data array
|
|
1023
|
-
|
|
1024
|
-
Returns
|
|
1025
|
-
-------
|
|
1026
|
-
k: float
|
|
1027
|
-
estimated shape parameter
|
|
1028
|
-
sigma: float
|
|
1029
|
-
estimated scale parameter
|
|
1030
|
-
"""
|
|
1031
|
-
prior_bs = 3
|
|
1032
|
-
prior_k = 10
|
|
1033
|
-
n = len(ary)
|
|
1034
|
-
m_est = 30 + int(n**0.5)
|
|
1035
|
-
|
|
1036
|
-
b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5))
|
|
1037
|
-
b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
|
|
1038
|
-
b_ary += 1 / ary[-1]
|
|
1039
|
-
|
|
1040
|
-
k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member
|
|
1041
|
-
len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
|
|
1042
|
-
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
|
|
1043
|
-
|
|
1044
|
-
# remove negligible weights
|
|
1045
|
-
real_idxs = weights >= 10 * np.finfo(float).eps
|
|
1046
|
-
if not np.all(real_idxs):
|
|
1047
|
-
weights = weights[real_idxs]
|
|
1048
|
-
b_ary = b_ary[real_idxs]
|
|
1049
|
-
# normalise weights
|
|
1050
|
-
weights /= weights.sum()
|
|
1051
|
-
|
|
1052
|
-
# posterior mean for b
|
|
1053
|
-
b_post = np.sum(b_ary * weights)
|
|
1054
|
-
# estimate for k
|
|
1055
|
-
k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member
|
|
1056
|
-
# add prior for k_post
|
|
1057
|
-
sigma = -k_post / b_post
|
|
1058
|
-
k_post = (n * k_post + prior_k * 0.5) / (n + prior_k)
|
|
1059
|
-
|
|
1060
|
-
return k_post, sigma
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
def _gpinv(probs, kappa, sigma):
|
|
1064
|
-
"""Inverse Generalized Pareto distribution function."""
|
|
1065
|
-
# pylint: disable=unsupported-assignment-operation, invalid-unary-operand-type
|
|
1066
|
-
x = np.full_like(probs, np.nan)
|
|
1067
|
-
if sigma <= 0:
|
|
1068
|
-
return x
|
|
1069
|
-
ok = (probs > 0) & (probs < 1)
|
|
1070
|
-
if np.all(ok):
|
|
1071
|
-
if np.abs(kappa) < np.finfo(float).eps:
|
|
1072
|
-
x = -np.log1p(-probs)
|
|
1073
|
-
else:
|
|
1074
|
-
x = np.expm1(-kappa * np.log1p(-probs)) / kappa
|
|
1075
|
-
x *= sigma
|
|
1076
|
-
else:
|
|
1077
|
-
if np.abs(kappa) < np.finfo(float).eps:
|
|
1078
|
-
x[ok] = -np.log1p(-probs[ok])
|
|
1079
|
-
else:
|
|
1080
|
-
x[ok] = np.expm1(-kappa * np.log1p(-probs[ok])) / kappa
|
|
1081
|
-
x *= sigma
|
|
1082
|
-
x[probs == 0] = 0
|
|
1083
|
-
x[probs == 1] = np.inf if kappa >= 0 else -sigma / kappa
|
|
1084
|
-
return x
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
def r2_samples(y_true, y_pred):
|
|
1088
|
-
"""R² samples for Bayesian regression models. Only valid for linear models.
|
|
1089
|
-
|
|
1090
|
-
Parameters
|
|
1091
|
-
----------
|
|
1092
|
-
y_true: array-like of shape = (n_outputs,)
|
|
1093
|
-
Ground truth (correct) target values.
|
|
1094
|
-
y_pred: array-like of shape = (n_posterior_samples, n_outputs)
|
|
1095
|
-
Estimated target values.
|
|
1096
|
-
|
|
1097
|
-
Returns
|
|
1098
|
-
-------
|
|
1099
|
-
Pandas Series with the following indices:
|
|
1100
|
-
Bayesian R² samples.
|
|
1101
|
-
|
|
1102
|
-
See Also
|
|
1103
|
-
--------
|
|
1104
|
-
plot_lm : Posterior predictive and mean plots for regression-like data.
|
|
1105
|
-
|
|
1106
|
-
Examples
|
|
1107
|
-
--------
|
|
1108
|
-
Calculate R² samples for Bayesian regression models :
|
|
1109
|
-
|
|
1110
|
-
.. ipython::
|
|
1111
|
-
|
|
1112
|
-
In [1]: import arviz as az
|
|
1113
|
-
...: data = az.load_arviz_data('regression1d')
|
|
1114
|
-
...: y_true = data.observed_data["y"].values
|
|
1115
|
-
...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
|
|
1116
|
-
...: az.r2_samples(y_true, y_pred)
|
|
1117
|
-
|
|
1118
|
-
"""
|
|
1119
|
-
_numba_flag = Numba.numba_flag
|
|
1120
|
-
if y_pred.ndim == 1:
|
|
1121
|
-
var_y_est = _numba_var(svar, np.var, y_pred)
|
|
1122
|
-
var_e = _numba_var(svar, np.var, (y_true - y_pred))
|
|
1123
|
-
else:
|
|
1124
|
-
var_y_est = _numba_var(svar, np.var, y_pred, axis=1)
|
|
1125
|
-
var_e = _numba_var(svar, np.var, (y_true - y_pred), axis=1)
|
|
1126
|
-
r_squared = var_y_est / (var_y_est + var_e)
|
|
1127
|
-
|
|
1128
|
-
return r_squared
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
def r2_score(y_true, y_pred):
|
|
1132
|
-
"""R² for Bayesian regression models. Only valid for linear models.
|
|
1133
|
-
|
|
1134
|
-
Parameters
|
|
1135
|
-
----------
|
|
1136
|
-
y_true: array-like of shape = (n_outputs,)
|
|
1137
|
-
Ground truth (correct) target values.
|
|
1138
|
-
y_pred: array-like of shape = (n_posterior_samples, n_outputs)
|
|
1139
|
-
Estimated target values.
|
|
1140
|
-
|
|
1141
|
-
Returns
|
|
1142
|
-
-------
|
|
1143
|
-
Pandas Series with the following indices:
|
|
1144
|
-
r2: Bayesian R²
|
|
1145
|
-
r2_std: standard deviation of the Bayesian R².
|
|
1146
|
-
|
|
1147
|
-
See Also
|
|
1148
|
-
--------
|
|
1149
|
-
plot_lm : Posterior predictive and mean plots for regression-like data.
|
|
1150
|
-
|
|
1151
|
-
Examples
|
|
1152
|
-
--------
|
|
1153
|
-
Calculate R² for Bayesian regression models :
|
|
1154
|
-
|
|
1155
|
-
.. ipython::
|
|
1156
|
-
|
|
1157
|
-
In [1]: import arviz as az
|
|
1158
|
-
...: data = az.load_arviz_data('regression1d')
|
|
1159
|
-
...: y_true = data.observed_data["y"].values
|
|
1160
|
-
...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
|
|
1161
|
-
...: az.r2_score(y_true, y_pred)
|
|
1162
|
-
|
|
1163
|
-
"""
|
|
1164
|
-
r_squared = r2_samples(y_true=y_true, y_pred=y_pred)
|
|
1165
|
-
return pd.Series([np.mean(r_squared), np.std(r_squared)], index=["r2", "r2_std"])
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
def summary(
|
|
1169
|
-
data,
|
|
1170
|
-
var_names: Optional[List[str]] = None,
|
|
1171
|
-
filter_vars=None,
|
|
1172
|
-
group=None,
|
|
1173
|
-
fmt: "Literal['wide', 'long', 'xarray']" = "wide",
|
|
1174
|
-
kind: "Literal['all', 'stats', 'diagnostics']" = "all",
|
|
1175
|
-
round_to=None,
|
|
1176
|
-
circ_var_names=None,
|
|
1177
|
-
stat_focus="mean",
|
|
1178
|
-
stat_funcs=None,
|
|
1179
|
-
extend=True,
|
|
1180
|
-
hdi_prob=None,
|
|
1181
|
-
skipna=False,
|
|
1182
|
-
labeller=None,
|
|
1183
|
-
coords=None,
|
|
1184
|
-
index_origin=None,
|
|
1185
|
-
order=None,
|
|
1186
|
-
) -> Union[pd.DataFrame, xr.Dataset]:
|
|
1187
|
-
"""Create a data frame with summary statistics.
|
|
1188
|
-
|
|
1189
|
-
Parameters
|
|
1190
|
-
----------
|
|
1191
|
-
data: obj
|
|
1192
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
1193
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
1194
|
-
var_names: list
|
|
1195
|
-
Names of variables to include in summary. Prefix the variables by ``~`` when you
|
|
1196
|
-
want to exclude them from the summary: `["~beta"]` instead of `["beta"]` (see
|
|
1197
|
-
examples below).
|
|
1198
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
1199
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
1200
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
1201
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
1202
|
-
``pandas.filter``.
|
|
1203
|
-
coords: Dict[str, List[Any]], optional
|
|
1204
|
-
Coordinate subset for which to calculate the summary.
|
|
1205
|
-
group: str
|
|
1206
|
-
Select a group for summary. Defaults to "posterior", "prior" or first group
|
|
1207
|
-
in that order, depending what groups exists.
|
|
1208
|
-
fmt: {'wide', 'long', 'xarray'}
|
|
1209
|
-
Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}.
|
|
1210
|
-
kind: {'all', 'stats', 'diagnostics'}
|
|
1211
|
-
Whether to include the `stats`: `mean`, `sd`, `hdi_3%`, `hdi_97%`, or the `diagnostics`:
|
|
1212
|
-
`mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and `r_hat`. Default to include `all` of
|
|
1213
|
-
them.
|
|
1214
|
-
round_to: int
|
|
1215
|
-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
|
|
1216
|
-
circ_var_names: list
|
|
1217
|
-
A list of circular variables to compute circular stats for
|
|
1218
|
-
stat_focus : str, default "mean"
|
|
1219
|
-
Select the focus for summary.
|
|
1220
|
-
stat_funcs: dict
|
|
1221
|
-
A list of functions or a dict of functions with function names as keys used to calculate
|
|
1222
|
-
statistics. By default, the mean, standard deviation, simulation standard error, and
|
|
1223
|
-
highest posterior density intervals are included.
|
|
1224
|
-
|
|
1225
|
-
The functions will be given one argument, the samples for a variable as an nD array,
|
|
1226
|
-
The functions should be in the style of a ufunc and return a single number. For example,
|
|
1227
|
-
:func:`numpy.mean`, or ``scipy.stats.var`` would both work.
|
|
1228
|
-
extend: boolean
|
|
1229
|
-
If True, use the statistics returned by ``stat_funcs`` in addition to, rather than in place
|
|
1230
|
-
of, the default statistics. This is only meaningful when ``stat_funcs`` is not None.
|
|
1231
|
-
hdi_prob: float, optional
|
|
1232
|
-
Highest density interval to compute. Defaults to 0.94. This is only meaningful when
|
|
1233
|
-
``stat_funcs`` is None.
|
|
1234
|
-
skipna: bool
|
|
1235
|
-
If true ignores nan values when computing the summary statistics, it does not affect the
|
|
1236
|
-
behaviour of the functions passed to ``stat_funcs``. Defaults to false.
|
|
1237
|
-
labeller : labeller instance, optional
|
|
1238
|
-
Class providing the method `make_label_flat` to generate the labels in the plot titles.
|
|
1239
|
-
For more details on ``labeller`` usage see :ref:`label_guide`
|
|
1240
|
-
credible_interval: float, optional
|
|
1241
|
-
deprecated: Please see hdi_prob
|
|
1242
|
-
order
|
|
1243
|
-
deprecated: order is now ignored.
|
|
1244
|
-
index_origin
|
|
1245
|
-
deprecated: index_origin is now ignored, modify the coordinate values to change the
|
|
1246
|
-
value used in summary.
|
|
1247
|
-
|
|
1248
|
-
Returns
|
|
1249
|
-
-------
|
|
1250
|
-
pandas.DataFrame or xarray.Dataset
|
|
1251
|
-
Return type dicated by `fmt` argument.
|
|
1252
|
-
|
|
1253
|
-
Return value will contain summary statistics for each variable. Default statistics depend on
|
|
1254
|
-
the value of ``stat_focus``:
|
|
1255
|
-
|
|
1256
|
-
``stat_focus="mean"``: `mean`, `sd`, `hdi_3%`, `hdi_97%`, `mcse_mean`, `mcse_sd`,
|
|
1257
|
-
`ess_bulk`, `ess_tail`, and `r_hat`
|
|
1258
|
-
|
|
1259
|
-
``stat_focus="median"``: `median`, `mad`, `eti_3%`, `eti_97%`, `mcse_median`, `ess_median`,
|
|
1260
|
-
`ess_tail`, and `r_hat`
|
|
1261
|
-
|
|
1262
|
-
`r_hat` is only computed for traces with 2 or more chains.
|
|
1263
|
-
|
|
1264
|
-
See Also
|
|
1265
|
-
--------
|
|
1266
|
-
waic : Compute the widely applicable information criterion.
|
|
1267
|
-
loo : Compute Pareto-smoothed importance sampling leave-one-out
|
|
1268
|
-
cross-validation (PSIS-LOO-CV).
|
|
1269
|
-
ess : Calculate estimate of the effective sample size (ess).
|
|
1270
|
-
rhat : Compute estimate of rank normalized splitR-hat for a set of traces.
|
|
1271
|
-
mcse : Calculate Markov Chain Standard Error statistic.
|
|
1272
|
-
|
|
1273
|
-
Examples
|
|
1274
|
-
--------
|
|
1275
|
-
.. ipython::
|
|
1276
|
-
|
|
1277
|
-
In [1]: import arviz as az
|
|
1278
|
-
...: data = az.load_arviz_data("centered_eight")
|
|
1279
|
-
...: az.summary(data, var_names=["mu", "tau"])
|
|
1280
|
-
|
|
1281
|
-
You can use ``filter_vars`` to select variables without having to specify all the exact
|
|
1282
|
-
names. Use ``filter_vars="like"`` to select based on partial naming:
|
|
1283
|
-
|
|
1284
|
-
.. ipython::
|
|
1285
|
-
|
|
1286
|
-
In [1]: az.summary(data, var_names=["the"], filter_vars="like")
|
|
1287
|
-
|
|
1288
|
-
Use ``filter_vars="regex"`` to select based on regular expressions, and prefix the variables
|
|
1289
|
-
you want to exclude by ``~``. Here, we exclude from the summary all the variables
|
|
1290
|
-
starting with the letter t:
|
|
1291
|
-
|
|
1292
|
-
.. ipython::
|
|
1293
|
-
|
|
1294
|
-
In [1]: az.summary(data, var_names=["~^t"], filter_vars="regex")
|
|
1295
|
-
|
|
1296
|
-
Other statistics can be calculated by passing a list of functions
|
|
1297
|
-
or a dictionary with key, function pairs.
|
|
1298
|
-
|
|
1299
|
-
.. ipython::
|
|
1300
|
-
|
|
1301
|
-
In [1]: import numpy as np
|
|
1302
|
-
...: def median_sd(x):
|
|
1303
|
-
...: median = np.percentile(x, 50)
|
|
1304
|
-
...: sd = np.sqrt(np.mean((x-median)**2))
|
|
1305
|
-
...: return sd
|
|
1306
|
-
...:
|
|
1307
|
-
...: func_dict = {
|
|
1308
|
-
...: "std": np.std,
|
|
1309
|
-
...: "median_std": median_sd,
|
|
1310
|
-
...: "5%": lambda x: np.percentile(x, 5),
|
|
1311
|
-
...: "median": lambda x: np.percentile(x, 50),
|
|
1312
|
-
...: "95%": lambda x: np.percentile(x, 95),
|
|
1313
|
-
...: }
|
|
1314
|
-
...: az.summary(
|
|
1315
|
-
...: data,
|
|
1316
|
-
...: var_names=["mu", "tau"],
|
|
1317
|
-
...: stat_funcs=func_dict,
|
|
1318
|
-
...: extend=False
|
|
1319
|
-
...: )
|
|
1320
|
-
|
|
1321
|
-
Use ``stat_focus`` to change the focus of summary statistics obatined to median:
|
|
1322
|
-
|
|
1323
|
-
.. ipython::
|
|
1324
|
-
|
|
1325
|
-
In [1]: az.summary(data, stat_focus="median")
|
|
1326
|
-
|
|
1327
|
-
"""
|
|
1328
|
-
_log.cache = []
|
|
1329
|
-
|
|
1330
|
-
if coords is None:
|
|
1331
|
-
coords = {}
|
|
1332
|
-
|
|
1333
|
-
if index_origin is not None:
|
|
1334
|
-
warnings.warn(
|
|
1335
|
-
"index_origin has been deprecated. summary now shows coordinate values, "
|
|
1336
|
-
"to change the label shown, modify the coordinate values before calling summary",
|
|
1337
|
-
DeprecationWarning,
|
|
1338
|
-
)
|
|
1339
|
-
index_origin = rcParams["data.index_origin"]
|
|
1340
|
-
if labeller is None:
|
|
1341
|
-
labeller = BaseLabeller()
|
|
1342
|
-
if hdi_prob is None:
|
|
1343
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
1344
|
-
elif not 1 >= hdi_prob > 0:
|
|
1345
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
1346
|
-
|
|
1347
|
-
if isinstance(data, InferenceData):
|
|
1348
|
-
if group is None:
|
|
1349
|
-
if not data.groups():
|
|
1350
|
-
raise TypeError("InferenceData does not contain any groups")
|
|
1351
|
-
if "posterior" in data:
|
|
1352
|
-
dataset = data["posterior"]
|
|
1353
|
-
elif "prior" in data:
|
|
1354
|
-
dataset = data["prior"]
|
|
1355
|
-
else:
|
|
1356
|
-
warnings.warn(f"Selecting first found group: {data.groups()[0]}")
|
|
1357
|
-
dataset = data[data.groups()[0]]
|
|
1358
|
-
elif group in data.groups():
|
|
1359
|
-
dataset = data[group]
|
|
1360
|
-
else:
|
|
1361
|
-
raise TypeError(f"InferenceData does not contain group: {group}")
|
|
1362
|
-
else:
|
|
1363
|
-
dataset = convert_to_dataset(data, group="posterior")
|
|
1364
|
-
var_names = _var_names(var_names, dataset, filter_vars)
|
|
1365
|
-
dataset = dataset if var_names is None else dataset[var_names]
|
|
1366
|
-
dataset = get_coords(dataset, coords)
|
|
1367
|
-
|
|
1368
|
-
fmt_group = ("wide", "long", "xarray")
|
|
1369
|
-
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
|
|
1370
|
-
raise TypeError(f"Invalid format: '{fmt}'. Formatting options are: {fmt_group}")
|
|
1371
|
-
|
|
1372
|
-
kind_group = ("all", "stats", "diagnostics")
|
|
1373
|
-
if not isinstance(kind, str) or kind not in kind_group:
|
|
1374
|
-
raise TypeError(f"Invalid kind: '{kind}'. Kind options are: {kind_group}")
|
|
1375
|
-
|
|
1376
|
-
focus_group = ("mean", "median")
|
|
1377
|
-
if not isinstance(stat_focus, str) or (stat_focus not in focus_group):
|
|
1378
|
-
raise TypeError(f"Invalid format: '{stat_focus}'. Focus options are: {focus_group}")
|
|
1379
|
-
|
|
1380
|
-
if stat_focus != "mean" and circ_var_names is not None:
|
|
1381
|
-
raise TypeError(f"Invalid format: Circular stats not supported for '{stat_focus}'")
|
|
1382
|
-
|
|
1383
|
-
if order is not None:
|
|
1384
|
-
warnings.warn(
|
|
1385
|
-
"order has been deprecated. summary now shows coordinate values.", DeprecationWarning
|
|
1386
|
-
)
|
|
1387
|
-
|
|
1388
|
-
alpha = 1 - hdi_prob
|
|
1389
|
-
|
|
1390
|
-
extra_metrics = []
|
|
1391
|
-
extra_metric_names = []
|
|
1392
|
-
|
|
1393
|
-
if stat_funcs is not None:
|
|
1394
|
-
if isinstance(stat_funcs, dict):
|
|
1395
|
-
for stat_func_name, stat_func in stat_funcs.items():
|
|
1396
|
-
extra_metrics.append(
|
|
1397
|
-
xr.apply_ufunc(
|
|
1398
|
-
_make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
|
|
1399
|
-
)
|
|
1400
|
-
)
|
|
1401
|
-
extra_metric_names.append(stat_func_name)
|
|
1402
|
-
else:
|
|
1403
|
-
for stat_func in stat_funcs:
|
|
1404
|
-
extra_metrics.append(
|
|
1405
|
-
xr.apply_ufunc(
|
|
1406
|
-
_make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
|
|
1407
|
-
)
|
|
1408
|
-
)
|
|
1409
|
-
extra_metric_names.append(stat_func.__name__)
|
|
1410
|
-
|
|
1411
|
-
metrics: List[xr.Dataset] = []
|
|
1412
|
-
metric_names: List[str] = []
|
|
1413
|
-
if extend and kind in ["all", "stats"]:
|
|
1414
|
-
if stat_focus == "mean":
|
|
1415
|
-
mean = dataset.mean(dim=("chain", "draw"), skipna=skipna)
|
|
1416
|
-
|
|
1417
|
-
sd = dataset.std(dim=("chain", "draw"), ddof=1, skipna=skipna)
|
|
1418
|
-
|
|
1419
|
-
hdi_post = hdi(dataset, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
|
|
1420
|
-
hdi_lower = hdi_post.sel(hdi="lower", drop=True)
|
|
1421
|
-
hdi_higher = hdi_post.sel(hdi="higher", drop=True)
|
|
1422
|
-
metrics.extend((mean, sd, hdi_lower, hdi_higher))
|
|
1423
|
-
metric_names.extend(
|
|
1424
|
-
("mean", "sd", f"hdi_{100 * alpha / 2:g}%", f"hdi_{100 * (1 - alpha / 2):g}%")
|
|
1425
|
-
)
|
|
1426
|
-
elif stat_focus == "median":
|
|
1427
|
-
median = dataset.median(dim=("chain", "draw"), skipna=skipna)
|
|
1428
|
-
|
|
1429
|
-
mad = stats.median_abs_deviation(dataset, dims=("chain", "draw"))
|
|
1430
|
-
eti_post = dataset.quantile(
|
|
1431
|
-
(alpha / 2, 1 - alpha / 2), dim=("chain", "draw"), skipna=skipna
|
|
1432
|
-
)
|
|
1433
|
-
eti_lower = eti_post.isel(quantile=0, drop=True)
|
|
1434
|
-
eti_higher = eti_post.isel(quantile=1, drop=True)
|
|
1435
|
-
metrics.extend((median, mad, eti_lower, eti_higher))
|
|
1436
|
-
metric_names.extend(
|
|
1437
|
-
("median", "mad", f"eti_{100 * alpha / 2:g}%", f"eti_{100 * (1 - alpha / 2):g}%")
|
|
1438
|
-
)
|
|
1439
|
-
|
|
1440
|
-
if circ_var_names:
|
|
1441
|
-
nan_policy = "omit" if skipna else "propagate"
|
|
1442
|
-
circ_mean = stats.circmean(
|
|
1443
|
-
dataset, dims=["chain", "draw"], high=np.pi, low=-np.pi, nan_policy=nan_policy
|
|
1444
|
-
)
|
|
1445
|
-
_numba_flag = Numba.numba_flag
|
|
1446
|
-
if _numba_flag:
|
|
1447
|
-
circ_sd = xr.apply_ufunc(
|
|
1448
|
-
_make_ufunc(_circular_standard_deviation),
|
|
1449
|
-
dataset,
|
|
1450
|
-
kwargs=dict(high=np.pi, low=-np.pi, skipna=skipna),
|
|
1451
|
-
input_core_dims=(("chain", "draw"),),
|
|
1452
|
-
)
|
|
1453
|
-
else:
|
|
1454
|
-
circ_sd = stats.circstd(
|
|
1455
|
-
dataset, dims=["chain", "draw"], high=np.pi, low=-np.pi, nan_policy=nan_policy
|
|
1456
|
-
)
|
|
1457
|
-
circ_mcse = xr.apply_ufunc(
|
|
1458
|
-
_make_ufunc(_mc_error),
|
|
1459
|
-
dataset,
|
|
1460
|
-
kwargs=dict(circular=True),
|
|
1461
|
-
input_core_dims=(("chain", "draw"),),
|
|
1462
|
-
)
|
|
1463
|
-
|
|
1464
|
-
circ_hdi = hdi(dataset, hdi_prob=hdi_prob, circular=True, skipna=skipna)
|
|
1465
|
-
circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True)
|
|
1466
|
-
circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)
|
|
1467
|
-
|
|
1468
|
-
if kind in ["all", "diagnostics"] and extend:
|
|
1469
|
-
diagnostics_names: Tuple[str, ...]
|
|
1470
|
-
if stat_focus == "mean":
|
|
1471
|
-
diagnostics = xr.apply_ufunc(
|
|
1472
|
-
_make_ufunc(_multichain_statistics, n_output=5, ravel=False),
|
|
1473
|
-
dataset,
|
|
1474
|
-
input_core_dims=(("chain", "draw"),),
|
|
1475
|
-
output_core_dims=tuple([] for _ in range(5)),
|
|
1476
|
-
)
|
|
1477
|
-
diagnostics_names = (
|
|
1478
|
-
"mcse_mean",
|
|
1479
|
-
"mcse_sd",
|
|
1480
|
-
"ess_bulk",
|
|
1481
|
-
"ess_tail",
|
|
1482
|
-
"r_hat",
|
|
1483
|
-
)
|
|
1484
|
-
|
|
1485
|
-
elif stat_focus == "median":
|
|
1486
|
-
diagnostics = xr.apply_ufunc(
|
|
1487
|
-
_make_ufunc(_multichain_statistics, n_output=4, ravel=False),
|
|
1488
|
-
dataset,
|
|
1489
|
-
kwargs=dict(focus="median"),
|
|
1490
|
-
input_core_dims=(("chain", "draw"),),
|
|
1491
|
-
output_core_dims=tuple([] for _ in range(4)),
|
|
1492
|
-
)
|
|
1493
|
-
diagnostics_names = (
|
|
1494
|
-
"mcse_median",
|
|
1495
|
-
"ess_median",
|
|
1496
|
-
"ess_tail",
|
|
1497
|
-
"r_hat",
|
|
1498
|
-
)
|
|
1499
|
-
metrics.extend(diagnostics)
|
|
1500
|
-
metric_names.extend(diagnostics_names)
|
|
1501
|
-
|
|
1502
|
-
if circ_var_names and kind != "diagnostics" and stat_focus == "mean":
|
|
1503
|
-
for metric, circ_stat in zip(
|
|
1504
|
-
# Replace only the first 5 statistics for their circular equivalent
|
|
1505
|
-
metrics[:5],
|
|
1506
|
-
(circ_mean, circ_sd, circ_hdi_lower, circ_hdi_higher, circ_mcse),
|
|
1507
|
-
):
|
|
1508
|
-
for circ_var in circ_var_names:
|
|
1509
|
-
metric[circ_var] = circ_stat[circ_var]
|
|
1510
|
-
|
|
1511
|
-
metrics.extend(extra_metrics)
|
|
1512
|
-
metric_names.extend(extra_metric_names)
|
|
1513
|
-
joined = (
|
|
1514
|
-
xr.concat(metrics, dim="metric").assign_coords(metric=metric_names).reset_coords(drop=True)
|
|
1515
|
-
)
|
|
1516
|
-
n_metrics = len(metric_names)
|
|
1517
|
-
n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars])
|
|
1518
|
-
|
|
1519
|
-
if fmt.lower() == "wide":
|
|
1520
|
-
summary_df = pd.DataFrame(
|
|
1521
|
-
(np.full((cast(int, n_vars), n_metrics), np.nan)), columns=metric_names
|
|
1522
|
-
)
|
|
1523
|
-
indices = []
|
|
1524
|
-
for i, (var_name, sel, isel, values) in enumerate(
|
|
1525
|
-
xarray_var_iter(joined, skip_dims={"metric"})
|
|
1526
|
-
):
|
|
1527
|
-
summary_df.iloc[i] = values
|
|
1528
|
-
indices.append(labeller.make_label_flat(var_name, sel, isel))
|
|
1529
|
-
summary_df.index = indices
|
|
1530
|
-
elif fmt.lower() == "long":
|
|
1531
|
-
df = joined.to_dataframe().reset_index().set_index("metric")
|
|
1532
|
-
df.index = list(df.index)
|
|
1533
|
-
summary_df = df
|
|
1534
|
-
else:
|
|
1535
|
-
# format is 'xarray'
|
|
1536
|
-
summary_df = joined
|
|
1537
|
-
if (round_to is not None) and (round_to not in ("None", "none")):
|
|
1538
|
-
summary_df = summary_df.round(round_to)
|
|
1539
|
-
elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")):
|
|
1540
|
-
# Don't round xarray object by default (even with "none")
|
|
1541
|
-
decimals = {
|
|
1542
|
-
col: 3 if col not in {"ess_bulk", "ess_tail", "r_hat"} else 2 if col == "r_hat" else 0
|
|
1543
|
-
for col in summary_df.columns
|
|
1544
|
-
}
|
|
1545
|
-
summary_df = summary_df.round(decimals)
|
|
1546
|
-
|
|
1547
|
-
return summary_df
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
def waic(data, pointwise=None, var_name=None, scale=None, dask_kwargs=None):
|
|
1551
|
-
"""Compute the widely applicable information criterion.
|
|
1552
|
-
|
|
1553
|
-
Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the
|
|
1554
|
-
WAIC's standard error and the effective number of parameters.
|
|
1555
|
-
Read more theory here https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1004.2316
|
|
1556
|
-
|
|
1557
|
-
Parameters
|
|
1558
|
-
----------
|
|
1559
|
-
data: obj
|
|
1560
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
1561
|
-
Refer to documentation of :func:`arviz.convert_to_inference_data` for details.
|
|
1562
|
-
pointwise: bool
|
|
1563
|
-
If True the pointwise predictive accuracy will be returned. Defaults to
|
|
1564
|
-
``stats.ic_pointwise`` rcParam.
|
|
1565
|
-
var_name : str, optional
|
|
1566
|
-
The name of the variable in log_likelihood groups storing the pointwise log
|
|
1567
|
-
likelihood data to use for waic computation.
|
|
1568
|
-
scale: str
|
|
1569
|
-
Output scale for WAIC. Available options are:
|
|
1570
|
-
|
|
1571
|
-
- `log` : (default) log-score
|
|
1572
|
-
- `negative_log` : -1 * log-score
|
|
1573
|
-
- `deviance` : -2 * log-score
|
|
1574
|
-
|
|
1575
|
-
A higher log-score (or a lower deviance or negative log_score) indicates a model with
|
|
1576
|
-
better predictive accuracy.
|
|
1577
|
-
dask_kwargs : dict, optional
|
|
1578
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
1579
|
-
|
|
1580
|
-
Returns
|
|
1581
|
-
-------
|
|
1582
|
-
ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
|
|
1583
|
-
elpd_waic: approximated expected log pointwise predictive density (elpd)
|
|
1584
|
-
se: standard error of the elpd
|
|
1585
|
-
p_waic: effective number parameters
|
|
1586
|
-
n_samples: number of samples
|
|
1587
|
-
n_data_points: number of data points
|
|
1588
|
-
warning: bool
|
|
1589
|
-
True if posterior variance of the log predictive densities exceeds 0.4
|
|
1590
|
-
waic_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
1591
|
-
only if pointwise=True
|
|
1592
|
-
scale: scale of the elpd
|
|
1593
|
-
|
|
1594
|
-
The returned object has a custom print method that overrides pd.Series method.
|
|
1595
|
-
|
|
1596
|
-
See Also
|
|
1597
|
-
--------
|
|
1598
|
-
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
|
|
1599
|
-
compare : Compare models based on PSIS-LOO-CV or WAIC.
|
|
1600
|
-
plot_compare : Summary plot for model comparison.
|
|
1601
|
-
|
|
1602
|
-
Examples
|
|
1603
|
-
--------
|
|
1604
|
-
Calculate WAIC of a model:
|
|
1605
|
-
|
|
1606
|
-
.. ipython::
|
|
1607
|
-
|
|
1608
|
-
In [1]: import arviz as az
|
|
1609
|
-
...: data = az.load_arviz_data("centered_eight")
|
|
1610
|
-
...: az.waic(data)
|
|
1611
|
-
|
|
1612
|
-
Calculate WAIC of a model and return the pointwise values:
|
|
1613
|
-
|
|
1614
|
-
.. ipython::
|
|
1615
|
-
|
|
1616
|
-
In [2]: data_waic = az.waic(data, pointwise=True)
|
|
1617
|
-
...: data_waic.waic_i
|
|
1618
|
-
"""
|
|
1619
|
-
inference_data = convert_to_inference_data(data)
|
|
1620
|
-
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
|
|
1621
|
-
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
|
|
1622
|
-
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
|
|
1623
|
-
|
|
1624
|
-
if scale == "deviance":
|
|
1625
|
-
scale_value = -2
|
|
1626
|
-
elif scale == "log":
|
|
1627
|
-
scale_value = 1
|
|
1628
|
-
elif scale == "negative_log":
|
|
1629
|
-
scale_value = -1
|
|
1630
|
-
else:
|
|
1631
|
-
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
|
|
1632
|
-
|
|
1633
|
-
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
|
|
1634
|
-
shape = log_likelihood.shape
|
|
1635
|
-
n_samples = shape[-1]
|
|
1636
|
-
n_data_points = np.prod(shape[:-1])
|
|
1637
|
-
|
|
1638
|
-
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
|
1639
|
-
kwargs = {"input_core_dims": [["__sample__"]]}
|
|
1640
|
-
lppd_i = _wrap_xarray_ufunc(
|
|
1641
|
-
_logsumexp,
|
|
1642
|
-
log_likelihood,
|
|
1643
|
-
func_kwargs={"b_inv": n_samples},
|
|
1644
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
1645
|
-
dask_kwargs=dask_kwargs,
|
|
1646
|
-
**kwargs,
|
|
1647
|
-
)
|
|
1648
|
-
|
|
1649
|
-
vars_lpd = log_likelihood.var(dim="__sample__")
|
|
1650
|
-
warn_mg = False
|
|
1651
|
-
if np.any(vars_lpd > 0.4):
|
|
1652
|
-
warnings.warn(
|
|
1653
|
-
(
|
|
1654
|
-
"For one or more samples the posterior variance of the log predictive "
|
|
1655
|
-
"densities exceeds 0.4. This could be indication of WAIC starting to fail. \n"
|
|
1656
|
-
"See http://arxiv.org/abs/1507.04544 for details"
|
|
1657
|
-
)
|
|
1658
|
-
)
|
|
1659
|
-
warn_mg = True
|
|
1660
|
-
|
|
1661
|
-
waic_i = scale_value * (lppd_i - vars_lpd)
|
|
1662
|
-
waic_se = (n_data_points * np.var(waic_i.values)) ** 0.5
|
|
1663
|
-
waic_sum = np.sum(waic_i.values)
|
|
1664
|
-
p_waic = np.sum(vars_lpd.values)
|
|
1665
|
-
|
|
1666
|
-
if not pointwise:
|
|
1667
|
-
return ELPDData(
|
|
1668
|
-
data=[waic_sum, waic_se, p_waic, n_samples, n_data_points, warn_mg, scale],
|
|
1669
|
-
index=[
|
|
1670
|
-
"waic",
|
|
1671
|
-
"se",
|
|
1672
|
-
"p_waic",
|
|
1673
|
-
"n_samples",
|
|
1674
|
-
"n_data_points",
|
|
1675
|
-
"warning",
|
|
1676
|
-
"scale",
|
|
1677
|
-
],
|
|
1678
|
-
)
|
|
1679
|
-
if np.equal(waic_sum, waic_i).all(): # pylint: disable=no-member
|
|
1680
|
-
warnings.warn(
|
|
1681
|
-
"""The point-wise WAIC is the same with the sum WAIC, please double check
|
|
1682
|
-
the Observed RV in your model to make sure it returns element-wise logp.
|
|
1683
|
-
"""
|
|
1684
|
-
)
|
|
1685
|
-
return ELPDData(
|
|
1686
|
-
data=[
|
|
1687
|
-
waic_sum,
|
|
1688
|
-
waic_se,
|
|
1689
|
-
p_waic,
|
|
1690
|
-
n_samples,
|
|
1691
|
-
n_data_points,
|
|
1692
|
-
warn_mg,
|
|
1693
|
-
waic_i.rename("waic_i"),
|
|
1694
|
-
scale,
|
|
1695
|
-
],
|
|
1696
|
-
index=[
|
|
1697
|
-
"elpd_waic",
|
|
1698
|
-
"se",
|
|
1699
|
-
"p_waic",
|
|
1700
|
-
"n_samples",
|
|
1701
|
-
"n_data_points",
|
|
1702
|
-
"warning",
|
|
1703
|
-
"waic_i",
|
|
1704
|
-
"scale",
|
|
1705
|
-
],
|
|
1706
|
-
)
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
def loo_pit(idata=None, *, y=None, y_hat=None, log_weights=None):
|
|
1710
|
-
"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
|
|
1711
|
-
|
|
1712
|
-
Parameters
|
|
1713
|
-
----------
|
|
1714
|
-
idata: InferenceData
|
|
1715
|
-
:class:`arviz.InferenceData` object.
|
|
1716
|
-
y: array, DataArray or str
|
|
1717
|
-
Observed data. If str, ``idata`` must be present and contain the observed data group
|
|
1718
|
-
y_hat: array, DataArray or str
|
|
1719
|
-
Posterior predictive samples for ``y``. It must have the same shape as y plus an
|
|
1720
|
-
extra dimension at the end of size n_samples (chains and draws stacked). If str or
|
|
1721
|
-
None, ``idata`` must contain the posterior predictive group. If None, y_hat is taken
|
|
1722
|
-
equal to y, thus, y must be str too.
|
|
1723
|
-
log_weights: array or DataArray
|
|
1724
|
-
Smoothed log_weights. It must have the same shape as ``y_hat``
|
|
1725
|
-
dask_kwargs : dict, optional
|
|
1726
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
1727
|
-
|
|
1728
|
-
Returns
|
|
1729
|
-
-------
|
|
1730
|
-
loo_pit: array or DataArray
|
|
1731
|
-
Value of the LOO-PIT at each observed data point.
|
|
1732
|
-
|
|
1733
|
-
See Also
|
|
1734
|
-
--------
|
|
1735
|
-
plot_loo_pit : Plot Leave-One-Out probability integral transformation (PIT) predictive checks.
|
|
1736
|
-
loo : Compute Pareto-smoothed importance sampling leave-one-out
|
|
1737
|
-
cross-validation (PSIS-LOO-CV).
|
|
1738
|
-
plot_elpd : Plot pointwise elpd differences between two or more models.
|
|
1739
|
-
plot_khat : Plot Pareto tail indices for diagnosing convergence.
|
|
1740
|
-
|
|
1741
|
-
Examples
|
|
1742
|
-
--------
|
|
1743
|
-
Calculate LOO-PIT values using as test quantity the observed values themselves.
|
|
1744
|
-
|
|
1745
|
-
.. ipython::
|
|
1746
|
-
|
|
1747
|
-
In [1]: import arviz as az
|
|
1748
|
-
...: data = az.load_arviz_data("centered_eight")
|
|
1749
|
-
...: az.loo_pit(idata=data, y="obs")
|
|
1750
|
-
|
|
1751
|
-
Calculate LOO-PIT values using as test quantity the square of the difference between
|
|
1752
|
-
each observation and `mu`. Both ``y`` and ``y_hat`` inputs will be array-like,
|
|
1753
|
-
but ``idata`` will still be passed in order to calculate the ``log_weights`` from
|
|
1754
|
-
there.
|
|
1755
|
-
|
|
1756
|
-
.. ipython::
|
|
1757
|
-
|
|
1758
|
-
In [1]: T = data.observed_data.obs - data.posterior.mu.median(dim=("chain", "draw"))
|
|
1759
|
-
...: T_hat = data.posterior_predictive.obs - data.posterior.mu
|
|
1760
|
-
...: T_hat = T_hat.stack(__sample__=("chain", "draw"))
|
|
1761
|
-
...: az.loo_pit(idata=data, y=T**2, y_hat=T_hat**2)
|
|
1762
|
-
|
|
1763
|
-
"""
|
|
1764
|
-
y_str = ""
|
|
1765
|
-
if idata is not None and not isinstance(idata, InferenceData):
|
|
1766
|
-
raise ValueError("idata must be of type InferenceData or None")
|
|
1767
|
-
|
|
1768
|
-
if idata is None:
|
|
1769
|
-
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat, log_weights)):
|
|
1770
|
-
raise ValueError(
|
|
1771
|
-
"all 3 y, y_hat and log_weights must be array or DataArray when idata is None "
|
|
1772
|
-
f"but they are of types {[type(arg) for arg in (y, y_hat, log_weights)]}"
|
|
1773
|
-
)
|
|
1774
|
-
|
|
1775
|
-
else:
|
|
1776
|
-
if y_hat is None and isinstance(y, str):
|
|
1777
|
-
y_hat = y
|
|
1778
|
-
elif y_hat is None:
|
|
1779
|
-
raise ValueError("y_hat cannot be None if y is not a str")
|
|
1780
|
-
if isinstance(y, str):
|
|
1781
|
-
y_str = y
|
|
1782
|
-
y = idata.observed_data[y].values
|
|
1783
|
-
elif not isinstance(y, (np.ndarray, xr.DataArray)):
|
|
1784
|
-
raise ValueError(f"y must be of types array, DataArray or str, not {type(y)}")
|
|
1785
|
-
if isinstance(y_hat, str):
|
|
1786
|
-
y_hat = idata.posterior_predictive[y_hat].stack(__sample__=("chain", "draw")).values
|
|
1787
|
-
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
|
|
1788
|
-
raise ValueError(f"y_hat must be of types array, DataArray or str, not {type(y_hat)}")
|
|
1789
|
-
if log_weights is None:
|
|
1790
|
-
if y_str:
|
|
1791
|
-
try:
|
|
1792
|
-
log_likelihood = _get_log_likelihood(idata, var_name=y_str)
|
|
1793
|
-
except TypeError:
|
|
1794
|
-
log_likelihood = _get_log_likelihood(idata)
|
|
1795
|
-
else:
|
|
1796
|
-
log_likelihood = _get_log_likelihood(idata)
|
|
1797
|
-
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
|
|
1798
|
-
posterior = convert_to_dataset(idata, group="posterior")
|
|
1799
|
-
n_chains = len(posterior.chain)
|
|
1800
|
-
n_samples = len(log_likelihood.__sample__)
|
|
1801
|
-
ess_p = ess(posterior, method="mean")
|
|
1802
|
-
# this mean is over all data variables
|
|
1803
|
-
reff = (
|
|
1804
|
-
(np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples)
|
|
1805
|
-
if n_chains > 1
|
|
1806
|
-
else 1
|
|
1807
|
-
)
|
|
1808
|
-
log_weights = psislw(-log_likelihood, reff=reff)[0].values
|
|
1809
|
-
elif not isinstance(log_weights, (np.ndarray, xr.DataArray)):
|
|
1810
|
-
raise ValueError(
|
|
1811
|
-
f"log_weights must be None or of types array or DataArray, not {type(log_weights)}"
|
|
1812
|
-
)
|
|
1813
|
-
|
|
1814
|
-
if len(y.shape) + 1 != len(y_hat.shape):
|
|
1815
|
-
raise ValueError(
|
|
1816
|
-
f"y_hat must have 1 more dimension than y, but y_hat has {len(y_hat.shape)} dims and "
|
|
1817
|
-
f"y has {len(y.shape)} dims"
|
|
1818
|
-
)
|
|
1819
|
-
|
|
1820
|
-
if y.shape != y_hat.shape[:-1]:
|
|
1821
|
-
raise ValueError(
|
|
1822
|
-
f"y has shape: {y.shape} which should be equal to y_hat shape (omitting the last "
|
|
1823
|
-
f"dimension): {y_hat.shape}"
|
|
1824
|
-
)
|
|
1825
|
-
|
|
1826
|
-
if y_hat.shape != log_weights.shape:
|
|
1827
|
-
raise ValueError(
|
|
1828
|
-
"y_hat and log_weights must have the same shape but have shapes "
|
|
1829
|
-
f"{y_hat.shape,} and {log_weights.shape}"
|
|
1830
|
-
)
|
|
1831
|
-
|
|
1832
|
-
kwargs = {
|
|
1833
|
-
"input_core_dims": [[], ["__sample__"], ["__sample__"]],
|
|
1834
|
-
"output_core_dims": [[]],
|
|
1835
|
-
"join": "left",
|
|
1836
|
-
}
|
|
1837
|
-
ufunc_kwargs = {"n_dims": 1}
|
|
1838
|
-
|
|
1839
|
-
if y.dtype.kind == "i" or y_hat.dtype.kind == "i":
|
|
1840
|
-
y, y_hat = smooth_data(y, y_hat)
|
|
1841
|
-
|
|
1842
|
-
return _wrap_xarray_ufunc(
|
|
1843
|
-
_loo_pit,
|
|
1844
|
-
y,
|
|
1845
|
-
y_hat,
|
|
1846
|
-
log_weights,
|
|
1847
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
1848
|
-
**kwargs,
|
|
1849
|
-
)
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
def _loo_pit(y, y_hat, log_weights):
|
|
1853
|
-
"""Compute LOO-PIT values."""
|
|
1854
|
-
sel = y_hat <= y
|
|
1855
|
-
if np.sum(sel) > 0:
|
|
1856
|
-
value = np.exp(_logsumexp(log_weights[sel]))
|
|
1857
|
-
return min(1, value)
|
|
1858
|
-
else:
|
|
1859
|
-
return 0
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
def apply_test_function(
|
|
1863
|
-
idata,
|
|
1864
|
-
func,
|
|
1865
|
-
group="both",
|
|
1866
|
-
var_names=None,
|
|
1867
|
-
pointwise=False,
|
|
1868
|
-
out_data_shape=None,
|
|
1869
|
-
out_pp_shape=None,
|
|
1870
|
-
out_name_data="T",
|
|
1871
|
-
out_name_pp=None,
|
|
1872
|
-
func_args=None,
|
|
1873
|
-
func_kwargs=None,
|
|
1874
|
-
ufunc_kwargs=None,
|
|
1875
|
-
wrap_data_kwargs=None,
|
|
1876
|
-
wrap_pp_kwargs=None,
|
|
1877
|
-
inplace=True,
|
|
1878
|
-
overwrite=None,
|
|
1879
|
-
):
|
|
1880
|
-
"""Apply a Bayesian test function to an InferenceData object.
|
|
1881
|
-
|
|
1882
|
-
Parameters
|
|
1883
|
-
----------
|
|
1884
|
-
idata: InferenceData
|
|
1885
|
-
:class:`arviz.InferenceData` object on which to apply the test function.
|
|
1886
|
-
This function will add new variables to the InferenceData object
|
|
1887
|
-
to store the result without modifying the existing ones.
|
|
1888
|
-
func: callable
|
|
1889
|
-
Callable that calculates the test function. It must have the following call signature
|
|
1890
|
-
``func(y, theta, *args, **kwargs)`` (where ``y`` is the observed data or posterior
|
|
1891
|
-
predictive and ``theta`` the model parameters) even if not all the arguments are
|
|
1892
|
-
used.
|
|
1893
|
-
group: str, optional
|
|
1894
|
-
Group on which to apply the test function. Can be observed_data, posterior_predictive
|
|
1895
|
-
or both.
|
|
1896
|
-
var_names: dict group -> var_names, optional
|
|
1897
|
-
Mapping from group name to the variables to be passed to func. It can be a dict of
|
|
1898
|
-
strings or lists of strings. There is also the option of using ``both`` as key,
|
|
1899
|
-
in which case, the same variables are used in observed data and posterior predictive
|
|
1900
|
-
groups
|
|
1901
|
-
pointwise: bool, optional
|
|
1902
|
-
If True, apply the test function to each observation and sample, otherwise, apply
|
|
1903
|
-
test function to each sample.
|
|
1904
|
-
out_data_shape, out_pp_shape: tuple, optional
|
|
1905
|
-
Output shape of the test function applied to the observed/posterior predictive data.
|
|
1906
|
-
If None, the default depends on the value of pointwise.
|
|
1907
|
-
out_name_data, out_name_pp: str, optional
|
|
1908
|
-
Name of the variables to add to the observed_data and posterior_predictive datasets
|
|
1909
|
-
respectively. ``out_name_pp`` can be ``None``, in which case will be taken equal to
|
|
1910
|
-
``out_name_data``.
|
|
1911
|
-
func_args: sequence, optional
|
|
1912
|
-
Passed as is to ``func``
|
|
1913
|
-
func_kwargs: mapping, optional
|
|
1914
|
-
Passed as is to ``func``
|
|
1915
|
-
wrap_data_kwargs, wrap_pp_kwargs: mapping, optional
|
|
1916
|
-
kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. By default, some suitable input_core_dims
|
|
1917
|
-
are used.
|
|
1918
|
-
inplace: bool, optional
|
|
1919
|
-
If True, add the variables inplace, otherwise, return a copy of idata with the variables
|
|
1920
|
-
added.
|
|
1921
|
-
overwrite: bool, optional
|
|
1922
|
-
Overwrite data in case ``out_name_data`` or ``out_name_pp`` are already variables in
|
|
1923
|
-
dataset. If ``None`` it will be the opposite of inplace.
|
|
1924
|
-
|
|
1925
|
-
Returns
|
|
1926
|
-
-------
|
|
1927
|
-
idata: InferenceData
|
|
1928
|
-
Output InferenceData object. If ``inplace=True``, it is the same input object modified
|
|
1929
|
-
inplace.
|
|
1930
|
-
|
|
1931
|
-
See Also
|
|
1932
|
-
--------
|
|
1933
|
-
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
|
|
1934
|
-
|
|
1935
|
-
Notes
|
|
1936
|
-
-----
|
|
1937
|
-
This function is provided for convenience to wrap scalar or functions working on low
|
|
1938
|
-
dims to inference data object. It is not optimized to be faster nor as fast as vectorized
|
|
1939
|
-
computations.
|
|
1940
|
-
|
|
1941
|
-
Examples
|
|
1942
|
-
--------
|
|
1943
|
-
Use ``apply_test_function`` to wrap ``numpy.min`` for illustration purposes. And plot the
|
|
1944
|
-
results.
|
|
1945
|
-
|
|
1946
|
-
.. plot::
|
|
1947
|
-
:context: close-figs
|
|
1948
|
-
|
|
1949
|
-
>>> import arviz as az
|
|
1950
|
-
>>> idata = az.load_arviz_data("centered_eight")
|
|
1951
|
-
>>> az.apply_test_function(idata, lambda y, theta: np.min(y))
|
|
1952
|
-
>>> T = idata.observed_data.T.item()
|
|
1953
|
-
>>> az.plot_posterior(idata, var_names=["T"], group="posterior_predictive", ref_val=T)
|
|
1954
|
-
|
|
1955
|
-
"""
|
|
1956
|
-
out = idata if inplace else deepcopy(idata)
|
|
1957
|
-
|
|
1958
|
-
valid_groups = ("observed_data", "posterior_predictive", "both")
|
|
1959
|
-
if group not in valid_groups:
|
|
1960
|
-
raise ValueError(f"Invalid group argument. Must be one of {valid_groups} not {group}.")
|
|
1961
|
-
if overwrite is None:
|
|
1962
|
-
overwrite = not inplace
|
|
1963
|
-
|
|
1964
|
-
if out_name_pp is None:
|
|
1965
|
-
out_name_pp = out_name_data
|
|
1966
|
-
|
|
1967
|
-
if func_args is None:
|
|
1968
|
-
func_args = tuple()
|
|
1969
|
-
|
|
1970
|
-
if func_kwargs is None:
|
|
1971
|
-
func_kwargs = {}
|
|
1972
|
-
|
|
1973
|
-
if ufunc_kwargs is None:
|
|
1974
|
-
ufunc_kwargs = {}
|
|
1975
|
-
ufunc_kwargs.setdefault("check_shape", False)
|
|
1976
|
-
ufunc_kwargs.setdefault("ravel", False)
|
|
1977
|
-
|
|
1978
|
-
if wrap_data_kwargs is None:
|
|
1979
|
-
wrap_data_kwargs = {}
|
|
1980
|
-
if wrap_pp_kwargs is None:
|
|
1981
|
-
wrap_pp_kwargs = {}
|
|
1982
|
-
if var_names is None:
|
|
1983
|
-
var_names = {}
|
|
1984
|
-
|
|
1985
|
-
both_var_names = var_names.pop("both", None)
|
|
1986
|
-
var_names.setdefault("posterior", list(out.posterior.data_vars))
|
|
1987
|
-
|
|
1988
|
-
in_posterior = out.posterior[var_names["posterior"]]
|
|
1989
|
-
if isinstance(in_posterior, xr.Dataset):
|
|
1990
|
-
in_posterior = in_posterior.to_array().squeeze()
|
|
1991
|
-
|
|
1992
|
-
groups = ("posterior_predictive", "observed_data") if group == "both" else [group]
|
|
1993
|
-
for grp in groups:
|
|
1994
|
-
out_group_shape = out_data_shape if grp == "observed_data" else out_pp_shape
|
|
1995
|
-
out_name_group = out_name_data if grp == "observed_data" else out_name_pp
|
|
1996
|
-
wrap_group_kwargs = wrap_data_kwargs if grp == "observed_data" else wrap_pp_kwargs
|
|
1997
|
-
if not hasattr(out, grp):
|
|
1998
|
-
raise ValueError(f"InferenceData object must have {grp} group")
|
|
1999
|
-
if not overwrite and out_name_group in getattr(out, grp).data_vars:
|
|
2000
|
-
raise ValueError(
|
|
2001
|
-
f"Should overwrite: {out_name_group} variable present in group {grp},"
|
|
2002
|
-
" but overwrite is False"
|
|
2003
|
-
)
|
|
2004
|
-
var_names.setdefault(
|
|
2005
|
-
grp, list(getattr(out, grp).data_vars) if both_var_names is None else both_var_names
|
|
2006
|
-
)
|
|
2007
|
-
in_group = getattr(out, grp)[var_names[grp]]
|
|
2008
|
-
if isinstance(in_group, xr.Dataset):
|
|
2009
|
-
in_group = in_group.to_array(dim=f"{grp}_var").squeeze()
|
|
2010
|
-
|
|
2011
|
-
if pointwise:
|
|
2012
|
-
out_group_shape = in_group.shape if out_group_shape is None else out_group_shape
|
|
2013
|
-
elif grp == "observed_data":
|
|
2014
|
-
out_group_shape = () if out_group_shape is None else out_group_shape
|
|
2015
|
-
elif grp == "posterior_predictive":
|
|
2016
|
-
out_group_shape = in_group.shape[:2] if out_group_shape is None else out_group_shape
|
|
2017
|
-
loop_dims = in_group.dims[: len(out_group_shape)]
|
|
2018
|
-
|
|
2019
|
-
wrap_group_kwargs.setdefault(
|
|
2020
|
-
"input_core_dims",
|
|
2021
|
-
[
|
|
2022
|
-
[dim for dim in dataset.dims if dim not in loop_dims]
|
|
2023
|
-
for dataset in [in_group, in_posterior]
|
|
2024
|
-
],
|
|
2025
|
-
)
|
|
2026
|
-
func_kwargs["out"] = np.empty(out_group_shape)
|
|
2027
|
-
|
|
2028
|
-
out_group = getattr(out, grp)
|
|
2029
|
-
try:
|
|
2030
|
-
out_group[out_name_group] = _wrap_xarray_ufunc(
|
|
2031
|
-
func,
|
|
2032
|
-
in_group.values,
|
|
2033
|
-
in_posterior.values,
|
|
2034
|
-
func_args=func_args,
|
|
2035
|
-
func_kwargs=func_kwargs,
|
|
2036
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
2037
|
-
**wrap_group_kwargs,
|
|
2038
|
-
)
|
|
2039
|
-
except IndexError:
|
|
2040
|
-
excluded_dims = set(
|
|
2041
|
-
wrap_group_kwargs["input_core_dims"][0] + wrap_group_kwargs["input_core_dims"][1]
|
|
2042
|
-
)
|
|
2043
|
-
out_group[out_name_group] = _wrap_xarray_ufunc(
|
|
2044
|
-
func,
|
|
2045
|
-
*xr.broadcast(in_group, in_posterior, exclude=excluded_dims),
|
|
2046
|
-
func_args=func_args,
|
|
2047
|
-
func_kwargs=func_kwargs,
|
|
2048
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
2049
|
-
**wrap_group_kwargs,
|
|
2050
|
-
)
|
|
2051
|
-
setattr(out, grp, out_group)
|
|
2052
|
-
|
|
2053
|
-
return out
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
def weight_predictions(idatas, weights=None):
|
|
2057
|
-
"""
|
|
2058
|
-
Generate weighted posterior predictive samples from a list of InferenceData
|
|
2059
|
-
and a set of weights.
|
|
2060
|
-
|
|
2061
|
-
Parameters
|
|
2062
|
-
---------
|
|
2063
|
-
idatas : list[InferenceData]
|
|
2064
|
-
List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive`
|
|
2065
|
-
and `observed_data`. Observations should be the same for all InferenceData objects.
|
|
2066
|
-
weights : array-like, optional
|
|
2067
|
-
Individual weights for each model. Weights should be positive. If they do not sum up to 1,
|
|
2068
|
-
they will be normalized. Default, same weight for each model.
|
|
2069
|
-
Weights can be computed using many different methods including those in
|
|
2070
|
-
:func:`arviz.compare`.
|
|
2071
|
-
|
|
2072
|
-
Returns
|
|
2073
|
-
-------
|
|
2074
|
-
idata: InferenceData
|
|
2075
|
-
Output InferenceData object with the groups `posterior_predictive` and `observed_data`.
|
|
2076
|
-
|
|
2077
|
-
See Also
|
|
2078
|
-
--------
|
|
2079
|
-
compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation
|
|
2080
|
-
"""
|
|
2081
|
-
if len(idatas) < 2:
|
|
2082
|
-
raise ValueError("You should provide a list with at least two InferenceData objects")
|
|
2083
|
-
|
|
2084
|
-
if not all("posterior_predictive" in idata.groups() for idata in idatas):
|
|
2085
|
-
raise ValueError(
|
|
2086
|
-
"All the InferenceData objects must contain the `posterior_predictive` group"
|
|
2087
|
-
)
|
|
2088
|
-
|
|
2089
|
-
if not all(idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]):
|
|
2090
|
-
raise ValueError("The observed data should be the same for all InferenceData objects")
|
|
2091
|
-
|
|
2092
|
-
if weights is None:
|
|
2093
|
-
weights = np.ones(len(idatas)) / len(idatas)
|
|
2094
|
-
elif len(idatas) != len(weights):
|
|
2095
|
-
raise ValueError(
|
|
2096
|
-
"The number of weights should be the same as the number of InferenceData objects"
|
|
2097
|
-
)
|
|
2098
|
-
|
|
2099
|
-
weights = np.array(weights, dtype=float)
|
|
2100
|
-
weights /= weights.sum()
|
|
2101
|
-
|
|
2102
|
-
len_idatas = [
|
|
2103
|
-
idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
|
|
2104
|
-
for idata in idatas
|
|
2105
|
-
]
|
|
2106
|
-
|
|
2107
|
-
if not all(len_idatas):
|
|
2108
|
-
raise ValueError("At least one of your idatas has 0 samples")
|
|
2109
|
-
|
|
2110
|
-
new_samples = (np.min(len_idatas) * weights).astype(int)
|
|
2111
|
-
|
|
2112
|
-
new_idatas = [
|
|
2113
|
-
extract(idata, group="posterior_predictive", num_samples=samples).reset_coords()
|
|
2114
|
-
for samples, idata in zip(new_samples, idatas)
|
|
2115
|
-
]
|
|
2116
|
-
|
|
2117
|
-
weighted_samples = InferenceData(
|
|
2118
|
-
posterior_predictive=xr.concat(new_idatas, dim="sample"),
|
|
2119
|
-
observed_data=idatas[0].observed_data,
|
|
2120
|
-
)
|
|
2121
|
-
|
|
2122
|
-
return weighted_samples
|
|
2123
|
-
|
|
2124
|
-
|
|
2125
|
-
def psens(
|
|
2126
|
-
data,
|
|
2127
|
-
*,
|
|
2128
|
-
component="prior",
|
|
2129
|
-
component_var_names=None,
|
|
2130
|
-
component_coords=None,
|
|
2131
|
-
var_names=None,
|
|
2132
|
-
coords=None,
|
|
2133
|
-
filter_vars=None,
|
|
2134
|
-
delta=0.01,
|
|
2135
|
-
dask_kwargs=None,
|
|
2136
|
-
):
|
|
2137
|
-
"""Compute power-scaling sensitivity diagnostic.
|
|
2138
|
-
|
|
2139
|
-
Power-scales the prior or likelihood and calculates how much the posterior is affected.
|
|
2140
|
-
|
|
2141
|
-
Parameters
|
|
2142
|
-
----------
|
|
2143
|
-
data : obj
|
|
2144
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
2145
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
2146
|
-
For ndarray: shape = (chain, draw).
|
|
2147
|
-
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
|
|
2148
|
-
component : {"prior", "likelihood"}, default "prior"
|
|
2149
|
-
When `component` is "likelihood", the log likelihood values are retrieved
|
|
2150
|
-
from the ``log_likelihood`` group as pointwise log likelihood and added
|
|
2151
|
-
together. With "prior", the log prior values are retrieved from the
|
|
2152
|
-
``log_prior`` group.
|
|
2153
|
-
component_var_names : str, optional
|
|
2154
|
-
Name of the prior or log likelihood variables to use
|
|
2155
|
-
component_coords : dict, optional
|
|
2156
|
-
Coordinates defining a subset over the component element for which to
|
|
2157
|
-
compute the prior sensitivity diagnostic.
|
|
2158
|
-
var_names : list of str, optional
|
|
2159
|
-
Names of posterior variables to include in the power scaling sensitivity diagnostic
|
|
2160
|
-
coords : dict, optional
|
|
2161
|
-
Coordinates defining a subset over the posterior. Only these variables will
|
|
2162
|
-
be used when computing the prior sensitivity.
|
|
2163
|
-
filter_vars: {None, "like", "regex"}, default None
|
|
2164
|
-
If ``None`` (default), interpret var_names as the real variables names.
|
|
2165
|
-
If "like", interpret var_names as substrings of the real variables names.
|
|
2166
|
-
If "regex", interpret var_names as regular expressions on the real variables names.
|
|
2167
|
-
delta : float
|
|
2168
|
-
Value for finite difference derivative calculation.
|
|
2169
|
-
dask_kwargs : dict, optional
|
|
2170
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
2171
|
-
|
|
2172
|
-
Returns
|
|
2173
|
-
-------
|
|
2174
|
-
xarray.Dataset
|
|
2175
|
-
Returns dataset of power-scaling sensitivity diagnostic values.
|
|
2176
|
-
Higher sensitivity values indicate greater sensitivity.
|
|
2177
|
-
Prior sensitivity above 0.05 indicates informative prior.
|
|
2178
|
-
Likelihood sensitivity below 0.05 indicates weak or nonin-formative likelihood.
|
|
2179
|
-
|
|
2180
|
-
Examples
|
|
2181
|
-
--------
|
|
2182
|
-
Compute the likelihood sensitivity for the non centered eight model:
|
|
2183
|
-
|
|
2184
|
-
.. ipython::
|
|
2185
|
-
|
|
2186
|
-
In [1]: import arviz as az
|
|
2187
|
-
...: data = az.load_arviz_data("non_centered_eight")
|
|
2188
|
-
...: az.psens(data, component="likelihood")
|
|
2189
|
-
|
|
2190
|
-
To compute the prior sensitivity, we need to first compute the log prior
|
|
2191
|
-
at each posterior sample. In our case, we know mu has a normal prior :math:`N(0, 5)`,
|
|
2192
|
-
tau is a half cauchy prior with scale/beta parameter 5,
|
|
2193
|
-
and theta has a standard normal as prior.
|
|
2194
|
-
We add this information to the ``log_prior`` group before computing powerscaling
|
|
2195
|
-
check with ``psens``
|
|
2196
|
-
|
|
2197
|
-
.. ipython::
|
|
2198
|
-
|
|
2199
|
-
In [1]: from xarray_einstats.stats import XrContinuousRV
|
|
2200
|
-
...: from scipy.stats import norm, halfcauchy
|
|
2201
|
-
...: post = data.posterior
|
|
2202
|
-
...: log_prior = {
|
|
2203
|
-
...: "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
|
|
2204
|
-
...: "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
|
|
2205
|
-
...: "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
|
|
2206
|
-
...: }
|
|
2207
|
-
...: data.add_groups({"log_prior": log_prior})
|
|
2208
|
-
...: az.psens(data, component="prior")
|
|
2209
|
-
|
|
2210
|
-
Notes
|
|
2211
|
-
-----
|
|
2212
|
-
The diagnostic is computed by power-scaling the specified component (prior or likelihood)
|
|
2213
|
-
and determining the degree to which the posterior changes as described in [1]_.
|
|
2214
|
-
It uses Pareto-smoothed importance sampling to avoid refitting the model.
|
|
2215
|
-
|
|
2216
|
-
References
|
|
2217
|
-
----------
|
|
2218
|
-
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
|
|
2219
|
-
power-scaling*, 2022, https://arxiv.org/abs/2107.14054
|
|
2220
|
-
|
|
2221
|
-
"""
|
|
2222
|
-
dataset = extract(data, var_names=var_names, filter_vars=filter_vars, group="posterior")
|
|
2223
|
-
if coords is None:
|
|
2224
|
-
dataset = dataset.sel(coords)
|
|
2225
|
-
|
|
2226
|
-
if component == "likelihood":
|
|
2227
|
-
component_draws = _get_log_likelihood(data, var_name=component_var_names, single_var=False)
|
|
2228
|
-
elif component == "prior":
|
|
2229
|
-
component_draws = _get_log_prior(data, var_names=component_var_names)
|
|
2230
|
-
else:
|
|
2231
|
-
raise ValueError("Value for `component` argument not recognized")
|
|
2232
|
-
|
|
2233
|
-
component_draws = component_draws.stack(__sample__=("chain", "draw"))
|
|
2234
|
-
if component_coords is None:
|
|
2235
|
-
component_draws = component_draws.sel(component_coords)
|
|
2236
|
-
if isinstance(component_draws, xr.DataArray):
|
|
2237
|
-
component_draws = component_draws.to_dataset()
|
|
2238
|
-
if len(component_draws.dims):
|
|
2239
|
-
component_draws = component_draws.to_stacked_array(
|
|
2240
|
-
"latent-obs_var", sample_dims=("__sample__",)
|
|
2241
|
-
).sum("latent-obs_var")
|
|
2242
|
-
# from here component_draws is a 1d object with dimensions (sample,)
|
|
2243
|
-
|
|
2244
|
-
# calculate lower and upper alpha values
|
|
2245
|
-
lower_alpha = 1 / (1 + delta)
|
|
2246
|
-
upper_alpha = 1 + delta
|
|
2247
|
-
|
|
2248
|
-
# calculate importance sampling weights for lower and upper alpha power-scaling
|
|
2249
|
-
lower_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=lower_alpha))
|
|
2250
|
-
lower_w = lower_w / np.sum(lower_w)
|
|
2251
|
-
|
|
2252
|
-
upper_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=upper_alpha))
|
|
2253
|
-
upper_w = upper_w / np.sum(upper_w)
|
|
2254
|
-
|
|
2255
|
-
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
|
2256
|
-
func_kwargs = {"lower_weights": lower_w.values, "upper_weights": upper_w.values, "delta": delta}
|
|
2257
|
-
|
|
2258
|
-
# calculate the sensitivity diagnostic based on the importance weights and draws
|
|
2259
|
-
return _wrap_xarray_ufunc(
|
|
2260
|
-
_powerscale_sens,
|
|
2261
|
-
dataset,
|
|
2262
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
2263
|
-
func_kwargs=func_kwargs,
|
|
2264
|
-
dask_kwargs=dask_kwargs,
|
|
2265
|
-
input_core_dims=[["sample"]],
|
|
2266
|
-
)
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
def _powerscale_sens(draws, *, lower_weights=None, upper_weights=None, delta=0.01):
|
|
2270
|
-
"""
|
|
2271
|
-
Calculate power-scaling sensitivity by finite difference
|
|
2272
|
-
second derivative of CJS
|
|
2273
|
-
"""
|
|
2274
|
-
lower_cjs = max(
|
|
2275
|
-
_cjs_dist(draws=draws, weights=lower_weights),
|
|
2276
|
-
_cjs_dist(draws=-1 * draws, weights=lower_weights),
|
|
2277
|
-
)
|
|
2278
|
-
upper_cjs = max(
|
|
2279
|
-
_cjs_dist(draws=draws, weights=upper_weights),
|
|
2280
|
-
_cjs_dist(draws=-1 * draws, weights=upper_weights),
|
|
2281
|
-
)
|
|
2282
|
-
logdiffsquare = 2 * np.log2(1 + delta)
|
|
2283
|
-
grad = (lower_cjs + upper_cjs) / logdiffsquare
|
|
2284
|
-
|
|
2285
|
-
return grad
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
def _powerscale_lw(alpha, component_draws):
|
|
2289
|
-
"""
|
|
2290
|
-
Calculate log weights for power-scaling component by alpha.
|
|
2291
|
-
"""
|
|
2292
|
-
log_weights = (alpha - 1) * component_draws
|
|
2293
|
-
log_weights = psislw(log_weights)[0]
|
|
2294
|
-
|
|
2295
|
-
return log_weights
|
|
2296
|
-
|
|
2297
|
-
|
|
2298
|
-
def _cjs_dist(draws, weights):
|
|
2299
|
-
"""
|
|
2300
|
-
Calculate the cumulative Jensen-Shannon distance between original draws and weighted draws.
|
|
2301
|
-
"""
|
|
2302
|
-
|
|
2303
|
-
# sort draws and weights
|
|
2304
|
-
order = np.argsort(draws)
|
|
2305
|
-
draws = draws[order]
|
|
2306
|
-
weights = weights[order]
|
|
2307
|
-
|
|
2308
|
-
binwidth = np.diff(draws)
|
|
2309
|
-
|
|
2310
|
-
# ecdfs
|
|
2311
|
-
cdf_p = np.linspace(1 / len(draws), 1 - 1 / len(draws), len(draws) - 1)
|
|
2312
|
-
cdf_q = np.cumsum(weights / np.sum(weights))[:-1]
|
|
2313
|
-
|
|
2314
|
-
# integrals of ecdfs
|
|
2315
|
-
cdf_p_int = np.dot(cdf_p, binwidth)
|
|
2316
|
-
cdf_q_int = np.dot(cdf_q, binwidth)
|
|
2317
|
-
|
|
2318
|
-
# cjs calculation
|
|
2319
|
-
pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
|
|
2320
|
-
qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)
|
|
2321
|
-
|
|
2322
|
-
denom = 0.5 * (cdf_p + cdf_q)
|
|
2323
|
-
denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)
|
|
2324
|
-
|
|
2325
|
-
cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
|
|
2326
|
-
cdf_q_int - cdf_p_int
|
|
2327
|
-
)
|
|
2328
|
-
|
|
2329
|
-
cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
|
|
2330
|
-
cdf_p_int - cdf_q_int
|
|
2331
|
-
)
|
|
2332
|
-
|
|
2333
|
-
cjs_pq = max(0, cjs_pq)
|
|
2334
|
-
cjs_qp = max(0, cjs_qp)
|
|
2335
|
-
|
|
2336
|
-
bound = cdf_p_int + cdf_q_int
|
|
2337
|
-
|
|
2338
|
-
return np.sqrt((cjs_pq + cjs_qp) / bound)
|
|
2339
|
-
|
|
2340
|
-
|
|
2341
|
-
def bayes_factor(idata, var_name, ref_val=0, prior=None, return_ref_vals=False):
|
|
2342
|
-
r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
|
|
2343
|
-
|
|
2344
|
-
The Bayes factor is estimated by comparing a model (H1) against a model in which the
|
|
2345
|
-
parameter of interest has been restricted to be a point-null (H0). This computation
|
|
2346
|
-
assumes the models are nested and thus H0 is a special case of H1.
|
|
2347
|
-
|
|
2348
|
-
Notes
|
|
2349
|
-
-----
|
|
2350
|
-
The bayes Factor is approximated as the Savage-Dickey density ratio
|
|
2351
|
-
algorithm presented in [1]_.
|
|
2352
|
-
|
|
2353
|
-
Parameters
|
|
2354
|
-
----------
|
|
2355
|
-
idata : InferenceData
|
|
2356
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
2357
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
2358
|
-
var_name : str, optional
|
|
2359
|
-
Name of variable we want to test.
|
|
2360
|
-
ref_val : int, default 0
|
|
2361
|
-
Point-null for Bayes factor estimation.
|
|
2362
|
-
prior : numpy.array, optional
|
|
2363
|
-
In case we want to use different prior, for example for sensitivity analysis.
|
|
2364
|
-
return_ref_vals : bool, optional
|
|
2365
|
-
Whether to return the values of the prior and posterior at the reference value.
|
|
2366
|
-
Used by :func:`arviz.plot_bf` to display the distribution comparison.
|
|
2367
|
-
|
|
2368
|
-
|
|
2369
|
-
Returns
|
|
2370
|
-
-------
|
|
2371
|
-
dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
|
|
2372
|
-
|
|
2373
|
-
References
|
|
2374
|
-
----------
|
|
2375
|
-
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
|
|
2376
|
-
The case of computing Bayes factors for regression parameters.
|
|
2377
|
-
|
|
2378
|
-
Examples
|
|
2379
|
-
--------
|
|
2380
|
-
Moderate evidence indicating that the parameter "a" is different from zero.
|
|
2381
|
-
|
|
2382
|
-
.. ipython::
|
|
2383
|
-
|
|
2384
|
-
In [1]: import numpy as np
|
|
2385
|
-
...: import arviz as az
|
|
2386
|
-
...: idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
|
|
2387
|
-
...: prior={"a":np.random.normal(0, 1, 5000)})
|
|
2388
|
-
...: az.bayes_factor(idata, var_name="a", ref_val=0)
|
|
2389
|
-
|
|
2390
|
-
"""
|
|
2391
|
-
|
|
2392
|
-
posterior = extract(idata, var_names=var_name).values
|
|
2393
|
-
|
|
2394
|
-
if ref_val > posterior.max() or ref_val < posterior.min():
|
|
2395
|
-
_log.warning(
|
|
2396
|
-
"The reference value is outside of the posterior. "
|
|
2397
|
-
"This translate into infinite support for H1, which is most likely an overstatement."
|
|
2398
|
-
)
|
|
2399
|
-
|
|
2400
|
-
if posterior.ndim > 1:
|
|
2401
|
-
_log.warning("Posterior distribution has {posterior.ndim} dimensions")
|
|
2402
|
-
|
|
2403
|
-
if prior is None:
|
|
2404
|
-
prior = extract(idata, var_names=var_name, group="prior").values
|
|
2405
|
-
|
|
2406
|
-
if posterior.dtype.kind == "f":
|
|
2407
|
-
posterior_grid, posterior_pdf, *_ = _kde_linear(posterior)
|
|
2408
|
-
prior_grid, prior_pdf, *_ = _kde_linear(prior)
|
|
2409
|
-
posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
|
|
2410
|
-
prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)
|
|
2411
|
-
|
|
2412
|
-
elif posterior.dtype.kind == "i":
|
|
2413
|
-
posterior_at_ref_val = (posterior == ref_val).mean()
|
|
2414
|
-
prior_at_ref_val = (prior == ref_val).mean()
|
|
2415
|
-
|
|
2416
|
-
bf_10 = prior_at_ref_val / posterior_at_ref_val
|
|
2417
|
-
bf = {"BF10": bf_10, "BF01": 1 / bf_10}
|
|
2418
|
-
|
|
2419
|
-
if return_ref_vals:
|
|
2420
|
-
return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
|
|
2421
|
-
else:
|
|
2422
|
-
return bf
|