arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,398 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name
|
|
2
|
-
# pylint: disable=too-many-lines
|
|
3
|
-
import os
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import pytest
|
|
7
|
-
|
|
8
|
-
from ... import from_cmdstan
|
|
9
|
-
|
|
10
|
-
from ..helpers import check_multiple_attrs
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class TestDataCmdStan:
|
|
14
|
-
@pytest.fixture(scope="session")
|
|
15
|
-
def data_directory(self):
|
|
16
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
17
|
-
data_directory = os.path.join(here, "..", "saved_models")
|
|
18
|
-
return data_directory
|
|
19
|
-
|
|
20
|
-
@pytest.fixture(scope="class")
|
|
21
|
-
def paths(self, data_directory):
|
|
22
|
-
paths = {
|
|
23
|
-
"no_warmup": [
|
|
24
|
-
os.path.join(data_directory, "cmdstan/output_no_warmup1.csv"),
|
|
25
|
-
os.path.join(data_directory, "cmdstan/output_no_warmup2.csv"),
|
|
26
|
-
os.path.join(data_directory, "cmdstan/output_no_warmup3.csv"),
|
|
27
|
-
os.path.join(data_directory, "cmdstan/output_no_warmup4.csv"),
|
|
28
|
-
],
|
|
29
|
-
"warmup": [
|
|
30
|
-
os.path.join(data_directory, "cmdstan/output_warmup1.csv"),
|
|
31
|
-
os.path.join(data_directory, "cmdstan/output_warmup2.csv"),
|
|
32
|
-
os.path.join(data_directory, "cmdstan/output_warmup3.csv"),
|
|
33
|
-
os.path.join(data_directory, "cmdstan/output_warmup4.csv"),
|
|
34
|
-
],
|
|
35
|
-
"no_warmup_glob": os.path.join(data_directory, "cmdstan/output_no_warmup[0-9].csv"),
|
|
36
|
-
"warmup_glob": os.path.join(data_directory, "cmdstan/output_warmup[0-9].csv"),
|
|
37
|
-
"eight_schools_glob": os.path.join(
|
|
38
|
-
data_directory, "cmdstan/eight_schools_output[0-9].csv"
|
|
39
|
-
),
|
|
40
|
-
"eight_schools": [
|
|
41
|
-
os.path.join(data_directory, "cmdstan/eight_schools_output1.csv"),
|
|
42
|
-
os.path.join(data_directory, "cmdstan/eight_schools_output2.csv"),
|
|
43
|
-
os.path.join(data_directory, "cmdstan/eight_schools_output3.csv"),
|
|
44
|
-
os.path.join(data_directory, "cmdstan/eight_schools_output4.csv"),
|
|
45
|
-
],
|
|
46
|
-
}
|
|
47
|
-
return paths
|
|
48
|
-
|
|
49
|
-
@pytest.fixture(scope="class")
|
|
50
|
-
def observed_data_paths(self, data_directory):
|
|
51
|
-
observed_data_paths = [
|
|
52
|
-
os.path.join(data_directory, "cmdstan/eight_schools.data.R"),
|
|
53
|
-
os.path.join(data_directory, "cmdstan/example_stan.data.R"),
|
|
54
|
-
os.path.join(data_directory, "cmdstan/example_stan.json"),
|
|
55
|
-
]
|
|
56
|
-
|
|
57
|
-
return observed_data_paths
|
|
58
|
-
|
|
59
|
-
def get_inference_data(self, posterior, **kwargs):
|
|
60
|
-
return from_cmdstan(posterior=posterior, **kwargs)
|
|
61
|
-
|
|
62
|
-
def test_sample_stats(self, paths):
|
|
63
|
-
for key, path in paths.items():
|
|
64
|
-
if "missing" in key:
|
|
65
|
-
continue
|
|
66
|
-
inference_data = self.get_inference_data(path)
|
|
67
|
-
assert hasattr(inference_data, "sample_stats")
|
|
68
|
-
assert "step_size" in inference_data.sample_stats.attrs
|
|
69
|
-
assert inference_data.sample_stats.attrs["step_size"] == "stepsize"
|
|
70
|
-
|
|
71
|
-
def test_inference_data_shapes(self, paths):
|
|
72
|
-
"""Assert that shapes are transformed correctly"""
|
|
73
|
-
for key, path in paths.items():
|
|
74
|
-
if "eight" in key or "missing" in key:
|
|
75
|
-
continue
|
|
76
|
-
inference_data = self.get_inference_data(path)
|
|
77
|
-
test_dict = {"posterior": ["x", "y", "Z"]}
|
|
78
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
79
|
-
assert not fails
|
|
80
|
-
assert inference_data.posterior["y"].shape == (4, 100)
|
|
81
|
-
assert inference_data.posterior["x"].shape == (4, 100, 3)
|
|
82
|
-
assert inference_data.posterior["Z"].shape == (4, 100, 4, 6)
|
|
83
|
-
dims = ["chain", "draw"]
|
|
84
|
-
y_mean_true = 0
|
|
85
|
-
y_mean = inference_data.posterior["y"].mean(dim=dims)
|
|
86
|
-
assert np.isclose(y_mean, y_mean_true, atol=1e-1)
|
|
87
|
-
x_mean_true = np.array([1, 2, 3])
|
|
88
|
-
x_mean = inference_data.posterior["x"].mean(dim=dims)
|
|
89
|
-
assert np.isclose(x_mean, x_mean_true, atol=1e-1).all()
|
|
90
|
-
Z_mean_true = np.array([1, 2, 3, 4])
|
|
91
|
-
Z_mean = inference_data.posterior["Z"].mean(dim=dims).mean(axis=1)
|
|
92
|
-
assert np.isclose(Z_mean, Z_mean_true, atol=7e-1).all()
|
|
93
|
-
assert "comments" in inference_data.posterior.attrs
|
|
94
|
-
|
|
95
|
-
def test_inference_data_input_types1(self, paths, observed_data_paths):
|
|
96
|
-
"""Check input types
|
|
97
|
-
|
|
98
|
-
posterior --> str, list of str
|
|
99
|
-
prior --> str, list of str
|
|
100
|
-
posterior_predictive --> str, variable in posterior
|
|
101
|
-
observed_data --> Rdump format
|
|
102
|
-
observed_data_var --> str, variable
|
|
103
|
-
log_likelihood --> str
|
|
104
|
-
coords --> one to many
|
|
105
|
-
dims --> one to many
|
|
106
|
-
"""
|
|
107
|
-
for key, path in paths.items():
|
|
108
|
-
if "eight" not in key:
|
|
109
|
-
continue
|
|
110
|
-
inference_data = self.get_inference_data(
|
|
111
|
-
posterior=path,
|
|
112
|
-
posterior_predictive="y_hat",
|
|
113
|
-
predictions="y_hat",
|
|
114
|
-
prior=path,
|
|
115
|
-
prior_predictive="y_hat",
|
|
116
|
-
observed_data=observed_data_paths[0],
|
|
117
|
-
observed_data_var="y",
|
|
118
|
-
constant_data=observed_data_paths[0],
|
|
119
|
-
constant_data_var="y",
|
|
120
|
-
predictions_constant_data=observed_data_paths[0],
|
|
121
|
-
predictions_constant_data_var="y",
|
|
122
|
-
log_likelihood="log_lik",
|
|
123
|
-
coords={"school": np.arange(8)},
|
|
124
|
-
dims={
|
|
125
|
-
"theta": ["school"],
|
|
126
|
-
"y": ["school"],
|
|
127
|
-
"log_lik": ["school"],
|
|
128
|
-
"y_hat": ["school"],
|
|
129
|
-
"eta": ["school"],
|
|
130
|
-
},
|
|
131
|
-
)
|
|
132
|
-
test_dict = {
|
|
133
|
-
"posterior": ["mu", "tau", "theta_tilde", "theta"],
|
|
134
|
-
"posterior_predictive": ["y_hat"],
|
|
135
|
-
"predictions": ["y_hat"],
|
|
136
|
-
"prior": ["mu", "tau", "theta_tilde", "theta"],
|
|
137
|
-
"prior_predictive": ["y_hat"],
|
|
138
|
-
"sample_stats": ["diverging"],
|
|
139
|
-
"observed_data": ["y"],
|
|
140
|
-
"constant_data": ["y"],
|
|
141
|
-
"predictions_constant_data": ["y"],
|
|
142
|
-
"log_likelihood": ["log_lik"],
|
|
143
|
-
}
|
|
144
|
-
if "output_warmup" in path:
|
|
145
|
-
test_dict.update({"warmup_posterior": ["mu", "tau", "theta_tilde", "theta"]})
|
|
146
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
147
|
-
assert not fails
|
|
148
|
-
|
|
149
|
-
def test_inference_data_input_types2(self, paths, observed_data_paths):
|
|
150
|
-
"""Check input types (change, see earlier)
|
|
151
|
-
|
|
152
|
-
posterior_predictive --> List[str], variable in posterior
|
|
153
|
-
observed_data_var --> List[str], variable
|
|
154
|
-
"""
|
|
155
|
-
for key, path in paths.items():
|
|
156
|
-
if "eight" not in key:
|
|
157
|
-
continue
|
|
158
|
-
inference_data = self.get_inference_data(
|
|
159
|
-
posterior=path,
|
|
160
|
-
posterior_predictive=["y_hat"],
|
|
161
|
-
predictions=["y_hat"],
|
|
162
|
-
prior=path,
|
|
163
|
-
prior_predictive=["y_hat"],
|
|
164
|
-
observed_data=observed_data_paths[0],
|
|
165
|
-
observed_data_var=["y"],
|
|
166
|
-
constant_data=observed_data_paths[0],
|
|
167
|
-
constant_data_var=["y"],
|
|
168
|
-
predictions_constant_data=observed_data_paths[0],
|
|
169
|
-
predictions_constant_data_var=["y"],
|
|
170
|
-
coords={"school": np.arange(8)},
|
|
171
|
-
dims={
|
|
172
|
-
"theta": ["school"],
|
|
173
|
-
"y": ["school"],
|
|
174
|
-
"log_lik": ["school"],
|
|
175
|
-
"y_hat": ["school"],
|
|
176
|
-
"eta": ["school"],
|
|
177
|
-
},
|
|
178
|
-
dtypes={"theta": np.int64},
|
|
179
|
-
)
|
|
180
|
-
test_dict = {
|
|
181
|
-
"posterior": ["mu", "tau", "theta_tilde", "theta"],
|
|
182
|
-
"posterior_predictive": ["y_hat"],
|
|
183
|
-
"predictions": ["y_hat"],
|
|
184
|
-
"prior": ["mu", "tau", "theta_tilde", "theta"],
|
|
185
|
-
"prior_predictive": ["y_hat"],
|
|
186
|
-
"sample_stats": ["diverging"],
|
|
187
|
-
"observed_data": ["y"],
|
|
188
|
-
"constant_data": ["y"],
|
|
189
|
-
"predictions_constant_data": ["y"],
|
|
190
|
-
"log_likelihood": ["log_lik"],
|
|
191
|
-
}
|
|
192
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
193
|
-
assert not fails
|
|
194
|
-
assert isinstance(inference_data.posterior.theta.data.flat[0], np.integer)
|
|
195
|
-
|
|
196
|
-
def test_inference_data_input_types3(self, paths, observed_data_paths):
|
|
197
|
-
"""Check input types (change, see earlier)
|
|
198
|
-
|
|
199
|
-
posterior_predictive --> str, csv file
|
|
200
|
-
coords --> one to many + one to one (default dim)
|
|
201
|
-
dims --> one to many
|
|
202
|
-
"""
|
|
203
|
-
for key, path in paths.items():
|
|
204
|
-
if "eight" not in key:
|
|
205
|
-
continue
|
|
206
|
-
post_pred = paths["eight_schools_glob"]
|
|
207
|
-
inference_data = self.get_inference_data(
|
|
208
|
-
posterior=path,
|
|
209
|
-
posterior_predictive=post_pred,
|
|
210
|
-
prior=path,
|
|
211
|
-
prior_predictive=post_pred,
|
|
212
|
-
observed_data=observed_data_paths[0],
|
|
213
|
-
observed_data_var=["y"],
|
|
214
|
-
log_likelihood=["log_lik", "y_hat"],
|
|
215
|
-
coords={
|
|
216
|
-
"school": np.arange(8),
|
|
217
|
-
"log_lik_dim_0": np.arange(8),
|
|
218
|
-
"y_hat": np.arange(8),
|
|
219
|
-
},
|
|
220
|
-
dims={"theta": ["school"], "y": ["school"], "y_hat": ["school"], "eta": ["school"]},
|
|
221
|
-
)
|
|
222
|
-
test_dict = {
|
|
223
|
-
"posterior": ["mu", "tau", "theta_tilde", "theta"],
|
|
224
|
-
"sample_stats": ["diverging"],
|
|
225
|
-
"prior": ["mu", "tau", "theta_tilde", "theta"],
|
|
226
|
-
"prior_predictive": ["y_hat"],
|
|
227
|
-
"observed_data": ["y"],
|
|
228
|
-
"posterior_predictive": ["y_hat"],
|
|
229
|
-
"log_likelihood": ["log_lik", "y_hat"],
|
|
230
|
-
}
|
|
231
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
232
|
-
assert not fails
|
|
233
|
-
|
|
234
|
-
def test_inference_data_input_types4(self, paths):
|
|
235
|
-
"""Check input types (change, see earlier)
|
|
236
|
-
|
|
237
|
-
coords --> one to many + one to one (non-default dim)
|
|
238
|
-
dims --> one to many + one to one
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
paths_ = paths["no_warmup"]
|
|
242
|
-
for path in [paths_, paths_[0]]:
|
|
243
|
-
inference_data = self.get_inference_data(
|
|
244
|
-
posterior=path,
|
|
245
|
-
posterior_predictive=path,
|
|
246
|
-
prior=path,
|
|
247
|
-
prior_predictive=path,
|
|
248
|
-
observed_data=None,
|
|
249
|
-
observed_data_var=None,
|
|
250
|
-
log_likelihood=False,
|
|
251
|
-
coords={"rand": np.arange(3)},
|
|
252
|
-
dims={"x": ["rand"]},
|
|
253
|
-
)
|
|
254
|
-
test_dict = {
|
|
255
|
-
"posterior": ["x", "y", "Z"],
|
|
256
|
-
"prior": ["x", "y", "Z"],
|
|
257
|
-
"prior_predictive": ["x", "y", "Z"],
|
|
258
|
-
"sample_stats": ["lp"],
|
|
259
|
-
"sample_stats_prior": ["lp"],
|
|
260
|
-
"posterior_predictive": ["x", "y", "Z"],
|
|
261
|
-
"~log_likelihood": [""],
|
|
262
|
-
}
|
|
263
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
264
|
-
assert not fails
|
|
265
|
-
|
|
266
|
-
def test_inference_data_input_types5(self, paths, observed_data_paths):
|
|
267
|
-
"""Check input types (change, see earlier)
|
|
268
|
-
|
|
269
|
-
posterior_predictive is None
|
|
270
|
-
prior_predictive is None
|
|
271
|
-
"""
|
|
272
|
-
for key, path in paths.items():
|
|
273
|
-
if "eight" not in key:
|
|
274
|
-
continue
|
|
275
|
-
inference_data = self.get_inference_data(
|
|
276
|
-
posterior=path,
|
|
277
|
-
posterior_predictive=None,
|
|
278
|
-
prior=path,
|
|
279
|
-
prior_predictive=None,
|
|
280
|
-
observed_data=observed_data_paths[0],
|
|
281
|
-
observed_data_var=["y"],
|
|
282
|
-
log_likelihood=["y_hat"],
|
|
283
|
-
coords={"school": np.arange(8), "log_lik_dim": np.arange(8)},
|
|
284
|
-
dims={
|
|
285
|
-
"theta": ["school"],
|
|
286
|
-
"y": ["school"],
|
|
287
|
-
"log_lik": ["log_lik_dim"],
|
|
288
|
-
"y_hat": ["school"],
|
|
289
|
-
"eta": ["school"],
|
|
290
|
-
},
|
|
291
|
-
)
|
|
292
|
-
test_dict = {
|
|
293
|
-
"posterior": ["mu", "tau", "theta_tilde", "theta", "log_lik"],
|
|
294
|
-
"prior": ["mu", "tau", "theta_tilde", "theta"],
|
|
295
|
-
"log_likelihood": ["y_hat", "~log_lik"],
|
|
296
|
-
"observed_data": ["y"],
|
|
297
|
-
"sample_stats_prior": ["lp"],
|
|
298
|
-
}
|
|
299
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
300
|
-
assert not fails
|
|
301
|
-
|
|
302
|
-
def test_inference_data_input_types6(self, paths, observed_data_paths):
|
|
303
|
-
"""Check input types (change, see earlier)
|
|
304
|
-
|
|
305
|
-
log_likelihood --> dict
|
|
306
|
-
"""
|
|
307
|
-
for key, path in paths.items():
|
|
308
|
-
if "eight" not in key:
|
|
309
|
-
continue
|
|
310
|
-
post_pred = paths["eight_schools_glob"]
|
|
311
|
-
inference_data = self.get_inference_data(
|
|
312
|
-
posterior=path,
|
|
313
|
-
posterior_predictive=post_pred,
|
|
314
|
-
prior=path,
|
|
315
|
-
prior_predictive=post_pred,
|
|
316
|
-
observed_data=observed_data_paths[0],
|
|
317
|
-
observed_data_var=["y"],
|
|
318
|
-
log_likelihood={"y": "log_lik"},
|
|
319
|
-
coords={
|
|
320
|
-
"school": np.arange(8),
|
|
321
|
-
"log_lik_dim_0": np.arange(8),
|
|
322
|
-
"y_hat": np.arange(8),
|
|
323
|
-
},
|
|
324
|
-
dims={"theta": ["school"], "y": ["school"], "y_hat": ["school"], "eta": ["school"]},
|
|
325
|
-
)
|
|
326
|
-
test_dict = {
|
|
327
|
-
"posterior": ["mu", "tau", "theta_tilde", "theta"],
|
|
328
|
-
"sample_stats": ["diverging"],
|
|
329
|
-
"prior": ["mu", "tau", "theta_tilde", "theta"],
|
|
330
|
-
"prior_predictive": ["y_hat"],
|
|
331
|
-
"observed_data": ["y"],
|
|
332
|
-
"posterior_predictive": ["y_hat"],
|
|
333
|
-
"log_likelihood": ["y", "~log_lik"],
|
|
334
|
-
}
|
|
335
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
336
|
-
assert not fails
|
|
337
|
-
|
|
338
|
-
def test_inference_data_observed_data1(self, observed_data_paths):
|
|
339
|
-
"""Read Rdump/JSON, check shapes are correct
|
|
340
|
-
|
|
341
|
-
All variables
|
|
342
|
-
"""
|
|
343
|
-
# Check the Rdump (idx=1) and equivalent JSON data file (idx=2)
|
|
344
|
-
for data_idx in (1, 2):
|
|
345
|
-
path = observed_data_paths[data_idx]
|
|
346
|
-
inference_data = self.get_inference_data(posterior=None, observed_data=path)
|
|
347
|
-
assert hasattr(inference_data, "observed_data")
|
|
348
|
-
assert len(inference_data.observed_data.data_vars) == 3
|
|
349
|
-
assert inference_data.observed_data["x"].shape == (1,)
|
|
350
|
-
assert inference_data.observed_data["x"][0] == 1
|
|
351
|
-
assert inference_data.observed_data["y"].shape == (3,)
|
|
352
|
-
assert inference_data.observed_data["Z"].shape == (4, 5)
|
|
353
|
-
|
|
354
|
-
def test_inference_data_observed_data2(self, observed_data_paths):
|
|
355
|
-
"""Read Rdump/JSON, check shapes are correct
|
|
356
|
-
|
|
357
|
-
One variable as str
|
|
358
|
-
"""
|
|
359
|
-
# Check the Rdump (idx=1) and equivalent JSON data file (idx=2)
|
|
360
|
-
for data_idx in (1, 2):
|
|
361
|
-
path = observed_data_paths[data_idx]
|
|
362
|
-
inference_data = self.get_inference_data(
|
|
363
|
-
posterior=None, observed_data=path, observed_data_var="x"
|
|
364
|
-
)
|
|
365
|
-
assert hasattr(inference_data, "observed_data")
|
|
366
|
-
assert len(inference_data.observed_data.data_vars) == 1
|
|
367
|
-
assert inference_data.observed_data["x"].shape == (1,)
|
|
368
|
-
|
|
369
|
-
def test_inference_data_observed_data3(self, observed_data_paths):
|
|
370
|
-
"""Read Rdump/JSON, check shapes are correct
|
|
371
|
-
|
|
372
|
-
One variable as a list
|
|
373
|
-
"""
|
|
374
|
-
# Check the Rdump (idx=1) and equivalent JSON data file (idx=2)
|
|
375
|
-
for data_idx in (1, 2):
|
|
376
|
-
path = observed_data_paths[data_idx]
|
|
377
|
-
inference_data = self.get_inference_data(
|
|
378
|
-
posterior=None, observed_data=path, observed_data_var=["x"]
|
|
379
|
-
)
|
|
380
|
-
assert hasattr(inference_data, "observed_data")
|
|
381
|
-
assert len(inference_data.observed_data.data_vars) == 1
|
|
382
|
-
assert inference_data.observed_data["x"].shape == (1,)
|
|
383
|
-
|
|
384
|
-
def test_inference_data_observed_data4(self, observed_data_paths):
|
|
385
|
-
"""Read Rdump/JSON, check shapes are correct
|
|
386
|
-
|
|
387
|
-
Many variables as list
|
|
388
|
-
"""
|
|
389
|
-
# Check the Rdump (idx=1) and equivalent JSON data file (idx=2)
|
|
390
|
-
for data_idx in (1, 2):
|
|
391
|
-
path = observed_data_paths[data_idx]
|
|
392
|
-
inference_data = self.get_inference_data(
|
|
393
|
-
posterior=None, observed_data=path, observed_data_var=["y", "Z"]
|
|
394
|
-
)
|
|
395
|
-
assert hasattr(inference_data, "observed_data")
|
|
396
|
-
assert len(inference_data.observed_data.data_vars) == 2
|
|
397
|
-
assert inference_data.observed_data["y"].shape == (3,)
|
|
398
|
-
assert inference_data.observed_data["Z"].shape == (4, 5)
|