arviz 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +2 -1
- arviz/data/base.py +18 -7
- arviz/data/converters.py +7 -3
- arviz/data/inference_data.py +8 -0
- arviz/data/io_cmdstan.py +4 -0
- arviz/data/io_numpyro.py +1 -1
- arviz/plots/backends/bokeh/ecdfplot.py +1 -2
- arviz/plots/backends/bokeh/khatplot.py +8 -3
- arviz/plots/backends/bokeh/pairplot.py +2 -6
- arviz/plots/backends/matplotlib/ecdfplot.py +1 -2
- arviz/plots/backends/matplotlib/khatplot.py +7 -3
- arviz/plots/backends/matplotlib/traceplot.py +1 -1
- arviz/plots/bpvplot.py +2 -2
- arviz/plots/compareplot.py +4 -4
- arviz/plots/densityplot.py +1 -1
- arviz/plots/dotplot.py +2 -2
- arviz/plots/ecdfplot.py +213 -89
- arviz/plots/essplot.py +2 -2
- arviz/plots/forestplot.py +3 -3
- arviz/plots/hdiplot.py +2 -2
- arviz/plots/kdeplot.py +9 -2
- arviz/plots/khatplot.py +23 -6
- arviz/plots/loopitplot.py +2 -2
- arviz/plots/mcseplot.py +3 -1
- arviz/plots/plot_utils.py +2 -4
- arviz/plots/posteriorplot.py +1 -1
- arviz/plots/rankplot.py +2 -2
- arviz/plots/violinplot.py +1 -1
- arviz/preview.py +17 -0
- arviz/rcparams.py +27 -2
- arviz/stats/diagnostics.py +13 -9
- arviz/stats/ecdf_utils.py +168 -10
- arviz/stats/stats.py +41 -20
- arviz/stats/stats_utils.py +8 -6
- arviz/tests/base_tests/test_data.py +11 -2
- arviz/tests/base_tests/test_data_zarr.py +0 -1
- arviz/tests/base_tests/test_diagnostics_numba.py +2 -7
- arviz/tests/base_tests/test_helpers.py +2 -2
- arviz/tests/base_tests/test_plot_utils.py +5 -13
- arviz/tests/base_tests/test_plots_matplotlib.py +95 -2
- arviz/tests/base_tests/test_rcparams.py +12 -0
- arviz/tests/base_tests/test_stats.py +1 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
- arviz/tests/base_tests/test_stats_numba.py +2 -7
- arviz/tests/base_tests/test_utils_numba.py +2 -5
- arviz/tests/external_tests/test_data_pystan.py +5 -5
- arviz/tests/helpers.py +17 -9
- arviz/utils.py +4 -0
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/METADATA +23 -19
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/RECORD +53 -52
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/WHEEL +1 -1
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/LICENSE +0 -0
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/top_level.txt +0 -0
arviz/preview.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# pylint: disable=unused-import,unused-wildcard-import,wildcard-import
|
|
2
|
+
"""Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from arviz_base import *
|
|
6
|
+
except ModuleNotFoundError:
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import arviz_stats
|
|
11
|
+
except ModuleNotFoundError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from arviz_plots import *
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
pass
|
arviz/rcparams.py
CHANGED
|
@@ -26,6 +26,8 @@ _log = logging.getLogger(__name__)
|
|
|
26
26
|
ScaleKeyword = Literal["log", "negative_log", "deviance"]
|
|
27
27
|
ICKeyword = Literal["loo", "waic"]
|
|
28
28
|
|
|
29
|
+
_identity = lambda x: x
|
|
30
|
+
|
|
29
31
|
|
|
30
32
|
def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
|
|
31
33
|
"""Validate value is in accepted_values.
|
|
@@ -300,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
300
302
|
lambda x: x,
|
|
301
303
|
),
|
|
302
304
|
"plot.matplotlib.show": (False, _validate_boolean),
|
|
303
|
-
"stats.
|
|
305
|
+
"stats.ci_prob": (0.94, _validate_probability),
|
|
304
306
|
"stats.information_criterion": (
|
|
305
307
|
"loo",
|
|
306
308
|
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
|
|
@@ -318,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
|
|
|
318
320
|
),
|
|
319
321
|
}
|
|
320
322
|
|
|
323
|
+
# map from deprecated params to (version, new_param, fold2new, fnew2old)
|
|
324
|
+
deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
|
|
325
|
+
|
|
321
326
|
|
|
322
327
|
class RcParams(MutableMapping):
|
|
323
328
|
"""Class to contain ArviZ default parameters.
|
|
@@ -335,6 +340,15 @@ class RcParams(MutableMapping):
|
|
|
335
340
|
|
|
336
341
|
def __setitem__(self, key, val):
|
|
337
342
|
"""Add validation to __setitem__ function."""
|
|
343
|
+
if key in deprecated_map:
|
|
344
|
+
version, key_new, fold2new, _ = deprecated_map[key]
|
|
345
|
+
warnings.warn(
|
|
346
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
347
|
+
FutureWarning,
|
|
348
|
+
)
|
|
349
|
+
key = key_new
|
|
350
|
+
val = fold2new(val)
|
|
351
|
+
|
|
338
352
|
try:
|
|
339
353
|
try:
|
|
340
354
|
cval = self.validate[key](val)
|
|
@@ -349,7 +363,18 @@ class RcParams(MutableMapping):
|
|
|
349
363
|
|
|
350
364
|
def __getitem__(self, key):
|
|
351
365
|
"""Use underlying dict's getitem method."""
|
|
352
|
-
|
|
366
|
+
if key in deprecated_map:
|
|
367
|
+
version, key_new, _, fnew2old = deprecated_map[key]
|
|
368
|
+
warnings.warn(
|
|
369
|
+
f"{key} is deprecated since {version}, use {key_new} instead",
|
|
370
|
+
FutureWarning,
|
|
371
|
+
)
|
|
372
|
+
if key not in self._underlying_storage:
|
|
373
|
+
key = key_new
|
|
374
|
+
else:
|
|
375
|
+
fnew2old = _identity
|
|
376
|
+
|
|
377
|
+
return fnew2old(self._underlying_storage[key])
|
|
353
378
|
|
|
354
379
|
def __delitem__(self, key):
|
|
355
380
|
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
|
arviz/stats/diagnostics.py
CHANGED
|
@@ -135,10 +135,11 @@ def ess(
|
|
|
135
135
|
|
|
136
136
|
References
|
|
137
137
|
----------
|
|
138
|
-
* Vehtari et al. (
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
*
|
|
138
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
139
|
+
localization: An improved Rhat for assessing convergence of
|
|
140
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
141
|
+
* https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
|
|
142
|
+
* Gelman et al. BDA3 (2013) Formula 11.8
|
|
142
143
|
|
|
143
144
|
See Also
|
|
144
145
|
--------
|
|
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
246
247
|
Names of variables to include in the rhat report
|
|
247
248
|
method : str
|
|
248
249
|
Select R-hat method. Valid methods are:
|
|
249
|
-
- "rank" # recommended by Vehtari et al. (
|
|
250
|
+
- "rank" # recommended by Vehtari et al. (2021)
|
|
250
251
|
- "split"
|
|
251
252
|
- "folded"
|
|
252
253
|
- "z_scale"
|
|
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
269
270
|
-----
|
|
270
271
|
The diagnostic is computed by:
|
|
271
272
|
|
|
272
|
-
.. math:: \hat{R} = \frac{\hat{V}}{W}
|
|
273
|
+
.. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
|
|
273
274
|
|
|
274
275
|
where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
|
|
275
276
|
estimate for the pooled rank-traces. This is the potential scale reduction factor, which
|
|
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
|
|
|
277
278
|
greater than one indicate that one or more chains have not yet converged.
|
|
278
279
|
|
|
279
280
|
Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
|
|
280
|
-
Each chain is split in two and normalized with the z-transform following
|
|
281
|
+
Each chain is split in two and normalized with the z-transform following
|
|
282
|
+
Vehtari et al. (2021).
|
|
281
283
|
|
|
282
284
|
References
|
|
283
285
|
----------
|
|
284
|
-
* Vehtari et al. (
|
|
285
|
-
|
|
286
|
+
* Vehtari et al. (2021). Rank-normalization, folding, and
|
|
287
|
+
localization: An improved Rhat for assessing convergence of
|
|
288
|
+
MCMC. Bayesian analysis, 16(2):667-718.
|
|
289
|
+
* Gelman et al. BDA3 (2013)
|
|
286
290
|
* Brooks and Gelman (1998)
|
|
287
291
|
* Gelman and Rubin (1992)
|
|
288
292
|
|
arviz/stats/ecdf_utils.py
CHANGED
|
@@ -1,10 +1,25 @@
|
|
|
1
1
|
"""Functions for evaluating ECDFs and their confidence bands."""
|
|
2
2
|
|
|
3
|
+
import math
|
|
3
4
|
from typing import Any, Callable, Optional, Tuple
|
|
4
5
|
import warnings
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from scipy.stats import uniform, binom
|
|
9
|
+
from scipy.optimize import minimize_scalar
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from numba import jit, vectorize
|
|
13
|
+
except ImportError:
|
|
14
|
+
|
|
15
|
+
def jit(*args, **kwargs): # pylint: disable=unused-argument
|
|
16
|
+
return lambda f: f
|
|
17
|
+
|
|
18
|
+
def vectorize(*args, **kwargs): # pylint: disable=unused-argument
|
|
19
|
+
return lambda f: f
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
from ..utils import Numba
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
def compute_ecdf(sample: np.ndarray, eval_points: np.ndarray) -> np.ndarray:
|
|
@@ -25,6 +40,13 @@ def _get_ecdf_points(
|
|
|
25
40
|
return x, y
|
|
26
41
|
|
|
27
42
|
|
|
43
|
+
def _call_rvs(rvs, ndraws, random_state):
|
|
44
|
+
if random_state is None:
|
|
45
|
+
return rvs(ndraws)
|
|
46
|
+
else:
|
|
47
|
+
return rvs(ndraws, random_state=random_state)
|
|
48
|
+
|
|
49
|
+
|
|
28
50
|
def _simulate_ecdf(
|
|
29
51
|
ndraws: int,
|
|
30
52
|
eval_points: np.ndarray,
|
|
@@ -32,7 +54,7 @@ def _simulate_ecdf(
|
|
|
32
54
|
random_state: Optional[Any] = None,
|
|
33
55
|
) -> np.ndarray:
|
|
34
56
|
"""Simulate ECDF at the `eval_points` using the given random variable sampler"""
|
|
35
|
-
sample = rvs
|
|
57
|
+
sample = _call_rvs(rvs, ndraws, random_state)
|
|
36
58
|
sample.sort()
|
|
37
59
|
return compute_ecdf(sample, eval_points)
|
|
38
60
|
|
|
@@ -66,7 +88,7 @@ def ecdf_confidence_band(
|
|
|
66
88
|
eval_points: np.ndarray,
|
|
67
89
|
cdf_at_eval_points: np.ndarray,
|
|
68
90
|
prob: float = 0.95,
|
|
69
|
-
method="
|
|
91
|
+
method="optimized",
|
|
70
92
|
**kwargs,
|
|
71
93
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
72
94
|
"""Compute the `prob`-level confidence band for the ECDF.
|
|
@@ -85,20 +107,17 @@ def ecdf_confidence_band(
|
|
|
85
107
|
method : string, default "simulated"
|
|
86
108
|
The method used to compute the confidence band. Valid options are:
|
|
87
109
|
- "pointwise": Compute the pointwise (i.e. marginal) confidence band.
|
|
110
|
+
- "optimized": Use optimization to estimate a simultaneous confidence band.
|
|
88
111
|
- "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence band.
|
|
89
112
|
`rvs` must be provided.
|
|
90
113
|
rvs: callable, optional
|
|
91
114
|
A function that takes an integer `ndraws` and optionally the object passed to
|
|
92
115
|
`random_state` and returns an array of `ndraws` samples from the same distribution
|
|
93
116
|
as the original dataset. Required if `method` is "simulated" and variable is discrete.
|
|
94
|
-
num_trials : int, default
|
|
117
|
+
num_trials : int, default 500
|
|
95
118
|
The number of random ECDFs to generate for constructing simultaneous confidence bands
|
|
96
119
|
(if `method` is "simulated").
|
|
97
|
-
random_state :
|
|
98
|
-
`numpy.random.RandomState`}, optional
|
|
99
|
-
If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
|
|
100
|
-
``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
|
|
101
|
-
`Generator` instance, the instance is used.
|
|
120
|
+
random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
|
|
102
121
|
|
|
103
122
|
Returns
|
|
104
123
|
-------
|
|
@@ -112,12 +131,18 @@ def ecdf_confidence_band(
|
|
|
112
131
|
|
|
113
132
|
if method == "pointwise":
|
|
114
133
|
prob_pointwise = prob
|
|
134
|
+
elif method == "optimized":
|
|
135
|
+
prob_pointwise = _optimize_simultaneous_ecdf_band_probability(
|
|
136
|
+
ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
|
|
137
|
+
)
|
|
115
138
|
elif method == "simulated":
|
|
116
139
|
prob_pointwise = _simulate_simultaneous_ecdf_band_probability(
|
|
117
140
|
ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
|
|
118
141
|
)
|
|
119
142
|
else:
|
|
120
|
-
raise ValueError(
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Unknown method {method}. Valid options are 'pointwise', 'optimized', or 'simulated'."
|
|
145
|
+
)
|
|
121
146
|
|
|
122
147
|
prob_lower, prob_upper = _get_pointwise_confidence_band(
|
|
123
148
|
prob_pointwise, ndraws, cdf_at_eval_points
|
|
@@ -126,13 +151,146 @@ def ecdf_confidence_band(
|
|
|
126
151
|
return prob_lower, prob_upper
|
|
127
152
|
|
|
128
153
|
|
|
154
|
+
def _update_ecdf_band_interior_probabilities(
|
|
155
|
+
prob_left: np.ndarray,
|
|
156
|
+
interval_left: np.ndarray,
|
|
157
|
+
interval_right: np.ndarray,
|
|
158
|
+
p: float,
|
|
159
|
+
ndraws: int,
|
|
160
|
+
) -> np.ndarray:
|
|
161
|
+
"""Update the probability that an ECDF has been within the envelope including at the current
|
|
162
|
+
point.
|
|
163
|
+
|
|
164
|
+
Arguments
|
|
165
|
+
---------
|
|
166
|
+
prob_left : np.ndarray
|
|
167
|
+
For each point in the interior at the previous point, the joint probability that it and all
|
|
168
|
+
points before are in the interior.
|
|
169
|
+
interval_left : np.ndarray
|
|
170
|
+
The set of points in the interior at the previous point.
|
|
171
|
+
interval_right : np.ndarray
|
|
172
|
+
The set of points in the interior at the current point.
|
|
173
|
+
p : float
|
|
174
|
+
The probability of any given point found between the previous point and the current one.
|
|
175
|
+
ndraws : int
|
|
176
|
+
Number of draws in the original dataset.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
prob_right : np.ndarray
|
|
181
|
+
For each point in the interior at the current point, the joint probability that it and all
|
|
182
|
+
previous points are in the interior.
|
|
183
|
+
"""
|
|
184
|
+
interval_left = interval_left[:, np.newaxis]
|
|
185
|
+
prob_conditional = binom.pmf(interval_right, ndraws - interval_left, p, loc=interval_left)
|
|
186
|
+
prob_right = prob_left.dot(prob_conditional)
|
|
187
|
+
return prob_right
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@vectorize(["float64(int64, int64, float64, int64)"])
|
|
191
|
+
def _binom_pmf(k, n, p, loc):
|
|
192
|
+
k -= loc
|
|
193
|
+
if k < 0 or k > n:
|
|
194
|
+
return 0.0
|
|
195
|
+
if p == 0:
|
|
196
|
+
return 1.0 if k == 0 else 0.0
|
|
197
|
+
if p == 1:
|
|
198
|
+
return 1.0 if k == n else 0.0
|
|
199
|
+
if k == 0:
|
|
200
|
+
return (1 - p) ** n
|
|
201
|
+
if k == n:
|
|
202
|
+
return p**n
|
|
203
|
+
lbinom = math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1)
|
|
204
|
+
return np.exp(lbinom + k * np.log(p) + (n - k) * np.log1p(-p))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@jit(nopython=True)
|
|
208
|
+
def _update_ecdf_band_interior_probabilities_numba(
|
|
209
|
+
prob_left: np.ndarray,
|
|
210
|
+
interval_left: np.ndarray,
|
|
211
|
+
interval_right: np.ndarray,
|
|
212
|
+
p: float,
|
|
213
|
+
ndraws: int,
|
|
214
|
+
) -> np.ndarray:
|
|
215
|
+
interval_left = interval_left[:, np.newaxis]
|
|
216
|
+
prob_conditional = _binom_pmf(interval_right, ndraws - interval_left, p, interval_left)
|
|
217
|
+
prob_right = prob_left.dot(prob_conditional)
|
|
218
|
+
return prob_right
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _ecdf_band_interior_probability(prob_between_points, ndraws, lower_count, upper_count):
|
|
222
|
+
interval_left = np.arange(1)
|
|
223
|
+
prob_interior = np.ones(1)
|
|
224
|
+
for i in range(prob_between_points.shape[0]):
|
|
225
|
+
interval_right = np.arange(lower_count[i], upper_count[i])
|
|
226
|
+
prob_interior = _update_ecdf_band_interior_probabilities(
|
|
227
|
+
prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
|
|
228
|
+
)
|
|
229
|
+
interval_left = interval_right
|
|
230
|
+
return prob_interior.sum()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@jit(nopython=True)
|
|
234
|
+
def _ecdf_band_interior_probability_numba(prob_between_points, ndraws, lower_count, upper_count):
|
|
235
|
+
interval_left = np.arange(1)
|
|
236
|
+
prob_interior = np.ones(1)
|
|
237
|
+
for i in range(prob_between_points.shape[0]):
|
|
238
|
+
interval_right = np.arange(lower_count[i], upper_count[i])
|
|
239
|
+
prob_interior = _update_ecdf_band_interior_probabilities_numba(
|
|
240
|
+
prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
|
|
241
|
+
)
|
|
242
|
+
interval_left = interval_right
|
|
243
|
+
return prob_interior.sum()
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _ecdf_band_optimization_objective(
|
|
247
|
+
prob_pointwise: float,
|
|
248
|
+
cdf_at_eval_points: np.ndarray,
|
|
249
|
+
ndraws: int,
|
|
250
|
+
prob_target: float,
|
|
251
|
+
) -> float:
|
|
252
|
+
"""Objective function for optimizing the simultaneous confidence band probability."""
|
|
253
|
+
lower, upper = _get_pointwise_confidence_band(prob_pointwise, ndraws, cdf_at_eval_points)
|
|
254
|
+
lower_count = (lower * ndraws).astype(int)
|
|
255
|
+
upper_count = (upper * ndraws).astype(int) + 1
|
|
256
|
+
cdf_with_zero = np.insert(cdf_at_eval_points[:-1], 0, 0)
|
|
257
|
+
prob_between_points = (cdf_at_eval_points - cdf_with_zero) / (1 - cdf_with_zero)
|
|
258
|
+
if Numba.numba_flag:
|
|
259
|
+
prob_interior = _ecdf_band_interior_probability_numba(
|
|
260
|
+
prob_between_points, ndraws, lower_count, upper_count
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
prob_interior = _ecdf_band_interior_probability(
|
|
264
|
+
prob_between_points, ndraws, lower_count, upper_count
|
|
265
|
+
)
|
|
266
|
+
return abs(prob_interior - prob_target)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _optimize_simultaneous_ecdf_band_probability(
|
|
270
|
+
ndraws: int,
|
|
271
|
+
eval_points: np.ndarray, # pylint: disable=unused-argument
|
|
272
|
+
cdf_at_eval_points: np.ndarray,
|
|
273
|
+
prob: float = 0.95,
|
|
274
|
+
**kwargs, # pylint: disable=unused-argument
|
|
275
|
+
):
|
|
276
|
+
"""Estimate probability for simultaneous confidence band using optimization.
|
|
277
|
+
|
|
278
|
+
This function simulates the pointwise probability needed to construct pointwise confidence bands
|
|
279
|
+
that form a `prob`-level confidence envelope for the ECDF of a sample.
|
|
280
|
+
"""
|
|
281
|
+
cdf_at_eval_points = np.unique(cdf_at_eval_points)
|
|
282
|
+
objective = lambda p: _ecdf_band_optimization_objective(p, cdf_at_eval_points, ndraws, prob)
|
|
283
|
+
prob_pointwise = minimize_scalar(objective, bounds=(prob, 1), method="bounded").x
|
|
284
|
+
return prob_pointwise
|
|
285
|
+
|
|
286
|
+
|
|
129
287
|
def _simulate_simultaneous_ecdf_band_probability(
|
|
130
288
|
ndraws: int,
|
|
131
289
|
eval_points: np.ndarray,
|
|
132
290
|
cdf_at_eval_points: np.ndarray,
|
|
133
291
|
prob: float = 0.95,
|
|
134
292
|
rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
|
|
135
|
-
num_trials: int =
|
|
293
|
+
num_trials: int = 500,
|
|
136
294
|
random_state: Optional[Any] = None,
|
|
137
295
|
) -> float:
|
|
138
296
|
"""Estimate probability for simultaneous confidence band using simulation.
|
arviz/stats/stats.py
CHANGED
|
@@ -270,12 +270,12 @@ def compare(
|
|
|
270
270
|
weights[i] = u_weights / np.sum(u_weights)
|
|
271
271
|
|
|
272
272
|
weights = weights.mean(axis=0)
|
|
273
|
-
ses = pd.Series(z_bs.std(axis=0), index=
|
|
273
|
+
ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
|
|
274
274
|
|
|
275
275
|
elif method.lower() == "pseudo-bma":
|
|
276
276
|
min_ic = ics.iloc[0][f"elpd_{ic}"]
|
|
277
277
|
z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
|
|
278
|
-
weights = z_rv / np.sum(z_rv)
|
|
278
|
+
weights = (z_rv / np.sum(z_rv)).to_numpy()
|
|
279
279
|
ses = ics["se"]
|
|
280
280
|
|
|
281
281
|
if np.any(weights):
|
|
@@ -471,7 +471,7 @@ def hdi(
|
|
|
471
471
|
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
472
472
|
hdi_prob: float, optional
|
|
473
473
|
Prob for which the highest density interval will be computed. Defaults to
|
|
474
|
-
``stats.
|
|
474
|
+
``stats.ci_prob`` rcParam.
|
|
475
475
|
circular: bool, optional
|
|
476
476
|
Whether to compute the hdi taking into account `x` is a circular variable
|
|
477
477
|
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
|
@@ -553,7 +553,7 @@ def hdi(
|
|
|
553
553
|
|
|
554
554
|
"""
|
|
555
555
|
if hdi_prob is None:
|
|
556
|
-
hdi_prob = rcParams["stats.
|
|
556
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
557
557
|
elif not 1 >= hdi_prob > 0:
|
|
558
558
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
559
559
|
|
|
@@ -711,15 +711,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
711
711
|
Returns
|
|
712
712
|
-------
|
|
713
713
|
ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
|
|
714
|
-
|
|
714
|
+
elpd_loo: approximated expected log pointwise predictive density (elpd)
|
|
715
715
|
se: standard error of the elpd
|
|
716
716
|
p_loo: effective number of parameters
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
717
|
+
n_samples: number of samples
|
|
718
|
+
n_data_points: number of data points
|
|
719
|
+
warning: bool
|
|
720
|
+
True if the estimated shape parameter of Pareto distribution is greater than
|
|
721
|
+
``good_k``.
|
|
722
|
+
loo_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
723
|
+
only if pointwise=True
|
|
721
724
|
pareto_k: array of Pareto shape values, only if pointwise True
|
|
722
725
|
scale: scale of the elpd
|
|
726
|
+
good_k: For a sample size S, the thresold is compute as min(1 - 1/log10(S), 0.7)
|
|
723
727
|
|
|
724
728
|
The returned object has a custom print method that overrides pd.Series method.
|
|
725
729
|
|
|
@@ -785,13 +789,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
785
789
|
log_weights += log_likelihood
|
|
786
790
|
|
|
787
791
|
warn_mg = False
|
|
788
|
-
|
|
792
|
+
good_k = min(1 - 1 / np.log10(n_samples), 0.7)
|
|
793
|
+
|
|
794
|
+
if np.any(pareto_shape > good_k):
|
|
789
795
|
warnings.warn(
|
|
790
|
-
"Estimated shape parameter of Pareto distribution is greater than
|
|
791
|
-
"one or more samples. You should consider using a more robust model, this is
|
|
792
|
-
"importance sampling is less likely to work well if the marginal posterior
|
|
793
|
-
"LOO posterior are very different. This is more likely to happen with a
|
|
794
|
-
"model and highly influential observations."
|
|
796
|
+
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
|
|
797
|
+
"for one or more samples. You should consider using a more robust model, this is "
|
|
798
|
+
"because importance sampling is less likely to work well if the marginal posterior "
|
|
799
|
+
"and LOO posterior are very different. This is more likely to happen with a "
|
|
800
|
+
"non-robust model and highly influential observations."
|
|
795
801
|
)
|
|
796
802
|
warn_mg = True
|
|
797
803
|
|
|
@@ -816,8 +822,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
816
822
|
|
|
817
823
|
if not pointwise:
|
|
818
824
|
return ELPDData(
|
|
819
|
-
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
|
|
820
|
-
index=[
|
|
825
|
+
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
|
|
826
|
+
index=[
|
|
827
|
+
"elpd_loo",
|
|
828
|
+
"se",
|
|
829
|
+
"p_loo",
|
|
830
|
+
"n_samples",
|
|
831
|
+
"n_data_points",
|
|
832
|
+
"warning",
|
|
833
|
+
"scale",
|
|
834
|
+
"good_k",
|
|
835
|
+
],
|
|
821
836
|
)
|
|
822
837
|
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
|
|
823
838
|
warnings.warn(
|
|
@@ -835,6 +850,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
835
850
|
loo_lppd_i.rename("loo_i"),
|
|
836
851
|
pareto_shape,
|
|
837
852
|
scale,
|
|
853
|
+
good_k,
|
|
838
854
|
],
|
|
839
855
|
index=[
|
|
840
856
|
"elpd_loo",
|
|
@@ -846,6 +862,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
846
862
|
"loo_i",
|
|
847
863
|
"pareto_k",
|
|
848
864
|
"scale",
|
|
865
|
+
"good_k",
|
|
849
866
|
],
|
|
850
867
|
)
|
|
851
868
|
|
|
@@ -879,7 +896,8 @@ def psislw(log_weights, reff=1.0):
|
|
|
879
896
|
|
|
880
897
|
References
|
|
881
898
|
----------
|
|
882
|
-
* Vehtari et al. (
|
|
899
|
+
* Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
|
|
900
|
+
Learning Research, 25(72):1-58.
|
|
883
901
|
|
|
884
902
|
See Also
|
|
885
903
|
--------
|
|
@@ -899,6 +917,7 @@ def psislw(log_weights, reff=1.0):
|
|
|
899
917
|
...: az.psislw(-log_likelihood, reff=0.8)
|
|
900
918
|
|
|
901
919
|
"""
|
|
920
|
+
log_weights = deepcopy(log_weights)
|
|
902
921
|
if hasattr(log_weights, "__sample__"):
|
|
903
922
|
n_samples = len(log_weights.__sample__)
|
|
904
923
|
shape = [
|
|
@@ -1322,7 +1341,7 @@ def summary(
|
|
|
1322
1341
|
if labeller is None:
|
|
1323
1342
|
labeller = BaseLabeller()
|
|
1324
1343
|
if hdi_prob is None:
|
|
1325
|
-
hdi_prob = rcParams["stats.
|
|
1344
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
1326
1345
|
elif not 1 >= hdi_prob > 0:
|
|
1327
1346
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
1328
1347
|
|
|
@@ -1565,7 +1584,9 @@ def waic(data, pointwise=None, var_name=None, scale=None, dask_kwargs=None):
|
|
|
1565
1584
|
elpd_waic: approximated expected log pointwise predictive density (elpd)
|
|
1566
1585
|
se: standard error of the elpd
|
|
1567
1586
|
p_waic: effective number parameters
|
|
1568
|
-
|
|
1587
|
+
n_samples: number of samples
|
|
1588
|
+
n_data_points: number of data points
|
|
1589
|
+
warning: bool
|
|
1569
1590
|
True if posterior variance of the log predictive densities exceeds 0.4
|
|
1570
1591
|
waic_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
|
|
1571
1592
|
only if pointwise=True
|
arviz/stats/stats_utils.py
CHANGED
|
@@ -454,10 +454,9 @@ POINTWISE_LOO_FMT = """------
|
|
|
454
454
|
|
|
455
455
|
Pareto k diagnostic values:
|
|
456
456
|
{{0:>{0}}} {{1:>6}}
|
|
457
|
-
(-Inf,
|
|
458
|
-
|
|
459
|
-
(
|
|
460
|
-
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
|
|
457
|
+
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
|
|
458
|
+
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
|
|
459
|
+
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
|
|
461
460
|
"""
|
|
462
461
|
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
|
|
463
462
|
|
|
@@ -488,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
|
488
487
|
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
489
488
|
|
|
490
489
|
if kind == "loo" and "pareto_k" in self:
|
|
491
|
-
bins = np.asarray([-np.inf,
|
|
490
|
+
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
|
|
492
491
|
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
493
492
|
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
494
493
|
extended = extended.format(
|
|
495
|
-
"Count",
|
|
494
|
+
"Count",
|
|
495
|
+
"Pct.",
|
|
496
|
+
*[*counts, *(counts / np.sum(counts) * 100)],
|
|
497
|
+
self.good_k,
|
|
496
498
|
)
|
|
497
499
|
base = "\n".join([base, extended])
|
|
498
500
|
return base
|
|
@@ -42,9 +42,12 @@ from ..helpers import ( # pylint: disable=unused-import
|
|
|
42
42
|
draws,
|
|
43
43
|
eight_schools_params,
|
|
44
44
|
models,
|
|
45
|
-
running_on_ci,
|
|
46
45
|
)
|
|
47
46
|
|
|
47
|
+
# Check if dm-tree is installed
|
|
48
|
+
dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
|
|
49
|
+
skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
50
|
+
|
|
48
51
|
|
|
49
52
|
@pytest.fixture(autouse=True)
|
|
50
53
|
def no_remote_data(monkeypatch, tmpdir):
|
|
@@ -896,6 +899,11 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
896
899
|
assert escape(repr(idata)) in html
|
|
897
900
|
xr.set_options(display_style=display_style)
|
|
898
901
|
|
|
902
|
+
def test_setitem(self, data_random):
|
|
903
|
+
data_random["new_group"] = data_random.posterior
|
|
904
|
+
assert "new_group" in data_random.groups()
|
|
905
|
+
assert hasattr(data_random, "new_group")
|
|
906
|
+
|
|
899
907
|
def test_add_groups(self, data_random):
|
|
900
908
|
data = np.random.normal(size=(4, 500, 8))
|
|
901
909
|
idata = data_random
|
|
@@ -1077,6 +1085,7 @@ def test_dict_to_dataset():
|
|
|
1077
1085
|
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1078
1086
|
|
|
1079
1087
|
|
|
1088
|
+
@pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
|
|
1080
1089
|
def test_nested_dict_to_dataset():
|
|
1081
1090
|
datadict = {
|
|
1082
1091
|
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
@@ -1469,7 +1478,7 @@ class TestJSON:
|
|
|
1469
1478
|
|
|
1470
1479
|
|
|
1471
1480
|
@pytest.mark.skipif(
|
|
1472
|
-
not (importlib.util.find_spec("datatree") or
|
|
1481
|
+
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1473
1482
|
reason="test requires xarray-datatree library",
|
|
1474
1483
|
)
|
|
1475
1484
|
class TestDataTree:
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Test Diagnostic methods"""
|
|
2
2
|
|
|
3
|
-
import importlib
|
|
4
|
-
|
|
5
3
|
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
|
|
6
4
|
import numpy as np
|
|
7
5
|
import pytest
|
|
@@ -11,13 +9,10 @@ from ...rcparams import rcParams
|
|
|
11
9
|
from ...stats import bfmi, mcse, rhat
|
|
12
10
|
from ...stats.diagnostics import _mc_error, ks_summary
|
|
13
11
|
from ...utils import Numba
|
|
14
|
-
from ..helpers import
|
|
12
|
+
from ..helpers import importorskip
|
|
15
13
|
from .test_diagnostics import data # pylint: disable=unused-import
|
|
16
14
|
|
|
17
|
-
|
|
18
|
-
(importlib.util.find_spec("numba") is None) and not running_on_ci(),
|
|
19
|
-
reason="test requires numba which is not installed",
|
|
20
|
-
)
|
|
15
|
+
importorskip("numba")
|
|
21
16
|
|
|
22
17
|
rcParams["data.load"] = "eager"
|
|
23
18
|
|
|
@@ -6,13 +6,13 @@ from ..helpers import importorskip
|
|
|
6
6
|
|
|
7
7
|
def test_importorskip_local(monkeypatch):
|
|
8
8
|
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
|
|
9
|
-
monkeypatch.delenv("
|
|
9
|
+
monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
|
|
10
10
|
with pytest.raises(Skipped):
|
|
11
11
|
importorskip("non-existent-function")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def test_importorskip_ci(monkeypatch):
|
|
15
15
|
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
|
|
16
|
-
monkeypatch.setenv("
|
|
16
|
+
monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
|
|
17
17
|
with pytest.raises(ModuleNotFoundError):
|
|
18
18
|
importorskip("non-existent-function")
|