arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/stats/stats_utils.py
DELETED
|
@@ -1,609 +0,0 @@
|
|
|
1
|
-
"""Stats-utility functions for ArviZ."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Sequence
|
|
5
|
-
from copy import copy as _copy
|
|
6
|
-
from copy import deepcopy as _deepcopy
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import pandas as pd
|
|
10
|
-
from scipy.fftpack import next_fast_len
|
|
11
|
-
from scipy.interpolate import CubicSpline
|
|
12
|
-
from scipy.stats.mstats import mquantiles
|
|
13
|
-
from xarray import apply_ufunc
|
|
14
|
-
|
|
15
|
-
from .. import _log
|
|
16
|
-
from ..utils import conditional_jit, conditional_vect, conditional_dask
|
|
17
|
-
from .density_utils import histogram as _histogram
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "smooth_data", "wrap_xarray_ufunc"]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def autocov(ary, axis=-1):
|
|
24
|
-
"""Compute autocovariance estimates for every lag for the input array.
|
|
25
|
-
|
|
26
|
-
Parameters
|
|
27
|
-
----------
|
|
28
|
-
ary : Numpy array
|
|
29
|
-
An array containing MCMC samples
|
|
30
|
-
|
|
31
|
-
Returns
|
|
32
|
-
-------
|
|
33
|
-
acov: Numpy array same size as the input array
|
|
34
|
-
"""
|
|
35
|
-
axis = axis if axis > 0 else len(ary.shape) + axis
|
|
36
|
-
n = ary.shape[axis]
|
|
37
|
-
m = next_fast_len(2 * n)
|
|
38
|
-
|
|
39
|
-
ary = ary - ary.mean(axis, keepdims=True)
|
|
40
|
-
|
|
41
|
-
# added to silence tuple warning for a submodule
|
|
42
|
-
with warnings.catch_warnings():
|
|
43
|
-
warnings.simplefilter("ignore")
|
|
44
|
-
|
|
45
|
-
ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
|
|
46
|
-
ifft_ary *= np.conjugate(ifft_ary)
|
|
47
|
-
|
|
48
|
-
shape = tuple(
|
|
49
|
-
slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
|
|
50
|
-
)
|
|
51
|
-
cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
|
|
52
|
-
cov /= n
|
|
53
|
-
|
|
54
|
-
return cov
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def autocorr(ary, axis=-1):
|
|
58
|
-
"""Compute autocorrelation using FFT for every lag for the input array.
|
|
59
|
-
|
|
60
|
-
See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation
|
|
61
|
-
|
|
62
|
-
Parameters
|
|
63
|
-
----------
|
|
64
|
-
ary : Numpy array
|
|
65
|
-
An array containing MCMC samples
|
|
66
|
-
|
|
67
|
-
Returns
|
|
68
|
-
-------
|
|
69
|
-
acorr: Numpy array same size as the input array
|
|
70
|
-
"""
|
|
71
|
-
corr = autocov(ary, axis=axis)
|
|
72
|
-
axis = axis = axis if axis > 0 else len(corr.shape) + axis
|
|
73
|
-
norm = tuple(
|
|
74
|
-
slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape)
|
|
75
|
-
)
|
|
76
|
-
with np.errstate(invalid="ignore"):
|
|
77
|
-
corr /= corr[norm]
|
|
78
|
-
return corr
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def make_ufunc(
|
|
82
|
-
func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None
|
|
83
|
-
): # noqa: D202
|
|
84
|
-
"""Make ufunc from a function taking 1D array input.
|
|
85
|
-
|
|
86
|
-
Parameters
|
|
87
|
-
----------
|
|
88
|
-
func : callable
|
|
89
|
-
n_dims : int, optional
|
|
90
|
-
Number of core dimensions not broadcasted. Dimensions are skipped from the end.
|
|
91
|
-
At minimum n_dims > 0.
|
|
92
|
-
n_output : int, optional
|
|
93
|
-
Select number of results returned by `func`.
|
|
94
|
-
If n_output > 1, ufunc returns a tuple of objects else returns an object.
|
|
95
|
-
n_input : int, optional
|
|
96
|
-
Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called
|
|
97
|
-
with ``func(ary1, ary2, *args, **kwargs)``
|
|
98
|
-
index : int, optional
|
|
99
|
-
Slice ndarray with `index`. Defaults to `Ellipsis`.
|
|
100
|
-
ravel : bool, optional
|
|
101
|
-
If true, ravel the ndarray before calling `func`.
|
|
102
|
-
check_shape: bool, optional
|
|
103
|
-
If false, do not check if the shape of the output is compatible with n_dims and
|
|
104
|
-
n_output. By default, True only for n_input=1. If n_input is larger than 1, the last
|
|
105
|
-
input array is used to check the shape, however, shape checking with multiple inputs
|
|
106
|
-
may not be correct.
|
|
107
|
-
|
|
108
|
-
Returns
|
|
109
|
-
-------
|
|
110
|
-
callable
|
|
111
|
-
ufunc wrapper for `func`.
|
|
112
|
-
"""
|
|
113
|
-
if n_dims < 1:
|
|
114
|
-
raise TypeError("n_dims must be one or higher.")
|
|
115
|
-
|
|
116
|
-
if n_input == 1 and check_shape is None:
|
|
117
|
-
check_shape = True
|
|
118
|
-
elif check_shape is None:
|
|
119
|
-
check_shape = False
|
|
120
|
-
|
|
121
|
-
def _ufunc(*args, out=None, out_shape=None, **kwargs):
|
|
122
|
-
"""General ufunc for single-output function."""
|
|
123
|
-
arys = args[:n_input]
|
|
124
|
-
n_dims_out = None
|
|
125
|
-
if out is None:
|
|
126
|
-
if out_shape is None:
|
|
127
|
-
out = np.empty(arys[-1].shape[:-n_dims])
|
|
128
|
-
else:
|
|
129
|
-
out = np.empty((*arys[-1].shape[:-n_dims], *out_shape))
|
|
130
|
-
n_dims_out = -len(out_shape)
|
|
131
|
-
elif check_shape:
|
|
132
|
-
if out.shape != arys[-1].shape[:-n_dims]:
|
|
133
|
-
msg = f"Shape incorrect for `out`: {out.shape}."
|
|
134
|
-
msg += f" Correct shape is {arys[-1].shape[:-n_dims]}"
|
|
135
|
-
raise TypeError(msg)
|
|
136
|
-
for idx in np.ndindex(out.shape[:n_dims_out]):
|
|
137
|
-
arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
|
|
138
|
-
out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
|
|
139
|
-
if n_dims_out is None:
|
|
140
|
-
out_idx = out_idx.item()
|
|
141
|
-
out[idx] = out_idx
|
|
142
|
-
return out
|
|
143
|
-
|
|
144
|
-
def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
|
|
145
|
-
"""General ufunc for multi-output function."""
|
|
146
|
-
arys = args[:n_input]
|
|
147
|
-
element_shape = arys[-1].shape[:-n_dims]
|
|
148
|
-
if out is None:
|
|
149
|
-
if out_shape is None:
|
|
150
|
-
out = tuple(np.empty(element_shape) for _ in range(n_output))
|
|
151
|
-
else:
|
|
152
|
-
out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output))
|
|
153
|
-
|
|
154
|
-
elif check_shape:
|
|
155
|
-
raise_error = False
|
|
156
|
-
correct_shape = tuple(element_shape for _ in range(n_output))
|
|
157
|
-
if isinstance(out, tuple):
|
|
158
|
-
out_shape = tuple(item.shape for item in out)
|
|
159
|
-
if out_shape != correct_shape:
|
|
160
|
-
raise_error = True
|
|
161
|
-
else:
|
|
162
|
-
raise_error = True
|
|
163
|
-
out_shape = "not tuple, type={type(out)}"
|
|
164
|
-
if raise_error:
|
|
165
|
-
msg = f"Shapes incorrect for `out`: {out_shape}."
|
|
166
|
-
msg += f" Correct shapes are {correct_shape}"
|
|
167
|
-
raise TypeError(msg)
|
|
168
|
-
for idx in np.ndindex(element_shape):
|
|
169
|
-
arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
|
|
170
|
-
results = func(*arys_idx, *args[n_input:], **kwargs)
|
|
171
|
-
for i, res in enumerate(results):
|
|
172
|
-
out[i][idx] = np.asarray(res)[index]
|
|
173
|
-
return out
|
|
174
|
-
|
|
175
|
-
if n_output > 1:
|
|
176
|
-
ufunc = _multi_ufunc
|
|
177
|
-
else:
|
|
178
|
-
ufunc = _ufunc
|
|
179
|
-
|
|
180
|
-
update_docstring(ufunc, func, n_output)
|
|
181
|
-
return ufunc
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
@conditional_dask
|
|
185
|
-
def wrap_xarray_ufunc(
|
|
186
|
-
ufunc,
|
|
187
|
-
*datasets,
|
|
188
|
-
ufunc_kwargs=None,
|
|
189
|
-
func_args=None,
|
|
190
|
-
func_kwargs=None,
|
|
191
|
-
dask_kwargs=None,
|
|
192
|
-
**kwargs,
|
|
193
|
-
):
|
|
194
|
-
"""Wrap make_ufunc with xarray.apply_ufunc.
|
|
195
|
-
|
|
196
|
-
Parameters
|
|
197
|
-
----------
|
|
198
|
-
ufunc : callable
|
|
199
|
-
*datasets : xarray.Dataset
|
|
200
|
-
ufunc_kwargs : dict
|
|
201
|
-
Keyword arguments passed to `make_ufunc`.
|
|
202
|
-
- 'n_dims', int, by default 2
|
|
203
|
-
- 'n_output', int, by default 1
|
|
204
|
-
- 'n_input', int, by default len(datasets)
|
|
205
|
-
- 'index', slice, by default Ellipsis
|
|
206
|
-
- 'ravel', bool, by default True
|
|
207
|
-
func_args : tuple
|
|
208
|
-
Arguments passed to 'ufunc'.
|
|
209
|
-
func_kwargs : dict
|
|
210
|
-
Keyword arguments passed to 'ufunc'.
|
|
211
|
-
- 'out_shape', int, by default None
|
|
212
|
-
dask_kwargs : dict
|
|
213
|
-
Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`.
|
|
214
|
-
Use ``enable_dask`` method of :class:`arviz.Dask` to set default kwargs.
|
|
215
|
-
**kwargs
|
|
216
|
-
Passed to :func:`xarray.apply_ufunc`.
|
|
217
|
-
|
|
218
|
-
Returns
|
|
219
|
-
-------
|
|
220
|
-
xarray.Dataset
|
|
221
|
-
"""
|
|
222
|
-
if ufunc_kwargs is None:
|
|
223
|
-
ufunc_kwargs = {}
|
|
224
|
-
ufunc_kwargs.setdefault("n_input", len(datasets))
|
|
225
|
-
if func_args is None:
|
|
226
|
-
func_args = tuple()
|
|
227
|
-
if func_kwargs is None:
|
|
228
|
-
func_kwargs = {}
|
|
229
|
-
if dask_kwargs is None:
|
|
230
|
-
dask_kwargs = {}
|
|
231
|
-
|
|
232
|
-
kwargs.setdefault(
|
|
233
|
-
"input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets)))
|
|
234
|
-
)
|
|
235
|
-
ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1]))
|
|
236
|
-
kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))
|
|
237
|
-
|
|
238
|
-
callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)
|
|
239
|
-
|
|
240
|
-
return apply_ufunc(
|
|
241
|
-
callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
def update_docstring(ufunc, func, n_output=1):
|
|
246
|
-
"""Update ArviZ generated ufunc docstring."""
|
|
247
|
-
module = ""
|
|
248
|
-
name = ""
|
|
249
|
-
docstring = ""
|
|
250
|
-
if hasattr(func, "__module__") and isinstance(func.__module__, str):
|
|
251
|
-
module += func.__module__
|
|
252
|
-
if hasattr(func, "__name__"):
|
|
253
|
-
name += func.__name__
|
|
254
|
-
if hasattr(func, "__doc__") and isinstance(func.__doc__, str):
|
|
255
|
-
docstring += func.__doc__
|
|
256
|
-
ufunc.__doc__ += "\n\n"
|
|
257
|
-
if module or name:
|
|
258
|
-
ufunc.__doc__ += "This function is a ufunc wrapper for "
|
|
259
|
-
ufunc.__doc__ += module + "." + name
|
|
260
|
-
ufunc.__doc__ += "\n"
|
|
261
|
-
ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:'
|
|
262
|
-
ufunc.__doc__ += "\n\n"
|
|
263
|
-
input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))'
|
|
264
|
-
if n_output > 1:
|
|
265
|
-
output_core_dims = f" tuple([] for _ in range({n_output}))"
|
|
266
|
-
msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, "
|
|
267
|
-
msg += f"output_core_dims={ output_core_dims})"
|
|
268
|
-
else:
|
|
269
|
-
output_core_dims = ""
|
|
270
|
-
msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})"
|
|
271
|
-
ufunc.__doc__ += msg
|
|
272
|
-
ufunc.__doc__ += "\n\n"
|
|
273
|
-
ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2"
|
|
274
|
-
if docstring:
|
|
275
|
-
ufunc.__doc__ += "\n\n"
|
|
276
|
-
ufunc.__doc__ += module
|
|
277
|
-
ufunc.__doc__ += name
|
|
278
|
-
ufunc.__doc__ += " docstring:"
|
|
279
|
-
ufunc.__doc__ += "\n\n"
|
|
280
|
-
ufunc.__doc__ += docstring
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True):
|
|
284
|
-
"""Stable logsumexp when b >= 0 and b is scalar.
|
|
285
|
-
|
|
286
|
-
b_inv overwrites b unless b_inv is None.
|
|
287
|
-
"""
|
|
288
|
-
# check dimensions for result arrays
|
|
289
|
-
ary = np.asarray(ary)
|
|
290
|
-
if ary.dtype.kind == "i":
|
|
291
|
-
ary = ary.astype(np.float64)
|
|
292
|
-
dtype = ary.dtype.type
|
|
293
|
-
shape = ary.shape
|
|
294
|
-
shape_len = len(shape)
|
|
295
|
-
if isinstance(axis, Sequence):
|
|
296
|
-
axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis)
|
|
297
|
-
agroup = axis
|
|
298
|
-
else:
|
|
299
|
-
axis = axis if (axis is None) or (axis >= 0) else shape_len + axis
|
|
300
|
-
agroup = (axis,)
|
|
301
|
-
shape_max = (
|
|
302
|
-
tuple(1 for _ in shape)
|
|
303
|
-
if axis is None
|
|
304
|
-
else tuple(1 if i in agroup else d for i, d in enumerate(shape))
|
|
305
|
-
)
|
|
306
|
-
# create result arrays
|
|
307
|
-
if out is None:
|
|
308
|
-
if not keepdims:
|
|
309
|
-
out_shape = (
|
|
310
|
-
tuple()
|
|
311
|
-
if axis is None
|
|
312
|
-
else tuple(d for i, d in enumerate(shape) if i not in agroup)
|
|
313
|
-
)
|
|
314
|
-
else:
|
|
315
|
-
out_shape = shape_max
|
|
316
|
-
out = np.empty(out_shape, dtype=dtype)
|
|
317
|
-
if b_inv == 0:
|
|
318
|
-
return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf
|
|
319
|
-
if b_inv is None and b == 0:
|
|
320
|
-
return np.full_like(out, -np.inf) if out.shape else -np.inf
|
|
321
|
-
ary_max = np.empty(shape_max, dtype=dtype)
|
|
322
|
-
# calculations
|
|
323
|
-
ary.max(axis=axis, keepdims=True, out=ary_max)
|
|
324
|
-
if copy:
|
|
325
|
-
ary = ary.copy()
|
|
326
|
-
ary -= ary_max
|
|
327
|
-
np.exp(ary, out=ary)
|
|
328
|
-
ary.sum(axis=axis, keepdims=keepdims, out=out)
|
|
329
|
-
np.log(out, out=out)
|
|
330
|
-
if b_inv is not None:
|
|
331
|
-
ary_max -= np.log(b_inv)
|
|
332
|
-
elif b:
|
|
333
|
-
ary_max += np.log(b)
|
|
334
|
-
out += ary_max if keepdims else ary_max.squeeze()
|
|
335
|
-
# transform to scalar if possible
|
|
336
|
-
return out if out.shape else dtype(out)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
def quantile(ary, q, axis=None, limit=None):
|
|
340
|
-
"""Use same quantile function as R (Type 7)."""
|
|
341
|
-
if limit is None:
|
|
342
|
-
limit = tuple()
|
|
343
|
-
return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit)
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None):
|
|
347
|
-
"""Validate ndarray.
|
|
348
|
-
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
|
-
ary : numpy.ndarray
|
|
352
|
-
check_nan : bool
|
|
353
|
-
Check if any value contains NaN.
|
|
354
|
-
check_shape : bool
|
|
355
|
-
Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape).
|
|
356
|
-
For 1D arrays (shape = (n,)) assumes chain equals 1.
|
|
357
|
-
nan_kwargs : dict
|
|
358
|
-
Valid kwargs are:
|
|
359
|
-
axis : int,
|
|
360
|
-
Defaults to None.
|
|
361
|
-
how : str, {"all", "any"}
|
|
362
|
-
Default to "any".
|
|
363
|
-
shape_kwargs : dict
|
|
364
|
-
Valid kwargs are:
|
|
365
|
-
min_chains : int
|
|
366
|
-
Defaults to 1.
|
|
367
|
-
min_draws : int
|
|
368
|
-
Defaults to 4.
|
|
369
|
-
|
|
370
|
-
Returns
|
|
371
|
-
-------
|
|
372
|
-
bool
|
|
373
|
-
"""
|
|
374
|
-
ary = np.asarray(ary)
|
|
375
|
-
|
|
376
|
-
nan_error = False
|
|
377
|
-
draw_error = False
|
|
378
|
-
chain_error = False
|
|
379
|
-
|
|
380
|
-
if check_nan:
|
|
381
|
-
if nan_kwargs is None:
|
|
382
|
-
nan_kwargs = {}
|
|
383
|
-
|
|
384
|
-
isnan = np.isnan(ary)
|
|
385
|
-
axis = nan_kwargs.get("axis", None)
|
|
386
|
-
if nan_kwargs.get("how", "any").lower() == "all":
|
|
387
|
-
nan_error = isnan.all(axis)
|
|
388
|
-
else:
|
|
389
|
-
nan_error = isnan.any(axis)
|
|
390
|
-
|
|
391
|
-
if (isinstance(nan_error, bool) and nan_error) or nan_error.any():
|
|
392
|
-
_log.warning("Array contains NaN-value.")
|
|
393
|
-
|
|
394
|
-
if check_shape:
|
|
395
|
-
shape = ary.shape
|
|
396
|
-
|
|
397
|
-
if shape_kwargs is None:
|
|
398
|
-
shape_kwargs = {}
|
|
399
|
-
|
|
400
|
-
min_chains = shape_kwargs.get("min_chains", 2)
|
|
401
|
-
min_draws = shape_kwargs.get("min_draws", 4)
|
|
402
|
-
error_msg = f"Shape validation failed: input_shape: {shape}, "
|
|
403
|
-
error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})"
|
|
404
|
-
|
|
405
|
-
chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains)
|
|
406
|
-
draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or (
|
|
407
|
-
(len(shape) > 1) and (shape[1] < min_draws)
|
|
408
|
-
)
|
|
409
|
-
|
|
410
|
-
if chain_error or draw_error:
|
|
411
|
-
_log.warning(error_msg)
|
|
412
|
-
|
|
413
|
-
return nan_error | chain_error | draw_error
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
def get_log_likelihood(idata, var_name=None, single_var=True):
|
|
417
|
-
"""Retrieve the log likelihood dataarray of a given variable."""
|
|
418
|
-
if (
|
|
419
|
-
not hasattr(idata, "log_likelihood")
|
|
420
|
-
and hasattr(idata, "sample_stats")
|
|
421
|
-
and hasattr(idata.sample_stats, "log_likelihood")
|
|
422
|
-
):
|
|
423
|
-
warnings.warn(
|
|
424
|
-
"Storing the log_likelihood in sample_stats groups has been deprecated",
|
|
425
|
-
DeprecationWarning,
|
|
426
|
-
)
|
|
427
|
-
return idata.sample_stats.log_likelihood
|
|
428
|
-
if not hasattr(idata, "log_likelihood"):
|
|
429
|
-
raise TypeError("log likelihood not found in inference data object")
|
|
430
|
-
if var_name is None:
|
|
431
|
-
var_names = list(idata.log_likelihood.data_vars)
|
|
432
|
-
if len(var_names) > 1:
|
|
433
|
-
if single_var:
|
|
434
|
-
raise TypeError(
|
|
435
|
-
f"Found several log likelihood arrays {var_names}, var_name cannot be None"
|
|
436
|
-
)
|
|
437
|
-
return idata.log_likelihood[var_names]
|
|
438
|
-
return idata.log_likelihood[var_names[0]]
|
|
439
|
-
else:
|
|
440
|
-
try:
|
|
441
|
-
log_likelihood = idata.log_likelihood[var_name]
|
|
442
|
-
except KeyError as err:
|
|
443
|
-
raise TypeError(f"No log likelihood data named {var_name} found") from err
|
|
444
|
-
return log_likelihood
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
BASE_FMT = """Computed from {{n_samples}} posterior samples and \
|
|
448
|
-
{{n_points}} observations log-likelihood matrix.
|
|
449
|
-
|
|
450
|
-
{{0:{0}}} Estimate SE
|
|
451
|
-
{{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}}
|
|
452
|
-
p_{{kind:{1}}} {{3:8.2f}} -"""
|
|
453
|
-
POINTWISE_LOO_FMT = """------
|
|
454
|
-
|
|
455
|
-
Pareto k diagnostic values:
|
|
456
|
-
{{0:>{0}}} {{1:>6}}
|
|
457
|
-
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
|
|
458
|
-
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
|
|
459
|
-
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
|
|
460
|
-
"""
|
|
461
|
-
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
465
|
-
"""Class to contain the data from elpd information criterion like waic or loo."""
|
|
466
|
-
|
|
467
|
-
def __str__(self):
|
|
468
|
-
"""Print elpd data in a user friendly way."""
|
|
469
|
-
kind = self.index[0].split("_")[1]
|
|
470
|
-
|
|
471
|
-
if kind not in ("loo", "waic"):
|
|
472
|
-
raise ValueError("Invalid ELPDData object")
|
|
473
|
-
|
|
474
|
-
scale_str = SCALE_DICT[self["scale"]]
|
|
475
|
-
padding = len(scale_str) + len(kind) + 1
|
|
476
|
-
base = BASE_FMT.format(padding, padding - 2)
|
|
477
|
-
base = base.format(
|
|
478
|
-
"",
|
|
479
|
-
kind=kind,
|
|
480
|
-
scale=scale_str,
|
|
481
|
-
n_samples=self.n_samples,
|
|
482
|
-
n_points=self.n_data_points,
|
|
483
|
-
*self.values,
|
|
484
|
-
)
|
|
485
|
-
|
|
486
|
-
if self.warning:
|
|
487
|
-
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
488
|
-
|
|
489
|
-
if kind == "loo" and "pareto_k" in self:
|
|
490
|
-
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
|
|
491
|
-
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
492
|
-
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
493
|
-
extended = extended.format(
|
|
494
|
-
"Count",
|
|
495
|
-
"Pct.",
|
|
496
|
-
*[*counts, *(counts / np.sum(counts) * 100)],
|
|
497
|
-
self.good_k,
|
|
498
|
-
)
|
|
499
|
-
base = "\n".join([base, extended])
|
|
500
|
-
return base
|
|
501
|
-
|
|
502
|
-
def __repr__(self):
|
|
503
|
-
"""Alias to ``__str__``."""
|
|
504
|
-
return self.__str__()
|
|
505
|
-
|
|
506
|
-
def copy(self, deep=True): # pylint:disable=overridden-final-method
|
|
507
|
-
"""Perform a pandas deep copy of the ELPDData plus a copy of the stored data."""
|
|
508
|
-
copied_obj = pd.Series.copy(self)
|
|
509
|
-
for key in copied_obj.keys():
|
|
510
|
-
if deep:
|
|
511
|
-
copied_obj[key] = _deepcopy(copied_obj[key])
|
|
512
|
-
else:
|
|
513
|
-
copied_obj[key] = _copy(copied_obj[key])
|
|
514
|
-
return ELPDData(copied_obj)
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
@conditional_jit(nopython=True)
|
|
518
|
-
def stats_variance_1d(data, ddof=0):
|
|
519
|
-
a_a, b_b = 0, 0
|
|
520
|
-
for i in data:
|
|
521
|
-
a_a = a_a + i
|
|
522
|
-
b_b = b_b + i * i
|
|
523
|
-
var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
|
|
524
|
-
var = var * (len(data) / (len(data) - ddof))
|
|
525
|
-
return var
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
def stats_variance_2d(data, ddof=0, axis=1):
|
|
529
|
-
if data.ndim == 1:
|
|
530
|
-
return stats_variance_1d(data, ddof=ddof)
|
|
531
|
-
a_a, b_b = data.shape
|
|
532
|
-
if axis == 1:
|
|
533
|
-
var = np.zeros(a_a)
|
|
534
|
-
for i in range(a_a):
|
|
535
|
-
var[i] = stats_variance_1d(data[i], ddof=ddof)
|
|
536
|
-
else:
|
|
537
|
-
var = np.zeros(b_b)
|
|
538
|
-
for i in range(b_b):
|
|
539
|
-
var[i] = stats_variance_1d(data[:, i], ddof=ddof)
|
|
540
|
-
|
|
541
|
-
return var
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
@conditional_vect
|
|
545
|
-
def _sqrt(a_a, b_b):
|
|
546
|
-
return (a_a + b_b) ** 0.5
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
def _circfunc(samples, high, low, skipna):
|
|
550
|
-
samples = np.asarray(samples)
|
|
551
|
-
if skipna:
|
|
552
|
-
samples = samples[~np.isnan(samples)]
|
|
553
|
-
if samples.size == 0:
|
|
554
|
-
return np.nan
|
|
555
|
-
return _angle(samples, low, high, np.pi)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
@conditional_vect
|
|
559
|
-
def _angle(samples, low, high, p_i=np.pi):
|
|
560
|
-
ang = (samples - low) * 2.0 * p_i / (high - low)
|
|
561
|
-
return ang
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None):
|
|
565
|
-
ang = _circfunc(samples, high, low, skipna)
|
|
566
|
-
s_s = np.sin(ang).mean(axis=axis)
|
|
567
|
-
c_c = np.cos(ang).mean(axis=axis)
|
|
568
|
-
r_r = np.hypot(s_s, c_c)
|
|
569
|
-
return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r))
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
def smooth_data(obs_vals, pp_vals):
|
|
573
|
-
"""Smooth data using a cubic spline.
|
|
574
|
-
|
|
575
|
-
Helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.
|
|
576
|
-
|
|
577
|
-
Parameters
|
|
578
|
-
----------
|
|
579
|
-
obs_vals : (N) array-like
|
|
580
|
-
Observed data
|
|
581
|
-
pp_vals : (S, N) array-like
|
|
582
|
-
Posterior predictive samples. ``N`` is the number of observations,
|
|
583
|
-
and ``S`` is the number of samples (generally n_chains*n_draws).
|
|
584
|
-
|
|
585
|
-
Returns
|
|
586
|
-
-------
|
|
587
|
-
obs_vals : (N) ndarray
|
|
588
|
-
Smoothed observed data
|
|
589
|
-
pp_vals : (S, N) ndarray
|
|
590
|
-
Smoothed posterior predictive samples
|
|
591
|
-
"""
|
|
592
|
-
x = np.linspace(0, 1, len(obs_vals))
|
|
593
|
-
csi = CubicSpline(x, obs_vals)
|
|
594
|
-
obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
|
|
595
|
-
|
|
596
|
-
x = np.linspace(0, 1, pp_vals.shape[1])
|
|
597
|
-
csi = CubicSpline(x, pp_vals, axis=1)
|
|
598
|
-
pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
|
|
599
|
-
|
|
600
|
-
return obs_vals, pp_vals
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def get_log_prior(idata, var_names=None):
|
|
604
|
-
"""Retrieve the log prior dataarray of a given variable."""
|
|
605
|
-
if not hasattr(idata, "log_prior"):
|
|
606
|
-
raise TypeError("log prior not found in inference data object")
|
|
607
|
-
if var_names is None:
|
|
608
|
-
var_names = list(idata.log_prior.data_vars)
|
|
609
|
-
return idata.log_prior[var_names]
|
arviz/tests/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Test suite."""
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Base test suite."""
|