arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,1013 +0,0 @@
1
- # pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name
2
- """Diagnostic functions for ArviZ."""
3
- import warnings
4
- from collections.abc import Sequence
5
-
6
- import numpy as np
7
- import packaging
8
- import pandas as pd
9
- import scipy
10
- from scipy import stats
11
-
12
- from ..data import convert_to_dataset
13
- from ..utils import Numba, _numba_var, _stack, _var_names
14
- from .density_utils import histogram as _histogram
15
- from .stats_utils import _circular_standard_deviation, _sqrt
16
- from .stats_utils import autocov as _autocov
17
- from .stats_utils import not_valid as _not_valid
18
- from .stats_utils import quantile as _quantile
19
- from .stats_utils import stats_variance_2d as svar
20
- from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
21
-
22
- __all__ = ["bfmi", "ess", "rhat", "mcse"]
23
-
24
-
25
- def bfmi(data):
26
- r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
27
-
28
- BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
29
- information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
30
- values smaller than 0.3 indicate poor sampling. However, this threshold is
31
- provisional and may change. See
32
- `pystan_workflow <http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html>`_
33
- for more information.
34
-
35
- Parameters
36
- ----------
37
- data : obj
38
- Any object that can be converted to an :class:`arviz.InferenceData` object.
39
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
40
- If InferenceData, energy variable needs to be found.
41
-
42
- Returns
43
- -------
44
- z : array
45
- The Bayesian fraction of missing information of the model and trace. One element per
46
- chain in the trace.
47
-
48
- See Also
49
- --------
50
- plot_energy : Plot energy transition distribution and marginal energy
51
- distribution in HMC algorithms.
52
-
53
- Examples
54
- --------
55
- Compute the BFMI of an InferenceData object
56
-
57
- .. ipython::
58
-
59
- In [1]: import arviz as az
60
- ...: data = az.load_arviz_data('radon')
61
- ...: az.bfmi(data)
62
-
63
- """
64
- if isinstance(data, np.ndarray):
65
- return _bfmi(data)
66
-
67
- dataset = convert_to_dataset(data, group="sample_stats")
68
- if not hasattr(dataset, "energy"):
69
- raise TypeError("Energy variable was not found.")
70
- return _bfmi(dataset.energy.transpose("chain", "draw"))
71
-
72
-
73
- def ess(
74
- data,
75
- *,
76
- var_names=None,
77
- method="bulk",
78
- relative=False,
79
- prob=None,
80
- dask_kwargs=None,
81
- ):
82
- r"""Calculate estimate of the effective sample size (ess).
83
-
84
- Parameters
85
- ----------
86
- data : obj
87
- Any object that can be converted to an :class:`arviz.InferenceData` object.
88
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
89
- For ndarray: shape = (chain, draw).
90
- For n-dimensional ndarray transform first to dataset with :func:`arviz.convert_to_dataset`.
91
- var_names : str or list of str
92
- Names of variables to include in the return value Dataset.
93
- method : str, optional, default "bulk"
94
- Select ess method. Valid methods are:
95
-
96
- - "bulk"
97
- - "tail" # prob, optional
98
- - "quantile" # prob
99
- - "mean" (old ess)
100
- - "sd"
101
- - "median"
102
- - "mad" (mean absolute deviance)
103
- - "z_scale"
104
- - "folded"
105
- - "identity"
106
- - "local"
107
- relative : bool
108
- Return relative ess
109
- ``ress = ess / n``
110
- prob : float, or tuple of two floats, optional
111
- probability value for "tail", "quantile" or "local" ess functions.
112
- dask_kwargs : dict, optional
113
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
114
-
115
- Returns
116
- -------
117
- xarray.Dataset
118
- Return the effective sample size, :math:`\hat{N}_{eff}`
119
-
120
- Notes
121
- -----
122
- The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by:
123
-
124
- .. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}
125
-
126
- .. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}
127
-
128
- where :math:`M` is the number of chains, :math:`N` the number of draws,
129
- :math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
130
- :math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} +
131
- \hat{\rho}_{2K+1}` is still positive.
132
-
133
- The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
134
- criterion (Geyer, 1992; Geyer, 2011).
135
-
136
- References
137
- ----------
138
- * Vehtari et al. (2021). Rank-normalization, folding, and
139
- localization: An improved Rhat for assessing convergence of
140
- MCMC. Bayesian analysis, 16(2):667-718.
141
- * https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
142
- * Gelman et al. BDA3 (2013) Formula 11.8
143
-
144
- See Also
145
- --------
146
- arviz.rhat : Compute estimate of rank normalized splitR-hat for a set of traces.
147
- arviz.mcse : Calculate Markov Chain Standard Error statistic.
148
- plot_ess : Plot quantile, local or evolution of effective sample sizes (ESS).
149
- arviz.summary : Create a data frame with summary statistics.
150
-
151
- Examples
152
- --------
153
- Calculate the effective_sample_size using the default arguments:
154
-
155
- .. ipython::
156
-
157
- In [1]: import arviz as az
158
- ...: data = az.load_arviz_data('non_centered_eight')
159
- ...: az.ess(data)
160
-
161
- Calculate the ress of some of the variables
162
-
163
- .. ipython::
164
-
165
- In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"])
166
-
167
- Calculate the ess using the "tail" method, leaving the `prob` argument at its default
168
- value.
169
-
170
- .. ipython::
171
-
172
- In [1]: az.ess(data, method="tail")
173
-
174
- """
175
- methods = {
176
- "bulk": _ess_bulk,
177
- "tail": _ess_tail,
178
- "quantile": _ess_quantile,
179
- "mean": _ess_mean,
180
- "sd": _ess_sd,
181
- "median": _ess_median,
182
- "mad": _ess_mad,
183
- "z_scale": _ess_z_scale,
184
- "folded": _ess_folded,
185
- "identity": _ess_identity,
186
- "local": _ess_local,
187
- }
188
-
189
- if method not in methods:
190
- raise TypeError(f"ess method {method} not found. Valid methods are:\n{', '.join(methods)}")
191
- ess_func = methods[method]
192
-
193
- if (method == "quantile") and prob is None:
194
- raise TypeError("Quantile (prob) information needs to be defined.")
195
-
196
- if isinstance(data, np.ndarray):
197
- data = np.atleast_2d(data)
198
- if len(data.shape) < 3:
199
- if prob is not None:
200
- return ess_func( # pylint: disable=unexpected-keyword-arg
201
- data, prob=prob, relative=relative
202
- )
203
-
204
- return ess_func(data, relative=relative)
205
-
206
- msg = (
207
- "Only uni-dimensional ndarray variables are supported."
208
- " Please transform first to dataset with `az.convert_to_dataset`."
209
- )
210
- raise TypeError(msg)
211
-
212
- dataset = convert_to_dataset(data, group="posterior")
213
- var_names = _var_names(var_names, dataset)
214
-
215
- dataset = dataset if var_names is None else dataset[var_names]
216
-
217
- ufunc_kwargs = {"ravel": False}
218
- func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative}
219
- return _wrap_xarray_ufunc(
220
- ess_func,
221
- dataset,
222
- ufunc_kwargs=ufunc_kwargs,
223
- func_kwargs=func_kwargs,
224
- dask_kwargs=dask_kwargs,
225
- )
226
-
227
-
228
- def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
229
- r"""Compute estimate of rank normalized splitR-hat for a set of traces.
230
-
231
- The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance
232
- between multiple chains to the variance within each chain. If convergence has been achieved,
233
- the between-chain and within-chain variances should be identical. To be most effective in
234
- detecting evidence for nonconvergence, each chain should have been initialized to starting
235
- values that are dispersed relative to the target distribution.
236
-
237
- Parameters
238
- ----------
239
- data : obj
240
- Any object that can be converted to an :class:`arviz.InferenceData` object.
241
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
242
- At least 2 posterior chains are needed to compute this diagnostic of one or more
243
- stochastic parameters.
244
- For ndarray: shape = (chain, draw).
245
- For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
246
- var_names : list
247
- Names of variables to include in the rhat report
248
- method : str
249
- Select R-hat method. Valid methods are:
250
- - "rank" # recommended by Vehtari et al. (2021)
251
- - "split"
252
- - "folded"
253
- - "z_scale"
254
- - "identity"
255
- dask_kwargs : dict, optional
256
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
257
-
258
- Returns
259
- -------
260
- xarray.Dataset
261
- Returns dataset of the potential scale reduction factors, :math:`\hat{R}`
262
-
263
- See Also
264
- --------
265
- ess : Calculate estimate of the effective sample size (ess).
266
- mcse : Calculate Markov Chain Standard Error statistic.
267
- plot_forest : Forest plot to compare HDI intervals from a number of distributions.
268
-
269
- Notes
270
- -----
271
- The diagnostic is computed by:
272
-
273
- .. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
274
-
275
- where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
276
- estimate for the pooled rank-traces. This is the potential scale reduction factor, which
277
- converges to unity when each of the traces is a sample from the target posterior. Values
278
- greater than one indicate that one or more chains have not yet converged.
279
-
280
- Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
281
- Each chain is split in two and normalized with the z-transform following
282
- Vehtari et al. (2021).
283
-
284
- References
285
- ----------
286
- * Vehtari et al. (2021). Rank-normalization, folding, and
287
- localization: An improved Rhat for assessing convergence of
288
- MCMC. Bayesian analysis, 16(2):667-718.
289
- * Gelman et al. BDA3 (2013)
290
- * Brooks and Gelman (1998)
291
- * Gelman and Rubin (1992)
292
-
293
- Examples
294
- --------
295
- Calculate the R-hat using the default arguments:
296
-
297
- .. ipython::
298
-
299
- In [1]: import arviz as az
300
- ...: data = az.load_arviz_data("non_centered_eight")
301
- ...: az.rhat(data)
302
-
303
- Calculate the R-hat of some variables using the folded method:
304
-
305
- .. ipython::
306
-
307
- In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded")
308
-
309
- """
310
- methods = {
311
- "rank": _rhat_rank,
312
- "split": _rhat_split,
313
- "folded": _rhat_folded,
314
- "z_scale": _rhat_z_scale,
315
- "identity": _rhat_identity,
316
- }
317
- if method not in methods:
318
- raise TypeError(
319
- f"R-hat method {method} not found. Valid methods are:\n{', '.join(methods)}"
320
- )
321
- rhat_func = methods[method]
322
-
323
- if isinstance(data, np.ndarray):
324
- data = np.atleast_2d(data)
325
- if len(data.shape) < 3:
326
- return rhat_func(data)
327
-
328
- msg = (
329
- "Only uni-dimensional ndarray variables are supported."
330
- " Please transform first to dataset with `az.convert_to_dataset`."
331
- )
332
- raise TypeError(msg)
333
-
334
- dataset = convert_to_dataset(data, group="posterior")
335
- var_names = _var_names(var_names, dataset)
336
-
337
- dataset = dataset if var_names is None else dataset[var_names]
338
-
339
- ufunc_kwargs = {"ravel": False}
340
- func_kwargs = {}
341
- return _wrap_xarray_ufunc(
342
- rhat_func,
343
- dataset,
344
- ufunc_kwargs=ufunc_kwargs,
345
- func_kwargs=func_kwargs,
346
- dask_kwargs=dask_kwargs,
347
- )
348
-
349
-
350
- def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
351
- """Calculate Markov Chain Standard Error statistic.
352
-
353
- Parameters
354
- ----------
355
- data : obj
356
- Any object that can be converted to an :class:`arviz.InferenceData` object
357
- Refer to documentation of :func:`arviz.convert_to_dataset` for details
358
- For ndarray: shape = (chain, draw).
359
- For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
360
- var_names : list
361
- Names of variables to include in the rhat report
362
- method : str
363
- Select mcse method. Valid methods are:
364
- - "mean"
365
- - "sd"
366
- - "median"
367
- - "quantile"
368
-
369
- prob : float
370
- Quantile information.
371
- dask_kwargs : dict, optional
372
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
373
-
374
- Returns
375
- -------
376
- xarray.Dataset
377
- Return the msce dataset
378
-
379
- See Also
380
- --------
381
- ess : Compute autocovariance estimates for every lag for the input array.
382
- summary : Create a data frame with summary statistics.
383
- plot_mcse : Plot quantile or local Monte Carlo Standard Error.
384
-
385
- Examples
386
- --------
387
- Calculate the Markov Chain Standard Error using the default arguments:
388
-
389
- .. ipython::
390
-
391
- In [1]: import arviz as az
392
- ...: data = az.load_arviz_data("non_centered_eight")
393
- ...: az.mcse(data)
394
-
395
- Calculate the Markov Chain Standard Error using the quantile method:
396
-
397
- .. ipython::
398
-
399
- In [1]: az.mcse(data, method="quantile", prob=0.7)
400
-
401
- """
402
- methods = {
403
- "mean": _mcse_mean,
404
- "sd": _mcse_sd,
405
- "median": _mcse_median,
406
- "quantile": _mcse_quantile,
407
- }
408
- if method not in methods:
409
- raise TypeError(
410
- "mcse method {} not found. Valid methods are:\n{}".format(
411
- method, "\n ".join(methods)
412
- )
413
- )
414
- mcse_func = methods[method]
415
-
416
- if method == "quantile" and prob is None:
417
- raise TypeError("Quantile (prob) information needs to be defined.")
418
-
419
- if isinstance(data, np.ndarray):
420
- data = np.atleast_2d(data)
421
- if len(data.shape) < 3:
422
- if prob is not None:
423
- return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg
424
-
425
- return mcse_func(data)
426
-
427
- msg = (
428
- "Only uni-dimensional ndarray variables are supported."
429
- " Please transform first to dataset with `az.convert_to_dataset`."
430
- )
431
- raise TypeError(msg)
432
-
433
- dataset = convert_to_dataset(data, group="posterior")
434
- var_names = _var_names(var_names, dataset)
435
-
436
- dataset = dataset if var_names is None else dataset[var_names]
437
-
438
- ufunc_kwargs = {"ravel": False}
439
- func_kwargs = {} if prob is None else {"prob": prob}
440
- return _wrap_xarray_ufunc(
441
- mcse_func,
442
- dataset,
443
- ufunc_kwargs=ufunc_kwargs,
444
- func_kwargs=func_kwargs,
445
- dask_kwargs=dask_kwargs,
446
- )
447
-
448
-
449
- def ks_summary(pareto_tail_indices):
450
- """Display a summary of Pareto tail indices.
451
-
452
- Parameters
453
- ----------
454
- pareto_tail_indices : array
455
- Pareto tail indices.
456
-
457
- Returns
458
- -------
459
- df_k : dataframe
460
- Dataframe containing k diagnostic values.
461
- """
462
- _numba_flag = Numba.numba_flag
463
- if _numba_flag:
464
- bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
465
- kcounts, *_ = _histogram(pareto_tail_indices, bins)
466
- else:
467
- kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
468
- kprop = kcounts / len(pareto_tail_indices) * 100
469
- df_k = pd.DataFrame(
470
- dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
471
- ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: " (0.7, 1]", 3: " (1, Inf)"})
472
-
473
- if np.sum(kcounts[1:]) == 0:
474
- warnings.warn("All Pareto k estimates are good (k < 0.5)")
475
- elif np.sum(kcounts[2:]) == 0:
476
- warnings.warn("All Pareto k estimates are ok (k < 0.7)")
477
-
478
- return df_k
479
-
480
-
481
- def _bfmi(energy):
482
- r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
483
-
484
- BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
485
- information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
486
- values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
487
- change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
488
- information.
489
-
490
- Parameters
491
- ----------
492
- energy : NumPy array
493
- Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically,
494
- after converting a trace or fit to InferenceData, the energy will be in
495
- `data.sample_stats.energy`.
496
-
497
- Returns
498
- -------
499
- z : array
500
- The Bayesian fraction of missing information of the model and trace. One element per
501
- chain in the trace.
502
- """
503
- energy_mat = np.atleast_2d(energy)
504
- num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1) # pylint: disable=no-member
505
- if energy_mat.ndim == 2:
506
- den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1)
507
- else:
508
- den = np.var(energy, axis=1, ddof=1)
509
- return num / den
510
-
511
-
512
- def _backtransform_ranks(arr, c=3 / 8): # pylint: disable=invalid-name
513
- """Backtransformation of ranks.
514
-
515
- Parameters
516
- ----------
517
- arr : np.ndarray
518
- Ranks array
519
- c : float
520
- Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).
521
-
522
- Returns
523
- -------
524
- np.ndarray
525
-
526
- References
527
- ----------
528
- Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
529
- """
530
- arr = np.asarray(arr)
531
- size = arr.size
532
- return (arr - c) / (size - 2 * c + 1)
533
-
534
-
535
- def _z_scale(ary):
536
- """Calculate z_scale.
537
-
538
- Parameters
539
- ----------
540
- ary : np.ndarray
541
-
542
- Returns
543
- -------
544
- np.ndarray
545
- """
546
- ary = np.asarray(ary)
547
- if packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
548
- rank = stats.rankdata(ary, method="average")
549
- else:
550
- # the .ravel part is only needed to overcom a bug in scipy 1.10.0.rc1
551
- rank = stats.rankdata( # pylint: disable=unexpected-keyword-arg
552
- ary, method="average", nan_policy="omit"
553
- )
554
- rank = _backtransform_ranks(rank)
555
- z = stats.norm.ppf(rank)
556
- z = z.reshape(ary.shape)
557
- return z
558
-
559
-
560
- def _split_chains(ary):
561
- """Split and stack chains."""
562
- ary = np.asarray(ary)
563
- if len(ary.shape) <= 1:
564
- ary = np.atleast_2d(ary)
565
- _, n_draw = ary.shape
566
- half = n_draw // 2
567
- return _stack(ary[:, :half], ary[:, -half:])
568
-
569
-
570
- def _z_fold(ary):
571
- """Fold and z-scale values."""
572
- ary = np.asarray(ary)
573
- ary = abs(ary - np.median(ary))
574
- ary = _z_scale(ary)
575
- return ary
576
-
577
-
578
- def _rhat(ary):
579
- """Compute the rhat for a 2d array."""
580
- _numba_flag = Numba.numba_flag
581
- ary = np.asarray(ary, dtype=float)
582
- if _not_valid(ary, check_shape=False):
583
- return np.nan
584
- _, num_samples = ary.shape
585
-
586
- # Calculate chain mean
587
- chain_mean = np.mean(ary, axis=1)
588
- # Calculate chain variance
589
- chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
590
- # Calculate between-chain variance
591
- between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
592
- # Calculate within-chain variance
593
- within_chain_variance = np.mean(chain_var)
594
- # Estimate of marginal posterior variance
595
- rhat_value = np.sqrt(
596
- (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
597
- )
598
- return rhat_value
599
-
600
-
601
- def _rhat_rank(ary):
602
- """Compute the rank normalized rhat for 2d array.
603
-
604
- Computation follows https://arxiv.org/abs/1903.08008
605
- """
606
- ary = np.asarray(ary)
607
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
608
- return np.nan
609
- split_ary = _split_chains(ary)
610
- rhat_bulk = _rhat(_z_scale(split_ary))
611
-
612
- split_ary_folded = abs(split_ary - np.median(split_ary))
613
- rhat_tail = _rhat(_z_scale(split_ary_folded))
614
-
615
- rhat_rank = max(rhat_bulk, rhat_tail)
616
- return rhat_rank
617
-
618
-
619
- def _rhat_folded(ary):
620
- """Calculate split-Rhat for folded z-values."""
621
- ary = np.asarray(ary)
622
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
623
- return np.nan
624
- ary = _z_fold(_split_chains(ary))
625
- return _rhat(ary)
626
-
627
-
628
- def _rhat_z_scale(ary):
629
- ary = np.asarray(ary)
630
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
631
- return np.nan
632
- return _rhat(_z_scale(_split_chains(ary)))
633
-
634
-
635
- def _rhat_split(ary):
636
- ary = np.asarray(ary)
637
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
638
- return np.nan
639
- return _rhat(_split_chains(ary))
640
-
641
-
642
- def _rhat_identity(ary):
643
- ary = np.asarray(ary)
644
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
645
- return np.nan
646
- return _rhat(ary)
647
-
648
-
649
- def _ess(ary, relative=False):
650
- """Compute the effective sample size for a 2D array."""
651
- _numba_flag = Numba.numba_flag
652
- ary = np.asarray(ary, dtype=float)
653
- if _not_valid(ary, check_shape=False):
654
- return np.nan
655
- if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution: # pylint: disable=no-member
656
- return ary.size
657
- if len(ary.shape) < 2:
658
- ary = np.atleast_2d(ary)
659
- n_chain, n_draw = ary.shape
660
- acov = _autocov(ary, axis=1)
661
- chain_mean = ary.mean(axis=1)
662
- mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
663
- var_plus = mean_var * (n_draw - 1.0) / n_draw
664
- if n_chain > 1:
665
- var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
666
-
667
- rho_hat_t = np.zeros(n_draw)
668
- rho_hat_even = 1.0
669
- rho_hat_t[0] = rho_hat_even
670
- rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
671
- rho_hat_t[1] = rho_hat_odd
672
-
673
- # Geyer's initial positive sequence
674
- t = 1
675
- while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
676
- rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
677
- rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
678
- if (rho_hat_even + rho_hat_odd) >= 0:
679
- rho_hat_t[t + 1] = rho_hat_even
680
- rho_hat_t[t + 2] = rho_hat_odd
681
- t += 2
682
-
683
- max_t = t - 2
684
- # improve estimation
685
- if rho_hat_even > 0:
686
- rho_hat_t[max_t + 1] = rho_hat_even
687
- # Geyer's initial monotone sequence
688
- t = 1
689
- while t <= max_t - 2:
690
- if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
691
- rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
692
- rho_hat_t[t + 2] = rho_hat_t[t + 1]
693
- t += 2
694
-
695
- ess = n_chain * n_draw
696
- tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
697
- tau_hat = max(tau_hat, 1 / np.log10(ess))
698
- ess = (1 if relative else ess) / tau_hat
699
- if np.isnan(rho_hat_t).any():
700
- ess = np.nan
701
- return ess
702
-
703
-
704
- def _ess_bulk(ary, relative=False):
705
- """Compute the effective sample size for the bulk."""
706
- ary = np.asarray(ary)
707
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
708
- return np.nan
709
- z_scaled = _z_scale(_split_chains(ary))
710
- ess_bulk = _ess(z_scaled, relative=relative)
711
- return ess_bulk
712
-
713
-
714
- def _ess_tail(ary, prob=None, relative=False):
715
- """Compute the effective sample size for the tail.
716
-
717
- If `prob` defined, ess = min(qess(prob), qess(1-prob))
718
- """
719
- if prob is None:
720
- prob = (0.05, 0.95)
721
- elif not isinstance(prob, Sequence):
722
- prob = (prob, 1 - prob)
723
-
724
- ary = np.asarray(ary)
725
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
726
- return np.nan
727
-
728
- prob_low, prob_high = prob
729
- quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
730
- quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
731
- return min(quantile_low_ess, quantile_high_ess)
732
-
733
-
734
- def _ess_mean(ary, relative=False):
735
- """Compute the effective sample size for the mean."""
736
- ary = np.asarray(ary)
737
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
738
- return np.nan
739
- return _ess(_split_chains(ary), relative=relative)
740
-
741
-
742
- def _ess_sd(ary, relative=False):
743
- """Compute the effective sample size for the sd."""
744
- ary = np.asarray(ary)
745
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
746
- return np.nan
747
- ary = (ary - ary.mean()) ** 2
748
- return _ess(_split_chains(ary), relative=relative)
749
-
750
-
751
- def _ess_quantile(ary, prob, relative=False):
752
- """Compute the effective sample size for the specific residual."""
753
- ary = np.asarray(ary)
754
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
755
- return np.nan
756
- if prob is None:
757
- raise TypeError("Prob not defined.")
758
- (quantile,) = _quantile(ary, prob)
759
- iquantile = ary <= quantile
760
- return _ess(_split_chains(iquantile), relative=relative)
761
-
762
-
763
- def _ess_local(ary, prob, relative=False):
764
- """Compute the effective sample size for the specific residual."""
765
- ary = np.asarray(ary)
766
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
767
- return np.nan
768
- if prob is None:
769
- raise TypeError("Prob not defined.")
770
- if len(prob) != 2:
771
- raise ValueError("Prob argument in ess local must be upper and lower bound")
772
- quantile = _quantile(ary, prob)
773
- iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
774
- return _ess(_split_chains(iquantile), relative=relative)
775
-
776
-
777
- def _ess_z_scale(ary, relative=False):
778
- """Calculate ess for z-scaLe."""
779
- ary = np.asarray(ary)
780
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
781
- return np.nan
782
- return _ess(_z_scale(_split_chains(ary)), relative=relative)
783
-
784
-
785
- def _ess_folded(ary, relative=False):
786
- """Calculate split-ess for folded data."""
787
- ary = np.asarray(ary)
788
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
789
- return np.nan
790
- return _ess(_z_fold(_split_chains(ary)), relative=relative)
791
-
792
-
793
- def _ess_median(ary, relative=False):
794
- """Calculate split-ess for median."""
795
- ary = np.asarray(ary)
796
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
797
- return np.nan
798
- return _ess_quantile(ary, 0.5, relative=relative)
799
-
800
-
801
- def _ess_mad(ary, relative=False):
802
- """Calculate split-ess for mean absolute deviance."""
803
- ary = np.asarray(ary)
804
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
805
- return np.nan
806
- ary = abs(ary - np.median(ary))
807
- ary = ary <= np.median(ary)
808
- ary = _z_scale(_split_chains(ary))
809
- return _ess(ary, relative=relative)
810
-
811
-
812
- def _ess_identity(ary, relative=False):
813
- """Calculate ess."""
814
- ary = np.asarray(ary)
815
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
816
- return np.nan
817
- return _ess(ary, relative=relative)
818
-
819
-
820
- def _mcse_mean(ary):
821
- """Compute the Markov Chain mean error."""
822
- _numba_flag = Numba.numba_flag
823
- ary = np.asarray(ary)
824
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
825
- return np.nan
826
- ess = _ess_mean(ary)
827
- if _numba_flag:
828
- sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
829
- else:
830
- sd = np.std(ary, ddof=1)
831
- mcse_mean_value = sd / np.sqrt(ess)
832
- return mcse_mean_value
833
-
834
-
835
- def _mcse_sd(ary):
836
- """Compute the Markov Chain sd error."""
837
- _numba_flag = Numba.numba_flag
838
- ary = np.asarray(ary)
839
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
840
- return np.nan
841
- sims_c2 = (ary - ary.mean()) ** 2
842
- ess = _ess_mean(sims_c2)
843
- evar = (sims_c2).mean()
844
- varvar = ((sims_c2**2).mean() - evar**2) / ess
845
- varsd = varvar / evar / 4
846
- if _numba_flag:
847
- mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
848
- else:
849
- mcse_sd_value = np.sqrt(varsd)
850
- return mcse_sd_value
851
-
852
-
853
- def _mcse_median(ary):
854
- """Compute the Markov Chain median error."""
855
- return _mcse_quantile(ary, 0.5)
856
-
857
-
858
- def _mcse_quantile(ary, prob):
859
- """Compute the Markov Chain quantile error at quantile=prob."""
860
- ary = np.asarray(ary)
861
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
862
- return np.nan
863
- ess = _ess_quantile(ary, prob)
864
- probability = [0.1586553, 0.8413447]
865
- with np.errstate(invalid="ignore"):
866
- ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
867
- sorted_ary = np.sort(ary.ravel())
868
- size = sorted_ary.size
869
- ppf_size = ppf * size - 1
870
- th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
871
- th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
872
- return (th2 - th1) / 2
873
-
874
-
875
- def _mc_error(ary, batches=5, circular=False):
876
- """Calculate the simulation standard error, accounting for non-independent samples.
877
-
878
- The trace is divided into batches, and the standard deviation of the batch
879
- means is calculated.
880
-
881
- Parameters
882
- ----------
883
- ary : Numpy array
884
- An array containing MCMC samples
885
- batches : integer
886
- Number of batches
887
- circular : bool
888
- Whether to compute the error taking into account `ary` is a circular variable
889
- (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
890
-
891
- Returns
892
- -------
893
- mc_error : float
894
- Simulation standard error
895
- """
896
- _numba_flag = Numba.numba_flag
897
- if ary.ndim > 1:
898
- dims = np.shape(ary)
899
- trace = np.transpose([t.ravel() for t in ary])
900
-
901
- return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])
902
-
903
- else:
904
- if _not_valid(ary, check_shape=False):
905
- return np.nan
906
- if batches == 1:
907
- if circular:
908
- if _numba_flag:
909
- std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
910
- else:
911
- std = stats.circstd(ary, high=np.pi, low=-np.pi)
912
- elif _numba_flag:
913
- std = float(_sqrt(svar(ary), np.zeros(1)).item())
914
- else:
915
- std = np.std(ary)
916
- return std / np.sqrt(len(ary))
917
-
918
- batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))
919
-
920
- if circular:
921
- means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
922
- if _numba_flag:
923
- std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
924
- else:
925
- std = stats.circstd(means, high=np.pi, low=-np.pi)
926
- else:
927
- means = np.mean(batched_traces, 1)
928
- std = _sqrt(svar(means), np.zeros(1)) if _numba_flag else np.std(means)
929
- return std / np.sqrt(batches)
930
-
931
-
932
- def _multichain_statistics(ary, focus="mean"):
933
- """Calculate efficiently multichain statistics for summary.
934
-
935
- Parameters
936
- ----------
937
- ary : numpy.ndarray
938
- focus : select focus for the statistics. Deafault is mean.
939
-
940
- Returns
941
- -------
942
- tuple
943
- Order of return parameters is
944
- If focus equals "mean"
945
- - mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat
946
- Else if focus equals "median"
947
- - mcse_median, ess_median, ess_tail, r_hat
948
- """
949
- ary = np.atleast_2d(ary)
950
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
951
- if focus == "mean":
952
- return np.nan, np.nan, np.nan, np.nan, np.nan
953
- return np.nan, np.nan, np.nan, np.nan
954
-
955
- z_split = _z_scale(_split_chains(ary))
956
-
957
- # ess tail
958
- quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
959
- iquantile05 = ary <= quantile05
960
- quantile05_ess = _ess(_split_chains(iquantile05))
961
- iquantile95 = ary <= quantile95
962
- quantile95_ess = _ess(_split_chains(iquantile95))
963
- ess_tail_value = min(quantile05_ess, quantile95_ess)
964
-
965
- if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
966
- rhat_value = np.nan
967
- else:
968
- # r_hat
969
- rhat_bulk = _rhat(z_split)
970
- ary_folded = np.abs(ary - np.median(ary))
971
- rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
972
- rhat_value = max(rhat_bulk, rhat_tail)
973
-
974
- if focus == "mean":
975
- # ess mean
976
- ess_mean_value = _ess_mean(ary)
977
-
978
- # mcse_mean
979
- sims_c2 = (ary - ary.mean()) ** 2
980
- sims_c2_sum = sims_c2.sum()
981
- var = sims_c2_sum / (sims_c2.size - 1)
982
- mcse_mean_value = np.sqrt(var / ess_mean_value)
983
-
984
- # ess bulk
985
- ess_bulk_value = _ess(z_split)
986
-
987
- # mcse_sd
988
- evar = sims_c2_sum / sims_c2.size
989
- ess_mean_sims = _ess_mean(sims_c2)
990
- varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
991
- varsd = varvar / evar / 4
992
- mcse_sd_value = np.sqrt(varsd)
993
-
994
- return (
995
- mcse_mean_value,
996
- mcse_sd_value,
997
- ess_bulk_value,
998
- ess_tail_value,
999
- rhat_value,
1000
- )
1001
-
1002
- # ess median
1003
- ess_median_value = _ess_median(ary)
1004
-
1005
- # mcse_median
1006
- mcse_median_value = _mcse_median(ary)
1007
-
1008
- return (
1009
- mcse_median_value,
1010
- ess_median_value,
1011
- ess_tail_value,
1012
- rhat_value,
1013
- )