arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/stats/stats.py DELETED
@@ -1,2422 +0,0 @@
1
- # pylint: disable=too-many-lines
2
- """Statistical functions in ArviZ."""
3
-
4
- import warnings
5
- from copy import deepcopy
6
- from typing import List, Optional, Tuple, Union, Mapping, cast, Callable
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import scipy.stats as st
11
- from xarray_einstats import stats
12
- import xarray as xr
13
- from scipy.optimize import minimize, LinearConstraint, Bounds
14
- from typing_extensions import Literal
15
-
16
- NO_GET_ARGS: bool = False # pylint: disable=invalid-name
17
- try:
18
- from typing_extensions import get_args
19
- except ImportError:
20
- NO_GET_ARGS = True # pylint: disable=invalid-name
21
-
22
- from .. import _log
23
- from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
24
- from ..rcparams import rcParams, ScaleKeyword, ICKeyword
25
- from ..utils import Numba, _numba_var, _var_names, get_coords
26
- from .density_utils import get_bins as _get_bins
27
- from .density_utils import histogram as _histogram
28
- from .density_utils import kde as _kde
29
- from .density_utils import _kde_linear
30
- from .diagnostics import _mc_error, _multichain_statistics, ess
31
- from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
32
- from .stats_utils import get_log_likelihood as _get_log_likelihood
33
- from .stats_utils import get_log_prior as _get_log_prior
34
- from .stats_utils import logsumexp as _logsumexp
35
- from .stats_utils import make_ufunc as _make_ufunc
36
- from .stats_utils import stats_variance_2d as svar
37
- from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
38
- from ..sel_utils import xarray_var_iter
39
- from ..labels import BaseLabeller
40
-
41
-
42
- __all__ = [
43
- "apply_test_function",
44
- "bayes_factor",
45
- "compare",
46
- "hdi",
47
- "loo",
48
- "loo_pit",
49
- "psislw",
50
- "r2_samples",
51
- "r2_score",
52
- "summary",
53
- "waic",
54
- "weight_predictions",
55
- "_calculate_ics",
56
- "psens",
57
- ]
58
-
59
-
60
- def compare(
61
- compare_dict: Mapping[str, InferenceData],
62
- ic: Optional[ICKeyword] = None,
63
- method: Literal["stacking", "BB-pseudo-BMA", "pseudo-BMA"] = "stacking",
64
- b_samples: int = 1000,
65
- alpha: float = 1,
66
- seed=None,
67
- scale: Optional[ScaleKeyword] = None,
68
- var_name: Optional[str] = None,
69
- ):
70
- r"""Compare models based on their expected log pointwise predictive density (ELPD).
71
-
72
- The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
73
- cross-validation (LOO) or using the widely applicable information criterion (WAIC).
74
- We recommend loo. Read more theory here - in a paper by some of the
75
- leading authorities on model comparison dx.doi.org/10.1111/1467-9868.00353
76
-
77
- Parameters
78
- ----------
79
- compare_dict: dict of {str: InferenceData or ELPDData}
80
- A dictionary of model names and :class:`arviz.InferenceData` or ``ELPDData``.
81
- ic: str, optional
82
- Method to estimate the ELPD, available options are "loo" or "waic". Defaults to
83
- ``rcParams["stats.information_criterion"]``.
84
- method: str, optional
85
- Method used to estimate the weights for each model. Available options are:
86
-
87
- - 'stacking' : stacking of predictive distributions.
88
- - 'BB-pseudo-BMA' : pseudo-Bayesian Model averaging using Akaike-type
89
- weighting. The weights are stabilized using the Bayesian bootstrap.
90
- - 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
91
- weighting, without Bootstrap stabilization (not recommended).
92
-
93
- For more information read https://arxiv.org/abs/1704.02030
94
- b_samples: int, optional default = 1000
95
- Number of samples taken by the Bayesian bootstrap estimation.
96
- Only useful when method = 'BB-pseudo-BMA'.
97
- Defaults to ``rcParams["stats.ic_compare_method"]``.
98
- alpha: float, optional
99
- The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
100
- useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
101
- on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
102
- seed: int or np.random.RandomState instance, optional
103
- If int or RandomState, use it for seeding Bayesian bootstrap. Only
104
- useful when method = 'BB-pseudo-BMA'. Default None the global
105
- :mod:`numpy.random` state is used.
106
- scale: str, optional
107
- Output scale for IC. Available options are:
108
-
109
- - `log` : (default) log-score (after Vehtari et al. (2017))
110
- - `negative_log` : -1 * (log-score)
111
- - `deviance` : -2 * (log-score)
112
-
113
- A higher log-score (or a lower deviance) indicates a model with better predictive
114
- accuracy.
115
- var_name: str, optional
116
- If there is more than a single observed variable in the ``InferenceData``, which
117
- should be used as the basis for comparison.
118
-
119
- Returns
120
- -------
121
- A DataFrame, ordered from best to worst model (measured by the ELPD).
122
- The index reflects the key with which the models are passed to this function. The columns are:
123
- rank: The rank-order of the models. 0 is the best.
124
- elpd: ELPD estimated either using (PSIS-LOO-CV `elpd_loo` or WAIC `elpd_waic`).
125
- Higher ELPD indicates higher out-of-sample predictive fit ("better" model).
126
- If `scale` is `deviance` or `negative_log` smaller values indicates
127
- higher out-of-sample predictive fit ("better" model).
128
- pIC: Estimated effective number of parameters.
129
- elpd_diff: The difference in ELPD between two models.
130
- If more than two models are compared, the difference is computed relative to the
131
- top-ranked model, that always has a elpd_diff of 0.
132
- weight: Relative weight for each model.
133
- This can be loosely interpreted as the probability of each model (among the compared model)
134
- given the data. By default the uncertainty in the weights estimation is considered using
135
- Bayesian bootstrap.
136
- SE: Standard error of the ELPD estimate.
137
- If method = BB-pseudo-BMA these values are estimated using Bayesian bootstrap.
138
- dSE: Standard error of the difference in ELPD between each model and the top-ranked model.
139
- It's always 0 for the top-ranked model.
140
- warning: A value of 1 indicates that the computation of the ELPD may not be reliable.
141
- This could be indication of WAIC/LOO starting to fail see
142
- http://arxiv.org/abs/1507.04544 for details.
143
- scale: Scale used for the ELPD.
144
-
145
- Examples
146
- --------
147
- Compare the centered and non centered models of the eight school problem:
148
-
149
- .. ipython::
150
- :okwarning:
151
-
152
- In [1]: import arviz as az
153
- ...: data1 = az.load_arviz_data("non_centered_eight")
154
- ...: data2 = az.load_arviz_data("centered_eight")
155
- ...: compare_dict = {"non centered": data1, "centered": data2}
156
- ...: az.compare(compare_dict)
157
-
158
- Compare the models using PSIS-LOO-CV, returning the ELPD in log scale and calculating the
159
- weights using the stacking method.
160
-
161
- .. ipython::
162
- :okwarning:
163
-
164
- In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
165
-
166
- See Also
167
- --------
168
- loo :
169
- Compute the ELPD using the Pareto smoothed importance sampling Leave-one-out
170
- cross-validation method.
171
- waic : Compute the ELPD using the widely applicable information criterion.
172
- plot_compare : Summary plot for model comparison.
173
-
174
- References
175
- ----------
176
- .. [1] Vehtari, A., Gelman, A. & Gabry, J. Practical Bayesian model evaluation using
177
- leave-one-out cross-validation and WAIC. Stat Comput 27, 1413–1432 (2017)
178
- see https://doi.org/10.1007/s11222-016-9696-4
179
-
180
- """
181
- try:
182
- (ics_dict, scale, ic) = _calculate_ics(compare_dict, scale=scale, ic=ic, var_name=var_name)
183
- except Exception as e:
184
- raise e.__class__("Encountered error in ELPD computation of compare.") from e
185
- names = list(ics_dict.keys())
186
- if ic in {"loo", "waic"}:
187
- df_comp = pd.DataFrame(
188
- {
189
- "rank": pd.Series(index=names, dtype="int"),
190
- f"elpd_{ic}": pd.Series(index=names, dtype="float"),
191
- f"p_{ic}": pd.Series(index=names, dtype="float"),
192
- "elpd_diff": pd.Series(index=names, dtype="float"),
193
- "weight": pd.Series(index=names, dtype="float"),
194
- "se": pd.Series(index=names, dtype="float"),
195
- "dse": pd.Series(index=names, dtype="float"),
196
- "warning": pd.Series(index=names, dtype="boolean"),
197
- "scale": pd.Series(index=names, dtype="str"),
198
- }
199
- )
200
- else:
201
- raise NotImplementedError(f"The information criterion {ic} is not supported.")
202
-
203
- if scale == "log":
204
- scale_value = 1
205
- ascending = False
206
- else:
207
- if scale == "negative_log":
208
- scale_value = -1
209
- else:
210
- scale_value = -2
211
- ascending = True
212
-
213
- method = rcParams["stats.ic_compare_method"] if method is None else method
214
- if method.lower() not in ["stacking", "bb-pseudo-bma", "pseudo-bma"]:
215
- raise ValueError(f"The method {method}, to compute weights, is not supported.")
216
-
217
- p_ic = f"p_{ic}"
218
- ic_i = f"{ic}_i"
219
-
220
- ics = pd.DataFrame.from_dict(ics_dict, orient="index")
221
- ics.sort_values(by=f"elpd_{ic}", inplace=True, ascending=ascending)
222
- ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
223
-
224
- if method.lower() == "stacking":
225
- rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
226
- exp_ic_i = np.exp(ic_i_val / scale_value)
227
-
228
- def log_score(weights):
229
- return -np.sum(np.log(exp_ic_i @ weights))
230
-
231
- def gradient(weights):
232
- denominator = exp_ic_i @ weights
233
- return -np.sum(exp_ic_i / denominator[:, np.newaxis], axis=0)
234
-
235
- theta = np.full(cols, 1.0 / cols)
236
- bounds = Bounds(lb=np.zeros(cols), ub=np.ones(cols))
237
- constraints = LinearConstraint(np.ones(cols), lb=1.0, ub=1.0)
238
-
239
- minimize_result = minimize(
240
- fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
241
- )
242
-
243
- weights = minimize_result["x"]
244
- ses = ics["se"]
245
-
246
- elif method.lower() == "bb-pseudo-bma":
247
- rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
248
- ic_i_val = ic_i_val * rows
249
-
250
- b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)
251
- weights = np.zeros((b_samples, cols))
252
- z_bs = np.zeros_like(weights)
253
- for i in range(b_samples):
254
- z_b = np.dot(b_weighting[i], ic_i_val)
255
- u_weights = np.exp((z_b - np.max(z_b)) / scale_value)
256
- z_bs[i] = z_b # pylint: disable=unsupported-assignment-operation
257
- weights[i] = u_weights / np.sum(u_weights)
258
-
259
- weights = weights.mean(axis=0)
260
- ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
261
-
262
- elif method.lower() == "pseudo-bma":
263
- min_ic = ics.iloc[0][f"elpd_{ic}"]
264
- z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
265
- weights = (z_rv / np.sum(z_rv)).to_numpy()
266
- ses = ics["se"]
267
-
268
- if np.any(weights):
269
- min_ic_i_val = ics[ic_i].iloc[0]
270
- for idx, val in enumerate(ics.index):
271
- res = ics.loc[val]
272
- if scale_value < 0:
273
- diff = res[ic_i] - min_ic_i_val
274
- else:
275
- diff = min_ic_i_val - res[ic_i]
276
- d_ic = np.sum(diff)
277
- d_std_err = np.sqrt(len(diff) * np.var(diff))
278
- std_err = ses.loc[val]
279
- weight = weights[idx]
280
- df_comp.loc[val] = (
281
- idx,
282
- res[f"elpd_{ic}"],
283
- res[p_ic],
284
- d_ic,
285
- weight,
286
- std_err,
287
- d_std_err,
288
- res["warning"],
289
- res["scale"],
290
- )
291
-
292
- df_comp["rank"] = df_comp["rank"].astype(int)
293
- df_comp["warning"] = df_comp["warning"].astype(bool)
294
- return df_comp.sort_values(by=f"elpd_{ic}", ascending=ascending)
295
-
296
-
297
- def _ic_matrix(ics, ic_i):
298
- """Store the previously computed pointwise predictive accuracy values (ics) in a 2D matrix."""
299
- cols, _ = ics.shape
300
- rows = len(ics[ic_i].iloc[0])
301
- ic_i_val = np.zeros((rows, cols))
302
-
303
- for idx, val in enumerate(ics.index):
304
- ic = ics.loc[val][ic_i]
305
-
306
- if len(ic) != rows:
307
- raise ValueError("The number of observations should be the same across all models")
308
-
309
- ic_i_val[:, idx] = ic
310
-
311
- return rows, cols, ic_i_val
312
-
313
-
314
- def _calculate_ics(
315
- compare_dict,
316
- scale: Optional[ScaleKeyword] = None,
317
- ic: Optional[ICKeyword] = None,
318
- var_name: Optional[str] = None,
319
- ):
320
- """Calculate LOO or WAIC only if necessary.
321
-
322
- It always calls the ic function with ``pointwise=True``.
323
-
324
- Parameters
325
- ----------
326
- compare_dict : dict of {str : InferenceData or ELPDData}
327
- A dictionary of model names and InferenceData or ELPDData objects
328
- scale : str, optional
329
- Output scale for IC. Available options are:
330
-
331
- - `log` : (default) log-score (after Vehtari et al. (2017))
332
- - `negative_log` : -1 * (log-score)
333
- - `deviance` : -2 * (log-score)
334
-
335
- A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
336
- ic : str, optional
337
- Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models.
338
- Defaults to ``rcParams["stats.information_criterion"]``.
339
- var_name : str, optional
340
- Name of the variable storing pointwise log likelihood values in ``log_likelihood`` group.
341
-
342
-
343
- Returns
344
- -------
345
- compare_dict : dict of ELPDData
346
- scale : str
347
- ic : str
348
-
349
- """
350
- precomputed_elpds = {
351
- name: elpd_data
352
- for name, elpd_data in compare_dict.items()
353
- if isinstance(elpd_data, ELPDData)
354
- }
355
- precomputed_ic = None
356
- precomputed_scale = None
357
- if precomputed_elpds:
358
- _, arbitrary_elpd = precomputed_elpds.popitem()
359
- precomputed_ic = arbitrary_elpd.index[0].split("_")[1]
360
- precomputed_scale = arbitrary_elpd["scale"]
361
- raise_non_pointwise = f"{precomputed_ic}_i" not in arbitrary_elpd
362
- if any(
363
- elpd_data.index[0].split("_")[1] != precomputed_ic
364
- for elpd_data in precomputed_elpds.values()
365
- ):
366
- raise ValueError(
367
- "All information criteria to be compared must be the same "
368
- "but found both loo and waic."
369
- )
370
- if any(elpd_data["scale"] != precomputed_scale for elpd_data in precomputed_elpds.values()):
371
- raise ValueError("All information criteria to be compared must use the same scale")
372
- if (
373
- any(f"{precomputed_ic}_i" not in elpd_data for elpd_data in precomputed_elpds.values())
374
- or raise_non_pointwise
375
- ):
376
- raise ValueError("Not all provided ELPDData have been calculated with pointwise=True")
377
- if ic is not None and ic.lower() != precomputed_ic:
378
- warnings.warn(
379
- "Provided ic argument is incompatible with precomputed elpd data. "
380
- f"Using ic from precomputed elpddata: {precomputed_ic}"
381
- )
382
- ic = precomputed_ic
383
- if scale is not None and scale.lower() != precomputed_scale:
384
- warnings.warn(
385
- "Provided scale argument is incompatible with precomputed elpd data. "
386
- f"Using scale from precomputed elpddata: {precomputed_scale}"
387
- )
388
- scale = precomputed_scale
389
-
390
- if ic is None and precomputed_ic is None:
391
- ic = cast(ICKeyword, rcParams["stats.information_criterion"])
392
- elif ic is None:
393
- ic = precomputed_ic
394
- else:
395
- ic = cast(ICKeyword, ic.lower())
396
- allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword)
397
- if ic not in allowable:
398
- raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}")
399
-
400
- if scale is None and precomputed_scale is None:
401
- scale = cast(ScaleKeyword, rcParams["stats.ic_scale"])
402
- elif scale is None:
403
- scale = precomputed_scale
404
- else:
405
- scale = cast(ScaleKeyword, scale.lower())
406
- allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword)
407
- if scale not in allowable:
408
- raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}")
409
-
410
- if ic == "loo":
411
- ic_func: Callable = loo
412
- elif ic == "waic":
413
- ic_func = waic
414
- else:
415
- raise NotImplementedError(f"The information criterion {ic} is not supported.")
416
-
417
- compare_dict = deepcopy(compare_dict)
418
- for name, dataset in compare_dict.items():
419
- if not isinstance(dataset, ELPDData):
420
- try:
421
- compare_dict[name] = ic_func(
422
- convert_to_inference_data(dataset),
423
- pointwise=True,
424
- scale=scale,
425
- var_name=var_name,
426
- )
427
- except Exception as e:
428
- raise e.__class__(
429
- f"Encountered error trying to compute {ic} from model {name}."
430
- ) from e
431
- return (compare_dict, scale, ic)
432
-
433
-
434
- def hdi(
435
- ary,
436
- hdi_prob=None,
437
- circular=False,
438
- multimodal=False,
439
- skipna=False,
440
- group="posterior",
441
- var_names=None,
442
- filter_vars=None,
443
- coords=None,
444
- max_modes=10,
445
- dask_kwargs=None,
446
- **kwargs,
447
- ):
448
- """
449
- Calculate highest density interval (HDI) of array for given probability.
450
-
451
- The HDI is the minimum width Bayesian credible interval (BCI).
452
-
453
- Parameters
454
- ----------
455
- ary: obj
456
- object containing posterior samples.
457
- Any object that can be converted to an :class:`arviz.InferenceData` object.
458
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
459
- hdi_prob: float, optional
460
- Prob for which the highest density interval will be computed. Defaults to
461
- ``stats.ci_prob`` rcParam.
462
- circular: bool, optional
463
- Whether to compute the hdi taking into account `x` is a circular variable
464
- (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
465
- Only works if multimodal is False.
466
- multimodal: bool, optional
467
- If true it may compute more than one hdi if the distribution is multimodal and the
468
- modes are well separated.
469
- skipna: bool, optional
470
- If true ignores nan values when computing the hdi. Defaults to false.
471
- group: str, optional
472
- Specifies which InferenceData group should be used to calculate hdi.
473
- Defaults to 'posterior'
474
- var_names: list, optional
475
- Names of variables to include in the hdi report. Prefix the variables by ``~``
476
- when you want to exclude them from the report: `["~beta"]` instead of `["beta"]`
477
- (see :func:`arviz.summary` for more details).
478
- filter_vars: {None, "like", "regex"}, optional, default=None
479
- If `None` (default), interpret var_names as the real variables names. If "like",
480
- interpret var_names as substrings of the real variables names. If "regex",
481
- interpret var_names as regular expressions on the real variables names. A la
482
- ``pandas.filter``.
483
- coords: mapping, optional
484
- Specifies the subset over to calculate hdi.
485
- max_modes: int, optional
486
- Specifies the maximum number of modes for multimodal case.
487
- dask_kwargs : dict, optional
488
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
489
- kwargs: dict, optional
490
- Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`.
491
-
492
- Returns
493
- -------
494
- np.ndarray or xarray.Dataset, depending upon input
495
- lower(s) and upper(s) values of the interval(s).
496
-
497
- See Also
498
- --------
499
- plot_hdi : Plot highest density intervals for regression data.
500
- xarray.Dataset.quantile : Calculate quantiles of array for given probabilities.
501
-
502
- Examples
503
- --------
504
- Calculate the HDI of a Normal random variable:
505
-
506
- .. ipython::
507
-
508
- In [1]: import arviz as az
509
- ...: import numpy as np
510
- ...: data = np.random.normal(size=2000)
511
- ...: az.hdi(data, hdi_prob=.68)
512
-
513
- Calculate the HDI of a dataset:
514
-
515
- .. ipython::
516
-
517
- In [1]: import arviz as az
518
- ...: data = az.load_arviz_data('centered_eight')
519
- ...: az.hdi(data)
520
-
521
- We can also calculate the HDI of some of the variables of dataset:
522
-
523
- .. ipython::
524
-
525
- In [1]: az.hdi(data, var_names=["mu", "theta"])
526
-
527
- By default, ``hdi`` is calculated over the ``chain`` and ``draw`` dimensions. We can use the
528
- ``input_core_dims`` argument of :func:`~arviz.wrap_xarray_ufunc` to change this. In this example
529
- we calculate the HDI also over the ``school`` dimension:
530
-
531
- .. ipython::
532
-
533
- In [1]: az.hdi(data, var_names="theta", input_core_dims = [["chain","draw", "school"]])
534
-
535
- We can also calculate the hdi over a particular selection:
536
-
537
- .. ipython::
538
-
539
- In [1]: az.hdi(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])
540
-
541
- """
542
- if hdi_prob is None:
543
- hdi_prob = rcParams["stats.ci_prob"]
544
- elif not 1 >= hdi_prob > 0:
545
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
546
-
547
- func_kwargs = {
548
- "hdi_prob": hdi_prob,
549
- "skipna": skipna,
550
- "out_shape": (max_modes, 2) if multimodal else (2,),
551
- }
552
- kwargs.setdefault("output_core_dims", [["mode", "hdi"] if multimodal else ["hdi"]])
553
- if not multimodal:
554
- func_kwargs["circular"] = circular
555
- else:
556
- func_kwargs["max_modes"] = max_modes
557
-
558
- func = _hdi_multimodal if multimodal else _hdi
559
-
560
- isarray = isinstance(ary, np.ndarray)
561
- if isarray and ary.ndim <= 1:
562
- func_kwargs.pop("out_shape")
563
- hdi_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg
564
- return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data
565
-
566
- if isarray and ary.ndim == 2:
567
- warnings.warn(
568
- "hdi currently interprets 2d data as (draw, shape) but this will change in "
569
- "a future release to (chain, draw) for coherence with other functions",
570
- FutureWarning,
571
- stacklevel=2,
572
- )
573
- ary = np.expand_dims(ary, 0)
574
-
575
- ary = convert_to_dataset(ary, group=group)
576
- if coords is not None:
577
- ary = get_coords(ary, coords)
578
- var_names = _var_names(var_names, ary, filter_vars)
579
- ary = ary[var_names] if var_names else ary
580
-
581
- hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
582
- hdi_data = _wrap_xarray_ufunc(
583
- func, ary, func_kwargs=func_kwargs, dask_kwargs=dask_kwargs, **kwargs
584
- ).assign_coords({"hdi": hdi_coord})
585
- hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
586
- return hdi_data.x.values if isarray else hdi_data
587
-
588
-
589
- def _hdi(ary, hdi_prob, circular, skipna):
590
- """Compute hpi over the flattened array."""
591
- ary = ary.flatten()
592
- if skipna:
593
- nans = np.isnan(ary)
594
- if not nans.all():
595
- ary = ary[~nans]
596
- n = len(ary)
597
-
598
- if circular:
599
- mean = st.circmean(ary, high=np.pi, low=-np.pi)
600
- ary = ary - mean
601
- ary = np.arctan2(np.sin(ary), np.cos(ary))
602
-
603
- ary = np.sort(ary)
604
- interval_idx_inc = int(np.floor(hdi_prob * n))
605
- n_intervals = n - interval_idx_inc
606
- interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
607
-
608
- if len(interval_width) == 0:
609
- raise ValueError("Too few elements for interval calculation. ")
610
-
611
- min_idx = np.argmin(interval_width)
612
- hdi_min = ary[min_idx]
613
- hdi_max = ary[min_idx + interval_idx_inc]
614
-
615
- if circular:
616
- hdi_min = hdi_min + mean
617
- hdi_max = hdi_max + mean
618
- hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
619
- hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
620
-
621
- hdi_interval = np.array([hdi_min, hdi_max])
622
-
623
- return hdi_interval
624
-
625
-
626
- def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
627
- """Compute HDI if the distribution is multimodal."""
628
- ary = ary.flatten()
629
- if skipna:
630
- ary = ary[~np.isnan(ary)]
631
-
632
- if ary.dtype.kind == "f":
633
- bins, density = _kde(ary)
634
- lower, upper = bins[0], bins[-1]
635
- range_x = upper - lower
636
- dx = range_x / len(density)
637
- else:
638
- bins = _get_bins(ary)
639
- _, density, _ = _histogram(ary, bins=bins)
640
- dx = np.diff(bins)[0]
641
-
642
- density *= dx
643
-
644
- idx = np.argsort(-density)
645
- intervals = bins[idx][density[idx].cumsum() <= hdi_prob]
646
- intervals.sort()
647
-
648
- intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
649
-
650
- hdi_intervals = np.full((max_modes, 2), np.nan)
651
- for i, interval in enumerate(intervals_splitted):
652
- if i == max_modes:
653
- warnings.warn(
654
- f"found more modes than {max_modes}, returning only the first {max_modes} modes"
655
- )
656
- break
657
- if interval.size == 0:
658
- hdi_intervals[i] = np.asarray([bins[0], bins[0]])
659
- else:
660
- hdi_intervals[i] = np.asarray([interval[0], interval[-1]])
661
-
662
- return np.array(hdi_intervals)
663
-
664
-
665
- def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
666
- """Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
667
-
668
- Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
669
- importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Also calculates LOO's
670
- standard error and the effective number of parameters. Read more theory here
671
- https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1507.02646
672
-
673
- Parameters
674
- ----------
675
- data: obj
676
- Any object that can be converted to an :class:`arviz.InferenceData` object.
677
- Refer to documentation of
678
- :func:`arviz.convert_to_dataset` for details.
679
- pointwise: bool, optional
680
- If True the pointwise predictive accuracy will be returned. Defaults to
681
- ``stats.ic_pointwise`` rcParam.
682
- var_name : str, optional
683
- The name of the variable in log_likelihood groups storing the pointwise log
684
- likelihood data to use for loo computation.
685
- reff: float, optional
686
- Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
687
- of actual samples. Computed from trace by default.
688
- scale: str
689
- Output scale for loo. Available options are:
690
-
691
- - ``log`` : (default) log-score
692
- - ``negative_log`` : -1 * log-score
693
- - ``deviance`` : -2 * log-score
694
-
695
- A higher log-score (or a lower deviance or negative log_score) indicates a model with
696
- better predictive accuracy.
697
-
698
- Returns
699
- -------
700
- ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
701
- elpd_loo: approximated expected log pointwise predictive density (elpd)
702
- se: standard error of the elpd
703
- p_loo: effective number of parameters
704
- n_samples: number of samples
705
- n_data_points: number of data points
706
- warning: bool
707
- True if the estimated shape parameter of Pareto distribution is greater than
708
- ``good_k``.
709
- loo_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
710
- only if pointwise=True
711
- pareto_k: array of Pareto shape values, only if pointwise True
712
- scale: scale of the elpd
713
- good_k: For a sample size S, the thresold is compute as min(1 - 1/log10(S), 0.7)
714
-
715
- The returned object has a custom print method that overrides pd.Series method.
716
-
717
- See Also
718
- --------
719
- compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
720
- waic : Compute the widely applicable information criterion.
721
- plot_compare : Summary plot for model comparison.
722
- plot_elpd : Plot pointwise elpd differences between two or more models.
723
- plot_khat : Plot Pareto tail indices for diagnosing convergence.
724
-
725
- Examples
726
- --------
727
- Calculate LOO of a model:
728
-
729
- .. ipython::
730
-
731
- In [1]: import arviz as az
732
- ...: data = az.load_arviz_data("centered_eight")
733
- ...: az.loo(data)
734
-
735
- Calculate LOO of a model and return the pointwise values:
736
-
737
- .. ipython::
738
-
739
- In [2]: data_loo = az.loo(data, pointwise=True)
740
- ...: data_loo.loo_i
741
- """
742
- inference_data = convert_to_inference_data(data)
743
- log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
744
- pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
745
-
746
- log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
747
- shape = log_likelihood.shape
748
- n_samples = shape[-1]
749
- n_data_points = np.prod(shape[:-1])
750
- scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
751
-
752
- if scale == "deviance":
753
- scale_value = -2
754
- elif scale == "log":
755
- scale_value = 1
756
- elif scale == "negative_log":
757
- scale_value = -1
758
- else:
759
- raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
760
-
761
- if reff is None:
762
- if not hasattr(inference_data, "posterior"):
763
- raise TypeError("Must be able to extract a posterior group from data.")
764
- posterior = inference_data.posterior
765
- n_chains = len(posterior.chain)
766
- if n_chains == 1:
767
- reff = 1.0
768
- else:
769
- ess_p = ess(posterior, method="mean")
770
- # this mean is over all data variables
771
- reff = (
772
- np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
773
- )
774
-
775
- log_weights, pareto_shape = psislw(-log_likelihood, reff)
776
- log_weights += log_likelihood
777
-
778
- warn_mg = False
779
- good_k = min(1 - 1 / np.log10(n_samples), 0.7)
780
-
781
- if np.any(pareto_shape > good_k):
782
- warnings.warn(
783
- f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
784
- "for one or more samples. You should consider using a more robust model, this is "
785
- "because importance sampling is less likely to work well if the marginal posterior "
786
- "and LOO posterior are very different. This is more likely to happen with a "
787
- "non-robust model and highly influential observations."
788
- )
789
- warn_mg = True
790
-
791
- ufunc_kwargs = {"n_dims": 1, "ravel": False}
792
- kwargs = {"input_core_dims": [["__sample__"]]}
793
- loo_lppd_i = scale_value * _wrap_xarray_ufunc(
794
- _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
795
- )
796
- loo_lppd = loo_lppd_i.values.sum()
797
- loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
798
-
799
- lppd = np.sum(
800
- _wrap_xarray_ufunc(
801
- _logsumexp,
802
- log_likelihood,
803
- func_kwargs={"b_inv": n_samples},
804
- ufunc_kwargs=ufunc_kwargs,
805
- **kwargs,
806
- ).values
807
- )
808
- p_loo = lppd - loo_lppd / scale_value
809
-
810
- if not pointwise:
811
- return ELPDData(
812
- data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
813
- index=[
814
- "elpd_loo",
815
- "se",
816
- "p_loo",
817
- "n_samples",
818
- "n_data_points",
819
- "warning",
820
- "scale",
821
- "good_k",
822
- ],
823
- )
824
- if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
825
- warnings.warn(
826
- "The point-wise LOO is the same with the sum LOO, please double check "
827
- "the Observed RV in your model to make sure it returns element-wise logp."
828
- )
829
- return ELPDData(
830
- data=[
831
- loo_lppd,
832
- loo_lppd_se,
833
- p_loo,
834
- n_samples,
835
- n_data_points,
836
- warn_mg,
837
- loo_lppd_i.rename("loo_i"),
838
- pareto_shape,
839
- scale,
840
- good_k,
841
- ],
842
- index=[
843
- "elpd_loo",
844
- "se",
845
- "p_loo",
846
- "n_samples",
847
- "n_data_points",
848
- "warning",
849
- "loo_i",
850
- "pareto_k",
851
- "scale",
852
- "good_k",
853
- ],
854
- )
855
-
856
-
857
- def psislw(log_weights, reff=1.0, normalize=True):
858
- """
859
- Pareto smoothed importance sampling (PSIS).
860
-
861
- Notes
862
- -----
863
- If the ``log_weights`` input is an :class:`~xarray.DataArray` with a dimension
864
- named ``__sample__`` (recommended) ``psislw`` will interpret this dimension as samples,
865
- and all other dimensions as dimensions of the observed data, looping over them to
866
- calculate the psislw of each observation. If no ``__sample__`` dimension is present or
867
- the input is a numpy array, the last dimension will be interpreted as ``__sample__``.
868
-
869
- Parameters
870
- ----------
871
- log_weights : DataArray or (..., N) array-like
872
- Array of size (n_observations, n_samples)
873
- reff : float, default 1
874
- relative MCMC efficiency, ``ess / n``
875
- normalize : bool, default True
876
- return normalized log weights
877
-
878
- Returns
879
- -------
880
- lw_out : DataArray or (..., N) ndarray
881
- Smoothed, truncated and possibly normalized log weights.
882
- kss : DataArray or (...) ndarray
883
- Estimates of the shape parameter *k* of the generalized Pareto
884
- distribution.
885
-
886
- References
887
- ----------
888
- * Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
889
- Learning Research, 25(72):1-58.
890
-
891
- See Also
892
- --------
893
- loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
894
-
895
- Examples
896
- --------
897
- Get Pareto smoothed importance sampling (PSIS) log weights:
898
-
899
- .. ipython::
900
-
901
- In [1]: import arviz as az
902
- ...: data = az.load_arviz_data("non_centered_eight")
903
- ...: log_likelihood = data.log_likelihood["obs"].stack(
904
- ...: __sample__=["chain", "draw"]
905
- ...: )
906
- ...: az.psislw(-log_likelihood, reff=0.8)
907
-
908
- """
909
- log_weights = deepcopy(log_weights)
910
- if hasattr(log_weights, "__sample__"):
911
- n_samples = len(log_weights.__sample__)
912
- shape = [
913
- size for size, dim in zip(log_weights.shape, log_weights.dims) if dim != "__sample__"
914
- ]
915
- else:
916
- n_samples = log_weights.shape[-1]
917
- shape = log_weights.shape[:-1]
918
- # precalculate constants
919
- cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1
920
- cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return
921
-
922
- # create output array with proper dimensions
923
- out = np.empty_like(log_weights), np.empty(shape)
924
-
925
- # define kwargs
926
- func_kwargs = {
927
- "cutoff_ind": cutoff_ind,
928
- "cutoffmin": cutoffmin,
929
- "out": out,
930
- "normalize": normalize,
931
- }
932
- ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
933
- kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
934
- log_weights, pareto_shape = _wrap_xarray_ufunc(
935
- _psislw,
936
- log_weights,
937
- ufunc_kwargs=ufunc_kwargs,
938
- func_kwargs=func_kwargs,
939
- **kwargs,
940
- )
941
- if isinstance(log_weights, xr.DataArray):
942
- log_weights = log_weights.rename("log_weights")
943
- if isinstance(pareto_shape, xr.DataArray):
944
- pareto_shape = pareto_shape.rename("pareto_shape")
945
- return log_weights, pareto_shape
946
-
947
-
948
- def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
949
- """
950
- Pareto smoothed importance sampling (PSIS) for a 1D vector.
951
-
952
- Parameters
953
- ----------
954
- log_weights: array
955
- Array of length n_observations
956
- cutoff_ind: int
957
- cutoffmin: float
958
- normalize: bool
959
-
960
- Returns
961
- -------
962
- lw_out: array
963
- Smoothed log weights
964
- kss: float
965
- Pareto tail index
966
- """
967
- x = np.asarray(log_weights)
968
-
969
- # improve numerical accuracy
970
- max_x = np.max(x)
971
- x -= max_x
972
- # sort the array
973
- x_sort_ind = np.argsort(x)
974
- # divide log weights into body and right tail
975
- xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)
976
-
977
- expxcutoff = np.exp(xcutoff)
978
- (tailinds,) = np.where(x > xcutoff) # pylint: disable=unbalanced-tuple-unpacking
979
- x_tail = x[tailinds]
980
- tail_len = len(x_tail)
981
- if tail_len <= 4:
982
- # not enough tail samples for gpdfit
983
- k = np.inf
984
- else:
985
- # order of tail samples
986
- x_tail_si = np.argsort(x_tail)
987
- # fit generalized Pareto distribution to the right tail samples
988
- x_tail = np.exp(x_tail) - expxcutoff
989
- k, sigma = _gpdfit(x_tail[x_tail_si])
990
-
991
- if np.isfinite(k):
992
- # no smoothing if GPD fit failed
993
- # compute ordered statistic for the fit
994
- sti = np.arange(0.5, tail_len) / tail_len
995
- smoothed_tail = _gpinv(sti, k, sigma)
996
- smoothed_tail = np.log( # pylint: disable=assignment-from-no-return
997
- smoothed_tail + expxcutoff
998
- )
999
- # place the smoothed tail into the output array
1000
- x[tailinds[x_tail_si]] = smoothed_tail
1001
- # truncate smoothed values to the largest raw weight 0
1002
- x[x > 0] = 0
1003
-
1004
- # renormalize weights
1005
- if normalize:
1006
- x -= _logsumexp(x)
1007
- else:
1008
- x += max_x
1009
-
1010
- return x, k
1011
-
1012
-
1013
- def _gpdfit(ary):
1014
- """Estimate the parameters for the Generalized Pareto Distribution (GPD).
1015
-
1016
- Empirical Bayes estimate for the parameters of the generalized Pareto
1017
- distribution given the data.
1018
-
1019
- Parameters
1020
- ----------
1021
- ary: array
1022
- sorted 1D data array
1023
-
1024
- Returns
1025
- -------
1026
- k: float
1027
- estimated shape parameter
1028
- sigma: float
1029
- estimated scale parameter
1030
- """
1031
- prior_bs = 3
1032
- prior_k = 10
1033
- n = len(ary)
1034
- m_est = 30 + int(n**0.5)
1035
-
1036
- b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5))
1037
- b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
1038
- b_ary += 1 / ary[-1]
1039
-
1040
- k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member
1041
- len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
1042
- weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
1043
-
1044
- # remove negligible weights
1045
- real_idxs = weights >= 10 * np.finfo(float).eps
1046
- if not np.all(real_idxs):
1047
- weights = weights[real_idxs]
1048
- b_ary = b_ary[real_idxs]
1049
- # normalise weights
1050
- weights /= weights.sum()
1051
-
1052
- # posterior mean for b
1053
- b_post = np.sum(b_ary * weights)
1054
- # estimate for k
1055
- k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member
1056
- # add prior for k_post
1057
- sigma = -k_post / b_post
1058
- k_post = (n * k_post + prior_k * 0.5) / (n + prior_k)
1059
-
1060
- return k_post, sigma
1061
-
1062
-
1063
- def _gpinv(probs, kappa, sigma):
1064
- """Inverse Generalized Pareto distribution function."""
1065
- # pylint: disable=unsupported-assignment-operation, invalid-unary-operand-type
1066
- x = np.full_like(probs, np.nan)
1067
- if sigma <= 0:
1068
- return x
1069
- ok = (probs > 0) & (probs < 1)
1070
- if np.all(ok):
1071
- if np.abs(kappa) < np.finfo(float).eps:
1072
- x = -np.log1p(-probs)
1073
- else:
1074
- x = np.expm1(-kappa * np.log1p(-probs)) / kappa
1075
- x *= sigma
1076
- else:
1077
- if np.abs(kappa) < np.finfo(float).eps:
1078
- x[ok] = -np.log1p(-probs[ok])
1079
- else:
1080
- x[ok] = np.expm1(-kappa * np.log1p(-probs[ok])) / kappa
1081
- x *= sigma
1082
- x[probs == 0] = 0
1083
- x[probs == 1] = np.inf if kappa >= 0 else -sigma / kappa
1084
- return x
1085
-
1086
-
1087
- def r2_samples(y_true, y_pred):
1088
- """R² samples for Bayesian regression models. Only valid for linear models.
1089
-
1090
- Parameters
1091
- ----------
1092
- y_true: array-like of shape = (n_outputs,)
1093
- Ground truth (correct) target values.
1094
- y_pred: array-like of shape = (n_posterior_samples, n_outputs)
1095
- Estimated target values.
1096
-
1097
- Returns
1098
- -------
1099
- Pandas Series with the following indices:
1100
- Bayesian R² samples.
1101
-
1102
- See Also
1103
- --------
1104
- plot_lm : Posterior predictive and mean plots for regression-like data.
1105
-
1106
- Examples
1107
- --------
1108
- Calculate R² samples for Bayesian regression models :
1109
-
1110
- .. ipython::
1111
-
1112
- In [1]: import arviz as az
1113
- ...: data = az.load_arviz_data('regression1d')
1114
- ...: y_true = data.observed_data["y"].values
1115
- ...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
1116
- ...: az.r2_samples(y_true, y_pred)
1117
-
1118
- """
1119
- _numba_flag = Numba.numba_flag
1120
- if y_pred.ndim == 1:
1121
- var_y_est = _numba_var(svar, np.var, y_pred)
1122
- var_e = _numba_var(svar, np.var, (y_true - y_pred))
1123
- else:
1124
- var_y_est = _numba_var(svar, np.var, y_pred, axis=1)
1125
- var_e = _numba_var(svar, np.var, (y_true - y_pred), axis=1)
1126
- r_squared = var_y_est / (var_y_est + var_e)
1127
-
1128
- return r_squared
1129
-
1130
-
1131
- def r2_score(y_true, y_pred):
1132
- """R² for Bayesian regression models. Only valid for linear models.
1133
-
1134
- Parameters
1135
- ----------
1136
- y_true: array-like of shape = (n_outputs,)
1137
- Ground truth (correct) target values.
1138
- y_pred: array-like of shape = (n_posterior_samples, n_outputs)
1139
- Estimated target values.
1140
-
1141
- Returns
1142
- -------
1143
- Pandas Series with the following indices:
1144
- r2: Bayesian R²
1145
- r2_std: standard deviation of the Bayesian R².
1146
-
1147
- See Also
1148
- --------
1149
- plot_lm : Posterior predictive and mean plots for regression-like data.
1150
-
1151
- Examples
1152
- --------
1153
- Calculate R² for Bayesian regression models :
1154
-
1155
- .. ipython::
1156
-
1157
- In [1]: import arviz as az
1158
- ...: data = az.load_arviz_data('regression1d')
1159
- ...: y_true = data.observed_data["y"].values
1160
- ...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
1161
- ...: az.r2_score(y_true, y_pred)
1162
-
1163
- """
1164
- r_squared = r2_samples(y_true=y_true, y_pred=y_pred)
1165
- return pd.Series([np.mean(r_squared), np.std(r_squared)], index=["r2", "r2_std"])
1166
-
1167
-
1168
- def summary(
1169
- data,
1170
- var_names: Optional[List[str]] = None,
1171
- filter_vars=None,
1172
- group=None,
1173
- fmt: "Literal['wide', 'long', 'xarray']" = "wide",
1174
- kind: "Literal['all', 'stats', 'diagnostics']" = "all",
1175
- round_to=None,
1176
- circ_var_names=None,
1177
- stat_focus="mean",
1178
- stat_funcs=None,
1179
- extend=True,
1180
- hdi_prob=None,
1181
- skipna=False,
1182
- labeller=None,
1183
- coords=None,
1184
- index_origin=None,
1185
- order=None,
1186
- ) -> Union[pd.DataFrame, xr.Dataset]:
1187
- """Create a data frame with summary statistics.
1188
-
1189
- Parameters
1190
- ----------
1191
- data: obj
1192
- Any object that can be converted to an :class:`arviz.InferenceData` object
1193
- Refer to documentation of :func:`arviz.convert_to_dataset` for details
1194
- var_names: list
1195
- Names of variables to include in summary. Prefix the variables by ``~`` when you
1196
- want to exclude them from the summary: `["~beta"]` instead of `["beta"]` (see
1197
- examples below).
1198
- filter_vars: {None, "like", "regex"}, optional, default=None
1199
- If `None` (default), interpret var_names as the real variables names. If "like",
1200
- interpret var_names as substrings of the real variables names. If "regex",
1201
- interpret var_names as regular expressions on the real variables names. A la
1202
- ``pandas.filter``.
1203
- coords: Dict[str, List[Any]], optional
1204
- Coordinate subset for which to calculate the summary.
1205
- group: str
1206
- Select a group for summary. Defaults to "posterior", "prior" or first group
1207
- in that order, depending what groups exists.
1208
- fmt: {'wide', 'long', 'xarray'}
1209
- Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}.
1210
- kind: {'all', 'stats', 'diagnostics'}
1211
- Whether to include the `stats`: `mean`, `sd`, `hdi_3%`, `hdi_97%`, or the `diagnostics`:
1212
- `mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and `r_hat`. Default to include `all` of
1213
- them.
1214
- round_to: int
1215
- Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1216
- circ_var_names: list
1217
- A list of circular variables to compute circular stats for
1218
- stat_focus : str, default "mean"
1219
- Select the focus for summary.
1220
- stat_funcs: dict
1221
- A list of functions or a dict of functions with function names as keys used to calculate
1222
- statistics. By default, the mean, standard deviation, simulation standard error, and
1223
- highest posterior density intervals are included.
1224
-
1225
- The functions will be given one argument, the samples for a variable as an nD array,
1226
- The functions should be in the style of a ufunc and return a single number. For example,
1227
- :func:`numpy.mean`, or ``scipy.stats.var`` would both work.
1228
- extend: boolean
1229
- If True, use the statistics returned by ``stat_funcs`` in addition to, rather than in place
1230
- of, the default statistics. This is only meaningful when ``stat_funcs`` is not None.
1231
- hdi_prob: float, optional
1232
- Highest density interval to compute. Defaults to 0.94. This is only meaningful when
1233
- ``stat_funcs`` is None.
1234
- skipna: bool
1235
- If true ignores nan values when computing the summary statistics, it does not affect the
1236
- behaviour of the functions passed to ``stat_funcs``. Defaults to false.
1237
- labeller : labeller instance, optional
1238
- Class providing the method `make_label_flat` to generate the labels in the plot titles.
1239
- For more details on ``labeller`` usage see :ref:`label_guide`
1240
- credible_interval: float, optional
1241
- deprecated: Please see hdi_prob
1242
- order
1243
- deprecated: order is now ignored.
1244
- index_origin
1245
- deprecated: index_origin is now ignored, modify the coordinate values to change the
1246
- value used in summary.
1247
-
1248
- Returns
1249
- -------
1250
- pandas.DataFrame or xarray.Dataset
1251
- Return type dicated by `fmt` argument.
1252
-
1253
- Return value will contain summary statistics for each variable. Default statistics depend on
1254
- the value of ``stat_focus``:
1255
-
1256
- ``stat_focus="mean"``: `mean`, `sd`, `hdi_3%`, `hdi_97%`, `mcse_mean`, `mcse_sd`,
1257
- `ess_bulk`, `ess_tail`, and `r_hat`
1258
-
1259
- ``stat_focus="median"``: `median`, `mad`, `eti_3%`, `eti_97%`, `mcse_median`, `ess_median`,
1260
- `ess_tail`, and `r_hat`
1261
-
1262
- `r_hat` is only computed for traces with 2 or more chains.
1263
-
1264
- See Also
1265
- --------
1266
- waic : Compute the widely applicable information criterion.
1267
- loo : Compute Pareto-smoothed importance sampling leave-one-out
1268
- cross-validation (PSIS-LOO-CV).
1269
- ess : Calculate estimate of the effective sample size (ess).
1270
- rhat : Compute estimate of rank normalized splitR-hat for a set of traces.
1271
- mcse : Calculate Markov Chain Standard Error statistic.
1272
-
1273
- Examples
1274
- --------
1275
- .. ipython::
1276
-
1277
- In [1]: import arviz as az
1278
- ...: data = az.load_arviz_data("centered_eight")
1279
- ...: az.summary(data, var_names=["mu", "tau"])
1280
-
1281
- You can use ``filter_vars`` to select variables without having to specify all the exact
1282
- names. Use ``filter_vars="like"`` to select based on partial naming:
1283
-
1284
- .. ipython::
1285
-
1286
- In [1]: az.summary(data, var_names=["the"], filter_vars="like")
1287
-
1288
- Use ``filter_vars="regex"`` to select based on regular expressions, and prefix the variables
1289
- you want to exclude by ``~``. Here, we exclude from the summary all the variables
1290
- starting with the letter t:
1291
-
1292
- .. ipython::
1293
-
1294
- In [1]: az.summary(data, var_names=["~^t"], filter_vars="regex")
1295
-
1296
- Other statistics can be calculated by passing a list of functions
1297
- or a dictionary with key, function pairs.
1298
-
1299
- .. ipython::
1300
-
1301
- In [1]: import numpy as np
1302
- ...: def median_sd(x):
1303
- ...: median = np.percentile(x, 50)
1304
- ...: sd = np.sqrt(np.mean((x-median)**2))
1305
- ...: return sd
1306
- ...:
1307
- ...: func_dict = {
1308
- ...: "std": np.std,
1309
- ...: "median_std": median_sd,
1310
- ...: "5%": lambda x: np.percentile(x, 5),
1311
- ...: "median": lambda x: np.percentile(x, 50),
1312
- ...: "95%": lambda x: np.percentile(x, 95),
1313
- ...: }
1314
- ...: az.summary(
1315
- ...: data,
1316
- ...: var_names=["mu", "tau"],
1317
- ...: stat_funcs=func_dict,
1318
- ...: extend=False
1319
- ...: )
1320
-
1321
- Use ``stat_focus`` to change the focus of summary statistics obatined to median:
1322
-
1323
- .. ipython::
1324
-
1325
- In [1]: az.summary(data, stat_focus="median")
1326
-
1327
- """
1328
- _log.cache = []
1329
-
1330
- if coords is None:
1331
- coords = {}
1332
-
1333
- if index_origin is not None:
1334
- warnings.warn(
1335
- "index_origin has been deprecated. summary now shows coordinate values, "
1336
- "to change the label shown, modify the coordinate values before calling summary",
1337
- DeprecationWarning,
1338
- )
1339
- index_origin = rcParams["data.index_origin"]
1340
- if labeller is None:
1341
- labeller = BaseLabeller()
1342
- if hdi_prob is None:
1343
- hdi_prob = rcParams["stats.ci_prob"]
1344
- elif not 1 >= hdi_prob > 0:
1345
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
1346
-
1347
- if isinstance(data, InferenceData):
1348
- if group is None:
1349
- if not data.groups():
1350
- raise TypeError("InferenceData does not contain any groups")
1351
- if "posterior" in data:
1352
- dataset = data["posterior"]
1353
- elif "prior" in data:
1354
- dataset = data["prior"]
1355
- else:
1356
- warnings.warn(f"Selecting first found group: {data.groups()[0]}")
1357
- dataset = data[data.groups()[0]]
1358
- elif group in data.groups():
1359
- dataset = data[group]
1360
- else:
1361
- raise TypeError(f"InferenceData does not contain group: {group}")
1362
- else:
1363
- dataset = convert_to_dataset(data, group="posterior")
1364
- var_names = _var_names(var_names, dataset, filter_vars)
1365
- dataset = dataset if var_names is None else dataset[var_names]
1366
- dataset = get_coords(dataset, coords)
1367
-
1368
- fmt_group = ("wide", "long", "xarray")
1369
- if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
1370
- raise TypeError(f"Invalid format: '{fmt}'. Formatting options are: {fmt_group}")
1371
-
1372
- kind_group = ("all", "stats", "diagnostics")
1373
- if not isinstance(kind, str) or kind not in kind_group:
1374
- raise TypeError(f"Invalid kind: '{kind}'. Kind options are: {kind_group}")
1375
-
1376
- focus_group = ("mean", "median")
1377
- if not isinstance(stat_focus, str) or (stat_focus not in focus_group):
1378
- raise TypeError(f"Invalid format: '{stat_focus}'. Focus options are: {focus_group}")
1379
-
1380
- if stat_focus != "mean" and circ_var_names is not None:
1381
- raise TypeError(f"Invalid format: Circular stats not supported for '{stat_focus}'")
1382
-
1383
- if order is not None:
1384
- warnings.warn(
1385
- "order has been deprecated. summary now shows coordinate values.", DeprecationWarning
1386
- )
1387
-
1388
- alpha = 1 - hdi_prob
1389
-
1390
- extra_metrics = []
1391
- extra_metric_names = []
1392
-
1393
- if stat_funcs is not None:
1394
- if isinstance(stat_funcs, dict):
1395
- for stat_func_name, stat_func in stat_funcs.items():
1396
- extra_metrics.append(
1397
- xr.apply_ufunc(
1398
- _make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
1399
- )
1400
- )
1401
- extra_metric_names.append(stat_func_name)
1402
- else:
1403
- for stat_func in stat_funcs:
1404
- extra_metrics.append(
1405
- xr.apply_ufunc(
1406
- _make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
1407
- )
1408
- )
1409
- extra_metric_names.append(stat_func.__name__)
1410
-
1411
- metrics: List[xr.Dataset] = []
1412
- metric_names: List[str] = []
1413
- if extend and kind in ["all", "stats"]:
1414
- if stat_focus == "mean":
1415
- mean = dataset.mean(dim=("chain", "draw"), skipna=skipna)
1416
-
1417
- sd = dataset.std(dim=("chain", "draw"), ddof=1, skipna=skipna)
1418
-
1419
- hdi_post = hdi(dataset, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
1420
- hdi_lower = hdi_post.sel(hdi="lower", drop=True)
1421
- hdi_higher = hdi_post.sel(hdi="higher", drop=True)
1422
- metrics.extend((mean, sd, hdi_lower, hdi_higher))
1423
- metric_names.extend(
1424
- ("mean", "sd", f"hdi_{100 * alpha / 2:g}%", f"hdi_{100 * (1 - alpha / 2):g}%")
1425
- )
1426
- elif stat_focus == "median":
1427
- median = dataset.median(dim=("chain", "draw"), skipna=skipna)
1428
-
1429
- mad = stats.median_abs_deviation(dataset, dims=("chain", "draw"))
1430
- eti_post = dataset.quantile(
1431
- (alpha / 2, 1 - alpha / 2), dim=("chain", "draw"), skipna=skipna
1432
- )
1433
- eti_lower = eti_post.isel(quantile=0, drop=True)
1434
- eti_higher = eti_post.isel(quantile=1, drop=True)
1435
- metrics.extend((median, mad, eti_lower, eti_higher))
1436
- metric_names.extend(
1437
- ("median", "mad", f"eti_{100 * alpha / 2:g}%", f"eti_{100 * (1 - alpha / 2):g}%")
1438
- )
1439
-
1440
- if circ_var_names:
1441
- nan_policy = "omit" if skipna else "propagate"
1442
- circ_mean = stats.circmean(
1443
- dataset, dims=["chain", "draw"], high=np.pi, low=-np.pi, nan_policy=nan_policy
1444
- )
1445
- _numba_flag = Numba.numba_flag
1446
- if _numba_flag:
1447
- circ_sd = xr.apply_ufunc(
1448
- _make_ufunc(_circular_standard_deviation),
1449
- dataset,
1450
- kwargs=dict(high=np.pi, low=-np.pi, skipna=skipna),
1451
- input_core_dims=(("chain", "draw"),),
1452
- )
1453
- else:
1454
- circ_sd = stats.circstd(
1455
- dataset, dims=["chain", "draw"], high=np.pi, low=-np.pi, nan_policy=nan_policy
1456
- )
1457
- circ_mcse = xr.apply_ufunc(
1458
- _make_ufunc(_mc_error),
1459
- dataset,
1460
- kwargs=dict(circular=True),
1461
- input_core_dims=(("chain", "draw"),),
1462
- )
1463
-
1464
- circ_hdi = hdi(dataset, hdi_prob=hdi_prob, circular=True, skipna=skipna)
1465
- circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True)
1466
- circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)
1467
-
1468
- if kind in ["all", "diagnostics"] and extend:
1469
- diagnostics_names: Tuple[str, ...]
1470
- if stat_focus == "mean":
1471
- diagnostics = xr.apply_ufunc(
1472
- _make_ufunc(_multichain_statistics, n_output=5, ravel=False),
1473
- dataset,
1474
- input_core_dims=(("chain", "draw"),),
1475
- output_core_dims=tuple([] for _ in range(5)),
1476
- )
1477
- diagnostics_names = (
1478
- "mcse_mean",
1479
- "mcse_sd",
1480
- "ess_bulk",
1481
- "ess_tail",
1482
- "r_hat",
1483
- )
1484
-
1485
- elif stat_focus == "median":
1486
- diagnostics = xr.apply_ufunc(
1487
- _make_ufunc(_multichain_statistics, n_output=4, ravel=False),
1488
- dataset,
1489
- kwargs=dict(focus="median"),
1490
- input_core_dims=(("chain", "draw"),),
1491
- output_core_dims=tuple([] for _ in range(4)),
1492
- )
1493
- diagnostics_names = (
1494
- "mcse_median",
1495
- "ess_median",
1496
- "ess_tail",
1497
- "r_hat",
1498
- )
1499
- metrics.extend(diagnostics)
1500
- metric_names.extend(diagnostics_names)
1501
-
1502
- if circ_var_names and kind != "diagnostics" and stat_focus == "mean":
1503
- for metric, circ_stat in zip(
1504
- # Replace only the first 5 statistics for their circular equivalent
1505
- metrics[:5],
1506
- (circ_mean, circ_sd, circ_hdi_lower, circ_hdi_higher, circ_mcse),
1507
- ):
1508
- for circ_var in circ_var_names:
1509
- metric[circ_var] = circ_stat[circ_var]
1510
-
1511
- metrics.extend(extra_metrics)
1512
- metric_names.extend(extra_metric_names)
1513
- joined = (
1514
- xr.concat(metrics, dim="metric").assign_coords(metric=metric_names).reset_coords(drop=True)
1515
- )
1516
- n_metrics = len(metric_names)
1517
- n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars])
1518
-
1519
- if fmt.lower() == "wide":
1520
- summary_df = pd.DataFrame(
1521
- (np.full((cast(int, n_vars), n_metrics), np.nan)), columns=metric_names
1522
- )
1523
- indices = []
1524
- for i, (var_name, sel, isel, values) in enumerate(
1525
- xarray_var_iter(joined, skip_dims={"metric"})
1526
- ):
1527
- summary_df.iloc[i] = values
1528
- indices.append(labeller.make_label_flat(var_name, sel, isel))
1529
- summary_df.index = indices
1530
- elif fmt.lower() == "long":
1531
- df = joined.to_dataframe().reset_index().set_index("metric")
1532
- df.index = list(df.index)
1533
- summary_df = df
1534
- else:
1535
- # format is 'xarray'
1536
- summary_df = joined
1537
- if (round_to is not None) and (round_to not in ("None", "none")):
1538
- summary_df = summary_df.round(round_to)
1539
- elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")):
1540
- # Don't round xarray object by default (even with "none")
1541
- decimals = {
1542
- col: 3 if col not in {"ess_bulk", "ess_tail", "r_hat"} else 2 if col == "r_hat" else 0
1543
- for col in summary_df.columns
1544
- }
1545
- summary_df = summary_df.round(decimals)
1546
-
1547
- return summary_df
1548
-
1549
-
1550
- def waic(data, pointwise=None, var_name=None, scale=None, dask_kwargs=None):
1551
- """Compute the widely applicable information criterion.
1552
-
1553
- Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the
1554
- WAIC's standard error and the effective number of parameters.
1555
- Read more theory here https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1004.2316
1556
-
1557
- Parameters
1558
- ----------
1559
- data: obj
1560
- Any object that can be converted to an :class:`arviz.InferenceData` object.
1561
- Refer to documentation of :func:`arviz.convert_to_inference_data` for details.
1562
- pointwise: bool
1563
- If True the pointwise predictive accuracy will be returned. Defaults to
1564
- ``stats.ic_pointwise`` rcParam.
1565
- var_name : str, optional
1566
- The name of the variable in log_likelihood groups storing the pointwise log
1567
- likelihood data to use for waic computation.
1568
- scale: str
1569
- Output scale for WAIC. Available options are:
1570
-
1571
- - `log` : (default) log-score
1572
- - `negative_log` : -1 * log-score
1573
- - `deviance` : -2 * log-score
1574
-
1575
- A higher log-score (or a lower deviance or negative log_score) indicates a model with
1576
- better predictive accuracy.
1577
- dask_kwargs : dict, optional
1578
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
1579
-
1580
- Returns
1581
- -------
1582
- ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
1583
- elpd_waic: approximated expected log pointwise predictive density (elpd)
1584
- se: standard error of the elpd
1585
- p_waic: effective number parameters
1586
- n_samples: number of samples
1587
- n_data_points: number of data points
1588
- warning: bool
1589
- True if posterior variance of the log predictive densities exceeds 0.4
1590
- waic_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
1591
- only if pointwise=True
1592
- scale: scale of the elpd
1593
-
1594
- The returned object has a custom print method that overrides pd.Series method.
1595
-
1596
- See Also
1597
- --------
1598
- loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
1599
- compare : Compare models based on PSIS-LOO-CV or WAIC.
1600
- plot_compare : Summary plot for model comparison.
1601
-
1602
- Examples
1603
- --------
1604
- Calculate WAIC of a model:
1605
-
1606
- .. ipython::
1607
-
1608
- In [1]: import arviz as az
1609
- ...: data = az.load_arviz_data("centered_eight")
1610
- ...: az.waic(data)
1611
-
1612
- Calculate WAIC of a model and return the pointwise values:
1613
-
1614
- .. ipython::
1615
-
1616
- In [2]: data_waic = az.waic(data, pointwise=True)
1617
- ...: data_waic.waic_i
1618
- """
1619
- inference_data = convert_to_inference_data(data)
1620
- log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
1621
- scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
1622
- pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
1623
-
1624
- if scale == "deviance":
1625
- scale_value = -2
1626
- elif scale == "log":
1627
- scale_value = 1
1628
- elif scale == "negative_log":
1629
- scale_value = -1
1630
- else:
1631
- raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
1632
-
1633
- log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
1634
- shape = log_likelihood.shape
1635
- n_samples = shape[-1]
1636
- n_data_points = np.prod(shape[:-1])
1637
-
1638
- ufunc_kwargs = {"n_dims": 1, "ravel": False}
1639
- kwargs = {"input_core_dims": [["__sample__"]]}
1640
- lppd_i = _wrap_xarray_ufunc(
1641
- _logsumexp,
1642
- log_likelihood,
1643
- func_kwargs={"b_inv": n_samples},
1644
- ufunc_kwargs=ufunc_kwargs,
1645
- dask_kwargs=dask_kwargs,
1646
- **kwargs,
1647
- )
1648
-
1649
- vars_lpd = log_likelihood.var(dim="__sample__")
1650
- warn_mg = False
1651
- if np.any(vars_lpd > 0.4):
1652
- warnings.warn(
1653
- (
1654
- "For one or more samples the posterior variance of the log predictive "
1655
- "densities exceeds 0.4. This could be indication of WAIC starting to fail. \n"
1656
- "See http://arxiv.org/abs/1507.04544 for details"
1657
- )
1658
- )
1659
- warn_mg = True
1660
-
1661
- waic_i = scale_value * (lppd_i - vars_lpd)
1662
- waic_se = (n_data_points * np.var(waic_i.values)) ** 0.5
1663
- waic_sum = np.sum(waic_i.values)
1664
- p_waic = np.sum(vars_lpd.values)
1665
-
1666
- if not pointwise:
1667
- return ELPDData(
1668
- data=[waic_sum, waic_se, p_waic, n_samples, n_data_points, warn_mg, scale],
1669
- index=[
1670
- "waic",
1671
- "se",
1672
- "p_waic",
1673
- "n_samples",
1674
- "n_data_points",
1675
- "warning",
1676
- "scale",
1677
- ],
1678
- )
1679
- if np.equal(waic_sum, waic_i).all(): # pylint: disable=no-member
1680
- warnings.warn(
1681
- """The point-wise WAIC is the same with the sum WAIC, please double check
1682
- the Observed RV in your model to make sure it returns element-wise logp.
1683
- """
1684
- )
1685
- return ELPDData(
1686
- data=[
1687
- waic_sum,
1688
- waic_se,
1689
- p_waic,
1690
- n_samples,
1691
- n_data_points,
1692
- warn_mg,
1693
- waic_i.rename("waic_i"),
1694
- scale,
1695
- ],
1696
- index=[
1697
- "elpd_waic",
1698
- "se",
1699
- "p_waic",
1700
- "n_samples",
1701
- "n_data_points",
1702
- "warning",
1703
- "waic_i",
1704
- "scale",
1705
- ],
1706
- )
1707
-
1708
-
1709
- def loo_pit(idata=None, *, y=None, y_hat=None, log_weights=None):
1710
- """Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
1711
-
1712
- Parameters
1713
- ----------
1714
- idata: InferenceData
1715
- :class:`arviz.InferenceData` object.
1716
- y: array, DataArray or str
1717
- Observed data. If str, ``idata`` must be present and contain the observed data group
1718
- y_hat: array, DataArray or str
1719
- Posterior predictive samples for ``y``. It must have the same shape as y plus an
1720
- extra dimension at the end of size n_samples (chains and draws stacked). If str or
1721
- None, ``idata`` must contain the posterior predictive group. If None, y_hat is taken
1722
- equal to y, thus, y must be str too.
1723
- log_weights: array or DataArray
1724
- Smoothed log_weights. It must have the same shape as ``y_hat``
1725
- dask_kwargs : dict, optional
1726
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
1727
-
1728
- Returns
1729
- -------
1730
- loo_pit: array or DataArray
1731
- Value of the LOO-PIT at each observed data point.
1732
-
1733
- See Also
1734
- --------
1735
- plot_loo_pit : Plot Leave-One-Out probability integral transformation (PIT) predictive checks.
1736
- loo : Compute Pareto-smoothed importance sampling leave-one-out
1737
- cross-validation (PSIS-LOO-CV).
1738
- plot_elpd : Plot pointwise elpd differences between two or more models.
1739
- plot_khat : Plot Pareto tail indices for diagnosing convergence.
1740
-
1741
- Examples
1742
- --------
1743
- Calculate LOO-PIT values using as test quantity the observed values themselves.
1744
-
1745
- .. ipython::
1746
-
1747
- In [1]: import arviz as az
1748
- ...: data = az.load_arviz_data("centered_eight")
1749
- ...: az.loo_pit(idata=data, y="obs")
1750
-
1751
- Calculate LOO-PIT values using as test quantity the square of the difference between
1752
- each observation and `mu`. Both ``y`` and ``y_hat`` inputs will be array-like,
1753
- but ``idata`` will still be passed in order to calculate the ``log_weights`` from
1754
- there.
1755
-
1756
- .. ipython::
1757
-
1758
- In [1]: T = data.observed_data.obs - data.posterior.mu.median(dim=("chain", "draw"))
1759
- ...: T_hat = data.posterior_predictive.obs - data.posterior.mu
1760
- ...: T_hat = T_hat.stack(__sample__=("chain", "draw"))
1761
- ...: az.loo_pit(idata=data, y=T**2, y_hat=T_hat**2)
1762
-
1763
- """
1764
- y_str = ""
1765
- if idata is not None and not isinstance(idata, InferenceData):
1766
- raise ValueError("idata must be of type InferenceData or None")
1767
-
1768
- if idata is None:
1769
- if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat, log_weights)):
1770
- raise ValueError(
1771
- "all 3 y, y_hat and log_weights must be array or DataArray when idata is None "
1772
- f"but they are of types {[type(arg) for arg in (y, y_hat, log_weights)]}"
1773
- )
1774
-
1775
- else:
1776
- if y_hat is None and isinstance(y, str):
1777
- y_hat = y
1778
- elif y_hat is None:
1779
- raise ValueError("y_hat cannot be None if y is not a str")
1780
- if isinstance(y, str):
1781
- y_str = y
1782
- y = idata.observed_data[y].values
1783
- elif not isinstance(y, (np.ndarray, xr.DataArray)):
1784
- raise ValueError(f"y must be of types array, DataArray or str, not {type(y)}")
1785
- if isinstance(y_hat, str):
1786
- y_hat = idata.posterior_predictive[y_hat].stack(__sample__=("chain", "draw")).values
1787
- elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
1788
- raise ValueError(f"y_hat must be of types array, DataArray or str, not {type(y_hat)}")
1789
- if log_weights is None:
1790
- if y_str:
1791
- try:
1792
- log_likelihood = _get_log_likelihood(idata, var_name=y_str)
1793
- except TypeError:
1794
- log_likelihood = _get_log_likelihood(idata)
1795
- else:
1796
- log_likelihood = _get_log_likelihood(idata)
1797
- log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
1798
- posterior = convert_to_dataset(idata, group="posterior")
1799
- n_chains = len(posterior.chain)
1800
- n_samples = len(log_likelihood.__sample__)
1801
- ess_p = ess(posterior, method="mean")
1802
- # this mean is over all data variables
1803
- reff = (
1804
- (np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples)
1805
- if n_chains > 1
1806
- else 1
1807
- )
1808
- log_weights = psislw(-log_likelihood, reff=reff)[0].values
1809
- elif not isinstance(log_weights, (np.ndarray, xr.DataArray)):
1810
- raise ValueError(
1811
- f"log_weights must be None or of types array or DataArray, not {type(log_weights)}"
1812
- )
1813
-
1814
- if len(y.shape) + 1 != len(y_hat.shape):
1815
- raise ValueError(
1816
- f"y_hat must have 1 more dimension than y, but y_hat has {len(y_hat.shape)} dims and "
1817
- f"y has {len(y.shape)} dims"
1818
- )
1819
-
1820
- if y.shape != y_hat.shape[:-1]:
1821
- raise ValueError(
1822
- f"y has shape: {y.shape} which should be equal to y_hat shape (omitting the last "
1823
- f"dimension): {y_hat.shape}"
1824
- )
1825
-
1826
- if y_hat.shape != log_weights.shape:
1827
- raise ValueError(
1828
- "y_hat and log_weights must have the same shape but have shapes "
1829
- f"{y_hat.shape,} and {log_weights.shape}"
1830
- )
1831
-
1832
- kwargs = {
1833
- "input_core_dims": [[], ["__sample__"], ["__sample__"]],
1834
- "output_core_dims": [[]],
1835
- "join": "left",
1836
- }
1837
- ufunc_kwargs = {"n_dims": 1}
1838
-
1839
- if y.dtype.kind == "i" or y_hat.dtype.kind == "i":
1840
- y, y_hat = smooth_data(y, y_hat)
1841
-
1842
- return _wrap_xarray_ufunc(
1843
- _loo_pit,
1844
- y,
1845
- y_hat,
1846
- log_weights,
1847
- ufunc_kwargs=ufunc_kwargs,
1848
- **kwargs,
1849
- )
1850
-
1851
-
1852
- def _loo_pit(y, y_hat, log_weights):
1853
- """Compute LOO-PIT values."""
1854
- sel = y_hat <= y
1855
- if np.sum(sel) > 0:
1856
- value = np.exp(_logsumexp(log_weights[sel]))
1857
- return min(1, value)
1858
- else:
1859
- return 0
1860
-
1861
-
1862
- def apply_test_function(
1863
- idata,
1864
- func,
1865
- group="both",
1866
- var_names=None,
1867
- pointwise=False,
1868
- out_data_shape=None,
1869
- out_pp_shape=None,
1870
- out_name_data="T",
1871
- out_name_pp=None,
1872
- func_args=None,
1873
- func_kwargs=None,
1874
- ufunc_kwargs=None,
1875
- wrap_data_kwargs=None,
1876
- wrap_pp_kwargs=None,
1877
- inplace=True,
1878
- overwrite=None,
1879
- ):
1880
- """Apply a Bayesian test function to an InferenceData object.
1881
-
1882
- Parameters
1883
- ----------
1884
- idata: InferenceData
1885
- :class:`arviz.InferenceData` object on which to apply the test function.
1886
- This function will add new variables to the InferenceData object
1887
- to store the result without modifying the existing ones.
1888
- func: callable
1889
- Callable that calculates the test function. It must have the following call signature
1890
- ``func(y, theta, *args, **kwargs)`` (where ``y`` is the observed data or posterior
1891
- predictive and ``theta`` the model parameters) even if not all the arguments are
1892
- used.
1893
- group: str, optional
1894
- Group on which to apply the test function. Can be observed_data, posterior_predictive
1895
- or both.
1896
- var_names: dict group -> var_names, optional
1897
- Mapping from group name to the variables to be passed to func. It can be a dict of
1898
- strings or lists of strings. There is also the option of using ``both`` as key,
1899
- in which case, the same variables are used in observed data and posterior predictive
1900
- groups
1901
- pointwise: bool, optional
1902
- If True, apply the test function to each observation and sample, otherwise, apply
1903
- test function to each sample.
1904
- out_data_shape, out_pp_shape: tuple, optional
1905
- Output shape of the test function applied to the observed/posterior predictive data.
1906
- If None, the default depends on the value of pointwise.
1907
- out_name_data, out_name_pp: str, optional
1908
- Name of the variables to add to the observed_data and posterior_predictive datasets
1909
- respectively. ``out_name_pp`` can be ``None``, in which case will be taken equal to
1910
- ``out_name_data``.
1911
- func_args: sequence, optional
1912
- Passed as is to ``func``
1913
- func_kwargs: mapping, optional
1914
- Passed as is to ``func``
1915
- wrap_data_kwargs, wrap_pp_kwargs: mapping, optional
1916
- kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. By default, some suitable input_core_dims
1917
- are used.
1918
- inplace: bool, optional
1919
- If True, add the variables inplace, otherwise, return a copy of idata with the variables
1920
- added.
1921
- overwrite: bool, optional
1922
- Overwrite data in case ``out_name_data`` or ``out_name_pp`` are already variables in
1923
- dataset. If ``None`` it will be the opposite of inplace.
1924
-
1925
- Returns
1926
- -------
1927
- idata: InferenceData
1928
- Output InferenceData object. If ``inplace=True``, it is the same input object modified
1929
- inplace.
1930
-
1931
- See Also
1932
- --------
1933
- plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
1934
-
1935
- Notes
1936
- -----
1937
- This function is provided for convenience to wrap scalar or functions working on low
1938
- dims to inference data object. It is not optimized to be faster nor as fast as vectorized
1939
- computations.
1940
-
1941
- Examples
1942
- --------
1943
- Use ``apply_test_function`` to wrap ``numpy.min`` for illustration purposes. And plot the
1944
- results.
1945
-
1946
- .. plot::
1947
- :context: close-figs
1948
-
1949
- >>> import arviz as az
1950
- >>> idata = az.load_arviz_data("centered_eight")
1951
- >>> az.apply_test_function(idata, lambda y, theta: np.min(y))
1952
- >>> T = idata.observed_data.T.item()
1953
- >>> az.plot_posterior(idata, var_names=["T"], group="posterior_predictive", ref_val=T)
1954
-
1955
- """
1956
- out = idata if inplace else deepcopy(idata)
1957
-
1958
- valid_groups = ("observed_data", "posterior_predictive", "both")
1959
- if group not in valid_groups:
1960
- raise ValueError(f"Invalid group argument. Must be one of {valid_groups} not {group}.")
1961
- if overwrite is None:
1962
- overwrite = not inplace
1963
-
1964
- if out_name_pp is None:
1965
- out_name_pp = out_name_data
1966
-
1967
- if func_args is None:
1968
- func_args = tuple()
1969
-
1970
- if func_kwargs is None:
1971
- func_kwargs = {}
1972
-
1973
- if ufunc_kwargs is None:
1974
- ufunc_kwargs = {}
1975
- ufunc_kwargs.setdefault("check_shape", False)
1976
- ufunc_kwargs.setdefault("ravel", False)
1977
-
1978
- if wrap_data_kwargs is None:
1979
- wrap_data_kwargs = {}
1980
- if wrap_pp_kwargs is None:
1981
- wrap_pp_kwargs = {}
1982
- if var_names is None:
1983
- var_names = {}
1984
-
1985
- both_var_names = var_names.pop("both", None)
1986
- var_names.setdefault("posterior", list(out.posterior.data_vars))
1987
-
1988
- in_posterior = out.posterior[var_names["posterior"]]
1989
- if isinstance(in_posterior, xr.Dataset):
1990
- in_posterior = in_posterior.to_array().squeeze()
1991
-
1992
- groups = ("posterior_predictive", "observed_data") if group == "both" else [group]
1993
- for grp in groups:
1994
- out_group_shape = out_data_shape if grp == "observed_data" else out_pp_shape
1995
- out_name_group = out_name_data if grp == "observed_data" else out_name_pp
1996
- wrap_group_kwargs = wrap_data_kwargs if grp == "observed_data" else wrap_pp_kwargs
1997
- if not hasattr(out, grp):
1998
- raise ValueError(f"InferenceData object must have {grp} group")
1999
- if not overwrite and out_name_group in getattr(out, grp).data_vars:
2000
- raise ValueError(
2001
- f"Should overwrite: {out_name_group} variable present in group {grp},"
2002
- " but overwrite is False"
2003
- )
2004
- var_names.setdefault(
2005
- grp, list(getattr(out, grp).data_vars) if both_var_names is None else both_var_names
2006
- )
2007
- in_group = getattr(out, grp)[var_names[grp]]
2008
- if isinstance(in_group, xr.Dataset):
2009
- in_group = in_group.to_array(dim=f"{grp}_var").squeeze()
2010
-
2011
- if pointwise:
2012
- out_group_shape = in_group.shape if out_group_shape is None else out_group_shape
2013
- elif grp == "observed_data":
2014
- out_group_shape = () if out_group_shape is None else out_group_shape
2015
- elif grp == "posterior_predictive":
2016
- out_group_shape = in_group.shape[:2] if out_group_shape is None else out_group_shape
2017
- loop_dims = in_group.dims[: len(out_group_shape)]
2018
-
2019
- wrap_group_kwargs.setdefault(
2020
- "input_core_dims",
2021
- [
2022
- [dim for dim in dataset.dims if dim not in loop_dims]
2023
- for dataset in [in_group, in_posterior]
2024
- ],
2025
- )
2026
- func_kwargs["out"] = np.empty(out_group_shape)
2027
-
2028
- out_group = getattr(out, grp)
2029
- try:
2030
- out_group[out_name_group] = _wrap_xarray_ufunc(
2031
- func,
2032
- in_group.values,
2033
- in_posterior.values,
2034
- func_args=func_args,
2035
- func_kwargs=func_kwargs,
2036
- ufunc_kwargs=ufunc_kwargs,
2037
- **wrap_group_kwargs,
2038
- )
2039
- except IndexError:
2040
- excluded_dims = set(
2041
- wrap_group_kwargs["input_core_dims"][0] + wrap_group_kwargs["input_core_dims"][1]
2042
- )
2043
- out_group[out_name_group] = _wrap_xarray_ufunc(
2044
- func,
2045
- *xr.broadcast(in_group, in_posterior, exclude=excluded_dims),
2046
- func_args=func_args,
2047
- func_kwargs=func_kwargs,
2048
- ufunc_kwargs=ufunc_kwargs,
2049
- **wrap_group_kwargs,
2050
- )
2051
- setattr(out, grp, out_group)
2052
-
2053
- return out
2054
-
2055
-
2056
- def weight_predictions(idatas, weights=None):
2057
- """
2058
- Generate weighted posterior predictive samples from a list of InferenceData
2059
- and a set of weights.
2060
-
2061
- Parameters
2062
- ---------
2063
- idatas : list[InferenceData]
2064
- List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive`
2065
- and `observed_data`. Observations should be the same for all InferenceData objects.
2066
- weights : array-like, optional
2067
- Individual weights for each model. Weights should be positive. If they do not sum up to 1,
2068
- they will be normalized. Default, same weight for each model.
2069
- Weights can be computed using many different methods including those in
2070
- :func:`arviz.compare`.
2071
-
2072
- Returns
2073
- -------
2074
- idata: InferenceData
2075
- Output InferenceData object with the groups `posterior_predictive` and `observed_data`.
2076
-
2077
- See Also
2078
- --------
2079
- compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation
2080
- """
2081
- if len(idatas) < 2:
2082
- raise ValueError("You should provide a list with at least two InferenceData objects")
2083
-
2084
- if not all("posterior_predictive" in idata.groups() for idata in idatas):
2085
- raise ValueError(
2086
- "All the InferenceData objects must contain the `posterior_predictive` group"
2087
- )
2088
-
2089
- if not all(idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]):
2090
- raise ValueError("The observed data should be the same for all InferenceData objects")
2091
-
2092
- if weights is None:
2093
- weights = np.ones(len(idatas)) / len(idatas)
2094
- elif len(idatas) != len(weights):
2095
- raise ValueError(
2096
- "The number of weights should be the same as the number of InferenceData objects"
2097
- )
2098
-
2099
- weights = np.array(weights, dtype=float)
2100
- weights /= weights.sum()
2101
-
2102
- len_idatas = [
2103
- idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
2104
- for idata in idatas
2105
- ]
2106
-
2107
- if not all(len_idatas):
2108
- raise ValueError("At least one of your idatas has 0 samples")
2109
-
2110
- new_samples = (np.min(len_idatas) * weights).astype(int)
2111
-
2112
- new_idatas = [
2113
- extract(idata, group="posterior_predictive", num_samples=samples).reset_coords()
2114
- for samples, idata in zip(new_samples, idatas)
2115
- ]
2116
-
2117
- weighted_samples = InferenceData(
2118
- posterior_predictive=xr.concat(new_idatas, dim="sample"),
2119
- observed_data=idatas[0].observed_data,
2120
- )
2121
-
2122
- return weighted_samples
2123
-
2124
-
2125
- def psens(
2126
- data,
2127
- *,
2128
- component="prior",
2129
- component_var_names=None,
2130
- component_coords=None,
2131
- var_names=None,
2132
- coords=None,
2133
- filter_vars=None,
2134
- delta=0.01,
2135
- dask_kwargs=None,
2136
- ):
2137
- """Compute power-scaling sensitivity diagnostic.
2138
-
2139
- Power-scales the prior or likelihood and calculates how much the posterior is affected.
2140
-
2141
- Parameters
2142
- ----------
2143
- data : obj
2144
- Any object that can be converted to an :class:`arviz.InferenceData` object.
2145
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
2146
- For ndarray: shape = (chain, draw).
2147
- For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
2148
- component : {"prior", "likelihood"}, default "prior"
2149
- When `component` is "likelihood", the log likelihood values are retrieved
2150
- from the ``log_likelihood`` group as pointwise log likelihood and added
2151
- together. With "prior", the log prior values are retrieved from the
2152
- ``log_prior`` group.
2153
- component_var_names : str, optional
2154
- Name of the prior or log likelihood variables to use
2155
- component_coords : dict, optional
2156
- Coordinates defining a subset over the component element for which to
2157
- compute the prior sensitivity diagnostic.
2158
- var_names : list of str, optional
2159
- Names of posterior variables to include in the power scaling sensitivity diagnostic
2160
- coords : dict, optional
2161
- Coordinates defining a subset over the posterior. Only these variables will
2162
- be used when computing the prior sensitivity.
2163
- filter_vars: {None, "like", "regex"}, default None
2164
- If ``None`` (default), interpret var_names as the real variables names.
2165
- If "like", interpret var_names as substrings of the real variables names.
2166
- If "regex", interpret var_names as regular expressions on the real variables names.
2167
- delta : float
2168
- Value for finite difference derivative calculation.
2169
- dask_kwargs : dict, optional
2170
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
2171
-
2172
- Returns
2173
- -------
2174
- xarray.Dataset
2175
- Returns dataset of power-scaling sensitivity diagnostic values.
2176
- Higher sensitivity values indicate greater sensitivity.
2177
- Prior sensitivity above 0.05 indicates informative prior.
2178
- Likelihood sensitivity below 0.05 indicates weak or nonin-formative likelihood.
2179
-
2180
- Examples
2181
- --------
2182
- Compute the likelihood sensitivity for the non centered eight model:
2183
-
2184
- .. ipython::
2185
-
2186
- In [1]: import arviz as az
2187
- ...: data = az.load_arviz_data("non_centered_eight")
2188
- ...: az.psens(data, component="likelihood")
2189
-
2190
- To compute the prior sensitivity, we need to first compute the log prior
2191
- at each posterior sample. In our case, we know mu has a normal prior :math:`N(0, 5)`,
2192
- tau is a half cauchy prior with scale/beta parameter 5,
2193
- and theta has a standard normal as prior.
2194
- We add this information to the ``log_prior`` group before computing powerscaling
2195
- check with ``psens``
2196
-
2197
- .. ipython::
2198
-
2199
- In [1]: from xarray_einstats.stats import XrContinuousRV
2200
- ...: from scipy.stats import norm, halfcauchy
2201
- ...: post = data.posterior
2202
- ...: log_prior = {
2203
- ...: "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
2204
- ...: "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
2205
- ...: "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
2206
- ...: }
2207
- ...: data.add_groups({"log_prior": log_prior})
2208
- ...: az.psens(data, component="prior")
2209
-
2210
- Notes
2211
- -----
2212
- The diagnostic is computed by power-scaling the specified component (prior or likelihood)
2213
- and determining the degree to which the posterior changes as described in [1]_.
2214
- It uses Pareto-smoothed importance sampling to avoid refitting the model.
2215
-
2216
- References
2217
- ----------
2218
- .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
2219
- power-scaling*, 2022, https://arxiv.org/abs/2107.14054
2220
-
2221
- """
2222
- dataset = extract(data, var_names=var_names, filter_vars=filter_vars, group="posterior")
2223
- if coords is None:
2224
- dataset = dataset.sel(coords)
2225
-
2226
- if component == "likelihood":
2227
- component_draws = _get_log_likelihood(data, var_name=component_var_names, single_var=False)
2228
- elif component == "prior":
2229
- component_draws = _get_log_prior(data, var_names=component_var_names)
2230
- else:
2231
- raise ValueError("Value for `component` argument not recognized")
2232
-
2233
- component_draws = component_draws.stack(__sample__=("chain", "draw"))
2234
- if component_coords is None:
2235
- component_draws = component_draws.sel(component_coords)
2236
- if isinstance(component_draws, xr.DataArray):
2237
- component_draws = component_draws.to_dataset()
2238
- if len(component_draws.dims):
2239
- component_draws = component_draws.to_stacked_array(
2240
- "latent-obs_var", sample_dims=("__sample__",)
2241
- ).sum("latent-obs_var")
2242
- # from here component_draws is a 1d object with dimensions (sample,)
2243
-
2244
- # calculate lower and upper alpha values
2245
- lower_alpha = 1 / (1 + delta)
2246
- upper_alpha = 1 + delta
2247
-
2248
- # calculate importance sampling weights for lower and upper alpha power-scaling
2249
- lower_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=lower_alpha))
2250
- lower_w = lower_w / np.sum(lower_w)
2251
-
2252
- upper_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=upper_alpha))
2253
- upper_w = upper_w / np.sum(upper_w)
2254
-
2255
- ufunc_kwargs = {"n_dims": 1, "ravel": False}
2256
- func_kwargs = {"lower_weights": lower_w.values, "upper_weights": upper_w.values, "delta": delta}
2257
-
2258
- # calculate the sensitivity diagnostic based on the importance weights and draws
2259
- return _wrap_xarray_ufunc(
2260
- _powerscale_sens,
2261
- dataset,
2262
- ufunc_kwargs=ufunc_kwargs,
2263
- func_kwargs=func_kwargs,
2264
- dask_kwargs=dask_kwargs,
2265
- input_core_dims=[["sample"]],
2266
- )
2267
-
2268
-
2269
- def _powerscale_sens(draws, *, lower_weights=None, upper_weights=None, delta=0.01):
2270
- """
2271
- Calculate power-scaling sensitivity by finite difference
2272
- second derivative of CJS
2273
- """
2274
- lower_cjs = max(
2275
- _cjs_dist(draws=draws, weights=lower_weights),
2276
- _cjs_dist(draws=-1 * draws, weights=lower_weights),
2277
- )
2278
- upper_cjs = max(
2279
- _cjs_dist(draws=draws, weights=upper_weights),
2280
- _cjs_dist(draws=-1 * draws, weights=upper_weights),
2281
- )
2282
- logdiffsquare = 2 * np.log2(1 + delta)
2283
- grad = (lower_cjs + upper_cjs) / logdiffsquare
2284
-
2285
- return grad
2286
-
2287
-
2288
- def _powerscale_lw(alpha, component_draws):
2289
- """
2290
- Calculate log weights for power-scaling component by alpha.
2291
- """
2292
- log_weights = (alpha - 1) * component_draws
2293
- log_weights = psislw(log_weights)[0]
2294
-
2295
- return log_weights
2296
-
2297
-
2298
- def _cjs_dist(draws, weights):
2299
- """
2300
- Calculate the cumulative Jensen-Shannon distance between original draws and weighted draws.
2301
- """
2302
-
2303
- # sort draws and weights
2304
- order = np.argsort(draws)
2305
- draws = draws[order]
2306
- weights = weights[order]
2307
-
2308
- binwidth = np.diff(draws)
2309
-
2310
- # ecdfs
2311
- cdf_p = np.linspace(1 / len(draws), 1 - 1 / len(draws), len(draws) - 1)
2312
- cdf_q = np.cumsum(weights / np.sum(weights))[:-1]
2313
-
2314
- # integrals of ecdfs
2315
- cdf_p_int = np.dot(cdf_p, binwidth)
2316
- cdf_q_int = np.dot(cdf_q, binwidth)
2317
-
2318
- # cjs calculation
2319
- pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
2320
- qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)
2321
-
2322
- denom = 0.5 * (cdf_p + cdf_q)
2323
- denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)
2324
-
2325
- cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
2326
- cdf_q_int - cdf_p_int
2327
- )
2328
-
2329
- cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
2330
- cdf_p_int - cdf_q_int
2331
- )
2332
-
2333
- cjs_pq = max(0, cjs_pq)
2334
- cjs_qp = max(0, cjs_qp)
2335
-
2336
- bound = cdf_p_int + cdf_q_int
2337
-
2338
- return np.sqrt((cjs_pq + cjs_qp) / bound)
2339
-
2340
-
2341
- def bayes_factor(idata, var_name, ref_val=0, prior=None, return_ref_vals=False):
2342
- r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
2343
-
2344
- The Bayes factor is estimated by comparing a model (H1) against a model in which the
2345
- parameter of interest has been restricted to be a point-null (H0). This computation
2346
- assumes the models are nested and thus H0 is a special case of H1.
2347
-
2348
- Notes
2349
- -----
2350
- The bayes Factor is approximated as the Savage-Dickey density ratio
2351
- algorithm presented in [1]_.
2352
-
2353
- Parameters
2354
- ----------
2355
- idata : InferenceData
2356
- Any object that can be converted to an :class:`arviz.InferenceData` object
2357
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
2358
- var_name : str, optional
2359
- Name of variable we want to test.
2360
- ref_val : int, default 0
2361
- Point-null for Bayes factor estimation.
2362
- prior : numpy.array, optional
2363
- In case we want to use different prior, for example for sensitivity analysis.
2364
- return_ref_vals : bool, optional
2365
- Whether to return the values of the prior and posterior at the reference value.
2366
- Used by :func:`arviz.plot_bf` to display the distribution comparison.
2367
-
2368
-
2369
- Returns
2370
- -------
2371
- dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
2372
-
2373
- References
2374
- ----------
2375
- .. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
2376
- The case of computing Bayes factors for regression parameters.
2377
-
2378
- Examples
2379
- --------
2380
- Moderate evidence indicating that the parameter "a" is different from zero.
2381
-
2382
- .. ipython::
2383
-
2384
- In [1]: import numpy as np
2385
- ...: import arviz as az
2386
- ...: idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
2387
- ...: prior={"a":np.random.normal(0, 1, 5000)})
2388
- ...: az.bayes_factor(idata, var_name="a", ref_val=0)
2389
-
2390
- """
2391
-
2392
- posterior = extract(idata, var_names=var_name).values
2393
-
2394
- if ref_val > posterior.max() or ref_val < posterior.min():
2395
- _log.warning(
2396
- "The reference value is outside of the posterior. "
2397
- "This translate into infinite support for H1, which is most likely an overstatement."
2398
- )
2399
-
2400
- if posterior.ndim > 1:
2401
- _log.warning("Posterior distribution has {posterior.ndim} dimensions")
2402
-
2403
- if prior is None:
2404
- prior = extract(idata, var_names=var_name, group="prior").values
2405
-
2406
- if posterior.dtype.kind == "f":
2407
- posterior_grid, posterior_pdf, *_ = _kde_linear(posterior)
2408
- prior_grid, prior_pdf, *_ = _kde_linear(prior)
2409
- posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
2410
- prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)
2411
-
2412
- elif posterior.dtype.kind == "i":
2413
- posterior_at_ref_val = (posterior == ref_val).mean()
2414
- prior_at_ref_val = (prior == ref_val).mean()
2415
-
2416
- bf_10 = prior_at_ref_val / posterior_at_ref_val
2417
- bf = {"BF10": bf_10, "BF01": 1 / bf_10}
2418
-
2419
- if return_ref_vals:
2420
- return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
2421
- else:
2422
- return bf