arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/inference_data.py
DELETED
|
@@ -1,2386 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines,too-many-public-methods
|
|
2
|
-
"""Data structure for using netcdf groups with xarray."""
|
|
3
|
-
import os
|
|
4
|
-
import re
|
|
5
|
-
import sys
|
|
6
|
-
import uuid
|
|
7
|
-
import warnings
|
|
8
|
-
from collections import OrderedDict, defaultdict
|
|
9
|
-
from collections.abc import MutableMapping, Sequence
|
|
10
|
-
from copy import copy as ccopy
|
|
11
|
-
from copy import deepcopy
|
|
12
|
-
import datetime
|
|
13
|
-
from html import escape
|
|
14
|
-
from typing import (
|
|
15
|
-
TYPE_CHECKING,
|
|
16
|
-
Any,
|
|
17
|
-
Iterable,
|
|
18
|
-
Iterator,
|
|
19
|
-
List,
|
|
20
|
-
Mapping,
|
|
21
|
-
Optional,
|
|
22
|
-
Tuple,
|
|
23
|
-
TypeVar,
|
|
24
|
-
Union,
|
|
25
|
-
overload,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
import numpy as np
|
|
29
|
-
import xarray as xr
|
|
30
|
-
from packaging import version
|
|
31
|
-
|
|
32
|
-
from ..rcparams import rcParams
|
|
33
|
-
from ..utils import HtmlTemplate, _subset_list, _var_names, either_dict_or_kwargs
|
|
34
|
-
from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
|
|
35
|
-
|
|
36
|
-
if sys.version_info[:2] >= (3, 9):
|
|
37
|
-
# As of 3.9, collections.abc types support generic parameters themselves.
|
|
38
|
-
from collections.abc import ItemsView, ValuesView
|
|
39
|
-
else:
|
|
40
|
-
# These typing imports are deprecated in 3.9, and moved to collections.abc instead.
|
|
41
|
-
from typing import ItemsView, ValuesView
|
|
42
|
-
|
|
43
|
-
if TYPE_CHECKING:
|
|
44
|
-
from typing_extensions import Literal
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
import ujson as json
|
|
48
|
-
except ImportError:
|
|
49
|
-
# mypy struggles with conditional imports expressed as catching ImportError:
|
|
50
|
-
# https://github.com/python/mypy/issues/1153
|
|
51
|
-
import json # type: ignore
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
SUPPORTED_GROUPS = [
|
|
55
|
-
"posterior",
|
|
56
|
-
"posterior_predictive",
|
|
57
|
-
"predictions",
|
|
58
|
-
"log_likelihood",
|
|
59
|
-
"log_prior",
|
|
60
|
-
"sample_stats",
|
|
61
|
-
"prior",
|
|
62
|
-
"prior_predictive",
|
|
63
|
-
"sample_stats_prior",
|
|
64
|
-
"observed_data",
|
|
65
|
-
"constant_data",
|
|
66
|
-
"predictions_constant_data",
|
|
67
|
-
"unconstrained_posterior",
|
|
68
|
-
"unconstrained_prior",
|
|
69
|
-
]
|
|
70
|
-
|
|
71
|
-
WARMUP_TAG = "warmup_"
|
|
72
|
-
|
|
73
|
-
SUPPORTED_GROUPS_WARMUP = [
|
|
74
|
-
f"{WARMUP_TAG}posterior",
|
|
75
|
-
f"{WARMUP_TAG}posterior_predictive",
|
|
76
|
-
f"{WARMUP_TAG}predictions",
|
|
77
|
-
f"{WARMUP_TAG}sample_stats",
|
|
78
|
-
f"{WARMUP_TAG}log_likelihood",
|
|
79
|
-
f"{WARMUP_TAG}log_prior",
|
|
80
|
-
]
|
|
81
|
-
|
|
82
|
-
SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
|
|
83
|
-
|
|
84
|
-
InferenceDataT = TypeVar("InferenceDataT", bound="InferenceData")
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _compressible_dtype(dtype):
|
|
88
|
-
"""Check basic dtypes for automatic compression."""
|
|
89
|
-
if dtype.kind == "V":
|
|
90
|
-
return all(_compressible_dtype(item) for item, _ in dtype.fields.values())
|
|
91
|
-
return dtype.kind in {"b", "i", "u", "f", "c", "S"}
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class InferenceData(Mapping[str, xr.Dataset]):
|
|
95
|
-
"""Container for inference data storage using xarray.
|
|
96
|
-
|
|
97
|
-
For a detailed introduction to ``InferenceData`` objects and their usage, see
|
|
98
|
-
:ref:`xarray_for_arviz`. This page provides help and documentation
|
|
99
|
-
on ``InferenceData`` methods and their low level implementation.
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
attrs: Union[None, Mapping[Any, Any]] = None,
|
|
105
|
-
warn_on_custom_groups: bool = False,
|
|
106
|
-
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
|
|
107
|
-
) -> None:
|
|
108
|
-
"""Initialize InferenceData object from keyword xarray datasets.
|
|
109
|
-
|
|
110
|
-
Parameters
|
|
111
|
-
----------
|
|
112
|
-
attrs : dict
|
|
113
|
-
sets global attribute for InferenceData object.
|
|
114
|
-
warn_on_custom_groups : bool, default False
|
|
115
|
-
Emit a warning when custom groups are present in the InferenceData.
|
|
116
|
-
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
117
|
-
kwargs :
|
|
118
|
-
Keyword arguments of xarray datasets
|
|
119
|
-
|
|
120
|
-
Examples
|
|
121
|
-
--------
|
|
122
|
-
Initiate an InferenceData object from scratch, not recommended. InferenceData
|
|
123
|
-
objects should be initialized using ``from_xyz`` methods, see :ref:`data_api` for more
|
|
124
|
-
details.
|
|
125
|
-
|
|
126
|
-
.. ipython::
|
|
127
|
-
|
|
128
|
-
In [1]: import arviz as az
|
|
129
|
-
...: import numpy as np
|
|
130
|
-
...: import xarray as xr
|
|
131
|
-
...: dataset = xr.Dataset(
|
|
132
|
-
...: {
|
|
133
|
-
...: "a": (["chain", "draw", "a_dim"], np.random.normal(size=(4, 100, 3))),
|
|
134
|
-
...: "b": (["chain", "draw"], np.random.normal(size=(4, 100))),
|
|
135
|
-
...: },
|
|
136
|
-
...: coords={
|
|
137
|
-
...: "chain": (["chain"], np.arange(4)),
|
|
138
|
-
...: "draw": (["draw"], np.arange(100)),
|
|
139
|
-
...: "a_dim": (["a_dim"], ["x", "y", "z"]),
|
|
140
|
-
...: }
|
|
141
|
-
...: )
|
|
142
|
-
...: idata = az.InferenceData(posterior=dataset, prior=dataset)
|
|
143
|
-
...: idata
|
|
144
|
-
|
|
145
|
-
We have created an ``InferenceData`` object with two groups. Now we can check its
|
|
146
|
-
contents:
|
|
147
|
-
|
|
148
|
-
.. ipython::
|
|
149
|
-
|
|
150
|
-
In [1]: idata.posterior
|
|
151
|
-
|
|
152
|
-
"""
|
|
153
|
-
self._groups: List[str] = []
|
|
154
|
-
self._groups_warmup: List[str] = []
|
|
155
|
-
self._attrs: Union[None, dict] = dict(attrs) if attrs is not None else None
|
|
156
|
-
save_warmup = kwargs.pop("save_warmup", False)
|
|
157
|
-
key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
|
|
158
|
-
for key in kwargs:
|
|
159
|
-
if key not in SUPPORTED_GROUPS_ALL:
|
|
160
|
-
key_list.append(key)
|
|
161
|
-
if warn_on_custom_groups:
|
|
162
|
-
warnings.warn(
|
|
163
|
-
f"{key} group is not defined in the InferenceData scheme", UserWarning
|
|
164
|
-
)
|
|
165
|
-
for key in key_list:
|
|
166
|
-
dataset = kwargs[key]
|
|
167
|
-
dataset_warmup = None
|
|
168
|
-
if dataset is None:
|
|
169
|
-
continue
|
|
170
|
-
elif isinstance(dataset, (list, tuple)):
|
|
171
|
-
dataset, dataset_warmup = dataset
|
|
172
|
-
elif not isinstance(dataset, xr.Dataset):
|
|
173
|
-
raise ValueError(
|
|
174
|
-
"Arguments to InferenceData must be xarray Datasets "
|
|
175
|
-
f"(argument '{key}' was type '{type(dataset)}')"
|
|
176
|
-
)
|
|
177
|
-
if not key.startswith(WARMUP_TAG):
|
|
178
|
-
if dataset:
|
|
179
|
-
setattr(self, key, dataset)
|
|
180
|
-
self._groups.append(key)
|
|
181
|
-
elif key.startswith(WARMUP_TAG):
|
|
182
|
-
if dataset:
|
|
183
|
-
setattr(self, key, dataset)
|
|
184
|
-
self._groups_warmup.append(key)
|
|
185
|
-
if save_warmup and dataset_warmup is not None and dataset_warmup:
|
|
186
|
-
key = f"{WARMUP_TAG}{key}"
|
|
187
|
-
setattr(self, key, dataset_warmup)
|
|
188
|
-
self._groups_warmup.append(key)
|
|
189
|
-
|
|
190
|
-
@property
|
|
191
|
-
def attrs(self) -> dict:
|
|
192
|
-
"""Attributes of InferenceData object."""
|
|
193
|
-
if self._attrs is None:
|
|
194
|
-
self._attrs = {}
|
|
195
|
-
return self._attrs
|
|
196
|
-
|
|
197
|
-
@attrs.setter
|
|
198
|
-
def attrs(self, value) -> None:
|
|
199
|
-
self._attrs = dict(value)
|
|
200
|
-
|
|
201
|
-
def __repr__(self) -> str:
|
|
202
|
-
"""Make string representation of InferenceData object."""
|
|
203
|
-
msg = "Inference data with groups:\n\t> {options}".format(
|
|
204
|
-
options="\n\t> ".join(self._groups)
|
|
205
|
-
)
|
|
206
|
-
if self._groups_warmup:
|
|
207
|
-
msg += f"\n\nWarmup iterations saved ({WARMUP_TAG}*)."
|
|
208
|
-
return msg
|
|
209
|
-
|
|
210
|
-
def _repr_html_(self) -> str:
|
|
211
|
-
"""Make html representation of InferenceData object."""
|
|
212
|
-
try:
|
|
213
|
-
from xarray.core.options import OPTIONS
|
|
214
|
-
|
|
215
|
-
display_style = OPTIONS["display_style"]
|
|
216
|
-
if display_style == "text":
|
|
217
|
-
html_repr = f"<pre>{escape(repr(self))}</pre>"
|
|
218
|
-
else:
|
|
219
|
-
elements = "".join(
|
|
220
|
-
[
|
|
221
|
-
HtmlTemplate.element_template.format(
|
|
222
|
-
group_id=group + str(uuid.uuid4()),
|
|
223
|
-
group=group,
|
|
224
|
-
xr_data=getattr( # pylint: disable=protected-access
|
|
225
|
-
self, group
|
|
226
|
-
)._repr_html_(),
|
|
227
|
-
)
|
|
228
|
-
for group in self._groups_all
|
|
229
|
-
]
|
|
230
|
-
)
|
|
231
|
-
formatted_html_template = ( # pylint: disable=possibly-unused-variable
|
|
232
|
-
HtmlTemplate.html_template.format(elements)
|
|
233
|
-
)
|
|
234
|
-
css_template = HtmlTemplate.css_template # pylint: disable=possibly-unused-variable
|
|
235
|
-
html_repr = f"{locals()['formatted_html_template']}{locals()['css_template']}"
|
|
236
|
-
except: # pylint: disable=bare-except
|
|
237
|
-
html_repr = f"<pre>{escape(repr(self))}</pre>"
|
|
238
|
-
return html_repr
|
|
239
|
-
|
|
240
|
-
def __delattr__(self, group: str) -> None:
|
|
241
|
-
"""Delete a group from the InferenceData object."""
|
|
242
|
-
if group in self._groups:
|
|
243
|
-
self._groups.remove(group)
|
|
244
|
-
elif group in self._groups_warmup:
|
|
245
|
-
self._groups_warmup.remove(group)
|
|
246
|
-
object.__delattr__(self, group)
|
|
247
|
-
|
|
248
|
-
def __delitem__(self, key: str) -> None:
|
|
249
|
-
"""Delete an item from the InferenceData object using del idata[key]."""
|
|
250
|
-
self.__delattr__(key)
|
|
251
|
-
|
|
252
|
-
@property
|
|
253
|
-
def _groups_all(self) -> List[str]:
|
|
254
|
-
return self._groups + self._groups_warmup
|
|
255
|
-
|
|
256
|
-
def __len__(self) -> int:
|
|
257
|
-
"""Return the number of groups in this InferenceData object."""
|
|
258
|
-
return len(self._groups_all)
|
|
259
|
-
|
|
260
|
-
def __iter__(self) -> Iterator[str]:
|
|
261
|
-
"""Iterate over groups in InferenceData object."""
|
|
262
|
-
yield from self._groups_all
|
|
263
|
-
|
|
264
|
-
def __contains__(self, key: object) -> bool:
|
|
265
|
-
"""Return True if the named item is present, and False otherwise."""
|
|
266
|
-
return key in self._groups_all
|
|
267
|
-
|
|
268
|
-
def __getitem__(self, key: str) -> xr.Dataset:
|
|
269
|
-
"""Get item by key."""
|
|
270
|
-
if key not in self._groups_all:
|
|
271
|
-
raise KeyError(key)
|
|
272
|
-
return getattr(self, key)
|
|
273
|
-
|
|
274
|
-
def __setitem__(self, key: str, value: xr.Dataset):
|
|
275
|
-
"""Set item by key and update group list accordingly."""
|
|
276
|
-
if key.startswith(WARMUP_TAG):
|
|
277
|
-
self._groups_warmup.append(key)
|
|
278
|
-
else:
|
|
279
|
-
self._groups.append(key)
|
|
280
|
-
setattr(self, key, value)
|
|
281
|
-
|
|
282
|
-
def groups(self) -> List[str]:
|
|
283
|
-
"""Return all groups present in InferenceData object."""
|
|
284
|
-
return self._groups_all
|
|
285
|
-
|
|
286
|
-
class InferenceDataValuesView(ValuesView[xr.Dataset]):
|
|
287
|
-
"""ValuesView implementation for InferenceData, to allow it to implement Mapping."""
|
|
288
|
-
|
|
289
|
-
def __init__( # pylint: disable=super-init-not-called
|
|
290
|
-
self, parent: "InferenceData"
|
|
291
|
-
) -> None:
|
|
292
|
-
"""Create a new InferenceDataValuesView from an InferenceData object."""
|
|
293
|
-
self.parent = parent
|
|
294
|
-
|
|
295
|
-
def __len__(self) -> int:
|
|
296
|
-
"""Return the number of groups in the parent InferenceData."""
|
|
297
|
-
return len(self.parent._groups_all)
|
|
298
|
-
|
|
299
|
-
def __iter__(self) -> Iterator[xr.Dataset]:
|
|
300
|
-
"""Iterate through the Xarray datasets present in the InferenceData object."""
|
|
301
|
-
parent = self.parent
|
|
302
|
-
for group in parent._groups_all:
|
|
303
|
-
yield getattr(parent, group)
|
|
304
|
-
|
|
305
|
-
def __contains__(self, key: object) -> bool:
|
|
306
|
-
"""Return True if the given Xarray dataset is one of the values, and False otherwise."""
|
|
307
|
-
if not isinstance(key, xr.Dataset):
|
|
308
|
-
return False
|
|
309
|
-
|
|
310
|
-
for dataset in self:
|
|
311
|
-
if dataset.equals(key):
|
|
312
|
-
return True
|
|
313
|
-
|
|
314
|
-
return False
|
|
315
|
-
|
|
316
|
-
def values(self) -> "InferenceData.InferenceDataValuesView":
|
|
317
|
-
"""Return a view over the Xarray Datasets present in the InferenceData object."""
|
|
318
|
-
return InferenceData.InferenceDataValuesView(self)
|
|
319
|
-
|
|
320
|
-
class InferenceDataItemsView(ItemsView[str, xr.Dataset]):
|
|
321
|
-
"""ItemsView implementation for InferenceData, to allow it to implement Mapping."""
|
|
322
|
-
|
|
323
|
-
def __init__( # pylint: disable=super-init-not-called
|
|
324
|
-
self, parent: "InferenceData"
|
|
325
|
-
) -> None:
|
|
326
|
-
"""Create a new InferenceDataItemsView from an InferenceData object."""
|
|
327
|
-
self.parent = parent
|
|
328
|
-
|
|
329
|
-
def __len__(self) -> int:
|
|
330
|
-
"""Return the number of groups in the parent InferenceData."""
|
|
331
|
-
return len(self.parent._groups_all)
|
|
332
|
-
|
|
333
|
-
def __iter__(self) -> Iterator[Tuple[str, xr.Dataset]]:
|
|
334
|
-
"""Iterate through the groups and corresponding Xarray datasets in the InferenceData."""
|
|
335
|
-
parent = self.parent
|
|
336
|
-
for group in parent._groups_all:
|
|
337
|
-
yield group, getattr(parent, group)
|
|
338
|
-
|
|
339
|
-
def __contains__(self, key: object) -> bool:
|
|
340
|
-
"""Return True if the (group, dataset) tuple is present, and False otherwise."""
|
|
341
|
-
parent = self.parent
|
|
342
|
-
if not isinstance(key, tuple) or len(key) != 2:
|
|
343
|
-
return False
|
|
344
|
-
|
|
345
|
-
group, dataset = key
|
|
346
|
-
if group not in parent._groups_all:
|
|
347
|
-
return False
|
|
348
|
-
|
|
349
|
-
if not isinstance(dataset, xr.Dataset):
|
|
350
|
-
return False
|
|
351
|
-
|
|
352
|
-
existing_dataset = getattr(parent, group)
|
|
353
|
-
return existing_dataset.equals(dataset)
|
|
354
|
-
|
|
355
|
-
def items(self) -> "InferenceData.InferenceDataItemsView":
|
|
356
|
-
"""Return a view over the groups and datasets present in the InferenceData object."""
|
|
357
|
-
return InferenceData.InferenceDataItemsView(self)
|
|
358
|
-
|
|
359
|
-
@staticmethod
|
|
360
|
-
def from_netcdf(
|
|
361
|
-
filename,
|
|
362
|
-
*,
|
|
363
|
-
engine="h5netcdf",
|
|
364
|
-
group_kwargs=None,
|
|
365
|
-
regex=False,
|
|
366
|
-
base_group: str = "/",
|
|
367
|
-
) -> "InferenceData":
|
|
368
|
-
"""Initialize object from a netcdf file.
|
|
369
|
-
|
|
370
|
-
Expects that the file will have groups, each of which can be loaded by xarray.
|
|
371
|
-
By default, the datasets of the InferenceData object will be lazily loaded instead
|
|
372
|
-
of being loaded into memory. This
|
|
373
|
-
behaviour is regulated by the value of ``az.rcParams["data.load"]``.
|
|
374
|
-
|
|
375
|
-
Parameters
|
|
376
|
-
----------
|
|
377
|
-
filename : str
|
|
378
|
-
location of netcdf file
|
|
379
|
-
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
|
|
380
|
-
Library used to read the netcdf file.
|
|
381
|
-
group_kwargs : dict of {str: dict}, optional
|
|
382
|
-
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
|
|
383
|
-
The keys of the higher level should be group names or regex matching group
|
|
384
|
-
names, the inner dicts re passed to ``open_dataset``
|
|
385
|
-
This feature is currently experimental.
|
|
386
|
-
regex : bool, default False
|
|
387
|
-
Specifies where regex search should be used to extend the keyword arguments.
|
|
388
|
-
This feature is currently experimental.
|
|
389
|
-
base_group : str, default "/"
|
|
390
|
-
The group in the netCDF file where the InferenceData is stored. By default,
|
|
391
|
-
assumes that the file only contains an InferenceData object.
|
|
392
|
-
|
|
393
|
-
Returns
|
|
394
|
-
-------
|
|
395
|
-
InferenceData
|
|
396
|
-
"""
|
|
397
|
-
groups = {}
|
|
398
|
-
attrs = {}
|
|
399
|
-
|
|
400
|
-
if engine == "h5netcdf":
|
|
401
|
-
import h5netcdf
|
|
402
|
-
elif engine == "netcdf4":
|
|
403
|
-
import netCDF4 as nc
|
|
404
|
-
else:
|
|
405
|
-
raise ValueError(
|
|
406
|
-
f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4"
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
try:
|
|
410
|
-
with (
|
|
411
|
-
h5netcdf.File(filename, mode="r")
|
|
412
|
-
if engine == "h5netcdf"
|
|
413
|
-
else nc.Dataset(filename, mode="r")
|
|
414
|
-
) as file_handle:
|
|
415
|
-
if base_group == "/":
|
|
416
|
-
data = file_handle
|
|
417
|
-
else:
|
|
418
|
-
data = file_handle[base_group]
|
|
419
|
-
|
|
420
|
-
data_groups = list(data.groups)
|
|
421
|
-
|
|
422
|
-
for group in data_groups:
|
|
423
|
-
group_kws = {}
|
|
424
|
-
|
|
425
|
-
group_kws = {}
|
|
426
|
-
if group_kwargs is not None and regex is False:
|
|
427
|
-
group_kws = group_kwargs.get(group, {})
|
|
428
|
-
if group_kwargs is not None and regex is True:
|
|
429
|
-
for key, kws in group_kwargs.items():
|
|
430
|
-
if re.search(key, group):
|
|
431
|
-
group_kws = kws
|
|
432
|
-
group_kws.setdefault("engine", engine)
|
|
433
|
-
data = xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws)
|
|
434
|
-
if rcParams["data.load"] == "eager":
|
|
435
|
-
with data:
|
|
436
|
-
groups[group] = data.load()
|
|
437
|
-
else:
|
|
438
|
-
groups[group] = data
|
|
439
|
-
|
|
440
|
-
with xr.open_dataset(filename, engine=engine, group=base_group) as data:
|
|
441
|
-
attrs.update(data.load().attrs)
|
|
442
|
-
|
|
443
|
-
return InferenceData(attrs=attrs, **groups)
|
|
444
|
-
except OSError as err:
|
|
445
|
-
if err.errno == -101:
|
|
446
|
-
raise type(err)(
|
|
447
|
-
str(err)
|
|
448
|
-
+ (
|
|
449
|
-
" while reading a NetCDF file. This is probably an error in HDF5, "
|
|
450
|
-
"which happens because your OS does not support HDF5 file locking. See "
|
|
451
|
-
"https://stackoverflow.com/questions/49317927/"
|
|
452
|
-
"errno-101-netcdf-hdf-error-when-opening-netcdf-file#49317928"
|
|
453
|
-
" for a possible solution."
|
|
454
|
-
)
|
|
455
|
-
) from err
|
|
456
|
-
raise err
|
|
457
|
-
|
|
458
|
-
def to_netcdf(
|
|
459
|
-
self,
|
|
460
|
-
filename: str,
|
|
461
|
-
compress: bool = True,
|
|
462
|
-
groups: Optional[List[str]] = None,
|
|
463
|
-
engine: str = "h5netcdf",
|
|
464
|
-
base_group: str = "/",
|
|
465
|
-
overwrite_existing: bool = True,
|
|
466
|
-
) -> str:
|
|
467
|
-
"""Write InferenceData to netcdf4 file.
|
|
468
|
-
|
|
469
|
-
Parameters
|
|
470
|
-
----------
|
|
471
|
-
filename : str
|
|
472
|
-
Location to write to
|
|
473
|
-
compress : bool, optional
|
|
474
|
-
Whether to compress result. Note this saves disk space, but may make
|
|
475
|
-
saving and loading somewhat slower (default: True).
|
|
476
|
-
groups : list, optional
|
|
477
|
-
Write only these groups to netcdf file.
|
|
478
|
-
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
|
|
479
|
-
Library used to read the netcdf file.
|
|
480
|
-
base_group : str, default "/"
|
|
481
|
-
The group in the netCDF file where the InferenceData is will be stored.
|
|
482
|
-
By default, will write to the root of the netCDF file
|
|
483
|
-
overwrite_existing : bool, default True
|
|
484
|
-
Whether to overwrite the existing file or append to it.
|
|
485
|
-
|
|
486
|
-
Returns
|
|
487
|
-
-------
|
|
488
|
-
str
|
|
489
|
-
Location of netcdf file
|
|
490
|
-
"""
|
|
491
|
-
if base_group is None:
|
|
492
|
-
base_group = "/"
|
|
493
|
-
|
|
494
|
-
if os.path.exists(filename) and not overwrite_existing:
|
|
495
|
-
mode = "a"
|
|
496
|
-
else:
|
|
497
|
-
mode = "w" # overwrite first, then append
|
|
498
|
-
|
|
499
|
-
if self._attrs:
|
|
500
|
-
xr.Dataset(attrs=self._attrs).to_netcdf(
|
|
501
|
-
filename, mode=mode, engine=engine, group=base_group
|
|
502
|
-
)
|
|
503
|
-
mode = "a"
|
|
504
|
-
|
|
505
|
-
if self._groups_all: # check's whether a group is present or not.
|
|
506
|
-
if groups is None:
|
|
507
|
-
groups = self._groups_all
|
|
508
|
-
else:
|
|
509
|
-
groups = [group for group in self._groups_all if group in groups]
|
|
510
|
-
|
|
511
|
-
for group in groups:
|
|
512
|
-
data = getattr(self, group)
|
|
513
|
-
kwargs = {"engine": engine}
|
|
514
|
-
if compress:
|
|
515
|
-
kwargs["encoding"] = {
|
|
516
|
-
var_name: {"zlib": True}
|
|
517
|
-
for var_name, values in data.variables.items()
|
|
518
|
-
if _compressible_dtype(values.dtype)
|
|
519
|
-
}
|
|
520
|
-
data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
|
|
521
|
-
data.close()
|
|
522
|
-
mode = "a"
|
|
523
|
-
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
|
|
524
|
-
if engine == "h5netcdf":
|
|
525
|
-
import h5netcdf
|
|
526
|
-
|
|
527
|
-
empty_netcdf_file = h5netcdf.File(filename, mode="w")
|
|
528
|
-
elif engine == "netcdf4":
|
|
529
|
-
import netCDF4 as nc
|
|
530
|
-
|
|
531
|
-
empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
|
|
532
|
-
empty_netcdf_file.close()
|
|
533
|
-
return filename
|
|
534
|
-
|
|
535
|
-
def to_datatree(self):
|
|
536
|
-
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
|
|
537
|
-
try:
|
|
538
|
-
from xarray import DataTree
|
|
539
|
-
except ImportError as err:
|
|
540
|
-
raise ImportError(
|
|
541
|
-
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
|
|
542
|
-
"Update to xarray>=2024.11.0"
|
|
543
|
-
) from err
|
|
544
|
-
dt = DataTree.from_dict({group: ds for group, ds in self.items()})
|
|
545
|
-
dt.attrs = self.attrs
|
|
546
|
-
return dt
|
|
547
|
-
|
|
548
|
-
@staticmethod
|
|
549
|
-
def from_datatree(datatree):
|
|
550
|
-
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
551
|
-
|
|
552
|
-
Parameters
|
|
553
|
-
----------
|
|
554
|
-
datatree : DataTree
|
|
555
|
-
"""
|
|
556
|
-
return InferenceData(
|
|
557
|
-
attrs=datatree.attrs,
|
|
558
|
-
**{group: child.to_dataset() for group, child in datatree.children.items()},
|
|
559
|
-
)
|
|
560
|
-
|
|
561
|
-
def to_dict(self, groups=None, filter_groups=None):
|
|
562
|
-
"""Convert InferenceData to a dictionary following xarray naming conventions.
|
|
563
|
-
|
|
564
|
-
Parameters
|
|
565
|
-
----------
|
|
566
|
-
groups : list, optional
|
|
567
|
-
Groups where the transformation is to be applied. Can either be group names
|
|
568
|
-
or metagroup names.
|
|
569
|
-
filter_groups: {None, "like", "regex"}, optional, default=None
|
|
570
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
571
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
572
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
573
|
-
metagroup names. A la `pandas.filter`.
|
|
574
|
-
|
|
575
|
-
Returns
|
|
576
|
-
-------
|
|
577
|
-
dict
|
|
578
|
-
A dictionary containing all groups of InferenceData object.
|
|
579
|
-
When `data=False` return just the schema.
|
|
580
|
-
"""
|
|
581
|
-
ret = defaultdict(dict)
|
|
582
|
-
if self._groups_all: # check's whether a group is present or not.
|
|
583
|
-
if groups is None:
|
|
584
|
-
groups = self._group_names(groups, filter_groups)
|
|
585
|
-
else:
|
|
586
|
-
groups = [group for group in self._groups_all if group in groups]
|
|
587
|
-
|
|
588
|
-
for group in groups:
|
|
589
|
-
dataset = getattr(self, group)
|
|
590
|
-
data = {}
|
|
591
|
-
for var_name, dataarray in dataset.items():
|
|
592
|
-
data[var_name] = dataarray.values
|
|
593
|
-
dims = []
|
|
594
|
-
for coord_name, coord_values in dataarray.coords.items():
|
|
595
|
-
if coord_name not in ("chain", "draw") and not coord_name.startswith(
|
|
596
|
-
f"{var_name}_dim_"
|
|
597
|
-
):
|
|
598
|
-
dims.append(coord_name)
|
|
599
|
-
ret["coords"][coord_name] = coord_values.values
|
|
600
|
-
|
|
601
|
-
if group in (
|
|
602
|
-
"predictions",
|
|
603
|
-
"predictions_constant_data",
|
|
604
|
-
):
|
|
605
|
-
dims_key = "pred_dims"
|
|
606
|
-
else:
|
|
607
|
-
dims_key = "dims"
|
|
608
|
-
if len(dims) > 0:
|
|
609
|
-
ret[dims_key][var_name] = dims
|
|
610
|
-
ret[group] = data
|
|
611
|
-
ret[f"{group}_attrs"] = dataset.attrs
|
|
612
|
-
|
|
613
|
-
ret["attrs"] = self.attrs
|
|
614
|
-
return ret
|
|
615
|
-
|
|
616
|
-
def to_json(self, filename, groups=None, filter_groups=None, **kwargs):
|
|
617
|
-
"""Write InferenceData to a json file.
|
|
618
|
-
|
|
619
|
-
Parameters
|
|
620
|
-
----------
|
|
621
|
-
filename : str
|
|
622
|
-
Location to write to
|
|
623
|
-
groups : list, optional
|
|
624
|
-
Groups where the transformation is to be applied. Can either be group names
|
|
625
|
-
or metagroup names.
|
|
626
|
-
filter_groups: {None, "like", "regex"}, optional, default=None
|
|
627
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
628
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
629
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
630
|
-
metagroup names. A la `pandas.filter`.
|
|
631
|
-
kwargs : dict
|
|
632
|
-
kwargs passed to json.dump()
|
|
633
|
-
|
|
634
|
-
Returns
|
|
635
|
-
-------
|
|
636
|
-
str
|
|
637
|
-
Location of json file
|
|
638
|
-
"""
|
|
639
|
-
idata_dict = _make_json_serializable(
|
|
640
|
-
self.to_dict(groups=groups, filter_groups=filter_groups)
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
with open(filename, "w", encoding="utf8") as file:
|
|
644
|
-
json.dump(idata_dict, file, **kwargs)
|
|
645
|
-
|
|
646
|
-
return filename
|
|
647
|
-
|
|
648
|
-
def to_dataframe(
|
|
649
|
-
self,
|
|
650
|
-
groups=None,
|
|
651
|
-
filter_groups=None,
|
|
652
|
-
var_names=None,
|
|
653
|
-
filter_vars=None,
|
|
654
|
-
include_coords=True,
|
|
655
|
-
include_index=True,
|
|
656
|
-
index_origin=None,
|
|
657
|
-
):
|
|
658
|
-
"""Convert InferenceData to a :class:`pandas.DataFrame` following xarray naming conventions.
|
|
659
|
-
|
|
660
|
-
This returns dataframe in a "wide" -format, where each item in ndimensional array is
|
|
661
|
-
unpacked. To access "tidy" -format, use xarray functionality found for each dataset.
|
|
662
|
-
|
|
663
|
-
In case of a multiple groups, function adds a group identification to the var name.
|
|
664
|
-
|
|
665
|
-
Data groups ("observed_data", "constant_data", "predictions_constant_data") are
|
|
666
|
-
skipped implicitly.
|
|
667
|
-
|
|
668
|
-
Raises TypeError if no valid groups are found.
|
|
669
|
-
Raises ValueError if no data are selected.
|
|
670
|
-
|
|
671
|
-
Parameters
|
|
672
|
-
----------
|
|
673
|
-
groups: str or list of str, optional
|
|
674
|
-
Groups where the transformation is to be applied. Can either be group names
|
|
675
|
-
or metagroup names.
|
|
676
|
-
filter_groups: {None, "like", "regex"}, optional, default=None
|
|
677
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
678
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
679
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
680
|
-
metagroup names. A la `pandas.filter`.
|
|
681
|
-
var_names : str or list of str, optional
|
|
682
|
-
Variables to be extracted. Prefix the variables by `~` when you want to exclude them.
|
|
683
|
-
filter_vars: {None, "like", "regex"}, optional
|
|
684
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
685
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
686
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
687
|
-
`pandas.filter`.
|
|
688
|
-
Like with plotting, sometimes it's easier to subset saying what to exclude
|
|
689
|
-
instead of what to include
|
|
690
|
-
include_coords: bool
|
|
691
|
-
Add coordinate values to column name (tuple).
|
|
692
|
-
include_index: bool
|
|
693
|
-
Add index information for multidimensional arrays.
|
|
694
|
-
index_origin: {0, 1}, optional
|
|
695
|
-
Starting index for multidimensional objects. 0- or 1-based.
|
|
696
|
-
Defaults to rcParams["data.index_origin"].
|
|
697
|
-
|
|
698
|
-
Returns
|
|
699
|
-
-------
|
|
700
|
-
pandas.DataFrame
|
|
701
|
-
A pandas DataFrame containing all selected groups of InferenceData object.
|
|
702
|
-
"""
|
|
703
|
-
# pylint: disable=too-many-nested-blocks
|
|
704
|
-
if not include_coords and not include_index:
|
|
705
|
-
raise TypeError("Both include_coords and include_index can not be False.")
|
|
706
|
-
if index_origin is None:
|
|
707
|
-
index_origin = rcParams["data.index_origin"]
|
|
708
|
-
if index_origin not in [0, 1]:
|
|
709
|
-
raise TypeError(f"index_origin must be 0 or 1, saw {index_origin}")
|
|
710
|
-
|
|
711
|
-
group_names = list(
|
|
712
|
-
filter(lambda x: "data" not in x, self._group_names(groups, filter_groups))
|
|
713
|
-
)
|
|
714
|
-
|
|
715
|
-
if not group_names:
|
|
716
|
-
raise TypeError(f"No valid groups found: {groups}")
|
|
717
|
-
|
|
718
|
-
dfs = {}
|
|
719
|
-
for group in group_names:
|
|
720
|
-
dataset = self[group]
|
|
721
|
-
group_var_names = _var_names(var_names, dataset, filter_vars, "ignore")
|
|
722
|
-
if (group_var_names is not None) and not group_var_names:
|
|
723
|
-
continue
|
|
724
|
-
if group_var_names is not None:
|
|
725
|
-
dataset = dataset[[var_name for var_name in group_var_names if var_name in dataset]]
|
|
726
|
-
df = None
|
|
727
|
-
coords_to_idx = {
|
|
728
|
-
name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin)))
|
|
729
|
-
for name in list(filter(lambda x: x not in ("chain", "draw"), dataset.coords))
|
|
730
|
-
}
|
|
731
|
-
for data_array in dataset.values():
|
|
732
|
-
dataframe = data_array.to_dataframe()
|
|
733
|
-
if list(filter(lambda x: x not in ("chain", "draw"), data_array.dims)):
|
|
734
|
-
levels = [
|
|
735
|
-
idx
|
|
736
|
-
for idx, dim in enumerate(data_array.dims)
|
|
737
|
-
if dim not in ("chain", "draw")
|
|
738
|
-
]
|
|
739
|
-
dataframe = dataframe.unstack(level=levels)
|
|
740
|
-
tuple_columns = []
|
|
741
|
-
for name, *coords in dataframe.columns:
|
|
742
|
-
if include_index:
|
|
743
|
-
idxs = []
|
|
744
|
-
for coordname, coorditem in zip(dataframe.columns.names[1:], coords):
|
|
745
|
-
idxs.append(coords_to_idx[coordname][coorditem])
|
|
746
|
-
if include_coords:
|
|
747
|
-
tuple_columns.append(
|
|
748
|
-
(f"{name}[{','.join(map(str, idxs))}]", *coords)
|
|
749
|
-
)
|
|
750
|
-
else:
|
|
751
|
-
tuple_columns.append(f"{name}[{','.join(map(str, idxs))}]")
|
|
752
|
-
else:
|
|
753
|
-
tuple_columns.append((name, *coords))
|
|
754
|
-
|
|
755
|
-
dataframe.columns = tuple_columns
|
|
756
|
-
dataframe.sort_index(axis=1, inplace=True)
|
|
757
|
-
if df is None:
|
|
758
|
-
df = dataframe
|
|
759
|
-
continue
|
|
760
|
-
df = df.join(dataframe, how="outer")
|
|
761
|
-
if df is not None:
|
|
762
|
-
df = df.reset_index()
|
|
763
|
-
dfs[group] = df
|
|
764
|
-
if not dfs:
|
|
765
|
-
raise ValueError("No data selected for the dataframe.")
|
|
766
|
-
if len(dfs) > 1:
|
|
767
|
-
for group, df in dfs.items():
|
|
768
|
-
df.columns = [
|
|
769
|
-
(
|
|
770
|
-
col
|
|
771
|
-
if col in ("draw", "chain")
|
|
772
|
-
else (group, *col) if isinstance(col, tuple) else (group, col)
|
|
773
|
-
)
|
|
774
|
-
for col in df.columns
|
|
775
|
-
]
|
|
776
|
-
dfs, *dfs_tail = list(dfs.values())
|
|
777
|
-
for df in dfs_tail:
|
|
778
|
-
dfs = dfs.merge(df, how="outer", copy=False)
|
|
779
|
-
else:
|
|
780
|
-
(dfs,) = dfs.values() # pylint: disable=unbalanced-dict-unpacking
|
|
781
|
-
return dfs
|
|
782
|
-
|
|
783
|
-
def to_zarr(self, store=None):
|
|
784
|
-
"""Convert InferenceData to a :class:`zarr.hierarchy.Group`.
|
|
785
|
-
|
|
786
|
-
The zarr storage is using the same group names as the InferenceData.
|
|
787
|
-
|
|
788
|
-
Raises
|
|
789
|
-
------
|
|
790
|
-
TypeError
|
|
791
|
-
If no valid store is found.
|
|
792
|
-
|
|
793
|
-
Parameters
|
|
794
|
-
----------
|
|
795
|
-
store: zarr.storage i.e MutableMapping or str, optional
|
|
796
|
-
Zarr storage class or path to desired DirectoryStore.
|
|
797
|
-
|
|
798
|
-
Returns
|
|
799
|
-
-------
|
|
800
|
-
zarr.hierarchy.group
|
|
801
|
-
A zarr hierarchy group containing the InferenceData.
|
|
802
|
-
|
|
803
|
-
References
|
|
804
|
-
----------
|
|
805
|
-
https://zarr.readthedocs.io/
|
|
806
|
-
"""
|
|
807
|
-
try:
|
|
808
|
-
import zarr
|
|
809
|
-
except ImportError as err:
|
|
810
|
-
raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
811
|
-
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
812
|
-
raise ImportError(
|
|
813
|
-
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
|
|
814
|
-
)
|
|
815
|
-
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
816
|
-
raise ImportError(
|
|
817
|
-
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
818
|
-
"'dt = InferenceData.to_datatree' followed by 'dt.to_zarr()' "
|
|
819
|
-
"(needs xarray>=2024.11.0)"
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
# Check store type and create store if necessary
|
|
823
|
-
if store is None:
|
|
824
|
-
store = zarr.storage.TempStore(suffix="arviz")
|
|
825
|
-
elif isinstance(store, str):
|
|
826
|
-
store = zarr.storage.DirectoryStore(path=store)
|
|
827
|
-
elif not isinstance(store, MutableMapping):
|
|
828
|
-
raise TypeError(f"No valid store found: {store}")
|
|
829
|
-
|
|
830
|
-
groups = self.groups()
|
|
831
|
-
|
|
832
|
-
if not groups:
|
|
833
|
-
raise TypeError("No valid groups found!")
|
|
834
|
-
|
|
835
|
-
# order matters here, saving attrs after the groups will erase the groups.
|
|
836
|
-
if self.attrs:
|
|
837
|
-
xr.Dataset(attrs=self.attrs).to_zarr(store=store, mode="w")
|
|
838
|
-
|
|
839
|
-
for group in groups:
|
|
840
|
-
# Create zarr group in store with same group name
|
|
841
|
-
getattr(self, group).to_zarr(store=store, group=group, mode="w")
|
|
842
|
-
|
|
843
|
-
return zarr.open(store) # Open store to get overarching group
|
|
844
|
-
|
|
845
|
-
@staticmethod
|
|
846
|
-
def from_zarr(store) -> "InferenceData":
|
|
847
|
-
"""Initialize object from a zarr store or path.
|
|
848
|
-
|
|
849
|
-
Expects that the zarr store will have groups, each of which can be loaded by xarray.
|
|
850
|
-
By default, the datasets of the InferenceData object will be lazily loaded instead
|
|
851
|
-
of being loaded into memory. This
|
|
852
|
-
behaviour is regulated by the value of ``az.rcParams["data.load"]``.
|
|
853
|
-
|
|
854
|
-
Parameters
|
|
855
|
-
----------
|
|
856
|
-
store: MutableMapping or zarr.hierarchy.Group or str.
|
|
857
|
-
Zarr storage class or path to desired Store.
|
|
858
|
-
|
|
859
|
-
Returns
|
|
860
|
-
-------
|
|
861
|
-
InferenceData object
|
|
862
|
-
|
|
863
|
-
References
|
|
864
|
-
----------
|
|
865
|
-
https://zarr.readthedocs.io/
|
|
866
|
-
"""
|
|
867
|
-
try:
|
|
868
|
-
import zarr
|
|
869
|
-
except ImportError as err:
|
|
870
|
-
raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
871
|
-
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
872
|
-
raise ImportError(
|
|
873
|
-
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
|
|
874
|
-
)
|
|
875
|
-
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
876
|
-
raise ImportError(
|
|
877
|
-
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
878
|
-
"'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
|
|
879
|
-
"(needs xarray>=2024.11.0)"
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
# Check store type and create store if necessary
|
|
883
|
-
if isinstance(store, str):
|
|
884
|
-
store = zarr.storage.DirectoryStore(path=store)
|
|
885
|
-
elif isinstance(store, zarr.hierarchy.Group):
|
|
886
|
-
store = store.store
|
|
887
|
-
elif not isinstance(store, MutableMapping):
|
|
888
|
-
raise TypeError(f"No valid store found: {store}")
|
|
889
|
-
|
|
890
|
-
groups = {}
|
|
891
|
-
zarr_handle = zarr.open(store, mode="r")
|
|
892
|
-
|
|
893
|
-
# Open each group via xarray method
|
|
894
|
-
for key_group, _ in zarr_handle.groups():
|
|
895
|
-
with xr.open_zarr(store=store, group=key_group) as data:
|
|
896
|
-
groups[key_group] = data.load() if rcParams["data.load"] == "eager" else data
|
|
897
|
-
|
|
898
|
-
with xr.open_zarr(store=store) as root:
|
|
899
|
-
attrs = root.attrs
|
|
900
|
-
|
|
901
|
-
return InferenceData(attrs=attrs, **groups)
|
|
902
|
-
|
|
903
|
-
def __add__(self, other: "InferenceData") -> "InferenceData":
|
|
904
|
-
"""Concatenate two InferenceData objects."""
|
|
905
|
-
return concat(self, other, copy=True, inplace=False)
|
|
906
|
-
|
|
907
|
-
def sel(
|
|
908
|
-
self: InferenceDataT,
|
|
909
|
-
groups: Optional[Union[str, List[str]]] = None,
|
|
910
|
-
filter_groups: Optional["Literal['like', 'regex']"] = None,
|
|
911
|
-
inplace: bool = False,
|
|
912
|
-
chain_prior: Optional[bool] = None,
|
|
913
|
-
**kwargs: Any,
|
|
914
|
-
) -> Optional[InferenceDataT]:
|
|
915
|
-
"""Perform an xarray selection on all groups.
|
|
916
|
-
|
|
917
|
-
Loops groups to perform Dataset.sel(key=item)
|
|
918
|
-
for every kwarg if key is a dimension of the dataset.
|
|
919
|
-
One example could be performing a burn in cut on the InferenceData object
|
|
920
|
-
or discarding a chain. The selection is performed on all relevant groups (like
|
|
921
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
922
|
-
omitted. See :meth:`xarray.Dataset.sel <xarray:xarray.Dataset.sel>`
|
|
923
|
-
|
|
924
|
-
Parameters
|
|
925
|
-
----------
|
|
926
|
-
groups : str or list of str, optional
|
|
927
|
-
Groups where the selection is to be applied. Can either be group names
|
|
928
|
-
or metagroup names.
|
|
929
|
-
filter_groups : {None, "like", "regex"}, optional, default=None
|
|
930
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
931
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
932
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
933
|
-
metagroup names. A la `pandas.filter`.
|
|
934
|
-
inplace : bool, optional
|
|
935
|
-
If ``True``, modify the InferenceData object inplace,
|
|
936
|
-
otherwise, return the modified copy.
|
|
937
|
-
chain_prior : bool, optional, deprecated
|
|
938
|
-
If ``False``, do not select prior related groups using ``chain`` dim.
|
|
939
|
-
Otherwise, use selection on ``chain`` if present. Default=False
|
|
940
|
-
kwargs : dict, optional
|
|
941
|
-
It must be accepted by Dataset.sel().
|
|
942
|
-
|
|
943
|
-
Returns
|
|
944
|
-
-------
|
|
945
|
-
InferenceData
|
|
946
|
-
A new InferenceData object by default.
|
|
947
|
-
When `inplace==True` perform selection in-place and return `None`
|
|
948
|
-
|
|
949
|
-
Examples
|
|
950
|
-
--------
|
|
951
|
-
Use ``sel`` to discard one chain of the InferenceData object. We first check the
|
|
952
|
-
dimensions of the original object:
|
|
953
|
-
|
|
954
|
-
.. jupyter-execute::
|
|
955
|
-
|
|
956
|
-
import arviz as az
|
|
957
|
-
idata = az.load_arviz_data("centered_eight")
|
|
958
|
-
idata
|
|
959
|
-
|
|
960
|
-
In order to remove the third chain:
|
|
961
|
-
|
|
962
|
-
.. jupyter-execute::
|
|
963
|
-
|
|
964
|
-
idata_subset = idata.sel(chain=[0, 1, 3], groups="posterior_groups")
|
|
965
|
-
idata_subset
|
|
966
|
-
|
|
967
|
-
See Also
|
|
968
|
-
--------
|
|
969
|
-
xarray.Dataset.sel :
|
|
970
|
-
Returns a new dataset with each array indexed by tick labels along the specified
|
|
971
|
-
dimension(s).
|
|
972
|
-
isel : Returns a new dataset with each array indexed along the specified dimension(s).
|
|
973
|
-
"""
|
|
974
|
-
if chain_prior is not None:
|
|
975
|
-
warnings.warn(
|
|
976
|
-
"chain_prior has been deprecated. Use groups argument and "
|
|
977
|
-
"rcParams['data.metagroups'] instead.",
|
|
978
|
-
DeprecationWarning,
|
|
979
|
-
)
|
|
980
|
-
else:
|
|
981
|
-
chain_prior = False
|
|
982
|
-
group_names = self._group_names(groups, filter_groups)
|
|
983
|
-
|
|
984
|
-
out = self if inplace else deepcopy(self)
|
|
985
|
-
for group in group_names:
|
|
986
|
-
dataset = getattr(self, group)
|
|
987
|
-
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
|
|
988
|
-
if not chain_prior and "prior" in group:
|
|
989
|
-
valid_keys -= {"chain"}
|
|
990
|
-
dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
|
|
991
|
-
setattr(out, group, dataset)
|
|
992
|
-
if inplace:
|
|
993
|
-
return None
|
|
994
|
-
else:
|
|
995
|
-
return out
|
|
996
|
-
|
|
997
|
-
def isel(
|
|
998
|
-
self: InferenceDataT,
|
|
999
|
-
groups: Optional[Union[str, List[str]]] = None,
|
|
1000
|
-
filter_groups: Optional["Literal['like', 'regex']"] = None,
|
|
1001
|
-
inplace: bool = False,
|
|
1002
|
-
**kwargs: Any,
|
|
1003
|
-
) -> Optional[InferenceDataT]:
|
|
1004
|
-
"""Perform an xarray selection on all groups.
|
|
1005
|
-
|
|
1006
|
-
Loops groups to perform Dataset.isel(key=item)
|
|
1007
|
-
for every kwarg if key is a dimension of the dataset.
|
|
1008
|
-
One example could be performing a burn in cut on the InferenceData object
|
|
1009
|
-
or discarding a chain. The selection is performed on all relevant groups (like
|
|
1010
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
1011
|
-
omitted. See :meth:`xarray:xarray.Dataset.isel`
|
|
1012
|
-
|
|
1013
|
-
Parameters
|
|
1014
|
-
----------
|
|
1015
|
-
groups : str or list of str, optional
|
|
1016
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1017
|
-
or metagroup names.
|
|
1018
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1019
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1020
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1021
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1022
|
-
metagroup names. A la `pandas.filter`.
|
|
1023
|
-
inplace : bool, optional
|
|
1024
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1025
|
-
otherwise, return the modified copy.
|
|
1026
|
-
kwargs : dict, optional
|
|
1027
|
-
It must be accepted by :meth:`xarray:xarray.Dataset.isel`.
|
|
1028
|
-
|
|
1029
|
-
Returns
|
|
1030
|
-
-------
|
|
1031
|
-
InferenceData
|
|
1032
|
-
A new InferenceData object by default.
|
|
1033
|
-
When `inplace==True` perform selection in-place and return `None`
|
|
1034
|
-
|
|
1035
|
-
Examples
|
|
1036
|
-
--------
|
|
1037
|
-
Use ``isel`` to discard one chain of the InferenceData object. We first check the
|
|
1038
|
-
dimensions of the original object:
|
|
1039
|
-
|
|
1040
|
-
.. jupyter-execute::
|
|
1041
|
-
|
|
1042
|
-
import arviz as az
|
|
1043
|
-
idata = az.load_arviz_data("centered_eight")
|
|
1044
|
-
idata
|
|
1045
|
-
|
|
1046
|
-
In order to remove the third chain:
|
|
1047
|
-
|
|
1048
|
-
.. jupyter-execute::
|
|
1049
|
-
|
|
1050
|
-
idata_subset = idata.isel(chain=[0, 1, 3], groups="posterior_groups")
|
|
1051
|
-
idata_subset
|
|
1052
|
-
|
|
1053
|
-
You can expand the groups and coords in each group to see how now only the chains 0, 1 and
|
|
1054
|
-
3 are present.
|
|
1055
|
-
|
|
1056
|
-
See Also
|
|
1057
|
-
--------
|
|
1058
|
-
xarray.Dataset.isel :
|
|
1059
|
-
Returns a new dataset with each array indexed along the specified dimension(s).
|
|
1060
|
-
sel :
|
|
1061
|
-
Returns a new dataset with each array indexed by tick labels along the specified
|
|
1062
|
-
dimension(s).
|
|
1063
|
-
"""
|
|
1064
|
-
group_names = self._group_names(groups, filter_groups)
|
|
1065
|
-
|
|
1066
|
-
out = self if inplace else deepcopy(self)
|
|
1067
|
-
for group in group_names:
|
|
1068
|
-
dataset = getattr(self, group)
|
|
1069
|
-
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
|
|
1070
|
-
dataset = dataset.isel(**{key: kwargs[key] for key in valid_keys})
|
|
1071
|
-
setattr(out, group, dataset)
|
|
1072
|
-
if inplace:
|
|
1073
|
-
return None
|
|
1074
|
-
else:
|
|
1075
|
-
return out
|
|
1076
|
-
|
|
1077
|
-
def stack(
|
|
1078
|
-
self,
|
|
1079
|
-
dimensions=None,
|
|
1080
|
-
groups=None,
|
|
1081
|
-
filter_groups=None,
|
|
1082
|
-
inplace=False,
|
|
1083
|
-
**kwargs,
|
|
1084
|
-
):
|
|
1085
|
-
"""Perform an xarray stacking on all groups.
|
|
1086
|
-
|
|
1087
|
-
Stack any number of existing dimensions into a single new dimension.
|
|
1088
|
-
Loops groups to perform Dataset.stack(key=value)
|
|
1089
|
-
for every kwarg if value is a dimension of the dataset.
|
|
1090
|
-
The selection is performed on all relevant groups (like
|
|
1091
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
1092
|
-
omitted. See :meth:`xarray:xarray.Dataset.stack`
|
|
1093
|
-
|
|
1094
|
-
Parameters
|
|
1095
|
-
----------
|
|
1096
|
-
dimensions : dict, optional
|
|
1097
|
-
Names of new dimensions, and the existing dimensions that they replace.
|
|
1098
|
-
groups: str or list of str, optional
|
|
1099
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1100
|
-
or metagroup names.
|
|
1101
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1102
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1103
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1104
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1105
|
-
metagroup names. A la `pandas.filter`.
|
|
1106
|
-
inplace : bool, optional
|
|
1107
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1108
|
-
otherwise, return the modified copy.
|
|
1109
|
-
kwargs : dict, optional
|
|
1110
|
-
It must be accepted by :meth:`xarray:xarray.Dataset.stack`.
|
|
1111
|
-
|
|
1112
|
-
Returns
|
|
1113
|
-
-------
|
|
1114
|
-
InferenceData
|
|
1115
|
-
A new InferenceData object by default.
|
|
1116
|
-
When `inplace==True` perform selection in-place and return `None`
|
|
1117
|
-
|
|
1118
|
-
Examples
|
|
1119
|
-
--------
|
|
1120
|
-
Use ``stack`` to stack any number of existing dimensions into a single new dimension.
|
|
1121
|
-
We first check the original object:
|
|
1122
|
-
|
|
1123
|
-
.. jupyter-execute::
|
|
1124
|
-
|
|
1125
|
-
import arviz as az
|
|
1126
|
-
idata = az.load_arviz_data("rugby")
|
|
1127
|
-
idata
|
|
1128
|
-
|
|
1129
|
-
In order to stack two dimensions ``chain`` and ``draw`` to ``sample``, we can use:
|
|
1130
|
-
|
|
1131
|
-
.. jupyter-execute::
|
|
1132
|
-
|
|
1133
|
-
idata.stack(sample=["chain", "draw"], inplace=True)
|
|
1134
|
-
idata
|
|
1135
|
-
|
|
1136
|
-
We can also take the example of custom InferenceData object and perform stacking. We first
|
|
1137
|
-
check the original object:
|
|
1138
|
-
|
|
1139
|
-
.. jupyter-execute::
|
|
1140
|
-
|
|
1141
|
-
import numpy as np
|
|
1142
|
-
datadict = {
|
|
1143
|
-
"a": np.random.randn(100),
|
|
1144
|
-
"b": np.random.randn(1, 100, 10),
|
|
1145
|
-
"c": np.random.randn(1, 100, 3, 4),
|
|
1146
|
-
}
|
|
1147
|
-
coords = {
|
|
1148
|
-
"c1": np.arange(3),
|
|
1149
|
-
"c99": np.arange(4),
|
|
1150
|
-
"b1": np.arange(10),
|
|
1151
|
-
}
|
|
1152
|
-
dims = {"c": ["c1", "c99"], "b": ["b1"]}
|
|
1153
|
-
idata = az.from_dict(
|
|
1154
|
-
posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims
|
|
1155
|
-
)
|
|
1156
|
-
idata
|
|
1157
|
-
|
|
1158
|
-
In order to stack two dimensions ``c1`` and ``c99`` to ``z``, we can use:
|
|
1159
|
-
|
|
1160
|
-
.. jupyter-execute::
|
|
1161
|
-
|
|
1162
|
-
idata.stack(z=["c1", "c99"], inplace=True)
|
|
1163
|
-
idata
|
|
1164
|
-
|
|
1165
|
-
See Also
|
|
1166
|
-
--------
|
|
1167
|
-
xarray.Dataset.stack : Stack any number of existing dimensions into a single new dimension.
|
|
1168
|
-
unstack : Perform an xarray unstacking on all groups of InferenceData object.
|
|
1169
|
-
"""
|
|
1170
|
-
groups = self._group_names(groups, filter_groups)
|
|
1171
|
-
|
|
1172
|
-
dimensions = {} if dimensions is None else dimensions
|
|
1173
|
-
dimensions.update(kwargs)
|
|
1174
|
-
out = self if inplace else deepcopy(self)
|
|
1175
|
-
for group in groups:
|
|
1176
|
-
dataset = getattr(self, group)
|
|
1177
|
-
kwarg_dict = {}
|
|
1178
|
-
for key, value in dimensions.items():
|
|
1179
|
-
try:
|
|
1180
|
-
if not set(value).difference(dataset.dims):
|
|
1181
|
-
kwarg_dict[key] = value
|
|
1182
|
-
except TypeError:
|
|
1183
|
-
kwarg_dict[key] = value
|
|
1184
|
-
dataset = dataset.stack(**kwarg_dict)
|
|
1185
|
-
setattr(out, group, dataset)
|
|
1186
|
-
if inplace:
|
|
1187
|
-
return None
|
|
1188
|
-
else:
|
|
1189
|
-
return out
|
|
1190
|
-
|
|
1191
|
-
def unstack(self, dim=None, groups=None, filter_groups=None, inplace=False):
|
|
1192
|
-
"""Perform an xarray unstacking on all groups.
|
|
1193
|
-
|
|
1194
|
-
Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
|
|
1195
|
-
Loops groups to perform Dataset.unstack(key=value).
|
|
1196
|
-
The selection is performed on all relevant groups (like posterior, prior,
|
|
1197
|
-
sample stats) while non relevant groups like observed data are omitted.
|
|
1198
|
-
See :meth:`xarray:xarray.Dataset.unstack`
|
|
1199
|
-
|
|
1200
|
-
Parameters
|
|
1201
|
-
----------
|
|
1202
|
-
dim : Hashable or iterable of Hashable, optional
|
|
1203
|
-
Dimension(s) over which to unstack. By default unstacks all MultiIndexes.
|
|
1204
|
-
groups : str or list of str, optional
|
|
1205
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1206
|
-
or metagroup names.
|
|
1207
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1208
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1209
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1210
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1211
|
-
metagroup names. A la `pandas.filter`.
|
|
1212
|
-
inplace : bool, optional
|
|
1213
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1214
|
-
otherwise, return the modified copy.
|
|
1215
|
-
|
|
1216
|
-
Returns
|
|
1217
|
-
-------
|
|
1218
|
-
InferenceData
|
|
1219
|
-
A new InferenceData object by default.
|
|
1220
|
-
When `inplace==True` perform selection in place and return `None`
|
|
1221
|
-
|
|
1222
|
-
Examples
|
|
1223
|
-
--------
|
|
1224
|
-
Use ``unstack`` to unstack existing dimensions corresponding to MultiIndexes into
|
|
1225
|
-
multiple new dimensions. We first stack two dimensions ``c1`` and ``c99`` to ``z``:
|
|
1226
|
-
|
|
1227
|
-
.. jupyter-execute::
|
|
1228
|
-
|
|
1229
|
-
import arviz as az
|
|
1230
|
-
import numpy as np
|
|
1231
|
-
datadict = {
|
|
1232
|
-
"a": np.random.randn(100),
|
|
1233
|
-
"b": np.random.randn(1, 100, 10),
|
|
1234
|
-
"c": np.random.randn(1, 100, 3, 4),
|
|
1235
|
-
}
|
|
1236
|
-
coords = {
|
|
1237
|
-
"c1": np.arange(3),
|
|
1238
|
-
"c99": np.arange(4),
|
|
1239
|
-
"b1": np.arange(10),
|
|
1240
|
-
}
|
|
1241
|
-
dims = {"c": ["c1", "c99"], "b": ["b1"]}
|
|
1242
|
-
idata = az.from_dict(
|
|
1243
|
-
posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims
|
|
1244
|
-
)
|
|
1245
|
-
idata.stack(z=["c1", "c99"], inplace=True)
|
|
1246
|
-
idata
|
|
1247
|
-
|
|
1248
|
-
In order to unstack the dimension ``z``, we use:
|
|
1249
|
-
|
|
1250
|
-
.. jupyter-execute::
|
|
1251
|
-
|
|
1252
|
-
idata.unstack(inplace=True)
|
|
1253
|
-
idata
|
|
1254
|
-
|
|
1255
|
-
See Also
|
|
1256
|
-
--------
|
|
1257
|
-
xarray.Dataset.unstack :
|
|
1258
|
-
Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
|
|
1259
|
-
stack : Perform an xarray stacking on all groups of InferenceData object.
|
|
1260
|
-
"""
|
|
1261
|
-
groups = self._group_names(groups, filter_groups)
|
|
1262
|
-
if isinstance(dim, str):
|
|
1263
|
-
dim = [dim]
|
|
1264
|
-
|
|
1265
|
-
out = self if inplace else deepcopy(self)
|
|
1266
|
-
for group in groups:
|
|
1267
|
-
dataset = getattr(self, group)
|
|
1268
|
-
valid_dims = set(dim).intersection(dataset.dims) if dim is not None else dim
|
|
1269
|
-
dataset = dataset.unstack(dim=valid_dims)
|
|
1270
|
-
setattr(out, group, dataset)
|
|
1271
|
-
if inplace:
|
|
1272
|
-
return None
|
|
1273
|
-
else:
|
|
1274
|
-
return out
|
|
1275
|
-
|
|
1276
|
-
def rename(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
|
1277
|
-
"""Perform xarray renaming of variable and dimensions on all groups.
|
|
1278
|
-
|
|
1279
|
-
Loops groups to perform Dataset.rename(name_dict)
|
|
1280
|
-
for every key in name_dict if key is a dimension/data_vars of the dataset.
|
|
1281
|
-
The renaming is performed on all relevant groups (like
|
|
1282
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
1283
|
-
omitted. See :meth:`xarray:xarray.Dataset.rename`
|
|
1284
|
-
|
|
1285
|
-
Parameters
|
|
1286
|
-
----------
|
|
1287
|
-
name_dict : dict
|
|
1288
|
-
Dictionary whose keys are current variable or dimension names
|
|
1289
|
-
and whose values are the desired names.
|
|
1290
|
-
groups : str or list of str, optional
|
|
1291
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1292
|
-
or metagroup names.
|
|
1293
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1294
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1295
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1296
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1297
|
-
metagroup names. A la `pandas.filter`.
|
|
1298
|
-
inplace : bool, optional
|
|
1299
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1300
|
-
otherwise, return the modified copy.
|
|
1301
|
-
|
|
1302
|
-
Returns
|
|
1303
|
-
-------
|
|
1304
|
-
InferenceData
|
|
1305
|
-
A new InferenceData object by default.
|
|
1306
|
-
When `inplace==True` perform renaming in-place and return `None`
|
|
1307
|
-
|
|
1308
|
-
Examples
|
|
1309
|
-
--------
|
|
1310
|
-
Use ``rename`` to renaming of variable and dimensions on all groups of the InferenceData
|
|
1311
|
-
object. We first check the original object:
|
|
1312
|
-
|
|
1313
|
-
.. jupyter-execute::
|
|
1314
|
-
|
|
1315
|
-
import arviz as az
|
|
1316
|
-
idata = az.load_arviz_data("rugby")
|
|
1317
|
-
idata
|
|
1318
|
-
|
|
1319
|
-
In order to rename the dimensions and variable, we use:
|
|
1320
|
-
|
|
1321
|
-
.. jupyter-execute::
|
|
1322
|
-
|
|
1323
|
-
idata.rename({"team": "team_new", "match":"match_new"}, inplace=True)
|
|
1324
|
-
idata
|
|
1325
|
-
|
|
1326
|
-
See Also
|
|
1327
|
-
--------
|
|
1328
|
-
xarray.Dataset.rename : Returns a new object with renamed variables and dimensions.
|
|
1329
|
-
rename_vars :
|
|
1330
|
-
Perform xarray renaming of variable or coordinate names on all groups of an
|
|
1331
|
-
InferenceData object.
|
|
1332
|
-
rename_dims : Perform xarray renaming of dimensions on all groups of InferenceData object.
|
|
1333
|
-
"""
|
|
1334
|
-
groups = self._group_names(groups, filter_groups)
|
|
1335
|
-
if "chain" in name_dict.keys() or "draw" in name_dict.keys():
|
|
1336
|
-
raise KeyError("'chain' or 'draw' dimensions can't be renamed")
|
|
1337
|
-
out = self if inplace else deepcopy(self)
|
|
1338
|
-
|
|
1339
|
-
for group in groups:
|
|
1340
|
-
dataset = getattr(self, group)
|
|
1341
|
-
expected_keys = list(dataset.data_vars) + list(dataset.dims)
|
|
1342
|
-
valid_keys = set(name_dict.keys()).intersection(expected_keys)
|
|
1343
|
-
dataset = dataset.rename({key: name_dict[key] for key in valid_keys})
|
|
1344
|
-
setattr(out, group, dataset)
|
|
1345
|
-
if inplace:
|
|
1346
|
-
return None
|
|
1347
|
-
else:
|
|
1348
|
-
return out
|
|
1349
|
-
|
|
1350
|
-
def rename_vars(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
|
1351
|
-
"""Perform xarray renaming of variable or coordinate names on all groups.
|
|
1352
|
-
|
|
1353
|
-
Loops groups to perform Dataset.rename_vars(name_dict)
|
|
1354
|
-
for every key in name_dict if key is a variable or coordinate names of the dataset.
|
|
1355
|
-
The renaming is performed on all relevant groups (like
|
|
1356
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
1357
|
-
omitted. See :meth:`xarray:xarray.Dataset.rename_vars`
|
|
1358
|
-
|
|
1359
|
-
Parameters
|
|
1360
|
-
----------
|
|
1361
|
-
name_dict : dict
|
|
1362
|
-
Dictionary whose keys are current variable or coordinate names
|
|
1363
|
-
and whose values are the desired names.
|
|
1364
|
-
groups : str or list of str, optional
|
|
1365
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1366
|
-
or metagroup names.
|
|
1367
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1368
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1369
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1370
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1371
|
-
metagroup names. A la `pandas.filter`.
|
|
1372
|
-
inplace : bool, optional
|
|
1373
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1374
|
-
otherwise, return the modified copy.
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
Returns
|
|
1378
|
-
-------
|
|
1379
|
-
InferenceData
|
|
1380
|
-
A new InferenceData object with renamed variables including coordinates by default.
|
|
1381
|
-
When `inplace==True` perform renaming in-place and return `None`
|
|
1382
|
-
|
|
1383
|
-
Examples
|
|
1384
|
-
--------
|
|
1385
|
-
Use ``rename_vars`` to renaming of variable and coordinates on all groups of the
|
|
1386
|
-
InferenceData object. We first check the data variables of original object:
|
|
1387
|
-
|
|
1388
|
-
.. jupyter-execute::
|
|
1389
|
-
|
|
1390
|
-
import arviz as az
|
|
1391
|
-
idata = az.load_arviz_data("rugby")
|
|
1392
|
-
idata
|
|
1393
|
-
|
|
1394
|
-
In order to rename the data variables, we use:
|
|
1395
|
-
|
|
1396
|
-
.. jupyter-execute::
|
|
1397
|
-
|
|
1398
|
-
idata.rename_vars({"home": "home_new"}, inplace=True)
|
|
1399
|
-
idata
|
|
1400
|
-
|
|
1401
|
-
See Also
|
|
1402
|
-
--------
|
|
1403
|
-
xarray.Dataset.rename_vars :
|
|
1404
|
-
Returns a new object with renamed variables including coordinates.
|
|
1405
|
-
rename :
|
|
1406
|
-
Perform xarray renaming of variable and dimensions on all groups of an InferenceData
|
|
1407
|
-
object.
|
|
1408
|
-
rename_dims : Perform xarray renaming of dimensions on all groups of InferenceData object.
|
|
1409
|
-
"""
|
|
1410
|
-
groups = self._group_names(groups, filter_groups)
|
|
1411
|
-
|
|
1412
|
-
out = self if inplace else deepcopy(self)
|
|
1413
|
-
for group in groups:
|
|
1414
|
-
dataset = getattr(self, group)
|
|
1415
|
-
valid_keys = set(name_dict.keys()).intersection(dataset.data_vars)
|
|
1416
|
-
dataset = dataset.rename_vars({key: name_dict[key] for key in valid_keys})
|
|
1417
|
-
setattr(out, group, dataset)
|
|
1418
|
-
if inplace:
|
|
1419
|
-
return None
|
|
1420
|
-
else:
|
|
1421
|
-
return out
|
|
1422
|
-
|
|
1423
|
-
def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
|
1424
|
-
"""Perform xarray renaming of dimensions on all groups.
|
|
1425
|
-
|
|
1426
|
-
Loops groups to perform Dataset.rename_dims(name_dict)
|
|
1427
|
-
for every key in name_dict if key is a dimension of the dataset.
|
|
1428
|
-
The renaming is performed on all relevant groups (like
|
|
1429
|
-
posterior, prior, sample stats) while non relevant groups like observed data are
|
|
1430
|
-
omitted. See :meth:`xarray:xarray.Dataset.rename_dims`
|
|
1431
|
-
|
|
1432
|
-
Parameters
|
|
1433
|
-
----------
|
|
1434
|
-
name_dict : dict
|
|
1435
|
-
Dictionary whose keys are current dimension names and whose values are the desired
|
|
1436
|
-
names.
|
|
1437
|
-
groups : str or list of str, optional
|
|
1438
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1439
|
-
or metagroup names.
|
|
1440
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1441
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1442
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1443
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1444
|
-
metagroup names. A la `pandas.filter`.
|
|
1445
|
-
inplace : bool, optional
|
|
1446
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1447
|
-
otherwise, return the modified copy.
|
|
1448
|
-
|
|
1449
|
-
Returns
|
|
1450
|
-
-------
|
|
1451
|
-
InferenceData
|
|
1452
|
-
A new InferenceData object with renamed dimension by default.
|
|
1453
|
-
When `inplace==True` perform renaming in-place and return `None`
|
|
1454
|
-
|
|
1455
|
-
Examples
|
|
1456
|
-
--------
|
|
1457
|
-
Use ``rename_dims`` to renaming of dimensions on all groups of the InferenceData
|
|
1458
|
-
object. We first check the dimensions of original object:
|
|
1459
|
-
|
|
1460
|
-
.. jupyter-execute::
|
|
1461
|
-
|
|
1462
|
-
import arviz as az
|
|
1463
|
-
idata = az.load_arviz_data("rugby")
|
|
1464
|
-
idata
|
|
1465
|
-
|
|
1466
|
-
In order to rename the dimensions, we use:
|
|
1467
|
-
|
|
1468
|
-
.. jupyter-execute::
|
|
1469
|
-
|
|
1470
|
-
idata.rename_dims({"team": "team_new"}, inplace=True)
|
|
1471
|
-
idata
|
|
1472
|
-
|
|
1473
|
-
See Also
|
|
1474
|
-
--------
|
|
1475
|
-
xarray.Dataset.rename_dims : Returns a new object with renamed dimensions only.
|
|
1476
|
-
rename :
|
|
1477
|
-
Perform xarray renaming of variable and dimensions on all groups of an InferenceData
|
|
1478
|
-
object.
|
|
1479
|
-
rename_vars :
|
|
1480
|
-
Perform xarray renaming of variable or coordinate names on all groups of an
|
|
1481
|
-
InferenceData object.
|
|
1482
|
-
"""
|
|
1483
|
-
groups = self._group_names(groups, filter_groups)
|
|
1484
|
-
if "chain" in name_dict.keys() or "draw" in name_dict.keys():
|
|
1485
|
-
raise KeyError("'chain' or 'draw' dimensions can't be renamed")
|
|
1486
|
-
|
|
1487
|
-
out = self if inplace else deepcopy(self)
|
|
1488
|
-
for group in groups:
|
|
1489
|
-
dataset = getattr(self, group)
|
|
1490
|
-
valid_keys = set(name_dict.keys()).intersection(dataset.dims)
|
|
1491
|
-
dataset = dataset.rename_dims({key: name_dict[key] for key in valid_keys})
|
|
1492
|
-
setattr(out, group, dataset)
|
|
1493
|
-
if inplace:
|
|
1494
|
-
return None
|
|
1495
|
-
else:
|
|
1496
|
-
return out
|
|
1497
|
-
|
|
1498
|
-
def add_groups(
|
|
1499
|
-
self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
|
|
1500
|
-
):
|
|
1501
|
-
"""Add new groups to InferenceData object.
|
|
1502
|
-
|
|
1503
|
-
Parameters
|
|
1504
|
-
----------
|
|
1505
|
-
group_dict : dict of {str : dict or xarray.Dataset}, optional
|
|
1506
|
-
Groups to be added
|
|
1507
|
-
coords : dict of {str : array_like}, optional
|
|
1508
|
-
Coordinates for the dataset
|
|
1509
|
-
dims : dict of {str : list of str}, optional
|
|
1510
|
-
Dimensions of each variable. The keys are variable names, values are lists of
|
|
1511
|
-
coordinates.
|
|
1512
|
-
warn_on_custom_groups : bool, default False
|
|
1513
|
-
Emit a warning when custom groups are present in the InferenceData.
|
|
1514
|
-
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1515
|
-
kwargs : dict, optional
|
|
1516
|
-
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
|
|
1517
|
-
|
|
1518
|
-
Examples
|
|
1519
|
-
--------
|
|
1520
|
-
Add a ``log_likelihood`` group to the "rugby" example InferenceData after loading.
|
|
1521
|
-
|
|
1522
|
-
.. jupyter-execute::
|
|
1523
|
-
|
|
1524
|
-
import arviz as az
|
|
1525
|
-
idata = az.load_arviz_data("rugby")
|
|
1526
|
-
del idata.log_likelihood
|
|
1527
|
-
idata2 = idata.copy()
|
|
1528
|
-
post = idata.posterior
|
|
1529
|
-
obs = idata.observed_data
|
|
1530
|
-
idata
|
|
1531
|
-
|
|
1532
|
-
Knowing the model, we can compute it manually. In this case however,
|
|
1533
|
-
we will generate random samples with the right shape.
|
|
1534
|
-
|
|
1535
|
-
.. jupyter-execute::
|
|
1536
|
-
|
|
1537
|
-
import numpy as np
|
|
1538
|
-
rng = np.random.default_rng(73)
|
|
1539
|
-
ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
|
|
1540
|
-
idata.add_groups(
|
|
1541
|
-
log_likelihood={"home_points": ary},
|
|
1542
|
-
dims={"home_points": ["match"]},
|
|
1543
|
-
)
|
|
1544
|
-
idata
|
|
1545
|
-
|
|
1546
|
-
This is fine if we have raw data, but a bit inconvenient if we start with labeled
|
|
1547
|
-
data already. Why provide dims and coords manually again?
|
|
1548
|
-
Let's generate a fake log likelihood (doesn't match the model but it serves just
|
|
1549
|
-
the same for illustration purposes here) working from the posterior and
|
|
1550
|
-
observed_data groups manually:
|
|
1551
|
-
|
|
1552
|
-
.. jupyter-execute::
|
|
1553
|
-
|
|
1554
|
-
import xarray as xr
|
|
1555
|
-
from xarray_einstats.stats import XrDiscreteRV
|
|
1556
|
-
from scipy.stats import poisson
|
|
1557
|
-
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
|
|
1558
|
-
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
|
|
1559
|
-
idata2.add_groups({"log_likelihood": log_lik})
|
|
1560
|
-
idata2
|
|
1561
|
-
|
|
1562
|
-
Note that in the first example we have used the ``kwargs`` argument
|
|
1563
|
-
and in the second we have used the ``group_dict`` one.
|
|
1564
|
-
|
|
1565
|
-
See Also
|
|
1566
|
-
--------
|
|
1567
|
-
extend : Extend InferenceData with groups from another InferenceData.
|
|
1568
|
-
concat : Concatenate InferenceData objects.
|
|
1569
|
-
"""
|
|
1570
|
-
group_dict = either_dict_or_kwargs(group_dict, kwargs, "add_groups")
|
|
1571
|
-
if not group_dict:
|
|
1572
|
-
raise ValueError("One of group_dict or kwargs must be provided.")
|
|
1573
|
-
repeated_groups = [group for group in group_dict.keys() if group in self._groups]
|
|
1574
|
-
if repeated_groups:
|
|
1575
|
-
raise ValueError(f"{repeated_groups} group(s) already exists.")
|
|
1576
|
-
for group, dataset in group_dict.items():
|
|
1577
|
-
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1578
|
-
warnings.warn(
|
|
1579
|
-
f"The group {group} is not defined in the InferenceData scheme",
|
|
1580
|
-
UserWarning,
|
|
1581
|
-
)
|
|
1582
|
-
if dataset is None:
|
|
1583
|
-
continue
|
|
1584
|
-
elif isinstance(dataset, dict):
|
|
1585
|
-
if (
|
|
1586
|
-
group in ("observed_data", "constant_data", "predictions_constant_data")
|
|
1587
|
-
or group not in SUPPORTED_GROUPS_ALL
|
|
1588
|
-
):
|
|
1589
|
-
warnings.warn(
|
|
1590
|
-
"the default dims 'chain' and 'draw' will be added automatically",
|
|
1591
|
-
UserWarning,
|
|
1592
|
-
)
|
|
1593
|
-
dataset = dict_to_dataset(dataset, coords=coords, dims=dims)
|
|
1594
|
-
elif isinstance(dataset, xr.DataArray):
|
|
1595
|
-
if dataset.name is None:
|
|
1596
|
-
dataset.name = "x"
|
|
1597
|
-
dataset = dataset.to_dataset()
|
|
1598
|
-
elif not isinstance(dataset, xr.Dataset):
|
|
1599
|
-
raise ValueError(
|
|
1600
|
-
"Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts\
|
|
1601
|
-
(argument '{}' was type '{}')".format(
|
|
1602
|
-
group, type(dataset)
|
|
1603
|
-
)
|
|
1604
|
-
)
|
|
1605
|
-
if dataset:
|
|
1606
|
-
setattr(self, group, dataset)
|
|
1607
|
-
if group.startswith(WARMUP_TAG):
|
|
1608
|
-
supported_order = [
|
|
1609
|
-
key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
|
|
1610
|
-
]
|
|
1611
|
-
if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
|
|
1612
|
-
group_order = [
|
|
1613
|
-
key
|
|
1614
|
-
for key in SUPPORTED_GROUPS_ALL
|
|
1615
|
-
if key in self._groups_warmup + [group]
|
|
1616
|
-
]
|
|
1617
|
-
group_idx = group_order.index(group)
|
|
1618
|
-
self._groups_warmup.insert(group_idx, group)
|
|
1619
|
-
else:
|
|
1620
|
-
self._groups_warmup.append(group)
|
|
1621
|
-
else:
|
|
1622
|
-
supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
|
|
1623
|
-
if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
|
|
1624
|
-
group_order = [
|
|
1625
|
-
key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
|
|
1626
|
-
]
|
|
1627
|
-
group_idx = group_order.index(group)
|
|
1628
|
-
self._groups.insert(group_idx, group)
|
|
1629
|
-
else:
|
|
1630
|
-
self._groups.append(group)
|
|
1631
|
-
|
|
1632
|
-
def extend(self, other, join="left", warn_on_custom_groups=False):
|
|
1633
|
-
"""Extend InferenceData with groups from another InferenceData.
|
|
1634
|
-
|
|
1635
|
-
Parameters
|
|
1636
|
-
----------
|
|
1637
|
-
other : InferenceData
|
|
1638
|
-
InferenceData to be added
|
|
1639
|
-
join : {'left', 'right'}, default 'left'
|
|
1640
|
-
Defines how the two decide which group to keep when the same group is
|
|
1641
|
-
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
|
|
1642
|
-
will keep the group in ``other`` and discard the one in ``self``.
|
|
1643
|
-
warn_on_custom_groups : bool, default False
|
|
1644
|
-
Emit a warning when custom groups are present in the InferenceData.
|
|
1645
|
-
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1646
|
-
|
|
1647
|
-
Examples
|
|
1648
|
-
--------
|
|
1649
|
-
Take two InferenceData objects, and extend the first with the groups it doesn't have
|
|
1650
|
-
but are present in the 2nd InferenceData object.
|
|
1651
|
-
|
|
1652
|
-
First InferenceData:
|
|
1653
|
-
|
|
1654
|
-
.. jupyter-execute::
|
|
1655
|
-
|
|
1656
|
-
import arviz as az
|
|
1657
|
-
idata = az.load_arviz_data("radon")
|
|
1658
|
-
|
|
1659
|
-
Second InferenceData:
|
|
1660
|
-
|
|
1661
|
-
.. jupyter-execute::
|
|
1662
|
-
|
|
1663
|
-
other_idata = az.load_arviz_data("rugby")
|
|
1664
|
-
|
|
1665
|
-
Call the ``extend`` method:
|
|
1666
|
-
|
|
1667
|
-
.. jupyter-execute::
|
|
1668
|
-
|
|
1669
|
-
idata.extend(other_idata)
|
|
1670
|
-
idata
|
|
1671
|
-
|
|
1672
|
-
See how now the first InferenceData has more groups, with the data from the
|
|
1673
|
-
second one, but the groups it originally had have not been modified,
|
|
1674
|
-
even if also present in the second InferenceData.
|
|
1675
|
-
|
|
1676
|
-
See Also
|
|
1677
|
-
--------
|
|
1678
|
-
add_groups : Add new groups to InferenceData object.
|
|
1679
|
-
concat : Concatenate InferenceData objects.
|
|
1680
|
-
|
|
1681
|
-
"""
|
|
1682
|
-
if not isinstance(other, InferenceData):
|
|
1683
|
-
raise ValueError("Extending is possible between two InferenceData objects only.")
|
|
1684
|
-
if join not in ("left", "right"):
|
|
1685
|
-
raise ValueError(f"join must be either 'left' or 'right', found {join}")
|
|
1686
|
-
for group in other._groups_all: # pylint: disable=protected-access
|
|
1687
|
-
if hasattr(self, group) and join == "left":
|
|
1688
|
-
continue
|
|
1689
|
-
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1690
|
-
warnings.warn(
|
|
1691
|
-
f"{group} group is not defined in the InferenceData scheme", UserWarning
|
|
1692
|
-
)
|
|
1693
|
-
dataset = getattr(other, group)
|
|
1694
|
-
setattr(self, group, dataset)
|
|
1695
|
-
if group.startswith(WARMUP_TAG):
|
|
1696
|
-
if group not in self._groups_warmup:
|
|
1697
|
-
supported_order = [
|
|
1698
|
-
key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
|
|
1699
|
-
]
|
|
1700
|
-
if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
|
|
1701
|
-
group_order = [
|
|
1702
|
-
key
|
|
1703
|
-
for key in SUPPORTED_GROUPS_ALL
|
|
1704
|
-
if key in self._groups_warmup + [group]
|
|
1705
|
-
]
|
|
1706
|
-
group_idx = group_order.index(group)
|
|
1707
|
-
self._groups_warmup.insert(group_idx, group)
|
|
1708
|
-
else:
|
|
1709
|
-
self._groups_warmup.append(group)
|
|
1710
|
-
elif group not in self._groups:
|
|
1711
|
-
supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
|
|
1712
|
-
if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
|
|
1713
|
-
group_order = [
|
|
1714
|
-
key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
|
|
1715
|
-
]
|
|
1716
|
-
group_idx = group_order.index(group)
|
|
1717
|
-
self._groups.insert(group_idx, group)
|
|
1718
|
-
else:
|
|
1719
|
-
self._groups.append(group)
|
|
1720
|
-
|
|
1721
|
-
set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index")
|
|
1722
|
-
get_index = _extend_xr_method(xr.Dataset.get_index)
|
|
1723
|
-
reset_index = _extend_xr_method(xr.Dataset.reset_index, see_also="set_index")
|
|
1724
|
-
set_coords = _extend_xr_method(xr.Dataset.set_coords, see_also="reset_coords")
|
|
1725
|
-
reset_coords = _extend_xr_method(xr.Dataset.reset_coords, see_also="set_coords")
|
|
1726
|
-
assign = _extend_xr_method(xr.Dataset.assign)
|
|
1727
|
-
assign_coords = _extend_xr_method(xr.Dataset.assign_coords)
|
|
1728
|
-
sortby = _extend_xr_method(xr.Dataset.sortby)
|
|
1729
|
-
chunk = _extend_xr_method(xr.Dataset.chunk)
|
|
1730
|
-
unify_chunks = _extend_xr_method(xr.Dataset.unify_chunks)
|
|
1731
|
-
load = _extend_xr_method(xr.Dataset.load)
|
|
1732
|
-
compute = _extend_xr_method(xr.Dataset.compute)
|
|
1733
|
-
persist = _extend_xr_method(xr.Dataset.persist)
|
|
1734
|
-
quantile = _extend_xr_method(xr.Dataset.quantile)
|
|
1735
|
-
close = _extend_xr_method(xr.Dataset.close)
|
|
1736
|
-
|
|
1737
|
-
# The following lines use methods on xr.Dataset that are dynamically defined and attached.
|
|
1738
|
-
# As a result mypy cannot see them, so we have to suppress the resulting mypy errors.
|
|
1739
|
-
mean = _extend_xr_method(xr.Dataset.mean, see_also="median") # type: ignore[attr-defined]
|
|
1740
|
-
median = _extend_xr_method(xr.Dataset.median, see_also="mean") # type: ignore[attr-defined]
|
|
1741
|
-
min = _extend_xr_method(xr.Dataset.min, see_also=["max", "sum"]) # type: ignore[attr-defined]
|
|
1742
|
-
max = _extend_xr_method(xr.Dataset.max, see_also=["min", "sum"]) # type: ignore[attr-defined]
|
|
1743
|
-
cumsum = _extend_xr_method(xr.Dataset.cumsum, see_also="sum") # type: ignore[attr-defined]
|
|
1744
|
-
sum = _extend_xr_method(xr.Dataset.sum, see_also="cumsum") # type: ignore[attr-defined]
|
|
1745
|
-
|
|
1746
|
-
def _group_names(
|
|
1747
|
-
self,
|
|
1748
|
-
groups: Optional[Union[str, List[str]]],
|
|
1749
|
-
filter_groups: Optional["Literal['like', 'regex']"] = None,
|
|
1750
|
-
) -> List[str]:
|
|
1751
|
-
"""Handle expansion of group names input across arviz.
|
|
1752
|
-
|
|
1753
|
-
Parameters
|
|
1754
|
-
----------
|
|
1755
|
-
groups: str, list of str or None
|
|
1756
|
-
group or metagroup names.
|
|
1757
|
-
idata: xarray.Dataset
|
|
1758
|
-
Posterior data in an xarray
|
|
1759
|
-
filter_groups: {None, "like", "regex"}, optional, default=None
|
|
1760
|
-
If `None` (default), interpret groups as the real group or metagroup names.
|
|
1761
|
-
If "like", interpret groups as substrings of the real group or metagroup names.
|
|
1762
|
-
If "regex", interpret groups as regular expressions on the real group or
|
|
1763
|
-
metagroup names. A la `pandas.filter`.
|
|
1764
|
-
|
|
1765
|
-
Returns
|
|
1766
|
-
-------
|
|
1767
|
-
groups: list
|
|
1768
|
-
"""
|
|
1769
|
-
if filter_groups not in {None, "like", "regex"}:
|
|
1770
|
-
raise ValueError(
|
|
1771
|
-
f"'filter_groups' can only be None, 'like', or 'regex', got: '{filter_groups}'"
|
|
1772
|
-
)
|
|
1773
|
-
|
|
1774
|
-
all_groups = self._groups_all
|
|
1775
|
-
if groups is None:
|
|
1776
|
-
return all_groups
|
|
1777
|
-
if isinstance(groups, str):
|
|
1778
|
-
groups = [groups]
|
|
1779
|
-
sel_groups = []
|
|
1780
|
-
metagroups = rcParams["data.metagroups"]
|
|
1781
|
-
for group in groups:
|
|
1782
|
-
if group[0] == "~":
|
|
1783
|
-
sel_groups.extend(
|
|
1784
|
-
[f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
|
|
1785
|
-
if group[1:] in metagroups
|
|
1786
|
-
else [group]
|
|
1787
|
-
)
|
|
1788
|
-
else:
|
|
1789
|
-
sel_groups.extend(
|
|
1790
|
-
[item for item in metagroups[group] if item in all_groups]
|
|
1791
|
-
if group in metagroups
|
|
1792
|
-
else [group]
|
|
1793
|
-
)
|
|
1794
|
-
|
|
1795
|
-
try:
|
|
1796
|
-
group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
|
|
1797
|
-
except KeyError as err:
|
|
1798
|
-
msg = " ".join(("groups:", f"{err}", "in InferenceData"))
|
|
1799
|
-
raise KeyError(msg) from err
|
|
1800
|
-
return group_names
|
|
1801
|
-
|
|
1802
|
-
def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
|
|
1803
|
-
"""Apply a function to multiple groups.
|
|
1804
|
-
|
|
1805
|
-
Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the
|
|
1806
|
-
group with the result of the function.
|
|
1807
|
-
|
|
1808
|
-
Parameters
|
|
1809
|
-
----------
|
|
1810
|
-
fun : callable
|
|
1811
|
-
Function to be applied to each group. Assumes the function is called as
|
|
1812
|
-
``fun(dataset, *args, **kwargs)``.
|
|
1813
|
-
groups : str or list of str, optional
|
|
1814
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1815
|
-
or metagroup names.
|
|
1816
|
-
filter_groups : {None, "like", "regex"}, optional
|
|
1817
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
1818
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
1819
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
1820
|
-
`pandas.filter`.
|
|
1821
|
-
inplace : bool, optional
|
|
1822
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1823
|
-
otherwise, return the modified copy.
|
|
1824
|
-
args : array_like, optional
|
|
1825
|
-
Positional arguments passed to ``fun``.
|
|
1826
|
-
**kwargs : mapping, optional
|
|
1827
|
-
Keyword arguments passed to ``fun``.
|
|
1828
|
-
|
|
1829
|
-
Returns
|
|
1830
|
-
-------
|
|
1831
|
-
InferenceData
|
|
1832
|
-
A new InferenceData object by default.
|
|
1833
|
-
When `inplace==True` perform selection in place and return `None`
|
|
1834
|
-
|
|
1835
|
-
Examples
|
|
1836
|
-
--------
|
|
1837
|
-
Shift observed_data, prior_predictive and posterior_predictive.
|
|
1838
|
-
|
|
1839
|
-
.. jupyter-execute::
|
|
1840
|
-
|
|
1841
|
-
import arviz as az
|
|
1842
|
-
import numpy as np
|
|
1843
|
-
idata = az.load_arviz_data("non_centered_eight")
|
|
1844
|
-
idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
|
|
1845
|
-
idata_shifted_obs
|
|
1846
|
-
|
|
1847
|
-
Rename and update the coordinate values in both posterior and prior groups.
|
|
1848
|
-
|
|
1849
|
-
.. jupyter-execute::
|
|
1850
|
-
|
|
1851
|
-
idata = az.load_arviz_data("radon")
|
|
1852
|
-
idata = idata.map(
|
|
1853
|
-
lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
|
|
1854
|
-
uranium_coefs=["intercept", "u_slope"]
|
|
1855
|
-
),
|
|
1856
|
-
groups=["posterior", "prior"]
|
|
1857
|
-
)
|
|
1858
|
-
idata
|
|
1859
|
-
|
|
1860
|
-
Add extra coordinates to all groups containing observed variables
|
|
1861
|
-
|
|
1862
|
-
.. jupyter-execute::
|
|
1863
|
-
|
|
1864
|
-
idata = az.load_arviz_data("rugby")
|
|
1865
|
-
home_team, away_team = np.array([
|
|
1866
|
-
m.split() for m in idata.observed_data.match.values
|
|
1867
|
-
]).T
|
|
1868
|
-
idata = idata.map(
|
|
1869
|
-
lambda ds, **kwargs: ds.assign_coords(**kwargs),
|
|
1870
|
-
groups="observed_vars",
|
|
1871
|
-
home_team=("match", home_team),
|
|
1872
|
-
away_team=("match", away_team),
|
|
1873
|
-
)
|
|
1874
|
-
idata
|
|
1875
|
-
|
|
1876
|
-
"""
|
|
1877
|
-
if args is None:
|
|
1878
|
-
args = []
|
|
1879
|
-
groups = self._group_names(groups, filter_groups)
|
|
1880
|
-
|
|
1881
|
-
out = self if inplace else deepcopy(self)
|
|
1882
|
-
for group in groups:
|
|
1883
|
-
dataset = getattr(self, group)
|
|
1884
|
-
dataset = fun(dataset, *args, **kwargs)
|
|
1885
|
-
setattr(out, group, dataset)
|
|
1886
|
-
if inplace:
|
|
1887
|
-
return None
|
|
1888
|
-
else:
|
|
1889
|
-
return out
|
|
1890
|
-
|
|
1891
|
-
def _wrap_xarray_method(
|
|
1892
|
-
self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
|
|
1893
|
-
):
|
|
1894
|
-
"""Extend and xarray.Dataset method to InferenceData object.
|
|
1895
|
-
|
|
1896
|
-
Parameters
|
|
1897
|
-
----------
|
|
1898
|
-
method: str
|
|
1899
|
-
Method to be extended. Must be a ``xarray.Dataset`` method.
|
|
1900
|
-
groups: str or list of str, optional
|
|
1901
|
-
Groups where the selection is to be applied. Can either be group names
|
|
1902
|
-
or metagroup names.
|
|
1903
|
-
inplace: bool, optional
|
|
1904
|
-
If ``True``, modify the InferenceData object inplace,
|
|
1905
|
-
otherwise, return the modified copy.
|
|
1906
|
-
**kwargs: mapping, optional
|
|
1907
|
-
Keyword arguments passed to the xarray Dataset method.
|
|
1908
|
-
|
|
1909
|
-
Returns
|
|
1910
|
-
-------
|
|
1911
|
-
InferenceData
|
|
1912
|
-
A new InferenceData object by default.
|
|
1913
|
-
When `inplace==True` perform selection in place and return `None`
|
|
1914
|
-
|
|
1915
|
-
Examples
|
|
1916
|
-
--------
|
|
1917
|
-
Compute the mean of `posterior_groups`:
|
|
1918
|
-
|
|
1919
|
-
.. ipython::
|
|
1920
|
-
|
|
1921
|
-
In [1]: import arviz as az
|
|
1922
|
-
...: idata = az.load_arviz_data("non_centered_eight")
|
|
1923
|
-
...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars")
|
|
1924
|
-
...: print(idata_means.posterior)
|
|
1925
|
-
...: print(idata_means.observed_data)
|
|
1926
|
-
|
|
1927
|
-
.. ipython::
|
|
1928
|
-
|
|
1929
|
-
In [1]: idata_stack = idata._wrap_xarray_method(
|
|
1930
|
-
...: "stack",
|
|
1931
|
-
...: groups=["posterior_groups", "prior_groups"],
|
|
1932
|
-
...: sample=["chain", "draw"]
|
|
1933
|
-
...: )
|
|
1934
|
-
...: print(idata_stack.posterior)
|
|
1935
|
-
...: print(idata_stack.prior)
|
|
1936
|
-
...: print(idata_stack.observed_data)
|
|
1937
|
-
|
|
1938
|
-
"""
|
|
1939
|
-
if args is None:
|
|
1940
|
-
args = []
|
|
1941
|
-
groups = self._group_names(groups, filter_groups)
|
|
1942
|
-
|
|
1943
|
-
method = getattr(xr.Dataset, method)
|
|
1944
|
-
|
|
1945
|
-
out = self if inplace else deepcopy(self)
|
|
1946
|
-
for group in groups:
|
|
1947
|
-
dataset = getattr(self, group)
|
|
1948
|
-
dataset = method(dataset, *args, **kwargs)
|
|
1949
|
-
setattr(out, group, dataset)
|
|
1950
|
-
if inplace:
|
|
1951
|
-
return None
|
|
1952
|
-
else:
|
|
1953
|
-
return out
|
|
1954
|
-
|
|
1955
|
-
def copy(self) -> "InferenceData":
|
|
1956
|
-
"""Return a fresh copy of the ``InferenceData`` object."""
|
|
1957
|
-
return deepcopy(self)
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
@overload
|
|
1961
|
-
def concat(
|
|
1962
|
-
*args,
|
|
1963
|
-
dim: Optional[str] = None,
|
|
1964
|
-
copy: bool = True,
|
|
1965
|
-
inplace: "Literal[True]",
|
|
1966
|
-
reset_dim: bool = True,
|
|
1967
|
-
) -> None: ...
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
@overload
|
|
1971
|
-
def concat(
|
|
1972
|
-
*args,
|
|
1973
|
-
dim: Optional[str] = None,
|
|
1974
|
-
copy: bool = True,
|
|
1975
|
-
inplace: "Literal[False]",
|
|
1976
|
-
reset_dim: bool = True,
|
|
1977
|
-
) -> InferenceData: ...
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
@overload
|
|
1981
|
-
def concat(
|
|
1982
|
-
ids: Iterable[InferenceData],
|
|
1983
|
-
dim: Optional[str] = None,
|
|
1984
|
-
*,
|
|
1985
|
-
copy: bool = True,
|
|
1986
|
-
inplace: "Literal[False]",
|
|
1987
|
-
reset_dim: bool = True,
|
|
1988
|
-
) -> InferenceData: ...
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
@overload
|
|
1992
|
-
def concat(
|
|
1993
|
-
ids: Iterable[InferenceData],
|
|
1994
|
-
dim: Optional[str] = None,
|
|
1995
|
-
*,
|
|
1996
|
-
copy: bool = True,
|
|
1997
|
-
inplace: "Literal[True]",
|
|
1998
|
-
reset_dim: bool = True,
|
|
1999
|
-
) -> None: ...
|
|
2000
|
-
|
|
2001
|
-
|
|
2002
|
-
@overload
|
|
2003
|
-
def concat(
|
|
2004
|
-
ids: Iterable[InferenceData],
|
|
2005
|
-
dim: Optional[str] = None,
|
|
2006
|
-
*,
|
|
2007
|
-
copy: bool = True,
|
|
2008
|
-
inplace: bool = False,
|
|
2009
|
-
reset_dim: bool = True,
|
|
2010
|
-
) -> Optional[InferenceData]: ...
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
# pylint: disable=protected-access, inconsistent-return-statements
|
|
2014
|
-
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
|
|
2015
|
-
"""Concatenate InferenceData objects.
|
|
2016
|
-
|
|
2017
|
-
Concatenates over `group`, `chain` or `draw`.
|
|
2018
|
-
By default concatenates over unique groups.
|
|
2019
|
-
To concatenate over `chain` or `draw` function
|
|
2020
|
-
needs identical groups and variables.
|
|
2021
|
-
|
|
2022
|
-
The `variables` in the `data` -group are merged if `dim` are not found.
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
Parameters
|
|
2026
|
-
----------
|
|
2027
|
-
*args : InferenceData
|
|
2028
|
-
Variable length InferenceData list or
|
|
2029
|
-
Sequence of InferenceData.
|
|
2030
|
-
dim : str, optional
|
|
2031
|
-
If defined, concatenated over the defined dimension.
|
|
2032
|
-
Dimension which is concatenated. If None, concatenates over
|
|
2033
|
-
unique groups.
|
|
2034
|
-
copy : bool
|
|
2035
|
-
If True, groups are copied to the new InferenceData object.
|
|
2036
|
-
Used only if `dim` is None.
|
|
2037
|
-
inplace : bool
|
|
2038
|
-
If True, merge args to first object.
|
|
2039
|
-
reset_dim : bool
|
|
2040
|
-
Valid only if dim is not None.
|
|
2041
|
-
|
|
2042
|
-
Returns
|
|
2043
|
-
-------
|
|
2044
|
-
InferenceData
|
|
2045
|
-
A new InferenceData object by default.
|
|
2046
|
-
When `inplace==True` merge args to first arg and return `None`
|
|
2047
|
-
|
|
2048
|
-
See Also
|
|
2049
|
-
--------
|
|
2050
|
-
add_groups : Add new groups to InferenceData object.
|
|
2051
|
-
extend : Extend InferenceData with groups from another InferenceData.
|
|
2052
|
-
|
|
2053
|
-
Examples
|
|
2054
|
-
--------
|
|
2055
|
-
Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
|
|
2056
|
-
unique groups by default. We first create an ``InferenceData`` object:
|
|
2057
|
-
|
|
2058
|
-
.. ipython::
|
|
2059
|
-
|
|
2060
|
-
In [1]: import arviz as az
|
|
2061
|
-
...: import numpy as np
|
|
2062
|
-
...: data = {
|
|
2063
|
-
...: "a": np.random.normal(size=(4, 100, 3)),
|
|
2064
|
-
...: "b": np.random.normal(size=(4, 100)),
|
|
2065
|
-
...: }
|
|
2066
|
-
...: coords = {"a_dim": ["x", "y", "z"]}
|
|
2067
|
-
...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
|
|
2068
|
-
...: dataA
|
|
2069
|
-
|
|
2070
|
-
We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
|
|
2071
|
-
create another ``InferenceData`` object:
|
|
2072
|
-
|
|
2073
|
-
.. ipython::
|
|
2074
|
-
|
|
2075
|
-
In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
|
|
2076
|
-
...: dataB
|
|
2077
|
-
|
|
2078
|
-
We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
|
|
2079
|
-
these two ``InferenceData`` objects:
|
|
2080
|
-
|
|
2081
|
-
.. ipython::
|
|
2082
|
-
|
|
2083
|
-
In [1]: az.concat(dataA, dataB)
|
|
2084
|
-
|
|
2085
|
-
Now, we will concatenate over chain (or draw). It requires identical groups and variables.
|
|
2086
|
-
Here we are concatenating two identical ``InferenceData`` objects over dimension chain:
|
|
2087
|
-
|
|
2088
|
-
.. ipython::
|
|
2089
|
-
|
|
2090
|
-
In [1]: az.concat(dataA, dataA, dim="chain")
|
|
2091
|
-
|
|
2092
|
-
It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
|
|
2093
|
-
we can also concatenate over draws.
|
|
2094
|
-
|
|
2095
|
-
"""
|
|
2096
|
-
# pylint: disable=undefined-loop-variable, too-many-nested-blocks
|
|
2097
|
-
if len(args) == 0:
|
|
2098
|
-
if inplace:
|
|
2099
|
-
return
|
|
2100
|
-
return InferenceData()
|
|
2101
|
-
|
|
2102
|
-
if len(args) == 1 and isinstance(args[0], Sequence):
|
|
2103
|
-
args = args[0]
|
|
2104
|
-
|
|
2105
|
-
# assert that all args are InferenceData
|
|
2106
|
-
for i, arg in enumerate(args):
|
|
2107
|
-
if not isinstance(arg, InferenceData):
|
|
2108
|
-
raise TypeError(
|
|
2109
|
-
"Concatenating is supported only"
|
|
2110
|
-
"between InferenceData objects. Input arg {} is {}".format(i, type(arg))
|
|
2111
|
-
)
|
|
2112
|
-
|
|
2113
|
-
if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
|
|
2114
|
-
msg = f'Invalid `dim`: {dim}. Valid `dim` are {{"group", "chain", "draw"}}'
|
|
2115
|
-
raise TypeError(msg)
|
|
2116
|
-
dim = dim.lower() if dim is not None else dim
|
|
2117
|
-
|
|
2118
|
-
if len(args) == 1 and isinstance(args[0], InferenceData):
|
|
2119
|
-
if inplace:
|
|
2120
|
-
return None
|
|
2121
|
-
else:
|
|
2122
|
-
if copy:
|
|
2123
|
-
return deepcopy(args[0])
|
|
2124
|
-
else:
|
|
2125
|
-
return args[0]
|
|
2126
|
-
|
|
2127
|
-
current_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
2128
|
-
combined_attr = defaultdict(list)
|
|
2129
|
-
for idata in args:
|
|
2130
|
-
for key, val in idata.attrs.items():
|
|
2131
|
-
combined_attr[key].append(val)
|
|
2132
|
-
|
|
2133
|
-
for key, val in combined_attr.items():
|
|
2134
|
-
all_same = True
|
|
2135
|
-
for indx in range(len(val) - 1):
|
|
2136
|
-
if val[indx] != val[indx + 1]:
|
|
2137
|
-
all_same = False
|
|
2138
|
-
break
|
|
2139
|
-
if all_same:
|
|
2140
|
-
combined_attr[key] = val[0]
|
|
2141
|
-
if inplace:
|
|
2142
|
-
setattr(args[0], "_attrs", dict(combined_attr))
|
|
2143
|
-
|
|
2144
|
-
if not inplace:
|
|
2145
|
-
# Keep order for python 3.5
|
|
2146
|
-
inference_data_dict = OrderedDict()
|
|
2147
|
-
|
|
2148
|
-
if dim is None:
|
|
2149
|
-
arg0 = args[0]
|
|
2150
|
-
arg0_groups = ccopy(arg0._groups_all)
|
|
2151
|
-
args_groups = {}
|
|
2152
|
-
# check if groups are independent
|
|
2153
|
-
# Concat over unique groups
|
|
2154
|
-
for arg in args[1:]:
|
|
2155
|
-
for group in arg._groups_all:
|
|
2156
|
-
if group in args_groups or group in arg0_groups:
|
|
2157
|
-
msg = (
|
|
2158
|
-
"Concatenating overlapping groups is not supported unless `dim` is defined."
|
|
2159
|
-
" Valid dimensions are `chain` and `draw`. Alternatively, use extend to"
|
|
2160
|
-
" combine InferenceData with overlapping groups"
|
|
2161
|
-
)
|
|
2162
|
-
raise TypeError(msg)
|
|
2163
|
-
group_data = getattr(arg, group)
|
|
2164
|
-
args_groups[group] = deepcopy(group_data) if copy else group_data
|
|
2165
|
-
# add arg0 to args_groups if inplace is False
|
|
2166
|
-
# otherwise it will merge args_groups to arg0
|
|
2167
|
-
# inference data object
|
|
2168
|
-
if not inplace:
|
|
2169
|
-
for group in arg0_groups:
|
|
2170
|
-
group_data = getattr(arg0, group)
|
|
2171
|
-
args_groups[group] = deepcopy(group_data) if copy else group_data
|
|
2172
|
-
|
|
2173
|
-
other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]
|
|
2174
|
-
|
|
2175
|
-
for group in SUPPORTED_GROUPS_ALL + other_groups:
|
|
2176
|
-
if group not in args_groups:
|
|
2177
|
-
continue
|
|
2178
|
-
if inplace:
|
|
2179
|
-
if group.startswith(WARMUP_TAG):
|
|
2180
|
-
arg0._groups_warmup.append(group)
|
|
2181
|
-
else:
|
|
2182
|
-
arg0._groups.append(group)
|
|
2183
|
-
setattr(arg0, group, args_groups[group])
|
|
2184
|
-
else:
|
|
2185
|
-
inference_data_dict[group] = args_groups[group]
|
|
2186
|
-
if inplace:
|
|
2187
|
-
other_groups = [
|
|
2188
|
-
group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
|
|
2189
|
-
] + other_groups
|
|
2190
|
-
sorted_groups = [
|
|
2191
|
-
group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
|
|
2192
|
-
]
|
|
2193
|
-
setattr(arg0, "_groups", sorted_groups)
|
|
2194
|
-
sorted_groups_warmup = [
|
|
2195
|
-
group
|
|
2196
|
-
for group in SUPPORTED_GROUPS_WARMUP + other_groups
|
|
2197
|
-
if group in arg0._groups_warmup
|
|
2198
|
-
]
|
|
2199
|
-
setattr(arg0, "_groups_warmup", sorted_groups_warmup)
|
|
2200
|
-
else:
|
|
2201
|
-
arg0 = args[0]
|
|
2202
|
-
arg0_groups = arg0._groups_all
|
|
2203
|
-
for arg in args[1:]:
|
|
2204
|
-
for group0 in arg0_groups:
|
|
2205
|
-
if group0 not in arg._groups_all:
|
|
2206
|
-
if group0 == "observed_data":
|
|
2207
|
-
continue
|
|
2208
|
-
msg = "Mismatch between the groups."
|
|
2209
|
-
raise TypeError(msg)
|
|
2210
|
-
for group in arg._groups_all:
|
|
2211
|
-
# handle data groups separately
|
|
2212
|
-
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
|
|
2213
|
-
# assert that groups are equal
|
|
2214
|
-
if group not in arg0_groups:
|
|
2215
|
-
msg = "Mismatch between the groups."
|
|
2216
|
-
raise TypeError(msg)
|
|
2217
|
-
|
|
2218
|
-
# assert that variables are equal
|
|
2219
|
-
group_data = getattr(arg, group)
|
|
2220
|
-
group_vars = group_data.data_vars
|
|
2221
|
-
|
|
2222
|
-
if not inplace and group in inference_data_dict:
|
|
2223
|
-
group0_data = inference_data_dict[group]
|
|
2224
|
-
else:
|
|
2225
|
-
group0_data = getattr(arg0, group)
|
|
2226
|
-
group0_vars = group0_data.data_vars
|
|
2227
|
-
|
|
2228
|
-
for var in group0_vars:
|
|
2229
|
-
if var not in group_vars:
|
|
2230
|
-
msg = "Mismatch between the variables."
|
|
2231
|
-
raise TypeError(msg)
|
|
2232
|
-
|
|
2233
|
-
for var in group_vars:
|
|
2234
|
-
if var not in group0_vars:
|
|
2235
|
-
msg = "Mismatch between the variables."
|
|
2236
|
-
raise TypeError(msg)
|
|
2237
|
-
var_dims = group_data[var].dims
|
|
2238
|
-
var0_dims = group0_data[var].dims
|
|
2239
|
-
if var_dims != var0_dims:
|
|
2240
|
-
msg = "Mismatch between the dimensions."
|
|
2241
|
-
raise TypeError(msg)
|
|
2242
|
-
|
|
2243
|
-
if dim not in var_dims or dim not in var0_dims:
|
|
2244
|
-
msg = f"Dimension {dim} missing."
|
|
2245
|
-
raise TypeError(msg)
|
|
2246
|
-
|
|
2247
|
-
# xr.concat
|
|
2248
|
-
concatenated_group = xr.concat((group0_data, group_data), dim=dim)
|
|
2249
|
-
if reset_dim:
|
|
2250
|
-
concatenated_group[dim] = range(concatenated_group[dim].size)
|
|
2251
|
-
|
|
2252
|
-
# handle attrs
|
|
2253
|
-
if hasattr(group0_data, "attrs"):
|
|
2254
|
-
group0_attrs = deepcopy(getattr(group0_data, "attrs"))
|
|
2255
|
-
else:
|
|
2256
|
-
group0_attrs = OrderedDict()
|
|
2257
|
-
|
|
2258
|
-
if hasattr(group_data, "attrs"):
|
|
2259
|
-
group_attrs = getattr(group_data, "attrs")
|
|
2260
|
-
else:
|
|
2261
|
-
group_attrs = {}
|
|
2262
|
-
|
|
2263
|
-
# gather attrs results to group0_attrs
|
|
2264
|
-
for attr_key, attr_values in group_attrs.items():
|
|
2265
|
-
group0_attr_values = group0_attrs.get(attr_key, None)
|
|
2266
|
-
equality = attr_values == group0_attr_values
|
|
2267
|
-
if hasattr(equality, "__iter__"):
|
|
2268
|
-
equality = np.all(equality)
|
|
2269
|
-
if equality:
|
|
2270
|
-
continue
|
|
2271
|
-
# handle special cases:
|
|
2272
|
-
if attr_key in ("created_at", "previous_created_at"):
|
|
2273
|
-
# check the defaults
|
|
2274
|
-
if not hasattr(group0_attrs, "previous_created_at"):
|
|
2275
|
-
group0_attrs["previous_created_at"] = []
|
|
2276
|
-
if group0_attr_values is not None:
|
|
2277
|
-
group0_attrs["previous_created_at"].append(group0_attr_values)
|
|
2278
|
-
# check previous values
|
|
2279
|
-
if attr_key == "previous_created_at":
|
|
2280
|
-
if not isinstance(attr_values, list):
|
|
2281
|
-
attr_values = [attr_values]
|
|
2282
|
-
group0_attrs["previous_created_at"].extend(attr_values)
|
|
2283
|
-
continue
|
|
2284
|
-
# update "created_at"
|
|
2285
|
-
if group0_attr_values != current_time:
|
|
2286
|
-
group0_attrs[attr_key] = current_time
|
|
2287
|
-
group0_attrs["previous_created_at"].append(attr_values)
|
|
2288
|
-
|
|
2289
|
-
elif attr_key in group0_attrs:
|
|
2290
|
-
combined_key = f"combined_{attr_key}"
|
|
2291
|
-
if combined_key not in group0_attrs:
|
|
2292
|
-
group0_attrs[combined_key] = [group0_attr_values]
|
|
2293
|
-
group0_attrs[combined_key].append(attr_values)
|
|
2294
|
-
else:
|
|
2295
|
-
group0_attrs[attr_key] = attr_values
|
|
2296
|
-
# update attrs
|
|
2297
|
-
setattr(concatenated_group, "attrs", group0_attrs)
|
|
2298
|
-
|
|
2299
|
-
if inplace:
|
|
2300
|
-
setattr(arg0, group, concatenated_group)
|
|
2301
|
-
else:
|
|
2302
|
-
inference_data_dict[group] = concatenated_group
|
|
2303
|
-
else:
|
|
2304
|
-
# observed_data, "constant_data", "predictions_constant_data",
|
|
2305
|
-
if group not in arg0_groups:
|
|
2306
|
-
setattr(arg0, group, deepcopy(group_data) if copy else group_data)
|
|
2307
|
-
arg0._groups.append(group)
|
|
2308
|
-
continue
|
|
2309
|
-
|
|
2310
|
-
# assert that variables are equal
|
|
2311
|
-
group_data = getattr(arg, group)
|
|
2312
|
-
group_vars = group_data.data_vars
|
|
2313
|
-
|
|
2314
|
-
group0_data = getattr(arg0, group)
|
|
2315
|
-
if not inplace:
|
|
2316
|
-
group0_data = deepcopy(group0_data)
|
|
2317
|
-
group0_vars = group0_data.data_vars
|
|
2318
|
-
|
|
2319
|
-
for var in group_vars:
|
|
2320
|
-
if var not in group0_vars:
|
|
2321
|
-
var_data = group_data[var]
|
|
2322
|
-
getattr(arg0, group)[var] = var_data
|
|
2323
|
-
else:
|
|
2324
|
-
var_data = group_data[var]
|
|
2325
|
-
var0_data = group0_data[var]
|
|
2326
|
-
if dim in var_data.dims and dim in var0_data.dims:
|
|
2327
|
-
concatenated_var = xr.concat((group_data, group0_data), dim=dim)
|
|
2328
|
-
group0_data[var] = concatenated_var
|
|
2329
|
-
|
|
2330
|
-
# handle attrs
|
|
2331
|
-
if hasattr(group0_data, "attrs"):
|
|
2332
|
-
group0_attrs = getattr(group0_data, "attrs")
|
|
2333
|
-
else:
|
|
2334
|
-
group0_attrs = OrderedDict()
|
|
2335
|
-
|
|
2336
|
-
if hasattr(group_data, "attrs"):
|
|
2337
|
-
group_attrs = getattr(group_data, "attrs")
|
|
2338
|
-
else:
|
|
2339
|
-
group_attrs = {}
|
|
2340
|
-
|
|
2341
|
-
# gather attrs results to group0_attrs
|
|
2342
|
-
for attr_key, attr_values in group_attrs.items():
|
|
2343
|
-
group0_attr_values = group0_attrs.get(attr_key, None)
|
|
2344
|
-
equality = attr_values == group0_attr_values
|
|
2345
|
-
if hasattr(equality, "__iter__"):
|
|
2346
|
-
equality = np.all(equality)
|
|
2347
|
-
if equality:
|
|
2348
|
-
continue
|
|
2349
|
-
# handle special cases:
|
|
2350
|
-
if attr_key in ("created_at", "previous_created_at"):
|
|
2351
|
-
# check the defaults
|
|
2352
|
-
if not hasattr(group0_attrs, "previous_created_at"):
|
|
2353
|
-
group0_attrs["previous_created_at"] = []
|
|
2354
|
-
if group0_attr_values is not None:
|
|
2355
|
-
group0_attrs["previous_created_at"].append(group0_attr_values)
|
|
2356
|
-
# check previous values
|
|
2357
|
-
if attr_key == "previous_created_at":
|
|
2358
|
-
if not isinstance(attr_values, list):
|
|
2359
|
-
attr_values = [attr_values]
|
|
2360
|
-
group0_attrs["previous_created_at"].extend(attr_values)
|
|
2361
|
-
continue
|
|
2362
|
-
# update "created_at"
|
|
2363
|
-
if group0_attr_values != current_time:
|
|
2364
|
-
group0_attrs[attr_key] = current_time
|
|
2365
|
-
group0_attrs["previous_created_at"].append(attr_values)
|
|
2366
|
-
|
|
2367
|
-
elif attr_key in group0_attrs:
|
|
2368
|
-
combined_key = f"combined_{attr_key}"
|
|
2369
|
-
if combined_key not in group0_attrs:
|
|
2370
|
-
group0_attrs[combined_key] = [group0_attr_values]
|
|
2371
|
-
group0_attrs[combined_key].append(attr_values)
|
|
2372
|
-
|
|
2373
|
-
else:
|
|
2374
|
-
group0_attrs[attr_key] = attr_values
|
|
2375
|
-
# update attrs
|
|
2376
|
-
setattr(group0_data, "attrs", group0_attrs)
|
|
2377
|
-
|
|
2378
|
-
if inplace:
|
|
2379
|
-
setattr(arg0, group, group0_data)
|
|
2380
|
-
else:
|
|
2381
|
-
inference_data_dict[group] = group0_data
|
|
2382
|
-
|
|
2383
|
-
if not inplace:
|
|
2384
|
-
inference_data_dict["attrs"] = combined_attr
|
|
2385
|
-
|
|
2386
|
-
return None if inplace else InferenceData(**inference_data_dict)
|