arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_pystan.py
DELETED
|
@@ -1,1095 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-instance-attributes,too-many-lines
|
|
2
|
-
"""PyStan-specific conversion code."""
|
|
3
|
-
import re
|
|
4
|
-
from collections import OrderedDict
|
|
5
|
-
from copy import deepcopy
|
|
6
|
-
from math import ceil
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import xarray as xr
|
|
10
|
-
|
|
11
|
-
from .. import _log
|
|
12
|
-
from ..rcparams import rcParams
|
|
13
|
-
from .base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs, requires
|
|
14
|
-
from .inference_data import InferenceData
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
import ujson as json
|
|
18
|
-
except ImportError:
|
|
19
|
-
# Can't find ujson using json
|
|
20
|
-
# mypy struggles with conditional imports expressed as catching ImportError:
|
|
21
|
-
# https://github.com/python/mypy/issues/1153
|
|
22
|
-
import json # type: ignore
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class PyStanConverter:
|
|
26
|
-
"""Encapsulate PyStan specific logic."""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
*,
|
|
31
|
-
posterior=None,
|
|
32
|
-
posterior_predictive=None,
|
|
33
|
-
predictions=None,
|
|
34
|
-
prior=None,
|
|
35
|
-
prior_predictive=None,
|
|
36
|
-
observed_data=None,
|
|
37
|
-
constant_data=None,
|
|
38
|
-
predictions_constant_data=None,
|
|
39
|
-
log_likelihood=None,
|
|
40
|
-
coords=None,
|
|
41
|
-
dims=None,
|
|
42
|
-
save_warmup=None,
|
|
43
|
-
dtypes=None,
|
|
44
|
-
):
|
|
45
|
-
self.posterior = posterior
|
|
46
|
-
self.posterior_predictive = posterior_predictive
|
|
47
|
-
self.predictions = predictions
|
|
48
|
-
self.prior = prior
|
|
49
|
-
self.prior_predictive = prior_predictive
|
|
50
|
-
self.observed_data = observed_data
|
|
51
|
-
self.constant_data = constant_data
|
|
52
|
-
self.predictions_constant_data = predictions_constant_data
|
|
53
|
-
self.log_likelihood = (
|
|
54
|
-
rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
|
|
55
|
-
)
|
|
56
|
-
self.coords = coords
|
|
57
|
-
self.dims = dims
|
|
58
|
-
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
59
|
-
self.dtypes = dtypes
|
|
60
|
-
|
|
61
|
-
if (
|
|
62
|
-
self.log_likelihood is True
|
|
63
|
-
and self.posterior is not None
|
|
64
|
-
and "log_lik" in self.posterior.sim["pars_oi"]
|
|
65
|
-
):
|
|
66
|
-
self.log_likelihood = ["log_lik"]
|
|
67
|
-
elif isinstance(self.log_likelihood, bool):
|
|
68
|
-
self.log_likelihood = None
|
|
69
|
-
|
|
70
|
-
import pystan # pylint: disable=import-error
|
|
71
|
-
|
|
72
|
-
self.pystan = pystan
|
|
73
|
-
|
|
74
|
-
@requires("posterior")
|
|
75
|
-
def posterior_to_xarray(self):
|
|
76
|
-
"""Extract posterior samples from fit."""
|
|
77
|
-
posterior = self.posterior
|
|
78
|
-
# filter posterior_predictive and log_likelihood
|
|
79
|
-
posterior_predictive = self.posterior_predictive
|
|
80
|
-
if posterior_predictive is None:
|
|
81
|
-
posterior_predictive = []
|
|
82
|
-
elif isinstance(posterior_predictive, str):
|
|
83
|
-
posterior_predictive = [posterior_predictive]
|
|
84
|
-
predictions = self.predictions
|
|
85
|
-
if predictions is None:
|
|
86
|
-
predictions = []
|
|
87
|
-
elif isinstance(predictions, str):
|
|
88
|
-
predictions = [predictions]
|
|
89
|
-
log_likelihood = self.log_likelihood
|
|
90
|
-
if log_likelihood is None:
|
|
91
|
-
log_likelihood = []
|
|
92
|
-
elif isinstance(log_likelihood, str):
|
|
93
|
-
log_likelihood = [log_likelihood]
|
|
94
|
-
elif isinstance(log_likelihood, dict):
|
|
95
|
-
log_likelihood = list(log_likelihood.values())
|
|
96
|
-
|
|
97
|
-
ignore = posterior_predictive + predictions + log_likelihood + ["lp__"]
|
|
98
|
-
|
|
99
|
-
data, data_warmup = get_draws(
|
|
100
|
-
posterior, ignore=ignore, warmup=self.save_warmup, dtypes=self.dtypes
|
|
101
|
-
)
|
|
102
|
-
attrs = get_attrs(posterior)
|
|
103
|
-
return (
|
|
104
|
-
dict_to_dataset(
|
|
105
|
-
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
106
|
-
),
|
|
107
|
-
dict_to_dataset(
|
|
108
|
-
data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
109
|
-
),
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
@requires("posterior")
|
|
113
|
-
def sample_stats_to_xarray(self):
|
|
114
|
-
"""Extract sample_stats from posterior."""
|
|
115
|
-
posterior = self.posterior
|
|
116
|
-
|
|
117
|
-
data, data_warmup = get_sample_stats(posterior, warmup=self.save_warmup)
|
|
118
|
-
|
|
119
|
-
# lp__
|
|
120
|
-
stat_lp, stat_lp_warmup = get_draws(
|
|
121
|
-
posterior, variables="lp__", warmup=self.save_warmup, dtypes=self.dtypes
|
|
122
|
-
)
|
|
123
|
-
data["lp"] = stat_lp["lp__"]
|
|
124
|
-
if stat_lp_warmup:
|
|
125
|
-
data_warmup["lp"] = stat_lp_warmup["lp__"]
|
|
126
|
-
|
|
127
|
-
attrs = get_attrs(posterior)
|
|
128
|
-
return (
|
|
129
|
-
dict_to_dataset(
|
|
130
|
-
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
131
|
-
),
|
|
132
|
-
dict_to_dataset(
|
|
133
|
-
data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
134
|
-
),
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
@requires("posterior")
|
|
138
|
-
@requires("log_likelihood")
|
|
139
|
-
def log_likelihood_to_xarray(self):
|
|
140
|
-
"""Store log_likelihood data in log_likelihood group."""
|
|
141
|
-
fit = self.posterior
|
|
142
|
-
|
|
143
|
-
# log_likelihood values
|
|
144
|
-
log_likelihood = self.log_likelihood
|
|
145
|
-
if isinstance(log_likelihood, str):
|
|
146
|
-
log_likelihood = [log_likelihood]
|
|
147
|
-
if isinstance(log_likelihood, (list, tuple)):
|
|
148
|
-
log_likelihood = {name: name for name in log_likelihood}
|
|
149
|
-
log_likelihood_draws, log_likelihood_draws_warmup = get_draws(
|
|
150
|
-
fit,
|
|
151
|
-
variables=list(log_likelihood.values()),
|
|
152
|
-
warmup=self.save_warmup,
|
|
153
|
-
dtypes=self.dtypes,
|
|
154
|
-
)
|
|
155
|
-
data = {
|
|
156
|
-
obs_var_name: log_likelihood_draws[log_like_name]
|
|
157
|
-
for obs_var_name, log_like_name in log_likelihood.items()
|
|
158
|
-
if log_like_name in log_likelihood_draws
|
|
159
|
-
}
|
|
160
|
-
|
|
161
|
-
data_warmup = {
|
|
162
|
-
obs_var_name: log_likelihood_draws_warmup[log_like_name]
|
|
163
|
-
for obs_var_name, log_like_name in log_likelihood.items()
|
|
164
|
-
if log_like_name in log_likelihood_draws_warmup
|
|
165
|
-
}
|
|
166
|
-
|
|
167
|
-
return (
|
|
168
|
-
dict_to_dataset(
|
|
169
|
-
data, library=self.pystan, coords=self.coords, dims=self.dims, skip_event_dims=True
|
|
170
|
-
),
|
|
171
|
-
dict_to_dataset(
|
|
172
|
-
data_warmup,
|
|
173
|
-
library=self.pystan,
|
|
174
|
-
coords=self.coords,
|
|
175
|
-
dims=self.dims,
|
|
176
|
-
skip_event_dims=True,
|
|
177
|
-
),
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
@requires("posterior")
|
|
181
|
-
@requires("posterior_predictive")
|
|
182
|
-
def posterior_predictive_to_xarray(self):
|
|
183
|
-
"""Convert posterior_predictive samples to xarray."""
|
|
184
|
-
posterior = self.posterior
|
|
185
|
-
posterior_predictive = self.posterior_predictive
|
|
186
|
-
data, data_warmup = get_draws(
|
|
187
|
-
posterior, variables=posterior_predictive, warmup=self.save_warmup, dtypes=self.dtypes
|
|
188
|
-
)
|
|
189
|
-
return (
|
|
190
|
-
dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
|
|
191
|
-
dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
@requires("posterior")
|
|
195
|
-
@requires("predictions")
|
|
196
|
-
def predictions_to_xarray(self):
|
|
197
|
-
"""Convert predictions samples to xarray."""
|
|
198
|
-
posterior = self.posterior
|
|
199
|
-
predictions = self.predictions
|
|
200
|
-
data, data_warmup = get_draws(
|
|
201
|
-
posterior, variables=predictions, warmup=self.save_warmup, dtypes=self.dtypes
|
|
202
|
-
)
|
|
203
|
-
return (
|
|
204
|
-
dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
|
|
205
|
-
dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
@requires("prior")
|
|
209
|
-
def prior_to_xarray(self):
|
|
210
|
-
"""Convert prior samples to xarray."""
|
|
211
|
-
prior = self.prior
|
|
212
|
-
# filter posterior_predictive and log_likelihood
|
|
213
|
-
prior_predictive = self.prior_predictive
|
|
214
|
-
if prior_predictive is None:
|
|
215
|
-
prior_predictive = []
|
|
216
|
-
elif isinstance(prior_predictive, str):
|
|
217
|
-
prior_predictive = [prior_predictive]
|
|
218
|
-
|
|
219
|
-
ignore = prior_predictive + ["lp__"]
|
|
220
|
-
|
|
221
|
-
data, _ = get_draws(prior, ignore=ignore, warmup=False, dtypes=self.dtypes)
|
|
222
|
-
attrs = get_attrs(prior)
|
|
223
|
-
return dict_to_dataset(
|
|
224
|
-
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
@requires("prior")
|
|
228
|
-
def sample_stats_prior_to_xarray(self):
|
|
229
|
-
"""Extract sample_stats_prior from prior."""
|
|
230
|
-
prior = self.prior
|
|
231
|
-
data, _ = get_sample_stats(prior, warmup=False)
|
|
232
|
-
|
|
233
|
-
# lp__
|
|
234
|
-
stat_lp, _ = get_draws(prior, variables="lp__", warmup=False, dtypes=self.dtypes)
|
|
235
|
-
data["lp"] = stat_lp["lp__"]
|
|
236
|
-
|
|
237
|
-
attrs = get_attrs(prior)
|
|
238
|
-
return dict_to_dataset(
|
|
239
|
-
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
@requires("prior")
|
|
243
|
-
@requires("prior_predictive")
|
|
244
|
-
def prior_predictive_to_xarray(self):
|
|
245
|
-
"""Convert prior_predictive samples to xarray."""
|
|
246
|
-
prior = self.prior
|
|
247
|
-
prior_predictive = self.prior_predictive
|
|
248
|
-
data, _ = get_draws(prior, variables=prior_predictive, warmup=False, dtypes=self.dtypes)
|
|
249
|
-
return dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims)
|
|
250
|
-
|
|
251
|
-
@requires("posterior")
|
|
252
|
-
@requires(["observed_data", "constant_data", "predictions_constant_data"])
|
|
253
|
-
def data_to_xarray(self):
|
|
254
|
-
"""Convert observed, constant data and predictions constant data to xarray."""
|
|
255
|
-
posterior = self.posterior
|
|
256
|
-
dims = {} if self.dims is None else self.dims
|
|
257
|
-
obs_const_dict = {}
|
|
258
|
-
for group_name in ("observed_data", "constant_data", "predictions_constant_data"):
|
|
259
|
-
names = getattr(self, group_name)
|
|
260
|
-
if names is None:
|
|
261
|
-
continue
|
|
262
|
-
names = [names] if isinstance(names, str) else names
|
|
263
|
-
data = OrderedDict()
|
|
264
|
-
for key in names:
|
|
265
|
-
vals = np.atleast_1d(posterior.data[key])
|
|
266
|
-
val_dims = dims.get(key)
|
|
267
|
-
val_dims, coords = generate_dims_coords(
|
|
268
|
-
vals.shape, key, dims=val_dims, coords=self.coords
|
|
269
|
-
)
|
|
270
|
-
data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
|
|
271
|
-
obs_const_dict[group_name] = xr.Dataset(
|
|
272
|
-
data_vars=data, attrs=make_attrs(library=self.pystan)
|
|
273
|
-
)
|
|
274
|
-
return obs_const_dict
|
|
275
|
-
|
|
276
|
-
def to_inference_data(self):
|
|
277
|
-
"""Convert all available data to an InferenceData object.
|
|
278
|
-
|
|
279
|
-
Note that if groups can not be created (i.e., there is no `fit`, so
|
|
280
|
-
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
|
|
281
|
-
will not have those groups.
|
|
282
|
-
"""
|
|
283
|
-
data_dict = self.data_to_xarray()
|
|
284
|
-
return InferenceData(
|
|
285
|
-
save_warmup=self.save_warmup,
|
|
286
|
-
**{
|
|
287
|
-
"posterior": self.posterior_to_xarray(),
|
|
288
|
-
"sample_stats": self.sample_stats_to_xarray(),
|
|
289
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
290
|
-
"posterior_predictive": self.posterior_predictive_to_xarray(),
|
|
291
|
-
"predictions": self.predictions_to_xarray(),
|
|
292
|
-
"prior": self.prior_to_xarray(),
|
|
293
|
-
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
|
|
294
|
-
"prior_predictive": self.prior_predictive_to_xarray(),
|
|
295
|
-
**({} if data_dict is None else data_dict),
|
|
296
|
-
},
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class PyStan3Converter:
|
|
301
|
-
"""Encapsulate PyStan3 specific logic."""
|
|
302
|
-
|
|
303
|
-
# pylint: disable=too-many-instance-attributes
|
|
304
|
-
def __init__(
|
|
305
|
-
self,
|
|
306
|
-
*,
|
|
307
|
-
posterior=None,
|
|
308
|
-
posterior_model=None,
|
|
309
|
-
posterior_predictive=None,
|
|
310
|
-
predictions=None,
|
|
311
|
-
prior=None,
|
|
312
|
-
prior_model=None,
|
|
313
|
-
prior_predictive=None,
|
|
314
|
-
observed_data=None,
|
|
315
|
-
constant_data=None,
|
|
316
|
-
predictions_constant_data=None,
|
|
317
|
-
log_likelihood=None,
|
|
318
|
-
coords=None,
|
|
319
|
-
dims=None,
|
|
320
|
-
save_warmup=None,
|
|
321
|
-
dtypes=None,
|
|
322
|
-
):
|
|
323
|
-
self.posterior = posterior
|
|
324
|
-
self.posterior_model = posterior_model
|
|
325
|
-
self.posterior_predictive = posterior_predictive
|
|
326
|
-
self.predictions = predictions
|
|
327
|
-
self.prior = prior
|
|
328
|
-
self.prior_model = prior_model
|
|
329
|
-
self.prior_predictive = prior_predictive
|
|
330
|
-
self.observed_data = observed_data
|
|
331
|
-
self.constant_data = constant_data
|
|
332
|
-
self.predictions_constant_data = predictions_constant_data
|
|
333
|
-
self.log_likelihood = (
|
|
334
|
-
rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
|
|
335
|
-
)
|
|
336
|
-
self.coords = coords
|
|
337
|
-
self.dims = dims
|
|
338
|
-
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
339
|
-
self.dtypes = dtypes
|
|
340
|
-
|
|
341
|
-
if (
|
|
342
|
-
self.log_likelihood is True
|
|
343
|
-
and self.posterior is not None
|
|
344
|
-
and "log_lik" in self.posterior.param_names
|
|
345
|
-
):
|
|
346
|
-
self.log_likelihood = ["log_lik"]
|
|
347
|
-
elif isinstance(self.log_likelihood, bool):
|
|
348
|
-
self.log_likelihood = None
|
|
349
|
-
|
|
350
|
-
import stan # pylint: disable=import-error
|
|
351
|
-
|
|
352
|
-
self.stan = stan
|
|
353
|
-
|
|
354
|
-
@requires("posterior")
|
|
355
|
-
def posterior_to_xarray(self):
|
|
356
|
-
"""Extract posterior samples from fit."""
|
|
357
|
-
posterior = self.posterior
|
|
358
|
-
posterior_model = self.posterior_model
|
|
359
|
-
# filter posterior_predictive and log_likelihood
|
|
360
|
-
posterior_predictive = self.posterior_predictive
|
|
361
|
-
if posterior_predictive is None:
|
|
362
|
-
posterior_predictive = []
|
|
363
|
-
elif isinstance(posterior_predictive, str):
|
|
364
|
-
posterior_predictive = [posterior_predictive]
|
|
365
|
-
predictions = self.predictions
|
|
366
|
-
if predictions is None:
|
|
367
|
-
predictions = []
|
|
368
|
-
elif isinstance(predictions, str):
|
|
369
|
-
predictions = [predictions]
|
|
370
|
-
log_likelihood = self.log_likelihood
|
|
371
|
-
if log_likelihood is None:
|
|
372
|
-
log_likelihood = []
|
|
373
|
-
elif isinstance(log_likelihood, str):
|
|
374
|
-
log_likelihood = [log_likelihood]
|
|
375
|
-
elif isinstance(log_likelihood, dict):
|
|
376
|
-
log_likelihood = list(log_likelihood.values())
|
|
377
|
-
|
|
378
|
-
ignore = posterior_predictive + predictions + log_likelihood
|
|
379
|
-
|
|
380
|
-
data, data_warmup = get_draws_stan3(
|
|
381
|
-
posterior,
|
|
382
|
-
model=posterior_model,
|
|
383
|
-
ignore=ignore,
|
|
384
|
-
warmup=self.save_warmup,
|
|
385
|
-
dtypes=self.dtypes,
|
|
386
|
-
)
|
|
387
|
-
attrs = get_attrs_stan3(posterior, model=posterior_model)
|
|
388
|
-
return (
|
|
389
|
-
dict_to_dataset(
|
|
390
|
-
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
391
|
-
),
|
|
392
|
-
dict_to_dataset(
|
|
393
|
-
data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
394
|
-
),
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
@requires("posterior")
|
|
398
|
-
def sample_stats_to_xarray(self):
|
|
399
|
-
"""Extract sample_stats from posterior."""
|
|
400
|
-
posterior = self.posterior
|
|
401
|
-
posterior_model = self.posterior_model
|
|
402
|
-
data, data_warmup = get_sample_stats_stan3(
|
|
403
|
-
posterior, ignore="lp__", warmup=self.save_warmup, dtypes=self.dtypes
|
|
404
|
-
)
|
|
405
|
-
data_lp, data_warmup_lp = get_sample_stats_stan3(
|
|
406
|
-
posterior, variables="lp__", warmup=self.save_warmup
|
|
407
|
-
)
|
|
408
|
-
data["lp"] = data_lp["lp"]
|
|
409
|
-
if data_warmup_lp:
|
|
410
|
-
data_warmup["lp"] = data_warmup_lp["lp"]
|
|
411
|
-
|
|
412
|
-
attrs = get_attrs_stan3(posterior, model=posterior_model)
|
|
413
|
-
return (
|
|
414
|
-
dict_to_dataset(
|
|
415
|
-
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
416
|
-
),
|
|
417
|
-
dict_to_dataset(
|
|
418
|
-
data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
419
|
-
),
|
|
420
|
-
)
|
|
421
|
-
|
|
422
|
-
@requires("posterior")
|
|
423
|
-
@requires("log_likelihood")
|
|
424
|
-
def log_likelihood_to_xarray(self):
|
|
425
|
-
"""Store log_likelihood data in log_likelihood group."""
|
|
426
|
-
fit = self.posterior
|
|
427
|
-
|
|
428
|
-
log_likelihood = self.log_likelihood
|
|
429
|
-
model = self.posterior_model
|
|
430
|
-
if isinstance(log_likelihood, str):
|
|
431
|
-
log_likelihood = [log_likelihood]
|
|
432
|
-
if isinstance(log_likelihood, (list, tuple)):
|
|
433
|
-
log_likelihood = {name: name for name in log_likelihood}
|
|
434
|
-
log_likelihood_draws, log_likelihood_draws_warmup = get_draws_stan3(
|
|
435
|
-
fit,
|
|
436
|
-
model=model,
|
|
437
|
-
variables=list(log_likelihood.values()),
|
|
438
|
-
warmup=self.save_warmup,
|
|
439
|
-
dtypes=self.dtypes,
|
|
440
|
-
)
|
|
441
|
-
data = {
|
|
442
|
-
obs_var_name: log_likelihood_draws[log_like_name]
|
|
443
|
-
for obs_var_name, log_like_name in log_likelihood.items()
|
|
444
|
-
if log_like_name in log_likelihood_draws
|
|
445
|
-
}
|
|
446
|
-
data_warmup = {
|
|
447
|
-
obs_var_name: log_likelihood_draws_warmup[log_like_name]
|
|
448
|
-
for obs_var_name, log_like_name in log_likelihood.items()
|
|
449
|
-
if log_like_name in log_likelihood_draws_warmup
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
return (
|
|
453
|
-
dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
|
|
454
|
-
dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
|
|
455
|
-
)
|
|
456
|
-
|
|
457
|
-
@requires("posterior")
|
|
458
|
-
@requires("posterior_predictive")
|
|
459
|
-
def posterior_predictive_to_xarray(self):
|
|
460
|
-
"""Convert posterior_predictive samples to xarray."""
|
|
461
|
-
posterior = self.posterior
|
|
462
|
-
posterior_model = self.posterior_model
|
|
463
|
-
posterior_predictive = self.posterior_predictive
|
|
464
|
-
data, data_warmup = get_draws_stan3(
|
|
465
|
-
posterior,
|
|
466
|
-
model=posterior_model,
|
|
467
|
-
variables=posterior_predictive,
|
|
468
|
-
warmup=self.save_warmup,
|
|
469
|
-
dtypes=self.dtypes,
|
|
470
|
-
)
|
|
471
|
-
return (
|
|
472
|
-
dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
|
|
473
|
-
dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
|
|
474
|
-
)
|
|
475
|
-
|
|
476
|
-
@requires("posterior")
|
|
477
|
-
@requires("predictions")
|
|
478
|
-
def predictions_to_xarray(self):
|
|
479
|
-
"""Convert predictions samples to xarray."""
|
|
480
|
-
posterior = self.posterior
|
|
481
|
-
posterior_model = self.posterior_model
|
|
482
|
-
predictions = self.predictions
|
|
483
|
-
data, data_warmup = get_draws_stan3(
|
|
484
|
-
posterior,
|
|
485
|
-
model=posterior_model,
|
|
486
|
-
variables=predictions,
|
|
487
|
-
warmup=self.save_warmup,
|
|
488
|
-
dtypes=self.dtypes,
|
|
489
|
-
)
|
|
490
|
-
return (
|
|
491
|
-
dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
|
|
492
|
-
dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
@requires("prior")
|
|
496
|
-
def prior_to_xarray(self):
|
|
497
|
-
"""Convert prior samples to xarray."""
|
|
498
|
-
prior = self.prior
|
|
499
|
-
prior_model = self.prior_model
|
|
500
|
-
# filter posterior_predictive and log_likelihood
|
|
501
|
-
prior_predictive = self.prior_predictive
|
|
502
|
-
if prior_predictive is None:
|
|
503
|
-
prior_predictive = []
|
|
504
|
-
elif isinstance(prior_predictive, str):
|
|
505
|
-
prior_predictive = [prior_predictive]
|
|
506
|
-
|
|
507
|
-
ignore = prior_predictive
|
|
508
|
-
|
|
509
|
-
data, data_warmup = get_draws_stan3(
|
|
510
|
-
prior, model=prior_model, ignore=ignore, warmup=self.save_warmup, dtypes=self.dtypes
|
|
511
|
-
)
|
|
512
|
-
attrs = get_attrs_stan3(prior, model=prior_model)
|
|
513
|
-
return (
|
|
514
|
-
dict_to_dataset(
|
|
515
|
-
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
516
|
-
),
|
|
517
|
-
dict_to_dataset(
|
|
518
|
-
data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
519
|
-
),
|
|
520
|
-
)
|
|
521
|
-
|
|
522
|
-
@requires("prior")
|
|
523
|
-
def sample_stats_prior_to_xarray(self):
|
|
524
|
-
"""Extract sample_stats_prior from prior."""
|
|
525
|
-
prior = self.prior
|
|
526
|
-
prior_model = self.prior_model
|
|
527
|
-
data, data_warmup = get_sample_stats_stan3(
|
|
528
|
-
prior, warmup=self.save_warmup, dtypes=self.dtypes
|
|
529
|
-
)
|
|
530
|
-
attrs = get_attrs_stan3(prior, model=prior_model)
|
|
531
|
-
return (
|
|
532
|
-
dict_to_dataset(
|
|
533
|
-
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
534
|
-
),
|
|
535
|
-
dict_to_dataset(
|
|
536
|
-
data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
|
|
537
|
-
),
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
@requires("prior")
|
|
541
|
-
@requires("prior_predictive")
|
|
542
|
-
def prior_predictive_to_xarray(self):
|
|
543
|
-
"""Convert prior_predictive samples to xarray."""
|
|
544
|
-
prior = self.prior
|
|
545
|
-
prior_model = self.prior_model
|
|
546
|
-
prior_predictive = self.prior_predictive
|
|
547
|
-
data, data_warmup = get_draws_stan3(
|
|
548
|
-
prior,
|
|
549
|
-
model=prior_model,
|
|
550
|
-
variables=prior_predictive,
|
|
551
|
-
warmup=self.save_warmup,
|
|
552
|
-
dtypes=self.dtypes,
|
|
553
|
-
)
|
|
554
|
-
return (
|
|
555
|
-
dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
|
|
556
|
-
dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
@requires("posterior_model")
|
|
560
|
-
@requires(["observed_data", "constant_data"])
|
|
561
|
-
def observed_and_constant_data_to_xarray(self):
|
|
562
|
-
"""Convert observed data to xarray."""
|
|
563
|
-
posterior_model = self.posterior_model
|
|
564
|
-
dims = {} if self.dims is None else self.dims
|
|
565
|
-
obs_const_dict = {}
|
|
566
|
-
for group_name in ("observed_data", "constant_data"):
|
|
567
|
-
names = getattr(self, group_name)
|
|
568
|
-
if names is None:
|
|
569
|
-
continue
|
|
570
|
-
names = [names] if isinstance(names, str) else names
|
|
571
|
-
data = OrderedDict()
|
|
572
|
-
for key in names:
|
|
573
|
-
vals = np.atleast_1d(posterior_model.data[key])
|
|
574
|
-
val_dims = dims.get(key)
|
|
575
|
-
val_dims, coords = generate_dims_coords(
|
|
576
|
-
vals.shape, key, dims=val_dims, coords=self.coords
|
|
577
|
-
)
|
|
578
|
-
data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
|
|
579
|
-
obs_const_dict[group_name] = xr.Dataset(
|
|
580
|
-
data_vars=data, attrs=make_attrs(library=self.stan)
|
|
581
|
-
)
|
|
582
|
-
return obs_const_dict
|
|
583
|
-
|
|
584
|
-
@requires("posterior_model")
|
|
585
|
-
@requires("predictions_constant_data")
|
|
586
|
-
def predictions_constant_data_to_xarray(self):
|
|
587
|
-
"""Convert observed data to xarray."""
|
|
588
|
-
posterior_model = self.posterior_model
|
|
589
|
-
dims = {} if self.dims is None else self.dims
|
|
590
|
-
names = self.predictions_constant_data
|
|
591
|
-
names = [names] if isinstance(names, str) else names
|
|
592
|
-
data = OrderedDict()
|
|
593
|
-
for key in names:
|
|
594
|
-
vals = np.atleast_1d(posterior_model.data[key])
|
|
595
|
-
val_dims = dims.get(key)
|
|
596
|
-
val_dims, coords = generate_dims_coords(
|
|
597
|
-
vals.shape, key, dims=val_dims, coords=self.coords
|
|
598
|
-
)
|
|
599
|
-
data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
|
|
600
|
-
return xr.Dataset(data_vars=data, attrs=make_attrs(library=self.stan))
|
|
601
|
-
|
|
602
|
-
def to_inference_data(self):
|
|
603
|
-
"""Convert all available data to an InferenceData object.
|
|
604
|
-
|
|
605
|
-
Note that if groups can not be created (i.e., there is no `fit`, so
|
|
606
|
-
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
|
|
607
|
-
will not have those groups.
|
|
608
|
-
"""
|
|
609
|
-
obs_const_dict = self.observed_and_constant_data_to_xarray()
|
|
610
|
-
predictions_const_data = self.predictions_constant_data_to_xarray()
|
|
611
|
-
return InferenceData(
|
|
612
|
-
save_warmup=self.save_warmup,
|
|
613
|
-
**{
|
|
614
|
-
"posterior": self.posterior_to_xarray(),
|
|
615
|
-
"sample_stats": self.sample_stats_to_xarray(),
|
|
616
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
617
|
-
"posterior_predictive": self.posterior_predictive_to_xarray(),
|
|
618
|
-
"predictions": self.predictions_to_xarray(),
|
|
619
|
-
"prior": self.prior_to_xarray(),
|
|
620
|
-
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
|
|
621
|
-
"prior_predictive": self.prior_predictive_to_xarray(),
|
|
622
|
-
**({} if obs_const_dict is None else obs_const_dict),
|
|
623
|
-
**(
|
|
624
|
-
{}
|
|
625
|
-
if predictions_const_data is None
|
|
626
|
-
else {"predictions_constant_data": predictions_const_data}
|
|
627
|
-
),
|
|
628
|
-
},
|
|
629
|
-
)
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
|
|
633
|
-
"""Extract draws from PyStan fit."""
|
|
634
|
-
if ignore is None:
|
|
635
|
-
ignore = []
|
|
636
|
-
if fit.mode == 1:
|
|
637
|
-
msg = "Model in mode 'test_grad'. Sampling is not conducted."
|
|
638
|
-
raise AttributeError(msg)
|
|
639
|
-
|
|
640
|
-
if fit.mode == 2 or fit.sim.get("samples") is None:
|
|
641
|
-
msg = "Fit doesn't contain samples."
|
|
642
|
-
raise AttributeError(msg)
|
|
643
|
-
|
|
644
|
-
if dtypes is None:
|
|
645
|
-
dtypes = {}
|
|
646
|
-
|
|
647
|
-
dtypes = {**infer_dtypes(fit), **dtypes}
|
|
648
|
-
|
|
649
|
-
if variables is None:
|
|
650
|
-
variables = fit.sim["pars_oi"]
|
|
651
|
-
elif isinstance(variables, str):
|
|
652
|
-
variables = [variables]
|
|
653
|
-
variables = list(variables)
|
|
654
|
-
|
|
655
|
-
for var, dim in zip(fit.sim["pars_oi"], fit.sim["dims_oi"]):
|
|
656
|
-
if var in variables and np.prod(dim) == 0:
|
|
657
|
-
del variables[variables.index(var)]
|
|
658
|
-
|
|
659
|
-
ndraws_warmup = fit.sim["warmup2"]
|
|
660
|
-
if max(ndraws_warmup) == 0:
|
|
661
|
-
warmup = False
|
|
662
|
-
ndraws = [s - w for s, w in zip(fit.sim["n_save"], ndraws_warmup)]
|
|
663
|
-
nchain = len(fit.sim["samples"])
|
|
664
|
-
|
|
665
|
-
# check if the values are in 0-based (<=2.17) or 1-based indexing (>=2.18)
|
|
666
|
-
shift = 1
|
|
667
|
-
if any(dim and np.prod(dim) != 0 for dim in fit.sim["dims_oi"]):
|
|
668
|
-
# choose variable with lowest number of dims > 1
|
|
669
|
-
par_idx = min(
|
|
670
|
-
(dim, i) for i, dim in enumerate(fit.sim["dims_oi"]) if (dim and np.prod(dim) != 0)
|
|
671
|
-
)[1]
|
|
672
|
-
offset = int(sum(map(np.prod, fit.sim["dims_oi"][:par_idx])))
|
|
673
|
-
par_offset = int(np.prod(fit.sim["dims_oi"][par_idx]))
|
|
674
|
-
par_keys = fit.sim["fnames_oi"][offset : offset + par_offset]
|
|
675
|
-
shift = len(par_keys)
|
|
676
|
-
for item in par_keys:
|
|
677
|
-
_, shape = item.replace("]", "").split("[")
|
|
678
|
-
shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
|
|
679
|
-
shift = min(shift, shape_idx_min)
|
|
680
|
-
# If shift is higher than 1, this will probably mean that Stan
|
|
681
|
-
# has implemented sparse structure (saves only non-zero parts),
|
|
682
|
-
# but let's hope that dims are still corresponding to the full shape
|
|
683
|
-
shift = int(min(shift, 1))
|
|
684
|
-
|
|
685
|
-
var_keys = OrderedDict((var, []) for var in fit.sim["pars_oi"])
|
|
686
|
-
for key in fit.sim["fnames_oi"]:
|
|
687
|
-
var, *tails = key.split("[")
|
|
688
|
-
loc = [Ellipsis]
|
|
689
|
-
for tail in tails:
|
|
690
|
-
loc = []
|
|
691
|
-
for i in tail[:-1].split(","):
|
|
692
|
-
loc.append(int(i) - shift)
|
|
693
|
-
var_keys[var].append((key, loc))
|
|
694
|
-
|
|
695
|
-
shapes = dict(zip(fit.sim["pars_oi"], fit.sim["dims_oi"]))
|
|
696
|
-
|
|
697
|
-
variables = [var for var in variables if var not in ignore]
|
|
698
|
-
|
|
699
|
-
data = OrderedDict()
|
|
700
|
-
data_warmup = OrderedDict()
|
|
701
|
-
|
|
702
|
-
for var in variables:
|
|
703
|
-
if var in data:
|
|
704
|
-
continue
|
|
705
|
-
keys_locs = var_keys.get(var, [(var, [Ellipsis])])
|
|
706
|
-
shape = shapes.get(var, [])
|
|
707
|
-
dtype = dtypes.get(var)
|
|
708
|
-
|
|
709
|
-
ndraw = max(ndraws)
|
|
710
|
-
ary_shape = [nchain, ndraw] + shape
|
|
711
|
-
ary = np.empty(ary_shape, dtype=dtype, order="F")
|
|
712
|
-
|
|
713
|
-
if warmup:
|
|
714
|
-
nwarmup = max(ndraws_warmup)
|
|
715
|
-
ary_warmup_shape = [nchain, nwarmup] + shape
|
|
716
|
-
ary_warmup = np.empty(ary_warmup_shape, dtype=dtype, order="F")
|
|
717
|
-
|
|
718
|
-
for chain, (pyholder, ndraw, ndraw_warmup) in enumerate(
|
|
719
|
-
zip(fit.sim["samples"], ndraws, ndraws_warmup)
|
|
720
|
-
):
|
|
721
|
-
axes = [chain, slice(None)]
|
|
722
|
-
for key, loc in keys_locs:
|
|
723
|
-
ary_slice = tuple(axes + loc)
|
|
724
|
-
ary[ary_slice] = pyholder.chains[key][-ndraw:]
|
|
725
|
-
if warmup:
|
|
726
|
-
ary_warmup[ary_slice] = pyholder.chains[key][:ndraw_warmup]
|
|
727
|
-
data[var] = ary
|
|
728
|
-
if warmup:
|
|
729
|
-
data_warmup[var] = ary_warmup
|
|
730
|
-
return data, data_warmup
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
def get_sample_stats(fit, warmup=False, dtypes=None):
|
|
734
|
-
"""Extract sample stats from PyStan fit."""
|
|
735
|
-
if dtypes is None:
|
|
736
|
-
dtypes = {}
|
|
737
|
-
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64, **dtypes}
|
|
738
|
-
|
|
739
|
-
rename_dict = {
|
|
740
|
-
"divergent": "diverging",
|
|
741
|
-
"n_leapfrog": "n_steps",
|
|
742
|
-
"treedepth": "tree_depth",
|
|
743
|
-
"stepsize": "step_size",
|
|
744
|
-
"accept_stat": "acceptance_rate",
|
|
745
|
-
}
|
|
746
|
-
|
|
747
|
-
ndraws_warmup = fit.sim["warmup2"]
|
|
748
|
-
if max(ndraws_warmup) == 0:
|
|
749
|
-
warmup = False
|
|
750
|
-
ndraws = [s - w for s, w in zip(fit.sim["n_save"], ndraws_warmup)]
|
|
751
|
-
|
|
752
|
-
extraction = OrderedDict()
|
|
753
|
-
extraction_warmup = OrderedDict()
|
|
754
|
-
for chain, (pyholder, ndraw, ndraw_warmup) in enumerate(
|
|
755
|
-
zip(fit.sim["samples"], ndraws, ndraws_warmup)
|
|
756
|
-
):
|
|
757
|
-
if chain == 0:
|
|
758
|
-
for key in pyholder["sampler_param_names"]:
|
|
759
|
-
extraction[key] = []
|
|
760
|
-
if warmup:
|
|
761
|
-
extraction_warmup[key] = []
|
|
762
|
-
for key, values in zip(pyholder["sampler_param_names"], pyholder["sampler_params"]):
|
|
763
|
-
extraction[key].append(values[-ndraw:])
|
|
764
|
-
if warmup:
|
|
765
|
-
extraction_warmup[key].append(values[:ndraw_warmup])
|
|
766
|
-
|
|
767
|
-
data = OrderedDict()
|
|
768
|
-
for key, values in extraction.items():
|
|
769
|
-
values = np.stack(values, axis=0)
|
|
770
|
-
dtype = dtypes.get(key)
|
|
771
|
-
values = values.astype(dtype)
|
|
772
|
-
name = re.sub("__$", "", key)
|
|
773
|
-
name = rename_dict.get(name, name)
|
|
774
|
-
data[name] = values
|
|
775
|
-
|
|
776
|
-
data_warmup = OrderedDict()
|
|
777
|
-
if warmup:
|
|
778
|
-
for key, values in extraction_warmup.items():
|
|
779
|
-
values = np.stack(values, axis=0)
|
|
780
|
-
values = values.astype(dtypes.get(key))
|
|
781
|
-
name = re.sub("__$", "", key)
|
|
782
|
-
name = rename_dict.get(name, name)
|
|
783
|
-
data_warmup[name] = values
|
|
784
|
-
|
|
785
|
-
return data, data_warmup
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
def get_attrs(fit):
|
|
789
|
-
"""Get attributes from PyStan fit object."""
|
|
790
|
-
attrs = {}
|
|
791
|
-
|
|
792
|
-
try:
|
|
793
|
-
attrs["args"] = [deepcopy(holder.args) for holder in fit.sim["samples"]]
|
|
794
|
-
except Exception as exp: # pylint: disable=broad-except
|
|
795
|
-
_log.warning("Failed to fetch args from fit: %s", exp)
|
|
796
|
-
if "args" in attrs:
|
|
797
|
-
for arg in attrs["args"]:
|
|
798
|
-
if isinstance(arg["init"], bytes):
|
|
799
|
-
arg["init"] = arg["init"].decode("utf-8")
|
|
800
|
-
attrs["args"] = json.dumps(attrs["args"])
|
|
801
|
-
try:
|
|
802
|
-
attrs["inits"] = [holder.inits for holder in fit.sim["samples"]]
|
|
803
|
-
except Exception as exp: # pylint: disable=broad-except
|
|
804
|
-
_log.warning("Failed to fetch `args` from fit: %s", exp)
|
|
805
|
-
else:
|
|
806
|
-
attrs["inits"] = json.dumps(attrs["inits"])
|
|
807
|
-
|
|
808
|
-
attrs["step_size"] = []
|
|
809
|
-
attrs["metric"] = []
|
|
810
|
-
attrs["inv_metric"] = []
|
|
811
|
-
for holder in fit.sim["samples"]:
|
|
812
|
-
try:
|
|
813
|
-
step_size = float(
|
|
814
|
-
re.search(
|
|
815
|
-
r"step\s*size\s*=\s*([0-9]+.?[0-9]+)\s*",
|
|
816
|
-
holder.adaptation_info,
|
|
817
|
-
flags=re.IGNORECASE,
|
|
818
|
-
).group(1)
|
|
819
|
-
)
|
|
820
|
-
except AttributeError:
|
|
821
|
-
step_size = np.nan
|
|
822
|
-
attrs["step_size"].append(step_size)
|
|
823
|
-
|
|
824
|
-
inv_metric_match = re.search(
|
|
825
|
-
r"mass matrix:\s*(.*)\s*$", holder.adaptation_info, flags=re.DOTALL
|
|
826
|
-
)
|
|
827
|
-
if inv_metric_match:
|
|
828
|
-
inv_metric_str = inv_metric_match.group(1)
|
|
829
|
-
if "Diagonal elements of inverse mass matrix" in holder.adaptation_info:
|
|
830
|
-
metric = "diag_e"
|
|
831
|
-
inv_metric = [float(item) for item in inv_metric_str.strip(" #\n").split(",")]
|
|
832
|
-
else:
|
|
833
|
-
metric = "dense_e"
|
|
834
|
-
inv_metric = [
|
|
835
|
-
list(map(float, item.split(",")))
|
|
836
|
-
for item in re.sub(r"#\s", "", inv_metric_str).splitlines()
|
|
837
|
-
]
|
|
838
|
-
else:
|
|
839
|
-
metric = "unit_e"
|
|
840
|
-
inv_metric = None
|
|
841
|
-
|
|
842
|
-
attrs["metric"].append(metric)
|
|
843
|
-
attrs["inv_metric"].append(inv_metric)
|
|
844
|
-
attrs["inv_metric"] = json.dumps(attrs["inv_metric"])
|
|
845
|
-
|
|
846
|
-
if not attrs["step_size"]:
|
|
847
|
-
del attrs["step_size"]
|
|
848
|
-
|
|
849
|
-
attrs["adaptation_info"] = fit.get_adaptation_info()
|
|
850
|
-
attrs["stan_code"] = fit.get_stancode()
|
|
851
|
-
|
|
852
|
-
return attrs
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
def get_draws_stan3(fit, model=None, variables=None, ignore=None, warmup=False, dtypes=None):
|
|
856
|
-
"""Extract draws from PyStan3 fit."""
|
|
857
|
-
if ignore is None:
|
|
858
|
-
ignore = []
|
|
859
|
-
|
|
860
|
-
if dtypes is None:
|
|
861
|
-
dtypes = {}
|
|
862
|
-
|
|
863
|
-
if model is not None:
|
|
864
|
-
dtypes = {**infer_dtypes(fit, model), **dtypes}
|
|
865
|
-
|
|
866
|
-
if not fit.save_warmup:
|
|
867
|
-
warmup = False
|
|
868
|
-
|
|
869
|
-
num_warmup = ceil((fit.num_warmup * fit.save_warmup) / fit.num_thin)
|
|
870
|
-
|
|
871
|
-
if variables is None:
|
|
872
|
-
variables = fit.param_names
|
|
873
|
-
elif isinstance(variables, str):
|
|
874
|
-
variables = [variables]
|
|
875
|
-
variables = list(variables)
|
|
876
|
-
|
|
877
|
-
data = OrderedDict()
|
|
878
|
-
data_warmup = OrderedDict()
|
|
879
|
-
|
|
880
|
-
for var in variables:
|
|
881
|
-
if var in ignore:
|
|
882
|
-
continue
|
|
883
|
-
if var in data:
|
|
884
|
-
continue
|
|
885
|
-
dtype = dtypes.get(var)
|
|
886
|
-
|
|
887
|
-
new_shape = (*fit.dims[fit.param_names.index(var)], -1, fit.num_chains)
|
|
888
|
-
if 0 in new_shape:
|
|
889
|
-
continue
|
|
890
|
-
values = fit._draws[fit._parameter_indexes(var), :] # pylint: disable=protected-access
|
|
891
|
-
values = values.reshape(new_shape, order="F")
|
|
892
|
-
values = np.moveaxis(values, [-2, -1], [1, 0])
|
|
893
|
-
values = values.astype(dtype)
|
|
894
|
-
if warmup:
|
|
895
|
-
data_warmup[var] = values[:, num_warmup:]
|
|
896
|
-
data[var] = values[:, num_warmup:]
|
|
897
|
-
|
|
898
|
-
return data, data_warmup
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
def get_sample_stats_stan3(fit, variables=None, ignore=None, warmup=False, dtypes=None):
|
|
902
|
-
"""Extract sample stats from PyStan3 fit."""
|
|
903
|
-
if dtypes is None:
|
|
904
|
-
dtypes = {}
|
|
905
|
-
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64, **dtypes}
|
|
906
|
-
|
|
907
|
-
rename_dict = {
|
|
908
|
-
"divergent": "diverging",
|
|
909
|
-
"n_leapfrog": "n_steps",
|
|
910
|
-
"treedepth": "tree_depth",
|
|
911
|
-
"stepsize": "step_size",
|
|
912
|
-
"accept_stat": "acceptance_rate",
|
|
913
|
-
}
|
|
914
|
-
|
|
915
|
-
if isinstance(variables, str):
|
|
916
|
-
variables = [variables]
|
|
917
|
-
if isinstance(ignore, str):
|
|
918
|
-
ignore = [ignore]
|
|
919
|
-
|
|
920
|
-
if not fit.save_warmup:
|
|
921
|
-
warmup = False
|
|
922
|
-
|
|
923
|
-
num_warmup = ceil((fit.num_warmup * fit.save_warmup) / fit.num_thin)
|
|
924
|
-
|
|
925
|
-
data = OrderedDict()
|
|
926
|
-
data_warmup = OrderedDict()
|
|
927
|
-
for key in fit.sample_and_sampler_param_names:
|
|
928
|
-
if (variables and key not in variables) or (ignore and key in ignore):
|
|
929
|
-
continue
|
|
930
|
-
new_shape = -1, fit.num_chains
|
|
931
|
-
values = fit._draws[fit._parameter_indexes(key)] # pylint: disable=protected-access
|
|
932
|
-
values = values.reshape(new_shape, order="F")
|
|
933
|
-
values = np.moveaxis(values, [-2, -1], [1, 0])
|
|
934
|
-
dtype = dtypes.get(key)
|
|
935
|
-
values = values.astype(dtype)
|
|
936
|
-
name = re.sub("__$", "", key)
|
|
937
|
-
name = rename_dict.get(name, name)
|
|
938
|
-
if warmup:
|
|
939
|
-
data_warmup[name] = values[:, :num_warmup]
|
|
940
|
-
data[name] = values[:, num_warmup:]
|
|
941
|
-
|
|
942
|
-
return data, data_warmup
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
def get_attrs_stan3(fit, model=None):
|
|
946
|
-
"""Get attributes from PyStan3 fit and model object."""
|
|
947
|
-
attrs = {}
|
|
948
|
-
for key in ["num_chains", "num_samples", "num_thin", "num_warmup", "save_warmup"]:
|
|
949
|
-
try:
|
|
950
|
-
attrs[key] = getattr(fit, key)
|
|
951
|
-
except AttributeError as exp:
|
|
952
|
-
_log.warning("Failed to access attribute %s in fit object %s", key, exp)
|
|
953
|
-
|
|
954
|
-
if model is not None:
|
|
955
|
-
for key in ["model_name", "program_code", "random_seed"]:
|
|
956
|
-
try:
|
|
957
|
-
attrs[key] = getattr(model, key)
|
|
958
|
-
except AttributeError as exp:
|
|
959
|
-
_log.warning("Failed to access attribute %s in model object %s", key, exp)
|
|
960
|
-
|
|
961
|
-
return attrs
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
def infer_dtypes(fit, model=None):
|
|
965
|
-
"""Infer dtypes from Stan model code.
|
|
966
|
-
|
|
967
|
-
Function strips out generated quantities block and searches for `int`
|
|
968
|
-
dtypes after stripping out comments inside the block.
|
|
969
|
-
"""
|
|
970
|
-
if model is None:
|
|
971
|
-
stan_code = fit.get_stancode()
|
|
972
|
-
model_pars = fit.model_pars
|
|
973
|
-
else:
|
|
974
|
-
stan_code = model.program_code
|
|
975
|
-
model_pars = fit.param_names
|
|
976
|
-
|
|
977
|
-
dtypes = {key: item for key, item in infer_stan_dtypes(stan_code).items() if key in model_pars}
|
|
978
|
-
return dtypes
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
# pylint disable=too-many-instance-attributes
|
|
982
|
-
def from_pystan(
|
|
983
|
-
posterior=None,
|
|
984
|
-
*,
|
|
985
|
-
posterior_predictive=None,
|
|
986
|
-
predictions=None,
|
|
987
|
-
prior=None,
|
|
988
|
-
prior_predictive=None,
|
|
989
|
-
observed_data=None,
|
|
990
|
-
constant_data=None,
|
|
991
|
-
predictions_constant_data=None,
|
|
992
|
-
log_likelihood=None,
|
|
993
|
-
coords=None,
|
|
994
|
-
dims=None,
|
|
995
|
-
posterior_model=None,
|
|
996
|
-
prior_model=None,
|
|
997
|
-
save_warmup=None,
|
|
998
|
-
dtypes=None,
|
|
999
|
-
):
|
|
1000
|
-
"""Convert PyStan data into an InferenceData object.
|
|
1001
|
-
|
|
1002
|
-
For a usage example read the
|
|
1003
|
-
:ref:`Creating InferenceData section on from_pystan <creating_InferenceData>`
|
|
1004
|
-
|
|
1005
|
-
Parameters
|
|
1006
|
-
----------
|
|
1007
|
-
posterior : StanFit4Model or stan.fit.Fit
|
|
1008
|
-
PyStan fit object for posterior.
|
|
1009
|
-
posterior_predictive : str, a list of str
|
|
1010
|
-
Posterior predictive samples for the posterior.
|
|
1011
|
-
predictions : str, a list of str
|
|
1012
|
-
Out-of-sample predictions for the posterior.
|
|
1013
|
-
prior : StanFit4Model or stan.fit.Fit
|
|
1014
|
-
PyStan fit object for prior.
|
|
1015
|
-
prior_predictive : str, a list of str
|
|
1016
|
-
Posterior predictive samples for the prior.
|
|
1017
|
-
observed_data : str or a list of str
|
|
1018
|
-
observed data used in the sampling.
|
|
1019
|
-
Observed data is extracted from the `posterior.data`.
|
|
1020
|
-
PyStan3 needs model object for the extraction.
|
|
1021
|
-
See `posterior_model`.
|
|
1022
|
-
constant_data : str or list of str
|
|
1023
|
-
Constants relevant to the model (i.e. x values in a linear
|
|
1024
|
-
regression).
|
|
1025
|
-
predictions_constant_data : str or list of str
|
|
1026
|
-
Constants relevant to the model predictions (i.e. new x values in a linear
|
|
1027
|
-
regression).
|
|
1028
|
-
log_likelihood : dict of {str: str}, list of str or str, optional
|
|
1029
|
-
Pointwise log_likelihood for the data. log_likelihood is extracted from the
|
|
1030
|
-
posterior. It is recommended to use this argument as a dictionary whose keys
|
|
1031
|
-
are observed variable names and its values are the variables storing log
|
|
1032
|
-
likelihood arrays in the Stan code. In other cases, a dictionary with keys
|
|
1033
|
-
equal to its values is used. By default, if a variable ``log_lik`` is
|
|
1034
|
-
present in the Stan model, it will be retrieved as pointwise log
|
|
1035
|
-
likelihood values. Use ``False`` or set ``data.log_likelihood`` to
|
|
1036
|
-
false to avoid this behaviour.
|
|
1037
|
-
coords : dict[str, iterable]
|
|
1038
|
-
A dictionary containing the values that are used as index. The key
|
|
1039
|
-
is the name of the dimension, the values are the index values.
|
|
1040
|
-
dims : dict[str, List(str)]
|
|
1041
|
-
A mapping from variables to a list of coordinate names for the variable.
|
|
1042
|
-
posterior_model : stan.model.Model
|
|
1043
|
-
PyStan3 specific model object. Needed for automatic dtype parsing
|
|
1044
|
-
and for the extraction of observed data.
|
|
1045
|
-
prior_model : stan.model.Model
|
|
1046
|
-
PyStan3 specific model object. Needed for automatic dtype parsing.
|
|
1047
|
-
save_warmup : bool
|
|
1048
|
-
Save warmup iterations into InferenceData object. If not defined, use default
|
|
1049
|
-
defined by the rcParams.
|
|
1050
|
-
dtypes: dict
|
|
1051
|
-
A dictionary containing dtype information (int, float) for parameters.
|
|
1052
|
-
By default dtype information is extracted from the model code.
|
|
1053
|
-
Model code is extracted from fit object in PyStan 2 and from model object
|
|
1054
|
-
in PyStan 3.
|
|
1055
|
-
|
|
1056
|
-
Returns
|
|
1057
|
-
-------
|
|
1058
|
-
InferenceData object
|
|
1059
|
-
"""
|
|
1060
|
-
check_posterior = (posterior is not None) and (type(posterior).__module__ == "stan.fit")
|
|
1061
|
-
check_prior = (prior is not None) and (type(prior).__module__ == "stan.fit")
|
|
1062
|
-
if check_posterior or check_prior:
|
|
1063
|
-
return PyStan3Converter(
|
|
1064
|
-
posterior=posterior,
|
|
1065
|
-
posterior_model=posterior_model,
|
|
1066
|
-
posterior_predictive=posterior_predictive,
|
|
1067
|
-
predictions=predictions,
|
|
1068
|
-
prior=prior,
|
|
1069
|
-
prior_model=prior_model,
|
|
1070
|
-
prior_predictive=prior_predictive,
|
|
1071
|
-
observed_data=observed_data,
|
|
1072
|
-
constant_data=constant_data,
|
|
1073
|
-
predictions_constant_data=predictions_constant_data,
|
|
1074
|
-
log_likelihood=log_likelihood,
|
|
1075
|
-
coords=coords,
|
|
1076
|
-
dims=dims,
|
|
1077
|
-
save_warmup=save_warmup,
|
|
1078
|
-
dtypes=dtypes,
|
|
1079
|
-
).to_inference_data()
|
|
1080
|
-
else:
|
|
1081
|
-
return PyStanConverter(
|
|
1082
|
-
posterior=posterior,
|
|
1083
|
-
posterior_predictive=posterior_predictive,
|
|
1084
|
-
predictions=predictions,
|
|
1085
|
-
prior=prior,
|
|
1086
|
-
prior_predictive=prior_predictive,
|
|
1087
|
-
observed_data=observed_data,
|
|
1088
|
-
constant_data=constant_data,
|
|
1089
|
-
predictions_constant_data=predictions_constant_data,
|
|
1090
|
-
log_likelihood=log_likelihood,
|
|
1091
|
-
coords=coords,
|
|
1092
|
-
dims=dims,
|
|
1093
|
-
save_warmup=save_warmup,
|
|
1094
|
-
dtypes=dtypes,
|
|
1095
|
-
).to_inference_data()
|