arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/inference_data.py +34 -7
  3. arviz/data/io_beanmachine.py +6 -1
  4. arviz/data/io_cmdstanpy.py +439 -50
  5. arviz/data/io_pyjags.py +5 -2
  6. arviz/data/io_pystan.py +1 -2
  7. arviz/labels.py +2 -0
  8. arviz/plots/backends/bokeh/bpvplot.py +7 -2
  9. arviz/plots/backends/bokeh/compareplot.py +7 -4
  10. arviz/plots/backends/bokeh/densityplot.py +0 -1
  11. arviz/plots/backends/bokeh/distplot.py +0 -2
  12. arviz/plots/backends/bokeh/forestplot.py +3 -5
  13. arviz/plots/backends/bokeh/kdeplot.py +0 -2
  14. arviz/plots/backends/bokeh/pairplot.py +0 -4
  15. arviz/plots/backends/matplotlib/bfplot.py +0 -1
  16. arviz/plots/backends/matplotlib/bpvplot.py +3 -3
  17. arviz/plots/backends/matplotlib/compareplot.py +1 -1
  18. arviz/plots/backends/matplotlib/dotplot.py +1 -1
  19. arviz/plots/backends/matplotlib/forestplot.py +2 -4
  20. arviz/plots/backends/matplotlib/kdeplot.py +0 -1
  21. arviz/plots/backends/matplotlib/khatplot.py +0 -1
  22. arviz/plots/backends/matplotlib/lmplot.py +4 -5
  23. arviz/plots/backends/matplotlib/pairplot.py +0 -1
  24. arviz/plots/backends/matplotlib/ppcplot.py +8 -5
  25. arviz/plots/backends/matplotlib/traceplot.py +1 -2
  26. arviz/plots/bfplot.py +7 -6
  27. arviz/plots/bpvplot.py +7 -2
  28. arviz/plots/compareplot.py +2 -2
  29. arviz/plots/ecdfplot.py +37 -112
  30. arviz/plots/elpdplot.py +1 -1
  31. arviz/plots/essplot.py +2 -2
  32. arviz/plots/kdeplot.py +0 -1
  33. arviz/plots/pairplot.py +1 -1
  34. arviz/plots/plot_utils.py +0 -1
  35. arviz/plots/ppcplot.py +51 -45
  36. arviz/plots/separationplot.py +0 -1
  37. arviz/stats/__init__.py +2 -0
  38. arviz/stats/density_utils.py +2 -2
  39. arviz/stats/diagnostics.py +2 -3
  40. arviz/stats/ecdf_utils.py +165 -0
  41. arviz/stats/stats.py +241 -38
  42. arviz/stats/stats_utils.py +36 -7
  43. arviz/tests/base_tests/test_data.py +73 -5
  44. arviz/tests/base_tests/test_plots_bokeh.py +0 -1
  45. arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
  46. arviz/tests/base_tests/test_stats.py +43 -1
  47. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  48. arviz/tests/base_tests/test_stats_utils.py +3 -3
  49. arviz/tests/external_tests/test_data_beanmachine.py +2 -0
  50. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  51. arviz/tests/external_tests/test_data_pyjags.py +3 -1
  52. arviz/tests/external_tests/test_data_pyro.py +3 -3
  53. arviz/tests/helpers.py +8 -8
  54. arviz/utils.py +15 -7
  55. arviz/wrappers/wrap_pymc.py +1 -1
  56. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
  57. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
  58. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
  59. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
  60. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.16.1"
3
+ __version__ = "0.17.1"
4
4
 
5
5
  import logging
6
6
  import os
@@ -1,5 +1,6 @@
1
1
  # pylint: disable=too-many-lines,too-many-public-methods
2
2
  """Data structure for using netcdf groups with xarray."""
3
+ import os
3
4
  import re
4
5
  import sys
5
6
  import uuid
@@ -23,14 +24,13 @@ from typing import (
23
24
  Union,
24
25
  overload,
25
26
  )
26
- import os
27
27
 
28
28
  import numpy as np
29
29
  import xarray as xr
30
30
  from packaging import version
31
31
 
32
32
  from ..rcparams import rcParams
33
- from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
33
+ from ..utils import HtmlTemplate, _subset_list, _var_names, either_dict_or_kwargs
34
34
  from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
35
35
 
36
36
  if sys.version_info[:2] >= (3, 9):
@@ -56,6 +56,7 @@ SUPPORTED_GROUPS = [
56
56
  "posterior_predictive",
57
57
  "predictions",
58
58
  "log_likelihood",
59
+ "log_prior",
59
60
  "sample_stats",
60
61
  "prior",
61
62
  "prior_predictive",
@@ -63,6 +64,8 @@ SUPPORTED_GROUPS = [
63
64
  "observed_data",
64
65
  "constant_data",
65
66
  "predictions_constant_data",
67
+ "unconstrained_posterior",
68
+ "unconstrained_prior",
66
69
  ]
67
70
 
68
71
  WARMUP_TAG = "warmup_"
@@ -73,6 +76,7 @@ SUPPORTED_GROUPS_WARMUP = [
73
76
  f"{WARMUP_TAG}predictions",
74
77
  f"{WARMUP_TAG}sample_stats",
75
78
  f"{WARMUP_TAG}log_likelihood",
79
+ f"{WARMUP_TAG}log_prior",
76
80
  ]
77
81
 
78
82
  SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
@@ -236,6 +240,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
236
240
  self._groups_warmup.remove(group)
237
241
  object.__delattr__(self, group)
238
242
 
243
+ def __delitem__(self, key: str) -> None:
244
+ """Delete an item from the InferenceData object using del idata[key]."""
245
+ self.__delattr__(key)
246
+
239
247
  @property
240
248
  def _groups_all(self) -> List[str]:
241
249
  return self._groups + self._groups_warmup
@@ -246,8 +254,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
246
254
 
247
255
  def __iter__(self) -> Iterator[str]:
248
256
  """Iterate over groups in InferenceData object."""
249
- for group in self._groups_all:
250
- yield group
257
+ yield from self._groups_all
251
258
 
252
259
  def __contains__(self, key: object) -> bool:
253
260
  """Return True if the named item is present, and False otherwise."""
@@ -620,6 +627,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
620
627
  self,
621
628
  groups=None,
622
629
  filter_groups=None,
630
+ var_names=None,
631
+ filter_vars=None,
623
632
  include_coords=True,
624
633
  include_index=True,
625
634
  index_origin=None,
@@ -635,6 +644,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
635
644
  skipped implicitly.
636
645
 
637
646
  Raises TypeError if no valid groups are found.
647
+ Raises ValueError if no data are selected.
638
648
 
639
649
  Parameters
640
650
  ----------
@@ -646,6 +656,15 @@ class InferenceData(Mapping[str, xr.Dataset]):
646
656
  If "like", interpret groups as substrings of the real group or metagroup names.
647
657
  If "regex", interpret groups as regular expressions on the real group or
648
658
  metagroup names. A la `pandas.filter`.
659
+ var_names : str or list of str, optional
660
+ Variables to be extracted. Prefix the variables by `~` when you want to exclude them.
661
+ filter_vars: {None, "like", "regex"}, optional
662
+ If `None` (default), interpret var_names as the real variables names. If "like",
663
+ interpret var_names as substrings of the real variables names. If "regex",
664
+ interpret var_names as regular expressions on the real variables names. A la
665
+ `pandas.filter`.
666
+ Like with plotting, sometimes it's easier to subset saying what to exclude
667
+ instead of what to include
649
668
  include_coords: bool
650
669
  Add coordinate values to column name (tuple).
651
670
  include_index: bool
@@ -677,6 +696,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
677
696
  dfs = {}
678
697
  for group in group_names:
679
698
  dataset = self[group]
699
+ group_var_names = _var_names(var_names, dataset, filter_vars, "ignore")
700
+ if (group_var_names is not None) and not group_var_names:
701
+ continue
702
+ if group_var_names is not None:
703
+ dataset = dataset[[var_name for var_name in group_var_names if var_name in dataset]]
680
704
  df = None
681
705
  coords_to_idx = {
682
706
  name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin)))
@@ -712,8 +736,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
712
736
  df = dataframe
713
737
  continue
714
738
  df = df.join(dataframe, how="outer")
715
- df = df.reset_index()
716
- dfs[group] = df
739
+ if df is not None:
740
+ df = df.reset_index()
741
+ dfs[group] = df
742
+ if not dfs:
743
+ raise ValueError("No data selected for the dataframe.")
717
744
  if len(dfs) > 1:
718
745
  for group, df in dfs.items():
719
746
  df.columns = [
@@ -1466,7 +1493,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1466
1493
 
1467
1494
  import numpy as np
1468
1495
  rng = np.random.default_rng(73)
1469
- ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
1496
+ ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
1470
1497
  idata.add_groups(
1471
1498
  log_likelihood={"home_points": ary},
1472
1499
  dims={"home_points": ["match"]},
@@ -18,7 +18,7 @@ class BMConverter:
18
18
  self.coords = coords
19
19
  self.dims = dims
20
20
 
21
- import beanmachine.ppl as bm
21
+ import beanmachine.ppl as bm # pylint: disable=import-error
22
22
 
23
23
  self.beanm = bm
24
24
 
@@ -99,6 +99,11 @@ def from_beanmachine(
99
99
  Map of dimensions to coordinates
100
100
  dims : dict of {str : list of str}
101
101
  Map variable names to their coordinates
102
+
103
+ Warnings
104
+ --------
105
+ `beanmachine` is no longer under active development, and therefore, it
106
+ is not possible to test this converter in ArviZ's CI.
102
107
  """
103
108
  return BMConverter(
104
109
  sampler=sampler,