arviz 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. arviz/__init__.py +2 -1
  2. arviz/data/io_cmdstan.py +4 -0
  3. arviz/data/io_numpyro.py +1 -1
  4. arviz/plots/backends/bokeh/ecdfplot.py +1 -2
  5. arviz/plots/backends/bokeh/khatplot.py +8 -3
  6. arviz/plots/backends/bokeh/pairplot.py +2 -6
  7. arviz/plots/backends/matplotlib/ecdfplot.py +1 -2
  8. arviz/plots/backends/matplotlib/khatplot.py +7 -3
  9. arviz/plots/backends/matplotlib/traceplot.py +1 -1
  10. arviz/plots/bpvplot.py +2 -2
  11. arviz/plots/densityplot.py +1 -1
  12. arviz/plots/dotplot.py +2 -2
  13. arviz/plots/ecdfplot.py +205 -89
  14. arviz/plots/essplot.py +2 -2
  15. arviz/plots/forestplot.py +1 -1
  16. arviz/plots/hdiplot.py +2 -2
  17. arviz/plots/khatplot.py +23 -6
  18. arviz/plots/loopitplot.py +2 -2
  19. arviz/plots/mcseplot.py +3 -1
  20. arviz/plots/plot_utils.py +2 -4
  21. arviz/plots/posteriorplot.py +1 -1
  22. arviz/plots/rankplot.py +2 -2
  23. arviz/plots/violinplot.py +1 -1
  24. arviz/preview.py +17 -0
  25. arviz/rcparams.py +27 -2
  26. arviz/stats/diagnostics.py +13 -9
  27. arviz/stats/ecdf_utils.py +11 -8
  28. arviz/stats/stats.py +31 -16
  29. arviz/stats/stats_utils.py +8 -6
  30. arviz/tests/base_tests/test_data.py +1 -2
  31. arviz/tests/base_tests/test_data_zarr.py +0 -1
  32. arviz/tests/base_tests/test_diagnostics_numba.py +2 -7
  33. arviz/tests/base_tests/test_helpers.py +2 -2
  34. arviz/tests/base_tests/test_plot_utils.py +5 -13
  35. arviz/tests/base_tests/test_plots_matplotlib.py +92 -2
  36. arviz/tests/base_tests/test_rcparams.py +12 -0
  37. arviz/tests/base_tests/test_stats.py +1 -1
  38. arviz/tests/base_tests/test_stats_numba.py +2 -7
  39. arviz/tests/base_tests/test_utils_numba.py +2 -5
  40. arviz/tests/external_tests/test_data_pystan.py +5 -5
  41. arviz/tests/helpers.py +17 -9
  42. arviz/utils.py +4 -0
  43. {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/METADATA +8 -4
  44. {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/RECORD +47 -46
  45. {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/LICENSE +0 -0
  46. {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/WHEEL +0 -0
  47. {arviz-0.18.0.dist-info → arviz-0.19.0.dist-info}/top_level.txt +0 -0
arviz/plots/forestplot.py CHANGED
@@ -246,7 +246,7 @@ def plot_forest(
246
246
  width_ratios.append(1)
247
247
 
248
248
  if hdi_prob is None:
249
- hdi_prob = rcParams["stats.hdi_prob"]
249
+ hdi_prob = rcParams["stats.ci_prob"]
250
250
  elif not 1 >= hdi_prob > 0:
251
251
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
252
252
 
arviz/plots/hdiplot.py CHANGED
@@ -42,7 +42,7 @@ def plot_hdi(
42
42
  hdi_data : array_like, optional
43
43
  Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``.
44
44
  hdi_prob : float, optional
45
- Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam.
45
+ Probability for the highest density interval. Defaults to ``stats.ci_prob`` rcParam.
46
46
  See :ref:`this section <common_ hdi_prob>` for usage examples.
47
47
  color : str, default "C1"
48
48
  Color used for the limits of the HDI and fill. Should be a valid matplotlib color.
@@ -155,7 +155,7 @@ def plot_hdi(
155
155
  else:
156
156
  y = np.asarray(y)
157
157
  if hdi_prob is None:
158
- hdi_prob = rcParams["stats.hdi_prob"]
158
+ hdi_prob = rcParams["stats.ci_prob"]
159
159
  elif not 1 >= hdi_prob > 0:
160
160
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
161
161
  hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
arviz/plots/khatplot.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Pareto tail indices plot."""
2
2
 
3
3
  import logging
4
+ import warnings
4
5
 
5
6
  import numpy as np
6
7
  from xarray import DataArray
@@ -40,10 +41,8 @@ def plot_khat(
40
41
 
41
42
  Parameters
42
43
  ----------
43
- khats : ELPDData or array-like
44
- The input Pareto tail indices to be plotted. It can be an ``ELPDData`` object containing
45
- Pareto shapes or an array. In this second case, all the values in the array are interpreted
46
- as Pareto tail indices.
44
+ khats : ELPDData
45
+ The input Pareto tail indices to be plotted.
47
46
  color : str or array_like, default "C0"
48
47
  Colors of the scatter plot, if color is a str all dots will have the same color,
49
48
  if it is the size of the observations, each dot will have the specified color,
@@ -149,8 +148,9 @@ def plot_khat(
149
148
 
150
149
  References
151
150
  ----------
152
- .. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J.,
153
- 2019. Pareto Smoothed Importance Sampling. arXiv:1507.02646 [stat].
151
+ .. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J. (2024).
152
+ Pareto Smoothed Importance Sampling. Journal of Machine Learning
153
+ Research, 25(72):1-58.
154
154
 
155
155
  """
156
156
  if annotate:
@@ -164,13 +164,29 @@ def plot_khat(
164
164
  color = "C0"
165
165
 
166
166
  if isinstance(khats, np.ndarray):
167
+ warnings.warn(
168
+ "support for arrays will be deprecated, please use ELPDData."
169
+ "The reason for this, is that we need to know the numbers of draws"
170
+ "sampled from the posterior",
171
+ FutureWarning,
172
+ )
167
173
  khats = khats.flatten()
168
174
  xlabels = False
169
175
  legend = False
170
176
  dims = []
177
+ good_k = None
171
178
  else:
172
179
  if isinstance(khats, ELPDData):
180
+ good_k = khats.good_k
173
181
  khats = khats.pareto_k
182
+ else:
183
+ good_k = None
184
+ warnings.warn(
185
+ "support for DataArrays will be deprecated, please use ELPDData."
186
+ "The reason for this, is that we need to know the numbers of draws"
187
+ "sampled from the posterior",
188
+ FutureWarning,
189
+ )
174
190
  if not isinstance(khats, DataArray):
175
191
  raise ValueError("Incorrect khat data input. Check the documentation")
176
192
 
@@ -191,6 +207,7 @@ def plot_khat(
191
207
  figsize=figsize,
192
208
  xdata=xdata,
193
209
  khats=khats,
210
+ good_k=good_k,
194
211
  kwargs=kwargs,
195
212
  threshold=threshold,
196
213
  coord_labels=coord_labels,
arviz/plots/loopitplot.py CHANGED
@@ -55,7 +55,7 @@ def plot_loo_pit(
55
55
  In this case, instead of overlaying uniform distributions, the beta ``hdi_prob``
56
56
  around the theoretical uniform CDF is shown. This approximation only holds
57
57
  for large S and ECDF values not very close to 0 nor 1. For more information, see
58
- `Vehtari et al. (2019)`, `Appendix G <https://avehtari.github.io/rhat_ess/rhat_ess.html>`_.
58
+ `Vehtari et al. (2021)`, `Appendix G <https://avehtari.github.io/rhat_ess/rhat_ess.html>`_.
59
59
  ecdf_fill : bool, optional
60
60
  Use :meth:`matplotlib.axes.Axes.fill_between` to mark the area
61
61
  inside the credible interval. Otherwise, plot the
@@ -159,7 +159,7 @@ def plot_loo_pit(
159
159
  x_vals = None
160
160
 
161
161
  if hdi_prob is None:
162
- hdi_prob = rcParams["stats.hdi_prob"]
162
+ hdi_prob = rcParams["stats.ci_prob"]
163
163
  elif not 1 >= hdi_prob > 0:
164
164
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
165
165
 
arviz/plots/mcseplot.py CHANGED
@@ -109,7 +109,9 @@ def plot_mcse(
109
109
 
110
110
  References
111
111
  ----------
112
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
112
+ .. [1] Vehtari et al. (2021). Rank-normalization, folding, and
113
+ localization: An improved Rhat for assessing convergence of
114
+ MCMC. Bayesian analysis, 16(2):667-718.
113
115
 
114
116
  Examples
115
117
  --------
arviz/plots/plot_utils.py CHANGED
@@ -245,10 +245,8 @@ def format_coords_as_labels(dataarray, skip_dims=None):
245
245
  coord_labels = coord_labels.values
246
246
  if isinstance(coord_labels[0], tuple):
247
247
  fmt = ", ".join(["{}" for _ in coord_labels[0]])
248
- coord_labels[:] = [fmt.format(*x) for x in coord_labels]
249
- else:
250
- coord_labels[:] = [f"{s}" for s in coord_labels]
251
- return coord_labels
248
+ return np.array([fmt.format(*x) for x in coord_labels])
249
+ return np.array([f"{s}" for s in coord_labels])
252
250
 
253
251
 
254
252
  def set_xticklabels(ax, coord_labels):
@@ -237,7 +237,7 @@ def plot_posterior(
237
237
  labeller = BaseLabeller()
238
238
 
239
239
  if hdi_prob is None:
240
- hdi_prob = rcParams["stats.hdi_prob"]
240
+ hdi_prob = rcParams["stats.ci_prob"]
241
241
  elif hdi_prob not in (None, "hide"):
242
242
  if not 1 >= hdi_prob > 0:
243
243
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
arviz/plots/rankplot.py CHANGED
@@ -46,8 +46,8 @@ def plot_rank(
46
46
  indicates good mixing of the chains.
47
47
 
48
48
  This plot was introduced by Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter,
49
- Paul-Christian Burkner (2019): Rank-normalization, folding, and localization: An improved R-hat
50
- for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008
49
+ Paul-Christian Burkner (2021): Rank-normalization, folding, and localization:
50
+ An improved R-hat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.
51
51
 
52
52
 
53
53
  Parameters
arviz/plots/violinplot.py CHANGED
@@ -152,7 +152,7 @@ def plot_violin(
152
152
  rows, cols = default_grid(len(plotters), grid=grid)
153
153
 
154
154
  if hdi_prob is None:
155
- hdi_prob = rcParams["stats.hdi_prob"]
155
+ hdi_prob = rcParams["stats.ci_prob"]
156
156
  elif not 1 >= hdi_prob > 0:
157
157
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
158
158
 
arviz/preview.py ADDED
@@ -0,0 +1,17 @@
1
+ # pylint: disable=unused-import,unused-wildcard-import,wildcard-import
2
+ """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
+
4
+ try:
5
+ from arviz_base import *
6
+ except ModuleNotFoundError:
7
+ pass
8
+
9
+ try:
10
+ import arviz_stats
11
+ except ModuleNotFoundError:
12
+ pass
13
+
14
+ try:
15
+ from arviz_plots import *
16
+ except ModuleNotFoundError:
17
+ pass
arviz/rcparams.py CHANGED
@@ -26,6 +26,8 @@ _log = logging.getLogger(__name__)
26
26
  ScaleKeyword = Literal["log", "negative_log", "deviance"]
27
27
  ICKeyword = Literal["loo", "waic"]
28
28
 
29
+ _identity = lambda x: x
30
+
29
31
 
30
32
  def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
31
33
  """Validate value is in accepted_values.
@@ -300,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
300
302
  lambda x: x,
301
303
  ),
302
304
  "plot.matplotlib.show": (False, _validate_boolean),
303
- "stats.hdi_prob": (0.94, _validate_probability),
305
+ "stats.ci_prob": (0.94, _validate_probability),
304
306
  "stats.information_criterion": (
305
307
  "loo",
306
308
  _make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
@@ -318,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
318
320
  ),
319
321
  }
320
322
 
323
+ # map from deprecated params to (version, new_param, fold2new, fnew2old)
324
+ deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
325
+
321
326
 
322
327
  class RcParams(MutableMapping):
323
328
  """Class to contain ArviZ default parameters.
@@ -335,6 +340,15 @@ class RcParams(MutableMapping):
335
340
 
336
341
  def __setitem__(self, key, val):
337
342
  """Add validation to __setitem__ function."""
343
+ if key in deprecated_map:
344
+ version, key_new, fold2new, _ = deprecated_map[key]
345
+ warnings.warn(
346
+ f"{key} is deprecated since {version}, use {key_new} instead",
347
+ FutureWarning,
348
+ )
349
+ key = key_new
350
+ val = fold2new(val)
351
+
338
352
  try:
339
353
  try:
340
354
  cval = self.validate[key](val)
@@ -349,7 +363,18 @@ class RcParams(MutableMapping):
349
363
 
350
364
  def __getitem__(self, key):
351
365
  """Use underlying dict's getitem method."""
352
- return self._underlying_storage[key]
366
+ if key in deprecated_map:
367
+ version, key_new, _, fnew2old = deprecated_map[key]
368
+ warnings.warn(
369
+ f"{key} is deprecated since {version}, use {key_new} instead",
370
+ FutureWarning,
371
+ )
372
+ if key not in self._underlying_storage:
373
+ key = key_new
374
+ else:
375
+ fnew2old = _identity
376
+
377
+ return fnew2old(self._underlying_storage[key])
353
378
 
354
379
  def __delitem__(self, key):
355
380
  """Raise TypeError if someone ever tries to delete a key from RcParams."""
@@ -135,10 +135,11 @@ def ess(
135
135
 
136
136
  References
137
137
  ----------
138
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
139
- * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
140
- Section 15.4.2
141
- * Gelman et al. BDA (2014) Formula 11.8
138
+ * Vehtari et al. (2021). Rank-normalization, folding, and
139
+ localization: An improved Rhat for assessing convergence of
140
+ MCMC. Bayesian analysis, 16(2):667-718.
141
+ * https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
142
+ * Gelman et al. BDA3 (2013) Formula 11.8
142
143
 
143
144
  See Also
144
145
  --------
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
246
247
  Names of variables to include in the rhat report
247
248
  method : str
248
249
  Select R-hat method. Valid methods are:
249
- - "rank" # recommended by Vehtari et al. (2019)
250
+ - "rank" # recommended by Vehtari et al. (2021)
250
251
  - "split"
251
252
  - "folded"
252
253
  - "z_scale"
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
269
270
  -----
270
271
  The diagnostic is computed by:
271
272
 
272
- .. math:: \hat{R} = \frac{\hat{V}}{W}
273
+ .. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
273
274
 
274
275
  where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
275
276
  estimate for the pooled rank-traces. This is the potential scale reduction factor, which
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
277
278
  greater than one indicate that one or more chains have not yet converged.
278
279
 
279
280
  Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
280
- Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
281
+ Each chain is split in two and normalized with the z-transform following
282
+ Vehtari et al. (2021).
281
283
 
282
284
  References
283
285
  ----------
284
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
285
- * Gelman et al. BDA (2014)
286
+ * Vehtari et al. (2021). Rank-normalization, folding, and
287
+ localization: An improved Rhat for assessing convergence of
288
+ MCMC. Bayesian analysis, 16(2):667-718.
289
+ * Gelman et al. BDA3 (2013)
286
290
  * Brooks and Gelman (1998)
287
291
  * Gelman and Rubin (1992)
288
292
 
arviz/stats/ecdf_utils.py CHANGED
@@ -25,6 +25,13 @@ def _get_ecdf_points(
25
25
  return x, y
26
26
 
27
27
 
28
+ def _call_rvs(rvs, ndraws, random_state):
29
+ if random_state is None:
30
+ return rvs(ndraws)
31
+ else:
32
+ return rvs(ndraws, random_state=random_state)
33
+
34
+
28
35
  def _simulate_ecdf(
29
36
  ndraws: int,
30
37
  eval_points: np.ndarray,
@@ -32,7 +39,7 @@ def _simulate_ecdf(
32
39
  random_state: Optional[Any] = None,
33
40
  ) -> np.ndarray:
34
41
  """Simulate ECDF at the `eval_points` using the given random variable sampler"""
35
- sample = rvs(ndraws, random_state=random_state)
42
+ sample = _call_rvs(rvs, ndraws, random_state)
36
43
  sample.sort()
37
44
  return compute_ecdf(sample, eval_points)
38
45
 
@@ -91,14 +98,10 @@ def ecdf_confidence_band(
91
98
  A function that takes an integer `ndraws` and optionally the object passed to
92
99
  `random_state` and returns an array of `ndraws` samples from the same distribution
93
100
  as the original dataset. Required if `method` is "simulated" and variable is discrete.
94
- num_trials : int, default 1000
101
+ num_trials : int, default 500
95
102
  The number of random ECDFs to generate for constructing simultaneous confidence bands
96
103
  (if `method` is "simulated").
97
- random_state : {None, int, `numpy.random.Generator`,
98
- `numpy.random.RandomState`}, optional
99
- If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
100
- ``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
101
- `Generator` instance, the instance is used.
104
+ random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
102
105
 
103
106
  Returns
104
107
  -------
@@ -132,7 +135,7 @@ def _simulate_simultaneous_ecdf_band_probability(
132
135
  cdf_at_eval_points: np.ndarray,
133
136
  prob: float = 0.95,
134
137
  rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
135
- num_trials: int = 1000,
138
+ num_trials: int = 500,
136
139
  random_state: Optional[Any] = None,
137
140
  ) -> float:
138
141
  """Estimate probability for simultaneous confidence band using simulation.
arviz/stats/stats.py CHANGED
@@ -270,12 +270,12 @@ def compare(
270
270
  weights[i] = u_weights / np.sum(u_weights)
271
271
 
272
272
  weights = weights.mean(axis=0)
273
- ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member
273
+ ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
274
274
 
275
275
  elif method.lower() == "pseudo-bma":
276
276
  min_ic = ics.iloc[0][f"elpd_{ic}"]
277
277
  z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
278
- weights = z_rv / np.sum(z_rv)
278
+ weights = (z_rv / np.sum(z_rv)).to_numpy()
279
279
  ses = ics["se"]
280
280
 
281
281
  if np.any(weights):
@@ -471,7 +471,7 @@ def hdi(
471
471
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
472
472
  hdi_prob: float, optional
473
473
  Prob for which the highest density interval will be computed. Defaults to
474
- ``stats.hdi_prob`` rcParam.
474
+ ``stats.ci_prob`` rcParam.
475
475
  circular: bool, optional
476
476
  Whether to compute the hdi taking into account `x` is a circular variable
477
477
  (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
@@ -553,7 +553,7 @@ def hdi(
553
553
 
554
554
  """
555
555
  if hdi_prob is None:
556
- hdi_prob = rcParams["stats.hdi_prob"]
556
+ hdi_prob = rcParams["stats.ci_prob"]
557
557
  elif not 1 >= hdi_prob > 0:
558
558
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
559
559
 
@@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
715
715
  se: standard error of the elpd
716
716
  p_loo: effective number of parameters
717
717
  shape_warn: bool
718
- True if the estimated shape parameter of
719
- Pareto distribution is greater than 0.7 for one or more samples
718
+ True if the estimated shape parameter of Pareto distribution is greater than a thresold
719
+ value for one or more samples. For a sample size S, the thresold is compute as
720
+ min(1 - 1/log10(S), 0.7)
720
721
  loo_i: array of pointwise predictive accuracy, only if pointwise True
721
722
  pareto_k: array of Pareto shape values, only if pointwise True
722
723
  scale: scale of the elpd
@@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
785
786
  log_weights += log_likelihood
786
787
 
787
788
  warn_mg = False
788
- if np.any(pareto_shape > 0.7):
789
+ good_k = min(1 - 1 / np.log10(n_samples), 0.7)
790
+
791
+ if np.any(pareto_shape > good_k):
789
792
  warnings.warn(
790
- "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
791
- "one or more samples. You should consider using a more robust model, this is because "
792
- "importance sampling is less likely to work well if the marginal posterior and "
793
- "LOO posterior are very different. This is more likely to happen with a non-robust "
794
- "model and highly influential observations."
793
+ f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
794
+ "for one or more samples. You should consider using a more robust model, this is "
795
+ "because importance sampling is less likely to work well if the marginal posterior "
796
+ "and LOO posterior are very different. This is more likely to happen with a "
797
+ "non-robust model and highly influential observations."
795
798
  )
796
799
  warn_mg = True
797
800
 
@@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
816
819
 
817
820
  if not pointwise:
818
821
  return ELPDData(
819
- data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
820
- index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
822
+ data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
823
+ index=[
824
+ "elpd_loo",
825
+ "se",
826
+ "p_loo",
827
+ "n_samples",
828
+ "n_data_points",
829
+ "warning",
830
+ "scale",
831
+ "good_k",
832
+ ],
821
833
  )
822
834
  if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
823
835
  warnings.warn(
@@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
835
847
  loo_lppd_i.rename("loo_i"),
836
848
  pareto_shape,
837
849
  scale,
850
+ good_k,
838
851
  ],
839
852
  index=[
840
853
  "elpd_loo",
@@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
846
859
  "loo_i",
847
860
  "pareto_k",
848
861
  "scale",
862
+ "good_k",
849
863
  ],
850
864
  )
851
865
 
@@ -879,7 +893,8 @@ def psislw(log_weights, reff=1.0):
879
893
 
880
894
  References
881
895
  ----------
882
- * Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
896
+ * Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
897
+ Learning Research, 25(72):1-58.
883
898
 
884
899
  See Also
885
900
  --------
@@ -1322,7 +1337,7 @@ def summary(
1322
1337
  if labeller is None:
1323
1338
  labeller = BaseLabeller()
1324
1339
  if hdi_prob is None:
1325
- hdi_prob = rcParams["stats.hdi_prob"]
1340
+ hdi_prob = rcParams["stats.ci_prob"]
1326
1341
  elif not 1 >= hdi_prob > 0:
1327
1342
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
1328
1343
 
@@ -454,10 +454,9 @@ POINTWISE_LOO_FMT = """------
454
454
 
455
455
  Pareto k diagnostic values:
456
456
  {{0:>{0}}} {{1:>6}}
457
- (-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
458
- (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
459
- (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
460
- (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
457
+ (-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
458
+ ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
459
+ (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
461
460
  """
462
461
  SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
463
462
 
@@ -488,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
488
487
  base += "\n\nThere has been a warning during the calculation. Please check the results."
489
488
 
490
489
  if kind == "loo" and "pareto_k" in self:
491
- bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
490
+ bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
492
491
  counts, *_ = _histogram(self.pareto_k.values, bins)
493
492
  extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
494
493
  extended = extended.format(
495
- "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
494
+ "Count",
495
+ "Pct.",
496
+ *[*counts, *(counts / np.sum(counts) * 100)],
497
+ self.good_k,
496
498
  )
497
499
  base = "\n".join([base, extended])
498
500
  return base
@@ -42,7 +42,6 @@ from ..helpers import ( # pylint: disable=unused-import
42
42
  draws,
43
43
  eight_schools_params,
44
44
  models,
45
- running_on_ci,
46
45
  )
47
46
 
48
47
 
@@ -1469,7 +1468,7 @@ class TestJSON:
1469
1468
 
1470
1469
 
1471
1470
  @pytest.mark.skipif(
1472
- not (importlib.util.find_spec("datatree") or running_on_ci()),
1471
+ not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1473
1472
  reason="test requires xarray-datatree library",
1474
1473
  )
1475
1474
  class TestDataTree:
@@ -16,7 +16,6 @@ from ..helpers import ( # pylint: disable=unused-import
16
16
  draws,
17
17
  eight_schools_params,
18
18
  importorskip,
19
- running_on_ci,
20
19
  )
21
20
 
22
21
  zarr = importorskip("zarr") # pylint: disable=invalid-name
@@ -1,7 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
2
 
3
- import importlib
4
-
5
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
6
4
  import numpy as np
7
5
  import pytest
@@ -11,13 +9,10 @@ from ...rcparams import rcParams
11
9
  from ...stats import bfmi, mcse, rhat
12
10
  from ...stats.diagnostics import _mc_error, ks_summary
13
11
  from ...utils import Numba
14
- from ..helpers import running_on_ci
12
+ from ..helpers import importorskip
15
13
  from .test_diagnostics import data # pylint: disable=unused-import
16
14
 
17
- pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
18
- (importlib.util.find_spec("numba") is None) and not running_on_ci(),
19
- reason="test requires numba which is not installed",
20
- )
15
+ importorskip("numba")
21
16
 
22
17
  rcParams["data.load"] = "eager"
23
18
 
@@ -6,13 +6,13 @@ from ..helpers import importorskip
6
6
 
7
7
  def test_importorskip_local(monkeypatch):
8
8
  """Test ``importorskip`` run on local machine with non-existent module, which should skip."""
9
- monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False)
9
+ monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
10
10
  with pytest.raises(Skipped):
11
11
  importorskip("non-existent-function")
12
12
 
13
13
 
14
14
  def test_importorskip_ci(monkeypatch):
15
15
  """Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
16
- monkeypatch.setenv("ARVIZ_CI_MACHINE", 1)
16
+ monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
17
17
  with pytest.raises(ModuleNotFoundError):
18
18
  importorskip("non-existent-function")
@@ -1,5 +1,6 @@
1
1
  # pylint: disable=redefined-outer-name
2
2
  import importlib
3
+ import os
3
4
 
4
5
  import numpy as np
5
6
  import pytest
@@ -20,10 +21,10 @@ from ...rcparams import rc_context
20
21
  from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
21
22
  from ...stats.density_utils import get_bins
22
23
  from ...utils import get_coords
23
- from ..helpers import running_on_ci
24
24
 
25
25
  # Check if Bokeh is installed
26
26
  bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
27
+ skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
27
28
 
28
29
 
29
30
  @pytest.mark.parametrize(
@@ -212,10 +213,7 @@ def test_filter_plotter_list_warning():
212
213
  assert len(plotters_filtered) == 5
213
214
 
214
215
 
215
- @pytest.mark.skipif(
216
- not (bokeh_installed or running_on_ci()),
217
- reason="test requires bokeh which is not installed",
218
- )
216
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
219
217
  def test_bokeh_import():
220
218
  """Tests that correct method is returned on bokeh import"""
221
219
  plot = get_plotting_function("plot_dist", "distplot", "bokeh")
@@ -290,10 +288,7 @@ def test_mpl_dealiase_sel_kwargs():
290
288
  assert res["line_color"] == "red"
291
289
 
292
290
 
293
- @pytest.mark.skipif(
294
- not (bokeh_installed or running_on_ci()),
295
- reason="test requires bokeh which is not installed",
296
- )
291
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
297
292
  def test_bokeh_dealiase_sel_kwargs():
298
293
  """Check bokeh dealiase_sel_kwargs behaviour.
299
294
 
@@ -315,10 +310,7 @@ def test_bokeh_dealiase_sel_kwargs():
315
310
  assert res["line_color"] == "red"
316
311
 
317
312
 
318
- @pytest.mark.skipif(
319
- not (bokeh_installed or running_on_ci()),
320
- reason="test requires bokeh which is not installed",
321
- )
313
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
322
314
  def test_set_bokeh_circular_ticks_labels():
323
315
  """Assert the axes returned after placing ticks and tick labels for circular plots."""
324
316
  import bokeh.plotting as bkp