arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/stats/stats_refitting.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
|
1
|
-
"""Stats functions that require refitting the model."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import warnings
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
from .stats import loo
|
|
9
|
-
from .stats_utils import logsumexp as _logsumexp
|
|
10
|
-
|
|
11
|
-
__all__ = ["reloo"]
|
|
12
|
-
|
|
13
|
-
_log = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):
|
|
17
|
-
"""Recalculate exact Leave-One-Out cross validation refitting where the approximation fails.
|
|
18
|
-
|
|
19
|
-
``az.loo`` estimates the values of Leave-One-Out (LOO) cross validation using Pareto
|
|
20
|
-
Smoothed Importance Sampling (PSIS) to approximate its value. PSIS works well when
|
|
21
|
-
the posterior and the posterior_i (excluding observation i from the data used to fit)
|
|
22
|
-
are similar. In some cases, there are highly influential observations for which PSIS
|
|
23
|
-
cannot approximate the LOO-CV, and a warning of a large Pareto shape is sent by ArviZ.
|
|
24
|
-
This cases typically have a handful of bad or very bad Pareto shapes and a majority of
|
|
25
|
-
good or ok shapes.
|
|
26
|
-
|
|
27
|
-
Therefore, this may not indicate that the model is not robust enough
|
|
28
|
-
nor that these observations are inherently bad, only that PSIS cannot approximate LOO-CV
|
|
29
|
-
correctly. Thus, we can use PSIS for all observations where the Pareto shape is below a
|
|
30
|
-
threshold and refit the model to perform exact cross validation for the handful of
|
|
31
|
-
observations where PSIS cannot be used. This approach allows to properly approximate
|
|
32
|
-
LOO-CV with only a handful of refits, which in most cases is still much less computationally
|
|
33
|
-
expensive than exact LOO-CV, which needs one refit per observation.
|
|
34
|
-
|
|
35
|
-
Parameters
|
|
36
|
-
----------
|
|
37
|
-
wrapper: SamplingWrapper-like
|
|
38
|
-
Class (preferably a subclass of ``az.SamplingWrapper``, see :ref:`wrappers_api`
|
|
39
|
-
for details) implementing the methods described
|
|
40
|
-
in the SamplingWrapper docs. This allows ArviZ to call **any** sampling backend
|
|
41
|
-
(like PyStan or emcee) using always the same syntax.
|
|
42
|
-
loo_orig : ELPDData, optional
|
|
43
|
-
ELPDData instance with pointwise loo results. The pareto_k attribute will be checked
|
|
44
|
-
for values above the threshold.
|
|
45
|
-
k_thresh : float, optional
|
|
46
|
-
Pareto shape threshold. Each pareto shape value above ``k_thresh`` will trigger
|
|
47
|
-
a refit excluding that observation.
|
|
48
|
-
scale : str, optional
|
|
49
|
-
Only taken into account when loo_orig is None. See ``az.loo`` for valid options.
|
|
50
|
-
|
|
51
|
-
Returns
|
|
52
|
-
-------
|
|
53
|
-
ELPDData
|
|
54
|
-
ELPDData instance containing the PSIS approximation where possible and the exact
|
|
55
|
-
LOO-CV result where PSIS failed. The Pareto shape of the observations where exact
|
|
56
|
-
LOO-CV was performed is artificially set to 0, but as PSIS is not performed, it
|
|
57
|
-
should be ignored.
|
|
58
|
-
|
|
59
|
-
Notes
|
|
60
|
-
-----
|
|
61
|
-
It is strongly recommended to first compute ``az.loo`` on the inference results to
|
|
62
|
-
confirm that the number of values above the threshold is small enough. Otherwise,
|
|
63
|
-
prohibitive computation time may be needed to perform all required refits.
|
|
64
|
-
|
|
65
|
-
As an extreme case, artificially assigning all ``pareto_k`` values to something
|
|
66
|
-
larger than the threshold would make ``reloo`` perform the whole exact LOO-CV.
|
|
67
|
-
This is not generally recommended
|
|
68
|
-
nor intended, however, if needed, this function can be used to achieve the result.
|
|
69
|
-
|
|
70
|
-
Warnings
|
|
71
|
-
--------
|
|
72
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
73
|
-
with caution.
|
|
74
|
-
"""
|
|
75
|
-
required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i")
|
|
76
|
-
not_implemented = wrapper.check_implemented_methods(required_methods)
|
|
77
|
-
if not_implemented:
|
|
78
|
-
raise TypeError(
|
|
79
|
-
"Passed wrapper instance does not implement all methods required for reloo "
|
|
80
|
-
f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be "
|
|
81
|
-
"implemented and were not found."
|
|
82
|
-
)
|
|
83
|
-
if loo_orig is None:
|
|
84
|
-
loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale)
|
|
85
|
-
loo_refitted = loo_orig.copy()
|
|
86
|
-
khats = loo_refitted.pareto_k
|
|
87
|
-
loo_i = loo_refitted.loo_i
|
|
88
|
-
scale = loo_orig.scale
|
|
89
|
-
|
|
90
|
-
if scale.lower() == "deviance":
|
|
91
|
-
scale_value = -2
|
|
92
|
-
elif scale.lower() == "log":
|
|
93
|
-
scale_value = 1
|
|
94
|
-
elif scale.lower() == "negative_log":
|
|
95
|
-
scale_value = -1
|
|
96
|
-
lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value
|
|
97
|
-
n_data_points = loo_orig.n_data_points
|
|
98
|
-
|
|
99
|
-
if verbose:
|
|
100
|
-
warnings.warn("reloo is an experimental and untested feature", UserWarning)
|
|
101
|
-
|
|
102
|
-
if np.any(khats > k_thresh):
|
|
103
|
-
for idx in np.argwhere(khats.values > k_thresh):
|
|
104
|
-
if verbose:
|
|
105
|
-
_log.info("Refitting model excluding observation %d", idx)
|
|
106
|
-
new_obs, excluded_obs = wrapper.sel_observations(idx)
|
|
107
|
-
fit = wrapper.sample(new_obs)
|
|
108
|
-
idata_idx = wrapper.get_inference_data(fit)
|
|
109
|
-
log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
|
|
110
|
-
loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx))
|
|
111
|
-
khats[idx] = 0
|
|
112
|
-
loo_i[idx] = loo_lppd_idx
|
|
113
|
-
loo_refitted.elpd_loo = loo_i.values.sum()
|
|
114
|
-
loo_refitted.se = (n_data_points * np.var(loo_i.values)) ** 0.5
|
|
115
|
-
loo_refitted.p_loo = lppd_orig - loo_refitted.elpd_loo / scale_value
|
|
116
|
-
return loo_refitted
|
|
117
|
-
else:
|
|
118
|
-
_log.info("No problematic observations")
|
|
119
|
-
return loo_orig
|