arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_pyro.py
DELETED
|
@@ -1,333 +0,0 @@
|
|
|
1
|
-
"""Pyro-specific conversion code."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from typing import Callable, Optional
|
|
5
|
-
import warnings
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
from packaging import version
|
|
9
|
-
|
|
10
|
-
from .. import utils
|
|
11
|
-
from ..rcparams import rcParams
|
|
12
|
-
from .base import dict_to_dataset, requires
|
|
13
|
-
from .inference_data import InferenceData
|
|
14
|
-
|
|
15
|
-
_log = logging.getLogger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class PyroConverter:
|
|
19
|
-
"""Encapsulate Pyro specific logic."""
|
|
20
|
-
|
|
21
|
-
# pylint: disable=too-many-instance-attributes
|
|
22
|
-
|
|
23
|
-
model = None # type: Optional[Callable]
|
|
24
|
-
nchains = None # type: int
|
|
25
|
-
ndraws = None # type: int
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
*,
|
|
30
|
-
posterior=None,
|
|
31
|
-
prior=None,
|
|
32
|
-
posterior_predictive=None,
|
|
33
|
-
log_likelihood=None,
|
|
34
|
-
predictions=None,
|
|
35
|
-
constant_data=None,
|
|
36
|
-
predictions_constant_data=None,
|
|
37
|
-
coords=None,
|
|
38
|
-
dims=None,
|
|
39
|
-
pred_dims=None,
|
|
40
|
-
num_chains=1,
|
|
41
|
-
):
|
|
42
|
-
"""Convert Pyro data into an InferenceData object.
|
|
43
|
-
|
|
44
|
-
Parameters
|
|
45
|
-
----------
|
|
46
|
-
posterior : pyro.infer.MCMC
|
|
47
|
-
Fitted MCMC object from Pyro
|
|
48
|
-
prior: dict
|
|
49
|
-
Prior samples from a Pyro model
|
|
50
|
-
posterior_predictive : dict
|
|
51
|
-
Posterior predictive samples for the posterior
|
|
52
|
-
predictions: dict
|
|
53
|
-
Out of sample predictions
|
|
54
|
-
constant_data: dict
|
|
55
|
-
Dictionary containing constant data variables mapped to their values.
|
|
56
|
-
predictions_constant_data: dict
|
|
57
|
-
Constant data used for out-of-sample predictions.
|
|
58
|
-
coords : dict[str] -> list[str]
|
|
59
|
-
Map of dimensions to coordinates
|
|
60
|
-
dims : dict[str] -> list[str]
|
|
61
|
-
Map variable names to their coordinates
|
|
62
|
-
pred_dims: dict
|
|
63
|
-
Dims for predictions data. Map variable names to their coordinates.
|
|
64
|
-
num_chains: int
|
|
65
|
-
Number of chains used for sampling. Ignored if posterior is present.
|
|
66
|
-
"""
|
|
67
|
-
self.posterior = posterior
|
|
68
|
-
self.prior = prior
|
|
69
|
-
self.posterior_predictive = posterior_predictive
|
|
70
|
-
self.log_likelihood = (
|
|
71
|
-
rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
|
|
72
|
-
)
|
|
73
|
-
self.predictions = predictions
|
|
74
|
-
self.constant_data = constant_data
|
|
75
|
-
self.predictions_constant_data = predictions_constant_data
|
|
76
|
-
self.coords = coords
|
|
77
|
-
self.dims = {} if dims is None else dims
|
|
78
|
-
self.pred_dims = {} if pred_dims is None else pred_dims
|
|
79
|
-
import pyro
|
|
80
|
-
|
|
81
|
-
def arbitrary_element(dct):
|
|
82
|
-
return next(iter(dct.values()))
|
|
83
|
-
|
|
84
|
-
self.pyro = pyro
|
|
85
|
-
if posterior is not None:
|
|
86
|
-
self.nchains, self.ndraws = posterior.num_chains, posterior.num_samples
|
|
87
|
-
if version.parse(pyro.__version__) >= version.parse("1.0.0"):
|
|
88
|
-
self.model = self.posterior.kernel.model
|
|
89
|
-
# model arguments and keyword arguments
|
|
90
|
-
self._args = self.posterior._args # pylint: disable=protected-access
|
|
91
|
-
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
|
|
92
|
-
else:
|
|
93
|
-
self.nchains = num_chains
|
|
94
|
-
get_from = None
|
|
95
|
-
if predictions is not None:
|
|
96
|
-
get_from = predictions
|
|
97
|
-
elif posterior_predictive is not None:
|
|
98
|
-
get_from = posterior_predictive
|
|
99
|
-
elif prior is not None:
|
|
100
|
-
get_from = prior
|
|
101
|
-
if get_from is None and constant_data is None and predictions_constant_data is None:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"When constructing InferenceData must have at least"
|
|
104
|
-
" one of posterior, prior, posterior_predictive or predictions."
|
|
105
|
-
)
|
|
106
|
-
if get_from is not None:
|
|
107
|
-
aelem = arbitrary_element(get_from)
|
|
108
|
-
self.ndraws = aelem.shape[0] // self.nchains
|
|
109
|
-
|
|
110
|
-
observations = {}
|
|
111
|
-
if self.model is not None:
|
|
112
|
-
trace = pyro.poutine.trace(self.model).get_trace( # pylint: disable=not-callable
|
|
113
|
-
*self._args, **self._kwargs
|
|
114
|
-
)
|
|
115
|
-
observations = {
|
|
116
|
-
name: site["value"].cpu()
|
|
117
|
-
for name, site in trace.nodes.items()
|
|
118
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
119
|
-
}
|
|
120
|
-
self.observations = observations if observations else None
|
|
121
|
-
|
|
122
|
-
@requires("posterior")
|
|
123
|
-
def posterior_to_xarray(self):
|
|
124
|
-
"""Convert the posterior to an xarray dataset."""
|
|
125
|
-
data = self.posterior.get_samples(group_by_chain=True)
|
|
126
|
-
data = {k: v.detach().cpu().numpy() for k, v in data.items()}
|
|
127
|
-
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
|
|
128
|
-
|
|
129
|
-
@requires("posterior")
|
|
130
|
-
def sample_stats_to_xarray(self):
|
|
131
|
-
"""Extract sample_stats from Pyro posterior."""
|
|
132
|
-
divergences = self.posterior.diagnostics()["divergences"]
|
|
133
|
-
diverging = np.zeros((self.nchains, self.ndraws), dtype=bool)
|
|
134
|
-
for i, k in enumerate(sorted(divergences)):
|
|
135
|
-
diverging[i, divergences[k]] = True
|
|
136
|
-
data = {"diverging": diverging}
|
|
137
|
-
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=None)
|
|
138
|
-
|
|
139
|
-
@requires("posterior")
|
|
140
|
-
@requires("model")
|
|
141
|
-
def log_likelihood_to_xarray(self):
|
|
142
|
-
"""Extract log likelihood from Pyro posterior."""
|
|
143
|
-
if not self.log_likelihood:
|
|
144
|
-
return None
|
|
145
|
-
data = {}
|
|
146
|
-
if self.observations is not None:
|
|
147
|
-
try:
|
|
148
|
-
samples = self.posterior.get_samples(group_by_chain=False)
|
|
149
|
-
predictive = self.pyro.infer.Predictive(self.model, samples)
|
|
150
|
-
vectorized_trace = predictive.get_vectorized_trace(*self._args, **self._kwargs)
|
|
151
|
-
for obs_name in self.observations.keys():
|
|
152
|
-
obs_site = vectorized_trace.nodes[obs_name]
|
|
153
|
-
log_like = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
|
|
154
|
-
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
|
|
155
|
-
data[obs_name] = np.reshape(log_like, shape)
|
|
156
|
-
except: # pylint: disable=bare-except
|
|
157
|
-
# cannot get vectorized trace
|
|
158
|
-
warnings.warn(
|
|
159
|
-
"Could not get vectorized trace, log_likelihood group will be omitted. "
|
|
160
|
-
"Check your model vectorization or set log_likelihood=False"
|
|
161
|
-
)
|
|
162
|
-
return None
|
|
163
|
-
return dict_to_dataset(
|
|
164
|
-
data, library=self.pyro, coords=self.coords, dims=self.dims, skip_event_dims=True
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
|
|
168
|
-
"""Convert posterior_predictive or prediction samples to xarray."""
|
|
169
|
-
data = {}
|
|
170
|
-
for k, ary in dct.items():
|
|
171
|
-
ary = ary.detach().cpu().numpy()
|
|
172
|
-
shape = ary.shape
|
|
173
|
-
if shape[0] == self.nchains and shape[1] == self.ndraws:
|
|
174
|
-
data[k] = ary
|
|
175
|
-
elif shape[0] == self.nchains * self.ndraws:
|
|
176
|
-
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
|
|
177
|
-
else:
|
|
178
|
-
data[k] = utils.expand_dims(ary)
|
|
179
|
-
_log.warning(
|
|
180
|
-
"posterior predictive shape not compatible with number of chains and draws."
|
|
181
|
-
"This can mean that some draws or even whole chains are not represented."
|
|
182
|
-
)
|
|
183
|
-
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=dims)
|
|
184
|
-
|
|
185
|
-
@requires("posterior_predictive")
|
|
186
|
-
def posterior_predictive_to_xarray(self):
|
|
187
|
-
"""Convert posterior_predictive samples to xarray."""
|
|
188
|
-
return self.translate_posterior_predictive_dict_to_xarray(
|
|
189
|
-
self.posterior_predictive, self.dims
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
@requires("predictions")
|
|
193
|
-
def predictions_to_xarray(self):
|
|
194
|
-
"""Convert predictions to xarray."""
|
|
195
|
-
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)
|
|
196
|
-
|
|
197
|
-
def priors_to_xarray(self):
|
|
198
|
-
"""Convert prior samples (and if possible prior predictive too) to xarray."""
|
|
199
|
-
if self.prior is None:
|
|
200
|
-
return {"prior": None, "prior_predictive": None}
|
|
201
|
-
if self.posterior is not None:
|
|
202
|
-
prior_vars = list(self.posterior.get_samples().keys())
|
|
203
|
-
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
|
|
204
|
-
else:
|
|
205
|
-
prior_vars = self.prior.keys()
|
|
206
|
-
prior_predictive_vars = None
|
|
207
|
-
priors_dict = {
|
|
208
|
-
group: (
|
|
209
|
-
None
|
|
210
|
-
if var_names is None
|
|
211
|
-
else dict_to_dataset(
|
|
212
|
-
{
|
|
213
|
-
k: utils.expand_dims(np.squeeze(self.prior[k].detach().cpu().numpy()))
|
|
214
|
-
for k in var_names
|
|
215
|
-
},
|
|
216
|
-
library=self.pyro,
|
|
217
|
-
coords=self.coords,
|
|
218
|
-
dims=self.dims,
|
|
219
|
-
)
|
|
220
|
-
)
|
|
221
|
-
for group, var_names in zip(
|
|
222
|
-
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
|
|
223
|
-
)
|
|
224
|
-
}
|
|
225
|
-
return priors_dict
|
|
226
|
-
|
|
227
|
-
@requires("observations")
|
|
228
|
-
@requires("model")
|
|
229
|
-
def observed_data_to_xarray(self):
|
|
230
|
-
"""Convert observed data to xarray."""
|
|
231
|
-
dims = {} if self.dims is None else self.dims
|
|
232
|
-
return dict_to_dataset(
|
|
233
|
-
self.observations, library=self.pyro, coords=self.coords, dims=dims, default_dims=[]
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
@requires("constant_data")
|
|
237
|
-
def constant_data_to_xarray(self):
|
|
238
|
-
"""Convert constant_data to xarray."""
|
|
239
|
-
return dict_to_dataset(
|
|
240
|
-
self.constant_data,
|
|
241
|
-
library=self.pyro,
|
|
242
|
-
coords=self.coords,
|
|
243
|
-
dims=self.dims,
|
|
244
|
-
default_dims=[],
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
@requires("predictions_constant_data")
|
|
248
|
-
def predictions_constant_data_to_xarray(self):
|
|
249
|
-
"""Convert predictions_constant_data to xarray."""
|
|
250
|
-
return dict_to_dataset(
|
|
251
|
-
self.predictions_constant_data,
|
|
252
|
-
library=self.pyro,
|
|
253
|
-
coords=self.coords,
|
|
254
|
-
dims=self.pred_dims,
|
|
255
|
-
default_dims=[],
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
def to_inference_data(self):
|
|
259
|
-
"""Convert all available data to an InferenceData object."""
|
|
260
|
-
return InferenceData(
|
|
261
|
-
**{
|
|
262
|
-
"posterior": self.posterior_to_xarray(),
|
|
263
|
-
"sample_stats": self.sample_stats_to_xarray(),
|
|
264
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
265
|
-
"posterior_predictive": self.posterior_predictive_to_xarray(),
|
|
266
|
-
"predictions": self.predictions_to_xarray(),
|
|
267
|
-
"constant_data": self.constant_data_to_xarray(),
|
|
268
|
-
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
|
|
269
|
-
**self.priors_to_xarray(),
|
|
270
|
-
"observed_data": self.observed_data_to_xarray(),
|
|
271
|
-
}
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def from_pyro(
|
|
276
|
-
posterior=None,
|
|
277
|
-
*,
|
|
278
|
-
prior=None,
|
|
279
|
-
posterior_predictive=None,
|
|
280
|
-
log_likelihood=None,
|
|
281
|
-
predictions=None,
|
|
282
|
-
constant_data=None,
|
|
283
|
-
predictions_constant_data=None,
|
|
284
|
-
coords=None,
|
|
285
|
-
dims=None,
|
|
286
|
-
pred_dims=None,
|
|
287
|
-
num_chains=1,
|
|
288
|
-
):
|
|
289
|
-
"""Convert Pyro data into an InferenceData object.
|
|
290
|
-
|
|
291
|
-
For a usage example read the
|
|
292
|
-
:ref:`Creating InferenceData section on from_pyro <creating_InferenceData>`
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
Parameters
|
|
296
|
-
----------
|
|
297
|
-
posterior : pyro.infer.MCMC
|
|
298
|
-
Fitted MCMC object from Pyro
|
|
299
|
-
prior: dict
|
|
300
|
-
Prior samples from a Pyro model
|
|
301
|
-
posterior_predictive : dict
|
|
302
|
-
Posterior predictive samples for the posterior
|
|
303
|
-
log_likelihood : bool, optional
|
|
304
|
-
Calculate and store pointwise log likelihood values. Defaults to the value
|
|
305
|
-
of rcParam ``data.log_likelihood``.
|
|
306
|
-
predictions: dict
|
|
307
|
-
Out of sample predictions
|
|
308
|
-
constant_data: dict
|
|
309
|
-
Dictionary containing constant data variables mapped to their values.
|
|
310
|
-
predictions_constant_data: dict
|
|
311
|
-
Constant data used for out-of-sample predictions.
|
|
312
|
-
coords : dict[str] -> list[str]
|
|
313
|
-
Map of dimensions to coordinates
|
|
314
|
-
dims : dict[str] -> list[str]
|
|
315
|
-
Map variable names to their coordinates
|
|
316
|
-
pred_dims: dict
|
|
317
|
-
Dims for predictions data. Map variable names to their coordinates.
|
|
318
|
-
num_chains: int
|
|
319
|
-
Number of chains used for sampling. Ignored if posterior is present.
|
|
320
|
-
"""
|
|
321
|
-
return PyroConverter(
|
|
322
|
-
posterior=posterior,
|
|
323
|
-
prior=prior,
|
|
324
|
-
posterior_predictive=posterior_predictive,
|
|
325
|
-
log_likelihood=log_likelihood,
|
|
326
|
-
predictions=predictions,
|
|
327
|
-
constant_data=constant_data,
|
|
328
|
-
predictions_constant_data=predictions_constant_data,
|
|
329
|
-
coords=coords,
|
|
330
|
-
dims=dims,
|
|
331
|
-
pred_dims=pred_dims,
|
|
332
|
-
num_chains=num_chains,
|
|
333
|
-
).to_inference_data()
|