arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,119 +0,0 @@
1
- """Stats functions that require refitting the model."""
2
-
3
- import logging
4
- import warnings
5
-
6
- import numpy as np
7
-
8
- from .stats import loo
9
- from .stats_utils import logsumexp as _logsumexp
10
-
11
- __all__ = ["reloo"]
12
-
13
- _log = logging.getLogger(__name__)
14
-
15
-
16
- def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):
17
- """Recalculate exact Leave-One-Out cross validation refitting where the approximation fails.
18
-
19
- ``az.loo`` estimates the values of Leave-One-Out (LOO) cross validation using Pareto
20
- Smoothed Importance Sampling (PSIS) to approximate its value. PSIS works well when
21
- the posterior and the posterior_i (excluding observation i from the data used to fit)
22
- are similar. In some cases, there are highly influential observations for which PSIS
23
- cannot approximate the LOO-CV, and a warning of a large Pareto shape is sent by ArviZ.
24
- This cases typically have a handful of bad or very bad Pareto shapes and a majority of
25
- good or ok shapes.
26
-
27
- Therefore, this may not indicate that the model is not robust enough
28
- nor that these observations are inherently bad, only that PSIS cannot approximate LOO-CV
29
- correctly. Thus, we can use PSIS for all observations where the Pareto shape is below a
30
- threshold and refit the model to perform exact cross validation for the handful of
31
- observations where PSIS cannot be used. This approach allows to properly approximate
32
- LOO-CV with only a handful of refits, which in most cases is still much less computationally
33
- expensive than exact LOO-CV, which needs one refit per observation.
34
-
35
- Parameters
36
- ----------
37
- wrapper: SamplingWrapper-like
38
- Class (preferably a subclass of ``az.SamplingWrapper``, see :ref:`wrappers_api`
39
- for details) implementing the methods described
40
- in the SamplingWrapper docs. This allows ArviZ to call **any** sampling backend
41
- (like PyStan or emcee) using always the same syntax.
42
- loo_orig : ELPDData, optional
43
- ELPDData instance with pointwise loo results. The pareto_k attribute will be checked
44
- for values above the threshold.
45
- k_thresh : float, optional
46
- Pareto shape threshold. Each pareto shape value above ``k_thresh`` will trigger
47
- a refit excluding that observation.
48
- scale : str, optional
49
- Only taken into account when loo_orig is None. See ``az.loo`` for valid options.
50
-
51
- Returns
52
- -------
53
- ELPDData
54
- ELPDData instance containing the PSIS approximation where possible and the exact
55
- LOO-CV result where PSIS failed. The Pareto shape of the observations where exact
56
- LOO-CV was performed is artificially set to 0, but as PSIS is not performed, it
57
- should be ignored.
58
-
59
- Notes
60
- -----
61
- It is strongly recommended to first compute ``az.loo`` on the inference results to
62
- confirm that the number of values above the threshold is small enough. Otherwise,
63
- prohibitive computation time may be needed to perform all required refits.
64
-
65
- As an extreme case, artificially assigning all ``pareto_k`` values to something
66
- larger than the threshold would make ``reloo`` perform the whole exact LOO-CV.
67
- This is not generally recommended
68
- nor intended, however, if needed, this function can be used to achieve the result.
69
-
70
- Warnings
71
- --------
72
- Sampling wrappers are an experimental feature in a very early stage. Please use them
73
- with caution.
74
- """
75
- required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i")
76
- not_implemented = wrapper.check_implemented_methods(required_methods)
77
- if not_implemented:
78
- raise TypeError(
79
- "Passed wrapper instance does not implement all methods required for reloo "
80
- f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be "
81
- "implemented and were not found."
82
- )
83
- if loo_orig is None:
84
- loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale)
85
- loo_refitted = loo_orig.copy()
86
- khats = loo_refitted.pareto_k
87
- loo_i = loo_refitted.loo_i
88
- scale = loo_orig.scale
89
-
90
- if scale.lower() == "deviance":
91
- scale_value = -2
92
- elif scale.lower() == "log":
93
- scale_value = 1
94
- elif scale.lower() == "negative_log":
95
- scale_value = -1
96
- lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value
97
- n_data_points = loo_orig.n_data_points
98
-
99
- if verbose:
100
- warnings.warn("reloo is an experimental and untested feature", UserWarning)
101
-
102
- if np.any(khats > k_thresh):
103
- for idx in np.argwhere(khats.values > k_thresh):
104
- if verbose:
105
- _log.info("Refitting model excluding observation %d", idx)
106
- new_obs, excluded_obs = wrapper.sel_observations(idx)
107
- fit = wrapper.sample(new_obs)
108
- idata_idx = wrapper.get_inference_data(fit)
109
- log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
110
- loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx))
111
- khats[idx] = 0
112
- loo_i[idx] = loo_lppd_idx
113
- loo_refitted.elpd_loo = loo_i.values.sum()
114
- loo_refitted.se = (n_data_points * np.var(loo_i.values)) ** 0.5
115
- loo_refitted.p_loo = lppd_orig - loo_refitted.elpd_loo / scale_value
116
- return loo_refitted
117
- else:
118
- _log.info("No problematic observations")
119
- return loo_orig