arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/inference_data.py +34 -7
  3. arviz/data/io_beanmachine.py +6 -1
  4. arviz/data/io_cmdstanpy.py +439 -50
  5. arviz/data/io_pyjags.py +5 -2
  6. arviz/data/io_pystan.py +1 -2
  7. arviz/labels.py +2 -0
  8. arviz/plots/backends/bokeh/bpvplot.py +7 -2
  9. arviz/plots/backends/bokeh/compareplot.py +7 -4
  10. arviz/plots/backends/bokeh/densityplot.py +0 -1
  11. arviz/plots/backends/bokeh/distplot.py +0 -2
  12. arviz/plots/backends/bokeh/forestplot.py +3 -5
  13. arviz/plots/backends/bokeh/kdeplot.py +0 -2
  14. arviz/plots/backends/bokeh/pairplot.py +0 -4
  15. arviz/plots/backends/matplotlib/bfplot.py +0 -1
  16. arviz/plots/backends/matplotlib/bpvplot.py +3 -3
  17. arviz/plots/backends/matplotlib/compareplot.py +1 -1
  18. arviz/plots/backends/matplotlib/dotplot.py +1 -1
  19. arviz/plots/backends/matplotlib/forestplot.py +2 -4
  20. arviz/plots/backends/matplotlib/kdeplot.py +0 -1
  21. arviz/plots/backends/matplotlib/khatplot.py +0 -1
  22. arviz/plots/backends/matplotlib/lmplot.py +4 -5
  23. arviz/plots/backends/matplotlib/pairplot.py +0 -1
  24. arviz/plots/backends/matplotlib/ppcplot.py +8 -5
  25. arviz/plots/backends/matplotlib/traceplot.py +1 -2
  26. arviz/plots/bfplot.py +7 -6
  27. arviz/plots/bpvplot.py +7 -2
  28. arviz/plots/compareplot.py +2 -2
  29. arviz/plots/ecdfplot.py +37 -112
  30. arviz/plots/elpdplot.py +1 -1
  31. arviz/plots/essplot.py +2 -2
  32. arviz/plots/kdeplot.py +0 -1
  33. arviz/plots/pairplot.py +1 -1
  34. arviz/plots/plot_utils.py +0 -1
  35. arviz/plots/ppcplot.py +51 -45
  36. arviz/plots/separationplot.py +0 -1
  37. arviz/stats/__init__.py +2 -0
  38. arviz/stats/density_utils.py +2 -2
  39. arviz/stats/diagnostics.py +2 -3
  40. arviz/stats/ecdf_utils.py +165 -0
  41. arviz/stats/stats.py +241 -38
  42. arviz/stats/stats_utils.py +36 -7
  43. arviz/tests/base_tests/test_data.py +73 -5
  44. arviz/tests/base_tests/test_plots_bokeh.py +0 -1
  45. arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
  46. arviz/tests/base_tests/test_stats.py +43 -1
  47. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  48. arviz/tests/base_tests/test_stats_utils.py +3 -3
  49. arviz/tests/external_tests/test_data_beanmachine.py +2 -0
  50. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  51. arviz/tests/external_tests/test_data_pyjags.py +3 -1
  52. arviz/tests/external_tests/test_data_pyro.py +3 -3
  53. arviz/tests/helpers.py +8 -8
  54. arviz/utils.py +15 -7
  55. arviz/wrappers/wrap_pymc.py +1 -1
  56. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
  57. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
  58. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
  59. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
  60. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/stats/stats.py CHANGED
@@ -30,6 +30,7 @@ from .density_utils import kde as _kde
30
30
  from .diagnostics import _mc_error, _multichain_statistics, ess
31
31
  from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
32
32
  from .stats_utils import get_log_likelihood as _get_log_likelihood
33
+ from .stats_utils import get_log_prior as _get_log_prior
33
34
  from .stats_utils import logsumexp as _logsumexp
34
35
  from .stats_utils import make_ufunc as _make_ufunc
35
36
  from .stats_utils import stats_variance_2d as svar
@@ -51,6 +52,7 @@ __all__ = [
51
52
  "waic",
52
53
  "weight_predictions",
53
54
  "_calculate_ics",
55
+ "psens",
54
56
  ]
55
57
 
56
58
 
@@ -144,6 +146,7 @@ def compare(
144
146
  Compare the centered and non centered models of the eight school problem:
145
147
 
146
148
  .. ipython::
149
+ :okwarning:
147
150
 
148
151
  In [1]: import arviz as az
149
152
  ...: data1 = az.load_arviz_data("non_centered_eight")
@@ -155,6 +158,7 @@ def compare(
155
158
  weights using the stacking method.
156
159
 
157
160
  .. ipython::
161
+ :okwarning:
158
162
 
159
163
  In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
160
164
 
@@ -178,37 +182,19 @@ def compare(
178
182
  except Exception as e:
179
183
  raise e.__class__("Encountered error in ELPD computation of compare.") from e
180
184
  names = list(ics_dict.keys())
181
- if ic == "loo":
182
- df_comp = pd.DataFrame(
183
- index=names,
184
- columns=[
185
- "rank",
186
- "elpd_loo",
187
- "p_loo",
188
- "elpd_diff",
189
- "weight",
190
- "se",
191
- "dse",
192
- "warning",
193
- "scale",
194
- ],
195
- dtype=np.float_,
196
- )
197
- elif ic == "waic":
185
+ if ic in {"loo", "waic"}:
198
186
  df_comp = pd.DataFrame(
199
- index=names,
200
- columns=[
201
- "rank",
202
- "elpd_waic",
203
- "p_waic",
204
- "elpd_diff",
205
- "weight",
206
- "se",
207
- "dse",
208
- "warning",
209
- "scale",
210
- ],
211
- dtype=np.float_,
187
+ {
188
+ "rank": pd.Series(index=names, dtype="int"),
189
+ f"elpd_{ic}": pd.Series(index=names, dtype="float"),
190
+ f"p_{ic}": pd.Series(index=names, dtype="float"),
191
+ "elpd_diff": pd.Series(index=names, dtype="float"),
192
+ "weight": pd.Series(index=names, dtype="float"),
193
+ "se": pd.Series(index=names, dtype="float"),
194
+ "dse": pd.Series(index=names, dtype="float"),
195
+ "warning": pd.Series(index=names, dtype="boolean"),
196
+ "scale": pd.Series(index=names, dtype="str"),
197
+ }
212
198
  )
213
199
  else:
214
200
  raise NotImplementedError(f"The information criterion {ic} is not supported.")
@@ -630,7 +616,7 @@ def _hdi(ary, hdi_prob, circular, skipna):
630
616
  ary = np.sort(ary)
631
617
  interval_idx_inc = int(np.floor(hdi_prob * n))
632
618
  n_intervals = n - interval_idx_inc
633
- interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float_)
619
+ interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
634
620
 
635
621
  if len(interval_width) == 0:
636
622
  raise ValueError("Too few elements for interval calculation. ")
@@ -878,17 +864,18 @@ def psislw(log_weights, reff=1.0):
878
864
 
879
865
  Parameters
880
866
  ----------
881
- log_weights: array
867
+ log_weights : DataArray or (..., N) array-like
882
868
  Array of size (n_observations, n_samples)
883
- reff: float
869
+ reff : float, default 1
884
870
  relative MCMC efficiency, ``ess / n``
885
871
 
886
872
  Returns
887
873
  -------
888
- lw_out: array
889
- Smoothed log weights
890
- kss: array
891
- Pareto tail indices
874
+ lw_out : DataArray or (..., N) ndarray
875
+ Smoothed, truncated and normalized log weights.
876
+ kss : DataArray or (...) ndarray
877
+ Estimates of the shape parameter *k* of the generalized Pareto
878
+ distribution.
892
879
 
893
880
  References
894
881
  ----------
@@ -2093,7 +2080,7 @@ def weight_predictions(idatas, weights=None):
2093
2080
  weights /= weights.sum()
2094
2081
 
2095
2082
  len_idatas = [
2096
- idata.posterior_predictive.dims["chain"] * idata.posterior_predictive.dims["draw"]
2083
+ idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
2097
2084
  for idata in idatas
2098
2085
  ]
2099
2086
 
@@ -2113,3 +2100,219 @@ def weight_predictions(idatas, weights=None):
2113
2100
  )
2114
2101
 
2115
2102
  return weighted_samples
2103
+
2104
+
2105
+ def psens(
2106
+ data,
2107
+ *,
2108
+ component="prior",
2109
+ component_var_names=None,
2110
+ component_coords=None,
2111
+ var_names=None,
2112
+ coords=None,
2113
+ filter_vars=None,
2114
+ delta=0.01,
2115
+ dask_kwargs=None,
2116
+ ):
2117
+ """Compute power-scaling sensitivity diagnostic.
2118
+
2119
+ Power-scales the prior or likelihood and calculates how much the posterior is affected.
2120
+
2121
+ Parameters
2122
+ ----------
2123
+ data : obj
2124
+ Any object that can be converted to an :class:`arviz.InferenceData` object.
2125
+ Refer to documentation of :func:`arviz.convert_to_dataset` for details.
2126
+ For ndarray: shape = (chain, draw).
2127
+ For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
2128
+ component : {"prior", "likelihood"}, default "prior"
2129
+ When `component` is "likelihood", the log likelihood values are retrieved
2130
+ from the ``log_likelihood`` group as pointwise log likelihood and added
2131
+ together. With "prior", the log prior values are retrieved from the
2132
+ ``log_prior`` group.
2133
+ component_var_names : str, optional
2134
+ Name of the prior or log likelihood variables to use
2135
+ component_coords : dict, optional
2136
+ Coordinates defining a subset over the component element for which to
2137
+ compute the prior sensitivity diagnostic.
2138
+ var_names : list of str, optional
2139
+ Names of posterior variables to include in the power scaling sensitivity diagnostic
2140
+ coords : dict, optional
2141
+ Coordinates defining a subset over the posterior. Only these variables will
2142
+ be used when computing the prior sensitivity.
2143
+ filter_vars: {None, "like", "regex"}, default None
2144
+ If ``None`` (default), interpret var_names as the real variables names.
2145
+ If "like", interpret var_names as substrings of the real variables names.
2146
+ If "regex", interpret var_names as regular expressions on the real variables names.
2147
+ delta : float
2148
+ Value for finite difference derivative calculation.
2149
+ dask_kwargs : dict, optional
2150
+ Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
2151
+
2152
+ Returns
2153
+ -------
2154
+ xarray.Dataset
2155
+ Returns dataset of power-scaling sensitivity diagnostic values.
2156
+ Higher sensitivity values indicate greater sensitivity.
2157
+ Prior sensitivity above 0.05 indicates informative prior.
2158
+ Likelihood sensitivity below 0.05 indicates weak or nonin-formative likelihood.
2159
+
2160
+ Examples
2161
+ --------
2162
+ Compute the likelihood sensitivity for the non centered eight model:
2163
+
2164
+ .. ipython::
2165
+
2166
+ In [1]: import arviz as az
2167
+ ...: data = az.load_arviz_data("non_centered_eight")
2168
+ ...: az.psens(data, component="likelihood")
2169
+
2170
+ To compute the prior sensitivity, we need to first compute the log prior
2171
+ at each posterior sample. In our case, we know mu has a normal prior :math:`N(0, 5)`,
2172
+ tau is a half cauchy prior with scale/beta parameter 5,
2173
+ and theta has a standard normal as prior.
2174
+ We add this information to the ``log_prior`` group before computing powerscaling
2175
+ check with ``psens``
2176
+
2177
+ .. ipython::
2178
+
2179
+ In [1]: from xarray_einstats.stats import XrContinuousRV
2180
+ ...: from scipy.stats import norm, halfcauchy
2181
+ ...: post = data.posterior
2182
+ ...: log_prior = {
2183
+ ...: "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
2184
+ ...: "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
2185
+ ...: "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
2186
+ ...: }
2187
+ ...: data.add_groups({"log_prior": log_prior})
2188
+ ...: az.psens(data, component="prior")
2189
+
2190
+ Notes
2191
+ -----
2192
+ The diagnostic is computed by power-scaling the specified component (prior or likelihood)
2193
+ and determining the degree to which the posterior changes as described in [1]_.
2194
+ It uses Pareto-smoothed importance sampling to avoid refitting the model.
2195
+
2196
+ References
2197
+ ----------
2198
+ .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
2199
+ power-scaling*, 2022, https://arxiv.org/abs/2107.14054
2200
+
2201
+ """
2202
+ dataset = extract(data, var_names=var_names, filter_vars=filter_vars, group="posterior")
2203
+ if coords is None:
2204
+ dataset = dataset.sel(coords)
2205
+
2206
+ if component == "likelihood":
2207
+ component_draws = _get_log_likelihood(data, var_name=component_var_names, single_var=False)
2208
+ elif component == "prior":
2209
+ component_draws = _get_log_prior(data, var_names=component_var_names)
2210
+ else:
2211
+ raise ValueError("Value for `component` argument not recognized")
2212
+
2213
+ component_draws = component_draws.stack(__sample__=("chain", "draw"))
2214
+ if component_coords is None:
2215
+ component_draws = component_draws.sel(component_coords)
2216
+ if isinstance(component_draws, xr.DataArray):
2217
+ component_draws = component_draws.to_dataset()
2218
+ if len(component_draws.dims):
2219
+ component_draws = component_draws.to_stacked_array(
2220
+ "latent-obs_var", sample_dims=("__sample__",)
2221
+ ).sum("latent-obs_var")
2222
+ # from here component_draws is a 1d object with dimensions (sample,)
2223
+
2224
+ # calculate lower and upper alpha values
2225
+ lower_alpha = 1 / (1 + delta)
2226
+ upper_alpha = 1 + delta
2227
+
2228
+ # calculate importance sampling weights for lower and upper alpha power-scaling
2229
+ lower_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=lower_alpha))
2230
+ lower_w = lower_w / np.sum(lower_w)
2231
+
2232
+ upper_w = np.exp(_powerscale_lw(component_draws=component_draws, alpha=upper_alpha))
2233
+ upper_w = upper_w / np.sum(upper_w)
2234
+
2235
+ ufunc_kwargs = {"n_dims": 1, "ravel": False}
2236
+ func_kwargs = {"lower_weights": lower_w.values, "upper_weights": upper_w.values, "delta": delta}
2237
+
2238
+ # calculate the sensitivity diagnostic based on the importance weights and draws
2239
+ return _wrap_xarray_ufunc(
2240
+ _powerscale_sens,
2241
+ dataset,
2242
+ ufunc_kwargs=ufunc_kwargs,
2243
+ func_kwargs=func_kwargs,
2244
+ dask_kwargs=dask_kwargs,
2245
+ input_core_dims=[["sample"]],
2246
+ )
2247
+
2248
+
2249
+ def _powerscale_sens(draws, *, lower_weights=None, upper_weights=None, delta=0.01):
2250
+ """
2251
+ Calculate power-scaling sensitivity by finite difference
2252
+ second derivative of CJS
2253
+ """
2254
+ lower_cjs = max(
2255
+ _cjs_dist(draws=draws, weights=lower_weights),
2256
+ _cjs_dist(draws=-1 * draws, weights=lower_weights),
2257
+ )
2258
+ upper_cjs = max(
2259
+ _cjs_dist(draws=draws, weights=upper_weights),
2260
+ _cjs_dist(draws=-1 * draws, weights=upper_weights),
2261
+ )
2262
+ logdiffsquare = 2 * np.log2(1 + delta)
2263
+ grad = (lower_cjs + upper_cjs) / logdiffsquare
2264
+
2265
+ return grad
2266
+
2267
+
2268
+ def _powerscale_lw(alpha, component_draws):
2269
+ """
2270
+ Calculate log weights for power-scaling component by alpha.
2271
+ """
2272
+ log_weights = (alpha - 1) * component_draws
2273
+ log_weights = psislw(log_weights)[0]
2274
+
2275
+ return log_weights
2276
+
2277
+
2278
+ def _cjs_dist(draws, weights):
2279
+ """
2280
+ Calculate the cumulative Jensen-Shannon distance between original draws and weighted draws.
2281
+ """
2282
+
2283
+ # sort draws and weights
2284
+ order = np.argsort(draws)
2285
+ draws = draws[order]
2286
+ weights = weights[order]
2287
+
2288
+ binwidth = np.diff(draws)
2289
+
2290
+ # ecdfs
2291
+ cdf_p = np.linspace(1 / len(draws), 1 - 1 / len(draws), len(draws) - 1)
2292
+ cdf_q = np.cumsum(weights / np.sum(weights))[:-1]
2293
+
2294
+ # integrals of ecdfs
2295
+ cdf_p_int = np.dot(cdf_p, binwidth)
2296
+ cdf_q_int = np.dot(cdf_q, binwidth)
2297
+
2298
+ # cjs calculation
2299
+ pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
2300
+ qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)
2301
+
2302
+ denom = 0.5 * (cdf_p + cdf_q)
2303
+ denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)
2304
+
2305
+ cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
2306
+ cdf_q_int - cdf_p_int
2307
+ )
2308
+
2309
+ cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
2310
+ cdf_p_int - cdf_q_int
2311
+ )
2312
+
2313
+ cjs_pq = max(0, cjs_pq)
2314
+ cjs_qp = max(0, cjs_qp)
2315
+
2316
+ bound = cdf_p_int + cdf_q_int
2317
+
2318
+ return np.sqrt((cjs_pq + cjs_qp) / bound)
@@ -16,7 +16,7 @@ from ..utils import conditional_jit, conditional_vect, conditional_dask
16
16
  from .density_utils import histogram as _histogram
17
17
 
18
18
 
19
- __all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"]
19
+ __all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "smooth_data", "wrap_xarray_ufunc"]
20
20
 
21
21
 
22
22
  def autocov(ary, axis=-1):
@@ -409,7 +409,7 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar
409
409
  return nan_error | chain_error | draw_error
410
410
 
411
411
 
412
- def get_log_likelihood(idata, var_name=None):
412
+ def get_log_likelihood(idata, var_name=None, single_var=True):
413
413
  """Retrieve the log likelihood dataarray of a given variable."""
414
414
  if (
415
415
  not hasattr(idata, "log_likelihood")
@@ -426,9 +426,11 @@ def get_log_likelihood(idata, var_name=None):
426
426
  if var_name is None:
427
427
  var_names = list(idata.log_likelihood.data_vars)
428
428
  if len(var_names) > 1:
429
- raise TypeError(
430
- f"Found several log likelihood arrays {var_names}, var_name cannot be None"
431
- )
429
+ if single_var:
430
+ raise TypeError(
431
+ f"Found several log likelihood arrays {var_names}, var_name cannot be None"
432
+ )
433
+ return idata.log_likelihood[var_names]
432
434
  return idata.log_likelihood[var_names[0]]
433
435
  else:
434
436
  try:
@@ -482,7 +484,7 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
482
484
  base += "\n\nThere has been a warning during the calculation. Please check the results."
483
485
 
484
486
  if kind == "loo" and "pareto_k" in self:
485
- bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
487
+ bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
486
488
  counts, *_ = _histogram(self.pareto_k.values, bins)
487
489
  extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
488
490
  extended = extended.format(
@@ -562,7 +564,25 @@ def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, a
562
564
 
563
565
 
564
566
  def smooth_data(obs_vals, pp_vals):
565
- """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit."""
567
+ """Smooth data using a cubic spline.
568
+
569
+ Helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.
570
+
571
+ Parameters
572
+ ----------
573
+ obs_vals : (N) array-like
574
+ Observed data
575
+ pp_vals : (S, N) array-like
576
+ Posterior predictive samples. ``N`` is the number of observations,
577
+ and ``S`` is the number of samples (generally n_chains*n_draws).
578
+
579
+ Returns
580
+ -------
581
+ obs_vals : (N) ndarray
582
+ Smoothed observed data
583
+ pp_vals : (S, N) ndarray
584
+ Smoothed posterior predictive samples
585
+ """
566
586
  x = np.linspace(0, 1, len(obs_vals))
567
587
  csi = CubicSpline(x, obs_vals)
568
588
  obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
@@ -572,3 +592,12 @@ def smooth_data(obs_vals, pp_vals):
572
592
  pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
573
593
 
574
594
  return obs_vals, pp_vals
595
+
596
+
597
+ def get_log_prior(idata, var_names=None):
598
+ """Retrieve the log prior dataarray of a given variable."""
599
+ if not hasattr(idata, "log_prior"):
600
+ raise TypeError("log prior not found in inference data object")
601
+ if var_names is None:
602
+ var_names = list(idata.log_prior.data_vars)
603
+ return idata.log_prior[var_names]
@@ -496,7 +496,7 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
496
496
  with pytest.raises(KeyError):
497
497
  idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
498
498
 
499
- @pytest.mark.parametrize("use", ("del", "delattr"))
499
+ @pytest.mark.parametrize("use", ("del", "delattr", "delitem"))
500
500
  def test_del(self, use):
501
501
  # create inference data object
502
502
  data = np.random.normal(size=(4, 500, 8))
@@ -523,6 +523,8 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
523
523
  # Use del method
524
524
  if use == "del":
525
525
  del idata.sample_stats
526
+ elif use == "delitem":
527
+ del idata["sample_stats"]
526
528
  else:
527
529
  delattr(idata, "sample_stats")
528
530
 
@@ -763,6 +765,69 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
763
765
  )
764
766
  assert all(item in test_data.columns for item in ("chain", "draw"))
765
767
 
768
+ @pytest.mark.parametrize(
769
+ "kwargs",
770
+ (
771
+ {
772
+ "var_names": ["parameter_1", "parameter_2", "variable_1", "variable_2"],
773
+ "filter_vars": None,
774
+ "var_results": [
775
+ ("posterior", "parameter_1"),
776
+ ("posterior", "parameter_2"),
777
+ ("prior", "parameter_1"),
778
+ ("prior", "parameter_2"),
779
+ ("posterior", "variable_1"),
780
+ ("posterior", "variable_2"),
781
+ ],
782
+ },
783
+ {
784
+ "var_names": "parameter",
785
+ "filter_vars": "like",
786
+ "groups": "posterior",
787
+ "var_results": ["parameter_1", "parameter_2"],
788
+ },
789
+ {
790
+ "var_names": "~parameter",
791
+ "filter_vars": "like",
792
+ "groups": "posterior",
793
+ "var_results": ["variable_1", "variable_2", "custom_name"],
794
+ },
795
+ {
796
+ "var_names": [".+_2$", "custom_name"],
797
+ "filter_vars": "regex",
798
+ "groups": "posterior",
799
+ "var_results": ["parameter_2", "variable_2", "custom_name"],
800
+ },
801
+ {
802
+ "var_names": ["lp"],
803
+ "filter_vars": "regex",
804
+ "groups": "sample_stats",
805
+ "var_results": ["lp"],
806
+ },
807
+ ),
808
+ )
809
+ def test_to_dataframe_selection(self, kwargs):
810
+ results = kwargs.pop("var_results")
811
+ idata = from_dict(
812
+ posterior={
813
+ "parameter_1": np.random.randn(4, 100),
814
+ "parameter_2": np.random.randn(4, 100),
815
+ "variable_1": np.random.randn(4, 100),
816
+ "variable_2": np.random.randn(4, 100),
817
+ "custom_name": np.random.randn(4, 100),
818
+ },
819
+ prior={
820
+ "parameter_1": np.random.randn(4, 100),
821
+ "parameter_2": np.random.randn(4, 100),
822
+ },
823
+ sample_stats={
824
+ "lp": np.random.randn(4, 100),
825
+ },
826
+ )
827
+ test_data = idata.to_dataframe(**kwargs)
828
+ assert not test_data.empty
829
+ assert set(test_data.columns).symmetric_difference(results) == set(["chain", "draw"])
830
+
766
831
  def test_to_dataframe_bad(self):
767
832
  idata = from_dict(
768
833
  posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
@@ -781,6 +846,9 @@ class TestInferenceData: # pylint: disable=too-many-public-methods
781
846
  with pytest.raises(KeyError):
782
847
  idata.to_dataframe(groups=["invalid_group"])
783
848
 
849
+ with pytest.raises(ValueError):
850
+ idata.to_dataframe(var_names=["c"])
851
+
784
852
  @pytest.mark.parametrize("use", (None, "args", "kwargs"))
785
853
  def test_map(self, use):
786
854
  idata = load_arviz_data("centered_eight")
@@ -1173,7 +1241,7 @@ class TestDataDict:
1173
1241
  self.check_var_names_coords_dims(inference_data.prior_predictive)
1174
1242
  self.check_var_names_coords_dims(inference_data.sample_stats_prior)
1175
1243
 
1176
- pred_dims = inference_data.predictions.dims["school_pred"]
1244
+ pred_dims = inference_data.predictions.sizes["school_pred"]
1177
1245
  assert pred_dims == 8
1178
1246
 
1179
1247
  def test_inference_data_warmup(self, data, eight_schools_params):
@@ -1518,8 +1586,8 @@ class TestExtractDataset:
1518
1586
  idata = load_arviz_data("centered_eight")
1519
1587
  post = extract(idata, combined=False)
1520
1588
  assert "sample" not in post.dims
1521
- assert post.dims["chain"] == 4
1522
- assert post.dims["draw"] == 500
1589
+ assert post.sizes["chain"] == 4
1590
+ assert post.sizes["draw"] == 500
1523
1591
 
1524
1592
  def test_var_name_group(self):
1525
1593
  idata = load_arviz_data("centered_eight")
@@ -1539,5 +1607,5 @@ class TestExtractDataset:
1539
1607
  def test_subset_samples(self):
1540
1608
  idata = load_arviz_data("centered_eight")
1541
1609
  post = extract(idata, num_samples=10)
1542
- assert post.dims["sample"] == 10
1610
+ assert post.sizes["sample"] == 10
1543
1611
  assert post.attrs == idata.posterior.attrs
@@ -327,7 +327,6 @@ def test_plot_autocorr_var_names(models, var_names):
327
327
  "kwargs", [{"insample_dev": False}, {"plot_standard_error": False}, {"plot_ic_diff": False}]
328
328
  )
329
329
  def test_plot_compare(models, kwargs):
330
-
331
330
  model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
332
331
 
333
332
  axes = plot_compare(model_compare, backend="bokeh", show=False, **kwargs)
@@ -9,6 +9,7 @@ import pytest
9
9
  from matplotlib import animation
10
10
  from pandas import DataFrame
11
11
  from scipy.stats import gaussian_kde, norm
12
+ import xarray as xr
12
13
 
13
14
  from ...data import from_dict, load_arviz_data
14
15
  from ...plots import (
@@ -732,6 +733,28 @@ def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
732
733
  assert axes
733
734
 
734
735
 
736
+ def test_plot_ppc_transposed():
737
+ idata = load_arviz_data("rugby")
738
+ idata.map(
739
+ lambda ds: ds.assign(points=xr.concat((ds.home_points, ds.away_points), "field")),
740
+ groups="observed_vars",
741
+ inplace=True,
742
+ )
743
+ assert idata.posterior_predictive.points.dims == ("field", "chain", "draw", "match")
744
+ ax = plot_ppc(
745
+ idata,
746
+ kind="scatter",
747
+ var_names="points",
748
+ flatten=["field"],
749
+ coords={"match": ["Wales Italy"]},
750
+ random_seed=3,
751
+ num_pp_samples=8,
752
+ )
753
+ x, y = ax.get_lines()[2].get_data()
754
+ assert not np.isclose(y[0], 0)
755
+ assert np.all(np.array([40, 43, 10, 9]) == x)
756
+
757
+
735
758
  @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
736
759
  @pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
737
760
  @pytest.mark.parametrize("animated", [False, True])
@@ -1898,7 +1921,7 @@ def test_plot_ts(kwargs):
1898
1921
  dims={"y": ["obs_dim"], "z": ["pred_dim"]},
1899
1922
  )
1900
1923
 
1901
- ax = plot_ts(idata=idata, y="y", show=True, **kwargs)
1924
+ ax = plot_ts(idata=idata, y="y", **kwargs)
1902
1925
  assert np.all(ax)
1903
1926
 
1904
1927
 
@@ -10,8 +10,9 @@ from numpy.testing import (
10
10
  assert_array_equal,
11
11
  )
12
12
  from scipy.special import logsumexp
13
- from scipy.stats import linregress
13
+ from scipy.stats import linregress, norm, halfcauchy
14
14
  from xarray import DataArray, Dataset
15
+ from xarray_einstats.stats import XrContinuousRV
15
16
 
16
17
  from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
17
18
  from ...rcparams import rcParams
@@ -22,6 +23,7 @@ from ...stats import (
22
23
  hdi,
23
24
  loo,
24
25
  loo_pit,
26
+ psens,
25
27
  psislw,
26
28
  r2_score,
27
29
  summary,
@@ -829,3 +831,43 @@ def test_weight_predictions():
829
831
  assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
830
832
  new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
831
833
  assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)
834
+
835
+
836
+ @pytest.fixture(scope="module")
837
+ def psens_data():
838
+ non_centered_eight = load_arviz_data("non_centered_eight")
839
+ post = non_centered_eight.posterior
840
+ log_prior = {
841
+ "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
842
+ "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
843
+ "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
844
+ }
845
+ non_centered_eight.add_groups({"log_prior": log_prior})
846
+ return non_centered_eight
847
+
848
+
849
+ @pytest.mark.parametrize("component", ("prior", "likelihood"))
850
+ def test_priorsens_global(psens_data, component):
851
+ result = psens(psens_data, component=component)
852
+ assert "mu" in result
853
+ assert "theta" in result
854
+ assert "school" in result.theta_t.dims
855
+
856
+
857
+ def test_priorsens_var_names(psens_data):
858
+ result1 = psens(
859
+ psens_data, component="prior", component_var_names=["mu", "tau"], var_names=["mu", "tau"]
860
+ )
861
+ result2 = psens(psens_data, component="prior", var_names=["mu", "tau"])
862
+ for result in (result1, result2):
863
+ assert "theta" not in result
864
+ assert "mu" in result
865
+ assert "tau" in result
866
+ assert not np.isclose(result1.mu, result2.mu)
867
+
868
+
869
+ def test_priorsens_coords(psens_data):
870
+ result = psens(psens_data, component="likelihood", component_coords={"school": "Choate"})
871
+ assert "mu" in result
872
+ assert "theta" in result
873
+ assert "school" in result.theta_t.dims