arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/__init__.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
|
1
|
-
"""Code for loading and manipulating data structures."""
|
|
2
|
-
|
|
3
|
-
from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
|
|
4
|
-
from .converters import convert_to_dataset, convert_to_inference_data
|
|
5
|
-
from .datasets import clear_data_home, list_datasets, load_arviz_data
|
|
6
|
-
from .inference_data import InferenceData, concat
|
|
7
|
-
from .io_beanmachine import from_beanmachine
|
|
8
|
-
from .io_cmdstan import from_cmdstan
|
|
9
|
-
from .io_cmdstanpy import from_cmdstanpy
|
|
10
|
-
from .io_datatree import from_datatree, to_datatree
|
|
11
|
-
from .io_dict import from_dict, from_pytree
|
|
12
|
-
from .io_emcee import from_emcee
|
|
13
|
-
from .io_json import from_json, to_json
|
|
14
|
-
from .io_netcdf import from_netcdf, to_netcdf
|
|
15
|
-
from .io_numpyro import from_numpyro
|
|
16
|
-
from .io_pyjags import from_pyjags
|
|
17
|
-
from .io_pyro import from_pyro
|
|
18
|
-
from .io_pystan import from_pystan
|
|
19
|
-
from .io_zarr import from_zarr, to_zarr
|
|
20
|
-
from .utils import extract, extract_dataset
|
|
21
|
-
|
|
22
|
-
__all__ = [
|
|
23
|
-
"InferenceData",
|
|
24
|
-
"concat",
|
|
25
|
-
"load_arviz_data",
|
|
26
|
-
"list_datasets",
|
|
27
|
-
"clear_data_home",
|
|
28
|
-
"numpy_to_data_array",
|
|
29
|
-
"extract",
|
|
30
|
-
"extract_dataset",
|
|
31
|
-
"dict_to_dataset",
|
|
32
|
-
"convert_to_dataset",
|
|
33
|
-
"convert_to_inference_data",
|
|
34
|
-
"from_beanmachine",
|
|
35
|
-
"from_pyjags",
|
|
36
|
-
"from_pystan",
|
|
37
|
-
"from_emcee",
|
|
38
|
-
"from_cmdstan",
|
|
39
|
-
"from_cmdstanpy",
|
|
40
|
-
"from_datatree",
|
|
41
|
-
"from_dict",
|
|
42
|
-
"from_pytree",
|
|
43
|
-
"from_json",
|
|
44
|
-
"from_pyro",
|
|
45
|
-
"from_numpyro",
|
|
46
|
-
"from_netcdf",
|
|
47
|
-
"pytree_to_dataset",
|
|
48
|
-
"to_datatree",
|
|
49
|
-
"to_json",
|
|
50
|
-
"to_netcdf",
|
|
51
|
-
"from_zarr",
|
|
52
|
-
"to_zarr",
|
|
53
|
-
"CoordSpec",
|
|
54
|
-
"DimSpec",
|
|
55
|
-
]
|
arviz/data/base.py
DELETED
|
@@ -1,596 +0,0 @@
|
|
|
1
|
-
"""Low level converters usually used by other functions."""
|
|
2
|
-
|
|
3
|
-
import datetime
|
|
4
|
-
import functools
|
|
5
|
-
import importlib
|
|
6
|
-
import re
|
|
7
|
-
import warnings
|
|
8
|
-
from copy import deepcopy
|
|
9
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
import xarray as xr
|
|
13
|
-
|
|
14
|
-
try:
|
|
15
|
-
import tree
|
|
16
|
-
except ImportError:
|
|
17
|
-
tree = None
|
|
18
|
-
|
|
19
|
-
try:
|
|
20
|
-
import ujson as json
|
|
21
|
-
except ImportError:
|
|
22
|
-
# mypy struggles with conditional imports expressed as catching ImportError:
|
|
23
|
-
# https://github.com/python/mypy/issues/1153
|
|
24
|
-
import json # type: ignore
|
|
25
|
-
|
|
26
|
-
from .. import __version__, utils
|
|
27
|
-
from ..rcparams import rcParams
|
|
28
|
-
|
|
29
|
-
CoordSpec = Dict[str, List[Any]]
|
|
30
|
-
DimSpec = Dict[str, List[str]]
|
|
31
|
-
RequiresArgTypeT = TypeVar("RequiresArgTypeT")
|
|
32
|
-
RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class requires: # pylint: disable=invalid-name
|
|
36
|
-
"""Decorator to return None if an object does not have the required attribute.
|
|
37
|
-
|
|
38
|
-
If the decorator is called various times on the same function with different
|
|
39
|
-
attributes, it will return None if one of them is missing. If instead a list
|
|
40
|
-
of attributes is passed, it will return None if all attributes in the list are
|
|
41
|
-
missing. Both functionalities can be combined as desired.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
def __init__(self, *props: Union[str, List[str]]) -> None:
|
|
45
|
-
self.props: Tuple[Union[str, List[str]], ...] = props
|
|
46
|
-
|
|
47
|
-
# Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available
|
|
48
|
-
# in all our supported Python versions, there is no way to simultaneously express
|
|
49
|
-
# the following two properties:
|
|
50
|
-
# - the input function may take arbitrary args/kwargs, and
|
|
51
|
-
# - the output function takes those same arbitrary args/kwargs, but has a different return type.
|
|
52
|
-
# We either have to limit the input function to e.g. only allowing a "self" argument,
|
|
53
|
-
# or we have to adopt the current approach of annotating the returned function as if
|
|
54
|
-
# it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]".
|
|
55
|
-
#
|
|
56
|
-
# Since all functions decorated with @requires currently only accept a single argument,
|
|
57
|
-
# we choose to limit application of @requires to only functions of one argument.
|
|
58
|
-
# When typing.ParamSpec is available, this definition can be updated to use it.
|
|
59
|
-
# See https://github.com/arviz-devs/arviz/pull/1504 for more discussion.
|
|
60
|
-
def __call__(
|
|
61
|
-
self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT]
|
|
62
|
-
) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]: # noqa: D202
|
|
63
|
-
"""Wrap the decorated function."""
|
|
64
|
-
|
|
65
|
-
def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
|
|
66
|
-
"""Return None if not all props are available."""
|
|
67
|
-
for prop in self.props:
|
|
68
|
-
prop = [prop] if isinstance(prop, str) else prop
|
|
69
|
-
if all((getattr(cls, prop_i) is None for prop_i in prop)):
|
|
70
|
-
return None
|
|
71
|
-
return func(cls)
|
|
72
|
-
|
|
73
|
-
return wrapped
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|
77
|
-
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
|
|
78
|
-
|
|
79
|
-
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
|
|
80
|
-
lists as leaves.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
|
|
84
|
-
input_tree: Nested structure. Return the paths and values from this tree.
|
|
85
|
-
Must have the same upper structure as shallow_tree.
|
|
86
|
-
path: Tuple. Optional argument, only used when recursing. The path from the
|
|
87
|
-
root of the original shallow_tree, down to the root of the shallow_tree
|
|
88
|
-
arg of this recursive call.
|
|
89
|
-
|
|
90
|
-
Yields:
|
|
91
|
-
Pairs of (path, value), where path the tuple path of a leaf node in
|
|
92
|
-
shallow_tree, and value is the value of the corresponding node in
|
|
93
|
-
input_tree.
|
|
94
|
-
"""
|
|
95
|
-
# pylint: disable=protected-access
|
|
96
|
-
if tree is None:
|
|
97
|
-
raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
|
|
98
|
-
|
|
99
|
-
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
|
|
100
|
-
isinstance(shallow_tree, tree.collections_abc.Mapping)
|
|
101
|
-
or tree._is_namedtuple(shallow_tree)
|
|
102
|
-
or tree._is_attrs(shallow_tree)
|
|
103
|
-
):
|
|
104
|
-
yield (path, input_tree)
|
|
105
|
-
else:
|
|
106
|
-
input_tree = dict(tree._yield_sorted_items(input_tree))
|
|
107
|
-
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
|
|
108
|
-
subpath = path + (shallow_key,)
|
|
109
|
-
input_subtree = input_tree[shallow_key]
|
|
110
|
-
for leaf_path, leaf_value in _yield_flat_up_to(
|
|
111
|
-
shallow_subtree, input_subtree, path=subpath
|
|
112
|
-
):
|
|
113
|
-
yield (leaf_path, leaf_value)
|
|
114
|
-
# pylint: enable=protected-access
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _flatten_with_path(structure):
|
|
118
|
-
return list(_yield_flat_up_to(structure, structure))
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def generate_dims_coords(
|
|
122
|
-
shape,
|
|
123
|
-
var_name,
|
|
124
|
-
dims=None,
|
|
125
|
-
coords=None,
|
|
126
|
-
default_dims=None,
|
|
127
|
-
index_origin=None,
|
|
128
|
-
skip_event_dims=None,
|
|
129
|
-
):
|
|
130
|
-
"""Generate default dimensions and coordinates for a variable.
|
|
131
|
-
|
|
132
|
-
Parameters
|
|
133
|
-
----------
|
|
134
|
-
shape : tuple[int]
|
|
135
|
-
Shape of the variable
|
|
136
|
-
var_name : str
|
|
137
|
-
Name of the variable. If no dimension name(s) is provided, ArviZ
|
|
138
|
-
will generate a default dimension name using ``var_name``, e.g.,
|
|
139
|
-
``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``.
|
|
140
|
-
dims : list
|
|
141
|
-
List of dimensions for the variable
|
|
142
|
-
coords : dict[str] -> list[str]
|
|
143
|
-
Map of dimensions to coordinates
|
|
144
|
-
default_dims : list[str]
|
|
145
|
-
Dimension names that are not part of the variable's shape. For example,
|
|
146
|
-
when manipulating Monte Carlo traces, the ``default_dims`` would be
|
|
147
|
-
``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
|
|
148
|
-
of MCMC traces.
|
|
149
|
-
index_origin : int, optional
|
|
150
|
-
Starting value of integer coordinate values. Defaults to the value in rcParam
|
|
151
|
-
``data.index_origin``.
|
|
152
|
-
skip_event_dims : bool, default False
|
|
153
|
-
|
|
154
|
-
Returns
|
|
155
|
-
-------
|
|
156
|
-
list[str]
|
|
157
|
-
Default dims
|
|
158
|
-
dict[str] -> list[str]
|
|
159
|
-
Default coords
|
|
160
|
-
"""
|
|
161
|
-
if index_origin is None:
|
|
162
|
-
index_origin = rcParams["data.index_origin"]
|
|
163
|
-
if default_dims is None:
|
|
164
|
-
default_dims = []
|
|
165
|
-
if dims is None:
|
|
166
|
-
dims = []
|
|
167
|
-
if skip_event_dims is None:
|
|
168
|
-
skip_event_dims = False
|
|
169
|
-
|
|
170
|
-
if coords is None:
|
|
171
|
-
coords = {}
|
|
172
|
-
|
|
173
|
-
coords = deepcopy(coords)
|
|
174
|
-
dims = deepcopy(dims)
|
|
175
|
-
|
|
176
|
-
ndims = len([dim for dim in dims if dim not in default_dims])
|
|
177
|
-
if ndims > len(shape):
|
|
178
|
-
if skip_event_dims:
|
|
179
|
-
dims = dims[: len(shape)]
|
|
180
|
-
else:
|
|
181
|
-
warnings.warn(
|
|
182
|
-
(
|
|
183
|
-
"In variable {var_name}, there are "
|
|
184
|
-
+ "more dims ({dims_len}) given than exist ({shape_len}). "
|
|
185
|
-
+ "Passed array should have shape ({defaults}*shape)"
|
|
186
|
-
).format(
|
|
187
|
-
var_name=var_name,
|
|
188
|
-
dims_len=len(dims),
|
|
189
|
-
shape_len=len(shape),
|
|
190
|
-
defaults=",".join(default_dims) + ", " if default_dims is not None else "",
|
|
191
|
-
),
|
|
192
|
-
UserWarning,
|
|
193
|
-
)
|
|
194
|
-
if skip_event_dims:
|
|
195
|
-
# this is needed in case the reduction keeps the dimension with size 1
|
|
196
|
-
for i, (dim, dim_size) in enumerate(zip(dims, shape)):
|
|
197
|
-
if (dim in coords) and (dim_size != len(coords[dim])):
|
|
198
|
-
dims = dims[:i]
|
|
199
|
-
break
|
|
200
|
-
|
|
201
|
-
for i, dim_len in enumerate(shape):
|
|
202
|
-
idx = i + len([dim for dim in default_dims if dim in dims])
|
|
203
|
-
if len(dims) < idx + 1:
|
|
204
|
-
dim_name = f"{var_name}_dim_{i}"
|
|
205
|
-
dims.append(dim_name)
|
|
206
|
-
elif dims[idx] is None:
|
|
207
|
-
dim_name = f"{var_name}_dim_{i}"
|
|
208
|
-
dims[idx] = dim_name
|
|
209
|
-
dim_name = dims[idx]
|
|
210
|
-
if dim_name not in coords:
|
|
211
|
-
coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
|
|
212
|
-
coords = {
|
|
213
|
-
key: coord
|
|
214
|
-
for key, coord in coords.items()
|
|
215
|
-
if any(key == dim for dim in dims + default_dims)
|
|
216
|
-
}
|
|
217
|
-
return dims, coords
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def numpy_to_data_array(
|
|
221
|
-
ary,
|
|
222
|
-
*,
|
|
223
|
-
var_name="data",
|
|
224
|
-
coords=None,
|
|
225
|
-
dims=None,
|
|
226
|
-
default_dims=None,
|
|
227
|
-
index_origin=None,
|
|
228
|
-
skip_event_dims=None,
|
|
229
|
-
):
|
|
230
|
-
"""Convert a numpy array to an xarray.DataArray.
|
|
231
|
-
|
|
232
|
-
By default, the first two dimensions will be (chain, draw), and any remaining
|
|
233
|
-
dimensions will be "shape".
|
|
234
|
-
* If the numpy array is 1d, this dimension is interpreted as draw
|
|
235
|
-
* If the numpy array is 2d, it is interpreted as (chain, draw)
|
|
236
|
-
* If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
|
|
237
|
-
|
|
238
|
-
To modify this behaviour, use ``default_dims``.
|
|
239
|
-
|
|
240
|
-
Parameters
|
|
241
|
-
----------
|
|
242
|
-
ary : np.ndarray
|
|
243
|
-
A numpy array. If it has 2 or more dimensions, the first dimension should be
|
|
244
|
-
independent chains from a simulation. Use `np.expand_dims(ary, 0)` to add a
|
|
245
|
-
single dimension to the front if there is only 1 chain.
|
|
246
|
-
var_name : str
|
|
247
|
-
If there are no dims passed, this string is used to name dimensions
|
|
248
|
-
coords : dict[str, iterable]
|
|
249
|
-
A dictionary containing the values that are used as index. The key
|
|
250
|
-
is the name of the dimension, the values are the index values.
|
|
251
|
-
dims : List(str)
|
|
252
|
-
A list of coordinate names for the variable
|
|
253
|
-
default_dims : list of str, optional
|
|
254
|
-
Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and
|
|
255
|
-
an empty list is accepted
|
|
256
|
-
index_origin : int, optional
|
|
257
|
-
Passed to :py:func:`generate_dims_coords`
|
|
258
|
-
skip_event_dims : bool
|
|
259
|
-
|
|
260
|
-
Returns
|
|
261
|
-
-------
|
|
262
|
-
xr.DataArray
|
|
263
|
-
Will have the same data as passed, but with coordinates and dimensions
|
|
264
|
-
"""
|
|
265
|
-
# manage and transform copies
|
|
266
|
-
if default_dims is None:
|
|
267
|
-
default_dims = ["chain", "draw"]
|
|
268
|
-
if "chain" in default_dims and "draw" in default_dims:
|
|
269
|
-
ary = utils.two_de(ary)
|
|
270
|
-
n_chains, n_samples, *_ = ary.shape
|
|
271
|
-
if n_chains > n_samples:
|
|
272
|
-
warnings.warn(
|
|
273
|
-
"More chains ({n_chains}) than draws ({n_samples}). "
|
|
274
|
-
"Passed array should have shape (chains, draws, *shape)".format(
|
|
275
|
-
n_chains=n_chains, n_samples=n_samples
|
|
276
|
-
),
|
|
277
|
-
UserWarning,
|
|
278
|
-
)
|
|
279
|
-
else:
|
|
280
|
-
ary = utils.one_de(ary)
|
|
281
|
-
|
|
282
|
-
dims, coords = generate_dims_coords(
|
|
283
|
-
ary.shape[len(default_dims) :],
|
|
284
|
-
var_name,
|
|
285
|
-
dims=dims,
|
|
286
|
-
coords=coords,
|
|
287
|
-
default_dims=default_dims,
|
|
288
|
-
index_origin=index_origin,
|
|
289
|
-
skip_event_dims=skip_event_dims,
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
# reversed order for default dims: 'chain', 'draw'
|
|
293
|
-
if "draw" not in dims and "draw" in default_dims:
|
|
294
|
-
dims = ["draw"] + dims
|
|
295
|
-
if "chain" not in dims and "chain" in default_dims:
|
|
296
|
-
dims = ["chain"] + dims
|
|
297
|
-
|
|
298
|
-
index_origin = rcParams["data.index_origin"]
|
|
299
|
-
if "chain" not in coords and "chain" in default_dims:
|
|
300
|
-
coords["chain"] = np.arange(index_origin, n_chains + index_origin)
|
|
301
|
-
if "draw" not in coords and "draw" in default_dims:
|
|
302
|
-
coords["draw"] = np.arange(index_origin, n_samples + index_origin)
|
|
303
|
-
|
|
304
|
-
# filter coords based on the dims
|
|
305
|
-
coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
|
|
306
|
-
return xr.DataArray(ary, coords=coords, dims=dims)
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
def dict_to_dataset(
|
|
310
|
-
data,
|
|
311
|
-
*,
|
|
312
|
-
attrs=None,
|
|
313
|
-
library=None,
|
|
314
|
-
coords=None,
|
|
315
|
-
dims=None,
|
|
316
|
-
default_dims=None,
|
|
317
|
-
index_origin=None,
|
|
318
|
-
skip_event_dims=None,
|
|
319
|
-
):
|
|
320
|
-
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
|
|
321
|
-
|
|
322
|
-
ArviZ itself supports conversion of flat dictionaries.
|
|
323
|
-
Suport for pytrees requires ``dm-tree`` which is an optional dependency.
|
|
324
|
-
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
|
|
325
|
-
this inclues at least dictionaries and tuple types.
|
|
326
|
-
|
|
327
|
-
Parameters
|
|
328
|
-
----------
|
|
329
|
-
data : dict of {str : array_like or dict} or pytree
|
|
330
|
-
Data to convert. Keys are variable names.
|
|
331
|
-
attrs : dict, optional
|
|
332
|
-
Json serializable metadata to attach to the dataset, in addition to defaults.
|
|
333
|
-
library : module, optional
|
|
334
|
-
Library used for performing inference. Will be attached to the attrs metadata.
|
|
335
|
-
coords : dict of {str : ndarray}, optional
|
|
336
|
-
Coordinates for the dataset
|
|
337
|
-
dims : dict of {str : list of str}, optional
|
|
338
|
-
Dimensions of each variable. The keys are variable names, values are lists of
|
|
339
|
-
coordinates.
|
|
340
|
-
default_dims : list of str, optional
|
|
341
|
-
Passed to :py:func:`numpy_to_data_array`
|
|
342
|
-
index_origin : int, optional
|
|
343
|
-
Passed to :py:func:`numpy_to_data_array`
|
|
344
|
-
skip_event_dims : bool, optional
|
|
345
|
-
If True, cut extra dims whenever present to match the shape of the data.
|
|
346
|
-
Necessary for PPLs which have the same name in both observed data and log
|
|
347
|
-
likelihood groups, to account for their different shapes when observations are
|
|
348
|
-
multivariate.
|
|
349
|
-
|
|
350
|
-
Returns
|
|
351
|
-
-------
|
|
352
|
-
xarray.Dataset
|
|
353
|
-
In case of nested pytrees, the variable name will be a tuple of individual names.
|
|
354
|
-
|
|
355
|
-
Notes
|
|
356
|
-
-----
|
|
357
|
-
This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
|
|
358
|
-
|
|
359
|
-
Examples
|
|
360
|
-
--------
|
|
361
|
-
Convert a dictionary with two 2D variables to a Dataset.
|
|
362
|
-
|
|
363
|
-
.. ipython::
|
|
364
|
-
|
|
365
|
-
In [1]: import arviz as az
|
|
366
|
-
...: import numpy as np
|
|
367
|
-
...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
|
|
368
|
-
|
|
369
|
-
Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
|
|
370
|
-
information to the generated Dataset such as default dimension names for sampled
|
|
371
|
-
dimensions and some attributes.
|
|
372
|
-
|
|
373
|
-
The function is also general enough to work on pytrees such as nested dictionaries:
|
|
374
|
-
|
|
375
|
-
.. ipython::
|
|
376
|
-
|
|
377
|
-
In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
|
|
378
|
-
|
|
379
|
-
which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
|
|
380
|
-
|
|
381
|
-
Dimensions and co-ordinates can be defined as usual:
|
|
382
|
-
|
|
383
|
-
.. ipython::
|
|
384
|
-
|
|
385
|
-
In [1]: datadict = {
|
|
386
|
-
...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
387
|
-
...: "d": np.random.randn(100),
|
|
388
|
-
...: }
|
|
389
|
-
...: az.dict_to_dataset(
|
|
390
|
-
...: datadict,
|
|
391
|
-
...: coords={"c": np.arange(10)},
|
|
392
|
-
...: dims={("top", "b"): ["c"]}
|
|
393
|
-
...: )
|
|
394
|
-
|
|
395
|
-
"""
|
|
396
|
-
if dims is None:
|
|
397
|
-
dims = {}
|
|
398
|
-
|
|
399
|
-
if tree is not None:
|
|
400
|
-
try:
|
|
401
|
-
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
|
|
402
|
-
except TypeError: # probably unsortable keys -- the function will still work if
|
|
403
|
-
pass # it is an honest dictionary.
|
|
404
|
-
|
|
405
|
-
data_vars = {
|
|
406
|
-
key: numpy_to_data_array(
|
|
407
|
-
values,
|
|
408
|
-
var_name=key,
|
|
409
|
-
coords=coords,
|
|
410
|
-
dims=dims.get(key),
|
|
411
|
-
default_dims=default_dims,
|
|
412
|
-
index_origin=index_origin,
|
|
413
|
-
skip_event_dims=skip_event_dims,
|
|
414
|
-
)
|
|
415
|
-
for key, values in data.items()
|
|
416
|
-
}
|
|
417
|
-
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
pytree_to_dataset = dict_to_dataset
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
def make_attrs(attrs=None, library=None):
|
|
424
|
-
"""Make standard attributes to attach to xarray datasets.
|
|
425
|
-
|
|
426
|
-
Parameters
|
|
427
|
-
----------
|
|
428
|
-
attrs : dict (optional)
|
|
429
|
-
Additional attributes to add or overwrite
|
|
430
|
-
|
|
431
|
-
Returns
|
|
432
|
-
-------
|
|
433
|
-
dict
|
|
434
|
-
attrs
|
|
435
|
-
"""
|
|
436
|
-
default_attrs = {
|
|
437
|
-
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
|
438
|
-
"arviz_version": __version__,
|
|
439
|
-
}
|
|
440
|
-
if library is not None:
|
|
441
|
-
library_name = library.__name__
|
|
442
|
-
default_attrs["inference_library"] = library_name
|
|
443
|
-
try:
|
|
444
|
-
version = importlib.metadata.version(library_name)
|
|
445
|
-
default_attrs["inference_library_version"] = version
|
|
446
|
-
except importlib.metadata.PackageNotFoundError:
|
|
447
|
-
if hasattr(library, "__version__"):
|
|
448
|
-
version = library.__version__
|
|
449
|
-
default_attrs["inference_library_version"] = version
|
|
450
|
-
if attrs is not None:
|
|
451
|
-
default_attrs.update(attrs)
|
|
452
|
-
return default_attrs
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def _extend_xr_method(func, doc=None, description="", examples="", see_also=""):
|
|
456
|
-
"""Make wrapper to extend methods from xr.Dataset to InferenceData Class.
|
|
457
|
-
|
|
458
|
-
Parameters
|
|
459
|
-
----------
|
|
460
|
-
func : callable
|
|
461
|
-
An xr.Dataset function
|
|
462
|
-
doc : str
|
|
463
|
-
docstring for the func
|
|
464
|
-
description : str
|
|
465
|
-
the description of the func to be added in docstring
|
|
466
|
-
examples : str
|
|
467
|
-
the examples of the func to be added in docstring
|
|
468
|
-
see_also : str, list
|
|
469
|
-
the similar methods of func to be included in See Also section of docstring
|
|
470
|
-
|
|
471
|
-
"""
|
|
472
|
-
# pydocstyle requires a non empty line
|
|
473
|
-
|
|
474
|
-
@functools.wraps(func)
|
|
475
|
-
def wrapped(self, *args, **kwargs):
|
|
476
|
-
_filter = kwargs.pop("filter_groups", None)
|
|
477
|
-
_groups = kwargs.pop("groups", None)
|
|
478
|
-
_inplace = kwargs.pop("inplace", False)
|
|
479
|
-
|
|
480
|
-
out = self if _inplace else deepcopy(self)
|
|
481
|
-
|
|
482
|
-
groups = self._group_names(_groups, _filter) # pylint: disable=protected-access
|
|
483
|
-
for group in groups:
|
|
484
|
-
xr_data = getattr(out, group)
|
|
485
|
-
xr_data = func(xr_data, *args, **kwargs) # pylint: disable=not-callable
|
|
486
|
-
setattr(out, group, xr_data)
|
|
487
|
-
|
|
488
|
-
return None if _inplace else out
|
|
489
|
-
|
|
490
|
-
description_default = """{method_name} method is extended from xarray.Dataset methods.
|
|
491
|
-
|
|
492
|
-
{description}
|
|
493
|
-
|
|
494
|
-
For more info see :meth:`xarray:xarray.Dataset.{method_name}`.
|
|
495
|
-
In addition to the arguments available in the original method, the following
|
|
496
|
-
ones are added by ArviZ to adapt the method to being called on an ``InferenceData`` object.
|
|
497
|
-
""".format(
|
|
498
|
-
description=description, method_name=func.__name__ # pylint: disable=no-member
|
|
499
|
-
)
|
|
500
|
-
params = """
|
|
501
|
-
Other Parameters
|
|
502
|
-
----------------
|
|
503
|
-
groups: str or list of str, optional
|
|
504
|
-
Groups where the selection is to be applied. Can either be group names
|
|
505
|
-
or metagroup names.
|
|
506
|
-
filter_groups: {None, "like", "regex"}, optional, default=None
|
|
507
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
508
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
509
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
510
|
-
metagroup names. A la `pandas.filter`.
|
|
511
|
-
inplace: bool, optional
|
|
512
|
-
If ``True``, modify the InferenceData object inplace,
|
|
513
|
-
otherwise, return the modified copy.
|
|
514
|
-
"""
|
|
515
|
-
|
|
516
|
-
if not isinstance(see_also, str):
|
|
517
|
-
see_also = "\n".join(see_also)
|
|
518
|
-
see_also_basic = """
|
|
519
|
-
See Also
|
|
520
|
-
--------
|
|
521
|
-
xarray.Dataset.{method_name}
|
|
522
|
-
{custom_see_also}
|
|
523
|
-
""".format(
|
|
524
|
-
method_name=func.__name__, custom_see_also=see_also # pylint: disable=no-member
|
|
525
|
-
)
|
|
526
|
-
wrapped.__doc__ = (
|
|
527
|
-
description_default + params + examples + see_also_basic if doc is None else doc
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
return wrapped
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
def _make_json_serializable(data: dict) -> dict:
|
|
534
|
-
"""Convert `data` with numpy.ndarray-like values to JSON-serializable form."""
|
|
535
|
-
ret = {}
|
|
536
|
-
for key, value in data.items():
|
|
537
|
-
try:
|
|
538
|
-
json.dumps(value)
|
|
539
|
-
except (TypeError, OverflowError):
|
|
540
|
-
pass
|
|
541
|
-
else:
|
|
542
|
-
ret[key] = value
|
|
543
|
-
continue
|
|
544
|
-
if isinstance(value, dict):
|
|
545
|
-
ret[key] = _make_json_serializable(value)
|
|
546
|
-
elif isinstance(value, np.ndarray):
|
|
547
|
-
ret[key] = np.asarray(value).tolist()
|
|
548
|
-
else:
|
|
549
|
-
raise TypeError(
|
|
550
|
-
f"Value associated with variable `{type(value)}` is not JSON serializable."
|
|
551
|
-
)
|
|
552
|
-
return ret
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
def infer_stan_dtypes(stan_code):
|
|
556
|
-
"""Infer Stan integer variables from generated quantities block."""
|
|
557
|
-
# Remove old deprecated comments
|
|
558
|
-
stan_code = "\n".join(
|
|
559
|
-
line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
|
|
560
|
-
)
|
|
561
|
-
pattern_remove_comments = re.compile(
|
|
562
|
-
r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
|
|
563
|
-
)
|
|
564
|
-
stan_code = re.sub(pattern_remove_comments, "", stan_code)
|
|
565
|
-
|
|
566
|
-
# Check generated quantities
|
|
567
|
-
if "generated quantities" not in stan_code:
|
|
568
|
-
return {}
|
|
569
|
-
|
|
570
|
-
# Extract generated quantities block
|
|
571
|
-
gen_quantities_location = stan_code.index("generated quantities")
|
|
572
|
-
block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{")
|
|
573
|
-
|
|
574
|
-
curly_bracket_count = 0
|
|
575
|
-
block_end = None
|
|
576
|
-
for block_end, char in enumerate(stan_code[block_start:], block_start + 1):
|
|
577
|
-
if char == "{":
|
|
578
|
-
curly_bracket_count += 1
|
|
579
|
-
elif char == "}":
|
|
580
|
-
curly_bracket_count -= 1
|
|
581
|
-
|
|
582
|
-
if curly_bracket_count == 0:
|
|
583
|
-
break
|
|
584
|
-
|
|
585
|
-
stan_code = stan_code[block_start:block_end]
|
|
586
|
-
|
|
587
|
-
stan_integer = r"int"
|
|
588
|
-
stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....>
|
|
589
|
-
stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace
|
|
590
|
-
stan_ws = r"\s*" # 0 or more whitespace
|
|
591
|
-
stan_ws_one = r"\s+" # 1 or more whitespace
|
|
592
|
-
pattern_int = re.compile(
|
|
593
|
-
"".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE
|
|
594
|
-
)
|
|
595
|
-
dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)}
|
|
596
|
-
return dtypes
|