arviz 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. arviz/__init__.py +8 -3
  2. arviz/data/base.py +2 -2
  3. arviz/data/inference_data.py +57 -26
  4. arviz/data/io_datatree.py +2 -2
  5. arviz/data/io_numpyro.py +112 -4
  6. arviz/plots/autocorrplot.py +12 -2
  7. arviz/plots/backends/__init__.py +8 -7
  8. arviz/plots/backends/bokeh/bpvplot.py +4 -3
  9. arviz/plots/backends/bokeh/densityplot.py +5 -1
  10. arviz/plots/backends/bokeh/dotplot.py +5 -2
  11. arviz/plots/backends/bokeh/essplot.py +4 -2
  12. arviz/plots/backends/bokeh/forestplot.py +11 -4
  13. arviz/plots/backends/bokeh/hdiplot.py +7 -6
  14. arviz/plots/backends/bokeh/khatplot.py +4 -2
  15. arviz/plots/backends/bokeh/lmplot.py +28 -6
  16. arviz/plots/backends/bokeh/mcseplot.py +2 -2
  17. arviz/plots/backends/bokeh/pairplot.py +27 -52
  18. arviz/plots/backends/bokeh/ppcplot.py +2 -1
  19. arviz/plots/backends/bokeh/rankplot.py +2 -1
  20. arviz/plots/backends/bokeh/traceplot.py +2 -1
  21. arviz/plots/backends/bokeh/violinplot.py +2 -1
  22. arviz/plots/backends/matplotlib/bpvplot.py +2 -1
  23. arviz/plots/backends/matplotlib/khatplot.py +8 -1
  24. arviz/plots/backends/matplotlib/lmplot.py +13 -7
  25. arviz/plots/backends/matplotlib/pairplot.py +14 -22
  26. arviz/plots/bfplot.py +9 -26
  27. arviz/plots/bpvplot.py +10 -1
  28. arviz/plots/hdiplot.py +5 -0
  29. arviz/plots/kdeplot.py +4 -4
  30. arviz/plots/lmplot.py +41 -14
  31. arviz/plots/pairplot.py +10 -3
  32. arviz/plots/plot_utils.py +5 -3
  33. arviz/preview.py +36 -5
  34. arviz/stats/__init__.py +1 -0
  35. arviz/stats/density_utils.py +1 -1
  36. arviz/stats/diagnostics.py +18 -14
  37. arviz/stats/stats.py +105 -7
  38. arviz/tests/base_tests/test_data.py +31 -11
  39. arviz/tests/base_tests/test_diagnostics.py +5 -4
  40. arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  41. arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
  42. arviz/tests/base_tests/test_stats.py +53 -1
  43. arviz/tests/external_tests/test_data_numpyro.py +130 -3
  44. arviz/utils.py +4 -0
  45. arviz/wrappers/base.py +1 -1
  46. arviz/wrappers/wrap_stan.py +1 -1
  47. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
  48. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/RECORD +51 -51
  49. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +1 -1
  50. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
  51. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
arviz/plots/lmplot.py CHANGED
@@ -300,20 +300,47 @@ def plot_lm(
300
300
  # Filter out the required values to generate plotters
301
301
  if y_model is not None:
302
302
  if kind_model == "lines":
303
- y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
304
-
305
- y_model = [
306
- tup
307
- for _, tup in zip(
308
- range(len_y),
309
- xarray_var_iter(
310
- y_model,
311
- skip_dims=set(y_model.dims),
312
- combined=True,
313
- ),
314
- )
315
- ]
316
- y_model = _repeat_flatten_list(y_model, len_x)
303
+ var_name = y_model.name if y_model.name else "y_model"
304
+ data = y_model.values
305
+
306
+ total_samples = data.shape[0] * data.shape[1]
307
+ data = data.reshape(total_samples, *data.shape[2:])
308
+
309
+ if pp_sample_ix is not None:
310
+ data = data[pp_sample_ix]
311
+
312
+ if plot_dim is not None:
313
+ # For plot_dim case, transpose to get dimension first
314
+ data = data.transpose(1, 0, 2)[..., 0]
315
+
316
+ # Create plotter tuple(s)
317
+ if plot_dim is not None:
318
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
319
+ else:
320
+ y_model = [(var_name, {}, {}, data)]
321
+ y_model = _repeat_flatten_list(y_model, len_x)
322
+
323
+ elif kind_model == "hdi":
324
+ var_name = y_model.name if y_model.name else "y_model"
325
+ data = y_model.values
326
+
327
+ if plot_dim is not None:
328
+ # First transpose to get plot_dim first
329
+ data = data.transpose(2, 0, 1, 3)
330
+ # For plot_dim case, we just want HDI for first dimension
331
+ data = data[..., 0]
332
+
333
+ # Reshape to (samples, points)
334
+ data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
335
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
336
+
337
+ else:
338
+ data = data.reshape(-1, data.shape[-1])
339
+ y_model = [(var_name, {}, {}, data)]
340
+ y_model = _repeat_flatten_list(y_model, len_x)
341
+
342
+ if len(y_model) == 1:
343
+ y_model = _repeat_flatten_list(y_model, len_x)
317
344
 
318
345
  rows, cols = default_grid(length_plotters)
319
346
 
arviz/plots/pairplot.py CHANGED
@@ -196,9 +196,14 @@ def plot_pair(
196
196
  get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197
197
  )
198
198
  )
199
- flat_var_names = [
200
- labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
201
- ]
199
+ flat_var_names = []
200
+ flat_ref_slices = []
201
+ flat_var_labels = []
202
+ for var_name, sel, isel, _ in plotters:
203
+ dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
204
+ flat_var_names.append(var_name)
205
+ flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
206
+ flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
202
207
 
203
208
  divergent_data = None
204
209
  diverging_mask = None
@@ -253,6 +258,8 @@ def plot_pair(
253
258
  diverging_mask=diverging_mask,
254
259
  divergences_kwargs=divergences_kwargs,
255
260
  flat_var_names=flat_var_names,
261
+ flat_ref_slices=flat_ref_slices,
262
+ flat_var_labels=flat_var_labels,
256
263
  backend_kwargs=backend_kwargs,
257
264
  marginal_kwargs=marginal_kwargs,
258
265
  show=show,
arviz/plots/plot_utils.py CHANGED
@@ -482,16 +482,18 @@ def plot_point_interval(
482
482
  if point_estimate:
483
483
  point_value = calculate_point_estimate(point_estimate, values)
484
484
  if rotated:
485
- ax.circle(
485
+ ax.scatter(
486
486
  x=0,
487
487
  y=point_value,
488
+ marker="circle",
488
489
  size=markersize,
489
490
  fill_color=markercolor,
490
491
  )
491
492
  else:
492
- ax.circle(
493
+ ax.scatter(
493
494
  x=point_value,
494
495
  y=0,
496
+ marker="circle",
495
497
  size=markersize,
496
498
  fill_color=markercolor,
497
499
  )
@@ -534,7 +536,7 @@ def set_bokeh_circular_ticks_labels(ax, hist, labels):
534
536
  )
535
537
 
536
538
  radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
537
- ax.circle(0, 0, radius=radii_circles, fill_color=None, line_color="grey")
539
+ ax.scatter(0, 0, marker="circle", radius=radii_circles, fill_color=None, line_color="grey")
538
540
 
539
541
  offset = np.max(hist * 1.05) * 0.15
540
542
  ticks_labels_pos_1 = np.max(hist * 1.05)
arviz/preview.py CHANGED
@@ -1,17 +1,48 @@
1
- # pylint: disable=unused-import,unused-wildcard-import,wildcard-import
1
+ # pylint: disable=unused-import,unused-wildcard-import,wildcard-import,invalid-name
2
2
  """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
+ import logging
4
+
5
+ _log = logging.getLogger(__name__)
6
+
7
+ info = ""
3
8
 
4
9
  try:
5
10
  from arviz_base import *
11
+
12
+ status = "arviz_base available, exposing its functions as part of arviz.preview"
13
+ _log.info(status)
6
14
  except ModuleNotFoundError:
7
- pass
15
+ status = "arviz_base not installed"
16
+ _log.info(status)
17
+ except ImportError:
18
+ status = "Unable to import arviz_base"
19
+ _log.info(status, exc_info=True)
20
+
21
+ info += status + "\n"
8
22
 
9
23
  try:
10
- import arviz_stats
24
+ from arviz_stats import *
25
+
26
+ status = "arviz_stats available, exposing its functions as part of arviz.preview"
27
+ _log.info(status)
11
28
  except ModuleNotFoundError:
12
- pass
29
+ status = "arviz_stats not installed"
30
+ _log.info(status)
31
+ except ImportError:
32
+ status = "Unable to import arviz_stats"
33
+ _log.info(status, exc_info=True)
34
+ info += status + "\n"
13
35
 
14
36
  try:
15
37
  from arviz_plots import *
38
+
39
+ status = "arviz_plots available, exposing its functions as part of arviz.preview"
40
+ _log.info(status)
16
41
  except ModuleNotFoundError:
17
- pass
42
+ status = "arviz_plots not installed"
43
+ _log.info(status)
44
+ except ImportError:
45
+ status = "Unable to import arviz_plots"
46
+ _log.info(status, exc_info=True)
47
+
48
+ info += status + "\n"
arviz/stats/__init__.py CHANGED
@@ -9,6 +9,7 @@ from .stats_utils import *
9
9
 
10
10
  __all__ = [
11
11
  "apply_test_function",
12
+ "bayes_factor",
12
13
  "bfmi",
13
14
  "compare",
14
15
  "hdi",
@@ -635,7 +635,7 @@ def _kde_circular(
635
635
  cumulative: bool, optional
636
636
  Whether return the PDF or the cumulative PDF. Defaults to False.
637
637
  grid_len: int, optional
638
- The number of intervals used to bin the data pointa i.e. the length of the grid used in the
638
+ The number of intervals used to bin the data point i.e. the length of the grid used in the
639
639
  estimation. Defaults to 512.
640
640
  """
641
641
  # All values between -pi and pi
@@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
744
744
  ary = np.asarray(ary)
745
745
  if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
746
746
  return np.nan
747
- ary = _split_chains(ary)
748
- return min(_ess(ary, relative=relative), _ess(ary**2, relative=relative))
747
+ ary = (ary - ary.mean()) ** 2
748
+ return _ess(_split_chains(ary), relative=relative)
749
749
 
750
750
 
751
751
  def _ess_quantile(ary, prob, relative=False):
@@ -838,13 +838,15 @@ def _mcse_sd(ary):
838
838
  ary = np.asarray(ary)
839
839
  if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
840
840
  return np.nan
841
- ess = _ess_sd(ary)
841
+ sims_c2 = (ary - ary.mean()) ** 2
842
+ ess = _ess_mean(sims_c2)
843
+ evar = (sims_c2).mean()
844
+ varvar = ((sims_c2**2).mean() - evar**2) / ess
845
+ varsd = varvar / evar / 4
842
846
  if _numba_flag:
843
- sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
847
+ mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
844
848
  else:
845
- sd = np.std(ary, ddof=1)
846
- fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
847
- mcse_sd_value = sd * fac_mcse_sd
849
+ mcse_sd_value = np.sqrt(varsd)
848
850
  return mcse_sd_value
849
851
 
850
852
 
@@ -973,19 +975,21 @@ def _multichain_statistics(ary, focus="mean"):
973
975
  # ess mean
974
976
  ess_mean_value = _ess_mean(ary)
975
977
 
976
- # ess sd
977
- ess_sd_value = _ess_sd(ary)
978
-
979
978
  # mcse_mean
980
- sd = np.std(ary, ddof=1)
981
- mcse_mean_value = sd / np.sqrt(ess_mean_value)
979
+ sims_c2 = (ary - ary.mean()) ** 2
980
+ sims_c2_sum = sims_c2.sum()
981
+ var = sims_c2_sum / (sims_c2.size - 1)
982
+ mcse_mean_value = np.sqrt(var / ess_mean_value)
982
983
 
983
984
  # ess bulk
984
985
  ess_bulk_value = _ess(z_split)
985
986
 
986
987
  # mcse_sd
987
- fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
988
- mcse_sd_value = sd * fac_mcse_sd
988
+ evar = sims_c2_sum / sims_c2.size
989
+ ess_mean_sims = _ess_mean(sims_c2)
990
+ varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
991
+ varsd = varvar / evar / 4
992
+ mcse_sd_value = np.sqrt(varsd)
989
993
 
990
994
  return (
991
995
  mcse_mean_value,
arviz/stats/stats.py CHANGED
@@ -27,6 +27,7 @@ from ..utils import Numba, _numba_var, _var_names, get_coords
27
27
  from .density_utils import get_bins as _get_bins
28
28
  from .density_utils import histogram as _histogram
29
29
  from .density_utils import kde as _kde
30
+ from .density_utils import _kde_linear
30
31
  from .diagnostics import _mc_error, _multichain_statistics, ess
31
32
  from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
32
33
  from .stats_utils import get_log_likelihood as _get_log_likelihood
@@ -41,6 +42,7 @@ from ..labels import BaseLabeller
41
42
 
42
43
  __all__ = [
43
44
  "apply_test_function",
45
+ "bayes_factor",
44
46
  "compare",
45
47
  "hdi",
46
48
  "loo",
@@ -867,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
867
869
  )
868
870
 
869
871
 
870
- def psislw(log_weights, reff=1.0):
872
+ def psislw(log_weights, reff=1.0, normalize=True):
871
873
  """
872
874
  Pareto smoothed importance sampling (PSIS).
873
875
 
@@ -885,11 +887,13 @@ def psislw(log_weights, reff=1.0):
885
887
  Array of size (n_observations, n_samples)
886
888
  reff : float, default 1
887
889
  relative MCMC efficiency, ``ess / n``
890
+ normalize : bool, default True
891
+ return normalized log weights
888
892
 
889
893
  Returns
890
894
  -------
891
895
  lw_out : DataArray or (..., N) ndarray
892
- Smoothed, truncated and normalized log weights.
896
+ Smoothed, truncated and possibly normalized log weights.
893
897
  kss : DataArray or (...) ndarray
894
898
  Estimates of the shape parameter *k* of the generalized Pareto
895
899
  distribution.
@@ -934,7 +938,12 @@ def psislw(log_weights, reff=1.0):
934
938
  out = np.empty_like(log_weights), np.empty(shape)
935
939
 
936
940
  # define kwargs
937
- func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
941
+ func_kwargs = {
942
+ "cutoff_ind": cutoff_ind,
943
+ "cutoffmin": cutoffmin,
944
+ "out": out,
945
+ "normalize": normalize,
946
+ }
938
947
  ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
939
948
  kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
940
949
  log_weights, pareto_shape = _wrap_xarray_ufunc(
@@ -951,7 +960,7 @@ def psislw(log_weights, reff=1.0):
951
960
  return log_weights, pareto_shape
952
961
 
953
962
 
954
- def _psislw(log_weights, cutoff_ind, cutoffmin):
963
+ def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
955
964
  """
956
965
  Pareto smoothed importance sampling (PSIS) for a 1D vector.
957
966
 
@@ -961,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
961
970
  Array of length n_observations
962
971
  cutoff_ind: int
963
972
  cutoffmin: float
964
- k_min: float
973
+ normalize: bool
965
974
 
966
975
  Returns
967
976
  -------
@@ -973,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
973
982
  x = np.asarray(log_weights)
974
983
 
975
984
  # improve numerical accuracy
976
- x -= np.max(x)
985
+ max_x = np.max(x)
986
+ x -= max_x
977
987
  # sort the array
978
988
  x_sort_ind = np.argsort(x)
979
989
  # divide log weights into body and right tail
@@ -1005,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
1005
1015
  x[tailinds[x_tail_si]] = smoothed_tail
1006
1016
  # truncate smoothed values to the largest raw weight 0
1007
1017
  x[x > 0] = 0
1018
+
1008
1019
  # renormalize weights
1009
- x -= _logsumexp(x)
1020
+ if normalize:
1021
+ x -= _logsumexp(x)
1022
+ else:
1023
+ x += max_x
1010
1024
 
1011
1025
  return x, k
1012
1026
 
@@ -2337,3 +2351,87 @@ def _cjs_dist(draws, weights):
2337
2351
  bound = cdf_p_int + cdf_q_int
2338
2352
 
2339
2353
  return np.sqrt((cjs_pq + cjs_qp) / bound)
2354
+
2355
+
2356
+ def bayes_factor(idata, var_name, ref_val=0, prior=None, return_ref_vals=False):
2357
+ r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
2358
+
2359
+ The Bayes factor is estimated by comparing a model (H1) against a model in which the
2360
+ parameter of interest has been restricted to be a point-null (H0). This computation
2361
+ assumes the models are nested and thus H0 is a special case of H1.
2362
+
2363
+ Notes
2364
+ -----
2365
+ The bayes Factor is approximated as the Savage-Dickey density ratio
2366
+ algorithm presented in [1]_.
2367
+
2368
+ Parameters
2369
+ ----------
2370
+ idata : InferenceData
2371
+ Any object that can be converted to an :class:`arviz.InferenceData` object
2372
+ Refer to documentation of :func:`arviz.convert_to_dataset` for details.
2373
+ var_name : str, optional
2374
+ Name of variable we want to test.
2375
+ ref_val : int, default 0
2376
+ Point-null for Bayes factor estimation.
2377
+ prior : numpy.array, optional
2378
+ In case we want to use different prior, for example for sensitivity analysis.
2379
+ return_ref_vals : bool, optional
2380
+ Whether to return the values of the prior and posterior at the reference value.
2381
+ Used by :func:`arviz.plot_bf` to display the distribution comparison.
2382
+
2383
+
2384
+ Returns
2385
+ -------
2386
+ dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
2387
+
2388
+ References
2389
+ ----------
2390
+ .. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
2391
+ The case of computing Bayes factors for regression parameters.
2392
+
2393
+ Examples
2394
+ --------
2395
+ Moderate evidence indicating that the parameter "a" is different from zero.
2396
+
2397
+ .. ipython::
2398
+
2399
+ In [1]: import numpy as np
2400
+ ...: import arviz as az
2401
+ ...: idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
2402
+ ...: prior={"a":np.random.normal(0, 1, 5000)})
2403
+ ...: az.bayes_factor(idata, var_name="a", ref_val=0)
2404
+
2405
+ """
2406
+
2407
+ posterior = extract(idata, var_names=var_name).values
2408
+
2409
+ if ref_val > posterior.max() or ref_val < posterior.min():
2410
+ _log.warning(
2411
+ "The reference value is outside of the posterior. "
2412
+ "This translate into infinite support for H1, which is most likely an overstatement."
2413
+ )
2414
+
2415
+ if posterior.ndim > 1:
2416
+ _log.warning("Posterior distribution has {posterior.ndim} dimensions")
2417
+
2418
+ if prior is None:
2419
+ prior = extract(idata, var_names=var_name, group="prior").values
2420
+
2421
+ if posterior.dtype.kind == "f":
2422
+ posterior_grid, posterior_pdf, *_ = _kde_linear(posterior)
2423
+ prior_grid, prior_pdf, *_ = _kde_linear(prior)
2424
+ posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
2425
+ prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)
2426
+
2427
+ elif posterior.dtype.kind == "i":
2428
+ posterior_at_ref_val = (posterior == ref_val).mean()
2429
+ prior_at_ref_val = (prior == ref_val).mean()
2430
+
2431
+ bf_10 = prior_at_ref_val / posterior_at_ref_val
2432
+ bf = {"BF10": bf_10, "BF01": 1 / bf_10}
2433
+
2434
+ if return_ref_vals:
2435
+ return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
2436
+ else:
2437
+ return bf
@@ -3,6 +3,7 @@
3
3
 
4
4
  import importlib
5
5
  import os
6
+ import warnings
6
7
  from collections import namedtuple
7
8
  from copy import deepcopy
8
9
  from html import escape
@@ -32,7 +33,13 @@ from ... import (
32
33
  extract,
33
34
  )
34
35
 
35
- from ...data.base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs
36
+ from ...data.base import (
37
+ dict_to_dataset,
38
+ generate_dims_coords,
39
+ infer_stan_dtypes,
40
+ make_attrs,
41
+ numpy_to_data_array,
42
+ )
36
43
  from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
37
44
  from ..helpers import ( # pylint: disable=unused-import
38
45
  chains,
@@ -231,6 +238,17 @@ def test_dims_coords_skip_event_dims(shape):
231
238
  assert "z" not in coords
232
239
 
233
240
 
241
+ @pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
242
+ def test_numpy_to_data_array_with_dims(dims):
243
+ da = numpy_to_data_array(
244
+ np.empty((4, 500, 7)),
245
+ var_name="a",
246
+ dims=dims,
247
+ default_dims=["chain", "draw"],
248
+ )
249
+ assert list(da.dims) == ["chain", "draw", "a_dim_0"]
250
+
251
+
234
252
  def test_make_attrs():
235
253
  extra_attrs = {"key": "Value"}
236
254
  attrs = make_attrs(attrs=extra_attrs)
@@ -921,7 +939,7 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
921
939
  data = np.random.normal(size=(4, 500, 8))
922
940
  idata = data_random
923
941
  with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
924
- idata.add_groups({"new_group": idata.posterior})
942
+ idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
925
943
  with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
926
944
  idata.add_groups(constant_data={"a": data[..., 0], "b": data})
927
945
  assert idata.new_group.equals(idata.posterior)
@@ -962,8 +980,8 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
962
980
  with pytest.raises(ValueError, match="join must be either"):
963
981
  idata.extend(idata2, join="outer")
964
982
  idata2.add_groups(new_group=idata2.prior)
965
- with pytest.warns(UserWarning):
966
- idata.extend(idata2)
983
+ with pytest.warns(UserWarning, match="new_group"):
984
+ idata.extend(idata2, warn_on_custom_groups=True)
967
985
 
968
986
 
969
987
  class TestNumpyToDataArray:
@@ -1156,11 +1174,17 @@ def test_bad_inference_data():
1156
1174
  InferenceData(posterior=[1, 2, 3])
1157
1175
 
1158
1176
 
1159
- def test_inference_data_other_groups():
1177
+ @pytest.mark.parametrize("warn", [True, False])
1178
+ def test_inference_data_other_groups(warn):
1160
1179
  datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
1161
1180
  dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
1162
- with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
1163
- idata = InferenceData(other_group=dataset)
1181
+ if warn:
1182
+ with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
1183
+ idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
1184
+ else:
1185
+ with warnings.catch_warnings():
1186
+ warnings.simplefilter("error")
1187
+ idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
1164
1188
  fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
1165
1189
  assert not fails
1166
1190
 
@@ -1477,10 +1501,6 @@ class TestJSON:
1477
1501
  assert not os.path.exists(filepath)
1478
1502
 
1479
1503
 
1480
- @pytest.mark.skipif(
1481
- not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1482
- reason="test requires xarray-datatree library",
1483
- )
1484
1504
  class TestDataTree:
1485
1505
  def test_datatree(self):
1486
1506
  idata = load_arviz_data("centered_eight")
@@ -120,10 +120,11 @@ class TestDiagnostics:
120
120
  ```
121
121
  Reference file:
122
122
 
123
- Created: 2020-08-31
124
- System: Ubuntu 18.04.5 LTS
125
- R version 4.0.2 (2020-06-22)
126
- posterior 0.1.2
123
+ Created: 2024-12-20
124
+ System: Ubuntu 24.04.1 LTS
125
+ R version 4.4.2 (2024-10-31)
126
+ posterior version from https://github.com/stan-dev/posterior/pull/388
127
+ (after release 1.6.0 but before the fixes in the PR were released).
127
128
  """
128
129
  # download input files
129
130
  here = os.path.dirname(os.path.abspath(__file__))
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
8
8
  from scipy.stats import norm # pylint: disable=wrong-import-position
9
9
 
10
10
  from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
11
+ from ...labels import MapLabeller # pylint: disable=wrong-import-position
11
12
  from ...plots import ( # pylint: disable=wrong-import-position
12
13
  plot_autocorr,
13
14
  plot_bpv,
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
773
774
  {"divergences": True, "var_names": ["theta", "mu"]},
774
775
  {"kind": "kde", "var_names": ["theta"]},
775
776
  {"kind": "hexbin", "var_names": ["theta"]},
776
- {"kind": "hexbin", "var_names": ["theta"]},
777
777
  {
778
778
  "kind": "hexbin",
779
779
  "var_names": ["theta"],
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
785
785
  "reference_values": {"mu": 0, "tau": 0},
786
786
  "reference_values_kwargs": {"line_color": "blue"},
787
787
  },
788
+ {
789
+ "var_names": ["mu", "tau"],
790
+ "reference_values": {"mu": 0, "tau": 0},
791
+ "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
792
+ },
793
+ {
794
+ "var_names": ["theta"],
795
+ "reference_values": {"theta": [0.0] * 8},
796
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
797
+ },
798
+ {
799
+ "var_names": ["theta"],
800
+ "reference_values": {"theta": np.zeros(8)},
801
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
802
+ },
788
803
  ],
789
804
  )
790
805
  def test_plot_pair(models, kwargs):
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
1201
1216
  },
1202
1217
  ],
1203
1218
  )
1204
- def test_plot_lm(models, kwargs):
1219
+ def test_plot_lm_1d(models, kwargs):
1205
1220
  """Test functionality for 1D data."""
1206
1221
  idata = models.model_1
1207
1222
  if "constant_data" not in idata.groups():
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
1228
1243
  """Test the plots when input data is list or ndarray."""
1229
1244
  y = [1, 2, 3, 4, 5]
1230
1245
  assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
1246
+
1247
+
1248
+ def generate_lm_1d_data():
1249
+ rng = np.random.default_rng()
1250
+ return from_dict(
1251
+ observed_data={"y": rng.normal(size=7)},
1252
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
1253
+ posterior={"y_model": rng.normal(size=(4, 1000, 7))},
1254
+ dims={"y": ["dim1"]},
1255
+ coords={"dim1": range(7)},
1256
+ )
1257
+
1258
+
1259
+ def generate_lm_2d_data():
1260
+ rng = np.random.default_rng()
1261
+ return from_dict(
1262
+ observed_data={"y": rng.normal(size=(5, 7))},
1263
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
1264
+ posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
1265
+ dims={"y": ["dim1", "dim2"]},
1266
+ coords={"dim1": range(5), "dim2": range(7)},
1267
+ )
1268
+
1269
+
1270
+ @pytest.mark.parametrize("data", ("1d", "2d"))
1271
+ @pytest.mark.parametrize("kind", ("lines", "hdi"))
1272
+ @pytest.mark.parametrize("use_y_model", (True, False))
1273
+ def test_plot_lm(data, kind, use_y_model):
1274
+ if data == "1d":
1275
+ idata = generate_lm_1d_data()
1276
+ else:
1277
+ idata = generate_lm_2d_data()
1278
+
1279
+ kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
1280
+ if data == "2d":
1281
+ kwargs["plot_dim"] = "dim1"
1282
+ if use_y_model:
1283
+ kwargs["y_model"] = "y_model"
1284
+ if kind == "lines":
1285
+ kwargs["num_samples"] = 50
1286
+
1287
+ ax = plot_lm(**kwargs)
1288
+ assert ax is not None