arviz 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +2 -1
- arviz/data/io_cmdstan.py +4 -0
- arviz/data/io_numpyro.py +1 -1
- arviz/plots/backends/bokeh/ecdfplot.py +1 -2
- arviz/plots/backends/bokeh/khatplot.py +8 -3
- arviz/plots/backends/bokeh/pairplot.py +2 -6
- arviz/plots/backends/matplotlib/ecdfplot.py +1 -2
- arviz/plots/backends/matplotlib/khatplot.py +7 -3
- arviz/plots/backends/matplotlib/traceplot.py +1 -1
- arviz/plots/bpvplot.py +2 -2
- arviz/plots/densityplot.py +1 -1
- arviz/plots/dotplot.py +2 -2
- arviz/plots/ecdfplot.py +205 -89
- arviz/plots/essplot.py +2 -2
- arviz/plots/forestplot.py +1 -1
- arviz/plots/hdiplot.py +2 -2
- arviz/plots/khatplot.py +23 -6
- arviz/plots/loopitplot.py +2 -2
- arviz/plots/mcseplot.py +3 -1
- arviz/plots/plot_utils.py +2 -4
- arviz/plots/posteriorplot.py +1 -1
- arviz/plots/rankplot.py +2 -2
- arviz/plots/violinplot.py +1 -1
- arviz/preview.py +17 -0
- arviz/rcparams.py +27 -2
- arviz/stats/diagnostics.py +13 -9
- arviz/stats/ecdf_utils.py +11 -8
- arviz/stats/stats.py +31 -16
- arviz/stats/stats_utils.py +8 -6
- arviz/tests/base_tests/test_data.py +1 -2
- arviz/tests/base_tests/test_data_zarr.py +0 -1
- arviz/tests/base_tests/test_diagnostics_numba.py +2 -7
- arviz/tests/base_tests/test_helpers.py +2 -2
- arviz/tests/base_tests/test_plot_utils.py +5 -13
- arviz/tests/base_tests/test_plots_matplotlib.py +92 -2
- arviz/tests/base_tests/test_rcparams.py +12 -0
- arviz/tests/base_tests/test_stats.py +1 -1
- arviz/tests/base_tests/test_stats_numba.py +2 -7
- arviz/tests/base_tests/test_utils_numba.py +2 -5
- arviz/tests/external_tests/test_data_pystan.py +5 -5
- arviz/tests/helpers.py +17 -9
- arviz/utils.py +4 -0
- {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/METADATA +8 -4
- {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/RECORD +47 -46
- {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/LICENSE +0 -0
- {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/WHEEL +0 -0
- {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/top_level.txt +0 -0
arviz/plots/forestplot.py
CHANGED
|
@@ -246,7 +246,7 @@ def plot_forest(
|
|
|
246
246
|
width_ratios.append(1)
|
|
247
247
|
|
|
248
248
|
if hdi_prob is None:
|
|
249
|
-
hdi_prob = rcParams["stats.
|
|
249
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
250
250
|
elif not 1 >= hdi_prob > 0:
|
|
251
251
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
252
252
|
|
arviz/plots/hdiplot.py
CHANGED
|
@@ -42,7 +42,7 @@ def plot_hdi(
|
|
|
42
42
|
hdi_data : array_like, optional
|
|
43
43
|
Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``.
|
|
44
44
|
hdi_prob : float, optional
|
|
45
|
-
Probability for the highest density interval. Defaults to ``stats.
|
|
45
|
+
Probability for the highest density interval. Defaults to ``stats.ci_prob`` rcParam.
|
|
46
46
|
See :ref:`this section <common_ hdi_prob>` for usage examples.
|
|
47
47
|
color : str, default "C1"
|
|
48
48
|
Color used for the limits of the HDI and fill. Should be a valid matplotlib color.
|
|
@@ -155,7 +155,7 @@ def plot_hdi(
|
|
|
155
155
|
else:
|
|
156
156
|
y = np.asarray(y)
|
|
157
157
|
if hdi_prob is None:
|
|
158
|
-
hdi_prob = rcParams["stats.
|
|
158
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
159
159
|
elif not 1 >= hdi_prob > 0:
|
|
160
160
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
161
161
|
hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
|
arviz/plots/khatplot.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Pareto tail indices plot."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import warnings
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
from xarray import DataArray
|
|
@@ -40,10 +41,8 @@ def plot_khat(
|
|
|
40
41
|
|
|
41
42
|
Parameters
|
|
42
43
|
----------
|
|
43
|
-
khats : ELPDData
|
|
44
|
-
The input Pareto tail indices to be plotted.
|
|
45
|
-
Pareto shapes or an array. In this second case, all the values in the array are interpreted
|
|
46
|
-
as Pareto tail indices.
|
|
44
|
+
khats : ELPDData
|
|
45
|
+
The input Pareto tail indices to be plotted.
|
|
47
46
|
color : str or array_like, default "C0"
|
|
48
47
|
Colors of the scatter plot, if color is a str all dots will have the same color,
|
|
49
48
|
if it is the size of the observations, each dot will have the specified color,
|
|
@@ -149,8 +148,9 @@ def plot_khat(
|
|
|
149
148
|
|
|
150
149
|
References
|
|
151
150
|
----------
|
|
152
|
-
.. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J
|
|
153
|
-
|
|
151
|
+
.. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J. (2024).
|
|
152
|
+
Pareto Smoothed Importance Sampling. Journal of Machine Learning
|
|
153
|
+
Research, 25(72):1-58.
|
|
154
154
|
|
|
155
155
|
"""
|
|
156
156
|
if annotate:
|
|
@@ -164,13 +164,29 @@ def plot_khat(
|
|
|
164
164
|
color = "C0"
|
|
165
165
|
|
|
166
166
|
if isinstance(khats, np.ndarray):
|
|
167
|
+
warnings.warn(
|
|
168
|
+
"support for arrays will be deprecated, please use ELPDData."
|
|
169
|
+
"The reason for this, is that we need to know the numbers of draws"
|
|
170
|
+
"sampled from the posterior",
|
|
171
|
+
FutureWarning,
|
|
172
|
+
)
|
|
167
173
|
khats = khats.flatten()
|
|
168
174
|
xlabels = False
|
|
169
175
|
legend = False
|
|
170
176
|
dims = []
|
|
177
|
+
good_k = None
|
|
171
178
|
else:
|
|
172
179
|
if isinstance(khats, ELPDData):
|
|
180
|
+
good_k = khats.good_k
|
|
173
181
|
khats = khats.pareto_k
|
|
182
|
+
else:
|
|
183
|
+
good_k = None
|
|
184
|
+
warnings.warn(
|
|
185
|
+
"support for DataArrays will be deprecated, please use ELPDData."
|
|
186
|
+
"The reason for this, is that we need to know the numbers of draws"
|
|
187
|
+
"sampled from the posterior",
|
|
188
|
+
FutureWarning,
|
|
189
|
+
)
|
|
174
190
|
if not isinstance(khats, DataArray):
|
|
175
191
|
raise ValueError("Incorrect khat data input. Check the documentation")
|
|
176
192
|
|
|
@@ -191,6 +207,7 @@ def plot_khat(
|
|
|
191
207
|
figsize=figsize,
|
|
192
208
|
xdata=xdata,
|
|
193
209
|
khats=khats,
|
|
210
|
+
good_k=good_k,
|
|
194
211
|
kwargs=kwargs,
|
|
195
212
|
threshold=threshold,
|
|
196
213
|
coord_labels=coord_labels,
|
arviz/plots/loopitplot.py
CHANGED
|
@@ -55,7 +55,7 @@ def plot_loo_pit(
|
|
|
55
55
|
In this case, instead of overlaying uniform distributions, the beta ``hdi_prob``
|
|
56
56
|
around the theoretical uniform CDF is shown. This approximation only holds
|
|
57
57
|
for large S and ECDF values not very close to 0 nor 1. For more information, see
|
|
58
|
-
`Vehtari et al. (
|
|
58
|
+
`Vehtari et al. (2021)`, `Appendix G <https://avehtari.github.io/rhat_ess/rhat_ess.html>`_.
|
|
59
59
|
ecdf_fill : bool, optional
|
|
60
60
|
Use :meth:`matplotlib.axes.Axes.fill_between` to mark the area
|
|
61
61
|
inside the credible interval. Otherwise, plot the
|
|
@@ -159,7 +159,7 @@ def plot_loo_pit(
|
|
|
159
159
|
x_vals = None
|
|
160
160
|
|
|
161
161
|
if hdi_prob is None:
|
|
162
|
-
hdi_prob = rcParams["stats.
|
|
162
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
163
163
|
elif not 1 >= hdi_prob > 0:
|
|
164
164
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
165
165
|
|
arviz/plots/mcseplot.py
CHANGED
|
@@ -109,7 +109,9 @@ def plot_mcse(
|
|
|
109
109
|
|
|
110
110
|
References
|
|
111
111
|
----------
|
|
112
|
-
|
|
112
|
+
.. [1] Vehtari et al. (2021). Rank-normalization, folding, and
|
|
113
|
+
localization: An improved Rhat for assessing convergence of
|
|
114
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
113
115
|
|
|
114
116
|
Examples
|
|
115
117
|
--------
|
arviz/plots/plot_utils.py
CHANGED
|
@@ -245,10 +245,8 @@ def format_coords_as_labels(dataarray, skip_dims=None):
|
|
|
245
245
|
coord_labels = coord_labels.values
|
|
246
246
|
if isinstance(coord_labels[0], tuple):
|
|
247
247
|
fmt = ", ".join(["{}" for _ in coord_labels[0]])
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
coord_labels[:] = [f"{s}" for s in coord_labels]
|
|
251
|
-
return coord_labels
|
|
248
|
+
return np.array([fmt.format(*x) for x in coord_labels])
|
|
249
|
+
return np.array([f"{s}" for s in coord_labels])
|
|
252
250
|
|
|
253
251
|
|
|
254
252
|
def set_xticklabels(ax, coord_labels):
|
arviz/plots/posteriorplot.py
CHANGED
|
@@ -237,7 +237,7 @@ def plot_posterior(
|
|
|
237
237
|
labeller = BaseLabeller()
|
|
238
238
|
|
|
239
239
|
if hdi_prob is None:
|
|
240
|
-
hdi_prob = rcParams["stats.
|
|
240
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
241
241
|
elif hdi_prob not in (None, "hide"):
|
|
242
242
|
if not 1 >= hdi_prob > 0:
|
|
243
243
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
arviz/plots/rankplot.py
CHANGED
|
@@ -46,8 +46,8 @@ def plot_rank(
|
|
|
46
46
|
indicates good mixing of the chains.
|
|
47
47
|
|
|
48
48
|
This plot was introduced by Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter,
|
|
49
|
-
Paul-Christian Burkner (
|
|
50
|
-
for assessing convergence of MCMC.
|
|
49
|
+
Paul-Christian Burkner (2021): Rank-normalization, folding, and localization:
|
|
50
|
+
An improved R-hat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
Parameters
|
arviz/plots/violinplot.py
CHANGED
|
@@ -152,7 +152,7 @@ def plot_violin(
|
|
|
152
152
|
rows, cols = default_grid(len(plotters), grid=grid)
|
|
153
153
|
|
|
154
154
|
if hdi_prob is None:
|
|
155
|
-
hdi_prob = rcParams["stats.
|
|
155
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
156
156
|
elif not 1 >= hdi_prob > 0:
|
|
157
157
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
158
158
|
|
arviz/preview.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# pylint: disable=unused-import,unused-wildcard-import,wildcard-import
|
|
2
|
+
"""Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from arviz_base import *
|
|
6
|
+
except ModuleNotFoundError:
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import arviz_stats
|
|
11
|
+
except ModuleNotFoundError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from arviz_plots import *
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
pass
|
arviz/rcparams.py
CHANGED
|
@@ -26,6 +26,8 @@ _log = logging.getLogger(__name__)
|
|
|
26
26
|
ScaleKeyword = Literal["log", "negative_log", "deviance"]
|
|
27
27
|
ICKeyword = Literal["loo", "waic"]
|
|
28
28
|
|
|
29
|
+
_identity = lambda x: x
|
|
30
|
+
|
|
29
31
|
|
|
30
32
|
def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
|
|
31
33
|
"""Validate value is in accepted_values.
|
|
@@ -300,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
300
302
|
lambda x: x,
|
|
301
303
|
),
|
|
302
304
|
"plot.matplotlib.show": (False, _validate_boolean),
|
|
303
|
-
"stats.
|
|
305
|
+
"stats.ci_prob": (0.94, _validate_probability),
|
|
304
306
|
"stats.information_criterion": (
|
|
305
307
|
"loo",
|
|
306
308
|
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
|
|
@@ -318,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
318
320
|
),
|
|
319
321
|
}
|
|
320
322
|
|
|
323
|
+
# map from deprecated params to (version, new_param, fold2new, fnew2old)
|
|
324
|
+
deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
|
|
325
|
+
|
|
321
326
|
|
|
322
327
|
class RcParams(MutableMapping):
|
|
323
328
|
"""Class to contain ArviZ default parameters.
|
|
@@ -335,6 +340,15 @@ class RcParams(MutableMapping):
|
|
|
335
340
|
|
|
336
341
|
def __setitem__(self, key, val):
|
|
337
342
|
"""Add validation to __setitem__ function."""
|
|
343
|
+
if key in deprecated_map:
|
|
344
|
+
version, key_new, fold2new, _ = deprecated_map[key]
|
|
345
|
+
warnings.warn(
|
|
346
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
347
|
+
FutureWarning,
|
|
348
|
+
)
|
|
349
|
+
key = key_new
|
|
350
|
+
val = fold2new(val)
|
|
351
|
+
|
|
338
352
|
try:
|
|
339
353
|
try:
|
|
340
354
|
cval = self.validate[key](val)
|
|
@@ -349,7 +363,18 @@ class RcParams(MutableMapping):
|
|
|
349
363
|
|
|
350
364
|
def __getitem__(self, key):
|
|
351
365
|
"""Use underlying dict's getitem method."""
|
|
352
|
-
|
|
366
|
+
if key in deprecated_map:
|
|
367
|
+
version, key_new, _, fnew2old = deprecated_map[key]
|
|
368
|
+
warnings.warn(
|
|
369
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
370
|
+
FutureWarning,
|
|
371
|
+
)
|
|
372
|
+
if key not in self._underlying_storage:
|
|
373
|
+
key = key_new
|
|
374
|
+
else:
|
|
375
|
+
fnew2old = _identity
|
|
376
|
+
|
|
377
|
+
return fnew2old(self._underlying_storage[key])
|
|
353
378
|
|
|
354
379
|
def __delitem__(self, key):
|
|
355
380
|
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
|
arviz/stats/diagnostics.py
CHANGED
|
@@ -135,10 +135,11 @@ def ess(
|
|
|
135
135
|
|
|
136
136
|
References
|
|
137
137
|
----------
|
|
138
|
-
* Vehtari et al. (
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
*
|
|
138
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
139
|
+
localization: An improved Rhat for assessing convergence of
|
|
140
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
141
|
+
* https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
|
|
142
|
+
* Gelman et al. BDA3 (2013) Formula 11.8
|
|
142
143
|
|
|
143
144
|
See Also
|
|
144
145
|
--------
|
|
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
246
247
|
Names of variables to include in the rhat report
|
|
247
248
|
method : str
|
|
248
249
|
Select R-hat method. Valid methods are:
|
|
249
|
-
- "rank" # recommended by Vehtari et al. (
|
|
250
|
+
- "rank" # recommended by Vehtari et al. (2021)
|
|
250
251
|
- "split"
|
|
251
252
|
- "folded"
|
|
252
253
|
- "z_scale"
|
|
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
269
270
|
-----
|
|
270
271
|
The diagnostic is computed by:
|
|
271
272
|
|
|
272
|
-
.. math:: \hat{R} = \frac{\hat{V}}{W}
|
|
273
|
+
.. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
|
|
273
274
|
|
|
274
275
|
where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
|
|
275
276
|
estimate for the pooled rank-traces. This is the potential scale reduction factor, which
|
|
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
277
278
|
greater than one indicate that one or more chains have not yet converged.
|
|
278
279
|
|
|
279
280
|
Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
|
|
280
|
-
Each chain is split in two and normalized with the z-transform following
|
|
281
|
+
Each chain is split in two and normalized with the z-transform following
|
|
282
|
+
Vehtari et al. (2021).
|
|
281
283
|
|
|
282
284
|
References
|
|
283
285
|
----------
|
|
284
|
-
* Vehtari et al. (
|
|
285
|
-
|
|
286
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
287
|
+
localization: An improved Rhat for assessing convergence of
|
|
288
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
289
|
+
* Gelman et al. BDA3 (2013)
|
|
286
290
|
* Brooks and Gelman (1998)
|
|
287
291
|
* Gelman and Rubin (1992)
|
|
288
292
|
|
arviz/stats/ecdf_utils.py
CHANGED
|
@@ -25,6 +25,13 @@ def _get_ecdf_points(
|
|
|
25
25
|
return x, y
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
def _call_rvs(rvs, ndraws, random_state):
|
|
29
|
+
if random_state is None:
|
|
30
|
+
return rvs(ndraws)
|
|
31
|
+
else:
|
|
32
|
+
return rvs(ndraws, random_state=random_state)
|
|
33
|
+
|
|
34
|
+
|
|
28
35
|
def _simulate_ecdf(
|
|
29
36
|
ndraws: int,
|
|
30
37
|
eval_points: np.ndarray,
|
|
@@ -32,7 +39,7 @@ def _simulate_ecdf(
|
|
|
32
39
|
random_state: Optional[Any] = None,
|
|
33
40
|
) -> np.ndarray:
|
|
34
41
|
"""Simulate ECDF at the `eval_points` using the given random variable sampler"""
|
|
35
|
-
sample = rvs
|
|
42
|
+
sample = _call_rvs(rvs, ndraws, random_state)
|
|
36
43
|
sample.sort()
|
|
37
44
|
return compute_ecdf(sample, eval_points)
|
|
38
45
|
|
|
@@ -91,14 +98,10 @@ def ecdf_confidence_band(
|
|
|
91
98
|
A function that takes an integer `ndraws` and optionally the object passed to
|
|
92
99
|
`random_state` and returns an array of `ndraws` samples from the same distribution
|
|
93
100
|
as the original dataset. Required if `method` is "simulated" and variable is discrete.
|
|
94
|
-
num_trials : int, default
|
|
101
|
+
num_trials : int, default 500
|
|
95
102
|
The number of random ECDFs to generate for constructing simultaneous confidence bands
|
|
96
103
|
(if `method` is "simulated").
|
|
97
|
-
random_state :
|
|
98
|
-
`numpy.random.RandomState`}, optional
|
|
99
|
-
If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
|
|
100
|
-
``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
|
|
101
|
-
`Generator` instance, the instance is used.
|
|
104
|
+
random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
|
|
102
105
|
|
|
103
106
|
Returns
|
|
104
107
|
-------
|
|
@@ -132,7 +135,7 @@ def _simulate_simultaneous_ecdf_band_probability(
|
|
|
132
135
|
cdf_at_eval_points: np.ndarray,
|
|
133
136
|
prob: float = 0.95,
|
|
134
137
|
rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
|
|
135
|
-
num_trials: int =
|
|
138
|
+
num_trials: int = 500,
|
|
136
139
|
random_state: Optional[Any] = None,
|
|
137
140
|
) -> float:
|
|
138
141
|
"""Estimate probability for simultaneous confidence band using simulation.
|
arviz/stats/stats.py
CHANGED
|
@@ -270,12 +270,12 @@ def compare(
|
|
|
270
270
|
weights[i] = u_weights / np.sum(u_weights)
|
|
271
271
|
|
|
272
272
|
weights = weights.mean(axis=0)
|
|
273
|
-
ses = pd.Series(z_bs.std(axis=0), index=
|
|
273
|
+
ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
|
|
274
274
|
|
|
275
275
|
elif method.lower() == "pseudo-bma":
|
|
276
276
|
min_ic = ics.iloc[0][f"elpd_{ic}"]
|
|
277
277
|
z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
|
|
278
|
-
weights = z_rv / np.sum(z_rv)
|
|
278
|
+
weights = (z_rv / np.sum(z_rv)).to_numpy()
|
|
279
279
|
ses = ics["se"]
|
|
280
280
|
|
|
281
281
|
if np.any(weights):
|
|
@@ -471,7 +471,7 @@ def hdi(
|
|
|
471
471
|
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
472
472
|
hdi_prob: float, optional
|
|
473
473
|
Prob for which the highest density interval will be computed. Defaults to
|
|
474
|
-
``stats.
|
|
474
|
+
``stats.ci_prob`` rcParam.
|
|
475
475
|
circular: bool, optional
|
|
476
476
|
Whether to compute the hdi taking into account `x` is a circular variable
|
|
477
477
|
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
|
@@ -553,7 +553,7 @@ def hdi(
|
|
|
553
553
|
|
|
554
554
|
"""
|
|
555
555
|
if hdi_prob is None:
|
|
556
|
-
hdi_prob = rcParams["stats.
|
|
556
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
557
557
|
elif not 1 >= hdi_prob > 0:
|
|
558
558
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
559
559
|
|
|
@@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
715
715
|
se: standard error of the elpd
|
|
716
716
|
p_loo: effective number of parameters
|
|
717
717
|
shape_warn: bool
|
|
718
|
-
True if the estimated shape parameter of
|
|
719
|
-
|
|
718
|
+
True if the estimated shape parameter of Pareto distribution is greater than a thresold
|
|
719
|
+
value for one or more samples. For a sample size S, the thresold is compute as
|
|
720
|
+
min(1 - 1/log10(S), 0.7)
|
|
720
721
|
loo_i: array of pointwise predictive accuracy, only if pointwise True
|
|
721
722
|
pareto_k: array of Pareto shape values, only if pointwise True
|
|
722
723
|
scale: scale of the elpd
|
|
@@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
785
786
|
log_weights += log_likelihood
|
|
786
787
|
|
|
787
788
|
warn_mg = False
|
|
788
|
-
|
|
789
|
+
good_k = min(1 - 1 / np.log10(n_samples), 0.7)
|
|
790
|
+
|
|
791
|
+
if np.any(pareto_shape > good_k):
|
|
789
792
|
warnings.warn(
|
|
790
|
-
"Estimated shape parameter of Pareto distribution is greater than
|
|
791
|
-
"one or more samples. You should consider using a more robust model, this is
|
|
792
|
-
"importance sampling is less likely to work well if the marginal posterior
|
|
793
|
-
"LOO posterior are very different. This is more likely to happen with a
|
|
794
|
-
"model and highly influential observations."
|
|
793
|
+
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
|
|
794
|
+
"for one or more samples. You should consider using a more robust model, this is "
|
|
795
|
+
"because importance sampling is less likely to work well if the marginal posterior "
|
|
796
|
+
"and LOO posterior are very different. This is more likely to happen with a "
|
|
797
|
+
"non-robust model and highly influential observations."
|
|
795
798
|
)
|
|
796
799
|
warn_mg = True
|
|
797
800
|
|
|
@@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
816
819
|
|
|
817
820
|
if not pointwise:
|
|
818
821
|
return ELPDData(
|
|
819
|
-
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
|
|
820
|
-
index=[
|
|
822
|
+
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
|
|
823
|
+
index=[
|
|
824
|
+
"elpd_loo",
|
|
825
|
+
"se",
|
|
826
|
+
"p_loo",
|
|
827
|
+
"n_samples",
|
|
828
|
+
"n_data_points",
|
|
829
|
+
"warning",
|
|
830
|
+
"scale",
|
|
831
|
+
"good_k",
|
|
832
|
+
],
|
|
821
833
|
)
|
|
822
834
|
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
|
|
823
835
|
warnings.warn(
|
|
@@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
835
847
|
loo_lppd_i.rename("loo_i"),
|
|
836
848
|
pareto_shape,
|
|
837
849
|
scale,
|
|
850
|
+
good_k,
|
|
838
851
|
],
|
|
839
852
|
index=[
|
|
840
853
|
"elpd_loo",
|
|
@@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
846
859
|
"loo_i",
|
|
847
860
|
"pareto_k",
|
|
848
861
|
"scale",
|
|
862
|
+
"good_k",
|
|
849
863
|
],
|
|
850
864
|
)
|
|
851
865
|
|
|
@@ -879,7 +893,8 @@ def psislw(log_weights, reff=1.0):
|
|
|
879
893
|
|
|
880
894
|
References
|
|
881
895
|
----------
|
|
882
|
-
* Vehtari et al. (
|
|
896
|
+
* Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
|
|
897
|
+
Learning Research, 25(72):1-58.
|
|
883
898
|
|
|
884
899
|
See Also
|
|
885
900
|
--------
|
|
@@ -1322,7 +1337,7 @@ def summary(
|
|
|
1322
1337
|
if labeller is None:
|
|
1323
1338
|
labeller = BaseLabeller()
|
|
1324
1339
|
if hdi_prob is None:
|
|
1325
|
-
hdi_prob = rcParams["stats.
|
|
1340
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
1326
1341
|
elif not 1 >= hdi_prob > 0:
|
|
1327
1342
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
1328
1343
|
|
arviz/stats/stats_utils.py
CHANGED
|
@@ -454,10 +454,9 @@ POINTWISE_LOO_FMT = """------
|
|
|
454
454
|
|
|
455
455
|
Pareto k diagnostic values:
|
|
456
456
|
{{0:>{0}}} {{1:>6}}
|
|
457
|
-
(-Inf,
|
|
458
|
-
|
|
459
|
-
(
|
|
460
|
-
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
|
|
457
|
+
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
|
|
458
|
+
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
|
|
459
|
+
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
|
|
461
460
|
"""
|
|
462
461
|
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
|
|
463
462
|
|
|
@@ -488,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
|
488
487
|
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
489
488
|
|
|
490
489
|
if kind == "loo" and "pareto_k" in self:
|
|
491
|
-
bins = np.asarray([-np.inf,
|
|
490
|
+
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
|
|
492
491
|
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
493
492
|
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
494
493
|
extended = extended.format(
|
|
495
|
-
"Count",
|
|
494
|
+
"Count",
|
|
495
|
+
"Pct.",
|
|
496
|
+
*[*counts, *(counts / np.sum(counts) * 100)],
|
|
497
|
+
self.good_k,
|
|
496
498
|
)
|
|
497
499
|
base = "\n".join([base, extended])
|
|
498
500
|
return base
|
|
@@ -42,7 +42,6 @@ from ..helpers import ( # pylint: disable=unused-import
|
|
|
42
42
|
draws,
|
|
43
43
|
eight_schools_params,
|
|
44
44
|
models,
|
|
45
|
-
running_on_ci,
|
|
46
45
|
)
|
|
47
46
|
|
|
48
47
|
|
|
@@ -1469,7 +1468,7 @@ class TestJSON:
|
|
|
1469
1468
|
|
|
1470
1469
|
|
|
1471
1470
|
@pytest.mark.skipif(
|
|
1472
|
-
not (importlib.util.find_spec("datatree") or
|
|
1471
|
+
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1473
1472
|
reason="test requires xarray-datatree library",
|
|
1474
1473
|
)
|
|
1475
1474
|
class TestDataTree:
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Test Diagnostic methods"""
|
|
2
2
|
|
|
3
|
-
import importlib
|
|
4
|
-
|
|
5
3
|
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
|
|
6
4
|
import numpy as np
|
|
7
5
|
import pytest
|
|
@@ -11,13 +9,10 @@ from ...rcparams import rcParams
|
|
|
11
9
|
from ...stats import bfmi, mcse, rhat
|
|
12
10
|
from ...stats.diagnostics import _mc_error, ks_summary
|
|
13
11
|
from ...utils import Numba
|
|
14
|
-
from ..helpers import
|
|
12
|
+
from ..helpers import importorskip
|
|
15
13
|
from .test_diagnostics import data # pylint: disable=unused-import
|
|
16
14
|
|
|
17
|
-
|
|
18
|
-
(importlib.util.find_spec("numba") is None) and not running_on_ci(),
|
|
19
|
-
reason="test requires numba which is not installed",
|
|
20
|
-
)
|
|
15
|
+
importorskip("numba")
|
|
21
16
|
|
|
22
17
|
rcParams["data.load"] = "eager"
|
|
23
18
|
|
|
@@ -6,13 +6,13 @@ from ..helpers import importorskip
|
|
|
6
6
|
|
|
7
7
|
def test_importorskip_local(monkeypatch):
|
|
8
8
|
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
|
|
9
|
-
monkeypatch.delenv("
|
|
9
|
+
monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
|
|
10
10
|
with pytest.raises(Skipped):
|
|
11
11
|
importorskip("non-existent-function")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def test_importorskip_ci(monkeypatch):
|
|
15
15
|
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
|
|
16
|
-
monkeypatch.setenv("
|
|
16
|
+
monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
|
|
17
17
|
with pytest.raises(ModuleNotFoundError):
|
|
18
18
|
importorskip("non-existent-function")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# pylint: disable=redefined-outer-name
|
|
2
2
|
import importlib
|
|
3
|
+
import os
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pytest
|
|
@@ -20,10 +21,10 @@ from ...rcparams import rc_context
|
|
|
20
21
|
from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
|
|
21
22
|
from ...stats.density_utils import get_bins
|
|
22
23
|
from ...utils import get_coords
|
|
23
|
-
from ..helpers import running_on_ci
|
|
24
24
|
|
|
25
25
|
# Check if Bokeh is installed
|
|
26
26
|
bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
|
|
27
|
+
skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
@pytest.mark.parametrize(
|
|
@@ -212,10 +213,7 @@ def test_filter_plotter_list_warning():
|
|
|
212
213
|
assert len(plotters_filtered) == 5
|
|
213
214
|
|
|
214
215
|
|
|
215
|
-
@pytest.mark.skipif(
|
|
216
|
-
not (bokeh_installed or running_on_ci()),
|
|
217
|
-
reason="test requires bokeh which is not installed",
|
|
218
|
-
)
|
|
216
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
219
217
|
def test_bokeh_import():
|
|
220
218
|
"""Tests that correct method is returned on bokeh import"""
|
|
221
219
|
plot = get_plotting_function("plot_dist", "distplot", "bokeh")
|
|
@@ -290,10 +288,7 @@ def test_mpl_dealiase_sel_kwargs():
|
|
|
290
288
|
assert res["line_color"] == "red"
|
|
291
289
|
|
|
292
290
|
|
|
293
|
-
@pytest.mark.skipif(
|
|
294
|
-
not (bokeh_installed or running_on_ci()),
|
|
295
|
-
reason="test requires bokeh which is not installed",
|
|
296
|
-
)
|
|
291
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
297
292
|
def test_bokeh_dealiase_sel_kwargs():
|
|
298
293
|
"""Check bokeh dealiase_sel_kwargs behaviour.
|
|
299
294
|
|
|
@@ -315,10 +310,7 @@ def test_bokeh_dealiase_sel_kwargs():
|
|
|
315
310
|
assert res["line_color"] == "red"
|
|
316
311
|
|
|
317
312
|
|
|
318
|
-
@pytest.mark.skipif(
|
|
319
|
-
not (bokeh_installed or running_on_ci()),
|
|
320
|
-
reason="test requires bokeh which is not installed",
|
|
321
|
-
)
|
|
313
|
+
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
322
314
|
def test_set_bokeh_circular_ticks_labels():
|
|
323
315
|
"""Assert the axes returned after placing ticks and tick labels for circular plots."""
|
|
324
316
|
import bokeh.plotting as bkp
|