arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/data/io_cmdstan.py
DELETED
|
@@ -1,1036 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines
|
|
2
|
-
"""CmdStan-specific conversion code."""
|
|
3
|
-
try:
|
|
4
|
-
import ujson as json
|
|
5
|
-
except ImportError:
|
|
6
|
-
# Can't find ujson using json
|
|
7
|
-
# mypy struggles with conditional imports expressed as catching ImportError:
|
|
8
|
-
# https://github.com/python/mypy/issues/1153
|
|
9
|
-
import json # type: ignore
|
|
10
|
-
import logging
|
|
11
|
-
import os
|
|
12
|
-
import re
|
|
13
|
-
from collections import defaultdict
|
|
14
|
-
from glob import glob
|
|
15
|
-
from pathlib import Path
|
|
16
|
-
from typing import Dict, List, Optional, Union
|
|
17
|
-
|
|
18
|
-
import numpy as np
|
|
19
|
-
|
|
20
|
-
from .. import utils
|
|
21
|
-
from ..rcparams import rcParams
|
|
22
|
-
from .base import CoordSpec, DimSpec, dict_to_dataset, infer_stan_dtypes, requires
|
|
23
|
-
from .inference_data import InferenceData
|
|
24
|
-
|
|
25
|
-
_log = logging.getLogger(__name__)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def check_glob(path, group, disable_glob):
|
|
29
|
-
"""Find files with glob."""
|
|
30
|
-
if isinstance(path, str) and (not disable_glob):
|
|
31
|
-
path_glob = glob(path)
|
|
32
|
-
if path_glob:
|
|
33
|
-
path = sorted(path_glob)
|
|
34
|
-
msg = "\n".join(f"{i}: {os.path.normpath(fpath)}" for i, fpath in enumerate(path, 1))
|
|
35
|
-
len_p = len(path)
|
|
36
|
-
_log.info("glob found %d files for '%s':\n%s", len_p, group, msg)
|
|
37
|
-
return path
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class CmdStanConverter:
|
|
41
|
-
"""Encapsulate CmdStan specific logic."""
|
|
42
|
-
|
|
43
|
-
# pylint: disable=too-many-instance-attributes
|
|
44
|
-
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
*,
|
|
48
|
-
posterior=None,
|
|
49
|
-
posterior_predictive=None,
|
|
50
|
-
predictions=None,
|
|
51
|
-
prior=None,
|
|
52
|
-
prior_predictive=None,
|
|
53
|
-
observed_data=None,
|
|
54
|
-
observed_data_var=None,
|
|
55
|
-
constant_data=None,
|
|
56
|
-
constant_data_var=None,
|
|
57
|
-
predictions_constant_data=None,
|
|
58
|
-
predictions_constant_data_var=None,
|
|
59
|
-
log_likelihood=None,
|
|
60
|
-
index_origin=None,
|
|
61
|
-
coords=None,
|
|
62
|
-
dims=None,
|
|
63
|
-
disable_glob=False,
|
|
64
|
-
save_warmup=None,
|
|
65
|
-
dtypes=None,
|
|
66
|
-
):
|
|
67
|
-
self.posterior_ = check_glob(posterior, "posterior", disable_glob)
|
|
68
|
-
self.posterior_predictive = check_glob(
|
|
69
|
-
posterior_predictive, "posterior_predictive", disable_glob
|
|
70
|
-
)
|
|
71
|
-
self.predictions = check_glob(predictions, "predictions", disable_glob)
|
|
72
|
-
self.prior_ = check_glob(prior, "prior", disable_glob)
|
|
73
|
-
self.prior_predictive = check_glob(prior_predictive, "prior_predictive", disable_glob)
|
|
74
|
-
self.log_likelihood = check_glob(log_likelihood, "log_likelihood", disable_glob)
|
|
75
|
-
self.observed_data = observed_data
|
|
76
|
-
self.observed_data_var = observed_data_var
|
|
77
|
-
self.constant_data = constant_data
|
|
78
|
-
self.constant_data_var = constant_data_var
|
|
79
|
-
self.predictions_constant_data = predictions_constant_data
|
|
80
|
-
self.predictions_constant_data_var = predictions_constant_data_var
|
|
81
|
-
self.coords = coords if coords is not None else {}
|
|
82
|
-
self.dims = dims if dims is not None else {}
|
|
83
|
-
|
|
84
|
-
self.posterior = None
|
|
85
|
-
self.prior = None
|
|
86
|
-
self.attrs = None
|
|
87
|
-
self.attrs_prior = None
|
|
88
|
-
|
|
89
|
-
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
90
|
-
self.index_origin = index_origin
|
|
91
|
-
|
|
92
|
-
if dtypes is None:
|
|
93
|
-
dtypes = {}
|
|
94
|
-
elif isinstance(dtypes, str):
|
|
95
|
-
dtypes_path = Path(dtypes)
|
|
96
|
-
if dtypes_path.exists():
|
|
97
|
-
with dtypes_path.open("r", encoding="UTF-8") as f_obj:
|
|
98
|
-
model_code = f_obj.read()
|
|
99
|
-
else:
|
|
100
|
-
model_code = dtypes
|
|
101
|
-
|
|
102
|
-
dtypes = infer_stan_dtypes(model_code)
|
|
103
|
-
|
|
104
|
-
self.dtypes = dtypes
|
|
105
|
-
|
|
106
|
-
# populate posterior and sample_stats
|
|
107
|
-
self._parse_posterior()
|
|
108
|
-
self._parse_prior()
|
|
109
|
-
|
|
110
|
-
if (
|
|
111
|
-
self.log_likelihood is None
|
|
112
|
-
and self.posterior_ is not None
|
|
113
|
-
and any(name.split(".")[0] == "log_lik" for name in self.posterior_columns)
|
|
114
|
-
):
|
|
115
|
-
self.log_likelihood = ["log_lik"]
|
|
116
|
-
elif isinstance(self.log_likelihood, bool):
|
|
117
|
-
self.log_likelihood = None
|
|
118
|
-
|
|
119
|
-
@requires("posterior_")
|
|
120
|
-
def _parse_posterior(self):
|
|
121
|
-
"""Read csv paths to list of ndarrays."""
|
|
122
|
-
paths = self.posterior_
|
|
123
|
-
if isinstance(paths, str):
|
|
124
|
-
paths = [paths]
|
|
125
|
-
|
|
126
|
-
chain_data = []
|
|
127
|
-
columns = None
|
|
128
|
-
for path in paths:
|
|
129
|
-
output_data = _read_output(path)
|
|
130
|
-
chain_data.append(output_data)
|
|
131
|
-
if columns is None:
|
|
132
|
-
columns = output_data
|
|
133
|
-
|
|
134
|
-
self.posterior = (
|
|
135
|
-
[item["sample"] for item in chain_data],
|
|
136
|
-
[item["sample_warmup"] for item in chain_data],
|
|
137
|
-
)
|
|
138
|
-
self.posterior_columns = columns["sample_columns"]
|
|
139
|
-
self.sample_stats_columns = columns["sample_stats_columns"]
|
|
140
|
-
|
|
141
|
-
attrs = {}
|
|
142
|
-
for item in chain_data:
|
|
143
|
-
for key, value in item["configuration_info"].items():
|
|
144
|
-
if key not in attrs:
|
|
145
|
-
attrs[key] = []
|
|
146
|
-
attrs[key].append(value)
|
|
147
|
-
self.attrs = attrs
|
|
148
|
-
|
|
149
|
-
@requires("prior_")
|
|
150
|
-
def _parse_prior(self):
|
|
151
|
-
"""Read csv paths to list of ndarrays."""
|
|
152
|
-
paths = self.prior_
|
|
153
|
-
if isinstance(paths, str):
|
|
154
|
-
paths = [paths]
|
|
155
|
-
|
|
156
|
-
chain_data = []
|
|
157
|
-
columns = None
|
|
158
|
-
for path in paths:
|
|
159
|
-
output_data = _read_output(path)
|
|
160
|
-
chain_data.append(output_data)
|
|
161
|
-
if columns is None:
|
|
162
|
-
columns = output_data
|
|
163
|
-
|
|
164
|
-
self.prior = (
|
|
165
|
-
[item["sample"] for item in chain_data],
|
|
166
|
-
[item["sample_warmup"] for item in chain_data],
|
|
167
|
-
)
|
|
168
|
-
self.prior_columns = columns["sample_columns"]
|
|
169
|
-
self.sample_stats_prior_columns = columns["sample_stats_columns"]
|
|
170
|
-
|
|
171
|
-
attrs = {}
|
|
172
|
-
for item in chain_data:
|
|
173
|
-
for key, value in item["configuration_info"].items():
|
|
174
|
-
if key not in attrs:
|
|
175
|
-
attrs[key] = []
|
|
176
|
-
attrs[key].append(value)
|
|
177
|
-
self.attrs_prior = attrs
|
|
178
|
-
|
|
179
|
-
@requires("posterior")
|
|
180
|
-
def posterior_to_xarray(self):
|
|
181
|
-
"""Extract posterior samples from output csv."""
|
|
182
|
-
columns = self.posterior_columns
|
|
183
|
-
|
|
184
|
-
# filter posterior_predictive, predictions and log_likelihood
|
|
185
|
-
posterior_predictive = self.posterior_predictive
|
|
186
|
-
if posterior_predictive is None or (
|
|
187
|
-
isinstance(posterior_predictive, str) and posterior_predictive.lower().endswith(".csv")
|
|
188
|
-
):
|
|
189
|
-
posterior_predictive = []
|
|
190
|
-
elif isinstance(posterior_predictive, str):
|
|
191
|
-
posterior_predictive = [
|
|
192
|
-
col for col in columns if posterior_predictive == col.split(".")[0]
|
|
193
|
-
]
|
|
194
|
-
else:
|
|
195
|
-
posterior_predictive = [
|
|
196
|
-
col
|
|
197
|
-
for col in columns
|
|
198
|
-
if any(item == col.split(".")[0] for item in posterior_predictive)
|
|
199
|
-
]
|
|
200
|
-
|
|
201
|
-
predictions = self.predictions
|
|
202
|
-
if predictions is None or (
|
|
203
|
-
isinstance(predictions, str) and predictions.lower().endswith(".csv")
|
|
204
|
-
):
|
|
205
|
-
predictions = []
|
|
206
|
-
elif isinstance(predictions, str):
|
|
207
|
-
predictions = [col for col in columns if predictions == col.split(".")[0]]
|
|
208
|
-
else:
|
|
209
|
-
predictions = [
|
|
210
|
-
col for col in columns if any(item == col.split(".")[0] for item in predictions)
|
|
211
|
-
]
|
|
212
|
-
|
|
213
|
-
log_likelihood = self.log_likelihood
|
|
214
|
-
if log_likelihood is None or (
|
|
215
|
-
isinstance(log_likelihood, str) and log_likelihood.lower().endswith(".csv")
|
|
216
|
-
):
|
|
217
|
-
log_likelihood = []
|
|
218
|
-
elif isinstance(log_likelihood, str):
|
|
219
|
-
log_likelihood = [col for col in columns if log_likelihood == col.split(".")[0]]
|
|
220
|
-
elif isinstance(log_likelihood, dict):
|
|
221
|
-
log_likelihood = [
|
|
222
|
-
col
|
|
223
|
-
for col in columns
|
|
224
|
-
if any(item == col.split(".")[0] for item in log_likelihood.values())
|
|
225
|
-
]
|
|
226
|
-
else:
|
|
227
|
-
log_likelihood = [
|
|
228
|
-
col for col in columns if any(item == col.split(".")[0] for item in log_likelihood)
|
|
229
|
-
]
|
|
230
|
-
|
|
231
|
-
invalid_cols = posterior_predictive + predictions + log_likelihood
|
|
232
|
-
valid_cols = {col: idx for col, idx in columns.items() if col not in invalid_cols}
|
|
233
|
-
data = _unpack_ndarrays(self.posterior[0], valid_cols, self.dtypes)
|
|
234
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], valid_cols, self.dtypes)
|
|
235
|
-
return (
|
|
236
|
-
dict_to_dataset(
|
|
237
|
-
data,
|
|
238
|
-
coords=self.coords,
|
|
239
|
-
dims=self.dims,
|
|
240
|
-
attrs=self.attrs,
|
|
241
|
-
index_origin=self.index_origin,
|
|
242
|
-
),
|
|
243
|
-
dict_to_dataset(
|
|
244
|
-
data_warmup,
|
|
245
|
-
coords=self.coords,
|
|
246
|
-
dims=self.dims,
|
|
247
|
-
attrs=self.attrs,
|
|
248
|
-
index_origin=self.index_origin,
|
|
249
|
-
),
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
@requires("posterior")
|
|
253
|
-
@requires("sample_stats_columns")
|
|
254
|
-
def sample_stats_to_xarray(self):
|
|
255
|
-
"""Extract sample_stats from fit."""
|
|
256
|
-
dtypes = {"diverging": bool, "n_steps": np.int64, "tree_depth": np.int64, **self.dtypes}
|
|
257
|
-
rename_dict = {
|
|
258
|
-
"divergent": "diverging",
|
|
259
|
-
"n_leapfrog": "n_steps",
|
|
260
|
-
"treedepth": "tree_depth",
|
|
261
|
-
"stepsize": "step_size",
|
|
262
|
-
"accept_stat": "acceptance_rate",
|
|
263
|
-
}
|
|
264
|
-
|
|
265
|
-
columns_new = {}
|
|
266
|
-
for key, idx in self.sample_stats_columns.items():
|
|
267
|
-
name = re.sub("__$", "", key)
|
|
268
|
-
name = rename_dict.get(name, name)
|
|
269
|
-
columns_new[name] = idx
|
|
270
|
-
|
|
271
|
-
data = _unpack_ndarrays(self.posterior[0], columns_new, dtypes)
|
|
272
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], columns_new, dtypes)
|
|
273
|
-
return (
|
|
274
|
-
dict_to_dataset(
|
|
275
|
-
data,
|
|
276
|
-
coords=self.coords,
|
|
277
|
-
dims=self.dims,
|
|
278
|
-
attrs={item: key for key, item in rename_dict.items()},
|
|
279
|
-
index_origin=self.index_origin,
|
|
280
|
-
),
|
|
281
|
-
dict_to_dataset(
|
|
282
|
-
data_warmup,
|
|
283
|
-
coords=self.coords,
|
|
284
|
-
dims=self.dims,
|
|
285
|
-
attrs={item: key for key, item in rename_dict.items()},
|
|
286
|
-
index_origin=self.index_origin,
|
|
287
|
-
),
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
@requires("posterior")
|
|
291
|
-
@requires("posterior_predictive")
|
|
292
|
-
def posterior_predictive_to_xarray(self):
|
|
293
|
-
"""Convert posterior_predictive samples to xarray."""
|
|
294
|
-
posterior_predictive = self.posterior_predictive
|
|
295
|
-
|
|
296
|
-
if (
|
|
297
|
-
isinstance(posterior_predictive, (tuple, list))
|
|
298
|
-
and posterior_predictive[0].endswith(".csv")
|
|
299
|
-
) or (isinstance(posterior_predictive, str) and posterior_predictive.endswith(".csv")):
|
|
300
|
-
if isinstance(posterior_predictive, str):
|
|
301
|
-
posterior_predictive = [posterior_predictive]
|
|
302
|
-
chain_data = []
|
|
303
|
-
chain_data_warmup = []
|
|
304
|
-
columns = None
|
|
305
|
-
attrs = {}
|
|
306
|
-
for path in posterior_predictive:
|
|
307
|
-
parsed_output = _read_output(path)
|
|
308
|
-
chain_data.append(parsed_output["sample"])
|
|
309
|
-
chain_data_warmup.append(parsed_output["sample_warmup"])
|
|
310
|
-
if columns is None:
|
|
311
|
-
columns = parsed_output["sample_columns"]
|
|
312
|
-
|
|
313
|
-
for key, value in parsed_output["configuration_info"].items():
|
|
314
|
-
if key not in attrs:
|
|
315
|
-
attrs[key] = []
|
|
316
|
-
attrs[key].append(value)
|
|
317
|
-
|
|
318
|
-
data = _unpack_ndarrays(chain_data, columns, self.dtypes)
|
|
319
|
-
data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
|
|
320
|
-
|
|
321
|
-
else:
|
|
322
|
-
if isinstance(posterior_predictive, str):
|
|
323
|
-
posterior_predictive = [posterior_predictive]
|
|
324
|
-
columns = {
|
|
325
|
-
col: idx
|
|
326
|
-
for col, idx in self.posterior_columns.items()
|
|
327
|
-
if any(item == col.split(".")[0] for item in posterior_predictive)
|
|
328
|
-
}
|
|
329
|
-
data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
|
|
330
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
|
|
331
|
-
|
|
332
|
-
attrs = None
|
|
333
|
-
return (
|
|
334
|
-
dict_to_dataset(
|
|
335
|
-
data,
|
|
336
|
-
coords=self.coords,
|
|
337
|
-
dims=self.dims,
|
|
338
|
-
attrs=attrs,
|
|
339
|
-
index_origin=self.index_origin,
|
|
340
|
-
),
|
|
341
|
-
dict_to_dataset(
|
|
342
|
-
data_warmup,
|
|
343
|
-
coords=self.coords,
|
|
344
|
-
dims=self.dims,
|
|
345
|
-
attrs=attrs,
|
|
346
|
-
index_origin=self.index_origin,
|
|
347
|
-
),
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
@requires("posterior")
|
|
351
|
-
@requires("predictions")
|
|
352
|
-
def predictions_to_xarray(self):
|
|
353
|
-
"""Convert out of sample predictions samples to xarray."""
|
|
354
|
-
predictions = self.predictions
|
|
355
|
-
|
|
356
|
-
if (isinstance(predictions, (tuple, list)) and predictions[0].endswith(".csv")) or (
|
|
357
|
-
isinstance(predictions, str) and predictions.endswith(".csv")
|
|
358
|
-
):
|
|
359
|
-
if isinstance(predictions, str):
|
|
360
|
-
predictions = [predictions]
|
|
361
|
-
chain_data = []
|
|
362
|
-
chain_data_warmup = []
|
|
363
|
-
columns = None
|
|
364
|
-
attrs = {}
|
|
365
|
-
for path in predictions:
|
|
366
|
-
parsed_output = _read_output(path)
|
|
367
|
-
chain_data.append(parsed_output["sample"])
|
|
368
|
-
chain_data_warmup.append(parsed_output["sample_warmup"])
|
|
369
|
-
if columns is None:
|
|
370
|
-
columns = parsed_output["sample_columns"]
|
|
371
|
-
|
|
372
|
-
for key, value in parsed_output["configuration_info"].items():
|
|
373
|
-
if key not in attrs:
|
|
374
|
-
attrs[key] = []
|
|
375
|
-
attrs[key].append(value)
|
|
376
|
-
|
|
377
|
-
data = _unpack_ndarrays(chain_data, columns, self.dtypes)
|
|
378
|
-
data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
|
|
379
|
-
else:
|
|
380
|
-
if isinstance(predictions, str):
|
|
381
|
-
predictions = [predictions]
|
|
382
|
-
columns = {
|
|
383
|
-
col: idx
|
|
384
|
-
for col, idx in self.posterior_columns.items()
|
|
385
|
-
if any(item == col.split(".")[0] for item in predictions)
|
|
386
|
-
}
|
|
387
|
-
data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
|
|
388
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
|
|
389
|
-
|
|
390
|
-
attrs = None
|
|
391
|
-
return (
|
|
392
|
-
dict_to_dataset(
|
|
393
|
-
data,
|
|
394
|
-
coords=self.coords,
|
|
395
|
-
dims=self.dims,
|
|
396
|
-
attrs=attrs,
|
|
397
|
-
index_origin=self.index_origin,
|
|
398
|
-
),
|
|
399
|
-
dict_to_dataset(
|
|
400
|
-
data_warmup,
|
|
401
|
-
coords=self.coords,
|
|
402
|
-
dims=self.dims,
|
|
403
|
-
attrs=attrs,
|
|
404
|
-
index_origin=self.index_origin,
|
|
405
|
-
),
|
|
406
|
-
)
|
|
407
|
-
|
|
408
|
-
@requires("posterior")
|
|
409
|
-
@requires("log_likelihood")
|
|
410
|
-
def log_likelihood_to_xarray(self):
|
|
411
|
-
"""Convert elementwise log_likelihood samples to xarray."""
|
|
412
|
-
log_likelihood = self.log_likelihood
|
|
413
|
-
|
|
414
|
-
if (isinstance(log_likelihood, (tuple, list)) and log_likelihood[0].endswith(".csv")) or (
|
|
415
|
-
isinstance(log_likelihood, str) and log_likelihood.endswith(".csv")
|
|
416
|
-
):
|
|
417
|
-
if isinstance(log_likelihood, str):
|
|
418
|
-
log_likelihood = [log_likelihood]
|
|
419
|
-
|
|
420
|
-
chain_data = []
|
|
421
|
-
chain_data_warmup = []
|
|
422
|
-
columns = None
|
|
423
|
-
attrs = {}
|
|
424
|
-
for path in log_likelihood:
|
|
425
|
-
parsed_output = _read_output(path)
|
|
426
|
-
chain_data.append(parsed_output["sample"])
|
|
427
|
-
chain_data_warmup.append(parsed_output["sample_warmup"])
|
|
428
|
-
|
|
429
|
-
if columns is None:
|
|
430
|
-
columns = parsed_output["sample_columns"]
|
|
431
|
-
|
|
432
|
-
for key, value in parsed_output["configuration_info"].items():
|
|
433
|
-
if key not in attrs:
|
|
434
|
-
attrs[key] = []
|
|
435
|
-
attrs[key].append(value)
|
|
436
|
-
data = _unpack_ndarrays(chain_data, columns, self.dtypes)
|
|
437
|
-
data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
|
|
438
|
-
else:
|
|
439
|
-
if isinstance(log_likelihood, dict):
|
|
440
|
-
log_lik_to_obs_name = {v: k for k, v in log_likelihood.items()}
|
|
441
|
-
columns = {
|
|
442
|
-
col.replace(col_name, log_lik_to_obs_name[col_name]): idx
|
|
443
|
-
for col, col_name, idx in (
|
|
444
|
-
(col, col.split(".")[0], idx) for col, idx in self.posterior_columns.items()
|
|
445
|
-
)
|
|
446
|
-
if any(item == col_name for item in log_likelihood.values())
|
|
447
|
-
}
|
|
448
|
-
else:
|
|
449
|
-
if isinstance(log_likelihood, str):
|
|
450
|
-
log_likelihood = [log_likelihood]
|
|
451
|
-
columns = {
|
|
452
|
-
col: idx
|
|
453
|
-
for col, idx in self.posterior_columns.items()
|
|
454
|
-
if any(item == col.split(".")[0] for item in log_likelihood)
|
|
455
|
-
}
|
|
456
|
-
data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
|
|
457
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
|
|
458
|
-
attrs = None
|
|
459
|
-
return (
|
|
460
|
-
dict_to_dataset(
|
|
461
|
-
data,
|
|
462
|
-
coords=self.coords,
|
|
463
|
-
dims=self.dims,
|
|
464
|
-
attrs=attrs,
|
|
465
|
-
index_origin=self.index_origin,
|
|
466
|
-
skip_event_dims=True,
|
|
467
|
-
),
|
|
468
|
-
dict_to_dataset(
|
|
469
|
-
data_warmup,
|
|
470
|
-
coords=self.coords,
|
|
471
|
-
dims=self.dims,
|
|
472
|
-
attrs=attrs,
|
|
473
|
-
index_origin=self.index_origin,
|
|
474
|
-
skip_event_dims=True,
|
|
475
|
-
),
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
@requires("prior")
|
|
479
|
-
def prior_to_xarray(self):
|
|
480
|
-
"""Convert prior samples to xarray."""
|
|
481
|
-
# filter prior_predictive
|
|
482
|
-
prior_predictive = self.prior_predictive
|
|
483
|
-
|
|
484
|
-
columns = self.prior_columns
|
|
485
|
-
|
|
486
|
-
if prior_predictive is None or (
|
|
487
|
-
isinstance(prior_predictive, str) and prior_predictive.lower().endswith(".csv")
|
|
488
|
-
):
|
|
489
|
-
prior_predictive = []
|
|
490
|
-
elif isinstance(prior_predictive, str):
|
|
491
|
-
prior_predictive = [col for col in columns if prior_predictive == col.split(".")[0]]
|
|
492
|
-
else:
|
|
493
|
-
prior_predictive = [
|
|
494
|
-
col
|
|
495
|
-
for col in columns
|
|
496
|
-
if any(item == col.split(".")[0] for item in prior_predictive)
|
|
497
|
-
]
|
|
498
|
-
|
|
499
|
-
invalid_cols = prior_predictive
|
|
500
|
-
valid_cols = {col: idx for col, idx in columns.items() if col not in invalid_cols}
|
|
501
|
-
data = _unpack_ndarrays(self.prior[0], valid_cols, self.dtypes)
|
|
502
|
-
data_warmup = _unpack_ndarrays(self.prior[1], valid_cols, self.dtypes)
|
|
503
|
-
return (
|
|
504
|
-
dict_to_dataset(
|
|
505
|
-
data,
|
|
506
|
-
coords=self.coords,
|
|
507
|
-
dims=self.dims,
|
|
508
|
-
attrs=self.attrs_prior,
|
|
509
|
-
index_origin=self.index_origin,
|
|
510
|
-
),
|
|
511
|
-
dict_to_dataset(
|
|
512
|
-
data_warmup,
|
|
513
|
-
coords=self.coords,
|
|
514
|
-
dims=self.dims,
|
|
515
|
-
attrs=self.attrs_prior,
|
|
516
|
-
index_origin=self.index_origin,
|
|
517
|
-
),
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
@requires("prior")
|
|
521
|
-
@requires("sample_stats_prior_columns")
|
|
522
|
-
def sample_stats_prior_to_xarray(self):
|
|
523
|
-
"""Extract sample_stats from fit."""
|
|
524
|
-
dtypes = {"diverging": bool, "n_steps": np.int64, "tree_depth": np.int64, **self.dtypes}
|
|
525
|
-
rename_dict = {
|
|
526
|
-
"divergent": "diverging",
|
|
527
|
-
"n_leapfrog": "n_steps",
|
|
528
|
-
"treedepth": "tree_depth",
|
|
529
|
-
"stepsize": "step_size",
|
|
530
|
-
"accept_stat": "acceptance_rate",
|
|
531
|
-
}
|
|
532
|
-
|
|
533
|
-
columns_new = {}
|
|
534
|
-
for key, idx in self.sample_stats_prior_columns.items():
|
|
535
|
-
name = re.sub("__$", "", key)
|
|
536
|
-
name = rename_dict.get(name, name)
|
|
537
|
-
columns_new[name] = idx
|
|
538
|
-
|
|
539
|
-
data = _unpack_ndarrays(self.posterior[0], columns_new, dtypes)
|
|
540
|
-
data_warmup = _unpack_ndarrays(self.posterior[1], columns_new, dtypes)
|
|
541
|
-
return (
|
|
542
|
-
dict_to_dataset(
|
|
543
|
-
data,
|
|
544
|
-
coords=self.coords,
|
|
545
|
-
dims=self.dims,
|
|
546
|
-
attrs={item: key for key, item in rename_dict.items()},
|
|
547
|
-
index_origin=self.index_origin,
|
|
548
|
-
),
|
|
549
|
-
dict_to_dataset(
|
|
550
|
-
data_warmup,
|
|
551
|
-
coords=self.coords,
|
|
552
|
-
dims=self.dims,
|
|
553
|
-
attrs={item: key for key, item in rename_dict.items()},
|
|
554
|
-
index_origin=self.index_origin,
|
|
555
|
-
),
|
|
556
|
-
)
|
|
557
|
-
|
|
558
|
-
@requires("prior")
|
|
559
|
-
@requires("prior_predictive")
|
|
560
|
-
def prior_predictive_to_xarray(self):
|
|
561
|
-
"""Convert prior_predictive samples to xarray."""
|
|
562
|
-
prior_predictive = self.prior_predictive
|
|
563
|
-
|
|
564
|
-
if (
|
|
565
|
-
isinstance(prior_predictive, (tuple, list)) and prior_predictive[0].endswith(".csv")
|
|
566
|
-
) or (isinstance(prior_predictive, str) and prior_predictive.endswith(".csv")):
|
|
567
|
-
if isinstance(prior_predictive, str):
|
|
568
|
-
prior_predictive = [prior_predictive]
|
|
569
|
-
chain_data = []
|
|
570
|
-
chain_data_warmup = []
|
|
571
|
-
columns = None
|
|
572
|
-
attrs = {}
|
|
573
|
-
for path in prior_predictive:
|
|
574
|
-
parsed_output = _read_output(path)
|
|
575
|
-
chain_data.append(parsed_output["sample"])
|
|
576
|
-
chain_data_warmup.append(parsed_output["sample_warmup"])
|
|
577
|
-
if columns is None:
|
|
578
|
-
columns = parsed_output["sample_columns"]
|
|
579
|
-
for key, value in parsed_output["configuration_info"].items():
|
|
580
|
-
if key not in attrs:
|
|
581
|
-
attrs[key] = []
|
|
582
|
-
attrs[key].append(value)
|
|
583
|
-
data = _unpack_ndarrays(chain_data, columns, self.dtypes)
|
|
584
|
-
data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
|
|
585
|
-
else:
|
|
586
|
-
if isinstance(prior_predictive, str):
|
|
587
|
-
prior_predictive = [prior_predictive]
|
|
588
|
-
columns = {
|
|
589
|
-
col: idx
|
|
590
|
-
for col, idx in self.prior_columns.items()
|
|
591
|
-
if any(item == col.split(".")[0] for item in prior_predictive)
|
|
592
|
-
}
|
|
593
|
-
data = _unpack_ndarrays(self.prior[0], columns, self.dtypes)
|
|
594
|
-
data_warmup = _unpack_ndarrays(self.prior[1], columns, self.dtypes)
|
|
595
|
-
attrs = None
|
|
596
|
-
return (
|
|
597
|
-
dict_to_dataset(
|
|
598
|
-
data,
|
|
599
|
-
coords=self.coords,
|
|
600
|
-
dims=self.dims,
|
|
601
|
-
attrs=attrs,
|
|
602
|
-
index_origin=self.index_origin,
|
|
603
|
-
),
|
|
604
|
-
dict_to_dataset(
|
|
605
|
-
data_warmup,
|
|
606
|
-
coords=self.coords,
|
|
607
|
-
dims=self.dims,
|
|
608
|
-
attrs=attrs,
|
|
609
|
-
index_origin=self.index_origin,
|
|
610
|
-
),
|
|
611
|
-
)
|
|
612
|
-
|
|
613
|
-
@requires("observed_data")
|
|
614
|
-
def observed_data_to_xarray(self):
|
|
615
|
-
"""Convert observed data to xarray."""
|
|
616
|
-
observed_data_raw = _read_data(self.observed_data)
|
|
617
|
-
variables = self.observed_data_var
|
|
618
|
-
if isinstance(variables, str):
|
|
619
|
-
variables = [variables]
|
|
620
|
-
observed_data = {
|
|
621
|
-
key: utils.one_de(vals)
|
|
622
|
-
for key, vals in observed_data_raw.items()
|
|
623
|
-
if variables is None or key in variables
|
|
624
|
-
}
|
|
625
|
-
return dict_to_dataset(
|
|
626
|
-
observed_data,
|
|
627
|
-
coords=self.coords,
|
|
628
|
-
dims=self.dims,
|
|
629
|
-
default_dims=[],
|
|
630
|
-
index_origin=self.index_origin,
|
|
631
|
-
)
|
|
632
|
-
|
|
633
|
-
@requires("constant_data")
|
|
634
|
-
def constant_data_to_xarray(self):
|
|
635
|
-
"""Convert constant data to xarray."""
|
|
636
|
-
constant_data_raw = _read_data(self.constant_data)
|
|
637
|
-
variables = self.constant_data_var
|
|
638
|
-
if isinstance(variables, str):
|
|
639
|
-
variables = [variables]
|
|
640
|
-
constant_data = {
|
|
641
|
-
key: utils.one_de(vals)
|
|
642
|
-
for key, vals in constant_data_raw.items()
|
|
643
|
-
if variables is None or key in variables
|
|
644
|
-
}
|
|
645
|
-
return dict_to_dataset(
|
|
646
|
-
constant_data,
|
|
647
|
-
coords=self.coords,
|
|
648
|
-
dims=self.dims,
|
|
649
|
-
default_dims=[],
|
|
650
|
-
index_origin=self.index_origin,
|
|
651
|
-
)
|
|
652
|
-
|
|
653
|
-
@requires("predictions_constant_data")
|
|
654
|
-
def predictions_constant_data_to_xarray(self):
|
|
655
|
-
"""Convert predictions constant data to xarray."""
|
|
656
|
-
predictions_constant_data_raw = _read_data(self.predictions_constant_data)
|
|
657
|
-
variables = self.predictions_constant_data_var
|
|
658
|
-
if isinstance(variables, str):
|
|
659
|
-
variables = [variables]
|
|
660
|
-
predictions_constant_data = {}
|
|
661
|
-
for key, vals in predictions_constant_data_raw.items():
|
|
662
|
-
if variables is not None and key not in variables:
|
|
663
|
-
continue
|
|
664
|
-
vals = utils.one_de(vals)
|
|
665
|
-
predictions_constant_data[key] = utils.one_de(vals)
|
|
666
|
-
return dict_to_dataset(
|
|
667
|
-
predictions_constant_data,
|
|
668
|
-
coords=self.coords,
|
|
669
|
-
dims=self.dims,
|
|
670
|
-
default_dims=[],
|
|
671
|
-
index_origin=self.index_origin,
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
def to_inference_data(self):
|
|
675
|
-
"""Convert all available data to an InferenceData object.
|
|
676
|
-
|
|
677
|
-
Note that if groups can not be created (i.e., there is no `output`, so
|
|
678
|
-
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
|
|
679
|
-
will not have those groups.
|
|
680
|
-
"""
|
|
681
|
-
return InferenceData(
|
|
682
|
-
save_warmup=self.save_warmup,
|
|
683
|
-
**{
|
|
684
|
-
"posterior": self.posterior_to_xarray(),
|
|
685
|
-
"sample_stats": self.sample_stats_to_xarray(),
|
|
686
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
687
|
-
"posterior_predictive": self.posterior_predictive_to_xarray(),
|
|
688
|
-
"prior": self.prior_to_xarray(),
|
|
689
|
-
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
|
|
690
|
-
"prior_predictive": self.prior_predictive_to_xarray(),
|
|
691
|
-
"observed_data": self.observed_data_to_xarray(),
|
|
692
|
-
"constant_data": self.constant_data_to_xarray(),
|
|
693
|
-
"predictions": self.predictions_to_xarray(),
|
|
694
|
-
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
|
|
695
|
-
},
|
|
696
|
-
)
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
def _process_configuration(comments):
|
|
700
|
-
"""Extract sampling information."""
|
|
701
|
-
results = {
|
|
702
|
-
"comments": "\n".join(comments),
|
|
703
|
-
"stan_version": {},
|
|
704
|
-
}
|
|
705
|
-
|
|
706
|
-
comments_gen = iter(comments)
|
|
707
|
-
|
|
708
|
-
for comment in comments_gen:
|
|
709
|
-
comment = re.sub(r"^\s*#\s*|\s*\(Default\)\s*$", "", comment).strip()
|
|
710
|
-
if comment.startswith("stan_version_"):
|
|
711
|
-
key, val = re.sub(r"^\s*stan_version_", "", comment).split("=")
|
|
712
|
-
results["stan_version"][key.strip()] = val.strip()
|
|
713
|
-
elif comment.startswith("Step size"):
|
|
714
|
-
_, val = comment.split("=")
|
|
715
|
-
results["step_size"] = float(val.strip())
|
|
716
|
-
elif "inverse mass matrix" in comment:
|
|
717
|
-
comment = re.sub(r"^\s*#\s*", "", next(comments_gen)).strip()
|
|
718
|
-
results["inverse_mass_matrix"] = [float(item) for item in comment.split(",")]
|
|
719
|
-
elif ("seconds" in comment) and any(
|
|
720
|
-
item in comment for item in ("(Warm-up)", "(Sampling)", "(Total)")
|
|
721
|
-
):
|
|
722
|
-
value = re.sub(
|
|
723
|
-
(
|
|
724
|
-
r"^Elapsed\s*Time:\s*|"
|
|
725
|
-
r"\s*seconds\s*\(Warm-up\)\s*|"
|
|
726
|
-
r"\s*seconds\s*\(Sampling\)\s*|"
|
|
727
|
-
r"\s*seconds\s*\(Total\)\s*"
|
|
728
|
-
),
|
|
729
|
-
"",
|
|
730
|
-
comment,
|
|
731
|
-
)
|
|
732
|
-
key = (
|
|
733
|
-
"warmup_time_seconds"
|
|
734
|
-
if "(Warm-up)" in comment
|
|
735
|
-
else "sampling_time_seconds" if "(Sampling)" in comment else "total_time_seconds"
|
|
736
|
-
)
|
|
737
|
-
results[key] = float(value)
|
|
738
|
-
elif "=" in comment:
|
|
739
|
-
match_int = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+)$", comment)
|
|
740
|
-
match_float = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+\.[0-9]+)$", comment)
|
|
741
|
-
match_str_bool = re.search(r"^(\S+)\s*=\s*(true|false)$", comment)
|
|
742
|
-
match_str = re.search(r"^(\S+)\s*=\s*(\S+)$", comment)
|
|
743
|
-
match_empty = re.search(r"^(\S+)\s*=\s*$", comment)
|
|
744
|
-
if match_int:
|
|
745
|
-
key, value = match_int.group(1), match_int.group(2)
|
|
746
|
-
results[key] = int(value)
|
|
747
|
-
elif match_float:
|
|
748
|
-
key, value = match_float.group(1), match_float.group(2)
|
|
749
|
-
results[key] = float(value)
|
|
750
|
-
elif match_str_bool:
|
|
751
|
-
key, value = match_str_bool.group(1), match_str_bool.group(2)
|
|
752
|
-
results[key] = int(value == "true")
|
|
753
|
-
elif match_str:
|
|
754
|
-
key, value = match_str.group(1), match_str.group(2)
|
|
755
|
-
results[key] = value
|
|
756
|
-
elif match_empty:
|
|
757
|
-
key = match_empty.group(1)
|
|
758
|
-
results[key] = None
|
|
759
|
-
|
|
760
|
-
results = {key: str(results[key]) for key in sorted(results)}
|
|
761
|
-
return results
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
def _read_output_file(path):
|
|
765
|
-
"""Read Stan csv file to ndarray."""
|
|
766
|
-
comments = []
|
|
767
|
-
data = []
|
|
768
|
-
columns = None
|
|
769
|
-
with open(path, "rb") as f_obj:
|
|
770
|
-
# read header
|
|
771
|
-
for line in f_obj:
|
|
772
|
-
if line.startswith(b"#"):
|
|
773
|
-
comments.append(line.strip().decode("utf-8"))
|
|
774
|
-
continue
|
|
775
|
-
columns = {key: idx for idx, key in enumerate(line.strip().decode("utf-8").split(","))}
|
|
776
|
-
break
|
|
777
|
-
# read data
|
|
778
|
-
for line in f_obj:
|
|
779
|
-
line = line.strip()
|
|
780
|
-
if line.startswith(b"#"):
|
|
781
|
-
comments.append(line.decode("utf-8"))
|
|
782
|
-
continue
|
|
783
|
-
if line:
|
|
784
|
-
data.append(np.array(line.split(b","), dtype=np.float64))
|
|
785
|
-
|
|
786
|
-
return columns, np.array(data, dtype=np.float64), comments
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
def _read_output(path):
|
|
790
|
-
"""Read CmdStan output csv file.
|
|
791
|
-
|
|
792
|
-
Parameters
|
|
793
|
-
----------
|
|
794
|
-
path : str
|
|
795
|
-
|
|
796
|
-
Returns
|
|
797
|
-
-------
|
|
798
|
-
Dict[str, Any]
|
|
799
|
-
"""
|
|
800
|
-
# Read data
|
|
801
|
-
columns, data, comments = _read_output_file(path)
|
|
802
|
-
|
|
803
|
-
pconf = _process_configuration(comments)
|
|
804
|
-
|
|
805
|
-
# split dataframe to warmup and draws
|
|
806
|
-
saved_warmup = (
|
|
807
|
-
int(pconf.get("save_warmup", 0))
|
|
808
|
-
* int(pconf.get("num_warmup", 0))
|
|
809
|
-
// int(pconf.get("thin", 1))
|
|
810
|
-
)
|
|
811
|
-
|
|
812
|
-
data_warmup = data[:saved_warmup]
|
|
813
|
-
data = data[saved_warmup:]
|
|
814
|
-
|
|
815
|
-
# Split data to sample_stats and sample
|
|
816
|
-
sample_stats_columns = {col: idx for col, idx in columns.items() if col.endswith("__")}
|
|
817
|
-
sample_columns = {col: idx for col, idx in columns.items() if col not in sample_stats_columns}
|
|
818
|
-
|
|
819
|
-
return {
|
|
820
|
-
"sample": data,
|
|
821
|
-
"sample_warmup": data_warmup,
|
|
822
|
-
"sample_columns": sample_columns,
|
|
823
|
-
"sample_stats_columns": sample_stats_columns,
|
|
824
|
-
"configuration_info": pconf,
|
|
825
|
-
}
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
def _process_data_var(string):
|
|
829
|
-
"""Transform datastring to key, values pair.
|
|
830
|
-
|
|
831
|
-
All values are transformed to floating point values.
|
|
832
|
-
|
|
833
|
-
Parameters
|
|
834
|
-
----------
|
|
835
|
-
string : str
|
|
836
|
-
|
|
837
|
-
Returns
|
|
838
|
-
-------
|
|
839
|
-
Tuple[Str, Str]
|
|
840
|
-
key, values pair
|
|
841
|
-
"""
|
|
842
|
-
key, var = string.split("<-")
|
|
843
|
-
if "structure" in var:
|
|
844
|
-
var, dim = var.replace("structure(", "").replace(",", "").split(".Dim")
|
|
845
|
-
# dtype = int if '.' not in var and 'e' not in var.lower() else float
|
|
846
|
-
dtype = float
|
|
847
|
-
var = var.replace("c(", "").replace(")", "").strip().split()
|
|
848
|
-
dim = dim.replace("=", "").replace("c(", "").replace(")", "").strip().split()
|
|
849
|
-
dim = tuple(map(int, dim))
|
|
850
|
-
var = np.fromiter(map(dtype, var), dtype).reshape(dim, order="F")
|
|
851
|
-
elif "c(" in var:
|
|
852
|
-
# dtype = int if '.' not in var and 'e' not in var.lower() else float
|
|
853
|
-
dtype = float
|
|
854
|
-
var = var.replace("c(", "").replace(")", "").split(",")
|
|
855
|
-
var = np.fromiter(map(dtype, var), dtype)
|
|
856
|
-
else:
|
|
857
|
-
# dtype = int if '.' not in var and 'e' not in var.lower() else float
|
|
858
|
-
dtype = float
|
|
859
|
-
var = dtype(var)
|
|
860
|
-
return key.strip(), var
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
def _read_data(path):
|
|
864
|
-
"""Read Rdump or JSON output to dictionary.
|
|
865
|
-
|
|
866
|
-
Parameters
|
|
867
|
-
----------
|
|
868
|
-
path : str
|
|
869
|
-
|
|
870
|
-
Returns
|
|
871
|
-
-------
|
|
872
|
-
Dict
|
|
873
|
-
key, values pairs from Rdump/JSON formatted data.
|
|
874
|
-
"""
|
|
875
|
-
data = {}
|
|
876
|
-
with open(path, "r", encoding="utf8") as f_obj:
|
|
877
|
-
if path.lower().endswith(".json"):
|
|
878
|
-
return json.load(f_obj)
|
|
879
|
-
var = ""
|
|
880
|
-
for line in f_obj:
|
|
881
|
-
if "<-" in line:
|
|
882
|
-
if len(var):
|
|
883
|
-
key, var = _process_data_var(var)
|
|
884
|
-
data[key] = var
|
|
885
|
-
var = ""
|
|
886
|
-
var += f" {line.strip()}"
|
|
887
|
-
if len(var):
|
|
888
|
-
key, var = _process_data_var(var)
|
|
889
|
-
data[key] = var
|
|
890
|
-
return data
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
def _unpack_ndarrays(arrays, columns, dtypes=None):
|
|
894
|
-
"""Transform a list of ndarrays to dictionary containing ndarrays.
|
|
895
|
-
|
|
896
|
-
Parameters
|
|
897
|
-
----------
|
|
898
|
-
arrays : List[np.ndarray]
|
|
899
|
-
columns: Dict[str, int]
|
|
900
|
-
dtypes: Dict[str, Any]
|
|
901
|
-
|
|
902
|
-
Returns
|
|
903
|
-
-------
|
|
904
|
-
Dict
|
|
905
|
-
key, values pairs. Values are formatted to shape = (nchain, ndraws, *shape)
|
|
906
|
-
"""
|
|
907
|
-
col_groups = defaultdict(list)
|
|
908
|
-
for col, col_idx in columns.items():
|
|
909
|
-
key, *loc = col.split(".")
|
|
910
|
-
loc = tuple(int(i) - 1 for i in loc)
|
|
911
|
-
col_groups[key].append((col_idx, loc))
|
|
912
|
-
|
|
913
|
-
chains = len(arrays)
|
|
914
|
-
draws = len(arrays[0])
|
|
915
|
-
sample = {}
|
|
916
|
-
if draws:
|
|
917
|
-
for key, cols_locs in col_groups.items():
|
|
918
|
-
ndim = np.array([loc for _, loc in cols_locs]).max(0) + 1
|
|
919
|
-
dtype = dtypes.get(key, np.float64)
|
|
920
|
-
sample[key] = np.zeros((chains, draws, *ndim), dtype=dtype)
|
|
921
|
-
for col, loc in cols_locs:
|
|
922
|
-
for chain_id, arr in enumerate(arrays):
|
|
923
|
-
draw = arr[:, col]
|
|
924
|
-
if loc == ():
|
|
925
|
-
sample[key][chain_id, :] = draw
|
|
926
|
-
else:
|
|
927
|
-
axis1_all = range(sample[key].shape[1])
|
|
928
|
-
slicer = (chain_id, axis1_all, *loc)
|
|
929
|
-
sample[key][slicer] = draw
|
|
930
|
-
return sample
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
def from_cmdstan(
|
|
934
|
-
posterior: Optional[Union[str, List[str]]] = None,
|
|
935
|
-
*,
|
|
936
|
-
posterior_predictive: Optional[Union[str, List[str]]] = None,
|
|
937
|
-
predictions: Optional[Union[str, List[str]]] = None,
|
|
938
|
-
prior: Optional[Union[str, List[str]]] = None,
|
|
939
|
-
prior_predictive: Optional[Union[str, List[str]]] = None,
|
|
940
|
-
observed_data: Optional[str] = None,
|
|
941
|
-
observed_data_var: Optional[Union[str, List[str]]] = None,
|
|
942
|
-
constant_data: Optional[str] = None,
|
|
943
|
-
constant_data_var: Optional[Union[str, List[str]]] = None,
|
|
944
|
-
predictions_constant_data: Optional[str] = None,
|
|
945
|
-
predictions_constant_data_var: Optional[Union[str, List[str]]] = None,
|
|
946
|
-
log_likelihood: Optional[Union[str, List[str]]] = None,
|
|
947
|
-
index_origin: Optional[int] = None,
|
|
948
|
-
coords: Optional[CoordSpec] = None,
|
|
949
|
-
dims: Optional[DimSpec] = None,
|
|
950
|
-
disable_glob: Optional[bool] = False,
|
|
951
|
-
save_warmup: Optional[bool] = None,
|
|
952
|
-
dtypes: Optional[Dict] = None,
|
|
953
|
-
) -> InferenceData:
|
|
954
|
-
"""Convert CmdStan data into an InferenceData object.
|
|
955
|
-
|
|
956
|
-
For a usage example read the
|
|
957
|
-
:ref:`Creating InferenceData section on from_cmdstan <creating_InferenceData>`
|
|
958
|
-
|
|
959
|
-
Parameters
|
|
960
|
-
----------
|
|
961
|
-
posterior : str or list of str, optional
|
|
962
|
-
List of paths to output.csv files.
|
|
963
|
-
posterior_predictive : str or list of str, optional
|
|
964
|
-
Posterior predictive samples for the fit. If endswith ".csv" assumes file.
|
|
965
|
-
predictions : str or list of str, optional
|
|
966
|
-
Out of sample predictions samples for the fit. If endswith ".csv" assumes file.
|
|
967
|
-
prior : str or list of str, optional
|
|
968
|
-
List of paths to output.csv files
|
|
969
|
-
prior_predictive : str or list of str, optional
|
|
970
|
-
Prior predictive samples for the fit. If endswith ".csv" assumes file.
|
|
971
|
-
observed_data : str, optional
|
|
972
|
-
Observed data used in the sampling. Path to data file in Rdump or JSON format.
|
|
973
|
-
observed_data_var : str or list of str, optional
|
|
974
|
-
Variable(s) used for slicing observed_data. If not defined, all
|
|
975
|
-
data variables are imported.
|
|
976
|
-
constant_data : str, optional
|
|
977
|
-
Constant data used in the sampling. Path to data file in Rdump or JSON format.
|
|
978
|
-
constant_data_var : str or list of str, optional
|
|
979
|
-
Variable(s) used for slicing constant_data. If not defined, all
|
|
980
|
-
data variables are imported.
|
|
981
|
-
predictions_constant_data : str, optional
|
|
982
|
-
Constant data for predictions used in the sampling.
|
|
983
|
-
Path to data file in Rdump or JSON format.
|
|
984
|
-
predictions_constant_data_var : str or list of str, optional
|
|
985
|
-
Variable(s) used for slicing predictions_constant_data.
|
|
986
|
-
If not defined, all data variables are imported.
|
|
987
|
-
log_likelihood : dict of {str: str}, list of str or str, optional
|
|
988
|
-
Pointwise log_likelihood for the data. log_likelihood is extracted from the
|
|
989
|
-
posterior. It is recommended to use this argument as a dictionary whose keys
|
|
990
|
-
are observed variable names and its values are the variables storing log
|
|
991
|
-
likelihood arrays in the Stan code. In other cases, a dictionary with keys
|
|
992
|
-
equal to its values is used. By default, if a variable ``log_lik`` is
|
|
993
|
-
present in the Stan model, it will be retrieved as pointwise log
|
|
994
|
-
likelihood values. Use ``False`` to avoid this behaviour.
|
|
995
|
-
index_origin : int, optional
|
|
996
|
-
Starting value of integer coordinate values. Defaults to the value in rcParam
|
|
997
|
-
``data.index_origin``.
|
|
998
|
-
coords : dict of {str: array_like}, optional
|
|
999
|
-
A dictionary containing the values that are used as index. The key
|
|
1000
|
-
is the name of the dimension, the values are the index values.
|
|
1001
|
-
dims : dict of {str: list of str}, optional
|
|
1002
|
-
A mapping from variables to a list of coordinate names for the variable.
|
|
1003
|
-
disable_glob : bool
|
|
1004
|
-
Don't use glob for string input. This means that all string input is
|
|
1005
|
-
assumed to be variable names (samples) or a path (data).
|
|
1006
|
-
save_warmup : bool
|
|
1007
|
-
Save warmup iterations into InferenceData object, if found in the input files.
|
|
1008
|
-
If not defined, use default defined by the rcParams.
|
|
1009
|
-
dtypes : dict or str
|
|
1010
|
-
A dictionary containing dtype information (int, float) for parameters.
|
|
1011
|
-
If input is a string, it is assumed to be a model code or path to model code file.
|
|
1012
|
-
|
|
1013
|
-
Returns
|
|
1014
|
-
-------
|
|
1015
|
-
InferenceData object
|
|
1016
|
-
"""
|
|
1017
|
-
return CmdStanConverter(
|
|
1018
|
-
posterior=posterior,
|
|
1019
|
-
posterior_predictive=posterior_predictive,
|
|
1020
|
-
predictions=predictions,
|
|
1021
|
-
prior=prior,
|
|
1022
|
-
prior_predictive=prior_predictive,
|
|
1023
|
-
observed_data=observed_data,
|
|
1024
|
-
observed_data_var=observed_data_var,
|
|
1025
|
-
constant_data=constant_data,
|
|
1026
|
-
constant_data_var=constant_data_var,
|
|
1027
|
-
predictions_constant_data=predictions_constant_data,
|
|
1028
|
-
predictions_constant_data_var=predictions_constant_data_var,
|
|
1029
|
-
log_likelihood=log_likelihood,
|
|
1030
|
-
index_origin=index_origin,
|
|
1031
|
-
coords=coords,
|
|
1032
|
-
dims=dims,
|
|
1033
|
-
disable_glob=disable_glob,
|
|
1034
|
-
save_warmup=save_warmup,
|
|
1035
|
-
dtypes=dtypes,
|
|
1036
|
-
).to_inference_data()
|