arviz 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. arviz/__init__.py +2 -1
  2. arviz/data/base.py +18 -7
  3. arviz/data/converters.py +7 -3
  4. arviz/data/inference_data.py +8 -0
  5. arviz/data/io_cmdstan.py +4 -0
  6. arviz/data/io_numpyro.py +1 -1
  7. arviz/plots/backends/bokeh/ecdfplot.py +1 -2
  8. arviz/plots/backends/bokeh/khatplot.py +8 -3
  9. arviz/plots/backends/bokeh/pairplot.py +2 -6
  10. arviz/plots/backends/matplotlib/ecdfplot.py +1 -2
  11. arviz/plots/backends/matplotlib/khatplot.py +7 -3
  12. arviz/plots/backends/matplotlib/traceplot.py +1 -1
  13. arviz/plots/bpvplot.py +2 -2
  14. arviz/plots/compareplot.py +4 -4
  15. arviz/plots/densityplot.py +1 -1
  16. arviz/plots/dotplot.py +2 -2
  17. arviz/plots/ecdfplot.py +213 -89
  18. arviz/plots/essplot.py +2 -2
  19. arviz/plots/forestplot.py +3 -3
  20. arviz/plots/hdiplot.py +2 -2
  21. arviz/plots/kdeplot.py +9 -2
  22. arviz/plots/khatplot.py +23 -6
  23. arviz/plots/loopitplot.py +2 -2
  24. arviz/plots/mcseplot.py +3 -1
  25. arviz/plots/plot_utils.py +2 -4
  26. arviz/plots/posteriorplot.py +1 -1
  27. arviz/plots/rankplot.py +2 -2
  28. arviz/plots/violinplot.py +1 -1
  29. arviz/preview.py +17 -0
  30. arviz/rcparams.py +27 -2
  31. arviz/stats/diagnostics.py +13 -9
  32. arviz/stats/ecdf_utils.py +168 -10
  33. arviz/stats/stats.py +41 -20
  34. arviz/stats/stats_utils.py +8 -6
  35. arviz/tests/base_tests/test_data.py +11 -2
  36. arviz/tests/base_tests/test_data_zarr.py +0 -1
  37. arviz/tests/base_tests/test_diagnostics_numba.py +2 -7
  38. arviz/tests/base_tests/test_helpers.py +2 -2
  39. arviz/tests/base_tests/test_plot_utils.py +5 -13
  40. arviz/tests/base_tests/test_plots_matplotlib.py +95 -2
  41. arviz/tests/base_tests/test_rcparams.py +12 -0
  42. arviz/tests/base_tests/test_stats.py +1 -1
  43. arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
  44. arviz/tests/base_tests/test_stats_numba.py +2 -7
  45. arviz/tests/base_tests/test_utils_numba.py +2 -5
  46. arviz/tests/external_tests/test_data_pystan.py +5 -5
  47. arviz/tests/helpers.py +17 -9
  48. arviz/utils.py +4 -0
  49. {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/METADATA +23 -19
  50. {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/RECORD +53 -52
  51. {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/WHEEL +1 -1
  52. {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/LICENSE +0 -0
  53. {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/top_level.txt +0 -0
arviz/preview.py ADDED
@@ -0,0 +1,17 @@
1
+ # pylint: disable=unused-import,unused-wildcard-import,wildcard-import
2
+ """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
+
4
+ try:
5
+ from arviz_base import *
6
+ except ModuleNotFoundError:
7
+ pass
8
+
9
+ try:
10
+ import arviz_stats
11
+ except ModuleNotFoundError:
12
+ pass
13
+
14
+ try:
15
+ from arviz_plots import *
16
+ except ModuleNotFoundError:
17
+ pass
arviz/rcparams.py CHANGED
@@ -26,6 +26,8 @@ _log = logging.getLogger(__name__)
26
26
  ScaleKeyword = Literal["log", "negative_log", "deviance"]
27
27
  ICKeyword = Literal["loo", "waic"]
28
28
 
29
+ _identity = lambda x: x
30
+
29
31
 
30
32
  def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
31
33
  """Validate value is in accepted_values.
@@ -300,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
300
302
  lambda x: x,
301
303
  ),
302
304
  "plot.matplotlib.show": (False, _validate_boolean),
303
- "stats.hdi_prob": (0.94, _validate_probability),
305
+ "stats.ci_prob": (0.94, _validate_probability),
304
306
  "stats.information_criterion": (
305
307
  "loo",
306
308
  _make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
@@ -318,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
318
320
  ),
319
321
  }
320
322
 
323
+ # map from deprecated params to (version, new_param, fold2new, fnew2old)
324
+ deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
325
+
321
326
 
322
327
  class RcParams(MutableMapping):
323
328
  """Class to contain ArviZ default parameters.
@@ -335,6 +340,15 @@ class RcParams(MutableMapping):
335
340
 
336
341
  def __setitem__(self, key, val):
337
342
  """Add validation to __setitem__ function."""
343
+ if key in deprecated_map:
344
+ version, key_new, fold2new, _ = deprecated_map[key]
345
+ warnings.warn(
346
+ f"{key} is deprecated since {version}, use {key_new} instead",
347
+ FutureWarning,
348
+ )
349
+ key = key_new
350
+ val = fold2new(val)
351
+
338
352
  try:
339
353
  try:
340
354
  cval = self.validate[key](val)
@@ -349,7 +363,18 @@ class RcParams(MutableMapping):
349
363
 
350
364
  def __getitem__(self, key):
351
365
  """Use underlying dict's getitem method."""
352
- return self._underlying_storage[key]
366
+ if key in deprecated_map:
367
+ version, key_new, _, fnew2old = deprecated_map[key]
368
+ warnings.warn(
369
+ f"{key} is deprecated since {version}, use {key_new} instead",
370
+ FutureWarning,
371
+ )
372
+ if key not in self._underlying_storage:
373
+ key = key_new
374
+ else:
375
+ fnew2old = _identity
376
+
377
+ return fnew2old(self._underlying_storage[key])
353
378
 
354
379
  def __delitem__(self, key):
355
380
  """Raise TypeError if someone ever tries to delete a key from RcParams."""
@@ -135,10 +135,11 @@ def ess(
135
135
 
136
136
  References
137
137
  ----------
138
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
139
- * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
140
- Section 15.4.2
141
- * Gelman et al. BDA (2014) Formula 11.8
138
+ * Vehtari et al. (2021). Rank-normalization, folding, and
139
+ localization: An improved Rhat for assessing convergence of
140
+ MCMC. Bayesian analysis, 16(2):667-718.
141
+ * https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
142
+ * Gelman et al. BDA3 (2013) Formula 11.8
142
143
 
143
144
  See Also
144
145
  --------
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
246
247
  Names of variables to include in the rhat report
247
248
  method : str
248
249
  Select R-hat method. Valid methods are:
249
- - "rank" # recommended by Vehtari et al. (2019)
250
+ - "rank" # recommended by Vehtari et al. (2021)
250
251
  - "split"
251
252
  - "folded"
252
253
  - "z_scale"
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
269
270
  -----
270
271
  The diagnostic is computed by:
271
272
 
272
- .. math:: \hat{R} = \frac{\hat{V}}{W}
273
+ .. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
273
274
 
274
275
  where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
275
276
  estimate for the pooled rank-traces. This is the potential scale reduction factor, which
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
277
278
  greater than one indicate that one or more chains have not yet converged.
278
279
 
279
280
  Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
280
- Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
281
+ Each chain is split in two and normalized with the z-transform following
282
+ Vehtari et al. (2021).
281
283
 
282
284
  References
283
285
  ----------
284
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
285
- * Gelman et al. BDA (2014)
286
+ * Vehtari et al. (2021). Rank-normalization, folding, and
287
+ localization: An improved Rhat for assessing convergence of
288
+ MCMC. Bayesian analysis, 16(2):667-718.
289
+ * Gelman et al. BDA3 (2013)
286
290
  * Brooks and Gelman (1998)
287
291
  * Gelman and Rubin (1992)
288
292
 
arviz/stats/ecdf_utils.py CHANGED
@@ -1,10 +1,25 @@
1
1
  """Functions for evaluating ECDFs and their confidence bands."""
2
2
 
3
+ import math
3
4
  from typing import Any, Callable, Optional, Tuple
4
5
  import warnings
5
6
 
6
7
  import numpy as np
7
8
  from scipy.stats import uniform, binom
9
+ from scipy.optimize import minimize_scalar
10
+
11
+ try:
12
+ from numba import jit, vectorize
13
+ except ImportError:
14
+
15
+ def jit(*args, **kwargs): # pylint: disable=unused-argument
16
+ return lambda f: f
17
+
18
+ def vectorize(*args, **kwargs): # pylint: disable=unused-argument
19
+ return lambda f: f
20
+
21
+
22
+ from ..utils import Numba
8
23
 
9
24
 
10
25
  def compute_ecdf(sample: np.ndarray, eval_points: np.ndarray) -> np.ndarray:
@@ -25,6 +40,13 @@ def _get_ecdf_points(
25
40
  return x, y
26
41
 
27
42
 
43
+ def _call_rvs(rvs, ndraws, random_state):
44
+ if random_state is None:
45
+ return rvs(ndraws)
46
+ else:
47
+ return rvs(ndraws, random_state=random_state)
48
+
49
+
28
50
  def _simulate_ecdf(
29
51
  ndraws: int,
30
52
  eval_points: np.ndarray,
@@ -32,7 +54,7 @@ def _simulate_ecdf(
32
54
  random_state: Optional[Any] = None,
33
55
  ) -> np.ndarray:
34
56
  """Simulate ECDF at the `eval_points` using the given random variable sampler"""
35
- sample = rvs(ndraws, random_state=random_state)
57
+ sample = _call_rvs(rvs, ndraws, random_state)
36
58
  sample.sort()
37
59
  return compute_ecdf(sample, eval_points)
38
60
 
@@ -66,7 +88,7 @@ def ecdf_confidence_band(
66
88
  eval_points: np.ndarray,
67
89
  cdf_at_eval_points: np.ndarray,
68
90
  prob: float = 0.95,
69
- method="simulated",
91
+ method="optimized",
70
92
  **kwargs,
71
93
  ) -> Tuple[np.ndarray, np.ndarray]:
72
94
  """Compute the `prob`-level confidence band for the ECDF.
@@ -85,20 +107,17 @@ def ecdf_confidence_band(
85
107
  method : string, default "simulated"
86
108
  The method used to compute the confidence band. Valid options are:
87
109
  - "pointwise": Compute the pointwise (i.e. marginal) confidence band.
110
+ - "optimized": Use optimization to estimate a simultaneous confidence band.
88
111
  - "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence band.
89
112
  `rvs` must be provided.
90
113
  rvs: callable, optional
91
114
  A function that takes an integer `ndraws` and optionally the object passed to
92
115
  `random_state` and returns an array of `ndraws` samples from the same distribution
93
116
  as the original dataset. Required if `method` is "simulated" and variable is discrete.
94
- num_trials : int, default 1000
117
+ num_trials : int, default 500
95
118
  The number of random ECDFs to generate for constructing simultaneous confidence bands
96
119
  (if `method` is "simulated").
97
- random_state : {None, int, `numpy.random.Generator`,
98
- `numpy.random.RandomState`}, optional
99
- If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
100
- ``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
101
- `Generator` instance, the instance is used.
120
+ random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
102
121
 
103
122
  Returns
104
123
  -------
@@ -112,12 +131,18 @@ def ecdf_confidence_band(
112
131
 
113
132
  if method == "pointwise":
114
133
  prob_pointwise = prob
134
+ elif method == "optimized":
135
+ prob_pointwise = _optimize_simultaneous_ecdf_band_probability(
136
+ ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
137
+ )
115
138
  elif method == "simulated":
116
139
  prob_pointwise = _simulate_simultaneous_ecdf_band_probability(
117
140
  ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
118
141
  )
119
142
  else:
120
- raise ValueError(f"Unknown method {method}. Valid options are 'pointwise' or 'simulated'.")
143
+ raise ValueError(
144
+ f"Unknown method {method}. Valid options are 'pointwise', 'optimized', or 'simulated'."
145
+ )
121
146
 
122
147
  prob_lower, prob_upper = _get_pointwise_confidence_band(
123
148
  prob_pointwise, ndraws, cdf_at_eval_points
@@ -126,13 +151,146 @@ def ecdf_confidence_band(
126
151
  return prob_lower, prob_upper
127
152
 
128
153
 
154
+ def _update_ecdf_band_interior_probabilities(
155
+ prob_left: np.ndarray,
156
+ interval_left: np.ndarray,
157
+ interval_right: np.ndarray,
158
+ p: float,
159
+ ndraws: int,
160
+ ) -> np.ndarray:
161
+ """Update the probability that an ECDF has been within the envelope including at the current
162
+ point.
163
+
164
+ Arguments
165
+ ---------
166
+ prob_left : np.ndarray
167
+ For each point in the interior at the previous point, the joint probability that it and all
168
+ points before are in the interior.
169
+ interval_left : np.ndarray
170
+ The set of points in the interior at the previous point.
171
+ interval_right : np.ndarray
172
+ The set of points in the interior at the current point.
173
+ p : float
174
+ The probability of any given point found between the previous point and the current one.
175
+ ndraws : int
176
+ Number of draws in the original dataset.
177
+
178
+ Returns
179
+ -------
180
+ prob_right : np.ndarray
181
+ For each point in the interior at the current point, the joint probability that it and all
182
+ previous points are in the interior.
183
+ """
184
+ interval_left = interval_left[:, np.newaxis]
185
+ prob_conditional = binom.pmf(interval_right, ndraws - interval_left, p, loc=interval_left)
186
+ prob_right = prob_left.dot(prob_conditional)
187
+ return prob_right
188
+
189
+
190
+ @vectorize(["float64(int64, int64, float64, int64)"])
191
+ def _binom_pmf(k, n, p, loc):
192
+ k -= loc
193
+ if k < 0 or k > n:
194
+ return 0.0
195
+ if p == 0:
196
+ return 1.0 if k == 0 else 0.0
197
+ if p == 1:
198
+ return 1.0 if k == n else 0.0
199
+ if k == 0:
200
+ return (1 - p) ** n
201
+ if k == n:
202
+ return p**n
203
+ lbinom = math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1)
204
+ return np.exp(lbinom + k * np.log(p) + (n - k) * np.log1p(-p))
205
+
206
+
207
+ @jit(nopython=True)
208
+ def _update_ecdf_band_interior_probabilities_numba(
209
+ prob_left: np.ndarray,
210
+ interval_left: np.ndarray,
211
+ interval_right: np.ndarray,
212
+ p: float,
213
+ ndraws: int,
214
+ ) -> np.ndarray:
215
+ interval_left = interval_left[:, np.newaxis]
216
+ prob_conditional = _binom_pmf(interval_right, ndraws - interval_left, p, interval_left)
217
+ prob_right = prob_left.dot(prob_conditional)
218
+ return prob_right
219
+
220
+
221
+ def _ecdf_band_interior_probability(prob_between_points, ndraws, lower_count, upper_count):
222
+ interval_left = np.arange(1)
223
+ prob_interior = np.ones(1)
224
+ for i in range(prob_between_points.shape[0]):
225
+ interval_right = np.arange(lower_count[i], upper_count[i])
226
+ prob_interior = _update_ecdf_band_interior_probabilities(
227
+ prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
228
+ )
229
+ interval_left = interval_right
230
+ return prob_interior.sum()
231
+
232
+
233
+ @jit(nopython=True)
234
+ def _ecdf_band_interior_probability_numba(prob_between_points, ndraws, lower_count, upper_count):
235
+ interval_left = np.arange(1)
236
+ prob_interior = np.ones(1)
237
+ for i in range(prob_between_points.shape[0]):
238
+ interval_right = np.arange(lower_count[i], upper_count[i])
239
+ prob_interior = _update_ecdf_band_interior_probabilities_numba(
240
+ prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
241
+ )
242
+ interval_left = interval_right
243
+ return prob_interior.sum()
244
+
245
+
246
+ def _ecdf_band_optimization_objective(
247
+ prob_pointwise: float,
248
+ cdf_at_eval_points: np.ndarray,
249
+ ndraws: int,
250
+ prob_target: float,
251
+ ) -> float:
252
+ """Objective function for optimizing the simultaneous confidence band probability."""
253
+ lower, upper = _get_pointwise_confidence_band(prob_pointwise, ndraws, cdf_at_eval_points)
254
+ lower_count = (lower * ndraws).astype(int)
255
+ upper_count = (upper * ndraws).astype(int) + 1
256
+ cdf_with_zero = np.insert(cdf_at_eval_points[:-1], 0, 0)
257
+ prob_between_points = (cdf_at_eval_points - cdf_with_zero) / (1 - cdf_with_zero)
258
+ if Numba.numba_flag:
259
+ prob_interior = _ecdf_band_interior_probability_numba(
260
+ prob_between_points, ndraws, lower_count, upper_count
261
+ )
262
+ else:
263
+ prob_interior = _ecdf_band_interior_probability(
264
+ prob_between_points, ndraws, lower_count, upper_count
265
+ )
266
+ return abs(prob_interior - prob_target)
267
+
268
+
269
+ def _optimize_simultaneous_ecdf_band_probability(
270
+ ndraws: int,
271
+ eval_points: np.ndarray, # pylint: disable=unused-argument
272
+ cdf_at_eval_points: np.ndarray,
273
+ prob: float = 0.95,
274
+ **kwargs, # pylint: disable=unused-argument
275
+ ):
276
+ """Estimate probability for simultaneous confidence band using optimization.
277
+
278
+ This function simulates the pointwise probability needed to construct pointwise confidence bands
279
+ that form a `prob`-level confidence envelope for the ECDF of a sample.
280
+ """
281
+ cdf_at_eval_points = np.unique(cdf_at_eval_points)
282
+ objective = lambda p: _ecdf_band_optimization_objective(p, cdf_at_eval_points, ndraws, prob)
283
+ prob_pointwise = minimize_scalar(objective, bounds=(prob, 1), method="bounded").x
284
+ return prob_pointwise
285
+
286
+
129
287
  def _simulate_simultaneous_ecdf_band_probability(
130
288
  ndraws: int,
131
289
  eval_points: np.ndarray,
132
290
  cdf_at_eval_points: np.ndarray,
133
291
  prob: float = 0.95,
134
292
  rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
135
- num_trials: int = 1000,
293
+ num_trials: int = 500,
136
294
  random_state: Optional[Any] = None,
137
295
  ) -> float:
138
296
  """Estimate probability for simultaneous confidence band using simulation.
arviz/stats/stats.py CHANGED
@@ -270,12 +270,12 @@ def compare(
270
270
  weights[i] = u_weights / np.sum(u_weights)
271
271
 
272
272
  weights = weights.mean(axis=0)
273
- ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member
273
+ ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
274
274
 
275
275
  elif method.lower() == "pseudo-bma":
276
276
  min_ic = ics.iloc[0][f"elpd_{ic}"]
277
277
  z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
278
- weights = z_rv / np.sum(z_rv)
278
+ weights = (z_rv / np.sum(z_rv)).to_numpy()
279
279
  ses = ics["se"]
280
280
 
281
281
  if np.any(weights):
@@ -471,7 +471,7 @@ def hdi(
471
471
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
472
472
  hdi_prob: float, optional
473
473
  Prob for which the highest density interval will be computed. Defaults to
474
- ``stats.hdi_prob`` rcParam.
474
+ ``stats.ci_prob`` rcParam.
475
475
  circular: bool, optional
476
476
  Whether to compute the hdi taking into account `x` is a circular variable
477
477
  (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
@@ -553,7 +553,7 @@ def hdi(
553
553
 
554
554
  """
555
555
  if hdi_prob is None:
556
- hdi_prob = rcParams["stats.hdi_prob"]
556
+ hdi_prob = rcParams["stats.ci_prob"]
557
557
  elif not 1 >= hdi_prob > 0:
558
558
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
559
559
 
@@ -711,15 +711,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
711
711
  Returns
712
712
  -------
713
713
  ELPDData object (inherits from :class:`pandas.Series`) with the following row/attributes:
714
- elpd: approximated expected log pointwise predictive density (elpd)
714
+ elpd_loo: approximated expected log pointwise predictive density (elpd)
715
715
  se: standard error of the elpd
716
716
  p_loo: effective number of parameters
717
- shape_warn: bool
718
- True if the estimated shape parameter of
719
- Pareto distribution is greater than 0.7 for one or more samples
720
- loo_i: array of pointwise predictive accuracy, only if pointwise True
717
+ n_samples: number of samples
718
+ n_data_points: number of data points
719
+ warning: bool
720
+ True if the estimated shape parameter of Pareto distribution is greater than
721
+ ``good_k``.
722
+ loo_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
723
+ only if pointwise=True
721
724
  pareto_k: array of Pareto shape values, only if pointwise True
722
725
  scale: scale of the elpd
726
+ good_k: For a sample size S, the thresold is compute as min(1 - 1/log10(S), 0.7)
723
727
 
724
728
  The returned object has a custom print method that overrides pd.Series method.
725
729
 
@@ -785,13 +789,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
785
789
  log_weights += log_likelihood
786
790
 
787
791
  warn_mg = False
788
- if np.any(pareto_shape > 0.7):
792
+ good_k = min(1 - 1 / np.log10(n_samples), 0.7)
793
+
794
+ if np.any(pareto_shape > good_k):
789
795
  warnings.warn(
790
- "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
791
- "one or more samples. You should consider using a more robust model, this is because "
792
- "importance sampling is less likely to work well if the marginal posterior and "
793
- "LOO posterior are very different. This is more likely to happen with a non-robust "
794
- "model and highly influential observations."
796
+ f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
797
+ "for one or more samples. You should consider using a more robust model, this is "
798
+ "because importance sampling is less likely to work well if the marginal posterior "
799
+ "and LOO posterior are very different. This is more likely to happen with a "
800
+ "non-robust model and highly influential observations."
795
801
  )
796
802
  warn_mg = True
797
803
 
@@ -816,8 +822,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
816
822
 
817
823
  if not pointwise:
818
824
  return ELPDData(
819
- data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
820
- index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
825
+ data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
826
+ index=[
827
+ "elpd_loo",
828
+ "se",
829
+ "p_loo",
830
+ "n_samples",
831
+ "n_data_points",
832
+ "warning",
833
+ "scale",
834
+ "good_k",
835
+ ],
821
836
  )
822
837
  if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
823
838
  warnings.warn(
@@ -835,6 +850,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
835
850
  loo_lppd_i.rename("loo_i"),
836
851
  pareto_shape,
837
852
  scale,
853
+ good_k,
838
854
  ],
839
855
  index=[
840
856
  "elpd_loo",
@@ -846,6 +862,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
846
862
  "loo_i",
847
863
  "pareto_k",
848
864
  "scale",
865
+ "good_k",
849
866
  ],
850
867
  )
851
868
 
@@ -879,7 +896,8 @@ def psislw(log_weights, reff=1.0):
879
896
 
880
897
  References
881
898
  ----------
882
- * Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
899
+ * Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
900
+ Learning Research, 25(72):1-58.
883
901
 
884
902
  See Also
885
903
  --------
@@ -899,6 +917,7 @@ def psislw(log_weights, reff=1.0):
899
917
  ...: az.psislw(-log_likelihood, reff=0.8)
900
918
 
901
919
  """
920
+ log_weights = deepcopy(log_weights)
902
921
  if hasattr(log_weights, "__sample__"):
903
922
  n_samples = len(log_weights.__sample__)
904
923
  shape = [
@@ -1322,7 +1341,7 @@ def summary(
1322
1341
  if labeller is None:
1323
1342
  labeller = BaseLabeller()
1324
1343
  if hdi_prob is None:
1325
- hdi_prob = rcParams["stats.hdi_prob"]
1344
+ hdi_prob = rcParams["stats.ci_prob"]
1326
1345
  elif not 1 >= hdi_prob > 0:
1327
1346
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
1328
1347
 
@@ -1565,7 +1584,9 @@ def waic(data, pointwise=None, var_name=None, scale=None, dask_kwargs=None):
1565
1584
  elpd_waic: approximated expected log pointwise predictive density (elpd)
1566
1585
  se: standard error of the elpd
1567
1586
  p_waic: effective number parameters
1568
- var_warn: bool
1587
+ n_samples: number of samples
1588
+ n_data_points: number of data points
1589
+ warning: bool
1569
1590
  True if posterior variance of the log predictive densities exceeds 0.4
1570
1591
  waic_i: :class:`~xarray.DataArray` with the pointwise predictive accuracy,
1571
1592
  only if pointwise=True
@@ -454,10 +454,9 @@ POINTWISE_LOO_FMT = """------
454
454
 
455
455
  Pareto k diagnostic values:
456
456
  {{0:>{0}}} {{1:>6}}
457
- (-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
458
- (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
459
- (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
460
- (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
457
+ (-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
458
+ ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
459
+ (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
461
460
  """
462
461
  SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
463
462
 
@@ -488,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
488
487
  base += "\n\nThere has been a warning during the calculation. Please check the results."
489
488
 
490
489
  if kind == "loo" and "pareto_k" in self:
491
- bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
490
+ bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
492
491
  counts, *_ = _histogram(self.pareto_k.values, bins)
493
492
  extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
494
493
  extended = extended.format(
495
- "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
494
+ "Count",
495
+ "Pct.",
496
+ *[*counts, *(counts / np.sum(counts) * 100)],
497
+ self.good_k,
496
498
  )
497
499
  base = "\n".join([base, extended])
498
500
  return base
@@ -42,9 +42,12 @@ from ..helpers import ( # pylint: disable=unused-import
42
42
  draws,
43
43
  eight_schools_params,
44
44
  models,
45
- running_on_ci,
46
45
  )
47
46
 
47
+ # Check if dm-tree is installed
48
+ dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
49
+ skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
50
+
48
51
 
49
52
  @pytest.fixture(autouse=True)
50
53
  def no_remote_data(monkeypatch, tmpdir):
@@ -896,6 +899,11 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
896
899
  assert escape(repr(idata)) in html
897
900
  xr.set_options(display_style=display_style)
898
901
 
902
+ def test_setitem(self, data_random):
903
+ data_random["new_group"] = data_random.posterior
904
+ assert "new_group" in data_random.groups()
905
+ assert hasattr(data_random, "new_group")
906
+
899
907
  def test_add_groups(self, data_random):
900
908
  data = np.random.normal(size=(4, 500, 8))
901
909
  idata = data_random
@@ -1077,6 +1085,7 @@ def test_dict_to_dataset():
1077
1085
  assert set(dataset.b.coords) == {"chain", "draw", "c"}
1078
1086
 
1079
1087
 
1088
+ @pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
1080
1089
  def test_nested_dict_to_dataset():
1081
1090
  datadict = {
1082
1091
  "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
@@ -1469,7 +1478,7 @@ class TestJSON:
1469
1478
 
1470
1479
 
1471
1480
  @pytest.mark.skipif(
1472
- not (importlib.util.find_spec("datatree") or running_on_ci()),
1481
+ not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1473
1482
  reason="test requires xarray-datatree library",
1474
1483
  )
1475
1484
  class TestDataTree:
@@ -16,7 +16,6 @@ from ..helpers import ( # pylint: disable=unused-import
16
16
  draws,
17
17
  eight_schools_params,
18
18
  importorskip,
19
- running_on_ci,
20
19
  )
21
20
 
22
21
  zarr = importorskip("zarr") # pylint: disable=invalid-name
@@ -1,7 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
2
 
3
- import importlib
4
-
5
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
6
4
  import numpy as np
7
5
  import pytest
@@ -11,13 +9,10 @@ from ...rcparams import rcParams
11
9
  from ...stats import bfmi, mcse, rhat
12
10
  from ...stats.diagnostics import _mc_error, ks_summary
13
11
  from ...utils import Numba
14
- from ..helpers import running_on_ci
12
+ from ..helpers import importorskip
15
13
  from .test_diagnostics import data # pylint: disable=unused-import
16
14
 
17
- pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
18
- (importlib.util.find_spec("numba") is None) and not running_on_ci(),
19
- reason="test requires numba which is not installed",
20
- )
15
+ importorskip("numba")
21
16
 
22
17
  rcParams["data.load"] = "eager"
23
18
 
@@ -6,13 +6,13 @@ from ..helpers import importorskip
6
6
 
7
7
  def test_importorskip_local(monkeypatch):
8
8
  """Test ``importorskip`` run on local machine with non-existent module, which should skip."""
9
- monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False)
9
+ monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
10
10
  with pytest.raises(Skipped):
11
11
  importorskip("non-existent-function")
12
12
 
13
13
 
14
14
  def test_importorskip_ci(monkeypatch):
15
15
  """Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
16
- monkeypatch.setenv("ARVIZ_CI_MACHINE", 1)
16
+ monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
17
17
  with pytest.raises(ModuleNotFoundError):
18
18
  importorskip("non-existent-function")