arviz 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. arviz/__init__.py +3 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +26 -25
  8. arviz/data/io_cmdstan.py +1 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +1 -0
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/io_pystan.py +1 -2
  16. arviz/data/utils.py +1 -0
  17. arviz/plots/__init__.py +1 -0
  18. arviz/plots/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  20. arviz/plots/backends/bokeh/bpvplot.py +8 -2
  21. arviz/plots/backends/bokeh/compareplot.py +8 -4
  22. arviz/plots/backends/bokeh/densityplot.py +1 -0
  23. arviz/plots/backends/bokeh/distplot.py +1 -0
  24. arviz/plots/backends/bokeh/dotplot.py +1 -0
  25. arviz/plots/backends/bokeh/ecdfplot.py +1 -0
  26. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  27. arviz/plots/backends/bokeh/energyplot.py +1 -0
  28. arviz/plots/backends/bokeh/forestplot.py +2 -4
  29. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  30. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  31. arviz/plots/backends/bokeh/khatplot.py +1 -0
  32. arviz/plots/backends/bokeh/lmplot.py +1 -0
  33. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  34. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  35. arviz/plots/backends/bokeh/pairplot.py +1 -0
  36. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  37. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  38. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  39. arviz/plots/backends/bokeh/rankplot.py +1 -0
  40. arviz/plots/backends/bokeh/separationplot.py +1 -0
  41. arviz/plots/backends/bokeh/traceplot.py +1 -0
  42. arviz/plots/backends/bokeh/violinplot.py +1 -0
  43. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  44. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  45. arviz/plots/backends/matplotlib/compareplot.py +2 -1
  46. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  47. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  48. arviz/plots/backends/matplotlib/distplot.py +1 -0
  49. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  50. arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
  51. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  52. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  53. arviz/plots/backends/matplotlib/essplot.py +6 -5
  54. arviz/plots/backends/matplotlib/forestplot.py +3 -4
  55. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  56. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  57. arviz/plots/backends/matplotlib/khatplot.py +1 -0
  58. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  59. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  60. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  61. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  62. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  63. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  64. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  65. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  66. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  67. arviz/plots/backends/matplotlib/traceplot.py +2 -1
  68. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  69. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  70. arviz/plots/bfplot.py +7 -6
  71. arviz/plots/bpvplot.py +3 -2
  72. arviz/plots/compareplot.py +3 -2
  73. arviz/plots/densityplot.py +1 -0
  74. arviz/plots/distcomparisonplot.py +1 -0
  75. arviz/plots/dotplot.py +1 -0
  76. arviz/plots/ecdfplot.py +38 -112
  77. arviz/plots/elpdplot.py +2 -1
  78. arviz/plots/energyplot.py +1 -0
  79. arviz/plots/essplot.py +3 -2
  80. arviz/plots/forestplot.py +1 -0
  81. arviz/plots/hdiplot.py +1 -0
  82. arviz/plots/khatplot.py +1 -0
  83. arviz/plots/lmplot.py +1 -0
  84. arviz/plots/loopitplot.py +1 -0
  85. arviz/plots/mcseplot.py +1 -0
  86. arviz/plots/pairplot.py +2 -1
  87. arviz/plots/parallelplot.py +1 -0
  88. arviz/plots/plot_utils.py +1 -0
  89. arviz/plots/posteriorplot.py +1 -0
  90. arviz/plots/ppcplot.py +11 -5
  91. arviz/plots/rankplot.py +1 -0
  92. arviz/plots/separationplot.py +1 -0
  93. arviz/plots/traceplot.py +1 -0
  94. arviz/plots/tsplot.py +1 -0
  95. arviz/plots/violinplot.py +1 -0
  96. arviz/rcparams.py +1 -0
  97. arviz/sel_utils.py +1 -0
  98. arviz/static/css/style.css +2 -1
  99. arviz/stats/density_utils.py +4 -3
  100. arviz/stats/diagnostics.py +4 -4
  101. arviz/stats/ecdf_utils.py +166 -0
  102. arviz/stats/stats.py +16 -32
  103. arviz/stats/stats_refitting.py +1 -0
  104. arviz/stats/stats_utils.py +6 -2
  105. arviz/tests/base_tests/test_data.py +18 -4
  106. arviz/tests/base_tests/test_diagnostics.py +1 -0
  107. arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
  108. arviz/tests/base_tests/test_labels.py +1 -0
  109. arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
  110. arviz/tests/base_tests/test_stats.py +4 -4
  111. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  112. arviz/tests/base_tests/test_stats_utils.py +4 -3
  113. arviz/tests/base_tests/test_utils.py +3 -2
  114. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  115. arviz/tests/external_tests/test_data_pyro.py +3 -3
  116. arviz/tests/helpers.py +1 -1
  117. arviz/wrappers/__init__.py +1 -0
  118. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/METADATA +10 -9
  119. arviz-0.18.0.dist-info/RECORD +182 -0
  120. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/WHEEL +1 -1
  121. arviz-0.17.0.dist-info/RECORD +0 -180
  122. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
  123. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
arviz/plots/ecdfplot.py CHANGED
@@ -1,8 +1,10 @@
1
1
  """Plot ecdf or ecdf-difference plot with confidence bands."""
2
+
2
3
  import numpy as np
3
- from scipy.stats import uniform, binom
4
+ from scipy.stats import uniform
4
5
 
5
6
  from ..rcparams import rcParams
7
+ from ..stats.ecdf_utils import compute_ecdf, ecdf_confidence_band, _get_ecdf_points
6
8
  from .plot_utils import get_plotting_function
7
9
 
8
10
 
@@ -26,7 +28,7 @@ def plot_ecdf(
26
28
  show=None,
27
29
  backend=None,
28
30
  backend_kwargs=None,
29
- **kwargs
31
+ **kwargs,
30
32
  ):
31
33
  r"""Plot ECDF or ECDF-Difference Plot with Confidence bands.
32
34
 
@@ -48,6 +50,7 @@ def plot_ecdf(
48
50
  Values to compare to the original sample.
49
51
  cdf : callable, optional
50
52
  Cumulative distribution function of the distribution to compare the original sample.
53
+ The function must take as input a numpy array of draws from the distribution.
51
54
  difference : bool, default False
52
55
  If True then plot ECDF-difference plot otherwise ECDF plot.
53
56
  pit : bool, default False
@@ -180,75 +183,47 @@ def plot_ecdf(
180
183
  values = np.ravel(values)
181
184
  values.sort()
182
185
 
183
- ## This block computes gamma and uses it to get the upper and lower confidence bands
184
- ## Here we check if we want confidence bands or not
185
- if confidence_bands:
186
- ## If plotting PIT then we find the PIT values of sample.
187
- ## Basically here we generate the evaluation points(x) and find the PIT values.
188
- ## z is the evaluation point for our uniform distribution in compute_gamma()
189
- if pit:
190
- x = np.linspace(1 / npoints, 1, npoints)
191
- z = x
192
- ## Finding PIT for our sample
193
- probs = cdf(values) if cdf else compute_ecdf(values2, values) / len(values2)
194
- else:
195
- ## If not PIT use sample for plots and for evaluation points(x) use equally spaced
196
- ## points between minimum and maximum of sample
197
- ## For z we have used cdf(x)
198
- x = np.linspace(values[0], values[-1], npoints)
199
- z = cdf(x) if cdf else compute_ecdf(values2, x)
200
- probs = values
201
-
202
- n = len(values) # number of samples
203
- ## Computing gamma
204
- gamma = fpr if pointwise else compute_gamma(n, z, npoints, num_trials, fpr)
205
- ## Using gamma to get the confidence intervals
206
- lower, higher = get_lims(gamma, n, z)
207
-
208
- ## This block is for whether to plot ECDF or ECDF-difference
209
- if not difference:
210
- ## We store the coordinates of our ecdf in x_coord, y_coord
211
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
186
+ if pit:
187
+ eval_points = np.linspace(1 / npoints, 1, npoints)
188
+ if cdf:
189
+ sample = cdf(values)
212
190
  else:
213
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
214
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
215
- for i, x_i in enumerate(x):
216
- y_coord[i] = y_coord[i] - (
217
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
218
- )
219
-
220
- ## Similarly we subtract from the upper and lower bounds
221
- if pit:
222
- lower = lower - x
223
- higher = higher - x
224
- else:
225
- lower = lower - (cdf(x) if cdf else compute_ecdf(values2, x))
226
- higher = higher - (cdf(x) if cdf else compute_ecdf(values2, x))
227
-
191
+ sample = compute_ecdf(values2, values) / len(values2)
192
+ cdf_at_eval_points = eval_points
193
+ rvs = uniform(0, 1).rvs
228
194
  else:
229
- if pit:
230
- x = np.linspace(1 / npoints, 1, npoints)
231
- probs = cdf(values)
195
+ eval_points = np.linspace(values[0], values[-1], npoints)
196
+ sample = values
197
+ if confidence_bands or difference:
198
+ if cdf:
199
+ cdf_at_eval_points = cdf(eval_points)
200
+ else:
201
+ cdf_at_eval_points = compute_ecdf(values2, eval_points)
232
202
  else:
233
- x = np.linspace(values[0], values[-1], npoints)
234
- probs = values
203
+ cdf_at_eval_points = np.zeros_like(eval_points)
204
+ rvs = None
235
205
 
206
+ x_coord, y_coord = _get_ecdf_points(sample, eval_points, difference)
207
+
208
+ if difference:
209
+ y_coord -= cdf_at_eval_points
210
+
211
+ if confidence_bands:
212
+ ndraws = len(values)
213
+ band_kwargs = {"prob": 1 - fpr, "num_trials": num_trials, "rvs": rvs, "random_state": None}
214
+ band_kwargs["method"] = "pointwise" if pointwise else "simulated"
215
+ lower, higher = ecdf_confidence_band(ndraws, eval_points, cdf_at_eval_points, **band_kwargs)
216
+
217
+ if difference:
218
+ lower -= cdf_at_eval_points
219
+ higher -= cdf_at_eval_points
220
+ else:
236
221
  lower, higher = None, None
237
- ## This block is for whether to plot ECDF or ECDF-difference
238
- if not difference:
239
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
240
- else:
241
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
242
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
243
- for i, x_i in enumerate(x):
244
- y_coord[i] = y_coord[i] - (
245
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
246
- )
247
222
 
248
223
  ecdf_plot_args = dict(
249
224
  x_coord=x_coord,
250
225
  y_coord=y_coord,
251
- x_bands=x,
226
+ x_bands=eval_points,
252
227
  lower=lower,
253
228
  higher=higher,
254
229
  confidence_bands=confidence_bands,
@@ -260,7 +235,7 @@ def plot_ecdf(
260
235
  ax=ax,
261
236
  show=show,
262
237
  backend_kwargs=backend_kwargs,
263
- **kwargs
238
+ **kwargs,
264
239
  )
265
240
 
266
241
  if backend is None:
@@ -271,52 +246,3 @@ def plot_ecdf(
271
246
  ax = plot(**ecdf_plot_args)
272
247
 
273
248
  return ax
274
-
275
-
276
- def compute_ecdf(sample, z):
277
- """Compute ECDF.
278
-
279
- This function computes the ecdf value at the evaluation point
280
- or a sorted set of evaluation points.
281
- """
282
- return np.searchsorted(sample, z, side="right") / len(sample)
283
-
284
-
285
- def get_ecdf_points(x, probs, difference):
286
- """Compute the coordinates for the ecdf points using compute_ecdf."""
287
- y = compute_ecdf(probs, x)
288
-
289
- if not difference:
290
- x = np.insert(x, 0, x[0])
291
- y = np.insert(y, 0, 0)
292
- return x, y
293
-
294
-
295
- def compute_gamma(n, z, npoints=None, num_trials=1000, fpr=0.05):
296
- """Compute gamma for confidence interval calculation.
297
-
298
- This function simulates an adjusted value of gamma to account for multiplicity
299
- when forming an 1-fpr level confidence envelope for the ECDF of a sample.
300
- """
301
- if npoints is None:
302
- npoints = n
303
- gamma = []
304
- for _ in range(num_trials):
305
- unif_samples = uniform.rvs(0, 1, n)
306
- unif_samples = np.sort(unif_samples)
307
- gamma_m = 1000
308
- ## Can compute ecdf for all the z together or one at a time.
309
- f_z = compute_ecdf(unif_samples, z)
310
- f_z = compute_ecdf(unif_samples, z)
311
- gamma_m = 2 * min(
312
- np.amin(binom.cdf(n * f_z, n, z)), np.amin(1 - binom.cdf(n * f_z - 1, n, z))
313
- )
314
- gamma.append(gamma_m)
315
- return np.quantile(gamma, fpr)
316
-
317
-
318
- def get_lims(gamma, n, z):
319
- """Compute the simultaneous 1 - fpr level confidence bands."""
320
- lower = binom.ppf(gamma / 2, n, z)
321
- upper = binom.ppf(1 - gamma / 2, n, z)
322
- return lower / n, upper / n
arviz/plots/elpdplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot pointwise elpd estimations of inference data."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ..rcparams import rcParams
@@ -98,7 +99,7 @@ def plot_elpd(
98
99
  References
99
100
  ----------
100
101
  .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
101
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
102
+ cross-validation and WAIC https://arxiv.org/abs/1507.04544
102
103
 
103
104
  Examples
104
105
  --------
arviz/plots/energyplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot energy transition distribution in HMC inference."""
2
+
2
3
  import warnings
3
4
 
4
5
  from ..data import convert_to_dataset
arviz/plots/essplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot quantile or local effective sample sizes."""
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
4
5
 
@@ -202,8 +203,8 @@ def plot_ess(
202
203
 
203
204
  data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
204
205
  var_names = _var_names(var_names, data, filter_vars)
205
- n_draws = data.dims["draw"]
206
- n_samples = n_draws * data.dims["chain"]
206
+ n_draws = data.sizes["draw"]
207
+ n_samples = n_draws * data.sizes["chain"]
207
208
 
208
209
  ess_tail_dataset = None
209
210
  mean_ess = None
arviz/plots/forestplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Forest plot."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller, NoModelLabeller
4
5
  from ..rcparams import rcParams
arviz/plots/hdiplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot highest density intervals for regression data."""
2
+
2
3
  import warnings
3
4
 
4
5
  import numpy as np
arviz/plots/khatplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Pareto tail indices plot."""
2
+
2
3
  import logging
3
4
 
4
5
  import numpy as np
arviz/plots/lmplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot regression figure."""
2
+
2
3
  import warnings
3
4
  from numbers import Integral
4
5
  from itertools import repeat
arviz/plots/loopitplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot LOO-PIT predictive checks of inference data."""
2
+
2
3
  import numpy as np
3
4
  from scipy import stats
4
5
 
arviz/plots/mcseplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot quantile MC standard error."""
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
4
5
 
arviz/plots/pairplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot a scatter, kde and/or hexbin of sampled parameters."""
2
+
2
3
  import warnings
3
4
  from typing import List, Optional, Union
4
5
 
@@ -229,7 +230,7 @@ def plot_pair(
229
230
  )
230
231
 
231
232
  if gridsize == "auto":
232
- gridsize = int(dataset.dims["draw"] ** 0.35)
233
+ gridsize = int(dataset.sizes["draw"] ** 0.35)
233
234
 
234
235
  numvars = len(flat_var_names)
235
236
 
@@ -1,4 +1,5 @@
1
1
  """Parallel coordinates plot showing posterior points with and without divergences marked."""
2
+
2
3
  import numpy as np
3
4
  from scipy.stats import rankdata
4
5
 
arviz/plots/plot_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utilities for plotting."""
2
+
2
3
  import importlib
3
4
  import warnings
4
5
  from typing import Any, Dict
@@ -1,4 +1,5 @@
1
1
  """Plot posterior densities."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
arviz/plots/ppcplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Posterior/Prior predictive plot."""
2
+
2
3
  import logging
3
4
  import warnings
4
5
  from numbers import Integral
@@ -19,7 +20,7 @@ def plot_ppc(
19
20
  kind="kde",
20
21
  alpha=None,
21
22
  mean=True,
22
- observed=True,
23
+ observed=None,
23
24
  observed_rug=False,
24
25
  color=None,
25
26
  colors=None,
@@ -60,8 +61,9 @@ def plot_ppc(
60
61
  Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
61
62
  mean : bool, default True
62
63
  Whether or not to plot the mean posterior/prior predictive distribution.
63
- observed : bool, default True
64
- Whether or not to plot the observed data.
64
+ observed : bool, optional
65
+ Whether or not to plot the observed data. Defaults to True for ``group = posterior``
66
+ and False for ``group = prior``.
65
67
  observed_rug : bool, default False
66
68
  Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
67
69
  `True` and for kind `kde` or `cumulative`.
@@ -253,8 +255,12 @@ def plot_ppc(
253
255
 
254
256
  if group == "posterior":
255
257
  predictive_dataset = data.posterior_predictive
258
+ if observed is None:
259
+ observed = True
256
260
  elif group == "prior":
257
261
  predictive_dataset = data.prior_predictive
262
+ if observed is None:
263
+ observed = False
258
264
 
259
265
  if var_names is None:
260
266
  var_names = list(observed_data.data_vars)
@@ -264,11 +270,11 @@ def plot_ppc(
264
270
 
265
271
  if flatten_pp is None:
266
272
  if flatten is None:
267
- flatten_pp = list(predictive_dataset.dims.keys())
273
+ flatten_pp = list(predictive_dataset.dims)
268
274
  else:
269
275
  flatten_pp = flatten
270
276
  if flatten is None:
271
- flatten = list(observed_data.dims.keys())
277
+ flatten = list(observed_data.dims)
272
278
 
273
279
  if coords is None:
274
280
  coords = {}
arviz/plots/rankplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Histograms of ranked posterior draws, plotted for each chain."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -1,4 +1,5 @@
1
1
  """Separation plot for discrete outcome models."""
2
+
2
3
  import warnings
3
4
 
4
5
  import numpy as np
arviz/plots/traceplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot kde or histograms and values from MCMC samples."""
2
+
2
3
  import warnings
3
4
  from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence
4
5
 
arviz/plots/tsplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot timeseries data."""
2
+
2
3
  import warnings
3
4
  import numpy as np
4
5
 
arviz/plots/violinplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot posterior traces as violin plot."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
arviz/rcparams.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """ArviZ rcparams. Based on matplotlib's implementation."""
2
+
2
3
  import locale
3
4
  import logging
4
5
  import os
arviz/sel_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utilities for selecting and iterating on xarray objects."""
2
+
2
3
  from itertools import product, tee
3
4
 
4
5
  import numpy as np
@@ -302,7 +302,8 @@ dl.xr-attrs {
302
302
  grid-template-columns: 125px auto;
303
303
  }
304
304
 
305
- .xr-attrs dt, dd {
305
+ .xr-attrs dt,
306
+ .xr-attrs dd {
306
307
  padding: 0;
307
308
  margin: 0;
308
309
  float: left;
@@ -5,7 +5,8 @@ import warnings
5
5
  import numpy as np
6
6
  from scipy.fftpack import fft
7
7
  from scipy.optimize import brentq
8
- from scipy.signal import convolve, convolve2d, gaussian # pylint: disable=no-name-in-module
8
+ from scipy.signal import convolve, convolve2d
9
+ from scipy.signal.windows import gaussian
9
10
  from scipy.sparse import coo_matrix
10
11
  from scipy.special import ive # pylint: disable=no-name-in-module
11
12
 
@@ -231,8 +232,8 @@ def _fixed_point(t, N, k_sq, a_sq):
231
232
  Z. I. Botev, J. F. Grotowski, and D. P. Kroese.
232
233
  Ann. Statist. 38 (2010), no. 5, 2916--2957.
233
234
  """
234
- k_sq = np.asfarray(k_sq, dtype=np.float64)
235
- a_sq = np.asfarray(a_sq, dtype=np.float64)
235
+ k_sq = np.asarray(k_sq, dtype=np.float64)
236
+ a_sq = np.asarray(a_sq, dtype=np.float64)
236
237
 
237
238
  l = 7
238
239
  f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
@@ -457,10 +457,10 @@ def ks_summary(pareto_tail_indices):
457
457
  """
458
458
  _numba_flag = Numba.numba_flag
459
459
  if _numba_flag:
460
- bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
460
+ bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
461
461
  kcounts, *_ = _histogram(pareto_tail_indices, bins)
462
462
  else:
463
- kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
463
+ kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
464
464
  kprop = kcounts / len(pareto_tail_indices) * 100
465
465
  df_k = pd.DataFrame(
466
466
  dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
@@ -836,7 +836,7 @@ def _mcse_sd(ary):
836
836
  return np.nan
837
837
  ess = _ess_sd(ary)
838
838
  if _numba_flag:
839
- sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
839
+ sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
840
840
  else:
841
841
  sd = np.std(ary, ddof=1)
842
842
  fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
@@ -904,7 +904,7 @@ def _mc_error(ary, batches=5, circular=False):
904
904
  else:
905
905
  std = stats.circstd(ary, high=np.pi, low=-np.pi)
906
906
  elif _numba_flag:
907
- std = float(_sqrt(svar(ary), np.zeros(1)))
907
+ std = float(_sqrt(svar(ary), np.zeros(1)).item())
908
908
  else:
909
909
  std = np.std(ary)
910
910
  return std / np.sqrt(len(ary))
@@ -0,0 +1,166 @@
1
+ """Functions for evaluating ECDFs and their confidence bands."""
2
+
3
+ from typing import Any, Callable, Optional, Tuple
4
+ import warnings
5
+
6
+ import numpy as np
7
+ from scipy.stats import uniform, binom
8
+
9
+
10
+ def compute_ecdf(sample: np.ndarray, eval_points: np.ndarray) -> np.ndarray:
11
+ """Compute ECDF of the sorted `sample` at the evaluation points."""
12
+ return np.searchsorted(sample, eval_points, side="right") / len(sample)
13
+
14
+
15
+ def _get_ecdf_points(
16
+ sample: np.ndarray, eval_points: np.ndarray, difference: bool
17
+ ) -> Tuple[np.ndarray, np.ndarray]:
18
+ """Compute the coordinates for the ecdf points using compute_ecdf."""
19
+ x = eval_points
20
+ y = compute_ecdf(sample, eval_points)
21
+
22
+ if not difference and y[0] > 0:
23
+ x = np.insert(x, 0, x[0])
24
+ y = np.insert(y, 0, 0)
25
+ return x, y
26
+
27
+
28
+ def _simulate_ecdf(
29
+ ndraws: int,
30
+ eval_points: np.ndarray,
31
+ rvs: Callable[[int, Optional[Any]], np.ndarray],
32
+ random_state: Optional[Any] = None,
33
+ ) -> np.ndarray:
34
+ """Simulate ECDF at the `eval_points` using the given random variable sampler"""
35
+ sample = rvs(ndraws, random_state=random_state)
36
+ sample.sort()
37
+ return compute_ecdf(sample, eval_points)
38
+
39
+
40
+ def _fit_pointwise_band_probability(
41
+ ndraws: int,
42
+ ecdf_at_eval_points: np.ndarray,
43
+ cdf_at_eval_points: np.ndarray,
44
+ ) -> float:
45
+ """Compute the smallest marginal probability of a pointwise confidence band that
46
+ contains the ECDF."""
47
+ ecdf_scaled = (ndraws * ecdf_at_eval_points).astype(int)
48
+ prob_lower_tail = np.amin(binom.cdf(ecdf_scaled, ndraws, cdf_at_eval_points))
49
+ prob_upper_tail = np.amin(binom.sf(ecdf_scaled - 1, ndraws, cdf_at_eval_points))
50
+ prob_pointwise = 1 - 2 * min(prob_lower_tail, prob_upper_tail)
51
+ return prob_pointwise
52
+
53
+
54
+ def _get_pointwise_confidence_band(
55
+ prob: float, ndraws: int, cdf_at_eval_points: np.ndarray
56
+ ) -> Tuple[np.ndarray, np.ndarray]:
57
+ """Compute the `prob`-level pointwise confidence band."""
58
+ count_lower, count_upper = binom.interval(prob, ndraws, cdf_at_eval_points)
59
+ prob_lower = count_lower / ndraws
60
+ prob_upper = count_upper / ndraws
61
+ return prob_lower, prob_upper
62
+
63
+
64
+ def ecdf_confidence_band(
65
+ ndraws: int,
66
+ eval_points: np.ndarray,
67
+ cdf_at_eval_points: np.ndarray,
68
+ prob: float = 0.95,
69
+ method="simulated",
70
+ **kwargs,
71
+ ) -> Tuple[np.ndarray, np.ndarray]:
72
+ """Compute the `prob`-level confidence band for the ECDF.
73
+
74
+ Arguments
75
+ ---------
76
+ ndraws : int
77
+ Number of samples in the original dataset.
78
+ eval_points : np.ndarray
79
+ Points at which the ECDF is evaluated. If these are dependent on the sample
80
+ values, simultaneous confidence bands may not be correctly calibrated.
81
+ cdf_at_eval_points : np.ndarray
82
+ CDF values at the evaluation points.
83
+ prob : float, default 0.95
84
+ The target probability that a true ECDF lies within the confidence band.
85
+ method : string, default "simulated"
86
+ The method used to compute the confidence band. Valid options are:
87
+ - "pointwise": Compute the pointwise (i.e. marginal) confidence band.
88
+ - "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence band.
89
+ `rvs` must be provided.
90
+ rvs: callable, optional
91
+ A function that takes an integer `ndraws` and optionally the object passed to
92
+ `random_state` and returns an array of `ndraws` samples from the same distribution
93
+ as the original dataset. Required if `method` is "simulated" and variable is discrete.
94
+ num_trials : int, default 1000
95
+ The number of random ECDFs to generate for constructing simultaneous confidence bands
96
+ (if `method` is "simulated").
97
+ random_state : {None, int, `numpy.random.Generator`,
98
+ `numpy.random.RandomState`}, optional
99
+ If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
100
+ ``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
101
+ `Generator` instance, the instance is used.
102
+
103
+ Returns
104
+ -------
105
+ prob_lower : np.ndarray
106
+ Lower confidence band for the ECDF at the evaluation points.
107
+ prob_upper : np.ndarray
108
+ Upper confidence band for the ECDF at the evaluation points.
109
+ """
110
+ if not 0 < prob < 1:
111
+ raise ValueError(f"Invalid value for `prob`. Expected 0 < prob < 1, but got {prob}.")
112
+
113
+ if method == "pointwise":
114
+ prob_pointwise = prob
115
+ elif method == "simulated":
116
+ prob_pointwise = _simulate_simultaneous_ecdf_band_probability(
117
+ ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
118
+ )
119
+ else:
120
+ raise ValueError(f"Unknown method {method}. Valid options are 'pointwise' or 'simulated'.")
121
+
122
+ prob_lower, prob_upper = _get_pointwise_confidence_band(
123
+ prob_pointwise, ndraws, cdf_at_eval_points
124
+ )
125
+
126
+ return prob_lower, prob_upper
127
+
128
+
129
+ def _simulate_simultaneous_ecdf_band_probability(
130
+ ndraws: int,
131
+ eval_points: np.ndarray,
132
+ cdf_at_eval_points: np.ndarray,
133
+ prob: float = 0.95,
134
+ rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
135
+ num_trials: int = 1000,
136
+ random_state: Optional[Any] = None,
137
+ ) -> float:
138
+ """Estimate probability for simultaneous confidence band using simulation.
139
+
140
+ This function simulates the pointwise probability needed to construct pointwise
141
+ confidence bands that form a `prob`-level confidence envelope for the ECDF
142
+ of a sample.
143
+ """
144
+ if rvs is None:
145
+ warnings.warn(
146
+ "Assuming variable is continuous for calibration of pointwise bands. "
147
+ "If the variable is discrete, specify random variable sampler `rvs`.",
148
+ UserWarning,
149
+ )
150
+ # if variable continuous, we can calibrate the confidence band using a uniform
151
+ # distribution
152
+ rvs = uniform(0, 1).rvs
153
+ eval_points_sim = cdf_at_eval_points
154
+ else:
155
+ eval_points_sim = eval_points
156
+
157
+ probs_pointwise = np.empty(num_trials)
158
+ for i in range(num_trials):
159
+ ecdf_at_eval_points = _simulate_ecdf(
160
+ ndraws, eval_points_sim, rvs, random_state=random_state
161
+ )
162
+ prob_pointwise = _fit_pointwise_band_probability(
163
+ ndraws, ecdf_at_eval_points, cdf_at_eval_points
164
+ )
165
+ probs_pointwise[i] = prob_pointwise
166
+ return np.quantile(probs_pointwise, prob)