arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/data/io_emcee.py
DELETED
|
@@ -1,317 +0,0 @@
|
|
|
1
|
-
"""emcee-specific conversion code."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from collections import OrderedDict
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import xarray as xr
|
|
8
|
-
|
|
9
|
-
from .. import utils
|
|
10
|
-
from .base import dict_to_dataset, generate_dims_coords, make_attrs
|
|
11
|
-
from .inference_data import InferenceData
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _verify_names(sampler, var_names, arg_names, slices):
|
|
15
|
-
"""Make sure var_names and arg_names are assigned reasonably.
|
|
16
|
-
|
|
17
|
-
This is meant to run before loading emcee objects into InferenceData.
|
|
18
|
-
In case var_names or arg_names is None, will provide defaults. If they are
|
|
19
|
-
not None, it verifies there are the right number of them.
|
|
20
|
-
|
|
21
|
-
Throws a ValueError in case validation fails.
|
|
22
|
-
|
|
23
|
-
Parameters
|
|
24
|
-
----------
|
|
25
|
-
sampler : emcee.EnsembleSampler
|
|
26
|
-
Fitted emcee sampler
|
|
27
|
-
var_names : list[str] or None
|
|
28
|
-
Names for the emcee parameters
|
|
29
|
-
arg_names : list[str] or None
|
|
30
|
-
Names for the args/observations provided to emcee
|
|
31
|
-
slices : list[seq] or None
|
|
32
|
-
slices to select the variables (used for multidimensional variables)
|
|
33
|
-
|
|
34
|
-
Returns
|
|
35
|
-
-------
|
|
36
|
-
list[str], list[str], list[seq]
|
|
37
|
-
Defaults for var_names, arg_names and slices
|
|
38
|
-
"""
|
|
39
|
-
# There are 3 possible cases: emcee2, emcee3 and sampler read from h5 file (emcee3 only)
|
|
40
|
-
if hasattr(sampler, "args"):
|
|
41
|
-
ndim = sampler.chain.shape[-1]
|
|
42
|
-
num_args = len(sampler.args)
|
|
43
|
-
elif hasattr(sampler, "log_prob_fn"):
|
|
44
|
-
ndim = sampler.get_chain().shape[-1]
|
|
45
|
-
num_args = len(sampler.log_prob_fn.args)
|
|
46
|
-
else:
|
|
47
|
-
ndim = sampler.get_chain().shape[-1]
|
|
48
|
-
num_args = 0 # emcee only stores the posterior samples
|
|
49
|
-
|
|
50
|
-
if slices is None:
|
|
51
|
-
slices = utils.arange(ndim)
|
|
52
|
-
num_vars = ndim
|
|
53
|
-
else:
|
|
54
|
-
num_vars = len(slices)
|
|
55
|
-
indices = utils.arange(ndim)
|
|
56
|
-
slicing_try = np.concatenate([utils.one_de(indices[idx]) for idx in slices])
|
|
57
|
-
if len(set(slicing_try)) != ndim:
|
|
58
|
-
warnings.warn(
|
|
59
|
-
"Check slices: Not all parameters in chain captured. "
|
|
60
|
-
f"{ndim} are present, and {len(slicing_try)} have been captured.",
|
|
61
|
-
UserWarning,
|
|
62
|
-
)
|
|
63
|
-
if len(slicing_try) != len(set(slicing_try)):
|
|
64
|
-
warnings.warn(f"Overlapping slices. Check the index present: {slicing_try}", UserWarning)
|
|
65
|
-
|
|
66
|
-
if var_names is None:
|
|
67
|
-
var_names = [f"var_{idx}" for idx in range(num_vars)]
|
|
68
|
-
if arg_names is None:
|
|
69
|
-
arg_names = [f"arg_{idx}" for idx in range(num_args)]
|
|
70
|
-
|
|
71
|
-
if len(var_names) != num_vars:
|
|
72
|
-
raise ValueError(
|
|
73
|
-
f"The sampler has {num_vars} variables, "
|
|
74
|
-
f"but only {len(var_names)} var_names were provided!"
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
if len(arg_names) != num_args:
|
|
78
|
-
raise ValueError(
|
|
79
|
-
f"The sampler has {num_args} args, "
|
|
80
|
-
f"but only {len(arg_names)} arg_names were provided!"
|
|
81
|
-
)
|
|
82
|
-
return var_names, arg_names, slices
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
# pylint: disable=too-many-instance-attributes
|
|
86
|
-
class EmceeConverter:
|
|
87
|
-
"""Encapsulate emcee specific logic."""
|
|
88
|
-
|
|
89
|
-
def __init__(
|
|
90
|
-
self,
|
|
91
|
-
sampler,
|
|
92
|
-
var_names=None,
|
|
93
|
-
slices=None,
|
|
94
|
-
arg_names=None,
|
|
95
|
-
arg_groups=None,
|
|
96
|
-
blob_names=None,
|
|
97
|
-
blob_groups=None,
|
|
98
|
-
index_origin=None,
|
|
99
|
-
coords=None,
|
|
100
|
-
dims=None,
|
|
101
|
-
):
|
|
102
|
-
var_names, arg_names, slices = _verify_names(sampler, var_names, arg_names, slices)
|
|
103
|
-
self.sampler = sampler
|
|
104
|
-
self.var_names = var_names
|
|
105
|
-
self.slices = slices
|
|
106
|
-
self.arg_names = arg_names
|
|
107
|
-
self.arg_groups = arg_groups
|
|
108
|
-
self.blob_names = blob_names
|
|
109
|
-
self.blob_groups = blob_groups
|
|
110
|
-
self.index_origin = index_origin
|
|
111
|
-
self.coords = coords
|
|
112
|
-
self.dims = dims
|
|
113
|
-
import emcee
|
|
114
|
-
|
|
115
|
-
self.emcee = emcee
|
|
116
|
-
|
|
117
|
-
def posterior_to_xarray(self):
|
|
118
|
-
"""Convert the posterior to an xarray dataset."""
|
|
119
|
-
# Use emcee3 syntax, else use emcee2
|
|
120
|
-
if hasattr(self.sampler, "get_chain"):
|
|
121
|
-
samples_ary = self.sampler.get_chain().swapaxes(0, 1)
|
|
122
|
-
else:
|
|
123
|
-
samples_ary = self.sampler.chain
|
|
124
|
-
|
|
125
|
-
data = {
|
|
126
|
-
var_name: (samples_ary[(..., idx)])
|
|
127
|
-
for idx, var_name in zip(self.slices, self.var_names)
|
|
128
|
-
}
|
|
129
|
-
return dict_to_dataset(
|
|
130
|
-
data,
|
|
131
|
-
library=self.emcee,
|
|
132
|
-
coords=self.coords,
|
|
133
|
-
dims=self.dims,
|
|
134
|
-
index_origin=self.index_origin,
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
def args_to_xarray(self):
|
|
138
|
-
"""Convert emcee args to observed and constant_data xarray Datasets."""
|
|
139
|
-
dims = {} if self.dims is None else self.dims
|
|
140
|
-
if self.arg_groups is None:
|
|
141
|
-
self.arg_groups = ["observed_data" for _ in self.arg_names]
|
|
142
|
-
if len(self.arg_names) != len(self.arg_groups):
|
|
143
|
-
raise ValueError(
|
|
144
|
-
"arg_names and arg_groups must have the same length, or arg_groups be None"
|
|
145
|
-
)
|
|
146
|
-
arg_groups_set = set(self.arg_groups)
|
|
147
|
-
bad_groups = [
|
|
148
|
-
group for group in arg_groups_set if group not in ("observed_data", "constant_data")
|
|
149
|
-
]
|
|
150
|
-
if bad_groups:
|
|
151
|
-
raise SyntaxError(
|
|
152
|
-
"all arg_groups values should be either 'observed_data' or 'constant_data' , "
|
|
153
|
-
f"not {bad_groups}"
|
|
154
|
-
)
|
|
155
|
-
obs_const_dict = {group: OrderedDict() for group in arg_groups_set}
|
|
156
|
-
for idx, (arg_name, group) in enumerate(zip(self.arg_names, self.arg_groups)):
|
|
157
|
-
# Use emcee3 syntax, else use emcee2
|
|
158
|
-
arg_array = np.atleast_1d(
|
|
159
|
-
self.sampler.log_prob_fn.args[idx]
|
|
160
|
-
if hasattr(self.sampler, "log_prob_fn")
|
|
161
|
-
else self.sampler.args[idx]
|
|
162
|
-
)
|
|
163
|
-
arg_dims = dims.get(arg_name)
|
|
164
|
-
arg_dims, coords = generate_dims_coords(
|
|
165
|
-
arg_array.shape,
|
|
166
|
-
arg_name,
|
|
167
|
-
dims=arg_dims,
|
|
168
|
-
coords=self.coords,
|
|
169
|
-
index_origin=self.index_origin,
|
|
170
|
-
)
|
|
171
|
-
# filter coords based on the dims
|
|
172
|
-
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in arg_dims}
|
|
173
|
-
obs_const_dict[group][arg_name] = xr.DataArray(arg_array, dims=arg_dims, coords=coords)
|
|
174
|
-
for key, values in obs_const_dict.items():
|
|
175
|
-
obs_const_dict[key] = xr.Dataset(data_vars=values, attrs=make_attrs(library=self.emcee))
|
|
176
|
-
return obs_const_dict
|
|
177
|
-
|
|
178
|
-
def blobs_to_dict(self):
|
|
179
|
-
"""Convert blobs to dictionary {groupname: xr.Dataset}.
|
|
180
|
-
|
|
181
|
-
It also stores lp values in sample_stats group.
|
|
182
|
-
"""
|
|
183
|
-
store_blobs = self.blob_names is not None
|
|
184
|
-
self.blob_names = [] if self.blob_names is None else self.blob_names
|
|
185
|
-
if self.blob_groups is None:
|
|
186
|
-
self.blob_groups = ["log_likelihood" for _ in self.blob_names]
|
|
187
|
-
if len(self.blob_names) != len(self.blob_groups):
|
|
188
|
-
raise ValueError(
|
|
189
|
-
"blob_names and blob_groups must have the same length, or blob_groups be None"
|
|
190
|
-
)
|
|
191
|
-
if store_blobs:
|
|
192
|
-
if int(self.emcee.__version__[0]) >= 3:
|
|
193
|
-
blobs = self.sampler.get_blobs()
|
|
194
|
-
else:
|
|
195
|
-
blobs = np.array(self.sampler.blobs, dtype=object)
|
|
196
|
-
if (blobs is None or blobs.size == 0) and self.blob_names:
|
|
197
|
-
raise ValueError("No blobs in sampler, blob_names must be None")
|
|
198
|
-
if len(blobs.shape) == 2:
|
|
199
|
-
blobs = np.expand_dims(blobs, axis=-1)
|
|
200
|
-
blobs = blobs.swapaxes(0, 2)
|
|
201
|
-
nblobs, nwalkers, ndraws, *_ = blobs.shape
|
|
202
|
-
if len(self.blob_names) != nblobs and len(self.blob_names) > 1:
|
|
203
|
-
raise ValueError(
|
|
204
|
-
"Incorrect number of blob names. "
|
|
205
|
-
f"Expected {nblobs}, found {len(self.blob_names)}"
|
|
206
|
-
)
|
|
207
|
-
blob_groups_set = set(self.blob_groups)
|
|
208
|
-
blob_groups_set.add("sample_stats")
|
|
209
|
-
idata_groups = ("posterior", "observed_data", "constant_data")
|
|
210
|
-
if np.any(np.isin(list(blob_groups_set), idata_groups)):
|
|
211
|
-
raise SyntaxError(
|
|
212
|
-
f"{idata_groups} groups should not come from blobs. "
|
|
213
|
-
"Using them here would overwrite their actual values"
|
|
214
|
-
)
|
|
215
|
-
blob_dict = {group: OrderedDict() for group in blob_groups_set}
|
|
216
|
-
if len(self.blob_names) == 1:
|
|
217
|
-
blob_dict[self.blob_groups[0]][self.blob_names[0]] = blobs.swapaxes(0, 2).swapaxes(0, 1)
|
|
218
|
-
else:
|
|
219
|
-
for i_blob, (name, group) in enumerate(zip(self.blob_names, self.blob_groups)):
|
|
220
|
-
# for coherent blobs (all having the same dimensions) one line is enough
|
|
221
|
-
blob = blobs[i_blob]
|
|
222
|
-
# for blobs of different size, we get an array of arrays, which we convert
|
|
223
|
-
# to an ndarray per blob_name
|
|
224
|
-
if blob.dtype == object:
|
|
225
|
-
blob = blob.reshape(-1)
|
|
226
|
-
blob = np.stack(blob)
|
|
227
|
-
blob = blob.reshape((nwalkers, ndraws, -1))
|
|
228
|
-
blob_dict[group][name] = np.squeeze(blob)
|
|
229
|
-
|
|
230
|
-
# store lp in sample_stats group
|
|
231
|
-
blob_dict["sample_stats"]["lp"] = (
|
|
232
|
-
self.sampler.get_log_prob().swapaxes(0, 1)
|
|
233
|
-
if hasattr(self.sampler, "get_log_prob")
|
|
234
|
-
else self.sampler.lnprobability
|
|
235
|
-
)
|
|
236
|
-
for key, values in blob_dict.items():
|
|
237
|
-
blob_dict[key] = dict_to_dataset(
|
|
238
|
-
values,
|
|
239
|
-
library=self.emcee,
|
|
240
|
-
coords=self.coords,
|
|
241
|
-
dims=self.dims,
|
|
242
|
-
index_origin=self.index_origin,
|
|
243
|
-
)
|
|
244
|
-
return blob_dict
|
|
245
|
-
|
|
246
|
-
def to_inference_data(self):
|
|
247
|
-
"""Convert all available data to an InferenceData object."""
|
|
248
|
-
blobs_dict = self.blobs_to_dict()
|
|
249
|
-
obs_const_dict = self.args_to_xarray()
|
|
250
|
-
return InferenceData(
|
|
251
|
-
**{"posterior": self.posterior_to_xarray(), **obs_const_dict, **blobs_dict}
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
def from_emcee(
|
|
256
|
-
sampler=None,
|
|
257
|
-
var_names=None,
|
|
258
|
-
slices=None,
|
|
259
|
-
arg_names=None,
|
|
260
|
-
arg_groups=None,
|
|
261
|
-
blob_names=None,
|
|
262
|
-
blob_groups=None,
|
|
263
|
-
index_origin=None,
|
|
264
|
-
coords=None,
|
|
265
|
-
dims=None,
|
|
266
|
-
):
|
|
267
|
-
"""Convert emcee data into an InferenceData object.
|
|
268
|
-
|
|
269
|
-
For a usage example read :ref:`emcee_conversion`
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
Parameters
|
|
273
|
-
----------
|
|
274
|
-
sampler : emcee.EnsembleSampler
|
|
275
|
-
Fitted sampler from emcee.
|
|
276
|
-
var_names : list of str, optional
|
|
277
|
-
A list of names for variables in the sampler
|
|
278
|
-
slices : list of array-like or slice, optional
|
|
279
|
-
A list containing the indexes of each variable. Should only be used
|
|
280
|
-
for multidimensional variables.
|
|
281
|
-
arg_names : list of str, optional
|
|
282
|
-
A list of names for args in the sampler
|
|
283
|
-
arg_groups : list of str, optional
|
|
284
|
-
A list of the group names (either ``observed_data`` or ``constant_data``) where
|
|
285
|
-
args in the sampler are stored. If None, all args will be stored in observed
|
|
286
|
-
data group.
|
|
287
|
-
blob_names : list of str, optional
|
|
288
|
-
A list of names for blobs in the sampler. When None,
|
|
289
|
-
blobs are omitted, independently of them being present
|
|
290
|
-
in the sampler or not.
|
|
291
|
-
blob_groups : list of str, optional
|
|
292
|
-
A list of the groups where blob_names variables
|
|
293
|
-
should be assigned respectively. If blob_names!=None
|
|
294
|
-
and blob_groups is None, all variables are assigned
|
|
295
|
-
to log_likelihood group
|
|
296
|
-
coords : dict of {str : array_like}, optional
|
|
297
|
-
Map of dimensions to coordinates
|
|
298
|
-
dims : dict of {str : list of str}, optional
|
|
299
|
-
Map variable names to their coordinates
|
|
300
|
-
|
|
301
|
-
Returns
|
|
302
|
-
-------
|
|
303
|
-
arviz.InferenceData
|
|
304
|
-
|
|
305
|
-
"""
|
|
306
|
-
return EmceeConverter(
|
|
307
|
-
sampler=sampler,
|
|
308
|
-
var_names=var_names,
|
|
309
|
-
slices=slices,
|
|
310
|
-
arg_names=arg_names,
|
|
311
|
-
arg_groups=arg_groups,
|
|
312
|
-
blob_names=blob_names,
|
|
313
|
-
blob_groups=blob_groups,
|
|
314
|
-
index_origin=index_origin,
|
|
315
|
-
coords=coords,
|
|
316
|
-
dims=dims,
|
|
317
|
-
).to_inference_data()
|
arviz/data/io_json.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
"""Input and output support for data."""
|
|
2
|
-
|
|
3
|
-
from .io_dict import from_dict
|
|
4
|
-
|
|
5
|
-
try:
|
|
6
|
-
import ujson as json
|
|
7
|
-
except ImportError:
|
|
8
|
-
# Can't find ujson using json
|
|
9
|
-
# mypy struggles with conditional imports expressed as catching ImportError:
|
|
10
|
-
# https://github.com/python/mypy/issues/1153
|
|
11
|
-
import json # type: ignore
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def from_json(filename):
|
|
15
|
-
"""Initialize object from a json file.
|
|
16
|
-
|
|
17
|
-
Will use the faster `ujson` (https://github.com/ultrajson/ultrajson) if it is available.
|
|
18
|
-
|
|
19
|
-
Parameters
|
|
20
|
-
----------
|
|
21
|
-
filename : str
|
|
22
|
-
location of json file
|
|
23
|
-
|
|
24
|
-
Returns
|
|
25
|
-
-------
|
|
26
|
-
InferenceData object
|
|
27
|
-
"""
|
|
28
|
-
with open(filename, "rb") as file:
|
|
29
|
-
idata_dict = json.load(file)
|
|
30
|
-
|
|
31
|
-
return from_dict(**idata_dict, save_warmup=True)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def to_json(idata, filename):
|
|
35
|
-
"""Save dataset as a json file.
|
|
36
|
-
|
|
37
|
-
Will use the faster `ujson` (https://github.com/ultrajson/ultrajson) if it is available.
|
|
38
|
-
|
|
39
|
-
WARNING: Only idempotent in case `idata` is InferenceData.
|
|
40
|
-
|
|
41
|
-
Parameters
|
|
42
|
-
----------
|
|
43
|
-
idata : InferenceData
|
|
44
|
-
Object to be saved
|
|
45
|
-
filename : str
|
|
46
|
-
name or path of the file to load trace
|
|
47
|
-
|
|
48
|
-
Returns
|
|
49
|
-
-------
|
|
50
|
-
str
|
|
51
|
-
filename saved to
|
|
52
|
-
"""
|
|
53
|
-
file_name = idata.to_json(filename)
|
|
54
|
-
return file_name
|
arviz/data/io_netcdf.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
"""Input and output support for data."""
|
|
2
|
-
|
|
3
|
-
from .converters import convert_to_inference_data
|
|
4
|
-
from .inference_data import InferenceData
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def from_netcdf(filename, *, engine="h5netcdf", group_kwargs=None, regex=False):
|
|
8
|
-
"""Load netcdf file back into an arviz.InferenceData.
|
|
9
|
-
|
|
10
|
-
Parameters
|
|
11
|
-
----------
|
|
12
|
-
filename : str
|
|
13
|
-
name or path of the file to load trace
|
|
14
|
-
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
|
|
15
|
-
Library used to read the netcdf file.
|
|
16
|
-
group_kwargs : dict of {str: dict}
|
|
17
|
-
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
|
|
18
|
-
The keys of the higher level should be group names or regex matching group
|
|
19
|
-
names, the inner dicts re passed to ``open_dataset``.
|
|
20
|
-
This feature is currently experimental
|
|
21
|
-
regex : str
|
|
22
|
-
Specifies where regex search should be used to extend the keyword arguments.
|
|
23
|
-
|
|
24
|
-
Returns
|
|
25
|
-
-------
|
|
26
|
-
InferenceData object
|
|
27
|
-
|
|
28
|
-
Notes
|
|
29
|
-
-----
|
|
30
|
-
By default, the datasets of the InferenceData object will be lazily loaded instead
|
|
31
|
-
of loaded into memory. This behaviour is regulated by the value of
|
|
32
|
-
``az.rcParams["data.load"]``.
|
|
33
|
-
"""
|
|
34
|
-
if group_kwargs is None:
|
|
35
|
-
group_kwargs = {}
|
|
36
|
-
return InferenceData.from_netcdf(
|
|
37
|
-
filename, engine=engine, group_kwargs=group_kwargs, regex=regex
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def to_netcdf(data, filename, *, group="posterior", engine="h5netcdf", coords=None, dims=None):
|
|
42
|
-
"""Save dataset as a netcdf file.
|
|
43
|
-
|
|
44
|
-
WARNING: Only idempotent in case `data` is InferenceData
|
|
45
|
-
|
|
46
|
-
Parameters
|
|
47
|
-
----------
|
|
48
|
-
data : InferenceData, or any object accepted by `convert_to_inference_data`
|
|
49
|
-
Object to be saved
|
|
50
|
-
filename : str
|
|
51
|
-
name or path of the file to load trace
|
|
52
|
-
group : str (optional)
|
|
53
|
-
In case `data` is not InferenceData, this is the group it will be saved to
|
|
54
|
-
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
|
|
55
|
-
Library used to read the netcdf file.
|
|
56
|
-
coords : dict (optional)
|
|
57
|
-
See `convert_to_inference_data`
|
|
58
|
-
dims : dict (optional)
|
|
59
|
-
See `convert_to_inference_data`
|
|
60
|
-
|
|
61
|
-
Returns
|
|
62
|
-
-------
|
|
63
|
-
str
|
|
64
|
-
filename saved to
|
|
65
|
-
"""
|
|
66
|
-
inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
|
|
67
|
-
file_name = inference_data.to_netcdf(filename, engine=engine)
|
|
68
|
-
return file_name
|