arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/stats/diagnostics.py
DELETED
|
@@ -1,1013 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name
|
|
2
|
-
"""Diagnostic functions for ArviZ."""
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Sequence
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import packaging
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import scipy
|
|
10
|
-
from scipy import stats
|
|
11
|
-
|
|
12
|
-
from ..data import convert_to_dataset
|
|
13
|
-
from ..utils import Numba, _numba_var, _stack, _var_names
|
|
14
|
-
from .density_utils import histogram as _histogram
|
|
15
|
-
from .stats_utils import _circular_standard_deviation, _sqrt
|
|
16
|
-
from .stats_utils import autocov as _autocov
|
|
17
|
-
from .stats_utils import not_valid as _not_valid
|
|
18
|
-
from .stats_utils import quantile as _quantile
|
|
19
|
-
from .stats_utils import stats_variance_2d as svar
|
|
20
|
-
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
|
|
21
|
-
|
|
22
|
-
__all__ = ["bfmi", "ess", "rhat", "mcse"]
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def bfmi(data):
|
|
26
|
-
r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
|
|
27
|
-
|
|
28
|
-
BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
|
|
29
|
-
information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
|
|
30
|
-
values smaller than 0.3 indicate poor sampling. However, this threshold is
|
|
31
|
-
provisional and may change. See
|
|
32
|
-
`pystan_workflow <http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html>`_
|
|
33
|
-
for more information.
|
|
34
|
-
|
|
35
|
-
Parameters
|
|
36
|
-
----------
|
|
37
|
-
data : obj
|
|
38
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
39
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
40
|
-
If InferenceData, energy variable needs to be found.
|
|
41
|
-
|
|
42
|
-
Returns
|
|
43
|
-
-------
|
|
44
|
-
z : array
|
|
45
|
-
The Bayesian fraction of missing information of the model and trace. One element per
|
|
46
|
-
chain in the trace.
|
|
47
|
-
|
|
48
|
-
See Also
|
|
49
|
-
--------
|
|
50
|
-
plot_energy : Plot energy transition distribution and marginal energy
|
|
51
|
-
distribution in HMC algorithms.
|
|
52
|
-
|
|
53
|
-
Examples
|
|
54
|
-
--------
|
|
55
|
-
Compute the BFMI of an InferenceData object
|
|
56
|
-
|
|
57
|
-
.. ipython::
|
|
58
|
-
|
|
59
|
-
In [1]: import arviz as az
|
|
60
|
-
...: data = az.load_arviz_data('radon')
|
|
61
|
-
...: az.bfmi(data)
|
|
62
|
-
|
|
63
|
-
"""
|
|
64
|
-
if isinstance(data, np.ndarray):
|
|
65
|
-
return _bfmi(data)
|
|
66
|
-
|
|
67
|
-
dataset = convert_to_dataset(data, group="sample_stats")
|
|
68
|
-
if not hasattr(dataset, "energy"):
|
|
69
|
-
raise TypeError("Energy variable was not found.")
|
|
70
|
-
return _bfmi(dataset.energy.transpose("chain", "draw"))
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def ess(
|
|
74
|
-
data,
|
|
75
|
-
*,
|
|
76
|
-
var_names=None,
|
|
77
|
-
method="bulk",
|
|
78
|
-
relative=False,
|
|
79
|
-
prob=None,
|
|
80
|
-
dask_kwargs=None,
|
|
81
|
-
):
|
|
82
|
-
r"""Calculate estimate of the effective sample size (ess).
|
|
83
|
-
|
|
84
|
-
Parameters
|
|
85
|
-
----------
|
|
86
|
-
data : obj
|
|
87
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
88
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
89
|
-
For ndarray: shape = (chain, draw).
|
|
90
|
-
For n-dimensional ndarray transform first to dataset with :func:`arviz.convert_to_dataset`.
|
|
91
|
-
var_names : str or list of str
|
|
92
|
-
Names of variables to include in the return value Dataset.
|
|
93
|
-
method : str, optional, default "bulk"
|
|
94
|
-
Select ess method. Valid methods are:
|
|
95
|
-
|
|
96
|
-
- "bulk"
|
|
97
|
-
- "tail" # prob, optional
|
|
98
|
-
- "quantile" # prob
|
|
99
|
-
- "mean" (old ess)
|
|
100
|
-
- "sd"
|
|
101
|
-
- "median"
|
|
102
|
-
- "mad" (mean absolute deviance)
|
|
103
|
-
- "z_scale"
|
|
104
|
-
- "folded"
|
|
105
|
-
- "identity"
|
|
106
|
-
- "local"
|
|
107
|
-
relative : bool
|
|
108
|
-
Return relative ess
|
|
109
|
-
``ress = ess / n``
|
|
110
|
-
prob : float, or tuple of two floats, optional
|
|
111
|
-
probability value for "tail", "quantile" or "local" ess functions.
|
|
112
|
-
dask_kwargs : dict, optional
|
|
113
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
114
|
-
|
|
115
|
-
Returns
|
|
116
|
-
-------
|
|
117
|
-
xarray.Dataset
|
|
118
|
-
Return the effective sample size, :math:`\hat{N}_{eff}`
|
|
119
|
-
|
|
120
|
-
Notes
|
|
121
|
-
-----
|
|
122
|
-
The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by:
|
|
123
|
-
|
|
124
|
-
.. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}
|
|
125
|
-
|
|
126
|
-
.. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}
|
|
127
|
-
|
|
128
|
-
where :math:`M` is the number of chains, :math:`N` the number of draws,
|
|
129
|
-
:math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
|
|
130
|
-
:math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} +
|
|
131
|
-
\hat{\rho}_{2K+1}` is still positive.
|
|
132
|
-
|
|
133
|
-
The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
|
|
134
|
-
criterion (Geyer, 1992; Geyer, 2011).
|
|
135
|
-
|
|
136
|
-
References
|
|
137
|
-
----------
|
|
138
|
-
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
139
|
-
localization: An improved Rhat for assessing convergence of
|
|
140
|
-
MCMC. Bayesian analysis, 16(2):667-718.
|
|
141
|
-
* https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
|
|
142
|
-
* Gelman et al. BDA3 (2013) Formula 11.8
|
|
143
|
-
|
|
144
|
-
See Also
|
|
145
|
-
--------
|
|
146
|
-
arviz.rhat : Compute estimate of rank normalized splitR-hat for a set of traces.
|
|
147
|
-
arviz.mcse : Calculate Markov Chain Standard Error statistic.
|
|
148
|
-
plot_ess : Plot quantile, local or evolution of effective sample sizes (ESS).
|
|
149
|
-
arviz.summary : Create a data frame with summary statistics.
|
|
150
|
-
|
|
151
|
-
Examples
|
|
152
|
-
--------
|
|
153
|
-
Calculate the effective_sample_size using the default arguments:
|
|
154
|
-
|
|
155
|
-
.. ipython::
|
|
156
|
-
|
|
157
|
-
In [1]: import arviz as az
|
|
158
|
-
...: data = az.load_arviz_data('non_centered_eight')
|
|
159
|
-
...: az.ess(data)
|
|
160
|
-
|
|
161
|
-
Calculate the ress of some of the variables
|
|
162
|
-
|
|
163
|
-
.. ipython::
|
|
164
|
-
|
|
165
|
-
In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"])
|
|
166
|
-
|
|
167
|
-
Calculate the ess using the "tail" method, leaving the `prob` argument at its default
|
|
168
|
-
value.
|
|
169
|
-
|
|
170
|
-
.. ipython::
|
|
171
|
-
|
|
172
|
-
In [1]: az.ess(data, method="tail")
|
|
173
|
-
|
|
174
|
-
"""
|
|
175
|
-
methods = {
|
|
176
|
-
"bulk": _ess_bulk,
|
|
177
|
-
"tail": _ess_tail,
|
|
178
|
-
"quantile": _ess_quantile,
|
|
179
|
-
"mean": _ess_mean,
|
|
180
|
-
"sd": _ess_sd,
|
|
181
|
-
"median": _ess_median,
|
|
182
|
-
"mad": _ess_mad,
|
|
183
|
-
"z_scale": _ess_z_scale,
|
|
184
|
-
"folded": _ess_folded,
|
|
185
|
-
"identity": _ess_identity,
|
|
186
|
-
"local": _ess_local,
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
if method not in methods:
|
|
190
|
-
raise TypeError(f"ess method {method} not found. Valid methods are:\n{', '.join(methods)}")
|
|
191
|
-
ess_func = methods[method]
|
|
192
|
-
|
|
193
|
-
if (method == "quantile") and prob is None:
|
|
194
|
-
raise TypeError("Quantile (prob) information needs to be defined.")
|
|
195
|
-
|
|
196
|
-
if isinstance(data, np.ndarray):
|
|
197
|
-
data = np.atleast_2d(data)
|
|
198
|
-
if len(data.shape) < 3:
|
|
199
|
-
if prob is not None:
|
|
200
|
-
return ess_func( # pylint: disable=unexpected-keyword-arg
|
|
201
|
-
data, prob=prob, relative=relative
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
return ess_func(data, relative=relative)
|
|
205
|
-
|
|
206
|
-
msg = (
|
|
207
|
-
"Only uni-dimensional ndarray variables are supported."
|
|
208
|
-
" Please transform first to dataset with `az.convert_to_dataset`."
|
|
209
|
-
)
|
|
210
|
-
raise TypeError(msg)
|
|
211
|
-
|
|
212
|
-
dataset = convert_to_dataset(data, group="posterior")
|
|
213
|
-
var_names = _var_names(var_names, dataset)
|
|
214
|
-
|
|
215
|
-
dataset = dataset if var_names is None else dataset[var_names]
|
|
216
|
-
|
|
217
|
-
ufunc_kwargs = {"ravel": False}
|
|
218
|
-
func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative}
|
|
219
|
-
return _wrap_xarray_ufunc(
|
|
220
|
-
ess_func,
|
|
221
|
-
dataset,
|
|
222
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
223
|
-
func_kwargs=func_kwargs,
|
|
224
|
-
dask_kwargs=dask_kwargs,
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
229
|
-
r"""Compute estimate of rank normalized splitR-hat for a set of traces.
|
|
230
|
-
|
|
231
|
-
The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance
|
|
232
|
-
between multiple chains to the variance within each chain. If convergence has been achieved,
|
|
233
|
-
the between-chain and within-chain variances should be identical. To be most effective in
|
|
234
|
-
detecting evidence for nonconvergence, each chain should have been initialized to starting
|
|
235
|
-
values that are dispersed relative to the target distribution.
|
|
236
|
-
|
|
237
|
-
Parameters
|
|
238
|
-
----------
|
|
239
|
-
data : obj
|
|
240
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
241
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
242
|
-
At least 2 posterior chains are needed to compute this diagnostic of one or more
|
|
243
|
-
stochastic parameters.
|
|
244
|
-
For ndarray: shape = (chain, draw).
|
|
245
|
-
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
|
|
246
|
-
var_names : list
|
|
247
|
-
Names of variables to include in the rhat report
|
|
248
|
-
method : str
|
|
249
|
-
Select R-hat method. Valid methods are:
|
|
250
|
-
- "rank" # recommended by Vehtari et al. (2021)
|
|
251
|
-
- "split"
|
|
252
|
-
- "folded"
|
|
253
|
-
- "z_scale"
|
|
254
|
-
- "identity"
|
|
255
|
-
dask_kwargs : dict, optional
|
|
256
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
257
|
-
|
|
258
|
-
Returns
|
|
259
|
-
-------
|
|
260
|
-
xarray.Dataset
|
|
261
|
-
Returns dataset of the potential scale reduction factors, :math:`\hat{R}`
|
|
262
|
-
|
|
263
|
-
See Also
|
|
264
|
-
--------
|
|
265
|
-
ess : Calculate estimate of the effective sample size (ess).
|
|
266
|
-
mcse : Calculate Markov Chain Standard Error statistic.
|
|
267
|
-
plot_forest : Forest plot to compare HDI intervals from a number of distributions.
|
|
268
|
-
|
|
269
|
-
Notes
|
|
270
|
-
-----
|
|
271
|
-
The diagnostic is computed by:
|
|
272
|
-
|
|
273
|
-
.. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
|
|
274
|
-
|
|
275
|
-
where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
|
|
276
|
-
estimate for the pooled rank-traces. This is the potential scale reduction factor, which
|
|
277
|
-
converges to unity when each of the traces is a sample from the target posterior. Values
|
|
278
|
-
greater than one indicate that one or more chains have not yet converged.
|
|
279
|
-
|
|
280
|
-
Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
|
|
281
|
-
Each chain is split in two and normalized with the z-transform following
|
|
282
|
-
Vehtari et al. (2021).
|
|
283
|
-
|
|
284
|
-
References
|
|
285
|
-
----------
|
|
286
|
-
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
287
|
-
localization: An improved Rhat for assessing convergence of
|
|
288
|
-
MCMC. Bayesian analysis, 16(2):667-718.
|
|
289
|
-
* Gelman et al. BDA3 (2013)
|
|
290
|
-
* Brooks and Gelman (1998)
|
|
291
|
-
* Gelman and Rubin (1992)
|
|
292
|
-
|
|
293
|
-
Examples
|
|
294
|
-
--------
|
|
295
|
-
Calculate the R-hat using the default arguments:
|
|
296
|
-
|
|
297
|
-
.. ipython::
|
|
298
|
-
|
|
299
|
-
In [1]: import arviz as az
|
|
300
|
-
...: data = az.load_arviz_data("non_centered_eight")
|
|
301
|
-
...: az.rhat(data)
|
|
302
|
-
|
|
303
|
-
Calculate the R-hat of some variables using the folded method:
|
|
304
|
-
|
|
305
|
-
.. ipython::
|
|
306
|
-
|
|
307
|
-
In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded")
|
|
308
|
-
|
|
309
|
-
"""
|
|
310
|
-
methods = {
|
|
311
|
-
"rank": _rhat_rank,
|
|
312
|
-
"split": _rhat_split,
|
|
313
|
-
"folded": _rhat_folded,
|
|
314
|
-
"z_scale": _rhat_z_scale,
|
|
315
|
-
"identity": _rhat_identity,
|
|
316
|
-
}
|
|
317
|
-
if method not in methods:
|
|
318
|
-
raise TypeError(
|
|
319
|
-
f"R-hat method {method} not found. Valid methods are:\n{', '.join(methods)}"
|
|
320
|
-
)
|
|
321
|
-
rhat_func = methods[method]
|
|
322
|
-
|
|
323
|
-
if isinstance(data, np.ndarray):
|
|
324
|
-
data = np.atleast_2d(data)
|
|
325
|
-
if len(data.shape) < 3:
|
|
326
|
-
return rhat_func(data)
|
|
327
|
-
|
|
328
|
-
msg = (
|
|
329
|
-
"Only uni-dimensional ndarray variables are supported."
|
|
330
|
-
" Please transform first to dataset with `az.convert_to_dataset`."
|
|
331
|
-
)
|
|
332
|
-
raise TypeError(msg)
|
|
333
|
-
|
|
334
|
-
dataset = convert_to_dataset(data, group="posterior")
|
|
335
|
-
var_names = _var_names(var_names, dataset)
|
|
336
|
-
|
|
337
|
-
dataset = dataset if var_names is None else dataset[var_names]
|
|
338
|
-
|
|
339
|
-
ufunc_kwargs = {"ravel": False}
|
|
340
|
-
func_kwargs = {}
|
|
341
|
-
return _wrap_xarray_ufunc(
|
|
342
|
-
rhat_func,
|
|
343
|
-
dataset,
|
|
344
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
345
|
-
func_kwargs=func_kwargs,
|
|
346
|
-
dask_kwargs=dask_kwargs,
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
|
|
351
|
-
"""Calculate Markov Chain Standard Error statistic.
|
|
352
|
-
|
|
353
|
-
Parameters
|
|
354
|
-
----------
|
|
355
|
-
data : obj
|
|
356
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
357
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
358
|
-
For ndarray: shape = (chain, draw).
|
|
359
|
-
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
|
|
360
|
-
var_names : list
|
|
361
|
-
Names of variables to include in the rhat report
|
|
362
|
-
method : str
|
|
363
|
-
Select mcse method. Valid methods are:
|
|
364
|
-
- "mean"
|
|
365
|
-
- "sd"
|
|
366
|
-
- "median"
|
|
367
|
-
- "quantile"
|
|
368
|
-
|
|
369
|
-
prob : float
|
|
370
|
-
Quantile information.
|
|
371
|
-
dask_kwargs : dict, optional
|
|
372
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
373
|
-
|
|
374
|
-
Returns
|
|
375
|
-
-------
|
|
376
|
-
xarray.Dataset
|
|
377
|
-
Return the msce dataset
|
|
378
|
-
|
|
379
|
-
See Also
|
|
380
|
-
--------
|
|
381
|
-
ess : Compute autocovariance estimates for every lag for the input array.
|
|
382
|
-
summary : Create a data frame with summary statistics.
|
|
383
|
-
plot_mcse : Plot quantile or local Monte Carlo Standard Error.
|
|
384
|
-
|
|
385
|
-
Examples
|
|
386
|
-
--------
|
|
387
|
-
Calculate the Markov Chain Standard Error using the default arguments:
|
|
388
|
-
|
|
389
|
-
.. ipython::
|
|
390
|
-
|
|
391
|
-
In [1]: import arviz as az
|
|
392
|
-
...: data = az.load_arviz_data("non_centered_eight")
|
|
393
|
-
...: az.mcse(data)
|
|
394
|
-
|
|
395
|
-
Calculate the Markov Chain Standard Error using the quantile method:
|
|
396
|
-
|
|
397
|
-
.. ipython::
|
|
398
|
-
|
|
399
|
-
In [1]: az.mcse(data, method="quantile", prob=0.7)
|
|
400
|
-
|
|
401
|
-
"""
|
|
402
|
-
methods = {
|
|
403
|
-
"mean": _mcse_mean,
|
|
404
|
-
"sd": _mcse_sd,
|
|
405
|
-
"median": _mcse_median,
|
|
406
|
-
"quantile": _mcse_quantile,
|
|
407
|
-
}
|
|
408
|
-
if method not in methods:
|
|
409
|
-
raise TypeError(
|
|
410
|
-
"mcse method {} not found. Valid methods are:\n{}".format(
|
|
411
|
-
method, "\n ".join(methods)
|
|
412
|
-
)
|
|
413
|
-
)
|
|
414
|
-
mcse_func = methods[method]
|
|
415
|
-
|
|
416
|
-
if method == "quantile" and prob is None:
|
|
417
|
-
raise TypeError("Quantile (prob) information needs to be defined.")
|
|
418
|
-
|
|
419
|
-
if isinstance(data, np.ndarray):
|
|
420
|
-
data = np.atleast_2d(data)
|
|
421
|
-
if len(data.shape) < 3:
|
|
422
|
-
if prob is not None:
|
|
423
|
-
return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg
|
|
424
|
-
|
|
425
|
-
return mcse_func(data)
|
|
426
|
-
|
|
427
|
-
msg = (
|
|
428
|
-
"Only uni-dimensional ndarray variables are supported."
|
|
429
|
-
" Please transform first to dataset with `az.convert_to_dataset`."
|
|
430
|
-
)
|
|
431
|
-
raise TypeError(msg)
|
|
432
|
-
|
|
433
|
-
dataset = convert_to_dataset(data, group="posterior")
|
|
434
|
-
var_names = _var_names(var_names, dataset)
|
|
435
|
-
|
|
436
|
-
dataset = dataset if var_names is None else dataset[var_names]
|
|
437
|
-
|
|
438
|
-
ufunc_kwargs = {"ravel": False}
|
|
439
|
-
func_kwargs = {} if prob is None else {"prob": prob}
|
|
440
|
-
return _wrap_xarray_ufunc(
|
|
441
|
-
mcse_func,
|
|
442
|
-
dataset,
|
|
443
|
-
ufunc_kwargs=ufunc_kwargs,
|
|
444
|
-
func_kwargs=func_kwargs,
|
|
445
|
-
dask_kwargs=dask_kwargs,
|
|
446
|
-
)
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
def ks_summary(pareto_tail_indices):
|
|
450
|
-
"""Display a summary of Pareto tail indices.
|
|
451
|
-
|
|
452
|
-
Parameters
|
|
453
|
-
----------
|
|
454
|
-
pareto_tail_indices : array
|
|
455
|
-
Pareto tail indices.
|
|
456
|
-
|
|
457
|
-
Returns
|
|
458
|
-
-------
|
|
459
|
-
df_k : dataframe
|
|
460
|
-
Dataframe containing k diagnostic values.
|
|
461
|
-
"""
|
|
462
|
-
_numba_flag = Numba.numba_flag
|
|
463
|
-
if _numba_flag:
|
|
464
|
-
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
|
|
465
|
-
kcounts, *_ = _histogram(pareto_tail_indices, bins)
|
|
466
|
-
else:
|
|
467
|
-
kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
|
|
468
|
-
kprop = kcounts / len(pareto_tail_indices) * 100
|
|
469
|
-
df_k = pd.DataFrame(
|
|
470
|
-
dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
|
|
471
|
-
).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: " (0.7, 1]", 3: " (1, Inf)"})
|
|
472
|
-
|
|
473
|
-
if np.sum(kcounts[1:]) == 0:
|
|
474
|
-
warnings.warn("All Pareto k estimates are good (k < 0.5)")
|
|
475
|
-
elif np.sum(kcounts[2:]) == 0:
|
|
476
|
-
warnings.warn("All Pareto k estimates are ok (k < 0.7)")
|
|
477
|
-
|
|
478
|
-
return df_k
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
def _bfmi(energy):
|
|
482
|
-
r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
|
|
483
|
-
|
|
484
|
-
BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
|
|
485
|
-
information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
|
|
486
|
-
values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
|
|
487
|
-
change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
|
|
488
|
-
information.
|
|
489
|
-
|
|
490
|
-
Parameters
|
|
491
|
-
----------
|
|
492
|
-
energy : NumPy array
|
|
493
|
-
Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically,
|
|
494
|
-
after converting a trace or fit to InferenceData, the energy will be in
|
|
495
|
-
`data.sample_stats.energy`.
|
|
496
|
-
|
|
497
|
-
Returns
|
|
498
|
-
-------
|
|
499
|
-
z : array
|
|
500
|
-
The Bayesian fraction of missing information of the model and trace. One element per
|
|
501
|
-
chain in the trace.
|
|
502
|
-
"""
|
|
503
|
-
energy_mat = np.atleast_2d(energy)
|
|
504
|
-
num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1) # pylint: disable=no-member
|
|
505
|
-
if energy_mat.ndim == 2:
|
|
506
|
-
den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1)
|
|
507
|
-
else:
|
|
508
|
-
den = np.var(energy, axis=1, ddof=1)
|
|
509
|
-
return num / den
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
def _backtransform_ranks(arr, c=3 / 8): # pylint: disable=invalid-name
|
|
513
|
-
"""Backtransformation of ranks.
|
|
514
|
-
|
|
515
|
-
Parameters
|
|
516
|
-
----------
|
|
517
|
-
arr : np.ndarray
|
|
518
|
-
Ranks array
|
|
519
|
-
c : float
|
|
520
|
-
Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).
|
|
521
|
-
|
|
522
|
-
Returns
|
|
523
|
-
-------
|
|
524
|
-
np.ndarray
|
|
525
|
-
|
|
526
|
-
References
|
|
527
|
-
----------
|
|
528
|
-
Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
|
|
529
|
-
"""
|
|
530
|
-
arr = np.asarray(arr)
|
|
531
|
-
size = arr.size
|
|
532
|
-
return (arr - c) / (size - 2 * c + 1)
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
def _z_scale(ary):
|
|
536
|
-
"""Calculate z_scale.
|
|
537
|
-
|
|
538
|
-
Parameters
|
|
539
|
-
----------
|
|
540
|
-
ary : np.ndarray
|
|
541
|
-
|
|
542
|
-
Returns
|
|
543
|
-
-------
|
|
544
|
-
np.ndarray
|
|
545
|
-
"""
|
|
546
|
-
ary = np.asarray(ary)
|
|
547
|
-
if packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
|
|
548
|
-
rank = stats.rankdata(ary, method="average")
|
|
549
|
-
else:
|
|
550
|
-
# the .ravel part is only needed to overcom a bug in scipy 1.10.0.rc1
|
|
551
|
-
rank = stats.rankdata( # pylint: disable=unexpected-keyword-arg
|
|
552
|
-
ary, method="average", nan_policy="omit"
|
|
553
|
-
)
|
|
554
|
-
rank = _backtransform_ranks(rank)
|
|
555
|
-
z = stats.norm.ppf(rank)
|
|
556
|
-
z = z.reshape(ary.shape)
|
|
557
|
-
return z
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
def _split_chains(ary):
|
|
561
|
-
"""Split and stack chains."""
|
|
562
|
-
ary = np.asarray(ary)
|
|
563
|
-
if len(ary.shape) <= 1:
|
|
564
|
-
ary = np.atleast_2d(ary)
|
|
565
|
-
_, n_draw = ary.shape
|
|
566
|
-
half = n_draw // 2
|
|
567
|
-
return _stack(ary[:, :half], ary[:, -half:])
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
def _z_fold(ary):
|
|
571
|
-
"""Fold and z-scale values."""
|
|
572
|
-
ary = np.asarray(ary)
|
|
573
|
-
ary = abs(ary - np.median(ary))
|
|
574
|
-
ary = _z_scale(ary)
|
|
575
|
-
return ary
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
def _rhat(ary):
|
|
579
|
-
"""Compute the rhat for a 2d array."""
|
|
580
|
-
_numba_flag = Numba.numba_flag
|
|
581
|
-
ary = np.asarray(ary, dtype=float)
|
|
582
|
-
if _not_valid(ary, check_shape=False):
|
|
583
|
-
return np.nan
|
|
584
|
-
_, num_samples = ary.shape
|
|
585
|
-
|
|
586
|
-
# Calculate chain mean
|
|
587
|
-
chain_mean = np.mean(ary, axis=1)
|
|
588
|
-
# Calculate chain variance
|
|
589
|
-
chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
|
|
590
|
-
# Calculate between-chain variance
|
|
591
|
-
between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
|
|
592
|
-
# Calculate within-chain variance
|
|
593
|
-
within_chain_variance = np.mean(chain_var)
|
|
594
|
-
# Estimate of marginal posterior variance
|
|
595
|
-
rhat_value = np.sqrt(
|
|
596
|
-
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
|
|
597
|
-
)
|
|
598
|
-
return rhat_value
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
def _rhat_rank(ary):
|
|
602
|
-
"""Compute the rank normalized rhat for 2d array.
|
|
603
|
-
|
|
604
|
-
Computation follows https://arxiv.org/abs/1903.08008
|
|
605
|
-
"""
|
|
606
|
-
ary = np.asarray(ary)
|
|
607
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
608
|
-
return np.nan
|
|
609
|
-
split_ary = _split_chains(ary)
|
|
610
|
-
rhat_bulk = _rhat(_z_scale(split_ary))
|
|
611
|
-
|
|
612
|
-
split_ary_folded = abs(split_ary - np.median(split_ary))
|
|
613
|
-
rhat_tail = _rhat(_z_scale(split_ary_folded))
|
|
614
|
-
|
|
615
|
-
rhat_rank = max(rhat_bulk, rhat_tail)
|
|
616
|
-
return rhat_rank
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def _rhat_folded(ary):
|
|
620
|
-
"""Calculate split-Rhat for folded z-values."""
|
|
621
|
-
ary = np.asarray(ary)
|
|
622
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
623
|
-
return np.nan
|
|
624
|
-
ary = _z_fold(_split_chains(ary))
|
|
625
|
-
return _rhat(ary)
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def _rhat_z_scale(ary):
|
|
629
|
-
ary = np.asarray(ary)
|
|
630
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
631
|
-
return np.nan
|
|
632
|
-
return _rhat(_z_scale(_split_chains(ary)))
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def _rhat_split(ary):
|
|
636
|
-
ary = np.asarray(ary)
|
|
637
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
638
|
-
return np.nan
|
|
639
|
-
return _rhat(_split_chains(ary))
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
def _rhat_identity(ary):
|
|
643
|
-
ary = np.asarray(ary)
|
|
644
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
645
|
-
return np.nan
|
|
646
|
-
return _rhat(ary)
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
def _ess(ary, relative=False):
|
|
650
|
-
"""Compute the effective sample size for a 2D array."""
|
|
651
|
-
_numba_flag = Numba.numba_flag
|
|
652
|
-
ary = np.asarray(ary, dtype=float)
|
|
653
|
-
if _not_valid(ary, check_shape=False):
|
|
654
|
-
return np.nan
|
|
655
|
-
if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution: # pylint: disable=no-member
|
|
656
|
-
return ary.size
|
|
657
|
-
if len(ary.shape) < 2:
|
|
658
|
-
ary = np.atleast_2d(ary)
|
|
659
|
-
n_chain, n_draw = ary.shape
|
|
660
|
-
acov = _autocov(ary, axis=1)
|
|
661
|
-
chain_mean = ary.mean(axis=1)
|
|
662
|
-
mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
|
|
663
|
-
var_plus = mean_var * (n_draw - 1.0) / n_draw
|
|
664
|
-
if n_chain > 1:
|
|
665
|
-
var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
|
|
666
|
-
|
|
667
|
-
rho_hat_t = np.zeros(n_draw)
|
|
668
|
-
rho_hat_even = 1.0
|
|
669
|
-
rho_hat_t[0] = rho_hat_even
|
|
670
|
-
rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
|
|
671
|
-
rho_hat_t[1] = rho_hat_odd
|
|
672
|
-
|
|
673
|
-
# Geyer's initial positive sequence
|
|
674
|
-
t = 1
|
|
675
|
-
while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
|
|
676
|
-
rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
|
|
677
|
-
rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
|
|
678
|
-
if (rho_hat_even + rho_hat_odd) >= 0:
|
|
679
|
-
rho_hat_t[t + 1] = rho_hat_even
|
|
680
|
-
rho_hat_t[t + 2] = rho_hat_odd
|
|
681
|
-
t += 2
|
|
682
|
-
|
|
683
|
-
max_t = t - 2
|
|
684
|
-
# improve estimation
|
|
685
|
-
if rho_hat_even > 0:
|
|
686
|
-
rho_hat_t[max_t + 1] = rho_hat_even
|
|
687
|
-
# Geyer's initial monotone sequence
|
|
688
|
-
t = 1
|
|
689
|
-
while t <= max_t - 2:
|
|
690
|
-
if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
|
|
691
|
-
rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
|
|
692
|
-
rho_hat_t[t + 2] = rho_hat_t[t + 1]
|
|
693
|
-
t += 2
|
|
694
|
-
|
|
695
|
-
ess = n_chain * n_draw
|
|
696
|
-
tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
|
|
697
|
-
tau_hat = max(tau_hat, 1 / np.log10(ess))
|
|
698
|
-
ess = (1 if relative else ess) / tau_hat
|
|
699
|
-
if np.isnan(rho_hat_t).any():
|
|
700
|
-
ess = np.nan
|
|
701
|
-
return ess
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
def _ess_bulk(ary, relative=False):
|
|
705
|
-
"""Compute the effective sample size for the bulk."""
|
|
706
|
-
ary = np.asarray(ary)
|
|
707
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
708
|
-
return np.nan
|
|
709
|
-
z_scaled = _z_scale(_split_chains(ary))
|
|
710
|
-
ess_bulk = _ess(z_scaled, relative=relative)
|
|
711
|
-
return ess_bulk
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
def _ess_tail(ary, prob=None, relative=False):
|
|
715
|
-
"""Compute the effective sample size for the tail.
|
|
716
|
-
|
|
717
|
-
If `prob` defined, ess = min(qess(prob), qess(1-prob))
|
|
718
|
-
"""
|
|
719
|
-
if prob is None:
|
|
720
|
-
prob = (0.05, 0.95)
|
|
721
|
-
elif not isinstance(prob, Sequence):
|
|
722
|
-
prob = (prob, 1 - prob)
|
|
723
|
-
|
|
724
|
-
ary = np.asarray(ary)
|
|
725
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
726
|
-
return np.nan
|
|
727
|
-
|
|
728
|
-
prob_low, prob_high = prob
|
|
729
|
-
quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
|
|
730
|
-
quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
|
|
731
|
-
return min(quantile_low_ess, quantile_high_ess)
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
def _ess_mean(ary, relative=False):
|
|
735
|
-
"""Compute the effective sample size for the mean."""
|
|
736
|
-
ary = np.asarray(ary)
|
|
737
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
738
|
-
return np.nan
|
|
739
|
-
return _ess(_split_chains(ary), relative=relative)
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
def _ess_sd(ary, relative=False):
|
|
743
|
-
"""Compute the effective sample size for the sd."""
|
|
744
|
-
ary = np.asarray(ary)
|
|
745
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
746
|
-
return np.nan
|
|
747
|
-
ary = (ary - ary.mean()) ** 2
|
|
748
|
-
return _ess(_split_chains(ary), relative=relative)
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
def _ess_quantile(ary, prob, relative=False):
|
|
752
|
-
"""Compute the effective sample size for the specific residual."""
|
|
753
|
-
ary = np.asarray(ary)
|
|
754
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
755
|
-
return np.nan
|
|
756
|
-
if prob is None:
|
|
757
|
-
raise TypeError("Prob not defined.")
|
|
758
|
-
(quantile,) = _quantile(ary, prob)
|
|
759
|
-
iquantile = ary <= quantile
|
|
760
|
-
return _ess(_split_chains(iquantile), relative=relative)
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
def _ess_local(ary, prob, relative=False):
|
|
764
|
-
"""Compute the effective sample size for the specific residual."""
|
|
765
|
-
ary = np.asarray(ary)
|
|
766
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
767
|
-
return np.nan
|
|
768
|
-
if prob is None:
|
|
769
|
-
raise TypeError("Prob not defined.")
|
|
770
|
-
if len(prob) != 2:
|
|
771
|
-
raise ValueError("Prob argument in ess local must be upper and lower bound")
|
|
772
|
-
quantile = _quantile(ary, prob)
|
|
773
|
-
iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
|
|
774
|
-
return _ess(_split_chains(iquantile), relative=relative)
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
def _ess_z_scale(ary, relative=False):
|
|
778
|
-
"""Calculate ess for z-scaLe."""
|
|
779
|
-
ary = np.asarray(ary)
|
|
780
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
781
|
-
return np.nan
|
|
782
|
-
return _ess(_z_scale(_split_chains(ary)), relative=relative)
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
def _ess_folded(ary, relative=False):
|
|
786
|
-
"""Calculate split-ess for folded data."""
|
|
787
|
-
ary = np.asarray(ary)
|
|
788
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
789
|
-
return np.nan
|
|
790
|
-
return _ess(_z_fold(_split_chains(ary)), relative=relative)
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
def _ess_median(ary, relative=False):
|
|
794
|
-
"""Calculate split-ess for median."""
|
|
795
|
-
ary = np.asarray(ary)
|
|
796
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
797
|
-
return np.nan
|
|
798
|
-
return _ess_quantile(ary, 0.5, relative=relative)
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
def _ess_mad(ary, relative=False):
|
|
802
|
-
"""Calculate split-ess for mean absolute deviance."""
|
|
803
|
-
ary = np.asarray(ary)
|
|
804
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
805
|
-
return np.nan
|
|
806
|
-
ary = abs(ary - np.median(ary))
|
|
807
|
-
ary = ary <= np.median(ary)
|
|
808
|
-
ary = _z_scale(_split_chains(ary))
|
|
809
|
-
return _ess(ary, relative=relative)
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
def _ess_identity(ary, relative=False):
|
|
813
|
-
"""Calculate ess."""
|
|
814
|
-
ary = np.asarray(ary)
|
|
815
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
816
|
-
return np.nan
|
|
817
|
-
return _ess(ary, relative=relative)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
def _mcse_mean(ary):
|
|
821
|
-
"""Compute the Markov Chain mean error."""
|
|
822
|
-
_numba_flag = Numba.numba_flag
|
|
823
|
-
ary = np.asarray(ary)
|
|
824
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
825
|
-
return np.nan
|
|
826
|
-
ess = _ess_mean(ary)
|
|
827
|
-
if _numba_flag:
|
|
828
|
-
sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
|
|
829
|
-
else:
|
|
830
|
-
sd = np.std(ary, ddof=1)
|
|
831
|
-
mcse_mean_value = sd / np.sqrt(ess)
|
|
832
|
-
return mcse_mean_value
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
def _mcse_sd(ary):
|
|
836
|
-
"""Compute the Markov Chain sd error."""
|
|
837
|
-
_numba_flag = Numba.numba_flag
|
|
838
|
-
ary = np.asarray(ary)
|
|
839
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
840
|
-
return np.nan
|
|
841
|
-
sims_c2 = (ary - ary.mean()) ** 2
|
|
842
|
-
ess = _ess_mean(sims_c2)
|
|
843
|
-
evar = (sims_c2).mean()
|
|
844
|
-
varvar = ((sims_c2**2).mean() - evar**2) / ess
|
|
845
|
-
varsd = varvar / evar / 4
|
|
846
|
-
if _numba_flag:
|
|
847
|
-
mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
|
|
848
|
-
else:
|
|
849
|
-
mcse_sd_value = np.sqrt(varsd)
|
|
850
|
-
return mcse_sd_value
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
def _mcse_median(ary):
|
|
854
|
-
"""Compute the Markov Chain median error."""
|
|
855
|
-
return _mcse_quantile(ary, 0.5)
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
def _mcse_quantile(ary, prob):
|
|
859
|
-
"""Compute the Markov Chain quantile error at quantile=prob."""
|
|
860
|
-
ary = np.asarray(ary)
|
|
861
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
862
|
-
return np.nan
|
|
863
|
-
ess = _ess_quantile(ary, prob)
|
|
864
|
-
probability = [0.1586553, 0.8413447]
|
|
865
|
-
with np.errstate(invalid="ignore"):
|
|
866
|
-
ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
|
|
867
|
-
sorted_ary = np.sort(ary.ravel())
|
|
868
|
-
size = sorted_ary.size
|
|
869
|
-
ppf_size = ppf * size - 1
|
|
870
|
-
th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
|
|
871
|
-
th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
|
|
872
|
-
return (th2 - th1) / 2
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
def _mc_error(ary, batches=5, circular=False):
|
|
876
|
-
"""Calculate the simulation standard error, accounting for non-independent samples.
|
|
877
|
-
|
|
878
|
-
The trace is divided into batches, and the standard deviation of the batch
|
|
879
|
-
means is calculated.
|
|
880
|
-
|
|
881
|
-
Parameters
|
|
882
|
-
----------
|
|
883
|
-
ary : Numpy array
|
|
884
|
-
An array containing MCMC samples
|
|
885
|
-
batches : integer
|
|
886
|
-
Number of batches
|
|
887
|
-
circular : bool
|
|
888
|
-
Whether to compute the error taking into account `ary` is a circular variable
|
|
889
|
-
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
|
890
|
-
|
|
891
|
-
Returns
|
|
892
|
-
-------
|
|
893
|
-
mc_error : float
|
|
894
|
-
Simulation standard error
|
|
895
|
-
"""
|
|
896
|
-
_numba_flag = Numba.numba_flag
|
|
897
|
-
if ary.ndim > 1:
|
|
898
|
-
dims = np.shape(ary)
|
|
899
|
-
trace = np.transpose([t.ravel() for t in ary])
|
|
900
|
-
|
|
901
|
-
return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])
|
|
902
|
-
|
|
903
|
-
else:
|
|
904
|
-
if _not_valid(ary, check_shape=False):
|
|
905
|
-
return np.nan
|
|
906
|
-
if batches == 1:
|
|
907
|
-
if circular:
|
|
908
|
-
if _numba_flag:
|
|
909
|
-
std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
|
|
910
|
-
else:
|
|
911
|
-
std = stats.circstd(ary, high=np.pi, low=-np.pi)
|
|
912
|
-
elif _numba_flag:
|
|
913
|
-
std = float(_sqrt(svar(ary), np.zeros(1)).item())
|
|
914
|
-
else:
|
|
915
|
-
std = np.std(ary)
|
|
916
|
-
return std / np.sqrt(len(ary))
|
|
917
|
-
|
|
918
|
-
batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))
|
|
919
|
-
|
|
920
|
-
if circular:
|
|
921
|
-
means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
|
|
922
|
-
if _numba_flag:
|
|
923
|
-
std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
|
|
924
|
-
else:
|
|
925
|
-
std = stats.circstd(means, high=np.pi, low=-np.pi)
|
|
926
|
-
else:
|
|
927
|
-
means = np.mean(batched_traces, 1)
|
|
928
|
-
std = _sqrt(svar(means), np.zeros(1)) if _numba_flag else np.std(means)
|
|
929
|
-
return std / np.sqrt(batches)
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
def _multichain_statistics(ary, focus="mean"):
|
|
933
|
-
"""Calculate efficiently multichain statistics for summary.
|
|
934
|
-
|
|
935
|
-
Parameters
|
|
936
|
-
----------
|
|
937
|
-
ary : numpy.ndarray
|
|
938
|
-
focus : select focus for the statistics. Deafault is mean.
|
|
939
|
-
|
|
940
|
-
Returns
|
|
941
|
-
-------
|
|
942
|
-
tuple
|
|
943
|
-
Order of return parameters is
|
|
944
|
-
If focus equals "mean"
|
|
945
|
-
- mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat
|
|
946
|
-
Else if focus equals "median"
|
|
947
|
-
- mcse_median, ess_median, ess_tail, r_hat
|
|
948
|
-
"""
|
|
949
|
-
ary = np.atleast_2d(ary)
|
|
950
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
|
951
|
-
if focus == "mean":
|
|
952
|
-
return np.nan, np.nan, np.nan, np.nan, np.nan
|
|
953
|
-
return np.nan, np.nan, np.nan, np.nan
|
|
954
|
-
|
|
955
|
-
z_split = _z_scale(_split_chains(ary))
|
|
956
|
-
|
|
957
|
-
# ess tail
|
|
958
|
-
quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
|
|
959
|
-
iquantile05 = ary <= quantile05
|
|
960
|
-
quantile05_ess = _ess(_split_chains(iquantile05))
|
|
961
|
-
iquantile95 = ary <= quantile95
|
|
962
|
-
quantile95_ess = _ess(_split_chains(iquantile95))
|
|
963
|
-
ess_tail_value = min(quantile05_ess, quantile95_ess)
|
|
964
|
-
|
|
965
|
-
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
|
966
|
-
rhat_value = np.nan
|
|
967
|
-
else:
|
|
968
|
-
# r_hat
|
|
969
|
-
rhat_bulk = _rhat(z_split)
|
|
970
|
-
ary_folded = np.abs(ary - np.median(ary))
|
|
971
|
-
rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
|
|
972
|
-
rhat_value = max(rhat_bulk, rhat_tail)
|
|
973
|
-
|
|
974
|
-
if focus == "mean":
|
|
975
|
-
# ess mean
|
|
976
|
-
ess_mean_value = _ess_mean(ary)
|
|
977
|
-
|
|
978
|
-
# mcse_mean
|
|
979
|
-
sims_c2 = (ary - ary.mean()) ** 2
|
|
980
|
-
sims_c2_sum = sims_c2.sum()
|
|
981
|
-
var = sims_c2_sum / (sims_c2.size - 1)
|
|
982
|
-
mcse_mean_value = np.sqrt(var / ess_mean_value)
|
|
983
|
-
|
|
984
|
-
# ess bulk
|
|
985
|
-
ess_bulk_value = _ess(z_split)
|
|
986
|
-
|
|
987
|
-
# mcse_sd
|
|
988
|
-
evar = sims_c2_sum / sims_c2.size
|
|
989
|
-
ess_mean_sims = _ess_mean(sims_c2)
|
|
990
|
-
varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
|
|
991
|
-
varsd = varvar / evar / 4
|
|
992
|
-
mcse_sd_value = np.sqrt(varsd)
|
|
993
|
-
|
|
994
|
-
return (
|
|
995
|
-
mcse_mean_value,
|
|
996
|
-
mcse_sd_value,
|
|
997
|
-
ess_bulk_value,
|
|
998
|
-
ess_tail_value,
|
|
999
|
-
rhat_value,
|
|
1000
|
-
)
|
|
1001
|
-
|
|
1002
|
-
# ess median
|
|
1003
|
-
ess_median_value = _ess_median(ary)
|
|
1004
|
-
|
|
1005
|
-
# mcse_median
|
|
1006
|
-
mcse_median_value = _mcse_median(ary)
|
|
1007
|
-
|
|
1008
|
-
return (
|
|
1009
|
-
mcse_median_value,
|
|
1010
|
-
ess_median_value,
|
|
1011
|
-
ess_tail_value,
|
|
1012
|
-
rhat_value,
|
|
1013
|
-
)
|