arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/inference_data.py +34 -7
- arviz/data/io_beanmachine.py +6 -1
- arviz/data/io_cmdstanpy.py +439 -50
- arviz/data/io_pyjags.py +5 -2
- arviz/data/io_pystan.py +1 -2
- arviz/labels.py +2 -0
- arviz/plots/backends/bokeh/bpvplot.py +7 -2
- arviz/plots/backends/bokeh/compareplot.py +7 -4
- arviz/plots/backends/bokeh/densityplot.py +0 -1
- arviz/plots/backends/bokeh/distplot.py +0 -2
- arviz/plots/backends/bokeh/forestplot.py +3 -5
- arviz/plots/backends/bokeh/kdeplot.py +0 -2
- arviz/plots/backends/bokeh/pairplot.py +0 -4
- arviz/plots/backends/matplotlib/bfplot.py +0 -1
- arviz/plots/backends/matplotlib/bpvplot.py +3 -3
- arviz/plots/backends/matplotlib/compareplot.py +1 -1
- arviz/plots/backends/matplotlib/dotplot.py +1 -1
- arviz/plots/backends/matplotlib/forestplot.py +2 -4
- arviz/plots/backends/matplotlib/kdeplot.py +0 -1
- arviz/plots/backends/matplotlib/khatplot.py +0 -1
- arviz/plots/backends/matplotlib/lmplot.py +4 -5
- arviz/plots/backends/matplotlib/pairplot.py +0 -1
- arviz/plots/backends/matplotlib/ppcplot.py +8 -5
- arviz/plots/backends/matplotlib/traceplot.py +1 -2
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +7 -2
- arviz/plots/compareplot.py +2 -2
- arviz/plots/ecdfplot.py +37 -112
- arviz/plots/elpdplot.py +1 -1
- arviz/plots/essplot.py +2 -2
- arviz/plots/kdeplot.py +0 -1
- arviz/plots/pairplot.py +1 -1
- arviz/plots/plot_utils.py +0 -1
- arviz/plots/ppcplot.py +51 -45
- arviz/plots/separationplot.py +0 -1
- arviz/stats/__init__.py +2 -0
- arviz/stats/density_utils.py +2 -2
- arviz/stats/diagnostics.py +2 -3
- arviz/stats/ecdf_utils.py +165 -0
- arviz/stats/stats.py +241 -38
- arviz/stats/stats_utils.py +36 -7
- arviz/tests/base_tests/test_data.py +73 -5
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1
- arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
- arviz/tests/base_tests/test_stats.py +43 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +3 -3
- arviz/tests/external_tests/test_data_beanmachine.py +2 -0
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyjags.py +3 -1
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +8 -8
- arviz/utils.py +15 -7
- arviz/wrappers/wrap_pymc.py +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/stats/stats.py
CHANGED
|
@@ -30,6 +30,7 @@ from .density_utils import kde as _kde
|
|
|
30
30
|
from .diagnostics import _mc_error, _multichain_statistics, ess
|
|
31
31
|
from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
|
|
32
32
|
from .stats_utils import get_log_likelihood as _get_log_likelihood
|
|
33
|
+
from .stats_utils import get_log_prior as _get_log_prior
|
|
33
34
|
from .stats_utils import logsumexp as _logsumexp
|
|
34
35
|
from .stats_utils import make_ufunc as _make_ufunc
|
|
35
36
|
from .stats_utils import stats_variance_2d as svar
|
|
@@ -51,6 +52,7 @@ __all__ = [
|
|
|
51
52
|
"waic",
|
|
52
53
|
"weight_predictions",
|
|
53
54
|
"_calculate_ics",
|
|
55
|
+
"psens",
|
|
54
56
|
]
|
|
55
57
|
|
|
56
58
|
|
|
@@ -144,6 +146,7 @@ def compare(
|
|
|
144
146
|
Compare the centered and non centered models of the eight school problem:
|
|
145
147
|
|
|
146
148
|
.. ipython::
|
|
149
|
+
:okwarning:
|
|
147
150
|
|
|
148
151
|
In [1]: import arviz as az
|
|
149
152
|
...: data1 = az.load_arviz_data("non_centered_eight")
|
|
@@ -155,6 +158,7 @@ def compare(
|
|
|
155
158
|
weights using the stacking method.
|
|
156
159
|
|
|
157
160
|
.. ipython::
|
|
161
|
+
:okwarning:
|
|
158
162
|
|
|
159
163
|
In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
|
|
160
164
|
|
|
@@ -178,37 +182,19 @@ def compare(
|
|
|
178
182
|
except Exception as e:
|
|
179
183
|
raise e.__class__("Encountered error in ELPD computation of compare.") from e
|
|
180
184
|
names = list(ics_dict.keys())
|
|
181
|
-
if ic
|
|
182
|
-
df_comp = pd.DataFrame(
|
|
183
|
-
index=names,
|
|
184
|
-
columns=[
|
|
185
|
-
"rank",
|
|
186
|
-
"elpd_loo",
|
|
187
|
-
"p_loo",
|
|
188
|
-
"elpd_diff",
|
|
189
|
-
"weight",
|
|
190
|
-
"se",
|
|
191
|
-
"dse",
|
|
192
|
-
"warning",
|
|
193
|
-
"scale",
|
|
194
|
-
],
|
|
195
|
-
dtype=np.float_,
|
|
196
|
-
)
|
|
197
|
-
elif ic == "waic":
|
|
185
|
+
if ic in {"loo", "waic"}:
|
|
198
186
|
df_comp = pd.DataFrame(
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
"
|
|
202
|
-
"
|
|
203
|
-
"
|
|
204
|
-
"
|
|
205
|
-
"
|
|
206
|
-
"
|
|
207
|
-
"
|
|
208
|
-
"
|
|
209
|
-
|
|
210
|
-
],
|
|
211
|
-
dtype=np.float_,
|
|
187
|
+
{
|
|
188
|
+
"rank": pd.Series(index=names, dtype="int"),
|
|
189
|
+
f"elpd_{ic}": pd.Series(index=names, dtype="float"),
|
|
190
|
+
f"p_{ic}": pd.Series(index=names, dtype="float"),
|
|
191
|
+
"elpd_diff": pd.Series(index=names, dtype="float"),
|
|
192
|
+
"weight": pd.Series(index=names, dtype="float"),
|
|
193
|
+
"se": pd.Series(index=names, dtype="float"),
|
|
194
|
+
"dse": pd.Series(index=names, dtype="float"),
|
|
195
|
+
"warning": pd.Series(index=names, dtype="boolean"),
|
|
196
|
+
"scale": pd.Series(index=names, dtype="str"),
|
|
197
|
+
}
|
|
212
198
|
)
|
|
213
199
|
else:
|
|
214
200
|
raise NotImplementedError(f"The information criterion {ic} is not supported.")
|
|
@@ -630,7 +616,7 @@ def _hdi(ary, hdi_prob, circular, skipna):
|
|
|
630
616
|
ary = np.sort(ary)
|
|
631
617
|
interval_idx_inc = int(np.floor(hdi_prob * n))
|
|
632
618
|
n_intervals = n - interval_idx_inc
|
|
633
|
-
interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.
|
|
619
|
+
interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
|
|
634
620
|
|
|
635
621
|
if len(interval_width) == 0:
|
|
636
622
|
raise ValueError("Too few elements for interval calculation. ")
|
|
@@ -878,17 +864,18 @@ def psislw(log_weights, reff=1.0):
|
|
|
878
864
|
|
|
879
865
|
Parameters
|
|
880
866
|
----------
|
|
881
|
-
log_weights: array
|
|
867
|
+
log_weights : DataArray or (..., N) array-like
|
|
882
868
|
Array of size (n_observations, n_samples)
|
|
883
|
-
reff: float
|
|
869
|
+
reff : float, default 1
|
|
884
870
|
relative MCMC efficiency, ``ess / n``
|
|
885
871
|
|
|
886
872
|
Returns
|
|
887
873
|
-------
|
|
888
|
-
lw_out:
|
|
889
|
-
Smoothed log weights
|
|
890
|
-
kss:
|
|
891
|
-
|
|
874
|
+
lw_out : DataArray or (..., N) ndarray
|
|
875
|
+
Smoothed, truncated and normalized log weights.
|
|
876
|
+
kss : DataArray or (...) ndarray
|
|
877
|
+
Estimates of the shape parameter *k* of the generalized Pareto
|
|
878
|
+
distribution.
|
|
892
879
|
|
|
893
880
|
References
|
|
894
881
|
----------
|
|
@@ -2093,7 +2080,7 @@ def weight_predictions(idatas, weights=None):
|
|
|
2093
2080
|
weights /= weights.sum()
|
|
2094
2081
|
|
|
2095
2082
|
len_idatas = [
|
|
2096
|
-
idata.posterior_predictive.
|
|
2083
|
+
idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
|
|
2097
2084
|
for idata in idatas
|
|
2098
2085
|
]
|
|
2099
2086
|
|
|
@@ -2113,3 +2100,219 @@ def weight_predictions(idatas, weights=None):
|
|
|
2113
2100
|
)
|
|
2114
2101
|
|
|
2115
2102
|
return weighted_samples
|
|
2103
|
+
|
|
2104
|
+
|
|
2105
|
+
def psens(
|
|
2106
|
+
data,
|
|
2107
|
+
*,
|
|
2108
|
+
component="prior",
|
|
2109
|
+
component_var_names=None,
|
|
2110
|
+
component_coords=None,
|
|
2111
|
+
var_names=None,
|
|
2112
|
+
coords=None,
|
|
2113
|
+
filter_vars=None,
|
|
2114
|
+
delta=0.01,
|
|
2115
|
+
dask_kwargs=None,
|
|
2116
|
+
):
|
|
2117
|
+
"""Compute power-scaling sensitivity diagnostic.
|
|
2118
|
+
|
|
2119
|
+
Power-scales the prior or likelihood and calculates how much the posterior is affected.
|
|
2120
|
+
|
|
2121
|
+
Parameters
|
|
2122
|
+
----------
|
|
2123
|
+
data : obj
|
|
2124
|
+
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
2125
|
+
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
2126
|
+
For ndarray: shape = (chain, draw).
|
|
2127
|
+
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
|
|
2128
|
+
component : {"prior", "likelihood"}, default "prior"
|
|
2129
|
+
When `component` is "likelihood", the log likelihood values are retrieved
|
|
2130
|
+
from the ``log_likelihood`` group as pointwise log likelihood and added
|
|
2131
|
+
together. With "prior", the log prior values are retrieved from the
|
|
2132
|
+
``log_prior`` group.
|
|
2133
|
+
component_var_names : str, optional
|
|
2134
|
+
Name of the prior or log likelihood variables to use
|
|
2135
|
+
component_coords : dict, optional
|
|
2136
|
+
Coordinates defining a subset over the component element for which to
|
|
2137
|
+
compute the prior sensitivity diagnostic.
|
|
2138
|
+
var_names : list of str, optional
|
|
2139
|
+
Names of posterior variables to include in the power scaling sensitivity diagnostic
|
|
2140
|
+
coords : dict, optional
|
|
2141
|
+
Coordinates defining a subset over the posterior. Only these variables will
|
|
2142
|
+
be used when computing the prior sensitivity.
|
|
2143
|
+
filter_vars: {None, "like", "regex"}, default None
|
|
2144
|
+
If ``None`` (default), interpret var_names as the real variables names.
|
|
2145
|
+
If "like", interpret var_names as substrings of the real variables names.
|
|
2146
|
+
If "regex", interpret var_names as regular expressions on the real variables names.
|
|
2147
|
+
delta : float
|
|
2148
|
+
Value for finite difference derivative calculation.
|
|
2149
|
+
dask_kwargs : dict, optional
|
|
2150
|
+
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
2151
|
+
|
|
2152
|
+
Returns
|
|
2153
|
+
-------
|
|
2154
|
+
xarray.Dataset
|
|
2155
|
+
Returns dataset of power-scaling sensitivity diagnostic values.
|
|
2156
|
+
Higher sensitivity values indicate greater sensitivity.
|
|
2157
|
+
Prior sensitivity above 0.05 indicates informative prior.
|
|
2158
|
+
Likelihood sensitivity below 0.05 indicates weak or nonin-formative likelihood.
|
|
2159
|
+
|
|
2160
|
+
Examples
|
|
2161
|
+
--------
|
|
2162
|
+
Compute the likelihood sensitivity for the non centered eight model:
|
|
2163
|
+
|
|
2164
|
+
.. ipython::
|
|
2165
|
+
|
|
2166
|
+
In [1]: import arviz as az
|
|
2167
|
+
...: data = az.load_arviz_data("non_centered_eight")
|
|
2168
|
+
...: az.psens(data, component="likelihood")
|
|
2169
|
+
|
|
2170
|
+
To compute the prior sensitivity, we need to first compute the log prior
|
|
2171
|
+
at each posterior sample. In our case, we know mu has a normal prior :math:`N(0, 5)`,
|
|
2172
|
+
tau is a half cauchy prior with scale/beta parameter 5,
|
|
2173
|
+
and theta has a standard normal as prior.
|
|
2174
|
+
We add this information to the ``log_prior`` group before computing powerscaling
|
|
2175
|
+
check with ``psens``
|
|
2176
|
+
|
|
2177
|
+
.. ipython::
|
|
2178
|
+
|
|
2179
|
+
In [1]: from xarray_einstats.stats import XrContinuousRV
|
|
2180
|
+
...: from scipy.stats import norm, halfcauchy
|
|
2181
|
+
...: post = data.posterior
|
|
2182
|
+
...: log_prior = {
|
|
2183
|
+
...: "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
|
|
2184
|
+
...: "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
|
|
2185
|
+
...: "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
|
|
2186
|
+
...: }
|
|
2187
|
+
...: data.add_groups({"log_prior": log_prior})
|
|
2188
|
+
...: az.psens(data, component="prior")
|
|
2189
|
+
|
|
2190
|
+
Notes
|
|
2191
|
+
-----
|
|
2192
|
+
The diagnostic is computed by power-scaling the specified component (prior or likelihood)
|
|
2193
|
+
and determining the degree to which the posterior changes as described in [1]_.
|
|
2194
|
+
It uses Pareto-smoothed importance sampling to avoid refitting the model.
|
|
2195
|
+
|
|
2196
|
+
References
|
|
2197
|
+
----------
|
|
2198
|
+
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
|
|
2199
|
+
power-scaling*, 2022, https://arxiv.org/abs/2107.14054
|
|
2200
|
+
|
|
2201
|
+
"""
|
|
2202
|
+
dataset = extract(data, var_names=var_names, filter_vars=filter_vars, group="posterior")
|
|
2203
|
+
if coords is None:
|
|
2204
|
+
dataset = dataset.sel(coords)
|
|
2205
|
+
|
|
2206
|
+
if component == "likelihood":
|
|
2207
|
+
component_draws = _get_log_likelihood(data, var_name=component_var_names, single_var=False)
|
|
2208
|
+
elif component == "prior":
|
|
2209
|
+
component_draws = _get_log_prior(data, var_names=component_var_names)
|
|
2210
|
+
else:
|
|
2211
|
+
raise ValueError("Value for `component` argument not recognized")
|
|
2212
|
+
|
|
2213
|
+
component_draws = component_draws.stack(__sample__=("chain", "draw"))
|
|
2214
|
+
if component_coords is None:
|
|
2215
|
+
component_draws = component_draws.sel(component_coords)
|
|
2216
|
+
if isinstance(component_draws, xr.DataArray):
|
|
2217
|
+
component_draws = component_draws.to_dataset()
|
|
2218
|
+
if len(component_draws.dims):
|
|
2219
|
+
component_draws = component_draws.to_stacked_array(
|
|
2220
|
+
"latent-obs_var", sample_dims=("__sample__",)
|
|
2221
|
+
).sum("latent-obs_var")
|
|
2222
|
+
# from here component_draws is a 1d object with dimensions (sample,)
|
|
2223
|
+
|
|
2224
|
+
# calculate lower and upper alpha values
|
|
2225
|
+
lower_alpha = 1 / (1 + delta)
|
|
2226
|
+
upper_alpha = 1 + delta
|
|
2227
|
+
|
|
2228
|
+
# calculate importance sampling weights for lower and upper alpha power-scaling
|
|
2229
|
+
lower_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=lower_alpha))
|
|
2230
|
+
lower_w = lower_w / np.sum(lower_w)
|
|
2231
|
+
|
|
2232
|
+
upper_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=upper_alpha))
|
|
2233
|
+
upper_w = upper_w / np.sum(upper_w)
|
|
2234
|
+
|
|
2235
|
+
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
|
2236
|
+
func_kwargs = {"lower_weights": lower_w.values, "upper_weights": upper_w.values, "delta": delta}
|
|
2237
|
+
|
|
2238
|
+
# calculate the sensitivity diagnostic based on the importance weights and draws
|
|
2239
|
+
return _wrap_xarray_ufunc(
|
|
2240
|
+
_powerscale_sens,
|
|
2241
|
+
dataset,
|
|
2242
|
+
ufunc_kwargs=ufunc_kwargs,
|
|
2243
|
+
func_kwargs=func_kwargs,
|
|
2244
|
+
dask_kwargs=dask_kwargs,
|
|
2245
|
+
input_core_dims=[["sample"]],
|
|
2246
|
+
)
|
|
2247
|
+
|
|
2248
|
+
|
|
2249
|
+
def _powerscale_sens(draws, *, lower_weights=None, upper_weights=None, delta=0.01):
|
|
2250
|
+
"""
|
|
2251
|
+
Calculate power-scaling sensitivity by finite difference
|
|
2252
|
+
second derivative of CJS
|
|
2253
|
+
"""
|
|
2254
|
+
lower_cjs = max(
|
|
2255
|
+
_cjs_dist(draws=draws, weights=lower_weights),
|
|
2256
|
+
_cjs_dist(draws=-1 * draws, weights=lower_weights),
|
|
2257
|
+
)
|
|
2258
|
+
upper_cjs = max(
|
|
2259
|
+
_cjs_dist(draws=draws, weights=upper_weights),
|
|
2260
|
+
_cjs_dist(draws=-1 * draws, weights=upper_weights),
|
|
2261
|
+
)
|
|
2262
|
+
logdiffsquare = 2 * np.log2(1 + delta)
|
|
2263
|
+
grad = (lower_cjs + upper_cjs) / logdiffsquare
|
|
2264
|
+
|
|
2265
|
+
return grad
|
|
2266
|
+
|
|
2267
|
+
|
|
2268
|
+
def _powerscale_lw(alpha, component_draws):
|
|
2269
|
+
"""
|
|
2270
|
+
Calculate log weights for power-scaling component by alpha.
|
|
2271
|
+
"""
|
|
2272
|
+
log_weights = (alpha - 1) * component_draws
|
|
2273
|
+
log_weights = psislw(log_weights)[0]
|
|
2274
|
+
|
|
2275
|
+
return log_weights
|
|
2276
|
+
|
|
2277
|
+
|
|
2278
|
+
def _cjs_dist(draws, weights):
|
|
2279
|
+
"""
|
|
2280
|
+
Calculate the cumulative Jensen-Shannon distance between original draws and weighted draws.
|
|
2281
|
+
"""
|
|
2282
|
+
|
|
2283
|
+
# sort draws and weights
|
|
2284
|
+
order = np.argsort(draws)
|
|
2285
|
+
draws = draws[order]
|
|
2286
|
+
weights = weights[order]
|
|
2287
|
+
|
|
2288
|
+
binwidth = np.diff(draws)
|
|
2289
|
+
|
|
2290
|
+
# ecdfs
|
|
2291
|
+
cdf_p = np.linspace(1 / len(draws), 1 - 1 / len(draws), len(draws) - 1)
|
|
2292
|
+
cdf_q = np.cumsum(weights / np.sum(weights))[:-1]
|
|
2293
|
+
|
|
2294
|
+
# integrals of ecdfs
|
|
2295
|
+
cdf_p_int = np.dot(cdf_p, binwidth)
|
|
2296
|
+
cdf_q_int = np.dot(cdf_q, binwidth)
|
|
2297
|
+
|
|
2298
|
+
# cjs calculation
|
|
2299
|
+
pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
|
|
2300
|
+
qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)
|
|
2301
|
+
|
|
2302
|
+
denom = 0.5 * (cdf_p + cdf_q)
|
|
2303
|
+
denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)
|
|
2304
|
+
|
|
2305
|
+
cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
|
|
2306
|
+
cdf_q_int - cdf_p_int
|
|
2307
|
+
)
|
|
2308
|
+
|
|
2309
|
+
cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
|
|
2310
|
+
cdf_p_int - cdf_q_int
|
|
2311
|
+
)
|
|
2312
|
+
|
|
2313
|
+
cjs_pq = max(0, cjs_pq)
|
|
2314
|
+
cjs_qp = max(0, cjs_qp)
|
|
2315
|
+
|
|
2316
|
+
bound = cdf_p_int + cdf_q_int
|
|
2317
|
+
|
|
2318
|
+
return np.sqrt((cjs_pq + cjs_qp) / bound)
|
arviz/stats/stats_utils.py
CHANGED
|
@@ -16,7 +16,7 @@ from ..utils import conditional_jit, conditional_vect, conditional_dask
|
|
|
16
16
|
from .density_utils import histogram as _histogram
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"]
|
|
19
|
+
__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "smooth_data", "wrap_xarray_ufunc"]
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def autocov(ary, axis=-1):
|
|
@@ -409,7 +409,7 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar
|
|
|
409
409
|
return nan_error | chain_error | draw_error
|
|
410
410
|
|
|
411
411
|
|
|
412
|
-
def get_log_likelihood(idata, var_name=None):
|
|
412
|
+
def get_log_likelihood(idata, var_name=None, single_var=True):
|
|
413
413
|
"""Retrieve the log likelihood dataarray of a given variable."""
|
|
414
414
|
if (
|
|
415
415
|
not hasattr(idata, "log_likelihood")
|
|
@@ -426,9 +426,11 @@ def get_log_likelihood(idata, var_name=None):
|
|
|
426
426
|
if var_name is None:
|
|
427
427
|
var_names = list(idata.log_likelihood.data_vars)
|
|
428
428
|
if len(var_names) > 1:
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
429
|
+
if single_var:
|
|
430
|
+
raise TypeError(
|
|
431
|
+
f"Found several log likelihood arrays {var_names}, var_name cannot be None"
|
|
432
|
+
)
|
|
433
|
+
return idata.log_likelihood[var_names]
|
|
432
434
|
return idata.log_likelihood[var_names[0]]
|
|
433
435
|
else:
|
|
434
436
|
try:
|
|
@@ -482,7 +484,7 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
|
|
|
482
484
|
base += "\n\nThere has been a warning during the calculation. Please check the results."
|
|
483
485
|
|
|
484
486
|
if kind == "loo" and "pareto_k" in self:
|
|
485
|
-
bins = np.asarray([-np.
|
|
487
|
+
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
|
|
486
488
|
counts, *_ = _histogram(self.pareto_k.values, bins)
|
|
487
489
|
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
|
|
488
490
|
extended = extended.format(
|
|
@@ -562,7 +564,25 @@ def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, a
|
|
|
562
564
|
|
|
563
565
|
|
|
564
566
|
def smooth_data(obs_vals, pp_vals):
|
|
565
|
-
"""Smooth data
|
|
567
|
+
"""Smooth data using a cubic spline.
|
|
568
|
+
|
|
569
|
+
Helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.
|
|
570
|
+
|
|
571
|
+
Parameters
|
|
572
|
+
----------
|
|
573
|
+
obs_vals : (N) array-like
|
|
574
|
+
Observed data
|
|
575
|
+
pp_vals : (S, N) array-like
|
|
576
|
+
Posterior predictive samples. ``N`` is the number of observations,
|
|
577
|
+
and ``S`` is the number of samples (generally n_chains*n_draws).
|
|
578
|
+
|
|
579
|
+
Returns
|
|
580
|
+
-------
|
|
581
|
+
obs_vals : (N) ndarray
|
|
582
|
+
Smoothed observed data
|
|
583
|
+
pp_vals : (S, N) ndarray
|
|
584
|
+
Smoothed posterior predictive samples
|
|
585
|
+
"""
|
|
566
586
|
x = np.linspace(0, 1, len(obs_vals))
|
|
567
587
|
csi = CubicSpline(x, obs_vals)
|
|
568
588
|
obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
|
|
@@ -572,3 +592,12 @@ def smooth_data(obs_vals, pp_vals):
|
|
|
572
592
|
pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
|
|
573
593
|
|
|
574
594
|
return obs_vals, pp_vals
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def get_log_prior(idata, var_names=None):
|
|
598
|
+
"""Retrieve the log prior dataarray of a given variable."""
|
|
599
|
+
if not hasattr(idata, "log_prior"):
|
|
600
|
+
raise TypeError("log prior not found in inference data object")
|
|
601
|
+
if var_names is None:
|
|
602
|
+
var_names = list(idata.log_prior.data_vars)
|
|
603
|
+
return idata.log_prior[var_names]
|
|
@@ -496,7 +496,7 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
496
496
|
with pytest.raises(KeyError):
|
|
497
497
|
idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
|
|
498
498
|
|
|
499
|
-
@pytest.mark.parametrize("use", ("del", "delattr"))
|
|
499
|
+
@pytest.mark.parametrize("use", ("del", "delattr", "delitem"))
|
|
500
500
|
def test_del(self, use):
|
|
501
501
|
# create inference data object
|
|
502
502
|
data = np.random.normal(size=(4, 500, 8))
|
|
@@ -523,6 +523,8 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
523
523
|
# Use del method
|
|
524
524
|
if use == "del":
|
|
525
525
|
del idata.sample_stats
|
|
526
|
+
elif use == "delitem":
|
|
527
|
+
del idata["sample_stats"]
|
|
526
528
|
else:
|
|
527
529
|
delattr(idata, "sample_stats")
|
|
528
530
|
|
|
@@ -763,6 +765,69 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
763
765
|
)
|
|
764
766
|
assert all(item in test_data.columns for item in ("chain", "draw"))
|
|
765
767
|
|
|
768
|
+
@pytest.mark.parametrize(
|
|
769
|
+
"kwargs",
|
|
770
|
+
(
|
|
771
|
+
{
|
|
772
|
+
"var_names": ["parameter_1", "parameter_2", "variable_1", "variable_2"],
|
|
773
|
+
"filter_vars": None,
|
|
774
|
+
"var_results": [
|
|
775
|
+
("posterior", "parameter_1"),
|
|
776
|
+
("posterior", "parameter_2"),
|
|
777
|
+
("prior", "parameter_1"),
|
|
778
|
+
("prior", "parameter_2"),
|
|
779
|
+
("posterior", "variable_1"),
|
|
780
|
+
("posterior", "variable_2"),
|
|
781
|
+
],
|
|
782
|
+
},
|
|
783
|
+
{
|
|
784
|
+
"var_names": "parameter",
|
|
785
|
+
"filter_vars": "like",
|
|
786
|
+
"groups": "posterior",
|
|
787
|
+
"var_results": ["parameter_1", "parameter_2"],
|
|
788
|
+
},
|
|
789
|
+
{
|
|
790
|
+
"var_names": "~parameter",
|
|
791
|
+
"filter_vars": "like",
|
|
792
|
+
"groups": "posterior",
|
|
793
|
+
"var_results": ["variable_1", "variable_2", "custom_name"],
|
|
794
|
+
},
|
|
795
|
+
{
|
|
796
|
+
"var_names": [".+_2$", "custom_name"],
|
|
797
|
+
"filter_vars": "regex",
|
|
798
|
+
"groups": "posterior",
|
|
799
|
+
"var_results": ["parameter_2", "variable_2", "custom_name"],
|
|
800
|
+
},
|
|
801
|
+
{
|
|
802
|
+
"var_names": ["lp"],
|
|
803
|
+
"filter_vars": "regex",
|
|
804
|
+
"groups": "sample_stats",
|
|
805
|
+
"var_results": ["lp"],
|
|
806
|
+
},
|
|
807
|
+
),
|
|
808
|
+
)
|
|
809
|
+
def test_to_dataframe_selection(self, kwargs):
|
|
810
|
+
results = kwargs.pop("var_results")
|
|
811
|
+
idata = from_dict(
|
|
812
|
+
posterior={
|
|
813
|
+
"parameter_1": np.random.randn(4, 100),
|
|
814
|
+
"parameter_2": np.random.randn(4, 100),
|
|
815
|
+
"variable_1": np.random.randn(4, 100),
|
|
816
|
+
"variable_2": np.random.randn(4, 100),
|
|
817
|
+
"custom_name": np.random.randn(4, 100),
|
|
818
|
+
},
|
|
819
|
+
prior={
|
|
820
|
+
"parameter_1": np.random.randn(4, 100),
|
|
821
|
+
"parameter_2": np.random.randn(4, 100),
|
|
822
|
+
},
|
|
823
|
+
sample_stats={
|
|
824
|
+
"lp": np.random.randn(4, 100),
|
|
825
|
+
},
|
|
826
|
+
)
|
|
827
|
+
test_data = idata.to_dataframe(**kwargs)
|
|
828
|
+
assert not test_data.empty
|
|
829
|
+
assert set(test_data.columns).symmetric_difference(results) == set(["chain", "draw"])
|
|
830
|
+
|
|
766
831
|
def test_to_dataframe_bad(self):
|
|
767
832
|
idata = from_dict(
|
|
768
833
|
posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
|
|
@@ -781,6 +846,9 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
|
|
|
781
846
|
with pytest.raises(KeyError):
|
|
782
847
|
idata.to_dataframe(groups=["invalid_group"])
|
|
783
848
|
|
|
849
|
+
with pytest.raises(ValueError):
|
|
850
|
+
idata.to_dataframe(var_names=["c"])
|
|
851
|
+
|
|
784
852
|
@pytest.mark.parametrize("use", (None, "args", "kwargs"))
|
|
785
853
|
def test_map(self, use):
|
|
786
854
|
idata = load_arviz_data("centered_eight")
|
|
@@ -1173,7 +1241,7 @@ class TestDataDict:
|
|
|
1173
1241
|
self.check_var_names_coords_dims(inference_data.prior_predictive)
|
|
1174
1242
|
self.check_var_names_coords_dims(inference_data.sample_stats_prior)
|
|
1175
1243
|
|
|
1176
|
-
pred_dims = inference_data.predictions.
|
|
1244
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
1177
1245
|
assert pred_dims == 8
|
|
1178
1246
|
|
|
1179
1247
|
def test_inference_data_warmup(self, data, eight_schools_params):
|
|
@@ -1518,8 +1586,8 @@ class TestExtractDataset:
|
|
|
1518
1586
|
idata = load_arviz_data("centered_eight")
|
|
1519
1587
|
post = extract(idata, combined=False)
|
|
1520
1588
|
assert "sample" not in post.dims
|
|
1521
|
-
assert post.
|
|
1522
|
-
assert post.
|
|
1589
|
+
assert post.sizes["chain"] == 4
|
|
1590
|
+
assert post.sizes["draw"] == 500
|
|
1523
1591
|
|
|
1524
1592
|
def test_var_name_group(self):
|
|
1525
1593
|
idata = load_arviz_data("centered_eight")
|
|
@@ -1539,5 +1607,5 @@ class TestExtractDataset:
|
|
|
1539
1607
|
def test_subset_samples(self):
|
|
1540
1608
|
idata = load_arviz_data("centered_eight")
|
|
1541
1609
|
post = extract(idata, num_samples=10)
|
|
1542
|
-
assert post.
|
|
1610
|
+
assert post.sizes["sample"] == 10
|
|
1543
1611
|
assert post.attrs == idata.posterior.attrs
|
|
@@ -327,7 +327,6 @@ def test_plot_autocorr_var_names(models, var_names):
|
|
|
327
327
|
"kwargs", [{"insample_dev": False}, {"plot_standard_error": False}, {"plot_ic_diff": False}]
|
|
328
328
|
)
|
|
329
329
|
def test_plot_compare(models, kwargs):
|
|
330
|
-
|
|
331
330
|
model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
|
|
332
331
|
|
|
333
332
|
axes = plot_compare(model_compare, backend="bokeh", show=False, **kwargs)
|
|
@@ -9,6 +9,7 @@ import pytest
|
|
|
9
9
|
from matplotlib import animation
|
|
10
10
|
from pandas import DataFrame
|
|
11
11
|
from scipy.stats import gaussian_kde, norm
|
|
12
|
+
import xarray as xr
|
|
12
13
|
|
|
13
14
|
from ...data import from_dict, load_arviz_data
|
|
14
15
|
from ...plots import (
|
|
@@ -732,6 +733,28 @@ def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
|
|
|
732
733
|
assert axes
|
|
733
734
|
|
|
734
735
|
|
|
736
|
+
def test_plot_ppc_transposed():
|
|
737
|
+
idata = load_arviz_data("rugby")
|
|
738
|
+
idata.map(
|
|
739
|
+
lambda ds: ds.assign(points=xr.concat((ds.home_points, ds.away_points), "field")),
|
|
740
|
+
groups="observed_vars",
|
|
741
|
+
inplace=True,
|
|
742
|
+
)
|
|
743
|
+
assert idata.posterior_predictive.points.dims == ("field", "chain", "draw", "match")
|
|
744
|
+
ax = plot_ppc(
|
|
745
|
+
idata,
|
|
746
|
+
kind="scatter",
|
|
747
|
+
var_names="points",
|
|
748
|
+
flatten=["field"],
|
|
749
|
+
coords={"match": ["Wales Italy"]},
|
|
750
|
+
random_seed=3,
|
|
751
|
+
num_pp_samples=8,
|
|
752
|
+
)
|
|
753
|
+
x, y = ax.get_lines()[2].get_data()
|
|
754
|
+
assert not np.isclose(y[0], 0)
|
|
755
|
+
assert np.all(np.array([40, 43, 10, 9]) == x)
|
|
756
|
+
|
|
757
|
+
|
|
735
758
|
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
|
|
736
759
|
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
|
|
737
760
|
@pytest.mark.parametrize("animated", [False, True])
|
|
@@ -1898,7 +1921,7 @@ def test_plot_ts(kwargs):
|
|
|
1898
1921
|
dims={"y": ["obs_dim"], "z": ["pred_dim"]},
|
|
1899
1922
|
)
|
|
1900
1923
|
|
|
1901
|
-
ax = plot_ts(idata=idata, y="y",
|
|
1924
|
+
ax = plot_ts(idata=idata, y="y", **kwargs)
|
|
1902
1925
|
assert np.all(ax)
|
|
1903
1926
|
|
|
1904
1927
|
|
|
@@ -10,8 +10,9 @@ from numpy.testing import (
|
|
|
10
10
|
assert_array_equal,
|
|
11
11
|
)
|
|
12
12
|
from scipy.special import logsumexp
|
|
13
|
-
from scipy.stats import linregress
|
|
13
|
+
from scipy.stats import linregress, norm, halfcauchy
|
|
14
14
|
from xarray import DataArray, Dataset
|
|
15
|
+
from xarray_einstats.stats import XrContinuousRV
|
|
15
16
|
|
|
16
17
|
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
|
|
17
18
|
from ...rcparams import rcParams
|
|
@@ -22,6 +23,7 @@ from ...stats import (
|
|
|
22
23
|
hdi,
|
|
23
24
|
loo,
|
|
24
25
|
loo_pit,
|
|
26
|
+
psens,
|
|
25
27
|
psislw,
|
|
26
28
|
r2_score,
|
|
27
29
|
summary,
|
|
@@ -829,3 +831,43 @@ def test_weight_predictions():
|
|
|
829
831
|
assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
|
|
830
832
|
new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
|
|
831
833
|
assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
@pytest.fixture(scope="module")
|
|
837
|
+
def psens_data():
|
|
838
|
+
non_centered_eight = load_arviz_data("non_centered_eight")
|
|
839
|
+
post = non_centered_eight.posterior
|
|
840
|
+
log_prior = {
|
|
841
|
+
"mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
|
|
842
|
+
"tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
|
|
843
|
+
"theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
|
|
844
|
+
}
|
|
845
|
+
non_centered_eight.add_groups({"log_prior": log_prior})
|
|
846
|
+
return non_centered_eight
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
@pytest.mark.parametrize("component", ("prior", "likelihood"))
|
|
850
|
+
def test_priorsens_global(psens_data, component):
|
|
851
|
+
result = psens(psens_data, component=component)
|
|
852
|
+
assert "mu" in result
|
|
853
|
+
assert "theta" in result
|
|
854
|
+
assert "school" in result.theta_t.dims
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def test_priorsens_var_names(psens_data):
|
|
858
|
+
result1 = psens(
|
|
859
|
+
psens_data, component="prior", component_var_names=["mu", "tau"], var_names=["mu", "tau"]
|
|
860
|
+
)
|
|
861
|
+
result2 = psens(psens_data, component="prior", var_names=["mu", "tau"])
|
|
862
|
+
for result in (result1, result2):
|
|
863
|
+
assert "theta" not in result
|
|
864
|
+
assert "mu" in result
|
|
865
|
+
assert "tau" in result
|
|
866
|
+
assert not np.isclose(result1.mu, result2.mu)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def test_priorsens_coords(psens_data):
|
|
870
|
+
result = psens(psens_data, component="likelihood", component_coords={"school": "Choate"})
|
|
871
|
+
assert "mu" in result
|
|
872
|
+
assert "theta" in result
|
|
873
|
+
assert "school" in result.theta_t.dims
|