arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,1679 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name
|
|
2
|
-
# pylint: disable=too-many-lines
|
|
3
|
-
|
|
4
|
-
import importlib
|
|
5
|
-
import os
|
|
6
|
-
import warnings
|
|
7
|
-
from collections import namedtuple
|
|
8
|
-
from copy import deepcopy
|
|
9
|
-
from html import escape
|
|
10
|
-
from typing import Dict
|
|
11
|
-
from tempfile import TemporaryDirectory
|
|
12
|
-
from urllib.parse import urlunsplit
|
|
13
|
-
|
|
14
|
-
import numpy as np
|
|
15
|
-
import pytest
|
|
16
|
-
import xarray as xr
|
|
17
|
-
from xarray.core.options import OPTIONS
|
|
18
|
-
from xarray.testing import assert_identical
|
|
19
|
-
|
|
20
|
-
from ... import (
|
|
21
|
-
InferenceData,
|
|
22
|
-
clear_data_home,
|
|
23
|
-
concat,
|
|
24
|
-
convert_to_dataset,
|
|
25
|
-
convert_to_inference_data,
|
|
26
|
-
from_datatree,
|
|
27
|
-
from_dict,
|
|
28
|
-
from_json,
|
|
29
|
-
from_netcdf,
|
|
30
|
-
list_datasets,
|
|
31
|
-
load_arviz_data,
|
|
32
|
-
to_netcdf,
|
|
33
|
-
extract,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
from ...data.base import (
|
|
37
|
-
dict_to_dataset,
|
|
38
|
-
generate_dims_coords,
|
|
39
|
-
infer_stan_dtypes,
|
|
40
|
-
make_attrs,
|
|
41
|
-
numpy_to_data_array,
|
|
42
|
-
)
|
|
43
|
-
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
|
|
44
|
-
from ..helpers import ( # pylint: disable=unused-import
|
|
45
|
-
chains,
|
|
46
|
-
check_multiple_attrs,
|
|
47
|
-
create_data_random,
|
|
48
|
-
data_random,
|
|
49
|
-
draws,
|
|
50
|
-
eight_schools_params,
|
|
51
|
-
models,
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
# Check if dm-tree is installed
|
|
55
|
-
dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
|
|
56
|
-
skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
@pytest.fixture(autouse=True)
|
|
60
|
-
def no_remote_data(monkeypatch, tmpdir):
|
|
61
|
-
"""Delete all remote data and replace it with a local dataset."""
|
|
62
|
-
keys = list(REMOTE_DATASETS)
|
|
63
|
-
for key in keys:
|
|
64
|
-
monkeypatch.delitem(REMOTE_DATASETS, key)
|
|
65
|
-
|
|
66
|
-
centered = LOCAL_DATASETS["centered_eight"]
|
|
67
|
-
filename = os.path.join(str(tmpdir), os.path.basename(centered.filename))
|
|
68
|
-
|
|
69
|
-
url = urlunsplit(("file", "", centered.filename, "", ""))
|
|
70
|
-
|
|
71
|
-
monkeypatch.setitem(
|
|
72
|
-
REMOTE_DATASETS,
|
|
73
|
-
"test_remote",
|
|
74
|
-
RemoteFileMetadata(
|
|
75
|
-
name="test_remote",
|
|
76
|
-
filename=filename,
|
|
77
|
-
url=url,
|
|
78
|
-
checksum="8efc3abafe0c796eb9aea7b69490d4e2400a33c57504ef4932e1c7105849176f",
|
|
79
|
-
description=centered.description,
|
|
80
|
-
),
|
|
81
|
-
)
|
|
82
|
-
monkeypatch.setitem(
|
|
83
|
-
REMOTE_DATASETS,
|
|
84
|
-
"bad_checksum",
|
|
85
|
-
RemoteFileMetadata(
|
|
86
|
-
name="bad_checksum",
|
|
87
|
-
filename=filename,
|
|
88
|
-
url=url,
|
|
89
|
-
checksum="bad!",
|
|
90
|
-
description=centered.description,
|
|
91
|
-
),
|
|
92
|
-
)
|
|
93
|
-
UnknownFileMetaData = namedtuple(
|
|
94
|
-
"UnknownFileMetaData", ["filename", "url", "checksum", "description"]
|
|
95
|
-
)
|
|
96
|
-
monkeypatch.setitem(
|
|
97
|
-
REMOTE_DATASETS,
|
|
98
|
-
"test_unknown",
|
|
99
|
-
UnknownFileMetaData(
|
|
100
|
-
filename=filename,
|
|
101
|
-
url=url,
|
|
102
|
-
checksum="9ae00c83654b3f061d32c882ec0a270d10838fa36515ecb162b89a290e014849",
|
|
103
|
-
description="Test bad REMOTE_DATASET",
|
|
104
|
-
),
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def test_load_local_arviz_data():
|
|
109
|
-
inference_data = load_arviz_data("centered_eight")
|
|
110
|
-
assert isinstance(inference_data, InferenceData)
|
|
111
|
-
assert set(inference_data.observed_data.obs.coords["school"].values) == {
|
|
112
|
-
"Hotchkiss",
|
|
113
|
-
"Mt. Hermon",
|
|
114
|
-
"Choate",
|
|
115
|
-
"Deerfield",
|
|
116
|
-
"Phillips Andover",
|
|
117
|
-
"St. Paul's",
|
|
118
|
-
"Lawrenceville",
|
|
119
|
-
"Phillips Exeter",
|
|
120
|
-
}
|
|
121
|
-
assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@pytest.mark.parametrize("fill_attrs", [True, False])
|
|
125
|
-
def test_local_save(fill_attrs):
|
|
126
|
-
inference_data = load_arviz_data("centered_eight")
|
|
127
|
-
assert isinstance(inference_data, InferenceData)
|
|
128
|
-
|
|
129
|
-
if fill_attrs:
|
|
130
|
-
inference_data.attrs["test"] = 1
|
|
131
|
-
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
|
|
132
|
-
path = os.path.join(tmp_dir, "test_file.nc")
|
|
133
|
-
inference_data.to_netcdf(path)
|
|
134
|
-
|
|
135
|
-
inference_data2 = from_netcdf(path)
|
|
136
|
-
if fill_attrs:
|
|
137
|
-
assert "test" in inference_data2.attrs
|
|
138
|
-
assert inference_data2.attrs["test"] == 1
|
|
139
|
-
# pylint: disable=protected-access
|
|
140
|
-
assert all(group in inference_data2 for group in inference_data._groups_all)
|
|
141
|
-
# pylint: enable=protected-access
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def test_clear_data_home():
|
|
145
|
-
resource = REMOTE_DATASETS["test_remote"]
|
|
146
|
-
assert not os.path.exists(resource.filename)
|
|
147
|
-
load_arviz_data("test_remote")
|
|
148
|
-
assert os.path.exists(resource.filename)
|
|
149
|
-
clear_data_home(data_home=os.path.dirname(resource.filename))
|
|
150
|
-
assert not os.path.exists(resource.filename)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def test_load_remote_arviz_data():
|
|
154
|
-
assert load_arviz_data("test_remote")
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def test_bad_checksum():
|
|
158
|
-
with pytest.raises(IOError):
|
|
159
|
-
load_arviz_data("bad_checksum")
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def test_missing_dataset():
|
|
163
|
-
with pytest.raises(ValueError):
|
|
164
|
-
load_arviz_data("does not exist")
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def test_list_datasets():
|
|
168
|
-
dataset_string = list_datasets()
|
|
169
|
-
# make sure all the names of the data sets are in the dataset description
|
|
170
|
-
for key in (
|
|
171
|
-
"centered_eight",
|
|
172
|
-
"non_centered_eight",
|
|
173
|
-
"test_remote",
|
|
174
|
-
"bad_checksum",
|
|
175
|
-
"test_unknown",
|
|
176
|
-
):
|
|
177
|
-
assert key in dataset_string
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def test_dims_coords():
|
|
181
|
-
shape = 4, 20, 5
|
|
182
|
-
var_name = "x"
|
|
183
|
-
dims, coords = generate_dims_coords(shape, var_name)
|
|
184
|
-
assert "x_dim_0" in dims
|
|
185
|
-
assert "x_dim_1" in dims
|
|
186
|
-
assert "x_dim_2" in dims
|
|
187
|
-
assert len(coords["x_dim_0"]) == 4
|
|
188
|
-
assert len(coords["x_dim_1"]) == 20
|
|
189
|
-
assert len(coords["x_dim_2"]) == 5
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
@pytest.mark.parametrize(
|
|
193
|
-
"in_dims", (["dim1", "dim2"], ["draw", "dim1", "dim2"], ["chain", "draw", "dim1", "dim2"])
|
|
194
|
-
)
|
|
195
|
-
def test_dims_coords_default_dims(in_dims):
|
|
196
|
-
shape = 4, 7
|
|
197
|
-
var_name = "x"
|
|
198
|
-
dims, coords = generate_dims_coords(
|
|
199
|
-
shape,
|
|
200
|
-
var_name,
|
|
201
|
-
dims=in_dims,
|
|
202
|
-
coords={"chain": ["a", "b", "c"]},
|
|
203
|
-
default_dims=["chain", "draw"],
|
|
204
|
-
)
|
|
205
|
-
assert "dim1" in dims
|
|
206
|
-
assert "dim2" in dims
|
|
207
|
-
assert ("chain" in dims) == ("chain" in in_dims)
|
|
208
|
-
assert ("draw" in dims) == ("draw" in in_dims)
|
|
209
|
-
assert len(coords["dim1"]) == 4
|
|
210
|
-
assert len(coords["dim2"]) == 7
|
|
211
|
-
assert len(coords["chain"]) == 3
|
|
212
|
-
assert "draw" not in coords
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
def test_dims_coords_extra_dims():
|
|
216
|
-
shape = 4, 20
|
|
217
|
-
var_name = "x"
|
|
218
|
-
with pytest.warns(UserWarning):
|
|
219
|
-
dims, coords = generate_dims_coords(shape, var_name, dims=["xx", "xy", "xz"])
|
|
220
|
-
assert "xx" in dims
|
|
221
|
-
assert "xy" in dims
|
|
222
|
-
assert "xz" in dims
|
|
223
|
-
assert len(coords["xx"]) == 4
|
|
224
|
-
assert len(coords["xy"]) == 20
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
@pytest.mark.parametrize("shape", [(4, 20), (4, 20, 1)])
|
|
228
|
-
def test_dims_coords_skip_event_dims(shape):
|
|
229
|
-
coords = {"x": np.arange(4), "y": np.arange(20), "z": np.arange(5)}
|
|
230
|
-
dims, coords = generate_dims_coords(
|
|
231
|
-
shape, "name", dims=["x", "y", "z"], coords=coords, skip_event_dims=True
|
|
232
|
-
)
|
|
233
|
-
assert "x" in dims
|
|
234
|
-
assert "y" in dims
|
|
235
|
-
assert "z" not in dims
|
|
236
|
-
assert len(coords["x"]) == 4
|
|
237
|
-
assert len(coords["y"]) == 20
|
|
238
|
-
assert "z" not in coords
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
|
|
242
|
-
def test_numpy_to_data_array_with_dims(dims):
|
|
243
|
-
da = numpy_to_data_array(
|
|
244
|
-
np.empty((4, 500, 7)),
|
|
245
|
-
var_name="a",
|
|
246
|
-
dims=dims,
|
|
247
|
-
default_dims=["chain", "draw"],
|
|
248
|
-
)
|
|
249
|
-
assert list(da.dims) == ["chain", "draw", "a_dim_0"]
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def test_make_attrs():
|
|
253
|
-
extra_attrs = {"key": "Value"}
|
|
254
|
-
attrs = make_attrs(attrs=extra_attrs)
|
|
255
|
-
assert "key" in attrs
|
|
256
|
-
assert attrs["key"] == "Value"
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
@pytest.mark.parametrize("copy", [True, False])
|
|
260
|
-
@pytest.mark.parametrize("inplace", [True, False])
|
|
261
|
-
@pytest.mark.parametrize("sequence", [True, False])
|
|
262
|
-
def test_concat_group(copy, inplace, sequence):
|
|
263
|
-
idata1 = from_dict(
|
|
264
|
-
posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
|
|
265
|
-
)
|
|
266
|
-
if copy and inplace:
|
|
267
|
-
original_idata1_posterior_id = id(idata1.posterior)
|
|
268
|
-
idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)})
|
|
269
|
-
idata3 = from_dict(observed_data={"E": np.random.randn(100), "F": np.random.randn(2, 100)})
|
|
270
|
-
# basic case
|
|
271
|
-
assert concat(idata1, idata2, copy=True, inplace=False) is not None
|
|
272
|
-
if sequence:
|
|
273
|
-
new_idata = concat((idata1, idata2, idata3), copy=copy, inplace=inplace)
|
|
274
|
-
else:
|
|
275
|
-
new_idata = concat(idata1, idata2, idata3, copy=copy, inplace=inplace)
|
|
276
|
-
if inplace:
|
|
277
|
-
assert new_idata is None
|
|
278
|
-
new_idata = idata1
|
|
279
|
-
assert new_idata is not None
|
|
280
|
-
test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"], "observed_data": ["E", "F"]}
|
|
281
|
-
fails = check_multiple_attrs(test_dict, new_idata)
|
|
282
|
-
assert not fails
|
|
283
|
-
if copy:
|
|
284
|
-
if inplace:
|
|
285
|
-
assert id(new_idata.posterior) == original_idata1_posterior_id
|
|
286
|
-
else:
|
|
287
|
-
assert id(new_idata.posterior) != id(idata1.posterior)
|
|
288
|
-
assert id(new_idata.prior) != id(idata2.prior)
|
|
289
|
-
assert id(new_idata.observed_data) != id(idata3.observed_data)
|
|
290
|
-
else:
|
|
291
|
-
assert id(new_idata.posterior) == id(idata1.posterior)
|
|
292
|
-
assert id(new_idata.prior) == id(idata2.prior)
|
|
293
|
-
assert id(new_idata.observed_data) == id(idata3.observed_data)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
@pytest.mark.parametrize("dim", ["chain", "draw"])
|
|
297
|
-
@pytest.mark.parametrize("copy", [True, False])
|
|
298
|
-
@pytest.mark.parametrize("inplace", [True, False])
|
|
299
|
-
@pytest.mark.parametrize("sequence", [True, False])
|
|
300
|
-
@pytest.mark.parametrize("reset_dim", [True, False])
|
|
301
|
-
def test_concat_dim(dim, copy, inplace, sequence, reset_dim):
|
|
302
|
-
idata1 = from_dict(
|
|
303
|
-
posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
|
|
304
|
-
observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
|
|
305
|
-
)
|
|
306
|
-
if inplace:
|
|
307
|
-
original_idata1_id = id(idata1)
|
|
308
|
-
idata2 = from_dict(
|
|
309
|
-
posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
|
|
310
|
-
observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
|
|
311
|
-
)
|
|
312
|
-
idata3 = from_dict(
|
|
313
|
-
posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
|
|
314
|
-
observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
|
|
315
|
-
)
|
|
316
|
-
# basic case
|
|
317
|
-
assert (
|
|
318
|
-
concat(idata1, idata2, dim=dim, copy=copy, inplace=False, reset_dim=reset_dim) is not None
|
|
319
|
-
)
|
|
320
|
-
if sequence:
|
|
321
|
-
new_idata = concat(
|
|
322
|
-
(idata1, idata2, idata3), copy=copy, dim=dim, inplace=inplace, reset_dim=reset_dim
|
|
323
|
-
)
|
|
324
|
-
else:
|
|
325
|
-
new_idata = concat(
|
|
326
|
-
idata1, idata2, idata3, dim=dim, copy=copy, inplace=inplace, reset_dim=reset_dim
|
|
327
|
-
)
|
|
328
|
-
if inplace:
|
|
329
|
-
assert new_idata is None
|
|
330
|
-
new_idata = idata1
|
|
331
|
-
assert new_idata is not None
|
|
332
|
-
test_dict = {"posterior": ["A", "B"], "observed_data": ["C", "D"]}
|
|
333
|
-
fails = check_multiple_attrs(test_dict, new_idata)
|
|
334
|
-
assert not fails
|
|
335
|
-
if inplace:
|
|
336
|
-
assert id(new_idata) == original_idata1_id
|
|
337
|
-
else:
|
|
338
|
-
assert id(new_idata) != id(idata1)
|
|
339
|
-
assert getattr(new_idata.posterior, dim).size == 6 if dim == "chain" else 30
|
|
340
|
-
if reset_dim:
|
|
341
|
-
assert np.all(
|
|
342
|
-
getattr(new_idata.posterior, dim).values
|
|
343
|
-
== (np.arange(6) if dim == "chain" else np.arange(30))
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@pytest.mark.parametrize("copy", [True, False])
|
|
348
|
-
@pytest.mark.parametrize("inplace", [True, False])
|
|
349
|
-
@pytest.mark.parametrize("sequence", [True, False])
|
|
350
|
-
def test_concat_edgecases(copy, inplace, sequence):
|
|
351
|
-
idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)})
|
|
352
|
-
empty = concat()
|
|
353
|
-
assert empty is not None
|
|
354
|
-
if sequence:
|
|
355
|
-
new_idata = concat([idata], copy=copy, inplace=inplace)
|
|
356
|
-
else:
|
|
357
|
-
new_idata = concat(idata, copy=copy, inplace=inplace)
|
|
358
|
-
if inplace:
|
|
359
|
-
assert new_idata is None
|
|
360
|
-
new_idata = idata
|
|
361
|
-
else:
|
|
362
|
-
assert new_idata is not None
|
|
363
|
-
test_dict = {"posterior": ["A", "B"]}
|
|
364
|
-
fails = check_multiple_attrs(test_dict, new_idata)
|
|
365
|
-
assert not fails
|
|
366
|
-
if copy and not inplace:
|
|
367
|
-
assert id(new_idata.posterior) != id(idata.posterior)
|
|
368
|
-
else:
|
|
369
|
-
assert id(new_idata.posterior) == id(idata.posterior)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
def test_concat_bad():
|
|
373
|
-
with pytest.raises(TypeError):
|
|
374
|
-
concat("hello", "hello")
|
|
375
|
-
idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)})
|
|
376
|
-
idata2 = from_dict(posterior={"A": np.random.randn(2, 10, 2)})
|
|
377
|
-
idata3 = from_dict(prior={"A": np.random.randn(2, 10, 2)})
|
|
378
|
-
with pytest.raises(TypeError):
|
|
379
|
-
concat(idata, np.array([1, 2, 3, 4, 5]))
|
|
380
|
-
with pytest.raises(TypeError):
|
|
381
|
-
concat(idata, idata, dim=None)
|
|
382
|
-
with pytest.raises(TypeError):
|
|
383
|
-
concat(idata, idata2, dim="chain")
|
|
384
|
-
with pytest.raises(TypeError):
|
|
385
|
-
concat(idata2, idata, dim="chain")
|
|
386
|
-
with pytest.raises(TypeError):
|
|
387
|
-
concat(idata, idata3, dim="chain")
|
|
388
|
-
with pytest.raises(TypeError):
|
|
389
|
-
concat(idata3, idata, dim="chain")
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
def test_inference_concat_keeps_all_fields():
|
|
393
|
-
"""From failures observed in issue #907"""
|
|
394
|
-
idata1 = from_dict(posterior={"A": [1, 2, 3, 4]}, sample_stats={"B": [2, 3, 4, 5]})
|
|
395
|
-
idata2 = from_dict(prior={"C": [1, 2, 3, 4]}, observed_data={"D": [2, 3, 4, 5]})
|
|
396
|
-
|
|
397
|
-
idata_c1 = concat(idata1, idata2)
|
|
398
|
-
idata_c2 = concat(idata2, idata1)
|
|
399
|
-
|
|
400
|
-
test_dict = {"posterior": ["A"], "sample_stats": ["B"], "prior": ["C"], "observed_data": ["D"]}
|
|
401
|
-
|
|
402
|
-
fails_c1 = check_multiple_attrs(test_dict, idata_c1)
|
|
403
|
-
assert not fails_c1
|
|
404
|
-
fails_c2 = check_multiple_attrs(test_dict, idata_c2)
|
|
405
|
-
assert not fails_c2
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
@pytest.mark.parametrize(
|
|
409
|
-
"model_code,expected",
|
|
410
|
-
[
|
|
411
|
-
("data {int y;} models {y ~ poisson(3);} generated quantities {int X;}", {"X": "int"}),
|
|
412
|
-
(
|
|
413
|
-
"data {real y;} models {y ~ normal(0,1);} generated quantities {int Y; real G;}",
|
|
414
|
-
{"Y": "int"},
|
|
415
|
-
),
|
|
416
|
-
],
|
|
417
|
-
)
|
|
418
|
-
def test_infer_stan_dtypes(model_code, expected):
|
|
419
|
-
"""Test different examples for dtypes in Stan models."""
|
|
420
|
-
res = infer_stan_dtypes(model_code)
|
|
421
|
-
assert res == expected
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
425
|
-
def test_addition(self):
|
|
426
|
-
idata1 = from_dict(
|
|
427
|
-
posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
|
|
428
|
-
)
|
|
429
|
-
idata2 = from_dict(
|
|
430
|
-
prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)}
|
|
431
|
-
)
|
|
432
|
-
new_idata = idata1 + idata2
|
|
433
|
-
assert new_idata is not None
|
|
434
|
-
test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"]}
|
|
435
|
-
fails = check_multiple_attrs(test_dict, new_idata)
|
|
436
|
-
assert not fails
|
|
437
|
-
|
|
438
|
-
def test_iter(self, models):
|
|
439
|
-
idata = models.model_1
|
|
440
|
-
for group in idata:
|
|
441
|
-
assert group in idata._groups_all # pylint: disable=protected-access
|
|
442
|
-
|
|
443
|
-
def test_groups(self, models):
|
|
444
|
-
idata = models.model_1
|
|
445
|
-
for group in idata.groups():
|
|
446
|
-
assert group in idata._groups_all # pylint: disable=protected-access
|
|
447
|
-
|
|
448
|
-
def test_values(self, models):
|
|
449
|
-
idata = models.model_1
|
|
450
|
-
datasets = idata.values()
|
|
451
|
-
for group in idata.groups():
|
|
452
|
-
assert group in idata._groups_all # pylint: disable=protected-access
|
|
453
|
-
dataset = getattr(idata, group)
|
|
454
|
-
assert dataset in datasets
|
|
455
|
-
|
|
456
|
-
def test_items(self, models):
|
|
457
|
-
idata = models.model_1
|
|
458
|
-
for group, dataset in idata.items():
|
|
459
|
-
assert group in idata._groups_all # pylint: disable=protected-access
|
|
460
|
-
assert dataset.equals(getattr(idata, group))
|
|
461
|
-
|
|
462
|
-
@pytest.mark.parametrize("inplace", [True, False])
|
|
463
|
-
def test_extend_xr_method(self, data_random, inplace):
|
|
464
|
-
idata = data_random
|
|
465
|
-
idata_copy = deepcopy(idata)
|
|
466
|
-
kwargs = {"groups": "posterior_groups"}
|
|
467
|
-
if inplace:
|
|
468
|
-
idata_copy.sum(dim="draw", inplace=inplace, **kwargs)
|
|
469
|
-
else:
|
|
470
|
-
idata2 = idata_copy.sum(dim="draw", inplace=inplace, **kwargs)
|
|
471
|
-
assert idata2 is not idata_copy
|
|
472
|
-
idata_copy = idata2
|
|
473
|
-
assert_identical(idata_copy.posterior, idata.posterior.sum(dim="draw"))
|
|
474
|
-
assert_identical(
|
|
475
|
-
idata_copy.posterior_predictive, idata.posterior_predictive.sum(dim="draw")
|
|
476
|
-
)
|
|
477
|
-
assert_identical(idata_copy.observed_data, idata.observed_data)
|
|
478
|
-
|
|
479
|
-
@pytest.mark.parametrize("inplace", [False, True])
|
|
480
|
-
def test_sel(self, data_random, inplace):
|
|
481
|
-
idata = data_random
|
|
482
|
-
original_groups = getattr(idata, "_groups")
|
|
483
|
-
ndraws = idata.posterior.draw.values.size
|
|
484
|
-
kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]}
|
|
485
|
-
if inplace:
|
|
486
|
-
idata.sel(inplace=inplace, **kwargs)
|
|
487
|
-
else:
|
|
488
|
-
idata2 = idata.sel(inplace=inplace, **kwargs)
|
|
489
|
-
assert idata2 is not idata
|
|
490
|
-
idata = idata2
|
|
491
|
-
groups = getattr(idata, "_groups")
|
|
492
|
-
assert np.all(np.isin(groups, original_groups))
|
|
493
|
-
for group in groups:
|
|
494
|
-
dataset = getattr(idata, group)
|
|
495
|
-
assert "b_dim_0" in dataset.dims
|
|
496
|
-
assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"]))
|
|
497
|
-
if group != "observed_data":
|
|
498
|
-
assert np.all(np.isin(["chain", "draw"], dataset.dims))
|
|
499
|
-
assert np.all(dataset.chain.values == np.arange(0, 4, 2))
|
|
500
|
-
assert np.all(dataset.draw.values == np.arange(200, ndraws))
|
|
501
|
-
|
|
502
|
-
def test_sel_chain_prior(self):
|
|
503
|
-
idata = load_arviz_data("centered_eight")
|
|
504
|
-
original_groups = getattr(idata, "_groups")
|
|
505
|
-
idata_subset = idata.sel(inplace=False, chain_prior=False, chain=[0, 1, 3])
|
|
506
|
-
groups = getattr(idata_subset, "_groups")
|
|
507
|
-
assert np.all(np.isin(groups, original_groups))
|
|
508
|
-
for group in groups:
|
|
509
|
-
dataset_subset = getattr(idata_subset, group)
|
|
510
|
-
dataset = getattr(idata, group)
|
|
511
|
-
if "chain" in dataset.dims:
|
|
512
|
-
assert "chain" in dataset_subset.dims
|
|
513
|
-
if "prior" not in group:
|
|
514
|
-
assert np.all(dataset_subset.chain.values == np.array([0, 1, 3]))
|
|
515
|
-
else:
|
|
516
|
-
assert "chain" not in dataset_subset.dims
|
|
517
|
-
with pytest.raises(KeyError):
|
|
518
|
-
idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
|
|
519
|
-
|
|
520
|
-
@pytest.mark.parametrize("use", ("del", "delattr", "delitem"))
|
|
521
|
-
def test_del(self, use):
|
|
522
|
-
# create inference data object
|
|
523
|
-
data = np.random.normal(size=(4, 500, 8))
|
|
524
|
-
idata = from_dict(
|
|
525
|
-
posterior={"a": data[..., 0], "b": data},
|
|
526
|
-
sample_stats={"a": data[..., 0], "b": data},
|
|
527
|
-
observed_data={"b": data[0, 0, :]},
|
|
528
|
-
posterior_predictive={"a": data[..., 0], "b": data},
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
# assert inference data object has all attributes
|
|
532
|
-
test_dict = {
|
|
533
|
-
"posterior": ("a", "b"),
|
|
534
|
-
"sample_stats": ("a", "b"),
|
|
535
|
-
"observed_data": ["b"],
|
|
536
|
-
"posterior_predictive": ("a", "b"),
|
|
537
|
-
}
|
|
538
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
539
|
-
assert not fails
|
|
540
|
-
# assert _groups attribute contains all groups
|
|
541
|
-
groups = getattr(idata, "_groups")
|
|
542
|
-
assert all((group in groups for group in test_dict))
|
|
543
|
-
|
|
544
|
-
# Use del method
|
|
545
|
-
if use == "del":
|
|
546
|
-
del idata.sample_stats
|
|
547
|
-
elif use == "delitem":
|
|
548
|
-
del idata["sample_stats"]
|
|
549
|
-
else:
|
|
550
|
-
delattr(idata, "sample_stats")
|
|
551
|
-
|
|
552
|
-
# assert attribute has been removed
|
|
553
|
-
test_dict.pop("sample_stats")
|
|
554
|
-
fails = check_multiple_attrs(test_dict, idata)
|
|
555
|
-
assert not fails
|
|
556
|
-
assert not hasattr(idata, "sample_stats")
|
|
557
|
-
# assert _groups attribute has been updated
|
|
558
|
-
assert "sample_stats" not in getattr(idata, "_groups")
|
|
559
|
-
|
|
560
|
-
@pytest.mark.parametrize(
|
|
561
|
-
"args_res",
|
|
562
|
-
(
|
|
563
|
-
([("posterior", "sample_stats")], ("posterior", "sample_stats")),
|
|
564
|
-
(["posterior", "like"], ("posterior", "warmup_posterior", "posterior_predictive")),
|
|
565
|
-
(["^posterior", "regex"], ("posterior", "posterior_predictive")),
|
|
566
|
-
(
|
|
567
|
-
[("~^warmup", "~^obs"), "regex"],
|
|
568
|
-
("posterior", "sample_stats", "posterior_predictive"),
|
|
569
|
-
),
|
|
570
|
-
(
|
|
571
|
-
["~observed_vars"],
|
|
572
|
-
("posterior", "sample_stats", "warmup_posterior", "warmup_sample_stats"),
|
|
573
|
-
),
|
|
574
|
-
),
|
|
575
|
-
)
|
|
576
|
-
def test_group_names(self, args_res):
|
|
577
|
-
args, result = args_res
|
|
578
|
-
ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
|
|
579
|
-
idata = InferenceData(
|
|
580
|
-
posterior=(ds, ds),
|
|
581
|
-
sample_stats=(ds, ds),
|
|
582
|
-
observed_data=ds,
|
|
583
|
-
posterior_predictive=ds,
|
|
584
|
-
)
|
|
585
|
-
group_names = idata._group_names(*args) # pylint: disable=protected-access
|
|
586
|
-
assert np.all([name in result for name in group_names])
|
|
587
|
-
|
|
588
|
-
def test_group_names_invalid_args(self):
|
|
589
|
-
ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
|
|
590
|
-
idata = InferenceData(posterior=(ds, ds))
|
|
591
|
-
msg = r"^\'filter_groups\' can only be None, \'like\', or \'regex\', got: 'foo'$"
|
|
592
|
-
with pytest.raises(ValueError, match=msg):
|
|
593
|
-
idata._group_names( # pylint: disable=protected-access
|
|
594
|
-
("posterior",), filter_groups="foo"
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
@pytest.mark.parametrize("inplace", [False, True])
|
|
598
|
-
def test_isel(self, data_random, inplace):
|
|
599
|
-
idata = data_random
|
|
600
|
-
original_groups = getattr(idata, "_groups")
|
|
601
|
-
ndraws = idata.posterior.draw.values.size
|
|
602
|
-
kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]}
|
|
603
|
-
if inplace:
|
|
604
|
-
idata.isel(inplace=inplace, **kwargs)
|
|
605
|
-
else:
|
|
606
|
-
idata2 = idata.isel(inplace=inplace, **kwargs)
|
|
607
|
-
assert idata2 is not idata
|
|
608
|
-
idata = idata2
|
|
609
|
-
groups = getattr(idata, "_groups")
|
|
610
|
-
assert np.all(np.isin(groups, original_groups))
|
|
611
|
-
for group in groups:
|
|
612
|
-
dataset = getattr(idata, group)
|
|
613
|
-
assert "b_dim_0" in dataset.dims
|
|
614
|
-
assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"]))
|
|
615
|
-
if group != "observed_data":
|
|
616
|
-
assert np.all(np.isin(["chain", "draw"], dataset.dims))
|
|
617
|
-
assert np.all(dataset.chain.values == np.arange(0, 4, 2))
|
|
618
|
-
assert np.all(dataset.draw.values == np.arange(200, ndraws))
|
|
619
|
-
|
|
620
|
-
def test_rename(self, data_random):
|
|
621
|
-
idata = data_random
|
|
622
|
-
original_groups = getattr(idata, "_groups")
|
|
623
|
-
renamed_idata = idata.rename({"b": "b_new"})
|
|
624
|
-
for group in original_groups:
|
|
625
|
-
xr_data = getattr(renamed_idata, group)
|
|
626
|
-
assert "b_new" in list(xr_data.data_vars)
|
|
627
|
-
assert "b" not in list(xr_data.data_vars)
|
|
628
|
-
|
|
629
|
-
renamed_idata = idata.rename({"b_dim_0": "b_new"})
|
|
630
|
-
for group in original_groups:
|
|
631
|
-
xr_data = getattr(renamed_idata, group)
|
|
632
|
-
assert "b_new" in list(xr_data.dims)
|
|
633
|
-
assert "b_dim_0" not in list(xr_data.dims)
|
|
634
|
-
|
|
635
|
-
def test_rename_vars(self, data_random):
|
|
636
|
-
idata = data_random
|
|
637
|
-
original_groups = getattr(idata, "_groups")
|
|
638
|
-
renamed_idata = idata.rename_vars({"b": "b_new"})
|
|
639
|
-
for group in original_groups:
|
|
640
|
-
xr_data = getattr(renamed_idata, group)
|
|
641
|
-
assert "b_new" in list(xr_data.data_vars)
|
|
642
|
-
assert "b" not in list(xr_data.data_vars)
|
|
643
|
-
|
|
644
|
-
renamed_idata = idata.rename_vars({"b_dim_0": "b_new"})
|
|
645
|
-
for group in original_groups:
|
|
646
|
-
xr_data = getattr(renamed_idata, group)
|
|
647
|
-
assert "b_new" not in list(xr_data.dims)
|
|
648
|
-
assert "b_dim_0" in list(xr_data.dims)
|
|
649
|
-
|
|
650
|
-
def test_rename_dims(self, data_random):
|
|
651
|
-
idata = data_random
|
|
652
|
-
original_groups = getattr(idata, "_groups")
|
|
653
|
-
renamed_idata = idata.rename_dims({"b_dim_0": "b_new"})
|
|
654
|
-
for group in original_groups:
|
|
655
|
-
xr_data = getattr(renamed_idata, group)
|
|
656
|
-
assert "b_new" in list(xr_data.dims)
|
|
657
|
-
assert "b_dim_0" not in list(xr_data.dims)
|
|
658
|
-
|
|
659
|
-
renamed_idata = idata.rename_dims({"b": "b_new"})
|
|
660
|
-
for group in original_groups:
|
|
661
|
-
xr_data = getattr(renamed_idata, group)
|
|
662
|
-
assert "b_new" not in list(xr_data.data_vars)
|
|
663
|
-
assert "b" in list(xr_data.data_vars)
|
|
664
|
-
|
|
665
|
-
def test_stack_unstack(self):
|
|
666
|
-
datadict = {
|
|
667
|
-
"a": np.random.randn(100),
|
|
668
|
-
"b": np.random.randn(1, 100, 10),
|
|
669
|
-
"c": np.random.randn(1, 100, 3, 4),
|
|
670
|
-
}
|
|
671
|
-
coords = {
|
|
672
|
-
"c1": np.arange(3),
|
|
673
|
-
"c99": np.arange(4),
|
|
674
|
-
"b1": np.arange(10),
|
|
675
|
-
}
|
|
676
|
-
dims = {"c": ["c1", "c99"], "b": ["b1"]}
|
|
677
|
-
dataset = from_dict(posterior=datadict, coords=coords, dims=dims)
|
|
678
|
-
assert_identical(
|
|
679
|
-
dataset.stack(z=["c1", "c99"]).posterior, dataset.posterior.stack(z=["c1", "c99"])
|
|
680
|
-
)
|
|
681
|
-
assert_identical(dataset.stack(z=["c1", "c99"]).unstack().posterior, dataset.posterior)
|
|
682
|
-
assert_identical(
|
|
683
|
-
dataset.stack(z=["c1", "c99"]).unstack(dim="z").posterior, dataset.posterior
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
def test_stack_bool(self):
|
|
687
|
-
datadict = {
|
|
688
|
-
"a": np.random.randn(100),
|
|
689
|
-
"b": np.random.randn(1, 100, 10),
|
|
690
|
-
"c": np.random.randn(1, 100, 3, 4),
|
|
691
|
-
}
|
|
692
|
-
coords = {
|
|
693
|
-
"c1": np.arange(3),
|
|
694
|
-
"c99": np.arange(4),
|
|
695
|
-
"b1": np.arange(10),
|
|
696
|
-
}
|
|
697
|
-
dims = {"c": ["c1", "c99"], "b": ["b1"]}
|
|
698
|
-
dataset = from_dict(posterior=datadict, coords=coords, dims=dims)
|
|
699
|
-
assert_identical(
|
|
700
|
-
dataset.stack(z=["c1", "c99"], create_index=False).posterior,
|
|
701
|
-
dataset.posterior.stack(z=["c1", "c99"], create_index=False),
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
def test_to_dict(self, models):
|
|
705
|
-
idata = models.model_1
|
|
706
|
-
test_data = from_dict(**idata.to_dict())
|
|
707
|
-
assert test_data
|
|
708
|
-
for group in idata._groups_all: # pylint: disable=protected-access
|
|
709
|
-
xr_data = getattr(idata, group)
|
|
710
|
-
test_xr_data = getattr(test_data, group)
|
|
711
|
-
assert xr_data.equals(test_xr_data)
|
|
712
|
-
|
|
713
|
-
def test_to_dict_warmup(self):
|
|
714
|
-
idata = create_data_random(
|
|
715
|
-
groups=[
|
|
716
|
-
"posterior",
|
|
717
|
-
"sample_stats",
|
|
718
|
-
"observed_data",
|
|
719
|
-
"warmup_posterior",
|
|
720
|
-
"warmup_posterior_predictive",
|
|
721
|
-
]
|
|
722
|
-
)
|
|
723
|
-
test_data = from_dict(**idata.to_dict(), save_warmup=True)
|
|
724
|
-
assert test_data
|
|
725
|
-
for group in idata._groups_all: # pylint: disable=protected-access
|
|
726
|
-
xr_data = getattr(idata, group)
|
|
727
|
-
test_xr_data = getattr(test_data, group)
|
|
728
|
-
assert xr_data.equals(test_xr_data)
|
|
729
|
-
|
|
730
|
-
@pytest.mark.parametrize(
|
|
731
|
-
"kwargs",
|
|
732
|
-
(
|
|
733
|
-
{
|
|
734
|
-
"groups": "posterior",
|
|
735
|
-
"include_coords": True,
|
|
736
|
-
"include_index": True,
|
|
737
|
-
"index_origin": 0,
|
|
738
|
-
},
|
|
739
|
-
{
|
|
740
|
-
"groups": ["posterior", "sample_stats"],
|
|
741
|
-
"include_coords": False,
|
|
742
|
-
"include_index": True,
|
|
743
|
-
"index_origin": 0,
|
|
744
|
-
},
|
|
745
|
-
{
|
|
746
|
-
"groups": "posterior_groups",
|
|
747
|
-
"include_coords": True,
|
|
748
|
-
"include_index": False,
|
|
749
|
-
"index_origin": 1,
|
|
750
|
-
},
|
|
751
|
-
),
|
|
752
|
-
)
|
|
753
|
-
def test_to_dataframe(self, kwargs):
|
|
754
|
-
idata = from_dict(
|
|
755
|
-
posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
|
|
756
|
-
sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
|
|
757
|
-
observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
|
|
758
|
-
)
|
|
759
|
-
test_data = idata.to_dataframe(**kwargs)
|
|
760
|
-
assert not test_data.empty
|
|
761
|
-
groups = kwargs.get("groups", idata._groups_all) # pylint: disable=protected-access
|
|
762
|
-
for group in idata._groups_all: # pylint: disable=protected-access
|
|
763
|
-
if "data" in group:
|
|
764
|
-
continue
|
|
765
|
-
assert test_data.shape == (
|
|
766
|
-
(4 * 100, 3 * 4 * 5 + 1 + 2)
|
|
767
|
-
if groups == "posterior"
|
|
768
|
-
else (4 * 100, (3 * 4 * 5 + 1) * 2 + 2)
|
|
769
|
-
)
|
|
770
|
-
if groups == "posterior":
|
|
771
|
-
if kwargs.get("include_coords", True) and kwargs.get("include_index", True):
|
|
772
|
-
assert any(
|
|
773
|
-
f"[{kwargs.get('index_origin', 0)}," in item[0]
|
|
774
|
-
for item in test_data.columns
|
|
775
|
-
if isinstance(item, tuple)
|
|
776
|
-
)
|
|
777
|
-
if kwargs.get("include_coords", True):
|
|
778
|
-
assert any(isinstance(item, tuple) for item in test_data.columns)
|
|
779
|
-
else:
|
|
780
|
-
assert not any(isinstance(item, tuple) for item in test_data.columns)
|
|
781
|
-
else:
|
|
782
|
-
if not kwargs.get("include_index", True):
|
|
783
|
-
assert all(
|
|
784
|
-
item in test_data.columns
|
|
785
|
-
for item in (("posterior", "a", 1, 1, 1), ("posterior", "b"))
|
|
786
|
-
)
|
|
787
|
-
assert all(item in test_data.columns for item in ("chain", "draw"))
|
|
788
|
-
|
|
789
|
-
@pytest.mark.parametrize(
|
|
790
|
-
"kwargs",
|
|
791
|
-
(
|
|
792
|
-
{
|
|
793
|
-
"var_names": ["parameter_1", "parameter_2", "variable_1", "variable_2"],
|
|
794
|
-
"filter_vars": None,
|
|
795
|
-
"var_results": [
|
|
796
|
-
("posterior", "parameter_1"),
|
|
797
|
-
("posterior", "parameter_2"),
|
|
798
|
-
("prior", "parameter_1"),
|
|
799
|
-
("prior", "parameter_2"),
|
|
800
|
-
("posterior", "variable_1"),
|
|
801
|
-
("posterior", "variable_2"),
|
|
802
|
-
],
|
|
803
|
-
},
|
|
804
|
-
{
|
|
805
|
-
"var_names": "parameter",
|
|
806
|
-
"filter_vars": "like",
|
|
807
|
-
"groups": "posterior",
|
|
808
|
-
"var_results": ["parameter_1", "parameter_2"],
|
|
809
|
-
},
|
|
810
|
-
{
|
|
811
|
-
"var_names": "~parameter",
|
|
812
|
-
"filter_vars": "like",
|
|
813
|
-
"groups": "posterior",
|
|
814
|
-
"var_results": ["variable_1", "variable_2", "custom_name"],
|
|
815
|
-
},
|
|
816
|
-
{
|
|
817
|
-
"var_names": [".+_2$", "custom_name"],
|
|
818
|
-
"filter_vars": "regex",
|
|
819
|
-
"groups": "posterior",
|
|
820
|
-
"var_results": ["parameter_2", "variable_2", "custom_name"],
|
|
821
|
-
},
|
|
822
|
-
{
|
|
823
|
-
"var_names": ["lp"],
|
|
824
|
-
"filter_vars": "regex",
|
|
825
|
-
"groups": "sample_stats",
|
|
826
|
-
"var_results": ["lp"],
|
|
827
|
-
},
|
|
828
|
-
),
|
|
829
|
-
)
|
|
830
|
-
def test_to_dataframe_selection(self, kwargs):
|
|
831
|
-
results = kwargs.pop("var_results")
|
|
832
|
-
idata = from_dict(
|
|
833
|
-
posterior={
|
|
834
|
-
"parameter_1": np.random.randn(4, 100),
|
|
835
|
-
"parameter_2": np.random.randn(4, 100),
|
|
836
|
-
"variable_1": np.random.randn(4, 100),
|
|
837
|
-
"variable_2": np.random.randn(4, 100),
|
|
838
|
-
"custom_name": np.random.randn(4, 100),
|
|
839
|
-
},
|
|
840
|
-
prior={
|
|
841
|
-
"parameter_1": np.random.randn(4, 100),
|
|
842
|
-
"parameter_2": np.random.randn(4, 100),
|
|
843
|
-
},
|
|
844
|
-
sample_stats={
|
|
845
|
-
"lp": np.random.randn(4, 100),
|
|
846
|
-
},
|
|
847
|
-
)
|
|
848
|
-
test_data = idata.to_dataframe(**kwargs)
|
|
849
|
-
assert not test_data.empty
|
|
850
|
-
assert set(test_data.columns).symmetric_difference(results) == set(["chain", "draw"])
|
|
851
|
-
|
|
852
|
-
def test_to_dataframe_bad(self):
|
|
853
|
-
idata = from_dict(
|
|
854
|
-
posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
|
|
855
|
-
sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
|
|
856
|
-
observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
|
|
857
|
-
)
|
|
858
|
-
with pytest.raises(TypeError):
|
|
859
|
-
idata.to_dataframe(index_origin=2)
|
|
860
|
-
|
|
861
|
-
with pytest.raises(TypeError):
|
|
862
|
-
idata.to_dataframe(include_coords=False, include_index=False)
|
|
863
|
-
|
|
864
|
-
with pytest.raises(TypeError):
|
|
865
|
-
idata.to_dataframe(groups=["observed_data"])
|
|
866
|
-
|
|
867
|
-
with pytest.raises(KeyError):
|
|
868
|
-
idata.to_dataframe(groups=["invalid_group"])
|
|
869
|
-
|
|
870
|
-
with pytest.raises(ValueError):
|
|
871
|
-
idata.to_dataframe(var_names=["c"])
|
|
872
|
-
|
|
873
|
-
@pytest.mark.parametrize("use", (None, "args", "kwargs"))
|
|
874
|
-
def test_map(self, use):
|
|
875
|
-
idata = load_arviz_data("centered_eight")
|
|
876
|
-
args = []
|
|
877
|
-
kwargs = {}
|
|
878
|
-
if use is None:
|
|
879
|
-
fun = lambda x: x + 3
|
|
880
|
-
elif use == "args":
|
|
881
|
-
fun = lambda x, a: x + a
|
|
882
|
-
args = [3]
|
|
883
|
-
else:
|
|
884
|
-
fun = lambda x, a: x + a
|
|
885
|
-
kwargs = {"a": 3}
|
|
886
|
-
groups = ("observed_data", "posterior_predictive")
|
|
887
|
-
idata_map = idata.map(fun, groups, args=args, **kwargs)
|
|
888
|
-
groups_map = idata_map._groups # pylint: disable=protected-access
|
|
889
|
-
assert groups_map == idata._groups # pylint: disable=protected-access
|
|
890
|
-
assert np.allclose(
|
|
891
|
-
idata_map.observed_data.obs, fun(idata.observed_data.obs, *args, **kwargs)
|
|
892
|
-
)
|
|
893
|
-
assert np.allclose(
|
|
894
|
-
idata_map.posterior_predictive.obs, fun(idata.posterior_predictive.obs, *args, **kwargs)
|
|
895
|
-
)
|
|
896
|
-
assert np.allclose(idata_map.posterior.mu, idata.posterior.mu)
|
|
897
|
-
|
|
898
|
-
def test_repr_html(self):
|
|
899
|
-
"""Test if the function _repr_html is generating html."""
|
|
900
|
-
idata = load_arviz_data("centered_eight")
|
|
901
|
-
display_style = OPTIONS["display_style"]
|
|
902
|
-
xr.set_options(display_style="html")
|
|
903
|
-
html = idata._repr_html_() # pylint: disable=protected-access
|
|
904
|
-
|
|
905
|
-
assert html is not None
|
|
906
|
-
assert "<div" in html
|
|
907
|
-
for group in idata._groups: # pylint: disable=protected-access
|
|
908
|
-
assert group in html
|
|
909
|
-
xr_data = getattr(idata, group)
|
|
910
|
-
for item, _ in xr_data.items():
|
|
911
|
-
assert item in html
|
|
912
|
-
specific_style = ".xr-wrap{width:700px!important;}"
|
|
913
|
-
assert specific_style in html
|
|
914
|
-
|
|
915
|
-
xr.set_options(display_style="text")
|
|
916
|
-
html = idata._repr_html_() # pylint: disable=protected-access
|
|
917
|
-
assert escape(repr(idata)) in html
|
|
918
|
-
xr.set_options(display_style=display_style)
|
|
919
|
-
|
|
920
|
-
def test_setitem(self, data_random):
|
|
921
|
-
data_random["new_group"] = data_random.posterior
|
|
922
|
-
assert "new_group" in data_random.groups()
|
|
923
|
-
assert hasattr(data_random, "new_group")
|
|
924
|
-
|
|
925
|
-
def test_add_groups(self, data_random):
|
|
926
|
-
data = np.random.normal(size=(4, 500, 8))
|
|
927
|
-
idata = data_random
|
|
928
|
-
idata.add_groups({"prior": {"a": data[..., 0], "b": data}})
|
|
929
|
-
assert "prior" in idata._groups # pylint: disable=protected-access
|
|
930
|
-
assert isinstance(idata.prior, xr.Dataset)
|
|
931
|
-
assert hasattr(idata, "prior")
|
|
932
|
-
|
|
933
|
-
idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data})
|
|
934
|
-
assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access
|
|
935
|
-
assert isinstance(idata.warmup_posterior, xr.Dataset)
|
|
936
|
-
assert hasattr(idata, "warmup_posterior")
|
|
937
|
-
|
|
938
|
-
def test_add_groups_warning(self, data_random):
|
|
939
|
-
data = np.random.normal(size=(4, 500, 8))
|
|
940
|
-
idata = data_random
|
|
941
|
-
with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
|
|
942
|
-
idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
|
|
943
|
-
with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
|
|
944
|
-
idata.add_groups(constant_data={"a": data[..., 0], "b": data})
|
|
945
|
-
assert idata.new_group.equals(idata.posterior)
|
|
946
|
-
|
|
947
|
-
def test_add_groups_error(self, data_random):
|
|
948
|
-
idata = data_random
|
|
949
|
-
with pytest.raises(ValueError, match="One of.+must be provided."):
|
|
950
|
-
idata.add_groups()
|
|
951
|
-
with pytest.raises(ValueError, match="Arguments.+xr.Dataset, xr.Dataarray or dicts"):
|
|
952
|
-
idata.add_groups({"new_group": "new_group"})
|
|
953
|
-
with pytest.raises(ValueError, match="group.+already exists"):
|
|
954
|
-
idata.add_groups({"posterior": idata.posterior})
|
|
955
|
-
|
|
956
|
-
def test_extend(self, data_random):
|
|
957
|
-
idata = data_random
|
|
958
|
-
idata2 = create_data_random(
|
|
959
|
-
groups=["prior", "prior_predictive", "observed_data", "warmup_posterior"], seed=7
|
|
960
|
-
)
|
|
961
|
-
idata.extend(idata2)
|
|
962
|
-
assert "prior" in idata._groups_all # pylint: disable=protected-access
|
|
963
|
-
assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access
|
|
964
|
-
assert hasattr(idata, "prior")
|
|
965
|
-
assert hasattr(idata, "prior_predictive")
|
|
966
|
-
assert idata.prior.equals(idata2.prior)
|
|
967
|
-
assert not idata.observed_data.equals(idata2.observed_data)
|
|
968
|
-
assert idata.prior_predictive.equals(idata2.prior_predictive)
|
|
969
|
-
|
|
970
|
-
idata.extend(idata2, join="right")
|
|
971
|
-
assert idata.prior.equals(idata2.prior)
|
|
972
|
-
assert idata.observed_data.equals(idata2.observed_data)
|
|
973
|
-
assert idata.prior_predictive.equals(idata2.prior_predictive)
|
|
974
|
-
|
|
975
|
-
def test_extend_errors_warnings(self, data_random):
|
|
976
|
-
idata = data_random
|
|
977
|
-
idata2 = create_data_random(groups=["prior", "prior_predictive", "observed_data"], seed=7)
|
|
978
|
-
with pytest.raises(ValueError, match="Extending.+InferenceData objects only."):
|
|
979
|
-
idata.extend("something")
|
|
980
|
-
with pytest.raises(ValueError, match="join must be either"):
|
|
981
|
-
idata.extend(idata2, join="outer")
|
|
982
|
-
idata2.add_groups(new_group=idata2.prior)
|
|
983
|
-
with pytest.warns(UserWarning, match="new_group"):
|
|
984
|
-
idata.extend(idata2, warn_on_custom_groups=True)
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
class TestNumpyToDataArray:
|
|
988
|
-
def test_1d_dataset(self):
|
|
989
|
-
size = 100
|
|
990
|
-
dataset = convert_to_dataset(np.random.randn(size))
|
|
991
|
-
assert len(dataset.data_vars) == 1
|
|
992
|
-
|
|
993
|
-
assert set(dataset.coords) == {"chain", "draw"}
|
|
994
|
-
assert dataset.chain.shape == (1,)
|
|
995
|
-
assert dataset.draw.shape == (size,)
|
|
996
|
-
|
|
997
|
-
def test_warns_bad_shape(self):
|
|
998
|
-
# Shape should be (chain, draw, *shape)
|
|
999
|
-
with pytest.warns(UserWarning):
|
|
1000
|
-
convert_to_dataset(np.random.randn(100, 4))
|
|
1001
|
-
|
|
1002
|
-
def test_nd_to_dataset(self):
|
|
1003
|
-
shape = (1, 2, 3, 4, 5)
|
|
1004
|
-
dataset = convert_to_dataset(np.random.randn(*shape))
|
|
1005
|
-
assert len(dataset.data_vars) == 1
|
|
1006
|
-
var_name = list(dataset.data_vars)[0]
|
|
1007
|
-
|
|
1008
|
-
assert len(dataset.coords) == len(shape)
|
|
1009
|
-
assert dataset.chain.shape == shape[:1]
|
|
1010
|
-
assert dataset.draw.shape == shape[1:2]
|
|
1011
|
-
assert dataset[var_name].shape == shape
|
|
1012
|
-
|
|
1013
|
-
def test_nd_to_inference_data(self):
|
|
1014
|
-
shape = (1, 2, 3, 4, 5)
|
|
1015
|
-
inference_data = convert_to_inference_data(np.random.randn(*shape), group="prior")
|
|
1016
|
-
assert hasattr(inference_data, "prior")
|
|
1017
|
-
assert len(inference_data.prior.data_vars) == 1
|
|
1018
|
-
var_name = list(inference_data.prior.data_vars)[0]
|
|
1019
|
-
|
|
1020
|
-
assert len(inference_data.prior.coords) == len(shape)
|
|
1021
|
-
assert inference_data.prior.chain.shape == shape[:1]
|
|
1022
|
-
assert inference_data.prior.draw.shape == shape[1:2]
|
|
1023
|
-
assert inference_data.prior[var_name].shape == shape
|
|
1024
|
-
assert repr(inference_data).startswith("Inference data with groups")
|
|
1025
|
-
|
|
1026
|
-
def test_more_chains_than_draws(self):
|
|
1027
|
-
shape = (10, 4)
|
|
1028
|
-
with pytest.warns(UserWarning):
|
|
1029
|
-
inference_data = convert_to_inference_data(np.random.randn(*shape), group="prior")
|
|
1030
|
-
assert hasattr(inference_data, "prior")
|
|
1031
|
-
assert len(inference_data.prior.data_vars) == 1
|
|
1032
|
-
var_name = list(inference_data.prior.data_vars)[0]
|
|
1033
|
-
|
|
1034
|
-
assert len(inference_data.prior.coords) == len(shape)
|
|
1035
|
-
assert inference_data.prior.chain.shape == shape[:1]
|
|
1036
|
-
assert inference_data.prior.draw.shape == shape[1:2]
|
|
1037
|
-
assert inference_data.prior[var_name].shape == shape
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
class TestConvertToDataset:
|
|
1041
|
-
@pytest.fixture(scope="class")
|
|
1042
|
-
def data(self):
|
|
1043
|
-
# pylint: disable=attribute-defined-outside-init
|
|
1044
|
-
class Data:
|
|
1045
|
-
datadict = {
|
|
1046
|
-
"a": np.random.randn(100),
|
|
1047
|
-
"b": np.random.randn(1, 100, 10),
|
|
1048
|
-
"c": np.random.randn(1, 100, 3, 4),
|
|
1049
|
-
}
|
|
1050
|
-
coords = {"c1": np.arange(3), "c2": np.arange(4), "b1": np.arange(10)}
|
|
1051
|
-
dims = {"b": ["b1"], "c": ["c1", "c2"]}
|
|
1052
|
-
|
|
1053
|
-
return Data
|
|
1054
|
-
|
|
1055
|
-
def test_use_all(self, data):
|
|
1056
|
-
dataset = convert_to_dataset(data.datadict, coords=data.coords, dims=data.dims)
|
|
1057
|
-
assert set(dataset.data_vars) == {"a", "b", "c"}
|
|
1058
|
-
assert set(dataset.coords) == {"chain", "draw", "c1", "c2", "b1"}
|
|
1059
|
-
|
|
1060
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1061
|
-
assert set(dataset.b.coords) == {"chain", "draw", "b1"}
|
|
1062
|
-
assert set(dataset.c.coords) == {"chain", "draw", "c1", "c2"}
|
|
1063
|
-
|
|
1064
|
-
def test_missing_coords(self, data):
|
|
1065
|
-
dataset = convert_to_dataset(data.datadict, coords=None, dims=data.dims)
|
|
1066
|
-
assert set(dataset.data_vars) == {"a", "b", "c"}
|
|
1067
|
-
assert set(dataset.coords) == {"chain", "draw", "c1", "c2", "b1"}
|
|
1068
|
-
|
|
1069
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1070
|
-
assert set(dataset.b.coords) == {"chain", "draw", "b1"}
|
|
1071
|
-
assert set(dataset.c.coords) == {"chain", "draw", "c1", "c2"}
|
|
1072
|
-
|
|
1073
|
-
def test_missing_dims(self, data):
|
|
1074
|
-
# missing dims
|
|
1075
|
-
coords = {"c_dim_0": np.arange(3), "c_dim_1": np.arange(4), "b_dim_0": np.arange(10)}
|
|
1076
|
-
dataset = convert_to_dataset(data.datadict, coords=coords, dims=None)
|
|
1077
|
-
assert set(dataset.data_vars) == {"a", "b", "c"}
|
|
1078
|
-
assert set(dataset.coords) == {"chain", "draw", "c_dim_0", "c_dim_1", "b_dim_0"}
|
|
1079
|
-
|
|
1080
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1081
|
-
assert set(dataset.b.coords) == {"chain", "draw", "b_dim_0"}
|
|
1082
|
-
assert set(dataset.c.coords) == {"chain", "draw", "c_dim_0", "c_dim_1"}
|
|
1083
|
-
|
|
1084
|
-
def test_skip_dim_0(self, data):
|
|
1085
|
-
dims = {"c": [None, "c2"]}
|
|
1086
|
-
coords = {"c_dim_0": np.arange(3), "c2": np.arange(4), "b_dim_0": np.arange(10)}
|
|
1087
|
-
dataset = convert_to_dataset(data.datadict, coords=coords, dims=dims)
|
|
1088
|
-
assert set(dataset.data_vars) == {"a", "b", "c"}
|
|
1089
|
-
assert set(dataset.coords) == {"chain", "draw", "c_dim_0", "c2", "b_dim_0"}
|
|
1090
|
-
|
|
1091
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1092
|
-
assert set(dataset.b.coords) == {"chain", "draw", "b_dim_0"}
|
|
1093
|
-
assert set(dataset.c.coords) == {"chain", "draw", "c_dim_0", "c2"}
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
def test_dict_to_dataset():
|
|
1097
|
-
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
|
|
1098
|
-
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
|
|
1099
|
-
assert set(dataset.data_vars) == {"a", "b"}
|
|
1100
|
-
assert set(dataset.coords) == {"chain", "draw", "c"}
|
|
1101
|
-
|
|
1102
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1103
|
-
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
@pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
|
|
1107
|
-
def test_nested_dict_to_dataset():
|
|
1108
|
-
datadict = {
|
|
1109
|
-
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
1110
|
-
"d": np.random.randn(100),
|
|
1111
|
-
}
|
|
1112
|
-
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
|
|
1113
|
-
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
|
|
1114
|
-
assert set(dataset.coords) == {"chain", "draw", "c"}
|
|
1115
|
-
|
|
1116
|
-
assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
|
|
1117
|
-
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
|
|
1118
|
-
assert set(dataset.d.coords) == {"chain", "draw"}
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
def test_dict_to_dataset_event_dims_error():
|
|
1122
|
-
datadict = {"a": np.random.randn(1, 100, 10)}
|
|
1123
|
-
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
|
|
1124
|
-
msg = "different number of dimensions on data and dims"
|
|
1125
|
-
with pytest.raises(ValueError, match=msg):
|
|
1126
|
-
convert_to_dataset(datadict, coords=coords, dims={"a": ["b", "c"]})
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
def test_dict_to_dataset_with_tuple_coord():
|
|
1130
|
-
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
|
|
1131
|
-
dataset = convert_to_dataset(datadict, coords={"c": tuple(range(10))}, dims={"b": ["c"]})
|
|
1132
|
-
assert set(dataset.data_vars) == {"a", "b"}
|
|
1133
|
-
assert set(dataset.coords) == {"chain", "draw", "c"}
|
|
1134
|
-
|
|
1135
|
-
assert set(dataset.a.coords) == {"chain", "draw"}
|
|
1136
|
-
assert set(dataset.b.coords) == {"chain", "draw", "c"}
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
def test_convert_to_dataset_idempotent():
|
|
1140
|
-
first = convert_to_dataset(np.random.randn(100))
|
|
1141
|
-
second = convert_to_dataset(first)
|
|
1142
|
-
assert first.equals(second)
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
def test_convert_to_inference_data_idempotent():
|
|
1146
|
-
first = convert_to_inference_data(np.random.randn(100), group="prior")
|
|
1147
|
-
second = convert_to_inference_data(first)
|
|
1148
|
-
assert first.prior is second.prior
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
def test_convert_to_inference_data_from_file(tmpdir):
|
|
1152
|
-
first = convert_to_inference_data(np.random.randn(100), group="prior")
|
|
1153
|
-
filename = str(tmpdir.join("test_file.nc"))
|
|
1154
|
-
first.to_netcdf(filename)
|
|
1155
|
-
second = convert_to_inference_data(filename)
|
|
1156
|
-
assert first.prior.equals(second.prior)
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
def test_convert_to_inference_data_bad():
|
|
1160
|
-
with pytest.raises(ValueError):
|
|
1161
|
-
convert_to_inference_data(1)
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def test_convert_to_dataset_bad(tmpdir):
|
|
1165
|
-
first = convert_to_inference_data(np.random.randn(100), group="prior")
|
|
1166
|
-
filename = str(tmpdir.join("test_file.nc"))
|
|
1167
|
-
first.to_netcdf(filename)
|
|
1168
|
-
with pytest.raises(ValueError):
|
|
1169
|
-
convert_to_dataset(filename, group="bar")
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
def test_bad_inference_data():
|
|
1173
|
-
with pytest.raises(ValueError):
|
|
1174
|
-
InferenceData(posterior=[1, 2, 3])
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
@pytest.mark.parametrize("warn", [True, False])
|
|
1178
|
-
def test_inference_data_other_groups(warn):
|
|
1179
|
-
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
|
|
1180
|
-
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
|
|
1181
|
-
if warn:
|
|
1182
|
-
with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
|
|
1183
|
-
idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
|
|
1184
|
-
else:
|
|
1185
|
-
with warnings.catch_warnings():
|
|
1186
|
-
warnings.simplefilter("error")
|
|
1187
|
-
idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
|
|
1188
|
-
fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
|
|
1189
|
-
assert not fails
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
class TestDataConvert:
|
|
1193
|
-
@pytest.fixture(scope="class")
|
|
1194
|
-
def data(self, draws, chains):
|
|
1195
|
-
class Data:
|
|
1196
|
-
# fake 8-school output
|
|
1197
|
-
obj = {}
|
|
1198
|
-
for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
|
|
1199
|
-
obj[key] = np.random.randn(chains, draws, *shape)
|
|
1200
|
-
|
|
1201
|
-
return Data
|
|
1202
|
-
|
|
1203
|
-
def get_inference_data(self, data):
|
|
1204
|
-
return convert_to_inference_data(
|
|
1205
|
-
data.obj,
|
|
1206
|
-
group="posterior",
|
|
1207
|
-
coords={"school": np.arange(8)},
|
|
1208
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
1209
|
-
)
|
|
1210
|
-
|
|
1211
|
-
def check_var_names_coords_dims(self, dataset):
|
|
1212
|
-
assert set(dataset.data_vars) == {"mu", "tau", "eta", "theta"}
|
|
1213
|
-
assert set(dataset.coords) == {"chain", "draw", "school"}
|
|
1214
|
-
|
|
1215
|
-
def test_convert_to_inference_data(self, data):
|
|
1216
|
-
inference_data = self.get_inference_data(data)
|
|
1217
|
-
assert hasattr(inference_data, "posterior")
|
|
1218
|
-
self.check_var_names_coords_dims(inference_data.posterior)
|
|
1219
|
-
|
|
1220
|
-
def test_convert_to_dataset(self, draws, chains, data):
|
|
1221
|
-
dataset = convert_to_dataset(
|
|
1222
|
-
data.obj,
|
|
1223
|
-
group="posterior",
|
|
1224
|
-
coords={"school": np.arange(8)},
|
|
1225
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
1226
|
-
)
|
|
1227
|
-
assert dataset.draw.shape == (draws,)
|
|
1228
|
-
assert dataset.chain.shape == (chains,)
|
|
1229
|
-
assert dataset.school.shape == (8,)
|
|
1230
|
-
assert dataset.theta.shape == (chains, draws, 8)
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
class TestDataDict:
|
|
1234
|
-
@pytest.fixture(scope="class")
|
|
1235
|
-
def data(self, draws, chains):
|
|
1236
|
-
class Data:
|
|
1237
|
-
# fake 8-school output
|
|
1238
|
-
obj = {}
|
|
1239
|
-
for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
|
|
1240
|
-
obj[key] = np.random.randn(chains, draws, *shape)
|
|
1241
|
-
|
|
1242
|
-
return Data
|
|
1243
|
-
|
|
1244
|
-
def check_var_names_coords_dims(self, dataset):
|
|
1245
|
-
assert set(dataset.data_vars) == {"mu", "tau", "eta", "theta"}
|
|
1246
|
-
assert set(dataset.coords) == {"chain", "draw", "school"}
|
|
1247
|
-
|
|
1248
|
-
def get_inference_data(self, data, eight_schools_params, save_warmup=False):
|
|
1249
|
-
return from_dict(
|
|
1250
|
-
posterior=data.obj,
|
|
1251
|
-
posterior_predictive=data.obj,
|
|
1252
|
-
sample_stats=data.obj,
|
|
1253
|
-
prior=data.obj,
|
|
1254
|
-
prior_predictive=data.obj,
|
|
1255
|
-
sample_stats_prior=data.obj,
|
|
1256
|
-
warmup_posterior=data.obj,
|
|
1257
|
-
warmup_posterior_predictive=data.obj,
|
|
1258
|
-
predictions=data.obj,
|
|
1259
|
-
observed_data=eight_schools_params,
|
|
1260
|
-
coords={
|
|
1261
|
-
"school": np.arange(8),
|
|
1262
|
-
},
|
|
1263
|
-
pred_coords={
|
|
1264
|
-
"school_pred": np.arange(8),
|
|
1265
|
-
},
|
|
1266
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
1267
|
-
pred_dims={"theta": ["school_pred"], "eta": ["school_pred"]},
|
|
1268
|
-
save_warmup=save_warmup,
|
|
1269
|
-
)
|
|
1270
|
-
|
|
1271
|
-
def test_inference_data(self, data, eight_schools_params):
|
|
1272
|
-
inference_data = self.get_inference_data(data, eight_schools_params)
|
|
1273
|
-
test_dict = {
|
|
1274
|
-
"posterior": [],
|
|
1275
|
-
"prior": [],
|
|
1276
|
-
"sample_stats": [],
|
|
1277
|
-
"posterior_predictive": [],
|
|
1278
|
-
"prior_predictive": [],
|
|
1279
|
-
"sample_stats_prior": [],
|
|
1280
|
-
"observed_data": ["J", "y", "sigma"],
|
|
1281
|
-
}
|
|
1282
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
1283
|
-
assert not fails
|
|
1284
|
-
self.check_var_names_coords_dims(inference_data.posterior)
|
|
1285
|
-
self.check_var_names_coords_dims(inference_data.posterior_predictive)
|
|
1286
|
-
self.check_var_names_coords_dims(inference_data.sample_stats)
|
|
1287
|
-
self.check_var_names_coords_dims(inference_data.prior)
|
|
1288
|
-
self.check_var_names_coords_dims(inference_data.prior_predictive)
|
|
1289
|
-
self.check_var_names_coords_dims(inference_data.sample_stats_prior)
|
|
1290
|
-
|
|
1291
|
-
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
1292
|
-
assert pred_dims == 8
|
|
1293
|
-
|
|
1294
|
-
def test_inference_data_warmup(self, data, eight_schools_params):
|
|
1295
|
-
inference_data = self.get_inference_data(data, eight_schools_params, save_warmup=True)
|
|
1296
|
-
test_dict = {
|
|
1297
|
-
"posterior": [],
|
|
1298
|
-
"prior": [],
|
|
1299
|
-
"sample_stats": [],
|
|
1300
|
-
"posterior_predictive": [],
|
|
1301
|
-
"prior_predictive": [],
|
|
1302
|
-
"sample_stats_prior": [],
|
|
1303
|
-
"observed_data": ["J", "y", "sigma"],
|
|
1304
|
-
"warmup_posterior_predictive": [],
|
|
1305
|
-
"warmup_posterior": [],
|
|
1306
|
-
}
|
|
1307
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
1308
|
-
assert not fails
|
|
1309
|
-
self.check_var_names_coords_dims(inference_data.posterior)
|
|
1310
|
-
self.check_var_names_coords_dims(inference_data.posterior_predictive)
|
|
1311
|
-
self.check_var_names_coords_dims(inference_data.sample_stats)
|
|
1312
|
-
self.check_var_names_coords_dims(inference_data.prior)
|
|
1313
|
-
self.check_var_names_coords_dims(inference_data.prior_predictive)
|
|
1314
|
-
self.check_var_names_coords_dims(inference_data.sample_stats_prior)
|
|
1315
|
-
self.check_var_names_coords_dims(inference_data.warmup_posterior)
|
|
1316
|
-
self.check_var_names_coords_dims(inference_data.warmup_posterior_predictive)
|
|
1317
|
-
|
|
1318
|
-
def test_inference_data_edge_cases(self):
|
|
1319
|
-
# create data
|
|
1320
|
-
log_likelihood = {
|
|
1321
|
-
"y": np.random.randn(4, 100),
|
|
1322
|
-
"log_likelihood": np.random.randn(4, 100, 8),
|
|
1323
|
-
}
|
|
1324
|
-
|
|
1325
|
-
# log_likelihood to posterior
|
|
1326
|
-
with pytest.warns(UserWarning, match="log_likelihood.+in posterior"):
|
|
1327
|
-
assert from_dict(posterior=log_likelihood) is not None
|
|
1328
|
-
|
|
1329
|
-
# dims == None
|
|
1330
|
-
assert from_dict(observed_data=log_likelihood, dims=None) is not None
|
|
1331
|
-
|
|
1332
|
-
def test_inference_data_bad(self):
|
|
1333
|
-
# create data
|
|
1334
|
-
x = np.random.randn(4, 100)
|
|
1335
|
-
|
|
1336
|
-
# input ndarray
|
|
1337
|
-
with pytest.raises(TypeError):
|
|
1338
|
-
from_dict(posterior=x)
|
|
1339
|
-
with pytest.raises(TypeError):
|
|
1340
|
-
from_dict(posterior_predictive=x)
|
|
1341
|
-
with pytest.raises(TypeError):
|
|
1342
|
-
from_dict(sample_stats=x)
|
|
1343
|
-
with pytest.raises(TypeError):
|
|
1344
|
-
from_dict(prior=x)
|
|
1345
|
-
with pytest.raises(TypeError):
|
|
1346
|
-
from_dict(prior_predictive=x)
|
|
1347
|
-
with pytest.raises(TypeError):
|
|
1348
|
-
from_dict(sample_stats_prior=x)
|
|
1349
|
-
with pytest.raises(TypeError):
|
|
1350
|
-
from_dict(observed_data=x)
|
|
1351
|
-
|
|
1352
|
-
def test_from_dict_warning(self):
|
|
1353
|
-
bad_posterior_dict = {"log_likelihood": np.ones((5, 1000, 2))}
|
|
1354
|
-
with pytest.warns(UserWarning):
|
|
1355
|
-
from_dict(posterior=bad_posterior_dict)
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
class TestDataNetCDF:
|
|
1359
|
-
@pytest.fixture(scope="class")
|
|
1360
|
-
def data(self, draws, chains):
|
|
1361
|
-
class Data:
|
|
1362
|
-
# fake 8-school output
|
|
1363
|
-
obj = {}
|
|
1364
|
-
for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
|
|
1365
|
-
obj[key] = np.random.randn(chains, draws, *shape)
|
|
1366
|
-
|
|
1367
|
-
return Data
|
|
1368
|
-
|
|
1369
|
-
def get_inference_data(self, data, eight_schools_params):
|
|
1370
|
-
return from_dict(
|
|
1371
|
-
posterior=data.obj,
|
|
1372
|
-
posterior_predictive=data.obj,
|
|
1373
|
-
sample_stats=data.obj,
|
|
1374
|
-
prior=data.obj,
|
|
1375
|
-
prior_predictive=data.obj,
|
|
1376
|
-
sample_stats_prior=data.obj,
|
|
1377
|
-
observed_data=eight_schools_params,
|
|
1378
|
-
coords={"school": np.array(["a" * i for i in range(8)], dtype="U")},
|
|
1379
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
1380
|
-
)
|
|
1381
|
-
|
|
1382
|
-
def test_io_function(self, data, eight_schools_params):
|
|
1383
|
-
# create inference data and assert all attributes are present
|
|
1384
|
-
inference_data = self.get_inference_data( # pylint: disable=W0612
|
|
1385
|
-
data, eight_schools_params
|
|
1386
|
-
)
|
|
1387
|
-
test_dict = {
|
|
1388
|
-
"posterior": ["eta", "theta", "mu", "tau"],
|
|
1389
|
-
"posterior_predictive": ["eta", "theta", "mu", "tau"],
|
|
1390
|
-
"sample_stats": ["eta", "theta", "mu", "tau"],
|
|
1391
|
-
"prior": ["eta", "theta", "mu", "tau"],
|
|
1392
|
-
"prior_predictive": ["eta", "theta", "mu", "tau"],
|
|
1393
|
-
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
|
|
1394
|
-
"observed_data": ["J", "y", "sigma"],
|
|
1395
|
-
}
|
|
1396
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
1397
|
-
assert not fails
|
|
1398
|
-
|
|
1399
|
-
# check filename does not exist and save InferenceData
|
|
1400
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
1401
|
-
data_directory = os.path.join(here, "..", "saved_models")
|
|
1402
|
-
filepath = os.path.join(data_directory, "io_function_testfile.nc")
|
|
1403
|
-
# az -function
|
|
1404
|
-
to_netcdf(inference_data, filepath)
|
|
1405
|
-
|
|
1406
|
-
# Assert InferenceData has been saved correctly
|
|
1407
|
-
assert os.path.exists(filepath)
|
|
1408
|
-
assert os.path.getsize(filepath) > 0
|
|
1409
|
-
inference_data2 = from_netcdf(filepath)
|
|
1410
|
-
fails = check_multiple_attrs(test_dict, inference_data2)
|
|
1411
|
-
assert not fails
|
|
1412
|
-
os.remove(filepath)
|
|
1413
|
-
assert not os.path.exists(filepath)
|
|
1414
|
-
|
|
1415
|
-
@pytest.mark.parametrize("base_group", ["/", "test_group", "group/subgroup"])
|
|
1416
|
-
@pytest.mark.parametrize("groups_arg", [False, True])
|
|
1417
|
-
@pytest.mark.parametrize("compress", [True, False])
|
|
1418
|
-
@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
|
|
1419
|
-
def test_io_method(self, data, eight_schools_params, groups_arg, base_group, compress, engine):
|
|
1420
|
-
# create InferenceData and check it has been properly created
|
|
1421
|
-
inference_data = self.get_inference_data( # pylint: disable=W0612
|
|
1422
|
-
data, eight_schools_params
|
|
1423
|
-
)
|
|
1424
|
-
if engine == "h5netcdf":
|
|
1425
|
-
try:
|
|
1426
|
-
import h5netcdf # pylint: disable=unused-import
|
|
1427
|
-
except ImportError:
|
|
1428
|
-
pytest.skip("h5netcdf not installed")
|
|
1429
|
-
elif engine == "netcdf4":
|
|
1430
|
-
try:
|
|
1431
|
-
import netCDF4 # pylint: disable=unused-import
|
|
1432
|
-
except ImportError:
|
|
1433
|
-
pytest.skip("netcdf4 not installed")
|
|
1434
|
-
test_dict = {
|
|
1435
|
-
"posterior": ["eta", "theta", "mu", "tau"],
|
|
1436
|
-
"posterior_predictive": ["eta", "theta", "mu", "tau"],
|
|
1437
|
-
"sample_stats": ["eta", "theta", "mu", "tau"],
|
|
1438
|
-
"prior": ["eta", "theta", "mu", "tau"],
|
|
1439
|
-
"prior_predictive": ["eta", "theta", "mu", "tau"],
|
|
1440
|
-
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
|
|
1441
|
-
"observed_data": ["J", "y", "sigma"],
|
|
1442
|
-
}
|
|
1443
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
1444
|
-
assert not fails
|
|
1445
|
-
|
|
1446
|
-
# check filename does not exist and use to_netcdf method
|
|
1447
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
1448
|
-
data_directory = os.path.join(here, "..", "saved_models")
|
|
1449
|
-
filepath = os.path.join(data_directory, "io_method_testfile.nc")
|
|
1450
|
-
assert not os.path.exists(filepath)
|
|
1451
|
-
# InferenceData method
|
|
1452
|
-
inference_data.to_netcdf(
|
|
1453
|
-
filepath,
|
|
1454
|
-
groups=("posterior", "observed_data") if groups_arg else None,
|
|
1455
|
-
compress=compress,
|
|
1456
|
-
base_group=base_group,
|
|
1457
|
-
)
|
|
1458
|
-
|
|
1459
|
-
# assert file has been saved correctly
|
|
1460
|
-
assert os.path.exists(filepath)
|
|
1461
|
-
assert os.path.getsize(filepath) > 0
|
|
1462
|
-
inference_data2 = InferenceData.from_netcdf(filepath, base_group=base_group)
|
|
1463
|
-
if groups_arg: # if groups arg, update test dict to contain only saved groups
|
|
1464
|
-
test_dict = {
|
|
1465
|
-
"posterior": ["eta", "theta", "mu", "tau"],
|
|
1466
|
-
"observed_data": ["J", "y", "sigma"],
|
|
1467
|
-
}
|
|
1468
|
-
assert not hasattr(inference_data2, "sample_stats")
|
|
1469
|
-
fails = check_multiple_attrs(test_dict, inference_data2)
|
|
1470
|
-
assert not fails
|
|
1471
|
-
|
|
1472
|
-
os.remove(filepath)
|
|
1473
|
-
assert not os.path.exists(filepath)
|
|
1474
|
-
|
|
1475
|
-
def test_empty_inference_data_object(self):
|
|
1476
|
-
inference_data = InferenceData()
|
|
1477
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
1478
|
-
data_directory = os.path.join(here, "..", "saved_models")
|
|
1479
|
-
filepath = os.path.join(data_directory, "empty_test_file.nc")
|
|
1480
|
-
assert not os.path.exists(filepath)
|
|
1481
|
-
inference_data.to_netcdf(filepath)
|
|
1482
|
-
assert os.path.exists(filepath)
|
|
1483
|
-
os.remove(filepath)
|
|
1484
|
-
assert not os.path.exists(filepath)
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
class TestJSON:
|
|
1488
|
-
def test_json_converters(self, models):
|
|
1489
|
-
idata = models.model_1
|
|
1490
|
-
|
|
1491
|
-
filepath = os.path.realpath("test.json")
|
|
1492
|
-
idata.to_json(filepath)
|
|
1493
|
-
|
|
1494
|
-
idata_copy = from_json(filepath)
|
|
1495
|
-
for group in idata._groups_all: # pylint: disable=protected-access
|
|
1496
|
-
xr_data = getattr(idata, group)
|
|
1497
|
-
test_xr_data = getattr(idata_copy, group)
|
|
1498
|
-
assert xr_data.equals(test_xr_data)
|
|
1499
|
-
|
|
1500
|
-
os.remove(filepath)
|
|
1501
|
-
assert not os.path.exists(filepath)
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
class TestDataTree:
|
|
1505
|
-
def test_datatree(self):
|
|
1506
|
-
idata = load_arviz_data("centered_eight")
|
|
1507
|
-
dt = idata.to_datatree()
|
|
1508
|
-
idata_back = from_datatree(dt)
|
|
1509
|
-
for group, ds in idata.items():
|
|
1510
|
-
assert_identical(ds, idata_back[group])
|
|
1511
|
-
assert all(group in dt.children for group in idata.groups())
|
|
1512
|
-
|
|
1513
|
-
def test_datatree_attrs(self):
|
|
1514
|
-
idata = load_arviz_data("centered_eight")
|
|
1515
|
-
idata.attrs = {"not": "empty"}
|
|
1516
|
-
assert idata.attrs
|
|
1517
|
-
dt = idata.to_datatree()
|
|
1518
|
-
idata_back = from_datatree(dt)
|
|
1519
|
-
assert dt.attrs == idata.attrs
|
|
1520
|
-
assert idata_back.attrs == idata.attrs
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
class TestConversions:
|
|
1524
|
-
def test_id_conversion_idempotent(self):
|
|
1525
|
-
stored = load_arviz_data("centered_eight")
|
|
1526
|
-
inference_data = convert_to_inference_data(stored)
|
|
1527
|
-
assert isinstance(inference_data, InferenceData)
|
|
1528
|
-
assert set(inference_data.observed_data.obs.coords["school"].values) == {
|
|
1529
|
-
"Hotchkiss",
|
|
1530
|
-
"Mt. Hermon",
|
|
1531
|
-
"Choate",
|
|
1532
|
-
"Deerfield",
|
|
1533
|
-
"Phillips Andover",
|
|
1534
|
-
"St. Paul's",
|
|
1535
|
-
"Lawrenceville",
|
|
1536
|
-
"Phillips Exeter",
|
|
1537
|
-
}
|
|
1538
|
-
assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
|
|
1539
|
-
|
|
1540
|
-
def test_dataset_conversion_idempotent(self):
|
|
1541
|
-
inference_data = load_arviz_data("centered_eight")
|
|
1542
|
-
data_set = convert_to_dataset(inference_data.posterior)
|
|
1543
|
-
assert isinstance(data_set, xr.Dataset)
|
|
1544
|
-
assert set(data_set.coords["school"].values) == {
|
|
1545
|
-
"Hotchkiss",
|
|
1546
|
-
"Mt. Hermon",
|
|
1547
|
-
"Choate",
|
|
1548
|
-
"Deerfield",
|
|
1549
|
-
"Phillips Andover",
|
|
1550
|
-
"St. Paul's",
|
|
1551
|
-
"Lawrenceville",
|
|
1552
|
-
"Phillips Exeter",
|
|
1553
|
-
}
|
|
1554
|
-
assert data_set["theta"].dims == ("chain", "draw", "school")
|
|
1555
|
-
|
|
1556
|
-
def test_id_conversion_args(self):
|
|
1557
|
-
stored = load_arviz_data("centered_eight")
|
|
1558
|
-
IVIES = ["Yale", "Harvard", "MIT", "Princeton", "Cornell", "Dartmouth", "Columbia", "Brown"]
|
|
1559
|
-
# test dictionary argument...
|
|
1560
|
-
# I reverse engineered a dictionary out of the centered_eight
|
|
1561
|
-
# data. That's what this block of code does.
|
|
1562
|
-
d = stored.posterior.to_dict()
|
|
1563
|
-
d = d["data_vars"]
|
|
1564
|
-
test_dict = {} # type: Dict[str, np.ndarray]
|
|
1565
|
-
for var_name in d:
|
|
1566
|
-
data = d[var_name]["data"]
|
|
1567
|
-
# this is a list of chains that is a list of samples...
|
|
1568
|
-
chain_arrs = []
|
|
1569
|
-
for chain in data: # list of samples
|
|
1570
|
-
chain_arrs.append(np.array(chain))
|
|
1571
|
-
data_arr = np.stack(chain_arrs)
|
|
1572
|
-
test_dict[var_name] = data_arr
|
|
1573
|
-
|
|
1574
|
-
inference_data = convert_to_inference_data(
|
|
1575
|
-
test_dict, dims={"theta": ["Ivies"]}, coords={"Ivies": IVIES}
|
|
1576
|
-
)
|
|
1577
|
-
|
|
1578
|
-
assert isinstance(inference_data, InferenceData)
|
|
1579
|
-
assert set(inference_data.posterior.coords["Ivies"].values) == set(IVIES)
|
|
1580
|
-
assert inference_data.posterior["theta"].dims == ("chain", "draw", "Ivies")
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
class TestDataArrayToDataset:
|
|
1584
|
-
def test_1d_dataset(self):
|
|
1585
|
-
size = 100
|
|
1586
|
-
dataset = convert_to_dataset(
|
|
1587
|
-
xr.DataArray(np.random.randn(1, size), name="plot", dims=("chain", "draw"))
|
|
1588
|
-
)
|
|
1589
|
-
assert len(dataset.data_vars) == 1
|
|
1590
|
-
assert "plot" in dataset.data_vars
|
|
1591
|
-
assert dataset.chain.shape == (1,)
|
|
1592
|
-
assert dataset.draw.shape == (size,)
|
|
1593
|
-
|
|
1594
|
-
def test_nd_to_dataset(self):
|
|
1595
|
-
shape = (1, 2, 3, 4, 5)
|
|
1596
|
-
dataset = convert_to_dataset(
|
|
1597
|
-
xr.DataArray(np.random.randn(*shape), dims=("chain", "draw", "dim_0", "dim_1", "dim_2"))
|
|
1598
|
-
)
|
|
1599
|
-
var_name = list(dataset.data_vars)[0]
|
|
1600
|
-
|
|
1601
|
-
assert len(dataset.data_vars) == 1
|
|
1602
|
-
assert dataset.chain.shape == shape[:1]
|
|
1603
|
-
assert dataset.draw.shape == shape[1:2]
|
|
1604
|
-
assert dataset[var_name].shape == shape
|
|
1605
|
-
|
|
1606
|
-
def test_nd_to_inference_data(self):
|
|
1607
|
-
shape = (1, 2, 3, 4, 5)
|
|
1608
|
-
inference_data = convert_to_inference_data(
|
|
1609
|
-
xr.DataArray(
|
|
1610
|
-
np.random.randn(*shape), dims=("chain", "draw", "dim_0", "dim_1", "dim_2")
|
|
1611
|
-
),
|
|
1612
|
-
group="prior",
|
|
1613
|
-
)
|
|
1614
|
-
var_name = list(inference_data.prior.data_vars)[0]
|
|
1615
|
-
|
|
1616
|
-
assert hasattr(inference_data, "prior")
|
|
1617
|
-
assert len(inference_data.prior.data_vars) == 1
|
|
1618
|
-
assert inference_data.prior.chain.shape == shape[:1]
|
|
1619
|
-
assert inference_data.prior.draw.shape == shape[1:2]
|
|
1620
|
-
assert inference_data.prior[var_name].shape == shape
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
class TestExtractDataset:
|
|
1624
|
-
def test_default(self):
|
|
1625
|
-
idata = load_arviz_data("centered_eight")
|
|
1626
|
-
post = extract(idata)
|
|
1627
|
-
assert isinstance(post, xr.Dataset)
|
|
1628
|
-
assert "sample" in post.dims
|
|
1629
|
-
assert post.theta.size == (4 * 500 * 8)
|
|
1630
|
-
|
|
1631
|
-
def test_seed(self):
|
|
1632
|
-
idata = load_arviz_data("centered_eight")
|
|
1633
|
-
post = extract(idata, rng=7)
|
|
1634
|
-
post_pred = extract(idata, group="posterior_predictive", rng=7)
|
|
1635
|
-
assert all(post.sample == post_pred.sample)
|
|
1636
|
-
|
|
1637
|
-
def test_no_combine(self):
|
|
1638
|
-
idata = load_arviz_data("centered_eight")
|
|
1639
|
-
post = extract(idata, combined=False)
|
|
1640
|
-
assert "sample" not in post.dims
|
|
1641
|
-
assert post.sizes["chain"] == 4
|
|
1642
|
-
assert post.sizes["draw"] == 500
|
|
1643
|
-
|
|
1644
|
-
def test_var_name_group(self):
|
|
1645
|
-
idata = load_arviz_data("centered_eight")
|
|
1646
|
-
prior = extract(idata, group="prior", var_names="the", filter_vars="like")
|
|
1647
|
-
assert {} == prior.attrs
|
|
1648
|
-
assert "theta" in prior.name
|
|
1649
|
-
|
|
1650
|
-
def test_keep_dataset(self):
|
|
1651
|
-
idata = load_arviz_data("centered_eight")
|
|
1652
|
-
prior = extract(
|
|
1653
|
-
idata, group="prior", var_names="the", filter_vars="like", keep_dataset=True
|
|
1654
|
-
)
|
|
1655
|
-
assert prior.attrs == idata.prior.attrs
|
|
1656
|
-
assert "theta" in prior.data_vars
|
|
1657
|
-
assert "mu" not in prior.data_vars
|
|
1658
|
-
|
|
1659
|
-
def test_subset_samples(self):
|
|
1660
|
-
idata = load_arviz_data("centered_eight")
|
|
1661
|
-
post = extract(idata, num_samples=10)
|
|
1662
|
-
assert post.sizes["sample"] == 10
|
|
1663
|
-
assert post.attrs == idata.posterior.attrs
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
def test_convert_to_inference_data_with_array_like():
|
|
1667
|
-
class ArrayLike:
|
|
1668
|
-
def __init__(self, data):
|
|
1669
|
-
self._data = np.asarray(data)
|
|
1670
|
-
|
|
1671
|
-
def __array__(self):
|
|
1672
|
-
return self._data
|
|
1673
|
-
|
|
1674
|
-
array_like = ArrayLike(np.random.randn(4, 100))
|
|
1675
|
-
idata = convert_to_inference_data(array_like, group="posterior")
|
|
1676
|
-
|
|
1677
|
-
assert hasattr(idata, "posterior")
|
|
1678
|
-
assert "x" in idata.posterior.data_vars
|
|
1679
|
-
assert idata.posterior["x"].shape == (4, 100)
|