arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_cmdstanpy.py
DELETED
|
@@ -1,1233 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines
|
|
2
|
-
"""CmdStanPy-specific conversion code."""
|
|
3
|
-
import logging
|
|
4
|
-
import re
|
|
5
|
-
from collections import defaultdict
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
|
|
11
|
-
from ..rcparams import rcParams
|
|
12
|
-
from .base import dict_to_dataset, infer_stan_dtypes, make_attrs, requires
|
|
13
|
-
from .inference_data import InferenceData
|
|
14
|
-
|
|
15
|
-
_log = logging.getLogger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class CmdStanPyConverter:
|
|
19
|
-
"""Encapsulate CmdStanPy specific logic."""
|
|
20
|
-
|
|
21
|
-
# pylint: disable=too-many-instance-attributes
|
|
22
|
-
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
*,
|
|
26
|
-
posterior=None,
|
|
27
|
-
posterior_predictive=None,
|
|
28
|
-
predictions=None,
|
|
29
|
-
prior=None,
|
|
30
|
-
prior_predictive=None,
|
|
31
|
-
observed_data=None,
|
|
32
|
-
constant_data=None,
|
|
33
|
-
predictions_constant_data=None,
|
|
34
|
-
log_likelihood=None,
|
|
35
|
-
index_origin=None,
|
|
36
|
-
coords=None,
|
|
37
|
-
dims=None,
|
|
38
|
-
save_warmup=None,
|
|
39
|
-
dtypes=None,
|
|
40
|
-
):
|
|
41
|
-
self.posterior = posterior # CmdStanPy CmdStanMCMC object
|
|
42
|
-
self.posterior_predictive = posterior_predictive
|
|
43
|
-
self.predictions = predictions
|
|
44
|
-
self.prior = prior
|
|
45
|
-
self.prior_predictive = prior_predictive
|
|
46
|
-
self.observed_data = observed_data
|
|
47
|
-
self.constant_data = constant_data
|
|
48
|
-
self.predictions_constant_data = predictions_constant_data
|
|
49
|
-
self.log_likelihood = (
|
|
50
|
-
rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
|
|
51
|
-
)
|
|
52
|
-
self.index_origin = index_origin
|
|
53
|
-
self.coords = coords
|
|
54
|
-
self.dims = dims
|
|
55
|
-
|
|
56
|
-
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
57
|
-
|
|
58
|
-
import cmdstanpy # pylint: disable=import-error
|
|
59
|
-
|
|
60
|
-
if dtypes is None:
|
|
61
|
-
dtypes = {}
|
|
62
|
-
elif isinstance(dtypes, cmdstanpy.model.CmdStanModel):
|
|
63
|
-
model_code = dtypes.code()
|
|
64
|
-
dtypes = infer_stan_dtypes(model_code)
|
|
65
|
-
elif isinstance(dtypes, str):
|
|
66
|
-
dtypes_path = Path(dtypes)
|
|
67
|
-
if dtypes_path.exists():
|
|
68
|
-
with dtypes_path.open("r", encoding="UTF-8") as f_obj:
|
|
69
|
-
model_code = f_obj.read()
|
|
70
|
-
else:
|
|
71
|
-
model_code = dtypes
|
|
72
|
-
dtypes = infer_stan_dtypes(model_code)
|
|
73
|
-
|
|
74
|
-
self.dtypes = dtypes
|
|
75
|
-
|
|
76
|
-
if hasattr(self.posterior, "metadata") and hasattr(
|
|
77
|
-
self.posterior.metadata, "stan_vars_cols"
|
|
78
|
-
):
|
|
79
|
-
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
|
|
80
|
-
self.log_likelihood = ["log_lik"]
|
|
81
|
-
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
82
|
-
self.posterior.metadata, "stan_vars_cols"
|
|
83
|
-
):
|
|
84
|
-
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
|
|
85
|
-
self.log_likelihood = ["log_lik"]
|
|
86
|
-
elif hasattr(self.posterior, "stan_vars_cols"):
|
|
87
|
-
if self.log_likelihood is True and "log_lik" in self.posterior.stan_vars_cols:
|
|
88
|
-
self.log_likelihood = ["log_lik"]
|
|
89
|
-
elif hasattr(self.posterior, "metadata") and hasattr(self.posterior.metadata, "stan_vars"):
|
|
90
|
-
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars:
|
|
91
|
-
self.log_likelihood = ["log_lik"]
|
|
92
|
-
elif (
|
|
93
|
-
self.log_likelihood is True
|
|
94
|
-
and self.posterior is not None
|
|
95
|
-
and hasattr(self.posterior, "column_names")
|
|
96
|
-
and any(name.split("[")[0] == "log_lik" for name in self.posterior.column_names)
|
|
97
|
-
):
|
|
98
|
-
self.log_likelihood = ["log_lik"]
|
|
99
|
-
|
|
100
|
-
if isinstance(self.log_likelihood, bool):
|
|
101
|
-
self.log_likelihood = None
|
|
102
|
-
|
|
103
|
-
self.cmdstanpy = cmdstanpy
|
|
104
|
-
|
|
105
|
-
@requires("posterior")
|
|
106
|
-
def posterior_to_xarray(self):
|
|
107
|
-
"""Extract posterior samples from output csv."""
|
|
108
|
-
if not (hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")):
|
|
109
|
-
return self.posterior_to_xarray_pre_v_0_9_68()
|
|
110
|
-
if (
|
|
111
|
-
hasattr(self.posterior, "metadata")
|
|
112
|
-
and hasattr(self.posterior.metadata, "stan_vars_cols")
|
|
113
|
-
) or hasattr(self.posterior, "stan_vars_cols"):
|
|
114
|
-
return self.posterior_to_xarray_pre_v_1_0_0()
|
|
115
|
-
if hasattr(self.posterior, "metadata") and hasattr(
|
|
116
|
-
self.posterior.metadata, "stan_vars_cols"
|
|
117
|
-
):
|
|
118
|
-
return self.posterior_to_xarray_pre_v_1_2_0()
|
|
119
|
-
|
|
120
|
-
items = list(self.posterior.metadata.stan_vars)
|
|
121
|
-
if self.posterior_predictive is not None:
|
|
122
|
-
try:
|
|
123
|
-
items = _filter(items, self.posterior_predictive)
|
|
124
|
-
except ValueError:
|
|
125
|
-
pass
|
|
126
|
-
if self.predictions is not None:
|
|
127
|
-
try:
|
|
128
|
-
items = _filter(items, self.predictions)
|
|
129
|
-
except ValueError:
|
|
130
|
-
pass
|
|
131
|
-
if self.log_likelihood is not None:
|
|
132
|
-
try:
|
|
133
|
-
items = _filter(items, self.log_likelihood)
|
|
134
|
-
except ValueError:
|
|
135
|
-
pass
|
|
136
|
-
|
|
137
|
-
valid_cols = []
|
|
138
|
-
for item in items:
|
|
139
|
-
if hasattr(self.posterior, "metadata"):
|
|
140
|
-
if item in self.posterior.metadata.stan_vars:
|
|
141
|
-
valid_cols.append(item)
|
|
142
|
-
|
|
143
|
-
data, data_warmup = _unpack_fit(
|
|
144
|
-
self.posterior,
|
|
145
|
-
items,
|
|
146
|
-
self.save_warmup,
|
|
147
|
-
self.dtypes,
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
151
|
-
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
152
|
-
|
|
153
|
-
return (
|
|
154
|
-
dict_to_dataset(
|
|
155
|
-
data,
|
|
156
|
-
library=self.cmdstanpy,
|
|
157
|
-
coords=coords,
|
|
158
|
-
dims=dims,
|
|
159
|
-
index_origin=self.index_origin,
|
|
160
|
-
),
|
|
161
|
-
dict_to_dataset(
|
|
162
|
-
data_warmup,
|
|
163
|
-
library=self.cmdstanpy,
|
|
164
|
-
coords=coords,
|
|
165
|
-
dims=dims,
|
|
166
|
-
index_origin=self.index_origin,
|
|
167
|
-
),
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
@requires("posterior")
|
|
171
|
-
def sample_stats_to_xarray(self):
|
|
172
|
-
"""Extract sample_stats from prosterior fit."""
|
|
173
|
-
return self.stats_to_xarray(self.posterior)
|
|
174
|
-
|
|
175
|
-
@requires("prior")
|
|
176
|
-
def sample_stats_prior_to_xarray(self):
|
|
177
|
-
"""Extract sample_stats from prior fit."""
|
|
178
|
-
return self.stats_to_xarray(self.prior)
|
|
179
|
-
|
|
180
|
-
def stats_to_xarray(self, fit):
|
|
181
|
-
"""Extract sample_stats from fit."""
|
|
182
|
-
if not (hasattr(fit, "metadata") or hasattr(fit, "sampler_vars_cols")):
|
|
183
|
-
return self.sample_stats_to_xarray_pre_v_0_9_68(fit)
|
|
184
|
-
if (hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols")) or hasattr(
|
|
185
|
-
fit, "stan_vars_cols"
|
|
186
|
-
):
|
|
187
|
-
return self.sample_stats_to_xarray_pre_v_1_0_0(fit)
|
|
188
|
-
if hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"):
|
|
189
|
-
return self.sample_stats_to_xarray_pre_v_1_2_0(fit)
|
|
190
|
-
|
|
191
|
-
dtypes = {
|
|
192
|
-
"divergent__": bool,
|
|
193
|
-
"n_leapfrog__": np.int64,
|
|
194
|
-
"treedepth__": np.int64,
|
|
195
|
-
**self.dtypes,
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
items = list(fit.method_variables()) # pylint: disable=protected-access
|
|
199
|
-
|
|
200
|
-
rename_dict = {
|
|
201
|
-
"divergent": "diverging",
|
|
202
|
-
"n_leapfrog": "n_steps",
|
|
203
|
-
"treedepth": "tree_depth",
|
|
204
|
-
"stepsize": "step_size",
|
|
205
|
-
"accept_stat": "acceptance_rate",
|
|
206
|
-
}
|
|
207
|
-
|
|
208
|
-
data, data_warmup = _unpack_fit(
|
|
209
|
-
fit,
|
|
210
|
-
items,
|
|
211
|
-
self.save_warmup,
|
|
212
|
-
self.dtypes,
|
|
213
|
-
)
|
|
214
|
-
for item in items:
|
|
215
|
-
name = re.sub("__$", "", item)
|
|
216
|
-
name = rename_dict.get(name, name)
|
|
217
|
-
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
218
|
-
if data_warmup:
|
|
219
|
-
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
220
|
-
|
|
221
|
-
return (
|
|
222
|
-
dict_to_dataset(
|
|
223
|
-
data,
|
|
224
|
-
library=self.cmdstanpy,
|
|
225
|
-
coords=self.coords,
|
|
226
|
-
dims=self.dims,
|
|
227
|
-
index_origin=self.index_origin,
|
|
228
|
-
),
|
|
229
|
-
dict_to_dataset(
|
|
230
|
-
data_warmup,
|
|
231
|
-
library=self.cmdstanpy,
|
|
232
|
-
coords=self.coords,
|
|
233
|
-
dims=self.dims,
|
|
234
|
-
index_origin=self.index_origin,
|
|
235
|
-
),
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
@requires("posterior")
|
|
239
|
-
@requires("posterior_predictive")
|
|
240
|
-
def posterior_predictive_to_xarray(self):
|
|
241
|
-
"""Convert posterior_predictive samples to xarray."""
|
|
242
|
-
return self.predictive_to_xarray(self.posterior_predictive, self.posterior)
|
|
243
|
-
|
|
244
|
-
@requires("prior")
|
|
245
|
-
@requires("prior_predictive")
|
|
246
|
-
def prior_predictive_to_xarray(self):
|
|
247
|
-
"""Convert prior_predictive samples to xarray."""
|
|
248
|
-
return self.predictive_to_xarray(self.prior_predictive, self.prior)
|
|
249
|
-
|
|
250
|
-
def predictive_to_xarray(self, names, fit):
|
|
251
|
-
"""Convert predictive samples to xarray."""
|
|
252
|
-
predictive = _as_set(names)
|
|
253
|
-
|
|
254
|
-
if not (hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols")): # pre_v_0_9_68
|
|
255
|
-
valid_cols = _filter_columns(fit.column_names, predictive)
|
|
256
|
-
data, data_warmup = _unpack_frame(
|
|
257
|
-
fit,
|
|
258
|
-
fit.column_names,
|
|
259
|
-
valid_cols,
|
|
260
|
-
self.save_warmup,
|
|
261
|
-
self.dtypes,
|
|
262
|
-
)
|
|
263
|
-
elif (hasattr(fit, "metadata") and hasattr(fit.metadata, "sample_vars_cols")) or hasattr(
|
|
264
|
-
fit, "stan_vars_cols"
|
|
265
|
-
): # pre_v_1_0_0
|
|
266
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
267
|
-
fit,
|
|
268
|
-
predictive,
|
|
269
|
-
self.save_warmup,
|
|
270
|
-
self.dtypes,
|
|
271
|
-
)
|
|
272
|
-
elif hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"): # pre_v_1_2_0
|
|
273
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
274
|
-
fit,
|
|
275
|
-
predictive,
|
|
276
|
-
self.save_warmup,
|
|
277
|
-
self.dtypes,
|
|
278
|
-
)
|
|
279
|
-
else:
|
|
280
|
-
data, data_warmup = _unpack_fit(
|
|
281
|
-
fit,
|
|
282
|
-
predictive,
|
|
283
|
-
self.save_warmup,
|
|
284
|
-
self.dtypes,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
return (
|
|
288
|
-
dict_to_dataset(
|
|
289
|
-
data,
|
|
290
|
-
library=self.cmdstanpy,
|
|
291
|
-
coords=self.coords,
|
|
292
|
-
dims=self.dims,
|
|
293
|
-
index_origin=self.index_origin,
|
|
294
|
-
),
|
|
295
|
-
dict_to_dataset(
|
|
296
|
-
data_warmup,
|
|
297
|
-
library=self.cmdstanpy,
|
|
298
|
-
coords=self.coords,
|
|
299
|
-
dims=self.dims,
|
|
300
|
-
index_origin=self.index_origin,
|
|
301
|
-
),
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
@requires("posterior")
|
|
305
|
-
@requires("predictions")
|
|
306
|
-
def predictions_to_xarray(self):
|
|
307
|
-
"""Convert out of sample predictions samples to xarray."""
|
|
308
|
-
predictions = _as_set(self.predictions)
|
|
309
|
-
|
|
310
|
-
if not (
|
|
311
|
-
hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
|
|
312
|
-
): # pre_v_0_9_68
|
|
313
|
-
columns = self.posterior.column_names
|
|
314
|
-
valid_cols = _filter_columns(columns, predictions)
|
|
315
|
-
data, data_warmup = _unpack_frame(
|
|
316
|
-
self.posterior,
|
|
317
|
-
columns,
|
|
318
|
-
valid_cols,
|
|
319
|
-
self.save_warmup,
|
|
320
|
-
self.dtypes,
|
|
321
|
-
)
|
|
322
|
-
elif (
|
|
323
|
-
hasattr(self.posterior, "metadata")
|
|
324
|
-
and hasattr(self.posterior.metadata, "sample_vars_cols")
|
|
325
|
-
) or hasattr(
|
|
326
|
-
self.posterior, "stan_vars_cols"
|
|
327
|
-
): # pre_v_1_0_0
|
|
328
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
329
|
-
self.posterior,
|
|
330
|
-
predictions,
|
|
331
|
-
self.save_warmup,
|
|
332
|
-
self.dtypes,
|
|
333
|
-
)
|
|
334
|
-
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
335
|
-
self.posterior.metadata, "stan_vars_cols"
|
|
336
|
-
): # pre_v_1_2_0
|
|
337
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
338
|
-
self.posterior,
|
|
339
|
-
predictions,
|
|
340
|
-
self.save_warmup,
|
|
341
|
-
self.dtypes,
|
|
342
|
-
)
|
|
343
|
-
else:
|
|
344
|
-
data, data_warmup = _unpack_fit(
|
|
345
|
-
self.posterior,
|
|
346
|
-
predictions,
|
|
347
|
-
self.save_warmup,
|
|
348
|
-
self.dtypes,
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
return (
|
|
352
|
-
dict_to_dataset(
|
|
353
|
-
data,
|
|
354
|
-
library=self.cmdstanpy,
|
|
355
|
-
coords=self.coords,
|
|
356
|
-
dims=self.dims,
|
|
357
|
-
index_origin=self.index_origin,
|
|
358
|
-
),
|
|
359
|
-
dict_to_dataset(
|
|
360
|
-
data_warmup,
|
|
361
|
-
library=self.cmdstanpy,
|
|
362
|
-
coords=self.coords,
|
|
363
|
-
dims=self.dims,
|
|
364
|
-
index_origin=self.index_origin,
|
|
365
|
-
),
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
@requires("posterior")
|
|
369
|
-
@requires("log_likelihood")
|
|
370
|
-
def log_likelihood_to_xarray(self):
|
|
371
|
-
"""Convert elementwise log likelihood samples to xarray."""
|
|
372
|
-
log_likelihood = _as_set(self.log_likelihood)
|
|
373
|
-
|
|
374
|
-
if not (
|
|
375
|
-
hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
|
|
376
|
-
): # pre_v_0_9_68
|
|
377
|
-
columns = self.posterior.column_names
|
|
378
|
-
valid_cols = _filter_columns(columns, log_likelihood)
|
|
379
|
-
data, data_warmup = _unpack_frame(
|
|
380
|
-
self.posterior,
|
|
381
|
-
columns,
|
|
382
|
-
valid_cols,
|
|
383
|
-
self.save_warmup,
|
|
384
|
-
self.dtypes,
|
|
385
|
-
)
|
|
386
|
-
elif (
|
|
387
|
-
hasattr(self.posterior, "metadata")
|
|
388
|
-
and hasattr(self.posterior.metadata, "sample_vars_cols")
|
|
389
|
-
) or hasattr(
|
|
390
|
-
self.posterior, "stan_vars_cols"
|
|
391
|
-
): # pre_v_1_0_0
|
|
392
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
393
|
-
self.posterior,
|
|
394
|
-
log_likelihood,
|
|
395
|
-
self.save_warmup,
|
|
396
|
-
self.dtypes,
|
|
397
|
-
)
|
|
398
|
-
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
399
|
-
self.posterior.metadata, "stan_vars_cols"
|
|
400
|
-
): # pre_v_1_2_0
|
|
401
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
402
|
-
self.posterior,
|
|
403
|
-
log_likelihood,
|
|
404
|
-
self.save_warmup,
|
|
405
|
-
self.dtypes,
|
|
406
|
-
)
|
|
407
|
-
else:
|
|
408
|
-
data, data_warmup = _unpack_fit(
|
|
409
|
-
self.posterior,
|
|
410
|
-
log_likelihood,
|
|
411
|
-
self.save_warmup,
|
|
412
|
-
self.dtypes,
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
if isinstance(self.log_likelihood, dict):
|
|
416
|
-
data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
|
|
417
|
-
if data_warmup:
|
|
418
|
-
data_warmup = {
|
|
419
|
-
obs_name: data_warmup[lik_name]
|
|
420
|
-
for obs_name, lik_name in self.log_likelihood.items()
|
|
421
|
-
}
|
|
422
|
-
return (
|
|
423
|
-
dict_to_dataset(
|
|
424
|
-
data,
|
|
425
|
-
library=self.cmdstanpy,
|
|
426
|
-
coords=self.coords,
|
|
427
|
-
dims=self.dims,
|
|
428
|
-
index_origin=self.index_origin,
|
|
429
|
-
skip_event_dims=True,
|
|
430
|
-
),
|
|
431
|
-
dict_to_dataset(
|
|
432
|
-
data_warmup,
|
|
433
|
-
library=self.cmdstanpy,
|
|
434
|
-
coords=self.coords,
|
|
435
|
-
dims=self.dims,
|
|
436
|
-
index_origin=self.index_origin,
|
|
437
|
-
skip_event_dims=True,
|
|
438
|
-
),
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
@requires("prior")
|
|
442
|
-
def prior_to_xarray(self):
|
|
443
|
-
"""Convert prior samples to xarray."""
|
|
444
|
-
if not (
|
|
445
|
-
hasattr(self.prior, "metadata") or hasattr(self.prior, "stan_vars_cols")
|
|
446
|
-
): # pre_v_0_9_68
|
|
447
|
-
columns = self.prior.column_names
|
|
448
|
-
prior_predictive = _as_set(self.prior_predictive)
|
|
449
|
-
prior_predictive = _filter_columns(columns, prior_predictive)
|
|
450
|
-
|
|
451
|
-
invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
|
|
452
|
-
valid_cols = [col for col in columns if col not in invalid_cols]
|
|
453
|
-
|
|
454
|
-
data, data_warmup = _unpack_frame(
|
|
455
|
-
self.prior,
|
|
456
|
-
columns,
|
|
457
|
-
valid_cols,
|
|
458
|
-
self.save_warmup,
|
|
459
|
-
self.dtypes,
|
|
460
|
-
)
|
|
461
|
-
elif (
|
|
462
|
-
hasattr(self.prior, "metadata") and hasattr(self.prior.metadata, "sample_vars_cols")
|
|
463
|
-
) or hasattr(
|
|
464
|
-
self.prior, "stan_vars_cols"
|
|
465
|
-
): # pre_v_1_0_0
|
|
466
|
-
if hasattr(self.prior, "metadata"):
|
|
467
|
-
items = list(self.prior.metadata.stan_vars_cols.keys())
|
|
468
|
-
else:
|
|
469
|
-
items = list(self.prior.stan_vars_cols.keys())
|
|
470
|
-
if self.prior_predictive is not None:
|
|
471
|
-
try:
|
|
472
|
-
items = _filter(items, self.prior_predictive)
|
|
473
|
-
except ValueError:
|
|
474
|
-
pass
|
|
475
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
476
|
-
self.prior,
|
|
477
|
-
items,
|
|
478
|
-
self.save_warmup,
|
|
479
|
-
self.dtypes,
|
|
480
|
-
)
|
|
481
|
-
elif hasattr(self.prior, "metadata") and hasattr(
|
|
482
|
-
self.prior.metadata, "stan_vars_cols"
|
|
483
|
-
): # pre_v_1_2_0
|
|
484
|
-
items = list(self.prior.metadata.stan_vars_cols.keys())
|
|
485
|
-
if self.prior_predictive is not None:
|
|
486
|
-
try:
|
|
487
|
-
items = _filter(items, self.prior_predictive)
|
|
488
|
-
except ValueError:
|
|
489
|
-
pass
|
|
490
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
491
|
-
self.prior,
|
|
492
|
-
items,
|
|
493
|
-
self.save_warmup,
|
|
494
|
-
self.dtypes,
|
|
495
|
-
)
|
|
496
|
-
else:
|
|
497
|
-
items = list(self.prior.metadata.stan_vars.keys())
|
|
498
|
-
if self.prior_predictive is not None:
|
|
499
|
-
try:
|
|
500
|
-
items = _filter(items, self.prior_predictive)
|
|
501
|
-
except ValueError:
|
|
502
|
-
pass
|
|
503
|
-
data, data_warmup = _unpack_fit(
|
|
504
|
-
self.prior,
|
|
505
|
-
items,
|
|
506
|
-
self.save_warmup,
|
|
507
|
-
self.dtypes,
|
|
508
|
-
)
|
|
509
|
-
|
|
510
|
-
return (
|
|
511
|
-
dict_to_dataset(
|
|
512
|
-
data,
|
|
513
|
-
library=self.cmdstanpy,
|
|
514
|
-
coords=self.coords,
|
|
515
|
-
dims=self.dims,
|
|
516
|
-
index_origin=self.index_origin,
|
|
517
|
-
),
|
|
518
|
-
dict_to_dataset(
|
|
519
|
-
data_warmup,
|
|
520
|
-
library=self.cmdstanpy,
|
|
521
|
-
coords=self.coords,
|
|
522
|
-
dims=self.dims,
|
|
523
|
-
index_origin=self.index_origin,
|
|
524
|
-
),
|
|
525
|
-
)
|
|
526
|
-
|
|
527
|
-
@requires("observed_data")
|
|
528
|
-
def observed_data_to_xarray(self):
|
|
529
|
-
"""Convert observed data to xarray."""
|
|
530
|
-
return dict_to_dataset(
|
|
531
|
-
self.observed_data,
|
|
532
|
-
library=self.cmdstanpy,
|
|
533
|
-
coords=self.coords,
|
|
534
|
-
dims=self.dims,
|
|
535
|
-
default_dims=[],
|
|
536
|
-
index_origin=self.index_origin,
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
@requires("constant_data")
|
|
540
|
-
def constant_data_to_xarray(self):
|
|
541
|
-
"""Convert constant data to xarray."""
|
|
542
|
-
return dict_to_dataset(
|
|
543
|
-
self.constant_data,
|
|
544
|
-
library=self.cmdstanpy,
|
|
545
|
-
coords=self.coords,
|
|
546
|
-
dims=self.dims,
|
|
547
|
-
default_dims=[],
|
|
548
|
-
index_origin=self.index_origin,
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
@requires("predictions_constant_data")
|
|
552
|
-
def predictions_constant_data_to_xarray(self):
|
|
553
|
-
"""Convert constant data to xarray."""
|
|
554
|
-
return dict_to_dataset(
|
|
555
|
-
self.predictions_constant_data,
|
|
556
|
-
library=self.cmdstanpy,
|
|
557
|
-
coords=self.coords,
|
|
558
|
-
dims=self.dims,
|
|
559
|
-
attrs=make_attrs(library=self.cmdstanpy),
|
|
560
|
-
default_dims=[],
|
|
561
|
-
index_origin=self.index_origin,
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
def to_inference_data(self):
|
|
565
|
-
"""Convert all available data to an InferenceData object.
|
|
566
|
-
|
|
567
|
-
Note that if groups can not be created (i.e., there is no `output`, so
|
|
568
|
-
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
|
|
569
|
-
will not have those groups.
|
|
570
|
-
"""
|
|
571
|
-
return InferenceData(
|
|
572
|
-
save_warmup=self.save_warmup,
|
|
573
|
-
**{
|
|
574
|
-
"posterior": self.posterior_to_xarray(),
|
|
575
|
-
"sample_stats": self.sample_stats_to_xarray(),
|
|
576
|
-
"posterior_predictive": self.posterior_predictive_to_xarray(),
|
|
577
|
-
"predictions": self.predictions_to_xarray(),
|
|
578
|
-
"prior": self.prior_to_xarray(),
|
|
579
|
-
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
|
|
580
|
-
"prior_predictive": self.prior_predictive_to_xarray(),
|
|
581
|
-
"observed_data": self.observed_data_to_xarray(),
|
|
582
|
-
"constant_data": self.constant_data_to_xarray(),
|
|
583
|
-
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
|
|
584
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
585
|
-
},
|
|
586
|
-
)
|
|
587
|
-
|
|
588
|
-
def posterior_to_xarray_pre_v_1_2_0(self):
|
|
589
|
-
items = list(self.posterior.metadata.stan_vars_cols)
|
|
590
|
-
if self.posterior_predictive is not None:
|
|
591
|
-
try:
|
|
592
|
-
items = _filter(items, self.posterior_predictive)
|
|
593
|
-
except ValueError:
|
|
594
|
-
pass
|
|
595
|
-
if self.predictions is not None:
|
|
596
|
-
try:
|
|
597
|
-
items = _filter(items, self.predictions)
|
|
598
|
-
except ValueError:
|
|
599
|
-
pass
|
|
600
|
-
if self.log_likelihood is not None:
|
|
601
|
-
try:
|
|
602
|
-
items = _filter(items, self.log_likelihood)
|
|
603
|
-
except ValueError:
|
|
604
|
-
pass
|
|
605
|
-
|
|
606
|
-
valid_cols = []
|
|
607
|
-
for item in items:
|
|
608
|
-
if hasattr(self.posterior, "metadata"):
|
|
609
|
-
if item in self.posterior.metadata.stan_vars_cols:
|
|
610
|
-
valid_cols.append(item)
|
|
611
|
-
|
|
612
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
613
|
-
self.posterior,
|
|
614
|
-
items,
|
|
615
|
-
self.save_warmup,
|
|
616
|
-
self.dtypes,
|
|
617
|
-
)
|
|
618
|
-
|
|
619
|
-
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
620
|
-
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
621
|
-
|
|
622
|
-
return (
|
|
623
|
-
dict_to_dataset(
|
|
624
|
-
data,
|
|
625
|
-
library=self.cmdstanpy,
|
|
626
|
-
coords=coords,
|
|
627
|
-
dims=dims,
|
|
628
|
-
index_origin=self.index_origin,
|
|
629
|
-
),
|
|
630
|
-
dict_to_dataset(
|
|
631
|
-
data_warmup,
|
|
632
|
-
library=self.cmdstanpy,
|
|
633
|
-
coords=coords,
|
|
634
|
-
dims=dims,
|
|
635
|
-
index_origin=self.index_origin,
|
|
636
|
-
),
|
|
637
|
-
)
|
|
638
|
-
|
|
639
|
-
@requires("posterior")
|
|
640
|
-
def posterior_to_xarray_pre_v_1_0_0(self):
|
|
641
|
-
if hasattr(self.posterior, "metadata"):
|
|
642
|
-
items = list(self.posterior.metadata.stan_vars_cols.keys())
|
|
643
|
-
else:
|
|
644
|
-
items = list(self.posterior.stan_vars_cols.keys())
|
|
645
|
-
if self.posterior_predictive is not None:
|
|
646
|
-
try:
|
|
647
|
-
items = _filter(items, self.posterior_predictive)
|
|
648
|
-
except ValueError:
|
|
649
|
-
pass
|
|
650
|
-
if self.predictions is not None:
|
|
651
|
-
try:
|
|
652
|
-
items = _filter(items, self.predictions)
|
|
653
|
-
except ValueError:
|
|
654
|
-
pass
|
|
655
|
-
if self.log_likelihood is not None:
|
|
656
|
-
try:
|
|
657
|
-
items = _filter(items, self.log_likelihood)
|
|
658
|
-
except ValueError:
|
|
659
|
-
pass
|
|
660
|
-
|
|
661
|
-
valid_cols = []
|
|
662
|
-
for item in items:
|
|
663
|
-
if hasattr(self.posterior, "metadata"):
|
|
664
|
-
valid_cols.extend(self.posterior.metadata.stan_vars_cols[item])
|
|
665
|
-
else:
|
|
666
|
-
valid_cols.extend(self.posterior.stan_vars_cols[item])
|
|
667
|
-
|
|
668
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
669
|
-
self.posterior,
|
|
670
|
-
items,
|
|
671
|
-
self.save_warmup,
|
|
672
|
-
self.dtypes,
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
676
|
-
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
677
|
-
|
|
678
|
-
return (
|
|
679
|
-
dict_to_dataset(
|
|
680
|
-
data,
|
|
681
|
-
library=self.cmdstanpy,
|
|
682
|
-
coords=coords,
|
|
683
|
-
dims=dims,
|
|
684
|
-
index_origin=self.index_origin,
|
|
685
|
-
),
|
|
686
|
-
dict_to_dataset(
|
|
687
|
-
data_warmup,
|
|
688
|
-
library=self.cmdstanpy,
|
|
689
|
-
coords=coords,
|
|
690
|
-
dims=dims,
|
|
691
|
-
index_origin=self.index_origin,
|
|
692
|
-
),
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
@requires("posterior")
|
|
696
|
-
def posterior_to_xarray_pre_v_0_9_68(self):
|
|
697
|
-
"""Extract posterior samples from output csv."""
|
|
698
|
-
columns = self.posterior.column_names
|
|
699
|
-
|
|
700
|
-
# filter posterior_predictive, predictions and log_likelihood
|
|
701
|
-
posterior_predictive = self.posterior_predictive
|
|
702
|
-
if posterior_predictive is None:
|
|
703
|
-
posterior_predictive = []
|
|
704
|
-
elif isinstance(posterior_predictive, str):
|
|
705
|
-
posterior_predictive = [
|
|
706
|
-
col for col in columns if posterior_predictive == col.split("[")[0].split(".")[0]
|
|
707
|
-
]
|
|
708
|
-
else:
|
|
709
|
-
posterior_predictive = [
|
|
710
|
-
col
|
|
711
|
-
for col in columns
|
|
712
|
-
if any(item == col.split("[")[0].split(".")[0] for item in posterior_predictive)
|
|
713
|
-
]
|
|
714
|
-
|
|
715
|
-
predictions = self.predictions
|
|
716
|
-
if predictions is None:
|
|
717
|
-
predictions = []
|
|
718
|
-
elif isinstance(predictions, str):
|
|
719
|
-
predictions = [col for col in columns if predictions == col.split("[")[0].split(".")[0]]
|
|
720
|
-
else:
|
|
721
|
-
predictions = [
|
|
722
|
-
col
|
|
723
|
-
for col in columns
|
|
724
|
-
if any(item == col.split("[")[0].split(".")[0] for item in predictions)
|
|
725
|
-
]
|
|
726
|
-
|
|
727
|
-
log_likelihood = self.log_likelihood
|
|
728
|
-
if log_likelihood is None:
|
|
729
|
-
log_likelihood = []
|
|
730
|
-
elif isinstance(log_likelihood, str):
|
|
731
|
-
log_likelihood = [
|
|
732
|
-
col for col in columns if log_likelihood == col.split("[")[0].split(".")[0]
|
|
733
|
-
]
|
|
734
|
-
else:
|
|
735
|
-
log_likelihood = [
|
|
736
|
-
col
|
|
737
|
-
for col in columns
|
|
738
|
-
if any(item == col.split("[")[0].split(".")[0] for item in log_likelihood)
|
|
739
|
-
]
|
|
740
|
-
|
|
741
|
-
invalid_cols = set(
|
|
742
|
-
posterior_predictive
|
|
743
|
-
+ predictions
|
|
744
|
-
+ log_likelihood
|
|
745
|
-
+ [col for col in columns if col.endswith("__")]
|
|
746
|
-
)
|
|
747
|
-
valid_cols = [col for col in columns if col not in invalid_cols]
|
|
748
|
-
data, data_warmup = _unpack_frame(
|
|
749
|
-
self.posterior,
|
|
750
|
-
columns,
|
|
751
|
-
valid_cols,
|
|
752
|
-
self.save_warmup,
|
|
753
|
-
self.dtypes,
|
|
754
|
-
)
|
|
755
|
-
|
|
756
|
-
return (
|
|
757
|
-
dict_to_dataset(
|
|
758
|
-
data,
|
|
759
|
-
library=self.cmdstanpy,
|
|
760
|
-
coords=self.coords,
|
|
761
|
-
dims=self.dims,
|
|
762
|
-
index_origin=self.index_origin,
|
|
763
|
-
),
|
|
764
|
-
dict_to_dataset(
|
|
765
|
-
data_warmup,
|
|
766
|
-
library=self.cmdstanpy,
|
|
767
|
-
coords=self.coords,
|
|
768
|
-
dims=self.dims,
|
|
769
|
-
index_origin=self.index_origin,
|
|
770
|
-
),
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
def sample_stats_to_xarray_pre_v_1_2_0(self, fit):
|
|
774
|
-
dtypes = {
|
|
775
|
-
"divergent__": bool,
|
|
776
|
-
"n_leapfrog__": np.int64,
|
|
777
|
-
"treedepth__": np.int64,
|
|
778
|
-
**self.dtypes,
|
|
779
|
-
}
|
|
780
|
-
|
|
781
|
-
items = list(fit.metadata.method_vars_cols.keys()) # pylint: disable=protected-access
|
|
782
|
-
|
|
783
|
-
rename_dict = {
|
|
784
|
-
"divergent": "diverging",
|
|
785
|
-
"n_leapfrog": "n_steps",
|
|
786
|
-
"treedepth": "tree_depth",
|
|
787
|
-
"stepsize": "step_size",
|
|
788
|
-
"accept_stat": "acceptance_rate",
|
|
789
|
-
}
|
|
790
|
-
|
|
791
|
-
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
792
|
-
fit,
|
|
793
|
-
items,
|
|
794
|
-
self.save_warmup,
|
|
795
|
-
self.dtypes,
|
|
796
|
-
)
|
|
797
|
-
for item in items:
|
|
798
|
-
name = re.sub("__$", "", item)
|
|
799
|
-
name = rename_dict.get(name, name)
|
|
800
|
-
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
801
|
-
if data_warmup:
|
|
802
|
-
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
803
|
-
|
|
804
|
-
return (
|
|
805
|
-
dict_to_dataset(
|
|
806
|
-
data,
|
|
807
|
-
library=self.cmdstanpy,
|
|
808
|
-
coords=self.coords,
|
|
809
|
-
dims=self.dims,
|
|
810
|
-
index_origin=self.index_origin,
|
|
811
|
-
),
|
|
812
|
-
dict_to_dataset(
|
|
813
|
-
data_warmup,
|
|
814
|
-
library=self.cmdstanpy,
|
|
815
|
-
coords=self.coords,
|
|
816
|
-
dims=self.dims,
|
|
817
|
-
index_origin=self.index_origin,
|
|
818
|
-
),
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
def sample_stats_to_xarray_pre_v_1_0_0(self, fit):
|
|
822
|
-
"""Extract sample_stats from fit."""
|
|
823
|
-
dtypes = {
|
|
824
|
-
"divergent__": bool,
|
|
825
|
-
"n_leapfrog__": np.int64,
|
|
826
|
-
"treedepth__": np.int64,
|
|
827
|
-
**self.dtypes,
|
|
828
|
-
}
|
|
829
|
-
if hasattr(fit, "metadata"):
|
|
830
|
-
items = list(fit.metadata._method_vars_cols.keys()) # pylint: disable=protected-access
|
|
831
|
-
else:
|
|
832
|
-
items = list(fit.sampler_vars_cols.keys())
|
|
833
|
-
rename_dict = {
|
|
834
|
-
"divergent": "diverging",
|
|
835
|
-
"n_leapfrog": "n_steps",
|
|
836
|
-
"treedepth": "tree_depth",
|
|
837
|
-
"stepsize": "step_size",
|
|
838
|
-
"accept_stat": "acceptance_rate",
|
|
839
|
-
}
|
|
840
|
-
|
|
841
|
-
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
842
|
-
fit,
|
|
843
|
-
items,
|
|
844
|
-
self.save_warmup,
|
|
845
|
-
self.dtypes,
|
|
846
|
-
)
|
|
847
|
-
for item in items:
|
|
848
|
-
name = re.sub("__$", "", item)
|
|
849
|
-
name = rename_dict.get(name, name)
|
|
850
|
-
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
851
|
-
if data_warmup:
|
|
852
|
-
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
853
|
-
return (
|
|
854
|
-
dict_to_dataset(
|
|
855
|
-
data,
|
|
856
|
-
library=self.cmdstanpy,
|
|
857
|
-
coords=self.coords,
|
|
858
|
-
dims=self.dims,
|
|
859
|
-
index_origin=self.index_origin,
|
|
860
|
-
),
|
|
861
|
-
dict_to_dataset(
|
|
862
|
-
data_warmup,
|
|
863
|
-
library=self.cmdstanpy,
|
|
864
|
-
coords=self.coords,
|
|
865
|
-
dims=self.dims,
|
|
866
|
-
index_origin=self.index_origin,
|
|
867
|
-
),
|
|
868
|
-
)
|
|
869
|
-
|
|
870
|
-
def sample_stats_to_xarray_pre_v_0_9_68(self, fit):
|
|
871
|
-
"""Extract sample_stats from fit."""
|
|
872
|
-
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
|
|
873
|
-
columns = fit.column_names
|
|
874
|
-
valid_cols = [col for col in columns if col.endswith("__")]
|
|
875
|
-
data, data_warmup = _unpack_frame(
|
|
876
|
-
fit,
|
|
877
|
-
columns,
|
|
878
|
-
valid_cols,
|
|
879
|
-
self.save_warmup,
|
|
880
|
-
self.dtypes,
|
|
881
|
-
)
|
|
882
|
-
for s_param in list(data.keys()):
|
|
883
|
-
s_param_, *_ = s_param.split(".")
|
|
884
|
-
name = re.sub("__$", "", s_param_)
|
|
885
|
-
name = "diverging" if name == "divergent" else name
|
|
886
|
-
data[name] = data.pop(s_param).astype(dtypes.get(s_param, float))
|
|
887
|
-
if data_warmup:
|
|
888
|
-
data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float))
|
|
889
|
-
return (
|
|
890
|
-
dict_to_dataset(
|
|
891
|
-
data,
|
|
892
|
-
library=self.cmdstanpy,
|
|
893
|
-
coords=self.coords,
|
|
894
|
-
dims=self.dims,
|
|
895
|
-
index_origin=self.index_origin,
|
|
896
|
-
),
|
|
897
|
-
dict_to_dataset(
|
|
898
|
-
data_warmup,
|
|
899
|
-
library=self.cmdstanpy,
|
|
900
|
-
coords=self.coords,
|
|
901
|
-
dims=self.dims,
|
|
902
|
-
index_origin=self.index_origin,
|
|
903
|
-
),
|
|
904
|
-
)
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
def _as_set(spec):
|
|
908
|
-
"""Uniform representation for args which be name or list of names."""
|
|
909
|
-
if spec is None:
|
|
910
|
-
return []
|
|
911
|
-
if isinstance(spec, str):
|
|
912
|
-
return [spec]
|
|
913
|
-
try:
|
|
914
|
-
return set(spec.values())
|
|
915
|
-
except AttributeError:
|
|
916
|
-
return set(spec)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
def _filter(names, spec):
|
|
920
|
-
"""Remove names from list of names."""
|
|
921
|
-
if isinstance(spec, str):
|
|
922
|
-
names.remove(spec)
|
|
923
|
-
elif isinstance(spec, list):
|
|
924
|
-
for item in spec:
|
|
925
|
-
names.remove(item)
|
|
926
|
-
elif isinstance(spec, dict):
|
|
927
|
-
for item in spec.values():
|
|
928
|
-
names.remove(item)
|
|
929
|
-
return names
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
def _filter_columns(columns, spec):
|
|
933
|
-
"""Parse variable name from column label, removing element index, if any."""
|
|
934
|
-
return [col for col in columns if col.split("[")[0].split(".")[0] in spec]
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
def _unpack_fit(fit, items, save_warmup, dtypes):
|
|
938
|
-
num_warmup = 0
|
|
939
|
-
if save_warmup:
|
|
940
|
-
if not fit._save_warmup: # pylint: disable=protected-access
|
|
941
|
-
save_warmup = False
|
|
942
|
-
else:
|
|
943
|
-
num_warmup = fit.num_draws_warmup
|
|
944
|
-
|
|
945
|
-
nchains = fit.chains
|
|
946
|
-
sample = {}
|
|
947
|
-
sample_warmup = {}
|
|
948
|
-
stan_vars_cols = list(fit.metadata.stan_vars)
|
|
949
|
-
sampler_vars = fit.method_variables()
|
|
950
|
-
for item in items:
|
|
951
|
-
if item in stan_vars_cols:
|
|
952
|
-
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
|
|
953
|
-
raw_draws = np.swapaxes(
|
|
954
|
-
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
|
|
955
|
-
)
|
|
956
|
-
elif item in sampler_vars:
|
|
957
|
-
raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
|
|
958
|
-
else:
|
|
959
|
-
raise ValueError(f"fit data, unknown variable: {item}")
|
|
960
|
-
raw_draws = raw_draws.astype(dtypes.get(item))
|
|
961
|
-
if save_warmup:
|
|
962
|
-
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
|
|
963
|
-
sample[item] = raw_draws[:, num_warmup:, ...]
|
|
964
|
-
else:
|
|
965
|
-
sample[item] = raw_draws
|
|
966
|
-
|
|
967
|
-
return sample, sample_warmup
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
def _unpack_fit_pre_v_1_2_0(fit, items, save_warmup, dtypes):
|
|
971
|
-
num_warmup = 0
|
|
972
|
-
if save_warmup:
|
|
973
|
-
if not fit._save_warmup: # pylint: disable=protected-access
|
|
974
|
-
save_warmup = False
|
|
975
|
-
else:
|
|
976
|
-
num_warmup = fit.num_draws_warmup
|
|
977
|
-
|
|
978
|
-
nchains = fit.chains
|
|
979
|
-
sample = {}
|
|
980
|
-
sample_warmup = {}
|
|
981
|
-
stan_vars_cols = list(fit.metadata.stan_vars_cols)
|
|
982
|
-
sampler_vars = fit.method_variables()
|
|
983
|
-
for item in items:
|
|
984
|
-
if item in stan_vars_cols:
|
|
985
|
-
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
|
|
986
|
-
raw_draws = np.swapaxes(
|
|
987
|
-
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
|
|
988
|
-
)
|
|
989
|
-
elif item in sampler_vars:
|
|
990
|
-
raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
|
|
991
|
-
else:
|
|
992
|
-
raise ValueError(f"fit data, unknown variable: {item}")
|
|
993
|
-
raw_draws = raw_draws.astype(dtypes.get(item))
|
|
994
|
-
if save_warmup:
|
|
995
|
-
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
|
|
996
|
-
sample[item] = raw_draws[:, num_warmup:, ...]
|
|
997
|
-
else:
|
|
998
|
-
sample[item] = raw_draws
|
|
999
|
-
|
|
1000
|
-
return sample, sample_warmup
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
def _unpack_fit_pre_v_1_0_0(fit, items, save_warmup, dtypes):
|
|
1004
|
-
"""Transform fit to dictionary containing ndarrays.
|
|
1005
|
-
|
|
1006
|
-
Parameters
|
|
1007
|
-
----------
|
|
1008
|
-
data: cmdstanpy.CmdStanMCMC
|
|
1009
|
-
items: list
|
|
1010
|
-
save_warmup: bool
|
|
1011
|
-
dtypes: dict
|
|
1012
|
-
|
|
1013
|
-
Returns
|
|
1014
|
-
-------
|
|
1015
|
-
dict
|
|
1016
|
-
key, values pairs. Values are formatted to shape = (chains, draws, *shape)
|
|
1017
|
-
"""
|
|
1018
|
-
num_warmup = 0
|
|
1019
|
-
if save_warmup:
|
|
1020
|
-
if not fit._save_warmup: # pylint: disable=protected-access
|
|
1021
|
-
save_warmup = False
|
|
1022
|
-
else:
|
|
1023
|
-
num_warmup = fit.num_draws_warmup
|
|
1024
|
-
|
|
1025
|
-
nchains = fit.chains
|
|
1026
|
-
draws = np.swapaxes(fit.draws(inc_warmup=save_warmup), 0, 1)
|
|
1027
|
-
sample = {}
|
|
1028
|
-
sample_warmup = {}
|
|
1029
|
-
|
|
1030
|
-
stan_vars_cols = fit.metadata.stan_vars_cols if hasattr(fit, "metadata") else fit.stan_vars_cols
|
|
1031
|
-
sampler_vars_cols = (
|
|
1032
|
-
fit.metadata._method_vars_cols # pylint: disable=protected-access
|
|
1033
|
-
if hasattr(fit, "metadata")
|
|
1034
|
-
else fit.sampler_vars_cols
|
|
1035
|
-
)
|
|
1036
|
-
for item in items:
|
|
1037
|
-
if item in stan_vars_cols:
|
|
1038
|
-
col_idxs = stan_vars_cols[item]
|
|
1039
|
-
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
|
|
1040
|
-
raw_draws = np.swapaxes(
|
|
1041
|
-
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
|
|
1042
|
-
)
|
|
1043
|
-
elif item in sampler_vars_cols:
|
|
1044
|
-
col_idxs = sampler_vars_cols[item]
|
|
1045
|
-
raw_draws = draws[..., col_idxs[0]]
|
|
1046
|
-
else:
|
|
1047
|
-
raise ValueError(f"fit data, unknown variable: {item}")
|
|
1048
|
-
raw_draws = raw_draws.astype(dtypes.get(item))
|
|
1049
|
-
if save_warmup:
|
|
1050
|
-
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
|
|
1051
|
-
sample[item] = raw_draws[:, num_warmup:, ...]
|
|
1052
|
-
else:
|
|
1053
|
-
sample[item] = raw_draws
|
|
1054
|
-
|
|
1055
|
-
return sample, sample_warmup
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
def _unpack_frame(fit, columns, valid_cols, save_warmup, dtypes):
|
|
1059
|
-
"""Transform fit to dictionary containing ndarrays.
|
|
1060
|
-
|
|
1061
|
-
Called when fit object created by cmdstanpy version < 0.9.68
|
|
1062
|
-
|
|
1063
|
-
Parameters
|
|
1064
|
-
----------
|
|
1065
|
-
data: cmdstanpy.CmdStanMCMC
|
|
1066
|
-
columns: list
|
|
1067
|
-
valid_cols: list
|
|
1068
|
-
save_warmup: bool
|
|
1069
|
-
dtypes: dict
|
|
1070
|
-
|
|
1071
|
-
Returns
|
|
1072
|
-
-------
|
|
1073
|
-
dict
|
|
1074
|
-
key, values pairs. Values are formatted to shape = (chains, draws, *shape)
|
|
1075
|
-
"""
|
|
1076
|
-
if save_warmup and not fit._save_warmup: # pylint: disable=protected-access
|
|
1077
|
-
save_warmup = False
|
|
1078
|
-
if hasattr(fit, "draws"):
|
|
1079
|
-
data = fit.draws(inc_warmup=save_warmup)
|
|
1080
|
-
if save_warmup:
|
|
1081
|
-
num_warmup = fit._draws_warmup # pylint: disable=protected-access
|
|
1082
|
-
data_warmup = data[:num_warmup]
|
|
1083
|
-
data = data[num_warmup:]
|
|
1084
|
-
else:
|
|
1085
|
-
data = fit.sample
|
|
1086
|
-
if save_warmup:
|
|
1087
|
-
data_warmup = fit.warmup[: data.shape[0]]
|
|
1088
|
-
|
|
1089
|
-
draws, chains, *_ = data.shape
|
|
1090
|
-
if save_warmup:
|
|
1091
|
-
draws_warmup, *_ = data_warmup.shape
|
|
1092
|
-
|
|
1093
|
-
column_groups = defaultdict(list)
|
|
1094
|
-
column_locs = defaultdict(list)
|
|
1095
|
-
# iterate flat column names
|
|
1096
|
-
for i, col in enumerate(columns):
|
|
1097
|
-
if "." in col:
|
|
1098
|
-
# parse parameter names e.g. X.1.2 --> X, (1,2)
|
|
1099
|
-
col_base, *col_tail = col.split(".")
|
|
1100
|
-
else:
|
|
1101
|
-
# parse parameter names e.g. X[1,2] --> X, (1,2)
|
|
1102
|
-
col_base, *col_tail = col.replace("]", "").replace("[", ",").split(",")
|
|
1103
|
-
if len(col_tail):
|
|
1104
|
-
# gather nD array locations
|
|
1105
|
-
column_groups[col_base].append(tuple(map(int, col_tail)))
|
|
1106
|
-
# gather raw data locations for each parameter
|
|
1107
|
-
column_locs[col_base].append(i)
|
|
1108
|
-
# gather parameter dimensions (assumes dense arrays)
|
|
1109
|
-
dims = {
|
|
1110
|
-
colname: tuple(np.array(col_dims).max(0)) for colname, col_dims in column_groups.items()
|
|
1111
|
-
}
|
|
1112
|
-
sample = {}
|
|
1113
|
-
sample_warmup = {}
|
|
1114
|
-
valid_base_cols = []
|
|
1115
|
-
# get list of parameters for extraction (basename) X.1.2 --> X
|
|
1116
|
-
for col in valid_cols:
|
|
1117
|
-
base_col = col.split("[")[0].split(".")[0]
|
|
1118
|
-
if base_col not in valid_base_cols:
|
|
1119
|
-
valid_base_cols.append(base_col)
|
|
1120
|
-
|
|
1121
|
-
# extract each wanted parameter to ndarray with correct shape
|
|
1122
|
-
for key in valid_base_cols:
|
|
1123
|
-
ndim = dims.get(key, None)
|
|
1124
|
-
shape_location = column_groups.get(key, None)
|
|
1125
|
-
if ndim is not None:
|
|
1126
|
-
sample[key] = np.full((chains, draws, *ndim), np.nan)
|
|
1127
|
-
if save_warmup:
|
|
1128
|
-
sample_warmup[key] = np.full((chains, draws_warmup, *ndim), np.nan)
|
|
1129
|
-
if shape_location is None:
|
|
1130
|
-
# reorder draw, chain -> chain, draw
|
|
1131
|
-
(i,) = column_locs[key]
|
|
1132
|
-
sample[key] = np.swapaxes(data[..., i], 0, 1)
|
|
1133
|
-
if save_warmup:
|
|
1134
|
-
sample_warmup[key] = np.swapaxes(data_warmup[..., i], 0, 1)
|
|
1135
|
-
else:
|
|
1136
|
-
for i, shape_loc in zip(column_locs[key], shape_location):
|
|
1137
|
-
# location to insert extracted array
|
|
1138
|
-
shape_loc = tuple([Ellipsis] + [j - 1 for j in shape_loc])
|
|
1139
|
-
# reorder draw, chain -> chain, draw and insert to ndarray
|
|
1140
|
-
sample[key][shape_loc] = np.swapaxes(data[..., i], 0, 1)
|
|
1141
|
-
if save_warmup:
|
|
1142
|
-
sample_warmup[key][shape_loc] = np.swapaxes(data_warmup[..., i], 0, 1)
|
|
1143
|
-
|
|
1144
|
-
for key, dtype in dtypes.items():
|
|
1145
|
-
if key in sample:
|
|
1146
|
-
sample[key] = sample[key].astype(dtype)
|
|
1147
|
-
if save_warmup and key in sample_warmup:
|
|
1148
|
-
sample_warmup[key] = sample_warmup[key].astype(dtype)
|
|
1149
|
-
return sample, sample_warmup
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
def from_cmdstanpy(
|
|
1153
|
-
posterior=None,
|
|
1154
|
-
*,
|
|
1155
|
-
posterior_predictive=None,
|
|
1156
|
-
predictions=None,
|
|
1157
|
-
prior=None,
|
|
1158
|
-
prior_predictive=None,
|
|
1159
|
-
observed_data=None,
|
|
1160
|
-
constant_data=None,
|
|
1161
|
-
predictions_constant_data=None,
|
|
1162
|
-
log_likelihood=None,
|
|
1163
|
-
index_origin=None,
|
|
1164
|
-
coords=None,
|
|
1165
|
-
dims=None,
|
|
1166
|
-
save_warmup=None,
|
|
1167
|
-
dtypes=None,
|
|
1168
|
-
):
|
|
1169
|
-
"""Convert CmdStanPy data into an InferenceData object.
|
|
1170
|
-
|
|
1171
|
-
For a usage example read the
|
|
1172
|
-
:ref:`Creating InferenceData section on from_cmdstanpy <creating_InferenceData>`
|
|
1173
|
-
|
|
1174
|
-
Parameters
|
|
1175
|
-
----------
|
|
1176
|
-
posterior : CmdStanMCMC object
|
|
1177
|
-
CmdStanPy CmdStanMCMC
|
|
1178
|
-
posterior_predictive : str, list of str
|
|
1179
|
-
Posterior predictive samples for the fit.
|
|
1180
|
-
predictions : str, list of str
|
|
1181
|
-
Out of sample prediction samples for the fit.
|
|
1182
|
-
prior : CmdStanMCMC
|
|
1183
|
-
CmdStanPy CmdStanMCMC
|
|
1184
|
-
prior_predictive : str, list of str
|
|
1185
|
-
Prior predictive samples for the fit.
|
|
1186
|
-
observed_data : dict
|
|
1187
|
-
Observed data used in the sampling.
|
|
1188
|
-
constant_data : dict
|
|
1189
|
-
Constant data used in the sampling.
|
|
1190
|
-
predictions_constant_data : dict
|
|
1191
|
-
Constant data for predictions used in the sampling.
|
|
1192
|
-
log_likelihood : str, list of str, dict of {str: str}, optional
|
|
1193
|
-
Pointwise log_likelihood for the data. If a dict, its keys should represent var_names
|
|
1194
|
-
from the corresponding observed data and its values the stan variable where the
|
|
1195
|
-
data is stored. By default, if a variable ``log_lik`` is present in the Stan model,
|
|
1196
|
-
it will be retrieved as pointwise log likelihood values. Use ``False``
|
|
1197
|
-
or set ``data.log_likelihood`` to false to avoid this behaviour.
|
|
1198
|
-
index_origin : int, optional
|
|
1199
|
-
Starting value of integer coordinate values. Defaults to the value in rcParam
|
|
1200
|
-
``data.index_origin``.
|
|
1201
|
-
coords : dict of str or dict of iterable
|
|
1202
|
-
A dictionary containing the values that are used as index. The key
|
|
1203
|
-
is the name of the dimension, the values are the index values.
|
|
1204
|
-
dims : dict of str or list of str
|
|
1205
|
-
A mapping from variables to a list of coordinate names for the variable.
|
|
1206
|
-
save_warmup : bool
|
|
1207
|
-
Save warmup iterations into InferenceData object, if found in the input files.
|
|
1208
|
-
If not defined, use default defined by the rcParams.
|
|
1209
|
-
dtypes: dict or str or cmdstanpy.CmdStanModel
|
|
1210
|
-
A dictionary containing dtype information (int, float) for parameters.
|
|
1211
|
-
If input is a string, it is assumed to be a model code or path to model code file.
|
|
1212
|
-
Model code can extracted from cmdstanpy.CmdStanModel object.
|
|
1213
|
-
|
|
1214
|
-
Returns
|
|
1215
|
-
-------
|
|
1216
|
-
InferenceData object
|
|
1217
|
-
"""
|
|
1218
|
-
return CmdStanPyConverter(
|
|
1219
|
-
posterior=posterior,
|
|
1220
|
-
posterior_predictive=posterior_predictive,
|
|
1221
|
-
predictions=predictions,
|
|
1222
|
-
prior=prior,
|
|
1223
|
-
prior_predictive=prior_predictive,
|
|
1224
|
-
observed_data=observed_data,
|
|
1225
|
-
constant_data=constant_data,
|
|
1226
|
-
predictions_constant_data=predictions_constant_data,
|
|
1227
|
-
log_likelihood=log_likelihood,
|
|
1228
|
-
index_origin=index_origin,
|
|
1229
|
-
coords=coords,
|
|
1230
|
-
dims=dims,
|
|
1231
|
-
save_warmup=save_warmup,
|
|
1232
|
-
dtypes=dtypes,
|
|
1233
|
-
).to_inference_data()
|