arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_pyjags.py
DELETED
|
@@ -1,378 +0,0 @@
|
|
|
1
|
-
"""Convert PyJAGS sample dictionaries to ArviZ inference data objects."""
|
|
2
|
-
|
|
3
|
-
import typing as tp
|
|
4
|
-
from collections import OrderedDict
|
|
5
|
-
from collections.abc import Iterable
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import xarray
|
|
9
|
-
|
|
10
|
-
from .inference_data import InferenceData
|
|
11
|
-
|
|
12
|
-
from ..rcparams import rcParams
|
|
13
|
-
from .base import dict_to_dataset
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class PyJAGSConverter:
|
|
17
|
-
"""Encapsulate PyJAGS specific logic."""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
*,
|
|
22
|
-
posterior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
|
|
23
|
-
prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
|
|
24
|
-
log_likelihood: tp.Optional[
|
|
25
|
-
tp.Union[str, tp.List[str], tp.Tuple[str, ...], tp.Mapping[str, str]]
|
|
26
|
-
] = None,
|
|
27
|
-
coords=None,
|
|
28
|
-
dims=None,
|
|
29
|
-
save_warmup: tp.Optional[bool] = None,
|
|
30
|
-
warmup_iterations: int = 0,
|
|
31
|
-
) -> None:
|
|
32
|
-
self.posterior: tp.Optional[tp.Mapping[str, np.ndarray]]
|
|
33
|
-
self.log_likelihood: tp.Optional[tp.Dict[str, np.ndarray]]
|
|
34
|
-
if log_likelihood is not None and posterior is not None:
|
|
35
|
-
posterior_copy = dict(posterior) # create a shallow copy of the dictionary
|
|
36
|
-
|
|
37
|
-
if isinstance(log_likelihood, str):
|
|
38
|
-
log_likelihood = [log_likelihood]
|
|
39
|
-
if isinstance(log_likelihood, (list, tuple)):
|
|
40
|
-
log_likelihood = {name: name for name in log_likelihood}
|
|
41
|
-
|
|
42
|
-
self.log_likelihood = {
|
|
43
|
-
obs_var_name: posterior_copy.pop(log_like_name)
|
|
44
|
-
for obs_var_name, log_like_name in log_likelihood.items()
|
|
45
|
-
}
|
|
46
|
-
self.posterior = posterior_copy
|
|
47
|
-
else:
|
|
48
|
-
self.posterior = posterior
|
|
49
|
-
self.log_likelihood = None
|
|
50
|
-
self.prior = prior
|
|
51
|
-
self.coords = coords
|
|
52
|
-
self.dims = dims
|
|
53
|
-
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
54
|
-
self.warmup_iterations = warmup_iterations
|
|
55
|
-
|
|
56
|
-
import pyjags # pylint: disable=import-error
|
|
57
|
-
|
|
58
|
-
self.pyjags = pyjags
|
|
59
|
-
|
|
60
|
-
def _pyjags_samples_to_xarray(
|
|
61
|
-
self, pyjags_samples: tp.Mapping[str, np.ndarray]
|
|
62
|
-
) -> tp.Tuple[xarray.Dataset, xarray.Dataset]:
|
|
63
|
-
data, data_warmup = get_draws(
|
|
64
|
-
pyjags_samples=pyjags_samples,
|
|
65
|
-
warmup_iterations=self.warmup_iterations,
|
|
66
|
-
warmup=self.save_warmup,
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
return (
|
|
70
|
-
dict_to_dataset(data, library=self.pyjags, coords=self.coords, dims=self.dims),
|
|
71
|
-
dict_to_dataset(
|
|
72
|
-
data_warmup,
|
|
73
|
-
library=self.pyjags,
|
|
74
|
-
coords=self.coords,
|
|
75
|
-
dims=self.dims,
|
|
76
|
-
),
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
def posterior_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
|
|
80
|
-
"""Extract posterior samples from fit."""
|
|
81
|
-
if self.posterior is None:
|
|
82
|
-
return None
|
|
83
|
-
|
|
84
|
-
return self._pyjags_samples_to_xarray(self.posterior)
|
|
85
|
-
|
|
86
|
-
def prior_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
|
|
87
|
-
"""Extract posterior samples from fit."""
|
|
88
|
-
if self.prior is None:
|
|
89
|
-
return None
|
|
90
|
-
|
|
91
|
-
return self._pyjags_samples_to_xarray(self.prior)
|
|
92
|
-
|
|
93
|
-
def log_likelihood_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
|
|
94
|
-
"""Extract log likelihood samples from fit."""
|
|
95
|
-
if self.log_likelihood is None:
|
|
96
|
-
return None
|
|
97
|
-
|
|
98
|
-
return self._pyjags_samples_to_xarray(self.log_likelihood)
|
|
99
|
-
|
|
100
|
-
def to_inference_data(self):
|
|
101
|
-
"""Convert all available data to an InferenceData object."""
|
|
102
|
-
# obs_const_dict = self.observed_and_constant_data_to_xarray()
|
|
103
|
-
# predictions_const_data = self.predictions_constant_data_to_xarray()
|
|
104
|
-
save_warmup = self.save_warmup and self.warmup_iterations > 0
|
|
105
|
-
# self.posterior is not None
|
|
106
|
-
|
|
107
|
-
idata_dict = {
|
|
108
|
-
"posterior": self.posterior_to_xarray(),
|
|
109
|
-
"prior": self.prior_to_xarray(),
|
|
110
|
-
"log_likelihood": self.log_likelihood_to_xarray(),
|
|
111
|
-
"save_warmup": save_warmup,
|
|
112
|
-
}
|
|
113
|
-
|
|
114
|
-
return InferenceData(**idata_dict)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def get_draws(
|
|
118
|
-
pyjags_samples: tp.Mapping[str, np.ndarray],
|
|
119
|
-
variables: tp.Optional[tp.Union[str, tp.Iterable[str]]] = None,
|
|
120
|
-
warmup: bool = False,
|
|
121
|
-
warmup_iterations: int = 0,
|
|
122
|
-
) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]:
|
|
123
|
-
"""
|
|
124
|
-
Convert PyJAGS samples dictionary to ArviZ format and split warmup samples.
|
|
125
|
-
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC
|
|
129
|
-
chains of samples with shape
|
|
130
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
131
|
-
|
|
132
|
-
variables: the variables to extract from the samples dictionary
|
|
133
|
-
warmup: whether or not to return warmup draws in data_warmup
|
|
134
|
-
warmup_iterations: the number of warmup iterations if any
|
|
135
|
-
|
|
136
|
-
Returns
|
|
137
|
-
-------
|
|
138
|
-
A tuple of two samples dictionaries in ArviZ format
|
|
139
|
-
"""
|
|
140
|
-
data_warmup: tp.Mapping[str, np.ndarray] = OrderedDict()
|
|
141
|
-
|
|
142
|
-
if variables is None:
|
|
143
|
-
variables = list(pyjags_samples.keys())
|
|
144
|
-
elif isinstance(variables, str):
|
|
145
|
-
variables = [variables]
|
|
146
|
-
|
|
147
|
-
if not isinstance(variables, Iterable):
|
|
148
|
-
raise TypeError("variables must be of type Sequence or str")
|
|
149
|
-
|
|
150
|
-
variables = tuple(variables)
|
|
151
|
-
|
|
152
|
-
if warmup_iterations > 0:
|
|
153
|
-
(
|
|
154
|
-
warmup_samples,
|
|
155
|
-
actual_samples,
|
|
156
|
-
) = _split_pyjags_dict_in_warmup_and_actual_samples(
|
|
157
|
-
pyjags_samples=pyjags_samples,
|
|
158
|
-
warmup_iterations=warmup_iterations,
|
|
159
|
-
variable_names=variables,
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
data = _convert_pyjags_dict_to_arviz_dict(samples=actual_samples, variable_names=variables)
|
|
163
|
-
|
|
164
|
-
if warmup:
|
|
165
|
-
data_warmup = _convert_pyjags_dict_to_arviz_dict(
|
|
166
|
-
samples=warmup_samples, variable_names=variables
|
|
167
|
-
)
|
|
168
|
-
else:
|
|
169
|
-
data = _convert_pyjags_dict_to_arviz_dict(samples=pyjags_samples, variable_names=variables)
|
|
170
|
-
|
|
171
|
-
return data, data_warmup
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def _split_pyjags_dict_in_warmup_and_actual_samples(
|
|
175
|
-
pyjags_samples: tp.Mapping[str, np.ndarray],
|
|
176
|
-
warmup_iterations: int,
|
|
177
|
-
variable_names: tp.Optional[tp.Tuple[str, ...]] = None,
|
|
178
|
-
) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]:
|
|
179
|
-
"""
|
|
180
|
-
Split a PyJAGS samples dictionary into actual samples and warmup samples.
|
|
181
|
-
|
|
182
|
-
Parameters
|
|
183
|
-
----------
|
|
184
|
-
pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC
|
|
185
|
-
chains of samples with shape
|
|
186
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
187
|
-
|
|
188
|
-
warmup_iterations: the number of draws to be split off for warmum
|
|
189
|
-
variable_names: the variables in the dictionary to use; if None use all
|
|
190
|
-
|
|
191
|
-
Returns
|
|
192
|
-
-------
|
|
193
|
-
A tuple of two pyjags samples dictionaries in PyJAGS format
|
|
194
|
-
"""
|
|
195
|
-
if variable_names is None:
|
|
196
|
-
variable_names = tuple(pyjags_samples.keys())
|
|
197
|
-
|
|
198
|
-
warmup_samples: tp.Dict[str, np.ndarray] = {}
|
|
199
|
-
actual_samples: tp.Dict[str, np.ndarray] = {}
|
|
200
|
-
|
|
201
|
-
for variable_name, chains in pyjags_samples.items():
|
|
202
|
-
if variable_name in variable_names:
|
|
203
|
-
warmup_samples[variable_name] = chains[:, :warmup_iterations, :]
|
|
204
|
-
actual_samples[variable_name] = chains[:, warmup_iterations:, :]
|
|
205
|
-
|
|
206
|
-
return warmup_samples, actual_samples
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def _convert_pyjags_dict_to_arviz_dict(
|
|
210
|
-
samples: tp.Mapping[str, np.ndarray],
|
|
211
|
-
variable_names: tp.Optional[tp.Tuple[str, ...]] = None,
|
|
212
|
-
) -> tp.Mapping[str, np.ndarray]:
|
|
213
|
-
"""
|
|
214
|
-
Convert a PyJAGS dictionary to an ArviZ dictionary.
|
|
215
|
-
|
|
216
|
-
Takes a python dictionary of samples that has been generated by the sample
|
|
217
|
-
method of a model instance and returns a dictionary of samples in ArviZ
|
|
218
|
-
format.
|
|
219
|
-
|
|
220
|
-
Parameters
|
|
221
|
-
----------
|
|
222
|
-
samples: a dictionary mapping variable names to P arrays with shape
|
|
223
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
224
|
-
|
|
225
|
-
Returns
|
|
226
|
-
-------
|
|
227
|
-
a dictionary mapping variable names to NumPy arrays with shape
|
|
228
|
-
(number_of_chains, chain_length, parameter_dimension)
|
|
229
|
-
"""
|
|
230
|
-
# pyjags returns a dictionary of NumPy arrays with shape
|
|
231
|
-
# (parameter_dimension, chain_length, number_of_chains)
|
|
232
|
-
# but arviz expects samples with shape
|
|
233
|
-
# (number_of_chains, chain_length, parameter_dimension)
|
|
234
|
-
|
|
235
|
-
variable_name_to_samples_map = {}
|
|
236
|
-
|
|
237
|
-
if variable_names is None:
|
|
238
|
-
variable_names = tuple(samples.keys())
|
|
239
|
-
|
|
240
|
-
for variable_name, chains in samples.items():
|
|
241
|
-
if variable_name in variable_names:
|
|
242
|
-
parameter_dimension, _, _ = chains.shape
|
|
243
|
-
if parameter_dimension == 1:
|
|
244
|
-
variable_name_to_samples_map[variable_name] = chains[0, :, :].transpose()
|
|
245
|
-
else:
|
|
246
|
-
variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2)
|
|
247
|
-
|
|
248
|
-
return variable_name_to_samples_map
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def _extract_arviz_dict_from_inference_data(
|
|
252
|
-
idata,
|
|
253
|
-
) -> tp.Mapping[str, np.ndarray]:
|
|
254
|
-
"""
|
|
255
|
-
Extract the samples dictionary from an ArviZ inference data object.
|
|
256
|
-
|
|
257
|
-
Extracts a dictionary mapping parameter names to NumPy arrays of samples
|
|
258
|
-
with shape (number_of_chains, chain_length, parameter_dimension) from an
|
|
259
|
-
ArviZ inference data object.
|
|
260
|
-
|
|
261
|
-
Parameters
|
|
262
|
-
----------
|
|
263
|
-
idata: InferenceData
|
|
264
|
-
|
|
265
|
-
Returns
|
|
266
|
-
-------
|
|
267
|
-
a dictionary mapping variable names to NumPy arrays with shape
|
|
268
|
-
(number_of_chains, chain_length, parameter_dimension)
|
|
269
|
-
|
|
270
|
-
"""
|
|
271
|
-
variable_name_to_samples_map = {
|
|
272
|
-
key: np.array(value["data"])
|
|
273
|
-
for key, value in idata.posterior.to_dict()["data_vars"].items()
|
|
274
|
-
}
|
|
275
|
-
|
|
276
|
-
return variable_name_to_samples_map
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
def _convert_arviz_dict_to_pyjags_dict(
|
|
280
|
-
samples: tp.Mapping[str, np.ndarray],
|
|
281
|
-
) -> tp.Mapping[str, np.ndarray]:
|
|
282
|
-
"""
|
|
283
|
-
Convert and ArviZ dictionary to a PyJAGS dictionary.
|
|
284
|
-
|
|
285
|
-
Takes a python dictionary of samples in ArviZ format and returns the samples
|
|
286
|
-
as a dictionary in PyJAGS format.
|
|
287
|
-
|
|
288
|
-
Parameters
|
|
289
|
-
----------
|
|
290
|
-
samples: dict of {str : array_like}
|
|
291
|
-
a dictionary mapping variable names to NumPy arrays with shape
|
|
292
|
-
(number_of_chains, chain_length, parameter_dimension)
|
|
293
|
-
|
|
294
|
-
Returns
|
|
295
|
-
-------
|
|
296
|
-
a dictionary mapping variable names to NumPy arrays with shape
|
|
297
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
298
|
-
|
|
299
|
-
"""
|
|
300
|
-
# pyjags returns a dictionary of NumPy arrays with shape
|
|
301
|
-
# (parameter_dimension, chain_length, number_of_chains)
|
|
302
|
-
# but arviz expects samples with shape
|
|
303
|
-
# (number_of_chains, chain_length, parameter_dimension)
|
|
304
|
-
|
|
305
|
-
variable_name_to_samples_map = {}
|
|
306
|
-
|
|
307
|
-
for variable_name, chains in samples.items():
|
|
308
|
-
if chains.ndim == 2:
|
|
309
|
-
number_of_chains, chain_length = chains.shape
|
|
310
|
-
chains = chains.reshape((number_of_chains, chain_length, 1))
|
|
311
|
-
|
|
312
|
-
variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2)
|
|
313
|
-
|
|
314
|
-
return variable_name_to_samples_map
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
def from_pyjags(
|
|
318
|
-
posterior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
|
|
319
|
-
prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
|
|
320
|
-
log_likelihood: tp.Optional[tp.Mapping[str, str]] = None,
|
|
321
|
-
coords=None,
|
|
322
|
-
dims=None,
|
|
323
|
-
save_warmup=None,
|
|
324
|
-
warmup_iterations: int = 0,
|
|
325
|
-
) -> InferenceData:
|
|
326
|
-
"""
|
|
327
|
-
Convert PyJAGS posterior samples to an ArviZ inference data object.
|
|
328
|
-
|
|
329
|
-
Takes a python dictionary of samples that has been generated by the sample
|
|
330
|
-
method of a model instance and returns an Arviz inference data object.
|
|
331
|
-
For a usage example read the
|
|
332
|
-
:ref:`Creating InferenceData section on from_pyjags <creating_InferenceData>`
|
|
333
|
-
|
|
334
|
-
Parameters
|
|
335
|
-
----------
|
|
336
|
-
posterior: dict of {str : array_like}, optional
|
|
337
|
-
a dictionary mapping variable names to NumPy arrays containing
|
|
338
|
-
posterior samples with shape
|
|
339
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
340
|
-
|
|
341
|
-
prior: dict of {str : array_like}, optional
|
|
342
|
-
a dictionary mapping variable names to NumPy arrays containing
|
|
343
|
-
prior samples with shape
|
|
344
|
-
(parameter_dimension, chain_length, number_of_chains)
|
|
345
|
-
|
|
346
|
-
log_likelihood: dict of {str: str}, list of str or str, optional
|
|
347
|
-
Pointwise log_likelihood for the data. log_likelihood is extracted from the
|
|
348
|
-
posterior. It is recommended to use this argument as a dictionary whose keys
|
|
349
|
-
are observed variable names and its values are the variables storing log
|
|
350
|
-
likelihood arrays in the JAGS code. In other cases, a dictionary with keys
|
|
351
|
-
equal to its values is used.
|
|
352
|
-
|
|
353
|
-
coords: dict[str, iterable]
|
|
354
|
-
A dictionary containing the values that are used as index. The key
|
|
355
|
-
is the name of the dimension, the values are the index values.
|
|
356
|
-
|
|
357
|
-
dims: dict[str, List(str)]
|
|
358
|
-
A mapping from variables to a list of coordinate names for the variable.
|
|
359
|
-
|
|
360
|
-
save_warmup : bool, optional
|
|
361
|
-
Save warmup iterations in InferenceData. If not defined, use default defined by the rcParams.
|
|
362
|
-
|
|
363
|
-
warmup_iterations: int, optional
|
|
364
|
-
Number of warmup iterations
|
|
365
|
-
|
|
366
|
-
Returns
|
|
367
|
-
-------
|
|
368
|
-
InferenceData
|
|
369
|
-
"""
|
|
370
|
-
return PyJAGSConverter(
|
|
371
|
-
posterior=posterior,
|
|
372
|
-
prior=prior,
|
|
373
|
-
log_likelihood=log_likelihood,
|
|
374
|
-
dims=dims,
|
|
375
|
-
coords=coords,
|
|
376
|
-
save_warmup=save_warmup,
|
|
377
|
-
warmup_iterations=warmup_iterations,
|
|
378
|
-
).to_inference_data()
|