arviz 0.17.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. arviz/__init__.py +4 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +20 -22
  8. arviz/data/io_cmdstan.py +5 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +2 -1
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/utils.py +1 -0
  16. arviz/plots/__init__.py +1 -0
  17. arviz/plots/autocorrplot.py +1 -0
  18. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/bpvplot.py +1 -0
  20. arviz/plots/backends/bokeh/compareplot.py +1 -0
  21. arviz/plots/backends/bokeh/densityplot.py +1 -0
  22. arviz/plots/backends/bokeh/distplot.py +1 -0
  23. arviz/plots/backends/bokeh/dotplot.py +1 -0
  24. arviz/plots/backends/bokeh/ecdfplot.py +2 -2
  25. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  26. arviz/plots/backends/bokeh/energyplot.py +1 -0
  27. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  28. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  29. arviz/plots/backends/bokeh/khatplot.py +9 -3
  30. arviz/plots/backends/bokeh/lmplot.py +1 -0
  31. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  32. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  33. arviz/plots/backends/bokeh/pairplot.py +3 -6
  34. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  35. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  36. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  37. arviz/plots/backends/bokeh/rankplot.py +1 -0
  38. arviz/plots/backends/bokeh/separationplot.py +1 -0
  39. arviz/plots/backends/bokeh/traceplot.py +1 -0
  40. arviz/plots/backends/bokeh/violinplot.py +1 -0
  41. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  42. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  43. arviz/plots/backends/matplotlib/compareplot.py +1 -0
  44. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  45. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  46. arviz/plots/backends/matplotlib/distplot.py +1 -0
  47. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  48. arviz/plots/backends/matplotlib/ecdfplot.py +2 -2
  49. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  50. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  51. arviz/plots/backends/matplotlib/essplot.py +6 -5
  52. arviz/plots/backends/matplotlib/forestplot.py +1 -0
  53. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  54. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  55. arviz/plots/backends/matplotlib/khatplot.py +8 -3
  56. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  57. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  58. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  59. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  60. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  61. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  62. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  63. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  64. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  65. arviz/plots/backends/matplotlib/traceplot.py +2 -1
  66. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  67. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  68. arviz/plots/bpvplot.py +3 -2
  69. arviz/plots/compareplot.py +1 -0
  70. arviz/plots/densityplot.py +2 -1
  71. arviz/plots/distcomparisonplot.py +1 -0
  72. arviz/plots/dotplot.py +3 -2
  73. arviz/plots/ecdfplot.py +206 -89
  74. arviz/plots/elpdplot.py +1 -0
  75. arviz/plots/energyplot.py +1 -0
  76. arviz/plots/essplot.py +3 -2
  77. arviz/plots/forestplot.py +2 -1
  78. arviz/plots/hdiplot.py +3 -2
  79. arviz/plots/khatplot.py +24 -6
  80. arviz/plots/lmplot.py +1 -0
  81. arviz/plots/loopitplot.py +3 -2
  82. arviz/plots/mcseplot.py +4 -1
  83. arviz/plots/pairplot.py +1 -0
  84. arviz/plots/parallelplot.py +1 -0
  85. arviz/plots/plot_utils.py +3 -4
  86. arviz/plots/posteriorplot.py +2 -1
  87. arviz/plots/ppcplot.py +1 -0
  88. arviz/plots/rankplot.py +3 -2
  89. arviz/plots/separationplot.py +1 -0
  90. arviz/plots/traceplot.py +1 -0
  91. arviz/plots/tsplot.py +1 -0
  92. arviz/plots/violinplot.py +2 -1
  93. arviz/preview.py +17 -0
  94. arviz/rcparams.py +28 -2
  95. arviz/sel_utils.py +1 -0
  96. arviz/static/css/style.css +2 -1
  97. arviz/stats/density_utils.py +2 -1
  98. arviz/stats/diagnostics.py +15 -11
  99. arviz/stats/ecdf_utils.py +12 -8
  100. arviz/stats/stats.py +31 -16
  101. arviz/stats/stats_refitting.py +1 -0
  102. arviz/stats/stats_utils.py +13 -7
  103. arviz/tests/base_tests/test_data.py +15 -2
  104. arviz/tests/base_tests/test_data_zarr.py +0 -1
  105. arviz/tests/base_tests/test_diagnostics.py +1 -0
  106. arviz/tests/base_tests/test_diagnostics_numba.py +2 -6
  107. arviz/tests/base_tests/test_helpers.py +2 -2
  108. arviz/tests/base_tests/test_labels.py +1 -0
  109. arviz/tests/base_tests/test_plot_utils.py +5 -13
  110. arviz/tests/base_tests/test_plots_matplotlib.py +98 -7
  111. arviz/tests/base_tests/test_rcparams.py +12 -0
  112. arviz/tests/base_tests/test_stats.py +5 -5
  113. arviz/tests/base_tests/test_stats_numba.py +2 -7
  114. arviz/tests/base_tests/test_stats_utils.py +1 -0
  115. arviz/tests/base_tests/test_utils.py +3 -2
  116. arviz/tests/base_tests/test_utils_numba.py +2 -5
  117. arviz/tests/external_tests/test_data_pystan.py +5 -5
  118. arviz/tests/helpers.py +18 -10
  119. arviz/utils.py +4 -0
  120. arviz/wrappers/__init__.py +1 -0
  121. {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/METADATA +13 -9
  122. arviz-0.19.0.dist-info/RECORD +183 -0
  123. arviz-0.17.1.dist-info/RECORD +0 -182
  124. {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/LICENSE +0 -0
  125. {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/WHEEL +0 -0
  126. {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/top_level.txt +0 -0
arviz/plots/rankplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Histograms of ranked posterior draws, plotted for each chain."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -45,8 +46,8 @@ def plot_rank(
45
46
  indicates good mixing of the chains.
46
47
 
47
48
  This plot was introduced by Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter,
48
- Paul-Christian Burkner (2019): Rank-normalization, folding, and localization: An improved R-hat
49
- for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008
49
+ Paul-Christian Burkner (2021): Rank-normalization, folding, and localization:
50
+ An improved R-hat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.
50
51
 
51
52
 
52
53
  Parameters
@@ -1,4 +1,5 @@
1
1
  """Separation plot for discrete outcome models."""
2
+
2
3
  import warnings
3
4
 
4
5
  import numpy as np
arviz/plots/traceplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot kde or histograms and values from MCMC samples."""
2
+
2
3
  import warnings
3
4
  from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence
4
5
 
arviz/plots/tsplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot timeseries data."""
2
+
2
3
  import warnings
3
4
  import numpy as np
4
5
 
arviz/plots/violinplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot posterior traces as violin plot."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
@@ -151,7 +152,7 @@ def plot_violin(
151
152
  rows, cols = default_grid(len(plotters), grid=grid)
152
153
 
153
154
  if hdi_prob is None:
154
- hdi_prob = rcParams["stats.hdi_prob"]
155
+ hdi_prob = rcParams["stats.ci_prob"]
155
156
  elif not 1 >= hdi_prob > 0:
156
157
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
157
158
 
arviz/preview.py ADDED
@@ -0,0 +1,17 @@
1
+ # pylint: disable=unused-import,unused-wildcard-import,wildcard-import
2
+ """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
+
4
+ try:
5
+ from arviz_base import *
6
+ except ModuleNotFoundError:
7
+ pass
8
+
9
+ try:
10
+ import arviz_stats
11
+ except ModuleNotFoundError:
12
+ pass
13
+
14
+ try:
15
+ from arviz_plots import *
16
+ except ModuleNotFoundError:
17
+ pass
arviz/rcparams.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """ArviZ rcparams. Based on matplotlib's implementation."""
2
+
2
3
  import locale
3
4
  import logging
4
5
  import os
@@ -25,6 +26,8 @@ _log = logging.getLogger(__name__)
25
26
  ScaleKeyword = Literal["log", "negative_log", "deviance"]
26
27
  ICKeyword = Literal["loo", "waic"]
27
28
 
29
+ _identity = lambda x: x
30
+
28
31
 
29
32
  def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
30
33
  """Validate value is in accepted_values.
@@ -299,7 +302,7 @@ defaultParams = { # pylint: disable=invalid-name
299
302
  lambda x: x,
300
303
  ),
301
304
  "plot.matplotlib.show": (False, _validate_boolean),
302
- "stats.hdi_prob": (0.94, _validate_probability),
305
+ "stats.ci_prob": (0.94, _validate_probability),
303
306
  "stats.information_criterion": (
304
307
  "loo",
305
308
  _make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
@@ -317,6 +320,9 @@ defaultParams = { # pylint: disable=invalid-name
317
320
  ),
318
321
  }
319
322
 
323
+ # map from deprecated params to (version, new_param, fold2new, fnew2old)
324
+ deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
325
+
320
326
 
321
327
  class RcParams(MutableMapping):
322
328
  """Class to contain ArviZ default parameters.
@@ -334,6 +340,15 @@ class RcParams(MutableMapping):
334
340
 
335
341
  def __setitem__(self, key, val):
336
342
  """Add validation to __setitem__ function."""
343
+ if key in deprecated_map:
344
+ version, key_new, fold2new, _ = deprecated_map[key]
345
+ warnings.warn(
346
+ f"{key} is deprecated since {version}, use {key_new} instead",
347
+ FutureWarning,
348
+ )
349
+ key = key_new
350
+ val = fold2new(val)
351
+
337
352
  try:
338
353
  try:
339
354
  cval = self.validate[key](val)
@@ -348,7 +363,18 @@ class RcParams(MutableMapping):
348
363
 
349
364
  def __getitem__(self, key):
350
365
  """Use underlying dict's getitem method."""
351
- return self._underlying_storage[key]
366
+ if key in deprecated_map:
367
+ version, key_new, _, fnew2old = deprecated_map[key]
368
+ warnings.warn(
369
+ f"{key} is deprecated since {version}, use {key_new} instead",
370
+ FutureWarning,
371
+ )
372
+ if key not in self._underlying_storage:
373
+ key = key_new
374
+ else:
375
+ fnew2old = _identity
376
+
377
+ return fnew2old(self._underlying_storage[key])
352
378
 
353
379
  def __delitem__(self, key):
354
380
  """Raise TypeError if someone ever tries to delete a key from RcParams."""
arviz/sel_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utilities for selecting and iterating on xarray objects."""
2
+
2
3
  from itertools import product, tee
3
4
 
4
5
  import numpy as np
@@ -302,7 +302,8 @@ dl.xr-attrs {
302
302
  grid-template-columns: 125px auto;
303
303
  }
304
304
 
305
- .xr-attrs dt, dd {
305
+ .xr-attrs dt,
306
+ .xr-attrs dd {
306
307
  padding: 0;
307
308
  margin: 0;
308
309
  float: left;
@@ -5,7 +5,8 @@ import warnings
5
5
  import numpy as np
6
6
  from scipy.fftpack import fft
7
7
  from scipy.optimize import brentq
8
- from scipy.signal import convolve, convolve2d, gaussian # pylint: disable=no-name-in-module
8
+ from scipy.signal import convolve, convolve2d
9
+ from scipy.signal.windows import gaussian
9
10
  from scipy.sparse import coo_matrix
10
11
  from scipy.special import ive # pylint: disable=no-name-in-module
11
12
 
@@ -135,10 +135,11 @@ def ess(
135
135
 
136
136
  References
137
137
  ----------
138
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
139
- * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
140
- Section 15.4.2
141
- * Gelman et al. BDA (2014) Formula 11.8
138
+ * Vehtari et al. (2021). Rank-normalization, folding, and
139
+ localization: An improved Rhat for assessing convergence of
140
+ MCMC. Bayesian analysis, 16(2):667-718.
141
+ * https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section
142
+ * Gelman et al. BDA3 (2013) Formula 11.8
142
143
 
143
144
  See Also
144
145
  --------
@@ -246,7 +247,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
246
247
  Names of variables to include in the rhat report
247
248
  method : str
248
249
  Select R-hat method. Valid methods are:
249
- - "rank" # recommended by Vehtari et al. (2019)
250
+ - "rank" # recommended by Vehtari et al. (2021)
250
251
  - "split"
251
252
  - "folded"
252
253
  - "z_scale"
@@ -269,7 +270,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
269
270
  -----
270
271
  The diagnostic is computed by:
271
272
 
272
- .. math:: \hat{R} = \frac{\hat{V}}{W}
273
+ .. math:: \hat{R} = \sqrt{\frac{\hat{V}}{W}}
273
274
 
274
275
  where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
275
276
  estimate for the pooled rank-traces. This is the potential scale reduction factor, which
@@ -277,12 +278,15 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
277
278
  greater than one indicate that one or more chains have not yet converged.
278
279
 
279
280
  Rank values are calculated over all the chains with ``scipy.stats.rankdata``.
280
- Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
281
+ Each chain is split in two and normalized with the z-transform following
282
+ Vehtari et al. (2021).
281
283
 
282
284
  References
283
285
  ----------
284
- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
285
- * Gelman et al. BDA (2014)
286
+ * Vehtari et al. (2021). Rank-normalization, folding, and
287
+ localization: An improved Rhat for assessing convergence of
288
+ MCMC. Bayesian analysis, 16(2):667-718.
289
+ * Gelman et al. BDA3 (2013)
286
290
  * Brooks and Gelman (1998)
287
291
  * Gelman and Rubin (1992)
288
292
 
@@ -836,7 +840,7 @@ def _mcse_sd(ary):
836
840
  return np.nan
837
841
  ess = _ess_sd(ary)
838
842
  if _numba_flag:
839
- sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
843
+ sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
840
844
  else:
841
845
  sd = np.std(ary, ddof=1)
842
846
  fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
@@ -904,7 +908,7 @@ def _mc_error(ary, batches=5, circular=False):
904
908
  else:
905
909
  std = stats.circstd(ary, high=np.pi, low=-np.pi)
906
910
  elif _numba_flag:
907
- std = float(_sqrt(svar(ary), np.zeros(1)))
911
+ std = float(_sqrt(svar(ary), np.zeros(1)).item())
908
912
  else:
909
913
  std = np.std(ary)
910
914
  return std / np.sqrt(len(ary))
arviz/stats/ecdf_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Functions for evaluating ECDFs and their confidence bands."""
2
+
2
3
  from typing import Any, Callable, Optional, Tuple
3
4
  import warnings
4
5
 
@@ -24,6 +25,13 @@ def _get_ecdf_points(
24
25
  return x, y
25
26
 
26
27
 
28
+ def _call_rvs(rvs, ndraws, random_state):
29
+ if random_state is None:
30
+ return rvs(ndraws)
31
+ else:
32
+ return rvs(ndraws, random_state=random_state)
33
+
34
+
27
35
  def _simulate_ecdf(
28
36
  ndraws: int,
29
37
  eval_points: np.ndarray,
@@ -31,7 +39,7 @@ def _simulate_ecdf(
31
39
  random_state: Optional[Any] = None,
32
40
  ) -> np.ndarray:
33
41
  """Simulate ECDF at the `eval_points` using the given random variable sampler"""
34
- sample = rvs(ndraws, random_state=random_state)
42
+ sample = _call_rvs(rvs, ndraws, random_state)
35
43
  sample.sort()
36
44
  return compute_ecdf(sample, eval_points)
37
45
 
@@ -90,14 +98,10 @@ def ecdf_confidence_band(
90
98
  A function that takes an integer `ndraws` and optionally the object passed to
91
99
  `random_state` and returns an array of `ndraws` samples from the same distribution
92
100
  as the original dataset. Required if `method` is "simulated" and variable is discrete.
93
- num_trials : int, default 1000
101
+ num_trials : int, default 500
94
102
  The number of random ECDFs to generate for constructing simultaneous confidence bands
95
103
  (if `method` is "simulated").
96
- random_state : {None, int, `numpy.random.Generator`,
97
- `numpy.random.RandomState`}, optional
98
- If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
99
- ``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
100
- `Generator` instance, the instance is used.
104
+ random_state : int, numpy.random.Generator or numpy.random.RandomState, optional
101
105
 
102
106
  Returns
103
107
  -------
@@ -131,7 +135,7 @@ def _simulate_simultaneous_ecdf_band_probability(
131
135
  cdf_at_eval_points: np.ndarray,
132
136
  prob: float = 0.95,
133
137
  rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
134
- num_trials: int = 1000,
138
+ num_trials: int = 500,
135
139
  random_state: Optional[Any] = None,
136
140
  ) -> float:
137
141
  """Estimate probability for simultaneous confidence band using simulation.
arviz/stats/stats.py CHANGED
@@ -270,12 +270,12 @@ def compare(
270
270
  weights[i] = u_weights / np.sum(u_weights)
271
271
 
272
272
  weights = weights.mean(axis=0)
273
- ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member
273
+ ses = pd.Series(z_bs.std(axis=0), index=ics.index) # pylint: disable=no-member
274
274
 
275
275
  elif method.lower() == "pseudo-bma":
276
276
  min_ic = ics.iloc[0][f"elpd_{ic}"]
277
277
  z_rv = np.exp((ics[f"elpd_{ic}"] - min_ic) / scale_value)
278
- weights = z_rv / np.sum(z_rv)
278
+ weights = (z_rv / np.sum(z_rv)).to_numpy()
279
279
  ses = ics["se"]
280
280
 
281
281
  if np.any(weights):
@@ -471,7 +471,7 @@ def hdi(
471
471
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
472
472
  hdi_prob: float, optional
473
473
  Prob for which the highest density interval will be computed. Defaults to
474
- ``stats.hdi_prob`` rcParam.
474
+ ``stats.ci_prob`` rcParam.
475
475
  circular: bool, optional
476
476
  Whether to compute the hdi taking into account `x` is a circular variable
477
477
  (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
@@ -553,7 +553,7 @@ def hdi(
553
553
 
554
554
  """
555
555
  if hdi_prob is None:
556
- hdi_prob = rcParams["stats.hdi_prob"]
556
+ hdi_prob = rcParams["stats.ci_prob"]
557
557
  elif not 1 >= hdi_prob > 0:
558
558
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
559
559
 
@@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
715
715
  se: standard error of the elpd
716
716
  p_loo: effective number of parameters
717
717
  shape_warn: bool
718
- True if the estimated shape parameter of
719
- Pareto distribution is greater than 0.7 for one or more samples
718
+ True if the estimated shape parameter of Pareto distribution is greater than a thresold
719
+ value for one or more samples. For a sample size S, the thresold is compute as
720
+ min(1 - 1/log10(S), 0.7)
720
721
  loo_i: array of pointwise predictive accuracy, only if pointwise True
721
722
  pareto_k: array of Pareto shape values, only if pointwise True
722
723
  scale: scale of the elpd
@@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
785
786
  log_weights += log_likelihood
786
787
 
787
788
  warn_mg = False
788
- if np.any(pareto_shape > 0.7):
789
+ good_k = min(1 - 1 / np.log10(n_samples), 0.7)
790
+
791
+ if np.any(pareto_shape > good_k):
789
792
  warnings.warn(
790
- "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
791
- "one or more samples. You should consider using a more robust model, this is because "
792
- "importance sampling is less likely to work well if the marginal posterior and "
793
- "LOO posterior are very different. This is more likely to happen with a non-robust "
794
- "model and highly influential observations."
793
+ f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
794
+ "for one or more samples. You should consider using a more robust model, this is "
795
+ "because importance sampling is less likely to work well if the marginal posterior "
796
+ "and LOO posterior are very different. This is more likely to happen with a "
797
+ "non-robust model and highly influential observations."
795
798
  )
796
799
  warn_mg = True
797
800
 
@@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
816
819
 
817
820
  if not pointwise:
818
821
  return ELPDData(
819
- data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
820
- index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
822
+ data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
823
+ index=[
824
+ "elpd_loo",
825
+ "se",
826
+ "p_loo",
827
+ "n_samples",
828
+ "n_data_points",
829
+ "warning",
830
+ "scale",
831
+ "good_k",
832
+ ],
821
833
  )
822
834
  if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
823
835
  warnings.warn(
@@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
835
847
  loo_lppd_i.rename("loo_i"),
836
848
  pareto_shape,
837
849
  scale,
850
+ good_k,
838
851
  ],
839
852
  index=[
840
853
  "elpd_loo",
@@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
846
859
  "loo_i",
847
860
  "pareto_k",
848
861
  "scale",
862
+ "good_k",
849
863
  ],
850
864
  )
851
865
 
@@ -879,7 +893,8 @@ def psislw(log_weights, reff=1.0):
879
893
 
880
894
  References
881
895
  ----------
882
- * Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
896
+ * Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine
897
+ Learning Research, 25(72):1-58.
883
898
 
884
899
  See Also
885
900
  --------
@@ -1322,7 +1337,7 @@ def summary(
1322
1337
  if labeller is None:
1323
1338
  labeller = BaseLabeller()
1324
1339
  if hdi_prob is None:
1325
- hdi_prob = rcParams["stats.hdi_prob"]
1340
+ hdi_prob = rcParams["stats.ci_prob"]
1326
1341
  elif not 1 >= hdi_prob > 0:
1327
1342
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
1328
1343
 
@@ -1,4 +1,5 @@
1
1
  """Stats functions that require refitting the model."""
2
+
2
3
  import logging
3
4
  import warnings
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Stats-utility functions for ArviZ."""
2
+
2
3
  import warnings
3
4
  from collections.abc import Sequence
4
5
  from copy import copy as _copy
@@ -134,7 +135,10 @@ def make_ufunc(
134
135
  raise TypeError(msg)
135
136
  for idx in np.ndindex(out.shape[:n_dims_out]):
136
137
  arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
137
- out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
138
+ out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
139
+ if n_dims_out is None:
140
+ out_idx = out_idx.item()
141
+ out[idx] = out_idx
138
142
  return out
139
143
 
140
144
  def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
@@ -450,10 +454,9 @@ POINTWISE_LOO_FMT = """------
450
454
 
451
455
  Pareto k diagnostic values:
452
456
  {{0:>{0}}} {{1:>6}}
453
- (-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
454
- (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
455
- (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
456
- (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
457
+ (-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
458
+ ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
459
+ (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
457
460
  """
458
461
  SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
459
462
 
@@ -484,11 +487,14 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
484
487
  base += "\n\nThere has been a warning during the calculation. Please check the results."
485
488
 
486
489
  if kind == "loo" and "pareto_k" in self:
487
- bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
490
+ bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
488
491
  counts, *_ = _histogram(self.pareto_k.values, bins)
489
492
  extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
490
493
  extended = extended.format(
491
- "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
494
+ "Count",
495
+ "Pct.",
496
+ *[*counts, *(counts / np.sum(counts) * 100)],
497
+ self.good_k,
492
498
  )
493
499
  base = "\n".join([base, extended])
494
500
  return base
@@ -42,7 +42,6 @@ from ..helpers import ( # pylint: disable=unused-import
42
42
  draws,
43
43
  eight_schools_params,
44
44
  models,
45
- running_on_ci,
46
45
  )
47
46
 
48
47
 
@@ -1077,6 +1076,20 @@ def test_dict_to_dataset():
1077
1076
  assert set(dataset.b.coords) == {"chain", "draw", "c"}
1078
1077
 
1079
1078
 
1079
+ def test_nested_dict_to_dataset():
1080
+ datadict = {
1081
+ "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
1082
+ "d": np.random.randn(100),
1083
+ }
1084
+ dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
1085
+ assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
1086
+ assert set(dataset.coords) == {"chain", "draw", "c"}
1087
+
1088
+ assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
1089
+ assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
1090
+ assert set(dataset.d.coords) == {"chain", "draw"}
1091
+
1092
+
1080
1093
  def test_dict_to_dataset_event_dims_error():
1081
1094
  datadict = {"a": np.random.randn(1, 100, 10)}
1082
1095
  coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
@@ -1455,7 +1468,7 @@ class TestJSON:
1455
1468
 
1456
1469
 
1457
1470
  @pytest.mark.skipif(
1458
- not (importlib.util.find_spec("datatree") or running_on_ci()),
1471
+ not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1459
1472
  reason="test requires xarray-datatree library",
1460
1473
  )
1461
1474
  class TestDataTree:
@@ -16,7 +16,6 @@ from ..helpers import ( # pylint: disable=unused-import
16
16
  draws,
17
17
  eight_schools_params,
18
18
  importorskip,
19
- running_on_ci,
20
19
  )
21
20
 
22
21
  zarr = importorskip("zarr") # pylint: disable=invalid-name
@@ -1,4 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
+
2
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
3
4
  import os
4
5
 
@@ -1,5 +1,4 @@
1
1
  """Test Diagnostic methods"""
2
- import importlib
3
2
 
4
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
5
4
  import numpy as np
@@ -10,13 +9,10 @@ from ...rcparams import rcParams
10
9
  from ...stats import bfmi, mcse, rhat
11
10
  from ...stats.diagnostics import _mc_error, ks_summary
12
11
  from ...utils import Numba
13
- from ..helpers import running_on_ci
12
+ from ..helpers import importorskip
14
13
  from .test_diagnostics import data # pylint: disable=unused-import
15
14
 
16
- pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
17
- (importlib.util.find_spec("numba") is None) and not running_on_ci(),
18
- reason="test requires numba which is not installed",
19
- )
15
+ importorskip("numba")
20
16
 
21
17
  rcParams["data.load"] = "eager"
22
18
 
@@ -6,13 +6,13 @@ from ..helpers import importorskip
6
6
 
7
7
  def test_importorskip_local(monkeypatch):
8
8
  """Test ``importorskip`` run on local machine with non-existent module, which should skip."""
9
- monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False)
9
+ monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
10
10
  with pytest.raises(Skipped):
11
11
  importorskip("non-existent-function")
12
12
 
13
13
 
14
14
  def test_importorskip_ci(monkeypatch):
15
15
  """Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
16
- monkeypatch.setenv("ARVIZ_CI_MACHINE", 1)
16
+ monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
17
17
  with pytest.raises(ModuleNotFoundError):
18
18
  importorskip("non-existent-function")
@@ -1,4 +1,5 @@
1
1
  """Tests for labeller classes."""
2
+
2
3
  import pytest
3
4
 
4
5
  from ...labels import (
@@ -1,5 +1,6 @@
1
1
  # pylint: disable=redefined-outer-name
2
2
  import importlib
3
+ import os
3
4
 
4
5
  import numpy as np
5
6
  import pytest
@@ -20,10 +21,10 @@ from ...rcparams import rc_context
20
21
  from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
21
22
  from ...stats.density_utils import get_bins
22
23
  from ...utils import get_coords
23
- from ..helpers import running_on_ci
24
24
 
25
25
  # Check if Bokeh is installed
26
26
  bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
27
+ skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
27
28
 
28
29
 
29
30
  @pytest.mark.parametrize(
@@ -212,10 +213,7 @@ def test_filter_plotter_list_warning():
212
213
  assert len(plotters_filtered) == 5
213
214
 
214
215
 
215
- @pytest.mark.skipif(
216
- not (bokeh_installed or running_on_ci()),
217
- reason="test requires bokeh which is not installed",
218
- )
216
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
219
217
  def test_bokeh_import():
220
218
  """Tests that correct method is returned on bokeh import"""
221
219
  plot = get_plotting_function("plot_dist", "distplot", "bokeh")
@@ -290,10 +288,7 @@ def test_mpl_dealiase_sel_kwargs():
290
288
  assert res["line_color"] == "red"
291
289
 
292
290
 
293
- @pytest.mark.skipif(
294
- not (bokeh_installed or running_on_ci()),
295
- reason="test requires bokeh which is not installed",
296
- )
291
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
297
292
  def test_bokeh_dealiase_sel_kwargs():
298
293
  """Check bokeh dealiase_sel_kwargs behaviour.
299
294
 
@@ -315,10 +310,7 @@ def test_bokeh_dealiase_sel_kwargs():
315
310
  assert res["line_color"] == "red"
316
311
 
317
312
 
318
- @pytest.mark.skipif(
319
- not (bokeh_installed or running_on_ci()),
320
- reason="test requires bokeh which is not installed",
321
- )
313
+ @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
322
314
  def test_set_bokeh_circular_ticks_labels():
323
315
  """Assert the axes returned after placing ticks and tick labels for circular plots."""
324
316
  import bokeh.plotting as bkp