arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/inference_data.py +34 -7
- arviz/data/io_beanmachine.py +6 -1
- arviz/data/io_cmdstanpy.py +439 -50
- arviz/data/io_pyjags.py +5 -2
- arviz/data/io_pystan.py +1 -2
- arviz/labels.py +2 -0
- arviz/plots/backends/bokeh/bpvplot.py +7 -2
- arviz/plots/backends/bokeh/compareplot.py +7 -4
- arviz/plots/backends/bokeh/densityplot.py +0 -1
- arviz/plots/backends/bokeh/distplot.py +0 -2
- arviz/plots/backends/bokeh/forestplot.py +3 -5
- arviz/plots/backends/bokeh/kdeplot.py +0 -2
- arviz/plots/backends/bokeh/pairplot.py +0 -4
- arviz/plots/backends/matplotlib/bfplot.py +0 -1
- arviz/plots/backends/matplotlib/bpvplot.py +3 -3
- arviz/plots/backends/matplotlib/compareplot.py +1 -1
- arviz/plots/backends/matplotlib/dotplot.py +1 -1
- arviz/plots/backends/matplotlib/forestplot.py +2 -4
- arviz/plots/backends/matplotlib/kdeplot.py +0 -1
- arviz/plots/backends/matplotlib/khatplot.py +0 -1
- arviz/plots/backends/matplotlib/lmplot.py +4 -5
- arviz/plots/backends/matplotlib/pairplot.py +0 -1
- arviz/plots/backends/matplotlib/ppcplot.py +8 -5
- arviz/plots/backends/matplotlib/traceplot.py +1 -2
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +7 -2
- arviz/plots/compareplot.py +2 -2
- arviz/plots/ecdfplot.py +37 -112
- arviz/plots/elpdplot.py +1 -1
- arviz/plots/essplot.py +2 -2
- arviz/plots/kdeplot.py +0 -1
- arviz/plots/pairplot.py +1 -1
- arviz/plots/plot_utils.py +0 -1
- arviz/plots/ppcplot.py +51 -45
- arviz/plots/separationplot.py +0 -1
- arviz/stats/__init__.py +2 -0
- arviz/stats/density_utils.py +2 -2
- arviz/stats/diagnostics.py +2 -3
- arviz/stats/ecdf_utils.py +165 -0
- arviz/stats/stats.py +241 -38
- arviz/stats/stats_utils.py +36 -7
- arviz/tests/base_tests/test_data.py +73 -5
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1
- arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
- arviz/tests/base_tests/test_stats.py +43 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +3 -3
- arviz/tests/external_tests/test_data_beanmachine.py +2 -0
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyjags.py +3 -1
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +8 -8
- arviz/utils.py +15 -7
- arviz/wrappers/wrap_pymc.py +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/__init__.py
CHANGED
arviz/data/inference_data.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# pylint: disable=too-many-lines,too-many-public-methods
|
|
2
2
|
"""Data structure for using netcdf groups with xarray."""
|
|
3
|
+
import os
|
|
3
4
|
import re
|
|
4
5
|
import sys
|
|
5
6
|
import uuid
|
|
@@ -23,14 +24,13 @@ from typing import (
|
|
|
23
24
|
Union,
|
|
24
25
|
overload,
|
|
25
26
|
)
|
|
26
|
-
import os
|
|
27
27
|
|
|
28
28
|
import numpy as np
|
|
29
29
|
import xarray as xr
|
|
30
30
|
from packaging import version
|
|
31
31
|
|
|
32
32
|
from ..rcparams import rcParams
|
|
33
|
-
from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
|
|
33
|
+
from ..utils import HtmlTemplate, _subset_list, _var_names, either_dict_or_kwargs
|
|
34
34
|
from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
|
|
35
35
|
|
|
36
36
|
if sys.version_info[:2] >= (3, 9):
|
|
@@ -56,6 +56,7 @@ SUPPORTED_GROUPS = [
|
|
|
56
56
|
"posterior_predictive",
|
|
57
57
|
"predictions",
|
|
58
58
|
"log_likelihood",
|
|
59
|
+
"log_prior",
|
|
59
60
|
"sample_stats",
|
|
60
61
|
"prior",
|
|
61
62
|
"prior_predictive",
|
|
@@ -63,6 +64,8 @@ SUPPORTED_GROUPS = [
|
|
|
63
64
|
"observed_data",
|
|
64
65
|
"constant_data",
|
|
65
66
|
"predictions_constant_data",
|
|
67
|
+
"unconstrained_posterior",
|
|
68
|
+
"unconstrained_prior",
|
|
66
69
|
]
|
|
67
70
|
|
|
68
71
|
WARMUP_TAG = "warmup_"
|
|
@@ -73,6 +76,7 @@ SUPPORTED_GROUPS_WARMUP = [
|
|
|
73
76
|
f"{WARMUP_TAG}predictions",
|
|
74
77
|
f"{WARMUP_TAG}sample_stats",
|
|
75
78
|
f"{WARMUP_TAG}log_likelihood",
|
|
79
|
+
f"{WARMUP_TAG}log_prior",
|
|
76
80
|
]
|
|
77
81
|
|
|
78
82
|
SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
|
|
@@ -236,6 +240,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
236
240
|
self._groups_warmup.remove(group)
|
|
237
241
|
object.__delattr__(self, group)
|
|
238
242
|
|
|
243
|
+
def __delitem__(self, key: str) -> None:
|
|
244
|
+
"""Delete an item from the InferenceData object using del idata[key]."""
|
|
245
|
+
self.__delattr__(key)
|
|
246
|
+
|
|
239
247
|
@property
|
|
240
248
|
def _groups_all(self) -> List[str]:
|
|
241
249
|
return self._groups + self._groups_warmup
|
|
@@ -246,8 +254,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
246
254
|
|
|
247
255
|
def __iter__(self) -> Iterator[str]:
|
|
248
256
|
"""Iterate over groups in InferenceData object."""
|
|
249
|
-
|
|
250
|
-
yield group
|
|
257
|
+
yield from self._groups_all
|
|
251
258
|
|
|
252
259
|
def __contains__(self, key: object) -> bool:
|
|
253
260
|
"""Return True if the named item is present, and False otherwise."""
|
|
@@ -620,6 +627,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
620
627
|
self,
|
|
621
628
|
groups=None,
|
|
622
629
|
filter_groups=None,
|
|
630
|
+
var_names=None,
|
|
631
|
+
filter_vars=None,
|
|
623
632
|
include_coords=True,
|
|
624
633
|
include_index=True,
|
|
625
634
|
index_origin=None,
|
|
@@ -635,6 +644,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
635
644
|
skipped implicitly.
|
|
636
645
|
|
|
637
646
|
Raises TypeError if no valid groups are found.
|
|
647
|
+
Raises ValueError if no data are selected.
|
|
638
648
|
|
|
639
649
|
Parameters
|
|
640
650
|
----------
|
|
@@ -646,6 +656,15 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
646
656
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
647
657
|
If "regex", interpret groups as regular expressions on the real group or
|
|
648
658
|
metagroup names. A la `pandas.filter`.
|
|
659
|
+
var_names : str or list of str, optional
|
|
660
|
+
Variables to be extracted. Prefix the variables by `~` when you want to exclude them.
|
|
661
|
+
filter_vars: {None, "like", "regex"}, optional
|
|
662
|
+
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
663
|
+
interpret var_names as substrings of the real variables names. If "regex",
|
|
664
|
+
interpret var_names as regular expressions on the real variables names. A la
|
|
665
|
+
`pandas.filter`.
|
|
666
|
+
Like with plotting, sometimes it's easier to subset saying what to exclude
|
|
667
|
+
instead of what to include
|
|
649
668
|
include_coords: bool
|
|
650
669
|
Add coordinate values to column name (tuple).
|
|
651
670
|
include_index: bool
|
|
@@ -677,6 +696,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
677
696
|
dfs = {}
|
|
678
697
|
for group in group_names:
|
|
679
698
|
dataset = self[group]
|
|
699
|
+
group_var_names = _var_names(var_names, dataset, filter_vars, "ignore")
|
|
700
|
+
if (group_var_names is not None) and not group_var_names:
|
|
701
|
+
continue
|
|
702
|
+
if group_var_names is not None:
|
|
703
|
+
dataset = dataset[[var_name for var_name in group_var_names if var_name in dataset]]
|
|
680
704
|
df = None
|
|
681
705
|
coords_to_idx = {
|
|
682
706
|
name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin)))
|
|
@@ -712,8 +736,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
712
736
|
df = dataframe
|
|
713
737
|
continue
|
|
714
738
|
df = df.join(dataframe, how="outer")
|
|
715
|
-
df
|
|
716
|
-
|
|
739
|
+
if df is not None:
|
|
740
|
+
df = df.reset_index()
|
|
741
|
+
dfs[group] = df
|
|
742
|
+
if not dfs:
|
|
743
|
+
raise ValueError("No data selected for the dataframe.")
|
|
717
744
|
if len(dfs) > 1:
|
|
718
745
|
for group, df in dfs.items():
|
|
719
746
|
df.columns = [
|
|
@@ -1466,7 +1493,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1466
1493
|
|
|
1467
1494
|
import numpy as np
|
|
1468
1495
|
rng = np.random.default_rng(73)
|
|
1469
|
-
ary = rng.normal(size=(post.
|
|
1496
|
+
ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
|
|
1470
1497
|
idata.add_groups(
|
|
1471
1498
|
log_likelihood={"home_points": ary},
|
|
1472
1499
|
dims={"home_points": ["match"]},
|
arviz/data/io_beanmachine.py
CHANGED
|
@@ -18,7 +18,7 @@ class BMConverter:
|
|
|
18
18
|
self.coords = coords
|
|
19
19
|
self.dims = dims
|
|
20
20
|
|
|
21
|
-
import beanmachine.ppl as bm
|
|
21
|
+
import beanmachine.ppl as bm # pylint: disable=import-error
|
|
22
22
|
|
|
23
23
|
self.beanm = bm
|
|
24
24
|
|
|
@@ -99,6 +99,11 @@ def from_beanmachine(
|
|
|
99
99
|
Map of dimensions to coordinates
|
|
100
100
|
dims : dict of {str : list of str}
|
|
101
101
|
Map variable names to their coordinates
|
|
102
|
+
|
|
103
|
+
Warnings
|
|
104
|
+
--------
|
|
105
|
+
`beanmachine` is no longer under active development, and therefore, it
|
|
106
|
+
is not possible to test this converter in ArviZ's CI.
|
|
102
107
|
"""
|
|
103
108
|
return BMConverter(
|
|
104
109
|
sampler=sampler,
|