arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/tests/helpers.py
DELETED
|
@@ -1,677 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name, comparison-with-callable, protected-access
|
|
2
|
-
"""Test helper functions."""
|
|
3
|
-
import gzip
|
|
4
|
-
import importlib
|
|
5
|
-
import logging
|
|
6
|
-
import os
|
|
7
|
-
import sys
|
|
8
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
9
|
-
import warnings
|
|
10
|
-
from contextlib import contextmanager
|
|
11
|
-
|
|
12
|
-
import cloudpickle
|
|
13
|
-
import numpy as np
|
|
14
|
-
import pytest
|
|
15
|
-
from _pytest.outcomes import Skipped
|
|
16
|
-
from packaging.version import Version
|
|
17
|
-
|
|
18
|
-
from ..data import InferenceData, from_dict
|
|
19
|
-
|
|
20
|
-
_log = logging.getLogger(__name__)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class RandomVariableTestClass:
|
|
24
|
-
"""Example class for random variables."""
|
|
25
|
-
|
|
26
|
-
def __init__(self, name):
|
|
27
|
-
self.name = name
|
|
28
|
-
|
|
29
|
-
def __repr__(self):
|
|
30
|
-
"""Return argument to constructor as string representation."""
|
|
31
|
-
return self.name
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@contextmanager
|
|
35
|
-
def does_not_warn(warning=Warning):
|
|
36
|
-
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
37
|
-
warnings.simplefilter("always")
|
|
38
|
-
yield
|
|
39
|
-
for w in caught_warnings:
|
|
40
|
-
if issubclass(w.category, warning):
|
|
41
|
-
raise AssertionError(
|
|
42
|
-
f"Expected no {warning.__name__} but caught warning with message: {w.message}"
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@pytest.fixture(scope="module")
|
|
47
|
-
def eight_schools_params():
|
|
48
|
-
"""Share setup for eight schools."""
|
|
49
|
-
return {
|
|
50
|
-
"J": 8,
|
|
51
|
-
"y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
|
|
52
|
-
"sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@pytest.fixture(scope="module")
|
|
57
|
-
def draws():
|
|
58
|
-
"""Share default draw count."""
|
|
59
|
-
return 500
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@pytest.fixture(scope="module")
|
|
63
|
-
def chains():
|
|
64
|
-
"""Share default chain count."""
|
|
65
|
-
return 2
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def create_model(seed=10, transpose=False):
|
|
69
|
-
"""Create model with fake data."""
|
|
70
|
-
np.random.seed(seed)
|
|
71
|
-
nchains = 4
|
|
72
|
-
ndraws = 500
|
|
73
|
-
data = {
|
|
74
|
-
"J": 8,
|
|
75
|
-
"y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
|
|
76
|
-
"sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
|
|
77
|
-
}
|
|
78
|
-
posterior = {
|
|
79
|
-
"mu": np.random.randn(nchains, ndraws),
|
|
80
|
-
"tau": abs(np.random.randn(nchains, ndraws)),
|
|
81
|
-
"eta": np.random.randn(nchains, ndraws, data["J"]),
|
|
82
|
-
"theta": np.random.randn(nchains, ndraws, data["J"]),
|
|
83
|
-
}
|
|
84
|
-
posterior_predictive = {"y": np.random.randn(nchains, ndraws, len(data["y"]))}
|
|
85
|
-
sample_stats = {
|
|
86
|
-
"energy": np.random.randn(nchains, ndraws),
|
|
87
|
-
"diverging": np.random.randn(nchains, ndraws) > 0.90,
|
|
88
|
-
"max_depth": np.random.randn(nchains, ndraws) > 0.90,
|
|
89
|
-
}
|
|
90
|
-
log_likelihood = {
|
|
91
|
-
"y": np.random.randn(nchains, ndraws, data["J"]),
|
|
92
|
-
}
|
|
93
|
-
prior = {
|
|
94
|
-
"mu": np.random.randn(nchains, ndraws) / 2,
|
|
95
|
-
"tau": abs(np.random.randn(nchains, ndraws)) / 2,
|
|
96
|
-
"eta": np.random.randn(nchains, ndraws, data["J"]) / 2,
|
|
97
|
-
"theta": np.random.randn(nchains, ndraws, data["J"]) / 2,
|
|
98
|
-
}
|
|
99
|
-
prior_predictive = {"y": np.random.randn(nchains, ndraws, len(data["y"])) / 2}
|
|
100
|
-
sample_stats_prior = {
|
|
101
|
-
"energy": np.random.randn(nchains, ndraws),
|
|
102
|
-
"diverging": (np.random.randn(nchains, ndraws) > 0.95).astype(int),
|
|
103
|
-
}
|
|
104
|
-
model = from_dict(
|
|
105
|
-
posterior=posterior,
|
|
106
|
-
posterior_predictive=posterior_predictive,
|
|
107
|
-
sample_stats=sample_stats,
|
|
108
|
-
log_likelihood=log_likelihood,
|
|
109
|
-
prior=prior,
|
|
110
|
-
prior_predictive=prior_predictive,
|
|
111
|
-
sample_stats_prior=sample_stats_prior,
|
|
112
|
-
observed_data={"y": data["y"]},
|
|
113
|
-
dims={
|
|
114
|
-
"y": ["obs_dim"],
|
|
115
|
-
"log_likelihood": ["obs_dim"],
|
|
116
|
-
"theta": ["school"],
|
|
117
|
-
"eta": ["school"],
|
|
118
|
-
},
|
|
119
|
-
coords={"obs_dim": range(data["J"])},
|
|
120
|
-
)
|
|
121
|
-
if transpose:
|
|
122
|
-
for group in model._groups:
|
|
123
|
-
group_dataset = getattr(model, group)
|
|
124
|
-
if all(dim in group_dataset.dims for dim in ("draw", "chain")):
|
|
125
|
-
setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
|
|
126
|
-
return model
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def create_multidimensional_model(seed=10, transpose=False):
|
|
130
|
-
"""Create model with fake data."""
|
|
131
|
-
np.random.seed(seed)
|
|
132
|
-
nchains = 4
|
|
133
|
-
ndraws = 500
|
|
134
|
-
ndim1 = 5
|
|
135
|
-
ndim2 = 7
|
|
136
|
-
data = {
|
|
137
|
-
"y": np.random.normal(size=(ndim1, ndim2)),
|
|
138
|
-
"sigma": np.random.normal(size=(ndim1, ndim2)),
|
|
139
|
-
}
|
|
140
|
-
posterior = {
|
|
141
|
-
"mu": np.random.randn(nchains, ndraws),
|
|
142
|
-
"tau": abs(np.random.randn(nchains, ndraws)),
|
|
143
|
-
"eta": np.random.randn(nchains, ndraws, ndim1, ndim2),
|
|
144
|
-
"theta": np.random.randn(nchains, ndraws, ndim1, ndim2),
|
|
145
|
-
}
|
|
146
|
-
posterior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2)}
|
|
147
|
-
sample_stats = {
|
|
148
|
-
"energy": np.random.randn(nchains, ndraws),
|
|
149
|
-
"diverging": np.random.randn(nchains, ndraws) > 0.90,
|
|
150
|
-
}
|
|
151
|
-
log_likelihood = {
|
|
152
|
-
"y": np.random.randn(nchains, ndraws, ndim1, ndim2),
|
|
153
|
-
}
|
|
154
|
-
prior = {
|
|
155
|
-
"mu": np.random.randn(nchains, ndraws) / 2,
|
|
156
|
-
"tau": abs(np.random.randn(nchains, ndraws)) / 2,
|
|
157
|
-
"eta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2,
|
|
158
|
-
"theta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2,
|
|
159
|
-
}
|
|
160
|
-
prior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2}
|
|
161
|
-
sample_stats_prior = {
|
|
162
|
-
"energy": np.random.randn(nchains, ndraws),
|
|
163
|
-
"diverging": (np.random.randn(nchains, ndraws) > 0.95).astype(int),
|
|
164
|
-
}
|
|
165
|
-
model = from_dict(
|
|
166
|
-
posterior=posterior,
|
|
167
|
-
posterior_predictive=posterior_predictive,
|
|
168
|
-
sample_stats=sample_stats,
|
|
169
|
-
log_likelihood=log_likelihood,
|
|
170
|
-
prior=prior,
|
|
171
|
-
prior_predictive=prior_predictive,
|
|
172
|
-
sample_stats_prior=sample_stats_prior,
|
|
173
|
-
observed_data={"y": data["y"]},
|
|
174
|
-
dims={"y": ["dim1", "dim2"], "log_likelihood": ["dim1", "dim2"]},
|
|
175
|
-
coords={"dim1": range(ndim1), "dim2": range(ndim2)},
|
|
176
|
-
)
|
|
177
|
-
if transpose:
|
|
178
|
-
for group in model._groups:
|
|
179
|
-
group_dataset = getattr(model, group)
|
|
180
|
-
if all(dim in group_dataset.dims for dim in ("draw", "chain")):
|
|
181
|
-
setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
|
|
182
|
-
return model
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
def create_data_random(groups=None, seed=10):
|
|
186
|
-
"""Create InferenceData object using random data."""
|
|
187
|
-
if groups is None:
|
|
188
|
-
groups = ["posterior", "sample_stats", "observed_data", "posterior_predictive"]
|
|
189
|
-
rng = np.random.default_rng(seed)
|
|
190
|
-
data = rng.normal(size=(4, 500, 8))
|
|
191
|
-
idata_dict = dict(
|
|
192
|
-
posterior={"a": data[..., 0], "b": data},
|
|
193
|
-
sample_stats={"a": data[..., 0], "b": data},
|
|
194
|
-
observed_data={"b": data[0, 0, :]},
|
|
195
|
-
posterior_predictive={"a": data[..., 0], "b": data},
|
|
196
|
-
prior={"a": data[..., 0], "b": data},
|
|
197
|
-
prior_predictive={"a": data[..., 0], "b": data},
|
|
198
|
-
warmup_posterior={"a": data[..., 0], "b": data},
|
|
199
|
-
warmup_posterior_predictive={"a": data[..., 0], "b": data},
|
|
200
|
-
warmup_prior={"a": data[..., 0], "b": data},
|
|
201
|
-
)
|
|
202
|
-
idata = from_dict(
|
|
203
|
-
**{group: ary for group, ary in idata_dict.items() if group in groups}, save_warmup=True
|
|
204
|
-
)
|
|
205
|
-
return idata
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
@pytest.fixture()
|
|
209
|
-
def data_random():
|
|
210
|
-
"""Fixture containing InferenceData object using random data."""
|
|
211
|
-
idata = create_data_random()
|
|
212
|
-
return idata
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
@pytest.fixture(scope="module")
|
|
216
|
-
def models():
|
|
217
|
-
"""Fixture containing 2 mock inference data instances for testing."""
|
|
218
|
-
# blank line to keep black and pydocstyle happy
|
|
219
|
-
|
|
220
|
-
class Models:
|
|
221
|
-
model_1 = create_model(seed=10)
|
|
222
|
-
model_2 = create_model(seed=11, transpose=True)
|
|
223
|
-
|
|
224
|
-
return Models()
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
@pytest.fixture(scope="module")
|
|
228
|
-
def multidim_models():
|
|
229
|
-
"""Fixture containing 2 mock inference data instances with multidimensional data for testing."""
|
|
230
|
-
# blank line to keep black and pydocstyle happy
|
|
231
|
-
|
|
232
|
-
class Models:
|
|
233
|
-
model_1 = create_multidimensional_model(seed=10)
|
|
234
|
-
model_2 = create_multidimensional_model(seed=11, transpose=True)
|
|
235
|
-
|
|
236
|
-
return Models()
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def check_multiple_attrs(
|
|
240
|
-
test_dict: Dict[str, List[str]], parent: InferenceData
|
|
241
|
-
) -> List[Union[str, Tuple[str, str]]]:
|
|
242
|
-
"""Perform multiple hasattr checks on InferenceData objects.
|
|
243
|
-
|
|
244
|
-
It is thought to first check if the parent object contains a given dataset,
|
|
245
|
-
and then (if present) check the attributes of the dataset.
|
|
246
|
-
|
|
247
|
-
Given the output of the function, all mismatches between expectation and reality can
|
|
248
|
-
be retrieved: a single string indicates a group mismatch and a tuple of strings
|
|
249
|
-
``(group, var)`` indicates a mismatch in the variable ``var`` of ``group``.
|
|
250
|
-
|
|
251
|
-
Parameters
|
|
252
|
-
----------
|
|
253
|
-
test_dict: dict of {str : list of str}
|
|
254
|
-
Its structure should be `{dataset1_name: [var1, var2], dataset2_name: [var]}`.
|
|
255
|
-
A ``~`` at the beginning of a dataset or variable name indicates the name NOT
|
|
256
|
-
being present must be asserted.
|
|
257
|
-
parent: InferenceData
|
|
258
|
-
InferenceData object on which to check the attributes.
|
|
259
|
-
|
|
260
|
-
Returns
|
|
261
|
-
-------
|
|
262
|
-
list
|
|
263
|
-
List containing the failed checks. It will contain either the dataset_name or a
|
|
264
|
-
tuple (dataset_name, var) for all non present attributes.
|
|
265
|
-
|
|
266
|
-
Examples
|
|
267
|
-
--------
|
|
268
|
-
The output below indicates that ``posterior`` group was expected but not found, and
|
|
269
|
-
variables ``a`` and ``b``:
|
|
270
|
-
|
|
271
|
-
["posterior", ("prior", "a"), ("prior", "b")]
|
|
272
|
-
|
|
273
|
-
Another example could be the following:
|
|
274
|
-
|
|
275
|
-
[("posterior", "a"), "~observed_data", ("sample_stats", "~log_likelihood")]
|
|
276
|
-
|
|
277
|
-
In this case, the output indicates that variable ``a`` was not found in ``posterior``
|
|
278
|
-
as it was expected, however, in the other two cases, the preceding ``~`` (kept from the
|
|
279
|
-
input negation notation) indicates that ``observed_data`` group should not be present
|
|
280
|
-
but was found in the InferenceData and that ``log_likelihood`` variable was found
|
|
281
|
-
in ``sample_stats``, also against what was expected.
|
|
282
|
-
|
|
283
|
-
"""
|
|
284
|
-
failed_attrs: List[Union[str, Tuple[str, str]]] = []
|
|
285
|
-
for dataset_name, attributes in test_dict.items():
|
|
286
|
-
if dataset_name.startswith("~"):
|
|
287
|
-
if hasattr(parent, dataset_name[1:]):
|
|
288
|
-
failed_attrs.append(dataset_name)
|
|
289
|
-
elif hasattr(parent, dataset_name):
|
|
290
|
-
dataset = getattr(parent, dataset_name)
|
|
291
|
-
for attribute in attributes:
|
|
292
|
-
if attribute.startswith("~"):
|
|
293
|
-
if hasattr(dataset, attribute[1:]):
|
|
294
|
-
failed_attrs.append((dataset_name, attribute))
|
|
295
|
-
elif not hasattr(dataset, attribute):
|
|
296
|
-
failed_attrs.append((dataset_name, attribute))
|
|
297
|
-
else:
|
|
298
|
-
failed_attrs.append(dataset_name)
|
|
299
|
-
return failed_attrs
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def emcee_version():
|
|
303
|
-
"""Check emcee version.
|
|
304
|
-
|
|
305
|
-
Returns
|
|
306
|
-
-------
|
|
307
|
-
int
|
|
308
|
-
Major version number
|
|
309
|
-
|
|
310
|
-
"""
|
|
311
|
-
import emcee
|
|
312
|
-
|
|
313
|
-
return int(emcee.__version__[0])
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
def needs_emcee3_func():
|
|
317
|
-
"""Check if emcee3 is required."""
|
|
318
|
-
# pylint: disable=invalid-name
|
|
319
|
-
needs_emcee3 = pytest.mark.skipif(emcee_version() < 3, reason="emcee3 required")
|
|
320
|
-
return needs_emcee3
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def _emcee_lnprior(theta):
|
|
324
|
-
"""Proper function to allow pickling."""
|
|
325
|
-
mu, tau, eta = theta[0], theta[1], theta[2:]
|
|
326
|
-
# Half-cauchy prior, hwhm=25
|
|
327
|
-
if tau < 0:
|
|
328
|
-
return -np.inf
|
|
329
|
-
prior_tau = -np.log(tau**2 + 25**2)
|
|
330
|
-
prior_mu = -((mu / 10) ** 2) # normal prior, loc=0, scale=10
|
|
331
|
-
prior_eta = -np.sum(eta**2) # normal prior, loc=0, scale=1
|
|
332
|
-
return prior_mu + prior_tau + prior_eta
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def _emcee_lnprob(theta, y, sigma):
|
|
336
|
-
"""Proper function to allow pickling."""
|
|
337
|
-
mu, tau, eta = theta[0], theta[1], theta[2:]
|
|
338
|
-
prior = _emcee_lnprior(theta)
|
|
339
|
-
like_vect = -(((mu + tau * eta - y) / sigma) ** 2)
|
|
340
|
-
like = np.sum(like_vect)
|
|
341
|
-
return like + prior, (like_vect, np.random.normal((mu + tau * eta), sigma))
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
def emcee_schools_model(data, draws, chains):
|
|
345
|
-
"""Schools model in emcee."""
|
|
346
|
-
import emcee
|
|
347
|
-
|
|
348
|
-
chains = 10 * chains # emcee is sad with too few walkers
|
|
349
|
-
y = data["y"]
|
|
350
|
-
sigma = data["sigma"]
|
|
351
|
-
J = data["J"] # pylint: disable=invalid-name
|
|
352
|
-
ndim = J + 2
|
|
353
|
-
|
|
354
|
-
pos = np.random.normal(size=(chains, ndim))
|
|
355
|
-
pos[:, 1] = np.absolute(pos[:, 1]) # pylint: disable=unsupported-assignment-operation
|
|
356
|
-
|
|
357
|
-
if emcee_version() < 3:
|
|
358
|
-
sampler = emcee.EnsembleSampler(chains, ndim, _emcee_lnprob, args=(y, sigma))
|
|
359
|
-
# pylint: enable=unexpected-keyword-arg
|
|
360
|
-
sampler.run_mcmc(pos, draws)
|
|
361
|
-
else:
|
|
362
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
363
|
-
data_directory = os.path.join(here, "saved_models")
|
|
364
|
-
filepath = os.path.join(data_directory, "reader_testfile.h5")
|
|
365
|
-
backend = emcee.backends.HDFBackend(filepath) # pylint: disable=no-member
|
|
366
|
-
backend.reset(chains, ndim)
|
|
367
|
-
# pylint: disable=unexpected-keyword-arg
|
|
368
|
-
sampler = emcee.EnsembleSampler(
|
|
369
|
-
chains, ndim, _emcee_lnprob, args=(y, sigma), backend=backend
|
|
370
|
-
)
|
|
371
|
-
# pylint: enable=unexpected-keyword-arg
|
|
372
|
-
sampler.run_mcmc(pos, draws, store=True)
|
|
373
|
-
return sampler
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
# pylint:disable=no-member,no-value-for-parameter,invalid-name
|
|
377
|
-
def _pyro_noncentered_model(J, sigma, y=None):
|
|
378
|
-
import pyro
|
|
379
|
-
import pyro.distributions as dist
|
|
380
|
-
|
|
381
|
-
mu = pyro.sample("mu", dist.Normal(0, 5))
|
|
382
|
-
tau = pyro.sample("tau", dist.HalfCauchy(5))
|
|
383
|
-
with pyro.plate("J", J):
|
|
384
|
-
eta = pyro.sample("eta", dist.Normal(0, 1))
|
|
385
|
-
theta = mu + tau * eta
|
|
386
|
-
return pyro.sample("obs", dist.Normal(theta, sigma), obs=y)
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
def pyro_noncentered_schools(data, draws, chains):
|
|
390
|
-
"""Non-centered eight schools implementation in Pyro."""
|
|
391
|
-
import torch
|
|
392
|
-
from pyro.infer import MCMC, NUTS
|
|
393
|
-
|
|
394
|
-
y = torch.from_numpy(data["y"]).float()
|
|
395
|
-
sigma = torch.from_numpy(data["sigma"]).float()
|
|
396
|
-
|
|
397
|
-
nuts_kernel = NUTS(_pyro_noncentered_model, jit_compile=True, ignore_jit_warnings=True)
|
|
398
|
-
posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
|
|
399
|
-
posterior.run(data["J"], sigma, y)
|
|
400
|
-
|
|
401
|
-
# This block lets the posterior be pickled
|
|
402
|
-
posterior.sampler = None
|
|
403
|
-
posterior.kernel.potential_fn = None
|
|
404
|
-
return posterior
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
# pylint:disable=no-member,no-value-for-parameter,invalid-name
|
|
408
|
-
def _numpyro_noncentered_model(J, sigma, y=None):
|
|
409
|
-
import numpyro
|
|
410
|
-
import numpyro.distributions as dist
|
|
411
|
-
|
|
412
|
-
mu = numpyro.sample("mu", dist.Normal(0, 5))
|
|
413
|
-
tau = numpyro.sample("tau", dist.HalfCauchy(5))
|
|
414
|
-
with numpyro.plate("J", J):
|
|
415
|
-
eta = numpyro.sample("eta", dist.Normal(0, 1))
|
|
416
|
-
theta = mu + tau * eta
|
|
417
|
-
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
def numpyro_schools_model(data, draws, chains):
|
|
421
|
-
"""Centered eight schools implementation in NumPyro."""
|
|
422
|
-
from jax.random import PRNGKey
|
|
423
|
-
from numpyro.infer import MCMC, NUTS
|
|
424
|
-
|
|
425
|
-
mcmc = MCMC(
|
|
426
|
-
NUTS(_numpyro_noncentered_model),
|
|
427
|
-
num_warmup=draws,
|
|
428
|
-
num_samples=draws,
|
|
429
|
-
num_chains=chains,
|
|
430
|
-
chain_method="sequential",
|
|
431
|
-
)
|
|
432
|
-
mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)
|
|
433
|
-
|
|
434
|
-
# This block lets the posterior be pickled
|
|
435
|
-
mcmc.sampler._sample_fn = None # pylint: disable=protected-access
|
|
436
|
-
mcmc.sampler._init_fn = None # pylint: disable=protected-access
|
|
437
|
-
mcmc.sampler._postprocess_fn = None # pylint: disable=protected-access
|
|
438
|
-
mcmc.sampler._potential_fn = None # pylint: disable=protected-access
|
|
439
|
-
mcmc.sampler._potential_fn_gen = None # pylint: disable=protected-access
|
|
440
|
-
mcmc._cache = {} # pylint: disable=protected-access
|
|
441
|
-
return mcmc
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
def pystan_noncentered_schools(data, draws, chains):
|
|
445
|
-
"""Non-centered eight schools implementation for pystan."""
|
|
446
|
-
schools_code = """
|
|
447
|
-
data {
|
|
448
|
-
int<lower=0> J;
|
|
449
|
-
array[J] real y;
|
|
450
|
-
array[J] real<lower=0> sigma;
|
|
451
|
-
}
|
|
452
|
-
|
|
453
|
-
parameters {
|
|
454
|
-
real mu;
|
|
455
|
-
real<lower=0> tau;
|
|
456
|
-
array[J] real eta;
|
|
457
|
-
}
|
|
458
|
-
|
|
459
|
-
transformed parameters {
|
|
460
|
-
array[J] real theta;
|
|
461
|
-
for (j in 1:J)
|
|
462
|
-
theta[j] = mu + tau * eta[j];
|
|
463
|
-
}
|
|
464
|
-
|
|
465
|
-
model {
|
|
466
|
-
mu ~ normal(0, 5);
|
|
467
|
-
tau ~ cauchy(0, 5);
|
|
468
|
-
eta ~ normal(0, 1);
|
|
469
|
-
y ~ normal(theta, sigma);
|
|
470
|
-
}
|
|
471
|
-
|
|
472
|
-
generated quantities {
|
|
473
|
-
array[J] real log_lik;
|
|
474
|
-
array[J] real y_hat;
|
|
475
|
-
for (j in 1:J) {
|
|
476
|
-
log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
|
|
477
|
-
y_hat[j] = normal_rng(theta[j], sigma[j]);
|
|
478
|
-
}
|
|
479
|
-
}
|
|
480
|
-
"""
|
|
481
|
-
if pystan_version() == 2:
|
|
482
|
-
import pystan # pylint: disable=import-error
|
|
483
|
-
|
|
484
|
-
stan_model = pystan.StanModel(model_code=schools_code)
|
|
485
|
-
fit = stan_model.sampling(
|
|
486
|
-
data=data,
|
|
487
|
-
iter=draws + 500,
|
|
488
|
-
warmup=500,
|
|
489
|
-
chains=chains,
|
|
490
|
-
check_hmc_diagnostics=False,
|
|
491
|
-
control=dict(adapt_engaged=False),
|
|
492
|
-
)
|
|
493
|
-
else:
|
|
494
|
-
import stan # pylint: disable=import-error
|
|
495
|
-
|
|
496
|
-
stan_model = stan.build(schools_code, data=data)
|
|
497
|
-
fit = stan_model.sample(
|
|
498
|
-
num_chains=chains, num_samples=draws, num_warmup=500, save_warmup=True
|
|
499
|
-
)
|
|
500
|
-
return stan_model, fit
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
def bm_schools_model(data, draws, chains):
|
|
504
|
-
import beanmachine.ppl as bm # pylint: disable=import-error
|
|
505
|
-
import torch
|
|
506
|
-
import torch.distributions as dist
|
|
507
|
-
|
|
508
|
-
class EightSchools:
|
|
509
|
-
@bm.random_variable
|
|
510
|
-
def mu(self):
|
|
511
|
-
return dist.Normal(0, 5)
|
|
512
|
-
|
|
513
|
-
@bm.random_variable
|
|
514
|
-
def tau(self):
|
|
515
|
-
return dist.HalfCauchy(5)
|
|
516
|
-
|
|
517
|
-
@bm.random_variable
|
|
518
|
-
def eta(self):
|
|
519
|
-
return dist.Normal(0, 1).expand((data["J"],))
|
|
520
|
-
|
|
521
|
-
@bm.functional
|
|
522
|
-
def theta(self):
|
|
523
|
-
return self.mu() + self.tau() * self.eta()
|
|
524
|
-
|
|
525
|
-
@bm.random_variable
|
|
526
|
-
def obs(self):
|
|
527
|
-
return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())
|
|
528
|
-
|
|
529
|
-
model = EightSchools()
|
|
530
|
-
|
|
531
|
-
prior = bm.GlobalNoUTurnSampler().infer(
|
|
532
|
-
queries=[model.mu(), model.tau(), model.eta()],
|
|
533
|
-
observations={},
|
|
534
|
-
num_samples=draws,
|
|
535
|
-
num_adaptive_samples=500,
|
|
536
|
-
num_chains=chains,
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
posterior = bm.GlobalNoUTurnSampler().infer(
|
|
540
|
-
queries=[model.mu(), model.tau(), model.eta()],
|
|
541
|
-
observations={model.obs(): torch.from_numpy(data["y"]).float()},
|
|
542
|
-
num_samples=draws,
|
|
543
|
-
num_adaptive_samples=500,
|
|
544
|
-
num_chains=chains,
|
|
545
|
-
)
|
|
546
|
-
return model, prior, posterior
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
def library_handle(library):
|
|
550
|
-
"""Import a library and return the handle."""
|
|
551
|
-
if library == "pystan":
|
|
552
|
-
try:
|
|
553
|
-
module = importlib.import_module("pystan")
|
|
554
|
-
except ImportError:
|
|
555
|
-
module = importlib.import_module("stan")
|
|
556
|
-
else:
|
|
557
|
-
module = importlib.import_module(library)
|
|
558
|
-
return module
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
def load_cached_models(eight_schools_data, draws, chains, libs=None):
|
|
562
|
-
"""Load pystan, emcee, and pyro models from pickle."""
|
|
563
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
564
|
-
supported = (
|
|
565
|
-
("pystan", pystan_noncentered_schools),
|
|
566
|
-
("emcee", emcee_schools_model),
|
|
567
|
-
("pyro", pyro_noncentered_schools),
|
|
568
|
-
("numpyro", numpyro_schools_model),
|
|
569
|
-
# ("beanmachine", bm_schools_model), # ignore beanmachine until it supports torch>=2
|
|
570
|
-
)
|
|
571
|
-
data_directory = os.path.join(here, "saved_models")
|
|
572
|
-
models = {}
|
|
573
|
-
|
|
574
|
-
if isinstance(libs, str):
|
|
575
|
-
libs = [libs]
|
|
576
|
-
|
|
577
|
-
for library_name, func in supported:
|
|
578
|
-
if libs is not None and library_name not in libs:
|
|
579
|
-
continue
|
|
580
|
-
library = library_handle(library_name)
|
|
581
|
-
if library.__name__ == "stan":
|
|
582
|
-
# PyStan3 does not support pickling
|
|
583
|
-
# httpstan caches models automatically
|
|
584
|
-
_log.info("Generating and loading stan model")
|
|
585
|
-
models["pystan"] = func(eight_schools_data, draws, chains)
|
|
586
|
-
continue
|
|
587
|
-
|
|
588
|
-
py_version = sys.version_info
|
|
589
|
-
fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format(
|
|
590
|
-
py_version, library, sys.platform, draws, chains
|
|
591
|
-
)
|
|
592
|
-
|
|
593
|
-
path = os.path.join(data_directory, fname)
|
|
594
|
-
if not os.path.exists(path):
|
|
595
|
-
with gzip.open(path, "wb") as buff:
|
|
596
|
-
try:
|
|
597
|
-
_log.info("Generating and caching %s", fname)
|
|
598
|
-
cloudpickle.dump(func(eight_schools_data, draws, chains), buff)
|
|
599
|
-
except AttributeError as err:
|
|
600
|
-
raise AttributeError(f"Failed caching {library_name}") from err
|
|
601
|
-
|
|
602
|
-
with gzip.open(path, "rb") as buff:
|
|
603
|
-
_log.info("Loading %s from cache", fname)
|
|
604
|
-
models[library.__name__] = cloudpickle.load(buff)
|
|
605
|
-
|
|
606
|
-
return models
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
def pystan_version():
|
|
610
|
-
"""Check PyStan version.
|
|
611
|
-
|
|
612
|
-
Returns
|
|
613
|
-
-------
|
|
614
|
-
int
|
|
615
|
-
Major version number
|
|
616
|
-
|
|
617
|
-
"""
|
|
618
|
-
try:
|
|
619
|
-
import pystan # pylint: disable=import-error
|
|
620
|
-
|
|
621
|
-
version = int(pystan.__version__[0])
|
|
622
|
-
except ImportError:
|
|
623
|
-
try:
|
|
624
|
-
import stan # pylint: disable=import-error
|
|
625
|
-
|
|
626
|
-
version = int(stan.__version__[0])
|
|
627
|
-
except ImportError:
|
|
628
|
-
version = None
|
|
629
|
-
return version
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def test_precompile_models(eight_schools_params, draws, chains):
|
|
633
|
-
"""Precompile model files."""
|
|
634
|
-
load_cached_models(eight_schools_params, draws, chains)
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
def importorskip(
|
|
638
|
-
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
|
|
639
|
-
) -> Any:
|
|
640
|
-
"""Import and return the requested module ``modname``.
|
|
641
|
-
|
|
642
|
-
Doesn't allow skips on CI machine.
|
|
643
|
-
Borrowed and modified from ``pytest.importorskip``.
|
|
644
|
-
:param str modname: the name of the module to import
|
|
645
|
-
:param str minversion: if given, the imported module's ``__version__``
|
|
646
|
-
attribute must be at least this minimal version, otherwise the test is
|
|
647
|
-
still skipped.
|
|
648
|
-
:param str reason: if given, this reason is shown as the message when the
|
|
649
|
-
module cannot be imported.
|
|
650
|
-
:returns: The imported module. This should be assigned to its canonical
|
|
651
|
-
name.
|
|
652
|
-
Example::
|
|
653
|
-
docutils = pytest.importorskip("docutils")
|
|
654
|
-
"""
|
|
655
|
-
# Unless ARVIZ_REQUIRE_ALL_DEPS is defined, tests that require a missing dependency are skipped
|
|
656
|
-
# if set, missing optional dependencies trigger failed tests.
|
|
657
|
-
if "ARVIZ_REQUIRE_ALL_DEPS" not in os.environ:
|
|
658
|
-
return pytest.importorskip(modname=modname, minversion=minversion, reason=reason)
|
|
659
|
-
|
|
660
|
-
compile(modname, "", "eval") # to catch syntaxerrors
|
|
661
|
-
|
|
662
|
-
with warnings.catch_warnings():
|
|
663
|
-
# make sure to ignore ImportWarnings that might happen because
|
|
664
|
-
# of existing directories with the same name we're trying to
|
|
665
|
-
# import but without a __init__.py file
|
|
666
|
-
warnings.simplefilter("ignore")
|
|
667
|
-
__import__(modname)
|
|
668
|
-
mod = sys.modules[modname]
|
|
669
|
-
if minversion is None:
|
|
670
|
-
return mod
|
|
671
|
-
verattr = getattr(mod, "__version__", None)
|
|
672
|
-
if verattr is None or Version(verattr) < Version(minversion):
|
|
673
|
-
raise Skipped(
|
|
674
|
-
"module %r has __version__ %r, required is: %r" % (modname, verattr, minversion),
|
|
675
|
-
allow_module_level=True,
|
|
676
|
-
)
|
|
677
|
-
return mod
|